Skip to content
Snippets Groups Projects
Commit 6baa0c66 authored by Ali Sinan Köksal's avatar Ali Sinan Köksal
Browse files

Using TreeDSL to make code generation code more compact and readable,...

Using TreeDSL to make code generation code more compact and readable, attempting to generate method that will convert funcheck expressions to scala terms.

parent 1b833855
Branches
Tags
No related merge requests found
package cp package cp
import scala.tools.nsc.transform.TypingTransformers import scala.tools.nsc.transform.TypingTransformers
import scala.tools.nsc.ast.TreeDSL
import purescala.FairZ3Solver import purescala.FairZ3Solver
import purescala.DefaultReporter import purescala.DefaultReporter
import purescala.Definitions._ import purescala.Definitions._
...@@ -9,9 +10,11 @@ import purescala.Trees._ ...@@ -9,9 +10,11 @@ import purescala.Trees._
trait CallTransformation trait CallTransformation
extends TypingTransformers extends TypingTransformers
with CodeGeneration with CodeGeneration
with TreeDSL
{ {
self: CPComponent => self: CPComponent =>
import global._ import global._
import CODE._
private lazy val cpPackage = definitions.getModule("cp") private lazy val cpPackage = definitions.getModule("cp")
private lazy val cpDefinitionsModule = definitions.getModule("cp.CP") private lazy val cpDefinitionsModule = definitions.getModule("cp.CP")
...@@ -21,6 +24,9 @@ trait CallTransformation ...@@ -21,6 +24,9 @@ trait CallTransformation
unit.body = new CallTransformer(unit, prog, programFilename).transform(unit.body) unit.body = new CallTransformer(unit, prog, programFilename).transform(unit.body)
class CallTransformer(unit: CompilationUnit, prog: Program, programFilename: String) extends TypingTransformer(unit) { class CallTransformer(unit: CompilationUnit, prog: Program, programFilename: String) extends TypingTransformer(unit) {
val codeGen = new CodeGenerator(unit, currentOwner)
val (exprToScalaSym, exprToScalaCode) = codeGen.exprToScala
override def transform(tree: Tree) : Tree = { override def transform(tree: Tree) : Tree = {
tree match { tree match {
case a @ Apply(TypeApply(Select(s: Select, n), _), rhs @ List(predicate: Function)) if (cpDefinitionsModule == s.symbol && n.toString == "choose") => { case a @ Apply(TypeApply(Select(s: Select, n), _), rhs @ List(predicate: Function)) if (cpDefinitionsModule == s.symbol && n.toString == "choose") => {
...@@ -33,16 +39,15 @@ trait CallTransformation ...@@ -33,16 +39,15 @@ trait CallTransformation
println("Here is the extracted FunDef:") println("Here is the extracted FunDef:")
println(fd) println(fd)
val codeGen = new CodeGenerator(unit, currentOwner)
fd.body match { fd.body match {
case None => println("Could not extract choose predicate: " + funBody); super.transform(tree) case None => println("Could not extract choose predicate: " + funBody); super.transform(tree)
case Some(b) => case Some(b) =>
val exprFilename = writeExpr(b) val exprFilename = writeExpr(b)
val (programGet, progSym) = codeGen.getProgram(programFilename) val (programGet, progSym) = codeGen.getProgram(programFilename)
val (exprGet, exprSym) = codeGen.getExpr(exprFilename) val (exprGet, exprSym) = codeGen.getExpr(exprFilename)
val solverInvocation = codeGen.invokeSolver(b, progSym, exprSym) val solverInvocation = codeGen.invokeSolver(progSym, exprSym)
val code = Block(programGet :: exprGet :: Nil, solverInvocation) val exprToScalaInvocation = codeGen.invokeExprToScala(exprToScalaSym)
val code = BLOCK(programGet, exprGet, solverInvocation) //, exprToScalaInvocation)
typer.typed(atOwner(currentOwner) { typer.typed(atOwner(currentOwner) {
code code
...@@ -50,6 +55,23 @@ trait CallTransformation ...@@ -50,6 +55,23 @@ trait CallTransformation
} }
} }
case cd @ ClassDef(mods, name, tparams, impl) if (cd.symbol.isModuleClass && tparams.isEmpty && !cd.symbol.isSynthetic) => {
println("I'm inside the object " + name.toString + " !")
atOwner(tree.symbol) {
treeCopy.ClassDef(tree, transformModifiers(mods), name,
transformTypeDefs(tparams), impl match {
case Template(parents, self, body) =>
treeCopy.Template(impl, transformTrees(parents),
transformValDef(self), typer.typed(atOwner(currentOwner) {exprToScalaCode}) :: transformStats(body, tree.symbol))
})
}
}
case dd @ DefDef(mods, name, _, _, _, _) => {
super.transform(tree)
}
case _ => super.transform(tree) case _ => super.transform(tree)
} }
} }
......
...@@ -5,6 +5,9 @@ import purescala.Trees._ ...@@ -5,6 +5,9 @@ import purescala.Trees._
trait CodeGeneration { trait CodeGeneration {
self: CallTransformation => self: CallTransformation =>
import global._ import global._
import CODE._
private lazy val exceptionClass = definitions.getClass("java.lang.Exception")
private lazy val cpPackage = definitions.getModule("cp") private lazy val cpPackage = definitions.getModule("cp")
...@@ -19,6 +22,7 @@ trait CodeGeneration { ...@@ -19,6 +22,7 @@ trait CodeGeneration {
private lazy val treesModule = definitions.getModule("purescala.Trees") private lazy val treesModule = definitions.getModule("purescala.Trees")
private lazy val exprClass = definitions.getClass("purescala.Trees.Expr") private lazy val exprClass = definitions.getClass("purescala.Trees.Expr")
private lazy val intLiteralClass = definitions.getClass("purescala.Trees.IntLiteral")
private lazy val fairZ3SolverClass = definitions.getClass("purescala.FairZ3Solver") private lazy val fairZ3SolverClass = definitions.getClass("purescala.FairZ3Solver")
private lazy val restartAndDecideWithModel = definitions.getMember(fairZ3SolverClass, "restartAndDecideWithModel") private lazy val restartAndDecideWithModel = definitions.getMember(fairZ3SolverClass, "restartAndDecideWithModel")
...@@ -29,82 +33,41 @@ trait CodeGeneration { ...@@ -29,82 +33,41 @@ trait CodeGeneration {
class CodeGenerator(unit : CompilationUnit, owner : Symbol) { class CodeGenerator(unit : CompilationUnit, owner : Symbol) {
def getProgram(filename : String) : (Tree, Symbol) = { def getProgram(filename : String) : (Tree, Symbol) = {
val progSymbol = owner.newValue(NoPosition, unit.fresh.newName(NoPosition, "prog")).setInfo(programClass.tpe) val progSym = owner.newValue(NoPosition, unit.fresh.newName(NoPosition, "prog")).setInfo(programClass.tpe)
val getStatement = val getStatement = VAL(progSym) === ((cpPackage DOT serializationModule DOT getProgramFunction) APPLY LIT(filename))
ValDef( (getStatement, progSym)
progSymbol,
Apply(
Select(
Select(
Ident(cpPackage),
serializationModule
) ,
getProgramFunction
),
List(Literal(Constant(filename)))
)
)
(getStatement, progSymbol)
} }
def getExpr(filename : String) : (Tree, Symbol) = { def getExpr(filename : String) : (Tree, Symbol) = {
val exprSymbol = owner.newValue(NoPosition, unit.fresh.newName(NoPosition, "expr")).setInfo(exprClass.tpe) val exprSym = owner.newValue(NoPosition, unit.fresh.newName(NoPosition, "expr")).setInfo(exprClass.tpe)
val getStatement = val getStatement = VAL(exprSym) === ((cpPackage DOT serializationModule DOT getExprFunction) APPLY LIT(filename))
ValDef( (getStatement, exprSym)
exprSymbol, }
Apply(
Select( def invokeSolver(progSym : Symbol, exprSym : Symbol) : Tree = {
Select( val solverSym = owner.newValue(NoPosition, unit.fresh.newName(NoPosition, "solver")).setInfo(fairZ3SolverClass.tpe)
Ident(cpPackage), val solverDeclaration = VAL(solverSym) === NEW(ID(fairZ3SolverClass), NEW(ID(defaultReporter)))
serializationModule val setProgram = (solverSym DOT setProgramFunction) APPLY ID(progSym)
), val invocation = (solverSym DOT restartAndDecideWithModel) APPLY (ID(exprSym), LIT(false))
getExprFunction
), BLOCK(solverDeclaration, setProgram, invocation, LIT(0))
List(Literal(Constant(filename)))
)
)
(getStatement, exprSymbol)
} }
def invokeSolver(formula : Expr, progSymbol : Symbol, exprSymbol : Symbol) : Tree = { def exprToScala : (Symbol, Tree) = {
val solverSymbol = owner.newValue(NoPosition, unit.fresh.newName(NoPosition, "solver")).setInfo(fairZ3SolverClass.tpe) val scrutSym = owner.newValue(NoPosition, unit.fresh.newName(NoPosition, "scrut")).setInfo(exprClass.tpe)
val solverDeclaration = val intSym = owner.newValue(NoPosition, unit.fresh.newName(NoPosition, "value")).setInfo(definitions.IntClass.tpe)
ValDef(
solverSymbol, val matchExpr = ID(scrutSym) MATCH (
New( CASE(ID(intLiteralClass) APPLY (intSym BIND WILD())) ==> ID(intSym) ,
Ident(fairZ3SolverClass), DEFAULT ==> THROW(exceptionClass, LIT("Cannot convert FunCheck expression to Scala term"))
List(
List(
New(
Ident(defaultReporter),
List(Nil)
)
)
)
)
)
val setProgram =
Apply(
Select(
Ident(solverSymbol),
setProgramFunction
),
List(Ident(progSymbol))
)
val invocation =
Apply(
Select(
Ident(solverSymbol),
restartAndDecideWithModel
),
List(Ident(exprSymbol), Literal(Constant(false)))
)
Block(
solverDeclaration :: setProgram :: invocation :: Nil,
Literal(Constant(0))
) )
val methodSym = owner.newMethod(NoPosition, unit.fresh.newName(NoPosition, "exprToScala")).setInfo(MethodType(Nil, definitions.IntClass.tpe))
// (methodSym, DEF(methodSym) === matchExpr)
(methodSym, DEF(methodSym) === LIT(0))
}
def invokeExprToScala(methodSym : Symbol) : Tree = {
methodSym APPLY ()
} }
} }
} }
...@@ -25,6 +25,9 @@ trait CodeExtraction extends Extractors { ...@@ -25,6 +25,9 @@ trait CodeExtraction extends Extractors {
scala.collection.mutable.Map.empty[Symbol,ClassTypeDef] scala.collection.mutable.Map.empty[Symbol,ClassTypeDef]
private val defsToDefs: scala.collection.mutable.Map[Symbol,FunDef] = private val defsToDefs: scala.collection.mutable.Map[Symbol,FunDef] =
scala.collection.mutable.Map.empty[Symbol,FunDef] scala.collection.mutable.Map.empty[Symbol,FunDef]
private val reverseClassesToClasses: scala.collection.mutable.Map[ClassTypeDef,Symbol] =
scala.collection.mutable.Map.empty[ClassTypeDef,Symbol]
protected def stopIfErrors: Unit = { protected def stopIfErrors: Unit = {
if(reporter.hasErrors) { if(reporter.hasErrors) {
...@@ -266,6 +269,9 @@ trait CodeExtraction extends Extractors { ...@@ -266,6 +269,9 @@ trait CodeExtraction extends Extractors {
stopIfErrors stopIfErrors
// Reverse map for Scala class symbols
reverseClassesToClasses ++= classesToClasses.map{ case (a, b) => (b, a) }
val programName: Identifier = unit.body match { val programName: Identifier = unit.body match {
case PackageDef(name, _) => FreshIdentifier(name.toString) case PackageDef(name, _) => FreshIdentifier(name.toString)
case _ => FreshIdentifier("<program>") case _ => FreshIdentifier("<program>")
...@@ -299,6 +305,19 @@ trait CodeExtraction extends Extractors { ...@@ -299,6 +305,19 @@ trait CodeExtraction extends Extractors {
fd fd
} }
/*
def groundExprToScala(expr : Expr) : Tree = {
val converted = expr match {
case IntLiteral(v) => Literal(Constant(v))
case BooleanLiteral(v) => Literal(Constant(v))
case StringLiteral(v) => Literal(Constant(v))
case CaseClass(cd,args) => New(Ident(reverseClassesToClasses(cd)), List(args.map(groundExprToScala(_)).toList))
case _ => scala.Predef.error("Cannot convert to Scala : " + expr)
}
converted
}
*/
/** An exception thrown when non-purescala compatible code is encountered. */ /** An exception thrown when non-purescala compatible code is encountered. */
sealed case class ImpureCodeEncounteredException(tree: Tree) extends Exception sealed case class ImpureCodeEncounteredException(tree: Tree) extends Exception
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment