From 0bafdbfb31e5654c698903f06921f2b853298d90 Mon Sep 17 00:00:00 2001 From: Philippe Suter <philippe.suter@gmail.com> Date: Wed, 30 Jun 2010 15:51:34 +0000 Subject: [PATCH] Now helper function to expand let expressions. Plus, SetEquals and Iff's are correctly extracted. --- src/funcheck/CodeExtraction.scala | 10 ++++++- src/purescala/PrettyPrinter.scala | 6 ++-- src/purescala/Trees.scala | 47 +++++++++++++++++++++++++++++-- src/purescala/Z3Solver.scala | 11 ++++++-- 4 files changed, 66 insertions(+), 8 deletions(-) diff --git a/src/funcheck/CodeExtraction.scala b/src/funcheck/CodeExtraction.scala index d0e0fed69..683db51ad 100644 --- a/src/funcheck/CodeExtraction.scala +++ b/src/funcheck/CodeExtraction.scala @@ -352,7 +352,15 @@ trait CodeExtraction extends Extractors { case ExMinus(l, r) => Minus(rec(l), rec(r)).setType(Int32Type) case ExTimes(l, r) => Times(rec(l), rec(r)).setType(Int32Type) case ExDiv(l, r) => Division(rec(l), rec(r)).setType(Int32Type) - case ExEquals(l, r) => Equals(rec(l), rec(r)).setType(BooleanType) + case ExEquals(l, r) => { + val rl = rec(l) + val rr = rec(r) + ((rl.getType,rr.getType) match { + case (SetType(_), SetType(_)) => SetEquals(rl, rr) + case (BooleanType, BooleanType) => Iff(rl, rr) + case (_, _) => Equals(rl, rr) + }).setType(BooleanType) + } case ExNotEquals(l, r) => Not(Equals(rec(l), rec(r)).setType(BooleanType)).setType(BooleanType) case ExGreaterThan(l, r) => GreaterThan(rec(l), rec(r)).setType(BooleanType) case ExGreaterEqThan(l, r) => GreaterEquals(rec(l), rec(r)).setType(BooleanType) diff --git a/src/purescala/PrettyPrinter.scala b/src/purescala/PrettyPrinter.scala index f273e877f..b07d7766f 100644 --- a/src/purescala/PrettyPrinter.scala +++ b/src/purescala/PrettyPrinter.scala @@ -68,14 +68,16 @@ object PrettyPrinter { private def pp(tree: Expr, sb: StringBuffer, lvl: Int): StringBuffer = tree match { case Variable(id) => sb.append(id) case Let(b,d,e) => { - pp(e, pp(d, sb.append("(let (" + b + " = "), lvl).append(") in "), lvl).append(")") + pp(e, pp(d, sb.append("(let (" + b + " := "), lvl).append(") in "), lvl).append(")") } case And(exprs) => ppNary(sb, exprs, "(", " \u2227 ", ")", lvl) // \land case Or(exprs) => ppNary(sb, exprs, "(", " \u2228 ", ")", lvl) // \lor case Not(Equals(l, r)) => ppBinary(sb, l, r, " \u2260 ", lvl) // \neq case Not(expr) => ppUnary(sb, expr, "\u00AC(", ")", lvl) // \neg - case Implies(l,r) => ppBinary(sb, l, r, "==>", lvl) + case Iff(l,r) => ppBinary(sb, l, r, " <=> ", lvl) + case Implies(l,r) => ppBinary(sb, l, r, " ==> ", lvl) case UMinus(expr) => ppUnary(sb, expr, "-(", ")", lvl) + case SetEquals(l,r) => ppBinary(sb, l, r, " =S= ", lvl) case Equals(l,r) => ppBinary(sb, l, r, " == ", lvl) case IntLiteral(v) => sb.append(v) case BooleanLiteral(v) => sb.append(v) diff --git a/src/purescala/Trees.scala b/src/purescala/Trees.scala index 6a86a0d19..63d03f20e 100644 --- a/src/purescala/Trees.scala +++ b/src/purescala/Trees.scala @@ -83,6 +83,10 @@ object Trees { val fixedType = BooleanType } + case class Iff(left: Expr, right: Expr) extends Expr with FixedType { + val fixedType = BooleanType + } + case class Implies(left: Expr, right: Expr) extends Expr with FixedType { val fixedType = BooleanType } @@ -91,7 +95,7 @@ object Trees { val fixedType = BooleanType } - /* Maybe we should split this one depending on types? */ + /* For all types that don't have their own XXXEquals */ case class Equals(left: Expr, right: Expr) extends Expr with FixedType { val fixedType = BooleanType } @@ -140,7 +144,9 @@ object Trees { case class FiniteSet(elements: Seq[Expr]) extends Expr case class ElementOfSet(element: Expr, set: Expr) extends Expr case class IsEmptySet(set: Expr) extends Expr - case class SetEquals(set1: Expr, set2: Expr) extends Expr + case class SetEquals(set1: Expr, set2: Expr) extends Expr with FixedType { + val fixedType = BooleanType + } case class SetCardinality(set: Expr) extends Expr with FixedType { val fixedType = Int32Type } @@ -194,6 +200,7 @@ object Trees { object BinaryOperator { def unapply(expr: Expr) : Option[(Expr,Expr,(Expr,Expr)=>Expr)] = expr match { case Equals(t1,t2) => Some((t1,t2,Equals)) + case Iff(t1,t2) => Some((t1,t2,Iff)) case Implies(t1,t2) => Some((t1,t2,Implies)) case Plus(t1,t2) => Some((t1,t2,Plus)) case Minus(t1,t2) => Some((t1,t2,Minus)) @@ -262,4 +269,40 @@ object Trees { rec(expr) } + + def expandLets(expr: Expr) : Expr = { + def rec(ex: Expr, s: Map[Identifier,Expr]) : Expr = ex match { + case v @ Variable(id) if s.isDefinedAt(id) => rec(s(id), s) + case l @ Let(i,e,b) => rec(b, s + (i -> rec(e, s))) + case f @ FunctionInvocation(fd, args) => FunctionInvocation(fd, args.map(rec(_, s))).setType(f.getType) + case i @ IfExpr(t1,t2,t3) => IfExpr(rec(t1, s),rec(t2, s),rec(t3, s)).setType(i.getType) + case m @ MatchExpr(scrut,cses) => MatchExpr(rec(scrut, s), cses.map(inCase(_, s))).setType(m.getType) + case And(exs) => And(exs.map(rec(_, s))) + case Or(exs) => Or(exs.map(rec(_, s))) + case Not(e) => Not(rec(e, s)) + case u @ UnaryOperator(t,recons) => { + val r = rec(t, s) + if(r != t) + recons(r).setType(u.getType) + else + u + } + case b @ BinaryOperator(t1,t2,recons) => { + val r1 = rec(t1, s) + val r2 = rec(t2, s) + if(r1 != t1 || r2 != t2) + recons(r1,r2).setType(b.getType) + else + b + } + case _ => ex + } + + def inCase(cse: MatchCase, s: Map[Identifier,Expr]) : MatchCase = cse match { + case SimpleCase(pat, rhs) => SimpleCase(pat, rec(rhs, s)) + case GuardedCase(pat, guard, rhs) => GuardedCase(pat, rec(guard, s), rec(rhs, s)) + } + + rec(expr, Map.empty) + } } diff --git a/src/purescala/Z3Solver.scala b/src/purescala/Z3Solver.scala index 814162eac..966d9cded 100644 --- a/src/purescala/Z3Solver.scala +++ b/src/purescala/Z3Solver.scala @@ -15,7 +15,7 @@ class Z3Solver(reporter: Reporter) extends Solver(reporter) { z3cfg.setParamValue("MODEL", "true") val z3 = new Z3Context(z3cfg) - toZ3Formula(z3, vc) match { + val result = toZ3Formula(z3, vc) match { case None => None // means it could not be translated case Some(z3f) => { z3.assertCnstr(z3.mkNot(z3f)) @@ -34,9 +34,12 @@ class Z3Solver(reporter: Reporter) extends Solver(reporter) { } } } + + z3.delete + result } - def toZ3Formula(z3: Z3Context, expr: Expr) : Option[Z3AST] = { + private def toZ3Formula(z3: Z3Context, expr: Expr) : Option[Z3AST] = { class CantTranslateException extends Exception lazy val intSort = z3.mkIntSort() @@ -69,9 +72,11 @@ class Z3Solver(reporter: Reporter) extends Solver(reporter) { case IfExpr(c,t,e) => z3.mkITE(rec(c), rec(t), rec(e)) case And(exs) => z3.mkAnd(exs.map(rec(_)) : _*) case Or(exs) => z3.mkOr(exs.map(rec(_)) : _*) + case Implies(l,r) => z3.mkImplies(rec(l), rec(r)) + case Iff(l,r) => z3.mkIff(rec(l), rec(r)) + case Not(Iff(l,r)) => z3.mkXor(rec(l), rec(r)) case Not(Equals(l,r)) => z3.mkDistinct(rec(l),rec(r)) case Not(e) => z3.mkNot(rec(e)) - case Implies(l,r) => z3.mkImplies(rec(l), rec(r)) case IntLiteral(v) => z3.mkInt(v, intSort) case BooleanLiteral(v) => if (v) z3.mkTrue() else z3.mkFalse() case Equals(l,r) => z3.mkEq(rec(l),rec(r)) -- GitLab