diff --git a/library/lazy/package.scala b/library/lazy/package.scala index 448eb2a3deb45807c27d45e0a2629ba086236343..2a6a29cf2f11c4075db553f17047543d993d525a 100644 --- a/library/lazy/package.scala +++ b/library/lazy/package.scala @@ -33,7 +33,17 @@ object $ { @library case class WithState[T](v: T) { @extern - def withState[U](x: Set[$[U]]): T = sys.error("withState method is not executable!") + def withState[U](u: Set[$[U]]): T = sys.error("withState method is not executable!") + + @extern + def withState[U, V](u: Set[$[U]], v: Set[$[V]]): T = sys.error("withState method is not executable!") + + @extern + def withState[U, V, W](u: Set[$[U]], v: Set[$[V]], w: Set[$[W]]): T = sys.error("withState method is not executable!") + + @extern + def withState[U, V, W, X](u: Set[$[U]], v: Set[$[V]], w: Set[$[W]], x: Set[$[X]]): T = sys.error("withState method is not executable!") + // extend this to more arguments if needed } @inline diff --git a/src/main/scala/leon/laziness/LazinessEliminationPhase.scala b/src/main/scala/leon/laziness/LazinessEliminationPhase.scala index fbf45de51c6673d0eeecad523f258bfbd20cff93..becbedd5475167829d4a0fa25769fffcf1752fb5 100644 --- a/src/main/scala/leon/laziness/LazinessEliminationPhase.scala +++ b/src/main/scala/leon/laziness/LazinessEliminationPhase.scala @@ -81,6 +81,7 @@ object LazinessEliminationPhase extends TransformationPhase { //println("After closure conversion: \n" + ScalaPrinter.apply(progWithClosures, purescala.PrinterOptions(printUniqueIds = true))) prettyPrintProgramToFile(progWithClosures, ctx, "-closures") } + System.exit(0) //Rectify type parameters and local types val typeCorrectProg = (new TypeRectifier(progWithClosures, tp => tp.id.name.endsWith("@"))).apply diff --git a/src/main/scala/leon/laziness/LazinessUtil.scala b/src/main/scala/leon/laziness/LazinessUtil.scala index 80a43f648d5fa4f70e7bb3b777193fd1ef8d2ce2..e19f31b271467832ab8b7ae11439fa068d72cca9 100644 --- a/src/main/scala/leon/laziness/LazinessUtil.scala +++ b/src/main/scala/leon/laziness/LazinessUtil.scala @@ -53,6 +53,7 @@ object LazinessUtil { val pgmText = pat.replaceAllIn(ScalaPrinter.apply(p), m => m.group("base") + m.group("mid") + ( if (!m.group("star").isEmpty()) "S" else "") + m.group("rest")) + //val pgmText = ScalaPrinter.apply(p) out.write(pgmText) out.close() } catch { @@ -108,8 +109,12 @@ object LazinessUtil { case _ => false } + /** + * There are many overloads of withState functions with different number + * of arguments. All of them should pass this check. + */ def isWithStateFun(e: Expr)(implicit p: Program): Boolean = e match { - case FunctionInvocation(TypedFunDef(fd, _), Seq(_, _)) => + case FunctionInvocation(TypedFunDef(fd, _), _) => fullName(fd)(p) == "leon.lazyeval.WithState.withState" case _ => false } @@ -168,6 +173,10 @@ object LazinessUtil { name.substring(4) } + def typeToFieldName(name: String) = { + name.toLowerCase() + } + def closureConsName(typeName: String) = { "new@" + typeName } @@ -184,38 +193,6 @@ object LazinessUtil { fd.id.name.startsWith("eval@") } - /** - * Returns all functions that 'need' states to be passed in - * and those that return a new state. - * TODO: implement backwards BFS by reversing the graph - */ - /*def funsNeedingnReturningState(prog: Program) = { - val cg = CallGraphUtil.constructCallGraph(prog, false, true) - var needRoots = Set[FunDef]() - var retRoots = Set[FunDef]() - prog.definedFunctions.foreach { - case fd if fd.hasBody && !fd.isLibrary => - postTraversal { - case finv: FunctionInvocation if isLazyInvocation(finv)(prog) => - // the lazy invocation constructor will need the state - needRoots += fd - case finv: FunctionInvocation if isEvaluatedInvocation(finv)(prog) => - needRoots += fd - case finv: FunctionInvocation if isValueInvocation(finv)(prog) => - needRoots += fd - retRoots += fd - case _ => - ; - }(fd.body.get) - case _ => ; - } - val funsNeedStates = prog.definedFunctions.filterNot(fd => - cg.transitiveCallees(fd).toSet.intersect(needRoots).isEmpty).toSet - val funsRetStates = prog.definedFunctions.filterNot(fd => - cg.transitiveCallees(fd).toSet.intersect(retRoots).isEmpty).toSet - (funsNeedStates, funsRetStates) - }*/ - def freshenTypeArguments(tpe: TypeTree): TypeTree = { tpe match { case NAryType(targs, tcons) => diff --git a/src/main/scala/leon/laziness/LazyClosureConverter.scala b/src/main/scala/leon/laziness/LazyClosureConverter.scala index 35046cdb49c5cf2b7c0a905c46439d79542dfa94..92c22569a15c4a8f9dc6b879e39ef782786b28ab 100644 --- a/src/main/scala/leon/laziness/LazyClosureConverter.scala +++ b/src/main/scala/leon/laziness/LazyClosureConverter.scala @@ -60,33 +60,21 @@ class LazyClosureConverter(p: Program, ctx: LeonContext, } val nretType = replaceLazyTypes(fd.returnType) val nfd = if (funsNeedStates(fd)) { // this also includes lazy constructors - var newTParams = Seq[TypeParameterDef]() - val stTypes = tnames map { tn => - val absClass = closureFactory.absClosureType(tn) - val tparams = absClass.tparams.map(_ => - TypeParameter.fresh("P@")) - newTParams ++= tparams map TypeParameterDef - SetType(AbstractClassType(absClass, tparams)) - } - val stParams = stTypes.map { stType => - ValDef(FreshIdentifier("st@", stType, true)) - } -// val flParams = -// if(transCons(fd)) { -// Seq(ValDef(FreshIdentifier("fl@", fvFactory.absType, true))) -// } else -// Seq() + // create fresh type parameters for the state + val ntparams = closureFactory.state.tparams.map(_ => TypeParameter.fresh("P@")) + val stType = CaseClassType(closureFactory.state, ntparams) + val stParam = ValDef(FreshIdentifier("st@", stType)) val retTypeWithState = if (funsRetStates(fd)) - TupleType(nretType +: stTypes) + TupleType(Seq(nretType, stType)) else nretType // the type parameters will be unified later - new FunDef(FreshIdentifier(fd.id.name, Untyped), - fd.tparams ++ newTParams, nparams ++ stParams, retTypeWithState) + new FunDef(FreshIdentifier(fd.id.name), fd.tparams ++ (ntparams map TypeParameterDef), + nparams :+ stParam, retTypeWithState) // body of these functions are defined later } else { - new FunDef(FreshIdentifier(fd.id.name, Untyped), fd.tparams, nparams, nretType) + new FunDef(FreshIdentifier(fd.id.name), fd.tparams, nparams, nretType) } // copy annotations fd.flags.foreach(nfd.addFlag(_)) @@ -172,10 +160,10 @@ class LazyClosureConverter(p: Program, ctx: LeonContext, * doesn't support 'Any' type yet. So we have to have multiple * state (though this is much clearer it results in more complicated code) */ - def getStateType(tname: String, tparams: Seq[TypeParameter]) = { + /*def getStateType(tname: String, tparams: Seq[TypeParameter]) = { //val (_, absdef, _) = tpeToADT(tname) SetType(AbstractClassType(closureFactory.absClosureType(tname), tparams)) - } + }*/ def replaceLazyTypes(t: TypeTree): TypeTree = { unwrapLazyType(t) match { @@ -203,50 +191,57 @@ class LazyClosureConverter(p: Program, ctx: LeonContext, val cdefs = closureFactory.closures(tname) // construct parameters and return types - val tparams = getTypeParameters(tpe) - val tparamDefs = tparams map TypeParameterDef.apply - val param1 = FreshIdentifier("cl", AbstractClassType(absdef, tparams)) - val stType = getStateType(tname, tparams) + val recvTparams = getTypeParameters(tpe) + val stTparams = closureFactory.state.tparams.map(_ => TypeParameter.fresh("P@")) + val param1 = FreshIdentifier("cl", AbstractClassType(absdef, recvTparams)) + val stType = CaseClassType(closureFactory.state, stTparams) val param2 = FreshIdentifier("st@", stType) val retType = TupleType(Seq(tpe, stType)) // create a eval function - val dfun = new FunDef(FreshIdentifier(evalFunctionName(absdef.id.name), Untyped), - tparamDefs, Seq(ValDef(param1), ValDef(param2)), retType) + val dfun = new FunDef(FreshIdentifier(evalFunctionName(absdef.id.name)), + (recvTparams ++ stTparams) map TypeParameterDef, + Seq(ValDef(param1), ValDef(param2)), retType) + //println("Creating eval function: "+dfun) // assign body of the eval fucntion // create a match case to switch over the possible class defs and invoke the corresponding functions val bodyMatchCases = cdefs map { cdef => - val ctype = CaseClassType(cdef, tparams) // we assume that the type parameters of cdefs are same as absdefs + val ctype = CaseClassType(cdef, recvTparams) // we assume that the type parameters of cdefs are same as absdefs val binder = FreshIdentifier("t", ctype) val pattern = InstanceOfPattern(Some(binder), ctype) // create a body of the match // the last field represents the result val args = cdef.fields.dropRight(1) map { fld => - CaseClassSelector(ctype, binder.toVariable, fld.id) } + CaseClassSelector(ctype, binder.toVariable, fld.id) + } val op = closureFactory.caseClassToOp(cdef) val targetFun = funMap(op) // invoke the target fun with appropriate values - val invoke = FunctionInvocation(TypedFunDef(targetFun, tparams), - args ++ (if (funsNeedStates(op)) Seq(param2.toVariable) else Seq())) - val invokeRes = FreshIdentifier("dres", invoke.getType) + val invoke = + if (funsNeedStates(op)) + FunctionInvocation(TypedFunDef(targetFun, recvTparams ++ stTparams), args :+ param2.toVariable) + else + FunctionInvocation(TypedFunDef(targetFun, recvTparams), args) + val invokeRes = FreshIdentifier("dres", invoke.getType) //println(s"invoking function $targetFun with args $args") - val (valPart, stPart) = if (funsRetStates(op)) { - // TODO: here we are assuming that only one state is used, fix this. - val invokeSt = TupleSelect(invokeRes.toVariable, 2) - (TupleSelect(invokeRes.toVariable, 1), - SetUnion(invokeSt, FiniteSet(Set(binder.toVariable), stType.base))) - } else { - (invokeRes.toVariable, - SetUnion(param2.toVariable, FiniteSet(Set(binder.toVariable), stType.base))) - } + val updateFun = TypedFunDef(closureFactory.stateUpdateFuns(tname), stTparams) + val (valPart, stPart) = + if (funsRetStates(op)) { + val invokeSt = TupleSelect(invokeRes.toVariable, 2) + val nst = FunctionInvocation(updateFun, Seq(invokeSt, binder.toVariable)) + (TupleSelect(invokeRes.toVariable, 1), nst) + } else { + val nst = FunctionInvocation(updateFun, Seq(param2.toVariable, binder.toVariable)) + (invokeRes.toVariable, nst) + } val rhs = Let(invokeRes, invoke, Tuple(Seq(valPart, stPart))) MatchCase(pattern, None, rhs) } // create a new match case for eager evaluation val eagerCase = { val eagerDef = closureFactory.eagerClosure(tname) - val ctype = CaseClassType(eagerDef, tparams) + val ctype = CaseClassType(eagerDef, recvTparams) val binder = FreshIdentifier("t", ctype) // create a body of the match val valPart = CaseClassSelector(ctype, binder.toVariable, eagerDef.fields(0).id) @@ -285,33 +280,33 @@ class LazyClosureConverter(p: Program, ctx: LeonContext, val adt = closureFactory.absClosureType(tname) val param1Type = AbstractClassType(adt, adt.tparams.map(_.tp)) val param1 = FreshIdentifier("cc", param1Type) - val stType = SetType(param1Type) + val stTparams = closureFactory.state.tparams.map(_ => TypeParameter.fresh("P@")) + val stType = CaseClassType(closureFactory.state, stTparams) val param2 = FreshIdentifier("st@", stType) - val tparamDefs = adt.tparams - val fun = new FunDef(FreshIdentifier(closureConsName(tname)), adt.tparams, + val tparamdefs = adt.tparams ++ (stTparams map TypeParameterDef) + val fun = new FunDef(FreshIdentifier(closureConsName(tname)), tparamdefs, Seq(ValDef(param1), ValDef(param2)), param1Type) fun.body = Some(param1.toVariable) // assert that the closure in unevaluated if useRefEquality is enabled - if (refEq) { + // not supported as of now + /*if (refEq) { val resid = FreshIdentifier("res", param1Type) val postbody = Not(ElementOfSet(resid.toVariable, param2.toVariable)) fun.postcondition = Some(Lambda(Seq(ValDef(resid)), postbody)) fun.addFlag(Annotation("axiom", Seq())) - } + }*/ (tname -> fun) }.toMap - def mapBody(body: Expr): (Map[String, Expr] => Expr, Boolean) = body match { + def mapBody(body: Expr): (Option[Expr] => Expr, Boolean) = body match { case finv @ FunctionInvocation(_, Seq(FunctionInvocation(TypedFunDef(argfd, tparams), args))) // lazy construction ? if isLazyInvocation(finv)(p) => - val op = (nargs: Seq[Expr]) => ((st: Map[String, Expr]) => { + val op = (nargs: Seq[Expr]) => ((st: Option[Expr]) => { val adt = closureFactory.closureOfLazyOp(argfd) // create lets to bind the nargs to variables val (flatArgs, letCons) = nargs.foldRight((Seq[Variable](), (e : Expr) => e)){ -// case (narg : Variable, (fargs, lcons)) => -// (narg +: fargs, lcons) case (narg, (fargs, lcons)) => val id = FreshIdentifier("a", narg.getType, true) (id.toVariable +: fargs, e => Let(id, narg, lcons(e))) @@ -320,15 +315,14 @@ class LazyClosureConverter(p: Program, ctx: LeonContext, val resval = FunctionInvocation(TypedFunDef(uiFuncs(argfd)._1, tparams), flatArgs) val cc = CaseClass(CaseClassType(adt, tparams), flatArgs :+ resval) val baseLazyTypeName = closureFactory.lazyTypeNameOfClosure(adt) - val fi = FunctionInvocation(TypedFunDef(closureCons(baseLazyTypeName), tparams), - Seq(cc, st(baseLazyTypeName))) + val fi = FunctionInvocation(TypedFunDef(closureCons(baseLazyTypeName), tparams), Seq(cc, st.get)) letCons(fi) // this could be 'fi' wrapped into lets }, false) mapNAryOperator(args, op) case finv @ FunctionInvocation(_, Seq(arg)) if isEagerInvocation(finv)(p) => // here arg is guaranteed to be a variable - ((st: Map[String, Expr]) => { + ((st: Option[Expr]) => { val rootType = bestRealType(arg.getType) val tname = typeNameWOParams(rootType) val tparams = getTypeArguments(rootType) @@ -337,11 +331,16 @@ class LazyClosureConverter(p: Program, ctx: LeonContext, }, false) case finv @ FunctionInvocation(_, args) if isEvaluatedInvocation(finv)(p) => // isEval function ? - val op = (nargs: Seq[Expr]) => ((st: Map[String, Expr]) => { + val op = (nargs: Seq[Expr]) => ((stOpt: Option[Expr]) => { val narg = nargs(0) // there must be only one argument here val baseType = unwrapLazyType(narg.getType).get val tname = typeNameWOParams(baseType) - val memberTest = ElementOfSet(narg, st(tname)) // should we use subtype instead ? + // select the set using the tname + val st = stOpt.get + val stTparams = closureFactory.state.tparams.map(_.tp) // using dummy set of tparams. The correct type should be inferred later + val stType = CaseClassType(closureFactory.state, stTparams) + val cls = closureFactory.selectFieldOfState(tname, st, stType) + val memberTest = ElementOfSet(narg, cls) val subtypeTest = IsInstanceOf(narg, CaseClassType(closureFactory.eagerClosure(tname), getTypeArguments(baseType))) Or(memberTest, subtypeTest) @@ -349,7 +348,7 @@ class LazyClosureConverter(p: Program, ctx: LeonContext, mapNAryOperator(args, op) case finv @ FunctionInvocation(_, Seq(recvr, funcArg)) if isSuspInvocation(finv)(p) => - ((st: Map[String, Expr]) => { + ((st: Option[Expr]) => { // `funcArg` is a closure whose body is a function invocation //TODO: make sure the function is not partially applied in the body funcArg match { @@ -365,61 +364,72 @@ class LazyClosureConverter(p: Program, ctx: LeonContext, } }, false) - case finv @ FunctionInvocation(_, Seq(recvr, stArg)) if isWithStateFun(finv)(p) => - // recvr is a `WithStateCaseClass` and `stArg` could be an arbitrary expression that returns a state + case finv @ FunctionInvocation(_, Seq(recvr, stArgs @ _*)) if isWithStateFun(finv)(p) => + // recvr is a `WithStateCaseClass` and `stArgs` could be arbitrary expressions that return values of types of fileds of state + val numStates = closureFactory.state.fields.size + if(stArgs.size != numStates) + throw new IllegalStateException("The arguments to `withState` should equal the number of states: "+numStates) + val CaseClass(_, Seq(exprNeedingState)) = recvr - val (nexpr, exprReturnsState) = mapBody(exprNeedingState) - val (nstArg, stArgReturnsState) = mapBody(stArg) - if(stArgReturnsState) - throw new IllegalStateException("The state argument to `withState` returns a new state, which is not supported: "+finv) + val (nexprCons, exprReturnsState) = mapBody(exprNeedingState) + val nstConses = stArgs map mapBody + if(nstConses.exists(_._2)) // any 'stArg' returning state + throw new IllegalStateException("One of the arguments to `withState` returns a new state, which is not supported: "+finv) else { - ((st: Map[String, Expr]) => { - val nst = nstArg(st) - // compute the baseType + ((st: Option[Expr]) => { + // create a new state using the nstConses + val nstSets = nstConses map { case (stCons, _) => stCons(st) } + val tparams = nstSets.flatMap(nst => getTypeParameters(nst.getType)).distinct + val nst = CaseClass(CaseClassType(closureFactory.state, tparams), nstSets) + nexprCons(Some(nst)) + /* // compute the baseType stArg.getType match { - case SetType(lazyType) => + case SetType(lazyType) => // note that stArg would still have the set type val baseType = unwrapLazyType(lazyType).get val tname = typeNameWOParams(baseType) val newStates = st + (tname -> nst) nexpr(newStates) case t => throw new IllegalStateException(s"$stArg should have a set type, current type: "+t) - } + }*/ }, exprReturnsState) } case finv @ FunctionInvocation(_, args) if isValueInvocation(finv)(p) => // is value function ? - val op = (nargs: Seq[Expr]) => ((st: Map[String, Expr]) => { + val op = (nargs: Seq[Expr]) => ((stOpt: Option[Expr]) => { + val st = stOpt.get val baseType = unwrapLazyType(nargs(0).getType).get // there must be only one argument here val tname = typeNameWOParams(baseType) val dispFun = evalFunctions(tname) - val dispFunInv = FunctionInvocation(TypedFunDef(dispFun, - getTypeParameters(baseType)), nargs :+ st(tname)) - val dispRes = FreshIdentifier("dres", dispFun.returnType, true) - val nstates = tnames map { + val tparams = (getTypeParameters(baseType) ++ getTypeParameters(st.getType)).distinct + FunctionInvocation(TypedFunDef(dispFun, tparams), nargs :+ st) + //val dispRes = FreshIdentifier("dres", dispFun.returnType, true) + /*val nstates = tnames map { case `tname` => TupleSelect(dispRes.toVariable, 2) case other => st(other) - } - Let(dispRes, dispFunInv, Tuple(TupleSelect(dispRes.toVariable, 1) +: nstates)) + }*/ + //Let(dispRes, dispFunInv, Tuple(TupleSelect(dispRes.toVariable, 1) +: nstates)) }, true) mapNAryOperator(args, op) case finv @ FunctionInvocation(_, args) if isStarInvocation(finv)(p) => // is * function ? - val op = (nargs: Seq[Expr]) => ((st: Map[String, Expr]) => { + val op = (nargs: Seq[Expr]) => ((st: Option[Expr]) => { val baseType = unwrapLazyType(nargs(0).getType).get // there must be only one argument here val tname = typeNameWOParams(baseType) val dispFun = computeFunctions(tname) + // TODO: important: what do we do with type parameters here. + val tparams = getTypeParameters(baseType) FunctionInvocation(TypedFunDef(dispFun, getTypeParameters(baseType)), nargs) }, false) mapNAryOperator(args, op) case FunctionInvocation(TypedFunDef(fd, tparams), args) if funMap.contains(fd) => mapNAryOperator(args, - (nargs: Seq[Expr]) => ((st: Map[String, Expr]) => { + (nargs: Seq[Expr]) => ((st: Option[Expr]) => { val stArgs = if (funsNeedStates(fd)) { - (tnames map st.apply) + st.toSeq } else Seq() FunctionInvocation(TypedFunDef(funMap(fd), tparams), nargs ++ stArgs) }, funsRetStates(fd))) @@ -427,21 +437,18 @@ class LazyClosureConverter(p: Program, ctx: LeonContext, case Let(id, value, body) => val (valCons, valUpdatesState) = mapBody(value) val (bodyCons, bodyUpdatesState) = mapBody(body) - ((st: Map[String, Expr]) => { + ((st: Option[Expr]) => { val nval = valCons(st) if (valUpdatesState) { val freshid = FreshIdentifier(id.name, nval.getType, true) - val nextStates = tnames.zipWithIndex.map { - case (tn, i) => - TupleSelect(freshid.toVariable, i + 2) - }.toSeq - val nsMap = (tnames zip nextStates).toMap + val nextState = TupleSelect(freshid.toVariable, 2) + //val nsMap = (tnames zip nextStates).toMap val transBody = replace(Map(id.toVariable -> TupleSelect(freshid.toVariable, 1)), - bodyCons(nsMap)) + bodyCons(Some(nextState))) if (bodyUpdatesState) Let(freshid, nval, transBody) else - Let(freshid, nval, Tuple(transBody +: nextStates)) + Let(freshid, nval, Tuple(Seq(transBody, nextState))) } else Let(id, nval, bodyCons(st)) }, valUpdatesState || bodyUpdatesState) @@ -450,59 +457,56 @@ class LazyClosureConverter(p: Program, ctx: LeonContext, val (condCons, condState) = mapBody(cond) val (thnCons, thnState) = mapBody(thn) val (elzeCons, elzeState) = mapBody(elze) - ((st: Map[String, Expr]) => { + ((st: Option[Expr]) => { val (ncondCons, nst) = if (condState) { val cndExpr = condCons(st) val bder = FreshIdentifier("c", cndExpr.getType) - val condst = tnames.zipWithIndex.map { - case (tn, i) => tn -> TupleSelect(bder.toVariable, i + 2) - }.toMap + val condst = TupleSelect(bder.toVariable, 2) ((th: Expr, el: Expr) => Let(bder, cndExpr, IfExpr(TupleSelect(bder.toVariable, 1), th, el)), - condst) + Some(condst)) } else { ((th: Expr, el: Expr) => IfExpr(condCons(st), th, el), st) } val nelze = if ((condState || thnState) && !elzeState) - Tuple(elzeCons(nst) +: tnames.map(nst.apply)) + Tuple(Seq(elzeCons(nst), nst.get)) else elzeCons(nst) val nthn = if (!thnState && (condState || elzeState)) - Tuple(thnCons(nst) +: tnames.map(nst.apply)) + Tuple(Seq(thnCons(nst), nst.get)) else thnCons(nst) ncondCons(nthn, nelze) }, condState || thnState || elzeState) case MatchExpr(scr, cases) => val (scrCons, scrUpdatesState) = mapBody(scr) - val casesRes = cases.foldLeft(Seq[(Map[String, Expr] => Expr, Boolean)]()) { + val casesRes = cases.foldLeft(Seq[(Option[Expr] => Expr, Boolean)]()) { case (acc, MatchCase(pat, None, rhs)) => acc :+ mapBody(rhs) case mcase => throw new IllegalStateException("Match case with guards are not supported yet: " + mcase) } val casesUpdatesState = casesRes.exists(_._2) - ((st: Map[String, Expr]) => { + ((st: Option[Expr]) => { val scrExpr = scrCons(st) - val (nscrCons, scrst) = if (scrUpdatesState) { - val bder = FreshIdentifier("scr", scrExpr.getType) - val scrst = tnames.zipWithIndex.map { - case (tn, i) => tn -> TupleSelect(bder.toVariable, i + 2) - }.toMap - ((ncases: Seq[MatchCase]) => - Let(bder, scrExpr, MatchExpr(TupleSelect(bder.toVariable, 1), ncases)), - scrst) - } else { - //println(s"Scrutiny does not update state: current state: $st") - ((ncases: Seq[MatchCase]) => MatchExpr(scrExpr, ncases), st) - } + val (nscrCons, scrst) = + if (scrUpdatesState) { + val bder = FreshIdentifier("scr", scrExpr.getType) + val scrst = Some(TupleSelect(bder.toVariable, 2)) + ((ncases: Seq[MatchCase]) => + Let(bder, scrExpr, MatchExpr(TupleSelect(bder.toVariable, 1), ncases)), + scrst) + } else { + //println(s"Scrutiny does not update state: current state: $st") + ((ncases: Seq[MatchCase]) => MatchExpr(scrExpr, ncases), st) + } val ncases = (cases zip casesRes).map { case (MatchCase(pat, None, _), (caseCons, caseUpdatesState)) => val nrhs = if ((scrUpdatesState || casesUpdatesState) && !caseUpdatesState) - Tuple(caseCons(scrst) +: tnames.map(scrst.apply)) + Tuple(Seq(caseCons(scrst), scrst.get)) else caseCons(scrst) MatchCase(pat, None, nrhs) } @@ -513,17 +517,17 @@ class LazyClosureConverter(p: Program, ctx: LeonContext, case CaseClass(cct, args) => val ntype = replaceLazyTypes(cct).asInstanceOf[CaseClassType] mapNAryOperator(args, - (nargs: Seq[Expr]) => ((st: Map[String, Expr]) => CaseClass(ntype, nargs), false)) + (nargs: Seq[Expr]) => ((st: Option[Expr]) => CaseClass(ntype, nargs), false)) case Operator(args, op) => // here, 'op' itself does not create a new state mapNAryOperator(args, - (nargs: Seq[Expr]) => ((st: Map[String, Expr]) => op(nargs), false)) + (nargs: Seq[Expr]) => ((st: Option[Expr]) => op(nargs), false)) case t: Terminal => (_ => t, false) } - def mapNAryOperator(args: Seq[Expr], op: Seq[Expr] => (Map[String, Expr] => Expr, Boolean)) = { + def mapNAryOperator(args: Seq[Expr], op: Seq[Expr] => (Option[Expr] => Expr, Boolean)) = { // create n variables to model n lets val letvars = args.map(arg => FreshIdentifier("arg", arg.getType, true).toVariable) (args zip letvars).foldRight(op(letvars)) { @@ -531,23 +535,24 @@ class LazyClosureConverter(p: Program, ctx: LeonContext, val (argCons, stUpdateFlag) = mapBody(arg) val cl = if (!stUpdateFlag) { // here arg does not affect the newstate - (st: Map[String, Expr]) => replace(Map(letvar -> argCons(st)), accCons(st)) + (st: Option[Expr]) => replace(Map(letvar -> argCons(st)), accCons(st)) } else { // here arg does affect the newstate - (st: Map[String, Expr]) => + (st: Option[Expr]) => { val narg = argCons(st) val argres = FreshIdentifier("a", narg.getType, true).toVariable - val nstateSeq = tnames.zipWithIndex.map { + val nstate = Some(TupleSelect(argres, 2)) + /*val nstateSeq = tnames.zipWithIndex.map { // states start from index 2 case (tn, i) => TupleSelect(argres, i + 2) } val nstate = (tnames zip nstateSeq).map { case (tn, st) => (tn -> st) - }.toMap[String, Expr] + }.toMap[String, Expr]*/ val letbody = if (stUpdatedBefore) accCons(nstate) // here, 'acc' already returns a superseeding state - else Tuple(accCons(nstate) +: nstateSeq) // here, 'acc; only retruns the result + else Tuple(Seq(accCons(nstate), nstate.get)) // here, 'acc; only returns the result Let(argres.id, narg, Let(letvar.id, TupleSelect(argres, 1), letbody)) } @@ -556,24 +561,33 @@ class LazyClosureConverter(p: Program, ctx: LeonContext, } } + def fieldsOfState(st: Expr, stType: CaseClassType): Seq[Expr] = { + closureFactory.lazyTypeNames.map { tn => + closureFactory.selectFieldOfState(tn, st, stType) + } + } + def assignBodiesToFunctions = funMap foreach { case (fd, nfd) => - // /println("Considering function: "+fd) - // Here, using name to identify 'state' parameters, also relying - // on fact that nfd.params are in the same order as tnames - val stateParams = nfd.params.foldLeft(Seq[Expr]()) { + println("Considering function: "+fd) + // Here, using name to identify 'state' parameters + /*nfd.params.foldLeft(Seq[Expr]()) { case (acc, ValDef(id, _)) if id.name.startsWith("st@") => acc :+ id.toVariable case (acc, _) => acc + }*/ + val stateParam = nfd.params.collectFirst { + case vd if vd.id.name.startsWith("st@") => + vd.id.toVariable } - val initStateMap = tnames zip stateParams toMap + val stType = stateParam.map(_.getType.asInstanceOf[CaseClassType]) val (nbodyFun, bodyUpdatesState) = mapBody(fd.body.get) - val nbody = nbodyFun(initStateMap) + val nbody = nbodyFun(stateParam) val bodyWithState = - if (!bodyUpdatesState && funsRetStates(fd)) { - val freshid = FreshIdentifier("bres", nbody.getType) - Let(freshid, nbody, Tuple(freshid.toVariable +: stateParams)) - } else nbody + if (!bodyUpdatesState && funsRetStates(fd)) + Tuple(Seq(nbody, stateParam.get)) + else + nbody nfd.body = Some(simplifyLets(bodyWithState)) //println(s"Body of ${fd.id.name} after conversion&simp: ${nfd.body}") @@ -585,8 +599,8 @@ class LazyClosureConverter(p: Program, ctx: LeonContext, val (npreFun, preUpdatesState) = mapBody(fd.precondition.get) nfd.precondition = if (preUpdatesState) - Some(TupleSelect(npreFun(initStateMap), 1)) // ignore state updated by pre - else Some(npreFun(initStateMap)) + Some(TupleSelect(npreFun(stateParam), 1)) // ignore state updated by pre + else Some(npreFun(stateParam)) } // create a new result variable @@ -597,24 +611,23 @@ class LazyClosureConverter(p: Program, ctx: LeonContext, } else FreshIdentifier("r", nfd.returnType) // create an output state map - val outStateMap = + val outState = if (bodyUpdatesState || funsRetStates(fd)) { - tnames.zipWithIndex.map { - case (tn, i) => (tn -> TupleSelect(newres.toVariable, i + 2)) - }.toMap + Some(TupleSelect(newres.toVariable, 2)) } else - initStateMap + stateParam // create a specification that relates input-output states val stateRel = if (funsRetStates(fd)) { // add specs on states - val instates = initStateMap.values.toSeq - val outstates = outStateMap.values.toSeq + val instates = fieldsOfState(stateParam.get, stType.get) + val outstates = fieldsOfState(outState.get, stType.get) val stateRel = if(fd.annotations.contains("invstate")) Equals.apply _ else SubsetOf.apply _ Some(createAnd((instates zip outstates).map(p => stateRel(p._1, p._2)))) } else None + println("stateRel: "+stateRel) // create a predicate that ensures that the value part is independent of the state val valRel = @@ -636,11 +649,13 @@ class LazyClosureConverter(p: Program, ctx: LeonContext, val tpost = simplePostTransform { case e if LazinessUtil.isInStateCall(e)(p) => val baseType = getTypeArguments(e.getType).head - initStateMap(typeNameWOParams(baseType)) + val tname = typeNameWOParams(baseType) + closureFactory.selectFieldOfState(tname, stateParam.get, stType.get) case e if LazinessUtil.isOutStateCall(e)(p) => val baseType = getTypeArguments(e.getType).head - outStateMap(typeNameWOParams(baseType)) + val tname = typeNameWOParams(baseType) + closureFactory.selectFieldOfState(tname, outState.get, stType.get) case e => e }(post) @@ -650,7 +665,7 @@ class LazyClosureConverter(p: Program, ctx: LeonContext, if (bodyUpdatesState || funsRetStates(fd)) TupleSelect(newres.toVariable, 1) else newres.toVariable - val npostWithState = replace(Map(resid.toVariable -> resval), npostFun(outStateMap)) + val npostWithState = replace(Map(resid.toVariable -> resval), npostFun(outState)) val npost = if (postUpdatesState) { TupleSelect(npostWithState, 1) // ignore state updated by post @@ -662,51 +677,18 @@ class LazyClosureConverter(p: Program, ctx: LeonContext, } nfd.postcondition = Some(Lambda(Seq(ValDef(newres)), createAnd(stateRel.toList ++ valRel.toList ++ targetPost.toList))) -// if (removeRecursionViaEval) { -// uninterpretedTargets.get(fd) match { -// case Some(uitarget) => -// // uitarget uses the same identifiers as nfd -// uitarget.postcondition = targetPost -// case None => ; -// } -// } - // add invariants on state - /*val fpost = - if (funsRetStates(fd)) { // add specs on states - val instates = stateParams - val resid = if (fd.hasPostcondition) { - val Lambda(Seq(ValDef(r, _)), _) = simpPost.get - r - } else FreshIdentifier("r", nfd.returnType) - val outstates = (0 until tnames.size).map(i => TupleSelect(resid.toVariable, i + 2)) - val invstate = fd.annotations.contains("invstate") - val statePred = PredicateUtil.createAnd((instates zip outstates).map { - case (x, y) => - if (invstate) - Equals(x, y) - else SubsetOf(x, y) - }) - val post = Lambda(Seq(ValDef(resid)), (if (fd.hasPostcondition) { - val Lambda(Seq(ValDef(_, _)), p) = simpPost.get - And(p, statePred) - } else statePred)) - Some(post) - } else simpPost*/ - //if (fpost.isDefined) { - // also attach postcondition to uninterpreted targets - //} } def assignContractsForEvals = evalFunctions.foreach { case (tname, evalfd) => val cdefs = closureFactory.closures(tname) - val tparams = evalfd.tparams.map(_.tp) + val recvTparams = getTypeParameters(evalfd.params.head.getType) val postres = FreshIdentifier("res", evalfd.returnType) val postMatchCases = cdefs map { cdef => // create a body of the match (which asserts that return value equals the uninterpreted function) // and also that the result field equals the result val op = closureFactory.lazyopOfClosure(cdef) - val ctype = CaseClassType(cdef, tparams) + val ctype = CaseClassType(cdef, recvTparams) val binder = FreshIdentifier("t", ctype) val pattern = InstanceOfPattern(Some(binder), ctype) // t.clres == res._1 @@ -719,7 +701,7 @@ class LazyClosureConverter(p: Program, ctx: LeonContext, fld => CaseClassSelector(ctype, binder.toVariable, fld.id) } Some(Equals(TupleSelect(postres.toVariable, 1), - FunctionInvocation(TypedFunDef(uiFuncs(op)._1, tparams), args))) + FunctionInvocation(TypedFunDef(uiFuncs(op)._1, recvTparams), args))) } else None val rhs = createAnd(clause1 +: clause2.toList) MatchCase(pattern, None, rhs) @@ -771,7 +753,9 @@ class LazyClosureConverter(p: Program, ctx: LeonContext, } case d => Seq(d) }), - closureFactory.allClosuresAndParents ++ closureCons.values ++ - evalFunctions.values ++ computeFunctions.values ++ uiStateFuns.values, anchor) + closureFactory.allClosuresAndParents ++ Seq(closureFactory.state) ++ + closureCons.values ++ evalFunctions.values ++ + computeFunctions.values ++ uiStateFuns.values ++ + closureFactory.stateUpdateFuns.values, anchor) } } diff --git a/src/main/scala/leon/laziness/LazyClosureFactory.scala b/src/main/scala/leon/laziness/LazyClosureFactory.scala index f5fffde3112dff7cf7f547b58a586d378bdb6e8d..49fd9b124d5be4b96d25ec7d710d1beab969f99d 100644 --- a/src/main/scala/leon/laziness/LazyClosureFactory.scala +++ b/src/main/scala/leon/laziness/LazyClosureFactory.scala @@ -144,4 +144,66 @@ class LazyClosureFactory(p: Program) { * This avoids the use of additional maps. */ def lazyTypeNameOfClosure(cl: CaseClassDef) = adtNameToTypeName(cl.parent.get.classDef.id.name) + + /** + * Define a state as an ADT whose fields are sets of closures. + * Note that we need to ensure that there are state ADT is not recursive. + */ + val state = { + var tparams = Seq[TypeParameter]() + var i = 0 + def freshTParams(n: Int): Seq[TypeParameter] = { + val start = i + 1 + i += n // create 'n' fresh ids + val nparams = (start to i).map(index => TypeParameter.fresh("T"+index)) + tparams ++= nparams + nparams + } + // field of the ADT + val fields = lazyTypeNames map { tn => + val absClass = absClosureType(tn) + val tparams = freshTParams(absClass.tparams.size) + val fldType = SetType(AbstractClassType(absClass, tparams)) + ValDef(FreshIdentifier(typeToFieldName(tn), fldType)) + } + val ccd = CaseClassDef(FreshIdentifier("State@"), tparams map TypeParameterDef, None, false) + ccd.setFields(fields) + ccd + } + + def selectFieldOfState(tn: String, st: Expr, stType: CaseClassType) = { + val selName = typeToFieldName(tn) + stType.classDef.fields.find{ fld => fld.id.name == selName} match { + case Some(fld) => + CaseClassSelector(stType, st, fld.id) + case _ => + throw new IllegalStateException(s"Cannot find a field of $stType with name: $selName") + } + } + + val stateUpdateFuns : Map[String, FunDef] = + lazyTypeNames.map{ tn => + val fldname = typeToFieldName(tn) + val tparams = state.tparams.map(_.tp) + val stType = CaseClassType(state, tparams) + val param1 = FreshIdentifier("st@", stType) + val SetType(baseT) = stType.classDef.fields.find{ fld => fld.id.name == fldname}.get.getType + val param2 = FreshIdentifier("cl", baseT) + + // TODO: as an optimization we can mark all these functions as inline and inline them at their callees + val updateFun = new FunDef(FreshIdentifier("updState"+tn), + state.tparams, Seq(ValDef(param1), ValDef(param2)), stType) + // create a body for the updateFun: + val nargs = state.fields.map{ fld => + val fldSelect = CaseClassSelector(stType, param1.toVariable, fld.id) + if(fld.id.name == fldname) { + SetUnion(fldSelect, FiniteSet(Set(param2.toVariable), baseT)) // st@.tn + Set(param2) + } else { + fldSelect + } + } + val nst = CaseClass(stType, nargs) + updateFun.body = Some(nst) + (tn -> updateFun) + }.toMap }