From 4ed54a4054c47d2d87a8c355efd65f51bf19117a Mon Sep 17 00:00:00 2001 From: Ravi <ravi.kandhadai@epfl.ch> Date: Sun, 4 Oct 2015 01:12:54 +0200 Subject: [PATCH] Resolving unknown in solver. Blame: type inference for FunInv with type parameters --- src/main/scala/leon/Main.scala | 3 +- .../{util => datastructure}/DisjointSet.scala | 5 +- .../datastructure/ExprStructure.scala | 62 + .../{util => datastructure}/Graph.scala | 2 +- .../leon/invariant/datastructure/Maps.scala | 117 ++ .../templateSolvers/UFADTEliminator.scala | 2 +- .../scala/leon/invariant/util/CallGraph.scala | 1 + .../scala/leon/invariant/util/TypeUtil.scala | 45 + src/main/scala/leon/invariant/util/Util.scala | 337 ----- .../leon/laziness/ClosurePreAsserter.scala | 78 ++ .../laziness/LazinessEliminationPhase.scala | 235 ++++ .../scala/leon/laziness/LazinessUtil.scala | 166 +++ .../leon/laziness/LazyClosureConverter.scala | 526 ++++++++ .../leon/laziness/LazyClosureFactory.scala | 115 ++ .../scala/leon/laziness/TypeChecker.scala | 188 +++ .../scala/leon/laziness/TypeRectifier.scala | 152 +++ .../LazinessEliminationPhase.scala | 1092 ----------------- 17 files changed, 1692 insertions(+), 1434 deletions(-) rename src/main/scala/leon/invariant/{util => datastructure}/DisjointSet.scala (92%) create mode 100644 src/main/scala/leon/invariant/datastructure/ExprStructure.scala rename src/main/scala/leon/invariant/{util => datastructure}/Graph.scala (99%) create mode 100644 src/main/scala/leon/invariant/datastructure/Maps.scala create mode 100644 src/main/scala/leon/invariant/util/TypeUtil.scala create mode 100644 src/main/scala/leon/laziness/ClosurePreAsserter.scala create mode 100644 src/main/scala/leon/laziness/LazinessEliminationPhase.scala create mode 100644 src/main/scala/leon/laziness/LazinessUtil.scala create mode 100644 src/main/scala/leon/laziness/LazyClosureConverter.scala create mode 100644 src/main/scala/leon/laziness/LazyClosureFactory.scala create mode 100644 src/main/scala/leon/laziness/TypeChecker.scala create mode 100644 src/main/scala/leon/laziness/TypeRectifier.scala delete mode 100644 src/main/scala/leon/transformations/LazinessEliminationPhase.scala diff --git a/src/main/scala/leon/Main.scala b/src/main/scala/leon/Main.scala index d6cd9e9f0..542fb6835 100644 --- a/src/main/scala/leon/Main.scala +++ b/src/main/scala/leon/Main.scala @@ -3,6 +3,7 @@ package leon import leon.utils._ +import leon.laziness.LazinessEliminationPhase object Main { @@ -28,7 +29,7 @@ object Main { solvers.isabelle.IsabellePhase, transformations.InstrumentationPhase, invariant.engine.InferInvariantsPhase, - transformations.LazinessEliminationPhase + laziness.LazinessEliminationPhase ) } diff --git a/src/main/scala/leon/invariant/util/DisjointSet.scala b/src/main/scala/leon/invariant/datastructure/DisjointSet.scala similarity index 92% rename from src/main/scala/leon/invariant/util/DisjointSet.scala rename to src/main/scala/leon/invariant/datastructure/DisjointSet.scala index 520139764..4cab7291b 100644 --- a/src/main/scala/leon/invariant/util/DisjointSet.scala +++ b/src/main/scala/leon/invariant/datastructure/DisjointSet.scala @@ -1,8 +1,9 @@ package leon -package invariant.util +package invariant.datastructure import scala.collection.mutable.{ Map => MutableMap } -import scala.collection.mutable.{ Set => MutableSet } +import scala.annotation.migration +import scala.collection.mutable.{Map => MutableMap} class DisjointSets[T] { // A map from elements to their parent and rank diff --git a/src/main/scala/leon/invariant/datastructure/ExprStructure.scala b/src/main/scala/leon/invariant/datastructure/ExprStructure.scala new file mode 100644 index 000000000..a6e42e76f --- /dev/null +++ b/src/main/scala/leon/invariant/datastructure/ExprStructure.scala @@ -0,0 +1,62 @@ +package leon +package invariant.util + +import purescala.Common._ +import purescala.Definitions._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.Extractors._ +import purescala.Types._ +import scala.collection.mutable.{ Set => MutableSet, Map => MutableMap } + +/** + * A class that looks for structural equality of expressions + * by ignoring the variable names. + * Useful for factoring common parts of two expressions into functions. + */ +class ExprStructure(val e: Expr) { + def structurallyEqual(e1: Expr, e2: Expr): Boolean = { + (e1, e2) match { + case (t1: Terminal, t2: Terminal) => + // we need to specially handle type parameters as they are not considered equal by default + (t1.getType, t2.getType) match { + case (ct1: ClassType, ct2: ClassType) => + if (ct1.classDef == ct2.classDef && ct1.tps.size == ct2.tps.size) { + (ct1.tps zip ct2.tps).forall { + case (TypeParameter(_), TypeParameter(_)) => + true + case (a, b) => + println(s"Checking Type arguments: $a, $b") + a == b + } + } else false + case (ty1, ty2) => ty1 == ty2 + } + case (Operator(args1, op1), Operator(args2, op2)) => + (op1 == op2) && (args1.size == args2.size) && (args1 zip args2).forall { + case (a1, a2) => structurallyEqual(a1, a2) + } + case _ => + false + } + } + + override def equals(other: Any) = { + other match { + case other: ExprStructure => + structurallyEqual(e, other.e) + case _ => + false + } + } + + override def hashCode = { + var opndcount = 0 // operand count + var opcount = 0 // operator count + postTraversal { + case t: Terminal => opndcount += 1 + case _ => opcount += 1 + }(e) + (opndcount << 16) ^ opcount + } +} \ No newline at end of file diff --git a/src/main/scala/leon/invariant/util/Graph.scala b/src/main/scala/leon/invariant/datastructure/Graph.scala similarity index 99% rename from src/main/scala/leon/invariant/util/Graph.scala rename to src/main/scala/leon/invariant/datastructure/Graph.scala index 22f4e83d1..484f04668 100644 --- a/src/main/scala/leon/invariant/util/Graph.scala +++ b/src/main/scala/leon/invariant/datastructure/Graph.scala @@ -1,5 +1,5 @@ package leon -package invariant.util +package invariant.datastructure class DirectedGraph[T] { diff --git a/src/main/scala/leon/invariant/datastructure/Maps.scala b/src/main/scala/leon/invariant/datastructure/Maps.scala new file mode 100644 index 000000000..777a79375 --- /dev/null +++ b/src/main/scala/leon/invariant/datastructure/Maps.scala @@ -0,0 +1,117 @@ +package leon +package invariant.util + +import purescala.Common._ +import purescala.Definitions._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.Extractors._ +import purescala.Types._ +import scala.collection.mutable.{ Set => MutableSet, Map => MutableMap } +import scala.annotation.tailrec + +class MultiMap[A, B] extends scala.collection.mutable.HashMap[A, scala.collection.mutable.Set[B]] with scala.collection.mutable.MultiMap[A, B] { + /** + * Creates a new map and does not change the existing map + */ + def append(that: MultiMap[A, B]): MultiMap[A, B] = { + val newmap = new MultiMap[A, B]() + this.foreach { case (k, vset) => newmap += (k -> vset) } + that.foreach { + case (k, vset) => vset.foreach(v => newmap.addBinding(k, v)) + } + newmap + } +} + +/** + * A multimap that allows duplicate entries + */ +class OrderedMultiMap[A, B] extends scala.collection.mutable.HashMap[A, scala.collection.mutable.ListBuffer[B]] { + + def addBinding(key: A, value: B): this.type = { + get(key) match { + case None => + val list = new scala.collection.mutable.ListBuffer[B]() + list += value + this(key) = list + case Some(list) => + list += value + } + this + } + + /** + * Creates a new map and does not change the existing map + */ + def append(that: OrderedMultiMap[A, B]): OrderedMultiMap[A, B] = { + val newmap = new OrderedMultiMap[A, B]() + this.foreach { case (k, vlist) => newmap += (k -> vlist) } + that.foreach { + case (k, vlist) => vlist.foreach(v => newmap.addBinding(k, v)) + } + newmap + } + + /** + * Make the value of every key distinct + */ + def distinct: OrderedMultiMap[A, B] = { + val newmap = new OrderedMultiMap[A, B]() + this.foreach { case (k, vlist) => newmap += (k -> vlist.distinct) } + newmap + } +} + +/** + * Implements a mapping from Seq[A] to B where Seq[A] + * is stored as a Trie + */ +final class TrieMap[A, B] { + var childrenMap = Map[A, TrieMap[A, B]]() + var dataMap = Map[A, B]() + + @tailrec def addBinding(key: Seq[A], value: B) { + key match { + case Seq() => + throw new IllegalStateException("Key is empty!!") + case Seq(x) => + //add the value to the dataMap + if (dataMap.contains(x)) + throw new IllegalStateException("A mapping for key already exists: " + x + " --> " + dataMap(x)) + else + dataMap += (x -> value) + case head +: tail => //here, tail has at least one element + //check if we have an entry for seq(0) if yes go to the children, if not create one + val child = childrenMap.getOrElse(head, { + val ch = new TrieMap[A, B]() + childrenMap += (head -> ch) + ch + }) + child.addBinding(tail, value) + } + } + + @tailrec def lookup(key: Seq[A]): Option[B] = { + key match { + case Seq() => + throw new IllegalStateException("Key is empty!!") + case Seq(x) => + dataMap.get(x) + case head +: tail => //here, tail has at least one element + childrenMap.get(head) match { + case Some(child) => + child.lookup(tail) + case _ => None + } + } + } +} + +class CounterMap[T] extends scala.collection.mutable.HashMap[T, Int] { + def inc(v: T) = { + if (this.contains(v)) + this(v) += 1 + else this += (v -> 1) + } +} \ No newline at end of file diff --git a/src/main/scala/leon/invariant/templateSolvers/UFADTEliminator.scala b/src/main/scala/leon/invariant/templateSolvers/UFADTEliminator.scala index fb7519850..409e7f4d4 100644 --- a/src/main/scala/leon/invariant/templateSolvers/UFADTEliminator.scala +++ b/src/main/scala/leon/invariant/templateSolvers/UFADTEliminator.scala @@ -8,7 +8,7 @@ import purescala.ExprOps._ import purescala.Extractors._ import purescala.Types._ import java.io._ -import leon.invariant.util.UndirectedGraph +import leon.invariant.datastructure.UndirectedGraph import scala.util.control.Breaks._ import invariant.util._ import leon.purescala.TypeOps diff --git a/src/main/scala/leon/invariant/util/CallGraph.scala b/src/main/scala/leon/invariant/util/CallGraph.scala index 73ac3b099..88f2d3d85 100644 --- a/src/main/scala/leon/invariant/util/CallGraph.scala +++ b/src/main/scala/leon/invariant/util/CallGraph.scala @@ -10,6 +10,7 @@ import purescala.Extractors._ import purescala.Types._ import Util._ import invariant.structure.FunctionUtils._ +import leon.invariant.datastructure.DirectedGraph /** * This represents a call graph of the functions in the program diff --git a/src/main/scala/leon/invariant/util/TypeUtil.scala b/src/main/scala/leon/invariant/util/TypeUtil.scala new file mode 100644 index 000000000..28bd428e9 --- /dev/null +++ b/src/main/scala/leon/invariant/util/TypeUtil.scala @@ -0,0 +1,45 @@ +package leon +package invariant.util + +import purescala.Common._ +import purescala.Definitions._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.Extractors._ +import purescala.Types._ + +object TypeUtil { + def getTypeParameters(t: TypeTree): Seq[TypeParameter] = { + t match { + case tp @ TypeParameter(_) => Seq(tp) + case NAryType(tps, _) => + (tps flatMap getTypeParameters).distinct + } + } + + def getTypeArguments(t: TypeTree) : Seq[TypeTree] = t match { + case ct: ClassType => ct.tps + case NAryType(tps, _) => + (tps flatMap getTypeParameters).distinct + } + + def typeNameWOParams(t: TypeTree): String = t match { + case ct: ClassType => ct.id.name + case TupleType(ts) => ts.map(typeNameWOParams).mkString("(", ",", ")") + case ArrayType(t) => s"Array[${typeNameWOParams(t)}]" + case SetType(t) => s"Set[${typeNameWOParams(t)}]" + case MapType(from, to) => s"Map[${typeNameWOParams(from)}, ${typeNameWOParams(to)}]" + case FunctionType(fts, tt) => + val ftstr = fts.map(typeNameWOParams).mkString("(", ",", ")") + s"$ftstr => ${typeNameWOParams(tt)}" + case t => t.toString + } + + def instantiateTypeParameters(tpMap: Map[TypeParameter, TypeTree])(t: TypeTree): TypeTree = { + t match { + case tp: TypeParameter => tpMap.getOrElse(tp, tp) + case NAryType(subtypes, tcons) => + tcons(subtypes map instantiateTypeParameters(tpMap) _) + } + } +} \ No newline at end of file diff --git a/src/main/scala/leon/invariant/util/Util.scala b/src/main/scala/leon/invariant/util/Util.scala index 855dba093..2df7ac71c 100644 --- a/src/main/scala/leon/invariant/util/Util.scala +++ b/src/main/scala/leon/invariant/util/Util.scala @@ -688,340 +688,3 @@ class RealToInt { }.toMap) } } - -class MultiMap[A, B] extends scala.collection.mutable.HashMap[A, scala.collection.mutable.Set[B]] with scala.collection.mutable.MultiMap[A, B] { - /** - * Creates a new map and does not change the existing map - */ - def append(that: MultiMap[A, B]): MultiMap[A, B] = { - val newmap = new MultiMap[A, B]() - this.foreach { case (k, vset) => newmap += (k -> vset) } - that.foreach { - case (k, vset) => vset.foreach(v => newmap.addBinding(k, v)) - } - newmap - } -} - -/** - * A multimap that allows duplicate entries - */ -class OrderedMultiMap[A, B] extends scala.collection.mutable.HashMap[A, scala.collection.mutable.ListBuffer[B]] { - - def addBinding(key: A, value: B): this.type = { - get(key) match { - case None => - val list = new scala.collection.mutable.ListBuffer[B]() - list += value - this(key) = list - case Some(list) => - list += value - } - this - } - - /** - * Creates a new map and does not change the existing map - */ - def append(that: OrderedMultiMap[A, B]): OrderedMultiMap[A, B] = { - val newmap = new OrderedMultiMap[A, B]() - this.foreach { case (k, vlist) => newmap += (k -> vlist) } - that.foreach { - case (k, vlist) => vlist.foreach(v => newmap.addBinding(k, v)) - } - newmap - } - - /** - * Make the value of every key distinct - */ - def distinct: OrderedMultiMap[A, B] = { - val newmap = new OrderedMultiMap[A, B]() - this.foreach { case (k, vlist) => newmap += (k -> vlist.distinct) } - newmap - } -} - -/** - * Implements a mapping from Seq[A] to B where Seq[A] - * is stored as a Trie - */ -final class TrieMap[A, B] { - var childrenMap = Map[A, TrieMap[A, B]]() - var dataMap = Map[A, B]() - - @tailrec def addBinding(key: Seq[A], value: B) { - key match { - case Seq() => - throw new IllegalStateException("Key is empty!!") - case Seq(x) => - //add the value to the dataMap - if (dataMap.contains(x)) - throw new IllegalStateException("A mapping for key already exists: " + x + " --> " + dataMap(x)) - else - dataMap += (x -> value) - case head +: tail => //here, tail has at least one element - //check if we have an entry for seq(0) if yes go to the children, if not create one - val child = childrenMap.getOrElse(head, { - val ch = new TrieMap[A, B]() - childrenMap += (head -> ch) - ch - }) - child.addBinding(tail, value) - } - } - - @tailrec def lookup(key: Seq[A]): Option[B] = { - key match { - case Seq() => - throw new IllegalStateException("Key is empty!!") - case Seq(x) => - dataMap.get(x) - case head +: tail => //here, tail has at least one element - childrenMap.get(head) match { - case Some(child) => - child.lookup(tail) - case _ => None - } - } - } -} - -class CounterMap[T] extends scala.collection.mutable.HashMap[T, Int] { - def inc(v: T) = { - if (this.contains(v)) - this(v) += 1 - else this += (v -> 1) - } -} - -/** - * A class that looks for structural equality of expressions - * by ignoring the variable names. - * Useful for factoring common parts of two expressions into functions. - */ -class ExprStructure(val e: Expr) { - def structurallyEqual(e1: Expr, e2: Expr): Boolean = { - (e1, e2) match { - case (t1: Terminal, t2: Terminal) => - // we need to specially handle type parameters as they are not considered equal by default - (t1.getType, t2.getType) match { - case (ct1: ClassType, ct2: ClassType) => - if (ct1.classDef == ct2.classDef && ct1.tps.size == ct2.tps.size) { - (ct1.tps zip ct2.tps).forall { - case (TypeParameter(_), TypeParameter(_)) => - true - case (a, b) => - println(s"Checking Type arguments: $a, $b") - a == b - } - } else false - case (ty1, ty2) => ty1 == ty2 - } - case (Operator(args1, op1), Operator(args2, op2)) => - (op1 == op2) && (args1.size == args2.size) && (args1 zip args2).forall { - case (a1, a2) => structurallyEqual(a1, a2) - } - case _ => - false - } - } - - override def equals(other: Any) = { - other match { - case other: ExprStructure => - structurallyEqual(e, other.e) - case _ => - false - } - } - - override def hashCode = { - var opndcount = 0 // operand count - var opcount = 0 // operator count - postTraversal { - case t: Terminal => opndcount += 1 - case _ => opcount += 1 - }(e) - (opndcount << 16) ^ opcount - } -} - -object TypeUtil { - def getTypeParameters(t: TypeTree): Seq[TypeParameter] = { - t match { - case tp @ TypeParameter(_) => Seq(tp) - case NAryType(tps, _) => - (tps flatMap getTypeParameters).distinct - } - } - - def typeNameWOParams(t: TypeTree): String = t match { - case ct: ClassType => ct.id.name - case TupleType(ts) => ts.map(typeNameWOParams).mkString("(", ",", ")") - case ArrayType(t) => s"Array[${typeNameWOParams(t)}]" - case SetType(t) => s"Set[${typeNameWOParams(t)}]" - case MapType(from, to) => s"Map[${typeNameWOParams(from)}, ${typeNameWOParams(to)}]" - case FunctionType(fts, tt) => - val ftstr = fts.map(typeNameWOParams).mkString("(", ",", ")") - s"$ftstr => ${typeNameWOParams(tt)}" - case t => t.toString - } - - def instantiateTypeParameters(tpMap: Map[TypeParameter, TypeTree])(t: TypeTree): TypeTree = { - t match { - case tp: TypeParameter => tpMap.getOrElse(tp, tp) - case NAryType(subtypes, tcons) => - tcons(subtypes map instantiateTypeParameters(tpMap) _) - } - } - - /** - * `gamma` is the initial type environment which has - * type bindings for free variables of `ine`. - * It is not necessary that gamma should match the types of the - * identifiers of the free variables. - * Set and Maps are not supported yet - */ - def inferTypesOfLocals(ine: Expr, initGamma: Map[Identifier, TypeTree]): Expr = { - var idmap = Map[Identifier, Identifier]() - var gamma = initGamma - - /** - * Note this method has side-effects - */ - def makeIdOfType(oldId: Identifier, tpe: TypeTree): Identifier = { - if (oldId.getType != tpe) { - val freshid = FreshIdentifier(oldId.name, tpe, true) - idmap += (oldId -> freshid) - gamma += (oldId -> tpe) - freshid - } else oldId - } - - def rec(e: Expr): (TypeTree, Expr) = { - val res = e match { - case Let(id, value, body) => - val (valType, nval) = rec(value) - val nid = makeIdOfType(id, valType) - val (btype, nbody) = rec(body) - (btype, Let(nid, nval, nbody)) - - case Ensuring(body, Lambda(Seq(resdef @ ValDef(resid, _)), postBody)) => - val (btype, nbody) = rec(body) - val nres = makeIdOfType(resid, btype) - (btype, Ensuring(nbody, Lambda(Seq(ValDef(nres)), rec(postBody)._2))) - - case MatchExpr(scr, mcases) => - val (scrtype, nscr) = rec(scr) - val ncases = mcases.map { - case MatchCase(pat, optGuard, rhs) => - // resetting the type of patterns in the matches - def mapPattern(p: Pattern, expType: TypeTree): (Pattern, TypeTree) = { - p match { - case InstanceOfPattern(bopt, ict) => - // choose the subtype of the `expType` that - // has the same constructor as `ict` - val ntype = subcast(ict, expType.asInstanceOf[ClassType]) - if (!ntype.isDefined) - throw new IllegalStateException(s"Cannot find subtype of $expType with name: ${ict.classDef.id.toString}") - val nbopt = bopt.map(makeIdOfType(_, ntype.get)) - (InstanceOfPattern(nbopt, ntype.get), ntype.get) - - case CaseClassPattern(bopt, ict, subpats) => - val ntype = subcast(ict, expType.asInstanceOf[ClassType]) - if (!ntype.isDefined) - throw new IllegalStateException(s"Cannot find subtype of $expType with name: ${ict.classDef.id.toString}") - val cct = ntype.get.asInstanceOf[CaseClassType] - val nbopt = bopt.map(makeIdOfType(_, cct)) - val npats = (subpats zip cct.fieldsTypes).map { - case (p, t) => - //println(s"Subpat: $p expected type: $t") - mapPattern(p, t)._1 - } - (CaseClassPattern(nbopt, cct, npats), cct) - - case TuplePattern(bopt, subpats) => - val TupleType(subts) = scrtype - val patnTypes = (subpats zip subts).map { - case (p, t) => mapPattern(p, t) - } - val npats = patnTypes.map(_._1) - val ntype = TupleType(patnTypes.map(_._2)) - val nbopt = bopt.map(makeIdOfType(_, ntype)) - (TuplePattern(nbopt, npats), ntype) - - case WildcardPattern(bopt) => - val nbopt = bopt.map(makeIdOfType(_, expType)) - (WildcardPattern(nbopt), expType) - - case LiteralPattern(bopt, lit) => - val ntype = lit.getType - val nbopt = bopt.map(makeIdOfType(_, ntype)) - (LiteralPattern(nbopt, lit), ntype) - case _ => - throw new IllegalStateException("Not supported yet!") - } - } - val npattern = mapPattern(pat, scrtype)._1 - val nguard = optGuard.map(rec(_)._2) - val nrhs = rec(rhs)._2 - //println(s"New rhs: $nrhs inferred type: ${nrhs.getType}") - MatchCase(npattern, nguard, nrhs) - } - val nmatch = MatchExpr(nscr, ncases) - (nmatch.getType, nmatch) - - case cs @ CaseClassSelector(cltype, clExpr, fld) => - val (ncltype: CaseClassType, nclExpr) = rec(clExpr) - (ncltype, CaseClassSelector(ncltype, nclExpr, fld)) - - case AsInstanceOf(clexpr, cltype) => - val (ncltype: ClassType, nexpr) = rec(clexpr) - subcast(cltype, ncltype) match { - case Some(ntype) => (ntype, AsInstanceOf(nexpr, ntype)) - case _ => - //println(s"asInstanceOf type of $clExpr is: $cltype inferred type of $nclExpr : $ct") - throw new IllegalStateException(s"$nexpr : $ncltype cannot be cast to case class type: $cltype") - } - - case v @ Variable(id) => - if (gamma.contains(id)) { - if (idmap.contains(id)) - (gamma(id), idmap(id).toVariable) - else { - (gamma(id), v) - } - } else (id.getType, v) - - // need to handle tuple select specially - case TupleSelect(tup, i) => - val nop = TupleSelect(rec(tup)._2, i) - (nop.getType, nop) - case Operator(args, op) => - val nop = op(args.map(arg => rec(arg)._2)) - (nop.getType, nop) - case t: Terminal => - (t.getType, t) - } - //println(s"Inferred type of $e : ${res._1} new expression: ${res._2}") - if (res._1 == Untyped) { - throw new IllegalStateException(s"Cannot infer type for expression: $e") - } - res - } - - def subcast(oldType: ClassType, newType: ClassType): Option[ClassType] = { - newType match { - case AbstractClassType(absClass, tps) if absClass.knownCCDescendants.contains(oldType.classDef) => - //here oldType.classDef <: absClass - Some(CaseClassType(oldType.classDef.asInstanceOf[CaseClassDef], tps)) - case cct: CaseClassType => - Some(cct) - case _ => - None - } - } - rec(ine)._2 - } -} \ No newline at end of file diff --git a/src/main/scala/leon/laziness/ClosurePreAsserter.scala b/src/main/scala/leon/laziness/ClosurePreAsserter.scala new file mode 100644 index 000000000..a0f68b29a --- /dev/null +++ b/src/main/scala/leon/laziness/ClosurePreAsserter.scala @@ -0,0 +1,78 @@ +package leon +package laziness + +import invariant.factories._ +import invariant.util.Util._ +import invariant.util._ +import invariant.structure.FunctionUtils._ +import purescala.ScalaPrinter +import purescala.Common._ +import purescala.Definitions._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.DefOps._ +import purescala.Extractors._ +import purescala.Types._ +import leon.invariant.util.TypeUtil._ +import leon.invariant.util.LetTupleSimplification._ +import leon.verification.AnalysisPhase +import LazinessUtil._ + +/** + * Generate lemmas that ensure that preconditions hold for closures. + */ +class ClosurePreAsserter(p: Program) { + + def hasClassInvariants(cc: CaseClass): Boolean = { + val opname = ccNameToOpName(cc.ct.classDef.id.name) + functionByName(opname, p).get.hasPrecondition + } + + // A nasty way of finding anchor functions + // Fix this soon !! + var anchorfd: Option[FunDef] = None + val lemmas = p.definedFunctions.flatMap { + case fd if (fd.hasBody && !fd.isLibrary) => + //println("collection closure creation preconditions for: "+fd) + val closures = CollectorWithPaths { + case FunctionInvocation(TypedFunDef(fund, _), + Seq(cc: CaseClass, st)) if isClosureCons(fund) && hasClassInvariants(cc) => + (cc, st) + } traverse (fd.body.get) // Note: closures cannot be created in specs + // Note: once we have separated normal preconditions from state preconditions + // it suffices to just consider state preconditions here + closures.map { + case ((CaseClass(CaseClassType(ccd, _), args), st), path) => + anchorfd = Some(fd) + val target = functionByName(ccNameToOpName(ccd.id.name), p).get //find the target corresponding to the closure + val pre = target.precondition.get + val nargs = + if (target.params.size > args.size) // target takes state ? + args :+ st + else args + val pre2 = replaceFromIDs((target.params.map(_.id) zip nargs).toMap, pre) + val vc = Implies(And(Util.precOrTrue(fd), path), pre2) + // create a function for each vc + val lemmaid = FreshIdentifier(ccd.id.name + fd.id.name + "Lem", Untyped, true) + val params = variablesOf(vc).toSeq.map(v => ValDef(v)) + val tparams = params.flatMap(p => getTypeParameters(p.getType)).distinct map TypeParameterDef + val lemmafd = new FunDef(lemmaid, tparams, BooleanType, params) + // reset the types of locals + val initGamma = params.map(vd => vd.id -> vd.getType).toMap + lemmafd.body = Some(TypeChecker.inferTypesOfLocals(vc, initGamma)) + // assert the lemma is true + val resid = FreshIdentifier("holds", BooleanType) + lemmafd.postcondition = Some(Lambda(Seq(ValDef(resid)), resid.toVariable)) + //println("Created lemma function: "+lemmafd) + lemmafd + } + case _ => Seq() + } + + def apply: Program = { + if (!lemmas.isEmpty) + addFunDefs(p, lemmas, anchorfd.get) + else p + } +} + diff --git a/src/main/scala/leon/laziness/LazinessEliminationPhase.scala b/src/main/scala/leon/laziness/LazinessEliminationPhase.scala new file mode 100644 index 000000000..defad548c --- /dev/null +++ b/src/main/scala/leon/laziness/LazinessEliminationPhase.scala @@ -0,0 +1,235 @@ +package leon +package laziness + +import invariant.factories._ +import invariant.util.Util._ +import invariant.util._ +import invariant.structure.FunctionUtils._ +import purescala.ScalaPrinter +import purescala.Common._ +import purescala.Definitions._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.DefOps._ +import purescala.Extractors._ +import purescala.Types._ +import leon.invariant.util.TypeUtil._ +import leon.invariant.util.LetTupleSimplification._ +import leon.verification.AnalysisPhase +import java.io.File +import java.io.FileWriter +import java.io.BufferedWriter +import scala.util.matching.Regex +import leon.purescala.PrettyPrinter +import leon.LeonContext +import leon.LeonOptionDef +import leon.Main +import leon.TransformationPhase +import LazinessUtil._ +import leon.invariant.datastructure.DisjointSets + +object LazinessEliminationPhase extends TransformationPhase { + val debugLifting = false + val dumpProgramWithClosures = false + val dumpTypeCorrectProg = false + val dumpFinalProg = false + val debugSolvers = false + + val skipVerification = false + val prettyPrint = true + + val name = "Laziness Elimination Phase" + val description = "Coverts a program that uses lazy construct" + + " to a program that does not use lazy constructs" + + /** + * TODO: enforce that the specs do not create lazy closures + * TODO: we are forced to make an assumption that lazy ops takes as type parameters only those + * type parameters of their return type and not more. (enforce this) + * TODO: Check that lazy types are not nested + */ + def apply(ctx: LeonContext, prog: Program): Program = { + + val nprog = liftLazyExpressions(prog) + val progWithClosures = (new LazyClosureConverter(nprog, new LazyClosureFactory(nprog))).apply + if (dumpProgramWithClosures) + println("After closure conversion: \n" + ScalaPrinter.apply(progWithClosures)) + + //Rectify type parameters and local types + val typeCorrectProg = (new TypeRectifier(progWithClosures, tp => tp.id.name.endsWith("@"))).apply + if (dumpTypeCorrectProg) + println("After rectifying types: \n" + ScalaPrinter.apply(typeCorrectProg)) + + val transProg = (new ClosurePreAsserter(typeCorrectProg)).apply + if (dumpFinalProg) + println("After asserting closure preconditions: \n" + ScalaPrinter.apply(transProg)) + + // check specifications (to be moved to a different phase) + if (!skipVerification) + checkSpecifications(transProg) + if (prettyPrint) + prettyPrintProgramToFile(transProg, ctx) + transProg + } + + /** + * convert the argument of every lazy constructors to a procedure + */ + def liftLazyExpressions(prog: Program): Program = { + var newfuns = Map[ExprStructure, (FunDef, ModuleDef)]() + var anchorFun: Option[FunDef] = None // Fix this way of finding anchor functions + val fdmap = prog.modules.flatMap { md => + md.definedFunctions.map { + case fd if fd.hasBody && !fd.isLibrary => + //println("FunDef: "+fd) + val nfd = preMapOnFunDef { + case finv @ FunctionInvocation(lazytfd, Seq(arg)) if isLazyInvocation(finv)(prog) && !arg.isInstanceOf[FunctionInvocation] => + val freevars = variablesOf(arg).toList + val tparams = freevars.map(_.getType) flatMap getTypeParameters + val argstruc = new ExprStructure(arg) + val argfun = + if (newfuns.contains(argstruc)) { + newfuns(argstruc)._1 + } else { + //construct type parameters for the function + val nfun = new FunDef(FreshIdentifier("lazyarg", arg.getType, true), + tparams map TypeParameterDef.apply, arg.getType, freevars.map(ValDef(_))) + nfun.body = Some(arg) + newfuns += (argstruc -> (nfun, md)) + nfun + } + Some(FunctionInvocation(lazytfd, Seq(FunctionInvocation(TypedFunDef(argfun, tparams), + freevars.map(_.toVariable))))) + case _ => None + }(fd) + (fd -> Some(nfd)) + case fd => fd -> Some(fd.duplicate) + } + }.toMap + // map the new functions to themselves + val newfdMap = newfuns.values.map { case (nfd, _) => nfd -> None }.toMap + val (repProg, _) = replaceFunDefs(prog)(fdmap ++ newfdMap) + val nprog = + if (!newfuns.isEmpty) { + val modToNewDefs = newfuns.values.groupBy(_._2).map { case (k, v) => (k, v.map(_._1)) }.toMap + appendDefsToModules(repProg, modToNewDefs) + } else + throw new IllegalStateException("Cannot find any lazy computation") + if (debugLifting) + println("After lifiting arguments of lazy constructors: \n" + ScalaPrinter.apply(nprog)) + nprog + } + + import leon.solvers._ + import leon.solvers.z3._ + + def checkSpecifications(prog: Program) { + // convert 'axiom annotation to library + prog.definedFunctions.foreach { fd => + if (fd.annotations.contains("axiom")) + fd.addFlag(Annotation("library", Seq())) + } + val functions = Seq() // Seq("--functions=Rotate@rotateLem") + val solverOptions = if(debugSolvers) Seq("--debug=solver") else Seq() + val ctx = Main.processOptions(Seq("--solvers=smt-cvc4") ++ solverOptions ++ functions) + val report = AnalysisPhase.run(ctx)(prog) + println(report.summaryString) + /*val timeout = 10 + val rep = ctx.reporter + * val fun = prog.definedFunctions.find(_.id.name == "firstUnevaluated").get + // create a verification context. + val solver = new FairZ3Solver(ctx, prog) with TimeoutSolver + val solverF = new TimeoutSolverFactory(SolverFactory(() => solver), timeout * 1000) + val vctx = VerificationContext(ctx, prog, solverF, rep) + val vc = (new DefaultTactic(vctx)).generatePostconditions(fun)(0) + val s = solverF.getNewSolver() + try { + rep.info(s" - Now considering '${vc.kind}' VC for ${vc.fd.id} @${vc.getPos}...") + val tStart = System.currentTimeMillis + s.assertVC(vc) + val satResult = s.check + val dt = System.currentTimeMillis - tStart + val res = satResult match { + case None => + rep.info("Cannot prove or disprove specifications") + case Some(false) => + rep.info("Valid") + case Some(true) => + println("Invalid - counter-example: " + s.getModel) + } + } finally { + s.free() + }*/ + } + + /** + * NOT USED CURRENTLY + * Lift the specifications on functions to the invariants corresponding + * case classes. + * Ideally we should class invariants here, but it is not currently supported + * so we create a functions that can be assume in the pre and post of functions. + * TODO: can this be optimized + */ + /* def liftSpecsToClosures(opToAdt: Map[FunDef, CaseClassDef]) = { + val invariants = opToAdt.collect { + case (fd, ccd) if fd.hasPrecondition => + val transFun = (args: Seq[Identifier]) => { + val argmap: Map[Expr, Expr] = (fd.params.map(_.id.toVariable) zip args.map(_.toVariable)).toMap + replace(argmap, fd.precondition.get) + } + (ccd -> transFun) + }.toMap + val absTypes = opToAdt.values.collect { + case cd if cd.parent.isDefined => cd.parent.get.classDef + } + val invFuns = absTypes.collect { + case abs if abs.knownCCDescendents.exists(invariants.contains) => + val absType = AbstractClassType(abs, abs.tparams.map(_.tp)) + val param = ValDef(FreshIdentifier("$this", absType)) + val tparams = abs.tparams + val invfun = new FunDef(FreshIdentifier(abs.id.name + "$Inv", Untyped), + tparams, BooleanType, Seq(param)) + (abs -> invfun) + }.toMap + // assign bodies for the 'invfuns' + invFuns.foreach { + case (abs, fd) => + val bodyCases = abs.knownCCDescendents.collect { + case ccd if invariants.contains(ccd) => + val ctype = CaseClassType(ccd, fd.tparams.map(_.tp)) + val cvar = FreshIdentifier("t", ctype) + val fldids = ctype.fields.map { + case ValDef(fid, Some(fldtpe)) => + FreshIdentifier(fid.name, fldtpe) + } + val pattern = CaseClassPattern(Some(cvar), ctype, + fldids.map(fid => WildcardPattern(Some(fid)))) + val rhsInv = invariants(ccd)(fldids) + // assert the validity of substructures + val rhsValids = fldids.flatMap { + case fid if fid.getType.isInstanceOf[ClassType] => + val t = fid.getType.asInstanceOf[ClassType] + val rootDef = t match { + case absT: AbstractClassType => absT.classDef + case _ if t.parent.isDefined => + t.parent.get.classDef + } + if (invFuns.contains(rootDef)) { + List(FunctionInvocation(TypedFunDef(invFuns(rootDef), t.tps), + Seq(fid.toVariable))) + } else + List() + case _ => List() + } + val rhs = Util.createAnd(rhsInv +: rhsValids) + MatchCase(pattern, None, rhs) + } + // create a default case + val defCase = MatchCase(WildcardPattern(None), None, Util.tru) + val matchExpr = MatchExpr(fd.params.head.id.toVariable, bodyCases :+ defCase) + fd.body = Some(matchExpr) + } + invFuns + }*/ +} + diff --git a/src/main/scala/leon/laziness/LazinessUtil.scala b/src/main/scala/leon/laziness/LazinessUtil.scala new file mode 100644 index 000000000..dd65f2df4 --- /dev/null +++ b/src/main/scala/leon/laziness/LazinessUtil.scala @@ -0,0 +1,166 @@ +package leon +package laziness + +import invariant.factories._ +import invariant.util.Util._ +import invariant.util._ +import invariant.structure.FunctionUtils._ +import purescala.ScalaPrinter +import purescala.Common._ +import purescala.Definitions._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.DefOps._ +import purescala.Extractors._ +import purescala.Types._ +import leon.invariant.util.TypeUtil._ +import leon.invariant.util.LetTupleSimplification._ +import leon.verification.AnalysisPhase +import java.io.File +import java.io.FileWriter +import java.io.BufferedWriter +import scala.util.matching.Regex +import leon.purescala.PrettyPrinter +import leon.LeonContext +import leon.LeonOptionDef +import leon.Main +import leon.TransformationPhase + +object LazinessUtil { + + def prettyPrintProgramToFile(p: Program, ctx: LeonContext) { + val optOutputDirectory = new LeonOptionDef[String] { + val name = "o" + val description = "Output directory" + val default = "leon.out" + val usageRhs = "dir" + val parser = (x: String) => x + } + val outputFolder = ctx.findOptionOrDefault(optOutputDirectory) + try { + new File(outputFolder).mkdir() + } catch { + case _: java.io.IOException => ctx.reporter.fatalError("Could not create directory " + outputFolder) + } + + for (u <- p.units if u.isMainUnit) { + val outputFile = s"$outputFolder${File.separator}${u.id.toString}.scala" + try { + val out = new BufferedWriter(new FileWriter(outputFile)) + // remove '@' from the end of the identifier names + val pat = new Regex("""(\S+)(@)""", "base", "suffix") + val pgmText = pat.replaceAllIn(ScalaPrinter.apply(p), m => m.group("base")) + out.write(pgmText) + out.close() + } catch { + case _: java.io.IOException => ctx.reporter.fatalError("Could not write on " + outputFile) + } + } + ctx.reporter.info("Output written on " + outputFolder) + } + + def isLazyInvocation(e: Expr)(implicit p: Program): Boolean = e match { + case FunctionInvocation(TypedFunDef(fd, _), Seq(_)) => + fullName(fd)(p) == "leon.lazyeval.$.apply" + case _ => + false + } + + def isEvaluatedInvocation(e: Expr)(implicit p: Program): Boolean = e match { + case FunctionInvocation(TypedFunDef(fd, _), Seq(_)) => + fullName(fd)(p) == "leon.lazyeval.$.isEvaluated" + case _ => false + } + + def isValueInvocation(e: Expr)(implicit p: Program): Boolean = e match { + case FunctionInvocation(TypedFunDef(fd, _), Seq(_)) => + fullName(fd)(p) == "leon.lazyeval.$.value" + case _ => false + } + + def isStarInvocation(e: Expr)(implicit p: Program): Boolean = e match { + case FunctionInvocation(TypedFunDef(fd, _), Seq(_)) => + fullName(fd)(p) == "leon.lazyeval.$.*" + case _ => false + } + + def isLazyType(tpe: TypeTree): Boolean = tpe match { + case CaseClassType(CaseClassDef(cid, _, None, false), Seq(_)) => + cid.name == "$" + case _ => false + } + + /** + * TODO: Check that lazy types are not nested + */ + def unwrapLazyType(tpe: TypeTree) = tpe match { + case ctype @ CaseClassType(_, Seq(innerType)) if isLazyType(ctype) => + Some(innerType) + case _ => None + } + + def rootType(t: TypeTree): Option[AbstractClassType] = t match { + case absT: AbstractClassType => Some(absT) + case ct: ClassType => ct.parent + case _ => None + } + + def opNameToCCName(name: String) = { + name.capitalize + "@" + } + + /** + * Convert the first character to lower case + * and remove the last character. + */ + def ccNameToOpName(name: String) = { + name.substring(0, 1).toLowerCase() + + name.substring(1, name.length() - 1) + } + + def typeNameToADTName(name: String) = { + "Lazy" + name + } + + def adtNameToTypeName(name: String) = { + name.substring(4) + } + + def closureConsName(typeName: String) = { + "new@" + typeName + } + + def isClosureCons(fd: FunDef) = { + fd.id.name.startsWith("new@") + } + + /** + * Returns all functions that 'need' states to be passed in + * and those that return a new state. + * TODO: implement backwards BFS by reversing the graph + */ + def funsNeedingnReturningState(prog: Program) = { + val cg = CallGraphUtil.constructCallGraph(prog, false, true) + var needRoots = Set[FunDef]() + var retRoots = Set[FunDef]() + prog.definedFunctions.foreach { + case fd if fd.hasBody && !fd.isLibrary => + postTraversal { + case finv: FunctionInvocation if isEvaluatedInvocation(finv)(prog) => + needRoots += fd + case finv: FunctionInvocation if isValueInvocation(finv)(prog) => + needRoots += fd + retRoots += fd + case _ => + ; + }(fd.body.get) + case _ => ; + } + val funsNeedStates = prog.definedFunctions.filterNot(fd => + cg.transitiveCallees(fd).toSet.intersect(needRoots).isEmpty).toSet + val funsRetStates = prog.definedFunctions.filterNot(fd => + cg.transitiveCallees(fd).toSet.intersect(retRoots).isEmpty).toSet + (funsNeedStates, funsRetStates) + } +} + diff --git a/src/main/scala/leon/laziness/LazyClosureConverter.scala b/src/main/scala/leon/laziness/LazyClosureConverter.scala new file mode 100644 index 000000000..6b11c654c --- /dev/null +++ b/src/main/scala/leon/laziness/LazyClosureConverter.scala @@ -0,0 +1,526 @@ +package leon +package laziness + +import invariant.factories._ +import invariant.util.Util._ +import invariant.util._ +import invariant.structure.FunctionUtils._ +import purescala.ScalaPrinter +import purescala.Common._ +import purescala.Definitions._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.DefOps._ +import purescala.Extractors._ +import purescala.Types._ +import leon.invariant.util.TypeUtil._ +import leon.invariant.util.LetTupleSimplification._ +import leon.verification.AnalysisPhase +import java.io.File +import java.io.FileWriter +import java.io.BufferedWriter +import scala.util.matching.Regex +import leon.purescala.PrettyPrinter +import leon.LeonContext +import leon.LeonOptionDef +import leon.Main +import leon.TransformationPhase +import LazinessUtil._ + +/** + * (a) add state to every function in the program + * (b) thread state through every expression in the program sequentially + * (c) replace lazy constructions with case class creations + * (d) replace isEvaluated with currentState.contains() + * (e) replace accesses to $.value with calls to dispatch with current state + */ +class LazyClosureConverter(p: Program, closureFactory: LazyClosureFactory) { + val debug = false + // flags + val removeRecursionViaEval = false + + val (funsNeedStates, funsRetStates) = funsNeedingnReturningState(p) + val tnames = closureFactory.lazyTypeNames + + // create a mapping from functions to new functions + lazy val funMap = p.definedFunctions.collect { + case fd if (fd.hasBody && !fd.isLibrary) => + // replace lazy types in parameters and return values + val nparams = fd.params map { vd => + ValDef(vd.id, Some(replaceLazyTypes(vd.getType))) // override type of identifier + } + val nretType = replaceLazyTypes(fd.returnType) + val nfd = if (funsNeedStates(fd)) { + var newTParams = Seq[TypeParameterDef]() + val stTypes = tnames map { tn => + val absClass = closureFactory.absClosureType(tn) + val tparams = absClass.tparams.map(_ => + TypeParameter.fresh("P@")) + newTParams ++= tparams map TypeParameterDef + SetType(AbstractClassType(absClass, tparams)) + } + val stParams = stTypes.map { stType => + ValDef(FreshIdentifier("st@", stType, true)) + } + val retTypeWithState = + if (funsRetStates(fd)) + TupleType(nretType +: stTypes) + else + nretType + // the type parameters will be unified later + new FunDef(FreshIdentifier(fd.id.name, Untyped), + fd.tparams ++ newTParams, retTypeWithState, nparams ++ stParams) + // body of these functions are defined later + } else { + new FunDef(FreshIdentifier(fd.id.name, Untyped), fd.tparams, nretType, nparams) + } + // copy annotations + fd.flags.foreach(nfd.addFlag(_)) + (fd -> nfd) + }.toMap + + /** + * A set of uninterpreted functions that may be used as targets + * of closures in the eval functions, for efficiency reasons. + */ + lazy val uninterpretedTargets = funMap.map { + case (k, v) => + val ufd = new FunDef(FreshIdentifier(v.id.name, v.id.getType, true), + v.tparams, v.returnType, v.params) + (k -> ufd) + }.toMap + + /** + * A function for creating a state type for every lazy type. Note that Leon + * doesn't support 'Any' type yet. So we have to have multiple + * state (though this is much clearer it results in more complicated code) + */ + def getStateType(tname: String, tparams: Seq[TypeParameter]) = { + //val (_, absdef, _) = tpeToADT(tname) + SetType(AbstractClassType(closureFactory.absClosureType(tname), tparams)) + } + + def replaceLazyTypes(t: TypeTree): TypeTree = { + unwrapLazyType(t) match { + case None => + val NAryType(tps, tcons) = t + tcons(tps map replaceLazyTypes) + case Some(btype) => + val absClass = closureFactory.absClosureType(typeNameWOParams(btype)) + val ntype = AbstractClassType(absClass, getTypeParameters(btype)) + val NAryType(tps, tcons) = ntype + tcons(tps map replaceLazyTypes) + } + } + + /** + * Create dispatch functions for each lazy type. + * Note: the dispatch functions will be annotated as library so that + * their pre/posts are not checked (the fact that they hold are verified separately) + * Note by using 'assume-pre' we can also assume the preconditions of closures at + * the call-sites. + */ + val evalFunctions = tnames.map { tname => + val tpe = closureFactory.lazyType(tname) + val absdef = closureFactory.absClosureType(tname) + val cdefs = closureFactory.closures(tname) + + // construct parameters and return types + val tparams = getTypeParameters(tpe) + val tparamDefs = tparams map TypeParameterDef.apply + val param1 = FreshIdentifier("cl", AbstractClassType(absdef, tparams)) + val stType = getStateType(tname, tparams) + val param2 = FreshIdentifier("st@", stType) + val retType = TupleType(Seq(tpe, stType)) + + // create a eval function + val dfun = new FunDef(FreshIdentifier("eval" + absdef.id.name, Untyped), + tparamDefs, retType, Seq(ValDef(param1), ValDef(param2))) + + // assign body of the eval fucntion + // create a match case to switch over the possible class defs and invoke the corresponding functions + val bodyMatchCases = cdefs map { cdef => + val ctype = CaseClassType(cdef, tparams) // we assume that the type parameters of cdefs are same as absdefs + val binder = FreshIdentifier("t", ctype) + val pattern = InstanceOfPattern(Some(binder), ctype) + // create a body of the match + val args = cdef.fields map { fld => CaseClassSelector(ctype, binder.toVariable, fld.id) } + val op = closureFactory.caseClassToOp(cdef) + // TODO: here we are assuming that only one state is used, fix this. + val stArgs = + if (funsNeedStates(op)) + // Note: it is important to use empty state here to eliminate + // dependency on state for the result value + Seq(FiniteSet(Set(), stType.base)) + else Seq() + val targetFun = + if (removeRecursionViaEval && op.hasPostcondition) { + // checking for postcondition is a hack to make sure that we + // do not make all targets uninterpreted + uninterpretedTargets(op) + } else funMap(op) + val invoke = FunctionInvocation(TypedFunDef(targetFun, tparams), args ++ stArgs) // we assume that the type parameters of lazy ops are same as absdefs + val invPart = if (funsRetStates(op)) { + TupleSelect(invoke, 1) // we are only interested in the value + } else invoke + val newst = SetUnion(param2.toVariable, FiniteSet(Set(binder.toVariable), stType.base)) + val rhs = Tuple(Seq(invPart, newst)) + MatchCase(pattern, None, rhs) + } + dfun.body = Some(MatchExpr(param1.toVariable, bodyMatchCases)) + dfun.addFlag(Annotation("axiom", Seq())) + (tname -> dfun) + }.toMap + + /** + * These are evalFunctions that do not affect the state + */ + val computeFunctions = evalFunctions.map { + case (tname, evalfd) => + val tpe = closureFactory.lazyType(tname) + val param1 = evalfd.params.head + val fun = new FunDef(FreshIdentifier(evalfd.id.name + "*", Untyped), + evalfd.tparams, tpe, Seq(param1)) + val invoke = FunctionInvocation(TypedFunDef(evalfd, evalfd.tparams.map(_.tp)), + Seq(param1.id.toVariable, FiniteSet(Set(), + getStateType(tname, getTypeParameters(tpe)).base))) + fun.body = Some(TupleSelect(invoke, 1)) + (tname -> fun) + }.toMap + + /** + * Create closure construction functions that ensures a postcondition. + * They are defined for each lazy class since it avoids generics and + * simplifies the type inference (which is not full-fledged in Leon) + */ + val closureCons = tnames.map { tname => + val adt = closureFactory.absClosureType(tname) + val param1Type = AbstractClassType(adt, adt.tparams.map(_.tp)) + val param1 = FreshIdentifier("cc", param1Type) + val stType = SetType(param1Type) + val param2 = FreshIdentifier("st@", stType) + val tparamDefs = adt.tparams + val fun = new FunDef(FreshIdentifier(closureConsName(tname)), adt.tparams, param1Type, + Seq(ValDef(param1), ValDef(param2))) + fun.body = Some(param1.toVariable) + val resid = FreshIdentifier("res", param1Type) + val postbody = Not(SubsetOf(FiniteSet(Set(resid.toVariable), param1Type), param2.toVariable)) + fun.postcondition = Some(Lambda(Seq(ValDef(resid)), postbody)) + fun.addFlag(Annotation("axiom", Seq())) + (tname -> fun) + }.toMap + + def mapBody(body: Expr): (Map[String, Expr] => Expr, Boolean) = body match { + + case finv @ FunctionInvocation(_, Seq(FunctionInvocation(TypedFunDef(argfd, tparams), args))) // lazy construction ? + if isLazyInvocation(finv)(p) => + val op = (nargs: Seq[Expr]) => ((st: Map[String, Expr]) => { + val adt = closureFactory.closureOfLazyOp(argfd) + val cc = CaseClass(CaseClassType(adt, tparams), nargs) + val baseLazyTypeName = closureFactory.lazyTypeNameOfClosure(adt) + FunctionInvocation(TypedFunDef(closureCons(baseLazyTypeName), tparams), + Seq(cc, st(baseLazyTypeName))) + }, false) + mapNAryOperator(args, op) + + case finv @ FunctionInvocation(_, args) if isEvaluatedInvocation(finv)(p) => // isEval function ? + val op = (nargs: Seq[Expr]) => ((st: Map[String, Expr]) => { + val narg = nargs(0) // there must be only one argument here + val baseType = unwrapLazyType(narg.getType).get + val tname = typeNameWOParams(baseType) + val adtType = AbstractClassType(closureFactory.absClosureType(tname), getTypeParameters(baseType)) + SubsetOf(FiniteSet(Set(narg), adtType), st(tname)) + }, false) + mapNAryOperator(args, op) + + case finv @ FunctionInvocation(_, args) if isValueInvocation(finv)(p) => // is value function ? + val op = (nargs: Seq[Expr]) => ((st: Map[String, Expr]) => { + val baseType = unwrapLazyType(nargs(0).getType).get // there must be only one argument here + val tname = typeNameWOParams(baseType) + val dispFun = evalFunctions(tname) + val dispFunInv = FunctionInvocation(TypedFunDef(dispFun, + getTypeParameters(baseType)), nargs :+ st(tname)) + val dispRes = FreshIdentifier("dres", dispFun.returnType) + val nstates = tnames map { + case `tname` => + TupleSelect(dispRes.toVariable, 2) + case other => st(other) + } + Let(dispRes, dispFunInv, Tuple(TupleSelect(dispRes.toVariable, 1) +: nstates)) + }, true) + mapNAryOperator(args, op) + + case finv @ FunctionInvocation(_, args) if isStarInvocation(finv)(p) => // is * function ? + val op = (nargs: Seq[Expr]) => ((st: Map[String, Expr]) => { + val baseType = unwrapLazyType(nargs(0).getType).get // there must be only one argument here + val tname = typeNameWOParams(baseType) + val dispFun = computeFunctions(tname) + FunctionInvocation(TypedFunDef(dispFun, getTypeParameters(baseType)), nargs) + }, false) + mapNAryOperator(args, op) + + case FunctionInvocation(TypedFunDef(fd, tparams), args) if funMap.contains(fd) => + mapNAryOperator(args, + (nargs: Seq[Expr]) => ((st: Map[String, Expr]) => { + val stArgs = + if (funsNeedStates(fd)) { + (tnames map st.apply) + } else Seq() + FunctionInvocation(TypedFunDef(funMap(fd), tparams), nargs ++ stArgs) + }, funsRetStates(fd))) + + case Let(id, value, body) => + val (valCons, valUpdatesState) = mapBody(value) + val (bodyCons, bodyUpdatesState) = mapBody(body) + ((st: Map[String, Expr]) => { + val nval = valCons(st) + if (valUpdatesState) { + val freshid = FreshIdentifier(id.name, nval.getType, true) + val nextStates = tnames.zipWithIndex.map { + case (tn, i) => + TupleSelect(freshid.toVariable, i + 2) + }.toSeq + val nsMap = (tnames zip nextStates).toMap + val transBody = replace(Map(id.toVariable -> TupleSelect(freshid.toVariable, 1)), + bodyCons(nsMap)) + if (bodyUpdatesState) + Let(freshid, nval, transBody) + else + Let(freshid, nval, Tuple(transBody +: nextStates)) + } else + Let(id, nval, bodyCons(st)) + }, valUpdatesState || bodyUpdatesState) + + case IfExpr(cond, thn, elze) => + val (condCons, condState) = mapBody(cond) + val (thnCons, thnState) = mapBody(thn) + val (elzeCons, elzeState) = mapBody(elze) + ((st: Map[String, Expr]) => { + val (ncondCons, nst) = + if (condState) { + val cndExpr = condCons(st) + val bder = FreshIdentifier("c", cndExpr.getType) + val condst = tnames.zipWithIndex.map { + case (tn, i) => tn -> TupleSelect(bder.toVariable, i + 2) + }.toMap + ((th: Expr, el: Expr) => + Let(bder, cndExpr, IfExpr(TupleSelect(bder.toVariable, 1), th, el)), + condst) + } else { + ((th: Expr, el: Expr) => IfExpr(condCons(st), th, el), st) + } + val nelze = + if ((condState || thnState) && !elzeState) + Tuple(elzeCons(nst) +: tnames.map(nst.apply)) + else elzeCons(nst) + val nthn = + if (!thnState && (condState || elzeState)) + Tuple(thnCons(nst) +: tnames.map(nst.apply)) + else thnCons(nst) + ncondCons(nthn, nelze) + }, condState || thnState || elzeState) + + case MatchExpr(scr, cases) => + val (scrCons, scrUpdatesState) = mapBody(scr) + val casesRes = cases.foldLeft(Seq[(Map[String, Expr] => Expr, Boolean)]()) { + case (acc, MatchCase(pat, None, rhs)) => + acc :+ mapBody(rhs) + case mcase => + throw new IllegalStateException("Match case with guards are not supported yet: " + mcase) + } + val casesUpdatesState = casesRes.exists(_._2) + ((st: Map[String, Expr]) => { + val scrExpr = scrCons(st) + val (nscrCons, scrst) = if (scrUpdatesState) { + val bder = FreshIdentifier("scr", scrExpr.getType) + val scrst = tnames.zipWithIndex.map { + case (tn, i) => tn -> TupleSelect(bder.toVariable, i + 2) + }.toMap + ((ncases: Seq[MatchCase]) => + Let(bder, scrExpr, MatchExpr(TupleSelect(bder.toVariable, 1), ncases)), + scrst) + } else { + ((ncases: Seq[MatchCase]) => MatchExpr(scrExpr, ncases), st) + } + val ncases = (cases zip casesRes).map { + case (MatchCase(pat, None, _), (caseCons, caseUpdatesState)) => + val nrhs = + if ((scrUpdatesState || casesUpdatesState) && !caseUpdatesState) + Tuple(caseCons(scrst) +: tnames.map(scrst.apply)) + else caseCons(scrst) + MatchCase(pat, None, nrhs) + } + nscrCons(ncases) + }, scrUpdatesState || casesUpdatesState) + + // need to reset types in the case of case class constructor calls + case CaseClass(cct, args) => + val ntype = replaceLazyTypes(cct).asInstanceOf[CaseClassType] + mapNAryOperator(args, + (nargs: Seq[Expr]) => ((st: Map[String, Expr]) => CaseClass(ntype, nargs), false)) + + case Operator(args, op) => + // here, 'op' itself does not create a new state + mapNAryOperator(args, + (nargs: Seq[Expr]) => ((st: Map[String, Expr]) => op(nargs), false)) + + case t: Terminal => (_ => t, false) + } + + def mapNAryOperator(args: Seq[Expr], op: Seq[Expr] => (Map[String, Expr] => Expr, Boolean)) = { + // create n variables to model n lets + val letvars = args.map(arg => FreshIdentifier("arg", arg.getType, true).toVariable) + (args zip letvars).foldRight(op(letvars)) { + case ((arg, letvar), (accCons, stUpdatedBefore)) => + val (argCons, stUpdateFlag) = mapBody(arg) + val cl = if (!stUpdateFlag) { + // here arg does not affect the newstate + (st: Map[String, Expr]) => replace(Map(letvar -> argCons(st)), accCons(st)) + } else { + // here arg does affect the newstate + (st: Map[String, Expr]) => + { + val narg = argCons(st) + val argres = FreshIdentifier("a", narg.getType, true).toVariable + val nstateSeq = tnames.zipWithIndex.map { + // states start from index 2 + case (tn, i) => TupleSelect(argres, i + 2) + } + val nstate = (tnames zip nstateSeq).map { + case (tn, st) => (tn -> st) + }.toMap[String, Expr] + val letbody = + if (stUpdatedBefore) accCons(nstate) // here, 'acc' already returns a superseeding state + else Tuple(accCons(nstate) +: nstateSeq) // here, 'acc; only retruns the result + Let(argres.id, narg, + Let(letvar.id, TupleSelect(argres, 1), letbody)) + } + } + (cl, stUpdatedBefore || stUpdateFlag) + } + } + + def assignBodiesToFunctions = funMap foreach { + case (fd, nfd) => + // /println("Considering function: "+fd) + // Here, using name to identify 'state' parameters, also relying + // on fact that nfd.params are in the same order as tnames + val stateParams = nfd.params.foldLeft(Seq[Expr]()) { + case (acc, ValDef(id, _)) if id.name.startsWith("st@") => + acc :+ id.toVariable + case (acc, _) => acc + } + val initStateMap = tnames zip stateParams toMap + val (nbodyFun, bodyUpdatesState) = mapBody(fd.body.get) + val nbody = nbodyFun(initStateMap) + val bodyWithState = + if (!bodyUpdatesState && funsRetStates(fd)) { + val freshid = FreshIdentifier("bres", nbody.getType) + Let(freshid, nbody, Tuple(freshid.toVariable +: stateParams)) + } else nbody + nfd.body = Some(simplifyLets(bodyWithState)) + //println(s"Body of ${fd.id.name} after conversion&simp: ${nfd.body}") + + // Important: specifications use lazy semantics but + // their state changes are ignored after their execution. + // This guarantees their observational purity/transparency + // collect class invariants that need to be added + if (fd.hasPrecondition) { + val (npreFun, preUpdatesState) = mapBody(fd.precondition.get) + nfd.precondition = + if (preUpdatesState) + Some(TupleSelect(npreFun(initStateMap), 1)) // ignore state updated by pre + else Some(npreFun(initStateMap)) + } + + if (fd.hasPostcondition) { + val Lambda(arg @ Seq(ValDef(resid, _)), post) = fd.postcondition.get + val (npostFun, postUpdatesState) = mapBody(post) + val newres = FreshIdentifier(resid.name, bodyWithState.getType) + val npost1 = + if (bodyUpdatesState || funsRetStates(fd)) { + val stmap = tnames.zipWithIndex.map { + case (tn, i) => (tn -> TupleSelect(newres.toVariable, i + 2)) + }.toMap + replace(Map(resid.toVariable -> TupleSelect(newres.toVariable, 1)), npostFun(stmap)) + } else + replace(Map(resid.toVariable -> newres.toVariable), npostFun(initStateMap)) + val npost = + if (postUpdatesState) + TupleSelect(npost1, 1) // ignore state updated by post + else npost1 + nfd.postcondition = Some(Lambda(Seq(ValDef(newres)), npost)) + } + } + + def assignContractsForEvals = evalFunctions.foreach { + case (tname, evalfd) => + val cdefs = closureFactory.closures(tname) + val tparams = evalfd.tparams.map(_.tp) + val postres = FreshIdentifier("res", evalfd.returnType) + // create a match case to switch over the possible class defs and invoke the corresponding functions + val postMatchCases = cdefs map { cdef => + val ctype = CaseClassType(cdef, tparams) + val binder = FreshIdentifier("t", ctype) + val pattern = InstanceOfPattern(Some(binder), ctype) + // create a body of the match + val op = closureFactory.lazyopOfClosure(cdef) + val targetFd = funMap(op) + val rhs = if (targetFd.hasPostcondition) { + val args = cdef.fields map { fld => CaseClassSelector(ctype, binder.toVariable, fld.id) } + val stArgs = + if (funsNeedStates(op)) // TODO: here we are assuming that only one state is used, fix this. + Seq(evalfd.params.last.toVariable) + else Seq() + val Lambda(Seq(resarg), targetPost) = targetFd.postcondition.get + val replaceMap = (targetFd.params.map(_.toVariable) zip (args ++ stArgs)).toMap[Expr, Expr] + + (resarg.toVariable -> postres.toVariable) + replace(replaceMap, targetPost) + } else + Util.tru + MatchCase(pattern, None, rhs) + } + evalfd.postcondition = Some( + Lambda(Seq(ValDef(postres)), + MatchExpr(evalfd.params.head.toVariable, postMatchCases))) + } + + /** + * Overrides the types of the lazy fields in the case class definitions + */ + def transformCaseClasses = p.definedClasses.foreach { + case ccd @ CaseClassDef(id, tparamDefs, superClass, isCaseObj) => + val nfields = ccd.fields.map { fld => + unwrapLazyType(fld.getType) match { + case None => fld + case Some(btype) => + val adtType = AbstractClassType(closureFactory.absClosureType(typeNameWOParams(btype)), + getTypeParameters(btype)) + ValDef(fld.id, Some(adtType)) // overriding the field type + } + } + ccd.setFields(nfields) + case _ => ; + } + + def apply: Program = { + // TODO: for now pick a arbitrary point to add new defs. But ideally the lazy closure will be added to a separate module + // and imported every where + val anchor = funMap.values.last + transformCaseClasses + assignBodiesToFunctions + assignContractsForEvals + addDefs( + copyProgram(p, + (defs: Seq[Definition]) => defs.flatMap { + case fd: FunDef if funMap.contains(fd) => + if (removeRecursionViaEval) + Seq(funMap(fd), uninterpretedTargets(fd)) + else Seq(funMap(fd)) + case d => Seq(d) + }), + closureFactory.allClosuresAndParents ++ closureCons.values ++ + evalFunctions.values ++ computeFunctions.values, anchor) + } +} + diff --git a/src/main/scala/leon/laziness/LazyClosureFactory.scala b/src/main/scala/leon/laziness/LazyClosureFactory.scala new file mode 100644 index 000000000..a0d0e089b --- /dev/null +++ b/src/main/scala/leon/laziness/LazyClosureFactory.scala @@ -0,0 +1,115 @@ +package leon +package laziness + +import invariant.factories._ +import invariant.util.Util._ +import invariant.util._ +import invariant.structure.FunctionUtils._ +import purescala.ScalaPrinter +import purescala.Common._ +import purescala.Definitions._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.DefOps._ +import purescala.Extractors._ +import purescala.Types._ +import leon.invariant.util.TypeUtil._ +import leon.invariant.util.LetTupleSimplification._ +import leon.verification.AnalysisPhase +import java.io.File +import java.io.FileWriter +import java.io.BufferedWriter +import scala.util.matching.Regex +import leon.purescala.PrettyPrinter +import leon.LeonContext +import leon.LeonOptionDef +import leon.Main +import leon.TransformationPhase +import LazinessUtil._ + +//case class ClosureData(tpe: TypeTree, absDef: AbstractClassDef, caseClass: Seq[CaseClassDef]) + +class LazyClosureFactory(p: Program) { + val debug = false + implicit val prog = p + /** + * Create a mapping from types to the lazyops that may produce a value of that type + * TODO: relax that requirement that type parameters of return type of a function + * lazy evaluated should include all of its type parameters + */ + private val (tpeToADT, opToCaseClass) = { + // collect all the operations that could be lazily evaluated. + val lazyops = p.definedFunctions.flatMap { + case fd if (fd.hasBody) => + filter(isLazyInvocation)(fd.body.get) map { + case FunctionInvocation(_, Seq(FunctionInvocation(tfd, _))) => tfd.fd + } + case _ => Seq() + }.distinct + if (debug) { + println("Lazy operations found: \n" + lazyops.map(_.id).mkString("\n")) + } + val tpeToLazyops = lazyops.groupBy(_.returnType) + val tpeToAbsClass = tpeToLazyops.map(_._1).map { tpe => + val name = typeNameWOParams(tpe) + val absTParams = getTypeParameters(tpe) map TypeParameterDef.apply + // using tpe name below to avoid mismatches due to type parameters + name -> AbstractClassDef(FreshIdentifier(typeNameToADTName(name), Untyped), + absTParams, None) + }.toMap + var opToAdt = Map[FunDef, CaseClassDef]() + val tpeToADT = tpeToLazyops map { + case (tpe, ops) => + val name = typeNameWOParams(tpe) + val absClass = tpeToAbsClass(name) + val absType = AbstractClassType(absClass, absClass.tparams.map(_.tp)) + val absTParams = absClass.tparams + // create a case class for every operation + val cdefs = ops map { opfd => + assert(opfd.tparams.size == absTParams.size) + val classid = FreshIdentifier(opNameToCCName(opfd.id.name), Untyped) + val cdef = CaseClassDef(classid, opfd.tparams, Some(absType), isCaseObject = false) + val nfields = opfd.params.map { vd => + val fldType = vd.getType + unwrapLazyType(fldType) match { + case None => + ValDef(FreshIdentifier(vd.id.name, fldType)) + case Some(btype) => + val adtType = AbstractClassType(absClass, getTypeParameters(btype)) + ValDef(FreshIdentifier(vd.id.name, adtType)) + } + } + cdef.setFields(nfields) + absClass.registerChild(cdef) + opToAdt += (opfd -> cdef) + cdef + } + (name -> (tpe, absClass, cdefs)) + } + (tpeToADT, opToAdt) + } + + // this fixes an ordering on lazy types + lazy val lazyTypeNames = tpeToADT.keys.toSeq + + def allClosuresAndParents = tpeToADT.values.flatMap(v => v._2 +: v._3) + + def lazyType(tn: String) = tpeToADT(tn)._1 + + def absClosureType(tn: String) = tpeToADT(tn)._2 + + def closures(tn: String) = tpeToADT(tn)._3 + + lazy val caseClassToOp = opToCaseClass map { case (k, v) => v -> k } + + def lazyopOfClosure(cl: CaseClassDef) = caseClassToOp(cl) + + def closureOfLazyOp(op: FunDef) = opToCaseClass(op) + + /** + * Here, the lazy type name is recovered from the closure's name. + * This avoids the use of additional maps. + */ + def lazyTypeNameOfClosure(cl: CaseClassDef) = adtNameToTypeName(cl.parent.get.classDef.id.name) +} + diff --git a/src/main/scala/leon/laziness/TypeChecker.scala b/src/main/scala/leon/laziness/TypeChecker.scala new file mode 100644 index 000000000..5bd216dcc --- /dev/null +++ b/src/main/scala/leon/laziness/TypeChecker.scala @@ -0,0 +1,188 @@ +package leon +package laziness + +import purescala.ScalaPrinter +import purescala.Common._ +import purescala.Definitions._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.Extractors._ +import purescala.Types._ +import leon.invariant.util.TypeUtil._ + +object TypeChecker { + /** + * `gamma` is the initial type environment which has + * type bindings for free variables of `ine`. + * It is not necessary that gamma should match the types of the + * identifiers of the free variables. + * Set and Maps are not supported yet + */ + def inferTypesOfLocals(ine: Expr, initGamma: Map[Identifier, TypeTree]): Expr = { + var idmap = Map[Identifier, Identifier]() + var gamma = initGamma + + /** + * Note this method has side-effects + */ + def makeIdOfType(oldId: Identifier, tpe: TypeTree): Identifier = { + if (oldId.getType != tpe) { + val freshid = FreshIdentifier(oldId.name, tpe, true) + idmap += (oldId -> freshid) + gamma += (oldId -> tpe) + freshid + } else oldId + } + + def rec(e: Expr): (TypeTree, Expr) = { + val res = e match { + case Let(id, value, body) => + val (valType, nval) = rec(value) + val nid = makeIdOfType(id, valType) + val (btype, nbody) = rec(body) + (btype, Let(nid, nval, nbody)) + + case Ensuring(body, Lambda(Seq(resdef @ ValDef(resid, _)), postBody)) => + val (btype, nbody) = rec(body) + val nres = makeIdOfType(resid, btype) + (btype, Ensuring(nbody, Lambda(Seq(ValDef(nres)), rec(postBody)._2))) + + case MatchExpr(scr, mcases) => + val (scrtype, nscr) = rec(scr) + val ncases = mcases.map { + case MatchCase(pat, optGuard, rhs) => + // resetting the type of patterns in the matches + def mapPattern(p: Pattern, expType: TypeTree): (Pattern, TypeTree) = { + p match { + case InstanceOfPattern(bopt, ict) => + // choose the subtype of the `expType` that + // has the same constructor as `ict` + val ntype = subcast(ict, expType.asInstanceOf[ClassType]) + if (!ntype.isDefined) + throw new IllegalStateException(s"Cannot find subtype of $expType with name: ${ict.classDef.id.toString}") + val nbopt = bopt.map(makeIdOfType(_, ntype.get)) + (InstanceOfPattern(nbopt, ntype.get), ntype.get) + + case CaseClassPattern(bopt, ict, subpats) => + val ntype = subcast(ict, expType.asInstanceOf[ClassType]) + if (!ntype.isDefined) + throw new IllegalStateException(s"Cannot find subtype of $expType with name: ${ict.classDef.id.toString}") + val cct = ntype.get.asInstanceOf[CaseClassType] + val nbopt = bopt.map(makeIdOfType(_, cct)) + val npats = (subpats zip cct.fieldsTypes).map { + case (p, t) => + //println(s"Subpat: $p expected type: $t") + mapPattern(p, t)._1 + } + (CaseClassPattern(nbopt, cct, npats), cct) + + case TuplePattern(bopt, subpats) => + val TupleType(subts) = scrtype + val patnTypes = (subpats zip subts).map { + case (p, t) => mapPattern(p, t) + } + val npats = patnTypes.map(_._1) + val ntype = TupleType(patnTypes.map(_._2)) + val nbopt = bopt.map(makeIdOfType(_, ntype)) + (TuplePattern(nbopt, npats), ntype) + + case WildcardPattern(bopt) => + val nbopt = bopt.map(makeIdOfType(_, expType)) + (WildcardPattern(nbopt), expType) + + case LiteralPattern(bopt, lit) => + val ntype = lit.getType + val nbopt = bopt.map(makeIdOfType(_, ntype)) + (LiteralPattern(nbopt, lit), ntype) + case _ => + throw new IllegalStateException("Not supported yet!") + } + } + val npattern = mapPattern(pat, scrtype)._1 + val nguard = optGuard.map(rec(_)._2) + val nrhs = rec(rhs)._2 + //println(s"New rhs: $nrhs inferred type: ${nrhs.getType}") + MatchCase(npattern, nguard, nrhs) + } + val nmatch = MatchExpr(nscr, ncases) + (nmatch.getType, nmatch) + + case cs @ CaseClassSelector(cltype, clExpr, fld) => + val (ncltype: ClassType, nclExpr) = rec(clExpr) + // this is a hack. TODO: fix this + subcast(cltype, ncltype) match { + case Some(ntype : CaseClassType) => + (ntype, CaseClassSelector(ntype, nclExpr, fld)) + case _ => + throw new IllegalStateException(s"$nclExpr : $ncltype cannot be cast to case class type: $cltype") + } + + case AsInstanceOf(clexpr, cltype) => + val (ncltype: ClassType, nexpr) = rec(clexpr) + subcast(cltype, ncltype) match { + case Some(ntype) => (ntype, AsInstanceOf(nexpr, ntype)) + case _ => + //println(s"asInstanceOf type of $clExpr is: $cltype inferred type of $nclExpr : $ct") + throw new IllegalStateException(s"$nexpr : $ncltype cannot be cast to case class type: $cltype") + } + + case v @ Variable(id) => + if (gamma.contains(id)) { + if (idmap.contains(id)) + (gamma(id), idmap(id).toVariable) + else { + (gamma(id), v) + } + } else (id.getType, v) + + case FunctionInvocation(TypedFunDef(fd, tparams), args) => + val nargs = args.map(arg => rec(arg)._2) + var tpmap = Map[TypeParameter, TypeTree]() + (fd.params zip nargs).foreach { x => + (x._1.getType, x._2.getType) match { + case (t1, t2) => + getTypeArguments(t1) zip getTypeArguments(t2) foreach { + case (tf : TypeParameter, ta) => + tpmap += (tf -> ta) + case _ => ; + } + /*throw new IllegalStateException(s"Types of formal and actual parameters: ($tf, $ta)" + + s"do not match for call: $call")*/ + } + } + val ntparams = fd.tparams.map(tpd => tpmap(tpd.tp)) + val nexpr = FunctionInvocation(TypedFunDef(fd, ntparams), nargs) + (nexpr.getType, nexpr) + + // need to handle tuple select specially + case TupleSelect(tup, i) => + val nop = TupleSelect(rec(tup)._2, i) + (nop.getType, nop) + case Operator(args, op) => + val nop = op(args.map(arg => rec(arg)._2)) + (nop.getType, nop) + case t: Terminal => + (t.getType, t) + } + //println(s"Inferred type of $e : ${res._1} new expression: ${res._2}") + if (res._1 == Untyped) { + throw new IllegalStateException(s"Cannot infer type for expression: $e") + } + res + } + + def subcast(oldType: ClassType, newType: ClassType): Option[ClassType] = { + newType match { + case AbstractClassType(absClass, tps) if absClass.knownCCDescendants.contains(oldType.classDef) => + //here oldType.classDef <: absClass + Some(CaseClassType(oldType.classDef.asInstanceOf[CaseClassDef], tps)) + case cct: CaseClassType => + Some(cct) + case _ => + None + } + } + rec(ine)._2 + } +} + diff --git a/src/main/scala/leon/laziness/TypeRectifier.scala b/src/main/scala/leon/laziness/TypeRectifier.scala new file mode 100644 index 000000000..171dd4f3d --- /dev/null +++ b/src/main/scala/leon/laziness/TypeRectifier.scala @@ -0,0 +1,152 @@ +package leon +package laziness + +import invariant.factories._ +import invariant.util.Util._ +import invariant.util._ +import invariant.structure.FunctionUtils._ +import purescala.ScalaPrinter +import purescala.Common._ +import purescala.Definitions._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.DefOps._ +import purescala.Extractors._ +import purescala.Types._ +import leon.invariant.util.TypeUtil._ +import leon.invariant.util.LetTupleSimplification._ +import leon.verification.AnalysisPhase +import java.io.File +import java.io.FileWriter +import java.io.BufferedWriter +import scala.util.matching.Regex +import leon.purescala.PrettyPrinter +import leon.LeonContext +import leon.LeonOptionDef +import leon.Main +import leon.TransformationPhase +import LazinessUtil._ +import leon.invariant.datastructure.DisjointSets + +/** + * This performs a little bit of Hindley-Milner type Inference + * to correct the local types and also unify type parameters + * @param placeHolderParameter Expected to returns true iff a type parameter + * is meant as a placeholder and cannot be used + * to represent a unified type + */ +class TypeRectifier(p: Program, placeHolderParameter: TypeParameter => Boolean) { + + val typeClasses = { + var tc = new DisjointSets[TypeTree]() + p.definedFunctions.foreach { + case fd if fd.hasBody && !fd.isLibrary => + postTraversal { + case call @ FunctionInvocation(TypedFunDef(fd, tparams), args) => + // unify formal type parameters with actual type arguments + (fd.tparams zip tparams).foreach(x => tc.union(x._1.tp, x._2)) + // unify the type parameters of types of formal parameters with + // type arguments of actual arguments + (fd.params zip args).foreach { x => + (x._1.getType, x._2.getType) match { + case (SetType(tf: ClassType), SetType(ta: ClassType)) => + (tf.tps zip ta.tps).foreach { x => tc.union(x._1, x._2) } + case (tf: TypeParameter, ta: TypeParameter) => + tc.union(tf, ta) + case (t1, t2) => + // others could be ignored for now as they are not part of the state + //TODO: handle this case + ; + /*throw new IllegalStateException(s"Types of formal and actual parameters: ($tf, $ta)" + + s"do not match for call: $call")*/ + } + } + case _ => ; + }(fd.fullBody) + case _ => ; + } + tc + } + + val equivTypeParams = typeClasses.toMap + + val fdMap = p.definedFunctions.collect { + case fd if !fd.isLibrary => + val (tempTPs, otherTPs) = fd.tparams.map(_.tp).partition { + case tp if placeHolderParameter(tp) => true + case _ => false + } + val others = otherTPs.toSet[TypeTree] + // for each of the type parameter pick one representative from its equvialence class + val tpMap = fd.tparams.map { + case TypeParameterDef(tp) => + val tpclass = equivTypeParams.getOrElse(tp, Set(tp)) + val candReps = tpclass.filter(r => others.contains(r) || !r.isInstanceOf[TypeParameter]) + val concRep = candReps.find(!_.isInstanceOf[TypeParameter]) + val rep = + if (concRep.isDefined) // there exists a concrete type ? + concRep.get + else if (!candReps.isEmpty) + candReps.head + else + throw new IllegalStateException(s"Cannot find a non-placeholder in equivalence class $tpclass for fundef: \n $fd") + tp -> rep + }.toMap + val instf = instantiateTypeParameters(tpMap) _ + val paramMap = fd.params.map { + case vd @ ValDef(id, _) => + (id -> FreshIdentifier(id.name, instf(vd.getType))) + }.toMap + val ntparams = tpMap.values.toSeq.distinct.collect { case tp: TypeParameter => tp } map TypeParameterDef + val nfd = new FunDef(fd.id.freshen, ntparams, instf(fd.returnType), + fd.params.map(vd => ValDef(paramMap(vd.id)))) + fd -> (nfd, tpMap, paramMap) + }.toMap + + /** + * Replace fundefs and unify type parameters in function invocations. + * Replace old parameters by new parameters + */ + def transformFunBody(ifd: FunDef) = { + val (nfd, tpMap, paramMap) = fdMap(ifd) + // need to handle tuple select specially as it breaks if the type of + // the tupleExpr if it is not TupleTyped. + // cannot use simplePostTransform because of this + def rec(e: Expr): Expr = e match { + case FunctionInvocation(TypedFunDef(callee, targsOld), args) => + val targs = targsOld.map { + case tp: TypeParameter => tpMap.getOrElse(tp, tp) + case t => t + }.distinct + val ncallee = + if (fdMap.contains(callee)) + fdMap(callee)._1 + else callee + FunctionInvocation(TypedFunDef(ncallee, targs), args map rec) + case Variable(id) if paramMap.contains(id) => + paramMap(id).toVariable + case TupleSelect(tup, index) => TupleSelect(rec(tup), index) + case Operator(args, op) => op(args map rec) + case t: Terminal => t + } + val nbody = rec(ifd.fullBody) + //println("Inferring types for: "+ifd.id) + val initGamma = nfd.params.map(vd => vd.id -> vd.getType).toMap + TypeChecker.inferTypesOfLocals(nbody, initGamma) + } + + def apply: Program = { + copyProgram(p, (defs: Seq[Definition]) => defs.map { + case fd: FunDef if fdMap.contains(fd) => + val nfd = fdMap(fd)._1 + if (!fd.fullBody.isInstanceOf[NoTree]) { + nfd.fullBody = simplifyLetsAndLetsWithTuples(transformFunBody(fd)) + } + fd.flags.foreach(nfd.addFlag(_)) + //println("New fun: "+fd) + nfd + case d => d + }) + } +} + diff --git a/src/main/scala/leon/transformations/LazinessEliminationPhase.scala b/src/main/scala/leon/transformations/LazinessEliminationPhase.scala deleted file mode 100644 index 4bae04842..000000000 --- a/src/main/scala/leon/transformations/LazinessEliminationPhase.scala +++ /dev/null @@ -1,1092 +0,0 @@ -package leon -package transformations - -import invariant.factories._ -import invariant.util.Util._ -import invariant.util._ -import invariant.structure.FunctionUtils._ -import purescala.ScalaPrinter -import purescala.Common._ -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.DefOps._ -import purescala.Extractors._ -import purescala.Types._ -import purescala.FunctionClosure -import scala.collection.mutable.{ Map => MutableMap } -import leon.invariant.util.TypeUtil._ -import leon.invariant.util.LetTupleSimplification._ -import leon.solvers.TimeoutSolverFactory -import leon.verification.VerificationContext -import leon.verification.DefaultTactic -import leon.verification.AnalysisPhase -import java.io.File -import java.io.FileWriter -import java.io.BufferedWriter -import scala.util.matching.Regex -import leon.purescala.PrettyPrinter - -object LazinessEliminationPhase extends TransformationPhase { - val debug = false - val dumpProgramWithClosures = false - val dumpTypeCorrectProg = false - val dumpFinalProg = false - - // flags - val removeRecursionViaEval = false - val skipVerification = false - val prettyPrint = true - val optOutputDirectory = new LeonOptionDef[String] { - val name = "o" - val description = "Output directory" - val default = "leon.out" - val usageRhs = "dir" - val parser = (x: String) => x - } - - val name = "Laziness Elimination Phase" - val description = "Coverts a program that uses lazy construct" + - " to a program that does not use lazy constructs" - - /** - * TODO: enforce that the specs do not create lazy closures - * TODO: we are forced to make an assumption that lazy ops takes as type parameters only those - * type parameters of their return type and not more. (enforce this) - * TODO: Check that lazy types are not nested - */ - def apply(ctx: LeonContext, prog: Program): Program = { - - val nprog = liftLazyExpressions(prog) - val (tpeToADT, opToAdt) = createClosures(nprog) - //Override the types of the lazy fields in the case class definition - nprog.definedClasses.foreach { - case ccd @ CaseClassDef(id, tparamDefs, superClass, isCaseObj) => - val nfields = ccd.fields.map { fld => - unwrapLazyType(fld.getType) match { - case None => fld - case Some(btype) => - val adtType = AbstractClassType(tpeToADT(typeNameWOParams(btype))._2, - getTypeParameters(btype)) - ValDef(fld.id, Some(adtType)) // overriding the field type - } - } - ccd.setFields(nfields) - case _ => ; - } - // TODO: for now pick one suitable module. But ideally the lazy closure will be added to a separate module - // and imported every where - val progWithClasses = addDefs(nprog, - tpeToADT.values.flatMap(v => v._2 +: v._3), - opToAdt.keys.last) - if (debug) - println("After adding case class corresponding to lazyops: \n" + ScalaPrinter.apply(progWithClasses)) - val progWithClosures = (new TransformProgramUsingClosures(progWithClasses, tpeToADT, opToAdt))() - //Rectify type parameters and local types - val typeCorrectProg = rectifyLocalTypeAndTypeParameters(progWithClosures) - if (dumpTypeCorrectProg) - println("After rectifying types: \n" + ScalaPrinter.apply(typeCorrectProg)) - - val transProg = assertClosurePres(typeCorrectProg) - if (dumpFinalProg) - println("After asserting closure preconditions: \n" + ScalaPrinter.apply(transProg)) - // handle 'axiom annotation - transProg.definedFunctions.foreach { fd => - if (fd.annotations.contains("axiom")) - fd.addFlag(Annotation("library", Seq())) - } - // check specifications (to be moved to a different phase) - if (!skipVerification) - checkSpecifications(transProg) - if (prettyPrint) - prettyPrintProgramToFile(transProg, ctx) - transProg - } - - def prettyPrintProgramToFile(p: Program, ctx: LeonContext) { - val outputFolder = ctx.findOptionOrDefault(optOutputDirectory) - try { - new File(outputFolder).mkdir() - } catch { - case _ : java.io.IOException => ctx.reporter.fatalError("Could not create directory " + outputFolder) - } - - for (u <- p.units if u.isMainUnit) { - val outputFile = s"$outputFolder${File.separator}${u.id.toString}.scala" - try { - val out = new BufferedWriter(new FileWriter(outputFile)) - // remove '@' from the end of the identifier names - val pat = new Regex("""(\S+)(@)""", "base", "suffix") - val pgmText = pat.replaceAllIn(PrettyPrinter.apply(p), m => m.group("base")) - out.write(pgmText) - out.close() - } - catch { - case _ : java.io.IOException => ctx.reporter.fatalError("Could not write on " + outputFile) - } - } - ctx.reporter.info("Output written on " + outputFolder) - } - - def isLazyInvocation(e: Expr)(implicit p: Program): Boolean = e match { - case FunctionInvocation(TypedFunDef(fd, _), Seq(_)) => - fullName(fd)(p) == "leon.lazyeval.$.apply" - case _ => - false - } - - def isEvaluatedInvocation(e: Expr)(implicit p: Program): Boolean = e match { - case FunctionInvocation(TypedFunDef(fd, _), Seq(_)) => - fullName(fd)(p) == "leon.lazyeval.$.isEvaluated" - case _ => false - } - - def isValueInvocation(e: Expr)(implicit p: Program): Boolean = e match { - case FunctionInvocation(TypedFunDef(fd, _), Seq(_)) => - fullName(fd)(p) == "leon.lazyeval.$.value" - case _ => false - } - - def isStarInvocation(e: Expr)(implicit p: Program): Boolean = e match { - case FunctionInvocation(TypedFunDef(fd, _), Seq(_)) => - fullName(fd)(p) == "leon.lazyeval.$.*" - case _ => false - } - - def isLazyType(tpe: TypeTree): Boolean = tpe match { - case CaseClassType(CaseClassDef(cid, _, None, false), Seq(_)) => - cid.name == "$" - case _ => false - } - - /** - * TODO: Check that lazy types are not nested - */ - def unwrapLazyType(tpe: TypeTree) = tpe match { - case ctype @ CaseClassType(_, Seq(innerType)) if isLazyType(ctype) => - Some(innerType) - case _ => None - } - - def rootType(t: TypeTree): Option[AbstractClassType] = t match { - case absT: AbstractClassType => Some(absT) - case ct: ClassType => ct.parent - case _ => None - } - - def opNameToCCName(name: String) = { - name.capitalize + "@" - } - - /** - * Convert the first character to lower case - * and remove the last character. - */ - def ccNameToOpName(name: String) = { - name.substring(0, 1).toLowerCase() + - name.substring(1, name.length() - 1) - } - - def typeNameToADTName(name: String) = { - "Lazy" + name - } - - def adtNameToTypeName(name: String) = { - name.substring(4) - } - - def closureConsName(typeName: String) = { - "new@" + typeName - } - - def isClosureCons(fd: FunDef) = { - fd.id.name.startsWith("new@") - } - - /** - * convert the argument of every lazy constructors to a procedure - */ - def liftLazyExpressions(prog: Program): Program = { - var newfuns = Map[ExprStructure, (FunDef, ModuleDef)]() - var anchorFun: Option[FunDef] = None - val fdmap = prog.modules.flatMap { md => - md.definedFunctions.map { - case fd if fd.hasBody && !fd.isLibrary => - //println("FunDef: "+fd) - val nfd = preMapOnFunDef { - case finv @ FunctionInvocation(lazytfd, Seq(arg)) if isLazyInvocation(finv)(prog) && !arg.isInstanceOf[FunctionInvocation] => - val freevars = variablesOf(arg).toList - val tparams = freevars.map(_.getType) flatMap getTypeParameters - val argstruc = new ExprStructure(arg) - val argfun = - if (newfuns.contains(argstruc)) { - newfuns(argstruc)._1 - } else { - //construct type parameters for the function - val nfun = new FunDef(FreshIdentifier("lazyarg", arg.getType, true), - tparams map TypeParameterDef.apply, arg.getType, freevars.map(ValDef(_))) - nfun.body = Some(arg) - newfuns += (argstruc -> (nfun, md)) - nfun - } - Some(FunctionInvocation(lazytfd, Seq(FunctionInvocation(TypedFunDef(argfun, tparams), - freevars.map(_.toVariable))))) - case _ => None - }(fd) - (fd -> Some(nfd)) - case fd => fd -> Some(fd.duplicate) - } - }.toMap - // map the new functions to themselves - val newfdMap = newfuns.values.map { case (nfd, _) => nfd -> None }.toMap - val (repProg, _) = replaceFunDefs(prog)(fdmap ++ newfdMap) - val nprog = - if (!newfuns.isEmpty) { - val modToNewDefs = newfuns.values.groupBy(_._2).map { case (k, v) => (k, v.map(_._1)) }.toMap - appendDefsToModules(repProg, modToNewDefs) - } else - throw new IllegalStateException("Cannot find any lazy computation") - if (debug) - println("After lifiting arguments of lazy constructors: \n" + ScalaPrinter.apply(nprog)) - nprog - } - - /** - * Create a mapping from types to the lazyops that may produce a value of that type - * TODO: relax that requirement that type parameters of return type of a function - * lazy evaluated should include all of its type parameters - */ - type ClosureData = (TypeTree, AbstractClassDef, Seq[CaseClassDef]) - def createClosures(implicit p: Program) = { - // collect all the operations that could be lazily evaluated. - val lazyops = p.definedFunctions.flatMap { - case fd if (fd.hasBody) => - filter(isLazyInvocation)(fd.body.get) map { - case FunctionInvocation(_, Seq(FunctionInvocation(tfd, _))) => tfd.fd - } - case _ => Seq() - }.distinct - if (debug) { - println("Lazy operations found: \n" + lazyops.map(_.id).mkString("\n")) - } - val tpeToLazyops = lazyops.groupBy(_.returnType) - val tpeToAbsClass = tpeToLazyops.map(_._1).map { tpe => - val name = typeNameWOParams(tpe) - val absTParams = getTypeParameters(tpe) map TypeParameterDef.apply - // using tpe name below to avoid mismatches due to type parameters - name -> AbstractClassDef(FreshIdentifier(typeNameToADTName(name), Untyped), - absTParams, None) - }.toMap - var opToAdt = Map[FunDef, CaseClassDef]() - val tpeToADT = tpeToLazyops map { - case (tpe, ops) => - val name = typeNameWOParams(tpe) - val absClass = tpeToAbsClass(name) - val absType = AbstractClassType(absClass, absClass.tparams.map(_.tp)) - val absTParams = absClass.tparams - // create a case class for every operation - val cdefs = ops map { opfd => - assert(opfd.tparams.size == absTParams.size) - val classid = FreshIdentifier(opNameToCCName(opfd.id.name), Untyped) - val cdef = CaseClassDef(classid, opfd.tparams, Some(absType), isCaseObject = false) - val nfields = opfd.params.map { vd => - val fldType = vd.getType - unwrapLazyType(fldType) match { - case None => - ValDef(FreshIdentifier(vd.id.name, fldType)) - case Some(btype) => - val adtType = AbstractClassType(absClass, getTypeParameters(btype)) - ValDef(FreshIdentifier(vd.id.name, adtType)) - } - } - cdef.setFields(nfields) - absClass.registerChild(cdef) - opToAdt += (opfd -> cdef) - cdef - } - (name -> (tpe, absClass, cdefs)) - } - (tpeToADT, opToAdt) - } - - /** - * (a) add state to every function in the program - * (b) thread state through every expression in the program sequentially - * (c) replace lazy constructions with case class creations - * (d) replace isEvaluated with currentState.contains() - * (e) replace accesses to $.value with calls to dispatch with current state - */ - class TransformProgramUsingClosures(p: Program, - tpeToADT: Map[String, ClosureData], - opToAdt: Map[FunDef, CaseClassDef]) { - - val (funsNeedStates, funsRetStates) = funsNeedingnReturningState(p) - // fix an ordering on types so that we can instrument programs with their states in the same order - val tnames = tpeToADT.keys.toSeq - // create a mapping from functions to new functions - val funMap = p.definedFunctions.collect { - case fd if (fd.hasBody && !fd.isLibrary) => - // replace lazy types in parameters and return values - val nparams = fd.params map { vd => - ValDef(vd.id, Some(replaceLazyTypes(vd.getType))) // override type of identifier - } - val nretType = replaceLazyTypes(fd.returnType) - // does the function require implicit state ? - val nfd = if (funsNeedStates(fd)) { - var newTParams = Seq[TypeParameterDef]() - val stTypes = tnames map { tn => - val absClass = tpeToADT(tn)._2 - val tparams = absClass.tparams.map(_ => - TypeParameter.fresh("P@")) - newTParams ++= tparams map TypeParameterDef - SetType(AbstractClassType(absClass, tparams)) - } - val stParams = stTypes.map { stType => - ValDef(FreshIdentifier("st@", stType, true)) - } - val retTypeWithState = - if (funsRetStates(fd)) - TupleType(nretType +: stTypes) - else - nretType - // the type parameters will be unified later - new FunDef(FreshIdentifier(fd.id.name, Untyped), - fd.tparams ++ newTParams, retTypeWithState, nparams ++ stParams) - // body of these functions are defined later - } else { - new FunDef(FreshIdentifier(fd.id.name, Untyped), fd.tparams, nretType, nparams) - } - // copy annotations - fd.flags.foreach(nfd.addFlag(_)) - (fd -> nfd) - }.toMap - - /** - * A set of uninterpreted functions that may be used as targets - * of closures in the eval functions, for efficiency reasons. - */ - lazy val uninterpretedTargets = { - funMap.map { - case (k, v) => - val ufd = new FunDef(FreshIdentifier(v.id.name, v.id.getType, true), - v.tparams, v.returnType, v.params) - (k -> ufd) - }.toMap - } - - /** - * A function for creating a state type for every lazy type. Note that Leon - * doesn't support 'Any' type yet. So we have to have multiple - * state (though this is much clearer it results in more complicated code) - */ - def getStateType(tname: String, tparams: Seq[TypeParameter]) = { - val (_, absdef, _) = tpeToADT(tname) - SetType(AbstractClassType(absdef, tparams)) - } - - def replaceLazyTypes(t: TypeTree): TypeTree = { - unwrapLazyType(t) match { - case None => - val NAryType(tps, tcons) = t - tcons(tps map replaceLazyTypes) - case Some(btype) => - val ntype = AbstractClassType(tpeToADT( - typeNameWOParams(btype))._2, getTypeParameters(btype)) - val NAryType(tps, tcons) = ntype - tcons(tps map replaceLazyTypes) - } - } - - /** - * Create dispatch functions for each lazy type. - * Note: the dispatch functions will be annotated as library so that - * their pre/posts are not checked (the fact that they hold are verified separately) - * Note by using 'assume-pre' we can also assume the preconditions of closures at - * the call-sites. - */ - val adtToOp = opToAdt map { case (k, v) => v -> k } - val evalFunctions = { - tpeToADT map { - case (tname, (tpe, absdef, cdefs)) => - val tparams = getTypeParameters(tpe) - val tparamDefs = tparams map TypeParameterDef.apply - val param1 = FreshIdentifier("cl", AbstractClassType(absdef, tparams)) - val stType = getStateType(tname, tparams) - val param2 = FreshIdentifier("st@", stType) - val retType = TupleType(Seq(tpe, stType)) - val dfun = new FunDef(FreshIdentifier("eval" + absdef.id.name, Untyped), - tparamDefs, retType, Seq(ValDef(param1), ValDef(param2))) - // assign body - // create a match case to switch over the possible class defs and invoke the corresponding functions - val bodyMatchCases = cdefs map { cdef => - val ctype = CaseClassType(cdef, tparams) // we assume that the type parameters of cdefs are same as absdefs - val binder = FreshIdentifier("t", ctype) - val pattern = InstanceOfPattern(Some(binder), ctype) - // create a body of the match - val args = cdef.fields map { fld => CaseClassSelector(ctype, binder.toVariable, fld.id) } - val op = adtToOp(cdef) - val stArgs = // TODO: here we are assuming that only one state is used, fix this. - if (funsNeedStates(op)) - // Note: it is important to use empty state here to eliminate - // dependency on state for the result value - Seq(FiniteSet(Set(), stType.base)) - else Seq() - val targetFun = - if (removeRecursionViaEval && op.hasPostcondition) { - // checking for postcondition is a hack to make sure that we - // do not make all targets uninterpreted - uninterpretedTargets(op) - } else funMap(op) - val invoke = FunctionInvocation(TypedFunDef(targetFun, tparams), args ++ stArgs) // we assume that the type parameters of lazy ops are same as absdefs - val invPart = if (funsRetStates(op)) { - TupleSelect(invoke, 1) // we are only interested in the value - } else invoke - val newst = SetUnion(param2.toVariable, FiniteSet(Set(binder.toVariable), stType.base)) - val rhs = Tuple(Seq(invPart, newst)) - MatchCase(pattern, None, rhs) - } - dfun.body = Some(MatchExpr(param1.toVariable, bodyMatchCases)) - dfun.addFlag(Annotation("axiom", Seq())) - (tname -> dfun) - } - } - - /** - * These are evalFunctions that do not affect the state - */ - val computeFunctions = { - evalFunctions map { - case (tname, evalfd) => - val (tpe, _, _) = tpeToADT(tname) - val param1 = evalfd.params.head - val fun = new FunDef(FreshIdentifier(evalfd.id.name + "*", Untyped), - evalfd.tparams, tpe, Seq(param1)) - val invoke = FunctionInvocation(TypedFunDef(evalfd, evalfd.tparams.map(_.tp)), - Seq(param1.id.toVariable, FiniteSet(Set(), - getStateType(tname, getTypeParameters(tpe)).base))) - fun.body = Some(TupleSelect(invoke, 1)) - (tname -> fun) - } - } - - /** - * Create closure construction functions that ensures a postcondition. - * They are defined for each lazy class since it avoids generics and - * simplifies the type inference (which is not full-fledged in Leon) - */ - val closureCons = tpeToADT map { - case (tname, (_, adt, _)) => - val param1Type = AbstractClassType(adt, adt.tparams.map(_.tp)) - val param1 = FreshIdentifier("cc", param1Type) - val stType = SetType(param1Type) - val param2 = FreshIdentifier("st@", stType) - val tparamDefs = adt.tparams - val fun = new FunDef(FreshIdentifier(closureConsName(tname)), adt.tparams, param1Type, - Seq(ValDef(param1), ValDef(param2))) - fun.body = Some(param1.toVariable) - val resid = FreshIdentifier("res", param1Type) - val postbody = Not(SubsetOf(FiniteSet(Set(resid.toVariable), param1Type), param2.toVariable)) - fun.postcondition = Some(Lambda(Seq(ValDef(resid)), postbody)) - fun.addFlag(Annotation("axiom", Seq())) - (tname -> fun) - } - - def mapNAryOperator(args: Seq[Expr], op: Seq[Expr] => (Map[String, Expr] => Expr, Boolean)) = { - // create n variables to model n lets - val letvars = args.map(arg => FreshIdentifier("arg", arg.getType, true).toVariable) - (args zip letvars).foldRight(op(letvars)) { - case ((arg, letvar), (accCons, stUpdatedBefore)) => - val (argCons, stUpdateFlag) = mapBody(arg) - val cl = if (!stUpdateFlag) { - // here arg does not affect the newstate - (st: Map[String, Expr]) => replace(Map(letvar -> argCons(st)), accCons(st)) // here, we don't have to create a let - } else { - // here arg does affect the newstate - (st: Map[String, Expr]) => - { - val narg = argCons(st) - val argres = FreshIdentifier("a", narg.getType, true).toVariable - val nstateSeq = tnames.zipWithIndex.map { - // states start from index 2 - case (tn, i) => TupleSelect(argres, i + 2) - } - val nstate = (tnames zip nstateSeq).map { - case (tn, st) => (tn -> st) - }.toMap[String, Expr] - val letbody = - if (stUpdatedBefore) accCons(nstate) // here, 'acc' already returns a superseeding state - else Tuple(accCons(nstate) +: nstateSeq) // here, 'acc; only retruns the result - Let(argres.id, narg, - Let(letvar.id, TupleSelect(argres, 1), letbody)) - } - } - (cl, stUpdatedBefore || stUpdateFlag) - } - } - - def mapBody(body: Expr): (Map[String, Expr] => Expr, Boolean) = body match { - - case finv @ FunctionInvocation(_, Seq(FunctionInvocation(TypedFunDef(argfd, tparams), args))) // lazy construction ? - if isLazyInvocation(finv)(p) => - val op = (nargs: Seq[Expr]) => ((st: Map[String, Expr]) => { - val adt = opToAdt(argfd) - val cc = CaseClass(CaseClassType(adt, tparams), nargs) - val baseLazyType = adtNameToTypeName(adt.parent.get.classDef.id.name) - FunctionInvocation(TypedFunDef(closureCons(baseLazyType), tparams), - Seq(cc, st(baseLazyType))) - }, false) - mapNAryOperator(args, op) - - case finv @ FunctionInvocation(_, args) if isEvaluatedInvocation(finv)(p) => // isEval function ? - val op = (nargs: Seq[Expr]) => ((st: Map[String, Expr]) => { - val narg = nargs(0) // there must be only one argument here - val baseType = unwrapLazyType(narg.getType).get - val tname = typeNameWOParams(baseType) - val adtType = AbstractClassType(tpeToADT(tname)._2, getTypeParameters(baseType)) - SubsetOf(FiniteSet(Set(narg), adtType), st(tname)) - }, false) - mapNAryOperator(args, op) - - case finv @ FunctionInvocation(_, args) if isValueInvocation(finv)(p) => // is value function ? - val op = (nargs: Seq[Expr]) => ((st: Map[String, Expr]) => { - val baseType = unwrapLazyType(nargs(0).getType).get // there must be only one argument here - val tname = typeNameWOParams(baseType) - val dispFun = evalFunctions(tname) - val dispFunInv = FunctionInvocation(TypedFunDef(dispFun, - getTypeParameters(baseType)), nargs :+ st(tname)) - val dispRes = FreshIdentifier("dres", dispFun.returnType) - val nstates = tnames map { - case `tname` => - TupleSelect(dispRes.toVariable, 2) - case other => st(other) - } - Let(dispRes, dispFunInv, Tuple(TupleSelect(dispRes.toVariable, 1) +: nstates)) - }, true) - mapNAryOperator(args, op) - - case finv @ FunctionInvocation(_, args) if isStarInvocation(finv)(p) => // is * function ? - val op = (nargs: Seq[Expr]) => ((st: Map[String, Expr]) => { - val baseType = unwrapLazyType(nargs(0).getType).get // there must be only one argument here - val tname = typeNameWOParams(baseType) - val dispFun = computeFunctions(tname) - FunctionInvocation(TypedFunDef(dispFun, getTypeParameters(baseType)), nargs) - }, false) - mapNAryOperator(args, op) - - case FunctionInvocation(TypedFunDef(fd, tparams), args) if funMap.contains(fd) => - mapNAryOperator(args, - (nargs: Seq[Expr]) => ((st: Map[String, Expr]) => { - val stArgs = - if (funsNeedStates(fd)) { - (tnames map st.apply) - } else Seq() - FunctionInvocation(TypedFunDef(funMap(fd), tparams), nargs ++ stArgs) - }, funsRetStates(fd))) - - case Let(id, value, body) => - val (valCons, valUpdatesState) = mapBody(value) - val (bodyCons, bodyUpdatesState) = mapBody(body) - ((st: Map[String, Expr]) => { - val nval = valCons(st) - if (valUpdatesState) { - val freshid = FreshIdentifier(id.name, nval.getType, true) - val nextStates = tnames.zipWithIndex.map { - case (tn, i) => - TupleSelect(freshid.toVariable, i + 2) - }.toSeq - val nsMap = (tnames zip nextStates).toMap - val transBody = replace(Map(id.toVariable -> TupleSelect(freshid.toVariable, 1)), - bodyCons(nsMap)) - if (bodyUpdatesState) - Let(freshid, nval, transBody) - else - Let(freshid, nval, Tuple(transBody +: nextStates)) - } else - Let(id, nval, bodyCons(st)) - }, valUpdatesState || bodyUpdatesState) - - case IfExpr(cond, thn, elze) => - val (condCons, condState) = mapBody(cond) - val (thnCons, thnState) = mapBody(thn) - val (elzeCons, elzeState) = mapBody(elze) - ((st: Map[String, Expr]) => { - val (ncondCons, nst) = - if (condState) { - val cndExpr = condCons(st) - val bder = FreshIdentifier("c", cndExpr.getType) - val condst = tnames.zipWithIndex.map { - case (tn, i) => tn -> TupleSelect(bder.toVariable, i + 2) - }.toMap - ((th: Expr, el: Expr) => - Let(bder, cndExpr, IfExpr(TupleSelect(bder.toVariable, 1), th, el)), - condst) - } else { - ((th: Expr, el: Expr) => IfExpr(condCons(st), th, el), st) - } - val nelze = - if ((condState || thnState) && !elzeState) - Tuple(elzeCons(nst) +: tnames.map(nst.apply)) - else elzeCons(nst) - val nthn = - if (!thnState && (condState || elzeState)) - Tuple(thnCons(nst) +: tnames.map(nst.apply)) - else thnCons(nst) - ncondCons(nthn, nelze) - }, condState || thnState || elzeState) - - case MatchExpr(scr, cases) => - val (scrCons, scrUpdatesState) = mapBody(scr) - val casesRes = cases.foldLeft(Seq[(Map[String, Expr] => Expr, Boolean)]()) { - case (acc, MatchCase(pat, None, rhs)) => - acc :+ mapBody(rhs) - case mcase => - throw new IllegalStateException("Match case with guards are not supported yet: " + mcase) - } - val casesUpdatesState = casesRes.exists(_._2) - ((st: Map[String, Expr]) => { - val scrExpr = scrCons(st) - val (nscrCons, scrst) = if (scrUpdatesState) { - val bder = FreshIdentifier("scr", scrExpr.getType) - val scrst = tnames.zipWithIndex.map { - case (tn, i) => tn -> TupleSelect(bder.toVariable, i + 2) - }.toMap - ((ncases: Seq[MatchCase]) => - Let(bder, scrExpr, MatchExpr(TupleSelect(bder.toVariable, 1), ncases)), - scrst) - } else { - ((ncases: Seq[MatchCase]) => MatchExpr(scrExpr, ncases), st) - } - val ncases = (cases zip casesRes).map { - case (MatchCase(pat, None, _), (caseCons, caseUpdatesState)) => - val nrhs = - if ((scrUpdatesState || casesUpdatesState) && !caseUpdatesState) - Tuple(caseCons(scrst) +: tnames.map(scrst.apply)) - else caseCons(scrst) - MatchCase(pat, None, nrhs) - } - nscrCons(ncases) - }, scrUpdatesState || casesUpdatesState) - - // need to reset types in the case of case class constructor calls - case CaseClass(cct, args) => - val ntype = replaceLazyTypes(cct).asInstanceOf[CaseClassType] - mapNAryOperator(args, - (nargs: Seq[Expr]) => ((st: Map[String, Expr]) => CaseClass(ntype, nargs), false)) - - case Operator(args, op) => - // here, 'op' itself does not create a new state - mapNAryOperator(args, - (nargs: Seq[Expr]) => ((st: Map[String, Expr]) => op(nargs), false)) - - case t: Terminal => (_ => t, false) - } - - /** - * Assign bodies for the functions in funMap. - */ - def transformFunctions = { - funMap foreach { - case (fd, nfd) => - // /println("Considering function: "+fd) - // Here, using name to identify 'state' parameters, also relying - // on fact that nfd.params are in the same order as tnames - val stateParams = nfd.params.foldLeft(Seq[Expr]()) { - case (acc, ValDef(id, _)) if id.name.startsWith("st@") => - acc :+ id.toVariable - case (acc, _) => acc - } - val initStateMap = tnames zip stateParams toMap - val (nbodyFun, bodyUpdatesState) = mapBody(fd.body.get) - val nbody = nbodyFun(initStateMap) - val bodyWithState = - if (!bodyUpdatesState && funsRetStates(fd)) { - val freshid = FreshIdentifier("bres", nbody.getType) - Let(freshid, nbody, Tuple(freshid.toVariable +: stateParams)) - } else nbody - nfd.body = Some(simplifyLets(bodyWithState)) - //println(s"Body of ${fd.id.name} after conversion&simp: ${nfd.body}") - - // Important: specifications use lazy semantics but - // their state changes are ignored after their execution. - // This guarantees their observational purity/transparency - // collect class invariants that need to be added - if (fd.hasPrecondition) { - val (npreFun, preUpdatesState) = mapBody(fd.precondition.get) - nfd.precondition = - if (preUpdatesState) - Some(TupleSelect(npreFun(initStateMap), 1)) // ignore state updated by pre - else Some(npreFun(initStateMap)) - } - - if (fd.hasPostcondition) { - val Lambda(arg @ Seq(ValDef(resid, _)), post) = fd.postcondition.get - val (npostFun, postUpdatesState) = mapBody(post) - val newres = FreshIdentifier(resid.name, bodyWithState.getType) - val npost1 = - if (bodyUpdatesState || funsRetStates(fd)) { - val stmap = tnames.zipWithIndex.map { - case (tn, i) => (tn -> TupleSelect(newres.toVariable, i + 2)) - }.toMap - replace(Map(resid.toVariable -> TupleSelect(newres.toVariable, 1)), npostFun(stmap)) - } else - replace(Map(resid.toVariable -> newres.toVariable), npostFun(initStateMap)) - val npost = - if (postUpdatesState) - TupleSelect(npost1, 1) // ignore state updated by post - else npost1 - nfd.postcondition = Some(Lambda(Seq(ValDef(newres)), npost)) - } - } - } - - /** - * Assign specs for the eval functions - */ - def transformEvals = { - evalFunctions.foreach { - case (tname, evalfd) => - val cdefs = tpeToADT(tname)._3 - val tparams = evalfd.tparams.map(_.tp) - val postres = FreshIdentifier("res", evalfd.returnType) - // create a match case to switch over the possible class defs and invoke the corresponding functions - val postMatchCases = cdefs map { cdef => - val ctype = CaseClassType(cdef, tparams) - val binder = FreshIdentifier("t", ctype) - val pattern = InstanceOfPattern(Some(binder), ctype) - // create a body of the match - val op = adtToOp(cdef) - val targetFd = funMap(op) - val rhs = if (targetFd.hasPostcondition) { - val args = cdef.fields map { fld => CaseClassSelector(ctype, binder.toVariable, fld.id) } - val stArgs = - if (funsNeedStates(op)) // TODO: here we are assuming that only one state is used, fix this. - Seq(evalfd.params.last.toVariable) - else Seq() - val Lambda(Seq(resarg), targetPost) = targetFd.postcondition.get - val replaceMap = (targetFd.params.map(_.toVariable) zip (args ++ stArgs)).toMap[Expr, Expr] + - (resarg.toVariable -> postres.toVariable) - replace(replaceMap, targetPost) - } else - Util.tru - MatchCase(pattern, None, rhs) - } - evalfd.postcondition = Some( - Lambda(Seq(ValDef(postres)), - MatchExpr(evalfd.params.head.toVariable, postMatchCases))) - } - } - - def apply() = { - transformFunctions - transformEvals - val progWithClosures = addFunDefs(copyProgram(p, - (defs: Seq[Definition]) => defs.flatMap { - case fd: FunDef if funMap.contains(fd) => - if (removeRecursionViaEval) - Seq(funMap(fd), uninterpretedTargets(fd)) - else Seq(funMap(fd)) - case d => Seq(d) - }), closureCons.values ++ evalFunctions.values ++ computeFunctions.values, - funMap.values.last) - if (dumpProgramWithClosures) - println("Program with closures \n" + ScalaPrinter(progWithClosures)) - progWithClosures - } - } - - /** - * Generate lemmas that ensure that preconditions hold for closures. - */ - def assertClosurePres(p: Program): Program = { - - def hasClassInvariants(cc: CaseClass): Boolean = { - val opname = ccNameToOpName(cc.ct.classDef.id.name) - functionByName(opname, p).get.hasPrecondition - } - - var anchorfd: Option[FunDef] = None - val lemmas = p.definedFunctions.flatMap { - case fd if (fd.hasBody && !fd.isLibrary) => - //println("collection closure creation preconditions for: "+fd) - val closures = CollectorWithPaths { - case FunctionInvocation(TypedFunDef(fund, _), - Seq(cc: CaseClass, st)) if isClosureCons(fund) && hasClassInvariants(cc) => - (cc, st) - } traverse (fd.body.get) // Note: closures cannot be created in specs - // Note: once we have separated normal preconditions from state preconditions - // it suffices to just consider state preconditions here - closures.map { - case ((CaseClass(CaseClassType(ccd, _), args), st), path) => - anchorfd = Some(fd) - val target = functionByName(ccNameToOpName(ccd.id.name), p).get //find the target corresponding to the closure - val pre = target.precondition.get - val nargs = - if (target.params.size > args.size) // target takes state ? - args :+ st - else args - val pre2 = replaceFromIDs((target.params.map(_.id) zip nargs).toMap, pre) - val vc = Implies(And(Util.precOrTrue(fd), path), pre2) - // create a function for each vc - val lemmaid = FreshIdentifier(ccd.id.name + fd.id.name + "Lem", Untyped, true) - val params = variablesOf(vc).toSeq.map(v => ValDef(v)) - val tparams = params.flatMap(p => getTypeParameters(p.getType)).distinct map TypeParameterDef - val lemmafd = new FunDef(lemmaid, tparams, BooleanType, params) - lemmafd.body = Some(vc) - // assert the lemma is true - val resid = FreshIdentifier("holds", BooleanType) - lemmafd.postcondition = Some(Lambda(Seq(ValDef(resid)), resid.toVariable)) -// /println("Created lemma function: "+fd) - lemmafd - } - case _ => Seq() - } - if (!lemmas.isEmpty) - addFunDefs(p, lemmas, anchorfd.get) - else p - } - - /** - * Returns all functions that 'need' states to be passed in - * and those that return a new state. - * TODO: implement backwards BFS by reversing the graph - */ - def funsNeedingnReturningState(prog: Program) = { - val cg = CallGraphUtil.constructCallGraph(prog, false, true) - var needRoots = Set[FunDef]() - var retRoots = Set[FunDef]() - prog.definedFunctions.foreach { - case fd if fd.hasBody && !fd.isLibrary => - postTraversal { - case finv: FunctionInvocation if isEvaluatedInvocation(finv)(prog) => - needRoots += fd - case finv: FunctionInvocation if isValueInvocation(finv)(prog) => - needRoots += fd - retRoots += fd - case _ => - ; - }(fd.body.get) - case _ => ; - } - val funsNeedStates = prog.definedFunctions.filterNot(fd => - cg.transitiveCallees(fd).toSet.intersect(needRoots).isEmpty).toSet - val funsRetStates = prog.definedFunctions.filterNot(fd => - cg.transitiveCallees(fd).toSet.intersect(retRoots).isEmpty).toSet - (funsNeedStates, funsRetStates) - } - - /** - * This performs a little bit of Hindley-Milner type Inference - * to correct the local types and also unify type parameters - */ - def rectifyLocalTypeAndTypeParameters(prog: Program): Program = { - var typeClasses = new DisjointSets[TypeTree]() - prog.definedFunctions.foreach { - case fd if fd.hasBody && !fd.isLibrary => - postTraversal { - case call @ FunctionInvocation(TypedFunDef(fd, tparams), args) => - // unify formal type parameters with actual type arguments - (fd.tparams zip tparams).foreach(x => typeClasses.union(x._1.tp, x._2)) - // unify the type parameters of types of formal parameters with - // type arguments of actual arguments - (fd.params zip args).foreach { x => - (x._1.getType, x._2.getType) match { - case (SetType(tf: ClassType), SetType(ta: ClassType)) => - (tf.tps zip ta.tps).foreach { x => typeClasses.union(x._1, x._2) } - case (tf: TypeParameter, ta: TypeParameter) => - typeClasses.union(tf, ta) - case _ => - // others could be ignored as they are not part of the state - ; - /*throw new IllegalStateException(s"Types of formal and actual parameters: ($tf, $ta)" - + s"do not match for call: $call")*/ - } - } - case _ => ; - }(fd.fullBody) - case _ => ; - } - - val equivTPs = typeClasses.toMap - val fdMap = prog.definedFunctions.collect { - case fd if !fd.isLibrary => //fd.hasBody && - - val (tempTPs, otherTPs) = fd.tparams.map(_.tp).partition { - case tp if tp.id.name.endsWith("@") => true - case _ => false - } - val others = otherTPs.toSet[TypeTree] - // for each of the type parameter pick one representative from its equvialence class - val tpMap = fd.tparams.map { - case TypeParameterDef(tp) => - val tpclass = equivTPs.getOrElse(tp, Set(tp)) - val candReps = tpclass.filter(r => others.contains(r) || !r.isInstanceOf[TypeParameter]) - val concRep = candReps.find(!_.isInstanceOf[TypeParameter]) - val rep = - if (concRep.isDefined) // there exists a concrete type ? - concRep.get - else if (!candReps.isEmpty) - candReps.head - else - throw new IllegalStateException(s"Cannot find a non-placeholder in equivalence class $tpclass for fundef: \n $fd") - tp -> rep - }.toMap - val instf = instantiateTypeParameters(tpMap) _ - val paramMap = fd.params.map { - case vd @ ValDef(id, _) => - (id -> FreshIdentifier(id.name, instf(vd.getType))) - }.toMap - val ntparams = tpMap.values.toSeq.distinct.collect { case tp: TypeParameter => tp } map TypeParameterDef - val nfd = new FunDef(fd.id.freshen, ntparams, instf(fd.returnType), - fd.params.map(vd => ValDef(paramMap(vd.id)))) - fd -> (nfd, tpMap, paramMap) - }.toMap - - /* - * Replace fundefs and unify type parameters in function invocations. - * Replace old parameters by new parameters - */ - def transformFunBody(ifd: FunDef) = { - val (nfd, tpMap, paramMap) = fdMap(ifd) - // need to handle tuple select specially as it breaks if the type of - // the tupleExpr if it is not TupleTyped. - // cannot use simplePostTransform because of this - def rec(e: Expr): Expr = e match { - case FunctionInvocation(TypedFunDef(callee, targsOld), args) => - val targs = targsOld.map { - case tp: TypeParameter => tpMap.getOrElse(tp, tp) - case t => t - }.distinct - val ncallee = - if (fdMap.contains(callee)) - fdMap(callee)._1 - else callee - FunctionInvocation(TypedFunDef(ncallee, targs), args map rec) - case Variable(id) if paramMap.contains(id) => - paramMap(id).toVariable - case TupleSelect(tup, index) => TupleSelect(rec(tup), index) - case Operator(args, op) => op(args map rec) - case t: Terminal => t - } - val nbody = rec(ifd.fullBody) - //println("Inferring types for: "+ifd.id) - val initGamma = nfd.params.map(vd => vd.id -> vd.getType).toMap - inferTypesOfLocals(nbody, initGamma) - } - copyProgram(prog, (defs: Seq[Definition]) => defs.map { - case fd: FunDef if fdMap.contains(fd) => - val nfd = fdMap(fd)._1 - if (!fd.fullBody.isInstanceOf[NoTree]) { - nfd.fullBody = simplifyLetsAndLetsWithTuples(transformFunBody(fd)) - } - fd.flags.foreach(nfd.addFlag(_)) - nfd - case d => d - }) - } - - import leon.solvers._ - import leon.solvers.z3._ - - def checkSpecifications(prog: Program) { - val ctx = Main.processOptions(Seq("--solvers=smt-cvc4","--debug=solver")) - val report = AnalysisPhase.run(ctx)(prog) - println(report.summaryString) - /*val timeout = 10 - val rep = ctx.reporter - * val fun = prog.definedFunctions.find(_.id.name == "firstUnevaluated").get - // create a verification context. - val solver = new FairZ3Solver(ctx, prog) with TimeoutSolver - val solverF = new TimeoutSolverFactory(SolverFactory(() => solver), timeout * 1000) - val vctx = VerificationContext(ctx, prog, solverF, rep) - val vc = (new DefaultTactic(vctx)).generatePostconditions(fun)(0) - val s = solverF.getNewSolver() - try { - rep.info(s" - Now considering '${vc.kind}' VC for ${vc.fd.id} @${vc.getPos}...") - val tStart = System.currentTimeMillis - s.assertVC(vc) - val satResult = s.check - val dt = System.currentTimeMillis - tStart - val res = satResult match { - case None => - rep.info("Cannot prove or disprove specifications") - case Some(false) => - rep.info("Valid") - case Some(true) => - println("Invalid - counter-example: " + s.getModel) - } - } finally { - s.free() - }*/ - } - - /** - * NOT USED CURRENTLY - * Lift the specifications on functions to the invariants corresponding - * case classes. - * Ideally we should class invariants here, but it is not currently supported - * so we create a functions that can be assume in the pre and post of functions. - * TODO: can this be optimized - */ - /* def liftSpecsToClosures(opToAdt: Map[FunDef, CaseClassDef]) = { - val invariants = opToAdt.collect { - case (fd, ccd) if fd.hasPrecondition => - val transFun = (args: Seq[Identifier]) => { - val argmap: Map[Expr, Expr] = (fd.params.map(_.id.toVariable) zip args.map(_.toVariable)).toMap - replace(argmap, fd.precondition.get) - } - (ccd -> transFun) - }.toMap - val absTypes = opToAdt.values.collect { - case cd if cd.parent.isDefined => cd.parent.get.classDef - } - val invFuns = absTypes.collect { - case abs if abs.knownCCDescendents.exists(invariants.contains) => - val absType = AbstractClassType(abs, abs.tparams.map(_.tp)) - val param = ValDef(FreshIdentifier("$this", absType)) - val tparams = abs.tparams - val invfun = new FunDef(FreshIdentifier(abs.id.name + "$Inv", Untyped), - tparams, BooleanType, Seq(param)) - (abs -> invfun) - }.toMap - // assign bodies for the 'invfuns' - invFuns.foreach { - case (abs, fd) => - val bodyCases = abs.knownCCDescendents.collect { - case ccd if invariants.contains(ccd) => - val ctype = CaseClassType(ccd, fd.tparams.map(_.tp)) - val cvar = FreshIdentifier("t", ctype) - val fldids = ctype.fields.map { - case ValDef(fid, Some(fldtpe)) => - FreshIdentifier(fid.name, fldtpe) - } - val pattern = CaseClassPattern(Some(cvar), ctype, - fldids.map(fid => WildcardPattern(Some(fid)))) - val rhsInv = invariants(ccd)(fldids) - // assert the validity of substructures - val rhsValids = fldids.flatMap { - case fid if fid.getType.isInstanceOf[ClassType] => - val t = fid.getType.asInstanceOf[ClassType] - val rootDef = t match { - case absT: AbstractClassType => absT.classDef - case _ if t.parent.isDefined => - t.parent.get.classDef - } - if (invFuns.contains(rootDef)) { - List(FunctionInvocation(TypedFunDef(invFuns(rootDef), t.tps), - Seq(fid.toVariable))) - } else - List() - case _ => List() - } - val rhs = Util.createAnd(rhsInv +: rhsValids) - MatchCase(pattern, None, rhs) - } - // create a default case - val defCase = MatchCase(WildcardPattern(None), None, Util.tru) - val matchExpr = MatchExpr(fd.params.head.id.toVariable, bodyCases :+ defCase) - fd.body = Some(matchExpr) - } - invFuns - }*/ -} - -- GitLab