diff --git a/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala b/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala index b9c0dfd94e449ac5a7d130e1279da8ebb7c19b92..1235d69c6d16a2726971cac386732fdcea501b95 100644 --- a/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala +++ b/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala @@ -16,6 +16,7 @@ import utils._ import z3.FairZ3Component.{optFeelingLucky, optUseCodeGen, optAssumePre, optNoChecks, optUnfoldFactor} import templates._ import evaluators._ +import Template._ class UnrollingSolver(val context: LeonContext, val program: Program, underlying: Solver) extends Solver @@ -116,7 +117,7 @@ class UnrollingSolver(val context: LeonContext, val program: Program, underlying val optEnabler = evaluator.eval(b, model).result if (optEnabler == Some(BooleanLiteral(true))) { - val optArgs = m.args.map(arg => evaluator.eval(Matcher.argValue(arg), model).result) + val optArgs = m.args.map(arg => evaluator.eval(arg.encoded, model).result) if (optArgs.forall(_.isDefined)) { Set(optArgs.map(_.get)) } else { diff --git a/src/main/scala/leon/solvers/templates/LambdaManager.scala b/src/main/scala/leon/solvers/templates/LambdaManager.scala index d1aa0377a8f939830cb877c5aba78071d74c2de6..036654c25ab64839efa5af4cb0d49a6566ffcf20 100644 --- a/src/main/scala/leon/solvers/templates/LambdaManager.scala +++ b/src/main/scala/leon/solvers/templates/LambdaManager.scala @@ -12,9 +12,10 @@ import purescala.Types._ import utils._ import Instantiation._ +import Template._ -case class App[T](caller: T, tpe: FunctionType, args: Seq[T]) { - override def toString = "(" + caller + " : " + tpe + ")" + args.mkString("(", ",", ")") +case class App[T](caller: T, tpe: FunctionType, args: Seq[Arg[T]]) { + override def toString = "(" + caller + " : " + tpe + ")" + args.map(_.encoded).mkString("(", ",", ")") } object LambdaTemplate { @@ -92,14 +93,6 @@ trait KeyedTemplate[T, E <: Expr] { structuralKey -> rec(structuralKey).distinct.map(dependencies) } - - override def equals(that: Any): Boolean = that match { - case t: KeyedTemplate[T, E] => - key == t.key - case _ => false - } - - override def hashCode: Int = key.hashCode } class LambdaTemplate[T] private ( @@ -124,27 +117,32 @@ class LambdaTemplate[T] private ( val args = arguments.map(_._2) val tpe = ids._1.getType.asInstanceOf[FunctionType] - def substitute(substituter: T => T): LambdaTemplate[T] = { + def substitute(substituter: T => T, matcherSubst: Map[T, Matcher[T]]): LambdaTemplate[T] = { val newStart = substituter(start) val newClauses = clauses.map(substituter) val newBlockers = blockers.map { case (b, fis) => val bp = if (b == start) newStart else b - bp -> fis.map(fi => fi.copy(args = fi.args.map(substituter))) + bp -> fis.map(fi => fi.copy( + args = fi.args.map(_.substitute(substituter, matcherSubst)) + )) } val newApplications = applications.map { case (b, fas) => val bp = if (b == start) newStart else b - bp -> fas.map(fa => fa.copy(caller = substituter(fa.caller), args = fa.args.map(substituter))) + bp -> fas.map(fa => fa.copy( + caller = substituter(fa.caller), + args = fa.args.map(_.substitute(substituter, matcherSubst)) + )) } - val newQuantifications = quantifications.map(_.substitute(substituter)) + val newQuantifications = quantifications.map(_.substitute(substituter, matcherSubst)) val newMatchers = matchers.map { case (b, ms) => val bp = if (b == start) newStart else b - bp -> ms.map(_.substitute(substituter)) + bp -> ms.map(_.substitute(substituter, matcherSubst)) } - val newLambdas = lambdas.map(_.substitute(substituter)) + val newLambdas = lambdas.map(_.substitute(substituter, matcherSubst)) val newDependencies = dependencies.map(p => p._1 -> substituter(p._2)) @@ -182,7 +180,7 @@ class LambdaTemplate[T] private ( private lazy val str : String = stringRepr() override def toString : String = str - override def instantiate(substMap: Map[T, T]): Instantiation[T] = { + override def instantiate(substMap: Map[T, Arg[T]]): Instantiation[T] = { super.instantiate(substMap) ++ manager.instantiateAxiom(this, substMap) } } @@ -221,7 +219,7 @@ class LambdaManager[T](encoder: TemplateEncoder[T]) extends TemplateManager(enco var appBlockers : AppBlockers[T] = Map.empty.withDefaultValue(Set.empty) // make sure the new lambda isn't equal to any free lambda var - clauses ++= freeLambdas(newTemplate.tpe).map(pIdT => encoder.mkNot(encoder.mkEquals(pIdT, idT))) + clauses ++= freeLambdas(newTemplate.tpe).map(pIdT => encoder.mkNot(encoder.mkEquals(idT, pIdT))) byID += idT -> newTemplate byType += newTemplate.tpe -> (byType(newTemplate.tpe) + (newTemplate.key -> newTemplate)) diff --git a/src/main/scala/leon/solvers/templates/QuantificationManager.scala b/src/main/scala/leon/solvers/templates/QuantificationManager.scala index 314c7549c0899488ba9e78c9739e4718e350b608..bc31267c09d8023390ac541b0cb99ecece6691ea 100644 --- a/src/main/scala/leon/solvers/templates/QuantificationManager.scala +++ b/src/main/scala/leon/solvers/templates/QuantificationManager.scala @@ -14,23 +14,17 @@ import purescala.Types._ import purescala.Quantification.{QuantificationTypeMatcher => QTM} import Instantiation._ +import Template._ import scala.collection.mutable.{Map => MutableMap, Set => MutableSet, Stack => MutableStack, Queue} -object Matcher { - def argValue[T](arg: Either[T, Matcher[T]]): T = arg match { - case Left(value) => value - case Right(matcher) => matcher.encoded - } -} - -case class Matcher[T](caller: T, tpe: TypeTree, args: Seq[Either[T, Matcher[T]]], encoded: T) { +case class Matcher[T](caller: T, tpe: TypeTree, args: Seq[Arg[T]], encoded: T) { override def toString = caller + args.map { case Right(m) => m.toString case Left(v) => v.toString }.mkString("(",",",")") - def substitute(substituter: T => T, matcherSubst: Map[T, Matcher[T]] = Map.empty): Matcher[T] = copy( + def substitute(substituter: T => T, matcherSubst: Map[T, Matcher[T]]): Matcher[T] = copy( caller = substituter(caller), args = args.map { case Left(v) => matcherSubst.get(v) match { @@ -64,7 +58,7 @@ class QuantificationTemplate[T]( lazy val start = pathVar._2 - def substitute(substituter: T => T): QuantificationTemplate[T] = { + def substitute(substituter: T => T, matcherSubst: Map[T, Matcher[T]]): QuantificationTemplate[T] = { new QuantificationTemplate[T]( quantificationManager, pathVar._1 -> substituter(start), @@ -78,15 +72,20 @@ class QuantificationTemplate[T]( condTree, clauses.map(substituter), blockers.map { case (b, fis) => - substituter(b) -> fis.map(fi => fi.copy(args = fi.args.map(substituter))) + substituter(b) -> fis.map(fi => fi.copy( + args = fi.args.map(_.substitute(substituter, matcherSubst)) + )) }, applications.map { case (b, apps) => - substituter(b) -> apps.map(app => app.copy(caller = substituter(app.caller), args = app.args.map(substituter))) + substituter(b) -> apps.map(app => app.copy( + caller = substituter(app.caller), + args = app.args.map(_.substitute(substituter, matcherSubst)) + )) }, matchers.map { case (b, ms) => - substituter(b) -> ms.map(_.substitute(substituter)) + substituter(b) -> ms.map(_.substitute(substituter, matcherSubst)) }, - lambdas.map(_.substitute(substituter)), + lambdas.map(_.substitute(substituter, matcherSubst)), dependencies.map { case (id, value) => id -> substituter(value) }, structuralKey ) @@ -397,7 +396,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage * matchers in the argument and quorum positions */ allMappings.filter { s => - def expand(ms: Traversable[(Either[T,Matcher[T]], Either[T,Matcher[T]])]): Set[(Matcher[T], Matcher[T])] = ms.flatMap { + def expand(ms: Traversable[(Arg[T], Arg[T])]): Set[(Matcher[T], Matcher[T])] = ms.flatMap { case (Right(qm), Right(m)) => Set(qm -> m) ++ expand(qm.args zip m.args) case _ => Set.empty[(Matcher[T], Matcher[T])] }.toSet @@ -409,11 +408,11 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage } } - private def extractSubst(mapping: Set[(Set[T], Matcher[T], Matcher[T])]): (Set[T], Map[T,Either[T, Matcher[T]]], Boolean) = { + private def extractSubst(mapping: Set[(Set[T], Matcher[T], Matcher[T])]): (Set[T], Map[T,Arg[T]], Boolean) = { var constraints: Set[T] = Set.empty var eqConstraints: Set[(T, T)] = Set.empty var matcherEqs: List[(T, T)] = Nil - var subst: Map[T, Either[T, Matcher[T]]] = Map.empty + var subst: Map[T, Arg[T]] = Map.empty for { (bs, qm @ Matcher(qcaller, _, qargs, _), m @ Matcher(caller, _, args, _)) <- mapping @@ -422,16 +421,16 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage (qarg, arg) <- (qargs zip args) } qarg match { case Left(quant) if subst.isDefinedAt(quant) => - eqConstraints += (quant -> Matcher.argValue(arg)) + eqConstraints += (quant -> arg.encoded) case Left(quant) if quantified(quant) => subst += quant -> arg case Right(qam) => - val argVal = Matcher.argValue(arg) + val argVal = arg.encoded eqConstraints += (qam.encoded -> argVal) matcherEqs :+= qam.encoded -> argVal } - val substituter = encoder.substitute(subst.mapValues(Matcher.argValue)) + val substituter = encoder.substitute(subst.mapValues(_.encoded)) val substConstraints = constraints.filter(_ != trueT).map(substituter) val substEqs = eqConstraints.map(p => substituter(p._1) -> p._2) .filter(p => p._1 != p._2).map(p => encoder.mkEquals(p._1, p._2)) @@ -448,25 +447,25 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage val (enablers, subst, isStrict) = extractSubst(mapping) val (enabler, optEnabler) = freshBlocker(enablers) - if (optEnabler.isDefined) { - instantiation = instantiation withClause encoder.mkEquals(enabler, optEnabler.get) - } - - val baseSubst = subst.mapValues(Matcher.argValue) ++ instanceSubst(enablers) + val baseSubst = subst ++ instanceSubst(enablers).mapValues(Left(_)) val (substMap, inst) = Template.substitution(encoder, QuantificationManager.this, - exprVars, condVars, condTree, Seq.empty, lambdas, baseSubst, pathVar._1, enabler) + condVars, exprVars, condTree, Seq.empty, lambdas, baseSubst, pathVar._1, enabler) if (!skip(substMap)) { + if (optEnabler.isDefined) { + instantiation = instantiation withClause encoder.mkEquals(enabler, optEnabler.get) + } + instantiation ++= inst instantiation ++= Template.instantiate(encoder, QuantificationManager.this, clauses, blockers, applications, Seq.empty, Map.empty[T, Set[Matcher[T]]], lambdas, substMap) - val msubst = subst.collect { case (c, Right(m)) => c -> m } - val substituter = encoder.substitute(substMap) + val msubst = substMap.collect { case (c, Right(m)) => c -> m } + val substituter = encoder.substitute(substMap.mapValues(_.encoded)) for ((b,ms) <- allMatchers; m <- ms) { val sb = enablers ++ (if (b == start) Set.empty else Set(substituter(b))) - val sm = m.substitute(substituter, matcherSubst = msubst) + val sm = m.substitute(substituter, msubst) if (matchers(m)) { handled += sb -> sm @@ -484,7 +483,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage protected def instanceSubst(enablers: Set[T]): Map[T, T] - protected def skip(subst: Map[T, T]): Boolean = false + protected def skip(subst: Map[T, Arg[T]]): Boolean = false } private class Quantification ( @@ -556,10 +555,11 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage Map(guardVar -> guardT, blocker -> newBlocker) } - override protected def skip(subst: Map[T, T]): Boolean = { - val substituter = encoder.substitute(subst) + override protected def skip(subst: Map[T, Arg[T]]): Boolean = { + val substituter = encoder.substitute(subst.mapValues(_.encoded)) + val msubst = subst.collect { case (c, Right(m)) => c -> m } allMatchers.forall { case (b, ms) => - ms.forall(m => matchers(m) || instCtx(Set(substituter(b)) -> m.substitute(substituter))) + ms.forall(m => matchers(m) || instCtx(Set(substituter(b)) -> m.substitute(substituter, msubst))) } } } @@ -588,9 +588,9 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage (m: Matcher[T]) => m.args.collect { case Left(a) if quantified(a) => a }.toSet) } - def instantiateAxiom(template: LambdaTemplate[T], substMap: Map[T, T]): Instantiation[T] = { - val quantifiers = template.arguments map { - case (id, idT) => id -> substMap(idT) + def instantiateAxiom(template: LambdaTemplate[T], substMap: Map[T, Arg[T]]): Instantiation[T] = { + val quantifiers = template.arguments flatMap { + case (id, idT) => substMap(idT).left.toOption.map(id -> _) } filter (p => isQuantifier(p._2)) if (quantifiers.isEmpty || lambdaAxioms(template -> quantifiers)) { @@ -602,19 +602,24 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage val guard = FreshIdentifier("guard", BooleanType, true) val guardT = encoder.encodeId(guard) - val substituter = encoder.substitute(substMap + (template.start -> blockerT)) - val allMatchers = template.matchers map { case (b, ms) => substituter(b) -> ms.map(_.substitute(substituter)) } + val substituter = encoder.substitute(substMap.mapValues(_.encoded) + (template.start -> blockerT)) + val msubst = substMap.collect { case (c, Right(m)) => c -> m } + + val allMatchers = template.matchers map { case (b, ms) => + substituter(b) -> ms.map(_.substitute(substituter, msubst)) + } + val qMatchers = allMatchers.flatMap(_._2).toSet - val encArgs = template.args map substituter + val encArgs = template.args map (arg => Left(arg).substitute(substituter, msubst)) val app = Application(Variable(template.ids._1), template.arguments.map(_._1.toVariable)) - val appT = encoder.encodeExpr((template.arguments.map(_._1) zip encArgs).toMap + template.ids)(app) - val selfMatcher = Matcher(template.ids._2, template.tpe, encArgs.map(Left(_)), appT) + val appT = encoder.encodeExpr((template.arguments.map(_._1) zip encArgs.map(_.encoded)).toMap + template.ids)(app) + val selfMatcher = Matcher(template.ids._2, template.tpe, encArgs, appT) val enablingClause = encoder.mkImplies(guardT, blockerT) instantiateAxiom( - template.pathVar._1 -> substMap(template.start), + template.pathVar._1 -> substituter(template.start), blockerT, guardT, quantifiers, @@ -625,12 +630,17 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage template.condTree, (template.clauses map substituter) :+ enablingClause, template.blockers map { case (b, fis) => - substituter(b) -> fis.map(fi => fi.copy(args = fi.args.map(substituter))) + substituter(b) -> fis.map(fi => fi.copy( + args = fi.args.map(_.substitute(substituter, msubst)) + )) }, template.applications map { case (b, apps) => - substituter(b) -> apps.map(app => app.copy(caller = substituter(app.caller), args = app.args.map(substituter))) + substituter(b) -> apps.map(app => app.copy( + caller = substituter(app.caller), + args = app.args.map(_.substitute(substituter, msubst)) + )) }, - template.lambdas map (_.substitute(substituter)) + template.lambdas map (_.substitute(substituter, msubst)) ) } } @@ -675,7 +685,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage for { m <- matchers - sm = m.substitute(substituter) + sm = m.substitute(substituter, Map.empty) if !instCtx.corresponding(sm).exists(_._2.args == sm.args) } instantiation ++= instCtx.instantiate(Set.empty, sm)(quantifications.toSeq : _*) @@ -732,7 +742,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage for { (_, ms) <- template.matchers; m <- ms - sm = m.substitute(substituter) + sm = m.substitute(substituter, Map.empty) if !instCtx.corresponding(sm).exists(_._2.args == sm.args) } instantiation ++= instCtx.instantiate(Set.empty, sm)(quantifications.toSeq : _*) @@ -787,7 +797,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage var prev = emptyT for ((b, m) <- insts.toSeq) { val next = encoder.encodeId(setNext) - val argsMap = (elems zip m.args).map { case (idT, arg) => idT -> Matcher.argValue(arg) } + val argsMap = (elems zip m.args).map { case (idT, arg) => idT -> arg.encoded } val substMap = Map(guardT -> b, setPrevT -> prev, setNextT -> next) ++ argsMap prev = next clauses += encoder.substitute(substMap)(setT) @@ -795,7 +805,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage val setMap = Map(setPrevT -> prev) for ((b, m) <- ctx.toSeq) { - val substMap = setMap ++ (elems zip m.args).map(p => p._1 -> Matcher.argValue(p._2)) + val substMap = setMap ++ (elems zip m.args).map(p => p._1 -> p._2.encoded) clauses += encoder.substitute(substMap)(encoder.mkImplies(b, containsT)) } } diff --git a/src/main/scala/leon/solvers/templates/TemplateGenerator.scala b/src/main/scala/leon/solvers/templates/TemplateGenerator.scala index fc570b83d85e5a4298ae52f01ead895f928a2061..9648adfe45f0655097e1daaa84e7804edbc1aee8 100644 --- a/src/main/scala/leon/solvers/templates/TemplateGenerator.scala +++ b/src/main/scala/leon/solvers/templates/TemplateGenerator.scala @@ -65,7 +65,9 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], val invocationEqualsBody : Option[Expr] = lambdaBody match { case Some(body) if isRealFunDef => - val b : Expr = And(Equals(invocation, body), liftedEquals(invocation, body, lambdaArguments)) + val b : Expr = And( + liftedEquals(invocation, body, lambdaArguments), + Equals(invocation, body)) Some(if(prec.isDefined) { Implies(prec.get, b) @@ -121,7 +123,7 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], } private def lambdaArgs(expr: Expr): Seq[Identifier] = expr match { - case Lambda(args, body) => args.map(_.id) ++ lambdaArgs(body) + case Lambda(args, body) => args.map(_.id.freshen) ++ lambdaArgs(body) case IsTyped(_, _: FunctionType) => sys.error("Only applicable on lambda chains") case _ => Seq.empty } @@ -133,7 +135,7 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], val arguments = currArgs.map(_.toVariable) val apply = if (inline) application _ else Application val (appliedInv, appliedBody) = (apply(i, arguments), apply(b, arguments)) - Equals(appliedInv, appliedBody) +: rec(appliedInv, appliedBody, nextArgs, false) + rec(appliedInv, appliedBody, nextArgs, false) :+ Equals(appliedInv, appliedBody) case _ => assert(args.isEmpty, "liftedEquals should consume all provided arguments") Seq.empty diff --git a/src/main/scala/leon/solvers/templates/TemplateInfo.scala b/src/main/scala/leon/solvers/templates/TemplateInfo.scala index 033f15dd6f251026a260ed6344212239e1714a37..27df9b25d13412c106f2bf30d6c75b1266b19d93 100644 --- a/src/main/scala/leon/solvers/templates/TemplateInfo.scala +++ b/src/main/scala/leon/solvers/templates/TemplateInfo.scala @@ -5,15 +5,22 @@ package solvers package templates import purescala.Definitions.TypedFunDef +import Template.Arg -case class TemplateCallInfo[T](tfd: TypedFunDef, args: Seq[T]) { +case class TemplateCallInfo[T](tfd: TypedFunDef, args: Seq[Arg[T]]) { override def toString = { - tfd.signature+args.mkString("(", ", ", ")") + tfd.signature + args.map { + case Right(m) => m.toString + case Left(v) => v.toString + }.mkString("(", ", ", ")") } } -case class TemplateAppInfo[T](template: LambdaTemplate[T], equals: T, args: Seq[T]) { +case class TemplateAppInfo[T](template: LambdaTemplate[T], equals: T, args: Seq[Arg[T]]) { override def toString = { - template.ids._2 + "|" + equals + args.mkString("(", ",", ")") + template.ids._2 + "|" + equals + args.map { + case Right(m) => m.toString + case Left(v) => v.toString + }.mkString("(", ",", ")") } } diff --git a/src/main/scala/leon/solvers/templates/TemplateManager.scala b/src/main/scala/leon/solvers/templates/TemplateManager.scala index 32f07a271ce45cbb76e2444f425b593561ce3358..2b75f08f0480cf272515bb8d8393e01e29d4dbf1 100644 --- a/src/main/scala/leon/solvers/templates/TemplateManager.scala +++ b/src/main/scala/leon/solvers/templates/TemplateManager.scala @@ -57,6 +57,7 @@ object Instantiation { } import Instantiation.{empty => _, _} +import Template.Arg trait Template[T] { self => val encoder : TemplateEncoder[T] @@ -76,14 +77,14 @@ trait Template[T] { self => lazy val start = pathVar._2 - def instantiate(aVar: T, args: Seq[T]): Instantiation[T] = { + def instantiate(aVar: T, args: Seq[Arg[T]]): Instantiation[T] = { val (substMap, instantiation) = Template.substitution(encoder, manager, condVars, exprVars, condTree, quantifications, lambdas, - (this.args zip args).toMap + (start -> aVar), pathVar._1, aVar) + (this.args zip args).toMap + (start -> Left(aVar)), pathVar._1, aVar) instantiation ++ instantiate(substMap) } - protected def instantiate(substMap: Map[T, T]): Instantiation[T] = { + protected def instantiate(substMap: Map[T, Arg[T]]): Instantiation[T] = { Template.instantiate(encoder, manager, clauses, blockers, applications, quantifications, matchers, lambdas, substMap) } @@ -93,6 +94,23 @@ trait Template[T] { self => object Template { + type Arg[T] = Either[T, Matcher[T]] + + implicit class ArgWrapper[T](arg: Arg[T]) { + def encoded: T = arg match { + case Left(value) => value + case Right(matcher) => matcher.encoded + } + + def substitute(substituter: T => T, matcherSubst: Map[T, Matcher[T]]): Arg[T] = arg match { + case Left(v) => matcherSubst.get(v) match { + case Some(m) => Right(m) + case None => Left(substituter(v)) + } + case Right(m) => Right(m.substitute(substituter, matcherSubst)) + } + } + private def invocationMatcher[T](encodeExpr: Expr => T)(tfd: TypedFunDef, args: Seq[Expr]): Matcher[T] = { assert(tfd.returnType.isInstanceOf[FunctionType], "invocationMatcher() is only defined on function-typed defs") @@ -131,9 +149,9 @@ object Template { encodeExpr(Implies(Variable(b), e)) }).toSeq - val optIdCall = optCall.map(tfd => TemplateCallInfo[T](tfd, arguments.map(_._2))) + val optIdCall = optCall.map(tfd => TemplateCallInfo[T](tfd, arguments.map(p => Left(p._2)))) val optIdApp = optApp.map { case (idT, tpe) => - App(idT, bestRealType(tpe).asInstanceOf[FunctionType], arguments.map(_._2)) + App(idT, bestRealType(tpe).asInstanceOf[FunctionType], arguments.map(p => Left(p._2))) } lazy val invocMatcher = optCall.filter(_.returnType.isInstanceOf[FunctionType]) @@ -150,12 +168,7 @@ object Template { var matchInfos : Set[Matcher[T]] = Set.empty for (e <- es) { - funInfos ++= firstOrderCallsOf(e).map(p => TemplateCallInfo(p._1, p._2.map(encodeExpr))) - appInfos ++= firstOrderAppsOf(e).map { case (c, args) => - App(encodeExpr(c), bestRealType(c.getType).asInstanceOf[FunctionType], args.map(encodeExpr)) - } - - matchInfos ++= fold[Map[Expr, Matcher[T]]] { (expr, res) => + val exprToMatcher = fold[Map[Expr, Matcher[T]]] { (expr, res) => val result = res.flatten.toMap result ++ (expr match { @@ -170,7 +183,19 @@ object Template { Some(expr -> Matcher(encodeExpr(c), bestRealType(c.getType), encodedArgs, encodeExpr(expr))) case _ => None }) - }(e).values + }(e) + + def encodeArg(arg: Expr): Arg[T] = exprToMatcher.get(arg) match { + case Some(matcher) => Right(matcher) + case None => Left(encodeExpr(arg)) + } + + funInfos ++= firstOrderCallsOf(e).map(p => TemplateCallInfo(p._1, p._2.map(encodeArg))) + appInfos ++= firstOrderAppsOf(e).map { case (c, args) => + App(encodeExpr(c), bestRealType(c.getType).asInstanceOf[FunctionType], args.map(encodeArg)) + } + + matchInfos ++= exprToMatcher.values } val calls = funInfos -- optIdCall @@ -181,7 +206,7 @@ object Template { val matchs = matchInfos.filter { case m @ Matcher(mc, mtpe, margs, _) => !optIdApp.exists { case App(ac, atpe, aargs) => - mc == ac && mtpe == atpe && margs.map(Matcher.argValue) == aargs + mc == ac && mtpe == atpe && margs == aargs } } ++ (if (funInfos.exists(info => Some(info) == optIdCall)) invocMatcher else None) @@ -199,8 +224,9 @@ object Template { " * Activating boolean : " + pathVar._1 + "\n" + " * Control booleans : " + condVars.keys.mkString(", ") + "\n" + " * Expression vars : " + exprVars.keys.mkString(", ") + "\n" + - " * Clauses : " + - (for ((b,es) <- guardedExprs; e <- es) yield (b + " ==> " + e)).mkString("\n ") + "\n" + + " * Clauses : " + (if (guardedExprs.isEmpty) "\n" else { + "\n " + (for ((b,es) <- guardedExprs; e <- es) yield (b + " ==> " + e)).mkString("\n ") + "\n" + }) + " * Invocation-blocks :" + (if (blockers.isEmpty) "\n" else { "\n " + blockers.map(p => p._1 + " ==> " + p._2).mkString("\n ") + "\n" }) + @@ -226,13 +252,14 @@ object Template { condTree: Map[Identifier, Set[Identifier]], quantifications: Seq[QuantificationTemplate[T]], lambdas: Seq[LambdaTemplate[T]], - baseSubst: Map[T, T], + baseSubst: Map[T, Arg[T]], pathVar: Identifier, aVar: T - ): (Map[T, T], Instantiation[T]) = { - var subst = exprVars.map { case (id, idT) => idT -> encoder.encodeId(id) } ++ - manager.freshConds(pathVar -> aVar, condVars, condTree) ++ - baseSubst + ): (Map[T, Arg[T]], Instantiation[T]) = { + val freshSubst = exprVars.map { case (id, idT) => idT -> encoder.encodeId(id) } ++ + manager.freshConds(pathVar -> aVar, condVars, condTree) + val matcherSubst = baseSubst.collect { case (c, Right(m)) => c -> m } + var subst = freshSubst.mapValues(Left(_)) ++ baseSubst // /!\ CAREFUL /!\ // We have to be wary while computing the lambda subst map since lambdas can @@ -248,22 +275,24 @@ object Template { if !seen(dep) } extractSubst(dep) - if (!seen(lambda)) { - val substLambda = lambda.substitute(encoder.substitute(subst)) - val (idT, inst) = manager.instantiateLambda(substLambda) - instantiation ++= inst - subst += lambda.ids._2 -> idT - seen += lambda - } + if (!seen(lambda)) { + val substMap = subst.mapValues(_.encoded) + val substLambda = lambda.substitute(encoder.substitute(substMap), matcherSubst) + val (idT, inst) = manager.instantiateLambda(substLambda) + instantiation ++= inst + subst += lambda.ids._2 -> Left(idT) + seen += lambda + } } for (l <- lambdas) extractSubst(l) for (q <- quantifications) { - val substQuant = q.substitute(encoder.substitute(subst)) + val substMap = subst.mapValues(_.encoded) + val substQuant = q.substitute(encoder.substitute(substMap), matcherSubst) val (qT, inst) = manager.instantiateQuantification(substQuant) instantiation ++= inst - subst += q.qs._2 -> qT + subst += q.qs._2 -> Left(qT) } (subst, instantiation) @@ -278,25 +307,27 @@ object Template { quantifications: Seq[QuantificationTemplate[T]], matchers: Map[T, Set[Matcher[T]]], lambdas: Seq[LambdaTemplate[T]], - substMap: Map[T, T] + substMap: Map[T, Arg[T]] ): Instantiation[T] = { - val substituter : T => T = encoder.substitute(substMap) + val substituter : T => T = encoder.substitute(substMap.mapValues(_.encoded)) + val msubst = substMap.collect { case (c, Right(m)) => c -> m } val newClauses = clauses.map(substituter) val newBlockers = blockers.map { case (b,fis) => - substituter(b) -> fis.map(fi => fi.copy(args = fi.args.map(substituter))) + substituter(b) -> fis.map(fi => fi.copy(args = fi.args.map(_.substitute(substituter, msubst)))) } var instantiation: Instantiation[T] = (newClauses, newBlockers, Map.empty) for ((b,apps) <- applications; bp = substituter(b); app <- apps) { - val newApp = app.copy(caller = substituter(app.caller), args = app.args.map(substituter)) + val newApp = app.copy(caller = substituter(app.caller), args = app.args.map(_.substitute(substituter, msubst))) instantiation ++= manager.instantiateApp(bp, newApp) } for ((b, matchs) <- matchers; bp = substituter(b); m <- matchs) { - instantiation ++= manager.instantiateMatcher(bp, m.substitute(substituter)) + val newMatcher = m.substitute(substituter, msubst) + instantiation ++= manager.instantiateMatcher(bp, newMatcher) } instantiation @@ -373,8 +404,8 @@ class FunctionTemplate[T] private( private lazy val str : String = stringRepr() override def toString : String = str - override def instantiate(aVar: T, args: Seq[T]): (Seq[T], Map[T, Set[TemplateCallInfo[T]]], Map[(T, App[T]), Set[TemplateAppInfo[T]]]) = { - if (!isRealFunDef) manager.registerFree(tfd.params.map(_.id) zip args) + override def instantiate(aVar: T, args: Seq[Arg[T]]): Instantiation[T] = { + if (!isRealFunDef) manager.registerFree(tfd.params.map(_.id) zip args.map(_.left.get)) super.instantiate(aVar, args) } } diff --git a/src/main/scala/leon/solvers/templates/UnrollingBank.scala b/src/main/scala/leon/solvers/templates/UnrollingBank.scala index f96d44239eef30f2c9ef4d758924c322036591f9..4543262d74204dd9f77d3eeea8882bf543956a90 100644 --- a/src/main/scala/leon/solvers/templates/UnrollingBank.scala +++ b/src/main/scala/leon/solvers/templates/UnrollingBank.scala @@ -171,7 +171,7 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat // define an activating boolean... val template = templateGenerator.mkTemplate(expr) - val trArgs = template.tfd.params.map(vd => bindings(Variable(vd.id))) + val trArgs = template.tfd.params.map(vd => Left(bindings(Variable(vd.id)))) for (vd <- template.tfd.params if vd.getType.isInstanceOf[FunctionType]) { functionVars += vd.getType -> (functionVars.getOrElse(vd.getType, Set()) + bindings(vd.toVariable)) @@ -291,7 +291,6 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat } for ((app, newInfos) <- nextApps) { - println(app -> newInfos) registerAppBlocker(nextGeneration(gen), app, newInfos) } diff --git a/src/main/scala/leon/solvers/z3/FairZ3Solver.scala b/src/main/scala/leon/solvers/z3/FairZ3Solver.scala index 6a96848c5fb283ec1a6ae22817afff95f9300122..3d39ae7e7e99e45b344859d46496c31baa17ad8c 100644 --- a/src/main/scala/leon/solvers/z3/FairZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/FairZ3Solver.scala @@ -17,6 +17,7 @@ import purescala.ExprOps._ import purescala.Types._ import solvers.templates._ +import Template._ import evaluators._ @@ -66,7 +67,7 @@ class FairZ3Solver(val context: LeonContext, val program: Program) if (optEnabler == Some(true)) { val optArgs = (m.args zip fromTypes).map { - p => softFromZ3Formula(model, model.eval(Matcher.argValue(p._1), true).get, p._2) + p => softFromZ3Formula(model, model.eval(p._1.encoded, true).get, p._2) } if (optArgs.forall(_.isDefined)) { @@ -243,7 +244,7 @@ class FairZ3Solver(val context: LeonContext, val program: Program) solver.assertCnstr(clause) } - reporter.debug(" - Verifying model transitivity") + reporter.debug(" - Enforcing model transitivity") val timer = context.timers.solvers.z3.check.start() solver.push() // FIXME: remove when z3 bug is fixed val res = solver.checkAssumptions((assumptionsAsZ3 ++ unrollingBank.satisfactionAssumptions) :_*) diff --git a/src/test/resources/regression/verification/purescala/valid/Formulas.scala b/src/test/resources/regression/verification/purescala/valid/Formulas.scala new file mode 100644 index 0000000000000000000000000000000000000000..0fafe4158a2fcbd1b6654d21dbfa072ec42d614f --- /dev/null +++ b/src/test/resources/regression/verification/purescala/valid/Formulas.scala @@ -0,0 +1,50 @@ +import leon.lang._ +import leon._ + +object Formulas { + abstract class Expr + case class And(lhs: Expr, rhs: Expr) extends Expr + case class Or(lhs: Expr, rhs: Expr) extends Expr + case class Implies(lhs: Expr, rhs: Expr) extends Expr + case class Not(e : Expr) extends Expr + case class BoolLiteral(i: BigInt) extends Expr + + def exists(e: Expr, f: Expr => Boolean): Boolean = { + f(e) || (e match { + case And(lhs, rhs) => exists(lhs, f) || exists(rhs, f) + case Or(lhs, rhs) => exists(lhs, f) || exists(rhs, f) + case Implies(lhs, rhs) => exists(lhs, f) || exists(rhs, f) + case Not(e) => exists(e, f) + case _ => false + }) + } + + def existsImplies(e: Expr): Boolean = { + e.isInstanceOf[Implies] || (e match { + case And(lhs, rhs) => existsImplies(lhs) || existsImplies(rhs) + case Or(lhs, rhs) => existsImplies(lhs) || existsImplies(rhs) + case Implies(lhs, rhs) => existsImplies(lhs) || existsImplies(rhs) + case Not(e) => existsImplies(e) + case _ => false + }) + } + + abstract class Value + case class BoolValue(b: Boolean) extends Value + case class IntValue(i: BigInt) extends Value + case object Error extends Value + + def desugar(e: Expr): Expr = { + e match { + case And(lhs, rhs) => And(desugar(lhs), desugar(rhs)) + case Or(lhs, rhs) => Or(desugar(lhs), desugar(rhs)) + case Implies(lhs, rhs) => + Or(Not(desugar(lhs)), desugar(rhs)) + case Not(e) => Not(desugar(e)) + case e => e + } + } ensuring { out => + !existsImplies(out) && + !exists(out, f => f.isInstanceOf[Implies]) + } +}