diff --git a/src/main/scala/leon/invariant/engine/InferInvariantsPhase.scala b/src/main/scala/leon/invariant/engine/InferInvariantsPhase.scala index 78fc476dd8ab9f3b8bc70010a40e021e2424f367..66a4bdfa16d80ea4a69926cc62704b92c6121535 100644 --- a/src/main/scala/leon/invariant/engine/InferInvariantsPhase.scala +++ b/src/main/scala/leon/invariant/engine/InferInvariantsPhase.scala @@ -22,11 +22,12 @@ object InferInvariantsPhase extends SimpleLeonPhase[Program, InferenceReport] { val optNLTimeout = LeonLongOptionDef("nlTimeout", "Timeout after T seconds when trying to solve nonlinear constraints", 20, "s") val optDisableInfer = LeonFlagOptionDef("disableInfer", "Disable automatic inference of auxiliary invariants", false) val optAssumePre = LeonFlagOptionDef("assumepreInf", "Assume preconditions of callees during unrolling", false) + val optStats = LeonFlagOptionDef("stats", "Tracks and prints detailed statistics", false) override val definedOptions: Set[LeonOptionDef[Any]] = Set(optFunctionUnroll, optWithMult, optUseReals, optMinBounds, optInferTemp, optCegis, optStatsSuffix, optVCTimeout, - optNLTimeout, optDisableInfer, optAssumePre) + optNLTimeout, optDisableInfer, optAssumePre, optStats) def apply(ctx: LeonContext, program: Program): InferenceReport = { val inferctx = new InferenceContext(program, ctx) diff --git a/src/main/scala/leon/invariant/engine/InferenceContext.scala b/src/main/scala/leon/invariant/engine/InferenceContext.scala index 96603581a689d20dcb50017c00819691e3f6710f..ef560e9f8e195233dd5ab3eecb1a5b2312ea0d39 100644 --- a/src/main/scala/leon/invariant/engine/InferenceContext.scala +++ b/src/main/scala/leon/invariant/engine/InferenceContext.scala @@ -33,7 +33,7 @@ class InferenceContext(val initProgram: Program, val leonContext: LeonContext) { val withmult = leonContext.findOption(optWithMult).getOrElse(false) val usereals = leonContext.findOption(optUseReals).getOrElse(false) val useCegis: Boolean = leonContext.findOption(optCegis).getOrElse(false) - val dumpStats = false + val dumpStats = leonContext.findOption(optStats).getOrElse(false) // the following options have default values val vcTimeout = leonContext.findOption(optVCTimeout).getOrElse(30L) // in secs diff --git a/src/main/scala/leon/invariant/engine/InferenceEngine.scala b/src/main/scala/leon/invariant/engine/InferenceEngine.scala index e42a545babaecf90854104fd6a0626ef559264b1..1af3c2c56e83125dab7f2da859136667f460c946 100644 --- a/src/main/scala/leon/invariant/engine/InferenceEngine.scala +++ b/src/main/scala/leon/invariant/engine/InferenceEngine.scala @@ -14,6 +14,7 @@ import transformations._ import leon.utils._ import Util._ import ProgramUtil._ +import Stats._ /** * @author ravi @@ -50,46 +51,43 @@ class InferenceEngine(ctx: InferenceContext) extends Interruptible { val reporter = ctx.reporter val program = ctx.inferProgram reporter.info("Running Inference Engine...") - - //register a shutdownhook - if (ctx.dumpStats) { + if (ctx.dumpStats) { //register a shutdownhook sys.ShutdownHookThread({ dumpStats(ctx.statsSuffix) }) } - val t1 = System.currentTimeMillis() - //compute functions to analyze by sorting based on topological order (this is an ascending topological order) - val callgraph = CallGraphUtil.constructCallGraph(program, withTemplates = true) - val functionsToAnalyze = ctx.functionsToInfer match { - case Some(rootfuns) => - val rootset = rootfuns.toSet - val rootfds = program.definedFunctions.filter(fd => rootset(InstUtil.userFunctionName(fd))) - val relfuns = rootfds.flatMap(callgraph.transitiveCallees _).toSet - callgraph.topologicalOrder.filter { fd => relfuns(fd) } - case _ => - callgraph.topologicalOrder - } - //reporter.info("Analysis Order: " + functionsToAnalyze.map(_.id)) var results: Map[FunDef, InferenceCondition] = null - if (!ctx.useCegis) { - results = analyseProgram(program, functionsToAnalyze, defaultVCSolver, progressCallback) - //println("Inferrence did not succeeded for functions: "+functionsToAnalyze.filterNot(succeededFuncs.contains _).map(_.id)) - } else { - var remFuncs = functionsToAnalyze - var b = 200 - val maxCegisBound = 200 - breakable { - while (b <= maxCegisBound) { - Stats.updateCumStats(1, "CegisBoundsTried") - val succeededFuncs = analyseProgram(program, remFuncs, defaultVCSolver, progressCallback) - remFuncs = remFuncs.filterNot(succeededFuncs.contains _) - if (remFuncs.isEmpty) break - b += 5 //increase bounds in steps of 5 + time { + //compute functions to analyze by sorting based on topological order (this is an ascending topological order) + val callgraph = CallGraphUtil.constructCallGraph(program, withTemplates = true) + val functionsToAnalyze = ctx.functionsToInfer match { + case Some(rootfuns) => + val rootset = rootfuns.toSet + val rootfds = program.definedFunctions.filter(fd => rootset(InstUtil.userFunctionName(fd))) + val relfuns = rootfds.flatMap(callgraph.transitiveCallees _).toSet + callgraph.topologicalOrder.filter { fd => relfuns(fd) } + case _ => + callgraph.topologicalOrder + } + //reporter.info("Analysis Order: " + functionsToAnalyze.map(_.id)) + + if (!ctx.useCegis) { + results = analyseProgram(program, functionsToAnalyze, defaultVCSolver, progressCallback) + //println("Inferrence did not succeeded for functions: "+functionsToAnalyze.filterNot(succeededFuncs.contains _).map(_.id)) + } else { + var remFuncs = functionsToAnalyze + var b = 200 + val maxCegisBound = 200 + breakable { + while (b <= maxCegisBound) { + Stats.updateCumStats(1, "CegisBoundsTried") + val succeededFuncs = analyseProgram(program, remFuncs, defaultVCSolver, progressCallback) + remFuncs = remFuncs.filterNot(succeededFuncs.contains _) + if (remFuncs.isEmpty) break + b += 5 //increase bounds in steps of 5 + } + //println("Inferrence did not succeeded for functions: " + remFuncs.map(_.id)) } - //println("Inferrence did not succeeded for functions: " + remFuncs.map(_.id)) } - } - val t2 = System.currentTimeMillis() - Stats.updateCumTime(t2 - t1, "TotalTime") - //dump stats + } { totTime => updateCumTime(totTime, "TotalTime") } if (ctx.dumpStats) { reporter.info("- Dumping statistics") dumpStats(ctx.statsSuffix) @@ -102,13 +100,15 @@ class InferenceEngine(ctx: InferenceContext) extends Interruptible { def dumpStats(statsSuffix: String) = { //pick the module id. - val modid = ctx.inferProgram.modules.last.id - val pw = new PrintWriter(modid + statsSuffix + ".txt") + val modid = ctx.inferProgram.modules.find(_.definedFunctions.exists(!_.isLibrary)).get.id + val filename = modid + statsSuffix + ".txt" + val pw = new PrintWriter(filename) Stats.dumpStats(pw) SpecificStats.dumpOutputs(pw) if (ctx.tightBounds) { SpecificStats.dumpMinimizationStats(pw) } + ctx.reporter.info("Stats dumped to file: "+filename) } def defaultVCSolver = @@ -169,9 +169,7 @@ class InferenceEngine(ctx: InferenceContext) extends Interruptible { // for stats Stats.updateCounter(1, "procs") val solver = vcSolver(funDef, prog) - val t1 = System.currentTimeMillis() - val infRes = solver() - val funcTime = (System.currentTimeMillis() - t1) / 1000.0 + val (infRes, funcTime) = getTime { solver() } infRes match { case Some(InferResult(true, model, inferredFuns)) => val origFds = inferredFuns.map { fd => @@ -193,11 +191,7 @@ class InferenceEngine(ctx: InferenceContext) extends Interruptible { case fd if !invs.contains(fd) && fd.hasTemplate => fd -> fd.getTemplate }.toMap - /*println("Inferred Funs: " + inferredFuns) - println("inv map: " + invs) - println("Templ map: " + funToTmpl)*/ val nextProg = assignTemplateAndCojoinPost(funToTmpl, prog, invs) - // create a inference condition for reporting var first = true inferredFuns.foreach { fd => @@ -221,7 +215,7 @@ class InferenceEngine(ctx: InferenceContext) extends Interruptible { ic } else { val ic = new InferenceCondition(Seq(inv), origFd) - ic.time = if (first) Some(funcTime) else Some(0.0) + ic.time = if (first) Some(funcTime / 1000.0) else Some(0.0) // update analyzed set analyzedSet += (origFd -> ic) first = false @@ -236,7 +230,7 @@ class InferenceEngine(ctx: InferenceContext) extends Interruptible { case _ => reporter.info("- Exhausted all templates, cannot infer invariants") val ic = new InferenceCondition(Seq(), origFun) - ic.time = Some(funcTime) + ic.time = Some(funcTime / 1000.0) analyzedSet += (origFun -> ic) prog } diff --git a/src/main/scala/leon/invariant/engine/SpecInstantiator.scala b/src/main/scala/leon/invariant/engine/SpecInstantiator.scala index 88d7cdb9bb1e64c843bc3ff8b82281f1e56c2fbb..0f3e39265d9b252b776d4ff8a40f720b40d3784c 100644 --- a/src/main/scala/leon/invariant/engine/SpecInstantiator.scala +++ b/src/main/scala/leon/invariant/engine/SpecInstantiator.scala @@ -35,7 +35,6 @@ class SpecInstantiator(ctx: InferenceContext, program: Program, ctrTracker: Cons val tru = BooleanLiteral(true) val axiomFactory = new AxiomFactory(ctx) //handles instantiation of axiomatic specification - //the guards of the set of calls that were already processed protected var exploredGuards = Set[Variable]() @@ -270,7 +269,7 @@ class SpecInstantiator(ctx: InferenceContext, program: Program, ctrTracker: Cons * Note: taking a formula as input may not be necessary. We can store it as a part of the state * TODO: can we use transitivity here to optimize ? */ - def axiomsForCalls(formula: Formula, calls: Set[Call], model: Model): Seq[Constraint] = { + def axiomsForCalls(formula: Formula, calls: Set[Call], model: LazyModel): Seq[Constraint] = { //note: unary axioms need not be instantiated //consider only binary axioms (for (x <- calls; y <- calls) yield (x, y)).foldLeft(Seq[Constraint]())((acc, pair) => { diff --git a/src/main/scala/leon/invariant/factories/TemplateFactory.scala b/src/main/scala/leon/invariant/factories/TemplateFactory.scala index b44fcfa684d66168cad5de3e0cca7d8df5d11f50..ea0f0c63368aaaac74470811265b9f99d46ea883 100644 --- a/src/main/scala/leon/invariant/factories/TemplateFactory.scala +++ b/src/main/scala/leon/invariant/factories/TemplateFactory.scala @@ -13,6 +13,7 @@ import invariant.structure._ import FunctionUtils._ import PredicateUtil._ import ProgramUtil._ +import TypeUtil._ object TemplateIdFactory { //a set of template ids diff --git a/src/main/scala/leon/invariant/factories/TemplateInstantiator.scala b/src/main/scala/leon/invariant/factories/TemplateInstantiator.scala index fcf786d2845ea699fa541c131a81eaabbaf5649f..34bdd7431c8827ca726467f65182ef10fef67e1d 100644 --- a/src/main/scala/leon/invariant/factories/TemplateInstantiator.scala +++ b/src/main/scala/leon/invariant/factories/TemplateInstantiator.scala @@ -26,7 +26,7 @@ object TemplateInstantiator { (v, model(v.id)) }).toMap val instTemplate = instantiate(template, tempVarMap, prettyInv) - val comprTemp = ExpressionTransformer.unFlatten(instTemplate) + val comprTemp = ExpressionTransformer.unflatten(instTemplate) (fd, comprTemp) }) invs diff --git a/src/main/scala/leon/invariant/structure/Constraint.scala b/src/main/scala/leon/invariant/structure/Constraint.scala index b2226b13238d239d1860e09fe0609c8dafa17b91..79c9a42ca40642be9f3e0834bb95865c5978e550 100644 --- a/src/main/scala/leon/invariant/structure/Constraint.scala +++ b/src/main/scala/leon/invariant/structure/Constraint.scala @@ -6,6 +6,8 @@ import purescala.ExprOps._ import purescala.Types._ import invariant.util._ import PredicateUtil._ +import TypeUtil._ +import purescala.Extractors._ trait Constraint { def toExpr: Expr @@ -164,7 +166,6 @@ class LinearConstraint(opr: Seq[Expr] => Expr, cMap: Map[Expr, Expr], constant: //TODO: here we should try to simplify the constant expressions cMap } - val const = constant.map((c) => { //check if constant does not have any variables assert(variablesOf(c).isEmpty) @@ -183,64 +184,41 @@ case class BoolConstraint(e: Expr) extends Constraint { }) e } - - override def toString(): String = { - expr.toString - } - + override def toString(): String = expr.toString def toExpr: Expr = expr } object ADTConstraint { - - def apply(e: Expr): ADTConstraint = e match { - - //is this a tuple or case class select ? - // case Equals(Variable(_), CaseClassSelector(_, _, _)) | Iff(Variable(_), CaseClassSelector(_, _, _)) => { - case Equals(Variable(_), CaseClassSelector(_, _, _)) => { - val ccExpr = ExpressionTransformer.classSelToCons(e) - new ADTConstraint(ccExpr, Some(ccExpr)) - } - // case Equals(Variable(_),TupleSelect(_,_)) | Iff(Variable(_),TupleSelect(_,_)) => { - case Equals(Variable(_), TupleSelect(_, _)) => { - val tpExpr = ExpressionTransformer.tupleSelToCons(e) - new ADTConstraint(tpExpr, Some(tpExpr)) - } - //is this a tuple or case class def ? - case Equals(Variable(_), CaseClass(_, _)) | Equals(Variable(_), Tuple(_)) => { - new ADTConstraint(e, Some(e)) - } - //is this an instanceOf ? - case Equals(v @ Variable(_), ci @ IsInstanceOf(_, _)) => { - new ADTConstraint(e, None, Some(e)) - } - // considering asInstanceOf as equalities - case Equals(lhs @ Variable(_), ci @ AsInstanceOf(rhs @ Variable(_), _)) => { - val eq = Equals(lhs, rhs) - new ADTConstraint(eq, None, None, Some(eq)) - } - //equals and disequalities betweeen variables - case Equals(lhs @ Variable(_), rhs @ Variable(_)) if (lhs.getType != Int32Type && lhs.getType != RealType && lhs.getType != IntegerType) => { - new ADTConstraint(e, None, None, Some(e)) - } - case Not(Equals(lhs @ Variable(_), rhs @ Variable(_))) if (lhs.getType != Int32Type && lhs.getType != RealType && lhs.getType != IntegerType) => { - new ADTConstraint(e, None, None, Some(e)) - } - case _ => { - throw new IllegalStateException("Expression not an ADT constraint: " + e) - } + // note: we consider even type parameters as ADT type + def adtType(e: Expr) = { + val tpe = e.getType + tpe.isInstanceOf[ClassType] || tpe.isInstanceOf[TupleType] || tpe.isInstanceOf[TypeParameter] + } + def apply(e: Expr): ADTConstraint = e match { + case Equals(_: Variable, _: CaseClassSelector | _: TupleSelect) => + new ADTConstraint(e, sel = true) + case Equals(_: Variable, _: CaseClass | _: Tuple) => + new ADTConstraint(e, cons = true) + case Equals(_: Variable, _: IsInstanceOf) => + new ADTConstraint(e, inst = true) + case Equals(lhs @ Variable(_), AsInstanceOf(rhs @ Variable(_), _)) => + new ADTConstraint(Equals(lhs, rhs), comp= true) + case Equals(lhs: Variable, _: Variable) if adtType(lhs) => + new ADTConstraint(e, comp = true) + case Not(Equals(lhs: Variable, _: Variable)) if adtType(lhs) => + new ADTConstraint(e, comp = true) + case _ => + throw new IllegalStateException(s"Expression not an ADT constraint: $e") } } class ADTConstraint(val expr: Expr, - val cons: Option[Expr] = None, - val inst: Option[Expr] = None, - val comp: Option[Expr] = None) extends Constraint { - - override def toString(): String = { - expr.toString - } - + val cons: Boolean = false, + val inst: Boolean = false, + val comp: Boolean = false, + val sel: Boolean = false) extends Constraint { + + override def toString(): String = expr.toString override def toExpr = expr } @@ -291,30 +269,24 @@ case class SetConstraint(expr: Expr) extends Constraint { override def toExpr = expr } -object ConstraintUtil { +object ConstraintUtil { def createConstriant(ie: Expr): Constraint = { ie match { - case Variable(_) | Not(Variable(_)) | BooleanLiteral(_) | Not(BooleanLiteral(_)) => BoolConstraint(ie) - case Equals(v @ Variable(_), fi @ FunctionInvocation(_, _)) => Call(v, fi) - case Equals(Variable(_), CaseClassSelector(_, _, _)) - | Equals(Variable(_), CaseClass(_, _)) - | Equals(Variable(_), TupleSelect(_, _)) - | Equals(Variable(_), Tuple(_)) - | Equals(Variable(_), IsInstanceOf(_, _)) => { - - ADTConstraint(ie) - } + case Variable(_) | Not(Variable(_)) | BooleanLiteral(_) | Not(BooleanLiteral(_)) => + BoolConstraint(ie) + case Equals(v @ Variable(_), fi @ FunctionInvocation(_, _)) => + Call(v, fi) + case Equals(_: Variable, _: CaseClassSelector | _: CaseClass | _: TupleSelect | _: Tuple |_: IsInstanceOf) => + ADTConstraint(ie) case _ if SetConstraint.isSetConstraint(ie) => SetConstraint(ie) - // every other equality will be considered an ADT constraint (including TypeParameter equalities) - case Equals(lhs, rhs) if (lhs.getType != Int32Type && lhs.getType != RealType && lhs.getType != IntegerType) => { + // every non-integer equality will be considered an ADT constraint (including TypeParameter equalities) + case Equals(lhs, rhs) if !isNumericType(lhs.getType) => //println("ADT constraint: "+ie) - ADTConstraint(ie) - } - case Not(Equals(lhs, rhs)) if (lhs.getType != Int32Type && lhs.getType != RealType && lhs.getType != IntegerType) => { - ADTConstraint(ie) - } + ADTConstraint(ie) + case Not(Equals(lhs, rhs)) if !isNumericType(lhs.getType) => + ADTConstraint(ie) case _ => { val simpe = simplifyArithmetic(ie) simpe match { diff --git a/src/main/scala/leon/invariant/structure/Formula.scala b/src/main/scala/leon/invariant/structure/Formula.scala index 83440a9fc9fd6ac54d54ceab6070e4110643ad02..d0995c71647a9c0f0fbdffe7cb2a69281583ae3e 100644 --- a/src/main/scala/leon/invariant/structure/Formula.scala +++ b/src/main/scala/leon/invariant/structure/Formula.scala @@ -19,6 +19,9 @@ import leon.solvers.Model import Util._ import PredicateUtil._ import TVarFactory._ +import ExpressionTransformer._ +import evaluators._ +import invariant.factories._ /** * Data associated with a call @@ -26,6 +29,7 @@ import TVarFactory._ class CallData(val guard : Variable, val parents: List[FunDef]) object Formula { + val debugUnflatten = false // a context for creating blockers val blockContext = newContext } @@ -47,6 +51,7 @@ class Formula(val fd: FunDef, initexpr: Expr, ctx: InferenceContext) { protected var disjuncts = Map[Variable, Seq[Constraint]]() //a mapping from guards to conjunction of atoms protected var conjuncts = Map[Variable, Expr]() //a mapping from guards to disjunction of atoms private var callDataMap = Map[Call, CallData]() //a mapping from a 'call' to the 'guard' guarding the call plus the list of transitive callers of 'call' + private var paramBlockers = Set[Variable]() val firstRoot : Variable = addConstraints(initexpr, List(fd))._1 protected var roots : Seq[Variable] = Seq(firstRoot) //a list of roots, the formula is a conjunction of formula of each root @@ -88,6 +93,19 @@ class Formula(val fd: FunDef, initexpr: Expr, ctx: InferenceContext) { }) } + /** + * Creates disjunct of the form b == exprs and updates the necessary mutable states + */ + def addToDisjunct(exprs: Seq[Expr], isTemplate: Boolean) = { + val g = createTemp("b", BooleanType, blockContext).toVariable + newDisjGuards :+= g + val ctrs = getCtrsFromExprs(g, exprs) + disjuncts += (g -> ctrs) + if(isTemplate) + paramBlockers += g + g + } + val f1 = simplePostTransform { case e@Or(args) => { val newargs = args.map { @@ -98,11 +116,8 @@ class Formula(val fd: FunDef, initexpr: Expr, ctx: InferenceContext) { case And(atms) => atms case _ => Seq(arg) } - val g = createTemp("b", BooleanType, blockContext).toVariable - newDisjGuards :+= g - //println("atoms: "+atoms) - val ctrs = getCtrsFromExprs(g, atoms) - disjuncts += (g -> ctrs) + val g = addToDisjunct(atoms, !getTemplateIds(arg).isEmpty) + //println(s"creating a new OR blocker $g for "+atoms) g } } @@ -113,16 +128,15 @@ class Formula(val fd: FunDef, initexpr: Expr, ctx: InferenceContext) { gor } case e@And(args) => { - val newargs = args.map(arg => if (getTemplateVars(e).isEmpty) { - arg - } else { - //if the expression has template variables then we separate it using guards - val g = createTemp("b", BooleanType, blockContext).toVariable - newDisjGuards :+= g - val ctrs = getCtrsFromExprs(g, Seq(arg)) - disjuncts += (g -> ctrs) - g - }) + //if the expression has template variables then we separate it using guards + val (nonparams, params) = args.partition(getTemplateIds(_).isEmpty) + val newargs = + if (!params.isEmpty) { + val g = addToDisjunct(params, true) + //println(s"creating a new Temp blocker $g for "+arg) + paramBlockers += g + g +: nonparams + } else nonparams createAnd(newargs) } case e => e @@ -141,10 +155,7 @@ class Formula(val fd: FunDef, initexpr: Expr, ctx: InferenceContext) { case And(atms) => atms case _ => Seq(f1) } - val g = createTemp("b", BooleanType, blockContext).toVariable - val ctrs = getCtrsFromExprs(g, atoms) - newDisjGuards :+= g - disjuncts += (g -> ctrs) + val g = addToDisjunct(atoms, !getTemplateIds(f1).isEmpty) g } } @@ -152,9 +163,9 @@ class Formula(val fd: FunDef, initexpr: Expr, ctx: InferenceContext) { } //'satGuard' is required to a guard variable - def pickSatDisjunct(startGaurd : Variable, model: Model): Seq[Constraint] = { + def pickSatDisjunct(startGaurd : Variable, model: LazyModel): Seq[Constraint] = { - def traverseOrs(gd: Variable, model: Model): Seq[Variable] = { + def traverseOrs(gd: Variable, model: LazyModel): Seq[Variable] = { val e @ Or(guards) = conjuncts(gd) //pick one guard that is true val guard = guards.collectFirst { case g @ Variable(id) if (model(id) == tru) => g } @@ -163,7 +174,7 @@ class Formula(val fd: FunDef, initexpr: Expr, ctx: InferenceContext) { guard.get +: traverseAnds(guard.get, model) } - def traverseAnds(gd: Variable, model: Model): Seq[Variable] = { + def traverseAnds(gd: Variable, model: LazyModel): Seq[Variable] = { val ctrs = disjuncts(gd) val guards = ctrs.collect { case BoolConstraint(v @ Variable(_)) if (conjuncts.contains(v) || disjuncts.contains(v)) => v @@ -209,24 +220,24 @@ class Formula(val fd: FunDef, initexpr: Expr, ctx: InferenceContext) { (exprRoot, newGaurds) } + def templateIdsInFormula = paramBlockers.flatMap { g => + getTemplateIds(createAnd(disjuncts(g).map(_.toExpr))) + }.toSet + /** * The first return value is param part and the second one is the * non-parametric part */ def splitParamPart : (Expr, Expr) = { - var paramPart = Seq[Expr]() - var rest = Seq[Expr]() - disjuncts.foreach(entry => { - val (g,ctrs) = entry - val ctrExpr = combiningOp(g,createAnd(ctrs.map(_.toExpr))) - if(getTemplateVars(ctrExpr).isEmpty) - rest :+= ctrExpr - else - paramPart :+= ctrExpr - - }) + val paramPart = paramBlockers.toSeq.map{ g => + combiningOp(g,createAnd(disjuncts(g).map(_.toExpr))) + } + val rest = disjuncts.collect { + case (g, ctrs) if !paramBlockers(g) => + combiningOp(g, createAnd(ctrs.map(_.toExpr))) + }.toSeq val conjs = conjuncts.map((entry) => combiningOp(entry._1, entry._2)).toSeq ++ roots - (createAnd(paramPart), createAnd(rest ++ conjs ++ roots)) + (createAnd(paramPart), createAnd(rest ++ conjs)) } def toExpr : Expr={ @@ -238,8 +249,52 @@ class Formula(val fd: FunDef, initexpr: Expr, ctx: InferenceContext) { createAnd(disjs ++ conjs ++ roots) } + /** + * Creates an unflat expr of the non-param part, + * and returns a constructor for the flat model from unflat models + */ + def toUnflatExpr = { + val paramPart = paramBlockers.toSeq.map{ g => + combiningOp(g,createAnd(disjuncts(g).map(_.toExpr))) + } + // compute variables used in more than one disjunct + var uniqueVars = Set[Identifier]() + var sharedVars = Set[Identifier]() + var freevars = Set[Identifier]() + disjuncts.foreach{ + case (g, ctrs) => + val fvs = ctrs.flatMap(c => variablesOf(c.toExpr)).toSet + val candUniques = fvs -- sharedVars + val newShared = uniqueVars.intersect(candUniques) + freevars ++= fvs + sharedVars ++= newShared + uniqueVars = (uniqueVars ++ candUniques) -- newShared + } + // unflatten rest + var flatIdMap = Map[Identifier, Expr]() + val unflatRest = (disjuncts collect { + case (g, ctrs) if !paramBlockers(g) => + val rhs = createAnd(ctrs.map(_.toExpr)) + val (unflatRhs, idmap) = simpleUnflattenWithMap(rhs, sharedVars, includeFuns = false) + // sanity checks + if (debugUnflatten) { + val rhsvars = variablesOf(rhs) + if(!rhsvars.filter(TemplateIdFactory.IsTemplateIdentifier).isEmpty) + throw new IllegalStateException(s"Non-param part has template identifiers ${toString}") + val seenKeys = flatIdMap.keySet.intersect(rhsvars) + if (!seenKeys.isEmpty) + throw new IllegalStateException(s"flat ids used across clauses $seenKeys in ${toString}") + } + flatIdMap ++= idmap + combiningOp(g, unflatRhs) + }).toSeq + val modelCons = (m: Model, eval: DefaultEvaluator) => new FlatModel(freevars, flatIdMap, m, eval) + val conjs = conjuncts.map{ case(g,rhs) => combiningOp(g, rhs) }.toSeq ++ roots + (createAnd(paramPart), createAnd(unflatRest ++ conjs), modelCons) + } + //unpack the disjunct and conjuncts by removing all guards - def unpackedExpr : Expr = { + def eliminateBlockers : Expr = { //replace all conjunct guards in disjuncts by their mapping val disjs : Map[Expr,Expr] = disjuncts.map((entry) => { val (g,ctrs) = entry @@ -265,8 +320,6 @@ class Formula(val fd: FunDef, initexpr: Expr, ctx: InferenceContext) { val guards = variablesOf(d).collect{ case id@_ if disjuncts.contains(id.toVariable) => id.toVariable } if (guards.isEmpty) entry else { - /*println("Disunct: "+d) - println("guard replaced: "+guards)*/ replacedGuard = true //removeGuards ++= guards (g, replace(unpackedDisjs, d)) @@ -287,4 +340,62 @@ class Formula(val fd: FunDef, initexpr: Expr, ctx: InferenceContext) { val rootStrs = roots.map(_.toString) (disjStrs ++ conjStrs ++ rootStrs).foldLeft("")((acc,str) => acc + "\n" + str) } + + /** + * Functions for stats + */ + def atomsCount = disjuncts.map(_._2.size).sum + conjuncts.map(i => atomNum(i._2)).sum + def funsCount = disjuncts.map(_._2.filter(_.isInstanceOf[Call]).size).sum + + /** + * Functions solely used for debugging + */ + import solvers.SimpleSolverAPI + def checkUnflattening(tempMap: Map[Expr, Expr], sol: SimpleSolverAPI, eval: DefaultEvaluator) = { + // solve unflat formula + val (temp, rest, modelCons) = toUnflatExpr + val packedFor = TemplateInstantiator.instantiate(And(Seq(rest, temp)), tempMap) + val (unflatSat, unflatModel) = sol.solveSAT(packedFor) + // solve flat formula (using the same values for the uncompressed vars) + val flatVCInst = simplifyArithmetic(TemplateInstantiator.instantiate(toExpr, tempMap)) + val modelExpr = SolverUtil.modelToExpr(unflatModel) + val (flatSat, flatModel) = sol.solveSAT(And(flatVCInst, modelExpr)) + //println("Formula: "+unpackedFor) + //println("packed formula: "+packedFor) + val satdisj = + if (unflatSat == Some(true)) + Some(pickSatDisjunct(firstRoot, new SimpleLazyModel(unflatModel))) + else None + if (unflatSat != flatSat) { + if (satdisj.isDefined) { + val preds = satdisj.get.filter { ctr => + if (getTemplateIds(ctr.toExpr).isEmpty) { + val exp = And(Seq(ctr.toExpr, modelExpr)) + sol.solveSAT(exp)._1 == Some(false) + } else false + } + println(s"Conflicting preds: ${preds.map(_.toExpr)}") + } + throw new IllegalStateException(s"VC produces different result with flattening: unflatSat: $unflatSat flatRes: $flatSat") + } else { + if (satdisj.isDefined) { + // print all differences between the models (only along the satisfiable path, values of other variables may not be computable) + val satExpr = createAnd(satdisj.get.map(_.toExpr)) + val lazyModel = modelCons(unflatModel, eval) + val allvars = variablesOf(satExpr) + val elimIds = allvars -- variablesOf(packedFor) + val diffs = allvars.filterNot(TemplateIdFactory.IsTemplateIdentifier).flatMap { + case id if !flatModel.isDefinedAt(id) => + println("Did not find a solver model for: " + id + " elimIds: " + elimIds(id)) + Seq() + case id if lazyModel(id) != flatModel(id) => + println(s"diff $id : flat: ${lazyModel(id)} solver: ${flatModel(id)}" + " elimIds: " + elimIds(id)) + Seq(id) + case _ => Seq() + } + if (!diffs.isEmpty) + throw new IllegalStateException("Model do not agree on diffs: " + diffs) + } + } + } } diff --git a/src/main/scala/leon/invariant/templateSolvers/CegisSolver.scala b/src/main/scala/leon/invariant/templateSolvers/CegisSolver.scala index 15ae673019e9c50eb3ff257fc88114c935230342..ee31167305c42e020fac4dc8bafcec53ff8f8eeb 100644 --- a/src/main/scala/leon/invariant/templateSolvers/CegisSolver.scala +++ b/src/main/scala/leon/invariant/templateSolvers/CegisSolver.scala @@ -14,13 +14,13 @@ import invariant.structure.FunctionUtils._ import leon.invariant.util.RealValuedExprEvaluator._ import PredicateUtil._ import SolverUtil._ +import Stats._ class CegisSolver(ctx: InferenceContext, program: Program, rootFun: FunDef, ctrTracker: ConstraintTracker, timeout: Int, bound: Option[Int] = None) extends TemplateSolver(ctx, rootFun, ctrTracker) { - override def solve(tempIds: Set[Identifier], funcVCs: Map[FunDef, Expr]): (Option[Model], Option[Set[Call]]) = { - + override def solve(tempIds: Set[Identifier], funcs: Seq[FunDef]): (Option[Model], Option[Set[Call]]) = { val initCtr = if (bound.isDefined) { //use a predefined bound on the template variables createAnd(tempIds.map((id) => { @@ -30,10 +30,7 @@ class CegisSolver(ctx: InferenceContext, program: Program, }).toSeq) } else tru - - val funcs = funcVCs.keys - val formula = createOr(funcs.map(funcVCs.apply _).toSeq) - + val formula = createOr(funcs.map(getVCForFun _).toSeq) //using reals with bounds does not converge and also results in overflow val (res, _, model) = (new CegisCore(ctx, program, timeout, this)).solve(tempIds, formula, initCtr, solveAsInt = true) res match { @@ -126,13 +123,10 @@ class CegisCore(ctx: InferenceContext, throw new IllegalStateException("Reals in instFormula: " + instFormula) //println("solving instantiated vcs...") - val t1 = System.currentTimeMillis() val solver1 = new ExtendedUFSolver(context, program) solver1.assertCnstr(instFormula) - val res = solver1.check - val t2 = System.currentTimeMillis() - println("1: " + (if (res.isDefined) "solved" else "timedout") + "... in " + (t2 - t1) / 1000.0 + "s") - + val (res, solTime) = getTime{ solver1.check } + println("1: " + (if (res.isDefined) "solved" else "timedout") + "... in " + solTime / 1000.0 + "s") res match { case Some(true) => { //simplify the tempctrs, evaluate every atom that does not involve a template variable @@ -149,56 +143,42 @@ class CegisCore(ctx: InferenceContext, case e => e }(Not(formula)) solver1.free() - //sanity checks val spuriousProgIds = variablesOf(satctrs).filterNot(TemplateIdFactory.IsTemplateIdentifier _) if (spuriousProgIds.nonEmpty) throw new IllegalStateException("Found a progam variable in tempctrs: " + spuriousProgIds) - val tempctrs = if (!solveAsInt) ExpressionTransformer.IntLiteralToReal(satctrs) else satctrs val newctr = And(tempctrs, prevctr) - //println("Newctr: " +newctr) - if (ctx.dumpStats) { Stats.updateCounterStats(atomNum(newctr), "CegisTemplateCtrs", "CegisIters") } - - //println("solving template constraints...") val t3 = System.currentTimeMillis() val elapsedTime = (t3 - startTime) val solver2 = SimpleSolverAPI(new TimeoutSolverFactory(SolverFactory(() => new ExtendedUFSolver(context, program) with TimeoutSolver), timeoutMillis - elapsedTime)) - val (res1, newModel) = if (solveAsInt) { //convert templates to integers and solve. Finally, re-convert integer models for templates to real models val rti = new RealToInt() val intctr = rti.mapRealToInt(And(newctr, initRealCtr)) val intObjective = rti.mapRealToInt(tempVarSum) - val (res1, intModel) = if (minimizeSum) { - minimizeIntegers(intctr, intObjective) - } else { - solver2.solveSAT(intctr) - } + val (res1, intModel) = + if (minimizeSum) + minimizeIntegers(intctr, intObjective) + else + solver2.solveSAT(intctr) (res1, rti.unmapModel(intModel)) } else { - - /*if(InvarianthasInts(tempctrs)) - throw new IllegalStateException("Template constraints have integer terms: " + tempctrs)*/ if (minimizeSum) { minimizeReals(And(newctr, initRealCtr), tempVarSum) } else { solver2.solveSAT(And(newctr, initRealCtr)) } } - - val t4 = System.currentTimeMillis() - println("2: " + (if (res1.isDefined) "solved" else "timed out") + "... in " + (t4 - t3) / 1000.0 + "s") - + println("2: " + (if (res1.isDefined) "solved" else "timed out") + "... in " + (System.currentTimeMillis() - t3) / 1000.0 + "s") if (res1.isDefined) { if (!res1.get) { //there exists no solution for templates (Some(false), newctr, Model.empty) - } else { //this is for sanity check addModel(newModel) @@ -238,7 +218,6 @@ class CegisCore(ctx: InferenceContext, val debugMinimization = false def minimizeReals(inputCtr: Expr, objective: Expr): (Option[Boolean], Model) = { - //val t1 = System.currentTimeMillis() val sol = SimpleSolverAPI(new TimeoutSolverFactory(SolverFactory(() => new ExtendedUFSolver(context, program) with TimeoutSolver), timeoutMillis)) val (res, model1) = sol.solveSAT(inputCtr) res match { @@ -278,11 +257,8 @@ class CegisCore(ctx: InferenceContext, } val boundCtr = LessEquals(objective, currval) - //val t1 = System.currentTimeMillis() val solver2 = SimpleSolverAPI(new TimeoutSolverFactory(SolverFactory(() => new ExtendedUFSolver(context, program) with TimeoutSolver), timeoutMillis)) val (res, newModel) = sol.solveSAT(And(inputCtr, boundCtr)) - //val t2 = System.currentTimeMillis() - //println((if (res.isDefined) "solved" else "timed out") + "... in " + (t2 - t1) / 1000.0 + "s") res match { case Some(true) => { //here we have a new upper bound @@ -324,7 +300,6 @@ class CegisCore(ctx: InferenceContext, } def minimizeIntegers(inputCtr: Expr, objective: Expr): (Option[Boolean], Model) = { - //val t1 = System.currentTimeMillis() val sol = SimpleSolverAPI(new TimeoutSolverFactory(SolverFactory(() => new ExtendedUFSolver(context, program) with TimeoutSolver), timeoutMillis)) val (res, model1) = sol.solveSAT(inputCtr) res match { @@ -357,11 +332,8 @@ class CegisCore(ctx: InferenceContext, } else 2 * upperBound } val boundCtr = LessEquals(objective, InfiniteIntegerLiteral(currval)) - //val t1 = System.currentTimeMillis() val solver2 = SimpleSolverAPI(new TimeoutSolverFactory(SolverFactory(() => new ExtendedUFSolver(context, program) with TimeoutSolver), timeoutMillis)) val (res, newModel) = sol.solveSAT(And(inputCtr, boundCtr)) - //val t2 = System.currentTimeMillis() - //println((if (res.isDefined) "solved" else "timed out") + "... in " + (t2 - t1) / 1000.0 + "s") res match { case Some(true) => { //here we have a new upper bound diff --git a/src/main/scala/leon/invariant/templateSolvers/FarkasLemmaSolver.scala b/src/main/scala/leon/invariant/templateSolvers/FarkasLemmaSolver.scala index 05067d675264b854ff7f4f279187988696b9e348..1a4d1cd06a916eaf8197b97471e47f920d66f4cb 100644 --- a/src/main/scala/leon/invariant/templateSolvers/FarkasLemmaSolver.scala +++ b/src/main/scala/leon/invariant/templateSolvers/FarkasLemmaSolver.scala @@ -10,6 +10,7 @@ import solvers.SimpleSolverAPI import invariant.engine._ import invariant.util._ import Util._ +import Stats._ import SolverUtil._ import PredicateUtil._ import invariant.structure._ @@ -276,14 +277,14 @@ class FarkasLemmaSolver(ctx: InferenceContext, program: Program) { new SMTLIBZ3Solver(leonctx, program) with TimeoutSolver), timeout * 1000)) } if (verbose) reporter.info("solving...") - val t1 = System.currentTimeMillis() val (res, model) = if (ctx.abort) (None, Model.empty) - else solver.solveSAT(simpctrs) - val t2 = System.currentTimeMillis() - if (verbose) reporter.info((if (res.isDefined) "solved" else "timed out") + "... in " + (t2 - t1) / 1000.0 + "s") - Stats.updateCounterTime((t2 - t1), "NL-solving-time", "disjuncts") - + else { + val (r, solTime) = getTime { solver.solveSAT(simpctrs) } + if (verbose) reporter.info((if (r._1.isDefined) "solved" else "timed out") + "... in " + solTime / 1000.0 + "s") + Stats.updateCounterTime(solTime, "NL-solving-time", "disjuncts") + r + } res match { case Some(true) => // construct assignments for the variables that were removed during nonlinearity reduction diff --git a/src/main/scala/leon/invariant/templateSolvers/NLTemplateSolver.scala b/src/main/scala/leon/invariant/templateSolvers/NLTemplateSolver.scala index 82255bbfaebf40c90138004c5307cc244d62db00..5b46b2df17271a16847b7c9498c9e18f01905c51 100644 --- a/src/main/scala/leon/invariant/templateSolvers/NLTemplateSolver.scala +++ b/src/main/scala/leon/invariant/templateSolvers/NLTemplateSolver.scala @@ -24,17 +24,19 @@ import invariant.util._ import invariant.util.ExpressionTransformer._ import invariant.structure._ import invariant.structure.FunctionUtils._ -import RealValuedExprEvaluator._ +import Stats._ + import Util._ import PredicateUtil._ import SolverUtil._ class NLTemplateSolver(ctx: InferenceContext, program: Program, - rootFun: FunDef, ctrTracker: ConstraintTracker, - minimizer: Option[(Expr, Model) => Model]) - extends TemplateSolver(ctx, rootFun, ctrTracker) { + rootFun: FunDef, ctrTracker: ConstraintTracker, + minimizer: Option[(Expr, Model) => Model]) + extends TemplateSolver(ctx, rootFun, ctrTracker) { //flags controlling debugging + val debugUnflattening = false val debugIncrementalVC = false val debugElimination = false val debugChooseDisjunct = false @@ -42,7 +44,7 @@ class NLTemplateSolver(ctx: InferenceContext, program: Program, val debugAxioms = false val verifyInvariant = false val debugReducedFormula = false - val trackUnpackedVCCTime = false + val trackCompressedVCCTime = false //print flags val verbose = true @@ -62,684 +64,595 @@ class NLTemplateSolver(ctx: InferenceContext, program: Program, private val useIncrementalSolvingForVCs = true private val usePortfolio = false // portfolio has a bug in incremental solving - //this is private mutable state used by initialized during every call to 'solve' and used by 'solveUNSAT' - protected var funcVCs = Map[FunDef, Expr]() - protected var vcSolvers = Map[FunDef, Solver with TimeoutSolver]() - protected var paramParts = Map[FunDef, Expr]() - protected var simpleParts = Map[FunDef, Expr]() - private var lastFoundModel: Option[Model] = None - - //for miscellaneous things - val trackNumericalDisjuncts = false - var numericalDisjuncts = List[Expr]() - - protected def splitVC(fd: FunDef): (Expr, Expr) = { - ctrTracker.getVC(fd).splitParamPart - } - + // an evaluator for extracting models + val defaultEval = new DefaultEvaluator(leonctx, program) + // an evaluator for quicky checking the result of linear predicates + val linearEval = new LinearRelationEvaluator(ctx) + // solver factory val solverFactory = - if (this.usePortfolio) { - if (this.useIncrementalSolvingForVCs) + if (usePortfolio) { + if (useIncrementalSolvingForVCs) throw new IllegalArgumentException("Cannot perform incremental solving with portfolio solvers!") SolverFactory(() => new PortfolioSolver(leonctx, Seq(new SMTLIBCVC4Solver(leonctx, program), new SMTLIBZ3Solver(leonctx, program))) with TimeoutSolver) } else SolverFactory.uninterpreted(leonctx, program) - def initVCSolvers = { - funcVCs.keys.foreach { fd => - val (paramPart, rest) = if (ctx.usereals) { - val (pp, r) = splitVC(fd) - (IntLiteralToReal(pp), IntLiteralToReal(r)) - } else - splitVC(fd) - - if (hasReals(rest) && hasInts(rest)) - throw new IllegalStateException("Non-param Part has both integers and reals: " + rest) - - if (!ctx.abort) { // this is required to ensure that solvers are not created after interrupts - val vcSolver = solverFactory.getNewSolver() - vcSolver.assertCnstr(rest) - - if (debugIncrementalVC) { - assert(getTemplateVars(rest).isEmpty) - println("For function: " + fd.id) - println("Param part: " + paramPart) - /*vcSolver.check match { - case Some(false) => throw new IllegalStateException("Non param-part is unsat "+rest) - case _ => ; - }*/ - } - vcSolvers += (fd -> vcSolver) - paramParts += (fd -> paramPart) - simpleParts += (fd -> rest) - } - } - } - - def freeVCSolvers { - vcSolvers.foreach(entry => entry._2.free) - } + // state for tracking the last model + private var lastFoundModel: Option[Model] = None /** * This function computes invariants belonging to the given templates incrementally. * The result is a mapping from function definitions to the corresponding invariants. */ - override def solve(tempIds: Set[Identifier], funcVCs: Map[FunDef, Expr]): (Option[Model], Option[Set[Call]]) = { - //initialize vcs of functions - this.funcVCs = funcVCs - if (useIncrementalSolvingForVCs) { - initVCSolvers - } - val initModel = if (this.startFromEarlierModel && lastFoundModel.isDefined) { - val candModel = lastFoundModel.get - new Model(tempIds.map(id => - (id -> candModel.getOrElse(id, simplestValue(id.getType)))).toMap) - } else { - new Model(tempIds.map((id) => - (id -> simplestValue(id.getType))).toMap) - } - val sol = solveUNSAT(initModel, tru, Seq(), Set()) - - if (useIncrementalSolvingForVCs) { - freeVCSolvers - } - //set lowerbound map - //TODO: find a way to record lower bound stats - /*if (ctx.tightBounds) - SpecificStats.addLowerBoundStats(rootFun, minimizer.lowerBoundMap, "")*/ - //miscellaneous stuff - if (trackNumericalDisjuncts) { - this.numericalDisjuncts = List[Expr]() - } - sol + override def solve(tempIds: Set[Identifier], funs: Seq[FunDef]): (Option[Model], Option[Set[Call]]) = { + val initModel = + if (this.startFromEarlierModel && lastFoundModel.isDefined) { + val candModel = lastFoundModel.get + new Model(tempIds.map(id => + (id -> candModel.getOrElse(id, simplestValue(id.getType)))).toMap) + } else { + new Model(tempIds.map((id) => + (id -> simplestValue(id.getType))).toMap) + } + val cgSolver = new CGSolver(funs) + val (resModel, seenCalls, lastModel) = cgSolver.solveUNSAT(initModel) + lastFoundModel = Some(lastModel) + cgSolver.free + (resModel, seenCalls) } - //state for minimization - var minStarted = false - var minStartTime: Long = 0 - var minimized = false + /** + * This solver does not use any theories other than UF/ADT. It assumes that other theories are axiomatized in the VC. + * This method can overloaded by the subclasses. + */ + protected def axiomsForTheory(formula: Formula, calls: Set[Call], model: LazyModel): Seq[Constraint] = Seq() - def minimizationInProgress { - if (!minStarted) { - minStarted = true - minStartTime = System.currentTimeMillis() + //a helper method + //TODO: this should also handle reals + protected def doesSatisfyExpr(expr: Expr, model: LazyModel): Boolean = { + val compModel = variablesOf(expr).map { k => k -> model(k) }.toMap + defaultEval.eval(expr, new Model(compModel)).result match { + case Some(BooleanLiteral(true)) => true + case _ => false } } - def minimizationCompleted { - minStarted = false - val mintime = (System.currentTimeMillis() - minStartTime) - /*Stats.updateCounterTime(mintime, "minimization-time", "procs") - Stats.updateCumTime(mintime, "Total-Min-Time")*/ + def splitVC(fd: FunDef) = { + val (paramPart, rest, modCons) = ctrTracker.getVC(fd).toUnflatExpr + if (ctx.usereals) { + (IntLiteralToReal(paramPart), IntLiteralToReal(rest), modCons) + } else (paramPart, rest, modCons) } - def solveUNSAT(model: Model, inputCtr: Expr, solvedDisjs: Seq[Expr], seenCalls: Set[Call]): (Option[Model], Option[Set[Call]]) = { + /** + * Counter-example guided solver + */ + class CGSolver(funs: Seq[FunDef]) { + + //for miscellaneous things + val trackNumericalDisjuncts = false + var numericalDisjuncts = List[Expr]() + + case class FunData(vcSolver: Solver with TimeoutSolver, modelCons: (Model, DefaultEvaluator) => FlatModel, + paramParts: Expr, simpleParts: Expr) + val funInfos = + if (useIncrementalSolvingForVCs) { + funs.foldLeft(Map[FunDef, FunData]()) { + case (acc, fd) => + val (paramPart, rest, modelCons) = splitVC(fd) + if (hasReals(rest) && hasInts(rest)) + throw new IllegalStateException("Non-param Part has both integers and reals: " + rest) + if (debugIncrementalVC) { + assert(getTemplateVars(rest).isEmpty) + println("For function: " + fd.id) + println("Param part: " + paramPart) + } + if (!ctx.abort) { // this is required to ensure that solvers are not created after interrupts + val vcSolver = solverFactory.getNewSolver() + vcSolver.assertCnstr(rest) + acc + (fd -> FunData(vcSolver, modelCons, paramPart, rest)) + } else acc + } + } else Map[FunDef, FunData]() - if (verbose) { - reporter.info("Candidate invariants") - val candInvs = getAllInvariants(model) - candInvs.foreach((entry) => reporter.info(entry._1.id + "-->" + entry._2)) + def free = { + if (useIncrementalSolvingForVCs) + funInfos.foreach(entry => entry._2.vcSolver.free) + if (trackNumericalDisjuncts) + this.numericalDisjuncts = List[Expr]() } - if (this.startFromEarlierModel) this.lastFoundModel = Some(model) + //state for minimization + var minStarted = false + var minStartTime: Long = 0 + var minimized = false - val (res, newCtr, newModel, newdisjs, newcalls) = invalidateSATDisjunct(inputCtr, model) - res match { - case _ if ctx.abort => - (None, None) - case None => { - //here, we cannot proceed and have to return unknown - //However, we can return the calls that need to be unrolled - (None, Some(seenCalls ++ newcalls)) - } - case Some(false) => { - //here, the vcs are unsatisfiable when instantiated with the invariant - if (minimizer.isDefined) { - //for stats - minimizationInProgress - if (minimized) { - minimizationCompleted - (Some(model), None) - } else { - val minModel = minimizer.get(inputCtr, model) - minimized = true - if (minModel == model) { - minimizationCompleted - (Some(model), None) - } else { - solveUNSAT(minModel, inputCtr, solvedDisjs, seenCalls) - } - } - } else { - (Some(model), None) - } - } - case Some(true) => { - //here, we have found a new candidate invariant. Hence, the above process needs to be repeated - minimized = false - solveUNSAT(newModel, newCtr, solvedDisjs ++ newdisjs, seenCalls ++ newcalls) + def minimizationInProgress { + if (!minStarted) { + minStarted = true + minStartTime = System.currentTimeMillis() } } - } - //TODO: this code does too much imperative update. - //TODO: use guards to block a path and not use the path itself - def invalidateSATDisjunct(inputCtr: Expr, model: Model): (Option[Boolean], Expr, Model, Seq[Expr], Set[Call]) = { - - val tempIds = model.map(_._1) - val tempVarMap: Map[Expr, Expr] = model.map((elem) => (elem._1.toVariable, elem._2)).toMap - val inputSize = atomNum(inputCtr) - - var disjsSolvedInIter = Seq[Expr]() - var callsInPaths = Set[Call]() - var conflictingFuns = funcVCs.keySet - //mapping from the functions to the counter-example paths that were seen - var seenPaths = MutableMap[FunDef, Seq[Expr]]() - def updateSeenPaths(fd: FunDef, cePath: Expr): Unit = { - if (seenPaths.contains(fd)) { - seenPaths.update(fd, cePath +: seenPaths(fd)) - } else { - seenPaths += (fd -> Seq(cePath)) - } + def minimizationCompleted { + minStarted = false + val mintime = (System.currentTimeMillis() - minStartTime) + /*Stats.updateCounterTime(mintime, "minimization-time", "procs") + Stats.updateCumTime(mintime, "Total-Min-Time")*/ } - def invalidateDisjRecr(prevCtr: Expr): (Option[Boolean], Expr, Model) = { - - Stats.updateCounter(1, "disjuncts") - - var blockedCEs = false - var confFunctions = Set[FunDef]() - var confDisjuncts = Seq[Expr]() - - val newctrsOpt = conflictingFuns.foldLeft(Some(Seq()): Option[Seq[Expr]]) { - case (None, _) => None - case (Some(acc), fd) => - val disableCounterExs = if (seenPaths.contains(fd)) { - blockedCEs = true - Not(createOr(seenPaths(fd))) - } else tru - if (ctx.abort) None - else - getUNSATConstraints(fd, model, disableCounterExs) match { - case None => - None - case Some(((disjunct, callsInPath), ctrsForFun)) => - if (ctrsForFun == tru) Some(acc) - else { - confFunctions += fd - confDisjuncts :+= disjunct - callsInPaths ++= callsInPath - //instantiate the disjunct - val cePath = simplifyArithmetic(TemplateInstantiator.instantiate(disjunct, tempVarMap)) - //some sanity checks - if (variablesOf(cePath).exists(TemplateIdFactory.IsTemplateIdentifier _)) - throw new IllegalStateException("Found template identifier in counter-example disjunct: " + cePath) - updateSeenPaths(fd, cePath) - Some(acc :+ ctrsForFun) - } - } + def solveUNSAT(initModel: Model): (Option[Model], Option[Set[Call]], Model) = solveUNSAT(initModel, tru, Seq(), Set()) + + def solveUNSAT(model: Model, inputCtr: Expr, solvedDisjs: Seq[Expr], seenCalls: Set[Call]): (Option[Model], Option[Set[Call]], Model) = { + if (verbose) { + reporter.info("Candidate invariants") + val candInvs = getAllInvariants(model) + candInvs.foreach((entry) => reporter.info(entry._1.id + "-->" + entry._2)) } - newctrsOpt match { + val (res, newCtr, newModel, newdisjs, newcalls) = invalidateSATDisjunct(inputCtr, model) + res match { + case _ if ctx.abort => + (None, None, model) case None => - // give up, the VC cannot be decided - (None, tru, Model.empty) - case Some(newctrs) => - //update conflicting functions - conflictingFuns = confFunctions - if (newctrs.isEmpty) { - if (!blockedCEs) { - //yes, hurray,found an inductive invariant - (Some(false), prevCtr, model) + //here, we cannot proceed and have to return unknown + //However, we can return the calls that need to be unrolled + (None, Some(seenCalls ++ newcalls), model) + case Some(false) => + //here, the vcs are unsatisfiable when instantiated with the invariant + if (minimizer.isDefined) { + //for stats + minimizationInProgress + if (minimized) { + minimizationCompleted + (Some(model), None, model) } else { - //give up, only hard paths remaining - reporter.info("- Exhausted all easy paths !!") - reporter.info("- Number of remaining hard paths: " + seenPaths.values.foldLeft(0)((acc, elem) => acc + elem.size)) - //TODO: what to unroll here ? - (None, tru, Model.empty) + val minModel = minimizer.get(inputCtr, model) + minimized = true + if (minModel == model) { + minimizationCompleted + (Some(model), None, model) + } else { + solveUNSAT(minModel, inputCtr, solvedDisjs, seenCalls) + } } } else { - //check that the new constraints does not have any reals - val newPart = createAnd(newctrs) - val newSize = atomNum(newPart) - Stats.updateCounterStats((newSize + inputSize), "NLsize", "disjuncts") - if (verbose) - reporter.info("# of atomic predicates: " + newSize + " + " + inputSize) - //here we need to solve for the newctrs + inputCtrs - val combCtr = And(prevCtr, newPart) - val (res, newModel) = farkasSolver.solveFarkasConstraints(combCtr) - res match { - case _ if ctx.abort => - // stop immediately + (Some(model), None, model) + } + case Some(true) => + //here, we have found a new candidate invariant. Hence, the above process needs to be repeated + minimized = false + solveUNSAT(newModel, newCtr, solvedDisjs ++ newdisjs, seenCalls ++ newcalls) + } + } + + //TODO: this code does too much imperative update. + //TODO: use guards to block a path and not use the path itself + def invalidateSATDisjunct(inputCtr: Expr, model: Model): (Option[Boolean], Expr, Model, Seq[Expr], Set[Call]) = { + + val tempIds = model.map(_._1) + val tempVarMap: Map[Expr, Expr] = model.map((elem) => (elem._1.toVariable, elem._2)).toMap + val inputSize = atomNum(inputCtr) + + var disjsSolvedInIter = Seq[Expr]() + var callsInPaths = Set[Call]() + var conflictingFuns = funs.toSet + //mapping from the functions to the counter-example paths that were seen + var seenPaths = MutableMap[FunDef, Seq[Expr]]() + def updateSeenPaths(fd: FunDef, cePath: Expr): Unit = { + if (seenPaths.contains(fd)) { + seenPaths.update(fd, cePath +: seenPaths(fd)) + } else { + seenPaths += (fd -> Seq(cePath)) + } + } + + def invalidateDisjRecr(prevCtr: Expr): (Option[Boolean], Expr, Model) = { + + Stats.updateCounter(1, "disjuncts") + + var blockedCEs = false + var confFunctions = Set[FunDef]() + var confDisjuncts = Seq[Expr]() + val newctrsOpt = conflictingFuns.foldLeft(Some(Seq()): Option[Seq[Expr]]) { + case (None, _) => None + case (Some(acc), fd) => + val disableCounterExs = if (seenPaths.contains(fd)) { + blockedCEs = true + Not(createOr(seenPaths(fd))) + } else tru + if (ctx.abort) None + else + getUNSATConstraints(fd, model, disableCounterExs) match { + case None => + None + case Some(((disjunct, callsInPath), ctrsForFun)) => + if (ctrsForFun == tru) Some(acc) + else { + confFunctions += fd + confDisjuncts :+= disjunct + callsInPaths ++= callsInPath + //instantiate the disjunct + val cePath = simplifyArithmetic(TemplateInstantiator.instantiate(disjunct, tempVarMap)) + //some sanity checks + if (variablesOf(cePath).exists(TemplateIdFactory.IsTemplateIdentifier _)) + throw new IllegalStateException("Found template identifier in counter-example disjunct: " + cePath) + updateSeenPaths(fd, cePath) + Some(acc :+ ctrsForFun) + } + } + } + newctrsOpt match { + case None => + // give up, the VC cannot be decided + (None, tru, Model.empty) + case Some(newctrs) => + //update conflicting functions + conflictingFuns = confFunctions + if (newctrs.isEmpty) { + if (!blockedCEs) { + //yes, hurray,found an inductive invariant + (Some(false), prevCtr, model) + } else { + //give up, only hard paths remaining + reporter.info("- Exhausted all easy paths !!") + reporter.info("- Number of remaining hard paths: " + seenPaths.values.foldLeft(0)((acc, elem) => acc + elem.size)) + //TODO: what to unroll here ? (None, tru, Model.empty) - case None => { - //here we have timed out while solving the non-linear constraints - if (verbose) - if (!this.disableCegis) - reporter.info("NLsolver timed-out on the disjunct... starting cegis phase...") - else - reporter.info("NLsolver timed-out on the disjunct... blocking this disjunct...") - if (!this.disableCegis) { - val (cres, cctr, cmodel) = solveWithCegis(tempIds.toSet, createOr(confDisjuncts), inputCtr, Some(model)) - cres match { - case Some(true) => { - disjsSolvedInIter ++= confDisjuncts - (Some(true), And(inputCtr, cctr), cmodel) - } - case Some(false) => { - disjsSolvedInIter ++= confDisjuncts - //here also return the calls that needs to be unrolled - (None, fls, Model.empty) - } - case _ => { - if (verbose) reporter.info("retrying...") - Stats.updateCumStats(1, "retries") - //disable this disjunct and retry but, use the inputCtrs + the constraints generated by cegis from the next iteration - invalidateDisjRecr(And(inputCtr, cctr)) + } + } else { + //check that the new constraints does not have any reals + val newPart = createAnd(newctrs) + val newSize = atomNum(newPart) + Stats.updateCounterStats((newSize + inputSize), "NLsize", "disjuncts") + if (verbose) + reporter.info("# of atomic predicates: " + newSize + " + " + inputSize) + val combCtr = And(prevCtr, newPart) + val (res, newModel) = farkasSolver.solveFarkasConstraints(combCtr) + res match { + case _ if ctx.abort => + // stop immediately + (None, tru, Model.empty) + case None => { + //here we have timed out while solving the non-linear constraints + if (verbose) + if (!disableCegis) + reporter.info("NLsolver timed-out on the disjunct... starting cegis phase...") + else + reporter.info("NLsolver timed-out on the disjunct... blocking this disjunct...") + if (!disableCegis) { + val (cres, cctr, cmodel) = solveWithCegis(tempIds.toSet, createOr(confDisjuncts), inputCtr, Some(model)) + cres match { + case Some(true) => { + disjsSolvedInIter ++= confDisjuncts + (Some(true), And(inputCtr, cctr), cmodel) + } + case Some(false) => { + disjsSolvedInIter ++= confDisjuncts + //here also return the calls that needs to be unrolled + (None, fls, Model.empty) + } + case _ => { + if (verbose) reporter.info("retrying...") + Stats.updateCumStats(1, "retries") + //disable this disjunct and retry but, use the inputCtrs + the constraints generated by cegis from the next iteration + invalidateDisjRecr(And(inputCtr, cctr)) + } } + } else { + if (verbose) reporter.info("retrying...") + Stats.updateCumStats(1, "retries") + invalidateDisjRecr(inputCtr) } - } else { - if (verbose) reporter.info("retrying...") - Stats.updateCumStats(1, "retries") - invalidateDisjRecr(inputCtr) } - } - case Some(false) => { - //reporter.info("- Number of explored paths (of the DAG) in this unroll step: " + exploredPaths) - disjsSolvedInIter ++= confDisjuncts - (None, fls, Model.empty) - } - case Some(true) => { - disjsSolvedInIter ++= confDisjuncts - //new model may not have mappings for all the template variables, hence, use the mappings from earlier models - val compModel = new Model(tempIds.map((id) => { - if (newModel.isDefinedAt(id)) - (id -> newModel(id)) - else - (id -> model(id)) - }).toMap) - (Some(true), combCtr, compModel) + case Some(false) => { + //reporter.info("- Number of explored paths (of the DAG) in this unroll step: " + exploredPaths) + disjsSolvedInIter ++= confDisjuncts + (None, fls, Model.empty) + } + case Some(true) => { + disjsSolvedInIter ++= confDisjuncts + //new model may not have mappings for all the template variables, hence, use the mappings from earlier models + val compModel = new Model(tempIds.map((id) => { + if (newModel.isDefinedAt(id)) + (id -> newModel(id)) + else + (id -> model(id)) + }).toMap) + (Some(true), combCtr, compModel) + } } } - } + } } + val (res, newctr, newmodel) = invalidateDisjRecr(inputCtr) + (res, newctr, newmodel, disjsSolvedInIter, callsInPaths) } - val (res, newctr, newmodel) = invalidateDisjRecr(inputCtr) - (res, newctr, newmodel, disjsSolvedInIter, callsInPaths) - } - - def solveWithCegis(tempIds: Set[Identifier], expr: Expr, precond: Expr, initModel: Option[Model]): (Option[Boolean], Expr, Model) = { - - val cegisSolver = new CegisCore(ctx, program, timeout.toInt, this) - val (res, ctr, model) = cegisSolver.solve(tempIds, expr, precond, solveAsInt = false, initModel) - if (res.isEmpty) - reporter.info("cegis timed-out on the disjunct...") - (res, ctr, model) - } - protected def instantiateTemplate(e: Expr, tempVarMap: Map[Expr, Expr]): Expr = { - if (ctx.usereals) replace(tempVarMap, e) - else - simplifyArithmetic(TemplateInstantiator.instantiate(e, tempVarMap)) - } - - /** - * Constructs a quantifier-free non-linear constraint for unsatisfiability - */ - def getUNSATConstraints(fd: FunDef, inModel: Model, disableCounterExs: Expr): Option[((Expr, Set[Call]), Expr)] = { - - val tempVarMap: Map[Expr, Expr] = inModel.map((elem) => (elem._1.toVariable, elem._2)).toMap - val solver = - if (this.useIncrementalSolvingForVCs) vcSolvers(fd) - else solverFactory.getNewSolver() - val instExpr = if (this.useIncrementalSolvingForVCs) { - val instParamPart = instantiateTemplate(this.paramParts(fd), tempVarMap) - And(instParamPart, disableCounterExs) - } else { - val instVC = instantiateTemplate(funcVCs(fd), tempVarMap) - And(instVC, disableCounterExs) - } - //For debugging - if (this.dumpInstantiatedVC) { - // println("Plain vc: "+funcVCs(fd)) - val wr = new PrintWriter(new File("formula-dump.txt")) - val fullExpr = if (this.useIncrementalSolvingForVCs) { - And(simpleParts(fd), instExpr) - } else - instExpr - // println("Instantiated VC of " + fd.id + " is: " + fullExpr) - wr.println("Function name: " + fd.id) - wr.println("Formula expr: ") - ExpressionTransformer.PrintWithIndentation(wr, fullExpr) - wr.flush() - wr.close() - } - if (hasMixedIntReals(instExpr)) { - throw new IllegalStateException("Instantiated VC of " + fd.id + " contains mixed integer/reals: " + instExpr) + def solveWithCegis(tempIds: Set[Identifier], expr: Expr, precond: Expr, initModel: Option[Model]): (Option[Boolean], Expr, Model) = { + val cegisSolver = new CegisCore(ctx, program, timeout.toInt, NLTemplateSolver.this) + val (res, ctr, model) = cegisSolver.solve(tempIds, expr, precond, solveAsInt = false, initModel) + if (res.isEmpty) + reporter.info("cegis timed-out on the disjunct...") + (res, ctr, model) } - //reporter.info("checking VC inst ...") - var t1 = System.currentTimeMillis() - solver.setTimeout(timeout * 1000) - val (res, model) = if (this.useIncrementalSolvingForVCs) { - solver.push - solver.assertCnstr(instExpr) - // new InterruptOnSignal(solver).interruptOnSignal(ctx.abort)( - val solRes = solver.check match { - case _ if ctx.abort => - (None, Model.empty) - case r @ Some(true) => - (r, solver.getModel) - case r => (r, Model.empty) - } - solver.pop() - solRes - } else { - SimpleSolverAPI(SolverFactory(() => solver)).solveSAT(instExpr) - } - val vccTime = (System.currentTimeMillis() - t1) - - if (verbose) reporter.info("checked VC inst... in " + vccTime / 1000.0 + "s") - Stats.updateCounterTime(vccTime, "VC-check-time", "disjuncts") - Stats.updateCumTime(vccTime, "TotalVCCTime") - - //for debugging - if (this.trackUnpackedVCCTime) { - val upVCinst = unFlatten(simplifyArithmetic( - TemplateInstantiator.instantiate(ctrTracker.getVC(fd).unpackedExpr, tempVarMap))) - Stats.updateCounterStats(atomNum(upVCinst), "UP-VC-size", "disjuncts") - t1 = System.currentTimeMillis() - val (res2, _) = SimpleSolverAPI(SolverFactory(() => solverFactory.getNewSolver())).solveSAT(upVCinst) - val unpackedTime = System.currentTimeMillis() - t1 - if (res != res2) { - throw new IllegalStateException("Unpacked VC produces different result: " + upVCinst) - } - Stats.updateCumTime(unpackedTime, "TotalUPVCCTime") - reporter.info("checked UP-VC inst... in " + unpackedTime / 1000.0 + "s") + protected def instantiateTemplate(e: Expr, tempVarMap: Map[Expr, Expr]): Expr = { + if (ctx.usereals) replace(tempVarMap, e) + else + simplifyArithmetic(TemplateInstantiator.instantiate(e, tempVarMap)) } - t1 = System.currentTimeMillis() - res match { - case None => { - //throw new IllegalStateException("cannot check the satisfiability of " + funcVCs(fd)) - None + /** + * Constructs a quantifier-free non-linear constraint for unsatisfiability + */ + def getUNSATConstraints(fd: FunDef, inModel: Model, disableCounterExs: Expr): Option[((Expr, Set[Call]), Expr)] = { + + val tempVarMap: Map[Expr, Expr] = inModel.map((elem) => (elem._1.toVariable, elem._2)).toMap + val (solver, instExpr, modelCons) = + if (useIncrementalSolvingForVCs) { + val funData = funInfos(fd) + val instParamPart = instantiateTemplate(funData.paramParts, tempVarMap) + (funData.vcSolver, And(instParamPart, disableCounterExs), funData.modelCons) + } else { + val (paramPart, rest, modCons) = ctrTracker.getVC(fd).toUnflatExpr + val instPart = instantiateTemplate(paramPart, tempVarMap) + (solverFactory.getNewSolver(), createAnd(Seq(rest, instPart, disableCounterExs)), modCons) + } + //For debugging + if (dumpInstantiatedVC) { + val filename = "vcInst-" + FileCountGUID.getID + val wr = new PrintWriter(new File(filename+".txt")) + val fullExpr = + if (useIncrementalSolvingForVCs) { + And(funInfos(fd).simpleParts, instExpr) + } else instExpr + wr.println("Function name: " + fd.id+" \nFormula expr: ") + ExpressionTransformer.PrintWithIndentation(wr, fullExpr) + wr.close() } - case Some(false) => { - //do not generate any constraints - Some(((fls, Set()), tru)) + if(debugUnflattening){ + ctrTracker.getVC(fd).checkUnflattening(tempVarMap, + SimpleSolverAPI(SolverFactory(() => solverFactory.getNewSolver())), + defaultEval) } - case Some(true) => { - //For debugging purposes. - if (verbose) reporter.info("Function: " + fd.id + "--Found candidate invariant is not a real invariant! ") - if (this.printCounterExample) { - reporter.info("Model: " + model) + // sanity check + if (hasMixedIntReals(instExpr)) + throw new IllegalStateException("Instantiated VC of " + fd.id + " contains mixed integer/reals: " + instExpr) + + //reporter.info("checking VC inst ...") + solver.setTimeout(timeout * 1000) + val (res, packedModel) = + time { + if (useIncrementalSolvingForVCs) { + solver.push + solver.assertCnstr(instExpr) + val solRes = solver.check match { + case _ if ctx.abort => + (None, Model.empty) + case r @ Some(true) => + (r, solver.getModel) + case r => (r, Model.empty) + } + solver.pop() + solRes + } else + SimpleSolverAPI(SolverFactory(() => solver)).solveSAT(instExpr) + } { vccTime => + if (verbose) reporter.info("checked VC inst... in " + vccTime / 1000.0 + "s") + updateCounterTime(vccTime, "VC-check-time", "disjuncts") + updateCumTime(vccTime, "TotalVCCTime") + } + //for statistics + if (trackCompressedVCCTime) { + val compressedVC = + unflatten(simplifyArithmetic(instantiateTemplate( + ctrTracker.getVC(fd).eliminateBlockers, tempVarMap))) + Stats.updateCounterStats(atomNum(compressedVC), "Compressed-VC-size", "disjuncts") + time { + SimpleSolverAPI(SolverFactory(() => solverFactory.getNewSolver())).solveSAT(compressedVC) + } { compTime => + Stats.updateCumTime(compTime, "TotalCompressVCCTime") + reporter.info("checked compressed VC... in " + compTime / 1000.0 + "s") } - //get the disjuncts that are satisfied - val (data, newctr) = generateCtrsFromDisjunct(fd, model) - if (newctr == tru) - throw new IllegalStateException("Cannot find a counter-example path!!") - - val t2 = System.currentTimeMillis() - Stats.updateCounterTime((t2 - t1), "Disj-choosing-time", "disjuncts") - Stats.updateCumTime((t2 - t1), "Total-Choose-Time") - - Some((data, newctr)) + } + res match { + case None => None // cannot check satisfiability of VCinst !! + case Some(false) => + Some(((fls, Set()), tru)) //do not generate any constraints + case Some(true) => + //For debugging purposes. + if (verbose) reporter.info("Function: " + fd.id + "--Found candidate invariant is not a real invariant! ") + if (printCounterExample) { + reporter.info("Model: " + packedModel) + } + //get the disjuncts that are satisfied + val model = modelCons(packedModel, defaultEval) + val (data, newctr) = + time { generateCtrsFromDisjunct(fd, model) } { chTime => + updateCounterTime(chTime, "Disj-choosing-time", "disjuncts") + updateCumTime(chTime, "Total-Choose-Time") + } + if (newctr == tru) throw new IllegalStateException("Cannot find a counter-example path!!") + Some((data, newctr)) } } - } - - lazy val evaluator = new DefaultEvaluator(leonctx, program) //as of now used only for debugging - //a helper method - //TODO: this should also handle reals - protected def doesSatisfyModel(expr: Expr, model: Model): Boolean = { - evaluator.eval(expr, model).result match { - case Some(BooleanLiteral(true)) => true - case _ => false - } - } - - /** - * Evaluator for a predicate that is a simple equality/inequality between two variables. - * Some expressions may not be evaluatable, so we return none in those cases. - */ - protected def predEval(model: Model): (Expr => Option[Boolean]) = { - if (ctx.usereals) realEval(model) - else intEval(model) - } - protected def intEval(model: Model): (Expr => Option[Boolean]) = { - def modelVal(id: Identifier): BigInt = { - val InfiniteIntegerLiteral(v) = model(id) - v - } - def eval: (Expr => Option[Boolean]) = { - case And(args) => - val argres = args.map(eval) - if(argres.exists(!_.isDefined)) None - else - Some(argres.forall(_.get)) - case Equals(Variable(id1), Variable(id2)) => - if(model.isDefinedAt(id1) && - model.isDefinedAt(id2)) - Some(model(id1) == model(id2)) //note: ADTs can also be compared for equality - else None - case LessEquals(Variable(id1), Variable(id2)) => Some(modelVal(id1) <= modelVal(id2)) - case GreaterEquals(Variable(id1), Variable(id2)) => Some(modelVal(id1) >= modelVal(id2)) - case GreaterThan(Variable(id1), Variable(id2)) => Some(modelVal(id1) > modelVal(id2)) - case LessThan(Variable(id1), Variable(id2)) => Some(modelVal(id1) < modelVal(id2)) - case e => throw new IllegalStateException("Predicate not handled: " + e) - } - eval - } + protected def generateCtrsFromDisjunct(fd: FunDef, initModel: LazyModel): ((Expr, Set[Call]), Expr) = { - protected def realEval(model: Model): (Expr => Option[Boolean]) = { - def modelVal(id: Identifier): FractionalLiteral = { - //println("Identifier: "+id) - model(id).asInstanceOf[FractionalLiteral] - } - { - case Equals(Variable(id1), Variable(id2)) => Some(model(id1) == model(id2)) //note: ADTs can also be compared for equality - case e@Operator(Seq(Variable(id1), Variable(id2)), op) if (e.isInstanceOf[LessThan] - || e.isInstanceOf[LessEquals] || e.isInstanceOf[GreaterThan] - || e.isInstanceOf[GreaterEquals]) => { - Some(evaluateRealPredicate(op(Seq(modelVal(id1), modelVal(id2))))) + val formula = ctrTracker.getVC(fd) + //this picks the satisfiable disjunct of the VC modulo axioms + val satCtrs = formula.pickSatDisjunct(formula.firstRoot, initModel) + //for debugging + if (debugChooseDisjunct || printPathToConsole || dumpPathAsSMTLIB || verifyInvariant) { + val pathctrs = satCtrs.map(_.toExpr) + val plainFormula = createAnd(pathctrs) + val pathcond = simplifyArithmetic(plainFormula) + + if (debugChooseDisjunct) { + satCtrs.filter(_.isInstanceOf[LinearConstraint]).map(_.toExpr).foreach((ctr) => { + if (!doesSatisfyExpr(ctr, initModel)) + throw new IllegalStateException("Path ctr not satisfied by model: " + ctr) + }) + } + if (verifyInvariant) { + println("checking invariant for path...") + val sat = checkInvariant(pathcond, leonctx, program) + } + if (printPathToConsole) { + //val simpcond = ExpressionTransformer.unFlatten(pathcond, variablesOf(pathcond).filterNot(TVarFactory.isTemporary _)) + val simpcond = pathcond + println("Full-path: " + ScalaPrinter(simpcond)) + val filename = "full-path-" + FileCountGUID.getID + ".txt" + val wr = new PrintWriter(new File(filename)) + ExpressionTransformer.PrintWithIndentation(wr, simpcond) + println("Printed to file: " + filename) + wr.flush() + wr.close() + } + if (dumpPathAsSMTLIB) { + val filename = "pathcond" + FileCountGUID.getID + ".smt2" + toZ3SMTLIB(pathcond, filename, "QF_NIA", leonctx, program) + println("Path dumped to: " + filename) + } } - case e => throw new IllegalStateException("Predicate not handled: " + e) - } - } - /** - * This solver does not use any theories other than UF/ADT. It assumes that other theories are axiomatized in the VC. - * This method can overloaded by the subclasses. - */ - protected def axiomsForTheory(formula: Formula, calls: Set[Call], model: Model): Seq[Constraint] = Seq() - - protected def generateCtrsFromDisjunct(fd: FunDef, model: Model): ((Expr, Set[Call]), Expr) = { - - val formula = ctrTracker.getVC(fd) - //this picks the satisfiable disjunct of the VC modulo axioms - val satCtrs = formula.pickSatDisjunct(formula.firstRoot, model) - //for debugging - if (this.debugChooseDisjunct || this.printPathToConsole || this.dumpPathAsSMTLIB || this.verifyInvariant) { - val pathctrs = satCtrs.map(_.toExpr) - val plainFormula = createAnd(pathctrs) - val pathcond = simplifyArithmetic(plainFormula) - - if (this.debugChooseDisjunct) { - satCtrs.filter(_.isInstanceOf[LinearConstraint]).map(_.toExpr).foreach((ctr) => { - if (!doesSatisfyModel(ctr, model)) - throw new IllegalStateException("Path ctr not satisfied by model: " + ctr) - }) + var calls = Set[Call]() + var adtExprs = Seq[Expr]() + satCtrs.foreach { + case t: Call => calls += t + case t: ADTConstraint if (t.cons || t.sel) => adtExprs :+= t.expr + // TODO: ignoring all set constraints here, fix this + case _ => ; } - - if (this.verifyInvariant) { - println("checking invariant for path...") - val sat = checkInvariant(pathcond, leonctx, program) + val callExprs = calls.map(_.toExpr) + + val axiomCtrs = time { + ctrTracker.specInstantiator.axiomsForCalls(formula, calls, initModel) + } { updateCumTime(_, "Total-AxiomChoose-Time") } + + //here, handle theory operations by reducing them to axioms. + //Note: uninterpreted calls/ADTs are handled below as they are more general. Here, we handle + //other theory axioms like: multiplication, sets, arrays, maps etc. + val theoryCtrs = time { + axiomsForTheory(formula, calls, initModel) + } { updateCumTime(_, "Total-TheoryAxiomatization-Time") } + + //Finally, eliminate UF/ADT + // convert all adt constraints to 'cons' ctrs, and expand the model + val selTrans = new SelectorToCons() + val cons = selTrans.selToCons(adtExprs) + val expModel = selTrans.getModel(initModel) + // get constraints for UFADTs + val callCtrs = time { + (new UFADTEliminator(leonctx, program)).constraintsForCalls((callExprs ++ cons), + linearEval.predEval(expModel)).map(ConstraintUtil.createConstriant _) + } { updateCumTime(_, "Total-ElimUF-Time") } + + //exclude guards, separate calls and cons from the rest + var lnctrs = Set[LinearConstraint]() + var temps = Set[LinearTemplate]() + (satCtrs ++ callCtrs ++ axiomCtrs ++ theoryCtrs).foreach { + case t: LinearConstraint => lnctrs += t + case t: LinearTemplate => temps += t + case _ => ; } - - if (this.printPathToConsole) { - //val simpcond = ExpressionTransformer.unFlatten(pathcond, variablesOf(pathcond).filterNot(TVarFactory.isTemporary _)) - val simpcond = pathcond - println("Full-path: " + ScalaPrinter(simpcond)) - val filename = "full-path-" + FileCountGUID.getID + ".txt" - val wr = new PrintWriter(new File(filename)) - ExpressionTransformer.PrintWithIndentation(wr, simpcond) - println("Printed to file: " + filename) - wr.flush() - wr.close() + if (debugChooseDisjunct) { + lnctrs.map(_.toExpr).foreach((ctr) => { + if (!doesSatisfyExpr(ctr, expModel)) + throw new IllegalStateException("Ctr not satisfied by model: " + ctr) + }) } - - if (this.dumpPathAsSMTLIB) { - val filename = "pathcond" + FileCountGUID.getID + ".smt2" - toZ3SMTLIB(pathcond, filename, "QF_NIA", leonctx, program) - println("Path dumped to: " + filename) + if (debugTheoryReduction) { + val simpPathCond = createAnd((lnctrs ++ temps).map(_.template).toSeq) + if (verifyInvariant) { + println("checking invariant for simp-path...") + checkInvariant(simpPathCond, leonctx, program) + } } - } - - var calls = Set[Call]() - var cons = Set[Expr]() - satCtrs.foreach { - case t: Call => calls += t - case t: ADTConstraint if (t.cons.isDefined) => cons += t.cons.get - // TODO: ignoring all set constraints here, fix this - case _ => ; - } - val callExprs = calls.map(_.toExpr) - - var t1 = System.currentTimeMillis() - val axiomCtrs = ctrTracker.specInstantiator.axiomsForCalls(formula, calls, model) - var t2 = System.currentTimeMillis() - Stats.updateCumTime((t2 - t1), "Total-AxiomChoose-Time") - - //here, handle theory operations by reducing them to axioms. - //Note: uninterpreted calls/ADTs are handled below as they are more general. Here, we handle - //other theory axioms like: multiplication, sets, arrays, maps etc. - t1 = System.currentTimeMillis() - val theoryCtrs = axiomsForTheory(formula, calls, model) - t2 = System.currentTimeMillis() - Stats.updateCumTime((t2 - t1), "Total-TheoryAxiomatization-Time") - - //Finally, eliminate UF/ADT - t1 = System.currentTimeMillis() - val callCtrs = (new UFADTEliminator(leonctx, program)).constraintsForCalls((callExprs ++ cons), - predEval(model)).map(ConstraintUtil.createConstriant _) - t2 = System.currentTimeMillis() - Stats.updateCumTime((t2 - t1), "Total-ElimUF-Time") - - //exclude guards, separate calls and cons from the rest - var lnctrs = Set[LinearConstraint]() - var temps = Set[LinearTemplate]() - (satCtrs ++ callCtrs ++ axiomCtrs ++ theoryCtrs).foreach { - case t: LinearConstraint => lnctrs += t - case t: LinearTemplate => temps += t - case _ => ; - } - - if (this.debugChooseDisjunct) { - lnctrs.map(_.toExpr).foreach((ctr) => { - if (!doesSatisfyModel(ctr, model)) - throw new IllegalStateException("Ctr not satisfied by model: " + ctr) - }) - } - - if (this.debugTheoryReduction) { - val simpPathCond = createAnd((lnctrs ++ temps).map(_.template).toSeq) - if (this.verifyInvariant) { - println("checking invariant for simp-path...") - checkInvariant(simpPathCond, leonctx, program) + if (trackNumericalDisjuncts) { + numericalDisjuncts :+= createAnd((lnctrs ++ temps).map(_.template).toSeq) } + val (data, nlctr) = processNumCtrs(lnctrs.toSeq, temps.toSeq) + ((data, calls), nlctr) } - if (this.trackNumericalDisjuncts) { - numericalDisjuncts :+= createAnd((lnctrs ++ temps).map(_.template).toSeq) - } - - val (data, nlctr) = processNumCtrs(lnctrs.toSeq, temps.toSeq) - ((data, calls), nlctr) - } - - /** - * Endpoint of the pipeline. Invokes the Farkas Lemma constraint generation. - */ - def processNumCtrs(lnctrs: Seq[LinearConstraint], temps: Seq[LinearTemplate]): (Expr, Expr) = { - //here we are invalidating A^~(B) - if (temps.isEmpty) { - //here ants ^ conseq is sat (otherwise we wouldn't reach here) and there is no way to falsify this path - (createAnd(lnctrs.map(_.toExpr)), fls) - } else { - - if (this.debugElimination) { - //println("Path Constraints (before elim): "+(lnctrs ++ temps)) - if (this.verifyInvariant) { - println("checking invariant for disjunct before elimination...") - checkInvariant(createAnd((lnctrs ++ temps).map(_.template)), leonctx, program) + /** + * Endpoint of the pipeline. Invokes the Farkas Lemma constraint generation. + */ + def processNumCtrs(lnctrs: Seq[LinearConstraint], temps: Seq[LinearTemplate]): (Expr, Expr) = { + //here we are invalidating A^~(B) + if (temps.isEmpty) { + //here ants ^ conseq is sat (otherwise we wouldn't reach here) and there is no way to falsify this path + (createAnd(lnctrs.map(_.toExpr)), fls) + } else { + if (debugElimination) { + //println("Path Constraints (before elim): "+(lnctrs ++ temps)) + if (verifyInvariant) { + println("checking invariant for disjunct before elimination...") + checkInvariant(createAnd((lnctrs ++ temps).map(_.template)), leonctx, program) + } } - } - //compute variables to be eliminated - val t1 = System.currentTimeMillis() - val ctrVars = lnctrs.foldLeft(Set[Identifier]())((acc, lc) => acc ++ variablesOf(lc.toExpr)) - val tempVars = temps.foldLeft(Set[Identifier]())((acc, lt) => acc ++ variablesOf(lt.template)) - val elimVars = ctrVars.diff(tempVars) - - val debugger = if (debugElimination && verifyInvariant) { - Some((ctrs: Seq[LinearConstraint]) => { - //println("checking disjunct before elimination...") - //println("ctrs: "+ctrs) - val debugRes = checkInvariant(createAnd((ctrs ++ temps).map(_.template)), leonctx, program) - }) - } else None - val elimLnctrs = LinearConstraintUtil.apply1PRuleOnDisjunct(lnctrs, elimVars, debugger) - val t2 = System.currentTimeMillis() - - if (this.debugElimination) { - println("Path constriants (after elimination): " + elimLnctrs) - if (this.verifyInvariant) { - println("checking invariant for disjunct after elimination...") - checkInvariant(createAnd((elimLnctrs ++ temps).map(_.template)), leonctx, program) + //compute variables to be eliminated + val ctrVars = lnctrs.foldLeft(Set[Identifier]())((acc, lc) => acc ++ variablesOf(lc.toExpr)) + val tempVars = temps.foldLeft(Set[Identifier]())((acc, lt) => acc ++ variablesOf(lt.template)) + val elimVars = ctrVars.diff(tempVars) + // for debugging + val debugger = + if (debugElimination && verifyInvariant) { + Some((ctrs: Seq[LinearConstraint]) => { + val debugRes = checkInvariant(createAnd((ctrs ++ temps).map(_.template)), leonctx, program) + }) + } else None + val elimLnctrs = time { + LinearConstraintUtil.apply1PRuleOnDisjunct(lnctrs, elimVars, debugger) + } { updateCumTime(_, "ElimTime") } + + if (debugElimination) { + println("Path constriants (after elimination): " + elimLnctrs) + if (verifyInvariant) { + println("checking invariant for disjunct after elimination...") + checkInvariant(createAnd((elimLnctrs ++ temps).map(_.template)), leonctx, program) + } } - } - //for stats - if (ctx.dumpStats) { - var elimCtrCount = 0 - var elimCtrs = Seq[LinearConstraint]() - var elimRems = Set[Identifier]() - elimLnctrs.foreach((lc) => { - val evars = variablesOf(lc.toExpr).intersect(elimVars) - if (evars.nonEmpty) { - elimCtrs :+= lc - elimCtrCount += 1 - elimRems ++= evars + //for stats + if (ctx.dumpStats) { + var elimCtrCount = 0 + var elimCtrs = Seq[LinearConstraint]() + var elimRems = Set[Identifier]() + elimLnctrs.foreach((lc) => { + val evars = variablesOf(lc.toExpr).intersect(elimVars) + if (evars.nonEmpty) { + elimCtrs :+= lc + elimCtrCount += 1 + elimRems ++= evars + } + }) + Stats.updateCounterStats((elimVars.size - elimRems.size), "Eliminated-Vars", "disjuncts") + Stats.updateCounterStats((lnctrs.size - elimLnctrs.size), "Eliminated-Atoms", "disjuncts") + Stats.updateCounterStats(temps.size, "Param-Atoms", "disjuncts") + Stats.updateCounterStats(lnctrs.size, "NonParam-Atoms", "disjuncts") + } + val newLnctrs = elimLnctrs.toSet.toSeq + + //TODO:Remove transitive facts. E.g. a <= b, b <= c, a <=c can be simplified by dropping a <= c + //TODO: simplify the formulas and remove implied conjuncts if possible (note the formula is satisfiable, so there can be no inconsistencies) + //e.g, remove: a <= b if we have a = b or if a < b + //Also, enrich the rules for quantifier elimination: try z3 quantifier elimination on variables that have an equality. + //TODO: Use the dependence chains in the formulas to identify what to assertionize + // and what can never be implied by solving for the templates + val disjunct = createAnd((newLnctrs ++ temps).map(_.template)) + val implCtrs = farkasSolver.constraintsForUnsat(newLnctrs, temps) + //for debugging + if (debugReducedFormula) { + println("Final Path Constraints: " + disjunct) + if (verifyInvariant) { + println("checking invariant for final disjunct... ") + checkInvariant(disjunct, leonctx, program) } - }) - Stats.updateCounterStats((elimVars.size - elimRems.size), "Eliminated-Vars", "disjuncts") - Stats.updateCounterStats((lnctrs.size - elimLnctrs.size), "Eliminated-Atoms", "disjuncts") - Stats.updateCounterStats(temps.size, "Param-Atoms", "disjuncts") - Stats.updateCounterStats(lnctrs.size, "NonParam-Atoms", "disjuncts") - Stats.updateCumTime((t2 - t1), "ElimTime") - } - val newLnctrs = elimLnctrs.toSet.toSeq - - //TODO:Remove transitive facts. E.g. a <= b, b <= c, a <=c can be simplified by dropping a <= c - //TODO: simplify the formulas and remove implied conjuncts if possible (note the formula is satisfiable, so there can be no inconsistencies) - //e.g, remove: a <= b if we have a = b or if a < b - //Also, enrich the rules for quantifier elimination: try z3 quantifier elimination on variables that have an equality. - - //TODO: Use the dependence chains in the formulas to identify what to assertionize - // and what can never be implied by solving for the templates - - val disjunct = createAnd((newLnctrs ++ temps).map(_.template)) - val implCtrs = farkasSolver.constraintsForUnsat(newLnctrs, temps) - - //for debugging - if (this.debugReducedFormula) { - println("Final Path Constraints: " + disjunct) - if (this.verifyInvariant) { - println("checking invariant for final disjunct... ") - checkInvariant(disjunct, leonctx, program) } + (disjunct, implCtrs) } - - (disjunct, implCtrs) } } } diff --git a/src/main/scala/leon/invariant/templateSolvers/NLTemplateSolverWithMult.scala b/src/main/scala/leon/invariant/templateSolvers/NLTemplateSolverWithMult.scala index 4fc681eb284daf192c3477b11f24078888fba9b4..3ae146a8a862bafefe488241e8dd3384099664eb 100644 --- a/src/main/scala/leon/invariant/templateSolvers/NLTemplateSolverWithMult.scala +++ b/src/main/scala/leon/invariant/templateSolvers/NLTemplateSolverWithMult.scala @@ -25,16 +25,16 @@ class NLTemplateSolverWithMult(ctx: InferenceContext, program: Program, rootFun: nlvc } - override def splitVC(fd: FunDef): (Expr, Expr) = { - val (paramPart, rest) = ctrTracker.getVC(fd).splitParamPart - (multToTimes(paramPart), multToTimes(rest)) + override def splitVC(fd: FunDef) = { + val (paramPart, rest, modCons) = super.splitVC(fd) + (multToTimes(paramPart), multToTimes(rest), modCons) } - override def axiomsForTheory(formula: Formula, calls: Set[Call], model: Model): Seq[Constraint] = { + override def axiomsForTheory(formula: Formula, calls: Set[Call], model: LazyModel): Seq[Constraint] = { //in the sequel we instantiate axioms for multiplication - val inst1 = unaryMultAxioms(formula, calls, predEval(model)) - val inst2 = binaryMultAxioms(formula, calls, predEval(model)) + val inst1 = unaryMultAxioms(formula, calls, linearEval.predEval(model)) + val inst2 = binaryMultAxioms(formula, calls, linearEval.predEval(model)) val multCtrs = (inst1 ++ inst2).flatMap { case And(args) => args.map(ConstraintUtil.createConstriant _) case e => Seq(ConstraintUtil.createConstriant(e)) diff --git a/src/main/scala/leon/invariant/templateSolvers/TemplateSolver.scala b/src/main/scala/leon/invariant/templateSolvers/TemplateSolver.scala index c9876ffd161d94fd10de75c93f3e4c6a574008da..4bd170f8c9a49a1ce15111431e02cb174cf1be3a 100644 --- a/src/main/scala/leon/invariant/templateSolvers/TemplateSolver.scala +++ b/src/main/scala/leon/invariant/templateSolvers/TemplateSolver.scala @@ -33,12 +33,13 @@ abstract class TemplateSolver(ctx: InferenceContext, val rootFun: FunDef, /** * Completes a model by adding mapping to new template variables */ - def completeModel(model: Map[Identifier, Expr], tempIds: Set[Identifier]): Map[Identifier, Expr] = { - tempIds.map((id) => { - if (!model.contains(id)) { + def completeModel(model: Model, ids: Set[Identifier]) = { + val idmap = ids.map((id) => { + if (!model.isDefinedAt(id)) { (id, simplestValue(id.getType)) } else (id, model(id)) }).toMap + new Model(idmap) } /** @@ -53,8 +54,16 @@ abstract class TemplateSolver(ctx: InferenceContext, val rootFun: FunDef, TemplateInstantiator.getAllInvariants(model, templates.toMap) } + var vcCache = Map[FunDef, Expr]() protected def getVCForFun(fd: FunDef): Expr = { - ctrTracker.getVC(fd).toExpr + vcCache.getOrElse(fd, { + val vcInit = ctrTracker.getVC(fd).toExpr + val vc = if (ctx.usereals) + ExpressionTransformer.IntLiteralToReal(vcInit) + else vcInit + vcCache += (fd -> vc) + vc + }) } /** @@ -62,47 +71,32 @@ abstract class TemplateSolver(ctx: InferenceContext, val rootFun: FunDef, * The result is a mapping from function definitions to the corresponding invariants. */ def solveTemplates(): (Option[Model], Option[Set[Call]]) = { - //traverse each of the functions and collect the VCs val funcs = ctrTracker.getFuncs - val funcExprs = funcs.map((fd) => { - val vc = if (ctx.usereals) - ExpressionTransformer.IntLiteralToReal(getVCForFun(fd)) - else getVCForFun(fd) + val tempIds = funcs.flatMap { fd => + val vc = ctrTracker.getVC(fd) if (dumpVCtoConsole || dumpVCasText) { - //val simpForm = simplifyArithmetic(vc) val filename = "vc-" + FileCountGUID.getID if (dumpVCtoConsole) { println("Func: " + fd.id + " VC: " + vc) } if (dumpVCasText) { val wr = new PrintWriter(new File(filename + ".txt")) - //ExpressionTransformer.PrintWithIndentation(wr, vcstr) println("Printed VC of " + fd.id + " to file: " + filename) - wr.println(vc.toString) - wr.flush() + wr.println(vc.toString()) wr.close() } } if (ctx.dumpStats) { - Stats.updateCounterStats(atomNum(vc), "VC-size", "VC-refinement") - Stats.updateCounterStats(numUIFADT(vc), "UIF+ADT", "VC-refinement") + Stats.updateCounterStats(vc.atomsCount, "VC-size", "VC-refinement") + Stats.updateCounterStats(vc.funsCount, "UIF+ADT", "VC-refinement") } - (fd -> vc) - }).toMap - //Assign some values for the template variables at random (actually use the simplest value for the type) - val tempIds = funcExprs.foldLeft(Set[Identifier]()) { - case (acc, (_, vc)) => - //val tempOption = if (fd.hasTemplate) Some(fd.getTemplate) else None - //if (!tempOption.isDefined) acc - //else - acc ++ getTemplateIds(vc) - } + vc.templateIdsInFormula + }.toSet + Stats.updateCounterStats(tempIds.size, "TemplateIds", "VC-refinement") - val solution = - if (ctx.abort) (None, None) - else solve(tempIds, funcExprs) - solution + if (ctx.abort) (None, None) + else solve(tempIds, funcs) } - def solve(tempIds: Set[Identifier], funcVCs: Map[FunDef, Expr]): (Option[Model], Option[Set[Call]]) + def solve(tempIds: Set[Identifier], funcVCs: Seq[FunDef]): (Option[Model], Option[Set[Call]]) } \ No newline at end of file diff --git a/src/main/scala/leon/invariant/util/ExpressionTransformer.scala b/src/main/scala/leon/invariant/util/ExpressionTransformer.scala index 8f72e12ec20826d70b932f749d879282b8accdb7..abda26d430a851bbdaec38f338cb4d5aea5340fb 100644 --- a/src/main/scala/leon/invariant/util/ExpressionTransformer.scala +++ b/src/main/scala/leon/invariant/util/ExpressionTransformer.scala @@ -27,10 +27,10 @@ object ExpressionTransformer { val fls = BooleanLiteral(false) val bone = BigInt(1) - // identifier for temporaries that are generated during flattening + // identifier for temporaries that are generated during flattening of terms other than functions val flatContext = newContext - // temporaries generated during conversion of field selects to ADT constructions - val fieldSelContext = newContext + // temporaries used in the function flattening + val funFlatContext = newContext // conversion of other language constructs val langContext = newContext @@ -138,7 +138,6 @@ object ExpressionTransformer { //TODO: do we have to consider reuse of let variables ? val (resbody, bodycjs) = transform(body, true) val (resvalue, valuecjs) = transform(value, true) - (resbody, (valuecjs + Equals(binder.toVariable, resvalue)) ++ bodycjs) } //the value is a tuple in the following case @@ -146,7 +145,6 @@ object ExpressionTransformer { //TODO: do we have to consider reuse of let variables ? val (resbody, bodycjs) = transform(body, true) val (resvalue, valuecjs) = transform(value, true) - //here we optimize the case where resvalue itself has tuples val newConjuncts = resvalue match { case Tuple(args) => { @@ -172,7 +170,6 @@ object ExpressionTransformer { (cjs ++ cjs2) } } - (resbody, (valuecjs ++ newConjuncts) ++ bodycjs) } case _ => { @@ -208,7 +205,7 @@ object ExpressionTransformer { case fi @ FunctionInvocation(fd, args) => val (newargs, newConjuncts) = flattenArgs(args, true) val newfi = FunctionInvocation(fd, newargs) - val freshResVar = Variable(createFlatTemp("r", fi.getType)) + val freshResVar = Variable(createTemp("r", fi.getType, funFlatContext)) val res = (freshResVar, newConjuncts + Equals(freshResVar, newfi)) res @@ -228,6 +225,7 @@ object ExpressionTransformer { val freshArg = newargs(0) val newCS = CaseClassSelector(cd, freshArg, sel) val freshResVar = Variable(createFlatTemp("cs", cs.getType)) + //val freshResVar = Variable(createTemp("cs", cs.getType, funFlatContext)) // we cannot flatten these as they will converted to cons newConjuncts += Equals(freshResVar, newCS) (freshResVar, newConjuncts) @@ -237,6 +235,7 @@ object ExpressionTransformer { val freshArg = newargs(0) val newTS = TupleSelect(freshArg, index) val freshResVar = Variable(createFlatTemp("ts", ts.getType)) + //val freshResVar = Variable(createTemp("ts", ts.getType, funFlatContext)) newConjuncts += Equals(freshResVar, newTS) (freshResVar, newConjuncts) @@ -414,40 +413,6 @@ object ExpressionTransformer { }(expr) } - def classSelToCons(e: Expr): Expr = { - val (r, cd, ccvar, ccfld) = e match { - case Equals(r0 @ Variable(_), CaseClassSelector(cd0, ccvar0, ccfld0)) => (r0, cd0, ccvar0, ccfld0) - case _ => throw new IllegalStateException("Not a case-class-selector call") - } - //convert this to a cons by creating dummy variables - val args = cd.fields.map((fld) => { - if (fld.id == ccfld) r - else { - //create a dummy identifier there - createTemp("fld", fld.getType, fieldSelContext).toVariable - } - }) - Equals(ccvar, CaseClass(cd, args)) - } - - def tupleSelToCons(e: Expr): Expr = { - val (r, tpvar, index) = e match { - case Equals(r0 @ Variable(_), TupleSelect(tpvar0, index0)) => (r0, tpvar0, index0) - // case Iff(r0 @ Variable(_), TupleSelect(tpvar0, index0)) => (r0, tpvar0, index0) - case _ => throw new IllegalStateException("Not a tuple-selector call") - } - //convert this to a Tuple by creating dummy variables - val tupleType = tpvar.getType.asInstanceOf[TupleType] - val args = (1 until tupleType.dimension + 1).map((i) => { - if (i == index) r - else { - //create a dummy identifier there (note that here we have to use i-1) - createTemp("fld", tupleType.bases(i - 1), fieldSelContext).toVariable - } - }) - Equals(tpvar, Tuple(args)) - } - /** * Normalizes the expressions */ @@ -470,23 +435,48 @@ object ExpressionTransformer { * This is the inverse operation of flattening. * This is used to produce a readable formula or more efficiently * solvable formulas. + * Note: this is a helper method that assumes that 'flatIds' + * are not shared across disjuncts. + * If this is not guaranteed to hold, use the 'unflatten' method */ - def unFlattenWithMap(ine: Expr): (Expr, Map[Identifier,Expr]) = { + def simpleUnflattenWithMap(ine: Expr, excludeIds: Set[Identifier] = Set(), + includeFuns: Boolean): (Expr, Map[Identifier,Expr]) = { + + def isFlatTemp(id: Identifier) = + isTemp(id, flatContext) || (includeFuns && isTemp(id, funFlatContext)) + var idMap = Map[Identifier, Expr]() - val newe = simplePostTransform { - case e @ Equals(Variable(id), rhs @ _) if isTemp(id, flatContext) => - if (idMap.contains(id)) e + /** + * Here, relying on library transforms is dangerous as they + * can perform additional simplifications to the expression on-the-fly, + * which is not desirable here. + */ + def rec(e: Expr): Expr = e match { + case e @ Equals(Variable(id), rhs @ _) if isFlatTemp(id) && !excludeIds(id) => + val nrhs = rec(rhs) + if (idMap.contains(id)) Equals(Variable(id), nrhs) else { - idMap += (id -> rhs) + idMap += (id -> nrhs) tru } - case e => e - }(ine) + // specially handle boolean function to prevent unnecessary simplifications + case Or(args) => Or(args map rec) + case And(args) => And(args map rec) + case Not(arg) => Not(rec(arg)) + case Operator(args, op) => op(args map rec) + } + val newe = rec(ine) val closure = (e: Expr) => replaceFromIDs(idMap, e) - (fix(closure)(newe), idMap) + val rese = fix(closure)(newe) + (rese, idMap) + } + + def unflattenWithMap(ine: Expr, excludeIds: Set[Identifier] = Set(), + includeFuns: Boolean = true): (Expr, Map[Identifier,Expr]) = { + simpleUnflattenWithMap(ine, sharedIds(ine) ++ excludeIds, includeFuns) } - def unFlatten(ine: Expr) = unFlattenWithMap(ine)._1 + def unflatten(ine: Expr) = unflattenWithMap(ine)._1 /** * convert all integer constants to real constants diff --git a/src/main/scala/leon/invariant/util/FlatToUnflatExpr.scala b/src/main/scala/leon/invariant/util/FlatToUnflatExpr.scala deleted file mode 100644 index 563a7feb5e90477b94b2618b7ecc7522f9409bda..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/invariant/util/FlatToUnflatExpr.scala +++ /dev/null @@ -1,60 +0,0 @@ -package leon -package invariant.util - -import purescala.Common._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Types._ -import invariant.factories._ -import solvers._ -import scala.collection.immutable._ -import scala.collection.mutable.{Set => MutableSet, Map => MutableMap} -import ExpressionTransformer._ -import leon.evaluators._ -import EvaluationResults._ - -/** - *A class that can used to compress a flattened expression - * and also expand the compressed models to the flat forms - */ -class FlatToUnflatExpr(ine: Expr, eval: DeterministicEvaluator) { - - val (unflate, flatIdMap) = unFlattenWithMap(ine) - def getModel(m: Model) = new FlatModel(m) - - /** - * Expands the model into a model with mappings for identifiers. - * Note: this class cannot be accessed in parallel. - */ - class FlatModel(initModel: Model) { - var idModel = initModel.toMap - - def apply(iden: Identifier) = { - var seen = Set[Identifier]() - def recBind(id: Identifier): Expr = { - val idv = idModel.get(id) - if (idv.isDefined) idv.get - else { - if (seen(id)) { - //we are in a cycle here - throw new IllegalStateException(s"$id depends on itself $id, input expression: $ine") - } else if (flatIdMap.contains(id)) { - val rhs = flatIdMap(id) - // recursively bind all freevars to values (we can ignore the return values) - seen += id - variablesOf(rhs).map(recBind) - eval.eval(rhs, idModel) match { - case Successful(v) => - idModel += (id -> v) - v - case _ => - throw new IllegalStateException(s"Evaluation Falied for $rhs") - } - } else - throw new IllegalStateException(s"Cannot extract model $id as it is not a part of the input expression: $ine") - } - } - recBind(iden) - } - } -} \ No newline at end of file diff --git a/src/main/scala/leon/invariant/util/LinearRelationEvaluator.scala b/src/main/scala/leon/invariant/util/LinearRelationEvaluator.scala new file mode 100644 index 0000000000000000000000000000000000000000..7c5886faf3ed4a05dc1845910eb149e6188ef81c --- /dev/null +++ b/src/main/scala/leon/invariant/util/LinearRelationEvaluator.scala @@ -0,0 +1,82 @@ +package leon +package invariant.util + +import purescala.Common._ +import purescala.Definitions._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.Extractors._ +import purescala.Types._ +import evaluators._ +import java.io._ +import solvers._ +import solvers.combinators._ +import solvers.smtlib._ +import solvers.z3._ +import scala.util.control.Breaks._ +import purescala.ScalaPrinter +import scala.collection.mutable.{ Map => MutableMap } +import scala.reflect.runtime.universe +import invariant.engine._ +import invariant.factories._ +import invariant.util._ +import invariant.util.ExpressionTransformer._ +import invariant.structure._ +import invariant.structure.FunctionUtils._ + +import Util._ +import PredicateUtil._ +import SolverUtil._ +import RealValuedExprEvaluator._ + +/** + * Evaluator for a predicate that is a simple equality/inequality between two variables. + * Some expressions cannot be evaluated, so we return none in those cases. + */ +class LinearRelationEvaluator(ctx: InferenceContext) { + + def predEval(model: LazyModel): (Expr => Option[Boolean]) = { + if (ctx.usereals) realEval(model) + else intEval(model) + } + + def intEval(model: LazyModel): (Expr => Option[Boolean]) = { + def modelVal(id: Identifier): BigInt = { + val InfiniteIntegerLiteral(v) = model(id) + v + } + def eval: (Expr => Option[Boolean]) = { + case And(args) => + val argres = args.map(eval) + if (argres.exists(!_.isDefined)) None + else + Some(argres.forall(_.get)) + case Equals(Variable(id1), Variable(id2)) => + if (model.isDefinedAt(id1) && model.isDefinedAt(id2)) + Some(model(id1) == model(id2)) //note: ADTs can also be compared for equality + else None + case LessEquals(Variable(id1), Variable(id2)) => Some(modelVal(id1) <= modelVal(id2)) + case GreaterEquals(Variable(id1), Variable(id2)) => Some(modelVal(id1) >= modelVal(id2)) + case GreaterThan(Variable(id1), Variable(id2)) => Some(modelVal(id1) > modelVal(id2)) + case LessThan(Variable(id1), Variable(id2)) => Some(modelVal(id1) < modelVal(id2)) + case e => throw new IllegalStateException("Predicate not handled: " + e) + } + eval + } + + def realEval(model: LazyModel): (Expr => Option[Boolean]) = { + def modelVal(id: Identifier): FractionalLiteral = { + //println("Identifier: "+id) + model(id).asInstanceOf[FractionalLiteral] + } + { + case Equals(Variable(id1), Variable(id2)) => Some(model(id1) == model(id2)) //note: ADTs can also be compared for equality + case e @ Operator(Seq(Variable(id1), Variable(id2)), op) if (e.isInstanceOf[LessThan] + || e.isInstanceOf[LessEquals] || e.isInstanceOf[GreaterThan] + || e.isInstanceOf[GreaterEquals]) => { + Some(evaluateRealPredicate(op(Seq(modelVal(id1), modelVal(id2))))) + } + case e => throw new IllegalStateException("Predicate not handled: " + e) + } + } +} diff --git a/src/main/scala/leon/invariant/util/Minimizer.scala b/src/main/scala/leon/invariant/util/Minimizer.scala index 3794f55463490ae048574f3613f720ba092d0ff2..267f66bae0043a0265dd8eda71dc9bd40eef7119 100644 --- a/src/main/scala/leon/invariant/util/Minimizer.scala +++ b/src/main/scala/leon/invariant/util/Minimizer.scala @@ -10,6 +10,7 @@ import solvers.smtlib.SMTLIBZ3Solver import invariant.engine.InferenceContext import invariant.factories._ import leon.invariant.util.RealValuedExprEvaluator._ +import Stats._ class Minimizer(ctx: InferenceContext, program: Program) { @@ -44,10 +45,13 @@ class Minimizer(ctx: InferenceContext, program: Program) { minimizeBounds(computeCompositionLevel(timeTemplate))(inputCtr, initModel) } + /** + * TODO: use incremental solving of z3 when it is supported in nlsat + * Do a binary search sequentially on the tempvars ordered by the rate of growth of the term they + * are a coefficient for. + */ def minimizeBounds(nestMap: Map[Variable, Int])(inputCtr: Expr, initModel: Model): Model = { val orderedTempVars = nestMap.toSeq.sortWith((a, b) => a._2 >= b._2).map(_._1) - //do a binary search sequentially on each of these tempvars - // note: use smtlib solvers so that they can be timedout lazy val solver = new SimpleSolverAPI(new TimeoutSolverFactory(SolverFactory(() => new SMTLIBZ3Solver(leonctx, program) with TimeoutSolver), ctx.vcTimeout * 1000)) @@ -64,42 +68,36 @@ class Minimizer(ctx: InferenceContext, program: Program) { if (tvar == orderedTempVars(0) && lowerBoundMap.contains(tvar)) lowerBoundMap(tvar) else realzero - //a helper method def updateState(nmodel: Model) = { upperBound = nmodel(tvar.id).asInstanceOf[FractionalLiteral] currentModel = nmodel - if (this.debugMinimization) { + if (this.debugMinimization) reporter.info("Found new upper bound: " + upperBound) - //reporter.info("Model: "+currentModel) - } } if (this.debugMinimization) reporter.info(s"Minimizing variable: $tvar Initial Bounds: [$upperBound,$lowerBound]") - //TODO: use incremental solving of z3 when it is supported in nlsat var continue = true var iter = 0 do { iter += 1 if (continue) { - //we make sure that curr val is an integer - val currval = floor(evaluate(Times(half, Plus(upperBound, lowerBound)))) - //check if the lowerbound, if it exists, is < currval - if (evaluateRealPredicate(GreaterEquals(lowerBound, currval))) + val currval = floor(evaluate(Times(half, Plus(upperBound, lowerBound)))) //make sure that curr val is an integer + if (evaluateRealPredicate(GreaterEquals(lowerBound, currval))) //check if the lowerbound, if it exists, is < currval continue = false else { val boundCtr = And(LessEquals(tvar, currval), GreaterEquals(tvar, lowerBound)) - //val t1 = System.currentTimeMillis() val (res, newModel) = if (ctx.abort) (None, Model.empty) - else solver.solveSAT(And(acc, boundCtr)) - //val t2 = System.currentTimeMillis() - //println((if (res.isDefined) "solved" else "timed out") + "... in " + (t2 - t1) / 1000.0 + "s") + else { + time { solver.solveSAT(And(acc, boundCtr)) }{minTime => + updateCumTime(minTime, "BinarySearchTime") + } + } res match { case Some(true) => updateState(newModel) - case _ => - //here we have a new lower bound: currval + case _ => //here we have a new lower bound: currval lowerBound = currval if (this.debugMinimization) reporter.info("Found new lower bound: " + currval) @@ -107,14 +105,12 @@ class Minimizer(ctx: InferenceContext, program: Program) { } } } while (!ctx.abort && continue && iter < MaxIter) - //this is the last ditch effort to make the upper bound constant smaller. - //check if the floor of the upper-bound is a solution + //A last ditch effort to make the upper bound an integer. val currval @ FractionalLiteral(n, d) = - if (currentModel.isDefinedAt(tvar.id)) { + if (currentModel.isDefinedAt(tvar.id)) currentModel(tvar.id).asInstanceOf[FractionalLiteral] - } else { + else initModel(tvar.id).asInstanceOf[FractionalLiteral] - } if (d != 1 && !ctx.abort) { val (res, newModel) = solver.solveSAT(And(acc, Equals(tvar, floor(currval)))) if (res == Some(true)) diff --git a/src/main/scala/leon/invariant/util/SelectorToCons.scala b/src/main/scala/leon/invariant/util/SelectorToCons.scala new file mode 100644 index 0000000000000000000000000000000000000000..a458d8197d7ea9658b65b73a225de29a3c9c9e17 --- /dev/null +++ b/src/main/scala/leon/invariant/util/SelectorToCons.scala @@ -0,0 +1,116 @@ +package leon +package invariant.util + +import purescala.Common._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.Types._ +import invariant.factories._ +import solvers._ +import scala.collection.mutable.{ Set => MutableSet, Map => MutableMap } +import ExpressionTransformer._ +import TVarFactory._ + +object SelectToCons { + // temporaries generated during conversion of field selects to ADT constructions + val fieldSelContext = newContext +} + +/** + * A class that converts case-class or tuple selectors in an expression + * to constructors, and updates a given lazy model. + * We assume that all the arguments are flattened in the input expression. + */ +class SelectorToCons { + + import SelectToCons._ + + var fldIdMap = Map[Identifier, (Variable, Int)]() + + /** + * For now this works only on a disjunct + */ + def selToCons(disjunct: Seq[Expr]): Seq[Expr] = { + def classSelToCons(eq: Equals) = eq match { + case Equals(r: Variable, CaseClassSelector(ctype, cc: Variable, selfld)) => + //convert this to a cons by creating dummy variables + val args = ctype.fields.zipWithIndex.map { + case (fld, i) if fld.id == selfld => r + case (fld, i) => + val t = createTemp("fld", fld.getType, fieldSelContext) //create a dummy identifier there + fldIdMap += (t -> (cc, i)) + t.toVariable + } + Equals(cc, CaseClass(ctype, args)) + case _ => + throw new IllegalStateException("Selector not flattened: " + eq) + } + def tupleSelToCons(eq: Equals) = eq match { + case Equals(r: Variable, TupleSelect(tp: Variable, idx)) => + val tupleType = tp.getType.asInstanceOf[TupleType] + //convert this to a Tuple by creating dummy variables + val args = (1 until tupleType.dimension + 1).map { i => + if (i == idx) r + else { + val t = createTemp("fld", tupleType.bases(i - 1), fieldSelContext) //note: we have to use i-1 + fldIdMap += (t -> (tp, i - 1)) + t.toVariable + } + } + Equals(tp, Tuple(args)) + case _ => + throw new IllegalStateException("Selector not flattened: " + eq) + } + //println("Input expression: "+ine) + disjunct.map { // we need to traverse top-down + case eq @ Equals(_, _: CaseClassSelector) => + classSelToCons(eq) + case eq @ Equals(_, _: TupleSelect) => + tupleSelToCons(eq) + case _: CaseClassSelector | _: TupleSelect => + throw new IllegalStateException("Selector not flattened") + case e => e + } +// println("Output expression: "+rese) +// rese + } + + // def tupleSelToCons(e: Expr): Expr = { + // val (r, tpvar, index) = e match { + // case Equals(r0 @ Variable(_), TupleSelect(tpvar0, index0)) => (r0, tpvar0, index0) + // // case Iff(r0 @ Variable(_), TupleSelect(tpvar0, index0)) => (r0, tpvar0, index0) + // case _ => throw new IllegalStateException("Not a tuple-selector call") + // } + // //convert this to a Tuple by creating dummy variables + // val tupleType = tpvar.getType.asInstanceOf[TupleType] + // val args = (1 until tupleType.dimension + 1).map((i) => { + // if (i == index) r + // else { + // //create a dummy identifier there (note that here we have to use i-1) + // createTemp("fld", tupleType.bases(i - 1), fieldSelContext).toVariable + // } + // }) + // Equals(tpvar, Tuple(args)) + // } + + /** + * Expands a given model into a model with mappings for identifiers introduced during flattening. + * Note: this class cannot be accessed in parallel. + */ + def getModel(initModel: LazyModel) = new LazyModel { + override def get(iden: Identifier) = { + val idv = initModel.get(iden) + if (idv.isDefined) idv + else { + fldIdMap.get(iden) match { + case Some((Variable(inst), fldIdx)) => + initModel(inst) match { + case CaseClass(_, args) => Some(args(fldIdx)) + case Tuple(args) => Some(args(fldIdx)) + } + case None => None + } + } + } + } +} \ No newline at end of file diff --git a/src/main/scala/leon/invariant/util/SolverUtil.scala b/src/main/scala/leon/invariant/util/SolverUtil.scala index 6b4641e039887bd734fee18e0a33e37f83221b23..9d798ac020a953e5f138bc7fd35494c30ce12358 100644 --- a/src/main/scala/leon/invariant/util/SolverUtil.scala +++ b/src/main/scala/leon/invariant/util/SolverUtil.scala @@ -106,7 +106,7 @@ object SolverUtil { solver.free //cores - ExpressionTransformer.unFlatten(cores) + ExpressionTransformer.unflatten(cores) } //tests if the solver uses nlsat diff --git a/src/main/scala/leon/invariant/util/Stats.scala b/src/main/scala/leon/invariant/util/Stats.scala index 72239ac32227122c5f65c1ea73ac231d04014a11..eb5dc6efd83c35226fe90409482955357843ea89 100644 --- a/src/main/scala/leon/invariant/util/Stats.scala +++ b/src/main/scala/leon/invariant/util/Stats.scala @@ -82,6 +82,19 @@ object Stats { }) }) } + + def time[T](code: => T)(cont: Long => Unit): T = { + var t1 = System.currentTimeMillis() + val r = code + cont((System.currentTimeMillis() - t1)) + r + } + + def getTime[T](code: => T): (T, Long) = { + var t1 = System.currentTimeMillis() + val r = code + (r, (System.currentTimeMillis() - t1)) + } } /** diff --git a/src/main/scala/leon/invariant/util/TreeUtil.scala b/src/main/scala/leon/invariant/util/TreeUtil.scala index fa7edc9ff183c501b1a7c2b73548adac7558bd8a..5949082d339c198d57a7d17255ce2db5f5368498 100644 --- a/src/main/scala/leon/invariant/util/TreeUtil.scala +++ b/src/main/scala/leon/invariant/util/TreeUtil.scala @@ -16,6 +16,7 @@ import FunctionUtils._ import scala.annotation.tailrec import PredicateUtil._ import ProgramUtil._ +import TypeUtil._ import Util._ import solvers._ import purescala.DefOps._ @@ -34,7 +35,7 @@ object ProgramUtil { def copyProgram(prog: Program, mapdefs: (Seq[Definition] => Seq[Definition])): Program = { prog.copy(units = prog.units.collect { case unit if unit.defs.nonEmpty => unit.copy(defs = unit.defs.collect { - case module : ModuleDef if module.defs.nonEmpty => + case module: ModuleDef if module.defs.nonEmpty => module.copy(defs = mapdefs(module.defs)) case other => other }) @@ -77,7 +78,7 @@ object ProgramUtil { } res } - + def createTemplateFun(plainTemp: Expr): FunctionInvocation = { val tmpl = Lambda(getTemplateIds(plainTemp).toSeq.map(id => ValDef(id)), plainTemp) val tmplFd = new FunDef(FreshIdentifier("tmpl", FunctionType(Seq(tmpl.getType), BooleanType), false), Seq(), @@ -120,7 +121,7 @@ object ProgramUtil { * will be removed */ def assignTemplateAndCojoinPost(funToTmpl: Map[FunDef, Expr], prog: Program, - funToPost: Map[FunDef, Expr] = Map(), uniqueIdDisplay : Boolean = true): Program = { + funToPost: Map[FunDef, Expr] = Map(), uniqueIdDisplay: Boolean = true): Program = { val funMap = functionsWOFields(prog.definedFunctions).foldLeft(Map[FunDef, FunDef]()) { case (accMap, fd) if fd.isTheoryOperation => @@ -185,7 +186,7 @@ object ProgramUtil { } def updatePost(funToPost: Map[FunDef, Lambda], prog: Program, - uniqueIdDisplay: Boolean = true, excludeLibrary: Boolean = true): Program = { + uniqueIdDisplay: Boolean = true, excludeLibrary: Boolean = true): Program = { val funMap = functionsWOFields(prog.definedFunctions).foldLeft(Map[FunDef, FunDef]()) { case (accMap, fd) if fd.isTheoryOperation || fd.isLibrary => @@ -291,12 +292,12 @@ object PredicateUtil { def isTemplateExpr(expr: Expr): Boolean = { var foundVar = false simplePostTransform { - case e@Variable(id) => { + case e @ Variable(id) => { if (!TemplateIdFactory.IsTemplateIdentifier(id)) foundVar = true e } - case e@ResultVariable(_) => { + case e @ ResultVariable(_) => { foundVar = true e } @@ -354,11 +355,11 @@ object PredicateUtil { def atomNum(e: Expr): Int = { var count: Int = 0 simplePostTransform { - case e@And(args) => { + case e @ And(args) => { count += args.size e } - case e@Or(args) => { + case e @ Or(args) => { count += args.size e } @@ -370,7 +371,7 @@ object PredicateUtil { def numUIFADT(e: Expr): Int = { var count: Int = 0 simplePostTransform { - case e@(FunctionInvocation(_, _) | CaseClass(_, _) | Tuple(_)) => { + case e @ (FunctionInvocation(_, _) | CaseClass(_, _) | Tuple(_)) => { count += 1 e } @@ -401,8 +402,8 @@ object PredicateUtil { def isADTConstructor(e: Expr): Boolean = e match { case Equals(Variable(_), CaseClass(_, _)) => true - case Equals(Variable(_), Tuple(_)) => true - case _ => false + case Equals(Variable(_), Tuple(_)) => true + case _ => false } def isMultFunctions(fd: FunDef) = { @@ -423,28 +424,43 @@ object PredicateUtil { def createAnd(exprs: Seq[Expr]): Expr = { val newExprs = exprs.filterNot(conj => conj == tru) newExprs match { - case Seq() => tru + case Seq() => tru case Seq(e) => e - case _ => And(newExprs) + case _ => And(newExprs) } } def createOr(exprs: Seq[Expr]): Expr = { val newExprs = exprs.filterNot(disj => disj == fls) newExprs match { - case Seq() => fls + case Seq() => fls case Seq(e) => e - case _ => Or(newExprs) + case _ => Or(newExprs) } } - def isNumericType(t: TypeTree) = t match { - case IntegerType | RealType => true - case _ => false - } - def precOrTrue(fd: FunDef): Expr = fd.precondition match { case Some(pre) => pre case None => BooleanLiteral(true) } + + /** + * Computes the set of variables that are shared across disjunctions. + * This may return bound variables as well + */ + def sharedIds(e: Expr): Set[Identifier] = e match { + case Or(args) => + var uniqueVars = Set[Identifier]() + var sharedVars = Set[Identifier]() + args.foreach { arg => + val candUniques = variablesOf(arg) -- sharedVars + val newShared = uniqueVars.intersect(candUniques) + sharedVars ++= newShared + uniqueVars = (uniqueVars ++ candUniques) -- newShared + } + sharedVars ++ (args flatMap sharedIds) + case Variable(_) => Set() + case Operator(args, op) => + (args flatMap sharedIds).toSet + } } diff --git a/src/main/scala/leon/invariant/util/TypeUtil.scala b/src/main/scala/leon/invariant/util/TypeUtil.scala index bde5cb5cd6453a18881c8b76034ddaec09a2d209..bb1a70e326c43e362641cbd042d6f7e7d8862149 100644 --- a/src/main/scala/leon/invariant/util/TypeUtil.scala +++ b/src/main/scala/leon/invariant/util/TypeUtil.scala @@ -42,4 +42,11 @@ object TypeUtil { tcons(subtypes map instantiateTypeParameters(tpMap) _) } } + + def isNumericType(t: TypeTree) = t match { + case IntegerType | RealType => true + case Int32Type => + throw new IllegalStateException("BitVector types not supported yet!") + case _ => false + } } \ No newline at end of file diff --git a/src/main/scala/leon/invariant/util/UnflatHelper.scala b/src/main/scala/leon/invariant/util/UnflatHelper.scala new file mode 100644 index 0000000000000000000000000000000000000000..16ca818eda7092e22d62669dfb4b304d230dd8ea --- /dev/null +++ b/src/main/scala/leon/invariant/util/UnflatHelper.scala @@ -0,0 +1,85 @@ +package leon +package invariant.util + +import purescala.Common._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.Types._ +import invariant.factories._ +import solvers._ +import scala.collection.immutable._ +import scala.collection.mutable.{ Set => MutableSet, Map => MutableMap } +import ExpressionTransformer._ +import leon.evaluators._ +import EvaluationResults._ + +trait LazyModel { + def get(iden: Identifier): Option[Expr] + + def apply(iden: Identifier): Expr = + get(iden) match { + case Some(e) => e + case _ => throw new IllegalStateException(s"Cannot create mapping for $iden") + } + + def isDefinedAt(iden: Identifier) = get(iden).isDefined +} + +class SimpleLazyModel(m: Model) extends LazyModel { + def get(iden: Identifier): Option[Expr] = m.get(iden) +} + +/** + * Expands a given model into a model with mappings for identifiers introduced during flattening. + * Note: this class cannot be accessed in parallel. + */ +class FlatModel(freeVars: Set[Identifier], flatIdMap: Map[Identifier, Expr], initModel: Model, eval: DefaultEvaluator) extends LazyModel { + var idModel = initModel.toMap + + override def get(iden: Identifier) = { + var seen = Set[Identifier]() + def recBind(id: Identifier): Option[Expr] = { + val idv = idModel.get(id) + if (idv.isDefined) idv + else { + if (seen(id)) { + //we are in a cycle here + throw new IllegalStateException(s"$id depends on itself") + } else if (flatIdMap.contains(id)) { + val rhs = flatIdMap(id) + // recursively bind all freevars to values (we can ignore the return values) + seen += id + variablesOf(rhs).filterNot(idModel.contains).map(recBind) + eval.eval(rhs, idModel) match { + case Successful(v) => + idModel += (id -> v) + Some(v) + case _ => + throw new IllegalStateException(s"Evaluation Falied for $id -> $rhs") + } + } else if (freeVars(id)) { + // here, `id` either belongs to values of the flatIdMap, or to flate or was lost in unflattening + println(s"Completing $id with simplest value") + val simpv = simplestValue(id.getType) + idModel += (id -> simpv) + Some(simpv) + } else + None + //throw new IllegalStateException(s"Cannot extract model $id as it not contained in the input expression: $ine") + } + } + recBind(iden) + } +} + +/** + * A class that can used to compress a flattened expression + * and also expand the compressed models to the flat forms + */ +class UnflatHelper(ine: Expr, excludeIds: Set[Identifier], eval: DefaultEvaluator) { + + val (unflate, flatIdMap) = unflattenWithMap(ine, excludeIds, includeFuns = false) + val invars = variablesOf(ine) + + def getModel(m: Model) = new FlatModel(invars, flatIdMap, m, eval) +} \ No newline at end of file diff --git a/src/main/scala/leon/laziness/LazyVerificationPhase.scala b/src/main/scala/leon/laziness/LazyVerificationPhase.scala index 3f1a8ce00c9830eae59eb171e1d6096c61018091..c39a34581dbbf8523752ddeb51203eb5d9b46781 100644 --- a/src/main/scala/leon/laziness/LazyVerificationPhase.scala +++ b/src/main/scala/leon/laziness/LazyVerificationPhase.scala @@ -229,9 +229,6 @@ object LazyVerificationPhase { (matchToIfThenElse(ants), conseq) } - /** - * TODO: fix this!! - */ override def verifyInvariant(newposts: Map[FunDef, Expr]) = (Some(false), Model.empty) } } diff --git a/src/main/scala/leon/transformations/IntToRealProgram.scala b/src/main/scala/leon/transformations/IntToRealProgram.scala index 0c313f2ff3ce79809420969a49bc2d3b6f11b2f3..f585cf042177bdec7dea20585fb6d5a505825c9a 100644 --- a/src/main/scala/leon/transformations/IntToRealProgram.scala +++ b/src/main/scala/leon/transformations/IntToRealProgram.scala @@ -14,6 +14,7 @@ import invariant.util._ import Util._ import ProgramUtil._ import PredicateUtil._ +import TypeUtil._ import invariant.structure._ abstract class ProgramTypeTransformer { diff --git a/src/main/scala/leon/transformations/NonlinearityEliminationPhase.scala b/src/main/scala/leon/transformations/NonlinearityEliminationPhase.scala index a696dae0580baf44c6ef955c72ccdcce9791bdab..ecd0bca4b1dbf343be42f8c1e43b5e099de4b66f 100644 --- a/src/main/scala/leon/transformations/NonlinearityEliminationPhase.scala +++ b/src/main/scala/leon/transformations/NonlinearityEliminationPhase.scala @@ -6,6 +6,7 @@ import invariant.util._ import Util._ import ProgramUtil._ import PredicateUtil._ +import TypeUtil._ import invariant.structure.FunctionUtils._ import purescala.ScalaPrinter diff --git a/testcases/lazy-datastructures/withOrb/Concat.scala b/testcases/lazy-datastructures/withOrb/Concat.scala index 3b961a1bb22b708a8ee738c8676e43452158dd92..bd4978f6a4c6145470a3fd2983e8a4e27420a20c 100644 --- a/testcases/lazy-datastructures/withOrb/Concat.scala +++ b/testcases/lazy-datastructures/withOrb/Concat.scala @@ -6,8 +6,6 @@ import leon.annotation._ import leon.instrumentation._ import leon.collection._ import leon.invariant._ -import scala.BigInt -import scala.math.BigInt.int2bigInt object Concat { sealed abstract class LList[T] { diff --git a/vcs-temp b/vcs-temp new file mode 100644 index 0000000000000000000000000000000000000000..26603979f76790531e4cc68c6b3383567abc1358 --- /dev/null +++ b/vcs-temp @@ -0,0 +1,48 @@ +b376 == (!ci50 && ci50 == l1.isInstanceOf[Cons0]) && +b388 == (r219 - r220 <= BigInt(0) && r217 == (ts53, fld42) && r219 == size23(ts53) && r220 == size23(cs17)) && +b382 == (ci48 && ci48 == l1.isInstanceOf[Cons0] && ci49 && l1 == Cons0(fld38, cs14) && ci49 == cs14.isInstanceOf[Nil0] && ifres24 == tp8 && cc4 == Nil0() && arg846 + BigInt(-6) == BigInt(0) && tp8 == (cc4, arg846)) && +b387 == (cs17 == cc7 && cc7 == Nil0()) && +b385 == (-ts52 + BigInt(8) < BigInt(0)) && +b379 == (b389 && ci52 && ci52 == l1.isInstanceOf[Cons0] && ifres25 == tp9 && cc5 == Cons0(cs16, ts50) && arg847 == Cons0(cs16, fld34) && tp9 == (cc5, arg848) && e20 == (ts50, fld35) && e20 == (fld36, ts51) && (arg848 - ts51) + BigInt(-9) == BigInt(0) && arg847 == l1 && e20 == r217 && arg849 == l1 && arg849 == Cons0(fld37, cs17) && r217 == removeLast-time10(cs17)) && +b383 == (b378 && ifres24 == ifres25 && b381) && +b386 == (l1 != cc3 && cc3 == Nil0() && res117 == tp7 && bd1 == (ts48, fld39) && bd1 == (fld40, ts49) && tp7 == (ts48, ts49) && b384 && bd1 == ifres24 && b385 && res117 == (fld41, ts52) && r218 == size23(l1)) && +b377 == (!ci51 && l1 == Cons0(fld33, cs15) && ci51 == cs15.isInstanceOf[Nil0]) && +b380 == (!ci53 && ci53 == l1.isInstanceOf[Cons0] && ifres25 == tp10 && cc6 == Nil0() && arg850 + BigInt(-6) == BigInt(0) && tp10 == (cc6, arg850)) && +b378 == (b376 || b377) && +b381 == (b379 || b380) && +b384 == (b382 || b383) && +b389 == (b387 || b388) && +b386 + + +b376 == !l1.isInstanceOf[Cons0] && +b388 == (size23(ts53) - size23(Nil0()) <= BigInt(0)) && +b382 == (l1.isInstanceOf[Cons0] && cs14.isInstanceOf[Nil0] && l1 == Cons0(fld38, cs14) && ifres24 == (Nil0(), arg846) && arg846 + BigInt(-6) == BigInt(0)) && +b387 == (true && true) && +b385 == (-ts52 + BigInt(8) < BigInt(0)) && +b379 == (b389 && l1.isInstanceOf[Cons0] && ifres25 == (Cons0(cs16, ts50), arg848) && e20 == (ts50, fld35) && e20 == (fld36, ts51) && (arg848 - ts51) + BigInt(-9) == BigInt(0) && Cons0(cs16, fld34) == l1 && e20 == (ts53, fld42) && l1 == Cons0(fld37, Nil0()) && (ts53, fld42) == removeLast-time10(Nil0())) && +b383 == (b378 && ifres24 == ifres25 && b381) && +b386 == (l1 != Nil0() && res117 == (ts48, ts49) && bd1 == (ts48, fld39) && bd1 == (fld40, ts49) && b384 && bd1 == ifres24 && b385 && res117 == (fld41, ts52)) && +b377 == (!cs15.isInstanceOf[Nil0] && l1 == Cons0(fld33, cs15)) && +b380 == (!l1.isInstanceOf[Cons0] && ifres25 == (Nil0(), arg850) && arg850 + BigInt(-6) == BigInt(0)) && +b378 == (b376 || b377) && +b381 == (b379 || b380) && +b384 == (b382 || b383) && +b389 == (b387 || b388) && +b386 + + +b389 == ((cs17 == Nil0()) || (size23(ts53) - size23(cs17) <= BigInt(0) && r217 == (ts53, fld42))) && + +b379 == (b389 && l1.isInstanceOf[Cons0] && ifres25 == (Cons0(cs16, ts50), arg848) && e20 == (ts50, fld35) && e20 == (fld36, ts51) && (arg848 - ts51) + BigInt(-9) == BigInt(0) && Cons0(cs16, fld34) == l1 && e20 == (ts53, fld42) && l1 == Cons0(fld37, Nil0()) && (ts53, fld42) == removeLast-time10(Nil0())) && + +b380 == (!l1.isInstanceOf[Cons0] && ifres25 == (Nil0(), arg850) && arg850 + BigInt(-6) == BigInt(0)) && + +(l1 != Nil0() && res117 == (ts48, ts49) && bd1 == (ts48, fld39) && bd1 == (fld40, ts49) && ((l1.isInstanceOf[Cons0] && cs14.isInstanceOf[Nil0] && l1 == Cons0(fld38, cs14) && ifres24 == (Nil0(), arg846) && arg846 + BigInt(-6) == BigInt(0)) || +((!l1.isInstanceOf[Cons0] || (!cs15.isInstanceOf[Nil0] && l1 == Cons0(fld33, cs15))) && ifres24 == ifres25 && +(b379 || b380))) +&& bd1 == ifres24 && +(-ts52 + BigInt(8) < BigInt(0)) && res117 == (fld41, ts52)) && + +What happenned to +???? r218 == size23(l1)) ???