From a1f6ecc3e26fa1d7a6348dcc5cd464c083cdea5f Mon Sep 17 00:00:00 2001
From: Philippe Suter <philippe.suter@gmail.com>
Date: Mon, 7 Mar 2011 19:12:02 +0000
Subject: [PATCH] towards getting enumeration of testcases to work again

---
 src/purescala/AbstractZ3Solver.scala | 58 +++++++++++++++++++++++++++-
 src/purescala/FairZ3Solver.scala     | 31 +++++++++++----
 src/purescala/Z3Solver.scala         | 48 +++--------------------
 src/purescala/testcases/Main.scala   |  8 +++-
 4 files changed, 92 insertions(+), 53 deletions(-)

diff --git a/src/purescala/AbstractZ3Solver.scala b/src/purescala/AbstractZ3Solver.scala
index 92b88eb3f..3f50e723c 100644
--- a/src/purescala/AbstractZ3Solver.scala
+++ b/src/purescala/AbstractZ3Solver.scala
@@ -10,7 +10,17 @@ import TypeTrees._
 // This is just to factor out the things that are common in "classes that deal
 // with a Z3 instance"
 trait AbstractZ3Solver {
-  val reporter : Reporter
+  self: Solver =>
+
+  val reporter: Reporter
+
+  protected[purescala] var z3 : Z3Context
+  protected[purescala] var program : Program
+
+  def typeToSort(tt: TypeTree): Z3Sort
+  protected[purescala] var adtTesters: Map[CaseClassDef, Z3FuncDecl]
+  protected[purescala] var adtConstructors: Map[CaseClassDef, Z3FuncDecl]
+  protected[purescala] var adtFieldSelectors: Map[Identifier, Z3FuncDecl]
 
   protected[purescala] var exprToZ3Id : Map[Expr,Z3AST]
   protected[purescala] def fromZ3Formula(tree : Z3AST) : Expr
@@ -22,4 +32,50 @@ trait AbstractZ3Solver {
       case _ => None
     }
   }
+
+  protected[purescala] def solveWithBounds(vc: Expr, forValidity: Boolean) : (Option[Boolean], Map[Identifier, Expr]) 
+
+  protected[purescala] def boundValues : Unit = {
+    val lowerBound: Z3AST = z3.mkInt(Settings.testBounds._1, z3.mkIntSort)
+    val upperBound: Z3AST = z3.mkInt(Settings.testBounds._2, z3.mkIntSort)
+
+    def isUnbounded(field: VarDecl) : Boolean = field.getType match {
+      case Int32Type => true
+      case _ => false
+    }
+
+    def boundConstraint(boundVar: Z3AST) : Z3AST = {
+      lowerBound <= boundVar && boundVar <= upperBound
+    }
+
+    // for all recursive type roots
+    //   for all child ccd of a root
+    //     if ccd contains unbounded types
+    //       create bound vars (mkBound) for each field
+    //       create pattern that says (valueBounds(ccd(f1, ..)))
+    //       create axiom tree that says "values of unbounded types are within bounds"
+    //       assert axiom for the tree above
+
+    val roots = program.classHierarchyRoots
+    for (root <- roots) {
+      val children: List[CaseClassDef] = (root match {
+        case c: CaseClassDef => List(c)
+        case a: AbstractClassDef => a.knownChildren.filter(_.isInstanceOf[CaseClassDef]).map(_.asInstanceOf[CaseClassDef]).toList
+      })
+      for (child <- children) child match {
+        case CaseClassDef(id, parent, fields) =>
+          val unboundedFields = fields.filter(isUnbounded(_))
+          if (!unboundedFields.isEmpty) {
+            val boundVars = fields.zipWithIndex.map{case (f, i) => z3.mkBound(i, typeToSort(f.getType))}
+            val term = adtConstructors(child)(boundVars : _*)
+            val pattern = z3.mkPattern(term)
+            //val constraint = (fields zip boundVars).filter((p: (VarDecl, Z3AST)) => isUnbounded(p._1)).map((p: (VarDecl, Z3AST)) => boundConstraint(p._2)).foldLeft(z3.mkTrue)((a, b) => a && b)
+            val constraint = (fields zip boundVars).filter((p: (VarDecl, Z3AST)) => isUnbounded(p._1)).map((p: (VarDecl, Z3AST)) => boundConstraint(adtFieldSelectors(p._1.id)(term))).foldLeft(z3.mkTrue)((a, b) => a && b)
+            val axiom = z3.mkForAll(0, List(pattern), fields.zipWithIndex.map{case (f, i) => (z3.mkIntSymbol(i), typeToSort(f.getType))}, constraint)
+            println("Asserting: " + axiom)
+            z3.assertCnstr(axiom)
+          }
+      }
+    }
+  }
 }
diff --git a/src/purescala/FairZ3Solver.scala b/src/purescala/FairZ3Solver.scala
index 57dc5e33b..41c649c34 100644
--- a/src/purescala/FairZ3Solver.scala
+++ b/src/purescala/FairZ3Solver.scala
@@ -13,6 +13,8 @@ import scala.collection.mutable.{Set => MutableSet}
 class FairZ3Solver(val reporter: Reporter) extends Solver(reporter) with AbstractZ3Solver with Z3ModelReconstruction {
   assert(Settings.useFairInstantiator)
 
+  private final val UNKNOWNASSAT : Boolean = true
+
   val description = "Fair Z3 Solver"
   override val shortDescription = "Z3-f"
 
@@ -70,9 +72,9 @@ class FairZ3Solver(val reporter: Reporter) extends Solver(reporter) with Abstrac
   private var adtSorts: Map[ClassTypeDef, Z3Sort] = Map.empty
   private var fallbackSorts: Map[TypeTree, Z3Sort] = Map.empty
 
-  private var adtTesters: Map[CaseClassDef, Z3FuncDecl] = Map.empty
-  private var adtConstructors: Map[CaseClassDef, Z3FuncDecl] = Map.empty
-  private var adtFieldSelectors: Map[Identifier, Z3FuncDecl] = Map.empty
+  protected[purescala] var adtTesters: Map[CaseClassDef, Z3FuncDecl] = Map.empty
+  protected[purescala] var adtConstructors: Map[CaseClassDef, Z3FuncDecl] = Map.empty
+  protected[purescala] var adtFieldSelectors: Map[Identifier, Z3FuncDecl] = Map.empty
 
   private var reverseADTTesters: Map[Z3FuncDecl, CaseClassDef] = Map.empty
   private var reverseADTConstructors: Map[Z3FuncDecl, CaseClassDef] = Map.empty
@@ -224,10 +226,19 @@ class FairZ3Solver(val reporter: Reporter) extends Solver(reporter) with Abstrac
 
   def solve(vc: Expr) = decide(vc, true)
 
-  def decide(vc: Expr, forValidity: Boolean):Option[Boolean] = decideWithModel(vc, forValidity)._1
-  def decideWithModel(vc: Expr, forValidity: Boolean): (Option[Boolean], Map[Identifier,Expr]) = {
+  def solveWithBounds(vc: Expr, fv: Boolean) : (Option[Boolean], Map[Identifier,Expr]) = {
+    restartZ3
+    boundValues
+    println(z3.check)
+    decideWithModel(vc, fv)
+  }
+
+  def decide(vc: Expr, forValidity: Boolean):Option[Boolean] = {
     restartZ3
+    decideWithModel(vc, forValidity)._1
+  }
 
+  def decideWithModel(vc: Expr, forValidity: Boolean): (Option[Boolean], Map[Identifier,Expr]) = {
     val unrollingBank = new UnrollingBank
 
     lazy val varsInVC = variablesOf(vc) 
@@ -276,6 +287,7 @@ class FairZ3Solver(val reporter: Reporter) extends Solver(reporter) with Abstrac
           z3.assertCnstr(z3.mkAnd(blockingSetAsZ3 : _*))
       }
 
+      reporter.info(" - Running Z3 search...")
       val (answer, model, core) : (Option[Boolean], Z3Model, Seq[Z3AST]) = if(Settings.useCores) {
         println(blockingSetAsZ3)
         z3.checkAssumptions(blockingSetAsZ3 : _*)
@@ -284,8 +296,12 @@ class FairZ3Solver(val reporter: Reporter) extends Solver(reporter) with Abstrac
         (a, m, Seq.empty[Z3AST])
       }
 
-      reporter.info(" - Running Z3 search...")
-      (answer, model) match {
+      val reinterpretedAnswer = if(!UNKNOWNASSAT) answer else (answer match {
+        case None | Some(true) => Some(true)
+        case Some(false) => Some(false)
+      })
+
+      (reinterpretedAnswer, model) match {
         case (None, m) => { // UNKNOWN
           reporter.warning("Z3 doesn't know because: " + z3.getSearchFailure.message)
           foundDefinitiveAnswer = true
@@ -428,6 +444,7 @@ class FairZ3Solver(val reporter: Reporter) extends Solver(reporter) with Abstrac
         definitiveAnswer = None
         definitiveModel = Map.empty
         reporter.error("Max. number of iterations reached.")
+        println("Max. number of iterations reached.")
       }
     }
 
diff --git a/src/purescala/Z3Solver.scala b/src/purescala/Z3Solver.scala
index 30585f288..012a50039 100644
--- a/src/purescala/Z3Solver.scala
+++ b/src/purescala/Z3Solver.scala
@@ -83,9 +83,9 @@ class Z3Solver(val reporter: Reporter) extends Solver(reporter) with AbstractZ3S
   private var adtSorts: Map[ClassTypeDef, Z3Sort] = Map.empty
   private var fallbackSorts: Map[TypeTree, Z3Sort] = Map.empty
 
-  private var adtTesters: Map[CaseClassDef, Z3FuncDecl] = Map.empty
-  private var adtConstructors: Map[CaseClassDef, Z3FuncDecl] = Map.empty
-  private var adtFieldSelectors: Map[Identifier, Z3FuncDecl] = Map.empty
+  protected[purescala] var adtTesters: Map[CaseClassDef, Z3FuncDecl] = Map.empty
+  protected[purescala] var adtConstructors: Map[CaseClassDef, Z3FuncDecl] = Map.empty
+  protected[purescala] var adtFieldSelectors: Map[Identifier, Z3FuncDecl] = Map.empty
 
   private var reverseADTTesters: Map[Z3FuncDecl, CaseClassDef] = Map.empty
   private var reverseADTConstructors: Map[Z3FuncDecl, CaseClassDef] = Map.empty
@@ -175,46 +175,6 @@ class Z3Solver(val reporter: Reporter) extends Solver(reporter) with AbstractZ3S
     // ...and now everything should be in there...
   }
 
-  private def boundValues : Unit = {
-    val lowerBound: Z3AST = z3.mkInt(Settings.testBounds._1, z3.mkIntSort)
-    val upperBound: Z3AST = z3.mkInt(Settings.testBounds._2, z3.mkIntSort)
-
-    def isUnbounded(field: VarDecl) : Boolean = field.getType match {
-      case Int32Type => true
-      case _ => false
-    }
-
-    def boundConstraint(boundVar: Z3AST) : Z3AST = {
-      lowerBound <= boundVar && boundVar <= upperBound
-    }
-
-    // for all recursive type roots
-    //   for all child ccd of a root
-    //     if ccd contains unbounded types
-    //       create bound vars (mkBound) for each field
-    //       create pattern that says (valueBounds(ccd(f1, ..)))
-    //       create axiom tree that says "values of unbounded types are within bounds"
-    //       assert axiom for the tree above
-
-    val roots = program.classHierarchyRoots
-    for (root <- roots) {
-      val children: List[CaseClassDef] = (root match {
-        case c: CaseClassDef => List(c)
-        case a: AbstractClassDef => a.knownChildren.filter(_.isInstanceOf[CaseClassDef]).map(_.asInstanceOf[CaseClassDef]).toList
-      })
-      for (child <- children) child match {
-        case CaseClassDef(id, parent, fields) =>
-          val unboundedFields = fields.filter(isUnbounded(_))
-          if (!unboundedFields.isEmpty) {
-            val boundVars = fields.zipWithIndex.map{case (f, i) => z3.mkBound(i, typeToSort(f.getType))}
-            val pattern = z3.mkPattern(adtConstructors(child)(boundVars: _*))
-            val constraint = (fields zip boundVars).filter((p: (VarDecl, Z3AST)) => isUnbounded(p._1)).map((p: (VarDecl, Z3AST)) => boundConstraint(p._2)).foldLeft(z3.mkTrue)((a, b) => a && b)
-            val axiom = z3.mkForAll(0, List(pattern), fields.zipWithIndex.map{case (f, i) => (z3.mkIntSymbol(i), typeToSort(f.getType))}, constraint)
-            z3.assertCnstr(axiom)
-          }
-      }
-    }
-  }
 
   def isKnownDef(funDef: FunDef) : Boolean = if(useAnyInstantiator) {
     instantiator.isKnownDef(funDef)
@@ -456,6 +416,8 @@ class Z3Solver(val reporter: Reporter) extends Solver(reporter) with AbstractZ3S
     decideIterativeWithModel(vc, forValidity)._1
   }
 
+  def solveWithBounds(vc: Expr, fv: Boolean) : (Option[Boolean], Map[Identifier,Expr]) = decideIterativeWithBounds(vc, fv)
+
   def decideIterativeWithBounds(vc: Expr, forValidity: Boolean) : (Option[Boolean], Map[Identifier, Expr]) = {
     restartZ3
     boundValues
diff --git a/src/purescala/testcases/Main.scala b/src/purescala/testcases/Main.scala
index 9eb39831c..b67b26cad 100644
--- a/src/purescala/testcases/Main.scala
+++ b/src/purescala/testcases/Main.scala
@@ -19,7 +19,11 @@ class Main(reporter : Reporter) extends Analyser(reporter) {
 
     reporter.info("Running testcase generation...")
 
-    val solver = new purescala.Z3Solver(reporter)
+    val solver = if(Settings.useFairInstantiator) {
+      new purescala.FairZ3Solver(reporter)
+    } else {
+      new purescala.Z3Solver(reporter)
+    }
     solver.setProgram(program)
     
     def writeToFile(filename: String, content: String) : Unit = {
@@ -39,7 +43,7 @@ class Main(reporter : Reporter) extends Analyser(reporter) {
       var noMoreModels = false
       for (i <- 1 to Settings.nbTestcases if !noMoreModels) {
         // reporter.info("Current constraints: " + constraints)
-        val argMap = solver.decideIterativeWithBounds(And(prec, constraints), false)
+        val argMap = solver.solveWithBounds(And(prec, constraints), false)
         argMap match {
           case (Some(true), _) => noMoreModels = true
           case (_ , map) =>
-- 
GitLab