From c72dba7ef5bae78eb87e2d7a06480b33dd45a292 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ali=20Sinan=20K=C3=B6ksal?= <alisinan@gmail.com> Date: Wed, 30 Mar 2011 20:17:24 +0000 Subject: [PATCH] Store and read output variable lists. --- src/cp/CallTransformation.scala | 31 +++++++++++++++++++------------ src/cp/CodeGeneration.scala | 22 ++++++++++++++++++---- src/cp/Serialization.scala | 33 +++++++++++++++++++++++---------- 3 files changed, 60 insertions(+), 26 deletions(-) diff --git a/src/cp/CallTransformation.scala b/src/cp/CallTransformation.scala index d32410de1..2f5dab938 100644 --- a/src/cp/CallTransformation.scala +++ b/src/cp/CallTransformation.scala @@ -4,6 +4,7 @@ import scala.tools.nsc.transform.TypingTransformers import scala.tools.nsc.ast.TreeDSL import purescala.FairZ3Solver import purescala.DefaultReporter +import purescala.Common.Identifier import purescala.Definitions._ import purescala.Trees._ @@ -27,15 +28,25 @@ trait CallTransformation var exprToScalaSym : Symbol = null var exprToScalaCode : Tree = null + def outputAssignmentList(outputVars: List[String], model: Map[Identifier, Expr]) : List[Any] = { + val modelWithStrings = model.map{ case (k, v) => (k.name, v) } + outputVars.map(ov => (modelWithStrings.get(ov) match { + case Some(value) => value + case None => scala.Predef.error("Did not find assignment for variable '" + ov + "'") + })) + } + override def transform(tree: Tree) : Tree = { tree match { case a @ Apply(TypeApply(Select(s: Select, n), _), rhs @ List(predicate: Function)) if (cpDefinitionsModule == s.symbol && n.toString == "choose") => { - println("I'm inside a choose call!") - val Function(funValDefs, funBody) = predicate val fd = extractPredicate(unit, funValDefs, funBody) + val outputVarList = funValDefs.map(_.name.toString) + println("Here is the output variable list: " + outputVarList.mkString(", ")) + val outputVarListFilename = writeObject(outputVarList) + println("Here is the extracted FunDef:") println(fd) val codeGen = new CodeGenerator(unit, currentOwner) @@ -43,13 +54,14 @@ trait CallTransformation fd.body match { case None => println("Could not extract choose predicate: " + funBody); super.transform(tree) case Some(b) => - val exprFilename = writeExpr(b) + val exprFilename = writeObject(b) val (programGet, progSym) = codeGen.getProgram(programFilename) val (exprGet, exprSym) = codeGen.getExpr(exprFilename) + val (outputVarListGet, outputVarListSym) = codeGen.getOutputVarList(outputVarListFilename) val solverInvocation = codeGen.invokeSolver(progSym, exprSym) val exprToScalaInvocation = codeGen.invokeExprToScala(exprToScalaSym) - // val code = BLOCK(programGet, exprGet, solverInvocation) - val code = BLOCK(programGet, exprGet, solverInvocation, exprToScalaInvocation) + + val code = BLOCK(programGet, exprGet, outputVarListGet, solverInvocation, exprToScalaInvocation) typer.typed(atOwner(currentOwner) { code @@ -58,10 +70,9 @@ trait CallTransformation } case cd @ ClassDef(mods, name, tparams, impl) if (cd.symbol.isModuleClass && tparams.isEmpty && !cd.symbol.isSynthetic) => { - println("I'm inside the object " + name.toString + " !") - val codeGen = new CodeGenerator(unit, currentOwner) - val (e2sSym, e2sCode) = codeGen.exprToScala(cd.symbol) + + val (e2sSym, e2sCode) = codeGen.exprToScala(cd.symbol, prog) exprToScalaSym = e2sSym exprToScalaCode = e2sCode atOwner(tree.symbol) { @@ -74,10 +85,6 @@ trait CallTransformation } } - case dd @ DefDef(mods, name, _, _, _, _) => { - super.transform(tree) - } - case _ => super.transform(tree) } } diff --git a/src/cp/CodeGeneration.scala b/src/cp/CodeGeneration.scala index b6e60daa1..e5f02ba05 100644 --- a/src/cp/CodeGeneration.scala +++ b/src/cp/CodeGeneration.scala @@ -1,12 +1,14 @@ package cp import purescala.Trees._ +import purescala.Definitions._ trait CodeGeneration { self: CallTransformation => import global._ import CODE._ + private lazy val scalaPackage = definitions.ScalaPackage private lazy val exceptionClass = definitions.getClass("java.lang.Exception") private lazy val cpPackage = definitions.getModule("cp") @@ -14,6 +16,7 @@ trait CodeGeneration { private lazy val serializationModule = definitions.getModule("cp.Serialization") private lazy val getProgramFunction = definitions.getMember(serializationModule, "getProgram") private lazy val getExprFunction = definitions.getMember(serializationModule, "getExpr") + private lazy val getOutputVarListFunction = definitions.getMember(serializationModule, "getOutputVarList") private lazy val purescalaPackage = definitions.getModule("purescala") @@ -22,8 +25,9 @@ trait CodeGeneration { private lazy val treesModule = definitions.getModule("purescala.Trees") private lazy val exprClass = definitions.getClass("purescala.Trees.Expr") - private lazy val intLiteralClass = definitions.getClass("purescala.Trees.IntLiteral") private lazy val intLiteralModule = definitions.getModule("purescala.Trees.IntLiteral") + private lazy val booleanLiteralModule = definitions.getModule("purescala.Trees.BooleanLiteral") + private lazy val caseClassModule = definitions.getModule("purescala.Trees.CaseClass") private lazy val fairZ3SolverClass = definitions.getClass("purescala.FairZ3Solver") private lazy val restartAndDecideWithModel = definitions.getMember(fairZ3SolverClass, "restartAndDecideWithModel") @@ -45,6 +49,12 @@ trait CodeGeneration { (getStatement, exprSym) } + def getOutputVarList(filename : String) : (Tree, Symbol) = { + val listSym = owner.newValue(NoPosition, unit.fresh.newName(NoPosition, "ovl")).setInfo(typeRef(NoPrefix, definitions.ListClass, List(definitions.StringClass.tpe))) + val getStatement = VAL(listSym) === ((cpPackage DOT serializationModule DOT getOutputVarListFunction) APPLY LIT(filename)) + (getStatement, listSym) + } + def invokeSolver(progSym : Symbol, exprSym : Symbol) : Tree = { val solverSym = owner.newValue(NoPosition, unit.fresh.newName(NoPosition, "solver")).setInfo(fairZ3SolverClass.tpe) val solverDeclaration = VAL(solverSym) === NEW(ID(fairZ3SolverClass), NEW(ID(defaultReporter))) @@ -54,16 +64,20 @@ trait CodeGeneration { BLOCK(solverDeclaration, setProgram, invocation, LIT(0)) } - def exprToScala(owner : Symbol) : (Symbol, Tree) = { + def exprToScala(owner : Symbol, prog : Program) : (Symbol, Tree) = { val methodSym = owner.newMethod(NoPosition, unit.fresh.newName(NoPosition, "exprToScala")) methodSym.setInfo(MethodType(methodSym.newSyntheticValueParams(List(definitions.AnyClass.tpe)), definitions.AnyClass.tpe)) owner.info.decls.enter(methodSym) val intSym = methodSym.newValue(NoPosition, unit.fresh.newName(NoPosition, "value")).setInfo(definitions.IntClass.tpe) + val booleanSym = methodSym.newValue(NoPosition, unit.fresh.newName(NoPosition, "value")).setInfo(definitions.BooleanClass.tpe) + + val definedCaseClasses : Seq[CaseClassDef] = prog.definedClasses.filter(_.isInstanceOf[CaseClassDef]).map(_.asInstanceOf[CaseClassDef]) val matchExpr = (methodSym ARG 0) MATCH ( - CASE((intLiteralModule) APPLY (intSym BIND WILD())) ==> ID(intSym) , - DEFAULT ==> THROW(exceptionClass, LIT("Cannot convert FunCheck expression to Scala term")) + CASE((intLiteralModule) APPLY (intSym BIND WILD())) ==> ID(intSym) , + CASE((booleanLiteralModule) APPLY (booleanSym BIND WILD())) ==> ID(booleanSym) , + DEFAULT ==> THROW(exceptionClass, LIT("Cannot convert FunCheck expression to Scala term")) ) (methodSym, DEF(methodSym) === matchExpr) diff --git a/src/cp/Serialization.scala b/src/cp/Serialization.scala index 12b43e341..e5e0b782a 100644 --- a/src/cp/Serialization.scala +++ b/src/cp/Serialization.scala @@ -5,12 +5,14 @@ trait Serialization { import purescala.Definitions._ import purescala.Trees._ + private val filePrefix = "serialized" private val fileSuffix = "" private val dirName = "serialized" private val directory = new java.io.File(dirName) private var cachedProgram : Option[Program] = None private val exprMap = new scala.collection.mutable.HashMap[String,Expr]() + private val outputVarListMap = new scala.collection.mutable.HashMap[String,List[String]]() def programFileName(prog : Program) : String = { prog.mainObject.id.toString + fileSuffix @@ -25,7 +27,17 @@ trait Serialization { fos.close() file.getAbsolutePath() + + } + + def writeObject(obj : Any) : String = { + directory.mkdir() + + val file = java.io.File.createTempFile(filePrefix, fileSuffix, directory) + + writeObject(file, obj) } + def writeProgram(prog : Program) : String = { directory.mkdir() @@ -35,14 +47,6 @@ trait Serialization { writeObject(file, prog) } - def writeExpr(pred : Expr) : String = { - directory.mkdir() - - val file = java.io.File.createTempFile("expr", fileSuffix, directory) - - writeObject(file, pred) - } - private def readObject[A](filename : String) : A = { val fis = new FileInputStream(filename) val ois = new ObjectInputStream(fis) @@ -61,14 +65,23 @@ trait Serialization { readObject[Expr](filename) } + private def readOutputVarList(filename : String) : List[String] = { + readObject[List[String]](filename) + } + def getProgram(filename : String) : Program = cachedProgram match { case None => val r = readProgram(filename); cachedProgram = Some(r); r case Some(cp) => cp } def getExpr(filename : String) : Expr = exprMap.get(filename) match { - case Some(p) => p - case None => val p = readExpr(filename); exprMap += (filename -> p); p + case Some(e) => e + case None => val e = readExpr(filename); exprMap += (filename -> e); e + } + + def getOutputVarList(filename : String) : List[String] = outputVarListMap.get(filename) match { + case Some(l) => l + case None => val l = readOutputVarList(filename); outputVarListMap += (filename -> l); l } } -- GitLab