Skip to content
Snippets Groups Projects
Commit 0a67f203 authored by Etienne Kneuss's avatar Etienne Kneuss Committed by Etienne Kneuss
Browse files

Unify and simplify function templates and unrolling banks

- Both Expr-based solvers and z3-based solvers rely on same template
  generation routines and unrolling strategy. Templates and unrolling
  bank is refactored in solvers.templates._

- Encoding is done through TemplateEncoder. Performance highly relies on
  an efficient TemplateEncoder.substitute() implementation.
parent a451ad7b
No related branches found
No related tags found
No related merge requests found
...@@ -34,12 +34,10 @@ object SolverFactory { ...@@ -34,12 +34,10 @@ object SolverFactory {
SolverFactory(() => new EnumerationSolver(ctx, program) with TimeoutSolver) SolverFactory(() => new EnumerationSolver(ctx, program) with TimeoutSolver)
case "smt" | "smt-z3" => case "smt" | "smt-z3" =>
val smtf = SolverFactory(() => new SMTLIBSolver(ctx, program) with SMTLIBZ3Target) SolverFactory(() => new UnrollingSolver(ctx, new SMTLIBSolver(ctx, program) with SMTLIBZ3Target) with TimeoutSolver)
SolverFactory(() => new UnrollingSolver(ctx, smtf) with TimeoutSolver)
case "smt-cvc4" => case "smt-cvc4" =>
val smtf = SolverFactory(() => new SMTLIBSolver(ctx, program) with SMTLIBCVC4Target) SolverFactory(() => new UnrollingSolver(ctx, new SMTLIBSolver(ctx, program) with SMTLIBCVC4Target) with TimeoutSolver)
SolverFactory(() => new UnrollingSolver(ctx, smtf) with TimeoutSolver)
case _ => case _ =>
ctx.reporter.fatalError("Unknown solver "+name) ctx.reporter.fatalError("Unknown solver "+name)
......
...@@ -10,152 +10,182 @@ import purescala.Trees._ ...@@ -10,152 +10,182 @@ import purescala.Trees._
import purescala.TreeOps._ import purescala.TreeOps._
import purescala.TypeTrees._ import purescala.TypeTrees._
import solvers.templates._
import utils.Interruptible import utils.Interruptible
import scala.collection.mutable.{Map=>MutableMap} import scala.collection.mutable.{Map=>MutableMap}
class UnrollingSolver(val context: LeonContext, underlyings: SolverFactory[IncrementalSolver]) class UnrollingSolver(val context: LeonContext, underlying: IncrementalSolver)
extends Solver with Interruptible { extends Solver with Interruptible {
private var theConstraint : Option[Expr] = None private var lastCheckResult: (Boolean, Option[Boolean], Option[Map[Identifier,Expr]]) = (false, None, None)
private var theModel : Option[Map[Identifier,Expr]] = None
val reporter = context.reporter val reporter = context.reporter
private var stop: Boolean = false private var interrupted: Boolean = false
def name = "U:"+underlyings.name def name = "U:"+underlying.name
def free {} def free {}
import context.reporter._ var varsInVC = List[Set[Identifier]](Set())
def assertCnstr(expression : Expr) { val templateGenerator = new TemplateGenerator(new TemplateEncoder[Expr] {
if(!theConstraint.isEmpty) { def encodeId(id: Identifier): Expr= {
fatalError("Multiple assertCnstr(...).") Variable(id.freshen)
}
def encodeExpr(bindings: Map[Identifier, Expr])(e: Expr): Expr = {
replaceFromIDs(bindings, e)
} }
theConstraint = Some(expression)
}
def check : Option[Boolean] = theConstraint.map { expr => def substitute(substMap: Map[Expr, Expr]): Expr => Expr = {
val solver = underlyings.getNewSolver (e: Expr) => replace(substMap, e)
}
val template = getTemplate(expr) def not(e: Expr) = Not(e)
def implies(l: Expr, r: Expr) = Implies(l, r)
})
val aVar : Identifier = template.activatingBool val unrollingBank = new UnrollingBank(reporter, templateGenerator)
var allBlockers : Map[Identifier,Set[FunctionInvocation]] = Map.empty
def unrollOneStep() : List[Expr] = { val solver = underlying
val blockersBefore = allBlockers
var newClauses : List[Seq[Expr]] = Nil def assertCnstr(expression: Expr) {
var newBlockers : Map[Identifier,Set[FunctionInvocation]] = Map.empty val freeIds = variablesOf(expression)
for(blocker <- allBlockers.keySet; fi @ FunctionInvocation(tfd, args) <- allBlockers(blocker)) { val freeVars = freeIds.map(_.toVariable: Expr)
val tmpl = getTemplate(tfd)
val (nc, nb) = tmpl.instantiate(blocker, args) val bindings = freeVars.zip(freeVars).toMap
newClauses = nc :: newClauses
newBlockers = newBlockers ++ nb val newClauses = unrollingBank.getClauses(expression, bindings)
//reporter.debug("Unrolling behind "+fi+" ("+nc.size+")")
//for (c <- nc) {
// reporter.debug(" . "+c)
//}
}
allBlockers = newBlockers for (cl <- newClauses) {
newClauses.flatten solver.assertCnstr(cl)
} }
val (nc, nb) = template.instantiate(aVar, template.tfd.params.map(a => Variable(a.id))) varsInVC = (varsInVC.head ++ freeIds) :: varsInVC.tail
}
def push() {
unrollingBank.push()
solver.push()
varsInVC = Set[Identifier]() :: varsInVC
}
def pop(lvl: Int = 1) {
unrollingBank.pop(lvl)
solver.pop(lvl)
varsInVC = varsInVC.drop(lvl)
}
def check: Option[Boolean] = {
genericCheck(Set())
}
def hasFoundAnswer = lastCheckResult._1
allBlockers = nb def foundAnswer(res: Option[Boolean], model: Option[Map[Identifier, Expr]] = None) = {
lastCheckResult = (true, res, model)
}
def genericCheck(assumptions: Set[Expr]): Option[Boolean] = {
lastCheckResult = (false, None, None)
var unrollingCount : Int = 0 while(!hasFoundAnswer && !interrupted) {
var done : Boolean = false reporter.debug(" - Running search...")
var result : Option[Boolean] = None
solver.assertCnstr(Variable(aVar))
solver.assertCnstr(And(nc))
// We're now past the initial step.
while(!done && !stop) {
solver.push() solver.push()
reporter.debug(" - Searching with blocked literals") solver.assertCnstr(And((assumptions ++ unrollingBank.currentBlockers).toSeq))
solver.assertCnstr(And(allBlockers.keySet.toSeq.map(id => Not(id.toVariable)))) val res = solver.check
solver.check match { solver.pop()
case Some(false) => reporter.debug(" - Finished search with blocked literals")
solver.pop(1)
reporter.debug(" - Searching with unblocked literals") res match {
//val open = fullOpenExpr case None =>
solver.check match { reporter.ifDebug { debug =>
case Some(false) => reporter.debug("Solver returned unknown!?")
done = true
result = Some(false)
case r =>
unrollingCount += 1
val model = solver.getModel
reporter.debug(" - Tentative model: "+model)
reporter.debug(" - more unrollings")
val newClauses = unrollOneStep()
reporter.debug(s" - ${newClauses.size} new clauses")
//readLine()
solver.assertCnstr(And(newClauses))
} }
foundAnswer(None)
case Some(true) => case Some(true) => // SAT
val model = solver.getModel val model = solver.getModel
done = true
result = Some(true)
theModel = Some(model)
case None => foundAnswer(Some(true), Some(model))
val model = solver.getModel
done = true case Some(false) if !unrollingBank.canUnroll =>
result = Some(true) foundAnswer(Some(false))
theModel = Some(model)
case Some(false) =>
//debug("UNSAT BECAUSE: "+solver.getUnsatCore.mkString("\n AND \n"))
//debug("UNSAT BECAUSE: "+core.mkString(" AND "))
if (!interrupted) {
reporter.debug(" - Running search without blocked literals (w/o lucky test)")
solver.push()
solver.assertCnstr(And(assumptions.toSeq))
val res2 = solver.check
solver.pop()
res2 match {
case Some(false) =>
//reporter.debug("UNSAT WITHOUT Blockers")
foundAnswer(Some(false))
case Some(true) =>
case None =>
foundAnswer(None)
}
}
if(interrupted) {
foundAnswer(None)
}
if(!hasFoundAnswer) {
reporter.debug("- We need to keep going.")
val toRelease = unrollingBank.getBlockersToUnlock
reporter.debug(" - more unrollings")
val newClauses = unrollingBank.unrollBehind(toRelease)
for(ncl <- newClauses) {
solver.assertCnstr(ncl)
}
reporter.debug(" - finished unrolling")
}
} }
} }
solver.free
result
} getOrElse { if(interrupted) {
Some(true) None
} else {
lastCheckResult._2
}
} }
def getModel : Map[Identifier,Expr] = { def getModel: Map[Identifier,Expr] = {
val vs : Set[Identifier] = theConstraint.map(variablesOf(_)).getOrElse(Set.empty) val allVars = varsInVC.flatten.toSet
theModel.getOrElse(Map.empty).filter(p => vs(p._1)) lastCheckResult match {
case (true, Some(true), Some(m)) =>
m.filterKeys(allVars)
case _ =>
Map()
}
} }
override def interrupt(): Unit = { override def interrupt(): Unit = {
stop = true interrupted = true
} }
override def recoverInterrupt(): Unit = { override def recoverInterrupt(): Unit = {
stop = false interrupted = false
}
private val tfdTemplateCache : MutableMap[TypedFunDef, FunctionTemplate] = MutableMap.empty
private val exprTemplateCache : MutableMap[Expr, FunctionTemplate] = MutableMap.empty
private def getTemplate(tfd: TypedFunDef): FunctionTemplate = {
tfdTemplateCache.getOrElse(tfd, {
val res = FunctionTemplate.mkTemplate(tfd, true)
tfdTemplateCache += tfd -> res
res
})
}
private def getTemplate(body: Expr): FunctionTemplate = {
exprTemplateCache.getOrElse(body, {
val fakeFunDef = new FunDef(FreshIdentifier("fake", true), Nil, body.getType, variablesOf(body).toSeq.map(id => ValDef(id, id.getType)), DefType.MethodDef)
fakeFunDef.body = Some(body)
val res = FunctionTemplate.mkTemplate(fakeFunDef.typed, false)
exprTemplateCache += body -> res
res
})
} }
} }
/* Copyright 2009-2014 EPFL, Lausanne */
package leon
package solvers
package templates
import utils._
import purescala.Common._
import purescala.Trees._
import purescala.Extractors._
import purescala.TreeOps._
import purescala.TypeTrees._
import purescala.Definitions._
import evaluators._
class FunctionTemplate[T](
val tfd: TypedFunDef,
val encoder: TemplateEncoder[T],
activatingBool: Identifier,
condVars: Set[Identifier],
exprVars: Set[Identifier],
guardedExprs: Map[Identifier,Seq[Expr]],
isRealFunDef: Boolean) {
val evalGroundApps = false
val clauses: Seq[Expr] = {
(for((b,es) <- guardedExprs; e <- es) yield {
Implies(Variable(b), e)
}).toSeq
}
val trActivatingBool = encoder.encodeId(activatingBool)
val trFunDefArgs = tfd.params.map( ad => encoder.encodeId(ad.id))
val zippedCondVars = condVars.map(id => (id -> encoder.encodeId(id)))
val zippedExprVars = exprVars.map(id => (id -> encoder.encodeId(id)))
val zippedFunDefArgs = tfd.params.map(_.id) zip trFunDefArgs
val idToTrId: Map[Identifier, T] = {
Map(activatingBool -> trActivatingBool) ++
zippedCondVars ++
zippedExprVars ++
zippedFunDefArgs
}
val encodeExpr = encoder.encodeExpr(idToTrId) _
val trClauses: Seq[T] = clauses.map(encodeExpr)
val trBlockers: Map[T, Set[TemplateCallInfo[T]]] = {
val idCall = TemplateCallInfo[T](tfd, trFunDefArgs)
Map((for((b, es) <- guardedExprs) yield {
val allCalls = es.map(functionCallsOf).flatten.toSet
val calls = (for (c <- allCalls) yield {
TemplateCallInfo[T](c.tfd, c.args.map(encodeExpr))
}) - idCall
if(calls.isEmpty) {
None
} else {
Some(idToTrId(b) -> calls)
}
}).flatten.toSeq : _*)
}
// We use a cache to create the same boolean variables.
var cache = Map[Seq[T], Map[T, T]]()
def instantiate(aVar: T, args: Seq[T]): (Seq[T], Map[T, Set[TemplateCallInfo[T]]]) = {
assert(args.size == tfd.params.size)
// The "isRealFunDef" part is to prevent evaluation of "fake"
// function templates, as generated from FairZ3Solver.
//if(evalGroundApps && isRealFunDef) {
// val ga = args.view.map(solver.asGround)
// if(ga.forall(_.isDefined)) {
// val leonArgs = ga.map(_.get).force
// val invocation = FunctionInvocation(tfd, leonArgs)
// solver.getEvaluator.eval(invocation) match {
// case EvaluationResults.Successful(result) =>
// val z3Invocation = z3.mkApp(solver.functionDefToDecl(tfd), args: _*)
// val z3Value = solver.toZ3Formula(result).get
// val asZ3 = z3.mkEq(z3Invocation, z3Value)
// return (Seq(asZ3), Map.empty)
// case _ => throw new Exception("Evaluation of ground term should have succeeded.")
// }
// }
//}
// ...end of ground evaluation part.
val baseSubstMap = cache.get(args) match {
case Some(m) => m
case None =>
val newMap: Map[T, T] =
(zippedCondVars ++ zippedExprVars).map{ case (id, idT) => idT -> encoder.encodeId(id) }.toMap ++
(trFunDefArgs zip args)
cache += args -> newMap
newMap
}
val substMap : Map[T, T] = baseSubstMap + (trActivatingBool -> aVar)
val substituter = encoder.substitute(substMap)
val newClauses = trClauses.map(substituter)
val newBlockers = trBlockers.map { case (b, funs) =>
val bp = substituter(b)
val newFuns = funs.map(fi => fi.copy(args = fi.args.map(substituter)))
bp -> newFuns
}
(newClauses, newBlockers)
}
override def toString : String = {
"Template for def " + tfd.signature + "(" + tfd.params.map(a => a.id + " : " + a.tpe).mkString(", ") + ") : " + tfd.returnType + " is :\n" +
" * Activating boolean : " + trActivatingBool + "\n" +
" * Control booleans : " + zippedCondVars.map(_._2.toString).mkString(", ") + "\n" +
" * Expression vars : " + zippedExprVars.map(_._2.toString).mkString(", ") + "\n" +
" * Clauses : " + "\n " +trClauses.mkString("\n ") + "\n" +
" * Block-map : " + trBlockers.toString
}
}
/* Copyright 2009-2014 EPFL, Lausanne */
package leon
package solvers
package templates
import purescala.Definitions.TypedFunDef
case class TemplateCallInfo[T](tfd: TypedFunDef, args: Seq[T]) {
override def toString = {
tfd.signature+args.mkString("(", ", ", ")")
}
}
/* Copyright 2009-2014 EPFL, Lausanne */
package leon
package solvers
package templates
import purescala.Common.Identifier
import purescala.Trees.Expr
trait TemplateEncoder[T] {
def encodeId(id: Identifier): T
def encodeExpr(bindings: Map[Identifier, T])(e: Expr): T
def substitute(map: Map[T, T]): T => T
// Encodings needed for unrollingbank
def not(v: T): T
def implies(l: T, r: T): T
}
/* Copyright 2009-2014 EPFL, Lausanne */ /* Copyright 2009-2014 EPFL, Lausanne */
package leon package leon
package solvers.combinators package solvers
package templates
import utils._
import purescala.Common._ import purescala.Common._
import purescala.Trees._ import purescala.Trees._
import purescala.Extractors._ import purescala.Extractors._
...@@ -12,105 +14,45 @@ import purescala.Definitions._ ...@@ -12,105 +14,45 @@ import purescala.Definitions._
import evaluators._ import evaluators._
import scala.collection.mutable.{Set=>MutableSet,Map=>MutableMap} class TemplateGenerator[T](val encoder: TemplateEncoder[T]) {
private var cache = Map[TypedFunDef, FunctionTemplate[T]]()
private var cacheExpr = Map[Expr, FunctionTemplate[T]]()
class FunctionTemplate private( def mkTemplate(body: Expr): FunctionTemplate[T] = {
val tfd : TypedFunDef, if (cacheExpr contains body) {
val activatingBool : Identifier, return cacheExpr(body);
condVars : Set[Identifier],
exprVars : Set[Identifier],
guardedExprs : Map[Identifier,Seq[Expr]],
isRealFunDef : Boolean) {
private val funDefArgsIDs : Seq[Identifier] = tfd.params.map(_.id)
private val asClauses : Seq[Expr] = {
(for((b,es) <- guardedExprs; e <- es) yield {
Implies(Variable(b), e)
}).toSeq
}
val blockers : Map[Identifier,Set[FunctionInvocation]] = {
val idCall = FunctionInvocation(tfd, tfd.params.map(_.toVariable))
Map((for((b, es) <- guardedExprs) yield {
val calls = es.foldLeft(Set.empty[FunctionInvocation])((s,e) => s ++ functionCallsOf(e)) - idCall
if(calls.isEmpty) {
None
} else {
Some((b, calls))
}
}).flatten.toSeq : _*)
}
private def idToFreshID(id : Identifier) : Identifier = {
FreshIdentifier(id.name, true).setType(id.getType)
}
// We use a cache to create the same boolean variables.
private val cache : MutableMap[Seq[Expr],Map[Identifier,Expr]] = MutableMap.empty
def instantiate(aVar : Identifier, args : Seq[Expr]) : (Seq[Expr], Map[Identifier,Set[FunctionInvocation]]) = {
assert(args.size == tfd.params.size)
val (wasHit,baseIDSubstMap) = cache.get(args) match {
case Some(m) => (true,m)
case None =>
val newMap : Map[Identifier,Expr] =
(exprVars ++ condVars).map(id => id -> Variable(idToFreshID(id))).toMap ++
(funDefArgsIDs zip args)
cache(args) = newMap
(false, newMap)
} }
val idSubstMap : Map[Identifier,Expr] = baseIDSubstMap + (activatingBool -> Variable(aVar)) val fakeFunDef = new FunDef(FreshIdentifier("fake", true),
val exprSubstMap : Map[Expr,Expr] = idSubstMap.map(p => (Variable(p._1), p._2)) Nil,
body.getType,
val newClauses = asClauses.map(replace(exprSubstMap, _)) variablesOf(body).toSeq.map(id => ValDef(id, id.getType)))
val newBlockers = blockers.map { case (id, funs) =>
val bp = if (id == activatingBool) {
aVar
} else {
// That's not exactly safe...
idSubstMap(id).asInstanceOf[Variable].id
}
val newFuns = funs.map(fi => fi.copy(args = fi.args.map(replace(exprSubstMap, _))))
bp -> newFuns fakeFunDef.body = Some(body)
}
(newClauses, newBlockers) val res = mkTemplate(fakeFunDef.typed, false)
cacheExpr += body -> res
res
} }
override def toString : String = { def mkTemplate(tfd: TypedFunDef, isRealFunDef: Boolean = true): FunctionTemplate[T] = {
"Template for def " + tfd.id + "(" + tfd.params.map(a => a.id + " : " + a.tpe).mkString(", ") + ") : " + tfd.returnType + " is :\n" + if (cache contains tfd) {
" * Activating boolean : " + activatingBool + "\n" + return cache(tfd)
" * Control booleans : " + condVars.toSeq.map(_.toString).mkString(", ") + "\n" + }
" * Expression vars : " + exprVars.toSeq.map(_.toString).mkString(", ") + "\n" +
" * \"Clauses\" : " + "\n " + asClauses.mkString("\n ") + "\n" +
" * Block-map : " + blockers.toString
}
}
object FunctionTemplate { var condVars = Set[Identifier]()
def mkTemplate(tfd: TypedFunDef, isRealFunDef : Boolean = true) : FunctionTemplate = { var exprVars = Set[Identifier]()
val condVars : MutableSet[Identifier] = MutableSet.empty
val exprVars : MutableSet[Identifier] = MutableSet.empty
// Represents clauses of the form: // Represents clauses of the form:
// id => expr && ... && expr // id => expr && ... && expr
val guardedExprs : MutableMap[Identifier,Seq[Expr]] = MutableMap.empty var guardedExprs = Map[Identifier, Seq[Expr]]()
def storeGuarded(guardVar : Identifier, expr : Expr) : Unit = { def storeGuarded(guardVar : Identifier, expr : Expr) : Unit = {
assert(expr.getType == BooleanType) assert(expr.getType == BooleanType)
if(guardedExprs.isDefinedAt(guardVar)) {
val prev : Seq[Expr] = guardedExprs(guardVar) val prev = guardedExprs.getOrElse(guardVar, Nil)
guardedExprs(guardVar) = expr +: prev
} else { guardedExprs += guardVar -> (expr +: prev)
guardedExprs(guardVar) = Seq(expr)
}
} }
// Group elements that satisfy p toghether // Group elements that satisfy p toghether
...@@ -143,7 +85,7 @@ object FunctionTemplate { ...@@ -143,7 +85,7 @@ object FunctionTemplate {
}(e) }(e)
} }
def rec(pathVar : Identifier, expr : Expr) : Expr = { def rec(pathVar: Identifier, expr: Expr): Expr = {
expr match { expr match {
case a @ Assert(cond, _, body) => case a @ Assert(cond, _, body) =>
storeGuarded(pathVar, rec(pathVar, cond)) storeGuarded(pathVar, rec(pathVar, cond))
...@@ -285,7 +227,14 @@ object FunctionTemplate { ...@@ -285,7 +227,14 @@ object FunctionTemplate {
} }
new FunctionTemplate(tfd, activatingBool, Set(condVars.toSeq : _*), Set(exprVars.toSeq : _*), Map(guardedExprs.toSeq : _*), val template = new FunctionTemplate[T](tfd,
isRealFunDef) encoder,
activatingBool,
Set(condVars.toSeq : _*),
Set(exprVars.toSeq : _*),
Map(guardedExprs.toSeq : _*),
isRealFunDef)
cache += tfd -> template
template
} }
} }
/* Copyright 2009-2014 EPFL, Lausanne */
package leon
package solvers
package templates
import utils._
import purescala.Common._
import purescala.Trees._
import purescala.Extractors._
import purescala.TreeOps._
import purescala.TypeTrees._
import purescala.Definitions._
import evaluators._
class UnrollingBank[T](reporter: Reporter, templateGenerator: TemplateGenerator[T]) {
implicit val debugSection = utils.DebugSectionSolver
private val encoder = templateGenerator.encoder
// Keep which function invocation is guarded by which guard,
// also specify the generation of the blocker.
private var blockersInfoStack = List[Map[T, (Int, Int, T, Set[TemplateCallInfo[T]])]](Map())
// Function instantiations have their own defblocker
private var defBlockers = Map[TemplateCallInfo[T], T]()
def blockersInfo = blockersInfoStack.head
def blockersInfo_= (v: Map[T, (Int, Int, T, Set[TemplateCallInfo[T]])]) = {
blockersInfoStack = v :: blockersInfoStack.tail
}
def push() {
blockersInfoStack = blockersInfo :: blockersInfoStack
}
def pop(lvl: Int) {
blockersInfoStack = blockersInfoStack.drop(lvl)
}
def dumpBlockers = {
blockersInfo.groupBy(_._2._1).toSeq.sortBy(_._1).foreach { case (gen, entries) =>
reporter.debug("--- "+gen)
for (((bast), (gen, origGen, ast, fis)) <- entries) {
reporter.debug(f". $bast%15s ~> "+fis.mkString(", "))
}
}
}
def canUnroll = !blockersInfo.isEmpty
def currentBlockers = blockersInfo.map(_._2._3)
def getBlockersToUnlock: Seq[T] = {
if (!blockersInfo.isEmpty) {
val minGeneration = blockersInfo.values.map(_._1).min
blockersInfo.filter(_._2._1 == minGeneration).toSeq.map(_._1)
} else {
Seq()
}
}
private def registerBlocker(gen: Int, id: T, fis: Set[TemplateCallInfo[T]]) {
val notId = encoder.not(id)
blockersInfo.get(id) match {
case Some((exGen, origGen, _, exFis)) =>
// PS: when recycling `b`s, this assertion becomes dangerous.
// It's better to simply take the max of the generations.
// assert(exGen == gen, "Mixing the same id "+id+" with various generations "+ exGen+" and "+gen)
val minGen = gen min exGen
blockersInfo += id -> (minGen, origGen, notId, fis++exFis)
case None =>
blockersInfo += id -> (gen, gen, notId, fis)
}
}
def getClauses(expr: Expr, bindings: Map[Expr, T]): Seq[T] = {
// OK, now this is subtle. This `getTemplate` will return
// a template for a "fake" function. Now, this template will
// define an activating boolean...
val template = templateGenerator.mkTemplate(expr)
val trArgs = template.tfd.params.map(vd => bindings(Variable(vd.id)))
// ...now this template defines clauses that are all guarded
// by that activating boolean. If that activating boolean is
// undefined (or false) these clauses have no effect...
val (newClauses, newBlocks) =
template.instantiate(template.trActivatingBool, trArgs)
for((i, fis) <- newBlocks) {
registerBlocker(nextGeneration(0), i, fis)
}
// ...so we must force it to true!
template.trActivatingBool +: newClauses
}
def nextGeneration(gen: Int) = gen + 3
def decreaseAllGenerations() = {
for ((block, (gen, origGen, ast, finvs)) <- blockersInfo) {
// We also decrease the original generation here
blockersInfo += block -> (math.max(1,gen-1), math.max(1,origGen-1), ast, finvs)
}
}
def promoteBlocker(b: T) = {
if (blockersInfo contains b) {
val (gen, origGen, ast, fis) = blockersInfo(b)
blockersInfo += b -> (1, origGen, ast, fis)
}
}
def unrollBehind(ids: Seq[T]): Seq[T] = {
assert(ids.forall(id => blockersInfo contains id))
var newClauses : Seq[T] = Seq.empty
for (id <- ids) {
val (gen, _, _, fis) = blockersInfo(id)
blockersInfo = blockersInfo - id
var reintroducedSelf = false
for (fi <- fis) {
var newCls = Seq[T]()
val defBlocker = defBlockers.get(fi) match {
case Some(defBlocker) =>
// we already have defBlocker => f(args) = body
defBlocker
case None =>
// we need to define this defBlocker and link it to definition
val defBlocker = encoder.encodeId(FreshIdentifier("d").setType(BooleanType))
defBlockers += fi -> defBlocker
val template = templateGenerator.mkTemplate(fi.tfd)
reporter.debug(template)
val (newExprs, newBlocks) = template.instantiate(defBlocker, fi.args)
for((i, fis2) <- newBlocks) {
registerBlocker(nextGeneration(gen), i, fis2)
}
newCls ++= newExprs
defBlocker
}
// We connect it to the defBlocker: blocker => defBlocker
if (defBlocker != id) {
newCls ++= List(encoder.implies(id, defBlocker))
}
reporter.debug("Unrolling behind "+fi+" ("+newCls.size+")")
for (cl <- newCls) {
reporter.debug(" . "+cl)
}
newClauses ++= newCls
}
}
reporter.debug(s" - ${newClauses.size} new clauses")
//context.reporter.ifDebug { debug =>
// debug(s" - new clauses:")
// debug("@@@@")
// for (cl <- newClauses) {
// debug(""+cl)
// }
// debug("////")
//}
//dumpBlockers
//readLine()
newClauses
}
}
/* Copyright 2009-2014 EPFL, Lausanne */ /* Copyright 2009-2014 EPFL, Lausanne */
package leon package leon
package solvers.z3 package solvers
package z3
import leon.utils._ import leon.utils._
import z3.scala._ import _root_.z3.scala._
import leon.solvers.{Solver, IncrementalSolver}
import purescala.Common._ import purescala.Common._
import purescala.Definitions._ import purescala.Definitions._
...@@ -16,13 +15,12 @@ import purescala.Extractors._ ...@@ -16,13 +15,12 @@ import purescala.Extractors._
import purescala.TreeOps._ import purescala.TreeOps._
import purescala.TypeTrees._ import purescala.TypeTrees._
import solvers.templates._
import evaluators._ import evaluators._
import termination._ import termination._
import scala.collection.mutable.{Map => MutableMap}
import scala.collection.mutable.{Set => MutableSet}
class FairZ3Solver(val context : LeonContext, val program: Program) class FairZ3Solver(val context : LeonContext, val program: Program)
extends AbstractZ3Solver extends AbstractZ3Solver
with Z3ModelReconstruction with Z3ModelReconstruction
...@@ -133,206 +131,28 @@ class FairZ3Solver(val context : LeonContext, val program: Program) ...@@ -133,206 +131,28 @@ class FairZ3Solver(val context : LeonContext, val program: Program)
} }
} }
private val funDefTemplateCache : MutableMap[TypedFunDef, FunctionTemplate] = MutableMap.empty val templateGenerator = new TemplateGenerator(new TemplateEncoder[Z3AST] {
private val exprTemplateCache : MutableMap[Expr , FunctionTemplate] = MutableMap.empty def encodeId(id: Identifier): Z3AST = {
idToFreshZ3Id(id)
private def getTemplate(tfd: TypedFunDef): FunctionTemplate = {
funDefTemplateCache.getOrElse(tfd, {
val res = FunctionTemplate.mkTemplate(this, tfd, true)
funDefTemplateCache += tfd -> res
res
})
}
private def getTemplate(body: Expr): FunctionTemplate = {
exprTemplateCache.getOrElse(body, {
val fakeFunDef = new FunDef(FreshIdentifier("fake", true), Nil, body.getType, variablesOf(body).toSeq.map(id => ValDef(id, id.getType)), DefType.MethodDef)
fakeFunDef.body = Some(body)
val res = FunctionTemplate.mkTemplate(this, fakeFunDef.typed, false)
exprTemplateCache += body -> res
res
})
}
class UnrollingBank {
// Keep which function invocation is guarded by which guard,
// also specify the generation of the blocker.
private var blockersInfoStack : List[MutableMap[Z3AST,(Int,Int,Z3AST,Set[Z3FunctionInvocation])]] = List(MutableMap())
def blockersInfo = blockersInfoStack.head
def push() {
blockersInfoStack = (MutableMap() ++ blockersInfo) :: blockersInfoStack
}
def pop(lvl: Int) {
blockersInfoStack = blockersInfoStack.drop(lvl)
}
def z3CurrentZ3Blockers = blockersInfo.map(_._2._3)
def finfo(fi: Z3FunctionInvocation) = {
fi.tfd.id.uniqueName+fi.args.mkString("(", ", ", ")")
} }
def dumpBlockers = { def encodeExpr(bindings: Map[Identifier, Z3AST])(e: Expr): Z3AST = {
blockersInfo.groupBy(_._2._1).toSeq.sortBy(_._1).foreach { case (gen, entries) => toZ3Formula(e, bindings).getOrElse {
reporter.debug("--- "+gen) reporter.fatalError("Failed to translate "+e+" to z3 ("+e.getClass+")")
for (((bast), (gen, origGen, ast, fis)) <- entries) {
reporter.debug(f". $bast%15s ~> "+fis.map(finfo).mkString(", "))
}
}
}
def canUnroll = !blockersInfo.isEmpty
def getZ3BlockersToUnlock: Seq[Z3AST] = {
if (!blockersInfo.isEmpty) {
val minGeneration = blockersInfo.values.map(_._1).min
blockersInfo.filter(_._2._1 == minGeneration).toSeq.map(_._1)
} else {
Seq()
} }
} }
private def registerBlocker(gen: Int, id: Z3AST, fis: Set[Z3FunctionInvocation]) { def substitute(substMap: Map[Z3AST, Z3AST]): Z3AST => Z3AST = {
val z3ast = z3.mkNot(id) val (from, to) = substMap.unzip
blockersInfo.get(id) match { val (fromArray, toArray) = (from.toArray, to.toArray)
case Some((exGen, origGen, _, exFis)) =>
// PS: when recycling `b`s, this assertion becomes dangerous.
// It's better to simply take the max of the generations.
// assert(exGen == gen, "Mixing the same id "+id+" with various generations "+ exGen+" and "+gen)
val minGen = gen min exGen
blockersInfo(id) = ((minGen, origGen, z3ast, fis++exFis)) (c: Z3AST) => z3.substitute(c, fromArray, toArray)
case None =>
blockersInfo(id) = ((gen, gen, z3ast, fis))
}
} }
def scanForNewTemplates(expr: Expr): Seq[Z3AST] = { def not(e: Z3AST) = z3.mkNot(e)
// OK, now this is subtle. This `getTemplate` will return def implies(l: Z3AST, r: Z3AST) = z3.mkImplies(l, r)
// a template for a "fake" function. Now, this template will })
// define an activating boolean...
val template = getTemplate(expr)
val z3args = for (vd <- template.tfd.params) yield {
variables.getZ3(Variable(vd.id)) match {
case Some(ast) =>
ast
case None =>
val ast = idToFreshZ3Id(vd.id)
variables += Variable(vd.id) -> ast
ast
}
}
// ...now this template defines clauses that are all guarded
// by that activating boolean. If that activating boolean is
// undefined (or false) these clauses have no effect...
val (newClauses, newBlocks) =
template.instantiate(template.z3ActivatingBool, z3args)
for((i, fis) <- newBlocks) {
registerBlocker(nextGeneration(0), i, fis)
}
// ...so we must force it to true!
template.z3ActivatingBool +: newClauses
}
def nextGeneration(gen: Int) = gen + 3
def decreaseAllGenerations() = {
for ((block, (gen, origGen, ast, finvs)) <- blockersInfo) {
// We also decrease the original generation here
blockersInfo(block) = (math.max(1,gen-1), math.max(1,origGen-1), ast, finvs)
}
}
def promoteBlocker(b: Z3AST) = {
if (blockersInfo contains b) {
val (gen, origGen, ast, finvs) = blockersInfo(b)
blockersInfo(b) = (1, origGen, ast, finvs)
}
}
private var defBlockers = Map[Z3FunctionInvocation, Z3AST]()
def unlock(ids: Seq[Z3AST]) : Seq[Z3AST] = {
assert(ids.forall(id => blockersInfo contains id))
var newClauses : Seq[Z3AST] = Seq.empty
for (id <- ids) {
val (gen, _, _, fis) = blockersInfo(id)
blockersInfo -= id
var reintroducedSelf = false
for (fi <- fis) {
var newCls = Seq[Z3AST]()
val defBlocker = defBlockers.get(fi) match {
case Some(defBlocker) =>
// we already have defBlocker => f(args) = body
defBlocker
case None =>
// we need to define this defBlocker and link it to definition
val defBlocker = z3.mkFreshConst("d", z3.mkBoolSort)
defBlockers += fi -> defBlocker
val template = getTemplate(fi.tfd)
reporter.debug(template)
val (newExprs, newBlocks) = template.instantiate(defBlocker, fi.args)
for((i, fis2) <- newBlocks) {
registerBlocker(nextGeneration(gen), i, fis2)
}
newCls ++= newExprs
defBlocker
}
// We connect it to the defBlocker: blocker => defBlocker
if (defBlocker != id) {
newCls ++= List(z3.mkImplies(id, defBlocker))
}
reporter.debug("Unrolling behind "+fi+" ("+newCls.size+")")
for (cl <- newCls) {
reporter.debug(" . "+cl)
}
newClauses ++= newCls
}
}
context.reporter.debug(s" - ${newClauses.size} new clauses")
//context.reporter.ifDebug { debug =>
// debug(s" - new clauses:")
// debug("@@@@")
// for (cl <- newClauses) {
// debug(""+cl)
// }
// debug("////")
//}
//dumpBlockers
//readLine()
newClauses
}
}
initZ3 initZ3
...@@ -342,7 +162,7 @@ class FairZ3Solver(val context : LeonContext, val program: Program) ...@@ -342,7 +162,7 @@ class FairZ3Solver(val context : LeonContext, val program: Program)
private var frameExpressions = List[List[Expr]](Nil) private var frameExpressions = List[List[Expr]](Nil)
val unrollingBank = new UnrollingBank() val unrollingBank = new UnrollingBank(reporter, templateGenerator)
def push() { def push() {
solver.push() solver.push()
...@@ -370,11 +190,19 @@ class FairZ3Solver(val context : LeonContext, val program: Program) ...@@ -370,11 +190,19 @@ class FairZ3Solver(val context : LeonContext, val program: Program)
var definitiveCore : Set[Expr] = Set.empty var definitiveCore : Set[Expr] = Set.empty
def assertCnstr(expression: Expr) { def assertCnstr(expression: Expr) {
varsInVC ++= variablesOf(expression) val freeVars = variablesOf(expression)
varsInVC ++= freeVars
// We make sure all free variables are registered as variables
freeVars.foreach { v =>
variables.toZ3OrCompute(Variable(v)) {
templateGenerator.encoder.encodeId(v)
}
}
frameExpressions = (expression :: frameExpressions.head) :: frameExpressions.tail frameExpressions = (expression :: frameExpressions.head) :: frameExpressions.tail
val newClauses = unrollingBank.scanForNewTemplates(expression) val newClauses = unrollingBank.getClauses(expression, variables.leonToZ3)
for (cl <- newClauses) { for (cl <- newClauses) {
solver.assertCnstr(cl) solver.assertCnstr(cl)
...@@ -426,7 +254,7 @@ class FairZ3Solver(val context : LeonContext, val program: Program) ...@@ -426,7 +254,7 @@ class FairZ3Solver(val context : LeonContext, val program: Program)
val timer = context.timers.solvers.z3.check.start() val timer = context.timers.solvers.z3.check.start()
solver.push() // FIXME: remove when z3 bug is fixed solver.push() // FIXME: remove when z3 bug is fixed
val res = solver.checkAssumptions((assumptionsAsZ3 ++ unrollingBank.z3CurrentZ3Blockers) :_*) val res = solver.checkAssumptions((assumptionsAsZ3 ++ unrollingBank.currentBlockers) :_*)
solver.pop() // FIXME: remove when z3 bug is fixed solver.pop() // FIXME: remove when z3 bug is fixed
timer.stop() timer.stop()
...@@ -548,11 +376,11 @@ class FairZ3Solver(val context : LeonContext, val program: Program) ...@@ -548,11 +376,11 @@ class FairZ3Solver(val context : LeonContext, val program: Program)
if(!foundDefinitiveAnswer) { if(!foundDefinitiveAnswer) {
reporter.debug("- We need to keep going.") reporter.debug("- We need to keep going.")
val toRelease = unrollingBank.getZ3BlockersToUnlock val toRelease = unrollingBank.getBlockersToUnlock
reporter.debug(" - more unrollings") reporter.debug(" - more unrollings")
val newClauses = unrollingBank.unlock(toRelease) val newClauses = unrollingBank.unrollBehind(toRelease)
for(ncl <- newClauses) { for(ncl <- newClauses) {
solver.assertCnstr(ncl) solver.assertCnstr(ncl)
......
/* Copyright 2009-2014 EPFL, Lausanne */
package leon
package solvers.z3
import purescala.Common._
import purescala.Trees._
import purescala.Extractors._
import purescala.TreeOps._
import purescala.TypeTrees._
import purescala.Definitions._
import evaluators._
import z3.scala._
import scala.collection.mutable.{Set=>MutableSet,Map=>MutableMap}
case class Z3FunctionInvocation(tfd: TypedFunDef, args: Seq[Z3AST]) {
override def toString = tfd.signature + args.mkString("(", ",", ")")
}
class FunctionTemplate private(
solver: FairZ3Solver,
val tfd : TypedFunDef,
activatingBool : Identifier,
condVars : Set[Identifier],
exprVars : Set[Identifier],
guardedExprs : Map[Identifier,Seq[Expr]],
isRealFunDef : Boolean) {
private def isTerminatingForAllInputs : Boolean = (
isRealFunDef
&& !tfd.hasPrecondition
&& solver.getTerminator.terminates(tfd.fd).isGuaranteed
)
private val z3 = solver.z3
private val asClauses : Seq[Expr] = {
(for((b,es) <- guardedExprs; e <- es) yield {
Implies(Variable(b), e)
}).toSeq
}
val z3ActivatingBool = solver.idToFreshZ3Id(activatingBool)
private val z3FunDefArgs = tfd.params.map( ad => solver.idToFreshZ3Id(ad.id))
private val zippedCondVars = condVars.map(id => (id, solver.idToFreshZ3Id(id)))
private val zippedExprVars = exprVars.map(id => (id, solver.idToFreshZ3Id(id)))
private val zippedFunDefArgs = tfd.params.map(_.id) zip z3FunDefArgs
val idToZ3Ids: Map[Identifier, Z3AST] = {
Map(activatingBool -> z3ActivatingBool) ++
zippedCondVars ++
zippedExprVars ++
zippedFunDefArgs
}
val asZ3Clauses: Seq[Z3AST] = asClauses.map {
cl => solver.toZ3Formula(cl, idToZ3Ids).getOrElse(sys.error("Could not translate to z3. Did you forget --xlang? @"+cl.getPos))
}
private val blockers : Map[Identifier,Set[FunctionInvocation]] = {
val idCall = FunctionInvocation(tfd, tfd.params.map(_.toVariable))
Map((for((b, es) <- guardedExprs) yield {
val calls = es.foldLeft(Set.empty[FunctionInvocation])((s,e) => s ++ functionCallsOf(e)) - idCall
if(calls.isEmpty) {
None
} else {
Some((b, calls))
}
}).flatten.toSeq : _*)
}
val z3Blockers: Map[Z3AST,Set[Z3FunctionInvocation]] = blockers.map {
case (b, funs) =>
(idToZ3Ids(b) -> funs.map(fi => Z3FunctionInvocation(fi.tfd, fi.args.map(solver.toZ3Formula(_, idToZ3Ids).get))))
}
// We use a cache to create the same boolean variables.
private val cache : MutableMap[Seq[Z3AST],Map[Z3AST,Z3AST]] = MutableMap.empty
def instantiate(aVar : Z3AST, args : Seq[Z3AST]) : (Seq[Z3AST], Map[Z3AST,Set[Z3FunctionInvocation]]) = {
assert(args.size == tfd.params.size)
// The "isRealFunDef" part is to prevent evaluation of "fake"
// function templates, as generated from FairZ3Solver.
if(solver.evalGroundApps && isRealFunDef) {
val ga = args.view.map(solver.asGround)
if(ga.forall(_.isDefined)) {
val leonArgs = ga.map(_.get).force
val invocation = FunctionInvocation(tfd, leonArgs)
solver.getEvaluator.eval(invocation) match {
case EvaluationResults.Successful(result) =>
val z3Invocation = z3.mkApp(solver.functionDefToDecl(tfd), args: _*)
val z3Value = solver.toZ3Formula(result).get
val asZ3 = z3.mkEq(z3Invocation, z3Value)
return (Seq(asZ3), Map.empty)
case _ => throw new Exception("Evaluation of ground term should have succeeded.")
}
}
}
// ...end of ground evaluation part.
val (wasHit,baseIDSubstMap) = cache.get(args) match {
case Some(m) => (true,m)
case None =>
val newMap : Map[Z3AST,Z3AST] =
(zippedExprVars ++ zippedCondVars).map(p => p._2 -> solver.idToFreshZ3Id(p._1)).toMap ++
(z3FunDefArgs zip args)
cache(args) = newMap
(false,newMap)
}
val idSubstMap : Map[Z3AST,Z3AST] = baseIDSubstMap + (z3ActivatingBool -> aVar)
val (from, to) = idSubstMap.unzip
val (fromArray, toArray) = (from.toArray, to.toArray)
val newClauses = asZ3Clauses.map(z3.substitute(_, fromArray, toArray))
val newBlockers = z3Blockers.map { case (b, funs) =>
val bp = if (b == z3ActivatingBool) {
aVar
} else {
idSubstMap(b)
}
val newFuns = funs.map(fi => fi.copy(args = fi.args.map(z3.substitute(_, fromArray, toArray))))
bp -> newFuns
}
(newClauses, newBlockers)
}
override def toString : String = {
"Template for def " + tfd.id + "(" + tfd.params.map(a => a.id + " : " + a.tpe).mkString(", ") + ") : " + tfd.returnType + " is :\n" +
" * Activating boolean : " + activatingBool + "\n" +
" * Control booleans : " + condVars.toSeq.map(_.toString).mkString(", ") + "\n" +
" * Expression vars : " + exprVars.toSeq.map(_.toString).mkString(", ") + "\n" +
" * \"Clauses\" : " + "\n " + asClauses.mkString("\n ") + "\n" +
" * Block-map : " + blockers.toString
}
}
object FunctionTemplate {
def mkTemplate(solver: FairZ3Solver, tfd: TypedFunDef, isRealFunDef : Boolean = true) : FunctionTemplate = {
val condVars : MutableSet[Identifier] = MutableSet.empty
val exprVars : MutableSet[Identifier] = MutableSet.empty
// Represents clauses of the form:
// id => expr && ... && expr
val guardedExprs : MutableMap[Identifier,Seq[Expr]] = MutableMap.empty
def storeGuarded(guardVar : Identifier, expr : Expr) : Unit = {
assert(expr.getType == BooleanType)
if(guardedExprs.isDefinedAt(guardVar)) {
val prev : Seq[Expr] = guardedExprs(guardVar)
guardedExprs(guardVar) = expr +: prev
} else {
guardedExprs(guardVar) = Seq(expr)
}
}
// Group elements that satisfy p toghether
// List(a, a, a, b, c, a, a), with p = _ == a will produce:
// List(List(a,a,a), List(b), List(c), List(a, a))
def groupWhile[T](p: T => Boolean, l: Seq[T]): Seq[Seq[T]] = {
var res: Seq[Seq[T]] = Nil
var c = l
while(!c.isEmpty) {
val (span, rest) = c.span(p)
if (span.isEmpty) {
res = res :+ Seq(rest.head)
c = rest.tail
} else {
res = res :+ span
c = rest
}
}
res
}
def requireDecomposition(e: Expr) = {
exists{
case (_: FunctionInvocation) | (_: Assert) | (_: Ensuring) | (_: Choose) => true
case _ => false
}(e)
}
def rec(pathVar : Identifier, expr : Expr) : Expr = {
expr match {
case a @ Assert(cond, _, body) =>
storeGuarded(pathVar, rec(pathVar, cond))
rec(pathVar, body)
case e @ Ensuring(body, id, post) =>
rec(pathVar, Let(id, body, Assert(post, None, Variable(id))))
case l @ Let(i, e, b) =>
val newExpr : Identifier = FreshIdentifier("lt", true).setType(i.getType)
exprVars += newExpr
val re = rec(pathVar, e)
storeGuarded(pathVar, Equals(Variable(newExpr), re))
val rb = rec(pathVar, replace(Map(Variable(i) -> Variable(newExpr)), b))
rb
case l @ LetTuple(is, e, b) =>
val tuple : Identifier = FreshIdentifier("t", true).setType(TupleType(is.map(_.getType)))
exprVars += tuple
val re = rec(pathVar, e)
storeGuarded(pathVar, Equals(Variable(tuple), re))
val mapping = for ((id, i) <- is.zipWithIndex) yield {
val newId = FreshIdentifier("ti", true).setType(id.getType)
exprVars += newId
storeGuarded(pathVar, Equals(Variable(newId), TupleSelect(Variable(tuple), i+1)))
(Variable(id) -> Variable(newId))
}
val rb = rec(pathVar, replace(mapping.toMap, b))
rb
case m : MatchExpr => sys.error("MatchExpr's should have been eliminated.")
case i @ Implies(lhs, rhs) =>
Implies(rec(pathVar, lhs), rec(pathVar, rhs))
case a @ And(parts) =>
And(parts.map(rec(pathVar, _)))
case o @ Or(parts) =>
Or(parts.map(rec(pathVar, _)))
case i @ IfExpr(cond, thenn, elze) => {
if(!requireDecomposition(i)) {
i
} else {
val newBool1 : Identifier = FreshIdentifier("b", true).setType(BooleanType)
val newBool2 : Identifier = FreshIdentifier("b", true).setType(BooleanType)
val newExpr : Identifier = FreshIdentifier("e", true).setType(i.getType)
condVars += newBool1
condVars += newBool2
exprVars += newExpr
val crec = rec(pathVar, cond)
val trec = rec(newBool1, thenn)
val erec = rec(newBool2, elze)
storeGuarded(pathVar, Or(Variable(newBool1), Variable(newBool2)))
storeGuarded(pathVar, Or(Not(Variable(newBool1)), Not(Variable(newBool2))))
// TODO can we improve this? i.e. make it more symmetrical?
// Probably it's symmetrical enough to Z3.
storeGuarded(pathVar, Iff(Variable(newBool1), crec))
storeGuarded(newBool1, Equals(Variable(newExpr), trec))
storeGuarded(newBool2, Equals(Variable(newExpr), erec))
Variable(newExpr)
}
}
case c @ Choose(ids, cond) =>
val cid = FreshIdentifier("choose", true).setType(c.getType)
exprVars += cid
val m: Map[Expr, Expr] = if (ids.size == 1) {
Map(Variable(ids.head) -> Variable(cid))
} else {
ids.zipWithIndex.map{ case (id, i) => Variable(id) -> TupleSelect(Variable(cid), i+1) }.toMap
}
storeGuarded(pathVar, replace(m, cond))
Variable(cid)
case n @ NAryOperator(as, r) => r(as.map(a => rec(pathVar, a))).setType(n.getType)
case b @ BinaryOperator(a1, a2, r) => r(rec(pathVar, a1), rec(pathVar, a2)).setType(b.getType)
case u @ UnaryOperator(a, r) => r(rec(pathVar, a)).setType(u.getType)
case t : Terminal => t
}
}
// The precondition if it exists.
val prec : Option[Expr] = tfd.precondition.map(p => matchToIfThenElse(p))
val newBody : Option[Expr] = tfd.body.map(b => matchToIfThenElse(b))
val invocation : Expr = FunctionInvocation(tfd, tfd.params.map(_.toVariable))
val invocationEqualsBody : Option[Expr] = newBody match {
case Some(body) if isRealFunDef =>
val b : Expr = Equals(invocation, body)
Some(if(prec.isDefined) {
Implies(prec.get, b)
} else {
b
})
case _ =>
None
}
val activatingBool : Identifier = FreshIdentifier("start", true).setType(BooleanType)
if (isRealFunDef) {
val finalPred : Option[Expr] = invocationEqualsBody.map(expr => rec(activatingBool, expr))
finalPred.foreach(p => storeGuarded(activatingBool, p))
} else {
val newFormula = rec(activatingBool, newBody.get)
storeGuarded(activatingBool, newFormula)
}
// Now the postcondition.
tfd.postcondition match {
case Some((id, post)) =>
val newPost : Expr = replace(Map(Variable(id) -> invocation), matchToIfThenElse(post))
val postHolds : Expr =
if(tfd.hasPrecondition) {
Implies(prec.get, newPost)
} else {
newPost
}
val finalPred2 : Expr = rec(activatingBool, postHolds)
storeGuarded(activatingBool, finalPred2)
case None =>
}
new FunctionTemplate(solver, tfd, activatingBool, Set(condVars.toSeq : _*), Set(exprVars.toSeq : _*), Map(guardedExprs.toSeq : _*),
isRealFunDef)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment