From 36c2029f9f2fdb4957310c5c4923d17b828005eb Mon Sep 17 00:00:00 2001 From: ravi <ravi.kandhadai@epfl.ch> Date: Thu, 24 Sep 2015 18:39:10 +0200 Subject: [PATCH] Adding support for verifying programs with laziness --- library/annotation/package.scala | 2 + library/lazy/package.scala | 20 + src/main/scala/leon/Main.scala | 12 +- .../invariant/structure/FunctionUtils.scala | 1 + .../leon/invariant/util/DisjointSet.scala | 60 + .../util/LetTupleSimplifications.scala | 47 +- src/main/scala/leon/invariant/util/Util.scala | 312 ++++- .../scala/leon/purescala/Expressions.scala | 1 + .../scala/leon/purescala/PrinterHelpers.scala | 7 +- .../LazinessEliminationPhase.scala | 1092 +++++++++++++++++ .../RealTimeQueue-transformed.scala | 292 +++++ .../lazy-datastructures/RealTimeQueue.scala | 133 ++ 12 files changed, 1955 insertions(+), 24 deletions(-) create mode 100644 library/lazy/package.scala create mode 100644 src/main/scala/leon/invariant/util/DisjointSet.scala create mode 100644 src/main/scala/leon/transformations/LazinessEliminationPhase.scala create mode 100644 testcases/lazy-datastructures/RealTimeQueue-transformed.scala create mode 100644 testcases/lazy-datastructures/RealTimeQueue.scala diff --git a/library/annotation/package.scala b/library/annotation/package.scala index 00ae0743c..c28a524de 100644 --- a/library/annotation/package.scala +++ b/library/annotation/package.scala @@ -19,4 +19,6 @@ package object annotation { class monotonic extends StaticAnnotation @ignore class compose extends StaticAnnotation + @ignore + class axiom extends StaticAnnotation } \ No newline at end of file diff --git a/library/lazy/package.scala b/library/lazy/package.scala new file mode 100644 index 000000000..f70aebb0f --- /dev/null +++ b/library/lazy/package.scala @@ -0,0 +1,20 @@ +/* Copyright 2009-2013 EPFL, Lausanne */ + +package leon.lazyeval + +import leon.annotation._ +import leon.lang._ +import scala.language.implicitConversions + +@library +object $ { + def apply[T](f: => T) = new $(Unit => f) +} + +@library +case class $[T](f: Unit => T) { // leon does not support call by name as of now + lazy val value = f(()) + def * = f(()) + def isEvaluated = true // for now this is a dummy function, but it will be made sound when leon supports mutable fields. +} + diff --git a/src/main/scala/leon/Main.scala b/src/main/scala/leon/Main.scala index 3a8cb39df..d6cd9e9f0 100644 --- a/src/main/scala/leon/Main.scala +++ b/src/main/scala/leon/Main.scala @@ -27,7 +27,9 @@ object Main { solvers.isabelle.AdaptationPhase, solvers.isabelle.IsabellePhase, transformations.InstrumentationPhase, - invariant.engine.InferInvariantsPhase) + invariant.engine.InferInvariantsPhase, + transformations.LazinessEliminationPhase + ) } // Add whatever you need here. @@ -53,9 +55,10 @@ object Main { val optHelp = LeonFlagOptionDef("help", "Show help message", false) val optInstrument = LeonFlagOptionDef("instrument", "Instrument the code for inferring time/depth/stack bounds", false) val optInferInv = LeonFlagOptionDef("inferInv", "Infer invariants from (instrumented) the code", false) + val optLazyEval = LeonFlagOptionDef("lazy", "Handles programs that may use the lazy construct", false) override val definedOptions: Set[LeonOptionDef[Any]] = - Set(optTermination, optRepair, optSynthesis, optIsabelle, optNoop, optHelp, optEval, optVerify, optInstrument, optInferInv) + Set(optTermination, optRepair, optSynthesis, optIsabelle, optNoop, optHelp, optEval, optVerify, optInstrument, optInferInv, optLazyEval) } lazy val allOptions: Set[LeonOptionDef[Any]] = allComponents.flatMap(_.definedOptions) @@ -153,7 +156,8 @@ object Main { import solvers.isabelle.IsabellePhase import MainComponent._ import invariant.engine.InferInvariantsPhase - import transformations.InstrumentationPhase + import transformations._ + val helpF = ctx.findOptionOrDefault(optHelp) val noopF = ctx.findOptionOrDefault(optNoop) @@ -166,6 +170,7 @@ object Main { val evalF = ctx.findOption(optEval).isDefined val inferInvF = ctx.findOptionOrDefault(optInferInv) val instrumentF = ctx.findOptionOrDefault(optInstrument) + val lazyevalF = ctx.findOptionOrDefault(optLazyEval) val analysisF = verifyF && terminationF if (helpF) { @@ -189,6 +194,7 @@ object Main { else if (evalF) EvaluationPhase else if (inferInvF) InferInvariantsPhase else if (instrumentF) InstrumentationPhase andThen FileOutputPhase + else if (lazyevalF) LazinessEliminationPhase else analysis } diff --git a/src/main/scala/leon/invariant/structure/FunctionUtils.scala b/src/main/scala/leon/invariant/structure/FunctionUtils.scala index f676520cd..7c812d781 100644 --- a/src/main/scala/leon/invariant/structure/FunctionUtils.scala +++ b/src/main/scala/leon/invariant/structure/FunctionUtils.scala @@ -25,6 +25,7 @@ object FunctionUtils { lazy val isCommutative = fd.annotations.contains("commutative") lazy val isDistributive = fd.annotations.contains("distributive") lazy val compose = fd.annotations.contains("compose") + lazy val isLibrary = fd.annotations.contains("library") //the template function lazy val tmplFunctionName = "tmpl" diff --git a/src/main/scala/leon/invariant/util/DisjointSet.scala b/src/main/scala/leon/invariant/util/DisjointSet.scala new file mode 100644 index 000000000..520139764 --- /dev/null +++ b/src/main/scala/leon/invariant/util/DisjointSet.scala @@ -0,0 +1,60 @@ +package leon +package invariant.util + +import scala.collection.mutable.{ Map => MutableMap } +import scala.collection.mutable.{ Set => MutableSet } + +class DisjointSets[T] { + // A map from elements to their parent and rank + private var disjTree = MutableMap[T, (T, Int)]() + + private def findInternal(x: T): (T, Int) = { + val (p, rank) = disjTree(x) + if (p == x) + (x, rank) + else { + val root = findInternal(p) + // compress path + disjTree(x) = root + root + } + } + + private def findOrCreateInternal(x: T) = + if (!disjTree.contains(x)) { + disjTree += (x -> (x, 1)) + (x, 1) + } else findInternal(x) + + def findOrCreate(x: T) = findOrCreateInternal(x)._1 + + def find(x: T) = findInternal(x)._1 + + def union(x: T, y: T) { + val (rep1, rank1) = findOrCreateInternal(x) + val (rep2, rank2) = findOrCreateInternal(y) + if (rank1 < rank2) { + disjTree(rep1) = (rep2, rank2) + } else if (rank2 < rank1) { + disjTree(rep2) = (rep1, rank1) + } else + disjTree(rep1) = (rep2, rank2 + 1) + } + + def toMap = { + val repToSet = disjTree.keys.foldLeft(MutableMap[T, Set[T]]()) { + case (acc, k) => + val root = find(k) + if (acc.contains(root)) + acc(root) = acc(root) + k + else + acc += (root -> Set(k)) + acc + } + disjTree.keys.map {k => (k -> repToSet(find(k)))}.toMap + } + + override def toString = { + disjTree.toString + } +} \ No newline at end of file diff --git a/src/main/scala/leon/invariant/util/LetTupleSimplifications.scala b/src/main/scala/leon/invariant/util/LetTupleSimplifications.scala index 966121756..c89eae941 100644 --- a/src/main/scala/leon/invariant/util/LetTupleSimplifications.scala +++ b/src/main/scala/leon/invariant/util/LetTupleSimplifications.scala @@ -31,10 +31,9 @@ object LetTupleSimplification { def letSanityChecks(ine: Expr) = { simplePostTransform(_ match { - case letExpr @ Let(binderId, letValue, body) - if (binderId.getType != letValue.getType) => - throw new IllegalStateException("Binder and value type mismatch: "+ - s"(${binderId.getType},${letValue.getType})") + case letExpr @ Let(binderId, letValue, body) if (binderId.getType != letValue.getType) => + throw new IllegalStateException("Binder and value type mismatch: " + + s"(${binderId.getType},${letValue.getType})") case e => e })(ine) } @@ -266,11 +265,12 @@ object LetTupleSimplification { // by pulling them out def pullLetToTop(e: Expr): Expr = { val transe = e match { - //note: do not pull let's out of ensuring or requires + case Lambda(args, body) => + Lambda(args, pullLetToTop(body)) case Ensuring(body, pred) => - Ensuring(pullLetToTop(body), pred) + Ensuring(pullLetToTop(body), pullLetToTop(pred)) case Require(pre, body) => - Require(pre, pullLetToTop(body)) + Require(pullLetToTop(pre), pullLetToTop(body)) case letExpr @ Let(binder, letValue, body) => // transform the 'letValue' with the current map @@ -293,9 +293,16 @@ object LetTupleSimplification { replaceLetBody(pullLetToTop(e1), te1 => replaceLetBody(pullLetToTop(e2), te2 => op(Seq(te1, te2)))) - //don't pull things out of if-then-else and match (don't know why this is a problem) + //don't pull lets out of if-then-else branches and match cases case IfExpr(c, th, elze) => - IfExpr(pullLetToTop(c), pullLetToTop(th), pullLetToTop(elze)) + replaceLetBody(pullLetToTop(c), IfExpr(_, pullLetToTop(th), pullLetToTop(elze))) + + case MatchExpr(scr, cases) => + val newcases = cases.map { + case MatchCase(pat, guard, rhs) => + MatchCase(pat, guard map pullLetToTop, pullLetToTop(rhs)) + } + replaceLetBody(pullLetToTop(scr), MatchExpr(_, newcases)) case Operator(Seq(), op) => op(Seq()) @@ -324,7 +331,8 @@ object LetTupleSimplification { } transe } - val res = pullLetToTop(matchToIfThenElse(ine)) + //val res = pullLetToTop(matchToIfThenElse(ine)) + val res = pullLetToTop(ine) // println("After Pulling lets to top : \n" + ScalaPrinter.apply(res)) res } @@ -348,7 +356,7 @@ object LetTupleSimplification { } else if (occurrences == 1) { Some(replace(Map(Variable(i) -> e), b)) } else { - //TODO: we can also remove zero occurrences and compress the tuples + //TODO: we can also remove zero occurrences and compress the tuples // this may be necessary when instrumentations are combined. letExpr match { case letExpr @ Let(binder, lval @ Tuple(subes), b) => @@ -359,24 +367,26 @@ object LetTupleSimplification { }(b) res } + val binderVar = binder.toVariable val repmap: Map[Expr, Expr] = subes.zipWithIndex.collect { - case (sube, i) if occurrences(i + 1) == 1 => - (TupleSelect(binder.toVariable, i + 1) -> sube) + case (sube, i) if occurrences(i + 1) == 1 => // sube is used only once ? + (TupleSelect(binderVar, i + 1) -> sube) + case (v @ Variable(_), i) => // sube is a variable ? + (TupleSelect(binderVar, i + 1) -> v) + case (ts @ TupleSelect(Variable(_), _), i) => // sube is a tuple select of a variable ? + (TupleSelect(binderVar, i + 1) -> ts) }.toMap Some(Let(binder, lval, replace(repmap, b))) //note: here, we cannot remove the let, //if it is not used it will be removed in the next iteration - case _ => None } } } - case _ => None } res } - val transforms = removeLetsFromLetValues _ andThen fixpoint(postMap(simplerLet)) _ andThen simplifyArithmetic transforms(ine) } @@ -413,8 +423,7 @@ object LetTupleSimplification { // Reconstruct the expressin tree with the non-constants and the result of constant evaluation above if (allConstantsOpped != identity) { allNonConstants.foldLeft(InfiniteIntegerLiteral(allConstantsOpped): Expr)((acc: Expr, currExpr) => makeTree(acc, currExpr)) - } - else { + } else { if (allNonConstants.size == 0) InfiniteIntegerLiteral(identity) else { allNonConstants.tail.foldLeft(allNonConstants.head)((acc: Expr, currExpr) => makeTree(acc, currExpr)) @@ -430,7 +439,7 @@ object LetTupleSimplification { case Plus(e1, e2) => { getAllSummands(e1, false) ++ getAllSummands(e2, false) } - case _ => if (isTopLevel) Seq[Expr]() else Seq[Expr](e) + case _ => if (isTopLevel) Seq[Expr]() else Seq[Expr](e) } } diff --git a/src/main/scala/leon/invariant/util/Util.scala b/src/main/scala/leon/invariant/util/Util.scala index 5a101569e..855dba093 100644 --- a/src/main/scala/leon/invariant/util/Util.scala +++ b/src/main/scala/leon/invariant/util/Util.scala @@ -105,13 +105,50 @@ object Util { def copyProgram(prog: Program, mapdefs: (Seq[Definition] => Seq[Definition])): Program = { prog.copy(units = prog.units.collect { case unit if (!unit.defs.isEmpty) => unit.copy(defs = unit.defs.collect { - case module : ModuleDef if (!module.defs.isEmpty) => + case module: ModuleDef if (!module.defs.isEmpty) => module.copy(defs = mapdefs(module.defs)) case other => other }) }) } + def appendDefsToModules(p: Program, defs: Map[ModuleDef, Traversable[Definition]]): Program = { + val res = p.copy(units = for (u <- p.units) yield { + u.copy( + defs = u.defs.map { + case m: ModuleDef if defs.contains(m) => + m.copy(defs = m.defs ++ defs(m)) + case other => other + }) + }) + res + } + + def addDefs(p: Program, defs: Traversable[Definition], after: Definition): Program = { + var found = false + val res = p.copy(units = for (u <- p.units) yield { + u.copy( + defs = u.defs.map { + case m: ModuleDef => + val newdefs = for (df <- m.defs) yield { + df match { + case `after` => + found = true + after +: defs.toSeq + case d => + Seq(d) + } + } + m.copy(defs = newdefs.flatten) + case other => other + }) + }) + if (!found) { + println("addDefs could not find anchor definition!") + } + res + } + def createTemplateFun(plainTemp: Expr): FunctionInvocation = { val tmpl = Lambda(getTemplateIds(plainTemp).toSeq.map(id => ValDef(id)), plainTemp) val tmplFd = new FunDef(FreshIdentifier("tmpl", FunctionType(Seq(tmpl.getType), BooleanType), false), Seq(), Seq(ValDef(FreshIdentifier("arg", tmpl.getType), @@ -547,6 +584,48 @@ object Util { case _ => false } } + + def matchToITE(ine: Expr) = { + val liftedExpr = simplePostTransform { + case me @ MatchExpr(scrut, cases) => scrut match { + case t: Terminal => me + case _ => { + val freshid = FreshIdentifier("m", scrut.getType, true) + Let(freshid, scrut, MatchExpr(freshid.toVariable, cases)) + } + } + case e => e + }(ine) + purescala.ExprOps.matchToIfThenElse(liftedExpr) + } + + def precOrTrue(fd: FunDef): Expr = fd.precondition match { + case Some(pre) => pre + case None => BooleanLiteral(true) + } + + /* + * Apply an expression operation on all expressions contained in a FunDef + */ + def applyOnFunDef(operation: Expr => Expr)(funDef: FunDef): FunDef = { + val newFunDef = funDef.duplicate + newFunDef.fullBody = operation(funDef.fullBody) + newFunDef + } + + /** + * Apply preMap on all expressions contained in a FunDef + */ + def preMapOnFunDef(repl: Expr => Option[Expr], applyRec: Boolean = false)(funDef: FunDef): FunDef = { + applyOnFunDef(preMap(repl, applyRec))(funDef) + } + + /** + * Apply postMap on all expressions contained in a FunDef + */ + def postMapOnFunDef(repl: Expr => Option[Expr], applyRec: Boolean = false)(funDef: FunDef): FunDef = { + applyOnFunDef(postMap(repl, applyRec))(funDef) + } } /** @@ -714,4 +793,235 @@ class CounterMap[T] extends scala.collection.mutable.HashMap[T, Int] { this(v) += 1 else this += (v -> 1) } +} + +/** + * A class that looks for structural equality of expressions + * by ignoring the variable names. + * Useful for factoring common parts of two expressions into functions. + */ +class ExprStructure(val e: Expr) { + def structurallyEqual(e1: Expr, e2: Expr): Boolean = { + (e1, e2) match { + case (t1: Terminal, t2: Terminal) => + // we need to specially handle type parameters as they are not considered equal by default + (t1.getType, t2.getType) match { + case (ct1: ClassType, ct2: ClassType) => + if (ct1.classDef == ct2.classDef && ct1.tps.size == ct2.tps.size) { + (ct1.tps zip ct2.tps).forall { + case (TypeParameter(_), TypeParameter(_)) => + true + case (a, b) => + println(s"Checking Type arguments: $a, $b") + a == b + } + } else false + case (ty1, ty2) => ty1 == ty2 + } + case (Operator(args1, op1), Operator(args2, op2)) => + (op1 == op2) && (args1.size == args2.size) && (args1 zip args2).forall { + case (a1, a2) => structurallyEqual(a1, a2) + } + case _ => + false + } + } + + override def equals(other: Any) = { + other match { + case other: ExprStructure => + structurallyEqual(e, other.e) + case _ => + false + } + } + + override def hashCode = { + var opndcount = 0 // operand count + var opcount = 0 // operator count + postTraversal { + case t: Terminal => opndcount += 1 + case _ => opcount += 1 + }(e) + (opndcount << 16) ^ opcount + } +} + +object TypeUtil { + def getTypeParameters(t: TypeTree): Seq[TypeParameter] = { + t match { + case tp @ TypeParameter(_) => Seq(tp) + case NAryType(tps, _) => + (tps flatMap getTypeParameters).distinct + } + } + + def typeNameWOParams(t: TypeTree): String = t match { + case ct: ClassType => ct.id.name + case TupleType(ts) => ts.map(typeNameWOParams).mkString("(", ",", ")") + case ArrayType(t) => s"Array[${typeNameWOParams(t)}]" + case SetType(t) => s"Set[${typeNameWOParams(t)}]" + case MapType(from, to) => s"Map[${typeNameWOParams(from)}, ${typeNameWOParams(to)}]" + case FunctionType(fts, tt) => + val ftstr = fts.map(typeNameWOParams).mkString("(", ",", ")") + s"$ftstr => ${typeNameWOParams(tt)}" + case t => t.toString + } + + def instantiateTypeParameters(tpMap: Map[TypeParameter, TypeTree])(t: TypeTree): TypeTree = { + t match { + case tp: TypeParameter => tpMap.getOrElse(tp, tp) + case NAryType(subtypes, tcons) => + tcons(subtypes map instantiateTypeParameters(tpMap) _) + } + } + + /** + * `gamma` is the initial type environment which has + * type bindings for free variables of `ine`. + * It is not necessary that gamma should match the types of the + * identifiers of the free variables. + * Set and Maps are not supported yet + */ + def inferTypesOfLocals(ine: Expr, initGamma: Map[Identifier, TypeTree]): Expr = { + var idmap = Map[Identifier, Identifier]() + var gamma = initGamma + + /** + * Note this method has side-effects + */ + def makeIdOfType(oldId: Identifier, tpe: TypeTree): Identifier = { + if (oldId.getType != tpe) { + val freshid = FreshIdentifier(oldId.name, tpe, true) + idmap += (oldId -> freshid) + gamma += (oldId -> tpe) + freshid + } else oldId + } + + def rec(e: Expr): (TypeTree, Expr) = { + val res = e match { + case Let(id, value, body) => + val (valType, nval) = rec(value) + val nid = makeIdOfType(id, valType) + val (btype, nbody) = rec(body) + (btype, Let(nid, nval, nbody)) + + case Ensuring(body, Lambda(Seq(resdef @ ValDef(resid, _)), postBody)) => + val (btype, nbody) = rec(body) + val nres = makeIdOfType(resid, btype) + (btype, Ensuring(nbody, Lambda(Seq(ValDef(nres)), rec(postBody)._2))) + + case MatchExpr(scr, mcases) => + val (scrtype, nscr) = rec(scr) + val ncases = mcases.map { + case MatchCase(pat, optGuard, rhs) => + // resetting the type of patterns in the matches + def mapPattern(p: Pattern, expType: TypeTree): (Pattern, TypeTree) = { + p match { + case InstanceOfPattern(bopt, ict) => + // choose the subtype of the `expType` that + // has the same constructor as `ict` + val ntype = subcast(ict, expType.asInstanceOf[ClassType]) + if (!ntype.isDefined) + throw new IllegalStateException(s"Cannot find subtype of $expType with name: ${ict.classDef.id.toString}") + val nbopt = bopt.map(makeIdOfType(_, ntype.get)) + (InstanceOfPattern(nbopt, ntype.get), ntype.get) + + case CaseClassPattern(bopt, ict, subpats) => + val ntype = subcast(ict, expType.asInstanceOf[ClassType]) + if (!ntype.isDefined) + throw new IllegalStateException(s"Cannot find subtype of $expType with name: ${ict.classDef.id.toString}") + val cct = ntype.get.asInstanceOf[CaseClassType] + val nbopt = bopt.map(makeIdOfType(_, cct)) + val npats = (subpats zip cct.fieldsTypes).map { + case (p, t) => + //println(s"Subpat: $p expected type: $t") + mapPattern(p, t)._1 + } + (CaseClassPattern(nbopt, cct, npats), cct) + + case TuplePattern(bopt, subpats) => + val TupleType(subts) = scrtype + val patnTypes = (subpats zip subts).map { + case (p, t) => mapPattern(p, t) + } + val npats = patnTypes.map(_._1) + val ntype = TupleType(patnTypes.map(_._2)) + val nbopt = bopt.map(makeIdOfType(_, ntype)) + (TuplePattern(nbopt, npats), ntype) + + case WildcardPattern(bopt) => + val nbopt = bopt.map(makeIdOfType(_, expType)) + (WildcardPattern(nbopt), expType) + + case LiteralPattern(bopt, lit) => + val ntype = lit.getType + val nbopt = bopt.map(makeIdOfType(_, ntype)) + (LiteralPattern(nbopt, lit), ntype) + case _ => + throw new IllegalStateException("Not supported yet!") + } + } + val npattern = mapPattern(pat, scrtype)._1 + val nguard = optGuard.map(rec(_)._2) + val nrhs = rec(rhs)._2 + //println(s"New rhs: $nrhs inferred type: ${nrhs.getType}") + MatchCase(npattern, nguard, nrhs) + } + val nmatch = MatchExpr(nscr, ncases) + (nmatch.getType, nmatch) + + case cs @ CaseClassSelector(cltype, clExpr, fld) => + val (ncltype: CaseClassType, nclExpr) = rec(clExpr) + (ncltype, CaseClassSelector(ncltype, nclExpr, fld)) + + case AsInstanceOf(clexpr, cltype) => + val (ncltype: ClassType, nexpr) = rec(clexpr) + subcast(cltype, ncltype) match { + case Some(ntype) => (ntype, AsInstanceOf(nexpr, ntype)) + case _ => + //println(s"asInstanceOf type of $clExpr is: $cltype inferred type of $nclExpr : $ct") + throw new IllegalStateException(s"$nexpr : $ncltype cannot be cast to case class type: $cltype") + } + + case v @ Variable(id) => + if (gamma.contains(id)) { + if (idmap.contains(id)) + (gamma(id), idmap(id).toVariable) + else { + (gamma(id), v) + } + } else (id.getType, v) + + // need to handle tuple select specially + case TupleSelect(tup, i) => + val nop = TupleSelect(rec(tup)._2, i) + (nop.getType, nop) + case Operator(args, op) => + val nop = op(args.map(arg => rec(arg)._2)) + (nop.getType, nop) + case t: Terminal => + (t.getType, t) + } + //println(s"Inferred type of $e : ${res._1} new expression: ${res._2}") + if (res._1 == Untyped) { + throw new IllegalStateException(s"Cannot infer type for expression: $e") + } + res + } + + def subcast(oldType: ClassType, newType: ClassType): Option[ClassType] = { + newType match { + case AbstractClassType(absClass, tps) if absClass.knownCCDescendants.contains(oldType.classDef) => + //here oldType.classDef <: absClass + Some(CaseClassType(oldType.classDef.asInstanceOf[CaseClassDef], tps)) + case cct: CaseClassType => + Some(cct) + case _ => + None + } + } + rec(ine)._2 + } } \ No newline at end of file diff --git a/src/main/scala/leon/purescala/Expressions.scala b/src/main/scala/leon/purescala/Expressions.scala index 7216652a4..f016f0bc9 100644 --- a/src/main/scala/leon/purescala/Expressions.scala +++ b/src/main/scala/leon/purescala/Expressions.scala @@ -475,6 +475,7 @@ object Expressions { val getType = { if (typesCompatible(lhs.getType, rhs.getType)) BooleanType else { + //println(s"Incompatible argument types: arguments: ($lhs, $rhs) types: ${lhs.getType}, ${rhs.getType}") Untyped } } diff --git a/src/main/scala/leon/purescala/PrinterHelpers.scala b/src/main/scala/leon/purescala/PrinterHelpers.scala index 535d590c0..17ae02e0c 100644 --- a/src/main/scala/leon/purescala/PrinterHelpers.scala +++ b/src/main/scala/leon/purescala/PrinterHelpers.scala @@ -24,7 +24,10 @@ object PrinterHelpers { var firstElem = true while(strings.hasNext) { - val s = strings.next.stripMargin + val currval = strings.next + val s = if(currval != " || ") { + currval.stripMargin + } else currval // Compute indentation val start = s.lastIndexOf('\n') @@ -45,6 +48,8 @@ object PrinterHelpers { if (expressions.hasNext) { val e = expressions.next + if(e == "||") + println("Seen Expression: "+e) e match { case (t1, t2) => diff --git a/src/main/scala/leon/transformations/LazinessEliminationPhase.scala b/src/main/scala/leon/transformations/LazinessEliminationPhase.scala new file mode 100644 index 000000000..4bae04842 --- /dev/null +++ b/src/main/scala/leon/transformations/LazinessEliminationPhase.scala @@ -0,0 +1,1092 @@ +package leon +package transformations + +import invariant.factories._ +import invariant.util.Util._ +import invariant.util._ +import invariant.structure.FunctionUtils._ +import purescala.ScalaPrinter +import purescala.Common._ +import purescala.Definitions._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.DefOps._ +import purescala.Extractors._ +import purescala.Types._ +import purescala.FunctionClosure +import scala.collection.mutable.{ Map => MutableMap } +import leon.invariant.util.TypeUtil._ +import leon.invariant.util.LetTupleSimplification._ +import leon.solvers.TimeoutSolverFactory +import leon.verification.VerificationContext +import leon.verification.DefaultTactic +import leon.verification.AnalysisPhase +import java.io.File +import java.io.FileWriter +import java.io.BufferedWriter +import scala.util.matching.Regex +import leon.purescala.PrettyPrinter + +object LazinessEliminationPhase extends TransformationPhase { + val debug = false + val dumpProgramWithClosures = false + val dumpTypeCorrectProg = false + val dumpFinalProg = false + + // flags + val removeRecursionViaEval = false + val skipVerification = false + val prettyPrint = true + val optOutputDirectory = new LeonOptionDef[String] { + val name = "o" + val description = "Output directory" + val default = "leon.out" + val usageRhs = "dir" + val parser = (x: String) => x + } + + val name = "Laziness Elimination Phase" + val description = "Coverts a program that uses lazy construct" + + " to a program that does not use lazy constructs" + + /** + * TODO: enforce that the specs do not create lazy closures + * TODO: we are forced to make an assumption that lazy ops takes as type parameters only those + * type parameters of their return type and not more. (enforce this) + * TODO: Check that lazy types are not nested + */ + def apply(ctx: LeonContext, prog: Program): Program = { + + val nprog = liftLazyExpressions(prog) + val (tpeToADT, opToAdt) = createClosures(nprog) + //Override the types of the lazy fields in the case class definition + nprog.definedClasses.foreach { + case ccd @ CaseClassDef(id, tparamDefs, superClass, isCaseObj) => + val nfields = ccd.fields.map { fld => + unwrapLazyType(fld.getType) match { + case None => fld + case Some(btype) => + val adtType = AbstractClassType(tpeToADT(typeNameWOParams(btype))._2, + getTypeParameters(btype)) + ValDef(fld.id, Some(adtType)) // overriding the field type + } + } + ccd.setFields(nfields) + case _ => ; + } + // TODO: for now pick one suitable module. But ideally the lazy closure will be added to a separate module + // and imported every where + val progWithClasses = addDefs(nprog, + tpeToADT.values.flatMap(v => v._2 +: v._3), + opToAdt.keys.last) + if (debug) + println("After adding case class corresponding to lazyops: \n" + ScalaPrinter.apply(progWithClasses)) + val progWithClosures = (new TransformProgramUsingClosures(progWithClasses, tpeToADT, opToAdt))() + //Rectify type parameters and local types + val typeCorrectProg = rectifyLocalTypeAndTypeParameters(progWithClosures) + if (dumpTypeCorrectProg) + println("After rectifying types: \n" + ScalaPrinter.apply(typeCorrectProg)) + + val transProg = assertClosurePres(typeCorrectProg) + if (dumpFinalProg) + println("After asserting closure preconditions: \n" + ScalaPrinter.apply(transProg)) + // handle 'axiom annotation + transProg.definedFunctions.foreach { fd => + if (fd.annotations.contains("axiom")) + fd.addFlag(Annotation("library", Seq())) + } + // check specifications (to be moved to a different phase) + if (!skipVerification) + checkSpecifications(transProg) + if (prettyPrint) + prettyPrintProgramToFile(transProg, ctx) + transProg + } + + def prettyPrintProgramToFile(p: Program, ctx: LeonContext) { + val outputFolder = ctx.findOptionOrDefault(optOutputDirectory) + try { + new File(outputFolder).mkdir() + } catch { + case _ : java.io.IOException => ctx.reporter.fatalError("Could not create directory " + outputFolder) + } + + for (u <- p.units if u.isMainUnit) { + val outputFile = s"$outputFolder${File.separator}${u.id.toString}.scala" + try { + val out = new BufferedWriter(new FileWriter(outputFile)) + // remove '@' from the end of the identifier names + val pat = new Regex("""(\S+)(@)""", "base", "suffix") + val pgmText = pat.replaceAllIn(PrettyPrinter.apply(p), m => m.group("base")) + out.write(pgmText) + out.close() + } + catch { + case _ : java.io.IOException => ctx.reporter.fatalError("Could not write on " + outputFile) + } + } + ctx.reporter.info("Output written on " + outputFolder) + } + + def isLazyInvocation(e: Expr)(implicit p: Program): Boolean = e match { + case FunctionInvocation(TypedFunDef(fd, _), Seq(_)) => + fullName(fd)(p) == "leon.lazyeval.$.apply" + case _ => + false + } + + def isEvaluatedInvocation(e: Expr)(implicit p: Program): Boolean = e match { + case FunctionInvocation(TypedFunDef(fd, _), Seq(_)) => + fullName(fd)(p) == "leon.lazyeval.$.isEvaluated" + case _ => false + } + + def isValueInvocation(e: Expr)(implicit p: Program): Boolean = e match { + case FunctionInvocation(TypedFunDef(fd, _), Seq(_)) => + fullName(fd)(p) == "leon.lazyeval.$.value" + case _ => false + } + + def isStarInvocation(e: Expr)(implicit p: Program): Boolean = e match { + case FunctionInvocation(TypedFunDef(fd, _), Seq(_)) => + fullName(fd)(p) == "leon.lazyeval.$.*" + case _ => false + } + + def isLazyType(tpe: TypeTree): Boolean = tpe match { + case CaseClassType(CaseClassDef(cid, _, None, false), Seq(_)) => + cid.name == "$" + case _ => false + } + + /** + * TODO: Check that lazy types are not nested + */ + def unwrapLazyType(tpe: TypeTree) = tpe match { + case ctype @ CaseClassType(_, Seq(innerType)) if isLazyType(ctype) => + Some(innerType) + case _ => None + } + + def rootType(t: TypeTree): Option[AbstractClassType] = t match { + case absT: AbstractClassType => Some(absT) + case ct: ClassType => ct.parent + case _ => None + } + + def opNameToCCName(name: String) = { + name.capitalize + "@" + } + + /** + * Convert the first character to lower case + * and remove the last character. + */ + def ccNameToOpName(name: String) = { + name.substring(0, 1).toLowerCase() + + name.substring(1, name.length() - 1) + } + + def typeNameToADTName(name: String) = { + "Lazy" + name + } + + def adtNameToTypeName(name: String) = { + name.substring(4) + } + + def closureConsName(typeName: String) = { + "new@" + typeName + } + + def isClosureCons(fd: FunDef) = { + fd.id.name.startsWith("new@") + } + + /** + * convert the argument of every lazy constructors to a procedure + */ + def liftLazyExpressions(prog: Program): Program = { + var newfuns = Map[ExprStructure, (FunDef, ModuleDef)]() + var anchorFun: Option[FunDef] = None + val fdmap = prog.modules.flatMap { md => + md.definedFunctions.map { + case fd if fd.hasBody && !fd.isLibrary => + //println("FunDef: "+fd) + val nfd = preMapOnFunDef { + case finv @ FunctionInvocation(lazytfd, Seq(arg)) if isLazyInvocation(finv)(prog) && !arg.isInstanceOf[FunctionInvocation] => + val freevars = variablesOf(arg).toList + val tparams = freevars.map(_.getType) flatMap getTypeParameters + val argstruc = new ExprStructure(arg) + val argfun = + if (newfuns.contains(argstruc)) { + newfuns(argstruc)._1 + } else { + //construct type parameters for the function + val nfun = new FunDef(FreshIdentifier("lazyarg", arg.getType, true), + tparams map TypeParameterDef.apply, arg.getType, freevars.map(ValDef(_))) + nfun.body = Some(arg) + newfuns += (argstruc -> (nfun, md)) + nfun + } + Some(FunctionInvocation(lazytfd, Seq(FunctionInvocation(TypedFunDef(argfun, tparams), + freevars.map(_.toVariable))))) + case _ => None + }(fd) + (fd -> Some(nfd)) + case fd => fd -> Some(fd.duplicate) + } + }.toMap + // map the new functions to themselves + val newfdMap = newfuns.values.map { case (nfd, _) => nfd -> None }.toMap + val (repProg, _) = replaceFunDefs(prog)(fdmap ++ newfdMap) + val nprog = + if (!newfuns.isEmpty) { + val modToNewDefs = newfuns.values.groupBy(_._2).map { case (k, v) => (k, v.map(_._1)) }.toMap + appendDefsToModules(repProg, modToNewDefs) + } else + throw new IllegalStateException("Cannot find any lazy computation") + if (debug) + println("After lifiting arguments of lazy constructors: \n" + ScalaPrinter.apply(nprog)) + nprog + } + + /** + * Create a mapping from types to the lazyops that may produce a value of that type + * TODO: relax that requirement that type parameters of return type of a function + * lazy evaluated should include all of its type parameters + */ + type ClosureData = (TypeTree, AbstractClassDef, Seq[CaseClassDef]) + def createClosures(implicit p: Program) = { + // collect all the operations that could be lazily evaluated. + val lazyops = p.definedFunctions.flatMap { + case fd if (fd.hasBody) => + filter(isLazyInvocation)(fd.body.get) map { + case FunctionInvocation(_, Seq(FunctionInvocation(tfd, _))) => tfd.fd + } + case _ => Seq() + }.distinct + if (debug) { + println("Lazy operations found: \n" + lazyops.map(_.id).mkString("\n")) + } + val tpeToLazyops = lazyops.groupBy(_.returnType) + val tpeToAbsClass = tpeToLazyops.map(_._1).map { tpe => + val name = typeNameWOParams(tpe) + val absTParams = getTypeParameters(tpe) map TypeParameterDef.apply + // using tpe name below to avoid mismatches due to type parameters + name -> AbstractClassDef(FreshIdentifier(typeNameToADTName(name), Untyped), + absTParams, None) + }.toMap + var opToAdt = Map[FunDef, CaseClassDef]() + val tpeToADT = tpeToLazyops map { + case (tpe, ops) => + val name = typeNameWOParams(tpe) + val absClass = tpeToAbsClass(name) + val absType = AbstractClassType(absClass, absClass.tparams.map(_.tp)) + val absTParams = absClass.tparams + // create a case class for every operation + val cdefs = ops map { opfd => + assert(opfd.tparams.size == absTParams.size) + val classid = FreshIdentifier(opNameToCCName(opfd.id.name), Untyped) + val cdef = CaseClassDef(classid, opfd.tparams, Some(absType), isCaseObject = false) + val nfields = opfd.params.map { vd => + val fldType = vd.getType + unwrapLazyType(fldType) match { + case None => + ValDef(FreshIdentifier(vd.id.name, fldType)) + case Some(btype) => + val adtType = AbstractClassType(absClass, getTypeParameters(btype)) + ValDef(FreshIdentifier(vd.id.name, adtType)) + } + } + cdef.setFields(nfields) + absClass.registerChild(cdef) + opToAdt += (opfd -> cdef) + cdef + } + (name -> (tpe, absClass, cdefs)) + } + (tpeToADT, opToAdt) + } + + /** + * (a) add state to every function in the program + * (b) thread state through every expression in the program sequentially + * (c) replace lazy constructions with case class creations + * (d) replace isEvaluated with currentState.contains() + * (e) replace accesses to $.value with calls to dispatch with current state + */ + class TransformProgramUsingClosures(p: Program, + tpeToADT: Map[String, ClosureData], + opToAdt: Map[FunDef, CaseClassDef]) { + + val (funsNeedStates, funsRetStates) = funsNeedingnReturningState(p) + // fix an ordering on types so that we can instrument programs with their states in the same order + val tnames = tpeToADT.keys.toSeq + // create a mapping from functions to new functions + val funMap = p.definedFunctions.collect { + case fd if (fd.hasBody && !fd.isLibrary) => + // replace lazy types in parameters and return values + val nparams = fd.params map { vd => + ValDef(vd.id, Some(replaceLazyTypes(vd.getType))) // override type of identifier + } + val nretType = replaceLazyTypes(fd.returnType) + // does the function require implicit state ? + val nfd = if (funsNeedStates(fd)) { + var newTParams = Seq[TypeParameterDef]() + val stTypes = tnames map { tn => + val absClass = tpeToADT(tn)._2 + val tparams = absClass.tparams.map(_ => + TypeParameter.fresh("P@")) + newTParams ++= tparams map TypeParameterDef + SetType(AbstractClassType(absClass, tparams)) + } + val stParams = stTypes.map { stType => + ValDef(FreshIdentifier("st@", stType, true)) + } + val retTypeWithState = + if (funsRetStates(fd)) + TupleType(nretType +: stTypes) + else + nretType + // the type parameters will be unified later + new FunDef(FreshIdentifier(fd.id.name, Untyped), + fd.tparams ++ newTParams, retTypeWithState, nparams ++ stParams) + // body of these functions are defined later + } else { + new FunDef(FreshIdentifier(fd.id.name, Untyped), fd.tparams, nretType, nparams) + } + // copy annotations + fd.flags.foreach(nfd.addFlag(_)) + (fd -> nfd) + }.toMap + + /** + * A set of uninterpreted functions that may be used as targets + * of closures in the eval functions, for efficiency reasons. + */ + lazy val uninterpretedTargets = { + funMap.map { + case (k, v) => + val ufd = new FunDef(FreshIdentifier(v.id.name, v.id.getType, true), + v.tparams, v.returnType, v.params) + (k -> ufd) + }.toMap + } + + /** + * A function for creating a state type for every lazy type. Note that Leon + * doesn't support 'Any' type yet. So we have to have multiple + * state (though this is much clearer it results in more complicated code) + */ + def getStateType(tname: String, tparams: Seq[TypeParameter]) = { + val (_, absdef, _) = tpeToADT(tname) + SetType(AbstractClassType(absdef, tparams)) + } + + def replaceLazyTypes(t: TypeTree): TypeTree = { + unwrapLazyType(t) match { + case None => + val NAryType(tps, tcons) = t + tcons(tps map replaceLazyTypes) + case Some(btype) => + val ntype = AbstractClassType(tpeToADT( + typeNameWOParams(btype))._2, getTypeParameters(btype)) + val NAryType(tps, tcons) = ntype + tcons(tps map replaceLazyTypes) + } + } + + /** + * Create dispatch functions for each lazy type. + * Note: the dispatch functions will be annotated as library so that + * their pre/posts are not checked (the fact that they hold are verified separately) + * Note by using 'assume-pre' we can also assume the preconditions of closures at + * the call-sites. + */ + val adtToOp = opToAdt map { case (k, v) => v -> k } + val evalFunctions = { + tpeToADT map { + case (tname, (tpe, absdef, cdefs)) => + val tparams = getTypeParameters(tpe) + val tparamDefs = tparams map TypeParameterDef.apply + val param1 = FreshIdentifier("cl", AbstractClassType(absdef, tparams)) + val stType = getStateType(tname, tparams) + val param2 = FreshIdentifier("st@", stType) + val retType = TupleType(Seq(tpe, stType)) + val dfun = new FunDef(FreshIdentifier("eval" + absdef.id.name, Untyped), + tparamDefs, retType, Seq(ValDef(param1), ValDef(param2))) + // assign body + // create a match case to switch over the possible class defs and invoke the corresponding functions + val bodyMatchCases = cdefs map { cdef => + val ctype = CaseClassType(cdef, tparams) // we assume that the type parameters of cdefs are same as absdefs + val binder = FreshIdentifier("t", ctype) + val pattern = InstanceOfPattern(Some(binder), ctype) + // create a body of the match + val args = cdef.fields map { fld => CaseClassSelector(ctype, binder.toVariable, fld.id) } + val op = adtToOp(cdef) + val stArgs = // TODO: here we are assuming that only one state is used, fix this. + if (funsNeedStates(op)) + // Note: it is important to use empty state here to eliminate + // dependency on state for the result value + Seq(FiniteSet(Set(), stType.base)) + else Seq() + val targetFun = + if (removeRecursionViaEval && op.hasPostcondition) { + // checking for postcondition is a hack to make sure that we + // do not make all targets uninterpreted + uninterpretedTargets(op) + } else funMap(op) + val invoke = FunctionInvocation(TypedFunDef(targetFun, tparams), args ++ stArgs) // we assume that the type parameters of lazy ops are same as absdefs + val invPart = if (funsRetStates(op)) { + TupleSelect(invoke, 1) // we are only interested in the value + } else invoke + val newst = SetUnion(param2.toVariable, FiniteSet(Set(binder.toVariable), stType.base)) + val rhs = Tuple(Seq(invPart, newst)) + MatchCase(pattern, None, rhs) + } + dfun.body = Some(MatchExpr(param1.toVariable, bodyMatchCases)) + dfun.addFlag(Annotation("axiom", Seq())) + (tname -> dfun) + } + } + + /** + * These are evalFunctions that do not affect the state + */ + val computeFunctions = { + evalFunctions map { + case (tname, evalfd) => + val (tpe, _, _) = tpeToADT(tname) + val param1 = evalfd.params.head + val fun = new FunDef(FreshIdentifier(evalfd.id.name + "*", Untyped), + evalfd.tparams, tpe, Seq(param1)) + val invoke = FunctionInvocation(TypedFunDef(evalfd, evalfd.tparams.map(_.tp)), + Seq(param1.id.toVariable, FiniteSet(Set(), + getStateType(tname, getTypeParameters(tpe)).base))) + fun.body = Some(TupleSelect(invoke, 1)) + (tname -> fun) + } + } + + /** + * Create closure construction functions that ensures a postcondition. + * They are defined for each lazy class since it avoids generics and + * simplifies the type inference (which is not full-fledged in Leon) + */ + val closureCons = tpeToADT map { + case (tname, (_, adt, _)) => + val param1Type = AbstractClassType(adt, adt.tparams.map(_.tp)) + val param1 = FreshIdentifier("cc", param1Type) + val stType = SetType(param1Type) + val param2 = FreshIdentifier("st@", stType) + val tparamDefs = adt.tparams + val fun = new FunDef(FreshIdentifier(closureConsName(tname)), adt.tparams, param1Type, + Seq(ValDef(param1), ValDef(param2))) + fun.body = Some(param1.toVariable) + val resid = FreshIdentifier("res", param1Type) + val postbody = Not(SubsetOf(FiniteSet(Set(resid.toVariable), param1Type), param2.toVariable)) + fun.postcondition = Some(Lambda(Seq(ValDef(resid)), postbody)) + fun.addFlag(Annotation("axiom", Seq())) + (tname -> fun) + } + + def mapNAryOperator(args: Seq[Expr], op: Seq[Expr] => (Map[String, Expr] => Expr, Boolean)) = { + // create n variables to model n lets + val letvars = args.map(arg => FreshIdentifier("arg", arg.getType, true).toVariable) + (args zip letvars).foldRight(op(letvars)) { + case ((arg, letvar), (accCons, stUpdatedBefore)) => + val (argCons, stUpdateFlag) = mapBody(arg) + val cl = if (!stUpdateFlag) { + // here arg does not affect the newstate + (st: Map[String, Expr]) => replace(Map(letvar -> argCons(st)), accCons(st)) // here, we don't have to create a let + } else { + // here arg does affect the newstate + (st: Map[String, Expr]) => + { + val narg = argCons(st) + val argres = FreshIdentifier("a", narg.getType, true).toVariable + val nstateSeq = tnames.zipWithIndex.map { + // states start from index 2 + case (tn, i) => TupleSelect(argres, i + 2) + } + val nstate = (tnames zip nstateSeq).map { + case (tn, st) => (tn -> st) + }.toMap[String, Expr] + val letbody = + if (stUpdatedBefore) accCons(nstate) // here, 'acc' already returns a superseeding state + else Tuple(accCons(nstate) +: nstateSeq) // here, 'acc; only retruns the result + Let(argres.id, narg, + Let(letvar.id, TupleSelect(argres, 1), letbody)) + } + } + (cl, stUpdatedBefore || stUpdateFlag) + } + } + + def mapBody(body: Expr): (Map[String, Expr] => Expr, Boolean) = body match { + + case finv @ FunctionInvocation(_, Seq(FunctionInvocation(TypedFunDef(argfd, tparams), args))) // lazy construction ? + if isLazyInvocation(finv)(p) => + val op = (nargs: Seq[Expr]) => ((st: Map[String, Expr]) => { + val adt = opToAdt(argfd) + val cc = CaseClass(CaseClassType(adt, tparams), nargs) + val baseLazyType = adtNameToTypeName(adt.parent.get.classDef.id.name) + FunctionInvocation(TypedFunDef(closureCons(baseLazyType), tparams), + Seq(cc, st(baseLazyType))) + }, false) + mapNAryOperator(args, op) + + case finv @ FunctionInvocation(_, args) if isEvaluatedInvocation(finv)(p) => // isEval function ? + val op = (nargs: Seq[Expr]) => ((st: Map[String, Expr]) => { + val narg = nargs(0) // there must be only one argument here + val baseType = unwrapLazyType(narg.getType).get + val tname = typeNameWOParams(baseType) + val adtType = AbstractClassType(tpeToADT(tname)._2, getTypeParameters(baseType)) + SubsetOf(FiniteSet(Set(narg), adtType), st(tname)) + }, false) + mapNAryOperator(args, op) + + case finv @ FunctionInvocation(_, args) if isValueInvocation(finv)(p) => // is value function ? + val op = (nargs: Seq[Expr]) => ((st: Map[String, Expr]) => { + val baseType = unwrapLazyType(nargs(0).getType).get // there must be only one argument here + val tname = typeNameWOParams(baseType) + val dispFun = evalFunctions(tname) + val dispFunInv = FunctionInvocation(TypedFunDef(dispFun, + getTypeParameters(baseType)), nargs :+ st(tname)) + val dispRes = FreshIdentifier("dres", dispFun.returnType) + val nstates = tnames map { + case `tname` => + TupleSelect(dispRes.toVariable, 2) + case other => st(other) + } + Let(dispRes, dispFunInv, Tuple(TupleSelect(dispRes.toVariable, 1) +: nstates)) + }, true) + mapNAryOperator(args, op) + + case finv @ FunctionInvocation(_, args) if isStarInvocation(finv)(p) => // is * function ? + val op = (nargs: Seq[Expr]) => ((st: Map[String, Expr]) => { + val baseType = unwrapLazyType(nargs(0).getType).get // there must be only one argument here + val tname = typeNameWOParams(baseType) + val dispFun = computeFunctions(tname) + FunctionInvocation(TypedFunDef(dispFun, getTypeParameters(baseType)), nargs) + }, false) + mapNAryOperator(args, op) + + case FunctionInvocation(TypedFunDef(fd, tparams), args) if funMap.contains(fd) => + mapNAryOperator(args, + (nargs: Seq[Expr]) => ((st: Map[String, Expr]) => { + val stArgs = + if (funsNeedStates(fd)) { + (tnames map st.apply) + } else Seq() + FunctionInvocation(TypedFunDef(funMap(fd), tparams), nargs ++ stArgs) + }, funsRetStates(fd))) + + case Let(id, value, body) => + val (valCons, valUpdatesState) = mapBody(value) + val (bodyCons, bodyUpdatesState) = mapBody(body) + ((st: Map[String, Expr]) => { + val nval = valCons(st) + if (valUpdatesState) { + val freshid = FreshIdentifier(id.name, nval.getType, true) + val nextStates = tnames.zipWithIndex.map { + case (tn, i) => + TupleSelect(freshid.toVariable, i + 2) + }.toSeq + val nsMap = (tnames zip nextStates).toMap + val transBody = replace(Map(id.toVariable -> TupleSelect(freshid.toVariable, 1)), + bodyCons(nsMap)) + if (bodyUpdatesState) + Let(freshid, nval, transBody) + else + Let(freshid, nval, Tuple(transBody +: nextStates)) + } else + Let(id, nval, bodyCons(st)) + }, valUpdatesState || bodyUpdatesState) + + case IfExpr(cond, thn, elze) => + val (condCons, condState) = mapBody(cond) + val (thnCons, thnState) = mapBody(thn) + val (elzeCons, elzeState) = mapBody(elze) + ((st: Map[String, Expr]) => { + val (ncondCons, nst) = + if (condState) { + val cndExpr = condCons(st) + val bder = FreshIdentifier("c", cndExpr.getType) + val condst = tnames.zipWithIndex.map { + case (tn, i) => tn -> TupleSelect(bder.toVariable, i + 2) + }.toMap + ((th: Expr, el: Expr) => + Let(bder, cndExpr, IfExpr(TupleSelect(bder.toVariable, 1), th, el)), + condst) + } else { + ((th: Expr, el: Expr) => IfExpr(condCons(st), th, el), st) + } + val nelze = + if ((condState || thnState) && !elzeState) + Tuple(elzeCons(nst) +: tnames.map(nst.apply)) + else elzeCons(nst) + val nthn = + if (!thnState && (condState || elzeState)) + Tuple(thnCons(nst) +: tnames.map(nst.apply)) + else thnCons(nst) + ncondCons(nthn, nelze) + }, condState || thnState || elzeState) + + case MatchExpr(scr, cases) => + val (scrCons, scrUpdatesState) = mapBody(scr) + val casesRes = cases.foldLeft(Seq[(Map[String, Expr] => Expr, Boolean)]()) { + case (acc, MatchCase(pat, None, rhs)) => + acc :+ mapBody(rhs) + case mcase => + throw new IllegalStateException("Match case with guards are not supported yet: " + mcase) + } + val casesUpdatesState = casesRes.exists(_._2) + ((st: Map[String, Expr]) => { + val scrExpr = scrCons(st) + val (nscrCons, scrst) = if (scrUpdatesState) { + val bder = FreshIdentifier("scr", scrExpr.getType) + val scrst = tnames.zipWithIndex.map { + case (tn, i) => tn -> TupleSelect(bder.toVariable, i + 2) + }.toMap + ((ncases: Seq[MatchCase]) => + Let(bder, scrExpr, MatchExpr(TupleSelect(bder.toVariable, 1), ncases)), + scrst) + } else { + ((ncases: Seq[MatchCase]) => MatchExpr(scrExpr, ncases), st) + } + val ncases = (cases zip casesRes).map { + case (MatchCase(pat, None, _), (caseCons, caseUpdatesState)) => + val nrhs = + if ((scrUpdatesState || casesUpdatesState) && !caseUpdatesState) + Tuple(caseCons(scrst) +: tnames.map(scrst.apply)) + else caseCons(scrst) + MatchCase(pat, None, nrhs) + } + nscrCons(ncases) + }, scrUpdatesState || casesUpdatesState) + + // need to reset types in the case of case class constructor calls + case CaseClass(cct, args) => + val ntype = replaceLazyTypes(cct).asInstanceOf[CaseClassType] + mapNAryOperator(args, + (nargs: Seq[Expr]) => ((st: Map[String, Expr]) => CaseClass(ntype, nargs), false)) + + case Operator(args, op) => + // here, 'op' itself does not create a new state + mapNAryOperator(args, + (nargs: Seq[Expr]) => ((st: Map[String, Expr]) => op(nargs), false)) + + case t: Terminal => (_ => t, false) + } + + /** + * Assign bodies for the functions in funMap. + */ + def transformFunctions = { + funMap foreach { + case (fd, nfd) => + // /println("Considering function: "+fd) + // Here, using name to identify 'state' parameters, also relying + // on fact that nfd.params are in the same order as tnames + val stateParams = nfd.params.foldLeft(Seq[Expr]()) { + case (acc, ValDef(id, _)) if id.name.startsWith("st@") => + acc :+ id.toVariable + case (acc, _) => acc + } + val initStateMap = tnames zip stateParams toMap + val (nbodyFun, bodyUpdatesState) = mapBody(fd.body.get) + val nbody = nbodyFun(initStateMap) + val bodyWithState = + if (!bodyUpdatesState && funsRetStates(fd)) { + val freshid = FreshIdentifier("bres", nbody.getType) + Let(freshid, nbody, Tuple(freshid.toVariable +: stateParams)) + } else nbody + nfd.body = Some(simplifyLets(bodyWithState)) + //println(s"Body of ${fd.id.name} after conversion&simp: ${nfd.body}") + + // Important: specifications use lazy semantics but + // their state changes are ignored after their execution. + // This guarantees their observational purity/transparency + // collect class invariants that need to be added + if (fd.hasPrecondition) { + val (npreFun, preUpdatesState) = mapBody(fd.precondition.get) + nfd.precondition = + if (preUpdatesState) + Some(TupleSelect(npreFun(initStateMap), 1)) // ignore state updated by pre + else Some(npreFun(initStateMap)) + } + + if (fd.hasPostcondition) { + val Lambda(arg @ Seq(ValDef(resid, _)), post) = fd.postcondition.get + val (npostFun, postUpdatesState) = mapBody(post) + val newres = FreshIdentifier(resid.name, bodyWithState.getType) + val npost1 = + if (bodyUpdatesState || funsRetStates(fd)) { + val stmap = tnames.zipWithIndex.map { + case (tn, i) => (tn -> TupleSelect(newres.toVariable, i + 2)) + }.toMap + replace(Map(resid.toVariable -> TupleSelect(newres.toVariable, 1)), npostFun(stmap)) + } else + replace(Map(resid.toVariable -> newres.toVariable), npostFun(initStateMap)) + val npost = + if (postUpdatesState) + TupleSelect(npost1, 1) // ignore state updated by post + else npost1 + nfd.postcondition = Some(Lambda(Seq(ValDef(newres)), npost)) + } + } + } + + /** + * Assign specs for the eval functions + */ + def transformEvals = { + evalFunctions.foreach { + case (tname, evalfd) => + val cdefs = tpeToADT(tname)._3 + val tparams = evalfd.tparams.map(_.tp) + val postres = FreshIdentifier("res", evalfd.returnType) + // create a match case to switch over the possible class defs and invoke the corresponding functions + val postMatchCases = cdefs map { cdef => + val ctype = CaseClassType(cdef, tparams) + val binder = FreshIdentifier("t", ctype) + val pattern = InstanceOfPattern(Some(binder), ctype) + // create a body of the match + val op = adtToOp(cdef) + val targetFd = funMap(op) + val rhs = if (targetFd.hasPostcondition) { + val args = cdef.fields map { fld => CaseClassSelector(ctype, binder.toVariable, fld.id) } + val stArgs = + if (funsNeedStates(op)) // TODO: here we are assuming that only one state is used, fix this. + Seq(evalfd.params.last.toVariable) + else Seq() + val Lambda(Seq(resarg), targetPost) = targetFd.postcondition.get + val replaceMap = (targetFd.params.map(_.toVariable) zip (args ++ stArgs)).toMap[Expr, Expr] + + (resarg.toVariable -> postres.toVariable) + replace(replaceMap, targetPost) + } else + Util.tru + MatchCase(pattern, None, rhs) + } + evalfd.postcondition = Some( + Lambda(Seq(ValDef(postres)), + MatchExpr(evalfd.params.head.toVariable, postMatchCases))) + } + } + + def apply() = { + transformFunctions + transformEvals + val progWithClosures = addFunDefs(copyProgram(p, + (defs: Seq[Definition]) => defs.flatMap { + case fd: FunDef if funMap.contains(fd) => + if (removeRecursionViaEval) + Seq(funMap(fd), uninterpretedTargets(fd)) + else Seq(funMap(fd)) + case d => Seq(d) + }), closureCons.values ++ evalFunctions.values ++ computeFunctions.values, + funMap.values.last) + if (dumpProgramWithClosures) + println("Program with closures \n" + ScalaPrinter(progWithClosures)) + progWithClosures + } + } + + /** + * Generate lemmas that ensure that preconditions hold for closures. + */ + def assertClosurePres(p: Program): Program = { + + def hasClassInvariants(cc: CaseClass): Boolean = { + val opname = ccNameToOpName(cc.ct.classDef.id.name) + functionByName(opname, p).get.hasPrecondition + } + + var anchorfd: Option[FunDef] = None + val lemmas = p.definedFunctions.flatMap { + case fd if (fd.hasBody && !fd.isLibrary) => + //println("collection closure creation preconditions for: "+fd) + val closures = CollectorWithPaths { + case FunctionInvocation(TypedFunDef(fund, _), + Seq(cc: CaseClass, st)) if isClosureCons(fund) && hasClassInvariants(cc) => + (cc, st) + } traverse (fd.body.get) // Note: closures cannot be created in specs + // Note: once we have separated normal preconditions from state preconditions + // it suffices to just consider state preconditions here + closures.map { + case ((CaseClass(CaseClassType(ccd, _), args), st), path) => + anchorfd = Some(fd) + val target = functionByName(ccNameToOpName(ccd.id.name), p).get //find the target corresponding to the closure + val pre = target.precondition.get + val nargs = + if (target.params.size > args.size) // target takes state ? + args :+ st + else args + val pre2 = replaceFromIDs((target.params.map(_.id) zip nargs).toMap, pre) + val vc = Implies(And(Util.precOrTrue(fd), path), pre2) + // create a function for each vc + val lemmaid = FreshIdentifier(ccd.id.name + fd.id.name + "Lem", Untyped, true) + val params = variablesOf(vc).toSeq.map(v => ValDef(v)) + val tparams = params.flatMap(p => getTypeParameters(p.getType)).distinct map TypeParameterDef + val lemmafd = new FunDef(lemmaid, tparams, BooleanType, params) + lemmafd.body = Some(vc) + // assert the lemma is true + val resid = FreshIdentifier("holds", BooleanType) + lemmafd.postcondition = Some(Lambda(Seq(ValDef(resid)), resid.toVariable)) +// /println("Created lemma function: "+fd) + lemmafd + } + case _ => Seq() + } + if (!lemmas.isEmpty) + addFunDefs(p, lemmas, anchorfd.get) + else p + } + + /** + * Returns all functions that 'need' states to be passed in + * and those that return a new state. + * TODO: implement backwards BFS by reversing the graph + */ + def funsNeedingnReturningState(prog: Program) = { + val cg = CallGraphUtil.constructCallGraph(prog, false, true) + var needRoots = Set[FunDef]() + var retRoots = Set[FunDef]() + prog.definedFunctions.foreach { + case fd if fd.hasBody && !fd.isLibrary => + postTraversal { + case finv: FunctionInvocation if isEvaluatedInvocation(finv)(prog) => + needRoots += fd + case finv: FunctionInvocation if isValueInvocation(finv)(prog) => + needRoots += fd + retRoots += fd + case _ => + ; + }(fd.body.get) + case _ => ; + } + val funsNeedStates = prog.definedFunctions.filterNot(fd => + cg.transitiveCallees(fd).toSet.intersect(needRoots).isEmpty).toSet + val funsRetStates = prog.definedFunctions.filterNot(fd => + cg.transitiveCallees(fd).toSet.intersect(retRoots).isEmpty).toSet + (funsNeedStates, funsRetStates) + } + + /** + * This performs a little bit of Hindley-Milner type Inference + * to correct the local types and also unify type parameters + */ + def rectifyLocalTypeAndTypeParameters(prog: Program): Program = { + var typeClasses = new DisjointSets[TypeTree]() + prog.definedFunctions.foreach { + case fd if fd.hasBody && !fd.isLibrary => + postTraversal { + case call @ FunctionInvocation(TypedFunDef(fd, tparams), args) => + // unify formal type parameters with actual type arguments + (fd.tparams zip tparams).foreach(x => typeClasses.union(x._1.tp, x._2)) + // unify the type parameters of types of formal parameters with + // type arguments of actual arguments + (fd.params zip args).foreach { x => + (x._1.getType, x._2.getType) match { + case (SetType(tf: ClassType), SetType(ta: ClassType)) => + (tf.tps zip ta.tps).foreach { x => typeClasses.union(x._1, x._2) } + case (tf: TypeParameter, ta: TypeParameter) => + typeClasses.union(tf, ta) + case _ => + // others could be ignored as they are not part of the state + ; + /*throw new IllegalStateException(s"Types of formal and actual parameters: ($tf, $ta)" + + s"do not match for call: $call")*/ + } + } + case _ => ; + }(fd.fullBody) + case _ => ; + } + + val equivTPs = typeClasses.toMap + val fdMap = prog.definedFunctions.collect { + case fd if !fd.isLibrary => //fd.hasBody && + + val (tempTPs, otherTPs) = fd.tparams.map(_.tp).partition { + case tp if tp.id.name.endsWith("@") => true + case _ => false + } + val others = otherTPs.toSet[TypeTree] + // for each of the type parameter pick one representative from its equvialence class + val tpMap = fd.tparams.map { + case TypeParameterDef(tp) => + val tpclass = equivTPs.getOrElse(tp, Set(tp)) + val candReps = tpclass.filter(r => others.contains(r) || !r.isInstanceOf[TypeParameter]) + val concRep = candReps.find(!_.isInstanceOf[TypeParameter]) + val rep = + if (concRep.isDefined) // there exists a concrete type ? + concRep.get + else if (!candReps.isEmpty) + candReps.head + else + throw new IllegalStateException(s"Cannot find a non-placeholder in equivalence class $tpclass for fundef: \n $fd") + tp -> rep + }.toMap + val instf = instantiateTypeParameters(tpMap) _ + val paramMap = fd.params.map { + case vd @ ValDef(id, _) => + (id -> FreshIdentifier(id.name, instf(vd.getType))) + }.toMap + val ntparams = tpMap.values.toSeq.distinct.collect { case tp: TypeParameter => tp } map TypeParameterDef + val nfd = new FunDef(fd.id.freshen, ntparams, instf(fd.returnType), + fd.params.map(vd => ValDef(paramMap(vd.id)))) + fd -> (nfd, tpMap, paramMap) + }.toMap + + /* + * Replace fundefs and unify type parameters in function invocations. + * Replace old parameters by new parameters + */ + def transformFunBody(ifd: FunDef) = { + val (nfd, tpMap, paramMap) = fdMap(ifd) + // need to handle tuple select specially as it breaks if the type of + // the tupleExpr if it is not TupleTyped. + // cannot use simplePostTransform because of this + def rec(e: Expr): Expr = e match { + case FunctionInvocation(TypedFunDef(callee, targsOld), args) => + val targs = targsOld.map { + case tp: TypeParameter => tpMap.getOrElse(tp, tp) + case t => t + }.distinct + val ncallee = + if (fdMap.contains(callee)) + fdMap(callee)._1 + else callee + FunctionInvocation(TypedFunDef(ncallee, targs), args map rec) + case Variable(id) if paramMap.contains(id) => + paramMap(id).toVariable + case TupleSelect(tup, index) => TupleSelect(rec(tup), index) + case Operator(args, op) => op(args map rec) + case t: Terminal => t + } + val nbody = rec(ifd.fullBody) + //println("Inferring types for: "+ifd.id) + val initGamma = nfd.params.map(vd => vd.id -> vd.getType).toMap + inferTypesOfLocals(nbody, initGamma) + } + copyProgram(prog, (defs: Seq[Definition]) => defs.map { + case fd: FunDef if fdMap.contains(fd) => + val nfd = fdMap(fd)._1 + if (!fd.fullBody.isInstanceOf[NoTree]) { + nfd.fullBody = simplifyLetsAndLetsWithTuples(transformFunBody(fd)) + } + fd.flags.foreach(nfd.addFlag(_)) + nfd + case d => d + }) + } + + import leon.solvers._ + import leon.solvers.z3._ + + def checkSpecifications(prog: Program) { + val ctx = Main.processOptions(Seq("--solvers=smt-cvc4","--debug=solver")) + val report = AnalysisPhase.run(ctx)(prog) + println(report.summaryString) + /*val timeout = 10 + val rep = ctx.reporter + * val fun = prog.definedFunctions.find(_.id.name == "firstUnevaluated").get + // create a verification context. + val solver = new FairZ3Solver(ctx, prog) with TimeoutSolver + val solverF = new TimeoutSolverFactory(SolverFactory(() => solver), timeout * 1000) + val vctx = VerificationContext(ctx, prog, solverF, rep) + val vc = (new DefaultTactic(vctx)).generatePostconditions(fun)(0) + val s = solverF.getNewSolver() + try { + rep.info(s" - Now considering '${vc.kind}' VC for ${vc.fd.id} @${vc.getPos}...") + val tStart = System.currentTimeMillis + s.assertVC(vc) + val satResult = s.check + val dt = System.currentTimeMillis - tStart + val res = satResult match { + case None => + rep.info("Cannot prove or disprove specifications") + case Some(false) => + rep.info("Valid") + case Some(true) => + println("Invalid - counter-example: " + s.getModel) + } + } finally { + s.free() + }*/ + } + + /** + * NOT USED CURRENTLY + * Lift the specifications on functions to the invariants corresponding + * case classes. + * Ideally we should class invariants here, but it is not currently supported + * so we create a functions that can be assume in the pre and post of functions. + * TODO: can this be optimized + */ + /* def liftSpecsToClosures(opToAdt: Map[FunDef, CaseClassDef]) = { + val invariants = opToAdt.collect { + case (fd, ccd) if fd.hasPrecondition => + val transFun = (args: Seq[Identifier]) => { + val argmap: Map[Expr, Expr] = (fd.params.map(_.id.toVariable) zip args.map(_.toVariable)).toMap + replace(argmap, fd.precondition.get) + } + (ccd -> transFun) + }.toMap + val absTypes = opToAdt.values.collect { + case cd if cd.parent.isDefined => cd.parent.get.classDef + } + val invFuns = absTypes.collect { + case abs if abs.knownCCDescendents.exists(invariants.contains) => + val absType = AbstractClassType(abs, abs.tparams.map(_.tp)) + val param = ValDef(FreshIdentifier("$this", absType)) + val tparams = abs.tparams + val invfun = new FunDef(FreshIdentifier(abs.id.name + "$Inv", Untyped), + tparams, BooleanType, Seq(param)) + (abs -> invfun) + }.toMap + // assign bodies for the 'invfuns' + invFuns.foreach { + case (abs, fd) => + val bodyCases = abs.knownCCDescendents.collect { + case ccd if invariants.contains(ccd) => + val ctype = CaseClassType(ccd, fd.tparams.map(_.tp)) + val cvar = FreshIdentifier("t", ctype) + val fldids = ctype.fields.map { + case ValDef(fid, Some(fldtpe)) => + FreshIdentifier(fid.name, fldtpe) + } + val pattern = CaseClassPattern(Some(cvar), ctype, + fldids.map(fid => WildcardPattern(Some(fid)))) + val rhsInv = invariants(ccd)(fldids) + // assert the validity of substructures + val rhsValids = fldids.flatMap { + case fid if fid.getType.isInstanceOf[ClassType] => + val t = fid.getType.asInstanceOf[ClassType] + val rootDef = t match { + case absT: AbstractClassType => absT.classDef + case _ if t.parent.isDefined => + t.parent.get.classDef + } + if (invFuns.contains(rootDef)) { + List(FunctionInvocation(TypedFunDef(invFuns(rootDef), t.tps), + Seq(fid.toVariable))) + } else + List() + case _ => List() + } + val rhs = Util.createAnd(rhsInv +: rhsValids) + MatchCase(pattern, None, rhs) + } + // create a default case + val defCase = MatchCase(WildcardPattern(None), None, Util.tru) + val matchExpr = MatchExpr(fd.params.head.id.toVariable, bodyCases :+ defCase) + fd.body = Some(matchExpr) + } + invFuns + }*/ +} + diff --git a/testcases/lazy-datastructures/RealTimeQueue-transformed.scala b/testcases/lazy-datastructures/RealTimeQueue-transformed.scala new file mode 100644 index 000000000..d416dacb0 --- /dev/null +++ b/testcases/lazy-datastructures/RealTimeQueue-transformed.scala @@ -0,0 +1,292 @@ +//import leon.lazyeval._ +import leon.lang._ +import leon.annotation._ +import leon.collection._ + +object RealTimeQueue { + abstract class LList[T] + + def LList$isEmpty[T]($this : LList[T]): Boolean = { + $this match { + case SNil() => + true + case _ => + false + } + } + + def LList$isCons[T]($this : LList[T]): Boolean = { + $this match { + case SCons(_, _) => + true + case _ => + false + } + } + + def LList$size[T]($this : LList[T]): BigInt = { + $this match { + case SNil() => + BigInt(0) + case SCons(x, t) => + BigInt(1) + ssize[T](t) + } + } ensuring { + (x$1 : BigInt) => (x$1 >= BigInt(0)) + } + + case class SCons[T](x : T, tail : LazyLList[T]) extends LList[T] + + case class SNil[T]() extends LList[T] + + // TODO: closures are not ADTs since two closures with same arguments are not necessarily equal but + // ADTs are equal. This creates a bit of problem in checking if a closure belongs to a set or not. + // However, currently we are assuming that such problems do not happen. + // A solution is to pass around a dummy id that is unique for each closure. + abstract class LazyLList[T] + + case class Rotate[T](f : LazyLList[T], r : List[T], a : LazyLList[T], res: LList[T]) extends LazyLList[T] + + case class Lazyarg[T](newa : LList[T]) extends LazyLList[T] + + + def ssize[T](l : LazyLList[T]): BigInt = { + val clist = evalLazyLList[T](l, Set[LazyLList[T]]())._1 + LList$size[T](clist) + + } ensuring(res => res >= 0) + + def isConcrete[T](l : LazyLList[T], st : Set[LazyLList[T]]): Boolean = { + Set[LazyLList[T]](l).subsetOf(st) && (evalLazyLList[T](l, st)._1 match { + case SCons(_, tail) => + isConcrete[T](tail, st) + case _ => + true + }) || LList$isEmpty[T](evalLazyLList[T](l, st)._1) + } + + // an assertion: closures created by evaluating a closure will result in unevaluated closure + @library + def lemmaLazy[T](l : LazyLList[T], st : Set[LazyLList[T]]) : Boolean = { + Set[LazyLList[T]](l).subsetOf(st) || { + evalLazyLList[T](l, Set[LazyLList[T]]())._1 match { + case SCons(_, tail) => + l != tail && !Set[LazyLList[T]](tail).subsetOf(st) + case _ => true + } + } +// Set[LazyLList[T]](l).subsetOf(st) || { +// val (nval, nst, _) = evalLazyLList[T](l, st) +// nval match { +// case SCons(_, tail) => +// !Set[LazyLList[T]](tail).subsetOf(nst) +// case _ => true +// } +// } + } holds + + def firstUnevaluated[T](l : LazyLList[T], st : Set[LazyLList[T]]): LazyLList[T] = { + if (Set[LazyLList[T]](l).subsetOf(st)) { + evalLazyLList[T](l, st)._1 match { + case SCons(_, tail) => + firstUnevaluated[T](tail, st) + case _ => + l + } + } else { + l + } + } ensuring(res => (!LList$isEmpty[T](evalLazyLList[T](res, st)._1) || isConcrete[T](l, st)) && + (LList$isEmpty[T](evalLazyLList[T](res, st)._1) || !Set[LazyLList[T]](res).subsetOf(st)) && + { val (nval, nst, _) = evalLazyLList[T](res, st) + nval match { + case SCons(_, tail) => + firstUnevaluated(l, nst) == tail + case _ => true + } + } && + lemmaLazy(res, st) + ) + + @library + def evalLazyLList[T](cl : LazyLList[T], st : Set[LazyLList[T]]): (LList[T], Set[LazyLList[T]], BigInt) = { + cl match { + case t : Rotate[T] => + val (rres, _, rtime) = rotate(t.f, t.r, t.a, st) + val tset = Set[LazyLList[T]](t) + val tme = + if(tset.subsetOf(st)) + BigInt(1) + else // time of rotate + rtime + (rotate(t.f, t.r, t.a, Set[LazyLList[T]]())._1, st ++ tset, tme) + + case t : Lazyarg[T] => + (lazyarg(t.newa), st ++ Set[LazyLList[T]](t), BigInt(1)) + } + } ensuring(res => (cl match { + case t : Rotate[T] => + LList$size(res._1) == ssize(t.f) + t.r.size + ssize(t.a) && + res._1 != SNil[T]() && + res._3 <= 4 + case _ => true + }) + // && + // (res._1 match { + // case SCons(_, tail) => + // Set[LazyLList[T]](cl).subsetOf(st) || !Set[LazyLList[T]](tail).subsetOf(res._2) + // case _ => true + // }) + ) + + @extern + def rotate2[T](f : LazyLList[T], r : List[T], a : LazyLList[T], st : Set[LazyLList[T]]): (LList[T], Set[LazyLList[T]], BigInt) = ??? + + def lazyarg[T](newa : LList[T]): LList[T] = { + newa + } + + def streamScheduleProperty[T](s : LazyLList[T], sch : LazyLList[T], st : Set[LazyLList[T]]): Boolean = { + firstUnevaluated[T](s, st) == sch + } + + case class Queue[T](f : LazyLList[T], r : List[T], s : LazyLList[T]) + + def Queue$isEmpty[T]($this : Queue[T], st : Set[LazyLList[T]]): Boolean = { + LList$isEmpty[T](evalLazyLList[T]($this.f, st)._1) + } + + def Queue$valid[T]($this : Queue[T], st : Set[LazyLList[T]]): Boolean = { + streamScheduleProperty[T]($this.f, $this.s, st) && + ssize[T]($this.s) == ssize[T]($this.f) - $this.r.size + } + + // things to prove: + // (a0) prove that pre implies post for 'rotate' (this depends on the assumption on eval) + // (a) Rotate closure creations satisfy the preconditions of 'rotate' (or) + // for the preconditions involving state, the state at the Rotate invocation sites (through eval) + // satisfy the preconditions of 'rotate' + // (b) If we verify that preconditoins involving state hold at creation time, + // then we can assume them for calling time only if the preconditions are monotonic + // with respect to inclusion of relation of state (this also have to be shown) + // Note: using both captured and calling context is possible but is more involved + // (c) Assume that 'eval' ensures the postcondition of 'rotate' + // (d) Moreover we can also assume that the preconditons of rotate hold whenever we use a closure + + // proof of (a) + // (i) for stateless invariants this can be proven by treating lazy eager, + // so not doing this here + + // monotonicity of isConcrete + def lemmaConcreteMonotone[T](f : LazyLList[T], st1 : Set[LazyLList[T]], st2 : Set[LazyLList[T]]) : Boolean = { + (evalLazyLList[T](f, st1)._1 match { + case SCons(_, tail) => + lemmaConcreteMonotone(tail, st1, st2) + case _ => + true + }) && + !(st1.subsetOf(st2) && isConcrete(f, st1)) || isConcrete(f, st2) + } holds + + // proof that the precondition isConcrete(f, st) holds for closure creation in 'rotate' function + def rotateClosureLemma1[T](f : LazyLList[T], st : Set[LazyLList[T]]): Boolean = { + require(isConcrete(f, st)) + val dres = evalLazyLList[T](f, st); + dres._1 match { + case SCons(x, tail) => + isConcrete(tail, st) + case _ => true + } + } holds + + // proof that the precondition isConcrete(f, st) holds for closure creation in 'createQueue' function + def rotateClosureLemma2[T](f : LazyLList[T], sch : LazyLList[T], st : Set[LazyLList[T]]): Boolean = { + require(streamScheduleProperty[T](f, sch, st)) // && ssize[T](sch) == (ssize[T](f) - r.size) + BigInt(1)) + val dres4 = evalLazyLList[T](sch, st); + dres4._1 match { + case SNil() => + //isConcrete(f, dres4._2) + isConcrete(f, st) // the above is implied by the monotonicity of 'isConcrete' + case _ => true + } + } holds + + // proof that the precondition isConcrete(f, st) holds for closure creation in 'dequeue' function + def rotateClosureLemma3[T](q : Queue[T], st : Set[LazyLList[T]]): Boolean = { + require(!Queue$isEmpty[T](q, st) && Queue$valid[T](q, st)) + val dres7 = evalLazyLList[T](q.f, st); + val SCons(x, nf) = dres7._1 + val dres8 = evalLazyLList[T](q.s, dres7._2); + dres8._1 match { + case SNil() => + isConcrete(nf, st) + // the above would imply: isConcrete(nf, dres8._2) by monotonicity + case _ => true + } + } holds + + // part(c) assume postconditon of 'rotate' in closure invocation time and also + // the preconditions of 'rotate' if necesssary, and prove the properties of the + // methods that invoke closures + + // proving specifications of 'rotate' (only state specifications are interesting) + def rotate[T](f : LazyLList[T], r : List[T], a : LazyLList[T], st : Set[LazyLList[T]]): (LList[T], Set[LazyLList[T]], BigInt) = { + require(r.size == ssize[T](f) + BigInt(1) && isConcrete(f, st)) + val dres = evalLazyLList[T](f, st); + (dres._1, r) match { + case (SNil(), Cons(y, _)) => + (SCons[T](y, a), dres._2, dres._3 + 2) + case (SCons(x, tail), Cons(y, r1)) => + val na = Lazyarg[T](SCons[T](y, a)) + (SCons[T](x, Rotate[T](tail, r1, na, SNil[T]())), dres._2, dres._3 + 3) + } + } ensuring(res => LList$size(res._1) == ssize(f) + r.size + ssize(a) && + res._1 != SNil[T]() && + res._3 <= 4) + + + // a stub for creating new closure (ensures that the new closures do not belong to the old state) + // Note: this could result in inconsistency since we are associating unique ids with closures + @library + def newclosure[T](rt: Rotate[T], st: Set[LazyLList[T]]) = { + (rt, st) + } ensuring(res => !Set[LazyLList[T]](res._1).subsetOf(st)) + + // proving specifications of 'createQueue' (only state specifications are interesting) + def createQueue[T](f : LazyLList[T], r : List[T], sch : LazyLList[T], st : Set[LazyLList[T]]): (Queue[T], Set[LazyLList[T]], BigInt) = { + require(streamScheduleProperty[T](f, sch, st) && + ssize[T](sch) == (ssize[T](f) - r.size) + BigInt(1)) + val dres4 = evalLazyLList[T](sch, st); + dres4._1 match { + case SCons(_, tail) => + (Queue[T](f, r, tail), dres4._2, dres4._3 + 2) + case SNil() => + val rotres1 = newclosure(Rotate[T](f, r, Lazyarg[T](SNil[T]()), SNil[T]()), dres4._2); // can also directly call rotate here + (Queue[T](rotres1._1, List[T](), rotres1._1), dres4._2, dres4._3 + 3) + } + } ensuring(res => Queue$valid[T](res._1, res._2) && + res._3 <= 7) + + // proving specifications of 'enqueue' + def enqueue[T](x : T, q : Queue[T], st : Set[LazyLList[T]]): (Queue[T], Set[LazyLList[T]], BigInt) = { + require(Queue$valid[T](q, st)) + createQueue[T](q.f, Cons[T](x, q.r), q.s, st) + } ensuring (res => Queue$valid[T](res._1, res._2) && + res._3 <= 7) + + // proving specifications of 'dequeue' + def dequeue[T](q : Queue[T], st : Set[LazyLList[T]]): (Queue[T], Set[LazyLList[T]], BigInt) = { + require(!Queue$isEmpty[T](q, st) && Queue$valid[T](q, st)) + val dres7 = evalLazyLList[T](q.f, st); + val SCons(x, nf) = dres7._1 + val dres8 = evalLazyLList[T](q.s, dres7._2); + dres8._1 match { + case SCons(_, nsch) => + (Queue[T](nf, q.r, nsch), dres8._2, dres7._3 + dres8._3 + 3) + case _ => + val rotres3 = newclosure(Rotate[T](nf, q.r, Lazyarg[T](SNil[T]()), SNil[T]()), dres8._2); + (Queue[T](rotres3._1, List[T](), rotres3._1), dres8._2, dres7._3 + dres8._3 + 4) + } + } ensuring(res => Queue$valid[T](res._1, res._2) && + res._3 <= 12) +} diff --git a/testcases/lazy-datastructures/RealTimeQueue.scala b/testcases/lazy-datastructures/RealTimeQueue.scala new file mode 100644 index 000000000..a058163be --- /dev/null +++ b/testcases/lazy-datastructures/RealTimeQueue.scala @@ -0,0 +1,133 @@ +import leon.lazyeval._ +import leon.lang._ +import leon.annotation._ +import leon.collection._ + +object RealTimeQueue { + + sealed abstract class LList[T] { + def isEmpty: Boolean = { + this match { + case SNil() => true + case _ => false + } + } + + def isCons: Boolean = { + this match { + case SCons(_, _) => true + case _ => false + } + } + + def size: BigInt = { + this match { + case SNil() => BigInt(0) + case SCons(x, t) => 1 + ssize(t) + } + } ensuring (_ >= 0) + } + case class SCons[T](x: T, tail: $[LList[T]]) extends LList[T] + case class SNil[T]() extends LList[T] + + def ssize[T](l: $[LList[T]]): BigInt = (l*).size + + def isConcrete[T](l: $[LList[T]]): Boolean = { + (l.isEvaluated && (l* match { + case SCons(_, tail) => + isConcrete(tail) + case _ => true + })) || (l*).isEmpty + } + + // an axiom about lazy streams (this should be a provable axiom, but not as of now) + @axiom + def streamLemma[T](l: $[LList[T]]) = { + l.isEvaluated || + (l* match { + case SCons(_, tail) => + l != tail && !tail.isEvaluated + case _ => true + }) + } holds + + def firstUnevaluated[T](l: $[LList[T]]): $[LList[T]] = { + if (l.isEvaluated) { + l* match { + case SCons(_, tail) => + firstUnevaluated(tail) + case _ => l + } + } else + l + } ensuring (res => (!(res*).isEmpty || isConcrete(l)) && //if there are no lazy closures then the stream is concrete + ((res*).isEmpty || !res.isEvaluated) && // if the return value is not a Nil closure then it would not have been evaluated + streamLemma(res) && + (res.value match { + case SCons(_, tail) => + firstUnevaluated(l) == tail // after evaluating the firstUnevaluated closure in 'l' we get the next unevaluated closure + case _ => true + }) + ) + + def streamScheduleProperty[T](s: $[LList[T]], sch: $[LList[T]]) = { + firstUnevaluated(s) == sch + } + + case class Queue[T](f: $[LList[T]], r: List[T], s: $[LList[T]]) { + def isEmpty = (f*).isEmpty + def valid = { + streamScheduleProperty(f, s) && + ssize(s) == ssize(f) - r.size //invariant: |s| = |f| - |r| + } + } + + //@lazyproc + def rotate[T](f: $[LList[T]], r: List[T], a: $[LList[T]]): LList[T] = { + require(r.size == ssize(f) + 1) // isConcrete(f) // size invariant between 'f' and 'r' holds + (f.value, r) match { + case (SNil(), Cons(y, _)) => //in this case 'y' is the only element in 'r' + SCons[T](y, a) + case (SCons(x, tail), Cons(y, r1)) => + val newa: LList[T] = SCons[T](y, a) + val rot = $(rotate(tail, r1, $(newa))) //this creates a lazy rotate operation + SCons[T](x, rot) + } + } ensuring (res => res.size == ssize(f) + r.size + ssize(a) && res.isCons) + //&& res._2 <= O(1) //time bound) + + // TODO: make newa into sch to avoid a different closure category + def createQueue[T](f: $[LList[T]], r: List[T], sch: $[LList[T]]): Queue[T] = { + require(streamScheduleProperty(f, sch) && + ssize(sch) == ssize(f) - r.size + 1) //size invariant holds + sch.value match { // evaluate the schedule if it is not evaluated + case SCons(_, tail) => + Queue(f, r, tail) + case SNil() => + val newa: LList[T] = SNil[T]() + val rotres = $(rotate(f, r, $(newa))) + Queue(rotres, Nil[T](), rotres) + } + } ensuring (res => res.valid) + + def enqueue[T](x: T, q: Queue[T]): Queue[T] = { + require(q.valid) + createQueue(q.f, Cons[T](x, q.r), q.s) + } ensuring (res => res.valid) + + def dequeue[T](q: Queue[T]): Queue[T] = { + require(!q.isEmpty && q.valid) + q.f.value match { + case SCons(x, nf) => + q.s.value match { + case SCons(_, st) => //here, the precondition of createQueue (reg. suffix property) may get violated, so it is handled specially here. + Queue(nf, q.r, st) + case _ => + val newa: LList[T] = SNil[T]() + val rotres = $(rotate(nf, q.r, $(newa))) + Queue(rotres, Nil[T](), rotres) + } + } + } ensuring (res => res.valid) +} + \ No newline at end of file -- GitLab