diff --git a/src/main/scala/leon/solvers/SolverFactory.scala b/src/main/scala/leon/solvers/SolverFactory.scala index 5343be2d2cbc8f47613182b932b9f6be52c4b7ec..ce5bb8c8907dbb233d5ecb93714b510512bbe2b2 100644 --- a/src/main/scala/leon/solvers/SolverFactory.scala +++ b/src/main/scala/leon/solvers/SolverFactory.scala @@ -34,12 +34,10 @@ object SolverFactory { SolverFactory(() => new EnumerationSolver(ctx, program) with TimeoutSolver) case "smt" | "smt-z3" => - val smtf = SolverFactory(() => new SMTLIBSolver(ctx, program) with SMTLIBZ3Target) - SolverFactory(() => new UnrollingSolver(ctx, smtf) with TimeoutSolver) + SolverFactory(() => new UnrollingSolver(ctx, new SMTLIBSolver(ctx, program) with SMTLIBZ3Target) with TimeoutSolver) case "smt-cvc4" => - val smtf = SolverFactory(() => new SMTLIBSolver(ctx, program) with SMTLIBCVC4Target) - SolverFactory(() => new UnrollingSolver(ctx, smtf) with TimeoutSolver) + SolverFactory(() => new UnrollingSolver(ctx, new SMTLIBSolver(ctx, program) with SMTLIBCVC4Target) with TimeoutSolver) case _ => ctx.reporter.fatalError("Unknown solver "+name) diff --git a/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala b/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala index 43b794e2d51fd229b66b99973dd96a6a9a4000f4..99dc09f0048f3d46773718d256d5b064d8517baa 100644 --- a/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala +++ b/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala @@ -10,152 +10,182 @@ import purescala.Trees._ import purescala.TreeOps._ import purescala.TypeTrees._ +import solvers.templates._ import utils.Interruptible import scala.collection.mutable.{Map=>MutableMap} -class UnrollingSolver(val context: LeonContext, underlyings: SolverFactory[IncrementalSolver]) +class UnrollingSolver(val context: LeonContext, underlying: IncrementalSolver) extends Solver with Interruptible { - private var theConstraint : Option[Expr] = None - private var theModel : Option[Map[Identifier,Expr]] = None + private var lastCheckResult: (Boolean, Option[Boolean], Option[Map[Identifier,Expr]]) = (false, None, None) val reporter = context.reporter - private var stop: Boolean = false + private var interrupted: Boolean = false - def name = "U:"+underlyings.name + def name = "U:"+underlying.name def free {} - import context.reporter._ + var varsInVC = List[Set[Identifier]](Set()) - def assertCnstr(expression : Expr) { - if(!theConstraint.isEmpty) { - fatalError("Multiple assertCnstr(...).") + val templateGenerator = new TemplateGenerator(new TemplateEncoder[Expr] { + def encodeId(id: Identifier): Expr= { + Variable(id.freshen) + } + + def encodeExpr(bindings: Map[Identifier, Expr])(e: Expr): Expr = { + replaceFromIDs(bindings, e) } - theConstraint = Some(expression) - } - def check : Option[Boolean] = theConstraint.map { expr => - val solver = underlyings.getNewSolver + def substitute(substMap: Map[Expr, Expr]): Expr => Expr = { + (e: Expr) => replace(substMap, e) + } - val template = getTemplate(expr) + def not(e: Expr) = Not(e) + def implies(l: Expr, r: Expr) = Implies(l, r) + }) - val aVar : Identifier = template.activatingBool - var allBlockers : Map[Identifier,Set[FunctionInvocation]] = Map.empty + val unrollingBank = new UnrollingBank(reporter, templateGenerator) - def unrollOneStep() : List[Expr] = { - val blockersBefore = allBlockers + val solver = underlying - var newClauses : List[Seq[Expr]] = Nil - var newBlockers : Map[Identifier,Set[FunctionInvocation]] = Map.empty + def assertCnstr(expression: Expr) { + val freeIds = variablesOf(expression) - for(blocker <- allBlockers.keySet; fi @ FunctionInvocation(tfd, args) <- allBlockers(blocker)) { - val tmpl = getTemplate(tfd) + val freeVars = freeIds.map(_.toVariable: Expr) - val (nc, nb) = tmpl.instantiate(blocker, args) - newClauses = nc :: newClauses - newBlockers = newBlockers ++ nb - //reporter.debug("Unrolling behind "+fi+" ("+nc.size+")") - //for (c <- nc) { - // reporter.debug(" . "+c) - //} - } + val bindings = freeVars.zip(freeVars).toMap + + val newClauses = unrollingBank.getClauses(expression, bindings) - allBlockers = newBlockers - newClauses.flatten + for (cl <- newClauses) { + solver.assertCnstr(cl) } - val (nc, nb) = template.instantiate(aVar, template.tfd.params.map(a => Variable(a.id))) + varsInVC = (varsInVC.head ++ freeIds) :: varsInVC.tail + } + + + def push() { + unrollingBank.push() + solver.push() + varsInVC = Set[Identifier]() :: varsInVC + } + + def pop(lvl: Int = 1) { + unrollingBank.pop(lvl) + solver.pop(lvl) + varsInVC = varsInVC.drop(lvl) + } + + def check: Option[Boolean] = { + genericCheck(Set()) + } + + def hasFoundAnswer = lastCheckResult._1 - allBlockers = nb + def foundAnswer(res: Option[Boolean], model: Option[Map[Identifier, Expr]] = None) = { + lastCheckResult = (true, res, model) + } + + def genericCheck(assumptions: Set[Expr]): Option[Boolean] = { + lastCheckResult = (false, None, None) - var unrollingCount : Int = 0 - var done : Boolean = false - var result : Option[Boolean] = None + while(!hasFoundAnswer && !interrupted) { + reporter.debug(" - Running search...") - solver.assertCnstr(Variable(aVar)) - solver.assertCnstr(And(nc)) - // We're now past the initial step. - while(!done && !stop) { solver.push() - reporter.debug(" - Searching with blocked literals") - solver.assertCnstr(And(allBlockers.keySet.toSeq.map(id => Not(id.toVariable)))) - solver.check match { + solver.assertCnstr(And((assumptions ++ unrollingBank.currentBlockers).toSeq)) + val res = solver.check + solver.pop() - case Some(false) => - solver.pop(1) - reporter.debug(" - Searching with unblocked literals") - //val open = fullOpenExpr - solver.check match { - case Some(false) => - done = true - result = Some(false) - - case r => - unrollingCount += 1 - val model = solver.getModel - reporter.debug(" - Tentative model: "+model) - reporter.debug(" - more unrollings") - val newClauses = unrollOneStep() - reporter.debug(s" - ${newClauses.size} new clauses") - //readLine() - solver.assertCnstr(And(newClauses)) + reporter.debug(" - Finished search with blocked literals") + + res match { + case None => + reporter.ifDebug { debug => + reporter.debug("Solver returned unknown!?") } + foundAnswer(None) - case Some(true) => + case Some(true) => // SAT val model = solver.getModel - done = true - result = Some(true) - theModel = Some(model) - case None => - val model = solver.getModel - done = true - result = Some(true) - theModel = Some(model) + foundAnswer(Some(true), Some(model)) + + case Some(false) if !unrollingBank.canUnroll => + foundAnswer(Some(false)) + + case Some(false) => + //debug("UNSAT BECAUSE: "+solver.getUnsatCore.mkString("\n AND \n")) + //debug("UNSAT BECAUSE: "+core.mkString(" AND ")) + + if (!interrupted) { + reporter.debug(" - Running search without blocked literals (w/o lucky test)") + + solver.push() + solver.assertCnstr(And(assumptions.toSeq)) + val res2 = solver.check + solver.pop() + + res2 match { + case Some(false) => + //reporter.debug("UNSAT WITHOUT Blockers") + foundAnswer(Some(false)) + + case Some(true) => + + case None => + foundAnswer(None) + } + } + + if(interrupted) { + foundAnswer(None) + } + + if(!hasFoundAnswer) { + reporter.debug("- We need to keep going.") + + val toRelease = unrollingBank.getBlockersToUnlock + + reporter.debug(" - more unrollings") + + val newClauses = unrollingBank.unrollBehind(toRelease) + + for(ncl <- newClauses) { + solver.assertCnstr(ncl) + } + + reporter.debug(" - finished unrolling") + } } } - solver.free - result - } getOrElse { - Some(true) + if(interrupted) { + None + } else { + lastCheckResult._2 + } } - def getModel : Map[Identifier,Expr] = { - val vs : Set[Identifier] = theConstraint.map(variablesOf(_)).getOrElse(Set.empty) - theModel.getOrElse(Map.empty).filter(p => vs(p._1)) + def getModel: Map[Identifier,Expr] = { + val allVars = varsInVC.flatten.toSet + lastCheckResult match { + case (true, Some(true), Some(m)) => + m.filterKeys(allVars) + case _ => + Map() + } } override def interrupt(): Unit = { - stop = true + interrupted = true } override def recoverInterrupt(): Unit = { - stop = false - } - - private val tfdTemplateCache : MutableMap[TypedFunDef, FunctionTemplate] = MutableMap.empty - private val exprTemplateCache : MutableMap[Expr, FunctionTemplate] = MutableMap.empty - - private def getTemplate(tfd: TypedFunDef): FunctionTemplate = { - tfdTemplateCache.getOrElse(tfd, { - val res = FunctionTemplate.mkTemplate(tfd, true) - tfdTemplateCache += tfd -> res - res - }) - } - - private def getTemplate(body: Expr): FunctionTemplate = { - exprTemplateCache.getOrElse(body, { - val fakeFunDef = new FunDef(FreshIdentifier("fake", true), Nil, body.getType, variablesOf(body).toSeq.map(id => ValDef(id, id.getType)), DefType.MethodDef) - fakeFunDef.body = Some(body) - - val res = FunctionTemplate.mkTemplate(fakeFunDef.typed, false) - exprTemplateCache += body -> res - res - }) + interrupted = false } } diff --git a/src/main/scala/leon/solvers/templates/FunctionTemplate.scala b/src/main/scala/leon/solvers/templates/FunctionTemplate.scala new file mode 100644 index 0000000000000000000000000000000000000000..76fe5a4888cbcd097b1c06e3d2bb16afcbc6d091 --- /dev/null +++ b/src/main/scala/leon/solvers/templates/FunctionTemplate.scala @@ -0,0 +1,131 @@ +/* Copyright 2009-2014 EPFL, Lausanne */ + +package leon +package solvers +package templates + +import utils._ +import purescala.Common._ +import purescala.Trees._ +import purescala.Extractors._ +import purescala.TreeOps._ +import purescala.TypeTrees._ +import purescala.Definitions._ + +import evaluators._ + +class FunctionTemplate[T]( + val tfd: TypedFunDef, + val encoder: TemplateEncoder[T], + activatingBool: Identifier, + condVars: Set[Identifier], + exprVars: Set[Identifier], + guardedExprs: Map[Identifier,Seq[Expr]], + isRealFunDef: Boolean) { + + val evalGroundApps = false + + val clauses: Seq[Expr] = { + (for((b,es) <- guardedExprs; e <- es) yield { + Implies(Variable(b), e) + }).toSeq + } + + val trActivatingBool = encoder.encodeId(activatingBool) + + val trFunDefArgs = tfd.params.map( ad => encoder.encodeId(ad.id)) + val zippedCondVars = condVars.map(id => (id -> encoder.encodeId(id))) + val zippedExprVars = exprVars.map(id => (id -> encoder.encodeId(id))) + val zippedFunDefArgs = tfd.params.map(_.id) zip trFunDefArgs + + val idToTrId: Map[Identifier, T] = { + Map(activatingBool -> trActivatingBool) ++ + zippedCondVars ++ + zippedExprVars ++ + zippedFunDefArgs + } + + val encodeExpr = encoder.encodeExpr(idToTrId) _ + + val trClauses: Seq[T] = clauses.map(encodeExpr) + + val trBlockers: Map[T, Set[TemplateCallInfo[T]]] = { + val idCall = TemplateCallInfo[T](tfd, trFunDefArgs) + + Map((for((b, es) <- guardedExprs) yield { + val allCalls = es.map(functionCallsOf).flatten.toSet + val calls = (for (c <- allCalls) yield { + TemplateCallInfo[T](c.tfd, c.args.map(encodeExpr)) + }) - idCall + + if(calls.isEmpty) { + None + } else { + Some(idToTrId(b) -> calls) + } + }).flatten.toSeq : _*) + } + + // We use a cache to create the same boolean variables. + var cache = Map[Seq[T], Map[T, T]]() + + def instantiate(aVar: T, args: Seq[T]): (Seq[T], Map[T, Set[TemplateCallInfo[T]]]) = { + assert(args.size == tfd.params.size) + + // The "isRealFunDef" part is to prevent evaluation of "fake" + // function templates, as generated from FairZ3Solver. + //if(evalGroundApps && isRealFunDef) { + // val ga = args.view.map(solver.asGround) + // if(ga.forall(_.isDefined)) { + // val leonArgs = ga.map(_.get).force + // val invocation = FunctionInvocation(tfd, leonArgs) + // solver.getEvaluator.eval(invocation) match { + // case EvaluationResults.Successful(result) => + // val z3Invocation = z3.mkApp(solver.functionDefToDecl(tfd), args: _*) + // val z3Value = solver.toZ3Formula(result).get + // val asZ3 = z3.mkEq(z3Invocation, z3Value) + // return (Seq(asZ3), Map.empty) + + // case _ => throw new Exception("Evaluation of ground term should have succeeded.") + // } + // } + //} + // ...end of ground evaluation part. + + val baseSubstMap = cache.get(args) match { + case Some(m) => m + case None => + val newMap: Map[T, T] = + (zippedCondVars ++ zippedExprVars).map{ case (id, idT) => idT -> encoder.encodeId(id) }.toMap ++ + (trFunDefArgs zip args) + + cache += args -> newMap + newMap + } + + val substMap : Map[T, T] = baseSubstMap + (trActivatingBool -> aVar) + + val substituter = encoder.substitute(substMap) + + val newClauses = trClauses.map(substituter) + + val newBlockers = trBlockers.map { case (b, funs) => + val bp = substituter(b) + + val newFuns = funs.map(fi => fi.copy(args = fi.args.map(substituter))) + + bp -> newFuns + } + + (newClauses, newBlockers) + } + + override def toString : String = { + "Template for def " + tfd.signature + "(" + tfd.params.map(a => a.id + " : " + a.tpe).mkString(", ") + ") : " + tfd.returnType + " is :\n" + + " * Activating boolean : " + trActivatingBool + "\n" + + " * Control booleans : " + zippedCondVars.map(_._2.toString).mkString(", ") + "\n" + + " * Expression vars : " + zippedExprVars.map(_._2.toString).mkString(", ") + "\n" + + " * Clauses : " + "\n " +trClauses.mkString("\n ") + "\n" + + " * Block-map : " + trBlockers.toString + } +} diff --git a/src/main/scala/leon/solvers/templates/TemplateCallInfo.scala b/src/main/scala/leon/solvers/templates/TemplateCallInfo.scala new file mode 100644 index 0000000000000000000000000000000000000000..ee5eb1b25ee3363cc9ceac06ae8ff78be65af2cf --- /dev/null +++ b/src/main/scala/leon/solvers/templates/TemplateCallInfo.scala @@ -0,0 +1,13 @@ +/* Copyright 2009-2014 EPFL, Lausanne */ + +package leon +package solvers +package templates + +import purescala.Definitions.TypedFunDef + +case class TemplateCallInfo[T](tfd: TypedFunDef, args: Seq[T]) { + override def toString = { + tfd.signature+args.mkString("(", ", ", ")") + } +} diff --git a/src/main/scala/leon/solvers/templates/TemplateEncoder.scala b/src/main/scala/leon/solvers/templates/TemplateEncoder.scala new file mode 100644 index 0000000000000000000000000000000000000000..a7ba69fbb3a91c9c22ce2ef4d8fa07fa241b53d9 --- /dev/null +++ b/src/main/scala/leon/solvers/templates/TemplateEncoder.scala @@ -0,0 +1,18 @@ +/* Copyright 2009-2014 EPFL, Lausanne */ + +package leon +package solvers +package templates + +import purescala.Common.Identifier +import purescala.Trees.Expr + +trait TemplateEncoder[T] { + def encodeId(id: Identifier): T + def encodeExpr(bindings: Map[Identifier, T])(e: Expr): T + def substitute(map: Map[T, T]): T => T + + // Encodings needed for unrollingbank + def not(v: T): T + def implies(l: T, r: T): T +} diff --git a/src/main/scala/leon/solvers/combinators/FunctionTemplate.scala b/src/main/scala/leon/solvers/templates/TemplateGenerator.scala similarity index 63% rename from src/main/scala/leon/solvers/combinators/FunctionTemplate.scala rename to src/main/scala/leon/solvers/templates/TemplateGenerator.scala index abe4b684159eab6af7e3f4c6fe0b2fa871ca4b48..72ffca0b55f3b2cb8bda6edcdec50f82138dfee8 100644 --- a/src/main/scala/leon/solvers/combinators/FunctionTemplate.scala +++ b/src/main/scala/leon/solvers/templates/TemplateGenerator.scala @@ -1,8 +1,10 @@ /* Copyright 2009-2014 EPFL, Lausanne */ package leon -package solvers.combinators +package solvers +package templates +import utils._ import purescala.Common._ import purescala.Trees._ import purescala.Extractors._ @@ -12,105 +14,45 @@ import purescala.Definitions._ import evaluators._ -import scala.collection.mutable.{Set=>MutableSet,Map=>MutableMap} +class TemplateGenerator[T](val encoder: TemplateEncoder[T]) { + private var cache = Map[TypedFunDef, FunctionTemplate[T]]() + private var cacheExpr = Map[Expr, FunctionTemplate[T]]() -class FunctionTemplate private( - val tfd : TypedFunDef, - val activatingBool : Identifier, - condVars : Set[Identifier], - exprVars : Set[Identifier], - guardedExprs : Map[Identifier,Seq[Expr]], - isRealFunDef : Boolean) { - - private val funDefArgsIDs : Seq[Identifier] = tfd.params.map(_.id) - - private val asClauses : Seq[Expr] = { - (for((b,es) <- guardedExprs; e <- es) yield { - Implies(Variable(b), e) - }).toSeq - } - - val blockers : Map[Identifier,Set[FunctionInvocation]] = { - val idCall = FunctionInvocation(tfd, tfd.params.map(_.toVariable)) - - Map((for((b, es) <- guardedExprs) yield { - val calls = es.foldLeft(Set.empty[FunctionInvocation])((s,e) => s ++ functionCallsOf(e)) - idCall - if(calls.isEmpty) { - None - } else { - Some((b, calls)) - } - }).flatten.toSeq : _*) - } - - private def idToFreshID(id : Identifier) : Identifier = { - FreshIdentifier(id.name, true).setType(id.getType) - } - - // We use a cache to create the same boolean variables. - private val cache : MutableMap[Seq[Expr],Map[Identifier,Expr]] = MutableMap.empty - - def instantiate(aVar : Identifier, args : Seq[Expr]) : (Seq[Expr], Map[Identifier,Set[FunctionInvocation]]) = { - assert(args.size == tfd.params.size) - - val (wasHit,baseIDSubstMap) = cache.get(args) match { - case Some(m) => (true,m) - case None => - val newMap : Map[Identifier,Expr] = - (exprVars ++ condVars).map(id => id -> Variable(idToFreshID(id))).toMap ++ - (funDefArgsIDs zip args) - cache(args) = newMap - (false, newMap) + def mkTemplate(body: Expr): FunctionTemplate[T] = { + if (cacheExpr contains body) { + return cacheExpr(body); } - val idSubstMap : Map[Identifier,Expr] = baseIDSubstMap + (activatingBool -> Variable(aVar)) - val exprSubstMap : Map[Expr,Expr] = idSubstMap.map(p => (Variable(p._1), p._2)) - - val newClauses = asClauses.map(replace(exprSubstMap, _)) - - val newBlockers = blockers.map { case (id, funs) => - val bp = if (id == activatingBool) { - aVar - } else { - // That's not exactly safe... - idSubstMap(id).asInstanceOf[Variable].id - } - - val newFuns = funs.map(fi => fi.copy(args = fi.args.map(replace(exprSubstMap, _)))) + val fakeFunDef = new FunDef(FreshIdentifier("fake", true), + Nil, + body.getType, + variablesOf(body).toSeq.map(id => ValDef(id, id.getType))) - bp -> newFuns - } + fakeFunDef.body = Some(body) - (newClauses, newBlockers) + val res = mkTemplate(fakeFunDef.typed, false) + cacheExpr += body -> res + res } - override def toString : String = { - "Template for def " + tfd.id + "(" + tfd.params.map(a => a.id + " : " + a.tpe).mkString(", ") + ") : " + tfd.returnType + " is :\n" + - " * Activating boolean : " + activatingBool + "\n" + - " * Control booleans : " + condVars.toSeq.map(_.toString).mkString(", ") + "\n" + - " * Expression vars : " + exprVars.toSeq.map(_.toString).mkString(", ") + "\n" + - " * \"Clauses\" : " + "\n " + asClauses.mkString("\n ") + "\n" + - " * Block-map : " + blockers.toString - } -} + def mkTemplate(tfd: TypedFunDef, isRealFunDef: Boolean = true): FunctionTemplate[T] = { + if (cache contains tfd) { + return cache(tfd) + } -object FunctionTemplate { - def mkTemplate(tfd: TypedFunDef, isRealFunDef : Boolean = true) : FunctionTemplate = { - val condVars : MutableSet[Identifier] = MutableSet.empty - val exprVars : MutableSet[Identifier] = MutableSet.empty + var condVars = Set[Identifier]() + var exprVars = Set[Identifier]() // Represents clauses of the form: // id => expr && ... && expr - val guardedExprs : MutableMap[Identifier,Seq[Expr]] = MutableMap.empty + var guardedExprs = Map[Identifier, Seq[Expr]]() def storeGuarded(guardVar : Identifier, expr : Expr) : Unit = { assert(expr.getType == BooleanType) - if(guardedExprs.isDefinedAt(guardVar)) { - val prev : Seq[Expr] = guardedExprs(guardVar) - guardedExprs(guardVar) = expr +: prev - } else { - guardedExprs(guardVar) = Seq(expr) - } + + val prev = guardedExprs.getOrElse(guardVar, Nil) + + guardedExprs += guardVar -> (expr +: prev) } // Group elements that satisfy p toghether @@ -143,7 +85,7 @@ object FunctionTemplate { }(e) } - def rec(pathVar : Identifier, expr : Expr) : Expr = { + def rec(pathVar: Identifier, expr: Expr): Expr = { expr match { case a @ Assert(cond, _, body) => storeGuarded(pathVar, rec(pathVar, cond)) @@ -285,7 +227,14 @@ object FunctionTemplate { } - new FunctionTemplate(tfd, activatingBool, Set(condVars.toSeq : _*), Set(exprVars.toSeq : _*), Map(guardedExprs.toSeq : _*), -isRealFunDef) + val template = new FunctionTemplate[T](tfd, + encoder, + activatingBool, + Set(condVars.toSeq : _*), + Set(exprVars.toSeq : _*), + Map(guardedExprs.toSeq : _*), + isRealFunDef) + cache += tfd -> template + template } } diff --git a/src/main/scala/leon/solvers/templates/UnrollingBank.scala b/src/main/scala/leon/solvers/templates/UnrollingBank.scala new file mode 100644 index 0000000000000000000000000000000000000000..b579b88ec645b145a150995d744b2bd4ab494e29 --- /dev/null +++ b/src/main/scala/leon/solvers/templates/UnrollingBank.scala @@ -0,0 +1,191 @@ +/* Copyright 2009-2014 EPFL, Lausanne */ + +package leon +package solvers +package templates + +import utils._ +import purescala.Common._ +import purescala.Trees._ +import purescala.Extractors._ +import purescala.TreeOps._ +import purescala.TypeTrees._ +import purescala.Definitions._ + +import evaluators._ + +class UnrollingBank[T](reporter: Reporter, templateGenerator: TemplateGenerator[T]) { + implicit val debugSection = utils.DebugSectionSolver + + private val encoder = templateGenerator.encoder + + // Keep which function invocation is guarded by which guard, + // also specify the generation of the blocker. + private var blockersInfoStack = List[Map[T, (Int, Int, T, Set[TemplateCallInfo[T]])]](Map()) + + // Function instantiations have their own defblocker + private var defBlockers = Map[TemplateCallInfo[T], T]() + + def blockersInfo = blockersInfoStack.head + + def blockersInfo_= (v: Map[T, (Int, Int, T, Set[TemplateCallInfo[T]])]) = { + blockersInfoStack = v :: blockersInfoStack.tail + } + + def push() { + blockersInfoStack = blockersInfo :: blockersInfoStack + } + + def pop(lvl: Int) { + blockersInfoStack = blockersInfoStack.drop(lvl) + } + + def dumpBlockers = { + blockersInfo.groupBy(_._2._1).toSeq.sortBy(_._1).foreach { case (gen, entries) => + reporter.debug("--- "+gen) + + + for (((bast), (gen, origGen, ast, fis)) <- entries) { + reporter.debug(f". $bast%15s ~> "+fis.mkString(", ")) + } + } + } + + def canUnroll = !blockersInfo.isEmpty + + def currentBlockers = blockersInfo.map(_._2._3) + + def getBlockersToUnlock: Seq[T] = { + if (!blockersInfo.isEmpty) { + val minGeneration = blockersInfo.values.map(_._1).min + + blockersInfo.filter(_._2._1 == minGeneration).toSeq.map(_._1) + } else { + Seq() + } + } + + private def registerBlocker(gen: Int, id: T, fis: Set[TemplateCallInfo[T]]) { + val notId = encoder.not(id) + + blockersInfo.get(id) match { + case Some((exGen, origGen, _, exFis)) => + // PS: when recycling `b`s, this assertion becomes dangerous. + // It's better to simply take the max of the generations. + // assert(exGen == gen, "Mixing the same id "+id+" with various generations "+ exGen+" and "+gen) + + val minGen = gen min exGen + + blockersInfo += id -> (minGen, origGen, notId, fis++exFis) + case None => + blockersInfo += id -> (gen, gen, notId, fis) + } + } + + def getClauses(expr: Expr, bindings: Map[Expr, T]): Seq[T] = { + // OK, now this is subtle. This `getTemplate` will return + // a template for a "fake" function. Now, this template will + // define an activating boolean... + val template = templateGenerator.mkTemplate(expr) + + + val trArgs = template.tfd.params.map(vd => bindings(Variable(vd.id))) + + // ...now this template defines clauses that are all guarded + // by that activating boolean. If that activating boolean is + // undefined (or false) these clauses have no effect... + val (newClauses, newBlocks) = + template.instantiate(template.trActivatingBool, trArgs) + + for((i, fis) <- newBlocks) { + registerBlocker(nextGeneration(0), i, fis) + } + + // ...so we must force it to true! + template.trActivatingBool +: newClauses + } + + def nextGeneration(gen: Int) = gen + 3 + + def decreaseAllGenerations() = { + for ((block, (gen, origGen, ast, finvs)) <- blockersInfo) { + // We also decrease the original generation here + blockersInfo += block -> (math.max(1,gen-1), math.max(1,origGen-1), ast, finvs) + } + } + + def promoteBlocker(b: T) = { + if (blockersInfo contains b) { + val (gen, origGen, ast, fis) = blockersInfo(b) + + blockersInfo += b -> (1, origGen, ast, fis) + } + } + + def unrollBehind(ids: Seq[T]): Seq[T] = { + assert(ids.forall(id => blockersInfo contains id)) + + var newClauses : Seq[T] = Seq.empty + + for (id <- ids) { + val (gen, _, _, fis) = blockersInfo(id) + + blockersInfo = blockersInfo - id + + var reintroducedSelf = false + + for (fi <- fis) { + var newCls = Seq[T]() + + val defBlocker = defBlockers.get(fi) match { + case Some(defBlocker) => + // we already have defBlocker => f(args) = body + defBlocker + case None => + // we need to define this defBlocker and link it to definition + val defBlocker = encoder.encodeId(FreshIdentifier("d").setType(BooleanType)) + defBlockers += fi -> defBlocker + + val template = templateGenerator.mkTemplate(fi.tfd) + reporter.debug(template) + val (newExprs, newBlocks) = template.instantiate(defBlocker, fi.args) + + for((i, fis2) <- newBlocks) { + registerBlocker(nextGeneration(gen), i, fis2) + } + + newCls ++= newExprs + defBlocker + } + + // We connect it to the defBlocker: blocker => defBlocker + if (defBlocker != id) { + newCls ++= List(encoder.implies(id, defBlocker)) + } + + reporter.debug("Unrolling behind "+fi+" ("+newCls.size+")") + for (cl <- newCls) { + reporter.debug(" . "+cl) + } + + newClauses ++= newCls + } + + } + + reporter.debug(s" - ${newClauses.size} new clauses") + //context.reporter.ifDebug { debug => + // debug(s" - new clauses:") + // debug("@@@@") + // for (cl <- newClauses) { + // debug(""+cl) + // } + // debug("////") + //} + + //dumpBlockers + //readLine() + + newClauses + } +} diff --git a/src/main/scala/leon/solvers/z3/FairZ3Solver.scala b/src/main/scala/leon/solvers/z3/FairZ3Solver.scala index 79d60b377dd68aa436bd1564e9abafceca55e45d..bca62952fb346f45794cff863060d5d3fdaf1b79 100644 --- a/src/main/scala/leon/solvers/z3/FairZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/FairZ3Solver.scala @@ -1,13 +1,12 @@ /* Copyright 2009-2014 EPFL, Lausanne */ package leon -package solvers.z3 +package solvers +package z3 import leon.utils._ -import z3.scala._ - -import leon.solvers.{Solver, IncrementalSolver} +import _root_.z3.scala._ import purescala.Common._ import purescala.Definitions._ @@ -16,13 +15,12 @@ import purescala.Extractors._ import purescala.TreeOps._ import purescala.TypeTrees._ +import solvers.templates._ + import evaluators._ import termination._ -import scala.collection.mutable.{Map => MutableMap} -import scala.collection.mutable.{Set => MutableSet} - class FairZ3Solver(val context : LeonContext, val program: Program) extends AbstractZ3Solver with Z3ModelReconstruction @@ -133,206 +131,28 @@ class FairZ3Solver(val context : LeonContext, val program: Program) } } - private val funDefTemplateCache : MutableMap[TypedFunDef, FunctionTemplate] = MutableMap.empty - private val exprTemplateCache : MutableMap[Expr , FunctionTemplate] = MutableMap.empty - - private def getTemplate(tfd: TypedFunDef): FunctionTemplate = { - funDefTemplateCache.getOrElse(tfd, { - val res = FunctionTemplate.mkTemplate(this, tfd, true) - funDefTemplateCache += tfd -> res - res - }) - } - - private def getTemplate(body: Expr): FunctionTemplate = { - exprTemplateCache.getOrElse(body, { - val fakeFunDef = new FunDef(FreshIdentifier("fake", true), Nil, body.getType, variablesOf(body).toSeq.map(id => ValDef(id, id.getType)), DefType.MethodDef) - fakeFunDef.body = Some(body) - - val res = FunctionTemplate.mkTemplate(this, fakeFunDef.typed, false) - exprTemplateCache += body -> res - res - }) - } - - class UnrollingBank { - // Keep which function invocation is guarded by which guard, - // also specify the generation of the blocker. - - private var blockersInfoStack : List[MutableMap[Z3AST,(Int,Int,Z3AST,Set[Z3FunctionInvocation])]] = List(MutableMap()) - - def blockersInfo = blockersInfoStack.head - - def push() { - blockersInfoStack = (MutableMap() ++ blockersInfo) :: blockersInfoStack - } - - def pop(lvl: Int) { - blockersInfoStack = blockersInfoStack.drop(lvl) - } - - def z3CurrentZ3Blockers = blockersInfo.map(_._2._3) - - def finfo(fi: Z3FunctionInvocation) = { - fi.tfd.id.uniqueName+fi.args.mkString("(", ", ", ")") + val templateGenerator = new TemplateGenerator(new TemplateEncoder[Z3AST] { + def encodeId(id: Identifier): Z3AST = { + idToFreshZ3Id(id) } - def dumpBlockers = { - blockersInfo.groupBy(_._2._1).toSeq.sortBy(_._1).foreach { case (gen, entries) => - reporter.debug("--- "+gen) - - - for (((bast), (gen, origGen, ast, fis)) <- entries) { - reporter.debug(f". $bast%15s ~> "+fis.map(finfo).mkString(", ")) - } - } - } - - def canUnroll = !blockersInfo.isEmpty - - def getZ3BlockersToUnlock: Seq[Z3AST] = { - if (!blockersInfo.isEmpty) { - val minGeneration = blockersInfo.values.map(_._1).min - - blockersInfo.filter(_._2._1 == minGeneration).toSeq.map(_._1) - } else { - Seq() + def encodeExpr(bindings: Map[Identifier, Z3AST])(e: Expr): Z3AST = { + toZ3Formula(e, bindings).getOrElse { + reporter.fatalError("Failed to translate "+e+" to z3 ("+e.getClass+")") } } - private def registerBlocker(gen: Int, id: Z3AST, fis: Set[Z3FunctionInvocation]) { - val z3ast = z3.mkNot(id) - blockersInfo.get(id) match { - case Some((exGen, origGen, _, exFis)) => - // PS: when recycling `b`s, this assertion becomes dangerous. - // It's better to simply take the max of the generations. - // assert(exGen == gen, "Mixing the same id "+id+" with various generations "+ exGen+" and "+gen) - - val minGen = gen min exGen + def substitute(substMap: Map[Z3AST, Z3AST]): Z3AST => Z3AST = { + val (from, to) = substMap.unzip + val (fromArray, toArray) = (from.toArray, to.toArray) - blockersInfo(id) = ((minGen, origGen, z3ast, fis++exFis)) - case None => - blockersInfo(id) = ((gen, gen, z3ast, fis)) - } + (c: Z3AST) => z3.substitute(c, fromArray, toArray) } - def scanForNewTemplates(expr: Expr): Seq[Z3AST] = { - // OK, now this is subtle. This `getTemplate` will return - // a template for a "fake" function. Now, this template will - // define an activating boolean... - val template = getTemplate(expr) - - - val z3args = for (vd <- template.tfd.params) yield { - variables.getZ3(Variable(vd.id)) match { - case Some(ast) => - ast - case None => - val ast = idToFreshZ3Id(vd.id) - variables += Variable(vd.id) -> ast - ast - } - } - - // ...now this template defines clauses that are all guarded - // by that activating boolean. If that activating boolean is - // undefined (or false) these clauses have no effect... - val (newClauses, newBlocks) = - template.instantiate(template.z3ActivatingBool, z3args) + def not(e: Z3AST) = z3.mkNot(e) + def implies(l: Z3AST, r: Z3AST) = z3.mkImplies(l, r) + }) - for((i, fis) <- newBlocks) { - registerBlocker(nextGeneration(0), i, fis) - } - - // ...so we must force it to true! - template.z3ActivatingBool +: newClauses - } - - def nextGeneration(gen: Int) = gen + 3 - - def decreaseAllGenerations() = { - for ((block, (gen, origGen, ast, finvs)) <- blockersInfo) { - // We also decrease the original generation here - blockersInfo(block) = (math.max(1,gen-1), math.max(1,origGen-1), ast, finvs) - } - } - - def promoteBlocker(b: Z3AST) = { - if (blockersInfo contains b) { - val (gen, origGen, ast, finvs) = blockersInfo(b) - blockersInfo(b) = (1, origGen, ast, finvs) - } - } - - private var defBlockers = Map[Z3FunctionInvocation, Z3AST]() - - def unlock(ids: Seq[Z3AST]) : Seq[Z3AST] = { - assert(ids.forall(id => blockersInfo contains id)) - - var newClauses : Seq[Z3AST] = Seq.empty - - for (id <- ids) { - val (gen, _, _, fis) = blockersInfo(id) - - blockersInfo -= id - - var reintroducedSelf = false - - for (fi <- fis) { - var newCls = Seq[Z3AST]() - - val defBlocker = defBlockers.get(fi) match { - case Some(defBlocker) => - // we already have defBlocker => f(args) = body - defBlocker - case None => - // we need to define this defBlocker and link it to definition - val defBlocker = z3.mkFreshConst("d", z3.mkBoolSort) - defBlockers += fi -> defBlocker - - val template = getTemplate(fi.tfd) - reporter.debug(template) - val (newExprs, newBlocks) = template.instantiate(defBlocker, fi.args) - - for((i, fis2) <- newBlocks) { - registerBlocker(nextGeneration(gen), i, fis2) - } - - newCls ++= newExprs - defBlocker - } - - // We connect it to the defBlocker: blocker => defBlocker - if (defBlocker != id) { - newCls ++= List(z3.mkImplies(id, defBlocker)) - } - - reporter.debug("Unrolling behind "+fi+" ("+newCls.size+")") - for (cl <- newCls) { - reporter.debug(" . "+cl) - } - - newClauses ++= newCls - } - - } - - context.reporter.debug(s" - ${newClauses.size} new clauses") - //context.reporter.ifDebug { debug => - // debug(s" - new clauses:") - // debug("@@@@") - // for (cl <- newClauses) { - // debug(""+cl) - // } - // debug("////") - //} - - //dumpBlockers - //readLine() - - newClauses - } - } initZ3 @@ -342,7 +162,7 @@ class FairZ3Solver(val context : LeonContext, val program: Program) private var frameExpressions = List[List[Expr]](Nil) - val unrollingBank = new UnrollingBank() + val unrollingBank = new UnrollingBank(reporter, templateGenerator) def push() { solver.push() @@ -370,11 +190,19 @@ class FairZ3Solver(val context : LeonContext, val program: Program) var definitiveCore : Set[Expr] = Set.empty def assertCnstr(expression: Expr) { - varsInVC ++= variablesOf(expression) + val freeVars = variablesOf(expression) + varsInVC ++= freeVars + + // We make sure all free variables are registered as variables + freeVars.foreach { v => + variables.toZ3OrCompute(Variable(v)) { + templateGenerator.encoder.encodeId(v) + } + } frameExpressions = (expression :: frameExpressions.head) :: frameExpressions.tail - val newClauses = unrollingBank.scanForNewTemplates(expression) + val newClauses = unrollingBank.getClauses(expression, variables.leonToZ3) for (cl <- newClauses) { solver.assertCnstr(cl) @@ -426,7 +254,7 @@ class FairZ3Solver(val context : LeonContext, val program: Program) val timer = context.timers.solvers.z3.check.start() solver.push() // FIXME: remove when z3 bug is fixed - val res = solver.checkAssumptions((assumptionsAsZ3 ++ unrollingBank.z3CurrentZ3Blockers) :_*) + val res = solver.checkAssumptions((assumptionsAsZ3 ++ unrollingBank.currentBlockers) :_*) solver.pop() // FIXME: remove when z3 bug is fixed timer.stop() @@ -548,11 +376,11 @@ class FairZ3Solver(val context : LeonContext, val program: Program) if(!foundDefinitiveAnswer) { reporter.debug("- We need to keep going.") - val toRelease = unrollingBank.getZ3BlockersToUnlock + val toRelease = unrollingBank.getBlockersToUnlock reporter.debug(" - more unrollings") - val newClauses = unrollingBank.unlock(toRelease) + val newClauses = unrollingBank.unrollBehind(toRelease) for(ncl <- newClauses) { solver.assertCnstr(ncl) diff --git a/src/main/scala/leon/solvers/z3/FunctionTemplate.scala b/src/main/scala/leon/solvers/z3/FunctionTemplate.scala deleted file mode 100644 index cf993c1b9054144987674d0ac2d8578d5e98b141..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/solvers/z3/FunctionTemplate.scala +++ /dev/null @@ -1,344 +0,0 @@ -/* Copyright 2009-2014 EPFL, Lausanne */ - -package leon -package solvers.z3 - -import purescala.Common._ -import purescala.Trees._ -import purescala.Extractors._ -import purescala.TreeOps._ -import purescala.TypeTrees._ -import purescala.Definitions._ - -import evaluators._ - -import z3.scala._ - -import scala.collection.mutable.{Set=>MutableSet,Map=>MutableMap} - -case class Z3FunctionInvocation(tfd: TypedFunDef, args: Seq[Z3AST]) { - override def toString = tfd.signature + args.mkString("(", ",", ")") -} - -class FunctionTemplate private( - solver: FairZ3Solver, - val tfd : TypedFunDef, - activatingBool : Identifier, - condVars : Set[Identifier], - exprVars : Set[Identifier], - guardedExprs : Map[Identifier,Seq[Expr]], - isRealFunDef : Boolean) { - - private def isTerminatingForAllInputs : Boolean = ( - isRealFunDef - && !tfd.hasPrecondition - && solver.getTerminator.terminates(tfd.fd).isGuaranteed - ) - - private val z3 = solver.z3 - - private val asClauses : Seq[Expr] = { - (for((b,es) <- guardedExprs; e <- es) yield { - Implies(Variable(b), e) - }).toSeq - } - - val z3ActivatingBool = solver.idToFreshZ3Id(activatingBool) - - private val z3FunDefArgs = tfd.params.map( ad => solver.idToFreshZ3Id(ad.id)) - - private val zippedCondVars = condVars.map(id => (id, solver.idToFreshZ3Id(id))) - private val zippedExprVars = exprVars.map(id => (id, solver.idToFreshZ3Id(id))) - private val zippedFunDefArgs = tfd.params.map(_.id) zip z3FunDefArgs - - val idToZ3Ids: Map[Identifier, Z3AST] = { - Map(activatingBool -> z3ActivatingBool) ++ - zippedCondVars ++ - zippedExprVars ++ - zippedFunDefArgs - } - - val asZ3Clauses: Seq[Z3AST] = asClauses.map { - cl => solver.toZ3Formula(cl, idToZ3Ids).getOrElse(sys.error("Could not translate to z3. Did you forget --xlang? @"+cl.getPos)) - } - - private val blockers : Map[Identifier,Set[FunctionInvocation]] = { - val idCall = FunctionInvocation(tfd, tfd.params.map(_.toVariable)) - - Map((for((b, es) <- guardedExprs) yield { - val calls = es.foldLeft(Set.empty[FunctionInvocation])((s,e) => s ++ functionCallsOf(e)) - idCall - if(calls.isEmpty) { - None - } else { - Some((b, calls)) - } - }).flatten.toSeq : _*) - } - - val z3Blockers: Map[Z3AST,Set[Z3FunctionInvocation]] = blockers.map { - case (b, funs) => - (idToZ3Ids(b) -> funs.map(fi => Z3FunctionInvocation(fi.tfd, fi.args.map(solver.toZ3Formula(_, idToZ3Ids).get)))) - } - - // We use a cache to create the same boolean variables. - private val cache : MutableMap[Seq[Z3AST],Map[Z3AST,Z3AST]] = MutableMap.empty - - def instantiate(aVar : Z3AST, args : Seq[Z3AST]) : (Seq[Z3AST], Map[Z3AST,Set[Z3FunctionInvocation]]) = { - assert(args.size == tfd.params.size) - - // The "isRealFunDef" part is to prevent evaluation of "fake" - // function templates, as generated from FairZ3Solver. - if(solver.evalGroundApps && isRealFunDef) { - val ga = args.view.map(solver.asGround) - if(ga.forall(_.isDefined)) { - val leonArgs = ga.map(_.get).force - val invocation = FunctionInvocation(tfd, leonArgs) - solver.getEvaluator.eval(invocation) match { - case EvaluationResults.Successful(result) => - val z3Invocation = z3.mkApp(solver.functionDefToDecl(tfd), args: _*) - val z3Value = solver.toZ3Formula(result).get - val asZ3 = z3.mkEq(z3Invocation, z3Value) - return (Seq(asZ3), Map.empty) - - case _ => throw new Exception("Evaluation of ground term should have succeeded.") - } - } - } - // ...end of ground evaluation part. - - val (wasHit,baseIDSubstMap) = cache.get(args) match { - case Some(m) => (true,m) - case None => - val newMap : Map[Z3AST,Z3AST] = - (zippedExprVars ++ zippedCondVars).map(p => p._2 -> solver.idToFreshZ3Id(p._1)).toMap ++ - (z3FunDefArgs zip args) - cache(args) = newMap - (false,newMap) - } - - val idSubstMap : Map[Z3AST,Z3AST] = baseIDSubstMap + (z3ActivatingBool -> aVar) - - val (from, to) = idSubstMap.unzip - val (fromArray, toArray) = (from.toArray, to.toArray) - - val newClauses = asZ3Clauses.map(z3.substitute(_, fromArray, toArray)) - val newBlockers = z3Blockers.map { case (b, funs) => - val bp = if (b == z3ActivatingBool) { - aVar - } else { - idSubstMap(b) - } - - val newFuns = funs.map(fi => fi.copy(args = fi.args.map(z3.substitute(_, fromArray, toArray)))) - - bp -> newFuns - } - - (newClauses, newBlockers) - } - - override def toString : String = { - "Template for def " + tfd.id + "(" + tfd.params.map(a => a.id + " : " + a.tpe).mkString(", ") + ") : " + tfd.returnType + " is :\n" + - " * Activating boolean : " + activatingBool + "\n" + - " * Control booleans : " + condVars.toSeq.map(_.toString).mkString(", ") + "\n" + - " * Expression vars : " + exprVars.toSeq.map(_.toString).mkString(", ") + "\n" + - " * \"Clauses\" : " + "\n " + asClauses.mkString("\n ") + "\n" + - " * Block-map : " + blockers.toString - } -} - -object FunctionTemplate { - def mkTemplate(solver: FairZ3Solver, tfd: TypedFunDef, isRealFunDef : Boolean = true) : FunctionTemplate = { - val condVars : MutableSet[Identifier] = MutableSet.empty - val exprVars : MutableSet[Identifier] = MutableSet.empty - - // Represents clauses of the form: - // id => expr && ... && expr - val guardedExprs : MutableMap[Identifier,Seq[Expr]] = MutableMap.empty - - def storeGuarded(guardVar : Identifier, expr : Expr) : Unit = { - assert(expr.getType == BooleanType) - if(guardedExprs.isDefinedAt(guardVar)) { - val prev : Seq[Expr] = guardedExprs(guardVar) - guardedExprs(guardVar) = expr +: prev - } else { - guardedExprs(guardVar) = Seq(expr) - } - } - - // Group elements that satisfy p toghether - // List(a, a, a, b, c, a, a), with p = _ == a will produce: - // List(List(a,a,a), List(b), List(c), List(a, a)) - def groupWhile[T](p: T => Boolean, l: Seq[T]): Seq[Seq[T]] = { - var res: Seq[Seq[T]] = Nil - - var c = l - - while(!c.isEmpty) { - val (span, rest) = c.span(p) - - if (span.isEmpty) { - res = res :+ Seq(rest.head) - c = rest.tail - } else { - res = res :+ span - c = rest - } - } - - res - } - - def requireDecomposition(e: Expr) = { - exists{ - case (_: FunctionInvocation) | (_: Assert) | (_: Ensuring) | (_: Choose) => true - case _ => false - }(e) - } - - def rec(pathVar : Identifier, expr : Expr) : Expr = { - expr match { - case a @ Assert(cond, _, body) => - storeGuarded(pathVar, rec(pathVar, cond)) - rec(pathVar, body) - - case e @ Ensuring(body, id, post) => - rec(pathVar, Let(id, body, Assert(post, None, Variable(id)))) - - case l @ Let(i, e, b) => - val newExpr : Identifier = FreshIdentifier("lt", true).setType(i.getType) - exprVars += newExpr - val re = rec(pathVar, e) - storeGuarded(pathVar, Equals(Variable(newExpr), re)) - val rb = rec(pathVar, replace(Map(Variable(i) -> Variable(newExpr)), b)) - rb - - case l @ LetTuple(is, e, b) => - val tuple : Identifier = FreshIdentifier("t", true).setType(TupleType(is.map(_.getType))) - exprVars += tuple - val re = rec(pathVar, e) - storeGuarded(pathVar, Equals(Variable(tuple), re)) - - val mapping = for ((id, i) <- is.zipWithIndex) yield { - val newId = FreshIdentifier("ti", true).setType(id.getType) - exprVars += newId - storeGuarded(pathVar, Equals(Variable(newId), TupleSelect(Variable(tuple), i+1))) - - (Variable(id) -> Variable(newId)) - } - - val rb = rec(pathVar, replace(mapping.toMap, b)) - rb - - case m : MatchExpr => sys.error("MatchExpr's should have been eliminated.") - - case i @ Implies(lhs, rhs) => - Implies(rec(pathVar, lhs), rec(pathVar, rhs)) - - case a @ And(parts) => - And(parts.map(rec(pathVar, _))) - - case o @ Or(parts) => - Or(parts.map(rec(pathVar, _))) - - case i @ IfExpr(cond, thenn, elze) => { - if(!requireDecomposition(i)) { - i - } else { - val newBool1 : Identifier = FreshIdentifier("b", true).setType(BooleanType) - val newBool2 : Identifier = FreshIdentifier("b", true).setType(BooleanType) - val newExpr : Identifier = FreshIdentifier("e", true).setType(i.getType) - - condVars += newBool1 - condVars += newBool2 - - exprVars += newExpr - - val crec = rec(pathVar, cond) - val trec = rec(newBool1, thenn) - val erec = rec(newBool2, elze) - - storeGuarded(pathVar, Or(Variable(newBool1), Variable(newBool2))) - storeGuarded(pathVar, Or(Not(Variable(newBool1)), Not(Variable(newBool2)))) - // TODO can we improve this? i.e. make it more symmetrical? - // Probably it's symmetrical enough to Z3. - storeGuarded(pathVar, Iff(Variable(newBool1), crec)) - storeGuarded(newBool1, Equals(Variable(newExpr), trec)) - storeGuarded(newBool2, Equals(Variable(newExpr), erec)) - Variable(newExpr) - } - } - - case c @ Choose(ids, cond) => - val cid = FreshIdentifier("choose", true).setType(c.getType) - exprVars += cid - - val m: Map[Expr, Expr] = if (ids.size == 1) { - Map(Variable(ids.head) -> Variable(cid)) - } else { - ids.zipWithIndex.map{ case (id, i) => Variable(id) -> TupleSelect(Variable(cid), i+1) }.toMap - } - - storeGuarded(pathVar, replace(m, cond)) - Variable(cid) - - case n @ NAryOperator(as, r) => r(as.map(a => rec(pathVar, a))).setType(n.getType) - case b @ BinaryOperator(a1, a2, r) => r(rec(pathVar, a1), rec(pathVar, a2)).setType(b.getType) - case u @ UnaryOperator(a, r) => r(rec(pathVar, a)).setType(u.getType) - case t : Terminal => t - } - } - - // The precondition if it exists. - val prec : Option[Expr] = tfd.precondition.map(p => matchToIfThenElse(p)) - - val newBody : Option[Expr] = tfd.body.map(b => matchToIfThenElse(b)) - - val invocation : Expr = FunctionInvocation(tfd, tfd.params.map(_.toVariable)) - - val invocationEqualsBody : Option[Expr] = newBody match { - case Some(body) if isRealFunDef => - val b : Expr = Equals(invocation, body) - - Some(if(prec.isDefined) { - Implies(prec.get, b) - } else { - b - }) - - case _ => - None - } - - val activatingBool : Identifier = FreshIdentifier("start", true).setType(BooleanType) - - if (isRealFunDef) { - val finalPred : Option[Expr] = invocationEqualsBody.map(expr => rec(activatingBool, expr)) - finalPred.foreach(p => storeGuarded(activatingBool, p)) - } else { - val newFormula = rec(activatingBool, newBody.get) - storeGuarded(activatingBool, newFormula) - } - - // Now the postcondition. - tfd.postcondition match { - case Some((id, post)) => - val newPost : Expr = replace(Map(Variable(id) -> invocation), matchToIfThenElse(post)) - - val postHolds : Expr = - if(tfd.hasPrecondition) { - Implies(prec.get, newPost) - } else { - newPost - } - - val finalPred2 : Expr = rec(activatingBool, postHolds) - storeGuarded(activatingBool, finalPred2) - case None => - - } - - new FunctionTemplate(solver, tfd, activatingBool, Set(condVars.toSeq : _*), Set(exprVars.toSeq : _*), Map(guardedExprs.toSeq : _*), -isRealFunDef) - } -}