diff --git a/.gitignore b/.gitignore index df8c680efb2df4ec5850d58b035038c36afacf36..bfbd461b232f026844f5d7d6fd2b5fb7647fa941 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ project/boot project/build project/target target +/leon diff --git a/mytest/Abs.scala b/mytest/Abs.scala new file mode 100644 index 0000000000000000000000000000000000000000..5efb41797b5dc4401407c5fea7c3fc6a7591f4f2 --- /dev/null +++ b/mytest/Abs.scala @@ -0,0 +1,51 @@ +import leon.Utils._ + +object Abs { + + + def abs(tab: Map[Int, Int], size: Int): Map[Int, Int] = ({ + require(size <= 5 && isArray(tab, size)) + var k = 0 + var tabres = Map.empty[Int, Int] + (while(k < size) { + if(tab(k) < 0) + tabres = tabres.updated(k, -tab(k)) + else + tabres = tabres.updated(k, tab(k)) + k = k + 1 + }) invariant(isArray(tabres, k) && k >= 0 && k <= size && isPositive(tabres, k)) + tabres + }) ensuring(res => isArray(res, size) && isPositive(res, size)) + + def isPositive(a: Map[Int, Int], size: Int): Boolean = { + require(size <= 10 && isArray(a, size)) + def rec(i: Int): Boolean = { + require(i >= 0) + if(i >= size) + true + else { + if(a(i) < 0) + false + else + rec(i+1) + } + } + rec(0) + } + + def isArray(a: Map[Int, Int], size: Int): Boolean = { + + def rec(i: Int): Boolean = { + require(i >= 0) + if(i >= size) true else { + if(a.isDefinedAt(i)) rec(i+1) else false + } + } + + if(size < 0) + false + else + rec(0) + } + +} diff --git a/mytest/Add.scala b/mytest/Add.scala new file mode 100644 index 0000000000000000000000000000000000000000..5d125e3c72f29d96fdaa7176756efe11ca36b98a --- /dev/null +++ b/mytest/Add.scala @@ -0,0 +1,25 @@ +import leon.Utils._ + +/* VSTTE 2008 - Dafny paper */ + +object Add { + + def add(x : Int, y : Int): Int = ({ + var r = x + if(y < 0) { + var n = y + (while(n != 0) { + r = r - 1 + n = n + 1 + }) invariant(r == x + y - n && 0 <= -n) + } else { + var n = y + (while(n != 0) { + r = r + 1 + n = n - 1 + }) invariant(r == x + y - n && 0 <= n) + } + r + }) ensuring(_ == x+y) + +} diff --git a/mytest/Assign1.scala b/mytest/Assign1.scala new file mode 100644 index 0000000000000000000000000000000000000000..ef346a02c6f51f7a1fad605ef4a42cd11fe889a3 --- /dev/null +++ b/mytest/Assign1.scala @@ -0,0 +1,18 @@ +object Assign1 { + + def foo(): Int = { + var a = 0 + val tmp = a + 1 + a = a + 2 + a = a + 3 + a = a + 4 + //var j = 0 + //var sortedArray = Map.empty[Int, Int] + //val tmp = sortedArray(j) + //sortedArray = sortedArray.updated(j, sortedArray(j+1)) + //sortedArray = sortedArray.updated(j+1, tmp) + //sortedArray(j) + a + } + +} diff --git a/mytest/BinarySearch.scala b/mytest/BinarySearch.scala new file mode 100644 index 0000000000000000000000000000000000000000..92d4cf0667fe0175b80ae0da34679e991efcc633 --- /dev/null +++ b/mytest/BinarySearch.scala @@ -0,0 +1,73 @@ +import leon.Utils._ + +/* VSTTE 2008 - Dafny paper */ + +object BinarySearch { + + def binarySearch(a: Map[Int, Int], size: Int, key: Int): Int = ({ + require(isArray(a, size) && size < 5 && sorted(a, size, 0, size - 1)) + var low = 0 + var high = size - 1 + var res = -1 + + (while(low <= high && res == -1) { + val i = (high + low) / 2 + val v = a(i) + + if(v == key) + res = i + + if(v > key) + high = i - 1 + else if(v < key) + low = i + 1 + }) invariant(0 <= low && low <= high + 1 && high < size && (if(res >= 0) a(res) == key else (!occurs(a, 0, low, key) && !occurs(a, high + 1, size, key)))) + res + }) ensuring(res => res >= -1 && res < size && (if(res >= 0) a(res) == key else !occurs(a, 0, size, key))) + + + def occurs(a: Map[Int, Int], from: Int, to: Int, key: Int): Boolean = { + require(isArray(a, to) && to < 5 && from >= 0) + def rec(i: Int): Boolean = { + require(i >= 0) + if(i >= to) + false + else { + if(a(i) == key) true else rec(i+1) + } + } + if(from >= to) + false + else + rec(from) + } + + + def isArray(a: Map[Int, Int], size: Int): Boolean = { + + def rec(i: Int): Boolean = { + require(i >= 0) + if(i >= size) true else { + if(a.isDefinedAt(i)) rec(i+1) else false + } + } + + if(size < 0) + false + else + rec(0) + } + + def sorted(a: Map[Int,Int], size: Int, l: Int, u: Int) : Boolean = { + require(isArray(a, size) && size < 5 && l >= 0 && l <= u && u < size) + val t = sortedWhile(true, l, l, u, a, size) + t._1 + } + + def sortedWhile(isSorted: Boolean, k: Int, l: Int, u: Int, a: Map[Int,Int], size: Int) : (Boolean, Int) = { + require(isArray(a, size) && size < 5 && l >= 0 && l <= u && u < size && k >= l && k <= u) + if(k < u) { + sortedWhile(if(a(k) > a(k + 1)) false else isSorted, k + 1, l, u, a, size) + } else (isSorted, k) + } +} diff --git a/mytest/Bubble.scala b/mytest/Bubble.scala new file mode 100644 index 0000000000000000000000000000000000000000..4e274e2707b0cff262e712ab4b78b53eee0db990 --- /dev/null +++ b/mytest/Bubble.scala @@ -0,0 +1,132 @@ +import leon.Utils._ + +/* The calculus of Computation textbook */ + +object Bubble { + + def sort(a: Map[Int, Int], size: Int): Map[Int, Int] = ({ + require(size < 5 && isArray(a, size)) + var i = size - 1 + var j = 0 + var sortedArray = a + (while(i > 0) { + j = 0 + (while(j < i) { + if(sortedArray(j) > sortedArray(j+1)) { + val tmp = sortedArray(j) + sortedArray = sortedArray.updated(j, sortedArray(j+1)) + sortedArray = sortedArray.updated(j+1, tmp) + } + j = j + 1 + }) invariant( + j >= 0 && + j <= i && + i < size && + isArray(sortedArray, size) && + partitioned(sortedArray, size, 0, i, i+1, size-1) && + sorted(sortedArray, size, i, size-1) && + partitioned(sortedArray, size, 0, j-1, j, j) + ) + i = i - 1 + }) invariant( + i >= 0 && + i < size && + isArray(sortedArray, size) && + partitioned(sortedArray, size, 0, i, i+1, size-1) && + sorted(sortedArray, size, i, size-1) + ) + sortedArray + }) ensuring(res => sorted(res, size, 0, size-1)) + + def sorted(a: Map[Int, Int], size: Int, l: Int, u: Int): Boolean = { + require(isArray(a, size) && size < 5 && l >= 0 && u < size && l <= u) + var k = l + var isSorted = true + (while(k < u) { + if(a(k) > a(k+1)) + isSorted = false + k = k + 1 + }) invariant(k <= u && k >= l) + isSorted + } + /* + // --------------------- sorted -------------------- + def sorted(a: Map[Int,Int], size: Int, l: Int, u: Int) : Boolean = { + require(isArray(a, size) && size < 5 && l >= 0 && l <= u && u < size) + val t = sortedWhile(true, l, l, u, a, size) + t._1 + } + + def sortedWhile(isSorted: Boolean, k: Int, l: Int, u: Int, a: Map[Int,Int], size: Int) : (Boolean, Int) = { + require(isArray(a, size) && size < 5 && l >= 0 && l <= u && u < size && k >= l && k <= u) + if(k < u) { + sortedWhile(if(a(k) > a(k + 1)) false else isSorted, k + 1, l, u, a, size) + } else (isSorted, k) + } + */ + + /* + // ------------- partitioned ------------------ + def partitioned(a: Map[Int,Int], size: Int, l1: Int, u1: Int, l2: Int, u2: Int) : Boolean = { + require(isArray(a, size) && size < 5 && l1 >= 0 && u1 < l2 && u2 < size) + if(l2 > u2 || l1 > u1) + true + else { + val t = partitionedWhile(l2, true, l1, l1, size, u2, l2, u1, a) + t._2 + } + } + def partitionedWhile(j: Int, isPartitionned: Boolean, i: Int, l1: Int, size: Int, u2: Int, l2: Int, u1: Int, a: Map[Int,Int]) : (Int, Boolean, Int) = { + require(isArray(a, size) && size < 5 && l1 >= 0 && l1 <= u1 && u1 < l2 && l2 <= u2 && u2 < size && i >= l1) + + if(i <= u1) { + val t = partitionedNestedWhile(isPartitionned, l2, i, l1, u1, size, u2, a, l2) + partitionedWhile(t._2, t._1, i + 1, l1, size, u2, l2, u1, a) + } else (j, isPartitionned, i) + } + def partitionedNestedWhile(isPartitionned: Boolean, j: Int, i: Int, l1: Int, u1: Int, size: Int, u2: Int, a: Map[Int,Int], l2: Int): (Boolean, Int) = { + require(isArray(a, size) && size < 5 && l1 >= 0 && l1 <= u1 && u1 < l2 && l2 <= u2 && u2 < size && j >= l2 && i >= l1 && i <= u1) + + if (j <= u2) { + partitionedNestedWhile( + (if (a(i) > a(j)) + false + else + isPartitionned), + j + 1, i, l1, u1, size, u2, a, l2) + } else (isPartitionned, j) + } + */ + + def partitioned(a: Map[Int, Int], size: Int, l1: Int, u1: Int, l2: Int, u2: Int): Boolean = { + require(l1 >= 0 && u1 < l2 && u2 < size && isArray(a, size) && size < 5) + if(l2 > u2 || l1 > u1) + true + else { + var i = l1 + var j = l2 + var isPartitionned = true + (while(i <= u1) { + j = l2 + (while(j <= u2) { + if(a(i) > a(j)) + isPartitionned = false + j = j + 1 + }) invariant(j >= l2 && i <= u1) + i = i + 1 + }) invariant(i >= l1) + isPartitionned + } + } + + def isArray(a: Map[Int, Int], size: Int): Boolean = { + def rec(i: Int): Boolean = if(i >= size) true else { + if(a.isDefinedAt(i)) rec(i+1) else false + } + if(size <= 0) + false + else + rec(0) + } + +} diff --git a/mytest/BubbleFun.scala b/mytest/BubbleFun.scala new file mode 100644 index 0000000000000000000000000000000000000000..5962eea28e93cbd3d496ae9c7dd92af3a1297d15 --- /dev/null +++ b/mytest/BubbleFun.scala @@ -0,0 +1,153 @@ +object BubbleFun { + + // --------------------- sort ---------------------- + + def sort(a: Map[Int,Int], size: Int): Map[Int,Int] = ({ + require(isArray(a, size) && size < 5) + + val i = size - 1 + val t = sortWhile(0, a, i, size) + t._2 + }) ensuring(res => isArray(res, size) && sorted(res, size, 0, size-1) /*&& content(res, size) == content(a, size)*/) + + def sortWhile(j: Int, sortedArray: Map[Int,Int], i: Int, size: Int) : (Int, Map[Int,Int], Int) = ({ + require(i >= 0 && i < size && isArray(sortedArray, size) && size < 5 && + sorted(sortedArray, size, i, size - 1) && + partitioned(sortedArray, size, 0, i, i+1, size-1)) + + if (i > 0) { + val t = sortNestedWhile(sortedArray, 0, i, size) + sortWhile(t._2, t._1, i - 1, size) + } else (j, sortedArray, i) + }) ensuring(res => isArray(res._2, size) && + sorted(res._2, size, res._3, size - 1) && + partitioned(res._2, size, 0, res._3, res._3+1, size-1) && + res._3 >= 0 && res._3 <= 0 /*&& content(res._2, size) == content(sortedArray, size)*/ + ) + + + def sortNestedWhile(sortedArray: Map[Int,Int], j: Int, i: Int, size: Int) : (Map[Int,Int], Int) = ({ + require(j >= 0 && j <= i && i < size && isArray(sortedArray, size) && size < 5 && + sorted(sortedArray, size, i, size - 1) && + partitioned(sortedArray, size, 0, i, i+1, size-1) && + partitioned(sortedArray, size, 0, j-1, j, j)) + if(j < i) { + val newSortedArray = + if(sortedArray(j) > sortedArray(j + 1)) + sortedArray.updated(j, sortedArray(j + 1)).updated(j+1, sortedArray(j)) + else + sortedArray + sortNestedWhile(newSortedArray, j + 1, i, size) + } else (sortedArray, j) + }) ensuring(res => isArray(res._1, size) && + sorted(res._1, size, i, size - 1) && + partitioned(res._1, size, 0, i, i+1, size-1) && + partitioned(res._1, size, 0, res._2-1, res._2, res._2) && + res._2 >= i && res._2 >= 0 && res._2 <= i /*&& content(res._1, size) == content(sortedArray, size)*/) + + + //some intermediate results + def lemma1(a: Map[Int, Int], size: Int, i: Int): Boolean = ({ + require(isArray(a, size) && size < 5 && sorted(a, size, i, size-1) && partitioned(a, size, 0, i, i+1, size-1) && i >= 0 && i < size) + val t = sortNestedWhile(a, 0, i, size) + val newJ = t._2 + i == newJ + }) ensuring(_ == true) + def lemma2(a: Map[Int, Int], size: Int, i: Int): Boolean = ({ + require(isArray(a, size) && size < 5 && sorted(a, size, i, size-1) && partitioned(a, size, 0, i, i+1, size-1) && i >= 0 && i < size) + val t = sortNestedWhile(a, 0, i, size) + val newA = t._1 + val newJ = t._2 + partitioned(newA, size, 0, i-1, i, i) + }) ensuring(_ == true) + def lemma3(a: Map[Int, Int], size: Int, i: Int): Boolean = ({ + require(partitioned(a, size, 0, i, i+1, size-1) && partitioned(a, size, 0, i-1, i, i) && isArray(a, size) && size < 5 && i >= 0 && i < size) + partitioned(a, size, 0, i-1, i, size-1) + }) ensuring(_ == true) + def lemma4(a: Map[Int, Int], size: Int, i: Int): Boolean = ({ + require(isArray(a, size) && size < 5 && sorted(a, size, i, size-1) && partitioned(a, size, 0, i, i+1, size-1) && i >= 0 && i < size) + val t = sortNestedWhile(a, 0, i, size) + val newA = t._1 + partitioned(newA, size, 0, i-1, i, size-1) + }) ensuring(_ == true) + + + // --------------------- sorted -------------------- + def sorted(a: Map[Int,Int], size: Int, l: Int, u: Int) : Boolean = { + require(isArray(a, size) && size < 5 && l >= 0 && l <= u && u < size) + val t = sortedWhile(true, l, l, u, a, size) + t._1 + } + + def sortedWhile(isSorted: Boolean, k: Int, l: Int, u: Int, a: Map[Int,Int], size: Int) : (Boolean, Int) = { + require(isArray(a, size) && size < 5 && l >= 0 && l <= u && u < size && k >= l && k <= u) + if(k < u) { + sortedWhile(if(a(k) > a(k + 1)) false else isSorted, k + 1, l, u, a, size) + } else (isSorted, k) + } + + + + // ------------- partitioned ------------------ + def partitioned(a: Map[Int,Int], size: Int, l1: Int, u1: Int, l2: Int, u2: Int) : Boolean = { + require(isArray(a, size) && size < 5 && l1 >= 0 && u1 < l2 && u2 < size) + if(l2 > u2 || l1 > u1) + true + else { + val t = partitionedWhile(l2, true, l1, l1, size, u2, l2, u1, a) + t._2 + } + } + def partitionedWhile(j: Int, isPartitionned: Boolean, i: Int, l1: Int, size: Int, u2: Int, l2: Int, u1: Int, a: Map[Int,Int]) : (Int, Boolean, Int) = { + require(isArray(a, size) && size < 5 && l1 >= 0 && l1 <= u1 && u1 < l2 && l2 <= u2 && u2 < size && i >= l1) + + if(i <= u1) { + val t = partitionedNestedWhile(isPartitionned, l2, i, l1, u1, size, u2, a, l2) + partitionedWhile(t._2, t._1, i + 1, l1, size, u2, l2, u1, a) + } else (j, isPartitionned, i) + } + def partitionedNestedWhile(isPartitionned: Boolean, j: Int, i: Int, l1: Int, u1: Int, size: Int, u2: Int, a: Map[Int,Int], l2: Int): (Boolean, Int) = { + require(isArray(a, size) && size < 5 && l1 >= 0 && l1 <= u1 && u1 < l2 && l2 <= u2 && u2 < size && j >= l2 && i >= l1 && i <= u1) + + if (j <= u2) { + partitionedNestedWhile( + (if (a(i) > a(j)) + false + else + isPartitionned), + j + 1, i, l1, u1, size, u2, a, l2) + } else (isPartitionned, j) + } + + + //------------ isArray ------------------- + def isArray(a: Map[Int,Int], size: Int): Boolean = + if(size <= 0) + false + else + isArrayRec(0, size, a) + + def isArrayRec(i: Int, size: Int, a: Map[Int,Int]): Boolean = + if (i >= size) + true + else { + if(a.isDefinedAt(i)) + isArrayRec(i + 1, size, a) + else + false + } + + + // ----------------- content ------------------ + def content(a: Map[Int, Int], size: Int): Set[Int] = { + require(isArray(a, size) && size < 5) + var i = 0 + var s = Set.empty[Int] + while(i < size) { + s = s ++ Set(a(i)) + i = i + 1 + } + s + } + +} diff --git a/mytest/BubbleWeakInvariant.scala b/mytest/BubbleWeakInvariant.scala new file mode 100644 index 0000000000000000000000000000000000000000..d335cb81ffce8831f0dc1c203ab47c227d761e7f --- /dev/null +++ b/mytest/BubbleWeakInvariant.scala @@ -0,0 +1,132 @@ +import leon.Utils._ + +/* The calculus of Computation textbook */ + +object Bubble { + + def sort(a: Map[Int, Int], size: Int): Map[Int, Int] = ({ + require(size < 5 && isArray(a, size)) + var i = size - 1 + var j = 0 + var sortedArray = a + (while(i > 0) { + j = 0 + (while(j < i) { + if(sortedArray(j) > sortedArray(j+1)) { + val tmp = sortedArray(j) + sortedArray = sortedArray.updated(j, sortedArray(j+1)) + sortedArray = sortedArray.updated(j+1, tmp) + } + j = j + 1 + }) invariant( + j >= 0 && + j <= i && + i < size && + isArray(sortedArray, size) && + partitioned(sortedArray, size, 0, i, i+1, size-1) && + //partitioned(sortedArray, size, 0, j-1, j, j) && + sorted(sortedArray, size, i, size-1) + ) + i = i - 1 + }) invariant( + i >= 0 && + i < size && + isArray(sortedArray, size) && + partitioned(sortedArray, size, 0, i, i+1, size-1) && + sorted(sortedArray, size, i, size-1) + ) + sortedArray + }) ensuring(res => sorted(res, size, 0, size-1)) + + def sorted(a: Map[Int, Int], size: Int, l: Int, u: Int): Boolean = { + require(isArray(a, size) && size < 5 && l >= 0 && u < size && l <= u) + var k = l + var isSorted = true + (while(k < u) { + if(a(k) > a(k+1)) + isSorted = false + k = k + 1 + }) invariant(k <= u && k >= l) + isSorted + } + /* + // --------------------- sorted -------------------- + def sorted(a: Map[Int,Int], size: Int, l: Int, u: Int) : Boolean = { + require(isArray(a, size) && size < 5 && l >= 0 && l <= u && u < size) + val t = sortedWhile(true, l, l, u, a, size) + t._1 + } + + def sortedWhile(isSorted: Boolean, k: Int, l: Int, u: Int, a: Map[Int,Int], size: Int) : (Boolean, Int) = { + require(isArray(a, size) && size < 5 && l >= 0 && l <= u && u < size && k >= l && k <= u) + if(k < u) { + sortedWhile(if(a(k) > a(k + 1)) false else isSorted, k + 1, l, u, a, size) + } else (isSorted, k) + } + */ + + /* + // ------------- partitioned ------------------ + def partitioned(a: Map[Int,Int], size: Int, l1: Int, u1: Int, l2: Int, u2: Int) : Boolean = { + require(isArray(a, size) && size < 5 && l1 >= 0 && u1 < l2 && u2 < size) + if(l2 > u2 || l1 > u1) + true + else { + val t = partitionedWhile(l2, true, l1, l1, size, u2, l2, u1, a) + t._2 + } + } + def partitionedWhile(j: Int, isPartitionned: Boolean, i: Int, l1: Int, size: Int, u2: Int, l2: Int, u1: Int, a: Map[Int,Int]) : (Int, Boolean, Int) = { + require(isArray(a, size) && size < 5 && l1 >= 0 && l1 <= u1 && u1 < l2 && l2 <= u2 && u2 < size && i >= l1) + + if(i <= u1) { + val t = partitionedNestedWhile(isPartitionned, l2, i, l1, u1, size, u2, a, l2) + partitionedWhile(t._2, t._1, i + 1, l1, size, u2, l2, u1, a) + } else (j, isPartitionned, i) + } + def partitionedNestedWhile(isPartitionned: Boolean, j: Int, i: Int, l1: Int, u1: Int, size: Int, u2: Int, a: Map[Int,Int], l2: Int): (Boolean, Int) = { + require(isArray(a, size) && size < 5 && l1 >= 0 && l1 <= u1 && u1 < l2 && l2 <= u2 && u2 < size && j >= l2 && i >= l1 && i <= u1) + + if (j <= u2) { + partitionedNestedWhile( + (if (a(i) > a(j)) + false + else + isPartitionned), + j + 1, i, l1, u1, size, u2, a, l2) + } else (isPartitionned, j) + } + */ + + def partitioned(a: Map[Int, Int], size: Int, l1: Int, u1: Int, l2: Int, u2: Int): Boolean = { + require(l1 >= 0 && u1 < l2 && u2 < size && isArray(a, size) && size < 5) + if(l2 > u2 || l1 > u1) + true + else { + var i = l1 + var j = l2 + var isPartitionned = true + (while(i <= u1) { + j = l2 + (while(j <= u2) { + if(a(i) > a(j)) + isPartitionned = false + j = j + 1 + }) invariant(j >= l2 && i <= u1) + i = i + 1 + }) invariant(i >= l1) + isPartitionned + } + } + + def isArray(a: Map[Int, Int], size: Int): Boolean = { + def rec(i: Int): Boolean = if(i >= size) true else { + if(a.isDefinedAt(i)) rec(i+1) else false + } + if(size <= 0) + false + else + rec(0) + } + +} diff --git a/mytest/Capture.scala b/mytest/Capture.scala new file mode 100644 index 0000000000000000000000000000000000000000..c1c4c7e1ef466b27f8232afbb67d22e98fb48077 --- /dev/null +++ b/mytest/Capture.scala @@ -0,0 +1,10 @@ +object Capture { + + def foo(i: Int): Int = { + val a = 3 + def rec(j: Int): Int = if(j == a) 0 else 1 + rec(3) + } +} + +// vim: set ts=4 sw=4 et: diff --git a/mytest/IfExpr1.scala b/mytest/IfExpr1.scala new file mode 100644 index 0000000000000000000000000000000000000000..82db13d5285bc3c779010fd3e62b6e9e68e84e84 --- /dev/null +++ b/mytest/IfExpr1.scala @@ -0,0 +1,13 @@ +object IfExpr1 { + + def foo(): Int = { + var a = 1 + var b = 2 + if({a = a + 1; a != b}) + a = a + 3 + else + b = a + b + a + } ensuring(_ == 2) + +} diff --git a/mytest/LinearSearch.scala b/mytest/LinearSearch.scala new file mode 100644 index 0000000000000000000000000000000000000000..e91a3f2679b33b8b35e7d64989db1b7ade451d14 --- /dev/null +++ b/mytest/LinearSearch.scala @@ -0,0 +1,53 @@ +import leon.Utils._ + +/* The calculus of Computation textbook */ + +object LinearSearch { + + def linearSearch(a: Map[Int, Int], size: Int, c: Int): Boolean = ({ + require(size <= 5 && isArray(a, size)) + var i = 0 + var found = false + (while(i < size) { + if(a(i) == c) + found = true + i = i + 1 + }) invariant(i <= size && + i >= 0 && + (if(found) contains(a, i, c) else !contains(a, i, c)) + ) + found + }) ensuring(res => if(res) contains(a, size, c) else !contains(a, size, c)) + + def contains(a: Map[Int, Int], size: Int, c: Int): Boolean = { + require(isArray(a, size) && size <= 5) + content(a, size).contains(c) + } + + + def content(a: Map[Int, Int], size: Int): Set[Int] = { + require(isArray(a, size) && size <= 5) + var set = Set.empty[Int] + var i = 0 + (while(i < size) { + set = set ++ Set(a(i)) + i = i + 1 + }) invariant(i >= 0 && i <= size) + set + } + + def isArray(a: Map[Int, Int], size: Int): Boolean = { + + def rec(i: Int): Boolean = { + require(i >= 0) + if(i >= size) true else { + if(a.isDefinedAt(i)) rec(i+1) else false + } + } + + if(size < 0) + false + else + rec(0) + } +} diff --git a/mytest/Match.scala b/mytest/Match.scala new file mode 100644 index 0000000000000000000000000000000000000000..5376cc61d1df009da2c8ba09ea2c57c3bf248cd7 --- /dev/null +++ b/mytest/Match.scala @@ -0,0 +1,18 @@ +object Match { + + sealed abstract class A + case class B(b: Int) extends A + case class C(c: Int) extends A + + def foo(a: A): Int = ({ + + var i = 0 + var j = 0 + + {i = i + 1; a} match { + case B(b) => {i = i + 1; b} + case C(c) => {j = j + 1; i = i + 1; c} + } + i + }) ensuring(_ == 2) +} diff --git a/mytest/MaxSum.scala b/mytest/MaxSum.scala new file mode 100644 index 0000000000000000000000000000000000000000..07c5d12ab2450cc95af964c7078759a13fe8f0cb --- /dev/null +++ b/mytest/MaxSum.scala @@ -0,0 +1,70 @@ +import leon.Utils._ + +/* VSTTE 2010 challenge 1 */ + +object MaxSum { + + + def maxSum(a: Map[Int, Int], size: Int): (Int, Int) = ({ + require(isArray(a, size) && size < 5 && isPositive(a, size)) + var sum = 0 + var max = 0 + var i = 0 + (while(i < size) { + if(max < a(i)) + max = a(i) + sum = sum + a(i) + i = i + 1 + }) invariant (sum <= i * max && /*isGreaterEq(a, i, max) &&*/ /*(if(i == 0) max == 0 else true) &&*/ i >= 0 && i <= size) + (sum, max) + }) ensuring(res => res._1 <= size * res._2) + +/* + def isGreaterEq(a: Map[Int, Int], size: Int, n: Int): Boolean = { + require(size <= 5 && isArray(a, size)) + def rec(i: Int): Boolean = { + require(i >= 0) + if(i >= size) + true + else { + if(a(i) > n) + false + else + rec(i+1) + } + } + rec(0) + } + */ + + def isPositive(a: Map[Int, Int], size: Int): Boolean = { + require(size <= 5 && isArray(a, size)) + def rec(i: Int): Boolean = { + require(i >= 0) + if(i >= size) + true + else { + if(a(i) < 0) + false + else + rec(i+1) + } + } + rec(0) + } + + def isArray(a: Map[Int, Int], size: Int): Boolean = { + + def rec(i: Int): Boolean = { + require(i >= 0) + if(i >= size) true else { + if(a.isDefinedAt(i)) rec(i+1) else false + } + } + + if(size < 0) + false + else + rec(0) + } +} diff --git a/mytest/Mult.scala b/mytest/Mult.scala new file mode 100644 index 0000000000000000000000000000000000000000..417e9e9cdc74440ccac679b2d55bac33a6381e4a --- /dev/null +++ b/mytest/Mult.scala @@ -0,0 +1,26 @@ +import leon.Utils._ + +/* VSTTE 2008 - Dafny paper */ + +object Mult { + + def mult(x : Int, y : Int): Int = ({ + var r = 0 + if(y < 0) { + var n = y + (while(n != 0) { + r = r - x + n = n + 1 + }) invariant(r == x * (y - n) && 0 <= -n) + } else { + var n = y + (while(n != 0) { + r = r + x + n = n - 1 + }) invariant(r == x * (y - n) && 0 <= n) + } + r + }) ensuring(_ == x*y) + +} + diff --git a/mytest/MyTuple1.scala b/mytest/MyTuple1.scala deleted file mode 100644 index 029eff9907d1053e7ec7aaff9c1f0bb659be616b..0000000000000000000000000000000000000000 --- a/mytest/MyTuple1.scala +++ /dev/null @@ -1,11 +0,0 @@ -object MyTuple1 { - - def foo(): Int = { - val t = (1, true) - val a1 = t._1 - val a2 = t._2 - a1 - } ensuring( _ > 0) - -} - diff --git a/mytest/MyTuple2.scala b/mytest/MyTuple2.scala deleted file mode 100644 index bc4289ffc26397f61eaf4a377b531b3a4d6929f4..0000000000000000000000000000000000000000 --- a/mytest/MyTuple2.scala +++ /dev/null @@ -1,10 +0,0 @@ -object MyTuple2 { - - def foo(): Int = { - val t = (1, true) - val a1 = t._1 - val a2 = t._2 - a1 - } ensuring( _ < 0) - -} diff --git a/mytest/MyTupleWrong.scala b/mytest/MyTupleWrong.scala deleted file mode 100644 index 23368d31ba73e2706ca98e87d663177790afed45..0000000000000000000000000000000000000000 --- a/mytest/MyTupleWrong.scala +++ /dev/null @@ -1,14 +0,0 @@ -object MyTupleWrong { - - def foo(): Int = { - val t = (1, true) - val a1 = t._1 - val a2 = t._2 - val a3 = t._3 - a1 - } - -} - - -// vim: set ts=4 sw=4 et: diff --git a/mytest/NAryOp.scala b/mytest/NAryOp.scala new file mode 100644 index 0000000000000000000000000000000000000000..adf6c84bfff520ebe720011612b52ec66d9c12dc --- /dev/null +++ b/mytest/NAryOp.scala @@ -0,0 +1,12 @@ +object NAryOp { + + def foo(): Int = ({ + + var a = 2 + bar({a = a + 1; a}, {a = 5 - a; a}, {a = a + 2; a}) + }) ensuring(_ == 9) + + + def bar(i1: Int, i2: Int, i3: Int): Int = i1 + i2 + i3 + +} diff --git a/mytest/Plus.scala b/mytest/Plus.scala new file mode 100644 index 0000000000000000000000000000000000000000..b0f08f1686c15b7126231fcf0f94e194cf7eb248 --- /dev/null +++ b/mytest/Plus.scala @@ -0,0 +1,13 @@ +object Plus { + + def foo(): Int = ({ + + var a = 2 + var b = 1 + + a = {b = b + 2; a = a + 1; a} + {a = 5 - a; a} + a + b + }) ensuring(_ == 8) + + +} diff --git a/mytest/ValSideEffect.scala b/mytest/ValSideEffect.scala new file mode 100644 index 0000000000000000000000000000000000000000..5beefdaffa48d1cdff604c70dcfba5e2fa481db5 --- /dev/null +++ b/mytest/ValSideEffect.scala @@ -0,0 +1,17 @@ +object ValSideEffect { + + def foo(): Int = ({ + + var a = 2 + var a2 = 1 + + val b = {a = a + 1; a2 = a2 + 1; a} + {a = 5 - a; a} + a = a + 1 + a2 = a2 + 3 + a + a2 + b + }) ensuring(_ == 13) + + +} + +// vim: set ts=4 sw=4 et: diff --git a/mytest/While1.scala b/mytest/While1.scala new file mode 100644 index 0000000000000000000000000000000000000000..d7ab085dec990927da0afc226a4417d682b62868 --- /dev/null +++ b/mytest/While1.scala @@ -0,0 +1,13 @@ +object While1 { + + def foo(): Int = { + var a = 0 + var i = 0 + while(i < 10) { + a = a + i + i = i + 1 + } + a + } + +} diff --git a/mytest/While2.scala b/mytest/While2.scala new file mode 100644 index 0000000000000000000000000000000000000000..de724d9baf6885b51b3801dcd6fb0c5826acaa81 --- /dev/null +++ b/mytest/While2.scala @@ -0,0 +1,15 @@ +object While1 { + + def foo(): Int = { + var a = 0 + var i = 0 + while({i = i+2; i <= 10}) { + a = a + i + i = i - 1 + } + a + } ensuring(_ == 54) + +} + +// vim: set ts=4 sw=4 et: diff --git a/mytest/WhileTest.scala b/mytest/WhileTest.scala new file mode 100644 index 0000000000000000000000000000000000000000..c080552e54f5861c1ad368ce4c132e957e9e8554 --- /dev/null +++ b/mytest/WhileTest.scala @@ -0,0 +1,20 @@ +import leon.Utils._ + +object WhileTest { +// object InvariantFunction { +// def invariant(x: Boolean): Unit = () +// } +// implicit def while2Invariant(u: Unit) = InvariantFunction + def foo(x : Int) : Int = { + require(x >= 0) + + var y : Int = x + + (while (y >= 0) { + y = y - 1 + // assert(y >= -1) + }) invariant(y >= -1) + + y + 1 + } ensuring(_ == 0) +} diff --git a/src/main/scala/leon/Analysis.scala b/src/main/scala/leon/Analysis.scala index 40468a4bef174cd93278452a038cc2b711710328..6458079be4ef662a49bd25ea2f118f8df73ce56f 100644 --- a/src/main/scala/leon/Analysis.scala +++ b/src/main/scala/leon/Analysis.scala @@ -7,7 +7,7 @@ import purescala.TypeTrees._ import Extensions._ import scala.collection.mutable.{Set => MutableSet} -class Analysis(val program: Program, val reporter: Reporter = Settings.reporter) { +class Analysis(val program : Program, val reporter: Reporter = Settings.reporter) { Extensions.loadAll(reporter) val analysisExtensions: Seq[Analyser] = loadedAnalysisExtensions @@ -71,6 +71,11 @@ class Analysis(val program: Program, val reporter: Reporter = Settings.reporter) allVCs ++= tactic.generatePostconditions(funDef).sortWith(vcSort) allVCs ++= tactic.generateMiscCorrectnessConditions(funDef).sortWith(vcSort) } + allVCs = allVCs.sortWith((vc1, vc2) => { + val id1 = vc1.funDef.id.name + val id2 = vc2.funDef.id.name + if(id1 != id2) id1 < id2 else vc1 < vc2 + }) } val notFound: Set[String] = Settings.functionsToAnalyse -- analysedFunctions diff --git a/src/main/scala/leon/DefaultTactic.scala b/src/main/scala/leon/DefaultTactic.scala index 6f31b9ff59bc4b4681e7284c4dc3479400585534..a01d025431c2ec3c704a1950410334fd322abae9 100644 --- a/src/main/scala/leon/DefaultTactic.scala +++ b/src/main/scala/leon/DefaultTactic.scala @@ -78,7 +78,10 @@ class DefaultTactic(reporter: Reporter) extends Tactic(reporter) { expr2 } } - Seq(new VerificationCondition(theExpr, functionDefinition, VCKind.Postcondition, this.asInstanceOf[DefaultTactic])) + if(functionDefinition.fromLoop) + Seq(new VerificationCondition(theExpr, functionDefinition.parent.get, VCKind.InvariantPost, this.asInstanceOf[DefaultTactic]).setPosInfo(functionDefinition)) + else + Seq(new VerificationCondition(theExpr, functionDefinition, VCKind.Postcondition, this.asInstanceOf[DefaultTactic])) } } @@ -107,11 +110,18 @@ class DefaultTactic(reporter: Reporter) extends Tactic(reporter) { val newBody : Expr = replace(substMap, prec) val newCall : Expr = (newLetIDs zip args).foldRight(newBody)((iap, e) => Let(iap._1, iap._2, e)) - new VerificationCondition( - withPrecIfDefined(path, newCall), - function, - VCKind.Precondition, - this.asInstanceOf[DefaultTactic]).setPosInfo(fi) + if(fd.fromLoop) + new VerificationCondition( + withPrecIfDefined(path, newCall), + fd.parent.get, + if(fd == function) VCKind.InvariantInd else VCKind.InvariantInit, + this.asInstanceOf[DefaultTactic]).setPosInfo(fd) + else + new VerificationCondition( + withPrecIfDefined(path, newCall), + function, + VCKind.Precondition, + this.asInstanceOf[DefaultTactic]).setPosInfo(fi) }).toSeq } else { Seq.empty @@ -141,7 +151,7 @@ class DefaultTactic(reporter: Reporter) extends Tactic(reporter) { allPathConds.map(pc => new VerificationCondition( withPrecIfDefined(pc._1), - function, + if(function.fromLoop) function.parent.get else function, VCKind.ExhaustiveMatch, this.asInstanceOf[DefaultTactic]).setPosInfo(pc._2.asInstanceOf[Error]) ).toSeq @@ -173,7 +183,7 @@ class DefaultTactic(reporter: Reporter) extends Tactic(reporter) { allPathConds.map(pc => new VerificationCondition( withPrecIfDefined(pc._1), - function, + if(function.fromLoop) function.parent.get else function, VCKind.MapAccess, this.asInstanceOf[DefaultTactic]).setPosInfo(pc._2.asInstanceOf[Error]) ).toSeq diff --git a/src/main/scala/leon/Evaluator.scala b/src/main/scala/leon/Evaluator.scala index bd01412c2606d1e86dec168c8c493b2b64f9de7e..72c62ef152b58b54a563ceb9e238c129af5a565c 100644 --- a/src/main/scala/leon/Evaluator.scala +++ b/src/main/scala/leon/Evaluator.scala @@ -51,6 +51,14 @@ object Evaluator { throw RuntimeErrorEx("No value for identifier " + id.name + " in context.") } } + case Tuple(ts) => { + val tsRec = ts.map(rec(ctx, _)) + Tuple(tsRec) + } + case TupleSelect(t, i) => { + val Tuple(rs) = rec(ctx, t) + rs(i-1) + } case Let(i,e,b) => { val first = rec(ctx, e) rec(ctx + ((i -> first)), b) diff --git a/src/main/scala/leon/FairZ3Solver.scala b/src/main/scala/leon/FairZ3Solver.scala index 280265426a29c3c89034829d1642457f615490c6..d70d665fdb4e1281e490ab7640794c9b36360373 100644 --- a/src/main/scala/leon/FairZ3Solver.scala +++ b/src/main/scala/leon/FairZ3Solver.scala @@ -175,6 +175,7 @@ class FairZ3Solver(reporter: Reporter) extends Solver(reporter) with AbstractZ3S } case class UntranslatableTypeException(msg: String) extends Exception(msg) + // Prepares some of the Z3 sorts, but *not* the tuple sorts; these are created on-demand. private def prepareSorts: Unit = { import Z3Context.{ADTSortReference, RecursiveType, RegularSort} // NOTE THAT abstract classes that extend abstract classes are not @@ -402,7 +403,8 @@ class FairZ3Solver(reporter: Reporter) extends Solver(reporter) with AbstractZ3S case Some(s) => s case None => { val tpesSorts = tpes.map(typeToSort) - val (tupleSort, consTuple, projsTuple) = z3.mkTupleSort(tpes.map(_.toString).mkString("Tuple2(",", ", ")"), tpesSorts: _*) + val sortSymbol = z3.mkFreshStringSymbol("TupleSort") + val (tupleSort, consTuple, projsTuple) = z3.mkTupleSort(sortSymbol, tpesSorts: _*) tupleSorts += (tt -> tupleSort) tupleConstructors += (tt -> consTuple) tupleSelectors += (tt -> projsTuple) @@ -859,11 +861,14 @@ class FairZ3Solver(reporter: Reporter) extends Solver(reporter) with AbstractZ3S //println(ex) val recResult = (ex match { case tu@Tuple(args) => { + // This call is required, because the Z3 sort may not have been generated yet. + // If it has, it's just a map lookup and instant return. typeToSort(tu.getType) val constructor = tupleConstructors(tu.getType) constructor(args.map(rec(_)): _*) } case ts@TupleSelect(tu, i) => { + // See comment above for similar code. typeToSort(tu.getType) val selector = tupleSelectors(tu.getType)(i-1) selector(rec(tu)) @@ -973,6 +978,7 @@ class FairZ3Solver(reporter: Reporter) extends Solver(reporter) with AbstractZ3S case errorType => scala.sys.error("Unexpected type for singleton map: " + (ex, errorType)) } case e @ EmptyMap(fromType, toType) => { + typeToSort(e.getType) //had to add this here because the mapRangeNoneConstructors was not yet constructed... val fromSort = typeToSort(fromType) val toSort = typeToSort(toType) z3.mkConstArray(fromSort, mapRangeNoneConstructors(toType)()) @@ -1221,7 +1227,7 @@ class FairZ3Solver(reporter: Reporter) extends Solver(reporter) with AbstractZ3S val startingVar : Identifier = FreshIdentifier("start", true).setType(BooleanType) val result = treatFunctionInvocationSet(startingVar, true, functionCallsOf(formula)) - reporter.info(result) + //reporter.info(result) (Variable(startingVar) +: formula +: result._1, result._2) } } diff --git a/src/main/scala/leon/FunctionClosure.scala b/src/main/scala/leon/FunctionClosure.scala new file mode 100644 index 0000000000000000000000000000000000000000..f4ec3f1b0fd564113972e687d3082904d46eb050 --- /dev/null +++ b/src/main/scala/leon/FunctionClosure.scala @@ -0,0 +1,142 @@ +package leon + +import purescala.Common._ +import purescala.Definitions._ +import purescala.Trees._ +import purescala.TypeTrees._ + +object FunctionClosure extends Pass { + + val description = "Closing function with its scoping variables" + + private var enclosingPreconditions: List[Expr] = Nil + + private var pathConstraints: List[Expr] = Nil + private var newFunDefs: Map[FunDef, FunDef] = Map() + //private var + + def apply(program: Program): Program = { + newFunDefs = Map() + val funDefs = program.definedFunctions + funDefs.foreach(fd => { + enclosingPreconditions = fd.precondition.toList + pathConstraints = fd.precondition.toList + fd.body = Some(functionClosure(fd.getBody, fd.args.map(_.id).toSet)) + }) + program + } + + private def functionClosure(expr: Expr, bindedVars: Set[Identifier]): Expr = expr match { + case l @ LetDef(fd, rest) => { + + val id = fd.id + val rt = fd.returnType + val varDecl = fd.args + val funBody = fd.getBody + val precondition = fd.precondition + val postcondition = fd.postcondition + + val bodyVars: Set[Identifier] = variablesOf(funBody) ++ variablesOf(precondition.getOrElse(BooleanLiteral(true))) + val capturedVars = bodyVars.intersect(bindedVars)// this should be the variable used that are in the scope + val (constraints, allCapturedVars) = filterConstraints(capturedVars) //all relevant path constraints + val capturedVarsWithConstraints = allCapturedVars.toSeq + + val freshVars: Map[Identifier, Identifier] = capturedVarsWithConstraints.map(v => (v, FreshIdentifier(v.name).setType(v.getType))).toMap + val freshVarsExpr: Map[Expr, Expr] = freshVars.map(p => (p._1.toVariable, p._2.toVariable)) + + val extraVarDecls = freshVars.map{ case (_, v2) => VarDecl(v2, v2.getType) } + val newVarDecls = varDecl ++ extraVarDecls + val newFunId = FreshIdentifier(id.name) + val newFunDef = new FunDef(newFunId, rt, newVarDecls).setPosInfo(fd) + newFunDef.fromLoop = fd.fromLoop + newFunDef.parent = fd.parent + + val freshPrecondition = precondition.map(expr => replace(freshVarsExpr, expr)) + val freshConstraints = constraints.map(expr => replace(freshVarsExpr, expr)) + newFunDef.precondition = freshConstraints match { + case List() => freshPrecondition + case precs => Some(And(freshPrecondition.getOrElse(BooleanLiteral(true)) +: precs)) + } + newFunDef.postcondition = postcondition.map(expr => replace(freshVarsExpr, expr)) + + def substFunInvocInDef(expr: Expr): Option[Expr] = expr match { + case fi@FunctionInvocation(fd, args) if fd.id == id => Some(FunctionInvocation(newFunDef, args ++ extraVarDecls.map(_.id.toVariable)).setPosInfo(fi)) + case _ => None + } + val freshBody = replace(freshVarsExpr, funBody) + val oldPathConstraints = pathConstraints + pathConstraints = (precondition.getOrElse(BooleanLiteral(true)) :: pathConstraints).map(e => replace(freshVarsExpr, e)) + val recBody = functionClosure(freshBody, bindedVars ++ newVarDecls.map(_.id)) + pathConstraints = oldPathConstraints + val recBody2 = searchAndReplaceDFS(substFunInvocInDef)(recBody) + newFunDef.body = Some(recBody2) + + def substFunInvocInRest(expr: Expr): Option[Expr] = expr match { + case fi@FunctionInvocation(fd, args) if fd.id == id => Some(FunctionInvocation(newFunDef, args ++ capturedVarsWithConstraints.map(_.toVariable)).setPosInfo(fi)) + case _ => None + } + val recRest = functionClosure(rest, bindedVars) + val recRest2 = searchAndReplaceDFS(substFunInvocInRest)(recRest) + LetDef(newFunDef, recRest2).setType(l.getType) + } + case l @ Let(i,e,b) => { + val re = functionClosure(e, bindedVars) + pathConstraints ::= Equals(Variable(i), re) + val rb = functionClosure(b, bindedVars + i) + pathConstraints = pathConstraints.tail + Let(i, re, rb).setType(l.getType) + } + case n @ NAryOperator(args, recons) => { + var change = false + val rargs = args.map(a => functionClosure(a, bindedVars)) + recons(rargs).setType(n.getType) + } + case b @ BinaryOperator(t1,t2,recons) => { + val r1 = functionClosure(t1, bindedVars) + val r2 = functionClosure(t2, bindedVars) + recons(r1,r2).setType(b.getType) + } + case u @ UnaryOperator(t,recons) => { + val r = functionClosure(t, bindedVars) + recons(r).setType(u.getType) + } + case i @ IfExpr(cond,then,elze) => { + val rCond = functionClosure(cond, bindedVars) + pathConstraints ::= rCond + val rThen = functionClosure(then, bindedVars) + pathConstraints = pathConstraints.tail + pathConstraints ::= Not(rCond) + val rElze = functionClosure(elze, bindedVars) + pathConstraints = pathConstraints.tail + IfExpr(rCond, rThen, rElze).setType(i.getType) + } + case m @ MatchExpr(scrut,cses) => { //TODO: will not work if there are actual nested function in cases + //val rScrut = functionClosure(scrut, bindedVars) + m + } + case t if t.isInstanceOf[Terminal] => t + case unhandled => scala.sys.error("Non-terminal case should be handled in searchAndReplace: " + unhandled) + } + + //filter the list of constraints, only keeping those relevant to the set of variables + def filterConstraints(vars: Set[Identifier]): (List[Expr], Set[Identifier]) = { + var allVars = vars + var newVars: Set[Identifier] = Set() + var constraints = pathConstraints + var filteredConstraints: List[Expr] = Nil + do { + allVars ++= newVars + newVars = Set() + constraints = pathConstraints.filterNot(filteredConstraints.contains(_)) + constraints.foreach(expr => { + val vs = variablesOf(expr) + if(!vs.intersect(allVars).isEmpty) { + filteredConstraints ::= expr + newVars ++= vs.diff(allVars) + } + }) + } while(newVars != Set()) + (filteredConstraints, allVars) + } + +} diff --git a/src/main/scala/leon/FunctionHoisting.scala b/src/main/scala/leon/FunctionHoisting.scala new file mode 100644 index 0000000000000000000000000000000000000000..1cfdffdc3268aa7f40986d8f00a3bac32efdf776 --- /dev/null +++ b/src/main/scala/leon/FunctionHoisting.scala @@ -0,0 +1,76 @@ +package leon + +import purescala.Common._ +import purescala.Definitions._ +import purescala.Trees._ +import purescala.TypeTrees._ + +object FunctionHoisting extends Pass { + + val description = "Hoist function at the top level" + + def apply(program: Program): Program = { + val funDefs = program.definedFunctions + var topLevelFuns: Set[FunDef] = Set() + funDefs.foreach(fd => { + val (newBody, additionalTopLevelFun) = hoist(fd.getBody) + fd.body = Some(newBody) + topLevelFuns ++= additionalTopLevelFun + }) + val Program(id, ObjectDef(objId, defs, invariants)) = program + Program(id, ObjectDef(objId, defs ++ topLevelFuns, invariants)) + } + + private def hoist(expr: Expr): (Expr, Set[FunDef]) = expr match { + case l @ LetDef(fd, rest) => { + val (e, s) = hoist(rest) + val (e2, s2) = hoist(fd.getBody) + fd.body = Some(e2) + + (e, (s ++ s2) + fd) + } + case l @ Let(i,e,b) => { + val (re, s1) = hoist(e) + val (rb, s2) = hoist(b) + (Let(i, re, rb), s1 ++ s2) + } + case n @ NAryOperator(args, recons) => { + val rargs = args.map(a => hoist(a)) + (recons(rargs.map(_._1)).setType(n.getType), rargs.flatMap(_._2).toSet) + } + case b @ BinaryOperator(t1,t2,recons) => { + val (r1, s1) = hoist(t1) + val (r2, s2) = hoist(t2) + (recons(r1,r2).setType(b.getType), s1 ++ s2) + } + case u @ UnaryOperator(t,recons) => { + val (r, s) = hoist(t) + (recons(r).setType(u.getType), s) + } + case i @ IfExpr(t1,t2,t3) => { + val (r1, s1) = hoist(t1) + val (r2, s2) = hoist(t2) + val (r3, s3) = hoist(t3) + (IfExpr(r1, r2, r3).setType(i.getType), s1 ++ s2 ++ s3) + } + case m @ MatchExpr(scrut,cses) => { + val (scrutRes, scrutSet) = hoist(scrut) + val (csesRes, csesSets) = cses.map{ + case SimpleCase(pat, rhs) => { + val (r, s) = hoist(rhs) + (SimpleCase(pat, r), s) + } + case GuardedCase(pat, guard, rhs) => { + val (r, s) = hoist(rhs) + (GuardedCase(pat, guard, r), s) + } + }.unzip + (MatchExpr(scrutRes, csesRes).setType(m.getType).setPosInfo(m), csesSets.toSet.flatten ++ scrutSet) + } + case t if t.isInstanceOf[Terminal] => (t, Set()) + case unhandled => scala.sys.error("Non-terminal case should be handled in searchAndReplace: " + unhandled) + } + +} + +// vim: set ts=4 sw=4 et: diff --git a/src/main/scala/leon/ImperativeCodeElimination.scala b/src/main/scala/leon/ImperativeCodeElimination.scala new file mode 100644 index 0000000000000000000000000000000000000000..b006520ebd0eee5b1d49861534d5dc686f8dea57 --- /dev/null +++ b/src/main/scala/leon/ImperativeCodeElimination.scala @@ -0,0 +1,267 @@ +package leon + +import purescala.Common._ +import purescala.Definitions._ +import purescala.Trees._ +import purescala.TypeTrees._ + +object ImperativeCodeElimination extends Pass { + + val description = "Transform imperative constructs into purely functional code" + + private var varInScope = Set[Identifier]() + private var parent: FunDef = null //the enclosing fundef + + def apply(pgm: Program): Program = { + val allFuns = pgm.definedFunctions + allFuns.foreach(fd => { + parent = fd + val (res, scope, _) = toFunction(fd.getBody) + fd.body = Some(scope(res)) + }) + pgm + } + + //return a "scope" consisting of purely functional code that defines potentially needed + //new variables (val, not var) and a mapping for each modified variable (var, not val :) ) + //to their new name defined in the scope. The first returned valued is the value of the expression + //that should be introduced as such in the returned scope (the val already refers to the new names) + private def toFunction(expr: Expr): (Expr, Expr => Expr, Map[Identifier, Identifier]) = { + val res = expr match { + case LetVar(id, e, b) => { + val newId = FreshIdentifier(id.name).setType(id.getType) + val (rhsVal, rhsScope, rhsFun) = toFunction(e) + varInScope += id + val (bodyRes, bodyScope, bodyFun) = toFunction(b) + varInScope -= id + val scope = (body: Expr) => rhsScope(Let(newId, rhsVal, replaceNames(rhsFun + (id -> newId), bodyScope(body)))) + (bodyRes, scope, (rhsFun + (id -> newId)) ++ bodyFun) + } + case Assignment(id, e) => { + assert(varInScope.contains(id)) + val newId = FreshIdentifier(id.name).setType(id.getType) + val (rhsVal, rhsScope, rhsFun) = toFunction(e) + val scope = (body: Expr) => rhsScope(Let(newId, rhsVal, body)) + (UnitLiteral, scope, rhsFun + (id -> newId)) + } + + case ite@IfExpr(cond, tExpr, eExpr) => { + val (cRes, cScope, cFun) = toFunction(cond) + val (tRes, tScope, tFun) = toFunction(tExpr) + val (eRes, eScope, eFun) = toFunction(eExpr) + + val modifiedVars: Seq[Identifier] = (tFun.keys ++ eFun.keys).toSet.intersect(varInScope).toSeq + val resId = FreshIdentifier("res").setType(ite.getType) + val freshIds = modifiedVars.map(id => FreshIdentifier(id.name).setType(id.getType)) + val iteType = if(modifiedVars.isEmpty) resId.getType else TupleType(resId.getType +: freshIds.map(_.getType)) + + val thenVal = if(modifiedVars.isEmpty) tRes else Tuple(tRes +: modifiedVars.map(vId => tFun.get(vId) match { + case Some(newId) => newId.toVariable + case None => vId.toVariable + })) + thenVal.setType(iteType) + + val elseVal = if(modifiedVars.isEmpty) eRes else Tuple(eRes +: modifiedVars.map(vId => eFun.get(vId) match { + case Some(newId) => newId.toVariable + case None => vId.toVariable + })) + elseVal.setType(iteType) + + val iteExpr = IfExpr(cRes, replaceNames(cFun, tScope(thenVal)), replaceNames(cFun, eScope(elseVal))).setType(iteType) + + val scope = ((body: Expr) => { + val tupleId = FreshIdentifier("t").setType(iteType) + cScope( + Let(tupleId, iteExpr, + if(freshIds.isEmpty) + Let(resId, tupleId.toVariable, body) + else + Let(resId, TupleSelect(tupleId.toVariable, 1), + freshIds.zipWithIndex.foldLeft(body)((b, id) => + Let(id._1, + TupleSelect(tupleId.toVariable, id._2 + 2).setType(id._1.getType), + b))))) + }) + + (resId.toVariable, scope, cFun ++ modifiedVars.zip(freshIds).toMap) + } + + case m @ MatchExpr(scrut, cses) => { + val csesRhs = cses.map(_.rhs) //we can ignore pattern, and the guard is required to be pure + val (csesRes, csesScope, csesFun) = csesRhs.map(toFunction).unzip3 + val (scrutRes, scrutScope, scrutFun) = toFunction(scrut) + + val modifiedVars: Seq[Identifier] = csesFun.toSet.flatMap((m: Map[Identifier, Identifier]) => m.keys).intersect(varInScope).toSeq + val resId = FreshIdentifier("res").setType(m.getType) + val freshIds = modifiedVars.map(id => FreshIdentifier(id.name).setType(id.getType)) + val matchType = if(modifiedVars.isEmpty) resId.getType else TupleType(resId.getType +: freshIds.map(_.getType)) + + val csesVals = csesRes.zip(csesFun).map{ + case (cRes, cFun) => (if(modifiedVars.isEmpty) cRes else Tuple(cRes +: modifiedVars.map(vId => cFun.get(vId) match { + case Some(newId) => newId.toVariable + case None => vId.toVariable + }))).setType(matchType) + } + + val newRhs = csesVals.zip(csesScope).map{ + case (cVal, cScope) => replaceNames(scrutFun, cScope(cVal)).setType(matchType) + } + val matchExpr = MatchExpr(scrutRes, cses.zip(newRhs).map{ + case (SimpleCase(pat, _), newRhs) => SimpleCase(pat, newRhs) + case (GuardedCase(pat, guard, _), newRhs) => GuardedCase(pat, replaceNames(scrutFun, guard), newRhs) + }).setType(matchType).setPosInfo(m) + + val scope = ((body: Expr) => { + val tupleId = FreshIdentifier("t").setType(matchType) + scrutScope( + Let(tupleId, matchExpr, + if(freshIds.isEmpty) + Let(resId, tupleId.toVariable, body) + else + Let(resId, TupleSelect(tupleId.toVariable, 1), + freshIds.zipWithIndex.foldLeft(body)((b, id) => + Let(id._1, + TupleSelect(tupleId.toVariable, id._2 + 2).setType(id._1.getType), + b))))) + }) + + (resId.toVariable, scope, scrutFun ++ modifiedVars.zip(freshIds).toMap) + } + case wh@While(cond, body) => { + val (condRes, condScope, condFun) = toFunction(cond) + val (_, bodyScope, bodyFun) = toFunction(body) + val condBodyFun = condFun ++ bodyFun + + val modifiedVars: Seq[Identifier] = condBodyFun.keys.toSet.intersect(varInScope).toSeq + + if(modifiedVars.isEmpty) + (UnitLiteral, (b: Expr) => b, Map()) + else { + val whileFunVars = modifiedVars.map(id => FreshIdentifier(id.name).setType(id.getType)) + val modifiedVars2WhileFunVars = modifiedVars.zip(whileFunVars).toMap + val whileFunVarDecls = whileFunVars.map(id => VarDecl(id, id.getType)) + val whileFunReturnType = if(whileFunVars.size == 1) whileFunVars.head.getType else TupleType(whileFunVars.map(_.getType)) + val whileFunDef = new FunDef(FreshIdentifier("while"), whileFunReturnType, whileFunVarDecls).setPosInfo(wh) + whileFunDef.fromLoop = true + whileFunDef.parent = Some(parent) + + val whileFunCond = condRes + val whileFunRecursiveCall = replaceNames(condFun, + bodyScope(FunctionInvocation(whileFunDef, modifiedVars.map(id => condBodyFun(id).toVariable)).setPosInfo(wh))) + val whileFunBaseCase = + (if(whileFunVars.size == 1) + condFun.get(modifiedVars.head).getOrElse(whileFunVars.head).toVariable + else + Tuple(modifiedVars.map(id => condFun.get(id).getOrElse(modifiedVars2WhileFunVars(id)).toVariable)) + ).setType(whileFunReturnType) + val whileFunBody = replaceNames(modifiedVars2WhileFunVars, + condScope(IfExpr(whileFunCond, whileFunRecursiveCall, whileFunBaseCase).setType(whileFunReturnType))) + whileFunDef.body = Some(whileFunBody) + + val resVar = ResultVariable().setType(whileFunReturnType) + val whileFunVars2ResultVars: Map[Expr, Expr] = + if(whileFunVars.size == 1) + Map(whileFunVars.head.toVariable -> resVar) + else + whileFunVars.zipWithIndex.map{ case (v, i) => (v.toVariable, TupleSelect(resVar, i+1).setType(v.getType)) }.toMap + val modifiedVars2ResultVars: Map[Expr, Expr] = modifiedVars.map(id => + (id.toVariable, whileFunVars2ResultVars(modifiedVars2WhileFunVars(id).toVariable))).toMap + + //the mapping of the trivial post condition variables depends on whether the condition has had some side effect + val trivialPostcondition: Option[Expr] = Some(Not(replace( + modifiedVars.map(id => (condFun.get(id).getOrElse(id).toVariable, modifiedVars2ResultVars(id.toVariable))).toMap, + whileFunCond))) + val invariantPrecondition: Option[Expr] = wh.invariant.map(expr => replaceNames(modifiedVars2WhileFunVars, expr)) + val invariantPostcondition: Option[Expr] = wh.invariant.map(expr => replace(modifiedVars2ResultVars, expr)) + whileFunDef.precondition = invariantPrecondition + whileFunDef.postcondition = trivialPostcondition.map(expr => + And(expr, invariantPostcondition match { + case Some(e) => e + case None => BooleanLiteral(true) + })) + + val finalVars = modifiedVars.map(id => FreshIdentifier(id.name).setType(id.getType)) + val finalScope = ((body: Expr) => { + val tupleId = FreshIdentifier("t").setType(whileFunReturnType) + LetDef( + whileFunDef, + Let(tupleId, + FunctionInvocation(whileFunDef, modifiedVars.map(_.toVariable)).setPosInfo(wh), + if(finalVars.size == 1) + Let(finalVars.head, tupleId.toVariable, body) + else + finalVars.zipWithIndex.foldLeft(body)((b, id) => + Let(id._1, + TupleSelect(tupleId.toVariable, id._2 + 1).setType(id._1.getType), + b)))) + }) + + (UnitLiteral, finalScope, modifiedVars.zip(finalVars).toMap) + } + } + + case Block(Seq(), expr) => toFunction(expr) + case Block(exprs, expr) => { + val (scope, fun) = exprs.foldRight((body: Expr) => body, Map[Identifier, Identifier]())((e, acc) => { + val (accScope, accFun) = acc + val (_, rScope, rFun) = toFunction(e) + val scope = (body: Expr) => rScope(replaceNames(rFun, accScope(body))) + (scope, rFun ++ accFun) + }) + val (lastRes, lastScope, lastFun) = toFunction(expr) + val finalFun = fun ++ lastFun + (replaceNames(finalFun, lastRes), + (body: Expr) => scope(replaceNames(fun, lastScope(body))), + finalFun) + } + + //pure expression (that could still contain side effects as a subexpression) (evaluation order is from left to right) + case Let(id, e, b) => { + val (bindRes, bindScope, bindFun) = toFunction(e) + val (bodyRes, bodyScope, bodyFun) = toFunction(b) + (bodyRes, + (b2: Expr) => bindScope(Let(id, bindRes, replaceNames(bindFun, bodyScope(b2)))), + bindFun ++ bodyFun) + } + case LetDef(fd, b) => { + //Recall that here the nested function should not access mutable variables from an outside scope + val (bodyRes, bodyScope, bodyFun) = toFunction(b) + (bodyRes, (b2: Expr) => LetDef(fd, bodyScope(b2)), bodyFun) + } + case n @ NAryOperator(Seq(), recons) => (n, (body: Expr) => body, Map()) + case n @ NAryOperator(args, recons) => { + val (recArgs, scope, fun) = args.foldRight((Seq[Expr](), (body: Expr) => body, Map[Identifier, Identifier]()))((arg, acc) => { + val (accArgs, accScope, accFun) = acc + val (argVal, argScope, argFun) = toFunction(arg) + val newScope = (body: Expr) => argScope(replaceNames(argFun, accScope(body))) + (argVal +: accArgs, newScope, argFun ++ accFun) + }) + (recons(recArgs).setType(n.getType), scope, fun) + } + case b @ BinaryOperator(a1, a2, recons) => { + val (argVal1, argScope1, argFun1) = toFunction(a1) + val (argVal2, argScope2, argFun2) = toFunction(a2) + val scope = (body: Expr) => { + val rhs = argScope2(replaceNames(argFun2, body)) + val lhs = argScope1(replaceNames(argFun1, rhs)) + lhs + } + (recons(argVal1, argVal2).setType(b.getType), scope, argFun1 ++ argFun2) + } + case u @ UnaryOperator(a, recons) => { + val (argVal, argScope, argFun) = toFunction(a) + (recons(argVal).setType(u.getType), argScope, argFun) + } + case (t: Terminal) => (t, (body: Expr) => body, Map()) + + + case _ => sys.error("not supported: " + expr) + } + //val codeRepresentation = res._2(Block(res._3.map{ case (id1, id2) => Assignment(id1, id2.toVariable)}.toSeq, res._1)) + //println("res of toFunction on: " + expr + " IS: " + codeRepresentation) + res.asInstanceOf[(Expr, (Expr) => Expr, Map[Identifier, Identifier])] //need cast because it seems that res first map type is _ <: Identifier instead of Identifier + } + + def replaceNames(fun: Map[Identifier, Identifier], expr: Expr) = replace(fun.map(ids => (ids._1.toVariable, ids._2.toVariable)), expr) + +} diff --git a/src/main/scala/leon/Main.scala b/src/main/scala/leon/Main.scala index aefcf528a0c83593f806374bdc21688c5d2a15bc..72623c96a2c732b4dac1995392545474c7f3f9c7 100644 --- a/src/main/scala/leon/Main.scala +++ b/src/main/scala/leon/Main.scala @@ -31,7 +31,9 @@ object Main { } private def defaultAction(program: Program, reporter: Reporter) : Unit = { - val analysis = new Analysis(program, reporter) + val passManager = new PassManager(Seq(ImperativeCodeElimination, UnitElimination, FunctionClosure, FunctionHoisting, Simplificator)) + val program2 = passManager.run(program) + val analysis = new Analysis(program2, reporter) analysis.analyse } diff --git a/src/main/scala/leon/Pass.scala b/src/main/scala/leon/Pass.scala new file mode 100644 index 0000000000000000000000000000000000000000..4bbc88856c20b96bd5c9c3fd9e1815b57b85d976 --- /dev/null +++ b/src/main/scala/leon/Pass.scala @@ -0,0 +1,14 @@ +package leon + +import purescala.Definitions._ + +abstract class Pass { + + def apply(pgm: Program): Program + + val description: String + + + //Maybe adding some dependency declaration and tree checking methods + +} diff --git a/src/main/scala/leon/PassManager.scala b/src/main/scala/leon/PassManager.scala new file mode 100644 index 0000000000000000000000000000000000000000..f2830df79bfa129ce93d58157110a2107453595a --- /dev/null +++ b/src/main/scala/leon/PassManager.scala @@ -0,0 +1,16 @@ +package leon + +import purescala.Definitions._ + +class PassManager(passes: Seq[Pass]) { + + def run(program: Program): Program = { + passes.foldLeft(program)((pgm, pass) => { + //println("Running Pass: " + pass.description) + val newPgm = pass(pgm) + //println("Resulting program: " + newPgm) + newPgm + }) + } + +} diff --git a/src/main/scala/leon/RandomSolver.scala b/src/main/scala/leon/RandomSolver.scala index f153818924cd47d79f7470b8487767df5f30f7bc..4cee5b06181bbc803cc2061af9f62c33c5d04386 100644 --- a/src/main/scala/leon/RandomSolver.scala +++ b/src/main/scala/leon/RandomSolver.scala @@ -10,8 +10,6 @@ import Evaluator._ import scala.util.Random -import scala.sys.error - class RandomSolver(reporter: Reporter, val nbTrial: Option[Int] = None) extends Solver(reporter) { require(nbTrial.forall(i => i >= 0)) @@ -68,13 +66,14 @@ class RandomSolver(reporter: Reporter, val nbTrial: Option[Int] = None) extends case AnyType => randomValue(randomType(), size) case SetType(base) => EmptySet(base) case MultisetType(base) => EmptyMultiset(base) - case Untyped => error("I don't know what to do") - case BottomType => error("I don't know what to do") - case ListType(base) => error("I don't know what to do") - case TupleType(bases) => error("I don't know what to do") - case MapType(from, to) => error("I don't know what to do") - case OptionType(base) => error("I don't know what to do") - case f: FunctionType => error("I don't know what to do") + case Untyped => sys.error("I don't know what to do") + case BottomType => sys.error("I don't know what to do") + case ListType(base) => sys.error("I don't know what to do") + case TupleType(bases) => sys.error("I don't know what to do") + case MapType(from, to) => sys.error("I don't know what to do") + case OptionType(base) => sys.error("I don't know what to do") + case f: FunctionType => sys.error("I don't know what to do") + case _ => sys.error("Unexpected type: " + t) } def solve(expression: Expr) : Option[Boolean] = { diff --git a/src/main/scala/leon/Settings.scala b/src/main/scala/leon/Settings.scala index 6593e418c28a4f49f5b8a798b527b468e9f3040c..797a209e57d08d863f24dcd5f97008d4fc875bf7 100644 --- a/src/main/scala/leon/Settings.scala +++ b/src/main/scala/leon/Settings.scala @@ -23,4 +23,5 @@ object Settings { var useParallel : Boolean = false // When this is None, use real integers var bitvectorBitwidth : Option[Int] = None + var verbose : Boolean = false } diff --git a/src/main/scala/leon/Simplificator.scala b/src/main/scala/leon/Simplificator.scala new file mode 100644 index 0000000000000000000000000000000000000000..e8a7e2e005b68ce54a6482cb4a06f942d54ba58f --- /dev/null +++ b/src/main/scala/leon/Simplificator.scala @@ -0,0 +1,22 @@ +package leon + +import purescala.Common._ +import purescala.Definitions._ +import purescala.Trees._ +import purescala.TypeTrees._ + +object Simplificator extends Pass { + + val description = "Some safe and minimal simplification" + + def apply(pgm: Program): Program = { + + val allFuns = pgm.definedFunctions + allFuns.foreach(fd => { + fd.body = Some(simplifyLets(fd.getBody)) + }) + pgm + } + +} + diff --git a/src/main/scala/leon/UnitElimination.scala b/src/main/scala/leon/UnitElimination.scala new file mode 100644 index 0000000000000000000000000000000000000000..7fe181fba57f7c80d0aa878896641b8195742485 --- /dev/null +++ b/src/main/scala/leon/UnitElimination.scala @@ -0,0 +1,146 @@ +package leon + +import purescala.Common._ +import purescala.Definitions._ +import purescala.Trees._ +import purescala.TypeTrees._ + +object UnitElimination extends Pass { + + val description = "Remove all usage of the Unit type and value" + + private var fun2FreshFun: Map[FunDef, FunDef] = Map() + private var id2FreshId: Map[Identifier, Identifier] = Map() + + def apply(pgm: Program): Program = { + fun2FreshFun = Map() + val allFuns = pgm.definedFunctions + + //first introduce new signatures without Unit parameters + allFuns.foreach(fd => { + if(fd.returnType != UnitType && fd.args.exists(vd => vd.tpe == UnitType)) { + val freshFunDef = new FunDef(FreshIdentifier(fd.id.name), fd.returnType, fd.args.filterNot(vd => vd.tpe == UnitType)).setPosInfo(fd) + freshFunDef.fromLoop = fd.fromLoop + freshFunDef.parent = fd.parent + freshFunDef.precondition = fd.precondition //TODO: maybe removing unit from the conditions as well.. + freshFunDef.postcondition = fd.postcondition//TODO: maybe removing unit from the conditions as well.. + fun2FreshFun += (fd -> freshFunDef) + } else { + fun2FreshFun += (fd -> fd) //this will make the next step simpler + } + }) + + //then apply recursively to the bodies + val newFuns = allFuns.flatMap(fd => if(fd.returnType == UnitType) Seq() else { + val body = fd.getBody + val newFd = fun2FreshFun(fd) + newFd.body = Some(removeUnit(body)) + Seq(newFd) + }) + + val Program(id, ObjectDef(objId, _, invariants)) = pgm + val allClasses = pgm.definedClasses + Program(id, ObjectDef(objId, allClasses ++ newFuns, invariants)) + } + + private def simplifyType(tpe: TypeTree): TypeTree = tpe match { + case TupleType(tpes) => tpes.map(simplifyType).filterNot{ case UnitType => true case _ => false } match { + case Seq() => UnitType + case Seq(tpe) => tpe + case tpes => TupleType(tpes) + } + case t => t + } + + //remove unit value as soon as possible, so expr should never be equal to a unit + private def removeUnit(expr: Expr): Expr = { + assert(expr.getType != UnitType) + expr match { + case fi@FunctionInvocation(fd, args) => { + val newArgs = args.filterNot(arg => arg.getType == UnitType) + FunctionInvocation(fun2FreshFun(fd), newArgs).setPosInfo(fi) + } + case t@Tuple(args) => { + val TupleType(tpes) = t.getType + val (newTpes, newArgs) = tpes.zip(args).filterNot{ case (UnitType, _) => true case _ => false }.unzip + Tuple(newArgs.map(removeUnit)).setType(TupleType(newTpes)) + } + case ts@TupleSelect(t, index) => { + val TupleType(tpes) = t.getType + val selectionType = tpes(index-1) + val (_, newIndex) = tpes.zipWithIndex.foldLeft((0,-1)){ + case ((nbUnit, newIndex), (tpe, i)) => + if(i == index-1) (nbUnit, index - nbUnit) else (if(tpe == UnitType) nbUnit + 1 else nbUnit, newIndex) + } + TupleSelect(removeUnit(t), newIndex).setType(selectionType) + } + case Let(id, e, b) => { + if(id.getType == UnitType) + removeUnit(b) + else { + id.getType match { + case TupleType(tpes) if tpes.exists(_ == UnitType) => { + val newTupleType = TupleType(tpes.filterNot(_ == UnitType)) + val freshId = FreshIdentifier(id.name).setType(newTupleType) + id2FreshId += (id -> freshId) + val newBody = removeUnit(b) + id2FreshId -= id + Let(freshId, removeUnit(e), newBody) + } + case _ => Let(id, removeUnit(e), removeUnit(b)) + } + } + } + case LetDef(fd, b) => { + if(fd.returnType == UnitType) + removeUnit(b) + else { + val (newFd, rest) = if(fd.args.exists(vd => vd.tpe == UnitType)) { + val freshFunDef = new FunDef(FreshIdentifier(fd.id.name), fd.returnType, fd.args.filterNot(vd => vd.tpe == UnitType)).setPosInfo(fd) + freshFunDef.precondition = fd.precondition //TODO: maybe removing unit from the conditions as well.. + freshFunDef.postcondition = fd.postcondition//TODO: maybe removing unit from the conditions as well.. + fun2FreshFun += (fd -> freshFunDef) + freshFunDef.body = Some(removeUnit(fd.getBody)) + val restRec = removeUnit(b) + fun2FreshFun -= fd + (freshFunDef, restRec) + } else { + fun2FreshFun += (fd -> fd) + fd.body = Some(removeUnit(fd.getBody)) + val restRec = removeUnit(b) + fun2FreshFun -= fd + (fd, restRec) + } + LetDef(newFd, rest) + } + } + case ite@IfExpr(cond, tExpr, eExpr) => { + val thenRec = removeUnit(tExpr) + val elseRec = removeUnit(eExpr) + IfExpr(removeUnit(cond), thenRec, elseRec).setType(thenRec.getType) + } + case n @ NAryOperator(args, recons) => { + recons(args.map(removeUnit(_))).setType(n.getType) + } + case b @ BinaryOperator(a1, a2, recons) => { + recons(removeUnit(a1), removeUnit(a2)).setType(b.getType) + } + case u @ UnaryOperator(a, recons) => { + recons(removeUnit(a)).setType(u.getType) + } + case v @ Variable(id) => if(id2FreshId.isDefinedAt(id)) Variable(id2FreshId(id)) else v + case (t: Terminal) => t + case m @ MatchExpr(scrut, cses) => { + val scrutRec = removeUnit(scrut) + val csesRec = cses.map{ + case SimpleCase(pat, rhs) => SimpleCase(pat, removeUnit(rhs)) + case GuardedCase(pat, guard, rhs) => GuardedCase(pat, removeUnit(guard), removeUnit(rhs)) + } + val tpe = csesRec.head.rhs.getType + MatchExpr(scrutRec, csesRec).setType(tpe).setPosInfo(m) + } + case _ => sys.error("not supported: " + expr) + } + } + +} diff --git a/src/main/scala/leon/Utils.scala b/src/main/scala/leon/Utils.scala index 6d7f019ce669a9e81f7da0380048d0071d6713bc..9ed8ef400ea0f143af078149c0d3b3034394c3a4 100644 --- a/src/main/scala/leon/Utils.scala +++ b/src/main/scala/leon/Utils.scala @@ -9,4 +9,10 @@ object Utils { } implicit def any2IsValid(x: Boolean) : IsValid = new IsValid(x) + + + object InvariantFunction { + def invariant(x: Boolean): Unit = () + } + implicit def while2Invariant(u: Unit) = InvariantFunction } diff --git a/src/main/scala/leon/VerificationCondition.scala b/src/main/scala/leon/VerificationCondition.scala index 4150536fa441352761a7542019caa64d478b0d6a..01545bc3e87bd3606591b464f59098a5b0d8dd8c 100644 --- a/src/main/scala/leon/VerificationCondition.scala +++ b/src/main/scala/leon/VerificationCondition.scala @@ -49,4 +49,7 @@ object VCKind extends Enumeration { val Postcondition = Value("postcond.") val ExhaustiveMatch = Value("match.") val MapAccess = Value("map acc.") + val InvariantInit = Value("inv init.") + val InvariantInd = Value("inv ind.") + val InvariantPost = Value("inv post.") } diff --git a/src/main/scala/leon/plugin/CodeExtraction.scala b/src/main/scala/leon/plugin/CodeExtraction.scala index f2bfec68113ac6c1e8cbdd40c736d3d9440a0cc8..cbf4424bf7ad4d83ecb99f7a68302105b777f4d6 100644 --- a/src/main/scala/leon/plugin/CodeExtraction.scala +++ b/src/main/scala/leon/plugin/CodeExtraction.scala @@ -5,7 +5,7 @@ import scala.tools.nsc._ import scala.tools.nsc.plugins._ import purescala.Definitions._ -import purescala.Trees._ +import purescala.Trees.{Block => PBlock, _} import purescala.TypeTrees._ import purescala.Common._ @@ -53,6 +53,8 @@ trait CodeExtraction extends Extractors { sym == function1TraitSym } + private val mutableVarSubsts: scala.collection.mutable.Map[Symbol,Function0[Expr]] = + scala.collection.mutable.Map.empty[Symbol,Function0[Expr]] private val varSubsts: scala.collection.mutable.Map[Symbol,Function0[Expr]] = scala.collection.mutable.Map.empty[Symbol,Function0[Expr]] private val classesToClasses: scala.collection.mutable.Map[Symbol,ClassTypeDef] = @@ -60,7 +62,7 @@ trait CodeExtraction extends Extractors { private val defsToDefs: scala.collection.mutable.Map[Symbol,FunDef] = scala.collection.mutable.Map.empty[Symbol,FunDef] - def extractCode(unit: CompilationUnit): Program = { + def extractCode(unit: CompilationUnit): Program = { import scala.collection.mutable.HashMap def s2ps(tree: Tree): Expr = toPureScala(unit)(tree) match { @@ -278,7 +280,7 @@ trait CodeExtraction extends Extractors { } val bodyAttempt = try { - Some(scala2PureScala(unit, pluginInstance.silentlyTolerateNonPureBodies)(realBody)) + Some(flattenBlocks(scala2PureScala(unit, pluginInstance.silentlyTolerateNonPureBodies)(realBody))) } catch { case e: ImpureCodeEncounteredException => None } @@ -375,304 +377,454 @@ trait CodeExtraction extends Extractors { if(cd.guard == EmptyTree) { SimpleCase(pat2pat(cd.pat), rec(cd.body)) } else { - GuardedCase(pat2pat(cd.pat), rec(cd.guard), rec(cd.body)) + val recPattern = pat2pat(cd.pat) + val recGuard = rec(cd.guard) + val recBody = rec(cd.body) + if(!isPure(recGuard)) { + unit.error(cd.guard.pos, "Guard expression must be pure") + throw ImpureCodeEncounteredException(cd) + } + GuardedCase(recPattern, recGuard, recBody) } } - def rec(tr: Tree): Expr = tr match { - case ExTuple(tpes, exprs) => { - println("getting ExTuple with " + tpes + " and " + exprs) - val tupleType = TupleType(tpes.map(tpe => scalaType2PureScala(unit, silent)(tpe))) - val tupleExprs = exprs.map(e => rec(e)) - Tuple(tupleExprs).setType(tupleType) - } - case ExTupleExtract(tuple, index) => { - val tupleExpr = rec(tuple) - val TupleType(tpes) = tupleExpr.getType - if(tpes.size < index) - throw ImpureCodeEncounteredException(tree) - else - TupleSelect(tupleExpr, index).setType(tpes(index-1)) - } - case ExValDef(vs, tpt, bdy, rst) => { - val binderTpe = scalaType2PureScala(unit, silent)(tpt.tpe) - val newID = FreshIdentifier(vs.name.toString).setType(binderTpe) - val oldSubsts = varSubsts - val valTree = rec(bdy) - varSubsts(vs) = (() => Variable(newID)) - val restTree = rec(rst) - varSubsts.remove(vs) - Let(newID, valTree, restTree) + def extractFunSig(nameStr: String, params: Seq[ValDef], tpt: Tree): FunDef = { + val newParams = params.map(p => { + val ptpe = scalaType2PureScala(unit, silent) (p.tpt.tpe) + val newID = FreshIdentifier(p.name.toString).setType(ptpe) + varSubsts(p.symbol) = (() => Variable(newID)) + VarDecl(newID, ptpe) + }) + new FunDef(FreshIdentifier(nameStr), scalaType2PureScala(unit, silent)(tpt.tpe), newParams) + } + + def extractFunDef(funDef: FunDef, body: Tree): FunDef = { + var realBody = body + var reqCont: Option[Expr] = None + var ensCont: Option[Expr] = None + + realBody match { + case ExEnsuredExpression(body2, resSym, contract) => { + varSubsts(resSym) = (() => ResultVariable().setType(funDef.returnType)) + val c1 = scala2PureScala(unit, pluginInstance.silentlyTolerateNonPureBodies) (contract) + // varSubsts.remove(resSym) + realBody = body2 + ensCont = Some(c1) + } + case ExHoldsExpression(body2) => { + realBody = body2 + ensCont = Some(ResultVariable().setType(BooleanType)) + } + case _ => ; } - case ExInt32Literal(v) => IntLiteral(v).setType(Int32Type) - case ExBooleanLiteral(v) => BooleanLiteral(v).setType(BooleanType) - case ExTyped(e,tpt) => rec(e) - case ExIdentifier(sym,tpt) => varSubsts.get(sym) match { - case Some(fun) => fun() - case None => { - unit.error(tr.pos, "Unidentified variable.") - throw ImpureCodeEncounteredException(tr) + + realBody match { + case ExRequiredExpression(body3, contract) => { + realBody = body3 + reqCont = Some(scala2PureScala(unit, pluginInstance.silentlyTolerateNonPureBodies) (contract)) } + case _ => ; } - case ExSomeConstruction(tpe, arg) => { - println("Got Some !" + tpe + ":" + arg) - val underlying = scalaType2PureScala(unit, silent)(tpe) - OptionSome(rec(arg)).setType(OptionType(underlying)) + + val bodyAttempt = try { + Some(flattenBlocks(scala2PureScala(unit, pluginInstance.silentlyTolerateNonPureBodies)(realBody))) + } catch { + case e: ImpureCodeEncounteredException => None } - case ExCaseClassConstruction(tpt, args) => { - val cctype = scalaType2PureScala(unit, silent)(tpt.tpe) - if(!cctype.isInstanceOf[CaseClassType]) { - if(!silent) { - unit.error(tr.pos, "Construction of a non-case class.") + + funDef.body = bodyAttempt + funDef.precondition = reqCont + funDef.postcondition = ensCont + funDef + } + + def rec(tr: Tree): Expr = { + + val (nextExpr, rest) = tr match { + case Block(Block(e :: es1, l1) :: es2, l2) => (e, Some(Block(es1 ++ Seq(l1) ++ es2, l2))) + case Block(e :: Nil, last) => (e, Some(last)) + case Block(e :: es, last) => (e, Some(Block(es, last))) + case _ => (tr, None) + } + + var handleRest = true + val psExpr = nextExpr match { + case ExTuple(tpes, exprs) => { + val tupleType = TupleType(tpes.map(tpe => scalaType2PureScala(unit, silent)(tpe))) + val tupleExprs = exprs.map(e => rec(e)) + Tuple(tupleExprs).setType(tupleType) + } + case ExTupleExtract(tuple, index) => { + val tupleExpr = rec(tuple) + val TupleType(tpes) = tupleExpr.getType + if(tpes.size < index) + throw ImpureCodeEncounteredException(tree) + else + TupleSelect(tupleExpr, index).setType(tpes(index-1)) + } + case ExValDef(vs, tpt, bdy) => { + val binderTpe = scalaType2PureScala(unit, silent)(tpt.tpe) + val newID = FreshIdentifier(vs.name.toString).setType(binderTpe) + val valTree = rec(bdy) + val restTree = rest match { + case Some(rst) => { + varSubsts(vs) = (() => Variable(newID)) + val res = rec(rst) + varSubsts.remove(vs) + res + } + case None => UnitLiteral } - throw ImpureCodeEncounteredException(tree) + handleRest = false + val res = Let(newID, valTree, restTree) + res + } + case dd@ExFunctionDef(n, p, t, b) => { + val funDef = extractFunSig(n, p, t).setPosInfo(dd.pos.line, dd.pos.column) + defsToDefs += (dd.symbol -> funDef) + val oldMutableVarSubst = mutableVarSubsts.toMap //take an immutable snapshot of the map + mutableVarSubsts.clear //reseting the visible mutable vars, we do not handle mutable variable closure in nested functions + val funDefWithBody = extractFunDef(funDef, b) + mutableVarSubsts ++= oldMutableVarSubst + val restTree = rest match { + case Some(rst) => rec(rst) + case None => UnitLiteral + } + defsToDefs.remove(dd.symbol) + handleRest = false + LetDef(funDefWithBody, restTree) + } + case ExVarDef(vs, tpt, bdy) => { + val binderTpe = scalaType2PureScala(unit, silent)(tpt.tpe) + val newID = FreshIdentifier(vs.name.toString).setType(binderTpe) + val valTree = rec(bdy) + mutableVarSubsts += (vs -> (() => Variable(newID))) + val restTree = rest match { + case Some(rst) => { + varSubsts(vs) = (() => Variable(newID)) + val res = rec(rst) + varSubsts.remove(vs) + res + } + case None => UnitLiteral + } + handleRest = false + val res = LetVar(newID, valTree, restTree) + res } - val nargs = args.map(rec(_)) - val cct = cctype.asInstanceOf[CaseClassType] - CaseClass(cct.classDef, nargs).setType(cct) - } - case ExAnd(l, r) => And(rec(l), rec(r)).setType(BooleanType) - case ExOr(l, r) => Or(rec(l), rec(r)).setType(BooleanType) - case ExNot(e) => Not(rec(e)).setType(BooleanType) - case ExUMinus(e) => UMinus(rec(e)).setType(Int32Type) - case ExPlus(l, r) => Plus(rec(l), rec(r)).setType(Int32Type) - case ExMinus(l, r) => Minus(rec(l), rec(r)).setType(Int32Type) - case ExTimes(l, r) => Times(rec(l), rec(r)).setType(Int32Type) - case ExDiv(l, r) => Division(rec(l), rec(r)).setType(Int32Type) - case ExMod(l, r) => Modulo(rec(l), rec(r)).setType(Int32Type) - case ExEquals(l, r) => { - val rl = rec(l) - val rr = rec(r) - ((rl.getType,rr.getType) match { - case (SetType(_), SetType(_)) => SetEquals(rl, rr) - case (BooleanType, BooleanType) => Iff(rl, rr) - case (_, _) => Equals(rl, rr) - }).setType(BooleanType) - } - case ExNotEquals(l, r) => Not(Equals(rec(l), rec(r)).setType(BooleanType)).setType(BooleanType) - case ExGreaterThan(l, r) => GreaterThan(rec(l), rec(r)).setType(BooleanType) - case ExGreaterEqThan(l, r) => GreaterEquals(rec(l), rec(r)).setType(BooleanType) - case ExLessThan(l, r) => LessThan(rec(l), rec(r)).setType(BooleanType) - case ExLessEqThan(l, r) => LessEquals(rec(l), rec(r)).setType(BooleanType) - case ExFiniteSet(tt, args) => { - val underlying = scalaType2PureScala(unit, silent)(tt.tpe) - FiniteSet(args.map(rec(_))).setType(SetType(underlying)) - } - case ExFiniteMultiset(tt, args) => { - val underlying = scalaType2PureScala(unit, silent)(tt.tpe) - FiniteMultiset(args.map(rec(_))).setType(MultisetType(underlying)) - } - case ExEmptySet(tt) => { - val underlying = scalaType2PureScala(unit, silent)(tt.tpe) - EmptySet(underlying).setType(SetType(underlying)) - } - case ExEmptyMultiset(tt) => { - val underlying = scalaType2PureScala(unit, silent)(tt.tpe) - EmptyMultiset(underlying).setType(MultisetType(underlying)) - } - case ExEmptyMap(ft, tt) => { - val fromUnderlying = scalaType2PureScala(unit, silent)(ft.tpe) - val toUnderlying = scalaType2PureScala(unit, silent)(tt.tpe) - EmptyMap(fromUnderlying, toUnderlying).setType(MapType(fromUnderlying, toUnderlying)) - } - case ExSetMin(t) => { - val set = rec(t) - if(!set.getType.isInstanceOf[SetType]) { - if(!silent) unit.error(t.pos, "Min should be computed on a set.") - throw ImpureCodeEncounteredException(tree) + + case ExAssign(sym, rhs) => mutableVarSubsts.get(sym) match { + case Some(fun) => { + val Variable(id) = fun() + val rhsTree = rec(rhs) + Assignment(id, rhsTree) + } + case None => { + unit.error(tr.pos, "Undeclared variable.") + throw ImpureCodeEncounteredException(tr) + } } - SetMin(set).setType(set.getType.asInstanceOf[SetType].base) - } - case ExSetMax(t) => { - val set = rec(t) - if(!set.getType.isInstanceOf[SetType]) { - if(!silent) unit.error(t.pos, "Max should be computed on a set.") - throw ImpureCodeEncounteredException(tree) + case wh@ExWhile(cond, body) => { + val condTree = rec(cond) + val bodyTree = rec(body) + While(condTree, bodyTree).setPosInfo(wh.pos.line,wh.pos.column) + } + case wh@ExWhileWithInvariant(cond, body, inv) => { + val condTree = rec(cond) + val bodyTree = rec(body) + val invTree = rec(inv) + val w = While(condTree, bodyTree).setPosInfo(wh.pos.line,wh.pos.column) + w.invariant = Some(invTree) + w + } + + case ExInt32Literal(v) => IntLiteral(v).setType(Int32Type) + case ExBooleanLiteral(v) => BooleanLiteral(v).setType(BooleanType) + case ExUnitLiteral() => UnitLiteral + + case ExTyped(e,tpt) => rec(e) + case ExIdentifier(sym,tpt) => varSubsts.get(sym) match { + case Some(fun) => fun() + case None => mutableVarSubsts.get(sym) match { + case Some(fun) => fun() + case None => { + unit.error(tr.pos, "Unidentified variable.") + throw ImpureCodeEncounteredException(tr) + } + } } - SetMax(set).setType(set.getType.asInstanceOf[SetType].base) - } - case ExUnion(t1,t2) => { - val rl = rec(t1) - val rr = rec(t2) - rl.getType match { - case s @ SetType(_) => SetUnion(rl, rr).setType(s) - case m @ MultisetType(_) => MultisetUnion(rl, rr).setType(m) - case _ => { - if(!silent) unit.error(tree.pos, "Union of non set/multiset expressions.") + case ExSomeConstruction(tpe, arg) => { + // println("Got Some !" + tpe + ":" + arg) + val underlying = scalaType2PureScala(unit, silent)(tpe) + OptionSome(rec(arg)).setType(OptionType(underlying)) + } + case ExCaseClassConstruction(tpt, args) => { + val cctype = scalaType2PureScala(unit, silent)(tpt.tpe) + if(!cctype.isInstanceOf[CaseClassType]) { + if(!silent) { + unit.error(tr.pos, "Construction of a non-case class.") + } throw ImpureCodeEncounteredException(tree) } - } - } - case ExIntersection(t1,t2) => { - val rl = rec(t1) - val rr = rec(t2) - rl.getType match { - case s @ SetType(_) => SetIntersection(rl, rr).setType(s) - case m @ MultisetType(_) => MultisetIntersection(rl, rr).setType(m) - case _ => { - if(!silent) unit.error(tree.pos, "Intersection of non set/multiset expressions.") + val nargs = args.map(rec(_)) + val cct = cctype.asInstanceOf[CaseClassType] + CaseClass(cct.classDef, nargs).setType(cct) + } + case ExAnd(l, r) => And(rec(l), rec(r)).setType(BooleanType) + case ExOr(l, r) => Or(rec(l), rec(r)).setType(BooleanType) + case ExNot(e) => Not(rec(e)).setType(BooleanType) + case ExUMinus(e) => UMinus(rec(e)).setType(Int32Type) + case ExPlus(l, r) => Plus(rec(l), rec(r)).setType(Int32Type) + case ExMinus(l, r) => Minus(rec(l), rec(r)).setType(Int32Type) + case ExTimes(l, r) => Times(rec(l), rec(r)).setType(Int32Type) + case ExDiv(l, r) => Division(rec(l), rec(r)).setType(Int32Type) + case ExMod(l, r) => Modulo(rec(l), rec(r)).setType(Int32Type) + case ExEquals(l, r) => { + val rl = rec(l) + val rr = rec(r) + ((rl.getType,rr.getType) match { + case (SetType(_), SetType(_)) => SetEquals(rl, rr) + case (BooleanType, BooleanType) => Iff(rl, rr) + case (_, _) => Equals(rl, rr) + }).setType(BooleanType) + } + case ExNotEquals(l, r) => Not(Equals(rec(l), rec(r)).setType(BooleanType)).setType(BooleanType) + case ExGreaterThan(l, r) => GreaterThan(rec(l), rec(r)).setType(BooleanType) + case ExGreaterEqThan(l, r) => GreaterEquals(rec(l), rec(r)).setType(BooleanType) + case ExLessThan(l, r) => LessThan(rec(l), rec(r)).setType(BooleanType) + case ExLessEqThan(l, r) => LessEquals(rec(l), rec(r)).setType(BooleanType) + case ExFiniteSet(tt, args) => { + val underlying = scalaType2PureScala(unit, silent)(tt.tpe) + FiniteSet(args.map(rec(_))).setType(SetType(underlying)) + } + case ExFiniteMultiset(tt, args) => { + val underlying = scalaType2PureScala(unit, silent)(tt.tpe) + FiniteMultiset(args.map(rec(_))).setType(MultisetType(underlying)) + } + case ExEmptySet(tt) => { + val underlying = scalaType2PureScala(unit, silent)(tt.tpe) + EmptySet(underlying).setType(SetType(underlying)) + } + case ExEmptyMultiset(tt) => { + val underlying = scalaType2PureScala(unit, silent)(tt.tpe) + EmptyMultiset(underlying).setType(MultisetType(underlying)) + } + case ExEmptyMap(ft, tt) => { + val fromUnderlying = scalaType2PureScala(unit, silent)(ft.tpe) + val toUnderlying = scalaType2PureScala(unit, silent)(tt.tpe) + EmptyMap(fromUnderlying, toUnderlying).setType(MapType(fromUnderlying, toUnderlying)) + } + case ExSetMin(t) => { + val set = rec(t) + if(!set.getType.isInstanceOf[SetType]) { + if(!silent) unit.error(t.pos, "Min should be computed on a set.") throw ImpureCodeEncounteredException(tree) } + SetMin(set).setType(set.getType.asInstanceOf[SetType].base) } - } - case ExSetContains(t1,t2) => { - val rl = rec(t1) - val rr = rec(t2) - rl.getType match { - case s @ SetType(_) => ElementOfSet(rr, rl) - case _ => { - if(!silent) unit.error(tree.pos, ".contains on non set expression.") + case ExSetMax(t) => { + val set = rec(t) + if(!set.getType.isInstanceOf[SetType]) { + if(!silent) unit.error(t.pos, "Max should be computed on a set.") throw ImpureCodeEncounteredException(tree) } + SetMax(set).setType(set.getType.asInstanceOf[SetType].base) + } + case ExUnion(t1,t2) => { + val rl = rec(t1) + val rr = rec(t2) + rl.getType match { + case s @ SetType(_) => SetUnion(rl, rr).setType(s) + case m @ MultisetType(_) => MultisetUnion(rl, rr).setType(m) + case _ => { + if(!silent) unit.error(tree.pos, "Union of non set/multiset expressions.") + throw ImpureCodeEncounteredException(tree) + } + } } - } - case ExSetSubset(t1,t2) => { - val rl = rec(t1) - val rr = rec(t2) - rl.getType match { - case s @ SetType(_) => SubsetOf(rl, rr) - case _ => { - if(!silent) unit.error(tree.pos, "Subset on non set expression.") - throw ImpureCodeEncounteredException(tree) + case ExIntersection(t1,t2) => { + val rl = rec(t1) + val rr = rec(t2) + rl.getType match { + case s @ SetType(_) => SetIntersection(rl, rr).setType(s) + case m @ MultisetType(_) => MultisetIntersection(rl, rr).setType(m) + case _ => { + if(!silent) unit.error(tree.pos, "Intersection of non set/multiset expressions.") + throw ImpureCodeEncounteredException(tree) + } } } - } - case ExSetMinus(t1,t2) => { - val rl = rec(t1) - val rr = rec(t2) - rl.getType match { - case s @ SetType(_) => SetDifference(rl, rr).setType(s) - case m @ MultisetType(_) => MultisetDifference(rl, rr).setType(m) - case _ => { - if(!silent) unit.error(tree.pos, "Difference of non set/multiset expressions.") - throw ImpureCodeEncounteredException(tree) + case ExSetContains(t1,t2) => { + val rl = rec(t1) + val rr = rec(t2) + rl.getType match { + case s @ SetType(_) => ElementOfSet(rr, rl) + case _ => { + if(!silent) unit.error(tree.pos, ".contains on non set expression.") + throw ImpureCodeEncounteredException(tree) + } } } - } - case ExSetCard(t) => { - val rt = rec(t) - rt.getType match { - case s @ SetType(_) => SetCardinality(rt) - case m @ MultisetType(_) => MultisetCardinality(rt) - case _ => { - if(!silent) unit.error(tree.pos, "Cardinality of non set/multiset expressions.") - throw ImpureCodeEncounteredException(tree) + case ExSetSubset(t1,t2) => { + val rl = rec(t1) + val rr = rec(t2) + rl.getType match { + case s @ SetType(_) => SubsetOf(rl, rr) + case _ => { + if(!silent) unit.error(tree.pos, "Subset on non set expression.") + throw ImpureCodeEncounteredException(tree) + } } } - } - case ExMultisetToSet(t) => { - val rt = rec(t) - rt.getType match { - case m @ MultisetType(u) => MultisetToSet(rt).setType(SetType(u)) - case _ => { - if(!silent) unit.error(tree.pos, "toSet can only be applied to multisets.") - throw ImpureCodeEncounteredException(tree) + case ExSetMinus(t1,t2) => { + val rl = rec(t1) + val rr = rec(t2) + rl.getType match { + case s @ SetType(_) => SetDifference(rl, rr).setType(s) + case m @ MultisetType(_) => MultisetDifference(rl, rr).setType(m) + case _ => { + if(!silent) unit.error(tree.pos, "Difference of non set/multiset expressions.") + throw ImpureCodeEncounteredException(tree) + } + } + } + case ExSetCard(t) => { + val rt = rec(t) + rt.getType match { + case s @ SetType(_) => SetCardinality(rt) + case m @ MultisetType(_) => MultisetCardinality(rt) + case _ => { + if(!silent) unit.error(tree.pos, "Cardinality of non set/multiset expressions.") + throw ImpureCodeEncounteredException(tree) + } } } - } - case ExMapUpdated(m,f,t) => { - val rm = rec(m) - val rf = rec(f) - val rt = rec(t) - val newSingleton = SingletonMap(rf, rt).setType(rm.getType) - rm.getType match { - case MapType(ft, tt) => - MapUnion(rm, FiniteMap(Seq(newSingleton)).setType(rm.getType)).setType(rm.getType) - case _ => { - if (!silent) unit.error(tree.pos, "updated can only be applied to maps.") - throw ImpureCodeEncounteredException(tree) + case ExMultisetToSet(t) => { + val rt = rec(t) + rt.getType match { + case m @ MultisetType(u) => MultisetToSet(rt).setType(SetType(u)) + case _ => { + if(!silent) unit.error(tree.pos, "toSet can only be applied to multisets.") + throw ImpureCodeEncounteredException(tree) + } } } - } - case ExMapIsDefinedAt(m,k) => { - val rm = rec(m) - val rk = rec(k) - MapIsDefinedAt(rm, rk) - } - - case ExPlusPlusPlus(t1,t2) => { - val rl = rec(t1) - val rr = rec(t2) - MultisetPlus(rl, rr).setType(rl.getType) - } - case ExApply(lhs,args) => { - val rlhs = rec(lhs) - val rargs = args map rec - rlhs.getType match { - case MapType(_,tt) => - assert(rargs.size == 1) - MapGet(rlhs, rargs.head).setType(tt) - case FunctionType(fts, tt) => { - rlhs match { - case Variable(id) => - AnonymousFunctionInvocation(id, rargs).setType(tt) - case _ => { - if (!silent) unit.error(tree.pos, "apply on non-variable or non-map expression") - throw ImpureCodeEncounteredException(tree) - } + case ExMapUpdated(m,f,t) => { + val rm = rec(m) + val rf = rec(f) + val rt = rec(t) + val newSingleton = SingletonMap(rf, rt).setType(rm.getType) + rm.getType match { + case MapType(ft, tt) => + MapUnion(rm, FiniteMap(Seq(newSingleton)).setType(rm.getType)).setType(rm.getType) + case _ => { + if (!silent) unit.error(tree.pos, "updated can only be applied to maps.") + throw ImpureCodeEncounteredException(tree) } } - case _ => { - if (!silent) unit.error(tree.pos, "apply on unexpected type") - throw ImpureCodeEncounteredException(tree) + } + case ExMapIsDefinedAt(m,k) => { + val rm = rec(m) + val rk = rec(k) + MapIsDefinedAt(rm, rk) + } + + case ExPlusPlusPlus(t1,t2) => { + val rl = rec(t1) + val rr = rec(t2) + MultisetPlus(rl, rr).setType(rl.getType) + } + case app@ExApply(lhs,args) => { + val rlhs = rec(lhs) + val rargs = args map rec + rlhs.getType match { + case MapType(_,tt) => + assert(rargs.size == 1) + MapGet(rlhs, rargs.head).setType(tt).setPosInfo(app.pos.line, app.pos.column) + case FunctionType(fts, tt) => { + rlhs match { + case Variable(id) => + AnonymousFunctionInvocation(id, rargs).setType(tt) + case _ => { + if (!silent) unit.error(tree.pos, "apply on non-variable or non-map expression") + throw ImpureCodeEncounteredException(tree) + } + } + } + case _ => { + if (!silent) unit.error(tree.pos, "apply on unexpected type") + throw ImpureCodeEncounteredException(tree) + } } } - } - case ExIfThenElse(t1,t2,t3) => { - val r1 = rec(t1) - val r2 = rec(t2) - val r3 = rec(t3) - IfExpr(r1, r2, r3).setType(leastUpperBound(r2.getType, r3.getType)) - } - case lc @ ExLocalCall(sy,nm,ar) => { - if(defsToDefs.keysIterator.find(_ == sy).isEmpty) { - if(!silent) - unit.error(tr.pos, "Invoking an invalid function.") - throw ImpureCodeEncounteredException(tr) + case ExIfThenElse(t1,t2,t3) => { + val r1 = rec(t1) + val r2 = rec(t2) + val r3 = rec(t3) + IfExpr(r1, r2, r3).setType(leastUpperBound(r2.getType, r3.getType)) } - val fd = defsToDefs(sy) - FunctionInvocation(fd, ar.map(rec(_))).setType(fd.returnType).setPosInfo(lc.pos.line,lc.pos.column) - } - case pm @ ExPatternMatching(sel, cses) => { - val rs = rec(sel) - val rc = cses.map(rewriteCaseDef(_)) - val rt: purescala.TypeTrees.TypeTree = rc.map(_.rhs.getType).reduceLeft(leastUpperBound(_,_)) - MatchExpr(rs, rc).setType(rt).setPosInfo(pm.pos.line,pm.pos.column) - } - - // this one should stay after all others, cause it also catches UMinus - // and Not, for instance. - case ExParameterlessMethodCall(t,n) => { - val selector = rec(t) - val selType = selector.getType - - if(!selType.isInstanceOf[CaseClassType]) { - if(!silent) - unit.error(tr.pos, "Invalid method or field invocation (not purescala?)") - throw ImpureCodeEncounteredException(tr) + case lc @ ExLocalCall(sy,nm,ar) => { + if(defsToDefs.keysIterator.find(_ == sy).isEmpty) { + if(!silent) + unit.error(tr.pos, "Invoking an invalid function.") + throw ImpureCodeEncounteredException(tr) + } + val fd = defsToDefs(sy) + FunctionInvocation(fd, ar.map(rec(_))).setType(fd.returnType).setPosInfo(lc.pos.line,lc.pos.column) + } + case pm @ ExPatternMatching(sel, cses) => { + val rs = rec(sel) + val rc = cses.map(rewriteCaseDef(_)) + val rt: purescala.TypeTrees.TypeTree = rc.map(_.rhs.getType).reduceLeft(leastUpperBound(_,_)) + MatchExpr(rs, rc).setType(rt).setPosInfo(pm.pos.line,pm.pos.column) } - val selDef: CaseClassDef = selType.asInstanceOf[CaseClassType].classDef + // this one should stay after all others, cause it also catches UMinus + // and Not, for instance. + case ExParameterlessMethodCall(t,n) => { + val selector = rec(t) + val selType = selector.getType - val fieldID = selDef.fields.find(_.id.name == n.toString) match { - case None => { + if(!selType.isInstanceOf[CaseClassType]) { if(!silent) - unit.error(tr.pos, "Invalid method or field invocation (not a case class arg?)") + unit.error(tr.pos, "Invalid method or field invocation (not purescala?)") throw ImpureCodeEncounteredException(tr) } - case Some(vd) => vd.id - } - CaseClassSelector(selDef, selector, fieldID).setType(fieldID.getType) + val selDef: CaseClassDef = selType.asInstanceOf[CaseClassType].classDef + + val fieldID = selDef.fields.find(_.id.name == n.toString) match { + case None => { + if(!silent) + unit.error(tr.pos, "Invalid method or field invocation (not a case class arg?)") + throw ImpureCodeEncounteredException(tr) + } + case Some(vd) => vd.id + } + + CaseClassSelector(selDef, selector, fieldID).setType(fieldID.getType) + } + + // default behaviour is to complain :) + case _ => { + if(!silent) { + println(tr) + reporter.info(tr.pos, "Could not extract as PureScala.", true) + } + throw ImpureCodeEncounteredException(tree) + } } - - // default behaviour is to complain :) - case _ => { - if(!silent) { - println(tr) - reporter.info(tr.pos, "Could not extract as PureScala.", true) + + if(handleRest) { + rest match { + case Some(rst) => { + val recRst = rec(rst) + PBlock(Seq(psExpr), recRst).setType(recRst.getType) + } + case None => psExpr } - throw ImpureCodeEncounteredException(tree) + } else { + psExpr } } rec(tree) diff --git a/src/main/scala/leon/plugin/Extractors.scala b/src/main/scala/leon/plugin/Extractors.scala index 08cdfe37973e3fc2ac9e641a4bf6344f63fcb91d..4bd9f9bf3bb34940a0bf0ea552a89df60b177ec3 100644 --- a/src/main/scala/leon/plugin/Extractors.scala +++ b/src/main/scala/leon/plugin/Extractors.scala @@ -31,61 +31,6 @@ trait Extractors { } } - object ExTuple { - def unapply(tree: Apply): Option[(Seq[Type], Seq[Tree])] = tree match { - case Apply( - Select(New(tupleType), _), - List(e1, e2) - ) if tupleType.symbol == tuple2Sym => tupleType.tpe match { - case TypeRef(_, sym, List(t1, t2)) => Some((Seq(t1, t2), Seq(e1, e2))) - case _ => None - } - - case Apply( - Select(New(tupleType), _), - List(e1, e2, e3) - ) if tupleType.symbol == tuple3Sym => tupleType.tpe match { - case TypeRef(_, sym, List(t1, t2, t3)) => Some((Seq(t1, t2, t3), Seq(e1, e2, e3))) - case _ => None - } - case Apply( - Select(New(tupleType), _), - List(e1, e2, e3, e4) - ) if tupleType.symbol == tuple4Sym => tupleType.tpe match { - case TypeRef(_, sym, List(t1, t2, t3, t4)) => Some((Seq(t1, t2, t3, t4), Seq(e1, e2, e3, e4))) - case _ => None - } - case Apply( - Select(New(tupleType), _), - List(e1, e2, e3, e4, e5) - ) if tupleType.symbol == tuple5Sym => tupleType.tpe match { - case TypeRef(_, sym, List(t1, t2, t3, t4, t5)) => Some((Seq(t1, t2, t3, t4, t5), Seq(e1, e2, e3, e4, e5))) - case _ => None - } - case _ => None - } - } - - object ExTupleExtract { - def unapply(tree: Select) : Option[(Tree,Int)] = tree match { - case Select(lhs, n) => { - val methodName = n.toString - if(methodName.head == '_') { - val indexString = methodName.tail - try { - val index = indexString.toInt - if(index > 0) { - Some((lhs, index)) - } else None - } catch { - case _ => None - } - } else None - } - case _ => None - } - } - object ExEnsuredExpression { /** Extracts the 'ensuring' contract from an expression. */ def unapply(tree: Apply): Option[(Tree,Symbol,Tree)] = tree match { @@ -126,17 +71,6 @@ trait Extractors { } } - object ExValDef { - /** Extracts val's in the head of blocks. */ - def unapply(tree: Block): Option[(Symbol,Tree,Tree,Tree)] = tree match { - case Block((vd @ ValDef(_, _, tpt, rhs)) :: rest, expr) => - if(rest.isEmpty) - Some((vd.symbol, tpt, rhs, expr)) - else - Some((vd.symbol, tpt, rhs, Block(rest, expr))) - case _ => None - } - } object ExObjectDef { /** Matches an object with no type parameters, and regardless of its @@ -209,9 +143,118 @@ trait Extractors { case _ => None } } + } object ExpressionExtractors { + + //object ExLocalFunctionDef { + // def unapply(tree: Block): Option[(DefDef,String,Seq[ValDef],Tree,Tree,Tree)] = tree match { + // case Block((dd @ DefDef(_, name, tparams, vparamss, tpt, rhs)) :: rest, expr) if(tparams.isEmpty && vparamss.size == 1 && name != nme.CONSTRUCTOR) => { + // if(rest.isEmpty) + // Some((dd,name.toString, vparamss(0), tpt, rhs, expr)) + // else + // Some((dd,name.toString, vparamss(0), tpt, rhs, Block(rest, expr))) + // } + // case _ => None + // } + //} + + object ExValDef { + /** Extracts val's in the head of blocks. */ + def unapply(tree: ValDef): Option[(Symbol,Tree,Tree)] = tree match { + case vd @ ValDef(mods, _, tpt, rhs) if !mods.isMutable => Some((vd.symbol, tpt, rhs)) + case _ => None + } + } + object ExVarDef { + /** Extracts var's in the head of blocks. */ + def unapply(tree: ValDef): Option[(Symbol,Tree,Tree)] = tree match { + case vd @ ValDef(mods, _, tpt, rhs) if mods.isMutable => Some((vd.symbol, tpt, rhs)) + case _ => None + } + } + + object ExAssign { + def unapply(tree: Assign): Option[(Symbol,Tree)] = tree match { + case Assign(id@Ident(_), rhs) => Some((id.symbol, rhs)) + case _ => None + } + } + + object ExWhile { + def unapply(tree: LabelDef): Option[(Tree,Tree)] = tree match { + case (label@LabelDef( + _, _, If(cond, Block(body, jump@Apply(_, _)), unit@ExUnitLiteral()))) + if label.symbol == jump.symbol && unit.symbol == null => Some((cond, Block(body, unit))) + case _ => None + } + } + object ExWhileWithInvariant { + def unapply(tree: Apply): Option[(Tree, Tree, Tree)] = tree match { + case Apply( + Select( + Apply(while2invariant, List(ExWhile(cond, body))), + invariantSym), + List(invariant)) if invariantSym.toString == "invariant" => Some((cond, body, invariant)) + case _ => None + } + } + object ExTuple { + def unapply(tree: Apply): Option[(Seq[Type], Seq[Tree])] = tree match { + case Apply( + Select(New(tupleType), _), + List(e1, e2) + ) if tupleType.symbol == tuple2Sym => tupleType.tpe match { + case TypeRef(_, sym, List(t1, t2)) => Some((Seq(t1, t2), Seq(e1, e2))) + case _ => None + } + + case Apply( + Select(New(tupleType), _), + List(e1, e2, e3) + ) if tupleType.symbol == tuple3Sym => tupleType.tpe match { + case TypeRef(_, sym, List(t1, t2, t3)) => Some((Seq(t1, t2, t3), Seq(e1, e2, e3))) + case _ => None + } + case Apply( + Select(New(tupleType), _), + List(e1, e2, e3, e4) + ) if tupleType.symbol == tuple4Sym => tupleType.tpe match { + case TypeRef(_, sym, List(t1, t2, t3, t4)) => Some((Seq(t1, t2, t3, t4), Seq(e1, e2, e3, e4))) + case _ => None + } + case Apply( + Select(New(tupleType), _), + List(e1, e2, e3, e4, e5) + ) if tupleType.symbol == tuple5Sym => tupleType.tpe match { + case TypeRef(_, sym, List(t1, t2, t3, t4, t5)) => Some((Seq(t1, t2, t3, t4, t5), Seq(e1, e2, e3, e4, e5))) + case _ => None + } + case _ => None + } + } + + object ExTupleExtract { + def unapply(tree: Select) : Option[(Tree,Int)] = tree match { + case Select(lhs, n) => { + val methodName = n.toString + if(methodName.head == '_') { + val indexString = methodName.tail + try { + val index = indexString.toInt + if(index > 0) { + Some((lhs, index)) + } else None + } catch { + case _ => None + } + } else None + } + case _ => None + } + } + object ExIfThenElse { def unapply(tree: If): Option[(Tree,Tree,Tree)] = tree match { case If(t1,t2,t3) => Some((t1,t2,t3)) @@ -234,6 +277,13 @@ trait Extractors { } } + object ExUnitLiteral { + def unapply(tree: Literal): Boolean = tree match { + case Literal(c @ Constant(_)) if c.tpe == UnitClass.tpe => true + case _ => false + } + } + object ExSomeConstruction { def unapply(tree: Apply) : Option[(Type,Tree)] = tree match { case Apply(s @ Select(New(tpt), n), arg) if (arg.size == 1 && n == nme.CONSTRUCTOR && tpt.symbol.name.toString == "Some") => tpt.tpe match { @@ -384,6 +434,7 @@ trait Extractors { object ExLocalCall { def unapply(tree: Apply): Option[(Symbol,String,List[Tree])] = tree match { case a @ Apply(Select(This(_), nme), args) => Some((a.symbol, nme.toString, args)) + case a @ Apply(Ident(nme), args) => Some((a.symbol, nme.toString, args)) case _ => None } } diff --git a/src/main/scala/leon/plugin/LeonPlugin.scala b/src/main/scala/leon/plugin/LeonPlugin.scala index b059ad59dbf218ea9042eb6923484778c3ae5c46..5de24fa44bbf489350371d2bce9775231ebf3e1d 100644 --- a/src/main/scala/leon/plugin/LeonPlugin.scala +++ b/src/main/scala/leon/plugin/LeonPlugin.scala @@ -39,7 +39,8 @@ class LeonPlugin(val global: Global, val actionAfterExtraction : Option[Program= " --cores Use UNSAT cores in the unrolling/refinement step" + "\n" + " --quickcheck Use QuickCheck-like random search" + "\n" + " --parallel Run all solvers in parallel" + "\n" + - " --noLuckyTests Do not perform additional tests to potentially find models early" + " --noLuckyTests Do not perform additional tests to potentially find models early" + "\n" + + " --verbose Print debugging informations" ) /** Processes the command-line options. */ @@ -62,6 +63,7 @@ class LeonPlugin(val global: Global, val actionAfterExtraction : Option[Program= case "quickcheck" => leon.Settings.useQuickCheck = true case "parallel" => leon.Settings.useParallel = true case "noLuckyTests" => leon.Settings.luckyTest = false + case "verbose" => leon.Settings.verbose = true case s if s.startsWith("unrolling=") => leon.Settings.unrollingLevel = try { s.substring("unrolling=".length, s.length).toInt } catch { case _ => 0 } case s if s.startsWith("functions=") => leon.Settings.functionsToAnalyse = Set(splitList(s.substring("functions=".length, s.length)): _*) case s if s.startsWith("extensions=") => leon.Settings.extensionNames = splitList(s.substring("extensions=".length, s.length)) diff --git a/src/main/scala/leon/purescala/Definitions.scala b/src/main/scala/leon/purescala/Definitions.scala index efd8b52fb31e7cd4d61ff868bba3c836ca1f0e2e..f2da8a7cef1f1f3b0290a35a260f40ceefbf6705 100644 --- a/src/main/scala/leon/purescala/Definitions.scala +++ b/src/main/scala/leon/purescala/Definitions.scala @@ -283,6 +283,11 @@ object Definitions { var precondition: Option[Expr] = None var postcondition: Option[Expr] = None + //true if this function has been generated from a while loop + var fromLoop = false + //the fundef where the loop was defined (if applies) + var parent: Option[FunDef] = None + def hasImplementation : Boolean = body.isDefined def hasBody = hasImplementation def hasPrecondition : Boolean = precondition.isDefined diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala index 9ed2e564646559bb0b62bfc6078c20fae7e028b9..9cc5b7b626a866430398cbd38a0faadb7b54cdee 100644 --- a/src/main/scala/leon/purescala/PrettyPrinter.scala +++ b/src/main/scala/leon/purescala/PrettyPrinter.scala @@ -69,7 +69,32 @@ object PrettyPrinter { case Variable(id) => sb.append(id) case DeBruijnIndex(idx) => sb.append("_" + idx) case Let(b,d,e) => { - pp(e, pp(d, sb.append("(let (" + b + " := "), lvl).append(") in "), lvl).append(")") + //pp(e, pp(d, sb.append("(let (" + b + " := "), lvl).append(") in "), lvl).append(")") + sb.append("(let (" + b + " := "); + pp(d, sb, lvl) + sb.append(") in\n") + ind(sb, lvl+1) + pp(e, sb, lvl+1) + sb.append(")") + sb + } + case LetVar(b,d,e) => { + sb.append("(letvar (" + b + " := "); + pp(d, sb, lvl) + sb.append(") in\n") + ind(sb, lvl+1) + pp(e, sb, lvl+1) + sb.append(")") + sb + } + case LetDef(fd,e) => { + sb.append("\n") + pp(fd, sb, lvl+1) + sb.append("\n") + sb.append("\n") + ind(sb, lvl) + pp(e, sb, lvl) + sb } case And(exprs) => ppNary(sb, exprs, "(", " \u2227 ", ")", lvl) // \land case Or(exprs) => ppNary(sb, exprs, "(", " \u2228 ", ")", lvl) // \lor @@ -81,6 +106,39 @@ object PrettyPrinter { case IntLiteral(v) => sb.append(v) case BooleanLiteral(v) => sb.append(v) case StringLiteral(s) => sb.append("\"" + s + "\"") + case UnitLiteral => sb.append("()") + case Block(exprs, last) => { + sb.append("{\n") + (exprs :+ last).foreach(e => { + ind(sb, lvl+1) + pp(e, sb, lvl+1) + sb.append("\n") + }) + ind(sb, lvl) + sb.append("}\n") + sb + } + case Assignment(lhs, rhs) => ppBinary(sb, lhs.toVariable, rhs, " = ", lvl) + case wh@While(cond, body) => { + wh.invariant match { + case Some(inv) => { + sb.append("\n") + ind(sb, lvl) + sb.append("@invariant: ") + pp(inv, sb, lvl) + sb.append("\n") + ind(sb, lvl) + } + case None => + } + sb.append("while(") + pp(cond, sb, lvl) + sb.append(")\n") + ind(sb, lvl+1) + pp(body, sb, lvl+1) + sb.append("\n") + } + case Tuple(exprs) => ppNary(sb, exprs, "(", ", ", ")", lvl) case TupleSelect(t, i) => { pp(t, sb, lvl) @@ -194,18 +252,27 @@ object PrettyPrinter { var nsb = sb nsb.append("if (") nsb = pp(c, nsb, lvl) - nsb.append(") {\n") + nsb.append(")\n") ind(nsb, lvl+1) nsb = pp(t, nsb, lvl+1) nsb.append("\n") ind(nsb, lvl) - nsb.append("} else {\n") + nsb.append("else\n") ind(nsb, lvl+1) nsb = pp(e, nsb, lvl+1) - nsb.append("\n") - ind(nsb, lvl) - nsb.append("}") nsb + //nsb.append(") {\n") + //ind(nsb, lvl+1) + //nsb = pp(t, nsb, lvl+1) + //nsb.append("\n") + //ind(nsb, lvl) + //nsb.append("} else {\n") + //ind(nsb, lvl+1) + //nsb = pp(e, nsb, lvl+1) + //nsb.append("\n") + //ind(nsb, lvl) + //nsb.append("}") + //nsb } case mex @ MatchExpr(s, csc) => { @@ -287,6 +354,7 @@ object PrettyPrinter { private def pp(tpe: TypeTree, sb: StringBuffer, lvl: Int): StringBuffer = tpe match { case Untyped => sb.append("???") + case UnitType => sb.append("Unit") case Int32Type => sb.append("Int") case BooleanType => sb.append("Boolean") case SetType(bt) => pp(bt, sb.append("Set["), lvl).append("]") diff --git a/src/main/scala/leon/purescala/Trees.scala b/src/main/scala/leon/purescala/Trees.scala index 7d6d885a7dfc90a4f69c45cf7b02c54622479638..e738708a6c0c288039f85ff01dd17bfd99d82517 100644 --- a/src/main/scala/leon/purescala/Trees.scala +++ b/src/main/scala/leon/purescala/Trees.scala @@ -17,6 +17,24 @@ object Trees { self: Expr => } + case class Block(exprs: Seq[Expr], last: Expr) extends Expr { + //val t = last.getType + //if(t != Untyped) + // setType(t) + } + + case class Assignment(varId: Identifier, expr: Expr) extends Expr with FixedType { + val fixedType = UnitType + } + case class While(cond: Expr, body: Expr) extends Expr with FixedType with ScalacPositional { + val fixedType = UnitType + var invariant: Option[Expr] = None + + def getInvariant: Expr = invariant.get + def setInvariant(inv: Expr) = { invariant = Some(inv); this } + def setInvariant(inv: Option[Expr]) = { invariant = inv; this } + } + /* This describes computational errors (unmatched case, taking min of an * empty set, division by zero, etc.). It should always be typed according to * the expected type. */ @@ -29,6 +47,26 @@ object Trees { if(et != Untyped) setType(et) } + //same as let, buf for mutable variable declaration + case class LetVar(binder: Identifier, value: Expr, body: Expr) extends Expr { + binder.markAsLetBinder + val et = body.getType + if(et != Untyped) + setType(et) + } + + //case class LetTuple(binders: Seq[Identifier], value: Expr, body: Expr) extends Expr { + // binders.foreach(_.markAsLetBinder) + // val et = body.getType + // if(et != Untyped) + // setType(et) + //} + + case class LetDef(value: FunDef, body: Expr) extends Expr { + val et = body.getType + if(et != Untyped) + setType(et) + } /* Control flow */ case class FunctionInvocation(funDef: FunDef, args: Seq[Expr]) extends Expr with FixedType with ScalacPositional { @@ -243,6 +281,10 @@ object Trees { val fixedType = BooleanType } case class StringLiteral(value: String) extends Literal[String] + case object UnitLiteral extends Literal[Unit] with FixedType { + val fixedType = UnitType + val value = () + } case class CaseClass(classDef: CaseClassDef, args: Seq[Expr]) extends Expr with FixedType { val fixedType = CaseClassType(classDef) @@ -329,7 +371,7 @@ object Trees { case class SingletonMap(from: Expr, to: Expr) extends Expr case class FiniteMap(singletons: Seq[SingletonMap]) extends Expr - case class MapGet(map: Expr, key: Expr) extends Expr + case class MapGet(map: Expr, key: Expr) extends Expr with ScalacPositional case class MapUnion(map1: Expr, map2: Expr) extends Expr case class MapDifference(map: Expr, keys: Expr) extends Expr case class MapIsDefinedAt(map: Expr, key: Expr) extends Expr with FixedType { @@ -366,6 +408,7 @@ object Trees { case SetMax(s) => Some((s,SetMax)) case CaseClassSelector(cd, e, sel) => Some((e, CaseClassSelector(cd, _, sel))) case CaseClassInstanceOf(cd, e) => Some((e, CaseClassInstanceOf(cd, _))) + case Assignment(id, e) => Some((e, Assignment(id, _))) case TupleSelect(t, i) => Some((t, TupleSelect(_, i))) case _ => None } @@ -397,12 +440,13 @@ object Trees { case MultisetPlus(t1,t2) => Some((t1,t2,MultisetPlus)) case MultisetDifference(t1,t2) => Some((t1,t2,MultisetDifference)) case SingletonMap(t1,t2) => Some((t1,t2,SingletonMap)) - case MapGet(t1,t2) => Some((t1,t2,MapGet)) + case mg@MapGet(t1,t2) => Some((t1,t2, (t1, t2) => MapGet(t1, t2).setPosInfo(mg))) case MapUnion(t1,t2) => Some((t1,t2,MapUnion)) case MapDifference(t1,t2) => Some((t1,t2,MapDifference)) case MapIsDefinedAt(t1,t2) => Some((t1,t2, MapIsDefinedAt)) case Concat(t1,t2) => Some((t1,t2,Concat)) case ListAt(t1,t2) => Some((t1,t2,ListAt)) + case wh@While(t1, t2) => Some((t1,t2, (t1, t2) => While(t1, t2).setInvariant(wh.invariant).setPosInfo(wh))) case _ => None } } @@ -418,6 +462,7 @@ object Trees { case FiniteMap(args) => Some((args, (as : Seq[Expr]) => FiniteMap(as.asInstanceOf[Seq[SingletonMap]]))) case FiniteMultiset(args) => Some((args, FiniteMultiset)) case Distinct(args) => Some((args, Distinct)) + case Block(args, rest) => Some((args :+ rest, exprs => Block(exprs.init, exprs.last))) case Tuple(args) => Some((args, Tuple)) case _ => None } @@ -470,6 +515,21 @@ object Trees { else l } + case l @ LetVar(i,e,b) => { + val re = rec(e) + val rb = rec(b) + if(re != e || rb != b) + LetVar(i, re, rb).setType(l.getType) + else + l + } + case l @ LetDef(fd, b) => { + //TODO, not sure, see comment for the next LetDef + fd.body = fd.body.map(rec(_)) + fd.precondition = fd.precondition.map(rec(_)) + fd.postcondition = fd.postcondition.map(rec(_)) + LetDef(fd, rec(b)).setType(l.getType) + } case n @ NAryOperator(args, recons) => { var change = false val rargs = args.map(a => { @@ -553,6 +613,23 @@ object Trees { l }) } + case l @ LetVar(i,e,b) => { + val re = rec(e) + val rb = rec(b) + applySubst(if(re != e || rb != b) { + LetVar(i,re,rb).setType(l.getType) + } else { + l + }) + } + case l @ LetDef(fd,b) => { + //TODO: Not sure: I actually need the replace to occurs even in the pre/post condition, hope this is correct + fd.body = fd.body.map(rec(_)) + fd.precondition = fd.precondition.map(rec(_)) + fd.postcondition = fd.postcondition.map(rec(_)) + val rl = LetDef(fd, rec(b)).setType(l.getType) + applySubst(rl) + } case n @ NAryOperator(args, recons) => { var change = false val rargs = args.map(a => { @@ -676,6 +753,8 @@ object Trees { def treeCatamorphism[A](convert: Expr=>A, combine: (A,A)=>A, compute: (Expr,A)=>A, expression: Expr) : A = { def rec(expr: Expr) : A = expr match { case l @ Let(_, e, b) => compute(l, combine(rec(e), rec(b))) + case l @ LetVar(_, e, b) => compute(l, combine(rec(e), rec(b))) + case l @ LetDef(fd, b) => compute(l, combine(rec(fd.getBody), rec(b))) //TODO, still not sure about the semantic case n @ NAryOperator(args, _) => { if(args.size == 0) compute(n, convert(n)) @@ -694,6 +773,46 @@ object Trees { rec(expression) } + def flattenBlocks(expr: Expr): Expr = { + def applyToTree(expr: Expr): Option[Expr] = expr match { + case Block(exprs, last) => { + val nexprs = (exprs :+ last).flatMap{ + case Block(es2, el) => es2 :+ el + case UnitLiteral => Seq() + case e2 => Seq(e2) + } + val fexpr = nexprs match { + case Seq() => UnitLiteral + case Seq(e) => e + case es => Block(es.init, es.last).setType(es.last.getType) + } + Some(fexpr) + } + case _ => None + } + searchAndReplaceDFS(applyToTree)(expr) + } + + //checking whether the expr is pure, that is do not contains any non-pure construct: assign, while and blocks + def isPure(expr: Expr): Boolean = { + def convert(t: Expr) : Boolean = t match { + case Block(_, _) => false + case Assignment(_, _) => false + case While(_, _) => false + case LetVar(_, _, _) => false + case _ => true + } + def combine(b1: Boolean, b2: Boolean) = b1 && b2 + def compute(e: Expr, b: Boolean) = e match { + case Block(_, _) => false + case Assignment(_, _) => false + case While(_, _) => false + case LetVar(_, _, _) => false + case _ => true + } + treeCatamorphism(convert, combine, compute, expr) + } + def variablesOf(expr: Expr) : Set[Identifier] = { def convert(t: Expr) : Set[Identifier] = t match { case Variable(i) => Set(i) @@ -773,6 +892,8 @@ object Trees { def allIdentifiers(expr: Expr) : Set[Identifier] = expr match { case l @ Let(binder, e, b) => allIdentifiers(e) ++ allIdentifiers(b) + binder + case l @ LetVar(binder, e, b) => allIdentifiers(e) ++ allIdentifiers(b) + binder + case l @ LetDef(fd, b) => allIdentifiers(fd.getBody) ++ allIdentifiers(b) + fd.id case n @ NAryOperator(args, _) => (args map (Trees.allIdentifiers(_))).foldLeft(Set[Identifier]())((a, b) => a ++ b) case b @ BinaryOperator(a1,a2,_) => allIdentifiers(a1) ++ allIdentifiers(a2) @@ -1168,7 +1289,7 @@ object Trees { def rewriteMapGet(e: Expr) : Option[Expr] = e match { case mg @ MapGet(m,k) => val ida = MapIsDefinedAt(m, k) - Some(IfExpr(ida, mg, Error("key not found for map access").setType(mg.getType)).setType(mg.getType)) + Some(IfExpr(ida, mg, Error("key not found for map access").setType(mg.getType).setPosInfo(mg)).setType(mg.getType)) case _ => None } diff --git a/src/main/scala/leon/purescala/TypeTrees.scala b/src/main/scala/leon/purescala/TypeTrees.scala index e948a433f42e58cb452877414fada6fae5f8da5d..31947bf28177682e1376f0fd013944ef9dc37c05 100644 --- a/src/main/scala/leon/purescala/TypeTrees.scala +++ b/src/main/scala/leon/purescala/TypeTrees.scala @@ -97,6 +97,7 @@ object TypeTrees { case AnyType => InfiniteSize case BottomType => FiniteSize(0) case BooleanType => FiniteSize(2) + case UnitType => FiniteSize(1) case Int32Type => InfiniteSize case ListType(_) => InfiniteSize case TupleType(bases) => { @@ -138,6 +139,7 @@ object TypeTrees { case object BottomType extends TypeTree // This type is useful when we need an underlying type for None, Set.empty, etc. It should always be removed after parsing, though. case object BooleanType extends TypeTree case object Int32Type extends TypeTree + case object UnitType extends TypeTree case class ListType(base: TypeTree) extends TypeTree case class TupleType(bases: Seq[TypeTree]) extends TypeTree { lazy val dimension: Int = bases.length }