Skip to content
Snippets Groups Projects
Commit 4d278747 authored by Ali Sinan Köksal's avatar Ali Sinan Köksal
Browse files

Now saving original Scala functions as part of our Term instances

parent 637c2fc7
No related branches found
No related tags found
No related merge requests found
import cp.Definitions._ import cp.Definitions._
import cp.Constraints._ import cp.Terms._
import purescala.Stopwatch import purescala.Stopwatch
@spec object Specs { @spec object Specs {
......
...@@ -94,7 +94,7 @@ trait CallTransformation ...@@ -94,7 +94,7 @@ trait CallTransformation
transformHelper(tree, function, codeGen) match { transformHelper(tree, function, codeGen) match {
case Some((serializedInputVarList, serializedOutputVars, serializedExpr, inputVarValues, arity)) => { case Some((serializedInputVarList, serializedOutputVars, serializedExpr, inputVarValues, arity)) => {
// create constraint instance // create constraint instance
val code = codeGen.newBaseTerm(exprToScalaSym, serializedProg, serializedInputVarList, serializedOutputVars, serializedExpr, inputVarValues, arity) val code = codeGen.newBaseTerm(exprToScalaSym, serializedProg, serializedInputVarList, serializedOutputVars, serializedExpr, inputVarValues, function, typeTreeList, arity)
typer.typed(atOwner(currentOwner) { typer.typed(atOwner(currentOwner) {
code code
......
...@@ -229,14 +229,16 @@ trait CodeGeneration { ...@@ -229,14 +229,16 @@ trait CodeGeneration {
(scalaPackage DOT collectionModule DOT immutableModule DOT definitions.ListModule DOT listModuleApplyFunction) APPLY (inputVarTrees) (scalaPackage DOT collectionModule DOT immutableModule DOT definitions.ListModule DOT listModuleApplyFunction) APPLY (inputVarTrees)
} }
def newBaseTerm(exprToScalaSym : Symbol, serializedProg : Serialized, serializedInputVarList : Serialized, serializedOutputVars : Serialized, serializedExpr : Serialized, inputVarValues : Tree, arity : Int) : Tree = { def newBaseTerm(exprToScalaSym : Symbol, serializedProg : Serialized, serializedInputVarList : Serialized, serializedOutputVars : Serialized, serializedExpr : Serialized, inputVarValues : Tree, function : Function, typeTreeList : List[Tree], arity : Int) : Tree = {
termModules(arity) APPLY ( TypeApply(
Ident(termModules(arity)), typeTreeList) APPLY(
newConverter(exprToScalaSym), newConverter(exprToScalaSym),
newSerialized(serializedProg), newSerialized(serializedProg),
newSerialized(serializedInputVarList), newSerialized(serializedInputVarList),
newSerialized(serializedOutputVars), newSerialized(serializedOutputVars),
newSerialized(serializedExpr), newSerialized(serializedExpr),
inputVarValues inputVarValues,
function
) )
} }
......
This diff is collapsed.
...@@ -24,7 +24,10 @@ object Utils { ...@@ -24,7 +24,10 @@ object Utils {
val replacedGParamsTuple = replacedGParams.mkString("(", ",", ")") val replacedGParamsTuple = replacedGParams.mkString("(", ",", ")")
val newTermSize = arityG + arityF - 1 val newTermSize = arityG + arityF - 1
val resultParams = (gParams.take(index) ++ fParams ++ gParams.drop(index + 1) ++ Seq("R2")).mkString("[", ",", "]") val resultParams = (gParams.take(index) ++ fParams ++ gParams.drop(index + 1) ++ Seq("R2"))
val resultParamsBrackets = resultParams.mkString("[", ",", "]")
val anonFunParams = gParams.take(index) ++ fParams ++ gParams.drop(index + 1)
val anonFunParamsParen = anonFunParams.mkString("(", ",", ")")
val fParamsBrackets = fParams.mkString("[", ",", "]") val fParamsBrackets = fParams.mkString("[", ",", "]")
val rangeType = "T" + (index + 1) val rangeType = "T" + (index + 1)
...@@ -33,11 +36,15 @@ object Utils { ...@@ -33,11 +36,15 @@ object Utils {
val classParams = (1 to arityG) map ("T" + _) val classParams = (1 to arityG) map ("T" + _)
val resultTermParams = (classParams.take(index) ++ fParams ++ classParams.drop(index + 1) ++ Seq("R")).mkString("[", ",", "]") val resultTermParams = (classParams.take(index) ++ fParams ++ classParams.drop(index + 1) ++ Seq("R")).mkString("[", ",", "]")
val anonFunArg = "(p : %s)" format (anonFunParamsParen)
val anonFunArgsF = if (arityG + arityF == 2) Seq("p") else ( ((index + 1) to (index + arityF)) map ("p._" + _) )
val anonFunArgsG = ((1 to (index)) map ("p._" + _)) ++ Seq("f.scalaExpr" + anonFunArgsF.mkString("((", ",", "))")) ++ (((index + arityF + 1) to (arityF + arityG - 1)) map ("p._" + _))
val anonFunArgsGParen = anonFunArgsG.mkString("((", ",", "))")
val s1 = val s1 =
"""private def %s%s(f : Term[%s,%s], g : Term[%s,%s]) : Term%d%s = { """private def %s%s(f : Term[%s,%s], g : Term[%s,%s]) : Term%d%s = {
val (newExpr, newTypes) = compose(f, g, %d, %d, %d) val (newExpr, newTypes) = compose(f, g, %d, %d, %d)
Term%d(f.program, newExpr, newTypes, f.converter) Term%d(f.program, newExpr, %s => g.scalaExpr%s, newTypes, f.converter)
}""" format (methodName, methodParams, fParamsTuple, "R1", replacedGParamsTuple, "R2", newTermSize, resultParams, index, arityF, arityG, newTermSize) }""" format (methodName, methodParams, fParamsTuple, "R1", replacedGParamsTuple, "R2", newTermSize, resultParamsBrackets, index, arityF, arityG, newTermSize, anonFunArg, anonFunArgsGParen)
val s2 = val s2 =
"""def compose%d%s(other : Term%d%s) : Term%d%s = %s(other, this)""" format (index, fParamsBrackets, arityF, otherTypeParams, resultTermArity, resultTermParams, methodName) """def compose%d%s(other : Term%d%s) : Term%d%s = %s(other, this)""" format (index, fParamsBrackets, arityF, otherTypeParams, resultTermArity, resultTermParams, methodName)
...@@ -118,32 +125,34 @@ object Utils { ...@@ -118,32 +125,34 @@ object Utils {
val booleanTermTraitName = "Term%d%s" format (arity, (argParams ++ Seq("Boolean")).mkString("[", ",", "]")) val booleanTermTraitName = "Term%d%s" format (arity, (argParams ++ Seq("Boolean")).mkString("[", ",", "]"))
val objectString = val objectString =
"""object Term%d { """object Term%d {
def apply%s(conv : Converter, serializedProg : Serialized, serializedInputVars: Serialized, serializedOutputVars : Serialized, serializedExpr : Serialized, inputVarValues : Seq[Expr]) = { def apply%s(conv : Converter, serializedProg : Serialized, serializedInputVars: Serialized, serializedOutputVars : Serialized, serializedExpr : Serialized, inputVarValues : Seq[Expr], scalaExpr : %s => %s) = {
val (converter, program, expr, types) = Term.processArgs(conv, serializedProg, serializedInputVars, serializedOutputVars, serializedExpr, inputVarValues) val (converter, program, expr, types) = Term.processArgs(conv, serializedProg, serializedInputVars, serializedOutputVars, serializedExpr, inputVarValues)
new %s(program, expr, types, converter) with %s new %s(program, expr, scalaExpr%s, types, converter) with %s
} }
def apply%s(program : Program, expr : Expr, types : Seq[TypeTree], converter : Converter) = def apply%s(program : Program, expr : Expr, scalaExpr : (%s) => %s, types : Seq[TypeTree], converter : Converter) =
new %s(program, expr, types, converter) with %s new %s(program, expr, scalaExpr, types, converter) with %s
}""" format (arity, applyParamString, termClassName, termTraitName, applyParamString, termClassName, termTraitName) }""" format (arity, applyParamString, argParamTuple, "R", termClassName, if (arity == 1) "" else ".tupled", termTraitName, applyParamString, argParamTuple, "R", termClassName, termTraitName)
val anonFunArgs = "(p : %s)" format (argParamTuple)
val anonFunArgTuple = "(p)"
val binaryOpObjectString = val binaryOpObjectString =
"""object %sConstraint%d { """object %sConstraint%d {
def apply%s(l : %s, r : %s) : %s = (l, r) match { def apply%s(l : %s, r : %s) : %s = (l, r) match {
case (Term(p1,ex1,ts1,conv1), Term(p2,ex2,ts2,conv2)) => Term%d(p1,%s(ex1,ex2),ts1,conv1) case (Term(p1,ex1,scalaEx1,ts1,conv1), Term(p2,ex2,scalaEx2,ts2,conv2)) => Term%d(p1,%s(ex1,ex2),%s => scalaEx1%s %s scalaEx2%s,ts1,conv1)
} }
}""" }"""
val orObjectString = binaryOpObjectString format ("Or", arity, argParamsString, booleanTermClassName, booleanTermClassName, booleanTermTraitName, arity, "Or") val orObjectString = binaryOpObjectString format ("Or", arity, argParamsString, booleanTermClassName, booleanTermClassName, booleanTermTraitName, arity, "Or", anonFunArgs, anonFunArgTuple, "||", anonFunArgTuple)
val andObjectString = binaryOpObjectString format ("And", arity, argParamsString, booleanTermClassName, booleanTermClassName, booleanTermTraitName, arity, "And") val andObjectString = binaryOpObjectString format ("And", arity, argParamsString, booleanTermClassName, booleanTermClassName, booleanTermTraitName, arity, "And", anonFunArgs, anonFunArgTuple, "&&", anonFunArgTuple)
val unaryOpObjectString = val unaryOpObjectString =
"""object %sConstraint%d { """object %sConstraint%d {
def apply%s(c : %s) : %s = c match { def apply%s(c : %s) : %s = c match {
case Term(p,ex,ts,conv) => Term%d(p,%s(ex),ts,conv) case Term(p,ex,scalaEx,ts,conv) => Term%d(p,%s(ex),%s => %s scalaEx%s,ts,conv)
} }
}""" }"""
val notObjectString = unaryOpObjectString format ("Not", arity, argParamsString, booleanTermClassName, booleanTermTraitName, arity, "Not") val notObjectString = unaryOpObjectString format ("Not", arity, argParamsString, booleanTermClassName, booleanTermTraitName, arity, "Not", anonFunArgs, "!", anonFunArgTuple)
List(objectString, orObjectString, andObjectString, notObjectString).mkString("\n\n") List(objectString, orObjectString, andObjectString, notObjectString).mkString("\n\n")
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment