diff --git a/project/build/funcheck.scala b/project/build/funcheck.scala index fd0aac5b3c1ef18c1210fbb8cc652069dce2a2ff..30521628ec2dce3816037e144467e916b3a6cb95 100644 --- a/project/build/funcheck.scala +++ b/project/build/funcheck.scala @@ -45,7 +45,8 @@ class FunCheckProject(info: ProjectInfo) extends DefaultProject(info) with FileT fw.write("done" + nl + nl) fw.write("SCALACCLASSPATH=\"") fw.write(multisetsLib.jarPath.absolutePath + ":") - fw.write(plugin.jarPath.absolutePath) + fw.write(plugin.jarPath.absolutePath + ":") + fw.write(purescala.jarPath.absolutePath) fw.write("\"" + nl + nl) fw.write("LD_LIBRARY_PATH=" + ("." / "lib-bin").absolutePath + " \\" + nl) fw.write("java -Xmx1024M \\" + nl) diff --git a/src/funcheck/CPComponent.scala b/src/funcheck/CPComponent.scala index 175f57b373b072015661f18e0ac95fdb4daa93dd..d34b9161717fc077dc208445b90c3f8462acca0f 100644 --- a/src/funcheck/CPComponent.scala +++ b/src/funcheck/CPComponent.scala @@ -31,7 +31,7 @@ class CPComponent(val global: Global, val pluginInstance: FunCheckPlugin) val fileName = writeProgram(prog) println("Program extracted and written into: " + fileName) - transformCalls(unit) + transformCalls(unit, prog) println("Finished transformation") /* diff --git a/src/funcheck/CallTransformation.scala b/src/funcheck/CallTransformation.scala index 795ae494423b805056c6970808de2ccbade056da..46f92a5849a8639386372625f6bb7a5268f35924 100644 --- a/src/funcheck/CallTransformation.scala +++ b/src/funcheck/CallTransformation.scala @@ -1,6 +1,9 @@ package funcheck import scala.tools.nsc.transform.TypingTransformers +import purescala.FairZ3Solver +import purescala.DefaultReporter +import purescala.Definitions._ trait CallTransformation extends TypingTransformers @@ -11,10 +14,14 @@ trait CallTransformation private lazy val funcheckPackage = definitions.getModule("funcheck") private lazy val cpDefinitionsModule = definitions.getModule("funcheck.CP") - def transformCalls(unit: CompilationUnit) : Unit = - unit.body = new CallTransformer(unit).transform(unit.body) + private lazy val purescalaPackage = definitions.getModule("purescala") + private lazy val fairZ3Solver = definitions.getClass("purescala.FairZ3Solver") + private lazy val defaultReporter = definitions.getClass("purescala.DefaultReporter") + + def transformCalls(unit: CompilationUnit, prog: Program) : Unit = + unit.body = new CallTransformer(unit, prog).transform(unit.body) - class CallTransformer(unit: CompilationUnit) extends TypingTransformer(unit) { + class CallTransformer(unit: CompilationUnit, prog: Program) extends TypingTransformer(unit) { override def transform(tree: Tree) : Tree = { tree match { case a @ Apply(TypeApply(Select(s: Select, n), _), rhs @ List(predicate: Function)) if (cpDefinitionsModule == s.symbol && n.toString == "choose") => { @@ -27,7 +34,36 @@ trait CallTransformation println("Here is the extracted FunDef:") println(fd) - super.transform(a) + val solverSymbol = currentOwner.newValue(NoPosition, unit.fresh.newName(NoPosition, "s")).setInfo(fairZ3Solver.tpe) + + val code = Block( + ValDef( + solverSymbol, + New( + Ident(fairZ3Solver), + List( + List( + New( + Ident(defaultReporter), + List(Nil) + ) + ) + ) + ) + ) :: Nil, + Literal(Constant(0)) + ) + + typer.typed(atOwner(currentOwner) { + code + }) + + /* + val solver = new FairZ3Solver(new DefaultReporter) + solver.setProgram(prog) + println(solver.decide(fd.body.get, false)) + super.transform(tree) + */ } case _ => super.transform(tree)