diff --git a/src/funcheck/CPComponent.scala b/src/funcheck/CPComponent.scala index d34b9161717fc077dc208445b90c3f8462acca0f..e049328f2f1e97e58d25bd2e9c036a3574bec7a2 100644 --- a/src/funcheck/CPComponent.scala +++ b/src/funcheck/CPComponent.scala @@ -28,15 +28,15 @@ class CPComponent(val global: Global, val pluginInstance: FunCheckPlugin) fresh = unit.fresh val prog: purescala.Definitions.Program = extractCode(unit) - val fileName = writeProgram(prog) - println("Program extracted and written into: " + fileName) + val filename = writeProgram(prog) + println("Program extracted and written into: " + filename) - transformCalls(unit, prog) + transformCalls(unit, prog, filename) println("Finished transformation") /* try { - val recovered = readProgram(fileName) + val recovered = readProgram(filename) println println("Recovered: " + recovered) } catch { diff --git a/src/funcheck/CallTransformation.scala b/src/funcheck/CallTransformation.scala index 46f92a5849a8639386372625f6bb7a5268f35924..661b09708e60ce819548d6095e7ee328eed966a7 100644 --- a/src/funcheck/CallTransformation.scala +++ b/src/funcheck/CallTransformation.scala @@ -4,24 +4,23 @@ import scala.tools.nsc.transform.TypingTransformers import purescala.FairZ3Solver import purescala.DefaultReporter import purescala.Definitions._ +import purescala.Trees._ trait CallTransformation extends TypingTransformers - with CodeExtraction + with CodeGeneration { + self: CPComponent => import global._ private lazy val funcheckPackage = definitions.getModule("funcheck") private lazy val cpDefinitionsModule = definitions.getModule("funcheck.CP") - private lazy val purescalaPackage = definitions.getModule("purescala") - private lazy val fairZ3Solver = definitions.getClass("purescala.FairZ3Solver") - private lazy val defaultReporter = definitions.getClass("purescala.DefaultReporter") - def transformCalls(unit: CompilationUnit, prog: Program) : Unit = - unit.body = new CallTransformer(unit, prog).transform(unit.body) + def transformCalls(unit: CompilationUnit, prog: Program, filename: String) : Unit = + unit.body = new CallTransformer(unit, prog, filename).transform(unit.body) - class CallTransformer(unit: CompilationUnit, prog: Program) extends TypingTransformer(unit) { + class CallTransformer(unit: CompilationUnit, prog: Program, filename: String) extends TypingTransformer(unit) { override def transform(tree: Tree) : Tree = { tree match { case a @ Apply(TypeApply(Select(s: Select, n), _), rhs @ List(predicate: Function)) if (cpDefinitionsModule == s.symbol && n.toString == "choose") => { @@ -34,35 +33,17 @@ trait CallTransformation println("Here is the extracted FunDef:") println(fd) - val solverSymbol = currentOwner.newValue(NoPosition, unit.fresh.newName(NoPosition, "s")).setInfo(fairZ3Solver.tpe) - - val code = Block( - ValDef( - solverSymbol, - New( - Ident(fairZ3Solver), - List( - List( - New( - Ident(defaultReporter), - List(Nil) - ) - ) - ) - ) - ) :: Nil, - Literal(Constant(0)) - ) - + /* typer.typed(atOwner(currentOwner) { code }) + */ + super.transform(tree) /* val solver = new FairZ3Solver(new DefaultReporter) solver.setProgram(prog) println(solver.decide(fd.body.get, false)) - super.transform(tree) */ } diff --git a/src/funcheck/CodeGeneration.scala b/src/funcheck/CodeGeneration.scala new file mode 100644 index 0000000000000000000000000000000000000000..3d3cc497059dcb9cbc235cb0e5840025710301e1 --- /dev/null +++ b/src/funcheck/CodeGeneration.scala @@ -0,0 +1,84 @@ +package funcheck + +import purescala.Trees._ + +trait CodeGeneration { + self: CallTransformation => + import global._ + + private lazy val serializationModule = definitions.getClass("funcheck.Serialization") + private lazy val readProgramFunction = definitions.getMember(serializationModule, "readProgram") + private lazy val purescalaPackage = definitions.getModule("purescala") + private lazy val definitionsModule = definitions.getModule("purescala.Definitions") + private lazy val programClass = definitions.getClass("purescala.Definitions.Program") + private lazy val fairZ3SolverClass = definitions.getClass("purescala.FairZ3Solver") + private lazy val decideWithModelFunction = definitions.getMember(fairZ3SolverClass, "decideWithModel") + private lazy val setProgramFunction = definitions.getMember(fairZ3SolverClass, "setProgram") + + private lazy val defaultReporter = definitions.getClass("purescala.DefaultReporter") + + class CodeGenerator(unit: CompilationUnit, owner: Symbol) { + /* + def exprToTree(expr: Expr) : Tree = expr match { + case Variable(id) => + } + */ + + def generateProgramRead(filename: String) : (Tree, Symbol) = { + val progSymbol = owner.newValue(NoPosition, unit.fresh.newName(NoPosition, "prog")).setInfo(programClass.tpe) + val readStatement = + ValDef( + progSymbol, + Apply( + Select( + Ident(serializationModule), + readProgramFunction + ), + List(Literal(Constant(filename))) + ) + ) + (readStatement, progSymbol) + } + + def generateSolverInvocation(formula: Expr) : Tree = { + val solverSymbol = owner.newValue(NoPosition, unit.fresh.newName(NoPosition, "solver")).setInfo(fairZ3SolverClass.tpe) + val solverDeclaration = + ValDef( + solverSymbol, + New( + Ident(fairZ3SolverClass), + List( + List( + New( + Ident(defaultReporter), + List(Nil) + ) + ) + ) + ) + ) + val setProgram = + Apply( + Select( + Ident(solverSymbol), + setProgramFunction + ), + List(/* read program into a var and plug its symbol here */) + ) + + val invocation = + Apply( + Select( + Ident(solverSymbol), + decideWithModelFunction + ), + List(/* convert pred into scala AST and plug it here */) + ) + + Block( + solverDeclaration :: setProgram :: invocation :: Nil, + Literal(Constant(0)) + ) + } + } +} diff --git a/src/funcheck/Serialization.scala b/src/funcheck/Serialization.scala index 198b666e0c807666d7fdb5d241e99e3f308b6ad4..12d08f4a874c91dac0c25f990b5efe282f52d5e1 100644 --- a/src/funcheck/Serialization.scala +++ b/src/funcheck/Serialization.scala @@ -37,3 +37,5 @@ trait Serialization { recovered } } + +object Serialization extends Serialization