diff --git a/src/funcheck/CallTransformation.scala b/src/funcheck/CallTransformation.scala index a2b0705a2d833922f1b3405c976fee246731fe82..42913bce279b7a5b96e75595a7d695146baa5255 100644 --- a/src/funcheck/CallTransformation.scala +++ b/src/funcheck/CallTransformation.scala @@ -17,10 +17,10 @@ trait CallTransformation private lazy val cpDefinitionsModule = definitions.getModule("funcheck.CP") - def transformCalls(unit: CompilationUnit, prog: Program, filename: String) : Unit = - unit.body = new CallTransformer(unit, prog, filename).transform(unit.body) + def transformCalls(unit: CompilationUnit, prog: Program, programFilename: String) : Unit = + unit.body = new CallTransformer(unit, prog, programFilename).transform(unit.body) - class CallTransformer(unit: CompilationUnit, prog: Program, filename: String) extends TypingTransformer(unit) { + class CallTransformer(unit: CompilationUnit, prog: Program, programFilename: String) extends TypingTransformer(unit) { override def transform(tree: Tree) : Tree = { tree match { case a @ Apply(TypeApply(Select(s: Select, n), _), rhs @ List(predicate: Function)) if (cpDefinitionsModule == s.symbol && n.toString == "choose") => { @@ -38,9 +38,11 @@ trait CallTransformation fd.body match { case None => println("Could not extract choose predicate: " + funBody); super.transform(tree) case Some(b) => - val (programGet, progSym) = codeGen.generateProgramGet(filename) - val solverInvocation = codeGen.generateSolverInvocation(b, progSym) - val code = Block(programGet :: Nil, solverInvocation) + val exprFilename = writeExpr(b) + val (programGet, progSym) = codeGen.generateProgramGet(programFilename) + val (exprGet, exprSym) = codeGen.generateExprGet(exprFilename) + val solverInvocation = codeGen.generateSolverInvocation(b, progSym, exprSym) + val code = Block(programGet :: exprGet :: Nil, solverInvocation) typer.typed(atOwner(currentOwner) { code diff --git a/src/funcheck/CodeGeneration.scala b/src/funcheck/CodeGeneration.scala index c81a2d71b2a6d7a9cc94a25f8fc0170a1ef07433..e2052a5e3e811ba328f2cd84dc5eade8731bfa9a 100644 --- a/src/funcheck/CodeGeneration.scala +++ b/src/funcheck/CodeGeneration.scala @@ -6,27 +6,39 @@ trait CodeGeneration { self: CallTransformation => import global._ - private lazy val serializationModule = definitions.getClass("funcheck.Serialization") + private lazy val funcheckPackage = definitions.getModule("funcheck") + + private lazy val serializationModule = definitions.getModule("funcheck.Serialization") private lazy val getProgramFunction = definitions.getMember(serializationModule, "getProgram") + private lazy val getExprFunction = definitions.getMember(serializationModule, "getExpr") + private lazy val purescalaPackage = definitions.getModule("purescala") + private lazy val definitionsModule = definitions.getModule("purescala.Definitions") private lazy val programClass = definitions.getClass("purescala.Definitions.Program") + + private lazy val treesModule = definitions.getModule("purescala.Trees") + private lazy val exprClass = definitions.getClass("purescala.Trees.Expr") + private lazy val fairZ3SolverClass = definitions.getClass("purescala.FairZ3Solver") - private lazy val decideWithModelFunction = definitions.getMember(fairZ3SolverClass, "decideWithModel") + private lazy val restartAndDecideWithModel = definitions.getMember(fairZ3SolverClass, "restartAndDecideWithModel") private lazy val setProgramFunction = definitions.getMember(fairZ3SolverClass, "setProgram") private lazy val defaultReporter = definitions.getClass("purescala.DefaultReporter") - class CodeGenerator(unit: CompilationUnit, owner: Symbol) { + class CodeGenerator(unit : CompilationUnit, owner : Symbol) { - def generateProgramGet(filename: String) : (Tree, Symbol) = { + def generateProgramGet(filename : String) : (Tree, Symbol) = { val progSymbol = owner.newValue(NoPosition, unit.fresh.newName(NoPosition, "prog")).setInfo(programClass.tpe) val getStatement = ValDef( progSymbol, Apply( Select( - Ident(serializationModule), + Select( + Ident(funcheckPackage), + serializationModule + ) , getProgramFunction ), List(Literal(Constant(filename))) @@ -35,7 +47,26 @@ trait CodeGeneration { (getStatement, progSymbol) } - def generateSolverInvocation(formula: Expr, progSymbol: Symbol) : Tree = { + def generateExprGet(filename : String) : (Tree, Symbol) = { + val exprSymbol = owner.newValue(NoPosition, unit.fresh.newName(NoPosition, "expr")).setInfo(exprClass.tpe) + val getStatement = + ValDef( + exprSymbol, + Apply( + Select( + Select( + Ident(funcheckPackage), + serializationModule + ), + getExprFunction + ), + List(Literal(Constant(filename))) + ) + ) + (getStatement, exprSymbol) + } + + def generateSolverInvocation(formula : Expr, progSymbol : Symbol, exprSymbol : Symbol) : Tree = { val solverSymbol = owner.newValue(NoPosition, unit.fresh.newName(NoPosition, "solver")).setInfo(fairZ3SolverClass.tpe) val solverDeclaration = ValDef( @@ -65,9 +96,9 @@ trait CodeGeneration { Apply( Select( Ident(solverSymbol), - decideWithModelFunction + restartAndDecideWithModel ), - List(/* convert pred into scala AST of funcheck expression and plug it here */) + List(Ident(exprSymbol), Literal(Constant(false))) ) Block( diff --git a/src/purescala/FairZ3Solver.scala b/src/purescala/FairZ3Solver.scala index 6945611ad87be9b73a4f0f9b6c0bbc56df189336..7e62e89965e7ea9702afedd0b1652da97a091cf2 100644 --- a/src/purescala/FairZ3Solver.scala +++ b/src/purescala/FairZ3Solver.scala @@ -238,6 +238,11 @@ class FairZ3Solver(val reporter: Reporter) extends Solver(reporter) with Abstrac decideWithModel(vc, forValidity)._1 } + def restartAndDecideWithModel(vc: Expr, forValidity: Boolean): (Option[Boolean], Map[Identifier,Expr]) = { + restartZ3 + decideWithModel(vc, forValidity) + } + def decideWithModel(vc: Expr, forValidity: Boolean): (Option[Boolean], Map[Identifier,Expr]) = { val unrollingBank = new UnrollingBank