diff --git a/demo/Arith.scala b/demo/Arith.scala new file mode 100644 index 0000000000000000000000000000000000000000..d4b02f51b5e532a15b9fb3571534316e0cc962ce --- /dev/null +++ b/demo/Arith.scala @@ -0,0 +1,41 @@ +import leon.Utils._ + +object Arith { + + def mult(x : Int, y : Int): Int = ({ + var r = 0 + if(y < 0) { + var n = y + (while(n != 0) { + r = r - x + n = n + 1 + }) invariant(r == x * (y - n) && 0 <= -n) + } else { + var n = y + (while(n != 0) { + r = r + x + n = n - 1 + }) invariant(r == x * (y - n) && 0 <= n) + } + r + }) ensuring(_ == x*y) + + def add(x : Int, y : Int): Int = ({ + var r = x + if(y < 0) { + var n = y + (while(n != 0) { + r = r - 1 + n = n + 1 + }) invariant(r == x + y - n && 0 <= -n) + } else { + var n = y + (while(n != 0) { + r = r + 1 + n = n - 1 + }) invariant(r == x + y - n && 0 <= n) + } + r + }) ensuring(_ == x+y) + +} diff --git a/demo/BubbleSortBug.scala b/demo/BubbleSortBug.scala new file mode 100644 index 0000000000000000000000000000000000000000..a25554f0305d57c53774cdc0b4d514393217b4f2 --- /dev/null +++ b/demo/BubbleSortBug.scala @@ -0,0 +1,39 @@ +import leon.Utils._ + +/* The calculus of Computation textbook */ + +object BubbleSortBug { + + def sort(a: Array[Int]): Array[Int] = ({ + require(a.length >= 1) + var i = a.length - 1 + var j = 0 + val sa = a.clone + (while(i > 0) { + j = 0 + (while(j < i) { + if(sa(j) < sa(j+1)) { + val tmp = sa(j) + sa(j) = sa(j+1) + sa(j+1) = tmp + } + j = j + 1 + }) invariant(j >= 0 && j <= i && i < sa.length) + i = i - 1 + }) invariant(i >= 0 && i < sa.length) + sa + }) ensuring(res => sorted(res, 0, a.length-1)) + + def sorted(a: Array[Int], l: Int, u: Int): Boolean = { + require(a.length >= 0 && l >= 0 && u < a.length && l <= u) + var k = l + var isSorted = true + (while(k < u) { + if(a(k) > a(k+1)) + isSorted = false + k = k + 1 + }) invariant(k <= u && k >= l) + isSorted + } + +} diff --git a/demo/List.scala b/demo/List.scala new file mode 100644 index 0000000000000000000000000000000000000000..302e86774c7d1e68e08d48ce957c87ec483b3d56 --- /dev/null +++ b/demo/List.scala @@ -0,0 +1,20 @@ +import leon.Utils._ + +object List { + + abstract class List + case class Cons(head: Int, tail: List) extends List + case class Nil() extends List + + def size(l: List): Int = waypoint(1, (l match { + case Cons(_, tail) => sizeTail(tail, 1) + case Nil() => 0 + })) ensuring(_ >= 0) + + + def sizeTail(l2: List, acc: Int): Int = l2 match { + case Cons(_, tail) => sizeTail(tail, acc+1) + case Nil() => acc + } + +} diff --git a/demo/ListOperations.scala b/demo/ListOperations.scala new file mode 100644 index 0000000000000000000000000000000000000000..a4fc4f8dc44a90f59a772b52b1a05053316e94d2 --- /dev/null +++ b/demo/ListOperations.scala @@ -0,0 +1,107 @@ +import scala.collection.immutable.Set +import leon.Annotations._ +import leon.Utils._ + +object ListOperations { + sealed abstract class List + case class Cons(head: Int, tail: List) extends List + case class Nil() extends List + + sealed abstract class IntPairList + case class IPCons(head: IntPair, tail: IntPairList) extends IntPairList + case class IPNil() extends IntPairList + + sealed abstract class IntPair + case class IP(fst: Int, snd: Int) extends IntPair + + def size(l: List) : Int = (l match { + case Nil() => 0 + case Cons(_, t) => 1 + size(t) + }) ensuring(res => res >= 0) + + def iplSize(l: IntPairList) : Int = (l match { + case IPNil() => 0 + case IPCons(_, xs) => 1 + iplSize(xs) + }) ensuring(_ >= 0) + + def zip(l1: List, l2: List) : IntPairList = { + // try to comment this and see how pattern-matching becomes + // non-exhaustive and post-condition fails + require(size(l1) == size(l2)) + + l1 match { + case Nil() => IPNil() + case Cons(x, xs) => l2 match { + case Cons(y, ys) => IPCons(IP(x, y), zip(xs, ys)) + } + } + } ensuring(iplSize(_) == size(l1)) + + def sizeTailRec(l: List) : Int = sizeTailRecAcc(l, 0) + def sizeTailRecAcc(l: List, acc: Int) : Int = { + require(acc >= 0) + l match { + case Nil() => acc + case Cons(_, xs) => sizeTailRecAcc(xs, acc+1) + } + } ensuring(res => res == size(l) + acc) + + def sizesAreEquiv(l: List) : Boolean = { + size(l) == sizeTailRec(l) + } holds + + def content(l: List) : Set[Int] = l match { + case Nil() => Set.empty[Int] + case Cons(x, xs) => Set(x) ++ content(xs) + } + + def sizeAndContent(l: List) : Boolean = { + size(l) == 0 || content(l) != Set.empty[Int] + } holds + + def drunk(l : List) : List = (l match { + case Nil() => Nil() + case Cons(x,l1) => Cons(x,Cons(x,drunk(l1))) + }) ensuring (size(_) == 2 * size(l)) + + def reverse(l: List) : List = reverse0(l, Nil()) ensuring(content(_) == content(l)) + def reverse0(l1: List, l2: List) : List = (l1 match { + case Nil() => l2 + case Cons(x, xs) => reverse0(xs, Cons(x, l2)) + }) ensuring(content(_) == content(l1) ++ content(l2)) + + def append(l1 : List, l2 : List) : List = (l1 match { + case Nil() => l2 + case Cons(x,xs) => Cons(x, append(xs, l2)) + }) ensuring(content(_) == content(l1) ++ content(l2)) + + @induct + def nilAppend(l : List) : Boolean = (append(l, Nil()) == l) holds + + @induct + def appendAssoc(xs : List, ys : List, zs : List) : Boolean = + (append(append(xs, ys), zs) == append(xs, append(ys, zs))) holds + + def revAuxBroken(l1 : List, e : Int, l2 : List) : Boolean = { + (append(reverse(l1), Cons(e,l2)) == reverse0(l1, l2)) + } holds + + @induct + def sizeAppend(l1 : List, l2 : List) : Boolean = + (size(append(l1, l2)) == size(l1) + size(l2)) holds + + @induct + def concat(l1: List, l2: List) : List = + concat0(l1, l2, Nil()) ensuring(content(_) == content(l1) ++ content(l2)) + + @induct + def concat0(l1: List, l2: List, l3: List) : List = (l1 match { + case Nil() => l2 match { + case Nil() => reverse(l3) + case Cons(y, ys) => { + concat0(Nil(), ys, Cons(y, l3)) + } + } + case Cons(x, xs) => concat0(xs, l2, Cons(x, l3)) + }) ensuring(content(_) == content(l1) ++ content(l2) ++ content(l3)) +} diff --git a/demo/MaxSum.scala b/demo/MaxSum.scala new file mode 100644 index 0000000000000000000000000000000000000000..ba724d255c23b782e1b4de716ccc8d50043e2059 --- /dev/null +++ b/demo/MaxSum.scala @@ -0,0 +1,38 @@ +import leon.Utils._ + +/* VSTTE 2010 challenge 1 */ + +object MaxSum { + + def maxSum(a: Array[Int]): (Int, Int) = ({ + require(a.length >= 0 && isPositive(a)) + var sum = 0 + var max = 0 + var i = 0 + (while(i < a.length) { + if(max < a(i)) + max = a(i) + sum = sum + a(i) + i = i + 1 + }) invariant (sum <= i * max && i >= 0 && i <= a.length) + (sum, max) + }) ensuring(res => res._1 <= a.length * res._2) + + + def isPositive(a: Array[Int]): Boolean = { + require(a.length >= 0) + def rec(i: Int): Boolean = { + require(i >= 0) + if(i >= a.length) + true + else { + if(a(i) < 0) + false + else + rec(i+1) + } + } + rec(0) + } + +} diff --git a/demo/RedBlackTree.scala b/demo/RedBlackTree.scala new file mode 100644 index 0000000000000000000000000000000000000000..bc2de6ba96ee699736d4558932b752eea9ebba9f --- /dev/null +++ b/demo/RedBlackTree.scala @@ -0,0 +1,117 @@ +import scala.collection.immutable.Set +import leon.Annotations._ +import leon.Utils._ + +object RedBlackTree { + sealed abstract class Color + case class Red() extends Color + case class Black() extends Color + + sealed abstract class Tree + case class Empty() extends Tree + case class Node(color: Color, left: Tree, value: Int, right: Tree) extends Tree + + sealed abstract class OptionInt + case class Some(v : Int) extends OptionInt + case class None() extends OptionInt + + def content(t: Tree) : Set[Int] = t match { + case Empty() => Set.empty + case Node(_, l, v, r) => content(l) ++ Set(v) ++ content(r) + } + + def size(t: Tree) : Int = (t match { + case Empty() => 0 + case Node(_, l, v, r) => size(l) + 1 + size(r) + }) ensuring(_ >= 0) + + /* We consider leaves to be black by definition */ + def isBlack(t: Tree) : Boolean = t match { + case Empty() => true + case Node(Black(),_,_,_) => true + case _ => false + } + + def redNodesHaveBlackChildren(t: Tree) : Boolean = t match { + case Empty() => true + case Node(Black(), l, _, r) => redNodesHaveBlackChildren(l) && redNodesHaveBlackChildren(r) + case Node(Red(), l, _, r) => isBlack(l) && isBlack(r) && redNodesHaveBlackChildren(l) && redNodesHaveBlackChildren(r) + } + + def redDescHaveBlackChildren(t: Tree) : Boolean = t match { + case Empty() => true + case Node(_,l,_,r) => redNodesHaveBlackChildren(l) && redNodesHaveBlackChildren(r) + } + + def blackBalanced(t : Tree) : Boolean = t match { + case Node(_,l,_,r) => blackBalanced(l) && blackBalanced(r) && blackHeight(l) == blackHeight(r) + case Empty() => true + } + + def blackHeight(t : Tree) : Int = t match { + case Empty() => 1 + case Node(Black(), l, _, _) => blackHeight(l) + 1 + case Node(Red(), l, _, _) => blackHeight(l) + } + + // <<insert element x into the tree t>> + def ins(x: Int, t: Tree): Tree = { + require(redNodesHaveBlackChildren(t) && blackBalanced(t)) + t match { + case Empty() => Node(Red(),Empty(),x,Empty()) + case Node(c,a,y,b) => + if (x < y) balance(c, ins(x, a), y, b) + else if (x == y) Node(c,a,y,b) + else balance(c,a,y,ins(x, b)) + } + } ensuring (res => content(res) == content(t) ++ Set(x) + && size(t) <= size(res) && size(res) <= size(t) + 1 + && redDescHaveBlackChildren(res) + && blackBalanced(res)) + + def makeBlack(n: Tree): Tree = { + require(redDescHaveBlackChildren(n) && blackBalanced(n)) + n match { + case Node(Red(),l,v,r) => Node(Black(),l,v,r) + case _ => n + } + } ensuring(res => redNodesHaveBlackChildren(res) && blackBalanced(res)) + + def add(x: Int, t: Tree): Tree = { + require(redNodesHaveBlackChildren(t) && blackBalanced(t)) + makeBlack(ins(x, t)) + } ensuring (res => content(res) == content(t) ++ Set(x) && redNodesHaveBlackChildren(res) && blackBalanced(res)) + + def buggyAdd(x: Int, t: Tree): Tree = { + require(redNodesHaveBlackChildren(t)) + ins(x, t) + } ensuring (res => content(res) == content(t) ++ Set(x) && redNodesHaveBlackChildren(res)) + + def balance(c: Color, a: Tree, x: Int, b: Tree): Tree = { + Node(c,a,x,b) match { + case Node(Black(),Node(Red(),Node(Red(),a,xV,b),yV,c),zV,d) => + Node(Red(),Node(Black(),a,xV,b),yV,Node(Black(),c,zV,d)) + case Node(Black(),Node(Red(),a,xV,Node(Red(),b,yV,c)),zV,d) => + Node(Red(),Node(Black(),a,xV,b),yV,Node(Black(),c,zV,d)) + case Node(Black(),a,xV,Node(Red(),Node(Red(),b,yV,c),zV,d)) => + Node(Red(),Node(Black(),a,xV,b),yV,Node(Black(),c,zV,d)) + case Node(Black(),a,xV,Node(Red(),b,yV,Node(Red(),c,zV,d))) => + Node(Red(),Node(Black(),a,xV,b),yV,Node(Black(),c,zV,d)) + case Node(c,a,xV,b) => Node(c,a,xV,b) + } + } ensuring (res => content(res) == content(Node(c,a,x,b)))// && redDescHaveBlackChildren(res)) + + def buggyBalance(c: Color, a: Tree, x: Int, b: Tree): Tree = { + Node(c,a,x,b) match { + case Node(Black(),Node(Red(),Node(Red(),a,xV,b),yV,c),zV,d) => + Node(Red(),Node(Black(),a,xV,b),yV,Node(Black(),c,zV,d)) + case Node(Black(),Node(Red(),a,xV,Node(Red(),b,yV,c)),zV,d) => + Node(Red(),Node(Black(),a,xV,b),yV,Node(Black(),c,zV,d)) + case Node(Black(),a,xV,Node(Red(),Node(Red(),b,yV,c),zV,d)) => + Node(Red(),Node(Black(),a,xV,b),yV,Node(Black(),c,zV,d)) + case Node(Black(),a,xV,Node(Red(),b,yV,Node(Red(),c,zV,d))) => + Node(Red(),Node(Black(),a,xV,b),yV,Node(Black(),c,zV,d)) + // case Node(c,a,xV,b) => Node(c,a,xV,b) + } + } ensuring (res => content(res) == content(Node(c,a,x,b)))// && redDescHaveBlackChildren(res)) +} diff --git a/run-demo b/run-demo new file mode 100755 index 0000000000000000000000000000000000000000..a3db30e6682f1483ccc66b79cc65d5222e82bb90 --- /dev/null +++ b/run-demo @@ -0,0 +1 @@ +./leon --timeout=10 --noLuckyTests $@ diff --git a/run-demo-testgen b/run-demo-testgen new file mode 100755 index 0000000000000000000000000000000000000000..90b0479f2dc548a4b2df35042cc802604aca3ee3 --- /dev/null +++ b/run-demo-testgen @@ -0,0 +1 @@ +./leon --nodefaults --extensions=leon.TestGeneration $@ diff --git a/src/main/scala/leon/Annotations.scala b/src/main/scala/leon/Annotations.scala index 0831c8ff81ae212d6e0fd10d5a847c77c88186f4..85e8c00e0379c42f000b61ac14861190bb46963a 100644 --- a/src/main/scala/leon/Annotations.scala +++ b/src/main/scala/leon/Annotations.scala @@ -3,4 +3,5 @@ package leon object Annotations { class induct extends StaticAnnotation class axiomatize extends StaticAnnotation + class main extends StaticAnnotation } diff --git a/src/main/scala/leon/Evaluator.scala b/src/main/scala/leon/Evaluator.scala index c3e0ca1a6a40fdbf59da5508d076399346ceffe7..d5c5745c19bcb8b3c58c1496c6a63547d179dcee 100644 --- a/src/main/scala/leon/Evaluator.scala +++ b/src/main/scala/leon/Evaluator.scala @@ -77,6 +77,7 @@ object Evaluator { case _ => throw TypeErrorEx(TypeError(first, BooleanType)) } } + case Waypoint(_, arg) => rec(ctx, arg) case FunctionInvocation(fd, args) => { val evArgs = args.map(a => rec(ctx, a)) // build a context for the function... diff --git a/src/main/scala/leon/Extensions.scala b/src/main/scala/leon/Extensions.scala index 1ddb97e7c774b029c421c551129d015c880a8062..e9fdde89c3555b41b47b696c542733b0bf411830 100644 --- a/src/main/scala/leon/Extensions.scala +++ b/src/main/scala/leon/Extensions.scala @@ -99,6 +99,7 @@ object Extensions { allLoaded = defaultExtensions ++ loaded analysisExtensions = allLoaded.filter(_.isInstanceOf[Analyser]).map(_.asInstanceOf[Analyser]) + //analysisExtensions = new TestGeneration(extensionsReporter) +: analysisExtensions val solverExtensions0 = allLoaded.filter(_.isInstanceOf[Solver]).map(_.asInstanceOf[Solver]) val solverExtensions1 = if(Settings.useQuickCheck) new RandomSolver(extensionsReporter) +: solverExtensions0 else solverExtensions0 diff --git a/src/main/scala/leon/FairZ3Solver.scala b/src/main/scala/leon/FairZ3Solver.scala index 4bfe5c0484a48fc34b71b13fde7a08635cd0f5cc..2c429bab38594f4109c467a89e62106d1cc60d6b 100644 --- a/src/main/scala/leon/FairZ3Solver.scala +++ b/src/main/scala/leon/FairZ3Solver.scala @@ -964,6 +964,7 @@ class FairZ3Solver(reporter: Reporter) extends Solver(reporter) with AbstractZ3S z3Vars = z3Vars - i rb } + case Waypoint(_, e) => rec(e) case e @ Error(_) => { val tpe = e.getType val newAST = z3.mkFreshConst("errorValue", typeToSort(tpe)) diff --git a/src/main/scala/leon/Utils.scala b/src/main/scala/leon/Utils.scala index ae60002b693a0b74657fa8640a7eef56a641703f..15006c7bc0820fae25274411ac4c5f5c9ef6a70f 100644 --- a/src/main/scala/leon/Utils.scala +++ b/src/main/scala/leon/Utils.scala @@ -16,4 +16,7 @@ object Utils { def invariant(x: Boolean): Unit = () } implicit def while2Invariant(u: Unit) = InvariantFunction + + + def waypoint[A](i: Int, expr: A): A = expr } diff --git a/src/main/scala/leon/plugin/CodeExtraction.scala b/src/main/scala/leon/plugin/CodeExtraction.scala index 4649532adafd24739a9c8082036929c5906abf95..1ff95e4c922d08cf0b400c2049dc0d8139da6380 100644 --- a/src/main/scala/leon/plugin/CodeExtraction.scala +++ b/src/main/scala/leon/plugin/CodeExtraction.scala @@ -203,6 +203,7 @@ trait CodeExtraction extends Extractors { a.atp.safeToString match { case "leon.Annotations.induct" => funDef.addAnnotation("induct") case "leon.Annotations.axiomatize" => funDef.addAnnotation("axiomatize") + case "leon.Annotations.main" => funDef.addAnnotation("main") case _ => ; } } @@ -680,6 +681,11 @@ trait CodeExtraction extends Extractors { } Epsilon(c1).setType(pstpe).setPosInfo(epsi.pos.line, epsi.pos.column) } + case ExWaypointExpression(tpe, i, tree) => { + val pstpe = scalaType2PureScala(unit, silent)(tpe) + val IntLiteral(ri) = rec(i) + Waypoint(ri, rec(tree)).setType(pstpe) + } case ExSomeConstruction(tpe, arg) => { // println("Got Some !" + tpe + ":" + arg) val underlying = scalaType2PureScala(unit, silent)(tpe) diff --git a/src/main/scala/leon/plugin/Extractors.scala b/src/main/scala/leon/plugin/Extractors.scala index a3313663baea11ad915e0229511f0a4426a27b5c..eae46482780a0362946b5ddacce1eae632ffc1e9 100644 --- a/src/main/scala/leon/plugin/Extractors.scala +++ b/src/main/scala/leon/plugin/Extractors.scala @@ -174,6 +174,19 @@ trait Extractors { case _ => None } } + object ExWaypointExpression { + def unapply(tree: Apply) : Option[(Type, Tree, Tree)] = tree match { + case Apply( + TypeApply(Select(Select(funcheckIdent, utilsName), waypoint), typeTree :: Nil), + List(i, expr)) => { + if (utilsName.toString == "Utils" && waypoint.toString == "waypoint") + Some((typeTree.tpe, i, expr)) + else + None + } + case _ => None + } + } object ExValDef { diff --git a/src/main/scala/leon/purescala/Definitions.scala b/src/main/scala/leon/purescala/Definitions.scala index bc497dcdd318cdbd7809c357b811002f986908da..3c18e858fcfeba8fa7d02421536e1080c5999847 100644 --- a/src/main/scala/leon/purescala/Definitions.scala +++ b/src/main/scala/leon/purescala/Definitions.scala @@ -54,6 +54,15 @@ object Definitions { def caseClassDef(name: String) = mainObject.caseClassDef(name) def allIdentifiers : Set[Identifier] = mainObject.allIdentifiers + id def isPure: Boolean = definedFunctions.forall(fd => fd.body.forall(Trees.isPure) && fd.precondition.forall(Trees.isPure) && fd.postcondition.forall(Trees.isPure)) + + def writeScalaFile(filename: String) { + import java.io.FileWriter + import java.io.BufferedWriter + val fstream = new FileWriter(filename) + val out = new BufferedWriter(fstream) + out.write(ScalaPrinter(this)) + out.close + } } /** Objects work as containers for class definitions, functions (def's) and diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala index 21a9393b142b177d49ba985988512c5dbe885c5f..b7b5bd343c2adc3bf973440776951a124a0ddd08 100644 --- a/src/main/scala/leon/purescala/PrettyPrinter.scala +++ b/src/main/scala/leon/purescala/PrettyPrinter.scala @@ -154,6 +154,12 @@ object PrettyPrinter { nsb } + case Waypoint(i, expr) => { + sb.append("waypoint_" + i + "(") + pp(expr, sb, lvl) + sb.append(")") + } + case OptionSome(a) => { var nsb = sb nsb.append("Some(") diff --git a/src/main/scala/leon/purescala/ScalaPrinter.scala b/src/main/scala/leon/purescala/ScalaPrinter.scala new file mode 100644 index 0000000000000000000000000000000000000000..501096c45db7c82d88c441e53a6564065423944d --- /dev/null +++ b/src/main/scala/leon/purescala/ScalaPrinter.scala @@ -0,0 +1,563 @@ +package leon.purescala + +/** This pretty-printer only print valid scala syntax */ +object ScalaPrinter { + import Common._ + import Trees._ + import TypeTrees._ + import Definitions._ + + import java.lang.StringBuffer + + def apply(tree: Expr): String = { + val retSB = pp(tree, new StringBuffer, 0) + retSB.toString + } + + def apply(tpe: TypeTree): String = { + val retSB = pp(tpe, new StringBuffer, 0) + retSB.toString + } + + def apply(defn: Definition): String = { + val retSB = pp(defn, new StringBuffer, 0) + retSB.toString + } + + private def ind(sb: StringBuffer, lvl: Int) : StringBuffer = { + sb.append(" " * lvl) + sb + } + + // EXPRESSIONS + // all expressions are printed in-line + private def ppUnary(sb: StringBuffer, expr: Expr, op1: String, op2: String, lvl: Int): StringBuffer = { + var nsb: StringBuffer = sb + nsb.append(op1) + nsb = pp(expr, nsb, lvl) + nsb.append(op2) + nsb + } + + private def ppBinary(sb: StringBuffer, left: Expr, right: Expr, op: String, lvl: Int): StringBuffer = { + var nsb: StringBuffer = sb + nsb.append("(") + nsb = pp(left, nsb, lvl) + nsb.append(op) + nsb = pp(right, nsb, lvl) + nsb.append(")") + nsb + } + + private def ppNary(sb: StringBuffer, exprs: Seq[Expr], pre: String, op: String, post: String, lvl: Int): StringBuffer = { + var nsb = sb + nsb.append(pre) + val sz = exprs.size + var c = 0 + + exprs.foreach(ex => { + nsb = pp(ex, nsb, lvl) ; c += 1 ; if(c < sz) nsb.append(op) + }) + + nsb.append(post) + nsb + } + + private def pp(tree: Expr, sb: StringBuffer, lvl: Int): StringBuffer = tree match { + case Variable(id) => sb.append(id) + case DeBruijnIndex(idx) => sys.error("Not Valid Scala") + case Let(b,d,e) => { + sb.append("locally {\n") + ind(sb, lvl+1) + sb.append("val " + b + " = ") + pp(d, sb, lvl+1) + sb.append("\n") + ind(sb, lvl+1) + pp(e, sb, lvl+1) + sb.append("\n") + ind(sb, lvl) + sb.append("}\n") + ind(sb, lvl) + sb + } + case LetVar(b,d,e) => { + sb.append("locally {\n") + ind(sb, lvl+1) + sb.append("var " + b + " = ") + pp(d, sb, lvl+1) + sb.append("\n") + ind(sb, lvl+1) + pp(e, sb, lvl+1) + sb.append("\n") + ind(sb, lvl) + sb.append("}\n") + ind(sb, lvl) + sb + } + case LetDef(fd,e) => { + sb.append("\n") + pp(fd, sb, lvl+1) + sb.append("\n") + sb.append("\n") + ind(sb, lvl) + pp(e, sb, lvl) + sb + } + case And(exprs) => ppNary(sb, exprs, "(", " && ", ")", lvl) // \land + case Or(exprs) => ppNary(sb, exprs, "(", " || ", ")", lvl) // \lor + case Not(Equals(l, r)) => ppBinary(sb, l, r, " != ", lvl) // \neq + case Iff(l,r) => sys.error("Not Scala Code") + case Implies(l,r) => sys.error("Not Scala Code") + case UMinus(expr) => ppUnary(sb, expr, "-(", ")", lvl) + case Equals(l,r) => ppBinary(sb, l, r, " == ", lvl) + case IntLiteral(v) => sb.append(v) + case BooleanLiteral(v) => sb.append(v) + case StringLiteral(s) => sb.append("\"" + s + "\"") + case UnitLiteral => sb.append("()") + case Block(exprs, last) => { + sb.append("{\n") + (exprs :+ last).foreach(e => { + ind(sb, lvl+1) + pp(e, sb, lvl+1) + sb.append("\n") + }) + ind(sb, lvl) + sb.append("}\n") + sb + } + case Assignment(lhs, rhs) => ppBinary(sb, lhs.toVariable, rhs, " = ", lvl) + case wh@While(cond, body) => { + wh.invariant match { + case Some(inv) => { + sb.append("\n") + ind(sb, lvl) + sb.append("@invariant: ") + pp(inv, sb, lvl) + sb.append("\n") + ind(sb, lvl) + } + case None => + } + sb.append("while(") + pp(cond, sb, lvl) + sb.append(")\n") + ind(sb, lvl+1) + pp(body, sb, lvl+1) + sb.append("\n") + } + + case t@Tuple(exprs) => ppNary(sb, exprs, "(", ", ", ")", lvl) + case s@TupleSelect(t, i) => { + pp(t, sb, lvl) + sb.append("._" + i) + sb + } + + case e@Epsilon(pred) => sys.error("Not Scala Code") + case Waypoint(i, expr) => pp(expr, sb, lvl) + + case OptionSome(a) => { + var nsb = sb + nsb.append("Some(") + nsb = pp(a, nsb, lvl) + nsb.append(")") + nsb + } + + case OptionNone(_) => sb.append("None") + + case CaseClass(cd, args) => { + var nsb = sb + nsb.append(cd.id) + nsb = ppNary(nsb, args, "(", ", ", ")", lvl) + nsb + } + case CaseClassInstanceOf(cd, e) => { + var nsb = sb + nsb = pp(e, nsb, lvl) + nsb.append(".isInstanceOf[" + cd.id + "]") + nsb + } + case CaseClassSelector(_, cc, id) => pp(cc, sb, lvl).append("." + id) + case FunctionInvocation(fd, args) => { + var nsb = sb + nsb.append(fd.id) + nsb = ppNary(nsb, args, "(", ", ", ")", lvl) + nsb + } + case AnonymousFunction(es, ev) => { + var nsb = sb + nsb.append("{") + es.foreach { + case (as, res) => + nsb = ppNary(nsb, as, "", " ", "", lvl) + nsb.append(" -> ") + nsb = pp(res, nsb, lvl) + nsb.append(", ") + } + nsb.append("else -> ") + nsb = pp(ev, nsb, lvl) + nsb.append("}") + } + case AnonymousFunctionInvocation(id, args) => { + var nsb = sb + nsb.append(id) + nsb = ppNary(nsb, args, "(", ", ", ")", lvl) + nsb + } + case Plus(l,r) => ppBinary(sb, l, r, " + ", lvl) + case Minus(l,r) => ppBinary(sb, l, r, " - ", lvl) + case Times(l,r) => ppBinary(sb, l, r, " * ", lvl) + case Division(l,r) => ppBinary(sb, l, r, " / ", lvl) + case Modulo(l,r) => ppBinary(sb, l, r, " % ", lvl) + case LessThan(l,r) => ppBinary(sb, l, r, " < ", lvl) + case GreaterThan(l,r) => ppBinary(sb, l, r, " > ", lvl) + case LessEquals(l,r) => ppBinary(sb, l, r, " <= ", lvl) // \leq + case GreaterEquals(l,r) => ppBinary(sb, l, r, " >= ", lvl) // \geq + case FiniteSet(rs) => ppNary(sb, rs, "{", ", ", "}", lvl) + case FiniteMultiset(rs) => ppNary(sb, rs, "{|", ", ", "|}", lvl) + case EmptySet(bt) => sb.append("Set()") // Ø + case EmptyMultiset(_) => sys.error("Not Valid Scala") + case Not(ElementOfSet(s,e)) => sys.error("TODO") + //case ElementOfSet(s,e) => ppBinary(sb, s, e, " \u2208 ", lvl) // \in + //case SubsetOf(l,r) => ppBinary(sb, l, r, " \u2286 ", lvl) // \subseteq + //case Not(SubsetOf(l,r)) => ppBinary(sb, l, r, " \u2288 ", lvl) // \notsubseteq + case SetMin(s) => pp(s, sb, lvl).append(".min") + case SetMax(s) => pp(s, sb, lvl).append(".max") + // case SetUnion(l,r) => ppBinary(sb, l, r, " \u222A ", lvl) // \cup + // case MultisetUnion(l,r) => ppBinary(sb, l, r, " \u222A ", lvl) // \cup + // case MapUnion(l,r) => ppBinary(sb, l, r, " \u222A ", lvl) // \cup + // case SetDifference(l,r) => ppBinary(sb, l, r, " \\ ", lvl) + // case MultisetDifference(l,r) => ppBinary(sb, l, r, " \\ ", lvl) + // case SetIntersection(l,r) => ppBinary(sb, l, r, " \u2229 ", lvl) // \cap + // case MultisetIntersection(l,r) => ppBinary(sb, l, r, " \u2229 ", lvl) // \cap + // case SetCardinality(t) => ppUnary(sb, t, "|", "|", lvl) + // case MultisetCardinality(t) => ppUnary(sb, t, "|", "|", lvl) + // case MultisetPlus(l,r) => ppBinary(sb, l, r, " \u228E ", lvl) // U+ + // case MultisetToSet(e) => pp(e, sb, lvl).append(".toSet") + case EmptyMap(_,_) => sb.append("Map()") + case SingletonMap(f,t) => ppBinary(sb, f, t, " -> ", lvl) + case FiniteMap(rs) => ppNary(sb, rs, "Map(", ", ", ")", lvl) + case MapGet(m,k) => { + var nsb = sb + pp(m, nsb, lvl) + nsb = ppNary(nsb, Seq(k), "(", ",", ")", lvl) + nsb + } + case MapIsDefinedAt(m,k) => { + var nsb = sb + pp(m, nsb, lvl) + nsb.append(".isDefinedAt") + nsb = ppNary(nsb, Seq(k), "(", ",", ")", lvl) + nsb + } + case ArrayLength(a) => { + pp(a, sb, lvl) + sb.append(".length") + } + case ArrayClone(a) => { + pp(a, sb, lvl) + sb.append(".clone") + } + case fill@ArrayFill(size, v) => { + sb.append("Array.fill(") + pp(size, sb, lvl) + sb.append(")(") + pp(v, sb, lvl) + sb.append(")") + } + case am@ArrayMake(v) => sys.error("Not Scala Code") + case sel@ArraySelect(ar, i) => { + pp(ar, sb, lvl) + sb.append("(") + pp(i, sb, lvl) + sb.append(")") + } + case up@ArrayUpdate(ar, i, v) => { + pp(ar, sb, lvl) + sb.append("(") + pp(i, sb, lvl) + sb.append(") = ") + pp(v, sb, lvl) + } + case up@ArrayUpdated(ar, i, v) => { + pp(ar, sb, lvl) + sb.append(".updated(") + pp(i, sb, lvl) + sb.append(", ") + pp(v, sb, lvl) + sb.append(")") + } + case FiniteArray(exprs) => { + ppNary(sb, exprs, "Array(", ", ", ")", lvl) + } + + case Distinct(exprs) => { + var nsb = sb + nsb.append("distinct") + nsb = ppNary(nsb, exprs, "(", ", ", ")", lvl) + nsb + } + + case IfExpr(c, t, e) => { + var nsb = sb + nsb.append("(if (") + nsb = pp(c, nsb, lvl) + nsb.append(")\n") + ind(nsb, lvl+1) + nsb = pp(t, nsb, lvl+1) + nsb.append("\n") + ind(nsb, lvl) + nsb.append("else\n") + ind(nsb, lvl+1) + nsb = pp(e, nsb, lvl+1) + nsb.append(")") + nsb + } + + case mex @ MatchExpr(s, csc) => { + def ppc(sb: StringBuffer, p: Pattern): StringBuffer = p match { + //case InstanceOfPattern(None, ctd) => + //case InstanceOfPattern(Some(id), ctd) => + case CaseClassPattern(bndr, ccd, subps) => { + var nsb = sb + bndr.foreach(b => nsb.append(b + " @ ")) + nsb.append(ccd.id).append("(") + var c = 0 + val sz = subps.size + subps.foreach(sp => { + nsb = ppc(nsb, sp) + if(c < sz - 1) + nsb.append(", ") + c = c + 1 + }) + nsb.append(")") + } + case WildcardPattern(None) => sb.append("_") + case WildcardPattern(Some(id)) => sb.append(id) + case TuplePattern(bndr, subPatterns) => { + bndr.foreach(b => sb.append(b + " @ ")) + sb.append("(") + subPatterns.init.foreach(p => { + ppc(sb, p) + sb.append(", ") + }) + ppc(sb, subPatterns.last) + sb.append(")") + } + case _ => sb.append("Pattern?") + } + + var nsb = sb + nsb.append("(") + nsb == pp(s, nsb, lvl) + // if(mex.posInfo != "") { + // nsb.append(" match@(" + mex.posInfo + ") {\n") + // } else { + nsb.append(" match {\n") + // } + + csc.foreach(cs => { + ind(nsb, lvl+1) + nsb.append("case ") + nsb = ppc(nsb, cs.pattern) + cs.theGuard.foreach(g => { + nsb.append(" if ") + nsb = pp(g, nsb, lvl+1) + }) + nsb.append(" => ") + nsb = pp(cs.rhs, nsb, lvl+1) + nsb.append("\n") + }) + ind(nsb, lvl).append("}") + nsb.append(")") + nsb + } + + case ResultVariable() => sb.append("res") + case EpsilonVariable((row, col)) => sb.append("x" + row + "_" + col) + case Not(expr) => ppUnary(sb, expr, "\u00AC(", ")", lvl) // \neg + + case e @ Error(desc) => { + var nsb = sb + nsb.append("error(\"" + desc + "\")[") + nsb = pp(e.getType, nsb, lvl) + nsb.append("]") + nsb + } + + case _ => sb.append("Expr?") + } + + // TYPE TREES + // all type trees are printed in-line + private def ppNaryType(sb: StringBuffer, tpes: Seq[TypeTree], pre: String, op: String, post: String, lvl: Int): StringBuffer = { + var nsb = sb + nsb.append(pre) + val sz = tpes.size + var c = 0 + + tpes.foreach(t => { + nsb = pp(t, nsb, lvl) ; c += 1 ; if(c < sz) nsb.append(op) + }) + + nsb.append(post) + nsb + } + + private def pp(tpe: TypeTree, sb: StringBuffer, lvl: Int): StringBuffer = tpe match { + case Untyped => sb.append("???") + case UnitType => sb.append("Unit") + case Int32Type => sb.append("Int") + case BooleanType => sb.append("Boolean") + case ArrayType(bt) => pp(bt, sb.append("Array["), lvl).append("]") + case SetType(bt) => pp(bt, sb.append("Set["), lvl).append("]") + case MapType(ft,tt) => pp(tt, pp(ft, sb.append("Map["), lvl).append(","), lvl).append("]") + case MultisetType(bt) => pp(bt, sb.append("Multiset["), lvl).append("]") + case OptionType(bt) => pp(bt, sb.append("Option["), lvl).append("]") + case FunctionType(fts, tt) => { + var nsb = sb + if (fts.size > 1) + nsb = ppNaryType(nsb, fts, "(", ", ", ")", lvl) + else if (fts.size == 1) + nsb = pp(fts.head, nsb, lvl) + nsb.append(" => ") + pp(tt, nsb, lvl) + } + case TupleType(tpes) => ppNaryType(sb, tpes, "(", ", ", ")", lvl) + case c: ClassType => sb.append(c.classDef.id) + case _ => sb.append("Type?") + } + + // DEFINITIONS + // all definitions are printed with an end-of-line + private def pp(defn: Definition, sb: StringBuffer, lvl: Int): StringBuffer = { + + defn match { + case Program(id, mainObj) => { + assert(lvl == 0) + pp(mainObj, sb, lvl) + } + + case ObjectDef(id, defs, invs) => { + var nsb = sb + ind(nsb, lvl) + nsb.append("object ") + nsb.append(id) + nsb.append(" {\n") + + var c = 0 + val sz = defs.size + + defs.foreach(df => { + nsb = pp(df, nsb, lvl+1) + if(c < sz - 1) { + nsb.append("\n\n") + } + c = c + 1 + }) + + nsb.append("\n") + ind(nsb, lvl).append("}\n") + } + + case AbstractClassDef(id, parent) => { + var nsb = sb + ind(nsb, lvl) + nsb.append("sealed abstract class ") + nsb.append(id) + parent.foreach(p => nsb.append(" extends " + p.id)) + nsb + } + + case CaseClassDef(id, parent, varDecls) => { + var nsb = sb + ind(nsb, lvl) + nsb.append("case class ") + nsb.append(id) + nsb.append("(") + var c = 0 + val sz = varDecls.size + + varDecls.foreach(vd => { + nsb.append(vd.id) + nsb.append(": ") + nsb = pp(vd.tpe, nsb, lvl) + if(c < sz - 1) { + nsb.append(", ") + } + c = c + 1 + }) + nsb.append(")") + parent.foreach(p => nsb.append(" extends " + p.id)) + nsb + } + + case fd @ FunDef(id, rt, args, body, pre, post) => { + var nsb = sb + + //for(a <- fd.annotations) { + // ind(nsb, lvl) + // nsb.append("@" + a + "\n") + //} + + ind(nsb, lvl) + nsb.append("def ") + nsb.append(id) + nsb.append("(") + + val sz = args.size + var c = 0 + + args.foreach(arg => { + nsb.append(arg.id) + nsb.append(" : ") + nsb = pp(arg.tpe, nsb, lvl) + + if(c < sz - 1) { + nsb.append(", ") + } + c = c + 1 + }) + + nsb.append(") : ") + nsb = pp(rt, nsb, lvl) + nsb.append(" = (") + if(body.isDefined) { + pre match { + case None => pp(body.get, nsb, lvl) + case Some(prec) => { + nsb.append("{\n") + ind(nsb, lvl+1) + nsb.append("require(") + nsb = pp(prec, nsb, lvl+1) + nsb.append(")\n") + pp(body.get, nsb, lvl+1) + nsb.append("\n") + ind(nsb, lvl) + nsb.append("}") + } + } + } else + nsb.append("[unknown function implementation]") + + post match { + case None => { + nsb.append(")") + } + case Some(postc) => { + nsb.append(" ensuring(res => ") //TODO, not very general solution... + nsb = pp(postc, nsb, lvl) + nsb.append("))") + } + } + + nsb + } + + case _ => sb.append("Defn?") + } + } +} + +// vim: set ts=4 sw=4 et: diff --git a/src/main/scala/leon/purescala/Trees.scala b/src/main/scala/leon/purescala/Trees.scala index 612bcb9b08455e04440e7502d56866038381d306..bfad43aae71ffb22c33c09cf8667d7920acca9ab 100644 --- a/src/main/scala/leon/purescala/Trees.scala +++ b/src/main/scala/leon/purescala/Trees.scala @@ -85,6 +85,8 @@ object Trees { } case class TupleSelect(tuple: Expr, index: Int) extends Expr + case class Waypoint(i: Int, expr: Expr) extends Expr + object MatchExpr { def apply(scrutinee: Expr, cases: Seq[MatchCase]) : MatchExpr = { scrutinee.getType match { @@ -154,6 +156,7 @@ object Trees { case class TuplePattern(binder: Option[Identifier], subPatterns: Seq[Pattern]) extends Pattern + /* Propositional logic */ object And { def apply(l: Expr, r: Expr) : Expr = (l,r) match { @@ -439,6 +442,7 @@ object Trees { case ArrayLength(a) => Some((a, ArrayLength)) case ArrayClone(a) => Some((a, ArrayClone)) case ArrayMake(t) => Some((t, ArrayMake)) + case Waypoint(i, t) => Some((t, (expr: Expr) => Waypoint(i, expr))) case e@Epsilon(t) => Some((t, (expr: Expr) => Epsilon(expr).setType(e.getType).setPosInfo(e))) case _ => None } @@ -886,6 +890,18 @@ object Trees { } treeCatamorphism(convert, combine, compute, expr) } + def containsIfExpr(expr: Expr): Boolean = { + def convert(t : Expr) : Boolean = t match { + case (i: IfExpr) => true + case _ => false + } + def combine(c1 : Boolean, c2 : Boolean) : Boolean = c1 || c2 + def compute(t : Expr, c : Boolean) = t match { + case (i: IfExpr) => true + case _ => c + } + treeCatamorphism(convert, combine, compute, expr) + } def variablesOf(expr: Expr) : Set[Identifier] = { def convert(t: Expr) : Set[Identifier] = t match { @@ -1450,4 +1466,34 @@ object Trees { case FunctionType(fromTypes, toType) => AnonymousFunction(Seq.empty, simplestValue(toType)).setType(tpe) case _ => throw new Exception("I can't choose simplest value for type " + tpe) } + + //guarentee that all IfExpr will be at the top level and as soon as you encounter a non-IfExpr, then no more IfExpr can be find in the sub-expressions + //require no-match, no-ets and only pure code + def hoistIte(expr: Expr): Expr = { + def transform(expr: Expr): Option[Expr] = expr match { + case uop@UnaryOperator(IfExpr(c, t, e), op) => Some(IfExpr(c, op(t).setType(uop.getType), op(e).setType(uop.getType)).setType(uop.getType)) + case bop@BinaryOperator(IfExpr(c, t, e), t2, op) => Some(IfExpr(c, op(t, t2).setType(bop.getType), op(e, t2).setType(bop.getType)).setType(bop.getType)) + case bop@BinaryOperator(t1, IfExpr(c, t, e), op) => Some(IfExpr(c, op(t1, t).setType(bop.getType), op(t1, e).setType(bop.getType)).setType(bop.getType)) + case nop@NAryOperator(ts, op) => { + val iteIndex = ts.indexWhere{ case IfExpr(_, _, _) => true case _ => false } + if(iteIndex == -1) None else { + val (beforeIte, startIte) = ts.splitAt(iteIndex) + val afterIte = startIte.tail + val IfExpr(c, t, e) = startIte.head + Some(IfExpr(c, + op(beforeIte ++ Seq(t) ++ afterIte).setType(nop.getType), + op(beforeIte ++ Seq(e) ++ afterIte).setType(nop.getType) + ).setType(nop.getType)) + } + } + case _ => None + } + + def fix[A](f: (A) => A, a: A): A = { + val na = f(a) + if(a == na) a else fix(f, na) + } + fix(searchAndReplaceDFS(transform), expr) + } + } diff --git a/src/main/scala/leon/testgen/CallGraph.scala b/src/main/scala/leon/testgen/CallGraph.scala new file mode 100644 index 0000000000000000000000000000000000000000..16e31057390e2cea0fd1714d7631c819593ecc04 --- /dev/null +++ b/src/main/scala/leon/testgen/CallGraph.scala @@ -0,0 +1,388 @@ +package leon.testgen + +import leon.purescala.Definitions._ +import leon.purescala.Trees._ +import leon.purescala.TypeTrees._ +import leon.purescala.Common._ +import leon.FairZ3Solver + +class CallGraph(val program: Program) { + + sealed abstract class ProgramPoint + case class FunctionStart(fd: FunDef) extends ProgramPoint + case class ExpressionPoint(wp: Expr, id: Int) extends ProgramPoint + private var epid = -1 + private def freshExpressionPoint(wp: Expr) = {epid += 1; ExpressionPoint(wp, epid)} + + case class TransitionLabel(cond: Expr, assignment: Map[Variable, Expr]) + + private lazy val graph: Map[ProgramPoint, Set[(ProgramPoint, TransitionLabel)]] = buildGraph + private lazy val programPoints: Set[ProgramPoint] = { + graph.flatMap(pair => pair._2.map(edge => edge._1).toSet + pair._1).toSet + } + + private def buildGraph: Map[ProgramPoint, Set[(ProgramPoint, TransitionLabel)]] = { + var callGraph: Map[ProgramPoint, Set[(ProgramPoint, TransitionLabel)]] = Map() + + program.definedFunctions.foreach(fd => { + val body = fd.body.get + //val cleanBody = hoistIte(expandLets(matchToIfThenElse(body))) + val cleanBody = expandLets(matchToIfThenElse(body)) + val subgraph = collectWithPathCondition(cleanBody, FunctionStart(fd)) + callGraph ++= subgraph + }) + + callGraph = addFunctionInvocationsEdges(callGraph) + + callGraph = simplifyGraph(callGraph) + + callGraph + } + + private def simplifyGraph(graph: Map[ProgramPoint, Set[(ProgramPoint, TransitionLabel)]]): Map[ProgramPoint, Set[(ProgramPoint, TransitionLabel)]] = { + def fix[A](f: (A) => A, a: A): A = { + val na = f(a) + if(a == na) a else fix(f, na) + } + fix(compressGraph, graph) + } + + //does a one level compression of the graph + private def compressGraph(graph: Map[ProgramPoint, Set[(ProgramPoint, TransitionLabel)]]): Map[ProgramPoint, Set[(ProgramPoint, TransitionLabel)]] = { + var newGraph = graph + + graph.find{ + case (point, edges) => { + edges.exists{ + case edge@(p2@ExpressionPoint(e, _), TransitionLabel(BooleanLiteral(true), assign)) if assign.isEmpty && !e.isInstanceOf[Waypoint] => { + val edgesOfPoint: Set[(ProgramPoint, TransitionLabel)] = graph.get(p2).getOrElse(Set()) //should be unique entry point and cannot be a FunctionStart + newGraph += (point -> ((edges - edge) ++ edgesOfPoint)) + newGraph -= p2 + true + } + case _ => false + } + } + } + + newGraph + } + + + private def addFunctionInvocationsEdges(graph: Map[ProgramPoint, Set[(ProgramPoint, TransitionLabel)]]): Map[ProgramPoint, Set[(ProgramPoint, TransitionLabel)]] = { + var augmentedGraph = graph + + graph.foreach{ + case (point@ExpressionPoint(FunctionInvocation(fd, args), _), edges) => { + val newPoint = FunctionStart(fd) + val newTransition = TransitionLabel(BooleanLiteral(true), fd.args.zip(args).map{ case (VarDecl(id, _), arg) => (id.toVariable, arg) }.toMap) + augmentedGraph += (point -> (edges + ((newPoint, newTransition)))) + } + case _ => ; + } + + augmentedGraph + } + + private def collectWithPathCondition(expression: Expr, startingPoint: ProgramPoint): Map[ProgramPoint, Set[(ProgramPoint, TransitionLabel)]] = { + var callGraph: Map[ProgramPoint, Set[(ProgramPoint, TransitionLabel)]] = Map() + + def rec(expr: Expr, path: List[Expr], startingPoint: ProgramPoint): Unit = { + val transitions: Set[(ProgramPoint, TransitionLabel)] = callGraph.get(startingPoint) match { + case None => Set() + case Some(s) => s + } + expr match { + //case FunctionInvocation(fd, args) => { + // val newPoint = FunctionStart(fd) + // val newTransition = TransitionLabel(And(path.toSeq), fd.args.zip(args).map{ case (VarDecl(id, _), arg) => (id.toVariable, arg) }.toMap) + // callGraph += (startingPoint -> (transitions + ((newPoint, newTransition)))) + // args.foreach(arg => rec(arg, path, startingPoint)) + //} + //this case is actually now handled in the unaryOp case + //case way@Waypoint(i, e) => { + // val newPoint = ExpressionPoint(way) + // val newTransition = TransitionLabel(And(path.toSeq), Map()) + // callGraph += (startingPoint -> (transitions + ((newPoint, newTransition)))) + // rec(e, List(), newPoint) + //} + case IfExpr(cond, then, elze) => { + rec(cond, path, startingPoint) + rec(then, cond :: path, startingPoint) + rec(elze, Not(cond) :: path, startingPoint) + } + case n@NAryOperator(args, _) => { + val newPoint = freshExpressionPoint(n) + val newTransition = TransitionLabel(And(path.toSeq), Map()) + callGraph += (startingPoint -> (transitions + ((newPoint, newTransition)))) + args.foreach(rec(_, List(), newPoint)) + } + case b@BinaryOperator(t1, t2, _) => { + val newPoint = freshExpressionPoint(b) + val newTransition = TransitionLabel(And(path.toSeq), Map()) + callGraph += (startingPoint -> (transitions + ((newPoint, newTransition)))) + rec(t1, List(), newPoint) + rec(t2, List(), newPoint) + } + case u@UnaryOperator(t, _) => { + val newPoint = freshExpressionPoint(u) + val newTransition = TransitionLabel(And(path.toSeq), Map()) + callGraph += (startingPoint -> (transitions + ((newPoint, newTransition)))) + rec(t, List(), newPoint) + } + case t : Terminal => { + val newPoint = freshExpressionPoint(t) + val newTransition = TransitionLabel(And(path.toSeq), Map()) + callGraph += (startingPoint -> (transitions + ((newPoint, newTransition)))) + } + case _ => scala.sys.error("Unhandled tree in collectWithPathCondition : " + expr) + } + } + + rec(expression, List(), startingPoint) + callGraph + } + + //given a path, follow the path to build the logical constraint that need to be satisfiable + def pathConstraint(path: Seq[(ProgramPoint, ProgramPoint, TransitionLabel)], assigns: List[Map[Expr, Expr]] = List()): Expr = { + if(path.isEmpty) BooleanLiteral(true) else { + val (_, _, TransitionLabel(cond, assign)) = path.head + val finalCond = assigns.foldRight(cond)((map, acc) => replace(map, acc)) + And(finalCond, + pathConstraint( + path.tail, + if(assign.isEmpty) assigns else assign.asInstanceOf[Map[Expr, Expr]] :: assigns + ) + ) + } + } + + private def isMain(fd: FunDef): Boolean = { + fd.annotations.exists(_ == "main") + } + + def findAllPaths(z3Solver: FairZ3Solver): Set[Seq[(ProgramPoint, ProgramPoint, TransitionLabel)]] = { + val waypoints: Set[ProgramPoint] = programPoints.filter{ case ExpressionPoint(Waypoint(_, _), _) => true case _ => false } + val sortedWaypoints: Seq[ProgramPoint] = waypoints.toSeq.sortWith((p1, p2) => { + val (ExpressionPoint(Waypoint(i1, _), _), ExpressionPoint(Waypoint(i2, _), _)) = (p1, p2) + i1 <= i2 + }) + + val functionPoints: Set[ProgramPoint] = programPoints.flatMap{ case f@FunctionStart(fd) => Set[ProgramPoint](f) case _ => Set[ProgramPoint]() } + val mainPoint: Option[ProgramPoint] = functionPoints.find{ case FunctionStart(fd) => isMain(fd) case p => sys.error("unexpected: " + p) } + + assert(mainPoint != None) + + if(sortedWaypoints.size == 0) { + findSimplePaths(mainPoint.get) + } else { + visitAllWaypoints(mainPoint.get :: sortedWaypoints.toList, z3Solver) match { + case None => Set() + case Some(p) => Set(p) + } + //Set( + // sortedWaypoints.zip(sortedWaypoints.tail).foldLeft(Seq[(ProgramPoint, ProgramPoint, TransitionLabel)]())((path, waypoint) => + // path ++ findPath(waypoint._1, waypoint._2)) + //) + } + } + + def visitAllWaypoints(waypoints: List[ProgramPoint], z3Solver: FairZ3Solver): Option[Seq[(ProgramPoint, ProgramPoint, TransitionLabel)]] = { + def rec(head: ProgramPoint, tail: List[ProgramPoint], path: Seq[(ProgramPoint, ProgramPoint, TransitionLabel)]): + Option[Seq[(ProgramPoint, ProgramPoint, TransitionLabel)]] = { + tail match { + case Nil => Some(path) + case x::xs => { + val allPaths = findSimplePaths(head, Some(x)) + var completePath: Option[Seq[(ProgramPoint, ProgramPoint, TransitionLabel)]] = None + allPaths.find(intermediatePath => { + val pc = pathConstraint(path ++ intermediatePath) + z3Solver.init() + z3Solver.restartZ3 + + var testcase: Option[Map[Identifier, Expr]] = None + + val (solverResult, model) = z3Solver.decideWithModel(pc, false) + solverResult match { + case None => { + false + } + case Some(true) => { + false + } + case Some(false) => { + val recPath = rec(x, xs, path ++ intermediatePath) + recPath match { + case None => false + case Some(path) => { + completePath = Some(path) + true + } + } + } + } + }) + completePath + } + } + } + rec(waypoints.head, waypoints.tail, Seq()) + } + + def findSimplePaths(from: ProgramPoint, to: Option[ProgramPoint] = None): Set[Seq[(ProgramPoint, ProgramPoint, TransitionLabel)]] = { + def dfs(point: ProgramPoint, path: List[(ProgramPoint, ProgramPoint, TransitionLabel)], visitedPoints: Set[ProgramPoint]): + Set[Seq[(ProgramPoint, ProgramPoint, TransitionLabel)]] = graph.get(point) match { + case None => Set(path.reverse) + case Some(edges) => { + if(to != None && to.get == point) + Set(path.reverse) + else if(to == None && edges.forall((edge: (ProgramPoint, TransitionLabel)) => visitedPoints.contains(edge._1) || point == edge._1)) + Set(path.reverse) + else { + edges.flatMap((edge: (ProgramPoint, TransitionLabel)) => { + val (neighbour, transition) = edge + if(visitedPoints.contains(neighbour) || point == neighbour) + Set[Seq[(ProgramPoint, ProgramPoint, TransitionLabel)]]() + else + dfs(neighbour, (point, neighbour, transition) :: path, visitedPoints + point) + }) + } + } + } + + dfs(from, List(), Set()) + } + + //find a path that goes through all waypoint in order + def findPath(from: ProgramPoint, to: ProgramPoint): Seq[(ProgramPoint, ProgramPoint, TransitionLabel)] = { + var visitedPoints: Set[ProgramPoint] = Set() + var history: Map[ProgramPoint, (ProgramPoint, TransitionLabel)] = Map() + var toVisit: List[ProgramPoint] = List(from) + + var currentPoint: ProgramPoint = null + while(!toVisit.isEmpty && currentPoint != to) { + currentPoint = toVisit.head + if(currentPoint != to) { + visitedPoints += currentPoint + toVisit = toVisit.tail + graph.get(currentPoint).foreach(edges => edges.foreach{ + case (neighbour, transition) => + if(!visitedPoints.contains(neighbour) && !toVisit.contains(neighbour)) { + toVisit ::= neighbour + history += (neighbour -> ((currentPoint, transition))) + } + }) + } + } + + def rebuildPath(point: ProgramPoint, path: List[(ProgramPoint, ProgramPoint, TransitionLabel)]): Seq[(ProgramPoint, ProgramPoint, TransitionLabel)] = { + if(point == from) path else { + val (previousPoint, transition) = history(point) + val newPath = (previousPoint, point, transition) :: path + rebuildPath(previousPoint, newPath) + } + } + + //TODO: handle case where the target node is not found + rebuildPath(to, List()) + } + + + lazy val toDotString: String = { + var vertexLabels: Set[(String, String)] = Set() + var vertexId = -1 + var point2vertex: Map[ProgramPoint, Int] = Map() + //return id and label + def getVertex(p: ProgramPoint): (String, String) = point2vertex.get(p) match { + case Some(id) => ("v_" + id, ppPoint(p)) + case None => { + vertexId += 1 + point2vertex += (p -> vertexId) + val pair = ("v_" + vertexId, ppPoint(p)) + vertexLabels += pair + pair + } + } + + def ppPoint(p: ProgramPoint): String = p match { + case FunctionStart(fd) => fd.id.name + case ExpressionPoint(Waypoint(i, e), _) => "WayPoint " + i + case ExpressionPoint(e, _) => e.toString + } + def ppLabel(l: TransitionLabel): String = { + val TransitionLabel(cond, assignments) = l + cond.toString + ", " + assignments.map(p => p._1 + " -> " + p._2).mkString("\n") + } + + val edges: List[(String, String, String)] = graph.flatMap(pair => { + val (startPoint, edges) = pair + val (startId, _) = getVertex(startPoint) + edges.map(pair => { + val (endPoint, label) = pair + val (endId, _) = getVertex(endPoint) + (startId, endId, ppLabel(label)) + }).toList + }).toList + + val res = ( + "digraph " + program.id.name + " {\n" + + vertexLabels.map(p => p._1 + " [label=\"" + p._2 + "\"];").mkString("\n") + "\n" + + edges.map(p => p._1 + " -> " + p._2 + " [label=\"" + p._3 + "\"];").mkString("\n") + "\n" + + "}") + + res + } + + def writeDotFile(filename: String) { + import java.io.FileWriter + import java.io.BufferedWriter + val fstream = new FileWriter(filename) + val out = new BufferedWriter(fstream) + out.write(toDotString) + out.close + } + +} + + //def hoistIte(expr: Expr): (Seq[Expr] => Expr, Seq[Expr]) = expr match { + // case ite@IfExpr(c, t, e) => { + // val (iteThen, valsThen) = hoistIte(t) + // val nbValsThen = valsThen.size + // val (iteElse, valsElse) = hoistIte(e) + // val nbValsElse = valsElse.size + // def ite(es: Seq[Expr]): Expr = { + // val argsThen = es.take(nbValsThen) + // val argsElse = es.drop(nbValsThen) + // IfExpr(c, iteThen(argsThen), iteElse(argsElse), e2) + // } + // (ite, valsThen ++ valsElse) + // } + // case BinaryOperator(t1, t2, op) => { + // val (iteLeft, valsLeft) = hoistIte(t1) + // val (iteRight, valsRight) = hoistIte(t2) + // def ite(e1: Expr, e2: Expr): Expr = { + + // } + // iteLeft( + // iteRight( + // op(thenValRight, thenValLeft), + // op(thenValRight, elseValLeft) + // ), iteRight( + // op(elseValRight, thenValLeft), + // op(elseValRight, elseValLeft) + // ) + // ) + // } + // case NAryOperator(args, op) => { + + // } + // case (t: Terminal) => { + // def ite(es: Seq[Expr]): Expr = { + // require(es.size == 1) + // es.head + // } + // (ite, Seq(t)) + // } + // case _ => scala.sys.error("Unhandled tree in hoistIte : " + expr) + //} + diff --git a/src/main/scala/leon/testgen/TestGeneration.scala b/src/main/scala/leon/testgen/TestGeneration.scala new file mode 100644 index 0000000000000000000000000000000000000000..1b48d72579f0ff620ef7d35f41e9bcfe841f3ac4 --- /dev/null +++ b/src/main/scala/leon/testgen/TestGeneration.scala @@ -0,0 +1,145 @@ +package leon.testgen + +import leon.purescala.Common._ +import leon.purescala.Definitions._ +import leon.purescala.Trees._ +import leon.purescala.TypeTrees._ +import leon.purescala.ScalaPrinter +import leon.Extensions._ +import leon.FairZ3Solver +import leon.Reporter + +import scala.collection.mutable.{Set => MutableSet} + +class TestGeneration(reporter: Reporter) extends Analyser(reporter) { + + def description: String = "Generate random testcases" + override def shortDescription: String = "test" + + private val z3Solver = new FairZ3Solver(reporter) + + def analyse(program: Program) { + z3Solver.setProgram(program) + reporter.info("Running test generation") + + val testcases = generateTestCases(program) + + val topFunDef = program.definedFunctions.find(fd => isMain(fd)).get +//fd.body.exists(body => body match { +// case Waypoint(1, _) => true +// case _ => false +// }) + val testFun = new FunDef(FreshIdentifier("test"), UnitType, Seq()) + val funInvocs = testcases.map(testcase => { + val params = topFunDef.args + val args = topFunDef.args.map{ + case VarDecl(id, tpe) => testcase.get(id) match { + case Some(v) => v + case None => simplestValue(tpe) + } + } + FunctionInvocation(topFunDef, args) + }).toSeq + testFun.body = Some(Block(funInvocs, UnitLiteral)) + + val Program(id, ObjectDef(objId, defs, invariants)) = program + val testProgram = Program(id, ObjectDef(objId, testFun +: defs , invariants)) + testProgram.writeScalaFile("TestGen.scalax") + + reporter.info("Running from waypoint with the following testcases:\n") + reporter.info(testcases.mkString("\n")) + } + + private def isMain(fd: FunDef): Boolean = { + fd.annotations.exists(_ == "main") + } + + def generatePathConditions(program: Program): Set[Expr] = { + + val callGraph = new CallGraph(program) + callGraph.writeDotFile("testgen.dot") + val constraints = callGraph.findAllPaths(z3Solver).map(path => { + println("Path is: " + path) + val cnstr = callGraph.pathConstraint(path) + println("constraint is: " + cnstr) + cnstr + }) + constraints + } + + private def generateTestCases(program: Program): Set[Map[Identifier, Expr]] = { + val allPaths = generatePathConditions(program) + + allPaths.flatMap(pathCond => { + reporter.info("Now considering path condition: " + pathCond) + + var testcase: Option[Map[Identifier, Expr]] = None + //val z3Solver: FairZ3Solver = loadedSolverExtensions.find(se => se.isInstanceOf[FairZ3Solver]).get.asInstanceOf[FairZ3Solver] + + z3Solver.init() + z3Solver.restartZ3 + val (solverResult, model) = z3Solver.decideWithModel(pathCond, false) + + solverResult match { + case None => Seq() + case Some(true) => { + reporter.info("The path is unreachable") + Seq() + } + case Some(false) => { + reporter.info("The model should be used as the testcase") + Seq(model) + } + } + }) + } + + //private def generatePathConditions(funDef: FunDef): Seq[Expr] = if(!funDef.hasImplementation) Seq() else { + // val body = funDef.body.get + // val cleanBody = hoistIte(expandLets(matchToIfThenElse(body))) + // collectWithPathCondition(cleanBody) + //} + + //private def generateTestCases(funDef: FunDef): Seq[Map[Identifier, Expr]] = { + // val allPaths = generatePathConditions(funDef) + + // allPaths.flatMap(pathCond => { + // reporter.info("Now considering path condition: " + pathCond) + + // var testcase: Option[Map[Identifier, Expr]] = None + // //val z3Solver: FairZ3Solver = loadedSolverExtensions.find(se => se.isInstanceOf[FairZ3Solver]).get.asInstanceOf[FairZ3Solver] + // + // z3Solver.init() + // z3Solver.restartZ3 + // val (solverResult, model) = z3Solver.decideWithModel(pathCond, false) + + // solverResult match { + // case None => Seq() + // case Some(true) => { + // reporter.info("The path is unreachable") + // Seq() + // } + // case Some(false) => { + // reporter.info("The model should be used as the testcase") + // Seq(model) + // } + // } + // }) + //} + + //prec: ite are hoisted and no lets nor match occurs + //private def collectWithPathCondition(expression: Expr): Seq[Expr] = { + // var allPaths: Seq[Expr] = Seq() + + // def rec(expr: Expr, path: List[Expr]): Seq[Expr] = expr match { + // case IfExpr(cond, then, elze) => rec(then, cond :: path) ++ rec(elze, Not(cond) :: path) + // case _ => Seq(And(path.toSeq)) + // } + + // rec(expression, List()) + //} + +} + + + diff --git a/testcases/Abs.scala b/testcases/Abs.scala index 1a071ee10cd76c98267eb2ea3cdbef510bcc681f..52baf710140b76bf5edeeda52ef9e0da5385734f 100644 --- a/testcases/Abs.scala +++ b/testcases/Abs.scala @@ -1,39 +1,5 @@ -import leon.Utils._ - object Abs { - - def abs(tab: Array[Int]): Array[Int] = ({ - require(tab.length >= 0) - var k = 0 - val tabres = Array.fill(tab.length)(0) - (while(k < tab.length) { - if(tab(k) < 0) - tabres(k) = -tab(k) - else - tabres(k) = tab(k) - k = k + 1 - }) invariant( - tabres.length == tab.length && - k >= 0 && k <= tab.length && - isPositive(tabres, k)) - tabres - }) ensuring(res => isPositive(res, res.length)) - - def isPositive(a: Array[Int], size: Int): Boolean = { - require(a.length >= 0 && size <= a.length) - def rec(i: Int): Boolean = { - require(i >= 0) - if(i >= size) - true - else { - if(a(i) < 0) - false - else - rec(i+1) - } - } - rec(0) - } + def abs(x: Int): Int = (if(x < 0) -x else x) ensuring(_ >= 0) } diff --git a/testcases/AbsArray.scala b/testcases/AbsArray.scala new file mode 100644 index 0000000000000000000000000000000000000000..0086891cdafc5a62987d5563695b38983de24fb4 --- /dev/null +++ b/testcases/AbsArray.scala @@ -0,0 +1,51 @@ +import leon.Utils._ + +object AbsArray { + + + def abs(tab: Map[Int, Int], size: Int): Map[Int, Int] = ({ + require(size <= 5 && isArray(tab, size)) + var k = 0 + var tabres = Map.empty[Int, Int] + (while(k < size) { + if(tab(k) < 0) + tabres = tabres.updated(k, -tab(k)) + else + tabres = tabres.updated(k, tab(k)) + k = k + 1 + }) invariant(isArray(tabres, k) && k >= 0 && k <= size && isPositive(tabres, k)) + tabres + }) ensuring(res => isArray(res, size) && isPositive(res, size)) + + def isPositive(a: Map[Int, Int], size: Int): Boolean = { + require(size <= 10 && isArray(a, size)) + def rec(i: Int): Boolean = { + require(i >= 0) + if(i >= size) + true + else { + if(a(i) < 0) + false + else + rec(i+1) + } + } + rec(0) + } + + def isArray(a: Map[Int, Int], size: Int): Boolean = { + + def rec(i: Int): Boolean = { + require(i >= 0) + if(i >= size) true else { + if(a.isDefinedAt(i)) rec(i+1) else false + } + } + + if(size < 0) + false + else + rec(0) + } + +} diff --git a/testcases/testgen/Abs.scala b/testcases/testgen/Abs.scala new file mode 100644 index 0000000000000000000000000000000000000000..4aa8307c8ec3da1d9ff941994b084d8590ce3afc --- /dev/null +++ b/testcases/testgen/Abs.scala @@ -0,0 +1,11 @@ +import leon.Utils._ +import leon.Annotations._ + +object Abs { + + @main + def abs(x: Int): Int = { + if(x < 0) -x else x + } ensuring(_ >= 0) + +} diff --git a/testcases/testgen/Abs2.scala b/testcases/testgen/Abs2.scala new file mode 100644 index 0000000000000000000000000000000000000000..d640f941bf299ce06e0e84e40e2a2f98c39a4c4c --- /dev/null +++ b/testcases/testgen/Abs2.scala @@ -0,0 +1,11 @@ +import leon.Utils._ +import leon.Annotations._ + +object Abs2 { + + @main + def f(x: Int): Int = if(x < 0) g(-x) else g(x) + + def g(y: Int): Int = if(y < 0) -y else y + +} diff --git a/testcases/testgen/Diamond.scala b/testcases/testgen/Diamond.scala new file mode 100644 index 0000000000000000000000000000000000000000..dbd354e94ffe82de38d7a83997fa1553189f7c0b --- /dev/null +++ b/testcases/testgen/Diamond.scala @@ -0,0 +1,9 @@ +import leon.Utils._ + +object Diamond { + + def foo(x: Int): Int = waypoint(1, if(x < 0) bar(x) else bar(x)) + + def bar(y: Int): Int = if(y > 5) y else -y + +} diff --git a/testcases/testgen/Imp.scala b/testcases/testgen/Imp.scala new file mode 100644 index 0000000000000000000000000000000000000000..8257e6619ca299a8e56fd3d78076c54f8b731882 --- /dev/null +++ b/testcases/testgen/Imp.scala @@ -0,0 +1,17 @@ +import leon.Utils._ +import leon.Annotations._ + +object Imp { + + @main + def foo(i: Int): Int = { + var a = 0 + a = a + 3 + if(i < a) + a = a + 1 + else + a = a - 1 + a + } ensuring(_ >= 0) + +} diff --git a/testcases/testgen/ImpWaypoint.scala b/testcases/testgen/ImpWaypoint.scala new file mode 100644 index 0000000000000000000000000000000000000000..932d362b6faada0d8acf684d96d6bcaada0cca21 --- /dev/null +++ b/testcases/testgen/ImpWaypoint.scala @@ -0,0 +1,20 @@ + +import leon.Utils._ +import leon.Annotations._ + +object Imp { + + @main + def foo(i: Int): Int = { + var a = 0 + a = a + 3 + if(i < a) + waypoint(1, a = a + 1) + else + a = a - 1 + a + } ensuring(_ >= 0) + +} + +// vim: set ts=4 sw=4 et: diff --git a/testcases/testgen/List.scala b/testcases/testgen/List.scala new file mode 100644 index 0000000000000000000000000000000000000000..b21928d924998d5dcd89ea2c278928c77aecf46a --- /dev/null +++ b/testcases/testgen/List.scala @@ -0,0 +1,20 @@ +import leon.Utils._ +import leon.Annotations._ + +object List { + + abstract class List + case class Cons(head: Int, tail: List) extends List + case class Nil() extends List + + @main + def size(l: List): Int = (l match { + case Cons(_, tail) => sizeTail(tail, 1) + case Nil() => 0 + }) ensuring(_ >= 0) + + def sizeTail(l2: List, acc: Int): Int = l2 match { + case Cons(_, tail) => sizeTail(tail, acc+1) + case Nil() => acc + } +} diff --git a/testcases/testgen/MultiCall.scala b/testcases/testgen/MultiCall.scala new file mode 100644 index 0000000000000000000000000000000000000000..6742cc4ca224b8bef044745dc6cd12d60683ac80 --- /dev/null +++ b/testcases/testgen/MultiCall.scala @@ -0,0 +1,15 @@ +import leon.Utils._ +import leon.Annotations._ + +object MultiCall { + + @main + def a(i: Int): Int = if(i < 0) b(i) else c(i) + + def b(j: Int): Int = if(j == -5) d(j) else e(j) + def c(k: Int): Int = if(k == 5) d(k) else e(k) + + def d(l: Int): Int = l + def e(m: Int): Int = m + +} diff --git a/testcases/testgen/Sum.scala b/testcases/testgen/Sum.scala new file mode 100644 index 0000000000000000000000000000000000000000..5f1c6c481f36d195f1fe6d58c591a5d91918871b --- /dev/null +++ b/testcases/testgen/Sum.scala @@ -0,0 +1,11 @@ +import leon.Utils._ +import leon.Annotations._ + +object Sum { + + @main + def sum(n: Int): Int = { + if(n <= 0) waypoint(4, 0) else waypoint(3, waypoint(2, n + sum(n-1))) + } ensuring(_ >= 0) + +}