From 9dded0b9e0a98cf3764887525b7a66a2e5ab1596 Mon Sep 17 00:00:00 2001
From: Philippe Suter <philippe.suter@gmail.com>
Date: Sun, 14 Nov 2010 22:19:06 +0000
Subject: [PATCH] now generating VCs for preconditions and pattern-matching
 exhaustiveness...

---
 src/purescala/DefaultTactic.scala             | 61 ++++++++++++++++++-
 src/purescala/PartialEvaluator.scala          |  6 +-
 src/purescala/Trees.scala                     | 18 ++++++
 src/purescala/VerificationCondition.scala     |  4 +-
 src/purescala/Z3Solver.scala                  |  6 +-
 .../z3plugins/instantiator/Instantiator.scala | 16 ++---
 testcases/BSTSimpler.scala                    | 51 +++++++++++++---
 testcases/ExprComp.scala                      | 22 ++++++-
 testcases/RedBlackTree.scala                  |  4 +-
 9 files changed, 162 insertions(+), 26 deletions(-)

diff --git a/src/purescala/DefaultTactic.scala b/src/purescala/DefaultTactic.scala
index da1f6c9f0..1a4338cdb 100644
--- a/src/purescala/DefaultTactic.scala
+++ b/src/purescala/DefaultTactic.scala
@@ -5,6 +5,8 @@ import purescala.Trees._
 import purescala.Definitions._
 import Extensions.Tactic
 
+import scala.collection.mutable.{Map => MutableMap}
+
 class DefaultTactic(reporter: Reporter) extends Tactic(reporter) {
     val description = "Default verification condition generation approach"
     override val shortDescription = "default"
@@ -78,16 +80,71 @@ class DefaultTactic(reporter: Reporter) extends Tactic(reporter) {
         Seq(new VerificationCondition(theExpr, functionDefinition, VCKind.Postcondition, this))
       }
     }
+  
+    private val errConds : MutableMap[FunDef,Seq[VerificationCondition]] = MutableMap.empty
+    private def errorConditions(function: FunDef) : Seq[VerificationCondition] = {
+      if(errConds.isDefinedAt(function)) {
+        errConds(function)
+      } else {
+        val conds = if(function.hasBody) {
+          val bodyToUse = if(function.hasPrecondition) {
+            IfExpr(function.precondition.get, function.body.get, Error("XXX").setType(function.returnType)).setType(function.returnType)
+          } else {
+            function.body.get
+          }
+          val withExplicit = expandLets(explicitPreconditions(matchToIfThenElse(bodyToUse)))
+  
+          val allPathConds = collectWithPathCondition((_ match { case Error(_) => true ; case _ => false }), withExplicit)
+  
+          allPathConds.filter(_._2 != Error("XXX")).map(pc => pc._2 match {
+            case Error("precondition violated") => new VerificationCondition(Not(And(pc._1)), function, VCKind.Precondition, this)
+            case Error("non-exhaustive match") => new VerificationCondition(Not(And(pc._1)), function, VCKind.ExhaustiveMatch, this)
+            case _ => scala.Predef.error("Don't know what to do with this path condition target: " + pc._2)
+          }).toSeq
+        } else {
+          Seq.empty
+        }
+        errConds(function) = conds
+        conds
+      }
+    }
 
     def generatePreconditions(function: FunDef) : Seq[VerificationCondition] = {
-      Seq.empty
+      errorConditions(function).filter(_.kind == VCKind.Precondition)
     }
 
     def generatePatternMatchingExhaustivenessChecks(function: FunDef) : Seq[VerificationCondition] = {
-      Seq.empty
+      errorConditions(function).filter(_.kind == VCKind.ExhaustiveMatch)
     }
 
     def generateMiscCorrectnessConditions(function: FunDef) : Seq[VerificationCondition] = {
       Seq.empty
     }
+
+    // prec: there should be no lets and no pattern-matching in this expression
+    private def collectWithPathCondition(matcher: Expr=>Boolean, expression: Expr) : Set[(Seq[Expr],Expr)] = {
+      var collected : Set[(Seq[Expr],Expr)] = Set.empty
+
+      def rec(expr: Expr, path: List[Expr]) : Unit = {
+        if(matcher(expr)) {
+          collected = collected + ((path.reverse, expr))
+        }
+
+        expr match {
+          case IfExpr(cond, then, elze) => {
+            rec(cond, path)
+            rec(then, cond :: path)
+            rec(elze, Not(cond) :: path)
+          }
+          case NAryOperator(args, _) => args.foreach(rec(_, path))
+          case BinaryOperator(t1, t2, _) => rec(t1, path); rec(t2, path)
+          case UnaryOperator(t, _) => rec(t, path)
+          case t : Terminal => ;
+          case _ => scala.Predef.error("Unhandled tree in collectWithPathCondition : " + expr)
+        }
+      }
+
+      rec(expression, Nil)
+      collected
+    }
 }
diff --git a/src/purescala/PartialEvaluator.scala b/src/purescala/PartialEvaluator.scala
index aabb6c437..8a8eca843 100644
--- a/src/purescala/PartialEvaluator.scala
+++ b/src/purescala/PartialEvaluator.scala
@@ -8,12 +8,12 @@ import TypeTrees._
 class PartialEvaluator(val program: Program) {
   val reporter = Settings.reporter
 
-  def apply0(expression:Expr) : Expr = expression
+  def apply(expression:Expr) : Expr = expression
   // Simplifies by partially evaluating.
   // Of course, I still have to decide what 'simplified' means.
-  def apply(expression: Expr) : Expr = {
+  def apply0(expression: Expr) : Expr = {
     def rec(expr: Expr, letMap: Map[Identifier,Expr]) : Expr = {
-//      println("****** rec called on " + expr + " *********")
+      println("****** rec called on " + expr + " *********")
       (expr match {
       case i @ IfExpr(cond, then, elze) => {
         val simpCond = rec(cond, letMap)
diff --git a/src/purescala/Trees.scala b/src/purescala/Trees.scala
index d098e1768..37a056542 100644
--- a/src/purescala/Trees.scala
+++ b/src/purescala/Trees.scala
@@ -956,6 +956,24 @@ object Trees {
     })
   }
 
+  def explicitPreconditions(expr: Expr) : Expr = {
+    def rewriteFunctionCall(e: Expr) : Option[Expr] = e match {
+      case fi @ FunctionInvocation(fd, args) if(fd.hasPrecondition) => {
+        val fTpe = fi.getType
+        val prec = matchToIfThenElse(fd.precondition.get)
+        val newLetIDs = fd.args.map(a => FreshIdentifier("precarg_" + a.id.name, true).setType(a.tpe))
+        val substMap = Map[Expr,Expr]((fd.args.map(_.toVariable) zip newLetIDs.map(Variable(_))) : _*)
+        val newPrec = replace(substMap, prec)
+        val newThen = FunctionInvocation(fd, newLetIDs.map(_.toVariable)).setType(fTpe)
+        val ifExpr: Expr = IfExpr(newPrec, newThen, Error("precondition violated").setType(fTpe)).setType(fTpe)
+        Some((newLetIDs zip args).foldRight(ifExpr)((iap,e) => Let(iap._1, iap._2, e)))
+      }
+      case _ => None
+    }
+
+    searchAndReplaceDFS(rewriteFunctionCall)(expr)
+  }
+
   private var matchConverterCache = new scala.collection.mutable.HashMap[Expr,Expr]()
   /** Rewrites all pattern-matching expressions into if-then-else expressions,
    * with additional error conditions. Does not introduce additional variables.
diff --git a/src/purescala/VerificationCondition.scala b/src/purescala/VerificationCondition.scala
index b64796f56..d43458ba4 100644
--- a/src/purescala/VerificationCondition.scala
+++ b/src/purescala/VerificationCondition.scala
@@ -38,7 +38,9 @@ class VerificationCondition(val condition: Expr, val funDef: FunDef, val kind: V
 
 object VerificationCondition {
   val infoFooter : String = "╚" + ("═" * 69) + "╝"
-  val infoHeader : String = "╔══ Summary " + ("═" * 58) + "╗"
+  val infoHeader : String = ". ┌─────────┐\n" +
+                            "╔═╡ Summary ╞" + ("═" * 57) + "╗\n" +
+                            "║ └─────────┘" + (" " * 57) + "║"
 }
 
 object VCKind extends Enumeration {
diff --git a/src/purescala/Z3Solver.scala b/src/purescala/Z3Solver.scala
index 268ce8a89..ca2f0f878 100644
--- a/src/purescala/Z3Solver.scala
+++ b/src/purescala/Z3Solver.scala
@@ -57,6 +57,8 @@ class Z3Solver(val reporter: Reporter) extends Solver(reporter) with Z3ModelReco
     if (useBAPA) bapa = new BAPATheoryType(z3)
     if (useInstantiator) instantiator = new Instantiator(this)
 
+    exprToZ3Id = Map.empty
+    z3IdToExpr = Map.empty
     counter = 0
     prepareSorts
     prepareFunctions
@@ -573,7 +575,9 @@ class Z3Solver(val reporter: Reporter) extends Solver(reporter) with Z3ModelReco
       case Z3AppAST(decl, args) => {
         val argsSize = args.size
         if(argsSize == 0 && z3IdToExpr.isDefinedAt(t)) {
-          z3IdToExpr(t)
+          val toRet = z3IdToExpr(t)
+          // println("Map says I should replace " + t + " by " + toRet)
+          toRet
         } else if(isKnownDecl(decl)) {
           val fd = functionDeclToDef(decl)
           assert(fd.args.size == argsSize)
diff --git a/src/purescala/z3plugins/instantiator/Instantiator.scala b/src/purescala/z3plugins/instantiator/Instantiator.scala
index 65042fdce..5f9239a16 100644
--- a/src/purescala/z3plugins/instantiator/Instantiator.scala
+++ b/src/purescala/z3plugins/instantiator/Instantiator.scala
@@ -64,7 +64,7 @@ class Instantiator(val z3Solver: Z3Solver) extends Z3Theory(z3Solver.z3, "Instan
   }
 
   override def newApp(ast: Z3AST) : Unit = {
-    examineAndUnroll(ast)
+    // examineAndUnroll(ast)
   }
 
   override def newRelevant(ast: Z3AST) : Unit = {
@@ -72,13 +72,15 @@ class Instantiator(val z3Solver: Z3Solver) extends Z3Theory(z3Solver.z3, "Instan
   }
 
   private var bodyInlined : Int = 0
-  def examineAndUnroll(ast: Z3AST) : Unit = if(bodyInlined < Settings.unrollingLevel) {
+  def examineAndUnroll(ast: Z3AST, allFunctions: Boolean = false) : Unit = if(bodyInlined < Settings.unrollingLevel) {
     val aps = fromZ3Formula(ast)
-    //val fis = functionCallsOf(aps)
-
-    val fis : Set[FunctionInvocation] = aps match {
-      case f @ FunctionInvocation(_,_) => Set(f)
-      case _ => Set.empty
+    val fis : Set[FunctionInvocation] = if(allFunctions) {
+      functionCallsOf(aps)
+    } else {
+      aps match {
+        case f @ FunctionInvocation(_,_) => Set(f)
+        case _ => Set.empty
+      }
     }
 
     //println("As Purescala: " + aps)
diff --git a/testcases/BSTSimpler.scala b/testcases/BSTSimpler.scala
index 83fa55e2d..98c55124e 100644
--- a/testcases/BSTSimpler.scala
+++ b/testcases/BSTSimpler.scala
@@ -6,9 +6,50 @@ object BSTSimpler {
   case class Node(left: Tree, value: Int, right: Tree) extends Tree
   case class Leaf() extends Tree
 
+  def size(t : Tree) : Int = (t match {
+    case Leaf() => 1
+    case Node(l,_,r) => size(l) + 1 + size(r)
+  }) ensuring(_ >= 1)
+
+  sealed abstract class IntOpt
+  case class Some(value: Int) extends IntOpt
+  case class None() extends IntOpt
+
+  sealed abstract class TripleAbs
+  case class Triple(lmax: IntOpt, isSorted: Boolean, rmin: IntOpt) extends TripleAbs
+
+  sealed abstract class TriplePairAbs
+  case class TriplePair(left: TripleAbs, right: TripleAbs) extends TriplePairAbs
+
+  def isBST(tree: Tree) : Boolean = isBST0(tree) match {
+    case Triple(_, v, _) => v
+  }
+
+  def isBST0(tree: Tree) : TripleAbs = tree match {
+    case Leaf() => Triple(None(), true, None())
+
+    case Node(l, v, r) => TriplePair(isBST0(l), isBST0(r)) match {
+      case TriplePair(Triple(None(),t1,None()),Triple(None(),t2,None()))
+        if(t1 && t2) =>
+          Triple(Some(v),true,Some(v))
+      case TriplePair(Triple(Some(minL),t1,Some(maxL)),Triple(None(),t2,None()))
+        if(t1 && t2 && minL <= maxL && maxL < v) =>
+          Triple(Some(minL),true,Some(v))
+      case TriplePair(Triple(None(),t1,None()),Triple(Some(minR),t2,Some(maxR)))
+        if(t1 && t2 && minR <= maxR && v < minR) =>
+          Triple(Some(v),true,Some(maxR))
+      case TriplePair(Triple(Some(minL),t1,Some(maxL)),Triple(Some(minR),t2,Some(maxR)))
+        if(t1 && t2 && minL <= maxL && minR <= maxR && maxL < v && v < minR) =>
+          Triple(Some(minL),true,Some(maxR))
+
+      case _ => Triple(None(),false,None())
+    }
+  }
+
   def emptySet(): Tree = Leaf()
 
   def insert(tree: Tree, value: Int): Node = {
+    require(size(tree) <= 1 && isBST(tree))
     tree match {
       case Leaf() => Node(Leaf(), value, Leaf())
       case Node(l, v, r) => (if (v < value) {
@@ -19,14 +60,8 @@ object BSTSimpler {
         Node(l, v, r)
       })
     }
-  } ensuring (contents(_) == contents(tree) ++ Set(value))
-
-  def dumbInsert(tree: Tree): Node = {
-    tree match {
-      case Leaf() => Node(Leaf(), 0, Leaf())
-      case Node(l, e, r) => Node(dumbInsert(l), e, r)
-    }
-  } ensuring (contents(_) == contents(tree) ++ Set(0))
+//  } ensuring (contents(_) == contents(tree) ++ Set(value))
+  } ensuring(isBST(_))
 
   def createRoot(v: Int): Node = {
     Node(Leaf(), v, Leaf())
diff --git a/testcases/ExprComp.scala b/testcases/ExprComp.scala
index aaf9499be..91afcd261 100644
--- a/testcases/ExprComp.scala
+++ b/testcases/ExprComp.scala
@@ -16,6 +16,11 @@ object ExprComp {
   case class Constant(v : Value) extends Expr
   case class Binary(exp1 : Expr, op : BinOp, exp2 : Expr) extends Expr
 
+  def exprSize(e: Expr) : Int = (e match {
+    case Constant(_) => 1
+    case Binary(e1, _, e2) => 1 + exprSize(e1) + exprSize(e2)
+  }) ensuring(_ >= 1)
+
   def evalOp(v1 : Value, op : BinOp, v2 : Value) : Value = op match {
     case Plus() => Value(v1.v + v2.v)
     case Times() => Value(v1.v * v2.v)
@@ -40,12 +45,22 @@ object ExprComp {
   case class EProgram() extends Program
   case class NProgram(first : Instruction, rest : Program) extends Program
 
+  def progSize(p: Program) : Int = (p match {
+    case EProgram() => 0
+    case NProgram(_, r) => 1 + progSize(r)
+  }) ensuring(_ >= 0)
+
   // Value stack
 
   sealed abstract class ValueStack
   case class EStack() extends ValueStack
   case class NStack(v : Value, rest : ValueStack) extends ValueStack
 
+  def stackSize(vs: ValueStack) : Int = (vs match {
+    case EStack() => 0
+    case NStack(_, r) => 1 + stackSize(r)
+  }) ensuring(_ >= 0)
+
   // Outcomes of running the program
 
   sealed abstract class Outcome
@@ -87,16 +102,18 @@ object ExprComp {
 
 /*
   def property(e : Expr, acc : Program, vs : ValueStack) : Boolean = {
+    require(exprSize(e) <= 1 && progSize(acc) <= 1 && stackSize(vs) <= 1)
     run(compile(e, acc), vs) == Ok(NStack(eval(e), vs))
   } holds
 
+*/
   def property0() : Boolean = {
     val e = Binary(Constant(Value(3)), Plus(), Constant(Value(5)))
     val vs = EStack()
     val acc = EProgram()
     run(compile(e, acc), vs) == Ok(NStack(eval(e), vs))
   } holds
-
+/*
   def main(args : Array[String]) = {
     val e = Binary(Constant(Value(100)), Times(), Binary(Constant(Value(3)), Plus(), Constant(Value(5))))
     val vs = EStack()
@@ -106,5 +123,6 @@ object ExprComp {
     println(Ok(NStack(eval(e), vs)))
     assert(property(e,acc,vs))
   }
-*/
+  */
+
 }
diff --git a/testcases/RedBlackTree.scala b/testcases/RedBlackTree.scala
index 04ccb08fa..70f3d6316 100644
--- a/testcases/RedBlackTree.scala
+++ b/testcases/RedBlackTree.scala
@@ -27,8 +27,8 @@ object RedBlackTree {
       else if (x == y) Node(c,a,y,b)
       else             balance(c,a,y,ins(x, b))
   }) ensuring (res => (
-             content(res) == content(t) ++ Set(x) 
-//          && size(t) <= size(res) && size(res) < size(t) + 2)
+             (content(res) == content(t) ++ Set(x))
+          && (size(t) <= size(res) && size(res) < size(t) + 2)
               ))
 
   def add(x: Int, t: Tree): Tree = {
-- 
GitLab