From 0bafdbfb31e5654c698903f06921f2b853298d90 Mon Sep 17 00:00:00 2001
From: Philippe Suter <philippe.suter@gmail.com>
Date: Wed, 30 Jun 2010 15:51:34 +0000
Subject: [PATCH] Now helper function to expand let expressions. Plus,
 SetEquals and Iff's are correctly extracted.

---
 src/funcheck/CodeExtraction.scala | 10 ++++++-
 src/purescala/PrettyPrinter.scala |  6 ++--
 src/purescala/Trees.scala         | 47 +++++++++++++++++++++++++++++--
 src/purescala/Z3Solver.scala      | 11 ++++++--
 4 files changed, 66 insertions(+), 8 deletions(-)

diff --git a/src/funcheck/CodeExtraction.scala b/src/funcheck/CodeExtraction.scala
index d0e0fed69..683db51ad 100644
--- a/src/funcheck/CodeExtraction.scala
+++ b/src/funcheck/CodeExtraction.scala
@@ -352,7 +352,15 @@ trait CodeExtraction extends Extractors {
       case ExMinus(l, r) => Minus(rec(l), rec(r)).setType(Int32Type)
       case ExTimes(l, r) => Times(rec(l), rec(r)).setType(Int32Type)
       case ExDiv(l, r) => Division(rec(l), rec(r)).setType(Int32Type)
-      case ExEquals(l, r) => Equals(rec(l), rec(r)).setType(BooleanType)
+      case ExEquals(l, r) => {
+        val rl = rec(l)
+        val rr = rec(r)
+        ((rl.getType,rr.getType) match {
+          case (SetType(_), SetType(_)) => SetEquals(rl, rr)
+          case (BooleanType, BooleanType) => Iff(rl, rr)
+          case (_, _) => Equals(rl, rr)
+        }).setType(BooleanType) 
+      }
       case ExNotEquals(l, r) => Not(Equals(rec(l), rec(r)).setType(BooleanType)).setType(BooleanType)
       case ExGreaterThan(l, r) => GreaterThan(rec(l), rec(r)).setType(BooleanType)
       case ExGreaterEqThan(l, r) => GreaterEquals(rec(l), rec(r)).setType(BooleanType)
diff --git a/src/purescala/PrettyPrinter.scala b/src/purescala/PrettyPrinter.scala
index f273e877f..b07d7766f 100644
--- a/src/purescala/PrettyPrinter.scala
+++ b/src/purescala/PrettyPrinter.scala
@@ -68,14 +68,16 @@ object PrettyPrinter {
   private def pp(tree: Expr, sb: StringBuffer, lvl: Int): StringBuffer = tree match {
     case Variable(id) => sb.append(id)
     case Let(b,d,e) => {
-        pp(e, pp(d, sb.append("(let (" + b + " = "), lvl).append(") in "), lvl).append(")")
+        pp(e, pp(d, sb.append("(let (" + b + " := "), lvl).append(") in "), lvl).append(")")
     }
     case And(exprs) => ppNary(sb, exprs, "(", " \u2227 ", ")", lvl)            // \land
     case Or(exprs) => ppNary(sb, exprs, "(", " \u2228 ", ")", lvl)             // \lor
     case Not(Equals(l, r)) => ppBinary(sb, l, r, " \u2260 ", lvl)    // \neq
     case Not(expr) => ppUnary(sb, expr, "\u00AC(", ")", lvl)               // \neg
-    case Implies(l,r) => ppBinary(sb, l, r, "==>", lvl)              
+    case Iff(l,r) => ppBinary(sb, l, r, " <=> ", lvl)              
+    case Implies(l,r) => ppBinary(sb, l, r, " ==> ", lvl)              
     case UMinus(expr) => ppUnary(sb, expr, "-(", ")", lvl)
+    case SetEquals(l,r) => ppBinary(sb, l, r, " =S= ", lvl)
     case Equals(l,r) => ppBinary(sb, l, r, " == ", lvl)
     case IntLiteral(v) => sb.append(v)
     case BooleanLiteral(v) => sb.append(v)
diff --git a/src/purescala/Trees.scala b/src/purescala/Trees.scala
index 6a86a0d19..63d03f20e 100644
--- a/src/purescala/Trees.scala
+++ b/src/purescala/Trees.scala
@@ -83,6 +83,10 @@ object Trees {
     val fixedType = BooleanType
   }
 
+  case class Iff(left: Expr, right: Expr) extends Expr with FixedType {
+    val fixedType = BooleanType
+  }
+
   case class Implies(left: Expr, right: Expr) extends Expr with FixedType {
     val fixedType = BooleanType
   }
@@ -91,7 +95,7 @@ object Trees {
     val fixedType = BooleanType
   }
 
-  /* Maybe we should split this one depending on types? */
+  /* For all types that don't have their own XXXEquals */
   case class Equals(left: Expr, right: Expr) extends Expr with FixedType {
     val fixedType = BooleanType
   }
@@ -140,7 +144,9 @@ object Trees {
   case class FiniteSet(elements: Seq[Expr]) extends Expr 
   case class ElementOfSet(element: Expr, set: Expr) extends Expr 
   case class IsEmptySet(set: Expr) extends Expr 
-  case class SetEquals(set1: Expr, set2: Expr) extends Expr 
+  case class SetEquals(set1: Expr, set2: Expr) extends Expr with FixedType {
+    val fixedType = BooleanType
+  }
   case class SetCardinality(set: Expr) extends Expr with FixedType {
     val fixedType = Int32Type
   }
@@ -194,6 +200,7 @@ object Trees {
   object BinaryOperator {
     def unapply(expr: Expr) : Option[(Expr,Expr,(Expr,Expr)=>Expr)] = expr match {
       case Equals(t1,t2) => Some((t1,t2,Equals))
+      case Iff(t1,t2) => Some((t1,t2,Iff))
       case Implies(t1,t2) => Some((t1,t2,Implies))
       case Plus(t1,t2) => Some((t1,t2,Plus))
       case Minus(t1,t2) => Some((t1,t2,Minus))
@@ -262,4 +269,40 @@ object Trees {
 
     rec(expr)
   }
+
+  def expandLets(expr: Expr) : Expr = {
+    def rec(ex: Expr, s: Map[Identifier,Expr]) : Expr = ex match {
+      case v @ Variable(id) if s.isDefinedAt(id) => rec(s(id), s)
+      case l @ Let(i,e,b) => rec(b, s + (i -> rec(e, s)))
+      case f @ FunctionInvocation(fd, args) => FunctionInvocation(fd, args.map(rec(_, s))).setType(f.getType)
+      case i @ IfExpr(t1,t2,t3) => IfExpr(rec(t1, s),rec(t2, s),rec(t3, s)).setType(i.getType)
+      case m @ MatchExpr(scrut,cses) => MatchExpr(rec(scrut, s), cses.map(inCase(_, s))).setType(m.getType)
+      case And(exs) => And(exs.map(rec(_, s)))
+      case Or(exs) => Or(exs.map(rec(_, s)))
+      case Not(e) => Not(rec(e, s))
+      case u @ UnaryOperator(t,recons) => {
+        val r = rec(t, s)
+        if(r != t)
+          recons(r).setType(u.getType)
+        else
+          u
+      }
+      case b @ BinaryOperator(t1,t2,recons) => {
+        val r1 = rec(t1, s)
+        val r2 = rec(t2, s)
+        if(r1 != t1 || r2 != t2)
+          recons(r1,r2).setType(b.getType)
+        else
+          b
+      }
+      case _ => ex
+    }
+
+    def inCase(cse: MatchCase, s: Map[Identifier,Expr]) : MatchCase = cse match {
+      case SimpleCase(pat, rhs) => SimpleCase(pat, rec(rhs, s))
+      case GuardedCase(pat, guard, rhs) => GuardedCase(pat, rec(guard, s), rec(rhs, s))
+    }
+
+    rec(expr, Map.empty)
+  }
 }
diff --git a/src/purescala/Z3Solver.scala b/src/purescala/Z3Solver.scala
index 814162eac..966d9cded 100644
--- a/src/purescala/Z3Solver.scala
+++ b/src/purescala/Z3Solver.scala
@@ -15,7 +15,7 @@ class Z3Solver(reporter: Reporter) extends Solver(reporter) {
     z3cfg.setParamValue("MODEL", "true")
     val z3 = new Z3Context(z3cfg)
 
-    toZ3Formula(z3, vc) match {
+    val result = toZ3Formula(z3, vc) match {
       case None => None // means it could not be translated
       case Some(z3f) => {
         z3.assertCnstr(z3.mkNot(z3f))
@@ -34,9 +34,12 @@ class Z3Solver(reporter: Reporter) extends Solver(reporter) {
         }
       }
     }
+
+    z3.delete
+    result
   }
 
-  def toZ3Formula(z3: Z3Context, expr: Expr) : Option[Z3AST] = {
+  private def toZ3Formula(z3: Z3Context, expr: Expr) : Option[Z3AST] = {
     class CantTranslateException extends Exception
 
     lazy val intSort  = z3.mkIntSort()
@@ -69,9 +72,11 @@ class Z3Solver(reporter: Reporter) extends Solver(reporter) {
       case IfExpr(c,t,e) => z3.mkITE(rec(c), rec(t), rec(e))
       case And(exs) => z3.mkAnd(exs.map(rec(_)) : _*)
       case Or(exs) => z3.mkOr(exs.map(rec(_)) : _*)
+      case Implies(l,r) => z3.mkImplies(rec(l), rec(r))
+      case Iff(l,r) => z3.mkIff(rec(l), rec(r))
+      case Not(Iff(l,r)) => z3.mkXor(rec(l), rec(r))   
       case Not(Equals(l,r)) => z3.mkDistinct(rec(l),rec(r))
       case Not(e) => z3.mkNot(rec(e))
-      case Implies(l,r) => z3.mkImplies(rec(l), rec(r))
       case IntLiteral(v) => z3.mkInt(v, intSort)
       case BooleanLiteral(v) => if (v) z3.mkTrue() else z3.mkFalse()
       case Equals(l,r) => z3.mkEq(rec(l),rec(r))
-- 
GitLab