RefinementEngine.scala 8.11 KiB
package leon
package invariant.engine
import purescala.Common._
import purescala.Definitions._
import purescala.Expressions._
import purescala.ExprOps._
import purescala.TypeOps._
import purescala.Extractors._
import purescala.Types._
import java.io._
import invariant.templateSolvers._
import invariant.factories._
import invariant.util._
import invariant.util.Util._
import invariant.structure._
import FunctionUtils._
import Util._
import PredicateUtil._
import ProgramUtil._
//TODO: the parts of the code that collect the new head functions is ugly and has many side-effects. Fix this.
//TODO: there is a better way to compute heads, which is to consider all guards not previous seen
class RefinementEngine(ctx: InferenceContext, prog: Program, ctrTracker: ConstraintTracker) {
val tru = BooleanLiteral(true)
val reporter = ctx.reporter
val cg = CallGraphUtil.constructCallGraph(prog)
//this count indicates the number of times we unroll a recursive call
private val MAX_UNROLLS = 2
//debugging flags
private val dumpInlinedSummary = false
//print flags
val verbose = false
//the guards of disjuncts that were already processed
private var exploredGuards = Set[Variable]()
//a set of calls that have not been unrolled (these are potential unroll candidates)
//However, these calls except those given by the unspecdCalls have been assumed specifications
private var headCalls = Map[FunDef, Set[Call]]()
def getHeads(fd: FunDef) = if (headCalls.contains(fd)) headCalls(fd) else Set()
def resetHeads(fd: FunDef, heads: Set[Call]) = {
if (headCalls.contains(fd)) {
headCalls -= fd
headCalls += (fd -> heads)
} else {
headCalls += (fd -> heads)
}
}
/**
* This procedure refines the existing abstraction.
* Currently, the refinement happens by unrolling the head functions.
*/
def refineAbstraction(toRefineCalls: Option[Set[Call]]): Set[Call] = {
ctrTracker.getFuncs.flatMap((fd) => {
val formula = ctrTracker.getVC(fd)
val disjuncts = formula.disjunctsInFormula
val newguards = formula.disjunctsInFormula.keySet.diff(exploredGuards)
exploredGuards ++= newguards
val newheads = newguards.flatMap(g => disjuncts(g).collect { case c: Call => c })
val allheads = getHeads(fd) ++ newheads
//unroll each call in the head pointers and in toRefineCalls
val callsToProcess = if (toRefineCalls.isDefined) {
//pick only those calls that have been least unrolled
val relevCalls = allheads.intersect(toRefineCalls.get)
var minCalls = Set[Call]()
var minUnrollings = MAX_UNROLLS
relevCalls.foreach((call) => {
val calldata = formula.callData(call)
val recInvokes = calldata.parents.count(_ == call.fi.tfd.fd)
if (recInvokes < minUnrollings) {
minUnrollings = recInvokes
minCalls = Set(call)
} else if (recInvokes == minUnrollings) {
minCalls += call
}
})
minCalls
} else allheads
if (verbose)
reporter.info("Unrolling: " + callsToProcess.size + "/" + allheads.size)
val unrolls = callsToProcess.foldLeft(Set[Call]())((acc, call) => {
val calldata = formula.callData(call)
val recInvokes = calldata.parents.count(_ == call.fi.tfd.fd)
//if the call is not a recursive call, unroll it unconditionally
if (recInvokes == 0) {
unrollCall(call, formula)
acc + call
} else {
//if the call is recursive, unroll iff the number of times the recursive function occurs in the context is < MAX-UNROLL
if (recInvokes < MAX_UNROLLS) {
unrollCall(call, formula)
acc + call
} else {
//otherwise, do not unroll the call
acc
}
}
//TODO: are there better ways of unrolling ?? Yes. Akask Lal "dag Inlining". Implement that!
})
//update the head functions
resetHeads(fd, allheads.diff(callsToProcess))
unrolls
}).toSet
}
def shouldCreateVC(recFun: FunDef): Boolean = {
if (ctrTracker.hasVC(recFun)) false
else {
//need not create vcs for theory operations
!recFun.isTheoryOperation && recFun.hasTemplate &&
!recFun.annotations.contains("library")
}
}
/**
* Returns a set of unrolled calls and a set of new head functions
* here we unroll the methods in the current abstraction by one step.
* This procedure has side-effects on 'headCalls' and 'callDataMap'
*/
def unrollCall(call: Call, formula: Formula) = {
val fi = call.fi
if (fi.tfd.fd.hasBody) {
//freshen the body and the post
val isRecursive = cg.isRecursive(fi.tfd.fd)
if (isRecursive) {
val recFun = fi.tfd.fd
val recFunTyped = fi.tfd
//check if we need to create a VC formula for the call's target
if (shouldCreateVC(recFun)) {
reporter.info("Creating VC for " + recFun.id)
// instantiate the body with new types
val tparamMap = (recFun.tparams zip recFunTyped.tps).toMap
val paramMap = recFun.params.map{pdef =>
pdef.id -> FreshIdentifier(pdef.id.name, instantiateType(pdef.id.getType, tparamMap))
}.toMap
val newbody = freshenLocals(matchToIfThenElse(recFun.body.get))
val freshBody = instantiateType(newbody, tparamMap, paramMap)
val resvar = if (recFun.hasPostcondition) {
//create a new result variable here for the same reason as freshening the locals,
//which is to avoid variable capturing during unrolling
val origRes = getResId(recFun).get
Variable(FreshIdentifier(origRes.name, recFunTyped.returnType, true))
} else {
//create a new resvar
Variable(FreshIdentifier("res", recFunTyped.returnType, true))
}
val plainBody = Equals(resvar, freshBody)
val bodyExpr =
if (recFun.hasPrecondition) {
val pre = instantiateType(matchToIfThenElse(recFun.precondition.get), tparamMap, paramMap)
And(pre, plainBody)
} else plainBody
//note: here we are only adding the template as the postcondition (other post need not be proved again)
val idmap = formalToActual(Call(resvar, FunctionInvocation(recFunTyped,
paramMap.values.toSeq.map(_.toVariable))))
val postTemp = replace(idmap, recFun.getTemplate)
val vcExpr = ExpressionTransformer.normalizeExpr(And(bodyExpr, Not(postTemp)), ctx.multOp)
ctrTracker.addVC(recFun, vcExpr)
}
//Here, unroll the call into the caller tree
if (verbose) reporter.info("Unrolling " + Equals(call.retexpr, call.fi))
inilineCall(call, formula)
} else {
//here we are unrolling a function without template
if (verbose) reporter.info("Unfolding " + Equals(call.retexpr, call.fi))
inilineCall(call, formula)
}
} else Set()
}
def inilineCall(call: Call, formula: Formula) = {
val tfd = call.fi.tfd
val callee = tfd.fd
if (callee.isBodyVisible) {
//here inline the body and conjoin it with the guard
//Important: make sure we use a fresh body expression here, and freshenlocals
val tparamMap = (callee.tparams zip tfd.tps).toMap
val newbody = freshenLocals(matchToIfThenElse(callee.body.get))
val freshBody = instantiateType(newbody, tparamMap, Map())
val calleeSummary =
Equals(getFunctionReturnVariable(callee), freshBody)
val argmap1 = formalToActual(call)
val inlinedSummary = ExpressionTransformer.normalizeExpr(replace(argmap1, calleeSummary), ctx.multOp)
if (this.dumpInlinedSummary)
println("Inlined Summary: " + inlinedSummary)
//conjoin the summary with the disjunct corresponding to the 'guard'
//note: the parents of the summary are the parents of the call plus the callee function
val calldata = formula.callData(call)
formula.conjoinWithDisjunct(calldata.guard, inlinedSummary, (callee +: calldata.parents))
} else {
if (verbose)
reporter.info(s"Not inlining ${call.fi}: body invisible!")
}
}
}