Skip to content
Snippets Groups Projects
Commit 0bf89378 authored by ravi's avatar ravi
Browse files

Packing all states into a single case class. Need to modify type checker so...

Packing all states into a single case class. Need to modify type checker so that types are correctly re-inferred.
parent 91a8b6c4
No related branches found
No related tags found
No related merge requests found
...@@ -33,7 +33,17 @@ object $ { ...@@ -33,7 +33,17 @@ object $ {
@library @library
case class WithState[T](v: T) { case class WithState[T](v: T) {
@extern @extern
def withState[U](x: Set[$[U]]): T = sys.error("withState method is not executable!") def withState[U](u: Set[$[U]]): T = sys.error("withState method is not executable!")
@extern
def withState[U, V](u: Set[$[U]], v: Set[$[V]]): T = sys.error("withState method is not executable!")
@extern
def withState[U, V, W](u: Set[$[U]], v: Set[$[V]], w: Set[$[W]]): T = sys.error("withState method is not executable!")
@extern
def withState[U, V, W, X](u: Set[$[U]], v: Set[$[V]], w: Set[$[W]], x: Set[$[X]]): T = sys.error("withState method is not executable!")
// extend this to more arguments if needed
} }
@inline @inline
......
...@@ -81,6 +81,7 @@ object LazinessEliminationPhase extends TransformationPhase { ...@@ -81,6 +81,7 @@ object LazinessEliminationPhase extends TransformationPhase {
//println("After closure conversion: \n" + ScalaPrinter.apply(progWithClosures, purescala.PrinterOptions(printUniqueIds = true))) //println("After closure conversion: \n" + ScalaPrinter.apply(progWithClosures, purescala.PrinterOptions(printUniqueIds = true)))
prettyPrintProgramToFile(progWithClosures, ctx, "-closures") prettyPrintProgramToFile(progWithClosures, ctx, "-closures")
} }
System.exit(0)
//Rectify type parameters and local types //Rectify type parameters and local types
val typeCorrectProg = (new TypeRectifier(progWithClosures, tp => tp.id.name.endsWith("@"))).apply val typeCorrectProg = (new TypeRectifier(progWithClosures, tp => tp.id.name.endsWith("@"))).apply
......
...@@ -53,6 +53,7 @@ object LazinessUtil { ...@@ -53,6 +53,7 @@ object LazinessUtil {
val pgmText = pat.replaceAllIn(ScalaPrinter.apply(p), val pgmText = pat.replaceAllIn(ScalaPrinter.apply(p),
m => m.group("base") + m.group("mid") + ( m => m.group("base") + m.group("mid") + (
if (!m.group("star").isEmpty()) "S" else "") + m.group("rest")) if (!m.group("star").isEmpty()) "S" else "") + m.group("rest"))
//val pgmText = ScalaPrinter.apply(p)
out.write(pgmText) out.write(pgmText)
out.close() out.close()
} catch { } catch {
...@@ -108,8 +109,12 @@ object LazinessUtil { ...@@ -108,8 +109,12 @@ object LazinessUtil {
case _ => false case _ => false
} }
/**
* There are many overloads of withState functions with different number
* of arguments. All of them should pass this check.
*/
def isWithStateFun(e: Expr)(implicit p: Program): Boolean = e match { def isWithStateFun(e: Expr)(implicit p: Program): Boolean = e match {
case FunctionInvocation(TypedFunDef(fd, _), Seq(_, _)) => case FunctionInvocation(TypedFunDef(fd, _), _) =>
fullName(fd)(p) == "leon.lazyeval.WithState.withState" fullName(fd)(p) == "leon.lazyeval.WithState.withState"
case _ => false case _ => false
} }
...@@ -168,6 +173,10 @@ object LazinessUtil { ...@@ -168,6 +173,10 @@ object LazinessUtil {
name.substring(4) name.substring(4)
} }
def typeToFieldName(name: String) = {
name.toLowerCase()
}
def closureConsName(typeName: String) = { def closureConsName(typeName: String) = {
"new@" + typeName "new@" + typeName
} }
...@@ -184,38 +193,6 @@ object LazinessUtil { ...@@ -184,38 +193,6 @@ object LazinessUtil {
fd.id.name.startsWith("eval@") fd.id.name.startsWith("eval@")
} }
/**
* Returns all functions that 'need' states to be passed in
* and those that return a new state.
* TODO: implement backwards BFS by reversing the graph
*/
/*def funsNeedingnReturningState(prog: Program) = {
val cg = CallGraphUtil.constructCallGraph(prog, false, true)
var needRoots = Set[FunDef]()
var retRoots = Set[FunDef]()
prog.definedFunctions.foreach {
case fd if fd.hasBody && !fd.isLibrary =>
postTraversal {
case finv: FunctionInvocation if isLazyInvocation(finv)(prog) =>
// the lazy invocation constructor will need the state
needRoots += fd
case finv: FunctionInvocation if isEvaluatedInvocation(finv)(prog) =>
needRoots += fd
case finv: FunctionInvocation if isValueInvocation(finv)(prog) =>
needRoots += fd
retRoots += fd
case _ =>
;
}(fd.body.get)
case _ => ;
}
val funsNeedStates = prog.definedFunctions.filterNot(fd =>
cg.transitiveCallees(fd).toSet.intersect(needRoots).isEmpty).toSet
val funsRetStates = prog.definedFunctions.filterNot(fd =>
cg.transitiveCallees(fd).toSet.intersect(retRoots).isEmpty).toSet
(funsNeedStates, funsRetStates)
}*/
def freshenTypeArguments(tpe: TypeTree): TypeTree = { def freshenTypeArguments(tpe: TypeTree): TypeTree = {
tpe match { tpe match {
case NAryType(targs, tcons) => case NAryType(targs, tcons) =>
......
This diff is collapsed.
...@@ -144,4 +144,66 @@ class LazyClosureFactory(p: Program) { ...@@ -144,4 +144,66 @@ class LazyClosureFactory(p: Program) {
* This avoids the use of additional maps. * This avoids the use of additional maps.
*/ */
def lazyTypeNameOfClosure(cl: CaseClassDef) = adtNameToTypeName(cl.parent.get.classDef.id.name) def lazyTypeNameOfClosure(cl: CaseClassDef) = adtNameToTypeName(cl.parent.get.classDef.id.name)
/**
* Define a state as an ADT whose fields are sets of closures.
* Note that we need to ensure that there are state ADT is not recursive.
*/
val state = {
var tparams = Seq[TypeParameter]()
var i = 0
def freshTParams(n: Int): Seq[TypeParameter] = {
val start = i + 1
i += n // create 'n' fresh ids
val nparams = (start to i).map(index => TypeParameter.fresh("T"+index))
tparams ++= nparams
nparams
}
// field of the ADT
val fields = lazyTypeNames map { tn =>
val absClass = absClosureType(tn)
val tparams = freshTParams(absClass.tparams.size)
val fldType = SetType(AbstractClassType(absClass, tparams))
ValDef(FreshIdentifier(typeToFieldName(tn), fldType))
}
val ccd = CaseClassDef(FreshIdentifier("State@"), tparams map TypeParameterDef, None, false)
ccd.setFields(fields)
ccd
}
def selectFieldOfState(tn: String, st: Expr, stType: CaseClassType) = {
val selName = typeToFieldName(tn)
stType.classDef.fields.find{ fld => fld.id.name == selName} match {
case Some(fld) =>
CaseClassSelector(stType, st, fld.id)
case _ =>
throw new IllegalStateException(s"Cannot find a field of $stType with name: $selName")
}
}
val stateUpdateFuns : Map[String, FunDef] =
lazyTypeNames.map{ tn =>
val fldname = typeToFieldName(tn)
val tparams = state.tparams.map(_.tp)
val stType = CaseClassType(state, tparams)
val param1 = FreshIdentifier("st@", stType)
val SetType(baseT) = stType.classDef.fields.find{ fld => fld.id.name == fldname}.get.getType
val param2 = FreshIdentifier("cl", baseT)
// TODO: as an optimization we can mark all these functions as inline and inline them at their callees
val updateFun = new FunDef(FreshIdentifier("updState"+tn),
state.tparams, Seq(ValDef(param1), ValDef(param2)), stType)
// create a body for the updateFun:
val nargs = state.fields.map{ fld =>
val fldSelect = CaseClassSelector(stType, param1.toVariable, fld.id)
if(fld.id.name == fldname) {
SetUnion(fldSelect, FiniteSet(Set(param2.toVariable), baseT)) // st@.tn + Set(param2)
} else {
fldSelect
}
}
val nst = CaseClass(stType, nargs)
updateFun.body = Some(nst)
(tn -> updateFun)
}.toMap
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment