diff --git a/src/main/scala/leon/solvers/templates/LambdaManager.scala b/src/main/scala/leon/solvers/templates/LambdaManager.scala index 202ecc7359e3b7a0c90ef3d0552e49c174c78bc8..0d2396ededd322a54291761e68dd3433ee168b9a 100644 --- a/src/main/scala/leon/solvers/templates/LambdaManager.scala +++ b/src/main/scala/leon/solvers/templates/LambdaManager.scala @@ -9,6 +9,8 @@ import purescala.Expressions._ import purescala.ExprOps._ import purescala.Types._ +import Instantiation._ + class LambdaManager[T](encoder: TemplateEncoder[T]) { private type IdMap = Map[T, LambdaTemplate[T]] private var byIDStack : List[IdMap] = List(Map.empty) @@ -56,10 +58,10 @@ class LambdaManager[T](encoder: TemplateEncoder[T]) { lambdas.foreach(p => freeLambdas += p._1 -> (freeLambdas(p._1) + p._2)) } - def instantiate(apps: Map[T, Set[App[T]]], lambdas: Map[T, LambdaTemplate[T]]) : (Seq[T], Map[T, Set[TemplateCallInfo[T]]], Map[(T, App[T]), Set[TemplateAppInfo[T]]]) = { - var clauses : Seq[T] = Seq.empty - var callBlockers : Map[T, Set[TemplateCallInfo[T]]] = Map.empty.withDefaultValue(Set.empty) - var appBlockers : Map[(T, App[T]), Set[TemplateAppInfo[T]]] = Map.empty.withDefaultValue(Set.empty) + private def instantiate(apps: Map[T, Set[App[T]]], lambdas: Map[T, LambdaTemplate[T]]) : Instantiation[T] = { + var clauses : Clauses[T] = Seq.empty + var callBlockers : CallBlockers[T] = Map.empty.withDefaultValue(Set.empty) + var appBlockers : AppBlockers[T] = Map.empty.withDefaultValue(Set.empty) def mkBlocker(blockedApp: (T, App[T]), lambda: (T, LambdaTemplate[T])) : Unit = { val (_, App(caller, tpe, args)) = blockedApp @@ -106,7 +108,15 @@ class LambdaManager[T](encoder: TemplateEncoder[T]) { (clauses, callBlockers, appBlockers) } - def equalityClauses(idT: T, template: LambdaTemplate[T]): Seq[T] = { + def instantiateLambda(idT: T, template: LambdaTemplate[T]): Instantiation[T] = { + val eqClauses = equalityClauses(idT, template) + val (clauses, blockers, apps) = instantiate(Map.empty, Map(idT -> template)) + (eqClauses ++ clauses, blockers, apps) + } + + def instantiateApps(apps: Map[T, Set[App[T]]]): Instantiation[T] = instantiate(apps, Map.empty) + + private def equalityClauses(idT: T, template: LambdaTemplate[T]): Seq[T] = { byType(template.tpe).map { case (thatIdT, that) => val equals = encoder.mkEquals(idT, thatIdT) template.contextEquality(that) match { diff --git a/src/main/scala/leon/solvers/templates/Templates.scala b/src/main/scala/leon/solvers/templates/Templates.scala index d5a5cbf49f3fc04d69d2dfd193f4de3383f317ec..aa8778f057e8b5dec6c8788f1e003d9bdf390ca8 100644 --- a/src/main/scala/leon/solvers/templates/Templates.scala +++ b/src/main/scala/leon/solvers/templates/Templates.scala @@ -17,6 +17,30 @@ case class App[T](caller: T, tpe: TypeTree, args: Seq[T]) { } } +object Instantiation { + type Clauses[T] = Seq[T] + type CallBlockers[T] = Map[T, Set[TemplateCallInfo[T]]] + type AppBlockers[T] = Map[(T, App[T]), Set[TemplateAppInfo[T]]] + type Instantiation[T] = (Clauses[T], CallBlockers[T], AppBlockers[T]) + + def empty[T] = (Seq.empty[T], Map.empty[T, Set[TemplateCallInfo[T]]], Map.empty[(T, App[T]), Set[TemplateAppInfo[T]]]) + + implicit class InstantiationWrapper[T](i: Instantiation[T]) { + def merge(that: Instantiation[T]): Instantiation[T] = { + val (thisClauses, thisBlockers, thisApps) = i + val (thatClauses, thatBlockers, thatApps) = that + + ( + thisClauses ++ thatClauses, + (thisBlockers.keys ++ thatBlockers.keys).map(k => k -> (thisBlockers.getOrElse(k, Set.empty) ++ thatBlockers.getOrElse(k, Set.empty))).toMap, + (thisApps.keys ++ thatApps.keys).map(k => k -> (thisApps.getOrElse(k, Set.empty) ++ thatApps.getOrElse(k, Set.empty))).toMap + ) + } + } +} + +import Instantiation.{empty => _, _} + trait Template[T] { self => val encoder : TemplateEncoder[T] val lambdaManager : LambdaManager[T] @@ -32,7 +56,7 @@ trait Template[T] { self => private var substCache : Map[Seq[T],Map[T,T]] = Map.empty - def instantiate(aVar: T, args: Seq[T]): (Seq[T], Map[T, Set[TemplateCallInfo[T]]], Map[(T, App[T]), Set[TemplateAppInfo[T]]]) = { + def instantiate(aVar: T, args: Seq[T]): Instantiation[T] = { val baseSubstMap : Map[T,T] = substCache.get(args) match { case Some(subst) => subst @@ -57,22 +81,17 @@ trait Template[T] { self => substituter(b) -> fas.map(fa => fa.copy(caller = substituter(fa.caller), args = fa.args.map(substituter))) } - val (newLambdas, lambdaClauses) = lambdas.foldLeft((Map.empty[T,LambdaTemplate[T]], Seq.empty[T])) { - case ((newLambdas, clauses), (idT, lambda)) => + val lambdaInstantiation = lambdas.foldLeft(Instantiation.empty[T]) { + case (acc, (idT, lambda)) => val newIdT = substituter(idT) val newTemplate = lambda.substitute(substMap) - val eqClauses = lambdaManager.equalityClauses(newIdT, newTemplate) - (newLambdas + (newIdT -> newTemplate), clauses ++ eqClauses) + val instantiation = lambdaManager.instantiateLambda(newIdT, newTemplate) + acc merge instantiation } - val (appClauses, appBlockers, appApps) = lambdaManager.instantiate(newApplications, newLambdas) - - val allClauses = newClauses ++ appClauses ++ lambdaClauses - val allBlockers = (newBlockers.keys ++ appBlockers.keys).map { k => - k -> (newBlockers.getOrElse(k, Set.empty) ++ appBlockers.getOrElse(k, Set.empty)) - }.toMap + val appInstantiation = lambdaManager.instantiateApps(newApplications) - (allClauses, allBlockers, appApps) + (newClauses, newBlockers, Map.empty[(T, App[T]), Set[TemplateAppInfo[T]]]) merge lambdaInstantiation merge appInstantiation } override def toString : String = "Instantiated template" @@ -368,7 +387,7 @@ class LambdaTemplate[T] private ( val applications: Map[T, Set[App[T]]], val lambdas: Map[T, LambdaTemplate[T]], private[templates] val dependencies: Map[Identifier, T], - private val structuralKey: Lambda, + private[templates] val structuralKey: Lambda, stringRepr: () => String) extends Template[T] { val tpe = id.getType diff --git a/src/test/resources/regression/verification/purescala/valid/LambdaEquality.scala b/src/test/resources/regression/verification/purescala/valid/LambdaEquality.scala new file mode 100644 index 0000000000000000000000000000000000000000..b7f589cf74109044c8caf3d86b1d04c65fb7b180 --- /dev/null +++ b/src/test/resources/regression/verification/purescala/valid/LambdaEquality.scala @@ -0,0 +1,32 @@ + +import leon.lang._ +import leon.annotation._ +import leon.collection._ + +object LambdaEquality { + def app[A, B](x: A)(f: A => B): B = f(x) + + @induct + def mapId[A](xs: List[A]): Boolean = { + xs.map(id) == xs && { + xs match { + case Nil() => true + case Cons(x, xs) => + mapId(xs) && + id(x) :: xs.map(id) == x :: xs && + true + } + } + }.holds + + def id[A](a: A): A = a + + def mapEquality1[A, B](xs: List[A])(f: A => B): Boolean = { + xs.map(f) == xs.map(f) + }.holds + + def mapEquality2[A, B](xs: List[A])(f: A => B): Boolean = { + xs.map(id) == xs.map(id) + }.holds + +}