From d2d70bba9504102b00f6e51e3be1b89c28d28373 Mon Sep 17 00:00:00 2001
From: Philippe Suter <philippe.suter@gmail.com>
Date: Sat, 13 Nov 2010 17:53:04 +0000
Subject: [PATCH] some random commit

---
 pldi2011-testcases/LambdaEval.scala           | 133 ++++++++------
 src/purescala/Z3ModelReconstruction.scala     |   4 +-
 src/purescala/Z3Solver.scala                  |  92 +++++++---
 .../z3plugins/instantiator/Instantiator.scala | 164 ++++++++++++++++--
 testcases/RedBlackTree.scala                  |   7 +-
 5 files changed, 296 insertions(+), 104 deletions(-)

diff --git a/pldi2011-testcases/LambdaEval.scala b/pldi2011-testcases/LambdaEval.scala
index afe2ae81c..e7f12bde2 100644
--- a/pldi2011-testcases/LambdaEval.scala
+++ b/pldi2011-testcases/LambdaEval.scala
@@ -23,74 +23,97 @@ object LambdaEval {
     case Snd(_) => false
   }
 
-  def okPair(p: StoreExprPair): Boolean = p match {
+  def okPair(p: StoreExprPairAbs): Boolean = p match {
     case StoreExprPair(_, res) => ok(res)
   }
 
   sealed abstract class List
-  case class Cons(head: BindingPair, tail: List) extends List
+  case class Cons(head: BindingPairAbs, tail: List) extends List
   case class Nil() extends List
 
-  sealed abstract class AbstractPair
-  case class BindingPair(key: Int, value: Expr) extends AbstractPair
-  case class StoreExprPair(store: List, expr: Expr) extends AbstractPair
+  sealed abstract class BindingPairAbs
+  case class BindingPair(key: Int, value: Expr) extends BindingPairAbs
 
+  sealed abstract class StoreExprPairAbs
+  case class StoreExprPair(store: List, expr: Expr) extends StoreExprPairAbs
+
+  def storeElems(store: List) : Set[Int] = store match {
+    case Nil() => Set.empty[Int]
+    case Cons(BindingPair(k,_), xs) => Set(k) ++ storeElems(xs)
+  }
+
+  def freeVars(expr: Expr) : Set[Int] = expr match {
+    case Const(_) => Set.empty[Int]
+    case Plus(l,r) => freeVars(l) ++ freeVars(r)
+    case Lam(x, bdy) => freeVars(bdy) -- Set(x)
+    case Pair(f,s) => freeVars(f) ++ freeVars(s)
+    case Var(n) => Set(n)
+    case App(l,r) => freeVars(l) ++ freeVars(r)
+    case Fst(e) => freeVars(e)
+    case Snd(e) => freeVars(e)
+  }
 
   // Find first element in list that has first component 'x' and return its
   // second component, analogous to List.assoc in OCaml
-  def find(x: Int, l: List): Expr = l match {
-    case Cons(i, is) => if (i.key == x) i.value else find(x, is)
+  def find(x: Int, l: List): Expr = {
+    require(storeElems(l).contains(x))
+    l match {
+      case Cons(BindingPair(k,v), is) => if (k == x) v else find(x, is)
+    }
   }
 
   // Evaluator
-  def eval(store: List, expr: Expr): StoreExprPair = (expr match {
-    case Const(i) => StoreExprPair(store, Const(i))
-    case Var(x) => StoreExprPair(store, find(x, store))
-    case Plus(e1, e2) =>
-      val i1 = eval(store, e1) match {
-        case StoreExprPair(_, Const(i)) => i
-      }
-      val i2 = eval(store, e2) match {
-        case StoreExprPair(_, Const(i)) => i
-      }
-      StoreExprPair(store, Const(i1 + i2))
-    case App(e1, e2) =>
-      val store1 = eval(store, e1) match {
-        case StoreExprPair(resS,_) => resS
-      }
-      val x = eval(store, e1) match {
-        case StoreExprPair(_, Lam(resX, _)) => resX
-      }
-      val e = eval(store, e1) match {
-        case StoreExprPair(_, Lam(_, resE)) => resE
-      }
-      /*
-      val StoreExprPair(store1, Lam(x, e)) = eval(store, e1) match {
-        case StoreExprPair(resS, Lam(resX, resE)) => StoreExprPair(resS, Lam(resX, resE))
-      }
-      */
-      val v2 = eval(store, e2) match {
-        case StoreExprPair(_, v) => v
-      }
-      eval(Cons(BindingPair(x, v2), store1), e)
-    case Lam(x, e) => StoreExprPair(store, Lam(x, e))
-    case Pair(e1, e2) =>
-      val v1 = eval(store, e1) match {
-        case StoreExprPair(_, v) => v
-      }
-      val v2 = eval(store, e2) match {
-        case StoreExprPair(_, v) => v
-      }
-      StoreExprPair(store, Pair(v1, v2))
-    case Fst(e) =>
-      eval(store, e) match {
-        case StoreExprPair(_, Pair(v1, _)) => StoreExprPair(store, v1)
-      }
-    case Snd(e) =>
-      eval(store, e) match {
-        case StoreExprPair(_, Pair(_, v2)) => StoreExprPair(store, v2)
-      }
-  }) ensuring(res => okPair(res))
+  def eval(store: List, expr: Expr): StoreExprPair = {
+    require(freeVars(expr) subsetOf storeElems(store))
+    expr match {
+      case Const(i) => StoreExprPair(store, Const(i))
+      case Var(x) => StoreExprPair(store, find(x, store))
+      case Plus(e1, e2) =>
+        val i1 = eval(store, e1) match {
+          case StoreExprPair(_, Const(i)) => i
+        }
+        val i2 = eval(store, e2) match {
+          case StoreExprPair(_, Const(i)) => i
+        }
+        StoreExprPair(store, Const(i1 + i2))
+      case App(e1, e2) =>
+        val store1 = eval(store, e1) match {
+          case StoreExprPair(resS,_) => resS
+        }
+        val x = eval(store, e1) match {
+          case StoreExprPair(_, Lam(resX, _)) => resX
+        }
+        val e = eval(store, e1) match {
+          case StoreExprPair(_, Lam(_, resE)) => resE
+        }
+        /*
+        val StoreExprPair(store1, Lam(x, e)) = eval(store, e1) match {
+          case StoreExprPair(resS, Lam(resX, resE)) => StoreExprPair(resS, Lam(resX, resE))
+        }
+        */
+        val v2 = eval(store, e2) match {
+          case StoreExprPair(_, v) => v
+        }
+        eval(Cons(BindingPair(x, v2), store1), e)
+      case Lam(x, e) => StoreExprPair(store, Lam(x, e))
+      case Pair(e1, e2) =>
+        val v1 = eval(store, e1) match {
+          case StoreExprPair(_, v) => v
+        }
+        val v2 = eval(store, e2) match {
+          case StoreExprPair(_, v) => v
+        }
+        StoreExprPair(store, Pair(v1, v2))
+      case Fst(e) =>
+        eval(store, e) match {
+          case StoreExprPair(_, Pair(v1, _)) => StoreExprPair(store, v1)
+        }
+      case Snd(e) =>
+        eval(store, e) match {
+          case StoreExprPair(_, Pair(_, v2)) => StoreExprPair(store, v2)
+        }
+    }
+  } ensuring(res => okPair(res))
   /*ensuring(res => res match {
     case StoreExprPair(_, resExpr) => ok(resExpr)
   }) */
diff --git a/src/purescala/Z3ModelReconstruction.scala b/src/purescala/Z3ModelReconstruction.scala
index 83ae2b1ee..6565549cd 100644
--- a/src/purescala/Z3ModelReconstruction.scala
+++ b/src/purescala/Z3ModelReconstruction.scala
@@ -17,8 +17,8 @@ trait Z3ModelReconstruction {
       val z3ID : Z3AST = exprToZ3Id(id.toVariable)
 
       expectedType match {
-        case BooleanType => model.evalAsBool(z3ID).map(BooleanLiteral(_))
-        case Int32Type => model.evalAsInt(z3ID).map(IntLiteral(_))
+        case BooleanType => model.evalAs[Boolean](z3ID).map(BooleanLiteral(_))
+        case Int32Type => model.evalAs[Int](z3ID).map(IntLiteral(_))
         case other => model.eval(z3ID) match {
           case None => None
           case Some(t) => softFromZ3Formula(t)
diff --git a/src/purescala/Z3Solver.scala b/src/purescala/Z3Solver.scala
index d372a45d0..bfadf1dc7 100644
--- a/src/purescala/Z3Solver.scala
+++ b/src/purescala/Z3Solver.scala
@@ -570,38 +570,80 @@ class Z3Solver(val reporter: Reporter) extends Solver(reporter) with Z3ModelReco
     class CantTranslateException(t: Z3AST) extends Exception("Can't translate from Z3 tree: " + t)
 
     def rec(t: Z3AST) : Expr = z3.getASTKind(t) match {
-      case Z3AppAST(_, args) if args.size == 0 && z3IdToExpr.isDefinedAt(t) => {
-        z3IdToExpr(t)
-      }
-      case Z3AppAST(decl, args) if isKnownDecl(decl) => {
-        val fd = functionDeclToDef(decl)
-        assert(fd.args.size == args.size)
-        FunctionInvocation(fd, args.map(rec(_)))
-      }
-      case Z3AppAST(decl, args) if args.size == 1 && reverseADTTesters.isDefinedAt(decl) => {
-        CaseClassInstanceOf(reverseADTTesters(decl), rec(args(0)))
-      }
-      case Z3AppAST(decl, args) if args.size == 1 && reverseADTFieldSelectors.isDefinedAt(decl) => {
-        val (ccd, fid) = reverseADTFieldSelectors(decl)
-        CaseClassSelector(ccd, rec(args(0)), fid)
-      }
-      case Z3AppAST(decl, args) if reverseADTConstructors.isDefinedAt(decl) => {
-        val ccd = reverseADTConstructors(decl)
-        assert(args.size == ccd.fields.size)
-        CaseClass(ccd, args.map(rec(_)))
+      case Z3AppAST(decl, args) => {
+        val argsSize = args.size
+        if(argsSize == 0 && z3IdToExpr.isDefinedAt(t)) {
+          z3IdToExpr(t)
+        } else if(isKnownDecl(decl)) {
+          val fd = functionDeclToDef(decl)
+          assert(fd.args.size == argsSize)
+          FunctionInvocation(fd, args.map(rec(_)))
+        } else if(argsSize == 1 && reverseADTTesters.isDefinedAt(decl)) {
+          CaseClassInstanceOf(reverseADTTesters(decl), rec(args(0)))
+        } else if(argsSize == 1 && reverseADTFieldSelectors.isDefinedAt(decl)) {
+          val (ccd, fid) = reverseADTFieldSelectors(decl)
+          CaseClassSelector(ccd, rec(args(0)), fid)
+        } else if(reverseADTConstructors.isDefinedAt(decl)) {
+          val ccd = reverseADTConstructors(decl)
+          assert(argsSize == ccd.fields.size)
+          CaseClass(ccd, args.map(rec(_)))
+        } else {
+          import Z3DeclKind._
+          val rargs = args.map(rec(_))
+          z3.getDeclKind(decl) match {
+            case OpTrue => BooleanLiteral(true)
+            case OpFalse => BooleanLiteral(false)
+            case OpEq => Equals(rargs(0), rargs(1))
+            case OpITE => {
+              assert(argsSize == 3)
+              val r0 = rargs(0)
+              val r1 = rargs(1)
+              val r2 = rargs(2)
+              IfExpr(r0, r1, r2).setType(leastUpperBound(r1.getType, r2.getType))
+            }
+            case OpAnd => And(rargs)
+            case OpOr => Or(rargs)
+            case OpIff => Iff(rargs(0), rargs(1))
+            case OpXor => Not(Iff(rargs(0), rargs(1)))
+            case OpNot => Not(rargs(0))
+            case OpImplies => Implies(rargs(0), rargs(1))
+            case OpLE => LessEquals(rargs(0), rargs(1))
+            case OpGE => GreaterEquals(rargs(0), rargs(1))
+            case OpLT => LessThan(rargs(0), rargs(1))
+            case OpGT => GreaterThan(rargs(0), rargs(1))
+            case OpAdd => {
+              assert(argsSize == 2)
+              Plus(rargs(0), rargs(1))
+            }
+            case OpSub => {
+              assert(argsSize == 2)
+              Minus(rargs(0), rargs(1))
+            }
+            case OpUMinus => UMinus(rargs(0))
+            case OpMul => {
+              assert(argsSize == 2)
+              Times(rargs(0), rargs(1))
+            }
+            case other => {
+              System.err.println("Don't know what to do with this declKind : " + other)
+              throw new CantTranslateException(t)
+            }
+          }
+        }
       }
+
       case Z3NumeralAST(Some(v)) => IntLiteral(v)
       case other @ _ => {
-        println("Don't know what this is " + other) 
+        System.err.println("Don't know what this is " + other) 
         if(useInstantiator) {
           instantiator.dumpFunctionMap
         } else {
-          println("REVERSE FUNCTION MAP:")
-          println(reverseFunctionMap.toSeq.mkString("\n"))
+          System.err.println("REVERSE FUNCTION MAP:")
+          System.err.println(reverseFunctionMap.toSeq.mkString("\n"))
         }
-        println("REVERSE CONS MAP:")
-        println(reverseADTConstructors.toSeq.mkString("\n"))
-        System.exit(-1)
+        System.err.println("REVERSE CONS MAP:")
+        System.err.println(reverseADTConstructors.toSeq.mkString("\n"))
+        // System.exit(-1)
         throw new CantTranslateException(t)
       }
     }
diff --git a/src/purescala/z3plugins/instantiator/Instantiator.scala b/src/purescala/z3plugins/instantiator/Instantiator.scala
index 2eecf5e87..9c84c857f 100644
--- a/src/purescala/z3plugins/instantiator/Instantiator.scala
+++ b/src/purescala/z3plugins/instantiator/Instantiator.scala
@@ -10,20 +10,22 @@ import purescala.Settings
 
 import purescala.Z3Solver
 
+import scala.collection.mutable.{Map => MutableMap, Set => MutableSet}
+
 class Instantiator(val z3Solver: Z3Solver) extends Z3Theory(z3Solver.z3, "Instantiator") {
   import z3Solver.{z3,program,typeToSort,fromZ3Formula,toZ3Formula}
 
   setCallbacks(
 //    reduceApp = true,
-//    finalCheck = true,
-//    push = true,
-//    pop = true,
+    finalCheck = true,
+    push = true,
+    pop = true,
     newApp = true,
     newAssignment = true,
     newRelevant = true,
 //    newEq = true,
 //    newDiseq = true,
-//    reset = true,
+    reset = true,
     restart = true
   )
 
@@ -51,35 +53,39 @@ class Instantiator(val z3Solver: Z3Solver) extends Z3Theory(z3Solver.z3, "Instan
     reverseFunctionMap.getOrElse(decl, scala.Predef.error("No FunDef found for Z3 definition " + decl + " in Instantiator."))
   }
 
-  override def newAssignment(ast: Z3AST, polarity: Boolean) : Unit = {
+  // The logic starts here.
+  private var stillToAssert : Set[(Int,Expr)] = Set.empty
+
+  override def newAssignment(ast: Z3AST, polarity: Boolean) : Unit = safeBlockToAssertAxioms {
 
   }
 
   override def newApp(ast: Z3AST) : Unit = {
+    examineAndUnroll(ast)
+  }
 
+  override def newRelevant(ast: Z3AST) : Unit = {
+    examineAndUnroll(ast)
   }
 
   private var bodyInlined : Int = 0
-  override def newRelevant(ast: Z3AST) : Unit = {
+  def examineAndUnroll(ast: Z3AST) : Unit = if(bodyInlined < Settings.unrollingLevel) {
     val aps = fromZ3Formula(ast)
     val fis = functionCallsOf(aps)
     println("As Purescala: " + aps)
     for(fi <- fis) {
       val FunctionInvocation(fd, args) = fi
       println("interesting function call : " + fi)
-      if(fd.hasPostcondition) {
+      if(bodyInlined < Settings.unrollingLevel && fd.hasPostcondition) {
+        bodyInlined += 1
         val post = matchToIfThenElse(fd.postcondition.get)
-        // FIXME TODO we could use let identifiers here to speed things up a little bit...
-        //  val resFresh = FreshIdentifier("resForPostOf" + fd.id.uniqueName, true).setType(fi.getType)
-        //  val newLetIDs = fd.args.map(a => FreshIdentifier("argForPostOf" + fd.id.uniqueName, true).setType(a.tpe)).toList
-        //  val substMap = Map[Expr,Expr]((fd.args.map(_.toVariable) zip newLetIDs.map(Variable(_))) : _*) + (ResultVariable() -> Variable(resFresh))
+        val isSafe = functionCallsOf(post).isEmpty
+
         val substMap = Map[Expr,Expr]((fd.args.map(_.toVariable) zip args) : _*) + (ResultVariable() -> fi)
         // println(substMap)
         val newBody = replace(substMap, post)
         println("I'm going to add this : " + newBody)
-        val newAxiom = toZ3Formula(z3, newBody).get
-        println("As Z3: " + newAxiom)
-        assertAxiom(newAxiom)
+        assertIfSafeOrDelay(newBody)//, isSafe)
       }
 
       if(bodyInlined < Settings.unrollingLevel && fd.hasBody) {
@@ -87,13 +93,133 @@ class Instantiator(val z3Solver: Z3Solver) extends Z3Theory(z3Solver.z3, "Instan
         val body = matchToIfThenElse(fd.body.get)
         val substMap = Map[Expr,Expr]((fd.args.map(_.toVariable) zip args) : _*)
         val newBody = replace(substMap, body)
-        println("I'm going to add this : " + newBody)
-        val newAxiom = z3.mkEq(toZ3Formula(z3, fi).get, toZ3Formula(z3, newBody).get)
-        println("As Z3: " + newAxiom)
-        assertAxiom(newAxiom)
+        val theEquality = Equals(fi, newBody)
+        println("I'm going to add this : " + theEquality)
+        assertIfSafeOrDelay(theEquality)
+      }
+    }
+  }
+
+  override def finalCheck : Boolean = safeBlockToAssertAxioms {
+    if(stillToAssert.isEmpty) {
+      true
+    } else {
+      for((lvl,ast) <- stillToAssert) {
+        assertAxiomASAP(ast, lvl)
+        // assertPermanently(ast)
       }
+      stillToAssert = Set.empty
+      true
     }
   }
 
-  override def restart : Unit = { }
+  // This is concerned with how many new function calls the assertion is going
+  // to introduce.
+  private def assertIfSafeOrDelay(ast: Expr, isSafe: Boolean = false) : Unit = {
+    stillToAssert += ((pushLevel, ast))
+  }
+
+  // Assert as soon as possible and keep asserting as long as level is >= lvl.
+  private def assertAxiomASAP(expr: Expr, lvl: Int) : Unit = assertAxiomASAP(toZ3Formula(z3, expr).get, lvl)
+  private def assertAxiomASAP(ast: Z3AST, lvl: Int) : Unit = {
+    if(canAssertAxiom) {
+      assertAxiomNow(ast)
+      if(lvl < pushLevel) {
+        // Remember to reassert when we backtrack.
+        if(pushLevel > 0) {
+          rememberToReassertAt(pushLevel - 1, lvl, ast)
+        }
+      }
+    } else {
+      toAssertASAP = toAssertASAP + ((lvl, ast))
+    }
+  }
+
+  private def assertAxiomFrom(ast: Z3AST, level: Int) : Unit = {
+    toAssertASAP = toAssertASAP + ((level, ast))
+  }
+
+//  private def assertPermanently(expr: Expr) : Unit = {
+//    val asZ3 = toZ3Formula(z3, expr).get
+//
+//    if(canAssertAxiom) {
+//      assertAxiomNow(asZ3)
+//    } else {
+//      toAssertASAP = toAssertASAP + ((0, asZ3))
+//    }
+//  }
+
+  private def assertAxiomNow(ast: Z3AST) : Unit = {
+    if(!canAssertAxiom)
+      println("WARNING ! ASSERTING AXIOM WHEN NOT SAFE !")
+
+    println("Now asserting : " + ast)
+    assertAxiom(ast)
+  }
+
+  override def push : Unit = {
+    pushLevel += 1
+  }
+
+  override def pop : Unit = {
+    pushLevel -= 1
+
+    if(toReassertAt.isDefinedAt(pushLevel)) {
+      for((lvl,ax) <- toReassertAt(pushLevel)) {
+        assertAxiomFrom(ax, lvl)
+      }
+      toReassertAt(pushLevel).clear
+    }
+
+    assert(pushLevel >= 0)
+  }
+
+  override def restart : Unit = {
+    pushLevel = 0
+  }
+
+  override def reset : Unit = reinit
+
+  // Below is all the machinery to be able to assert axioms in safe states.
+
+  private var pushLevel : Int = _
+  private var canAssertAxiom : Boolean = _
+  private var toAssertASAP : Set[(Int,Z3AST)] = _
+  private val toReassertAt : MutableMap[Int,MutableSet[(Int,Z3AST)]] = MutableMap.empty
+
+  private def rememberToReassertAt(lvl: Int, axLvl: Int, ax: Z3AST) : Unit = {
+    if(toReassertAt.isDefinedAt(lvl)) {
+      toReassertAt(lvl) += ((axLvl, ax))
+    } else {
+      toReassertAt(lvl) = MutableSet((axLvl, ax))
+    }
+  }
+
+  reinit
+  private def reinit : Unit = {
+    pushLevel = 0
+    canAssertAxiom = false
+    toAssertASAP = Set.empty
+    stillToAssert = Set.empty
+  }
+
+  private def safeBlockToAssertAxioms[A](block: => A): A = {
+    canAssertAxiom = true
+
+    if (toAssertASAP.nonEmpty) {
+      for ((lvl, ax) <- toAssertASAP) {
+        if(lvl <= pushLevel) {
+          assertAxiomNow(ax)
+          if(lvl < pushLevel && pushLevel > 0) {
+            rememberToReassertAt(pushLevel - 1, lvl, ax)
+          }
+        }
+      }
+      toAssertASAP = Set.empty
+    }
+    
+    val result = block
+    canAssertAxiom = false
+    result
+  }
 }
diff --git a/testcases/RedBlackTree.scala b/testcases/RedBlackTree.scala
index 9cb55c574..04ccb08fa 100644
--- a/testcases/RedBlackTree.scala
+++ b/testcases/RedBlackTree.scala
@@ -26,9 +26,10 @@ object RedBlackTree {
       if      (x < y)  balance(c, ins(x, a), y, b)
       else if (x == y) Node(c,a,y,b)
       else             balance(c,a,y,ins(x, b))
-  }) ensuring (res => 
-          content(res) == content(t) ++ Set(x) &&
-          size(t) <= size(res) && size(res) < size(t) + 2)
+  }) ensuring (res => (
+             content(res) == content(t) ++ Set(x) 
+//          && size(t) <= size(res) && size(res) < size(t) + 2)
+              ))
 
   def add(x: Int, t: Tree): Tree = {
     makeBlack(ins(x, t))
-- 
GitLab