diff --git a/src/main/scala/leon/solvers/templates/LambdaManager.scala b/src/main/scala/leon/solvers/templates/LambdaManager.scala index cfecfacda32f2ff7724b7e4cb27eff28887ea701..7249313f16c1f171d1ef15888616ebff1108c468 100644 --- a/src/main/scala/leon/solvers/templates/LambdaManager.scala +++ b/src/main/scala/leon/solvers/templates/LambdaManager.scala @@ -38,19 +38,11 @@ class LambdaManager[T](encoder: TemplateEncoder[T]) { freeLambdasStack = map :: freeLambdasStack.tail } - private type StructuralMap = Map[Lambda, List[(T, LambdaTemplate[T])]] - private var structuralLambdasStack : List[StructuralMap] = List(Map.empty.withDefaultValue(List.empty)) - private def structuralLambdas : StructuralMap = structuralLambdasStack.head - private def structuralLambdas_=(map: StructuralMap) : Unit = { - structuralLambdasStack = map :: structuralLambdasStack.tail - } - def push(): Unit = { byIDStack = byID :: byIDStack byTypeStack = byType :: byTypeStack applicationsStack = applications :: applicationsStack freeLambdasStack = freeLambdas :: freeLambdasStack - structuralLambdasStack = structuralLambdas :: structuralLambdasStack } def pop(lvl: Int): Unit = { @@ -58,7 +50,6 @@ class LambdaManager[T](encoder: TemplateEncoder[T]) { byTypeStack = byTypeStack.drop(lvl) applicationsStack = applicationsStack.drop(lvl) freeLambdasStack = freeLambdasStack.drop(lvl) - structuralLambdasStack = structuralLambdasStack.drop(lvl) } def registerFree(lambdas: Seq[(TypeTree, T)]): Unit = { @@ -79,10 +70,8 @@ class LambdaManager[T](encoder: TemplateEncoder[T]) { } for (lambda @ (idT, template) <- lambdas) { - // get all lambda references... - val lambdaRefs = freeLambdas(template.tpe) ++ byType(template.tpe).map(_._1) - // ... and make sure the new lambda isn't equal to one of them! - clauses ++= lambdaRefs.map(pIdT => encoder.mkNot(encoder.mkEquals(pIdT, idT))) + // make sure the new lambda isn't equal to any free lambda var + clauses ++= freeLambdas(template.tpe).map(pIdT => encoder.mkNot(encoder.mkEquals(pIdT, idT))) byID += idT -> template byType += template.tpe -> (byType(template.tpe) + (idT -> template)) @@ -117,22 +106,15 @@ class LambdaManager[T](encoder: TemplateEncoder[T]) { (clauses, callBlockers, appBlockers) } - def equalityClauses(template: LambdaTemplate[T], idT: T, substMap: Map[T,T]): Seq[T] = { - val key : Lambda = template.key - val t : LambdaTemplate[T] = template.substitute(substMap) - - val newClauses = structuralLambdas(key).map { case (thatIdT, that) => + def equalityClauses(idT: T, template: LambdaTemplate[T]): Seq[T] = { + byType(template.tpe).map { case (thatIdT, that) => val equals = encoder.mkEquals(idT, thatIdT) - if (t.dependencies.isEmpty) { - equals - } else { - encoder.mkImplies(t.contextEquality(that), equals) + template.contextEquality(that) match { + case None => encoder.mkNot(equals) + case Some(Seq()) => equals + case Some(seq) => encoder.mkImplies(encoder.mkAnd(seq : _*), equals) } - } - - structuralLambdas += key -> (structuralLambdas(key) :+ (idT -> t)) - - newClauses + }.toSeq } } diff --git a/src/main/scala/leon/solvers/templates/Templates.scala b/src/main/scala/leon/solvers/templates/Templates.scala index fb73c599d2bf805f0f60b796ed8b0e38f92b15a8..c7e79263f807141565644ac864cafa59ad70a8b7 100644 --- a/src/main/scala/leon/solvers/templates/Templates.scala +++ b/src/main/scala/leon/solvers/templates/Templates.scala @@ -45,13 +45,7 @@ trait Template[T] { self => subst } - val (lambdaSubstMap, lambdaClauses) = lambdas.foldLeft((Map.empty[T,T], Seq.empty[T])) { - case ((subst, clauses), (idT, lambda)) => - val newIdT = encoder.encodeId(lambda.id) - val eqClauses = lambdaManager.equalityClauses(lambda, newIdT, baseSubstMap) - - (subst + (idT -> newIdT), clauses ++ eqClauses) - } + val lambdaSubstMap = lambdas.map { case (idT, lambda) => idT -> encoder.encodeId(lambda.id) } val substMap : Map[T,T] = baseSubstMap ++ lambdaSubstMap + (start -> aVar) val substituter : T => T = encoder.substitute(substMap) @@ -65,8 +59,12 @@ trait Template[T] { self => substituter(b) -> fas.map(fa => fa.copy(caller = substituter(fa.caller), args = fa.args.map(substituter))) } - val newLambdas = lambdas.map { case (idT, lambda) => - substituter(idT) -> lambda.substitute(substMap) + val (newLambdas, lambdaClauses) = lambdas.foldLeft((Map.empty[T,LambdaTemplate[T]], Seq.empty[T])) { + case ((newLambdas, clauses), (idT, lambda)) => + val newIdT = substituter(idT) + val newTemplate = lambda.substitute(substMap) + val eqClauses = lambdaManager.equalityClauses(newIdT, newTemplate) + (newLambdas + (newIdT -> newTemplate), clauses ++ eqClauses) } val (appClauses, appBlockers, appApps) = lambdaManager.instantiate(newApplications, newLambdas) @@ -229,7 +227,7 @@ object LambdaTemplate { private var typedIds : Map[TypeTree, List[Identifier]] = Map.empty.withDefaultValue(List.empty) - private def templateKey[T](lambda: LambdaTemplate[T]): Lambda = { + private def structuralKey[T](lambda: Lambda, dependencies: Map[Identifier, T]): (Lambda, Map[Identifier,T]) = { def closureIds(expr: Expr): Seq[Identifier] = { val vars = variablesOf(expr) @@ -243,7 +241,7 @@ object LambdaTemplate { allVars.filter(vars(_)).distinct } - val grouped : Map[TypeTree, Seq[Identifier]] = closureIds(lambda.lambda).groupBy(_.getType) + val grouped : Map[TypeTree, Seq[Identifier]] = closureIds(lambda).groupBy(_.getType) val subst : Map[Identifier, Identifier] = grouped.foldLeft(Map.empty[Identifier,Identifier]) { case (subst, (tpe, ids)) => val currentVars = typedIds(tpe) @@ -259,7 +257,10 @@ object LambdaTemplate { subst ++ (ids zip typedVars) } - replaceFromIDs(subst.mapValues(_.toVariable), lambda.lambda).asInstanceOf[Lambda] + val structuralLambda = replaceFromIDs(subst.mapValues(_.toVariable), lambda).asInstanceOf[Lambda] + val newDeps = dependencies.map { case (id, idT) => subst(id) -> idT } + + structuralLambda -> newDeps } def apply[T]( @@ -287,6 +288,8 @@ object LambdaTemplate { "Template for lambda " + ids._1 + ": " + lambda + " is :\n" + templateString() } + val (key, keyDeps) = structuralKey(lambda, dependencies) + new LambdaTemplate[T]( ids._1, encoder, @@ -299,8 +302,8 @@ object LambdaTemplate { blockers, applications, lambdas, - dependencies, - lambda, + keyDeps, + key, lambdaString ) } @@ -319,7 +322,7 @@ class LambdaTemplate[T] private ( val applications: Map[T, Set[App[T]]], val lambdas: Map[T, LambdaTemplate[T]], private[templates] val dependencies: Map[Identifier, T], - private val lambda: Lambda, + private val structuralKey: Lambda, stringRepr: () => String) extends Template[T] { val tpe = id.getType @@ -356,7 +359,7 @@ class LambdaTemplate[T] private ( newApplications, newLambdas, newDependencies, - lambda, + structuralKey, stringRepr ) } @@ -364,32 +367,33 @@ class LambdaTemplate[T] private ( private lazy val str : String = stringRepr() override def toString : String = str - def contextEquality(that: LambdaTemplate[T]) : T = { - assert(key == that.key, "Can't generate equality clause for lambdas that don't share structure") - assert(dependencies.nonEmpty, "No closures implies obvious equality") - - def rec(e1: Expr, e2: Expr): Seq[T] = (e1,e2) match { - case (Variable(id1), Variable(id2)) => - if (dependencies.isDefinedAt(id1)) { - Seq(encoder.mkEquals(dependencies(id1), that.dependencies(id2))) - } else { - Seq.empty - } - - case (NAryOperator(es1, _), NAryOperator(es2, _)) => - (es1 zip es2).flatMap(p => rec(p._1, p._2)) - - case (BinaryOperator(e11, e12, _), BinaryOperator(e21, e22, _)) => - rec(e11, e21) ++ rec(e12, e22) - - case (UnaryOperator(ue1, _), UnaryOperator(ue2, _)) => - rec(ue1, ue2) + def contextEquality(that: LambdaTemplate[T]) : Option[Seq[T]] = { + if (structuralKey != that.structuralKey) { + None // can't be equal + } else if (dependencies.isEmpty) { + Some(Seq.empty) // must be equal + } else { + def rec(e1: Expr, e2: Expr): Seq[T] = (e1,e2) match { + case (Variable(id1), Variable(id2)) => + if (dependencies.isDefinedAt(id1)) { + Seq(encoder.mkEquals(dependencies(id1), that.dependencies(id2))) + } else { + Seq.empty + } + + case (NAryOperator(es1, _), NAryOperator(es2, _)) => + (es1 zip es2).flatMap(p => rec(p._1, p._2)) + + case (BinaryOperator(e11, e12, _), BinaryOperator(e21, e22, _)) => + rec(e11, e21) ++ rec(e12, e22) + + case (UnaryOperator(ue1, _), UnaryOperator(ue2, _)) => + rec(ue1, ue2) + + case _ => Seq.empty + } - case _ => Seq.empty + Some(rec(structuralKey, that.structuralKey)) } - - encoder.mkAnd(rec(lambda, that.lambda) : _*) } - - def key : Lambda = LambdaTemplate.templateKey(this) }