From c7cecedab763c7817433971303b57acbc7f4c75a Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ali=20Sinan=20K=C3=B6ksal?= <alisinan@gmail.com>
Date: Fri, 14 Jan 2011 21:16:26 +0000
Subject: [PATCH] Testcase generation extension.

---
 src/funcheck/FunCheckPlugin.scala         |  4 +
 src/purescala/Settings.scala              |  2 +
 src/purescala/Trees.scala                 | 44 +++++++++++
 src/purescala/Z3ModelReconstruction.scala | 20 ++++-
 src/purescala/Z3Solver.scala              | 20 +++--
 src/purescala/testcases/Main.scala        | 95 ++++++++++++++++++++++-
 6 files changed, 173 insertions(+), 12 deletions(-)

diff --git a/src/funcheck/FunCheckPlugin.scala b/src/funcheck/FunCheckPlugin.scala
index 1d27e3f21..89af09e58 100644
--- a/src/funcheck/FunCheckPlugin.scala
+++ b/src/funcheck/FunCheckPlugin.scala
@@ -27,6 +27,8 @@ class FunCheckPlugin(val global: Global) extends Plugin {
     "  -P:funcheck:axioms             Generate simple forall axioms for recursive functions when possible" + "\n" + 
     "  -P:funcheck:tolerant           Silently extracts non-pure function bodies as ''unknown''" + "\n" +
     "  -P:funcheck:nobapa             Disable BAPA Z3 extension" + "\n" +
+    "  -P:funcheck:impure             Generate testcases only for impure functions" + "\n" +
+    "  -P:funcheck:testcases=[1,2]    Number of testcases to generate per function" + "\n" +
     "  -P:funcheck:quiet              No info and warning messages from the extensions" + "\n" +
     "  -P:funcheck:XP                 Enable weird transformations and other bug-producing features" + "\n" +
     "  -P:funcheck:PLDI               PLDI 2011 settings. Now frozen. Not completely functional. See CAV." + "\n" +
@@ -46,6 +48,7 @@ class FunCheckPlugin(val global: Global) extends Plugin {
         case "nodefaults" =>                     purescala.Settings.runDefaultExtensions = false
         case "axioms"     =>                     purescala.Settings.noForallAxioms = false
         case "nobapa"     =>                     purescala.Settings.useBAPA = false
+        case "impure"     =>                     purescala.Settings.impureTestcases = true
         case "newPM"      =>                     { println("''newPM'' is no longer a command-line option, because the new translation is now on by default."); System.exit(0) }
         case "XP"         =>                     purescala.Settings.experimental = true
         case "PLDI"       =>                     { purescala.Settings.experimental = true; purescala.Settings.useInstantiator = true; purescala.Settings.useFairInstantiator = false; purescala.Settings.useBAPA = false; purescala.Settings.zeroInlining = true }
@@ -53,6 +56,7 @@ class FunCheckPlugin(val global: Global) extends Plugin {
         case s if s.startsWith("unrolling=") =>  purescala.Settings.unrollingLevel = try { s.substring("unrolling=".length, s.length).toInt } catch { case _ => 0 }
         case s if s.startsWith("functions=") =>  purescala.Settings.functionsToAnalyse = Set(splitList(s.substring("functions=".length, s.length)): _*)
         case s if s.startsWith("extensions=") => purescala.Settings.extensionNames = splitList(s.substring("extensions=".length, s.length))
+        case s if s.startsWith("testcases=") =>  purescala.Settings.nbTestcases = try { s.substring("testcases=".length, s.length).toInt } catch { case _ => 1 }
         case _ => error("Invalid option: " + option)
       }
     }
diff --git a/src/purescala/Settings.scala b/src/purescala/Settings.scala
index c3ffa74ec..720b3ea58 100644
--- a/src/purescala/Settings.scala
+++ b/src/purescala/Settings.scala
@@ -14,6 +14,8 @@ object Settings {
   var unrollingLevel: Int = 0
   var zeroInlining : Boolean = false
   var useBAPA: Boolean = true
+  var impureTestcases: Boolean = false
+  var nbTestcases: Int = 1
   var useInstantiator: Boolean = false
   var useFairInstantiator: Boolean = false
   def useAnyInstantiator : Boolean = useInstantiator || useFairInstantiator
diff --git a/src/purescala/Trees.scala b/src/purescala/Trees.scala
index 8be7aaeb8..0ffbd5311 100644
--- a/src/purescala/Trees.scala
+++ b/src/purescala/Trees.scala
@@ -1051,4 +1051,48 @@ object Trees {
     
     rec(expression,Map.empty)
   }
+
+  private val random = new scala.util.Random()
+
+  def randomValue(v: Variable) : Expr = randomValue(v.getType)
+  def simplestValue(v: Variable) : Expr = simplestValue(v.getType)
+
+  private def randomValue(tpe: TypeTree) : Expr = tpe match {
+    case Int32Type => IntLiteral(random.nextInt(42))
+    case BooleanType => BooleanLiteral(random.nextBoolean())
+    case AbstractClassType(acd) =>
+      val children = acd.knownChildren
+      randomValue(classDefToClassType(children(random.nextInt(children.size))))
+    case CaseClassType(cd) =>
+      val fields = cd.fields
+      CaseClass(cd, fields.map(f => randomValue(f.getType)))
+    case _ => throw new Exception("I can't choose random value for type " + tpe)
+  }
+
+  private def simplestValue(tpe: TypeTree) : Expr = tpe match {
+    case Int32Type => IntLiteral(0)
+    case BooleanType => BooleanLiteral(false)
+    case AbstractClassType(acd) => {
+      val children = acd.knownChildren
+      val simplerChildren = children.filter{
+        case ccd @ CaseClassDef(id, Some(parent), fields) =>
+          !fields.exists(vd => vd.getType match {
+            case AbstractClassType(fieldAcd) => acd == fieldAcd
+            case CaseClassType(fieldCcd) => ccd == fieldCcd
+            case _ => false
+          })
+        case _ => false
+      }
+      def orderByNumberOfFields(fst: ClassTypeDef, snd: ClassTypeDef) : Boolean = (fst, snd) match {
+        case (CaseClassDef(_, _, flds1), CaseClassDef(_, _, flds2)) => flds1.size <= flds2.size
+        case _ => true
+      }
+      val orderedChildren = simplerChildren.sortWith(orderByNumberOfFields)
+      simplestValue(classDefToClassType(orderedChildren.head))
+    }
+    case CaseClassType(ccd) =>
+      val fields = ccd.fields
+      CaseClass(ccd, fields.map(f => simplestValue(f.getType)))
+    case _ => throw new Exception("I can't choose simplest value for type " + tpe)
+  }
 }
diff --git a/src/purescala/Z3ModelReconstruction.scala b/src/purescala/Z3ModelReconstruction.scala
index c7d94612a..c9c046e79 100644
--- a/src/purescala/Z3ModelReconstruction.scala
+++ b/src/purescala/Z3ModelReconstruction.scala
@@ -10,6 +10,9 @@ import TypeTrees._
 trait Z3ModelReconstruction {
   self: Z3Solver =>
 
+  private val AUTOCOMPLETEMODELS : Boolean = true
+  private val SIMPLESTCOMPLETION : Boolean = false // if true, use 0, Nil(), etc., else random
+
   def modelValue(model: Z3Model, id: Identifier, tpe: TypeTree = null) : Option[Expr] = {
     val expectedType = if(tpe == null) id.getType else tpe
     
@@ -32,7 +35,22 @@ trait Z3ModelReconstruction {
     for(id <- ids) {
       modelValue(model, id) match {
         case None => ; // can't do much here
-        case Some(ex) => asMap = asMap + ((id -> ex))
+        case Some(ex) =>
+          if (AUTOCOMPLETEMODELS) {
+            ex match {
+              case v @ Variable(exprId) if exprId == id =>
+                if (SIMPLESTCOMPLETION) {
+                  asMap = asMap + ((id -> simplestValue(id.toVariable)))
+                  reporter.info("Completing variable '" + id + "' to simplest value")
+                } else {
+                  asMap = asMap + ((id -> randomValue(id.toVariable)))
+                  reporter.info("Completing variable '" + id + "' to random value")
+                }
+              case _ => asMap = asMap + ((id -> ex))
+            }
+          } else {
+            asMap = asMap + ((id -> ex))
+          }
       }
     }
     asMap
diff --git a/src/purescala/Z3Solver.scala b/src/purescala/Z3Solver.scala
index dc2b0730b..44bffb448 100644
--- a/src/purescala/Z3Solver.scala
+++ b/src/purescala/Z3Solver.scala
@@ -411,6 +411,10 @@ class Z3Solver(val reporter: Reporter) extends Solver(reporter) with Z3ModelReco
   }
 
   def decideIterative(vc: Expr, forValidity: Boolean) : Option[Boolean] = {
+    decideIterativeWithModel(vc, forValidity)._1
+  }
+
+  def decideIterativeWithModel(vc: Expr, forValidity: Boolean) : (Option[Boolean], Map[Identifier, Expr]) = {
     restartZ3
     assert(instantiator != null)
     assert(!useBAPA)
@@ -424,14 +428,14 @@ class Z3Solver(val reporter: Reporter) extends Solver(reporter) with Z3ModelReco
     val toConvert = if (forValidity) negate(vc) else vc
     val toCheckAgainstModels = toConvert
 
-    val result = toZ3Formula(z3, toConvert) match {
-      case None => None // means it could not be translated
+    val result : (Option[Boolean], Map[Identifier, Expr]) = toZ3Formula(z3, toConvert) match {
+      case None => (None, Map.empty) // means it could not be translated
       case Some(z3f) => {
         z3.assertCnstr(z3f)
 
         // THE LOOP STARTS HERE.
         var foundDefinitiveSolution : Boolean = false
-        var finalResult : Option[Boolean] = None
+        var finalResult : (Option[Boolean], Map[Identifier, Expr]) = (None, Map.empty)
 
         while(!foundDefinitiveSolution && instantiator.possibleContinuation) {
           instantiator.increaseSearchDepth()
@@ -456,19 +460,19 @@ class Z3Solver(val reporter: Reporter) extends Solver(reporter) with Z3ModelReco
                   reporter.error("Counter-example found and confirmed:")
                   reporter.error(modelAsString)
                   foundDefinitiveSolution = true
-                  finalResult = Some(false)
+                  finalResult = (Some(false), asMap)
                 }
                 case InfiniteComputation() => {
                   reporter.info("Model seems to lead to divergent computation.")
                   reporter.error(modelAsString)
                   foundDefinitiveSolution = true
-                  finalResult = None
+                  finalResult = (None, asMap)
                 }
                 case RuntimeError(msg) => {
                   reporter.info("Model leads to runtime error: " + msg)
                   reporter.error(modelAsString)
                   foundDefinitiveSolution = true
-                  finalResult = Some(false)
+                  finalResult = (Some(false), asMap)
                 }
                 case t @ TypeError(_,_) => {
                   scala.Predef.error("Type error in model evaluation.\n" + t.msg)
@@ -482,13 +486,13 @@ class Z3Solver(val reporter: Reporter) extends Solver(reporter) with Z3ModelReco
             case (Some(false), _) => {
               // This means a definitive proof of unsatisfiability has been found.
               foundDefinitiveSolution = true
-              finalResult = Some(true)
+              finalResult = (Some(true), Map.empty)
             }
 
             case (None, m) => {
               reporter.warning("Iterative Z3 gave up because: " + z3.getSearchFailure.message)
               foundDefinitiveSolution = true
-              finalResult = None
+              finalResult = (None, modelToMap(m, varsInVC))
             }
           }
         }
diff --git a/src/purescala/testcases/Main.scala b/src/purescala/testcases/Main.scala
index 593f3d48f..546cdd0ca 100644
--- a/src/purescala/testcases/Main.scala
+++ b/src/purescala/testcases/Main.scala
@@ -4,20 +4,109 @@ import purescala.Reporter
 import purescala.Trees._
 import purescala.Definitions._
 import purescala.Extensions._
+import purescala.Settings
+import purescala.Common.Identifier
 
-class Main(reporter: Reporter) extends Analyser(reporter) {
+class Main(reporter : Reporter) extends Analyser(reporter) {
   val description = "Testcase generation from preconditions"
   override val shortDescription = "testcases"
 
   def analyse(program : Program) : Unit = {
     // things that we could control with options:
     //   - generate for all functions or just impure
-    //   - number of cases per function
     //   - do not generate for private functions (check FunDef.isPrivate)
+    //   - number of cases per function
 
     reporter.info("Running testcase generation...")
 
-    // when you build the solver, call setProgram !
+    val solver = new purescala.Z3Solver(reporter)
+    solver.setProgram(program)
+    
+    def writeToFile(filename: String, content: String) : Unit = {
+      try {
+        val fw = new java.io.FileWriter(filename)
+        fw.write(content)
+        fw.close
+      } catch {
+        case e => reporter.error("There was an error while generating the test file" + filename)
+      }
+    }
+
+    def generateTestInput(funDef: FunDef) : Seq[Seq[Expr]] = {
+      var constraints: Expr = BooleanLiteral(true)
+      val prec = funDef.precondition.getOrElse(BooleanLiteral(true))
+      var inputList: List[Seq[Expr]] = Nil
+      for (i <- 1 to Settings.nbTestcases) {
+        reporter.info("Current constraints: " + constraints)
+        val argMap = solver.decideIterativeWithModel(And(prec, constraints), false)
+        argMap match {
+          case (Some(true), _) => None
+          case (_ , map) =>
+            reporter.info("Solver returned the following assignment: " + map)
+            val testInput = (for (arg <- funDef.args) yield {
+              map.get(arg.id) match {
+                case Some(value) => value
+                case None => randomValue(arg.toVariable)
+              }
+            })
+            inputList = testInput :: inputList
+            val newConstraints = And(funDef.args.map(_.toVariable).zip(testInput).map{
+              case (variable, value) => Equals(variable, value)
+            })
+            constraints = And(constraints, negate(newConstraints))
+        }
+      }
+
+      inputList.reverse
+    }
+
+    def indent(s: String) : String = "  " + s.split("\n").mkString("\n  ")
+    def capitalize(s: String) : String = s.substring(0, 1).toUpperCase + s.substring(1)
+
+    def testFunctionName(id: Identifier) : String = "test" + capitalize(id.toString)
+    def testFunction(id: Identifier, inputs: Seq[Seq[Expr]]) : String = {
+      val idString = id.toString
+      val toReturn = 
+        "def " + testFunctionName(id) + " : Unit = {" + "\n" +
+        inputs.map(input => indent(idString + input.mkString("(", ", ", ")"))).mkString("\n") + "\n" +
+        "}" + "\n"
+      toReturn
+    }
+
+    def testMainMethod(funcIds: Seq[Identifier]) : String = {
+      "def main(args: Array[String]) : Unit = {" + "\n" +
+      indent(funcIds.map(testFunctionName(_)).mkString("\n")) + "\n" +
+      "}" + "\n"
+    }
+
+    def testObject(funcInputPairs: Seq[(Identifier, Seq[Seq[Expr]])]) : String = {
+      val objectName = program.mainObject.id.toString
+      val toReturn =
+        "import " + objectName + "._" + "\n" +
+        "\n" +
+        "object Test" + capitalize(objectName) + " {" + "\n" +
+        indent(testMainMethod(funcInputPairs.map(_._1))) + "\n" +
+        "\n" +
+        indent(funcInputPairs.map(pair => testFunction(pair._1, pair._2)).mkString("\n")) + "\n" +
+        "}"
+      toReturn
+    }
+
+    val funcInputPairs: Seq[(Identifier, Seq[Seq[Expr]])] = (for (funDef <- program.definedFunctions.toList.sortWith((fd1, fd2) => fd1 < fd2) if (!funDef.isPrivate && (!Settings.impureTestcases || !funDef.hasBody))) yield {
+      reporter.info("Considering function definition: " + funDef.id)
+      funDef.precondition match {
+        case Some(p) => reporter.info("The precondition is: " + p)
+        case None =>    reporter.info("Function has no precondition")
+      }
+
+      val testInput = generateTestInput(funDef)
+      reporter.info("Generated test input is: " + testInput)
+      (funDef.id, testInput)
+    })
+
+    writeToFile("Test" + program.mainObject.id.toString + ".scala", testObject(funcInputPairs))
+    
     reporter.info("Done.")
   }
+
 }
-- 
GitLab