diff --git a/src/funcheck/CodeExtraction.scala b/src/funcheck/CodeExtraction.scala index 1e189924c04349de5b63e6efefab85b8e591507d..c8c887a13fbbb04b8489bb0dce531f355616343b 100644 --- a/src/funcheck/CodeExtraction.scala +++ b/src/funcheck/CodeExtraction.scala @@ -313,6 +313,16 @@ trait CodeExtraction extends Extractors { } def rec(tr: Tree): Expr = tr match { + case ExValDef(vs, tpt, bdy, rst) => { + val binderTpe = scalaType2PureScala(unit, silent)(tpt.tpe) + val newID = FreshIdentifier(vs.name.toString).setType(binderTpe) + val oldSubsts = varSubsts + val valTree = rec(bdy) + varSubsts(vs) = (() => Variable(newID)) + val restTree = rec(rst) + varSubsts.remove(vs) + Let(newID, valTree, restTree) + } case ExInt32Literal(v) => IntLiteral(v).setType(Int32Type) case ExBooleanLiteral(v) => BooleanLiteral(v).setType(BooleanType) case ExIdentifier(sym,tpt) => varSubsts.get(sym) match { diff --git a/src/funcheck/Extractors.scala b/src/funcheck/Extractors.scala index 2d108abdcd56d32de01d99163d2e31f5fb41bb72..da04302b6845bc2e7341cf85d0ab4b051caf633f 100644 --- a/src/funcheck/Extractors.scala +++ b/src/funcheck/Extractors.scala @@ -53,6 +53,18 @@ trait Extractors { } } + object ExValDef { + /** Extracts val's in the head of blocks. */ + def unapply(tree: Block): Option[(Symbol,Tree,Tree,Tree)] = tree match { + case Block((vd @ ValDef(_, _, tpt, rhs)) :: rest, expr) => + if(rest.isEmpty) + Some((vd.symbol, tpt, rhs, expr)) + else + Some((vd.symbol, tpt, rhs, Block(rest, expr))) + case _ => None + } + } + object ExObjectDef { /** Matches an object with no type parameters, and regardless of its * visibility. Does not match on the automatically generated companion diff --git a/src/purescala/Analysis.scala b/src/purescala/Analysis.scala index 9018e7565e52538574d36a22e7b1b4465b89e551..20defdec3878fd664b42b2ec3a9b1d11f15704aa 100644 --- a/src/purescala/Analysis.scala +++ b/src/purescala/Analysis.scala @@ -41,7 +41,7 @@ class Analysis(val program: Program) { reporter.info(vc) if(Settings.runDefaultExtensions) { - val (z3f,stupidMap) = toZ3Formula(z3, vc) + val z3f = toZ3Formula(z3, vc) z3.assertCnstr(z3.mkNot(z3f)) //z3.print z3.checkAndGetModel() match { @@ -133,17 +133,30 @@ class Analysis(val program: Program) { rec(expr) } - def toZ3Formula(z3: Z3Context, expr: Expr) : (Z3AST,Map[Identifier,Z3AST]) = { - val intSort = z3.mkIntSort() - var varMap: Map[Identifier,Z3AST] = Map.empty + def toZ3Formula(z3: Z3Context, expr: Expr) : (Z3AST) = { + lazy val intSort = z3.mkIntSort() + lazy val boolSort = z3.mkBoolSort() + + // because we create identifiers the first time we see them, this is + // convenient. + var z3Vars: Map[Identifier,Z3AST] = Map.empty def rec(ex: Expr) : Z3AST = ex match { - case v @ Variable(id) => varMap.get(id) match { + case Let(i,e,b) => { + z3Vars = z3Vars + (i -> rec(e)) + rec(b) + } + case v @ Variable(id) => z3Vars.get(id) match { case Some(ast) => ast case None => { - assert(v.getType == Int32Type) - val newAST = z3.mkConst(z3.mkStringSymbol(id.name), intSort) - varMap = varMap + (id -> newAST) + val newAST = if(v.getType == Int32Type) { + z3.mkConst(z3.mkStringSymbol(id.name), intSort) + } else if(v.getType == BooleanType) { + z3.mkConst(z3.mkStringSymbol(id.name), boolSort) + } else { + reporter.fatalError("Unsupported type in Z3 transformation: " + v.getType) + } + z3Vars = z3Vars + (id -> newAST) newAST } } @@ -168,7 +181,6 @@ class Analysis(val program: Program) { case _ => scala.Predef.error("Can't handle this in translation to Z3: " + ex) } - val res = rec(expr) - (res,varMap) + rec(expr) } } diff --git a/src/purescala/PrettyPrinter.scala b/src/purescala/PrettyPrinter.scala index ad018307c3472b37b4193eefa9399643e657b159..4e312210afb5ce79e84d684e094cb54df27f76e1 100644 --- a/src/purescala/PrettyPrinter.scala +++ b/src/purescala/PrettyPrinter.scala @@ -67,6 +67,9 @@ object PrettyPrinter { private def pp(tree: Expr, sb: StringBuffer, lvl: Int): StringBuffer = tree match { case Variable(id) => sb.append(id) + case Let(b,d,e) => { + pp(e, pp(d, sb.append("(let (" + b + " = "), lvl).append(") in "), lvl).append(")") + } case And(exprs) => ppNary(sb, exprs, "(", " \u2227 ", ")", lvl) // \land case Or(exprs) => ppNary(sb, exprs, "(", " \u2228 ", ")", lvl) // \lor case Not(Equals(l, r)) => ppBinary(sb, l, r, " \u2260 ", lvl) // \neq diff --git a/src/purescala/Trees.scala b/src/purescala/Trees.scala index c0a61240abfbf9145a2fe89d758f83c5a195b76c..2f58fd3cbd65fe7435a6fa5f060959e354e4ab10 100644 --- a/src/purescala/Trees.scala +++ b/src/purescala/Trees.scala @@ -13,8 +13,8 @@ object Trees { } /* Like vals */ - case class Let(binder: Identifier, expression: Expr) extends Expr { - val et = expression.getType + case class Let(binder: Identifier, value: Expr, body: Expr) extends Expr { + val et = body.getType if(et != NoType) setType(et) } diff --git a/testcases/IntOperations.scala b/testcases/IntOperations.scala index d38896f80884455b915a2c779422c05291c08cf0..98a7574342ab4a7cebe3ae2a6741c977dc06fe44 100644 --- a/testcases/IntOperations.scala +++ b/testcases/IntOperations.scala @@ -1,7 +1,9 @@ object IntOperations { def sum(a: Int, b: Int) : Int = { require(b >= 0) - a + b + val b2 = b - 1 + val b3 = b2 + 1 + a + b3 } ensuring(_ >= a)