diff --git a/src/funcheck/FunCheckPlugin.scala b/src/funcheck/FunCheckPlugin.scala index 587e69a527c3293574c7f68bab28dbf69cb109d7..d038a8f51a614113f7b2c9aef37a02ad74efb62d 100644 --- a/src/funcheck/FunCheckPlugin.scala +++ b/src/funcheck/FunCheckPlugin.scala @@ -23,6 +23,7 @@ class FunCheckPlugin(val global: Global) extends Plugin { " -P:funcheck:extensions=ex1:... Specifies a list of qualified class names of extensions to be loaded" + "\n" + " -P:funcheck:nodefaults Runs only the analyses provided by the extensions" + "\n" + " -P:funcheck:functions=fun1:... Only generates verification conditions for the specified functions" + "\n" + + " -P:funcheck:unrolling=[0,1,2] Unrolling depth for recursive functions" + "\n" + " -P:funcheck:tolerant Silently extracts non-pure function bodies as ''unknown''" + "\n" + " -P:funcheck:quiet No info and warning messages from the extensions" ) @@ -38,6 +39,7 @@ class FunCheckPlugin(val global: Global) extends Plugin { case "tolerant" => silentlyTolerateNonPureBodies = true case "quiet" => purescala.Settings.quietExtensions = true case "nodefaults" => purescala.Settings.runDefaultExtensions = false + case s if s.startsWith("unrolling=") => purescala.Settings.unrollingLevel = try { s.substring("unrolling=".length, s.length).toInt } catch { case _ => 0 } case s if s.startsWith("functions=") => purescala.Settings.functionsToAnalyse = Set(splitList(s.substring("functions=".length, s.length)): _*) case s if s.startsWith("extensions=") => purescala.Settings.extensionNames = splitList(s.substring("extensions=".length, s.length)) case _ => error("Invalid option: " + option) diff --git a/src/purescala/Analysis.scala b/src/purescala/Analysis.scala index 205de861b0473c43fd72fa71507a61011eceac1d..a56acce566ee2629702b902b7ba5acd774a4ee32 100644 --- a/src/purescala/Analysis.scala +++ b/src/purescala/Analysis.scala @@ -72,8 +72,9 @@ class Analysis(val program: Program) { reporter.info(vc) // reporter.info("Negated:") // reporter.info(negate(vc)) - reporter.info("Negated, expanded:") - reporter.info(expandLets(negate(vc))) + // reporter.info("Negated, expanded:") + // val exp = expandLets(negate(vc)) + // reporter.info(exp) // try all solvers until one returns a meaningful answer solverExtensions.find(se => { @@ -107,14 +108,15 @@ class Analysis(val program: Program) { } else { val resFresh = FreshIdentifier("result", true).setType(body.getType) val bodyAndPost = Let(resFresh, body, replace(Map(ResultVariable() -> Variable(resFresh)), post.get)) - val newExpr = if(prec.isEmpty) { + val withPrec = if(prec.isEmpty) { bodyAndPost } else { Implies(prec.get, bodyAndPost) } import Analysis._ - val (newExpr1, sideExprs1) = rewriteSimplePatternMatching(newExpr) + val newExpr0 = unrollRecursiveFunctions(program, withPrec, Settings.unrollingLevel) + val (newExpr1, sideExprs1) = rewriteSimplePatternMatching(newExpr0) val (newExpr2, sideExprs2) = inlineFunctionsAndContracts(program, newExpr1) if(sideExprs1.isEmpty && sideExprs2.isEmpty) { @@ -162,12 +164,37 @@ object Analysis { (searchAndApply(isFunCall, applyToCall, expr), extras.reverse) } + def unrollRecursiveFunctions(program: Program, expr: Expr, times: Int) : Expr = { + def isRecursiveCall(e: Expr) = e match { + case f @ FunctionInvocation(fd, _) if fd.hasImplementation && program.isRecursive(fd) => true + case _ => false + } + def unrollCall(t: Int)(e: Expr) = e match { + case f @ FunctionInvocation(fd, args) if fd.hasImplementation && program.isRecursive(fd) => { + val newLetIDs = fd.args.map(a => FreshIdentifier(a.id.name, true).setType(a.tpe)) + val newLetVars = newLetIDs.map(Variable(_)) + val substs: Map[Expr,Expr] = Map((fd.args.map(_.toVariable) zip newLetVars) :_*) + val bodyWithLetVars: Expr = replace(substs, fd.body.get) + val bigLet = (newLetIDs zip args).foldLeft(bodyWithLetVars)((e,p) => Let(p._1, p._2, e)) + unrollRecursiveFunctions(program, bigLet, t - 1) + } + case o => o + } + + if(times > 0) + searchAndApply(isRecursiveCall, unrollCall(times), expr, false) + else + expr + } + // Rewrites pattern matching expressions where the cases simply correspond to // the list of constructors def rewriteSimplePatternMatching(expr: Expr) : (Expr, Seq[Expr]) = { var extras : List[Expr] = Nil - def isPMExpr(e: Expr) : Boolean = e.isInstanceOf[MatchExpr] + def isPMExpr(e: Expr) : Boolean = { + e.isInstanceOf[MatchExpr] + } def rewritePM(e: Expr) : Expr = { val MatchExpr(scrutinee, cases) = e.asInstanceOf[MatchExpr] @@ -194,7 +221,8 @@ object Analysis { } else { Variable(FreshIdentifier("pat", true)).setType(p._1.tpe) }) - (newPVar, List(Equals(newPVar, CaseClass(ccd, argVars)), Implies(Equals(Variable(scrutAsLetID), newPVar), Equals(newVar, rhs)))) + val (rewrittenRHS, moreExtras) = rewriteSimplePatternMatching(rhs) + (newPVar, List(Equals(newPVar, CaseClass(ccd, argVars)), Implies(Equals(Variable(scrutAsLetID), newPVar), Equals(newVar, rewrittenRHS))) ::: moreExtras.toList) } case _ => (null,Nil) }).toList @@ -219,13 +247,19 @@ object Analysis { } } - // this gets us "extras", but we will still need to clean these up. - val cleanerTree = searchAndApply(isPMExpr, rewritePM, expr) - val theExtras = extras.reverse - val onExtras: Seq[(Expr,Seq[Expr])] = theExtras.map(rewriteSimplePatternMatching(_)) - // the "moreExtras" should be cleaned up due to the recursive call.. - val (rewrittenExtras, moreExtras) = onExtras.unzip - - (cleanerTree, rewrittenExtras ++ moreExtras.flatten) + val cleanerTree = searchAndApply(isPMExpr, rewritePM, expr) + // println("******************") + // println("rewrote: " + expr) + // println(" *** to ***") + // println(cleanerTree) + // println(" ** with side conds ** ") + // println(extras.reverse) + // println("******************") + (cleanerTree, extras.reverse) + // val theExtras = extras.reverse + // val onExtras: Seq[(Expr,Seq[Expr])] = theExtras.map(rewriteSimplePatternMatching(_)) + // // the "moreExtras" should be cleaned up due to the recursive call.. + // val (rewrittenExtras, moreExtras) = onExtras.unzip + // (cleanerTree, rewrittenExtras ++ moreExtras.flatten) } } diff --git a/src/purescala/Settings.scala b/src/purescala/Settings.scala index 7842e1fec5fa333a86e47ca9a41812428ff37036..5e591acf78b2154213abab700df97c79e9633091 100644 --- a/src/purescala/Settings.scala +++ b/src/purescala/Settings.scala @@ -9,4 +9,5 @@ object Settings { var reporter: Reporter = new DefaultReporter var quietReporter: Reporter = new QuietReporter var runDefaultExtensions: Boolean = true + var unrollingLevel: Int = 0 } diff --git a/src/purescala/Trees.scala b/src/purescala/Trees.scala index 8593de5f03f90bcec664a1a380ac941cb5cfc9cc..f6b31f72d57daef0836d60c230c30f955bbb625c 100644 --- a/src/purescala/Trees.scala +++ b/src/purescala/Trees.scala @@ -103,7 +103,6 @@ object Trees { val fixedType = BooleanType } - /* Literals */ case class Variable(id: Identifier) extends Expr { override def getType = id.getType override def setType(tt: TypeTree) = { id.setType(tt); this } @@ -112,6 +111,7 @@ object Trees { // represents the result in post-conditions case class ResultVariable() extends Expr + /* Literals */ sealed abstract class Literal[T] extends Expr { val value: T } @@ -275,7 +275,9 @@ object Trees { // the replacement map should be understood as follows: // - on each subexpression, checkFun checks whether it should be replaced // - repFun is applied is checkFun succeeded - def searchAndApply(checkFun: Expr=>Boolean, repFun: Expr=>Expr, expr: Expr) : Expr = { + // - if the result of repFun is different from its argument and recursive + // is set to true, search/replace is reapplied on the result. + def searchAndApply(checkFun: Expr=>Boolean, repFun: Expr=>Expr, expr: Expr, recursive: Boolean=true) : Expr = { def rec(ex: Expr, skip: Expr = null) : Expr = ex match { case _ if (ex != skip && checkFun(ex)) => { val newExpr = repFun(ex) @@ -283,9 +285,9 @@ object Trees { Settings.reporter.warning("REPLACING IN EXPRESSION WITH AN UNTYPED TREE ! " + ex + " --to--> " + newExpr) } if(ex == newExpr) - rec(ex, ex) + if(recursive) rec(ex, ex) else ex else - rec(newExpr) + if(recursive) rec(newExpr) else newExpr } case l @ Let(i,e,b) => { val re = rec(e) @@ -361,16 +363,17 @@ object Trees { def simplifyLets(expr: Expr) : Expr = { val isLet = ((t: Expr) => t.isInstanceOf[Let]) def simplerLet(t: Expr) : Expr = t match { - case letExpr @ Let(_, Variable(_), _) => expandLets(letExpr) + case letExpr @ Let(i, Variable(v), b) => replace(Map((Variable(i) -> Variable(v))), b) + case letExpr @ Let(i, l: Literal[_], b) => replace(Map((Variable(i) -> l)), b) case letExpr @ Let(i,e,b) => { var occurences = 0 - def isOcc(tr: Expr) = (occurences < 2 && tr.isInstanceOf[Variable] && tr.asInstanceOf[Variable].id == i) + def isOcc(tr: Expr) = (occurences < 2 && tr == Variable(i)) def incCount(tr: Expr) = { occurences = occurences + 1; tr } - searchAndApply(isOcc,incCount,b) + searchAndApply(isOcc, incCount, b, false) if(occurences == 0) { b } else if(occurences == 1) { - expandLets(letExpr) + replace(Map((Variable(i) -> e)), b) } else { t } diff --git a/src/purescala/Z3Solver.scala b/src/purescala/Z3Solver.scala index 52b9d91ad8ccbfc3202b9e334a922872514e2db6..b0fb46068c9a8c72267e74c5520b52caa8bcea67 100644 --- a/src/purescala/Z3Solver.scala +++ b/src/purescala/Z3Solver.scala @@ -150,7 +150,7 @@ class Z3Solver(reporter: Reporter) extends Solver(reporter) { Equals(fOfX, funDef.body.get) } val (newExpr1, sideExprs1) = Analysis.rewriteSimplePatternMatching(toConvert) - val (newExpr2, sideExprs2) = Analysis.inlineFunctionsAndContracts(program, newExpr1) + val (newExpr2, sideExprs2) = (newExpr1, Nil) // Analysis.inlineFunctionsAndContracts(program, newExpr1) val finalToConvert = if(sideExprs1.isEmpty && sideExprs2.isEmpty) { newExpr2 } else { @@ -165,6 +165,9 @@ class Z3Solver(reporter: Reporter) extends Solver(reporter) { } case None => { reporter.warning("Could not generate forall axiom for " + funDef.id.name) + reporter.warning(toConvert) + reporter.warning(newExpr1) + reporter.warning(newExpr2) reporter.warning(finalToConvert) } } diff --git a/testcases/ExprComp.scala b/testcases/ExprComp.scala index 92ae4131994589c6187a61d1a7dc6b8c2ccbd961..1346f1abbb32eb1e1bd846d7d9949f8ede4f9028 100644 --- a/testcases/ExprComp.scala +++ b/testcases/ExprComp.scala @@ -5,8 +5,8 @@ object ExprComp { // Operations sealed abstract class BinOp - case class Plus extends BinOp - case class Times extends BinOp + case class Plus() extends BinOp + case class Times() extends BinOp // Expressions sealed abstract class Expr @@ -34,20 +34,20 @@ object ExprComp { // Programs sealed abstract class Program - case class EProgram extends Program + case class EProgram() extends Program case class NProgram(first : Instruction, rest : Program) extends Program // Value stack sealed abstract class ValueStack - case class EStack extends ValueStack + case class EStack() extends ValueStack case class NStack(v : Value, rest : ValueStack) extends ValueStack // Outcomes of running the program sealed abstract class Outcome case class Ok(v : ValueStack) extends Outcome - case class Fail extends Outcome + case class Fail() extends Outcome // Running programs on a given initial stack def run(p : Program, vs : ValueStack) : Outcome = p match { diff --git a/testcases/Fibonacci.scala b/testcases/Fibonacci.scala new file mode 100644 index 0000000000000000000000000000000000000000..f2140e54dafb4ff08df3533ae67539db23a67a13 --- /dev/null +++ b/testcases/Fibonacci.scala @@ -0,0 +1,15 @@ +object Fibonacci { + def fib(x: Int) : Int = { + require(x >= 0) + if(x < 2) { + x + } else { + fib(x - 1) + fib(x - 2) + } + } + + // requires that fib is universally quantified to work... + def check() : Boolean = { + fib(5) == 5 + } ensuring(_ == true) +}