From fcdb10620ae5e126cde812f4dfaedfef7ce2fbec Mon Sep 17 00:00:00 2001 From: ravi <ravi.kandhadai@epfl.ch> Date: Tue, 19 Jan 2016 20:20:54 +0100 Subject: [PATCH] Updating handling of memoization --- library/lazy/package.scala | 8 +++ .../laziness/LazinessEliminationPhase.scala | 20 ++++--- .../scala/leon/laziness/LazinessUtil.scala | 12 ++++ .../leon/laziness/LazyClosureConverter.scala | 50 +++++++++-------- .../memoization/FibonacciMemoized.scala | 7 +-- .../memoization/Knapsack.scala | 56 +++++++++++++++++++ 6 files changed, 119 insertions(+), 34 deletions(-) create mode 100644 testcases/lazy-datastructures/memoization/Knapsack.scala diff --git a/library/lazy/package.scala b/library/lazy/package.scala index f64cec654..3d03565d5 100644 --- a/library/lazy/package.scala +++ b/library/lazy/package.scala @@ -50,6 +50,14 @@ object $ { @inline implicit def toWithState[T](x: T) = new WithState(x) + @library + case class Mem[T](v: T) { + @extern + def isCached: Boolean = sys.error("not implemented!") + } + @inline + implicit def toMem[T](x: T) = new Mem(x) + /** * annotations for monotonicity proofs. * Note implemented as of now. diff --git a/src/main/scala/leon/laziness/LazinessEliminationPhase.scala b/src/main/scala/leon/laziness/LazinessEliminationPhase.scala index 88ef5bda2..40b56168a 100644 --- a/src/main/scala/leon/laziness/LazinessEliminationPhase.scala +++ b/src/main/scala/leon/laziness/LazinessEliminationPhase.scala @@ -203,7 +203,7 @@ object LazinessEliminationPhase extends TransformationPhase { lazy val valueFun = ProgramUtil.functionByFullName("leon.lazyeval.$.value", prog).get prog.modules.foreach { md => - def exprLifter(inspec: Boolean)(expr: Expr) = expr match { + def exprLifter(inmem: Boolean)(expr: Expr) = expr match { // is the arugment of lazy invocation not a function ? case finv @ FunctionInvocation(lazytfd, Seq(arg)) if isLazyInvocation(finv)(prog) && !arg.isInstanceOf[FunctionInvocation] => val freevars = variablesOf(arg).toList @@ -230,7 +230,7 @@ object LazinessEliminationPhase extends TransformationPhase { val freshid = FreshIdentifier("t", rootType) Let(freshid, arg, FunctionInvocation(TypedFunDef(fd, Seq(rootType)), Seq(freshid.toVariable))) - case FunctionInvocation(TypedFunDef(fd, targs), args) if isMemoized(fd) && !inspec => + case FunctionInvocation(TypedFunDef(fd, targs), args) if isMemoized(fd) && !inmem => // calling a memoized function is modeled as creating a lazy closure and forcing it val tfd = TypedFunDef(fdmap.getOrElse(fd, fd), targs) val finv = FunctionInvocation(tfd, args) @@ -244,14 +244,20 @@ object LazinessEliminationPhase extends TransformationPhase { } md.definedFunctions.foreach { case fd if fd.hasBody && !fd.isLibrary => - val nbody = simplePostTransform(exprLifter(false))(fd.body.get) + def rec(inmem: Boolean)(e: Expr): Expr = e match { + case Operator(args, op) => + val nargs = args map rec(inmem || isMemCons(e)(prog)) + exprLifter(inmem)(op(nargs)) + } + val nfd = fdmap(fd) + nfd.fullBody = rec(false)(fd.fullBody) + /*val nbody = simplePostTransform(exprLifter(false))(fd.body.get) val npre = fd.precondition.map(simplePostTransform(exprLifter(true))) - val npost = fd.postcondition.map(simplePostTransform(exprLifter(true))) + val npost = fd.postcondition.map(simplePostTransform(exprLifter(true)))*/ //println(s"New body of $fd: $nbody") - val nfd = fdmap(fd) - nfd.body = Some(nbody) + /*nfd.body = Some(nbody) nfd.precondition = npre - nfd.postcondition = npost + nfd.postcondition = npost*/ case _ => ; } } diff --git a/src/main/scala/leon/laziness/LazinessUtil.scala b/src/main/scala/leon/laziness/LazinessUtil.scala index 7c6a3b45f..1c543fa2d 100644 --- a/src/main/scala/leon/laziness/LazinessUtil.scala +++ b/src/main/scala/leon/laziness/LazinessUtil.scala @@ -113,6 +113,12 @@ object LazinessUtil { case _ => false } + def isMemCons(e: Expr)(implicit p: Program): Boolean = e match { + case CaseClass(cct, Seq(_)) => + fullName(cct.classDef)(p) == "leon.lazyeval.$.Mem" + case _ => false + } + /** * There are many overloads of withState functions with different number * of arguments. All of them should pass this check. @@ -123,6 +129,12 @@ object LazinessUtil { case _ => false } + def isCachedInv(e: Expr)(implicit p: Program): Boolean = e match { + case FunctionInvocation(TypedFunDef(fd, _), Seq(_)) => + fullName(fd)(p) == "leon.lazyeval.Mem.isCached" + case _ => false + } + def isValueInvocation(e: Expr)(implicit p: Program): Boolean = e match { case FunctionInvocation(TypedFunDef(fd, _), Seq(_)) => fullName(fd)(p) == "leon.lazyeval.$.value" diff --git a/src/main/scala/leon/laziness/LazyClosureConverter.scala b/src/main/scala/leon/laziness/LazyClosureConverter.scala index c86ec021f..45534d6b6 100644 --- a/src/main/scala/leon/laziness/LazyClosureConverter.scala +++ b/src/main/scala/leon/laziness/LazyClosureConverter.scala @@ -299,10 +299,10 @@ class LazyClosureConverter(p: Program, ctx: LeonContext, (tname -> fun) }.toMap - def mapExpr(expr: Expr)(implicit stTparams: Seq[TypeParameter], inSpec: Boolean): (Option[Expr] => Expr, Boolean) = expr match { + def mapExpr(expr: Expr)(implicit stTparams: Seq[TypeParameter]): (Option[Expr] => Expr, Boolean) = expr match { case finv @ FunctionInvocation(_, Seq(FunctionInvocation(TypedFunDef(argfd, tparams), args))) // lazy construction ? - if isLazyInvocation(finv)(p) && !inSpec => + if isLazyInvocation(finv)(p) => val op = (nargs: Seq[Expr]) => ((st: Option[Expr]) => { val adt = closureFactory.closureOfLazyOp(argfd) // create lets to bind the nargs to variables @@ -325,18 +325,10 @@ class LazyClosureConverter(p: Program, ctx: LeonContext, }, false) mapNAryOperator(args, op) - // TODO: we don't have to create a dummy return value for memoized functions, as they wouldn't be stored in other structures - case finv @ FunctionInvocation(_, Seq(FunctionInvocation(TypedFunDef(argfd, tparams), args))) // invocation in spec ? - if isLazyInvocation(finv)(p) && inSpec => + case cc @ CaseClass(_, Seq(FunctionInvocation(TypedFunDef(argfd, tparams), args))) if isMemCons(cc)(p) => // in this case argfd is a memoized function val op = (nargs: Seq[Expr]) => ((st: Option[Expr]) => { val adt = closureFactory.closureOfLazyOp(argfd) - /*// create lets to bind the nargs to variables - val (flatArgs, letCons) = nargs.foldRight((Seq[Variable](), (e : Expr) => e)){ - case (narg, (fargs, lcons)) => - val id = FreshIdentifier("a", narg.getType, true) - (id.toVariable +: fargs, e => Let(id, narg, lcons(e))) - }*/ CaseClass(CaseClassType(adt, tparams), nargs) }, false) mapNAryOperator(args, op) @@ -360,13 +352,24 @@ class LazyClosureConverter(p: Program, ctx: LeonContext, val stType = CaseClassType(closureFactory.state, stTparams) val cls = closureFactory.selectFieldOfState(tname, st, stType) val memberTest = ElementOfSet(narg, cls) - if(closureFactory.isMemType(tname)) { - memberTest - } else { - val subtypeTest = IsInstanceOf(narg, + val subtypeTest = IsInstanceOf(narg, CaseClassType(closureFactory.eagerClosure(tname).get, getTypeArguments(baseType))) - Or(memberTest, subtypeTest) - } + Or(memberTest, subtypeTest) + }, false) + mapNAryOperator(args, op) + + case finv @ FunctionInvocation(_, args) if isCachedInv(finv)(p) => // isCached function ? + val baseType = args(0).getType match { + case cct: CaseClassType => cct.fieldsTypes(0) + } + val op = (nargs: Seq[Expr]) => ((stOpt: Option[Expr]) => { + val narg = nargs(0) // there must be only one argument here + //println("narg: "+narg+" type: "+narg.getType) + val tname = typeNameWOParams(baseType) + val st = stOpt.get + val stType = CaseClassType(closureFactory.state, stTparams) + val cls = closureFactory.selectFieldOfState(tname, st, stType) + ElementOfSet(narg, cls) }, false) mapNAryOperator(args, op) @@ -545,7 +548,7 @@ class LazyClosureConverter(p: Program, ctx: LeonContext, case t: Terminal => (_ => t, false) } - def mapNAryOperator(args: Seq[Expr], op: Seq[Expr] => (Option[Expr] => Expr, Boolean))(implicit stTparams: Seq[TypeParameter], inSpec: Boolean) = { + def mapNAryOperator(args: Seq[Expr], op: Seq[Expr] => (Option[Expr] => Expr, Boolean))(implicit stTparams: Seq[TypeParameter]) = { // create n variables to model n lets val letvars = args.map(arg => FreshIdentifier("arg", arg.getType, true).toVariable) (args zip letvars).foldRight(op(letvars)) { @@ -591,7 +594,7 @@ class LazyClosureConverter(p: Program, ctx: LeonContext, val stTparams = nfd.tparams.collect{ case tpd if isPlaceHolderTParam(tpd.tp) => tpd.tp } - val (nbodyFun, bodyUpdatesState) = mapExpr(fd.body.get)(stTparams, false) + val (nbodyFun, bodyUpdatesState) = mapExpr(fd.body.get)(stTparams) val nbody = nbodyFun(stateParam) val bodyWithState = if (!bodyUpdatesState && funsRetStates(fd)) @@ -606,7 +609,7 @@ class LazyClosureConverter(p: Program, ctx: LeonContext, // This guarantees their observational purity/transparency // collect class invariants that need to be added if (fd.hasPrecondition) { - val (npreFun, preUpdatesState) = mapExpr(fd.precondition.get)(stTparams, true) + val (npreFun, preUpdatesState) = mapExpr(fd.precondition.get)(stTparams) nfd.precondition = if (preUpdatesState) Some(TupleSelect(npreFun(stateParam), 1)) // ignore state updated by pre @@ -670,7 +673,7 @@ class LazyClosureConverter(p: Program, ctx: LeonContext, case e => e }(post) // thread state through postcondition - val (npostFun, postUpdatesState) = mapExpr(tpost)(stTparams, true) + val (npostFun, postUpdatesState) = mapExpr(tpost)(stTparams) val resval = if (bodyUpdatesState || funsRetStates(fd)) TupleSelect(newres.toVariable, 1) @@ -703,7 +706,7 @@ class LazyClosureConverter(p: Program, ctx: LeonContext, val binder = FreshIdentifier("t", ctype) val pattern = InstanceOfPattern(Some(binder), ctype) // t.clres == res._1 - val clause1 = if(ismem) { + val clause1 = if(!ismem) { val clresField = cdef.fields.last Equals(TupleSelect(postres.toVariable, 1), CaseClassSelector(ctype, binder.toVariable, clresField.id)) @@ -734,7 +737,8 @@ class LazyClosureConverter(p: Program, ctx: LeonContext, * Overrides the types of the lazy fields in the case class definitions */ def transformCaseClasses = p.definedClasses.foreach { - case ccd @ CaseClassDef(id, tparamDefs, superClass, isCaseObj) => + case ccd @ CaseClassDef(id, tparamDefs, superClass, isCaseObj) + if !ccd.flags.contains(Annotation("library", Seq())) => val nfields = ccd.fields.map { fld => unwrapLazyType(fld.getType) match { case None => fld diff --git a/testcases/lazy-datastructures/memoization/FibonacciMemoized.scala b/testcases/lazy-datastructures/memoization/FibonacciMemoized.scala index 786c392b6..6e614db98 100644 --- a/testcases/lazy-datastructures/memoization/FibonacciMemoized.scala +++ b/testcases/lazy-datastructures/memoization/FibonacciMemoized.scala @@ -1,6 +1,5 @@ -package Unproved - import leon.lazyeval._ +import leon.lazyeval.$._ import leon.lang._ import leon.annotation._ import leon.instrumentation._ @@ -19,8 +18,8 @@ object FibMem { else fibRec(n-1) + fibRec(n-2) // postcondition implies that the second call would be cached } ensuring(_ => - (n <= 2 || ($(fibRec(n-1)).isEvaluated && - $(fibRec(n-2)).isEvaluated)) && time <= 40*n + 10) + (n <= 2 || (fibRec(n-1).isCached && + fibRec(n-2).isCached)) && time <= 40*n + 10) /*def fibRange(i: BigInt, k: BigInt): IList = { require(k >= 1 && i <= k && diff --git a/testcases/lazy-datastructures/memoization/Knapsack.scala b/testcases/lazy-datastructures/memoization/Knapsack.scala new file mode 100644 index 000000000..5526114a1 --- /dev/null +++ b/testcases/lazy-datastructures/memoization/Knapsack.scala @@ -0,0 +1,56 @@ +import leon.lazyeval._ +import leon.lazyeval.$._ +import leon.lang._ +import leon.annotation._ +import leon.instrumentation._ +//import leon.invariant._ + +object Knapscak { + sealed abstract class IList + case class Cons(x: BigInt, tail: IList) extends IList + case class Nil() extends IList + + def depsEval(i: BigInt, items: IList): Boolean = { + require(i >= 0) + if(i <= 0) true + else { + knapSack(i, items).isCached && depsEval(i-1, items) + } + } + + def maxValue(items: IList, w: BigInt, currList: IList): BigInt = { + require(w >= 0) + currList match { + case Cons(wi, tail) => + val oldMax = maxValue(items, w, tail) + if (wi <= w) { + val choiceVal = wi + knapSack(w - wi, items) + if (choiceVal >= oldMax) + choiceVal + else + oldMax + } else oldMax + case Nil() => BigInt(0) + } + } + + @memoize + def knapSack(w: BigInt, items: IList): BigInt = { + require(w >= 0) + if (w == 0) BigInt(0) + else { + maxValue(items, w, items) + } + } + + def bottomup(i: BigInt, w: BigInt, items: IList): IList = { + require(i >= 0 && w >= 0 && + (i == 0 || depsEval(i-1, items))) + if (i == w) + Cons(knapSack(i, items), Nil()) + else { + val ri = knapSack(i, items) + Cons(ri, bottomup(i + 1, w, items)) + } + } +} -- GitLab