From caa9d19b7e3b56bcd58ff9e9e1d0338cfeb3c659 Mon Sep 17 00:00:00 2001
From: Philippe Suter <philippe.suter@gmail.com>
Date: Mon, 15 Nov 2010 22:33:22 +0000
Subject: [PATCH] morebetterfaster

---
 src/purescala/DefaultTactic.scala             |   6 +-
 src/purescala/Trees.scala                     |  21 +++-
 src/purescala/Z3Solver.scala                  |   4 -
 .../z3plugins/instantiator/Instantiator.scala | 100 ++++++++++++++----
 4 files changed, 106 insertions(+), 25 deletions(-)

diff --git a/src/purescala/DefaultTactic.scala b/src/purescala/DefaultTactic.scala
index 1a4338cdb..cecf34278 100644
--- a/src/purescala/DefaultTactic.scala
+++ b/src/purescala/DefaultTactic.scala
@@ -110,7 +110,11 @@ class DefaultTactic(reporter: Reporter) extends Tactic(reporter) {
     }
 
     def generatePreconditions(function: FunDef) : Seq[VerificationCondition] = {
-      errorConditions(function).filter(_.kind == VCKind.Precondition)
+      val toRet = errorConditions(function).filter(_.kind == VCKind.Precondition)
+
+      println("PRECONDITIONS FOR " + function.id.name)
+      println(toRet.map(_.condition).toList.mkString("\n"))
+      toRet
     }
 
     def generatePatternMatchingExhaustivenessChecks(function: FunDef) : Seq[VerificationCondition] = {
diff --git a/src/purescala/Trees.scala b/src/purescala/Trees.scala
index 37a056542..de585fcb3 100644
--- a/src/purescala/Trees.scala
+++ b/src/purescala/Trees.scala
@@ -958,7 +958,7 @@ object Trees {
 
   def explicitPreconditions(expr: Expr) : Expr = {
     def rewriteFunctionCall(e: Expr) : Option[Expr] = e match {
-      case fi @ FunctionInvocation(fd, args) if(fd.hasPrecondition) => {
+      case fi @ FunctionInvocation(fd, args) if(fd.hasPrecondition && fd.precondition.get != BooleanLiteral(true)) => {
         val fTpe = fi.getType
         val prec = matchToIfThenElse(fd.precondition.get)
         val newLetIDs = fd.args.map(a => FreshIdentifier("precarg_" + a.id.name, true).setType(a.tpe))
@@ -1044,4 +1044,23 @@ object Trees {
     
     searchAndReplaceDFS(rewritePM)(expr)
   }
+
+  // prec: expression does not contain match expressions
+  def measureADTChildrenDepth(expression: Expr) : Int = {
+    import scala.math.max
+
+    def rec(ex: Expr, lm: Map[Identifier,Int]) : Int = ex match {
+      case Let(i,e,b) => rec(b,lm + (i -> rec(e,lm)))
+      case Variable(id) => lm.getOrElse(id, 0)
+      case CaseClassSelector(_, e, _) => rec(e,lm) + 1
+      case NAryOperator(args, _) => if(args.isEmpty) 0 else args.map(rec(_,lm)).max
+      case BinaryOperator(e1,e2,_) => max(rec(e1,lm), rec(e2,lm))
+      case UnaryOperator(e,_) => rec(e,lm)
+      case IfExpr(c,t,e) => max(max(rec(c,lm),rec(t,lm)),rec(e,lm))
+      case t: Terminal => 0
+      case _ => scala.Predef.error("Not handled in measureChildrenDepth : " + ex)
+    }
+    
+    rec(expression,Map.empty)
+  }
 }
diff --git a/src/purescala/Z3Solver.scala b/src/purescala/Z3Solver.scala
index ca2f0f878..e837ba80b 100644
--- a/src/purescala/Z3Solver.scala
+++ b/src/purescala/Z3Solver.scala
@@ -350,10 +350,6 @@ class Z3Solver(val reporter: Reporter) extends Solver(reporter) with Z3ModelReco
     val result = toZ3Formula(z3, toConvert) match {
       case None => None // means it could not be translated
       case Some(z3f) => {
-        if(Settings.experimental) {
-          reporter.info("Z3 Formula:")
-          reporter.info(z3f)
-        }
         //z3.push
         z3.assertCnstr(z3f)
         //z3.print
diff --git a/src/purescala/z3plugins/instantiator/Instantiator.scala b/src/purescala/z3plugins/instantiator/Instantiator.scala
index 5f9239a16..0b861a084 100644
--- a/src/purescala/z3plugins/instantiator/Instantiator.scala
+++ b/src/purescala/z3plugins/instantiator/Instantiator.scala
@@ -12,6 +12,7 @@ import purescala.Z3Solver
 import purescala.PartialEvaluator
 
 import scala.collection.mutable.{Map => MutableMap, Set => MutableSet}
+import scala.collection.mutable.PriorityQueue
 
 class Instantiator(val z3Solver: Z3Solver) extends Z3Theory(z3Solver.z3, "Instantiator") {
   import z3Solver.{z3,program,typeToSort,fromZ3Formula,toZ3Formula}
@@ -25,15 +26,16 @@ class Instantiator(val z3Solver: Z3Solver) extends Z3Theory(z3Solver.z3, "Instan
     pop = true,
     newApp = true,
     newAssignment = true,
-    newRelevant = true,
-//    newEq = true,
-//    newDiseq = true,
+    //newRelevant = true,
+    newEq = true,
+    newDiseq = true,
     reset = true,
     restart = true
   )
 
   //showCallbacks(true)
 
+  // Related to creating and recovering Z3 function symbols
   private var functionMap : Map[FunDef,Z3FuncDecl] = Map.empty
   private var reverseFunctionMap : Map[Z3FuncDecl,FunDef] = Map.empty
 
@@ -56,23 +58,39 @@ class Instantiator(val z3Solver: Z3Solver) extends Z3Theory(z3Solver.z3, "Instan
     reverseFunctionMap.getOrElse(decl, scala.Predef.error("No FunDef found for Z3 definition " + decl + " in Instantiator."))
   }
 
-  // The logic starts here.
-  private var stillToAssert : List[(Int,Expr)] = Nil
+  // Related to discovering function calls and adding instantiations
+  private var queue : PriorityQueue[Unrolling] = new PriorityQueue[Unrolling]()(UnrollingOrdering)
+//  private var stillToAssert : List[(Int,Expr)] = Nil
 
   override def newAssignment(ast: Z3AST, polarity: Boolean) : Unit = safeBlockToAssertAxioms {
     examineAndUnroll(ast)
   }
 
+  // Just using these to assert axioms early when possible...
+  override def newEq(ast1: Z3AST, ast2: Z3AST) : Unit = safeBlockToAssertAxioms {}
+  override def newDiseq(ast1: Z3AST, ast2: Z3AST) : Unit = safeBlockToAssertAxioms {}
+
   override def newApp(ast: Z3AST) : Unit = {
-    // examineAndUnroll(ast)
+    examineAndUnroll(ast)
   }
 
   override def newRelevant(ast: Z3AST) : Unit = {
-    examineAndUnroll(ast)
+    // WARNING : CURRENTLY NOT CALLED !
+    //examineAndUnroll(ast)
   }
 
   private var bodyInlined : Int = 0
+  private var seen : Set[Z3AST] = Set.empty
+  private var seenCount : Int = 0
   def examineAndUnroll(ast: Z3AST, allFunctions: Boolean = false) : Unit = if(bodyInlined < Settings.unrollingLevel) {
+    if(seen(ast)) {
+      seenCount += 1
+//      println(" HIT ! seenCount now at " + seenCount)
+      return
+    } else {
+      seen += ast
+    }
+
     val aps = fromZ3Formula(ast)
     val fis : Set[FunctionInvocation] = if(allFunctions) {
       functionCallsOf(aps)
@@ -86,31 +104,50 @@ class Instantiator(val z3Solver: Z3Solver) extends Z3Theory(z3Solver.z3, "Instan
     //println("As Purescala: " + aps)
     for(fi <- fis) {
       val FunctionInvocation(fd, args) = fi
-      //println("interesting function call : " + fi)
       if(bodyInlined < Settings.unrollingLevel && fd.hasPostcondition) {
         bodyInlined += 1
         val post = matchToIfThenElse(fd.postcondition.get)
 
         val substMap = Map[Expr,Expr]((fd.args.map(_.toVariable) zip args) : _*) + (ResultVariable() -> fi)
-        // println(substMap)
         val newBody = partialEvaluator(replace(substMap, post))
-        //println("I'm going to add this : " + newBody)
-        assertIfSafeOrDelay(newBody)//, isSafe)
+
+        val unrolling = new Unrolling(fi, newBody, true, pushLevel)
+        queue += unrolling
+
+        //assertIfSafeOrDelay(newBody)//, isSafe)
       }
 
       if(bodyInlined < Settings.unrollingLevel && fd.hasBody) {
         bodyInlined += 1
         val body = matchToIfThenElse(fd.body.get)
         val substMap = Map[Expr,Expr]((fd.args.map(_.toVariable) zip args) : _*)
-        val newBody = replace(substMap, body)
-        val theEquality = Equals(fi, partialEvaluator(newBody))
-        //println("I'm going to add this : " + theEquality)
-        assertIfSafeOrDelay(theEquality)
+        val newBody = partialEvaluator(replace(substMap, body))
+        val theEquality = Equals(fi, newBody)
+
+        val unrolling = new Unrolling(fi, newBody, false, pushLevel)
+        queue += unrolling
+
+        //assertIfSafeOrDelay(theEquality)
       }
     }
   }
 
   override def finalCheck : Boolean = safeBlockToAssertAxioms {
+    if(!queue.isEmpty) {
+      val smallest = queue.head.depth
+      while(!queue.isEmpty) { // && queue.head.depth == smallest) {
+        val highest : Unrolling = queue.dequeue()
+        val toConvertAndAssert = if(highest.isContract) {
+          highest.body
+        } else {
+          Equals(highest.invocation, highest.body)
+        }
+        assertAxiomASAP(toZ3Formula(z3, toConvertAndAssert).get, 0)
+      }
+    }
+//    stillToAssert = Nil
+    true
+/*
     if(stillToAssert.isEmpty) {
       true
     } else {
@@ -121,13 +158,17 @@ class Instantiator(val z3Solver: Z3Solver) extends Z3Theory(z3Solver.z3, "Instan
       stillToAssert = Nil
       true
     }
+*/
   }
 
   // This is concerned with how many new function calls the assertion is going
   // to introduce.
-  private def assertIfSafeOrDelay(ast: Expr, isSafe: Boolean = false) : Unit = {
-    stillToAssert = ((pushLevel, ast)) :: stillToAssert
-  }
+  // private def assertIfSafeOrDelay(ast: Expr, isSafe: Boolean = false) : Unit = {
+  //   println("I'm going to assert this at the next final check :")
+  //   println(ast)
+  //   println("BTW, I think you should know the depth of this thing is : " + measureADTChildrenDepth(ast))
+  //   stillToAssert = ((pushLevel, ast)) :: stillToAssert
+  // }
 
   // Assert as soon as possible and keep asserting as long as level is >= lvl.
   private def assertAxiomASAP(expr: Expr, lvl: Int) : Unit = assertAxiomASAP(toZ3Formula(z3, expr).get, lvl)
@@ -213,13 +254,15 @@ class Instantiator(val z3Solver: Z3Solver) extends Z3Theory(z3Solver.z3, "Instan
     pushLevel = 0
     canAssertAxiom = false
     toAssertASAP = Set.empty
-    stillToAssert = Nil
+ //   stillToAssert = Nil
+    queue.clear
   }
 
   private def safeBlockToAssertAxioms[A](block: => A): A = {
     canAssertAxiom = true
 
     if (toAssertASAP.nonEmpty) {
+      println("In a safe block. " + toAssertASAP.size + " axioms to add.")
       for ((lvl, ax) <- toAssertASAP) {
         if(lvl <= pushLevel) {
           assertAxiomNow(ax)
@@ -236,4 +279,23 @@ class Instantiator(val z3Solver: Z3Solver) extends Z3Theory(z3Solver.z3, "Instan
     canAssertAxiom = false
     result
   }
+
+  private object UnrollingOrdering extends Ordering[Unrolling] {
+    def compare(u1: Unrolling, u2: Unrolling) : Int = {
+      u2.depth - u1.depth
+    }
+  }
+
+  private class Unrolling(val invocation: FunctionInvocation, val body: Expr, val isContract: Boolean, val fromLevel: Int) {
+    // the maximal depth of selector calls in arguments of the invocation
+    val depth : Int = measureADTChildrenDepth(invocation)
+//    println("unrolling built. It has depth " + depth)
+  }
+  private object Unrolling {
+    def unapply(u: Unrolling) : Option[(FunctionInvocation,Expr)] = if(u != null) {
+      Some((u.invocation, u.body))
+    } else {
+      None
+    }
+  }
 }
-- 
GitLab