From c72dba7ef5bae78eb87e2d7a06480b33dd45a292 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ali=20Sinan=20K=C3=B6ksal?= <alisinan@gmail.com>
Date: Wed, 30 Mar 2011 20:17:24 +0000
Subject: [PATCH] Store and read output variable lists.

---
 src/cp/CallTransformation.scala | 31 +++++++++++++++++++------------
 src/cp/CodeGeneration.scala     | 22 ++++++++++++++++++----
 src/cp/Serialization.scala      | 33 +++++++++++++++++++++++----------
 3 files changed, 60 insertions(+), 26 deletions(-)

diff --git a/src/cp/CallTransformation.scala b/src/cp/CallTransformation.scala
index d32410de1..2f5dab938 100644
--- a/src/cp/CallTransformation.scala
+++ b/src/cp/CallTransformation.scala
@@ -4,6 +4,7 @@ import scala.tools.nsc.transform.TypingTransformers
 import scala.tools.nsc.ast.TreeDSL
 import purescala.FairZ3Solver
 import purescala.DefaultReporter
+import purescala.Common.Identifier
 import purescala.Definitions._
 import purescala.Trees._
 
@@ -27,15 +28,25 @@ trait CallTransformation
     var exprToScalaSym : Symbol = null
     var exprToScalaCode : Tree = null
 
+    def outputAssignmentList(outputVars: List[String], model: Map[Identifier, Expr]) : List[Any] = {
+      val modelWithStrings = model.map{ case (k, v) => (k.name, v) }
+      outputVars.map(ov => (modelWithStrings.get(ov) match {
+        case Some(value) => value
+        case None => scala.Predef.error("Did not find assignment for variable '" + ov + "'")
+      }))
+    }
+
     override def transform(tree: Tree) : Tree = {
       tree match {
         case a @ Apply(TypeApply(Select(s: Select, n), _), rhs @ List(predicate: Function)) if (cpDefinitionsModule == s.symbol && n.toString == "choose") => {
-          println("I'm inside a choose call!")
-
           val Function(funValDefs, funBody) = predicate
 
           val fd = extractPredicate(unit, funValDefs, funBody)
 
+          val outputVarList = funValDefs.map(_.name.toString)
+          println("Here is the output variable list: " + outputVarList.mkString(", "))
+          val outputVarListFilename = writeObject(outputVarList)
+
           println("Here is the extracted FunDef:") 
           println(fd)
           val codeGen = new CodeGenerator(unit, currentOwner)
@@ -43,13 +54,14 @@ trait CallTransformation
           fd.body match {
             case None => println("Could not extract choose predicate: " + funBody); super.transform(tree)
             case Some(b) =>
-              val exprFilename = writeExpr(b)
+              val exprFilename = writeObject(b)
               val (programGet, progSym) = codeGen.getProgram(programFilename)
               val (exprGet, exprSym) = codeGen.getExpr(exprFilename)
+              val (outputVarListGet, outputVarListSym) = codeGen.getOutputVarList(outputVarListFilename)
               val solverInvocation = codeGen.invokeSolver(progSym, exprSym)
               val exprToScalaInvocation = codeGen.invokeExprToScala(exprToScalaSym)
-              // val code = BLOCK(programGet, exprGet, solverInvocation)
-              val code = BLOCK(programGet, exprGet, solverInvocation, exprToScalaInvocation)
+
+              val code = BLOCK(programGet, exprGet, outputVarListGet, solverInvocation, exprToScalaInvocation)
 
               typer.typed(atOwner(currentOwner) {
                 code
@@ -58,10 +70,9 @@ trait CallTransformation
         }
 
         case cd @ ClassDef(mods, name, tparams, impl) if (cd.symbol.isModuleClass && tparams.isEmpty && !cd.symbol.isSynthetic) => {
-          println("I'm inside the object " + name.toString + " !")
-
           val codeGen = new CodeGenerator(unit, currentOwner)
-          val (e2sSym, e2sCode) = codeGen.exprToScala(cd.symbol)
+
+          val (e2sSym, e2sCode) = codeGen.exprToScala(cd.symbol, prog)
           exprToScalaSym = e2sSym
           exprToScalaCode = e2sCode
           atOwner(tree.symbol) {
@@ -74,10 +85,6 @@ trait CallTransformation
           }
         }
 
-        case dd @ DefDef(mods, name, _, _, _, _) => {
-          super.transform(tree)
-        }
-
         case _ => super.transform(tree)
       }
     }
diff --git a/src/cp/CodeGeneration.scala b/src/cp/CodeGeneration.scala
index b6e60daa1..e5f02ba05 100644
--- a/src/cp/CodeGeneration.scala
+++ b/src/cp/CodeGeneration.scala
@@ -1,12 +1,14 @@
 package cp
 
 import purescala.Trees._
+import purescala.Definitions._
 
 trait CodeGeneration {
   self: CallTransformation =>
   import global._
   import CODE._
 
+  private lazy val scalaPackage = definitions.ScalaPackage
   private lazy val exceptionClass = definitions.getClass("java.lang.Exception")
 
   private lazy val cpPackage = definitions.getModule("cp")
@@ -14,6 +16,7 @@ trait CodeGeneration {
   private lazy val serializationModule = definitions.getModule("cp.Serialization")
   private lazy val getProgramFunction = definitions.getMember(serializationModule, "getProgram")
   private lazy val getExprFunction = definitions.getMember(serializationModule, "getExpr")
+  private lazy val getOutputVarListFunction = definitions.getMember(serializationModule, "getOutputVarList")
 
   private lazy val purescalaPackage = definitions.getModule("purescala")
 
@@ -22,8 +25,9 @@ trait CodeGeneration {
 
   private lazy val treesModule = definitions.getModule("purescala.Trees")
   private lazy val exprClass = definitions.getClass("purescala.Trees.Expr")
-  private lazy val intLiteralClass = definitions.getClass("purescala.Trees.IntLiteral")
   private lazy val intLiteralModule = definitions.getModule("purescala.Trees.IntLiteral")
+  private lazy val booleanLiteralModule = definitions.getModule("purescala.Trees.BooleanLiteral")
+  private lazy val caseClassModule = definitions.getModule("purescala.Trees.CaseClass")
 
   private lazy val fairZ3SolverClass = definitions.getClass("purescala.FairZ3Solver")
   private lazy val restartAndDecideWithModel = definitions.getMember(fairZ3SolverClass, "restartAndDecideWithModel")
@@ -45,6 +49,12 @@ trait CodeGeneration {
       (getStatement, exprSym)
     }
 
+    def getOutputVarList(filename : String) : (Tree, Symbol) = {
+      val listSym = owner.newValue(NoPosition, unit.fresh.newName(NoPosition, "ovl")).setInfo(typeRef(NoPrefix, definitions.ListClass, List(definitions.StringClass.tpe)))
+      val getStatement = VAL(listSym) === ((cpPackage DOT serializationModule DOT getOutputVarListFunction) APPLY LIT(filename))
+      (getStatement, listSym)
+    }
+
     def invokeSolver(progSym : Symbol, exprSym : Symbol) : Tree = {
       val solverSym = owner.newValue(NoPosition, unit.fresh.newName(NoPosition, "solver")).setInfo(fairZ3SolverClass.tpe)
       val solverDeclaration = VAL(solverSym) === NEW(ID(fairZ3SolverClass), NEW(ID(defaultReporter)))
@@ -54,16 +64,20 @@ trait CodeGeneration {
       BLOCK(solverDeclaration, setProgram, invocation, LIT(0))
     }
 
-    def exprToScala(owner : Symbol) : (Symbol, Tree) = {
+    def exprToScala(owner : Symbol, prog : Program) : (Symbol, Tree) = {
       val methodSym = owner.newMethod(NoPosition, unit.fresh.newName(NoPosition, "exprToScala"))
       methodSym.setInfo(MethodType(methodSym.newSyntheticValueParams(List(definitions.AnyClass.tpe)), definitions.AnyClass.tpe))
       owner.info.decls.enter(methodSym)
 
       val intSym = methodSym.newValue(NoPosition, unit.fresh.newName(NoPosition, "value")).setInfo(definitions.IntClass.tpe)
+      val booleanSym = methodSym.newValue(NoPosition, unit.fresh.newName(NoPosition, "value")).setInfo(definitions.BooleanClass.tpe)
+
+      val definedCaseClasses : Seq[CaseClassDef] = prog.definedClasses.filter(_.isInstanceOf[CaseClassDef]).map(_.asInstanceOf[CaseClassDef])
 
       val matchExpr = (methodSym ARG 0) MATCH (
-        CASE((intLiteralModule) APPLY (intSym BIND WILD())) ==> ID(intSym) ,
-        DEFAULT                                             ==> THROW(exceptionClass, LIT("Cannot convert FunCheck expression to Scala term"))
+        CASE((intLiteralModule) APPLY (intSym BIND WILD()))         ==> ID(intSym) ,
+        CASE((booleanLiteralModule) APPLY (booleanSym BIND WILD())) ==> ID(booleanSym) ,
+        DEFAULT                                                     ==> THROW(exceptionClass, LIT("Cannot convert FunCheck expression to Scala term"))
       )
 
       (methodSym, DEF(methodSym) === matchExpr)
diff --git a/src/cp/Serialization.scala b/src/cp/Serialization.scala
index 12b43e341..e5e0b782a 100644
--- a/src/cp/Serialization.scala
+++ b/src/cp/Serialization.scala
@@ -5,12 +5,14 @@ trait Serialization {
   import purescala.Definitions._
   import purescala.Trees._
 
+  private val filePrefix = "serialized"
   private val fileSuffix = ""
   private val dirName = "serialized"
   private val directory = new java.io.File(dirName)
 
   private var cachedProgram : Option[Program] = None
   private val exprMap = new scala.collection.mutable.HashMap[String,Expr]()
+  private val outputVarListMap = new scala.collection.mutable.HashMap[String,List[String]]()
 
   def programFileName(prog : Program) : String = {
     prog.mainObject.id.toString + fileSuffix
@@ -25,7 +27,17 @@ trait Serialization {
     fos.close()
 
     file.getAbsolutePath()
+
+  }
+
+  def writeObject(obj : Any) : String = {
+    directory.mkdir()
+
+    val file = java.io.File.createTempFile(filePrefix, fileSuffix, directory)
+    
+    writeObject(file, obj)
   }
+
   def writeProgram(prog : Program) : String = {
     directory.mkdir()
 
@@ -35,14 +47,6 @@ trait Serialization {
     writeObject(file, prog)
   }
 
-  def writeExpr(pred : Expr) : String = {
-    directory.mkdir()
-
-    val file = java.io.File.createTempFile("expr", fileSuffix, directory)
-    
-    writeObject(file, pred)
-  }
-
   private def readObject[A](filename : String) : A = {
     val fis = new FileInputStream(filename)
     val ois = new ObjectInputStream(fis)
@@ -61,14 +65,23 @@ trait Serialization {
     readObject[Expr](filename)
   }
 
+  private def readOutputVarList(filename : String) : List[String] = {
+    readObject[List[String]](filename)
+  }
+
   def getProgram(filename : String) : Program = cachedProgram match {
     case None => val r = readProgram(filename); cachedProgram = Some(r); r
     case Some(cp) => cp
   }
 
   def getExpr(filename : String) : Expr = exprMap.get(filename) match {
-    case Some(p) => p
-    case None => val p = readExpr(filename); exprMap += (filename -> p); p
+    case Some(e) => e
+    case None => val e = readExpr(filename); exprMap += (filename -> e); e
+  }
+
+  def getOutputVarList(filename : String) : List[String] = outputVarListMap.get(filename) match {
+    case Some(l) => l
+    case None => val l = readOutputVarList(filename); outputVarListMap += (filename -> l); l
   }
 }
 
-- 
GitLab