diff --git a/src/cp/CallTransformation.scala b/src/cp/CallTransformation.scala index 5f93a11e49780bc08765e49b47d648be8902d5c2..8b8a47e3a162e1f204fae68e624c3f9b6a4d1afa 100644 --- a/src/cp/CallTransformation.scala +++ b/src/cp/CallTransformation.scala @@ -1,6 +1,7 @@ package cp import scala.tools.nsc.transform.TypingTransformers +import scala.tools.nsc.ast.TreeDSL import purescala.FairZ3Solver import purescala.DefaultReporter import purescala.Definitions._ @@ -9,9 +10,11 @@ import purescala.Trees._ trait CallTransformation extends TypingTransformers with CodeGeneration + with TreeDSL { self: CPComponent => import global._ + import CODE._ private lazy val cpPackage = definitions.getModule("cp") private lazy val cpDefinitionsModule = definitions.getModule("cp.CP") @@ -21,6 +24,9 @@ trait CallTransformation unit.body = new CallTransformer(unit, prog, programFilename).transform(unit.body) class CallTransformer(unit: CompilationUnit, prog: Program, programFilename: String) extends TypingTransformer(unit) { + val codeGen = new CodeGenerator(unit, currentOwner) + val (exprToScalaSym, exprToScalaCode) = codeGen.exprToScala + override def transform(tree: Tree) : Tree = { tree match { case a @ Apply(TypeApply(Select(s: Select, n), _), rhs @ List(predicate: Function)) if (cpDefinitionsModule == s.symbol && n.toString == "choose") => { @@ -33,16 +39,15 @@ trait CallTransformation println("Here is the extracted FunDef:") println(fd) - val codeGen = new CodeGenerator(unit, currentOwner) - fd.body match { case None => println("Could not extract choose predicate: " + funBody); super.transform(tree) case Some(b) => val exprFilename = writeExpr(b) val (programGet, progSym) = codeGen.getProgram(programFilename) val (exprGet, exprSym) = codeGen.getExpr(exprFilename) - val solverInvocation = codeGen.invokeSolver(b, progSym, exprSym) - val code = Block(programGet :: exprGet :: Nil, solverInvocation) + val solverInvocation = codeGen.invokeSolver(progSym, exprSym) + val exprToScalaInvocation = codeGen.invokeExprToScala(exprToScalaSym) + val code = BLOCK(programGet, exprGet, solverInvocation) //, exprToScalaInvocation) typer.typed(atOwner(currentOwner) { code @@ -50,6 +55,23 @@ trait CallTransformation } } + case cd @ ClassDef(mods, name, tparams, impl) if (cd.symbol.isModuleClass && tparams.isEmpty && !cd.symbol.isSynthetic) => { + println("I'm inside the object " + name.toString + " !") + + atOwner(tree.symbol) { + treeCopy.ClassDef(tree, transformModifiers(mods), name, + transformTypeDefs(tparams), impl match { + case Template(parents, self, body) => + treeCopy.Template(impl, transformTrees(parents), + transformValDef(self), typer.typed(atOwner(currentOwner) {exprToScalaCode}) :: transformStats(body, tree.symbol)) + }) + } + } + + case dd @ DefDef(mods, name, _, _, _, _) => { + super.transform(tree) + } + case _ => super.transform(tree) } } diff --git a/src/cp/CodeGeneration.scala b/src/cp/CodeGeneration.scala index 62b66a6880cdc8c3e49d77b529007d99911f7216..510f9c20a395cfc6a60c53791c0daf4be05876a3 100644 --- a/src/cp/CodeGeneration.scala +++ b/src/cp/CodeGeneration.scala @@ -5,6 +5,9 @@ import purescala.Trees._ trait CodeGeneration { self: CallTransformation => import global._ + import CODE._ + + private lazy val exceptionClass = definitions.getClass("java.lang.Exception") private lazy val cpPackage = definitions.getModule("cp") @@ -19,6 +22,7 @@ trait CodeGeneration { private lazy val treesModule = definitions.getModule("purescala.Trees") private lazy val exprClass = definitions.getClass("purescala.Trees.Expr") + private lazy val intLiteralClass = definitions.getClass("purescala.Trees.IntLiteral") private lazy val fairZ3SolverClass = definitions.getClass("purescala.FairZ3Solver") private lazy val restartAndDecideWithModel = definitions.getMember(fairZ3SolverClass, "restartAndDecideWithModel") @@ -29,82 +33,41 @@ trait CodeGeneration { class CodeGenerator(unit : CompilationUnit, owner : Symbol) { def getProgram(filename : String) : (Tree, Symbol) = { - val progSymbol = owner.newValue(NoPosition, unit.fresh.newName(NoPosition, "prog")).setInfo(programClass.tpe) - val getStatement = - ValDef( - progSymbol, - Apply( - Select( - Select( - Ident(cpPackage), - serializationModule - ) , - getProgramFunction - ), - List(Literal(Constant(filename))) - ) - ) - (getStatement, progSymbol) + val progSym = owner.newValue(NoPosition, unit.fresh.newName(NoPosition, "prog")).setInfo(programClass.tpe) + val getStatement = VAL(progSym) === ((cpPackage DOT serializationModule DOT getProgramFunction) APPLY LIT(filename)) + (getStatement, progSym) } def getExpr(filename : String) : (Tree, Symbol) = { - val exprSymbol = owner.newValue(NoPosition, unit.fresh.newName(NoPosition, "expr")).setInfo(exprClass.tpe) - val getStatement = - ValDef( - exprSymbol, - Apply( - Select( - Select( - Ident(cpPackage), - serializationModule - ), - getExprFunction - ), - List(Literal(Constant(filename))) - ) - ) - (getStatement, exprSymbol) + val exprSym = owner.newValue(NoPosition, unit.fresh.newName(NoPosition, "expr")).setInfo(exprClass.tpe) + val getStatement = VAL(exprSym) === ((cpPackage DOT serializationModule DOT getExprFunction) APPLY LIT(filename)) + (getStatement, exprSym) + } + + def invokeSolver(progSym : Symbol, exprSym : Symbol) : Tree = { + val solverSym = owner.newValue(NoPosition, unit.fresh.newName(NoPosition, "solver")).setInfo(fairZ3SolverClass.tpe) + val solverDeclaration = VAL(solverSym) === NEW(ID(fairZ3SolverClass), NEW(ID(defaultReporter))) + val setProgram = (solverSym DOT setProgramFunction) APPLY ID(progSym) + val invocation = (solverSym DOT restartAndDecideWithModel) APPLY (ID(exprSym), LIT(false)) + + BLOCK(solverDeclaration, setProgram, invocation, LIT(0)) } - def invokeSolver(formula : Expr, progSymbol : Symbol, exprSymbol : Symbol) : Tree = { - val solverSymbol = owner.newValue(NoPosition, unit.fresh.newName(NoPosition, "solver")).setInfo(fairZ3SolverClass.tpe) - val solverDeclaration = - ValDef( - solverSymbol, - New( - Ident(fairZ3SolverClass), - List( - List( - New( - Ident(defaultReporter), - List(Nil) - ) - ) - ) - ) - ) - val setProgram = - Apply( - Select( - Ident(solverSymbol), - setProgramFunction - ), - List(Ident(progSymbol)) - ) - - val invocation = - Apply( - Select( - Ident(solverSymbol), - restartAndDecideWithModel - ), - List(Ident(exprSymbol), Literal(Constant(false))) - ) - - Block( - solverDeclaration :: setProgram :: invocation :: Nil, - Literal(Constant(0)) + def exprToScala : (Symbol, Tree) = { + val scrutSym = owner.newValue(NoPosition, unit.fresh.newName(NoPosition, "scrut")).setInfo(exprClass.tpe) + val intSym = owner.newValue(NoPosition, unit.fresh.newName(NoPosition, "value")).setInfo(definitions.IntClass.tpe) + + val matchExpr = ID(scrutSym) MATCH ( + CASE(ID(intLiteralClass) APPLY (intSym BIND WILD())) ==> ID(intSym) , + DEFAULT ==> THROW(exceptionClass, LIT("Cannot convert FunCheck expression to Scala term")) ) + val methodSym = owner.newMethod(NoPosition, unit.fresh.newName(NoPosition, "exprToScala")).setInfo(MethodType(Nil, definitions.IntClass.tpe)) + // (methodSym, DEF(methodSym) === matchExpr) + (methodSym, DEF(methodSym) === LIT(0)) + } + + def invokeExprToScala(methodSym : Symbol) : Tree = { + methodSym APPLY () } } } diff --git a/src/funcheck/CodeExtraction.scala b/src/funcheck/CodeExtraction.scala index f1f2420cb1c9e4bdfe3adbabe989059b237bdc82..335fb44b44897af6d87ce751398f3feee508b26d 100644 --- a/src/funcheck/CodeExtraction.scala +++ b/src/funcheck/CodeExtraction.scala @@ -25,6 +25,9 @@ trait CodeExtraction extends Extractors { scala.collection.mutable.Map.empty[Symbol,ClassTypeDef] private val defsToDefs: scala.collection.mutable.Map[Symbol,FunDef] = scala.collection.mutable.Map.empty[Symbol,FunDef] + + private val reverseClassesToClasses: scala.collection.mutable.Map[ClassTypeDef,Symbol] = + scala.collection.mutable.Map.empty[ClassTypeDef,Symbol] protected def stopIfErrors: Unit = { if(reporter.hasErrors) { @@ -266,6 +269,9 @@ trait CodeExtraction extends Extractors { stopIfErrors + // Reverse map for Scala class symbols + reverseClassesToClasses ++= classesToClasses.map{ case (a, b) => (b, a) } + val programName: Identifier = unit.body match { case PackageDef(name, _) => FreshIdentifier(name.toString) case _ => FreshIdentifier("<program>") @@ -299,6 +305,19 @@ trait CodeExtraction extends Extractors { fd } + /* + def groundExprToScala(expr : Expr) : Tree = { + val converted = expr match { + case IntLiteral(v) => Literal(Constant(v)) + case BooleanLiteral(v) => Literal(Constant(v)) + case StringLiteral(v) => Literal(Constant(v)) + case CaseClass(cd,args) => New(Ident(reverseClassesToClasses(cd)), List(args.map(groundExprToScala(_)).toList)) + case _ => scala.Predef.error("Cannot convert to Scala : " + expr) + } + converted + } + */ + /** An exception thrown when non-purescala compatible code is encountered. */ sealed case class ImpureCodeEncounteredException(tree: Tree) extends Exception