diff --git a/library/annotation/package.scala b/library/annotation/package.scala index b1ac6dbf90d66bf8b2a2ae0a3f312f621fa9c301..0e844188f55e3c1737beab20c359309019793f3c 100644 --- a/library/annotation/package.scala +++ b/library/annotation/package.scala @@ -24,7 +24,7 @@ package object annotation { @ignore class monotonic extends StaticAnnotation @ignore - class compose extends StaticAnnotation + class compose extends StaticAnnotation @ignore class axiom extends StaticAnnotation @ignore @@ -34,5 +34,7 @@ package object annotation { @ignore class invisibleBody extends StaticAnnotation // do not unfold the body of the function @ignore + class usePost extends StaticAnnotation // assume the post-condition while proving time bounds + @ignore class unfoldFactor(f: Int=0) extends StaticAnnotation // 0 implies no bound on unfolding } \ No newline at end of file diff --git a/src/main/scala/leon/invariant/datastructure/Graph.scala b/src/main/scala/leon/invariant/datastructure/Graph.scala index 2b4055b53f88f5a58a96f6e398043417dafb111d..e8c8a729688e9ea6c79898113404229d0e4ff15f 100644 --- a/src/main/scala/leon/invariant/datastructure/Graph.scala +++ b/src/main/scala/leon/invariant/datastructure/Graph.scala @@ -88,63 +88,60 @@ class DirectedGraph[T] { def getSuccessors(src: T): Set[T] = adjlist(src) /** - * Change this to the verified component + * TODO: Change this to the verified component + * The computed nodes are also in reverse topological order. */ def sccs: List[List[T]] = { type Component = List[T] case class State(count: Int, - visited: Map[T, Boolean], + visited: Set[T], dfNumber: Map[T, Int], lowlinks: Map[T, Int], stack: List[T], components: List[Component]) def search(vertex: T, state: State): State = { - val newState = state.copy(visited = state.visited.updated(vertex, true), - dfNumber = state.dfNumber.updated(vertex, state.count), + val newState = state.copy(visited = state.visited + vertex, + dfNumber = state.dfNumber + (vertex -> state.count), count = state.count + 1, - lowlinks = state.lowlinks.updated(vertex, state.count), + lowlinks = state.lowlinks + (vertex -> state.count), stack = vertex :: state.stack) - def processVertex(st: State, w: T): State = { + def processNeighbor(st: State, w: T): State = { if (!st.visited(w)) { val st1 = search(w, st) val min = Math.min(st1.lowlinks(w), st1.lowlinks(vertex)) - st1.copy(lowlinks = st1.lowlinks.updated(vertex, min)) + st1.copy(lowlinks = st1.lowlinks + (vertex -> min)) } else { if ((st.dfNumber(w) < st.dfNumber(vertex)) && st.stack.contains(w)) { val min = Math.min(st.dfNumber(w), st.lowlinks(vertex)) - st.copy(lowlinks = st.lowlinks.updated(vertex, min)) + st.copy(lowlinks = st.lowlinks + (vertex -> min)) } else st } } - - val strslt = getSuccessors(vertex).foldLeft(newState)(processVertex) - + val strslt = getSuccessors(vertex).foldLeft(newState)(processNeighbor) if (strslt.lowlinks(vertex) == strslt.dfNumber(vertex)) { - val index = strslt.stack.indexOf(vertex) val (comp, rest) = strslt.stack.splitAt(index + 1) strslt.copy(stack = rest, components = strslt.components :+ comp) } else strslt } - val initial = State( count = 1, - visited = getNodes.map { (_, false) }.toMap, + visited = Set(), dfNumber = Map(), lowlinks = Map(), stack = Nil, components = Nil) var state = initial - while (state.visited.exists(_._2 == false)) { - state.visited.find(_._2 == false).foreach { tuple => - val (vertex, _) = tuple - state = search(vertex, state) + val totalNodes = getNodes + while (state.visited.size < totalNodes.size) { + totalNodes.find(n => !state.visited.contains(n)).foreach { n => + state = search(n, state) } } state.components @@ -166,16 +163,15 @@ class DirectedGraph[T] { class UndirectedGraph[T] extends DirectedGraph[T] { override def addEdge(src: T, dest: T): Unit = { - val newset1 = if (adjlist.contains(src)) adjlist(src) + dest - else Set(dest) - - val newset2 = if (adjlist.contains(dest)) adjlist(dest) + src - else Set(src) - + val newset1 = + if (adjlist.contains(src)) adjlist(src) + dest + else Set(dest) + val newset2 = + if (adjlist.contains(dest)) adjlist(dest) + src + else Set(src) //this has some side-effects adjlist.update(src, newset1) adjlist.update(dest, newset2) - edgeCount += 1 } } diff --git a/src/main/scala/leon/invariant/engine/CompositionalTimeBoundSolver.scala b/src/main/scala/leon/invariant/engine/CompositionalTimeBoundSolver.scala index 7a9605c922e21373abb6df06fe9d585a104da2ba..1992607fa3dfed9e64ae4e920fd683e973f62f17 100644 --- a/src/main/scala/leon/invariant/engine/CompositionalTimeBoundSolver.scala +++ b/src/main/scala/leon/invariant/engine/CompositionalTimeBoundSolver.scala @@ -17,6 +17,7 @@ import leon.solvers.Model import Util._ import PredicateUtil._ import ProgramUtil._ +import invariant.factories.TemplateInstantiator._ class CompositionalTimeBoundSolver(ctx: InferenceContext, prog: Program, rootFd: FunDef) extends FunctionTemplateSolver { @@ -95,25 +96,21 @@ class CompositionalTimeBoundSolver(ctx: InferenceContext, prog: Program, rootFd: val timeUpperBound = ExpressionTransformer.normalizeMultiplication( Plus(FunctionInvocation(TypedFunDef(multFun, Seq()), Seq(recFunInst, tprFunInst)), tprFunInst), ctx.multOp) - // res = body - val plainBody = Equals(getResId(rootFd).get.toVariable, matchToIfThenElse(rootFd.body.get)) - val bodyExpr = if (rootFd.hasPrecondition) { - And(matchToIfThenElse(rootFd.precondition.get), plainBody) - } else plainBody - - val Operator(Seq(timeInstExpr, _), _) = timeTmpl - val compositionAnt = And(Seq(LessEquals(timeInstExpr, timeUpperBound), bodyExpr)) - val prototypeVC = And(compositionAnt, Not(timeTmpl)) // map the old functions in the vc using the new functions val substMap = origProg.definedFunctions.collect { - case fd => - (fd -> functionByName(fd.id.name, compProg).get) + case fd => (fd -> functionByName(fd.id.name, compProg).get) }.toMap - val vcExpr = mapFunctionsInExpr(substMap)(prototypeVC) + // res = body + val body = mapFunctionsInExpr(substMap)(Equals(getResId(rootFd).get.toVariable, rootFd.body.get)) + val pre = rootFd.precondition.getOrElse(tru) + val Operator(Seq(timeInstExpr, _), _) = timeTmpl + val trans = mapFunctionsInExpr(substMap) _ + val assump = trans(createAnd(Seq(LessEquals(timeInstExpr, timeUpperBound), pre))) + val conseq = trans(timeTmpl) if (printIntermediatePrograms) reporter.info("Comp prog: " + compProg) - if (debugComposition) reporter.info("Compositional VC: " + vcExpr) + if (debugComposition) reporter.info("Compositional VC: " + createAnd(Seq(assump, body, Not(conseq)))) val recTempSolver = new UnfoldingTemplateSolver(ctx, compProg, compFunDef) { val minFunc = { @@ -124,14 +121,13 @@ class CompositionalTimeBoundSolver(ctx: InferenceContext, prog: Program, rootFd: TemplateSolverFactory.createTemplateSolver(ctx, compProg, constTracker, rootFd, minFunc) override def instantiateModel(model: Model, funcs: Seq[FunDef]) = { funcs.collect { - case `compFunDef` => - compFunDef -> timeTmpl + case `compFunDef` => compFunDef -> timeTmpl case fd if fd.hasTemplate => - fd -> fd.getTemplate + fd -> instantiateNormTemplates(model, fd.normalizedTemplate.get) }.toMap } } - recTempSolver.solveParametricVC(vcExpr) match { + recTempSolver.solveParametricVC(assump, body, conseq) match { case Some(InferResult(true, Some(timeModel),timeInferredFuncs)) => val inferredFuns = (recInfRes.get.inferredFuncs ++ tprInfRes.get.inferredFuncs ++ timeInferredFuncs).distinct Some(InferResult(true, Some(recModel ++ tprModel.toMap ++ timeModel.toMap), @@ -200,29 +196,24 @@ class CompositionalTimeBoundSolver(ctx: InferenceContext, prog: Program, rootFd: def inferTPRTemplate(tprProg: Program) = { val tempSolver = new UnfoldingTemplateSolver(ctx, tprProg, findRoot(tprProg)) { - override def constructVC(rootFd: FunDef): (Expr, Expr) = { - val body = Equals(getResId(rootFd).get.toVariable, matchToIfThenElse(rootFd.body.get)) - val preExpr = - if (rootFd.hasPrecondition) - matchToIfThenElse(rootFd.precondition.get) - else tru + override def constructVC(rootFd: FunDef): (Expr, Expr, Expr) = { + val body = Equals(getResId(rootFd).get.toVariable, rootFd.body.get) + val preExpr = rootFd.precondition.getOrElse(tru) val tprTmpl = rootFd.getTemplate - val postWithTemplate = matchToIfThenElse(And(rootFd.getPostWoTemplate, tprTmpl)) + val postWithTemplate = And(rootFd.getPostWoTemplate, tprTmpl) // generate constraints characterizing decrease of the tpr function with recursive calls val Operator(Seq(_, tprFun), op) = tprTmpl val bodyFormula = new Formula(rootFd, ExpressionTransformer.normalizeExpr(body, ctx.multOp), ctx) - val constraints = bodyFormula.disjunctsInFormula.flatMap { - case (guard, ctrs) => - ctrs.collect { - case call @ Call(_, FunctionInvocation(TypedFunDef(`rootFd`, _), _)) => //direct recursive call ? - Implies(guard, LessEquals(replace(formalToActual(call), tprFun), tprFun)) - } + val constraints = bodyFormula.callsInFormula.collect { + case call @ Call(_, FunctionInvocation(TypedFunDef(`rootFd`, _), _)) => //direct recursive call ? + val cdata = bodyFormula.callData(call) + Implies(cdata.guard, LessEquals(replace(formalToActual(call), tprFun), tprFun)) } if (debugDecreaseConstraints) reporter.info("Decrease constraints: " + createAnd(constraints.toSeq)) val fullPost = createAnd(postWithTemplate +: constraints.toSeq) - (And(preExpr, bodyFormula.toExpr), fullPost) + (bodyFormula.toExpr, preExpr, fullPost) } } tempSolver() diff --git a/src/main/scala/leon/invariant/engine/ConstraintTracker.scala b/src/main/scala/leon/invariant/engine/ConstraintTracker.scala index 95927a66ad3bda803a329b61e284f854a94bfcb4..6c7d1e2ec7d6c170cb66505357b515f73de521ab 100644 --- a/src/main/scala/leon/invariant/engine/ConstraintTracker.scala +++ b/src/main/scala/leon/invariant/engine/ConstraintTracker.scala @@ -4,9 +4,16 @@ package invariant.engine import purescala.Definitions._ import purescala.Expressions._ import invariant.structure._ +import invariant.util.ExpressionTransformer._ +import purescala.ExprOps._ +import invariant.util.PredicateUtil._ +object ConstraintTracker { + val debugVC = false +} class ConstraintTracker(ctx : InferenceContext, program: Program, rootFun : FunDef/*, temFactory: TemplateFactory*/) { + import ConstraintTracker._ //a mapping from functions to its VCs represented as a CNF formula protected var funcVCs = Map[FunDef,Formula]() @@ -17,8 +24,26 @@ class ConstraintTracker(ctx : InferenceContext, program: Program, rootFun : FunD def hasVC(fdef: FunDef) = funcVCs.contains(fdef) def getVC(fd: FunDef) : Formula = funcVCs(fd) - def addVC(fd: FunDef, vc: Expr) = { - funcVCs += (fd -> new Formula(fd, vc, ctx)) + /** + * @param body the body part of the VC that may possibly have instrumentation + * @param assump is the additional assumptions e.g. pre and conseq + * is the goal e.g. post + * The VC constructed is assump ^ body ^ Not(conseq) + */ + def addVC(fd: FunDef, assump: Expr, body: Expr, conseq: Expr) = { + if(debugVC) { + println(s"Init VC \n assumption: $assump \n body: $body \n conseq: $conseq") + } + val flatBody = normalizeExpr(body, ctx.multOp) + val flatAssump = normalizeExpr(assump, ctx.multOp) + val conseqNeg = normalizeExpr(Not(conseq), ctx.multOp) + val callCollect = collect { + case c @ Equals(_, _: FunctionInvocation) => Set[Expr](c) + case _ => Set[Expr]() + } _ + val specCalls = callCollect(flatAssump) ++ callCollect(conseqNeg) + val vc = createAnd(Seq(flatAssump, flatBody, conseqNeg)) + funcVCs += (fd -> new Formula(fd, vc, ctx, specCalls)) } def initialize = { diff --git a/src/main/scala/leon/invariant/engine/InferInvariantsPhase.scala b/src/main/scala/leon/invariant/engine/InferInvariantsPhase.scala index 66a4bdfa16d80ea4a69926cc62704b92c6121535..78fc476dd8ab9f3b8bc70010a40e021e2424f367 100644 --- a/src/main/scala/leon/invariant/engine/InferInvariantsPhase.scala +++ b/src/main/scala/leon/invariant/engine/InferInvariantsPhase.scala @@ -22,12 +22,11 @@ object InferInvariantsPhase extends SimpleLeonPhase[Program, InferenceReport] { val optNLTimeout = LeonLongOptionDef("nlTimeout", "Timeout after T seconds when trying to solve nonlinear constraints", 20, "s") val optDisableInfer = LeonFlagOptionDef("disableInfer", "Disable automatic inference of auxiliary invariants", false) val optAssumePre = LeonFlagOptionDef("assumepreInf", "Assume preconditions of callees during unrolling", false) - val optStats = LeonFlagOptionDef("stats", "Tracks and prints detailed statistics", false) override val definedOptions: Set[LeonOptionDef[Any]] = Set(optFunctionUnroll, optWithMult, optUseReals, optMinBounds, optInferTemp, optCegis, optStatsSuffix, optVCTimeout, - optNLTimeout, optDisableInfer, optAssumePre, optStats) + optNLTimeout, optDisableInfer, optAssumePre) def apply(ctx: LeonContext, program: Program): InferenceReport = { val inferctx = new InferenceContext(program, ctx) diff --git a/src/main/scala/leon/invariant/engine/InferenceContext.scala b/src/main/scala/leon/invariant/engine/InferenceContext.scala index ef560e9f8e195233dd5ab3eecb1a5b2312ea0d39..48a45b048a33691f368242b63b45ae5bc062881c 100644 --- a/src/main/scala/leon/invariant/engine/InferenceContext.scala +++ b/src/main/scala/leon/invariant/engine/InferenceContext.scala @@ -33,10 +33,10 @@ class InferenceContext(val initProgram: Program, val leonContext: LeonContext) { val withmult = leonContext.findOption(optWithMult).getOrElse(false) val usereals = leonContext.findOption(optUseReals).getOrElse(false) val useCegis: Boolean = leonContext.findOption(optCegis).getOrElse(false) - val dumpStats = leonContext.findOption(optStats).getOrElse(false) + val dumpStats = leonContext.findOption(SharedOptions.optBenchmark).getOrElse(false) // the following options have default values - val vcTimeout = leonContext.findOption(optVCTimeout).getOrElse(30L) // in secs + val vcTimeout = leonContext.findOption(optVCTimeout).getOrElse(15L) // in secs val nlTimeout = leonContext.findOption(optNLTimeout).getOrElse(15L) val totalTimeout = leonContext.findOption(SharedOptions.optTimeout) // in secs val functionsToInfer = leonContext.findOption(SharedOptions.optFunctions) @@ -97,7 +97,9 @@ class InferenceContext(val initProgram: Program, val leonContext: LeonContext) { def isFunctionPostVerified(funName: String) = { if (validPosts.contains(funName)) { validPosts(funName).isValid - } else { + } + else if (abort) false + else { val verifyPipe = VerificationPhase val ctxWithTO = createLeonContext(leonContext, s"--timeout=$vcTimeout", s"--functions=$funName") (true /: verifyPipe.run(ctxWithTO, qMarksRemovedProg)._2.results) { diff --git a/src/main/scala/leon/invariant/engine/InferenceEngine.scala b/src/main/scala/leon/invariant/engine/InferenceEngine.scala index 1af3c2c56e83125dab7f2da859136667f460c946..3e7f2fa3fe143c18ddbf4d4cc43294e9f824caa1 100644 --- a/src/main/scala/leon/invariant/engine/InferenceEngine.scala +++ b/src/main/scala/leon/invariant/engine/InferenceEngine.scala @@ -24,8 +24,10 @@ import Stats._ class InferenceEngine(ctx: InferenceContext) extends Interruptible { val debugBottomupIterations = false + val debugAnalysisOrder = false val ti = new TimeoutFor(this) + val reporter = ctx.reporter def interrupt() = { ctx.abort = true @@ -48,39 +50,27 @@ class InferenceEngine(ctx: InferenceContext) extends Interruptible { } private def run(progressCallback: Option[InferenceCondition => Unit] = None): InferenceReport = { - val reporter = ctx.reporter val program = ctx.inferProgram reporter.info("Running Inference Engine...") if (ctx.dumpStats) { //register a shutdownhook sys.ShutdownHookThread({ dumpStats(ctx.statsSuffix) }) } + val relfuns = ctx.functionsToInfer.getOrElse(program.definedFunctions.map(InstUtil.userFunctionName)) var results: Map[FunDef, InferenceCondition] = null time { - //compute functions to analyze by sorting based on topological order (this is an ascending topological order) - val callgraph = CallGraphUtil.constructCallGraph(program, withTemplates = true) - val functionsToAnalyze = ctx.functionsToInfer match { - case Some(rootfuns) => - val rootset = rootfuns.toSet - val rootfds = program.definedFunctions.filter(fd => rootset(InstUtil.userFunctionName(fd))) - val relfuns = rootfds.flatMap(callgraph.transitiveCallees _).toSet - callgraph.topologicalOrder.filter { fd => relfuns(fd) } - case _ => - callgraph.topologicalOrder - } - //reporter.info("Analysis Order: " + functionsToAnalyze.map(_.id)) - if (!ctx.useCegis) { - results = analyseProgram(program, functionsToAnalyze, defaultVCSolver, progressCallback) + results = analyseProgram(program, relfuns, defaultVCSolver, progressCallback) //println("Inferrence did not succeeded for functions: "+functionsToAnalyze.filterNot(succeededFuncs.contains _).map(_.id)) } else { - var remFuncs = functionsToAnalyze + var remFuncs = relfuns var b = 200 val maxCegisBound = 200 breakable { while (b <= maxCegisBound) { Stats.updateCumStats(1, "CegisBoundsTried") val succeededFuncs = analyseProgram(program, remFuncs, defaultVCSolver, progressCallback) - remFuncs = remFuncs.filterNot(succeededFuncs.contains _) + val successes = succeededFuncs.keySet.map(InstUtil.userFunctionName) + remFuncs = remFuncs.filterNot(successes.contains _) if (remFuncs.isEmpty) break b += 5 //increase bounds in steps of 5 } @@ -92,15 +82,12 @@ class InferenceEngine(ctx: InferenceContext) extends Interruptible { reporter.info("- Dumping statistics") dumpStats(ctx.statsSuffix) } - new InferenceReport(results.map(pair => { - val (fd, ic) = pair - (fd -> List[VC](ic)) - }), program)(ctx) + new InferenceReport(results.map { case (fd, ic) => (fd -> List[VC](ic)) }, program)(ctx) } def dumpStats(statsSuffix: String) = { //pick the module id. - val modid = ctx.inferProgram.modules.find(_.definedFunctions.exists(!_.isLibrary)).get.id + val modid = ctx.inferProgram.units.find(_.isMainUnit).get.id val filename = modid + statsSuffix + ".txt" val pw = new PrintWriter(filename) Stats.dumpStats(pw) @@ -108,7 +95,8 @@ class InferenceEngine(ctx: InferenceContext) extends Interruptible { if (ctx.tightBounds) { SpecificStats.dumpMinimizationStats(pw) } - ctx.reporter.info("Stats dumped to file: "+filename) + pw.close() + ctx.reporter.info("Stats dumped to file: " + filename) } def defaultVCSolver = @@ -119,15 +107,33 @@ class InferenceEngine(ctx: InferenceContext) extends Interruptible { new UnfoldingTemplateSolver(ctx, prog, funDef) } + /** + * sort the given functions based on ascending topological order of the callgraph. + * For SCCs, preserve the order in which the functions are called in the program + */ + def sortByTopologicalOrder(program: Program, relfuns: Seq[String]) = { + val callgraph = CallGraphUtil.constructCallGraph(program, onlyBody = true) + val relset = relfuns.toSet + val relfds = program.definedFunctions.filter(fd => relset(InstUtil.userFunctionName(fd))) + val funsToAnalyze = relfds.flatMap(callgraph.transitiveCallees _).toSet + // note: the order preserves the order in which functions appear in the program within an SCC + val funsInOrder = callgraph.reverseTopologicalOrder(program.definedFunctions).filter(funsToAnalyze) + if (debugAnalysisOrder) + reporter.info("Analysis Order: " + funsInOrder.map(_.id.uniqueName)) + funsInOrder + } + /** * Returns map from analyzed functions to their inference conditions. + * @param - a list of user-level function names that need to analyzed. The names should not + * include the instrumentation suffixes * TODO: use function names in inference conditions, so that * we an get rid of dependence on origFd in many places. */ - def analyseProgram(startProg: Program, functionsToAnalyze: Seq[FunDef], + def analyseProgram(startProg: Program, relfuns: Seq[String], vcSolver: (FunDef, Program) => FunctionTemplateSolver, progressCallback: Option[InferenceCondition => Unit]): Map[FunDef, InferenceCondition] = { - val reporter = ctx.reporter + val functionsToAnalyze = sortByTopologicalOrder(startProg, relfuns) val funToTmpl = if (ctx.autoInference) { //A template generator that generates templates for the functions (here we are generating templates by enumeration) @@ -140,10 +146,11 @@ class InferenceEngine(ctx: InferenceContext) extends Interruptible { val progWithTemplates = assignTemplateAndCojoinPost(funToTmpl, startProg) var analyzedSet = Map[FunDef, InferenceCondition]() - functionsToAnalyze.filterNot((fd) => { + functionsToAnalyze.filterNot(fd => { (fd.annotations contains "verified") || (fd.annotations contains "library") || - (fd.annotations contains "theoryop") + (fd.annotations contains "theoryop") || + (fd.annotations contains "extern") }).foldLeft(progWithTemplates) { (prog, origFun) => if (debugBottomupIterations) { @@ -182,10 +189,7 @@ class InferenceEngine(ctx: InferenceContext) extends Interruptible { !analyzedSet.contains(origFd) && origFd.hasTemplate } // now the templates of these functions will be replaced by inferred invariants - val invs = TemplateInstantiator.getAllInvariants(model.get, - funsWithTemplates.collect { - case fd if fd.hasTemplate => fd -> fd.getTemplate - }.toMap) + val invs = TemplateInstantiator.getAllInvariants(model.get, funsWithTemplates) // collect templates of remaining functions val funToTmpl = prog.definedFunctions.collect { case fd if !invs.contains(fd) && fd.hasTemplate => @@ -197,11 +201,9 @@ class InferenceEngine(ctx: InferenceContext) extends Interruptible { inferredFuns.foreach { fd => val origFd = origFds(fd) val invOpt = if (funsWithTemplates.contains(fd)) { - Some(TemplateInstantiator.getAllInvariants(model.get, - Map(origFd -> origFd.getTemplate), prettyInv = true)(origFd)) + Some(TemplateInstantiator.getAllInvariants(model.get, Seq(origFd), prettyInv = true)(origFd)) } else if (fd.hasTemplate) { - val currentInv = TemplateInstantiator.getAllInvariants(model.get, - Map(fd -> fd.getTemplate), prettyInv = true)(fd) + val currentInv = TemplateInstantiator.getAllInvariants(model.get, Seq(fd), prettyInv = true)(fd) // map result variable in currentInv val repInv = replace(Map(getResId(fd).get.toVariable -> getResId(origFd).get.toVariable), currentInv) Some(translateExprToProgram(repInv, prog, startProg)) diff --git a/src/main/scala/leon/invariant/engine/RefinementEngine.scala b/src/main/scala/leon/invariant/engine/RefinementEngine.scala index 786a66b256cf953e91c3430f5ff6afa9cad102cd..f60a86af0228c30db578c2d288cb3b5c4de125f9 100644 --- a/src/main/scala/leon/invariant/engine/RefinementEngine.scala +++ b/src/main/scala/leon/invariant/engine/RefinementEngine.scala @@ -65,7 +65,7 @@ class RefinementEngine(ctx: InferenceContext, prog: Program, ctrTracker: Constra val newguards = formula.disjunctsInFormula.keySet.diff(exploredGuards) exploredGuards ++= newguards - val newheads = newguards.flatMap(g => disjuncts(g).collect { case c: Call => c }) + val newheads = formula.getCallsOfGuards(newguards.toSeq) //.flatMap(g => disjuncts(g).collect { case c: Call => c }) val allheads = getHeads(fd) ++ newheads //unroll each call in the head pointers and in toRefineCalls @@ -115,13 +115,18 @@ class RefinementEngine(ctx: InferenceContext, prog: Program, ctrTracker: Constra unrolls }).toSet } + import leon.transformations.InstUtil._ - def shouldCreateVC(recFun: FunDef): Boolean = { + def shouldCreateVC(recFun: FunDef, inSpec: Boolean): Boolean = { if (ctrTracker.hasVC(recFun)) false else { - //need not create vcs for theory operations - !recFun.isTheoryOperation && recFun.hasTemplate && - !recFun.annotations.contains("library") + //need not create vcs for theory operations and library methods + !recFun.isTheoryOperation && !recFun.annotations.contains("library") && + (recFun.template match { + case Some(temp) if inSpec && isResourceBoundOf(recFun)(temp) => false // TODO: here we can also drop resource templates if it is used with other templates + case Some(_) => true + case _ => false + }) } } @@ -130,82 +135,66 @@ class RefinementEngine(ctx: InferenceContext, prog: Program, ctrTracker: Constra * here we unroll the methods in the current abstraction by one step. * This procedure has side-effects on 'headCalls' and 'callDataMap' */ - def unrollCall(call: Call, formula: Formula) = { + def unrollCall(call: Call, formula: Formula) { val fi = call.fi + val calldata = formula.callData(call) + val callee = fi.tfd.fd if (fi.tfd.fd.hasBody) { - //freshen the body and the post - val isRecursive = cg.isRecursive(fi.tfd.fd) + val isRecursive = cg.isRecursive(callee) if (isRecursive) { - val recFun = fi.tfd.fd + val recFun = callee val recFunTyped = fi.tfd - //check if we need to create a VC formula for the call's target - if (shouldCreateVC(recFun)) { + if (shouldCreateVC(recFun, calldata.inSpec)) { reporter.info("Creating VC for " + recFun.id) // instantiate the body with new types val tparamMap = (recFun.tparams zip recFunTyped.tps).toMap val paramMap = recFun.params.map{pdef => pdef.id -> FreshIdentifier(pdef.id.name, instantiateType(pdef.id.getType, tparamMap)) }.toMap - val newbody = freshenLocals(matchToIfThenElse(recFun.body.get)) - val freshBody = instantiateType(newbody, tparamMap, paramMap) - val resvar = if (recFun.hasPostcondition) { - //create a new result variable here for the same reason as freshening the locals, - //which is to avoid variable capturing during unrolling - val origRes = getResId(recFun).get - Variable(FreshIdentifier(origRes.name, recFunTyped.returnType, true)) - } else { - //create a new resvar - Variable(FreshIdentifier("res", recFunTyped.returnType, true)) - } - val plainBody = Equals(resvar, freshBody) - val bodyExpr = - if (recFun.hasPrecondition) { - val pre = instantiateType(matchToIfThenElse(recFun.precondition.get), tparamMap, paramMap) - And(pre, plainBody) - } else plainBody - + val freshBody = instantiateType(freshenLocals(recFun.body.get), tparamMap, paramMap) + val resname = if (recFun.hasPostcondition) getResId(recFun).get.name else "res" + //create a new result variable here for the same reason as freshening the locals, + //which is to avoid variable capturing during unrolling + val resvar = Variable(FreshIdentifier(resname, recFunTyped.returnType, true)) + val bodyExpr = Equals(resvar, freshBody) + val pre = recFun.precondition.map(p => instantiateType(p, tparamMap, paramMap)).getOrElse(tru) //note: here we are only adding the template as the postcondition (other post need not be proved again) - val idmap = formalToActual(Call(resvar, FunctionInvocation(recFunTyped, - paramMap.values.toSeq.map(_.toVariable)))) + val idmap = formalToActual(Call(resvar, FunctionInvocation(recFunTyped, paramMap.values.toSeq.map(_.toVariable)))) val postTemp = replace(idmap, recFun.getTemplate) - val vcExpr = ExpressionTransformer.normalizeExpr(And(bodyExpr, Not(postTemp)), ctx.multOp) - ctrTracker.addVC(recFun, vcExpr) + //val vcExpr = ExpressionTransformer.normalizeExpr(And(bodyExpr, Not(postTemp)), ctx.multOp) + ctrTracker.addVC(recFun, pre, bodyExpr, postTemp) } - //Here, unroll the call into the caller tree if (verbose) reporter.info("Unrolling " + Equals(call.retexpr, call.fi)) - inilineCall(call, formula) + inilineCall(call, calldata, formula) } else { //here we are unrolling a function without template if (verbose) reporter.info("Unfolding " + Equals(call.retexpr, call.fi)) - inilineCall(call, formula) + inilineCall(call, calldata, formula) } } else Set() } - def inilineCall(call: Call, formula: Formula) = { + def inilineCall(call: Call, calldata: CallData, formula: Formula) { val tfd = call.fi.tfd val callee = tfd.fd if (callee.isBodyVisible) { //here inline the body and conjoin it with the guard //Important: make sure we use a fresh body expression here, and freshenlocals val tparamMap = (callee.tparams zip tfd.tps).toMap - val newbody = freshenLocals(matchToIfThenElse(callee.body.get)) - val freshBody = instantiateType(newbody, tparamMap, Map()) - val calleeSummary = - Equals(getFunctionReturnVariable(callee), freshBody) + val freshBody = instantiateType(freshenLocals(callee.body.get), tparamMap, Map()) + val calleeSummary = Equals(getFunctionReturnVariable(callee), freshBody) val argmap1 = formalToActual(call) val inlinedSummary = ExpressionTransformer.normalizeExpr(replace(argmap1, calleeSummary), ctx.multOp) if (this.dumpInlinedSummary) - println("Inlined Summary: " + inlinedSummary) + println(s"Inlined Summary of ${callee.id}: " + inlinedSummary) //conjoin the summary with the disjunct corresponding to the 'guard' //note: the parents of the summary are the parents of the call plus the callee function - val calldata = formula.callData(call) - formula.conjoinWithDisjunct(calldata.guard, inlinedSummary, (callee +: calldata.parents)) + formula.conjoinWithDisjunct(calldata.guard, inlinedSummary, (callee +: calldata.parents), calldata.inSpec) } else { if (verbose) reporter.info(s"Not inlining ${call.fi}: body invisible!") diff --git a/src/main/scala/leon/invariant/engine/SpecInstantiator.scala b/src/main/scala/leon/invariant/engine/SpecInstantiator.scala index 0f3e39265d9b252b776d4ff8a40f720b40d3784c..d38bdce1f781e9f438ff2bb273b0e2911e0c6ba7 100644 --- a/src/main/scala/leon/invariant/engine/SpecInstantiator.scala +++ b/src/main/scala/leon/invariant/engine/SpecInstantiator.scala @@ -14,6 +14,7 @@ import scala.util.control.Breaks._ import solvers._ import scala.concurrent._ import scala.concurrent.duration._ +import leon.evaluators.DefaultEvaluator import invariant.templateSolvers._ import invariant.factories._ @@ -47,7 +48,7 @@ class SpecInstantiator(ctx: InferenceContext, program: Program, ctrTracker: Cons val newguards = disjuncts.keySet.diff(exploredGuards) exploredGuards ++= newguards - val newcalls = newguards.flatMap(g => disjuncts(g).collect { case c: Call => c }) + val newcalls = formula.getCallsOfGuards(newguards.toSeq).toSet //flatMap(g => disjuncts(g).collect { case c: Call => c }) instantiateSpecs(formula, newcalls, funcs.toSet) if (!disableAxioms) { @@ -85,26 +86,25 @@ class SpecInstantiator(ctx: InferenceContext, program: Program, ctrTracker: Cons val spec = specForCall(call) if (spec.isDefined && spec.get != tru) { val cdata = formula.callData(call) - formula.conjoinWithDisjunct(cdata.guard, spec.get, cdata.parents) + formula.conjoinWithDisjunct(cdata.guard, spec.get, cdata.parents, inSpec = true) } }) //try to assume templates for all the current un-templated calls var newUntemplatedCalls = Set[Call]() - getUntempCalls(formula.fd).foreach((call) => { - //first get the template for the call if one needs to be added - if (funcsWithVC.contains(call.fi.tfd.fd)) { + getUntempCalls(formula.fd).foreach { call => + if (funcsWithVC.contains(call.fi.tfd.fd)) { // add templates of only functions for which there exists a VC templateForCall(call) match { case Some(temp) => val cdata = formula.callData(call) - formula.conjoinWithDisjunct(cdata.guard, temp, cdata.parents) + formula.conjoinWithDisjunct(cdata.guard, temp, cdata.parents, inSpec = true) case _ => ; // here there is no template for the call } } else { newUntemplatedCalls += call } - }) + } resetUntempCalls(formula.fd, newUntemplatedCalls ++ calls) } @@ -116,7 +116,7 @@ class SpecInstantiator(ctx: InferenceContext, program: Program, ctrTracker: Cons if (callee.hasPostcondition) { // instantiate the post val tparamMap = (callee.tparams zip tfd.tps).toMap - val trans = freshenLocals _ andThen (e => instantiateType(e, tparamMap, Map())) andThen matchToIfThenElse _ + val trans = freshenLocals _ andThen (e => instantiateType(e, tparamMap, Map())) //get the postcondition without templates val rawpost = trans(callee.getPostWoTemplate) val rawspec = if (callee.hasPrecondition) { @@ -145,7 +145,7 @@ class SpecInstantiator(ctx: InferenceContext, program: Program, ctrTracker: Cons val tempExpr = replace(argmap, instantiateType(callee.getTemplate, tparamMap, Map())) val template = if (callee.hasPrecondition) { val pre = replace(argmap, instantiateType(callee.precondition.get, tparamMap, Map())) - val freshPre = freshenLocals(matchToIfThenElse(pre)) + val freshPre = freshenLocals(pre) if (ctx.assumepre) And(freshPre, tempExpr) else @@ -153,9 +153,8 @@ class SpecInstantiator(ctx: InferenceContext, program: Program, ctrTracker: Cons } else { tempExpr } - //flatten functions - //TODO: should we freshen locals here ?? - Some(ExpressionTransformer.normalizeExpr(matchToIfThenElse(template), ctx.multOp)) + //TODO: should we freshen locals of template here ?? + Some(ExpressionTransformer.normalizeExpr(template, ctx.multOp)) } else None } @@ -201,7 +200,7 @@ class SpecInstantiator(ctx: InferenceContext, program: Program, ctrTracker: Cons val axiomInst = Implies(ant, conseq) val nnfAxiom = ExpressionTransformer.normalizeExpr(axiomInst, ctx.multOp) val cdata = formula.callData(call) - formula.conjoinWithDisjunct(cdata.guard, nnfAxiom, cdata.parents) + formula.conjoinWithDisjunct(cdata.guard, nnfAxiom, cdata.parents, inSpec = true) axiomInst } } @@ -255,7 +254,7 @@ class SpecInstantiator(ctx: InferenceContext, program: Program, ctrTracker: Cons val (ant, conseq) = inst val axiom = Implies(ant, conseq) val nnfAxiom = ExpressionTransformer.normalizeExpr(axiom, ctx.multOp) - val (axroot, _) = formula.conjoinWithRoot(nnfAxiom, parents) + val axroot = formula.conjoinWithRoot(nnfAxiom, parents, true) //important: here we need to update the axiom roots axiomRoots += (Seq(pair._1, pair._2) -> axroot) acc :+ axiom @@ -269,7 +268,7 @@ class SpecInstantiator(ctx: InferenceContext, program: Program, ctrTracker: Cons * Note: taking a formula as input may not be necessary. We can store it as a part of the state * TODO: can we use transitivity here to optimize ? */ - def axiomsForCalls(formula: Formula, calls: Set[Call], model: LazyModel): Seq[Constraint] = { + def axiomsForCalls(formula: Formula, calls: Set[Call], model: LazyModel, tmplMap: Map[Identifier,Expr], eval: DefaultEvaluator): Seq[Constraint] = { //note: unary axioms need not be instantiated //consider only binary axioms (for (x <- calls; y <- calls) yield (x, y)).foldLeft(Seq[Constraint]())((acc, pair) => { @@ -277,7 +276,7 @@ class SpecInstantiator(ctx: InferenceContext, program: Program, ctrTracker: Cons if (c1 != c2) { val axRoot = axiomRoots.get(Seq(c1, c2)) if (axRoot.isDefined) - acc ++ formula.pickSatDisjunct(axRoot.get, model) + acc ++ formula.pickSatDisjunct(axRoot.get, model, tmplMap, eval) else acc } else acc }) diff --git a/src/main/scala/leon/invariant/engine/UnfoldingTemplateSolver.scala b/src/main/scala/leon/invariant/engine/UnfoldingTemplateSolver.scala index a5e48d4472e03a6e2b1481892f636382a7c63bc9..15063dd35ae1812c42ef9d1831a10406c4ae84ae 100644 --- a/src/main/scala/leon/invariant/engine/UnfoldingTemplateSolver.scala +++ b/src/main/scala/leon/invariant/engine/UnfoldingTemplateSolver.scala @@ -8,6 +8,7 @@ import purescala.ExprOps._ import purescala.Types._ import purescala.DefOps._ import purescala.ScalaPrinter +import purescala.Constructors._ import solvers._ import verification._ @@ -40,19 +41,15 @@ class UnfoldingTemplateSolver(ctx: InferenceContext, program: Program, rootFd: F lazy val constTracker = new ConstraintTracker(ctx, program, rootFd) lazy val templateSolver = TemplateSolverFactory.createTemplateSolver(ctx, program, constTracker, rootFd) - def constructVC(funDef: FunDef): (Expr, Expr) = { - val body = funDef.body.get + def constructVC(funDef: FunDef): (Expr, Expr, Expr) = { val Lambda(Seq(ValDef(resid)), _) = funDef.postcondition.get - val resvar = resid.toVariable - - val simpBody = matchToIfThenElse(body) - val plainBody = Equals(resvar, simpBody) - val bodyExpr = if (funDef.hasPrecondition) { - And(matchToIfThenElse(funDef.precondition.get), plainBody) - } else plainBody - + val body = Equals(resid.toVariable, funDef.body.get) val funName = fullName(funDef, useUniqueIds = false)(program) - val fullPost = matchToIfThenElse( + val assumptions = + if (funDef.usePost && ctx.isFunctionPostVerified(funName)) + createAnd(Seq(funDef.getPostWoTemplate, funDef.precOrTrue)) + else funDef.precOrTrue + val fullPost = if (funDef.hasTemplate) { // if the postcondition is verified do not include it in the sequent if (ctx.isFunctionPostVerified(funName)) @@ -61,19 +58,13 @@ class UnfoldingTemplateSolver(ctx: InferenceContext, program: Program, rootFd: F And(funDef.getPostWoTemplate, funDef.getTemplate) } else if (!ctx.isFunctionPostVerified(funName)) funDef.getPostWoTemplate - else - BooleanLiteral(true)) - - (bodyExpr, fullPost) + else tru + (body, assumptions, fullPost) } - def solveParametricVC(vc: Expr) = { - val vcExpr = ExpressionTransformer.normalizeExpr(vc, ctx.multOp) - //for debugging - if (debugVCs) reporter.info("flattened VC: " + ScalaPrinter(vcExpr)) - + def solveParametricVC(assump: Expr, body: Expr, conseq: Expr) = { // initialize the constraint tracker - constTracker.addVC(rootFd, vcExpr) + constTracker.addVC(rootFd, assump, body, conseq) var refinementStep: Int = 0 var toRefineCalls: Option[Set[Call]] = None @@ -113,8 +104,7 @@ class UnfoldingTemplateSolver(ctx: InferenceContext, program: Program, rootFd: F toRefineCalls = callsInPath //Validate the model here instantiateAndValidateModel(model, constTracker.getFuncs) - Some(InferResult(true, Some(model), - constTracker.getFuncs.toList)) + Some(InferResult(true, Some(model), constTracker.getFuncs.toList)) case (None, callsInPath) => toRefineCalls = callsInPath //here, we do not know if the template is solvable or not, we need to do more unrollings. @@ -130,26 +120,23 @@ class UnfoldingTemplateSolver(ctx: InferenceContext, program: Program, rootFd: F if(ctx.abort) { Some(InferResult(false, None, List())) } else { - //create a body and post of the function - val (bodyExpr, fullPost) = constructVC(rootFd) - if (fullPost == tru) + val (body, pre, post) = constructVC(rootFd) + if (post == tru) Some(InferResult(true, Some(Model.empty), List())) else - solveParametricVC(And(bodyExpr, Not(fullPost))) + solveParametricVC(pre, body, post) } } def instantiateModel(model: Model, funcs: Seq[FunDef]) = { funcs.collect { case fd if fd.hasTemplate => - fd -> fd.getTemplate + fd -> TemplateInstantiator.instantiateNormTemplates(model, fd.normalizedTemplate.get) }.toMap } def instantiateAndValidateModel(model: Model, funcs: Seq[FunDef]) = { - val templates = instantiateModel(model, funcs) - val sols = TemplateInstantiator.getAllInvariants(model, templates) - + val sols = instantiateModel(model, funcs) var output = "Invariants for Function: " + rootFd.id + "\n" sols foreach { case (fd, inv) => @@ -181,87 +168,32 @@ class UnfoldingTemplateSolver(ctx: InferenceContext, program: Program, rootFd: F * the inferred postcondition */ def verifyInvariant(newposts: Map[FunDef, Expr]): (Option[Boolean], Model) = { - //create a fundef for each function in the program - //note: mult functions are also copied - val newFundefs = program.definedFunctions.collect { - case fd @ _ => { //if !isMultFunctions(fd) - val newfd = new FunDef(FreshIdentifier(fd.id.name, Untyped, false), fd.tparams, fd.params, fd.returnType) - (fd, newfd) - } - }.toMap - //note: we are not replacing "mult" function by "Times" - val replaceFun = (e: Expr) => e match { - case fi @ FunctionInvocation(tfd1, args) if newFundefs.contains(tfd1.fd) => - FunctionInvocation(TypedFunDef(newFundefs(tfd1.fd), tfd1.tps), args) - case _ => e - } - //create a body, pre, post for each newfundef - newFundefs.foreach((entry) => { - val (fd, newfd) = entry - //add a new precondition - newfd.precondition = - if (fd.precondition.isDefined) - Some(simplePostTransform(replaceFun)(fd.precondition.get)) - else None - - //add a new body - newfd.body = if (fd.hasBody) - Some(simplePostTransform(replaceFun)(fd.body.get)) - else None - - //add a new postcondition - val newpost = if (newposts.contains(fd)) { - val inv = newposts(fd) - if (fd.postcondition.isDefined) { - val Lambda(resultBinder, _) = fd.postcondition.get - Some(Lambda(resultBinder, And(fd.getPostWoTemplate, inv))) - } else { - //replace #res in the invariant by a new result variable - val resvar = FreshIdentifier("res", fd.returnType, true) - // FIXME: Is this correct (ResultVariable(fd.returnType) -> resvar.toVariable)) - val ninv = replace(Map(ResultVariable(fd.returnType) -> resvar.toVariable), inv) - Some(Lambda(Seq(ValDef(resvar)), ninv)) - } - } else if (fd.postcondition.isDefined) { - val Lambda(resultBinder, _) = fd.postcondition.get - Some(Lambda(resultBinder, fd.getPostWoTemplate)) - } else None - - newfd.postcondition = if (newpost.isDefined) { - val Lambda(resultBinder, pexpr) = newpost.get - // Some((resvar, simplePostTransform(replaceFun)(pexpr))) - Some(Lambda(resultBinder, simplePostTransform(replaceFun)(pexpr))) - } else None - newfd.addFlags(fd.flags) - }) - - val augmentedProg = copyProgram(program, (defs: Seq[Definition]) => defs.collect { - case fd: FunDef if (newFundefs.contains(fd)) => newFundefs(fd) - case d if (!d.isInstanceOf[FunDef]) => d - }) + val augProg = assignTemplateAndCojoinPost(Map(), program, newposts, uniqueIdDisplay = false) //convert the program back to an integer program if necessary - val (newprog, newroot) = if (ctx.usereals) { - val realToIntconverter = new RealToIntProgram() - val intProg = realToIntconverter(augmentedProg) - (intProg, realToIntconverter.mappedFun(newFundefs(rootFd))) - } else { - (augmentedProg, newFundefs(rootFd)) + val newprog = + if (ctx.usereals) new RealToIntProgram()(augProg) + else augProg + val newroot = functionByFullName(fullName(rootFd)(program), newprog).get + verifyVC(newprog, newroot) + } + + /** + * Uses default postcondition VC, but can be overriden in the case of non-standard VCs + */ + def verifyVC(newprog: Program, newroot: FunDef) = { + (newroot.postcondition, newroot.body) match { + case (Some(post), Some(body)) => + val vc = implies(newroot.precOrTrue, application(post, Seq(body))) + solveUsingLeon(ctx.leonContext, newprog, VC(vc, newroot, VCKinds.Postcondition)) } - // TODO: note here we must reuse the created vc, instead of creating default VC - val solFactory = SolverFactory.uninterpreted(ctx.leonContext, newprog) - val vericontext = VerificationContext(ctx.leonContext, newprog, solFactory, reporter) - val defaultTactic = new DefaultTactic(vericontext) - val vc = defaultTactic.generatePostconditions(newroot).head - solveUsingLeon(ctx.leonContext, newprog, vc) } import leon.solvers._ import leon.solvers.combinators.UnrollingSolver def solveUsingLeon(leonctx: LeonContext, p: Program, vc: VC) = { val solFactory = SolverFactory.uninterpreted(leonctx, program) - val verifyTimeout = 5 val smtUnrollZ3 = new UnrollingSolver(ctx.leonContext, program, solFactory.getNewSolver()) with TimeoutSolver - smtUnrollZ3.setTimeout(verifyTimeout * 1000) + smtUnrollZ3.setTimeout(ctx.vcTimeout * 1000) smtUnrollZ3.assertVC(vc) smtUnrollZ3.check match { case Some(true) => diff --git a/src/main/scala/leon/invariant/factories/TemplateInstantiator.scala b/src/main/scala/leon/invariant/factories/TemplateInstantiator.scala index 34bdd7431c8827ca726467f65182ef10fef67e1d..dc7caf9b3406e6e4d47e7908bc6400d1c24b4c3a 100644 --- a/src/main/scala/leon/invariant/factories/TemplateInstantiator.scala +++ b/src/main/scala/leon/invariant/factories/TemplateInstantiator.scala @@ -7,31 +7,36 @@ import purescala.ExprOps._ import purescala.Extractors._ import invariant.util._ import invariant.structure._ +import invariant.engine.InferenceContext import leon.solvers.Model import leon.invariant.util.RealValuedExprEvaluator import PredicateUtil._ +import FunctionUtils._ +import ExpressionTransformer._ object TemplateInstantiator { + /** - * Computes the invariant for all the procedures given a mapping for the - * template variables. + * Computes the invariant for all the procedures given a model for the template variables. * (Undone) If the mapping does not have a value for an id, then the id is bound to the simplest value */ - def getAllInvariants(model: Model, templates: Map[FunDef, Expr], prettyInv: Boolean = false): Map[FunDef, Expr] = { - val invs = templates.map((pair) => { - val (fd, t) = pair - val template = ExpressionTransformer.FlattenFunction(t) - val tempvars = getTemplateVars(template) - val tempVarMap: Map[Expr, Expr] = tempvars.map((v) => { - (v, model(v.id)) - }).toMap - val instTemplate = instantiate(template, tempVarMap, prettyInv) - val comprTemp = ExpressionTransformer.unflatten(instTemplate) - (fd, comprTemp) - }) + def getAllInvariants(model: Model, funs: Seq[FunDef], prettyInv: Boolean = false): Map[FunDef, Expr] = { + val invs = funs.collect { + case fd if fd.hasTemplate => + (fd, instantiateNormTemplates(model, fd.normalizedTemplate.get, prettyInv)) + }.toMap invs } + /** + * This function expects a template in a normalized form. + */ + def instantiateNormTemplates(model: Model, template: Expr, prettyInv: Boolean = false): Expr = { + val tempvars = getTemplateVars(template) + val instTemplate = instantiate(template, tempvars.map { v => (v, model(v.id)) }.toMap, prettyInv) + unflatten(instTemplate) + } + /** * Instantiates templated subexpressions of the given expression (expr) using the given mapping for the template variables. * The instantiation also takes care of converting the rational coefficients to integer coefficients. diff --git a/src/main/scala/leon/invariant/structure/Constraint.scala b/src/main/scala/leon/invariant/structure/Constraint.scala index 79c9a42ca40642be9f3e0834bb95865c5978e550..513bfd9687e0881491307ab53807b344fd7994c8 100644 --- a/src/main/scala/leon/invariant/structure/Constraint.scala +++ b/src/main/scala/leon/invariant/structure/Constraint.scala @@ -5,13 +5,28 @@ import purescala.Expressions._ import purescala.ExprOps._ import purescala.Types._ import invariant.util._ +import Util._ import PredicateUtil._ import TypeUtil._ import purescala.Extractors._ +import ExpressionTransformer._ +import solvers.Model +import purescala.Common._ +import leon.evaluators._ trait Constraint { def toExpr: Expr } + +trait ExtendedConstraint extends Constraint { + def pickSatDisjunct(model: LazyModel, tmplModel: Map[Identifier,Expr], eval: DefaultEvaluator): Constraint +} + +object LinearTemplate { + val debug = false + val debugPickSat = false +} + /** * Class representing linear templates which is a constraint of the form * a1*v1 + a2*v2 + .. + an*vn + a0 <= 0 or = 0 or < 0 where ai's are unknown coefficients @@ -24,42 +39,71 @@ class LinearTemplate(oper: Seq[Expr] => Expr, coeffTemp: Map[Expr, Expr], constTemp: Option[Expr]) extends Constraint { + import LinearTemplate._ + val zero = InfiniteIntegerLiteral(0) + val op = oper - val op = { - oper - } val coeffTemplate = { - //assert if the coefficients are templated expressions - assert(coeffTemp.values.forall(e => isTemplateExpr(e))) + if(debug) assert(coeffTemp.values.forall(e => isTemplateExpr(e))) coeffTemp } val constTemplate = { - assert(constTemp match { - case None => true - case Some(e) => isTemplateExpr(e) - }) + if(debug) assert(constTemp.map(isTemplateExpr).getOrElse(true)) constTemp } - val template = { + val lhsExpr = { //construct the expression corresponding to the template here - var lhs = coeffTemp.foldLeft(null: Expr)((acc, entry) => { - val (term, coeff) = entry - val minterm = Times(coeff, term) - if (acc == null) minterm else Plus(acc, minterm) - }) - lhs = if (constTemp.isDefined) { + var lhs = coeffTemp.foldLeft(null: Expr) { + case (acc, (term, coeff)) => + val minterm = Times(coeff, term) + if (acc == null) minterm else Plus(acc, minterm) + } + if (constTemp.isDefined) { if (lhs == null) constTemp.get else Plus(lhs, constTemp.get) } else lhs - val expr = oper(Seq(lhs, zero)) - expr } - def templateVars: Set[Variable] = { - getTemplateVars(template) + val template = oper(Seq(lhsExpr, zero)) + + def templateVars: Set[Variable] = getTemplateVars(template) + + /** + * Picks a sat disjunct of the negation of the template w.r.t to the + * given model. + */ + lazy val negTmpls = { + val args = template match { + case _: Equals => Seq(GreaterThan(lhsExpr, zero), LessThan(lhsExpr,zero)) + case _: LessEquals => Seq(GreaterThan(lhsExpr, zero)) + case _: LessThan => Seq(GreaterEquals(lhsExpr, zero)) + case _: GreaterEquals => Seq(LessThan(lhsExpr, zero)) + case _: GreaterThan => Seq(LessEquals(lhsExpr, zero)) + } + args map LinearConstraintUtil.exprToTemplate + } + + def pickSatDisjunctOfNegation(model: LazyModel, tmplModel: Map[Identifier, Expr], eval: DefaultEvaluator) = { + val err = new IllegalStateException(s"Cannot pick a sat disjunct of negation: ${toString} is sat!") + template match { + case _: Equals => // here, negation is a disjunction + UnflatHelper.evaluate(replaceFromIDs(tmplModel, lhsExpr), model, eval) match { + case InfiniteIntegerLiteral(lval) => + val Seq(grt, less) = negTmpls + if (lval > 0) grt + else if (lval < 0) less + else throw err + } + case _ => // here, the negation must be sat + if (debugPickSat) { + if (UnflatHelper.evaluate(replaceFromIDs(tmplModel, negTmpls.head.toExpr), model, eval) != tru) + throw err + } + negTmpls.head + } } def coeffEntryToString(coeffEntry: (Expr, Expr)): String = { @@ -117,7 +161,6 @@ class LinearTemplate(oper: Seq[Expr] => Expr, } override def toString(): String = { - val coeffStr = if (coeffTemplate.isEmpty) "" else { val (head :: tail) = coeffTemplate.toList @@ -139,17 +182,10 @@ class LinearTemplate(oper: Seq[Expr] => Expr, }) + "0" } - override def hashCode(): Int = { - template.hashCode() - } + override def hashCode(): Int = template.hashCode() override def equals(obj: Any): Boolean = obj match { - case lit: LinearTemplate => { - if (!lit.template.equals(this.template)) { - //println(lit.template + " and " + this.template+ " are not equal ") - false - } else true - } + case lit: LinearTemplate => lit.template.equals(this.template) case _ => false } } @@ -159,29 +195,49 @@ class LinearTemplate(oper: Seq[Expr] => Expr, */ class LinearConstraint(opr: Seq[Expr] => Expr, cMap: Map[Expr, Expr], constant: Option[Expr]) extends LinearTemplate(opr, cMap, constant) { + val coeffMap = cMap + val const = constant +} + +/** + * Class representing Equality or disequality of a boolean variable and an linear template. + * Used for efficiently choosing a disjunct + */ +case class ExtendedLinearTemplate(v: Variable, tmpl: LinearTemplate, diseq: Boolean) extends ExtendedConstraint { + val expr = { + val eqExpr = Equals(v, tmpl.toExpr) + if(diseq) Not(eqExpr) else eqExpr + } + override def toExpr = expr + override def toString: String = expr.toString + + /** + * Chooses a sat disjunct of the constraint + */ + override def pickSatDisjunct(model: LazyModel, tmplModel: Map[Identifier,Expr], eval: DefaultEvaluator) = { + if((model(v.id) == tru && !diseq) || (model(v.id) == fls && diseq)) tmpl + else { + //println(s"Picking sat disjunct of: ${toExpr} model($v) = ${model(v.id)}") + tmpl.pickSatDisjunctOfNegation(model, tmplModel, eval) + } + } +} - val coeffMap = { - //assert if the coefficients are only constant expressions - assert(cMap.values.forall(e => variablesOf(e).isEmpty)) - //TODO: here we should try to simplify the constant expressions - cMap +object BoolConstraint { + def isBoolConstraint(e: Expr): Boolean = e match { + case _: Variable | _: BooleanLiteral if e.getType == BooleanType => true + case Equals(l, r) => isBoolConstraint(l) && isBoolConstraint(r) //enabling makes the system slower!! surprising + case Not(arg) => isBoolConstraint(arg) + case And(args) => args forall isBoolConstraint + case Or(args) => args forall isBoolConstraint + case _ => false } - val const = constant.map((c) => { - //check if constant does not have any variables - assert(variablesOf(c).isEmpty) - c - }) } case class BoolConstraint(e: Expr) extends Constraint { + import BoolConstraint._ val expr = { - assert(e match { - case Variable(_) => true - case Not(Variable(_)) => true - case t: BooleanLiteral => true - case Not(t: BooleanLiteral) => true - case _ => false - }) + assert(isBoolConstraint(e)) e } override def toString(): String = expr.toString @@ -189,26 +245,21 @@ case class BoolConstraint(e: Expr) extends Constraint { } object ADTConstraint { - // note: we consider even type parameters as ADT type - def adtType(e: Expr) = { - val tpe = e.getType - tpe.isInstanceOf[ClassType] || tpe.isInstanceOf[TupleType] || tpe.isInstanceOf[TypeParameter] - } - def apply(e: Expr): ADTConstraint = e match { - case Equals(_: Variable, _: CaseClassSelector | _: TupleSelect) => - new ADTConstraint(e, sel = true) - case Equals(_: Variable, _: CaseClass | _: Tuple) => - new ADTConstraint(e, cons = true) - case Equals(_: Variable, _: IsInstanceOf) => - new ADTConstraint(e, inst = true) - case Equals(lhs @ Variable(_), AsInstanceOf(rhs @ Variable(_), _)) => - new ADTConstraint(Equals(lhs, rhs), comp= true) + def apply(e: Expr): ADTConstraint = e match { + case Equals(_: Variable, _: CaseClassSelector | _: TupleSelect) => + new ADTConstraint(e, sel = true) + case Equals(_: Variable, _: CaseClass | _: Tuple) => + new ADTConstraint(e, cons = true) + case Equals(_: Variable, _: IsInstanceOf) => + new ADTConstraint(e, inst = true) + case Equals(lhs @ Variable(_), AsInstanceOf(rhs @ Variable(_), _)) => + new ADTConstraint(Equals(lhs, rhs), comp= true) case Equals(lhs: Variable, _: Variable) if adtType(lhs) => new ADTConstraint(e, comp = true) - case Not(Equals(lhs: Variable, _: Variable)) if adtType(lhs) => - new ADTConstraint(e, comp = true) - case _ => - throw new IllegalStateException(s"Expression not an ADT constraint: $e") + case Not(Equals(lhs: Variable, _: Variable)) if adtType(lhs) => + new ADTConstraint(e, comp = true) + case _ => + throw new IllegalStateException(s"Expression not an ADT constraint: $e") } } @@ -216,15 +267,40 @@ class ADTConstraint(val expr: Expr, val cons: Boolean = false, val inst: Boolean = false, val comp: Boolean = false, - val sel: Boolean = false) extends Constraint { - - override def toString(): String = expr.toString + val sel: Boolean = false) extends Constraint { + + override def toString(): String = expr.toString override def toExpr = expr } +case class ExtendedADTConstraint(v: Variable, adtCtr: ADTConstraint, diseq: Boolean) extends ExtendedConstraint { + val expr = { + assert(adtCtr.comp) + val eqExpr = Equals(v, adtCtr.toExpr) + if(diseq) Not(eqExpr) else eqExpr + } + override def toExpr = expr + override def toString: String = expr.toString + + /** + * Chooses a sat disjunct of the constraint + */ + override def pickSatDisjunct(model: LazyModel, tmplModel: Map[Identifier,Expr], eval: DefaultEvaluator) = { + if((model(v.id) == tru && !diseq) || (model(v.id) == fls && diseq)) adtCtr + else ADTConstraint(Not(adtCtr.toExpr)) + } +} + case class Call(retexpr: Expr, fi: FunctionInvocation) extends Constraint { val expr = Equals(retexpr, fi) + override def toExpr = expr +} +/** + * If-then-else constraint + */ +case class ITE(cond: BoolConstraint, ths: Seq[Constraint], elzs: Seq[Constraint]) extends Constraint { + val expr = IfExpr(cond.toExpr, createAnd(ths.map(_.toExpr)), createAnd(elzs.map(_.toExpr))) override def toExpr = expr } @@ -269,37 +345,42 @@ case class SetConstraint(expr: Expr) extends Constraint { override def toExpr = expr } -object ConstraintUtil { - - def createConstriant(ie: Expr): Constraint = { - ie match { - case Variable(_) | Not(Variable(_)) | BooleanLiteral(_) | Not(BooleanLiteral(_)) => - BoolConstraint(ie) - case Equals(v @ Variable(_), fi @ FunctionInvocation(_, _)) => - Call(v, fi) - case Equals(_: Variable, _: CaseClassSelector | _: CaseClass | _: TupleSelect | _: Tuple |_: IsInstanceOf) => - ADTConstraint(ie) - case _ if SetConstraint.isSetConstraint(ie) => - SetConstraint(ie) - // every non-integer equality will be considered an ADT constraint (including TypeParameter equalities) - case Equals(lhs, rhs) if !isNumericType(lhs.getType) => - //println("ADT constraint: "+ie) - ADTConstraint(ie) - case Not(Equals(lhs, rhs)) if !isNumericType(lhs.getType) => - ADTConstraint(ie) +object ConstraintUtil { + def toLinearTemplate(ie: Expr) = { + simplifyArithmetic(ie) match { + case b: BooleanLiteral => BoolConstraint(b) case _ => { - val simpe = simplifyArithmetic(ie) - simpe match { - case b: BooleanLiteral => BoolConstraint(b) - case _ => { - val template = LinearConstraintUtil.exprToTemplate(ie) - LinearConstraintUtil.evaluate(template) match { - case Some(v) => BoolConstraint(BooleanLiteral(v)) - case _ => template - } - } + val template = LinearConstraintUtil.exprToTemplate(ie) + LinearConstraintUtil.evaluate(template) match { + case Some(v) => BoolConstraint(BooleanLiteral(v)) + case _ => template } } } } + + def toExtendedTemplate(v: Variable, ie: Expr, diseq: Boolean) = { + toLinearTemplate(ie) match { + case bc: BoolConstraint => BoolConstraint(Equals(v, bc.toExpr)) + case t: LinearTemplate => ExtendedLinearTemplate(v, t, diseq) + } + } + + def createConstriant(ie: Expr): Constraint = { + ie match { + case _ if BoolConstraint.isBoolConstraint(ie) => BoolConstraint(ie) + case Equals(v @ Variable(_), fi @ FunctionInvocation(_, _)) => Call(v, fi) + case Equals(_: Variable, _: CaseClassSelector | _: CaseClass | _: TupleSelect | _: Tuple | _: IsInstanceOf) => + ADTConstraint(ie) + case _ if SetConstraint.isSetConstraint(ie) => SetConstraint(ie) + case Equals(v: Variable, rhs) if (isArithmeticRelation(rhs) != Some(false)) => toExtendedTemplate(v, rhs, false) + case Not(Equals(v: Variable, rhs)) if (isArithmeticRelation(rhs) != Some(false)) => toExtendedTemplate(v, rhs, true) + case _ if (isArithmeticRelation(ie) != Some(false)) => toLinearTemplate(ie) + case Equals(v: Variable, rhs@Equals(l, _)) if adtType(l) => ExtendedADTConstraint(v, ADTConstraint(rhs), false) + + // every other equality will be considered an ADT constraint (including TypeParameter equalities) + case Equals(lhs, rhs) if !isNumericType(lhs.getType) => ADTConstraint(ie) + case Not(Equals(lhs, rhs)) if !isNumericType(lhs.getType) => ADTConstraint(ie) + } + } } diff --git a/src/main/scala/leon/invariant/structure/Formula.scala b/src/main/scala/leon/invariant/structure/Formula.scala index d0995c71647a9c0f0fbdffe7cb2a69281583ae3e..588cf86ea7b87c4e8840bf288e78a2baa87d3c33 100644 --- a/src/main/scala/leon/invariant/structure/Formula.scala +++ b/src/main/scala/leon/invariant/structure/Formula.scala @@ -22,14 +22,18 @@ import TVarFactory._ import ExpressionTransformer._ import evaluators._ import invariant.factories._ +import evaluators._ +import EvaluationResults._ /** - * Data associated with a call + * Data associated with a call. + * @param inSpec true if the call (transitively) made within specifications */ -class CallData(val guard : Variable, val parents: List[FunDef]) +class CallData(val guard : Variable, val parents: List[FunDef], val inSpec: Boolean) object Formula { val debugUnflatten = false + val dumpUnflatFormula = false // a context for creating blockers val blockContext = newContext } @@ -38,22 +42,23 @@ object Formula { * Representation of an expression as a set of implications. * 'initexpr' is required to be in negation normal form and And/Ors have been pulled up * TODO: optimize the representation so that we use fewer guards. + * @param initSpecCalls when specified it optimizes the handling of calls made in the specification. */ -class Formula(val fd: FunDef, initexpr: Expr, ctx: InferenceContext) { +class Formula(val fd: FunDef, initexpr: Expr, ctx: InferenceContext, initSpecCalls: Set[Expr] = Set()) { import Formula._ val fls = BooleanLiteral(false) val tru = BooleanLiteral(true) - val useImplies = false + val useImplies = false // note: we have to use equality for 'cond' blockers (no matter what!) val combiningOp = if(useImplies) Implies.apply _ else Equals.apply _ protected var disjuncts = Map[Variable, Seq[Constraint]]() //a mapping from guards to conjunction of atoms protected var conjuncts = Map[Variable, Expr]() //a mapping from guards to disjunction of atoms - private var callDataMap = Map[Call, CallData]() //a mapping from a 'call' to the 'guard' guarding the call plus the list of transitive callers of 'call' private var paramBlockers = Set[Variable]() + private var callDataMap = Map[Call, CallData]() //a mapping from a 'call' to the 'guard' guarding the call plus the list of transitive callers of 'call' - val firstRoot : Variable = addConstraints(initexpr, List(fd))._1 + val firstRoot: Variable = addConstraints(initexpr, List(fd), c => initSpecCalls(c.toExpr))._1 protected var roots : Seq[Variable] = Seq(firstRoot) //a list of roots, the formula is a conjunction of formula of each root def disjunctsInFormula = disjuncts @@ -62,37 +67,36 @@ class Formula(val fd: FunDef, initexpr: Expr, ctx: InferenceContext) { //return the root variable and the sequence of disjunct guards added //(which includes the root variable incase it respresents a disjunct) - def addConstraints(ine: Expr, callParents : List[FunDef]) : (Variable, Seq[Variable]) = { - + def addConstraints(ine: Expr, callParents: List[FunDef], inSpec: Call => Boolean): (Variable, Seq[Variable]) = { + def atoms(e: Expr) = e match { + case And(atms) => atms + case _ => Seq(e) + } var newDisjGuards = Seq[Variable]() + var condBlockers = Map[Variable, (Variable, Variable)]() // a mapping from condition constraint to then and else blockers - def getCtrsFromExprs(guard: Variable, exprs: Seq[Expr]) : Seq[Constraint] = { + def getCtrsFromExprs(guard: Variable, exprs: Seq[Expr]): Seq[Constraint] = { var break = false - exprs.foldLeft(Seq[Constraint]())((acc, e) => { - if (break) acc - else { - val ctr = ConstraintUtil.createConstriant(e) - ctr match { + exprs.foldLeft(Seq[Constraint]()) { + case (acc, _) if break => acc + case (acc, ife @ IfExpr(cond: Variable, th, elze)) => + val (thBlock, elseBlock) = condBlockers(cond) + acc :+ ITE(BoolConstraint(cond), BoolConstraint(thBlock) +: getCtrsFromExprs(thBlock, atoms(th)), + BoolConstraint(elseBlock) +: getCtrsFromExprs(elseBlock, atoms(elze))) + case (acc, e) => + ConstraintUtil.createConstriant(e) match { case BoolConstraint(BooleanLiteral(true)) => acc - case BoolConstraint(BooleanLiteral(false)) => { + case fls @ BoolConstraint(BooleanLiteral(false)) => break = true - Seq(ctr) - } - case call@Call(_,_) => { - - if(callParents.isEmpty) - throw new IllegalArgumentException("Parent not specified for call: "+ctr) - else { - callDataMap += (call -> new CallData(guard, callParents)) - } + Seq(fls) + case call @ Call(_, _) => + if (callParents.isEmpty) throw new IllegalArgumentException("Parent not specified for call: " + call) + else callDataMap += (call -> new CallData(guard, callParents, inSpec(call))) acc :+ call - } - case _ => acc :+ ctr + case ctr => acc :+ ctr } - } - }) + } } - /** * Creates disjunct of the form b == exprs and updates the necessary mutable states */ @@ -105,121 +109,147 @@ class Formula(val fd: FunDef, initexpr: Expr, ctx: InferenceContext) { paramBlockers += g g } - - val f1 = simplePostTransform { - case e@Or(args) => { - val newargs = args.map { - case arg@(v: Variable) if (disjuncts.contains(v)) => arg - case v: Variable if (conjuncts.contains(v)) => throw new IllegalStateException("or gaurd inside conjunct: " + e + " or-guard: " + v) - case arg => { - val atoms = arg match { - case And(atms) => atms - case _ => Seq(arg) - } - val g = addToDisjunct(atoms, !getTemplateIds(arg).isEmpty) + def rec(e: Expr)(implicit insideOperation: Boolean): Expr = e match { + case Or(args) if !insideOperation => + val newargs = (args map rec).map { + case v: Variable if disjuncts.contains(v) => v + case v: Variable if conjuncts.contains(v) => throw new IllegalStateException("or gaurd inside conjunct: " + e + " or-guard: " + v) + case arg => + val g = addToDisjunct(atoms(arg), !getTemplateIds(arg).isEmpty) //println(s"creating a new OR blocker $g for "+atoms) g - } } - //create a temporary for Or val gor = createTemp("b", BooleanType, blockContext).toVariable val newor = createOr(newargs) + //println("Creating or const: "+(gor -> newor)) conjuncts += (gor -> newor) gor - } - case e@And(args) => { + + case And(args) => //if the expression has template variables then we separate it using guards - val (nonparams, params) = args.partition(getTemplateIds(_).isEmpty) + val (nonparams, params) = (args map rec).partition(getTemplateIds(_).isEmpty) val newargs = - if (!params.isEmpty) { - val g = addToDisjunct(params, true) - //println(s"creating a new Temp blocker $g for "+arg) - paramBlockers += g - g +: nonparams - } else nonparams + if (!params.isEmpty) + addToDisjunct(params, true) +: nonparams + else nonparams createAnd(newargs) - } - case e => e - }(ExpressionTransformer.simplify(simplifyArithmetic( + + case e : IfExpr => + val (con, th, elze) = (rec(e.cond)(true), rec(e.thenn)(false), rec(e.elze)(false)) + if(!isAtom(con) || !getTemplateIds(con).isEmpty) + throw new IllegalStateException(s"Condition of ifexpr is not an atom: $e") + // create condition and anti-condition blockers + val ncond = addToDisjunct(Seq(con), false) + val thBlock = addToDisjunct(Seq(), false) + val elseBlock = addToDisjunct(Seq(), false) + condBlockers += (ncond -> (thBlock, elseBlock)) + // normalize thn and elze + val trans = (e: Expr) => { + if(getTemplateIds(e).isEmpty) e + else addToDisjunct(atoms(e), true) + } + IfExpr(ncond, trans(th), trans(elze)) + + case Operator(args, op) => + op(args.map(rec(_)(true))) + } + val f1 = rec(ExpressionTransformer.simplify(simplifyArithmetic( //TODO: this is a hack as of now. Fix this. //Note: it is necessary to convert real literals to integers since the linear constraint cannot handle real literals if(ctx.usereals) ExpressionTransformer.FractionalLiteralToInt(ine) else ine - ))) - + )))(false) val rootvar = f1 match { case v: Variable if(conjuncts.contains(v)) => v case v: Variable if(disjuncts.contains(v)) => throw new IllegalStateException("f1 is a disjunct guard: "+v) - case _ => { - val atoms = f1 match { - case And(atms) => atms - case _ => Seq(f1) - } - val g = addToDisjunct(atoms, !getTemplateIds(f1).isEmpty) - g - } + case _ => addToDisjunct(atoms(f1), !getTemplateIds(f1).isEmpty) } (rootvar, newDisjGuards) } - //'satGuard' is required to a guard variable - def pickSatDisjunct(startGaurd : Variable, model: LazyModel): Seq[Constraint] = { + def pickSatDisjunct(startGaurd : Variable, model: LazyModel, tmplModel: Map[Identifier, Expr], eval: DefaultEvaluator): Seq[Constraint] = { - def traverseOrs(gd: Variable, model: LazyModel): Seq[Variable] = { - val e @ Or(guards) = conjuncts(gd) - //pick one guard that is true - val guard = guards.collectFirst { case g @ Variable(id) if (model(id) == tru) => g } + def traverseOrs(ine: Expr): Seq[Constraint] = { + val Or(guards) = ine + val guard = guards.collectFirst { case g @ Variable(id) if (model(id) == tru) => g } //pick one guard that is true if (guard.isEmpty) - throw new IllegalStateException("No satisfiable guard found: " + e) - guard.get +: traverseAnds(guard.get, model) + throw new IllegalStateException("No satisfiable guard found: " + ine) + BoolConstraint(guard.get) +: traverseAnds(disjuncts(guard.get)) } - - def traverseAnds(gd: Variable, model: LazyModel): Seq[Variable] = { - val ctrs = disjuncts(gd) - val guards = ctrs.collect { - case BoolConstraint(v @ Variable(_)) if (conjuncts.contains(v) || disjuncts.contains(v)) => v + def traverseAnds(inctrs: Seq[Constraint]): Seq[Constraint] = + inctrs.foldLeft(Seq[Constraint]()) { + case (acc, ITE(BoolConstraint(c: Variable), ths, elzes)) => + val conds = disjuncts(c) // here, cond it guaranteed to be an atom + assert(conds.size <= 1) + val ctrs = + if (model(c.id) == tru) + conds ++ traverseAnds(ths) + else { + val condCtr = conds match { + case Seq(bc: BoolConstraint) => BoolConstraint(Not(bc.toExpr)) + case Seq(lc: LinearTemplate) => lc.pickSatDisjunctOfNegation(model, tmplModel, eval) + case Seq(adteq: ADTConstraint) if adteq.comp => + adteq.toExpr match { + case Not(eq) => ADTConstraint(eq) + case eq => ADTConstraint(Not(eq)) + } + } + condCtr +: traverseAnds(elzes) + } + acc ++ ctrs + case (acc, elt: ExtendedConstraint) => + acc :+ elt.pickSatDisjunct(model, tmplModel, eval) + case (acc, ctr @ BoolConstraint(v: Variable)) if conjuncts.contains(v) => //assert(model(v.id) == tru) + acc ++ (ctr +: traverseOrs(conjuncts(v))) + case (acc, ctr @ BoolConstraint(v: Variable)) if disjuncts.contains(v) => //assert(model(v.id) == tru) + acc ++ (ctr +: traverseAnds(disjuncts(v))) + case (acc, ctr) => acc :+ ctr } - if (guards.isEmpty) Seq() + val path = + if (model(startGaurd.id) == fls) Seq() //if startGuard is unsat return empty else { - guards.foldLeft(Seq[Variable]())((acc, g) => { - if (model(g.id) != tru) - throw new IllegalStateException("Not a satisfiable guard: " + g) - - if (conjuncts.contains(g)) - acc ++ traverseOrs(g, model) - else { - acc ++ (g +: traverseAnds(g, model)) - } - }) + if (conjuncts.contains(startGaurd)) + traverseOrs(conjuncts(startGaurd)) + else + BoolConstraint(startGaurd) +: traverseAnds(disjuncts(startGaurd)) } - } - //if startGuard is unsat return empty - if (model(startGaurd.id) == fls) Seq() - else { - val satGuards = if (conjuncts.contains(startGaurd)) traverseOrs(startGaurd, model) - else (startGaurd +: traverseAnds(startGaurd, model)) - satGuards.flatMap(g => disjuncts(g)) - } + /*println("Path: " + simplifyArithmetic(createAnd(path.map(_.toExpr)))) + scala.io.StdIn.readLine()*/ + path } /** * 'neweexpr' is required to be in negation normal form and And/Ors have been pulled up */ - def conjoinWithDisjunct(guard: Variable, newexpr: Expr, callParents: List[FunDef]) : (Variable, Seq[Variable]) = { - val (exprRoot, newGaurds) = addConstraints(newexpr, callParents) + def conjoinWithDisjunct(guard: Variable, newexpr: Expr, callParents: List[FunDef], inSpec:Boolean) = { + val (exprRoot, newGaurds) = addConstraints(newexpr, callParents, _ => inSpec) //add 'newguard' in conjunction with 'disjuncts(guard)' val ctrs = disjuncts(guard) disjuncts -= guard disjuncts += (guard -> (BoolConstraint(exprRoot) +: ctrs)) - (exprRoot, newGaurds) + exprRoot } - def conjoinWithRoot(newexpr: Expr, callParents: List[FunDef]): (Variable, Seq[Variable]) = { - val (exprRoot, newGaurds) = addConstraints(newexpr, callParents) + def conjoinWithRoot(newexpr: Expr, callParents: List[FunDef], inSpec: Boolean) = { + val (exprRoot, newGaurds) = addConstraints(newexpr, callParents, _ => inSpec) roots :+= exprRoot - (exprRoot, newGaurds) + exprRoot } + def getCallsOfGuards(guards: Seq[Variable]): Seq[Call] = { + def calls(ctrs: Seq[Constraint]): Seq[Call] = { + ctrs.flatMap { + case c: Call => Seq(c) + case ITE(_, th, el) => + calls(th) ++ calls(el) + case _ => Seq() + } + } + guards.flatMap{g => calls(disjuncts(g)) } + } + + def callsInFormula: Seq[Call] = getCallsOfGuards(disjuncts.keys.toSeq) + def templateIdsInFormula = paramBlockers.flatMap { g => getTemplateIds(createAnd(disjuncts(g).map(_.toExpr))) }.toSet @@ -257,13 +287,26 @@ class Formula(val fd: FunDef, initexpr: Expr, ctx: InferenceContext) { val paramPart = paramBlockers.toSeq.map{ g => combiningOp(g,createAnd(disjuncts(g).map(_.toExpr))) } + // simplify blockers if we can, and close the map + val blockMap = substClosure(disjuncts.collect { + case (g, Seq(ctr)) if !paramBlockers(g) => (g.id -> ctr.toExpr) + case (g, Seq()) => (g.id -> tru) + }.toMap) + val conjs = conjuncts.map { + case (g, rhs) => replaceFromIDs(blockMap, combiningOp(g, rhs)) + }.toSeq ++ roots.map(replaceFromIDs(blockMap, _)) + val flatRest = disjuncts.toSeq collect { + case (g, ctrs) if !paramBlockers(g) && !blockMap.contains(g.id) => + //val ng = blockMap.getOrElse(g.id, g) + (g, replaceFromIDs(blockMap, createAnd(ctrs.map(_.toExpr)))) + } // compute variables used in more than one disjunct + var sharedVars = (paramPart ++ conjs).flatMap(variablesOf).toSet var uniqueVars = Set[Identifier]() - var sharedVars = Set[Identifier]() var freevars = Set[Identifier]() - disjuncts.foreach{ - case (g, ctrs) => - val fvs = ctrs.flatMap(c => variablesOf(c.toExpr)).toSet + flatRest.foreach{ + case (g, rhs) => + val fvs = variablesOf(rhs).toSet val candUniques = fvs -- sharedVars val newShared = uniqueVars.intersect(candUniques) freevars ++= fvs @@ -271,10 +314,10 @@ class Formula(val fd: FunDef, initexpr: Expr, ctx: InferenceContext) { uniqueVars = (uniqueVars ++ candUniques) -- newShared } // unflatten rest - var flatIdMap = Map[Identifier, Expr]() - val unflatRest = (disjuncts collect { - case (g, ctrs) if !paramBlockers(g) => - val rhs = createAnd(ctrs.map(_.toExpr)) + var flatIdMap = blockMap + val unflatRest = (flatRest collect { + case (g, rhs) => + // note: we call simple unflatten in the presence of if-then-else because it will not have flat-ids transcending then and else branches val (unflatRhs, idmap) = simpleUnflattenWithMap(rhs, sharedVars, includeFuns = false) // sanity checks if (debugUnflatten) { @@ -288,8 +331,20 @@ class Formula(val fd: FunDef, initexpr: Expr, ctx: InferenceContext) { flatIdMap ++= idmap combiningOp(g, unflatRhs) }).toSeq + val modelCons = (m: Model, eval: DefaultEvaluator) => new FlatModel(freevars, flatIdMap, m, eval) - val conjs = conjuncts.map{ case(g,rhs) => combiningOp(g, rhs) }.toSeq ++ roots + + if (dumpUnflatFormula) { + val unf = ((paramPart ++ unflatRest.map(_.toString) ++ conjs.map(_.toString)).mkString("\n")) + val filename = "unflatVC-" + FileCountGUID.getID + val wr = new PrintWriter(new File(filename + ".txt")) + println("Printed VC of " + fd.id + " to file: " + filename) + wr.println(unf) + wr.close() + } + if (ctx.dumpStats) { + Stats.updateCounterStats(atomNum(And(paramPart ++ unflatRest ++ conjs)), "unflatSize", "VC-refinement") + } (createAnd(paramPart), createAnd(unflatRest ++ conjs), modelCons) } @@ -314,7 +369,6 @@ class Formula(val fd: FunDef, initexpr: Expr, ctx: InferenceContext) { //var removeGuards = Seq[Variable]() while(replacedGuard) { replacedGuard = false - val newDisjs = unpackedDisjs.map(entry => { val (g,d) = entry val guards = variablesOf(d).collect{ case id@_ if disjuncts.contains(id.toVariable) => id.toVariable } @@ -345,7 +399,10 @@ class Formula(val fd: FunDef, initexpr: Expr, ctx: InferenceContext) { * Functions for stats */ def atomsCount = disjuncts.map(_._2.size).sum + conjuncts.map(i => atomNum(i._2)).sum - def funsCount = disjuncts.map(_._2.filter(_.isInstanceOf[Call]).size).sum + def funsCount = disjuncts.map(_._2.filter { + case _: Call | _: ADTConstraint => true + case _ => false + }.size).sum /** * Functions solely used for debugging @@ -364,7 +421,8 @@ class Formula(val fd: FunDef, initexpr: Expr, ctx: InferenceContext) { //println("packed formula: "+packedFor) val satdisj = if (unflatSat == Some(true)) - Some(pickSatDisjunct(firstRoot, new SimpleLazyModel(unflatModel))) + Some(pickSatDisjunct(firstRoot, new SimpleLazyModel(unflatModel), + tempMap.map{ case (Variable(id), v) => id -> v }.toMap, eval)) else None if (unflatSat != flatSat) { if (satdisj.isDefined) { @@ -398,4 +456,28 @@ class Formula(val fd: FunDef, initexpr: Expr, ctx: InferenceContext) { } } } + + /** + * A method for picking a sat disjunct of unflat formula. Mostly used for debugging. + */ + def pickSatFromUnflatFormula(unflate: Expr, model: Model, evaluator: DefaultEvaluator): Seq[Expr] = { + def rec(e: Expr): Seq[Expr] = e match { + case IfExpr(cond, thn, elze) => + evaluator.eval(cond, model) match { + case Successful(BooleanLiteral(true)) => cond +: rec(thn) + case Successful(BooleanLiteral(false)) => Not(cond) +: rec(elze) + } + case And(args) => args flatMap rec + case Or(args) => rec(args.find(evaluator.eval(_, model) == Successful(BooleanLiteral(true))).get) + case Equals(b: Variable, rhs) if b.getType == BooleanType => + evaluator.eval(b, model) match { + case Successful(BooleanLiteral(true)) => + rec(b) ++ rec(rhs) + case Successful(BooleanLiteral(false)) => + Seq(Not(b)) + } + case e => Seq(e) + } + rec(unflate) + } } diff --git a/src/main/scala/leon/invariant/structure/FunctionUtils.scala b/src/main/scala/leon/invariant/structure/FunctionUtils.scala index fa9597e70d09ddd368b11c57119c51977255e978..ccee2b16f2be7040b36c77b88703fa19239b2463 100644 --- a/src/main/scala/leon/invariant/structure/FunctionUtils.scala +++ b/src/main/scala/leon/invariant/structure/FunctionUtils.scala @@ -11,6 +11,7 @@ import invariant.util._ import Util._ import PredicateUtil._ import scala.language.implicitConversions +import ExpressionTransformer._ /** * Some utiliy methods for functions. @@ -31,6 +32,7 @@ object FunctionUtils { lazy val hasFieldFlag = fd.flags.contains(IsField(false)) lazy val hasLazyFieldFlag = fd.flags.contains(IsField(true)) lazy val isUserFunction = !hasFieldFlag && !hasLazyFieldFlag + lazy val usePost = fd.annotations.contains("usePost") //the template function lazy val tmplFunctionName = "tmpl" @@ -108,14 +110,23 @@ object FunctionUtils { // collect all terms with question marks and convert them to a template val postWoQmarks = postBody match { case And(args) if args.exists(exists(isQMark)) => - val (tempExprs, otherPreds) = args.partition { - case a if exists(isQMark)(a) => true - case _ => false - } + val (tempExprs, otherPreds) = args.partition(exists(isQMark)) //println(s"Otherpreds: $otherPreds ${qmarksToTmplFunction(createAnd(tempExprs))}") createAnd(otherPreds :+ qmarksToTmplFunction(createAnd(tempExprs))) case pb if exists(isQMark)(pb) => - qmarksToTmplFunction(pb) + pb match { + case l: Let => + val (letsCons, letsBody) = letStarUnapplyWithSimplify(l) // we try to see if the post is let* .. in e_1 ^ e_2 ^ ... + letsBody match { + case And(args) => + val (tempExprs, rest) = args.partition(exists(isQMark)) + val toTmplFun = qmarksToTmplFunction(letsCons(createAnd(tempExprs))) + createAnd(Seq(letsCons(createAnd(rest)), toTmplFun)) + case _ => + qmarksToTmplFunction(pb) + } + case _ => qmarksToTmplFunction(pb) + } case other => other } //the 'body' could be a template or 'And(pred, template)' @@ -123,10 +134,7 @@ object FunctionUtils { case finv @ FunctionInvocation(_, args) if isTemplateInvocation(finv) => (None, Some(finv)) case And(args) if args.exists(isTemplateInvocation) => - val (tempFuns, otherPreds) = args.partition { - case a if isTemplateInvocation(a) => true - case _ => false - } + val (tempFuns, otherPreds) = args.partition(isTemplateInvocation) if (tempFuns.size > 1) { throw new IllegalStateException("Multiple template functions used in the postcondition: " + postBody) } else { @@ -142,6 +150,8 @@ object FunctionUtils { } lazy val template = templateExpr map (finv => extractTemplateFromLambda(finv.args(0).asInstanceOf[Lambda])) + lazy val normalizedTemplate = template.map(normalizeExpr(_, (e1: Expr, e2: Expr) => + throw new IllegalStateException("Not implemented yet!"))) def hasTemplate: Boolean = templateExpr.isDefined def getPostWoTemplate = postWoTemplate match { diff --git a/src/main/scala/leon/invariant/structure/LinearConstraintUtil.scala b/src/main/scala/leon/invariant/structure/LinearConstraintUtil.scala index dfdace00bf024e9e6dda8894dbf428de00581de0..0848ccbb487679dcaf07892b4d39c8ba4edf6942 100644 --- a/src/main/scala/leon/invariant/structure/LinearConstraintUtil.scala +++ b/src/main/scala/leon/invariant/structure/LinearConstraintUtil.scala @@ -5,11 +5,14 @@ import purescala._ import purescala.Common._ import purescala.Expressions._ import purescala.ExprOps._ +import leon.purescala.Types._ import purescala.Extractors._ -import scala.collection.mutable.{ Map => MutableMap } +import scala.collection.mutable.{ Map => MutableMap, Set => MutableSet, MutableList } import invariant.util._ import BigInt._ import PredicateUtil._ +import Stats._ + class NotImplementedException(message: String) extends RuntimeException(message) @@ -21,6 +24,8 @@ object LinearConstraintUtil { val tru = BooleanLiteral(true) val fls = BooleanLiteral(false) + val debugElimination = false + //some utility methods def getFIs(ctr: LinearConstraint): Set[FunctionInvocation] = { val fis = ctr.coeffMap.keys.collect { @@ -33,12 +38,12 @@ object LinearConstraintUtil { case lc: LinearConstraint if lc.coeffMap.isEmpty => ExpressionTransformer.simplify(lt.toExpr) match { case BooleanLiteral(v) => Some(v) - case _ => None + case _ => None } case _ => None } - /** + /** * the expression 'Expr' is required to be a linear atomic predicate (or a template), * if not, an exception would be thrown. * For now some of the constructs are not handled. @@ -47,37 +52,33 @@ object LinearConstraintUtil { */ def exprToTemplate(expr: Expr): LinearTemplate = { - //println("Expr: "+expr) //these are the result values var coeffMap = MutableMap[Expr, Expr]() var constant: Option[Expr] = None - var isTemplate : Boolean = false + var isTemplate: Boolean = false def addCoefficient(term: Expr, coeff: Expr) = { if (coeffMap.contains(term)) { val value = coeffMap(term) val newcoeff = simplifyArithmetic(Plus(value, coeff)) - //if newcoeff becomes zero remove it from the coeffMap - if(newcoeff == zero) { + if (newcoeff == zero) { coeffMap.remove(term) - } else{ + } else { coeffMap.update(term, newcoeff) } } else coeffMap += (term -> simplifyArithmetic(coeff)) - if (variablesOf(coeff).nonEmpty) { isTemplate = true } } - def addConstant(coeff: Expr) ={ + def addConstant(coeff: Expr) = { if (constant.isDefined) { val value = constant.get constant = Some(simplifyArithmetic(Plus(value, coeff))) } else constant = Some(simplifyArithmetic(coeff)) - if (variablesOf(coeff).nonEmpty) { isTemplate = true } @@ -86,54 +87,37 @@ object LinearConstraintUtil { //recurse into plus and get all minterms def getMinTerms(lexpr: Expr): Seq[Expr] = lexpr match { case Plus(e1, e2) => getMinTerms(e1) ++ getMinTerms(e2) - case _ => Seq(lexpr) + case _ => Seq(lexpr) } - val linearExpr = MakeLinear(expr) //the top most operator should be a relation - val Operator(Seq(lhs, InfiniteIntegerLiteral(x)), op) = linearExpr + val Operator(Seq(lhs, InfiniteIntegerLiteral(x)), op) = makeLinear(expr) /*if (lhs.isInstanceOf[InfiniteIntegerLiteral]) - throw new IllegalStateException("relation on two integers, not in canonical form: " + linearExpr)*/ - - val minterms = getMinTerms(lhs) - + throw new IllegalStateException("relation on two integers, not in canonical form: " + linearExpr)*/ //handle each minterm - minterms.foreach((minterm: Expr) => minterm match { - case _ if (isTemplateExpr(minterm)) => { - addConstant(minterm) - } - case Times(e1, e2) => { + getMinTerms(lhs).foreach(minterm => minterm match { + case _ if (isTemplateExpr(minterm)) => addConstant(minterm) + case Times(e1, e2) => e2 match { - case Variable(_) => ; - case ResultVariable(_) => ; - case FunctionInvocation(_, _) => ; - case _ => throw new IllegalStateException("Multiplicand not a constraint variable: " + e2) + case Variable(_) | ResultVariable(_) | FunctionInvocation(_, _) => + case _ => throw new IllegalStateException("Multiplicand not a constraint variable: " + e2) } - e1 match { - //case c @ InfiniteIntegerLiteral(_) => addCoefficient(e2, c) - case _ if (isTemplateExpr(e1)) => { - addCoefficient(e2, e1) - } + e1 match { + case _ if (isTemplateExpr(e1)) => addCoefficient(e2, e1) case _ => throw new IllegalStateException("Coefficient not a constant or template expression: " + e1) - } - } - case Variable(_) => { - //here the coefficient is 1 - addCoefficient(minterm, one) - } - case ResultVariable(_) => { - addCoefficient(minterm, one) - } + } + case Variable(_) => addCoefficient(minterm, one) //here the coefficient is 1 + case ResultVariable(_) => addCoefficient(minterm, one) case _ => throw new IllegalStateException("Unhandled min term: " + minterm) }) - if(coeffMap.isEmpty && constant.isEmpty) { + if (coeffMap.isEmpty && constant.isEmpty) { //here the generated template the constant term is zero. new LinearConstraint(op, Map.empty, Some(zero)) - } else if(isTemplate) { + } else if (isTemplate) { new LinearTemplate(op, coeffMap.toMap, constant) - } else{ - new LinearConstraint(op, coeffMap.toMap,constant) + } else { + new LinearConstraint(op, coeffMap.toMap, constant) } } @@ -142,56 +126,53 @@ object LinearConstraintUtil { * This assumes that the input expression is an atomic predicate (i.e, without and, or and nots) * This is subjected to constant modification. */ - def MakeLinear(atom: Expr): Expr = { + def makeLinear(atom: Expr): Expr = { //pushes the minus inside the arithmetic terms //we assume that inExpr is in linear form - def PushMinus(inExpr: Expr): Expr = { + def pushMinus(inExpr: Expr): Expr = { inExpr match { - case IntLiteral(v) => IntLiteral(-v) - case InfiniteIntegerLiteral(v) => InfiniteIntegerLiteral(-v) - case t: Terminal => Times(mone, t) + case IntLiteral(v) => IntLiteral(-v) + case InfiniteIntegerLiteral(v) => InfiniteIntegerLiteral(-v) + case t: Terminal => Times(mone, t) case fi @ FunctionInvocation(fdef, args) => Times(mone, fi) - case UMinus(e1) => e1 - case RealUMinus(e1) => e1 - case Minus(e1, e2) => Plus(PushMinus(e1), e2) - case RealMinus(e1, e2) => Plus(PushMinus(e1), e2) - case Plus(e1, e2) => Plus(PushMinus(e1), PushMinus(e2)) - case RealPlus(e1, e2) => Plus(PushMinus(e1), PushMinus(e2)) - case Times(e1, e2) => { + case UMinus(e1) => e1 + case RealUMinus(e1) => e1 + case Minus(e1, e2) => Plus(pushMinus(e1), e2) + case RealMinus(e1, e2) => Plus(pushMinus(e1), e2) + case Plus(e1, e2) => Plus(pushMinus(e1), pushMinus(e2)) + case RealPlus(e1, e2) => Plus(pushMinus(e1), pushMinus(e2)) + case Times(e1, e2) => //here push the minus in to the coefficient which is the first argument - Times(PushMinus(e1), e2) - } - case RealTimes(e1, e2) => Times(PushMinus(e1), e2) - case _ => throw new NotImplementedException("PushMinus -- Operators not yet handled: " + inExpr) + Times(pushMinus(e1), e2) + case RealTimes(e1, e2) => Times(pushMinus(e1), e2) + case _ => throw new NotImplementedException("pushMinus -- Operators not yet handled: " + inExpr) } } - - import leon.purescala.Types._ + //we assume that ine is in linear form - def PushTimes(mul: Expr, ine: Expr): Expr = { + def pushTimes(mul: Expr, ine: Expr): Expr = { val isReal = ine.getType == RealType && mul.getType == RealType val timesCons = - if(isReal) RealTimes + if (isReal) RealTimes else Times ine match { - case t: Terminal => timesCons(mul, t) + case t: Terminal => timesCons(mul, t) case fi @ FunctionInvocation(fdef, ars) => timesCons(mul, fi) - case Plus(e1, e2) => Plus(PushTimes(mul, e1), PushTimes(mul, e2)) + case Plus(e1, e2) => Plus(pushTimes(mul, e1), pushTimes(mul, e2)) case RealPlus(e1, e2) => - val r1 = PushTimes(mul, e1) - val r2 = PushTimes(mul, e2) + val r1 = pushTimes(mul, e1) + val r2 = pushTimes(mul, e2) if (isReal) RealPlus(r1, r2) else Plus(r1, r2) - case Times(e1, e2) => { + case Times(e1, e2) => //here push the times into the coefficient which should be the first expression - Times(PushTimes(mul, e1), e2) - } + Times(pushTimes(mul, e1), e2) case RealTimes(e1, e2) => - val r = PushTimes(mul, e1) - if(isReal) RealTimes(r, e2) + val r = pushTimes(mul, e1) + if (isReal) RealTimes(r, e2) else Times(r, e2) - case _ => throw new NotImplementedException("PushTimes -- Operators not yet handled: " + ine) + case _ => throw new NotImplementedException("pushTimes -- Operators not yet handled: " + ine) } } @@ -199,16 +180,15 @@ object LinearConstraintUtil { //we assume that ine is in linear form and also that all constants are integers def simplifyConsts(ine: Expr): (Option[Expr], BigInt) = { ine match { - case IntLiteral(v) => (None, v) + case IntLiteral(v) => (None, v) case InfiniteIntegerLiteral(v) => (None, v) case Plus(e1, e2) => { val (r1, c1) = simplifyConsts(e1) val (r2, c2) = simplifyConsts(e2) - val newe = (r1, r2) match { - case (None, None) => None - case (Some(t), None) => Some(t) - case (None, Some(t)) => Some(t) + case (None, None) => None + case (Some(t), None) => Some(t) + case (None, Some(t)) => Some(t) case (Some(t1), Some(t2)) => Some(Plus(t1, t2)) } (newe, c1 + c2) @@ -220,31 +200,26 @@ object LinearConstraintUtil { def mkLinearRecur(inExpr: Expr): Expr = { //println("inExpr: "+inExpr + " tpe: "+inExpr.getType) val res = inExpr match { - case e @ Operator(Seq(e1, e2), op) - if ((e.isInstanceOf[Equals] || e.isInstanceOf[LessThan] - || e.isInstanceOf[LessEquals] || e.isInstanceOf[GreaterThan] - || e.isInstanceOf[GreaterEquals])) => { + case e @ Operator(Seq(e1, e2), op) if ((e.isInstanceOf[Equals] || e.isInstanceOf[LessThan] + || e.isInstanceOf[LessEquals] || e.isInstanceOf[GreaterThan] + || e.isInstanceOf[GreaterEquals])) => { //check if the expression has real valued sub-expressions - val isReal = hasReals(e1) || hasReals(e2) - //doing something else ... ? - // println("[DEBUG] Expr 1 " + e1 + " of type " + e1.getType + " and Expr 2 " + e2 + " of type" + e2.getType) + val isReal = hasReals(e1) || hasReals(e2) val (newe, newop) = e match { - case t: Equals => (Minus(e1, e2), Equals) - case t: LessEquals => (Minus(e1, e2), LessEquals) + case t: Equals => (Minus(e1, e2), Equals) + case t: LessEquals => (Minus(e1, e2), LessEquals) case t: GreaterEquals => (Minus(e2, e1), LessEquals) - case t: LessThan => { + case t: LessThan => if (isReal) (Minus(e1, e2), LessThan) else - (Plus(Minus(e1, e2), one), LessEquals) - } - case t: GreaterThan => { - if(isReal) - (Minus(e2,e1),LessThan) + (Plus(Minus(e1, e2), one), LessEquals) + case t: GreaterThan => + if (isReal) + (Minus(e2, e1), LessThan) else - (Plus(Minus(e2, e1), one), LessEquals) - } + (Plus(Minus(e2, e1), one), LessEquals) } val r = mkLinearRecur(newe) //simplify the resulting constants @@ -252,36 +227,30 @@ object LinearConstraintUtil { val finale = if (r2.isDefined) { if (const != 0) Plus(r2.get, InfiniteIntegerLiteral(const)) else r2.get - } else InfiniteIntegerLiteral(const) - //println(r + " simplifies to "+finale) + } else InfiniteIntegerLiteral(const) newop(finale, zero) } - case Minus(e1, e2) => Plus(mkLinearRecur(e1), PushMinus(mkLinearRecur(e2))) - case RealMinus(e1, e2) => RealPlus(mkLinearRecur(e1), PushMinus(mkLinearRecur(e2))) - case UMinus(e1) => PushMinus(mkLinearRecur(e1)) - case RealUMinus(e1) => PushMinus(mkLinearRecur(e1)) + case Minus(e1, e2) => Plus(mkLinearRecur(e1), pushMinus(mkLinearRecur(e2))) + case RealMinus(e1, e2) => RealPlus(mkLinearRecur(e1), pushMinus(mkLinearRecur(e2))) + case UMinus(e1) => pushMinus(mkLinearRecur(e1)) + case RealUMinus(e1) => pushMinus(mkLinearRecur(e1)) case Times(_, _) | RealTimes(_, _) => { val Operator(Seq(e1, e2), op) = inExpr val (r1, r2) = (mkLinearRecur(e1), mkLinearRecur(e2)) - if(isTemplateExpr(r1)) { - PushTimes(r1, r2) - } else if(isTemplateExpr(r2)){ - PushTimes(r2, r1) - } else + if (isTemplateExpr(r1)) + pushTimes(r1, r2) + else if (isTemplateExpr(r2)) + pushTimes(r2, r1) + else throw new IllegalStateException("Expression not linear: " + Times(r1, r2)) } case Plus(e1, e2) => Plus(mkLinearRecur(e1), mkLinearRecur(e2)) - case rp@RealPlus(e1, e2) => - //println(s"Expr: $rp arg1: $e1 tpe: ${e1.getType} arg2: $e2 tpe: ${e2.getType}") - val r1 = mkLinearRecur(e1) - val r2 = mkLinearRecur(e2) - //println(s"Res1: $r1 tpe: ${r1.getType} Res2: $r2 tpe: ${r2.getType}") - RealPlus(r1, r2) - case t: Terminal => t + case rp @ RealPlus(e1, e2) => + RealPlus(mkLinearRecur(e1), mkLinearRecur(e2)) + case t: Terminal => t case fi: FunctionInvocation => fi - case _ => throw new IllegalStateException("Expression not linear: " + inExpr) + case _ => throw new IllegalStateException("Expression not linear: " + inExpr) } - //println("Res: "+res+" tpe: "+res.getType) res } val rese = mkLinearRecur(atom) @@ -291,161 +260,124 @@ object LinearConstraintUtil { /** * Replaces an expression by another expression in the terms of the given linear constraint. */ - def replaceInCtr(replaceMap: Map[Expr, Expr], lc: LinearConstraint): Option[LinearConstraint] = { - + def replaceInCtr(replaceMap: Map[Identifier, Expr], lc: LinearConstraint): Option[LinearConstraint] = { //println("Replacing in "+lc+" repMap: "+replaceMap) - val newexpr = ExpressionTransformer.simplify(simplifyArithmetic(replace(replaceMap, lc.toExpr))) - //println("new expression: "+newexpr) + val newexpr = ExpressionTransformer.simplify(replaceFromIDs(replaceMap, lc.toExpr)) if (newexpr == tru) None - else if(newexpr == fls) throw new IllegalStateException("!!Constraint reduced to false during elimination: " + lc) + else if (newexpr == fls) throw new IllegalStateException("!!Constraint reduced to false during elimination: " + lc) else { val res = exprToTemplate(newexpr) //check if res is true or false evaluate(res) match { case Some(false) => throw new IllegalStateException("!!Constraint reduced to false during elimination: " + lc) - case Some(true) => None //constraint reduced to true + case Some(true) => None //constraint reduced to true case _ => - val resctr = res.asInstanceOf[LinearConstraint] - Some(resctr) + Some(res.asInstanceOf[LinearConstraint]) } } } - /** - * Eliminates the specified variables from a conjunction of linear constraints (a disjunct) (that is satisfiable) - * We assume that the disjunct is in nnf form + def ctrVars(lc: LinearConstraint) = lc.coeffMap.keySet.map { case Variable(id) => id } + + /** + * Eliminates all variables except the `retainVars` from a conjunction of linear constraints (a disjunct) (that is satisfiable) + * We assume that the disjunct is in nnf form. + * The strategy is to look for (a) equality involving the elimVars or (b) check if all bounds are lower or (c) if all bounds are upper. + * TODO: handle cases wherein the coefficient of the variable that is substituted is not 1 or -1 * - * debugger is a function used for debugging + * @param debugger is a function used for debugging */ - val debugElimination = false - def apply1PRuleOnDisjunct(linearCtrs: Seq[LinearConstraint], elimVars: Set[Identifier], - debugger: Option[(Seq[LinearConstraint] => Unit)]): Seq[LinearConstraint] = { - //eliminate one variable at a time - //each iteration produces a new set of linear constraints - elimVars.foldLeft(linearCtrs)((acc, elimVar) => { - val newdisj = apply1PRuleOnDisjunct(acc, elimVar) - - if(debugElimination) { - if(debugger.isDefined) { - debugger.get(newdisj) - } - } - - newdisj - }) - } - - def apply1PRuleOnDisjunct(linearCtrs: Seq[LinearConstraint], elimVar: Identifier): Seq[LinearConstraint] = { - - if(debugElimination) - println("Trying to eliminate: "+elimVar) - - //collect all relevant constraints - val emptySeq = Seq[LinearConstraint]() - val (relCtrs, rest) = linearCtrs.foldLeft((emptySeq,emptySeq))((acc,lc) => { - if(variablesOf(lc.toExpr).contains(elimVar)) { - (lc +: acc._1,acc._2) - } else { - (acc._1,lc +: acc._2) - } - }) - - //now consider each constraint look for (a) equality involving the elimVar or (b) check if all bounds are lower - //or (c) if all bounds are upper. - var elimExpr : Option[Expr] = None - var elimCtr : Option[LinearConstraint] = None - var allUpperBounds : Boolean = true - var allLowerBounds : Boolean = true - var foundEquality : Boolean = false - var skippingEquality : Boolean = false - - relCtrs.foreach((lc) => { - //check for an equality - if (lc.toExpr.isInstanceOf[Equals] && lc.coeffMap.contains(elimVar.toVariable)) { - foundEquality = true - - //here, sometimes we replace an existing expression with a better one if available - if (elimExpr.isEmpty || shouldReplace(elimExpr.get, lc, elimVar)) { - //if the coeffcient of elimVar is +ve the the sign of the coeff of every other term should be changed - val InfiniteIntegerLiteral(elimCoeff) = lc.coeffMap(elimVar.toVariable) - //make sure the value of the coefficient is 1 or -1 - //TODO: handle cases wherein the coefficient is not 1 or -1 - if (elimCoeff == 1 || elimCoeff == -1) { - val changeSign = if (elimCoeff > 0) true else false - - val startval = if (lc.const.isDefined) { - val InfiniteIntegerLiteral(cval) = lc.const.get - val newconst = if (changeSign) -cval else cval - InfiniteIntegerLiteral(newconst) - - } else zero - - val substExpr = lc.coeffMap.foldLeft(startval: Expr)((acc, summand) => { - val (term, InfiniteIntegerLiteral(coeff)) = summand - if (term != elimVar.toVariable) { - - val newcoeff = if (changeSign) -coeff else coeff - val newsummand = if (newcoeff == 1) term else Times(term, InfiniteIntegerLiteral(newcoeff)) - if (acc == zero) newsummand - else Plus(acc, newsummand) - - } else acc - }) - - elimExpr = Some(simplifyArithmetic(substExpr)) - elimCtr = Some(lc) - - if (debugElimination) { - println("Using ctr: " + lc + " found mapping: " + elimVar + " --> " + substExpr) + def apply1PRuleOnDisjunct(linearCtrs: Seq[LinearConstraint], retainVars: Set[Identifier], + debugger: Option[(Seq[LinearConstraint] => Unit)]): Seq[LinearConstraint] = { + val idsWithUpperBounds = MutableSet[Identifier]() // identifiers with only upper bounds + val idsWithLowerBounds = MutableSet[Identifier]() // identifiers with only lower bounds + val idsWithEquality = MutableSet[Identifier]() // identifiers for which an equality constraint exist + var eqctrs = MutableList[LinearConstraint]() + var restctrs = MutableList[LinearConstraint]() + linearCtrs.foreach { + case lc => + val vars = ctrVars(lc) + val elimVars = vars -- retainVars + lc.template match { + case eq: Equals => + idsWithEquality ++= vars + if (!elimVars.isEmpty) + eqctrs += lc + else restctrs += lc + // choose all vars whose coefficient is either 1 or -1 + case _: LessEquals | _: LessThan => + elimVars.foreach { elimVar => + val InfiniteIntegerLiteral(elimCoeff) = lc.coeffMap(elimVar.toVariable) + if (elimCoeff > 0) + idsWithUpperBounds += elimVar //here, we have found an upper bound + else + idsWithLowerBounds += elimVar //here, we have found a lower bound } - } else { - skippingEquality = true - } + restctrs += lc + case _ => throw new IllegalStateException("LinearConstraint not in expeceted form : " + lc.toExpr) } - } else if ((lc.toExpr.isInstanceOf[LessEquals] || lc.toExpr.isInstanceOf[LessThan]) - && lc.coeffMap.contains(elimVar.toVariable)) { - - val InfiniteIntegerLiteral(elimCoeff) = lc.coeffMap(elimVar.toVariable) - if (elimCoeff > 0) { - //here, we have found an upper bound - allLowerBounds = false - } else { - //here, we have found a lower bound - allUpperBounds = false + } + // sort 'eqctrs' by the size of the constraints so that we use smaller expressions in 'subst' map. + var currEqs = eqctrs.sortBy(eqc => eqc.coeffMap.keySet.size + (if (eqc.const.isDefined) 1 else 0)) + // compute the subst map recursively + var nextEqs = MutableList[LinearConstraint]() + var foundSubst = true + var subst = Map[Identifier, Expr]() + while (foundSubst) { + foundSubst = false + currEqs.foreach { eq => + // replace the constraint by the current subst (which may require multiple applications) + replaceInCtr(subst, eq) match { + case None => // constraint reduced to true, drop the constraint + case Some(newc) => + // choose one new variable that can be substituted + val elimVarOpt = ctrVars(newc).find { evar => + !retainVars.contains(evar) && !subst.contains(evar) && + (newc.coeffMap(evar.toVariable) match { + case InfiniteIntegerLiteral(elimCoeff) if (elimCoeff == 1 || elimCoeff == -1) => true + case _ => false + }) + } + elimVarOpt match { + case None => + nextEqs += newc // here, the constraint cannot be substituted, so we need to preserve it + case Some(elimVar) => + //if the coeffcient of elimVar is +ve the the sign of the coeff of every other term should be changed + val InfiniteIntegerLiteral(elimCoeff) = newc.coeffMap(elimVar.toVariable) + val changeSign = elimCoeff > 0 + val startval = if (newc.const.isDefined) { + val InfiniteIntegerLiteral(cval) = newc.const.get + val newconst = if (changeSign) -cval else cval + InfiniteIntegerLiteral(newconst) + } else zero + val substExpr = newc.coeffMap.foldLeft(startval: Expr) { + case (acc, (term, InfiniteIntegerLiteral(coeff))) if (term != elimVar.toVariable) => + val newcoeff = if (changeSign) -coeff else coeff + val newsummand = if (newcoeff == 1) term else Times(term, InfiniteIntegerLiteral(newcoeff)) + if (acc == zero) newsummand + else Plus(acc, newsummand) + case (acc, _) => acc + } + if (debugElimination) { + println("Analyzing ctr: " + newc + " found mapping: " + elimVar + " --> " + substExpr) + } + subst = Util.substClosure(subst + (elimVar -> simplifyArithmetic(substExpr))) + foundSubst = true + } } - } else { - //here, we assume that the operators are normalized to Equals, LessThan and LessEquals - throw new IllegalStateException("LinearConstraint not in expeceted form : " + lc.toExpr) } - }) - - val newctrs = if (elimExpr.isDefined) { - - val elimMap = Map[Expr, Expr](elimVar.toVariable -> elimExpr.get) - var repCtrs = Seq[LinearConstraint]() - relCtrs.foreach((ctr) => { - if (ctr != elimCtr.get) { - //replace 'elimVar' by 'elimExpr' in ctr - val repCtr = this.replaceInCtr(elimMap, ctr) - if (repCtr.isDefined) - repCtrs +:= repCtr.get - } - }) - repCtrs - - } else if (!foundEquality && (allLowerBounds || allUpperBounds)) { - //here, drop all relCtrs. None of them are important - Seq() - } else { - //for stats - if(skippingEquality) { - Stats.updateCumStats(1,"SkippedVar") - } - //cannot eliminate the variable - relCtrs + currEqs = nextEqs } - val resctrs = (newctrs ++ rest) - //println("After eliminating: "+elimVar+" : "+resctrs) + val oneSidedVars = ((idsWithUpperBounds -- idsWithLowerBounds) ++ (idsWithLowerBounds -- idsWithUpperBounds)) -- idsWithEquality + val resctrs = (restctrs.flatMap { + case ctr if ctrVars(ctr).intersect(oneSidedVars).isEmpty => + replaceInCtr(subst, ctr) match { + case None => Seq() + case Some(newctr) => Seq(newctr) + } + case _ => Seq() // drop constraints with `oneSidedVars` + } ++ currEqs).distinct // note: this is very important!! + Stats.updateCounterStats(currEqs.size, "UneliminatedEqualities", "disjuncts") resctrs } @@ -459,43 +391,29 @@ object LinearConstraintUtil { size } - def sizeCtr(ctr : LinearConstraint) : Int = { + def sizeCtr(ctr: LinearConstraint): Int = { val coeffSize = ctr.coeffMap.foldLeft(0)((acc, pair) => { val (term, coeff) = pair - if(coeff == one) acc + 1 + if (coeff == one) acc + 1 else acc + sizeExpr(coeff) + 2 }) - if(ctr.const.isDefined) coeffSize + 1 + if (ctr.const.isDefined) coeffSize + 1 else coeffSize } - def shouldReplace(currExpr : Expr, candidateCtr : LinearConstraint, elimVar: Identifier) : Boolean = { - if(!currExpr.isInstanceOf[InfiniteIntegerLiteral]) { - //is the candidate a constant - if(candidateCtr.coeffMap.size == 1) true - else{ - //computing the size of currExpr - if(sizeExpr(currExpr) > (sizeCtr(candidateCtr) - 1)) true - else false - } - } else false - } - - //remove transitive axioms - /** * Checks if the expression is linear i.e, * is only conjuntion and disjunction of linear atomic predicates */ - def isLinear(e: Expr) : Boolean = { - e match { - case And(args) => args forall isLinear - case Or(args) => args forall isLinear - case Not(arg) => isLinear(arg) - case Implies(e1, e2) => isLinear(e1) && isLinear(e2) - case t : Terminal => true - case atom => - exprToTemplate(atom).isInstanceOf[LinearConstraint] - } + def isLinearFormula(e: Expr): Boolean = { + e match { + case And(args) => args forall isLinearFormula + case Or(args) => args forall isLinearFormula + case Not(arg) => isLinearFormula(arg) + case Implies(e1, e2) => isLinearFormula(e1) && isLinearFormula(e2) + case t: Terminal => true + case atom => + exprToTemplate(atom).isInstanceOf[LinearConstraint] + } } } diff --git a/src/main/scala/leon/invariant/templateSolvers/CegisSolver.scala b/src/main/scala/leon/invariant/templateSolvers/CegisSolver.scala index ee31167305c42e020fac4dc8bafcec53ff8f8eeb..a3eca287b0d221af93120b1578020e8baa7a5fc9 100644 --- a/src/main/scala/leon/invariant/templateSolvers/CegisSolver.scala +++ b/src/main/scala/leon/invariant/templateSolvers/CegisSolver.scala @@ -15,6 +15,7 @@ import leon.invariant.util.RealValuedExprEvaluator._ import PredicateUtil._ import SolverUtil._ import Stats._ +import Util._ class CegisSolver(ctx: InferenceContext, program: Program, rootFun: FunDef, ctrTracker: ConstraintTracker, @@ -109,7 +110,7 @@ class CegisCore(ctx: InferenceContext, if (dumpCandidateInvs) { reporter.info("Candidate invariants") - val candInvs = cegisSolver.getAllInvariants(model) + val candInvs = TemplateInstantiator.getAllInvariants(model, cegisSolver.ctrTracker.getFuncs) candInvs.foreach((entry) => println(entry._1.id + "-->" + entry._2)) } val tempVarMap: Map[Expr, Expr] = model.map((elem) => (elem._1.toVariable, elem._2)).toMap diff --git a/src/main/scala/leon/invariant/templateSolvers/DisjunctChooser.scala b/src/main/scala/leon/invariant/templateSolvers/DisjunctChooser.scala new file mode 100644 index 0000000000000000000000000000000000000000..b3a935124bbe609bc423e49404336fd7cf20ef7c --- /dev/null +++ b/src/main/scala/leon/invariant/templateSolvers/DisjunctChooser.scala @@ -0,0 +1,213 @@ +package leon +package invariant.templateSolvers + +import z3.scala._ +import purescala.Common._ +import purescala.Definitions._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.Extractors._ +import purescala.Types._ +import evaluators._ +import java.io._ +import solvers._ +import solvers.combinators._ +import solvers.smtlib._ +import solvers.z3._ +import scala.util.control.Breaks._ +import purescala.ScalaPrinter +import scala.collection.mutable.{ Map => MutableMap } +import scala.reflect.runtime.universe +import invariant.engine._ +import invariant.factories._ +import invariant.util._ +import invariant.util.ExpressionTransformer._ +import invariant.structure._ +import invariant.structure.FunctionUtils._ +import Stats._ + +import Util._ +import PredicateUtil._ +import SolverUtil._ + +class DisjunctChooser(ctx: InferenceContext, program: Program, ctrTracker: ConstraintTracker, defaultEval: DefaultEvaluator) { + val debugElimination = false + val debugChooseDisjunct = false + val debugTheoryReduction = false + val debugAxioms = false + val debugReducedFormula = false + val verifyInvariant = false + val printPathToFile = false + val dumpPathAsSMTLIB = false + + val leonctx = ctx.leonContext + val linearEval = new LinearRelationEvaluator(ctx) // an evaluator for quickly checking the result of linear predicates + + //additional book-keeping for statistics + val trackNumericalDisjuncts = false + var numericalDisjuncts = List[Expr]() + + /** + * A helper function used only in debugging. + */ + protected def doesSatisfyExpr(expr: Expr, model: LazyModel): Boolean = { + val compModel = variablesOf(expr).map { k => k -> model(k) }.toMap + defaultEval.eval(expr, new Model(compModel)).result match { + case Some(BooleanLiteral(true)) => true + case _ => false + } + } + + /** + * This solver does not use any theories other than UF/ADT. It assumes that other theories are axiomatized in the VC. + * This method can be overloaded by the subclasses. + */ + protected def axiomsForTheory(formula: Formula, calls: Set[Call], model: LazyModel): Seq[Constraint] = Seq() + + /** + * Chooses a purely numerical disjunct from a given formula that is + * satisfied by the model + * @precondition the formula is satisfied by the model + * @tempIdMap a model for the template variables + */ + def chooseNumericalDisjunct(formula: Formula, initModel: LazyModel, tempIdMap: Map[Identifier, Expr]): (Seq[LinearConstraint], Seq[LinearTemplate], Set[Call]) = { + val satCtrs = formula.pickSatDisjunct(formula.firstRoot, initModel, tempIdMap, defaultEval) //this picks the satisfiable disjunct of the VC modulo axioms + //for debugging + if (debugChooseDisjunct || printPathToFile || dumpPathAsSMTLIB || verifyInvariant) { + val pathctrs = satCtrs.map(_.toExpr) + val plainFormula = createAnd(pathctrs) + val pathcond = simplifyArithmetic(plainFormula) + if (printPathToFile) { + //val simpcond = ExpressionTransformer.unFlatten(pathcond, variablesOf(pathcond).filterNot(TVarFactory.isTemporary _)) + ExpressionTransformer.PrintWithIndentation("full-path", pathcond) + } + if (dumpPathAsSMTLIB) { + val filename = "pathcond" + FileCountGUID.getID + ".smt2" + toZ3SMTLIB(pathcond, filename, "QF_NIA", leonctx, program) + println("Path dumped to: " + filename) + } + if (debugChooseDisjunct) { + satCtrs.filter(_.isInstanceOf[LinearConstraint]).map(_.toExpr).foreach((ctr) => { + if (!doesSatisfyExpr(ctr, initModel)) + throw new IllegalStateException("Path ctr not satisfied by model: " + ctr) + }) + } + if (verifyInvariant) { + println("checking invariant for path...") + val sat = checkInvariant(pathcond, leonctx, program) + } + } + var calls = Set[Call]() + var adtExprs = Seq[Expr]() + satCtrs.foreach { + case t: Call => calls += t + case t: ADTConstraint if (t.cons || t.sel) => adtExprs :+= t.expr + // TODO: ignoring all set constraints here, fix this + case _ => ; + } + val callExprs = calls.map(_.toExpr) + + val axiomCtrs = time { + ctrTracker.specInstantiator.axiomsForCalls(formula, calls, initModel, tempIdMap, defaultEval) + } { updateCumTime(_, "Total-AxiomChoose-Time") } + + //here, handle theory operations by reducing them to axioms. + //Note: uninterpreted calls/ADTs are handled below as they are more general. Here, we handle + //other theory axioms like: multiplication, sets, arrays, maps etc. + val theoryCtrs = time { + axiomsForTheory(formula, calls, initModel) + } { updateCumTime(_, "Total-TheoryAxiomatization-Time") } + + //Finally, eliminate UF/ADT + // convert all adt constraints to 'cons' ctrs, and expand the model + val selTrans = new SelectorToCons() + val cons = selTrans.selToCons(adtExprs) + val expModel = selTrans.getModel(initModel) + // get constraints for UFADTs + val callCtrs = time { + (new UFADTEliminator(leonctx, program)).constraintsForCalls((callExprs ++ cons), + linearEval.predEval(expModel)).map(ConstraintUtil.createConstriant _) + } { updateCumTime(_, "Total-ElimUF-Time") } + + //exclude guards, separate calls and cons from the rest + var lnctrs = Set[LinearConstraint]() + var temps = Set[LinearTemplate]() + (satCtrs ++ callCtrs ++ axiomCtrs ++ theoryCtrs).foreach { + case t: LinearConstraint => lnctrs += t + case t: LinearTemplate => temps += t + case _ => ; + } + if (debugChooseDisjunct) { + lnctrs.map(_.toExpr).foreach((ctr) => { + if (!doesSatisfyExpr(ctr, expModel)) + throw new IllegalStateException("Ctr not satisfied by model: " + ctr) + }) + } + if (debugTheoryReduction) { + val simpPathCond = createAnd((lnctrs ++ temps).map(_.template).toSeq) + if (verifyInvariant) { + println("checking invariant for simp-path...") + checkInvariant(simpPathCond, leonctx, program) + } + } + if (trackNumericalDisjuncts) { + numericalDisjuncts :+= createAnd((lnctrs ++ temps).map(_.template).toSeq) + } + val tempCtrs = temps.toSeq + val elimCtrs = eliminateVars(lnctrs.toSeq, tempCtrs) + //for debugging + if (debugReducedFormula) { + println("Final Path Constraints: " + elimCtrs ++ tempCtrs) + if (verifyInvariant) { + println("checking invariant for final disjunct... ") + checkInvariant(createAnd((elimCtrs ++ tempCtrs).map(_.template)), leonctx, program) + } + } + (elimCtrs, tempCtrs, calls) + } + + /** + * TODO:Remove transitive facts. E.g. a <= b, b <= c, a <=c can be simplified by dropping a <= c + * TODO: simplify the formulas and remove implied conjuncts if possible (note the formula is satisfiable, so there can be no inconsistencies) + * e.g, remove: a <= b if we have a = b or if a < b + * Also, enrich the rules for quantifier elimination: try z3 quantifier elimination on variables that have an equality. + * TODO: Use the dependence chains in the formulas to identify what to assertionize + * and what can never be implied by solving for the templates + */ + import LinearConstraintUtil._ + def eliminateVars(lnctrs: Seq[LinearConstraint], temps: Seq[LinearTemplate]): Seq[LinearConstraint] = { + if (temps.isEmpty) lnctrs //here ants ^ conseq is sat (otherwise we wouldn't reach here) and there is no way to falsify this path + else { + if (debugElimination && verifyInvariant) { + println("checking invariant for disjunct before elimination...") + checkInvariant(createAnd((lnctrs ++ temps).map(_.template)), leonctx, program) + } + // for debugging + val debugger = + if (debugElimination && verifyInvariant) { + Some((ctrs: Seq[LinearConstraint]) => { + val debugRes = checkInvariant(createAnd((ctrs ++ temps).map(_.template)), leonctx, program) + }) + } else None + val elimLnctrs = time { + apply1PRuleOnDisjunct(lnctrs, temps.flatMap(lt => variablesOf(lt.template)).toSet, debugger) + } { updateCumTime(_, "ElimTime") } + + if (debugElimination) { + println("Path constriants (after elimination): " + elimLnctrs) + if (verifyInvariant) { + println("checking invariant for disjunct after elimination...") + checkInvariant(createAnd((elimLnctrs ++ temps).map(_.template)), leonctx, program) + } + } + //for stats + if (ctx.dumpStats) { + Stats.updateCounterStats(lnctrs.size, "CtrsBeforeElim", "disjuncts") + Stats.updateCounterStats(lnctrs.size - elimLnctrs.size, "EliminatedAtoms", "disjuncts") + Stats.updateCounterStats(temps.size, "Param-Atoms", "disjuncts") + Stats.updateCounterStats(elimLnctrs.size, "NonParam-Atoms", "disjuncts") + } + elimLnctrs + } + } +} diff --git a/src/main/scala/leon/invariant/templateSolvers/ExistentialQuantificationSolver.scala b/src/main/scala/leon/invariant/templateSolvers/ExistentialQuantificationSolver.scala new file mode 100644 index 0000000000000000000000000000000000000000..5af100ed7037832291416229fc690b215a6acca4 --- /dev/null +++ b/src/main/scala/leon/invariant/templateSolvers/ExistentialQuantificationSolver.scala @@ -0,0 +1,93 @@ +package leon +package invariant.templateSolvers + +import z3.scala._ +import purescala.Common._ +import purescala.Definitions._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.Extractors._ +import purescala.Types._ +import evaluators._ +import java.io._ +import solvers._ +import solvers.combinators._ +import solvers.smtlib._ +import solvers.z3._ +import scala.util.control.Breaks._ +import purescala.ScalaPrinter +import scala.collection.mutable.{ Map => MutableMap } +import scala.reflect.runtime.universe +import invariant.engine._ +import invariant.factories._ +import invariant.util._ +import invariant.util.ExpressionTransformer._ +import invariant.structure._ +import invariant.structure.FunctionUtils._ +import Stats._ + +import Util._ +import PredicateUtil._ +import SolverUtil._ + +/** + * This class uses Farkas' lemma to try and falsify numerical disjuncts with templates provided one by one + */ +class ExistentialQuantificationSolver(ctx: InferenceContext, program: Program, + ctrTracker: ConstraintTracker, defaultEval: DefaultEvaluator) { + import NLTemplateSolver._ + val reporter = ctx.reporter + + var currentCtr: Expr = tru + private val farkasSolver = new FarkasLemmaSolver(ctx, program) + val disjunctChooser = new DisjunctChooser(ctx, program, ctrTracker, defaultEval) + + def getSolvedCtrs = currentCtr + + def generateCtrsForUNSAT(fd: FunDef, univModel: LazyModel, tempModel: Model) = { + // chooose a sat numerical disjunct from the model + val (lnctrs, temps, calls) = + time { + disjunctChooser.chooseNumericalDisjunct(ctrTracker.getVC(fd), univModel, tempModel.toMap) + } { chTime => + updateCounterTime(chTime, "Disj-choosing-time", "disjuncts") + updateCumTime(chTime, "Total-Choose-Time") + } + val disjunct = (lnctrs ++ temps) + if (temps.isEmpty) { + //here ants ^ conseq is sat (otherwise we wouldn't reach here) and there is no way to falsify this path + (fls, disjunct, calls) + } else + (farkasSolver.constraintsForUnsat(lnctrs, temps), disjunct, calls) + } + + /** + * Solves the nonlinear Farkas' constraints + */ + def solveConstraints(newctrs: Seq[Expr], oldModel: Model): (Option[Boolean], Model) = { + val newPart = createAnd(newctrs) + val newSize = atomNum(newPart) + val currSize = atomNum(currentCtr) + + Stats.updateCounterStats((newSize + currSize), "NLsize", "disjuncts") + if (verbose) reporter.info("# of atomic predicates: " + newSize + " + " + currSize) + + val combCtr = And(currentCtr, newPart) + val (res, newModel) = farkasSolver.solveFarkasConstraints(combCtr) + res match { + case _ if ctx.abort => + (None, Model.empty) // stop immediately + case None => + //here we have timed out while solving the non-linear constraints + if (verbose) reporter.info("NLsolver timed-out on the disjunct...") + (None, Model.empty) + case Some(false) => + currentCtr = fls + (Some(false), Model.empty) + case Some(true) => + currentCtr = combCtr + //new model may not have mappings for all the template variables, hence, use the mappings from earlier models + (Some(true), completeWithRefModel(newModel, oldModel)) + } + } +} diff --git a/src/main/scala/leon/invariant/templateSolvers/FarkasLemmaSolver.scala b/src/main/scala/leon/invariant/templateSolvers/FarkasLemmaSolver.scala index 1a4d1cd06a916eaf8197b97471e47f920d66f4cb..62eabca927059aa61c54a341d5339385a5c26e42 100644 --- a/src/main/scala/leon/invariant/templateSolvers/FarkasLemmaSolver.scala +++ b/src/main/scala/leon/invariant/templateSolvers/FarkasLemmaSolver.scala @@ -249,7 +249,7 @@ class FarkasLemmaSolver(ctx: InferenceContext, program: Program) { if (this.debugNLCtrs && hasInts(simpctrs)) { throw new IllegalStateException("Nonlinear constraints have integers: " + simpctrs) } - if (verbose && LinearConstraintUtil.isLinear(simpctrs)) { + if (verbose && LinearConstraintUtil.isLinearFormula(simpctrs)) { reporter.info("Constraints reduced to linear !") } if (this.dumpNLCtrs) { diff --git a/src/main/scala/leon/invariant/templateSolvers/NLTemplateSolver.scala b/src/main/scala/leon/invariant/templateSolvers/NLTemplateSolver.scala index 5b46b2df17271a16847b7c9498c9e18f01905c51..f0da67891a2240fbb175957a192d39b926f5939f 100644 --- a/src/main/scala/leon/invariant/templateSolvers/NLTemplateSolver.scala +++ b/src/main/scala/leon/invariant/templateSolvers/NLTemplateSolver.scala @@ -30,54 +30,16 @@ import Util._ import PredicateUtil._ import SolverUtil._ +object NLTemplateSolver { + val verbose = true +} + class NLTemplateSolver(ctx: InferenceContext, program: Program, rootFun: FunDef, ctrTracker: ConstraintTracker, minimizer: Option[(Expr, Model) => Model]) extends TemplateSolver(ctx, rootFun, ctrTracker) { - //flags controlling debugging - val debugUnflattening = false - val debugIncrementalVC = false - val debugElimination = false - val debugChooseDisjunct = false - val debugTheoryReduction = false - val debugAxioms = false - val verifyInvariant = false - val debugReducedFormula = false - val trackCompressedVCCTime = false - - //print flags - val verbose = true - val printCounterExample = false - val printPathToConsole = false - val dumpPathAsSMTLIB = false - val printCallConstriants = false - val dumpInstantiatedVC = false - - private val timeout = ctx.vcTimeout - private val leonctx = ctx.leonContext - - //flag controlling behavior - private val farkasSolver = new FarkasLemmaSolver(ctx, program) - private val startFromEarlierModel = true - private val disableCegis = true - private val useIncrementalSolvingForVCs = true - private val usePortfolio = false // portfolio has a bug in incremental solving - - // an evaluator for extracting models - val defaultEval = new DefaultEvaluator(leonctx, program) - // an evaluator for quicky checking the result of linear predicates - val linearEval = new LinearRelationEvaluator(ctx) - // solver factory - val solverFactory = - if (usePortfolio) { - if (useIncrementalSolvingForVCs) - throw new IllegalArgumentException("Cannot perform incremental solving with portfolio solvers!") - SolverFactory(() => new PortfolioSolver(leonctx, Seq(new SMTLIBCVC4Solver(leonctx, program), - new SMTLIBZ3Solver(leonctx, program))) with TimeoutSolver) - } else - SolverFactory.uninterpreted(leonctx, program) - + private val startFromEarlierModel = false // state for tracking the last model private var lastFoundModel: Option[Model] = None @@ -86,573 +48,12 @@ class NLTemplateSolver(ctx: InferenceContext, program: Program, * The result is a mapping from function definitions to the corresponding invariants. */ override def solve(tempIds: Set[Identifier], funs: Seq[FunDef]): (Option[Model], Option[Set[Call]]) = { - val initModel = - if (this.startFromEarlierModel && lastFoundModel.isDefined) { - val candModel = lastFoundModel.get - new Model(tempIds.map(id => - (id -> candModel.getOrElse(id, simplestValue(id.getType)))).toMap) - } else { - new Model(tempIds.map((id) => - (id -> simplestValue(id.getType))).toMap) - } - val cgSolver = new CGSolver(funs) - val (resModel, seenCalls, lastModel) = cgSolver.solveUNSAT(initModel) - lastFoundModel = Some(lastModel) - cgSolver.free + val initModel = completeModel( + (if (this.startFromEarlierModel && lastFoundModel.isDefined) lastFoundModel.get + else Model.empty), tempIds) + val univSolver = new UniversalQuantificationSolver(ctx, program, funs, ctrTracker, minimizer) + val (resModel, seenCalls) = univSolver.solveUNSAT(initModel, (m: Model) => lastFoundModel = Some(m)) + univSolver.free (resModel, seenCalls) } - - /** - * This solver does not use any theories other than UF/ADT. It assumes that other theories are axiomatized in the VC. - * This method can overloaded by the subclasses. - */ - protected def axiomsForTheory(formula: Formula, calls: Set[Call], model: LazyModel): Seq[Constraint] = Seq() - - //a helper method - //TODO: this should also handle reals - protected def doesSatisfyExpr(expr: Expr, model: LazyModel): Boolean = { - val compModel = variablesOf(expr).map { k => k -> model(k) }.toMap - defaultEval.eval(expr, new Model(compModel)).result match { - case Some(BooleanLiteral(true)) => true - case _ => false - } - } - - def splitVC(fd: FunDef) = { - val (paramPart, rest, modCons) = ctrTracker.getVC(fd).toUnflatExpr - if (ctx.usereals) { - (IntLiteralToReal(paramPart), IntLiteralToReal(rest), modCons) - } else (paramPart, rest, modCons) - } - - /** - * Counter-example guided solver - */ - class CGSolver(funs: Seq[FunDef]) { - - //for miscellaneous things - val trackNumericalDisjuncts = false - var numericalDisjuncts = List[Expr]() - - case class FunData(vcSolver: Solver with TimeoutSolver, modelCons: (Model, DefaultEvaluator) => FlatModel, - paramParts: Expr, simpleParts: Expr) - val funInfos = - if (useIncrementalSolvingForVCs) { - funs.foldLeft(Map[FunDef, FunData]()) { - case (acc, fd) => - val (paramPart, rest, modelCons) = splitVC(fd) - if (hasReals(rest) && hasInts(rest)) - throw new IllegalStateException("Non-param Part has both integers and reals: " + rest) - if (debugIncrementalVC) { - assert(getTemplateVars(rest).isEmpty) - println("For function: " + fd.id) - println("Param part: " + paramPart) - } - if (!ctx.abort) { // this is required to ensure that solvers are not created after interrupts - val vcSolver = solverFactory.getNewSolver() - vcSolver.assertCnstr(rest) - acc + (fd -> FunData(vcSolver, modelCons, paramPart, rest)) - } else acc - } - } else Map[FunDef, FunData]() - - def free = { - if (useIncrementalSolvingForVCs) - funInfos.foreach(entry => entry._2.vcSolver.free) - if (trackNumericalDisjuncts) - this.numericalDisjuncts = List[Expr]() - } - - //state for minimization - var minStarted = false - var minStartTime: Long = 0 - var minimized = false - - def minimizationInProgress { - if (!minStarted) { - minStarted = true - minStartTime = System.currentTimeMillis() - } - } - - def minimizationCompleted { - minStarted = false - val mintime = (System.currentTimeMillis() - minStartTime) - /*Stats.updateCounterTime(mintime, "minimization-time", "procs") - Stats.updateCumTime(mintime, "Total-Min-Time")*/ - } - - def solveUNSAT(initModel: Model): (Option[Model], Option[Set[Call]], Model) = solveUNSAT(initModel, tru, Seq(), Set()) - - def solveUNSAT(model: Model, inputCtr: Expr, solvedDisjs: Seq[Expr], seenCalls: Set[Call]): (Option[Model], Option[Set[Call]], Model) = { - if (verbose) { - reporter.info("Candidate invariants") - val candInvs = getAllInvariants(model) - candInvs.foreach((entry) => reporter.info(entry._1.id + "-->" + entry._2)) - } - val (res, newCtr, newModel, newdisjs, newcalls) = invalidateSATDisjunct(inputCtr, model) - res match { - case _ if ctx.abort => - (None, None, model) - case None => - //here, we cannot proceed and have to return unknown - //However, we can return the calls that need to be unrolled - (None, Some(seenCalls ++ newcalls), model) - case Some(false) => - //here, the vcs are unsatisfiable when instantiated with the invariant - if (minimizer.isDefined) { - //for stats - minimizationInProgress - if (minimized) { - minimizationCompleted - (Some(model), None, model) - } else { - val minModel = minimizer.get(inputCtr, model) - minimized = true - if (minModel == model) { - minimizationCompleted - (Some(model), None, model) - } else { - solveUNSAT(minModel, inputCtr, solvedDisjs, seenCalls) - } - } - } else { - (Some(model), None, model) - } - case Some(true) => - //here, we have found a new candidate invariant. Hence, the above process needs to be repeated - minimized = false - solveUNSAT(newModel, newCtr, solvedDisjs ++ newdisjs, seenCalls ++ newcalls) - } - } - - //TODO: this code does too much imperative update. - //TODO: use guards to block a path and not use the path itself - def invalidateSATDisjunct(inputCtr: Expr, model: Model): (Option[Boolean], Expr, Model, Seq[Expr], Set[Call]) = { - - val tempIds = model.map(_._1) - val tempVarMap: Map[Expr, Expr] = model.map((elem) => (elem._1.toVariable, elem._2)).toMap - val inputSize = atomNum(inputCtr) - - var disjsSolvedInIter = Seq[Expr]() - var callsInPaths = Set[Call]() - var conflictingFuns = funs.toSet - //mapping from the functions to the counter-example paths that were seen - var seenPaths = MutableMap[FunDef, Seq[Expr]]() - def updateSeenPaths(fd: FunDef, cePath: Expr): Unit = { - if (seenPaths.contains(fd)) { - seenPaths.update(fd, cePath +: seenPaths(fd)) - } else { - seenPaths += (fd -> Seq(cePath)) - } - } - - def invalidateDisjRecr(prevCtr: Expr): (Option[Boolean], Expr, Model) = { - - Stats.updateCounter(1, "disjuncts") - - var blockedCEs = false - var confFunctions = Set[FunDef]() - var confDisjuncts = Seq[Expr]() - val newctrsOpt = conflictingFuns.foldLeft(Some(Seq()): Option[Seq[Expr]]) { - case (None, _) => None - case (Some(acc), fd) => - val disableCounterExs = if (seenPaths.contains(fd)) { - blockedCEs = true - Not(createOr(seenPaths(fd))) - } else tru - if (ctx.abort) None - else - getUNSATConstraints(fd, model, disableCounterExs) match { - case None => - None - case Some(((disjunct, callsInPath), ctrsForFun)) => - if (ctrsForFun == tru) Some(acc) - else { - confFunctions += fd - confDisjuncts :+= disjunct - callsInPaths ++= callsInPath - //instantiate the disjunct - val cePath = simplifyArithmetic(TemplateInstantiator.instantiate(disjunct, tempVarMap)) - //some sanity checks - if (variablesOf(cePath).exists(TemplateIdFactory.IsTemplateIdentifier _)) - throw new IllegalStateException("Found template identifier in counter-example disjunct: " + cePath) - updateSeenPaths(fd, cePath) - Some(acc :+ ctrsForFun) - } - } - } - newctrsOpt match { - case None => - // give up, the VC cannot be decided - (None, tru, Model.empty) - case Some(newctrs) => - //update conflicting functions - conflictingFuns = confFunctions - if (newctrs.isEmpty) { - if (!blockedCEs) { - //yes, hurray,found an inductive invariant - (Some(false), prevCtr, model) - } else { - //give up, only hard paths remaining - reporter.info("- Exhausted all easy paths !!") - reporter.info("- Number of remaining hard paths: " + seenPaths.values.foldLeft(0)((acc, elem) => acc + elem.size)) - //TODO: what to unroll here ? - (None, tru, Model.empty) - } - } else { - //check that the new constraints does not have any reals - val newPart = createAnd(newctrs) - val newSize = atomNum(newPart) - Stats.updateCounterStats((newSize + inputSize), "NLsize", "disjuncts") - if (verbose) - reporter.info("# of atomic predicates: " + newSize + " + " + inputSize) - val combCtr = And(prevCtr, newPart) - val (res, newModel) = farkasSolver.solveFarkasConstraints(combCtr) - res match { - case _ if ctx.abort => - // stop immediately - (None, tru, Model.empty) - case None => { - //here we have timed out while solving the non-linear constraints - if (verbose) - if (!disableCegis) - reporter.info("NLsolver timed-out on the disjunct... starting cegis phase...") - else - reporter.info("NLsolver timed-out on the disjunct... blocking this disjunct...") - if (!disableCegis) { - val (cres, cctr, cmodel) = solveWithCegis(tempIds.toSet, createOr(confDisjuncts), inputCtr, Some(model)) - cres match { - case Some(true) => { - disjsSolvedInIter ++= confDisjuncts - (Some(true), And(inputCtr, cctr), cmodel) - } - case Some(false) => { - disjsSolvedInIter ++= confDisjuncts - //here also return the calls that needs to be unrolled - (None, fls, Model.empty) - } - case _ => { - if (verbose) reporter.info("retrying...") - Stats.updateCumStats(1, "retries") - //disable this disjunct and retry but, use the inputCtrs + the constraints generated by cegis from the next iteration - invalidateDisjRecr(And(inputCtr, cctr)) - } - } - } else { - if (verbose) reporter.info("retrying...") - Stats.updateCumStats(1, "retries") - invalidateDisjRecr(inputCtr) - } - } - case Some(false) => { - //reporter.info("- Number of explored paths (of the DAG) in this unroll step: " + exploredPaths) - disjsSolvedInIter ++= confDisjuncts - (None, fls, Model.empty) - } - case Some(true) => { - disjsSolvedInIter ++= confDisjuncts - //new model may not have mappings for all the template variables, hence, use the mappings from earlier models - val compModel = new Model(tempIds.map((id) => { - if (newModel.isDefinedAt(id)) - (id -> newModel(id)) - else - (id -> model(id)) - }).toMap) - (Some(true), combCtr, compModel) - } - } - } - } - } - val (res, newctr, newmodel) = invalidateDisjRecr(inputCtr) - (res, newctr, newmodel, disjsSolvedInIter, callsInPaths) - } - - def solveWithCegis(tempIds: Set[Identifier], expr: Expr, precond: Expr, initModel: Option[Model]): (Option[Boolean], Expr, Model) = { - val cegisSolver = new CegisCore(ctx, program, timeout.toInt, NLTemplateSolver.this) - val (res, ctr, model) = cegisSolver.solve(tempIds, expr, precond, solveAsInt = false, initModel) - if (res.isEmpty) - reporter.info("cegis timed-out on the disjunct...") - (res, ctr, model) - } - - protected def instantiateTemplate(e: Expr, tempVarMap: Map[Expr, Expr]): Expr = { - if (ctx.usereals) replace(tempVarMap, e) - else - simplifyArithmetic(TemplateInstantiator.instantiate(e, tempVarMap)) - } - - /** - * Constructs a quantifier-free non-linear constraint for unsatisfiability - */ - def getUNSATConstraints(fd: FunDef, inModel: Model, disableCounterExs: Expr): Option[((Expr, Set[Call]), Expr)] = { - - val tempVarMap: Map[Expr, Expr] = inModel.map((elem) => (elem._1.toVariable, elem._2)).toMap - val (solver, instExpr, modelCons) = - if (useIncrementalSolvingForVCs) { - val funData = funInfos(fd) - val instParamPart = instantiateTemplate(funData.paramParts, tempVarMap) - (funData.vcSolver, And(instParamPart, disableCounterExs), funData.modelCons) - } else { - val (paramPart, rest, modCons) = ctrTracker.getVC(fd).toUnflatExpr - val instPart = instantiateTemplate(paramPart, tempVarMap) - (solverFactory.getNewSolver(), createAnd(Seq(rest, instPart, disableCounterExs)), modCons) - } - //For debugging - if (dumpInstantiatedVC) { - val filename = "vcInst-" + FileCountGUID.getID - val wr = new PrintWriter(new File(filename+".txt")) - val fullExpr = - if (useIncrementalSolvingForVCs) { - And(funInfos(fd).simpleParts, instExpr) - } else instExpr - wr.println("Function name: " + fd.id+" \nFormula expr: ") - ExpressionTransformer.PrintWithIndentation(wr, fullExpr) - wr.close() - } - if(debugUnflattening){ - ctrTracker.getVC(fd).checkUnflattening(tempVarMap, - SimpleSolverAPI(SolverFactory(() => solverFactory.getNewSolver())), - defaultEval) - } - // sanity check - if (hasMixedIntReals(instExpr)) - throw new IllegalStateException("Instantiated VC of " + fd.id + " contains mixed integer/reals: " + instExpr) - - //reporter.info("checking VC inst ...") - solver.setTimeout(timeout * 1000) - val (res, packedModel) = - time { - if (useIncrementalSolvingForVCs) { - solver.push - solver.assertCnstr(instExpr) - val solRes = solver.check match { - case _ if ctx.abort => - (None, Model.empty) - case r @ Some(true) => - (r, solver.getModel) - case r => (r, Model.empty) - } - solver.pop() - solRes - } else - SimpleSolverAPI(SolverFactory(() => solver)).solveSAT(instExpr) - } { vccTime => - if (verbose) reporter.info("checked VC inst... in " + vccTime / 1000.0 + "s") - updateCounterTime(vccTime, "VC-check-time", "disjuncts") - updateCumTime(vccTime, "TotalVCCTime") - } - //for statistics - if (trackCompressedVCCTime) { - val compressedVC = - unflatten(simplifyArithmetic(instantiateTemplate( - ctrTracker.getVC(fd).eliminateBlockers, tempVarMap))) - Stats.updateCounterStats(atomNum(compressedVC), "Compressed-VC-size", "disjuncts") - time { - SimpleSolverAPI(SolverFactory(() => solverFactory.getNewSolver())).solveSAT(compressedVC) - } { compTime => - Stats.updateCumTime(compTime, "TotalCompressVCCTime") - reporter.info("checked compressed VC... in " + compTime / 1000.0 + "s") - } - } - res match { - case None => None // cannot check satisfiability of VCinst !! - case Some(false) => - Some(((fls, Set()), tru)) //do not generate any constraints - case Some(true) => - //For debugging purposes. - if (verbose) reporter.info("Function: " + fd.id + "--Found candidate invariant is not a real invariant! ") - if (printCounterExample) { - reporter.info("Model: " + packedModel) - } - //get the disjuncts that are satisfied - val model = modelCons(packedModel, defaultEval) - val (data, newctr) = - time { generateCtrsFromDisjunct(fd, model) } { chTime => - updateCounterTime(chTime, "Disj-choosing-time", "disjuncts") - updateCumTime(chTime, "Total-Choose-Time") - } - if (newctr == tru) throw new IllegalStateException("Cannot find a counter-example path!!") - Some((data, newctr)) - } - } - - protected def generateCtrsFromDisjunct(fd: FunDef, initModel: LazyModel): ((Expr, Set[Call]), Expr) = { - - val formula = ctrTracker.getVC(fd) - //this picks the satisfiable disjunct of the VC modulo axioms - val satCtrs = formula.pickSatDisjunct(formula.firstRoot, initModel) - //for debugging - if (debugChooseDisjunct || printPathToConsole || dumpPathAsSMTLIB || verifyInvariant) { - val pathctrs = satCtrs.map(_.toExpr) - val plainFormula = createAnd(pathctrs) - val pathcond = simplifyArithmetic(plainFormula) - - if (debugChooseDisjunct) { - satCtrs.filter(_.isInstanceOf[LinearConstraint]).map(_.toExpr).foreach((ctr) => { - if (!doesSatisfyExpr(ctr, initModel)) - throw new IllegalStateException("Path ctr not satisfied by model: " + ctr) - }) - } - if (verifyInvariant) { - println("checking invariant for path...") - val sat = checkInvariant(pathcond, leonctx, program) - } - if (printPathToConsole) { - //val simpcond = ExpressionTransformer.unFlatten(pathcond, variablesOf(pathcond).filterNot(TVarFactory.isTemporary _)) - val simpcond = pathcond - println("Full-path: " + ScalaPrinter(simpcond)) - val filename = "full-path-" + FileCountGUID.getID + ".txt" - val wr = new PrintWriter(new File(filename)) - ExpressionTransformer.PrintWithIndentation(wr, simpcond) - println("Printed to file: " + filename) - wr.flush() - wr.close() - } - if (dumpPathAsSMTLIB) { - val filename = "pathcond" + FileCountGUID.getID + ".smt2" - toZ3SMTLIB(pathcond, filename, "QF_NIA", leonctx, program) - println("Path dumped to: " + filename) - } - } - - var calls = Set[Call]() - var adtExprs = Seq[Expr]() - satCtrs.foreach { - case t: Call => calls += t - case t: ADTConstraint if (t.cons || t.sel) => adtExprs :+= t.expr - // TODO: ignoring all set constraints here, fix this - case _ => ; - } - val callExprs = calls.map(_.toExpr) - - val axiomCtrs = time { - ctrTracker.specInstantiator.axiomsForCalls(formula, calls, initModel) - } { updateCumTime(_, "Total-AxiomChoose-Time") } - - //here, handle theory operations by reducing them to axioms. - //Note: uninterpreted calls/ADTs are handled below as they are more general. Here, we handle - //other theory axioms like: multiplication, sets, arrays, maps etc. - val theoryCtrs = time { - axiomsForTheory(formula, calls, initModel) - } { updateCumTime(_, "Total-TheoryAxiomatization-Time") } - - //Finally, eliminate UF/ADT - // convert all adt constraints to 'cons' ctrs, and expand the model - val selTrans = new SelectorToCons() - val cons = selTrans.selToCons(adtExprs) - val expModel = selTrans.getModel(initModel) - // get constraints for UFADTs - val callCtrs = time { - (new UFADTEliminator(leonctx, program)).constraintsForCalls((callExprs ++ cons), - linearEval.predEval(expModel)).map(ConstraintUtil.createConstriant _) - } { updateCumTime(_, "Total-ElimUF-Time") } - - //exclude guards, separate calls and cons from the rest - var lnctrs = Set[LinearConstraint]() - var temps = Set[LinearTemplate]() - (satCtrs ++ callCtrs ++ axiomCtrs ++ theoryCtrs).foreach { - case t: LinearConstraint => lnctrs += t - case t: LinearTemplate => temps += t - case _ => ; - } - if (debugChooseDisjunct) { - lnctrs.map(_.toExpr).foreach((ctr) => { - if (!doesSatisfyExpr(ctr, expModel)) - throw new IllegalStateException("Ctr not satisfied by model: " + ctr) - }) - } - if (debugTheoryReduction) { - val simpPathCond = createAnd((lnctrs ++ temps).map(_.template).toSeq) - if (verifyInvariant) { - println("checking invariant for simp-path...") - checkInvariant(simpPathCond, leonctx, program) - } - } - if (trackNumericalDisjuncts) { - numericalDisjuncts :+= createAnd((lnctrs ++ temps).map(_.template).toSeq) - } - val (data, nlctr) = processNumCtrs(lnctrs.toSeq, temps.toSeq) - ((data, calls), nlctr) - } - - /** - * Endpoint of the pipeline. Invokes the Farkas Lemma constraint generation. - */ - def processNumCtrs(lnctrs: Seq[LinearConstraint], temps: Seq[LinearTemplate]): (Expr, Expr) = { - //here we are invalidating A^~(B) - if (temps.isEmpty) { - //here ants ^ conseq is sat (otherwise we wouldn't reach here) and there is no way to falsify this path - (createAnd(lnctrs.map(_.toExpr)), fls) - } else { - if (debugElimination) { - //println("Path Constraints (before elim): "+(lnctrs ++ temps)) - if (verifyInvariant) { - println("checking invariant for disjunct before elimination...") - checkInvariant(createAnd((lnctrs ++ temps).map(_.template)), leonctx, program) - } - } - //compute variables to be eliminated - val ctrVars = lnctrs.foldLeft(Set[Identifier]())((acc, lc) => acc ++ variablesOf(lc.toExpr)) - val tempVars = temps.foldLeft(Set[Identifier]())((acc, lt) => acc ++ variablesOf(lt.template)) - val elimVars = ctrVars.diff(tempVars) - // for debugging - val debugger = - if (debugElimination && verifyInvariant) { - Some((ctrs: Seq[LinearConstraint]) => { - val debugRes = checkInvariant(createAnd((ctrs ++ temps).map(_.template)), leonctx, program) - }) - } else None - val elimLnctrs = time { - LinearConstraintUtil.apply1PRuleOnDisjunct(lnctrs, elimVars, debugger) - } { updateCumTime(_, "ElimTime") } - - if (debugElimination) { - println("Path constriants (after elimination): " + elimLnctrs) - if (verifyInvariant) { - println("checking invariant for disjunct after elimination...") - checkInvariant(createAnd((elimLnctrs ++ temps).map(_.template)), leonctx, program) - } - } - //for stats - if (ctx.dumpStats) { - var elimCtrCount = 0 - var elimCtrs = Seq[LinearConstraint]() - var elimRems = Set[Identifier]() - elimLnctrs.foreach((lc) => { - val evars = variablesOf(lc.toExpr).intersect(elimVars) - if (evars.nonEmpty) { - elimCtrs :+= lc - elimCtrCount += 1 - elimRems ++= evars - } - }) - Stats.updateCounterStats((elimVars.size - elimRems.size), "Eliminated-Vars", "disjuncts") - Stats.updateCounterStats((lnctrs.size - elimLnctrs.size), "Eliminated-Atoms", "disjuncts") - Stats.updateCounterStats(temps.size, "Param-Atoms", "disjuncts") - Stats.updateCounterStats(lnctrs.size, "NonParam-Atoms", "disjuncts") - } - val newLnctrs = elimLnctrs.toSet.toSeq - - //TODO:Remove transitive facts. E.g. a <= b, b <= c, a <=c can be simplified by dropping a <= c - //TODO: simplify the formulas and remove implied conjuncts if possible (note the formula is satisfiable, so there can be no inconsistencies) - //e.g, remove: a <= b if we have a = b or if a < b - //Also, enrich the rules for quantifier elimination: try z3 quantifier elimination on variables that have an equality. - //TODO: Use the dependence chains in the formulas to identify what to assertionize - // and what can never be implied by solving for the templates - val disjunct = createAnd((newLnctrs ++ temps).map(_.template)) - val implCtrs = farkasSolver.constraintsForUnsat(newLnctrs, temps) - //for debugging - if (debugReducedFormula) { - println("Final Path Constraints: " + disjunct) - if (verifyInvariant) { - println("checking invariant for final disjunct... ") - checkInvariant(disjunct, leonctx, program) - } - } - (disjunct, implCtrs) - } - } - } } diff --git a/src/main/scala/leon/invariant/templateSolvers/NLTemplateSolverWithMult.scala b/src/main/scala/leon/invariant/templateSolvers/NLTemplateSolverWithMult.scala index 3ae146a8a862bafefe488241e8dd3384099664eb..40ecbf14c8e30bede887e380b1fd27243a6e5de8 100644 --- a/src/main/scala/leon/invariant/templateSolvers/NLTemplateSolverWithMult.scala +++ b/src/main/scala/leon/invariant/templateSolvers/NLTemplateSolverWithMult.scala @@ -17,78 +17,79 @@ class NLTemplateSolverWithMult(ctx: InferenceContext, program: Program, rootFun: ctrTracker: ConstraintTracker, minimizer: Option[(Expr, Model) => Model]) extends NLTemplateSolver(ctx, program, rootFun, ctrTracker, minimizer) { - val axiomFactory = new AxiomFactory(ctx) - - override def getVCForFun(fd: FunDef): Expr = { - val plainvc = ctrTracker.getVC(fd).toExpr - val nlvc = multToTimes(plainvc) - nlvc - } - - override def splitVC(fd: FunDef) = { - val (paramPart, rest, modCons) = super.splitVC(fd) - (multToTimes(paramPart), multToTimes(rest), modCons) - } - - override def axiomsForTheory(formula: Formula, calls: Set[Call], model: LazyModel): Seq[Constraint] = { - - //in the sequel we instantiate axioms for multiplication - val inst1 = unaryMultAxioms(formula, calls, linearEval.predEval(model)) - val inst2 = binaryMultAxioms(formula, calls, linearEval.predEval(model)) - val multCtrs = (inst1 ++ inst2).flatMap { - case And(args) => args.map(ConstraintUtil.createConstriant _) - case e => Seq(ConstraintUtil.createConstriant(e)) - } - - Stats.updateCounterStats(multCtrs.size, "MultAxiomBlowup", "disjuncts") - ctx.reporter.info("Number of multiplication induced predicates: " + multCtrs.size) - multCtrs - } - - def chooseSATPredicate(expr: Expr, predEval: (Expr => Option[Boolean])): Expr = { - val norme = ExpressionTransformer.normalizeExpr(expr, ctx.multOp) - val preds = norme match { - case Or(args) => args - case Operator(_, _) => Seq(norme) - case _ => throw new IllegalStateException("Not(ant) is not in expected format: " + norme) - } - //pick the first predicate that holds true - preds.collectFirst { case pred @ _ if predEval(pred).get => pred }.get - } - - def isMultOp(call: Call): Boolean = { - isMultFunctions(call.fi.tfd.fd) - } - - def unaryMultAxioms(formula: Formula, calls: Set[Call], predEval: (Expr => Option[Boolean])): Seq[Expr] = { - val axioms = calls.flatMap { - case call @ _ if (isMultOp(call) && axiomFactory.hasUnaryAxiom(call)) => { - val (ant, conseq) = axiomFactory.unaryAxiom(call) - if (predEval(ant).get) - Seq(ant, conseq) - else - Seq(chooseSATPredicate(Not(ant), predEval)) - } - case _ => Seq() - } - axioms.toSeq - } - - def binaryMultAxioms(formula: Formula, calls: Set[Call], predEval: (Expr => Option[Boolean])): Seq[Expr] = { - - val mults = calls.filter(call => isMultOp(call) && axiomFactory.hasBinaryAxiom(call)) - val product = cross(mults, mults).collect { case (c1, c2) if c1 != c2 => (c1, c2) } - - ctx.reporter.info("Theory axioms: " + product.size) - Stats.updateCumStats(product.size, "-Total-theory-axioms") - - val newpreds = product.flatMap(pair => { - val axiomInsts = axiomFactory.binaryAxiom(pair._1, pair._2) - axiomInsts.flatMap { - case (ant, conseq) if predEval(ant).get => Seq(ant, conseq) //if axiom-pre holds. - case (ant, _) => Seq(chooseSATPredicate(Not(ant), predEval)) //if axiom-pre does not hold. - } - }) - newpreds.toSeq - } + throw new IllegalStateException("Not Maintained!!") +// val axiomFactory = new AxiomFactory(ctx) +// +// override def getVCForFun(fd: FunDef): Expr = { +// val plainvc = ctrTracker.getVC(fd).toExpr +// val nlvc = multToTimes(plainvc) +// nlvc +// } +// +// override def splitVC(fd: FunDef) = { +// val (paramPart, rest, modCons) = super.splitVC(fd) +// (multToTimes(paramPart), multToTimes(rest), modCons) +// } +// +// override def axiomsForTheory(formula: Formula, calls: Set[Call], model: LazyModel): Seq[Constraint] = { +// +// //in the sequel we instantiate axioms for multiplication +// val inst1 = unaryMultAxioms(formula, calls, linearEval.predEval(model)) +// val inst2 = binaryMultAxioms(formula, calls, linearEval.predEval(model)) +// val multCtrs = (inst1 ++ inst2).flatMap { +// case And(args) => args.map(ConstraintUtil.createConstriant _) +// case e => Seq(ConstraintUtil.createConstriant(e)) +// } +// +// Stats.updateCounterStats(multCtrs.size, "MultAxiomBlowup", "disjuncts") +// ctx.reporter.info("Number of multiplication induced predicates: " + multCtrs.size) +// multCtrs +// } +// +// def chooseSATPredicate(expr: Expr, predEval: (Expr => Option[Boolean])): Expr = { +// val norme = ExpressionTransformer.normalizeExpr(expr, ctx.multOp) +// val preds = norme match { +// case Or(args) => args +// case Operator(_, _) => Seq(norme) +// case _ => throw new IllegalStateException("Not(ant) is not in expected format: " + norme) +// } +// //pick the first predicate that holds true +// preds.collectFirst { case pred @ _ if predEval(pred).get => pred }.get +// } +// +// def isMultOp(call: Call): Boolean = { +// isMultFunctions(call.fi.tfd.fd) +// } +// +// def unaryMultAxioms(formula: Formula, calls: Set[Call], predEval: (Expr => Option[Boolean])): Seq[Expr] = { +// val axioms = calls.flatMap { +// case call @ _ if (isMultOp(call) && axiomFactory.hasUnaryAxiom(call)) => { +// val (ant, conseq) = axiomFactory.unaryAxiom(call) +// if (predEval(ant).get) +// Seq(ant, conseq) +// else +// Seq(chooseSATPredicate(Not(ant), predEval)) +// } +// case _ => Seq() +// } +// axioms.toSeq +// } +// +// def binaryMultAxioms(formula: Formula, calls: Set[Call], predEval: (Expr => Option[Boolean])): Seq[Expr] = { +// +// val mults = calls.filter(call => isMultOp(call) && axiomFactory.hasBinaryAxiom(call)) +// val product = cross(mults, mults).collect { case (c1, c2) if c1 != c2 => (c1, c2) } +// +// ctx.reporter.info("Theory axioms: " + product.size) +// Stats.updateCumStats(product.size, "-Total-theory-axioms") +// +// val newpreds = product.flatMap(pair => { +// val axiomInsts = axiomFactory.binaryAxiom(pair._1, pair._2) +// axiomInsts.flatMap { +// case (ant, conseq) if predEval(ant).get => Seq(ant, conseq) //if axiom-pre holds. +// case (ant, _) => Seq(chooseSATPredicate(Not(ant), predEval)) //if axiom-pre does not hold. +// } +// }) +// newpreds.toSeq +// } } diff --git a/src/main/scala/leon/invariant/templateSolvers/TemplateSolver.scala b/src/main/scala/leon/invariant/templateSolvers/TemplateSolver.scala index 4bd170f8c9a49a1ce15111431e02cb174cf1be3a..e42817a1e51bea5cd4c65e08f549489add3249f0 100644 --- a/src/main/scala/leon/invariant/templateSolvers/TemplateSolver.scala +++ b/src/main/scala/leon/invariant/templateSolvers/TemplateSolver.scala @@ -17,16 +17,9 @@ import PredicateUtil._ import ExpressionTransformer._ abstract class TemplateSolver(ctx: InferenceContext, val rootFun: FunDef, - ctrTracker: ConstraintTracker) { + val ctrTracker: ConstraintTracker) { protected val reporter = ctx.reporter - //protected val cg = CallGraphUtil.constructCallGraph(program) - - //some constants - protected val fls = BooleanLiteral(false) - protected val tru = BooleanLiteral(true) - //protected val zero = IntLiteral(0) - private val dumpVCtoConsole = false private val dumpVCasText = false @@ -34,26 +27,14 @@ abstract class TemplateSolver(ctx: InferenceContext, val rootFun: FunDef, * Completes a model by adding mapping to new template variables */ def completeModel(model: Model, ids: Set[Identifier]) = { - val idmap = ids.map((id) => { + val idmap = ids.map { id => if (!model.isDefinedAt(id)) { (id, simplestValue(id.getType)) } else (id, model(id)) - }).toMap + }.toMap new Model(idmap) } - /** - * Computes the invariant for all the procedures given a mapping for the - * template variables. - */ - def getAllInvariants(model: Model): Map[FunDef, Expr] = { - val templates = ctrTracker.getFuncs.collect { - case fd if fd.hasTemplate => - fd -> fd.getTemplate - } - TemplateInstantiator.getAllInvariants(model, templates.toMap) - } - var vcCache = Map[FunDef, Expr]() protected def getVCForFun(fd: FunDef): Expr = { vcCache.getOrElse(fd, { diff --git a/src/main/scala/leon/invariant/templateSolvers/UFADTEliminator.scala b/src/main/scala/leon/invariant/templateSolvers/UFADTEliminator.scala index 93d2608fff979e00d910cfd752c7368efeb2c810..4edc0f59a9e12ad702e4913b5d7080e716177b56 100644 --- a/src/main/scala/leon/invariant/templateSolvers/UFADTEliminator.scala +++ b/src/main/scala/leon/invariant/templateSolvers/UFADTEliminator.scala @@ -5,10 +5,12 @@ import purescala.Definitions._ import purescala.Expressions._ import purescala.Extractors._ import purescala.Types._ -import invariant.datastructure.UndirectedGraph +import invariant.datastructure._ import invariant.util._ import leon.purescala.TypeOps import PredicateUtil._ +import Stats._ +import scala.collection.mutable.{ Set => MutableSet, Map => MutableMap, MutableList } class UFADTEliminator(ctx: LeonContext, program: Program) { @@ -17,65 +19,108 @@ class UFADTEliminator(ctx: LeonContext, program: Program) { val reporter = ctx.reporter val verbose = false - def collectCompatibleCalls(calls: Set[Expr]) = { - //compute the cartesian product of the calls and select the pairs having the same function symbol and also implied by the precond - val vec = calls.toArray - val size = calls.size - var j = 0 - //for stats - var tuples = 0 - var functions = 0 - var adts = 0 - val product = vec.foldLeft(Set[(Expr, Expr)]())((acc, call) => { + // def collectCompatibleCalls(calls: Set[Expr]) = { + // //compute the cartesian product of the calls and select the pairs having the same function symbol and also implied by the precond + // val vec = calls.toArray + // val size = calls.size + // var j = 0 + // //for stats + // var tuples = 0 + // var functions = 0 + // var adts = 0 + // val product = vec.foldLeft(Set[(Expr, Expr)]())((acc, call) => { + // //an optimization: here we can exclude calls to maxFun from axiomatization, they will be inlined anyway + // /*val shouldConsider = if(InvariantisCallExpr(call)) { + // val BinaryOperator(_,FunctionInvocation(calledFun,_), _) = call + // if(calledFun == DepthInstPhase.maxFun) false + // else true + // } else true*/ + // var pairs = Set[(Expr, Expr)]() + // for (i <- j + 1 until size) { + // val call2 = vec(i) + // if (mayAlias(call, call2)) { + // call match { + // case Equals(_, fin: FunctionInvocation) => functions += 1 + // case Equals(_, tup: Tuple) => tuples += 1 + // case _ => adts += 1 + // } + // if (debugAliases) + // println("Aliases: " + call + "," + call2) + // pairs ++= Set((call, call2)) + // } else { + // if (debugAliases) { + // (call, call2) match { + // case (Equals(_, t1 @ Tuple(_)), Equals(_, t2 @ Tuple(_))) => + // println("No Aliases: " + t1.getType + "," + t2.getType) + // case _ => println("No Aliases: " + call + "," + call2) + // } + // } + // } + // } + // j += 1 + // acc ++ pairs + // }) + // if (verbose) reporter.info("Number of compatible calls: " + product.size) + // Stats.updateCounterStats(product.size, "Compatible-Calls", "disjuncts") + // Stats.updateCumStats(functions, "Compatible-functioncalls") + // Stats.updateCumStats(adts, "Compatible-adtcalls") + // Stats.updateCumStats(tuples, "Compatible-tuples") + // product + // } + def collectCompatibleTerms(terms: Set[Expr]) = { + class Comp(val key: Either[TypedFunDef, TypeTree]) { + override def equals(other: Any) = other match { + case otherComp: Comp => mayAlias(key, otherComp.key) + case _ => false + } + // an weaker property whose equality is necessary for mayAlias + val hashcode = + key match { + case Left(TypedFunDef(fd, _)) => fd.id.hashCode() + case Right(ct: CaseClassType) => ct.classDef.id.hashCode() + case Right(tp @ TupleType(tps)) => (tps.hashCode() << 3) ^ tp.dimension + } + override def hashCode = hashcode + } + val compTerms = MutableMap[Comp, MutableList[Expr]]() + terms.foreach { term => //an optimization: here we can exclude calls to maxFun from axiomatization, they will be inlined anyway /*val shouldConsider = if(InvariantisCallExpr(call)) { val BinaryOperator(_,FunctionInvocation(calledFun,_), _) = call if(calledFun == DepthInstPhase.maxFun) false else true } else true*/ - var pairs = Set[(Expr, Expr)]() - for (i <- j + 1 until size) { - val call2 = vec(i) - if (mayAlias(call, call2)) { - - call match { - case Equals(_, fin: FunctionInvocation) => functions += 1 - case Equals(_, tup: Tuple) => tuples += 1 - case _ => adts += 1 - } - if (debugAliases) - println("Aliases: " + call + "," + call2) - - pairs ++= Set((call, call2)) - - } else { - if (debugAliases) { - (call, call2) match { - case (Equals(_, t1 @ Tuple(_)), Equals(_, t2 @ Tuple(_))) => - println("No Aliases: " + t1.getType + "," + t2.getType) - case _ => println("No Aliases: " + call + "," + call2) - } - } + val compKey: Either[TypedFunDef, TypeTree] = term match { + case Equals(_, rhs) => rhs match { // tuple types require special handling before they are used as keys + case tp: Tuple => + val TupleType(tps) = tp.getType + Right(TupleType(tps.map { TypeOps.bestRealType })) + case FunctionInvocation(tfd, _) => Left(tfd) + case CaseClass(ct, _) => Right(ct) } } - j += 1 - acc ++ pairs - }) - if (verbose) reporter.info("Number of compatible calls: " + product.size) - /*reporter.info("Compatible Tuples: "+tuples) - reporter.info("Compatible Functions+ADTs: "+(functions+adts))*/ - Stats.updateCounterStats(product.size, "Compatible-Calls", "disjuncts") - Stats.updateCumStats(functions, "Compatible-functioncalls") - Stats.updateCumStats(adts, "Compatible-adtcalls") - Stats.updateCumStats(tuples, "Compatible-tuples") - product + val comp = new Comp(compKey) + val compList = compTerms.getOrElse(comp, { + val newl = new MutableList[Expr]() + compTerms += (comp -> newl) + newl + }) + compList += term + } + if (debugAliases) { + compTerms.foreach { + case (_, v) => println("Aliases: " + v.mkString("{", ",", "}")) + } + } + compTerms } /** * Convert the theory formula into linear arithmetic formula. * The calls could be functions calls or ADT constructor calls. * 'predEval' is an evaluator that evaluates a predicate to a boolean value + * TODO: is type parameter inheritance handled correctly ? */ def constraintsForCalls(calls: Set[Expr], predEval: (Expr => Option[Boolean])): Seq[Expr] = { @@ -99,7 +144,6 @@ class UFADTEliminator(ctx: LeonContext, program: Program) { } def predForEquality(call1: Expr, call2: Expr): Seq[Expr] = { - val eqs = if (isCallExpr(call1)) { val (_, rhs) = axiomatizeCalls(call1, call2) Seq(rhs) @@ -122,13 +166,11 @@ class UFADTEliminator(ctx: LeonContext, program: Program) { } def predForDisequality(call1: Expr, call2: Expr): Seq[Expr] = { - val (ants, _) = if (isCallExpr(call1)) { axiomatizeCalls(call1, call2) } else { axiomatizeADTCons(call1, call2) } - if (makeEfficient && ants.exists { case Equals(l, r) if (l.getType != RealType && l.getType != BooleanType && l.getType != IntegerType) => true case _ => false @@ -163,60 +205,82 @@ class UFADTEliminator(ctx: LeonContext, program: Program) { } } - var eqGraph = new UndirectedGraph[Expr]() //an equality graph - var neqSet = Set[(Expr, Expr)]() - val product = collectCompatibleCalls(calls) - val newctrs = product.foldLeft(Seq[Expr]())((acc, pair) => { - val (call1, call2) = (pair._1, pair._2) - //note: here it suffices to check for adjacency and not reachability of calls (i.e, exprs). - //This is because the transitive equalities (corresponding to rechability) are encoded by the generated equalities. - if (!eqGraph.BFSReach(call1, call2) && !neqSet.contains((call1, call2)) && !neqSet.contains((call2, call1))) { - doesAlias(call1, call2) match { - case Some(true) => - eqGraph.addEdge(call1, call2) - acc ++ predForEquality(call1, call2) - case Some(false) => - neqSet ++= Set((call1, call2)) - acc ++ predForDisequality(call1, call2) - case _ => - // in this case, we construct a weaker disjunct by dropping this predicate - acc + var equivClasses = new DisjointSets[Expr]() + var neqSet = MutableSet[(Expr, Expr)]() + val termClasses = collectCompatibleTerms(calls) + val preds = MutableList[Expr]() + termClasses.foreach { + case (_, compTerms) => + val vec = compTerms.toArray + val size = vec.size + vec.zipWithIndex.foreach { + case (t1, j) => + (j + 1 until size).foreach { i => + val t2 = vec(i) + if (compatibleTArgs(termTArgs(t1), termTArgs(t2))) { + //note: here we omit constraints that encode transitive equality facts + val class1 = equivClasses.findOrCreate(t1) + val class2 = equivClasses.findOrCreate(t2) + if (class1 != class2 && !neqSet.contains((t1, t2)) && !neqSet.contains((t2, t1))) { + doesAlias(t1, t2) match { + case Some(true) => + equivClasses.union(class1, class2) + preds ++= predForEquality(t1, t2) + case Some(false) => + neqSet ++= Set((t1, t2)) + preds ++= predForDisequality(t1, t2) + case _ => + // in this case, we construct a weaker disjunct by dropping this predicate + } + } + } + } } - } else acc - }) - //reporter.info("Number of equal calls: " + eqGraph.getEdgeCount) - newctrs + } + Stats.updateCounterStats(preds.size, "CallADT-Constraints", "disjuncts") + preds.toSeq + } + + def termTArgs(t: Expr) = { + t match { + case Equals(_, e) => + e match { + case FunctionInvocation(TypedFunDef(_, tps), _) => tps + case CaseClass(ct, _) => ct.tps + case tp: Tuple => + val TupleType(tps) = tp.getType + tps + } + } } /** * This function actually checks if two non-primitive expressions could have the same value * (when some constraints on their arguments hold). * Remark: notice that when the expressions have ADT types, then this is basically a form of may-alias check. - * TODO: handling generic can become very trickier here. + * TODO: handling type parameters can become very trickier here. + * For now ignoring type parameters of functions and classes. (This is complete, but may be less efficient) */ - def mayAlias(e1: Expr, e2: Expr): Boolean = { - (e1, e2) match { - case (Equals(_, FunctionInvocation(fd1, _)), Equals(_, FunctionInvocation(fd2, _))) => { - (fd1.id == fd2.id && fd1.fd.tparams == fd2.fd.tparams) - } - case (Equals(_, CaseClass(cd1, _)), Equals(_, CaseClass(cd2, _))) => { - // if (cd1.id == cd2.id && cd1.tps != cd2.tps) println("Invalidated the classes " + e1 + " " + e2) - (cd1.id == cd2.id && cd1.tps == cd2.tps) - } - case (Equals(_, tp1 @ Tuple(e1)), Equals(_, tp2 @ Tuple(e2))) => { - //get the types and check if the types are compatible - val TupleType(tps1) = tp1.getType - val TupleType(tps2) = tp2.getType - (tps1 zip tps2).forall(pair => { - val (t1, t2) = pair - val lub = TypeOps.leastUpperBound(t1, t2) - (lub == Some(t1) || lub == Some(t2)) - }) - } + def mayAlias(term1: Either[TypedFunDef, TypeTree], term2: Either[TypedFunDef, TypeTree]): Boolean = { + (term1, term2) match { + case (Left(TypedFunDef(fd1, _)), Left(TypedFunDef(fd2, _))) => + fd1.id == fd2.id + case (Right(ct1: CaseClassType), Right(ct2: CaseClassType)) => + ct1.classDef.id == ct2.classDef.id + case (Right(tp1 @ TupleType(tps1)), Right(tp2 @ TupleType(tps2))) if tp1.dimension == tp2.dimension => + compatibleTArgs(tps1, tps2) //get the types and check if the types are compatible case _ => false } } + def compatibleTArgs(tps1: Seq[TypeTree], tps2: Seq[TypeTree]): Boolean = { + (tps1 zip tps2).forall { + case (t1, t2) => + val lub = TypeOps.leastUpperBound(t1, t2) + (lub == Some(t1) || lub == Some(t2)) // is t1 a super type of t2 + } + } + /** * This procedure generates constraints for the calls to be equal */ @@ -239,7 +303,6 @@ class UFADTEliminator(ctx: LeonContext, program: Program) { * The returned pairs should be interpreted as a bidirectional implication */ def axiomatizeADTCons(sel1: Expr, sel2: Expr): (Seq[Expr], Expr) = { - val (v1, args1, v2, args2) = sel1 match { case Equals(r1 @ Variable(_), CaseClass(_, a1)) => { val Equals(r2 @ Variable(_), CaseClass(_, a2)) = sel2 @@ -250,7 +313,6 @@ class UFADTEliminator(ctx: LeonContext, program: Program) { (r1, a1, r2, a2) } } - val ants = (args1.zip(args2)).foldLeft(Seq[Expr]())((acc, pair) => { val (arg1, arg2) = pair acc :+ Equals(arg1, arg2) diff --git a/src/main/scala/leon/invariant/templateSolvers/UniversalQuantificationSolver.scala b/src/main/scala/leon/invariant/templateSolvers/UniversalQuantificationSolver.scala new file mode 100644 index 0000000000000000000000000000000000000000..bb39630784f3af54776b9f884acf1028c17cd139 --- /dev/null +++ b/src/main/scala/leon/invariant/templateSolvers/UniversalQuantificationSolver.scala @@ -0,0 +1,399 @@ +package leon +package invariant.templateSolvers + +import z3.scala._ +import purescala.Common._ +import purescala.Definitions._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.Extractors._ +import purescala.Types._ +import evaluators._ +import java.io._ +import solvers._ +import solvers.combinators._ +import solvers.smtlib._ +import solvers.z3._ +import scala.util.control.Breaks._ +import purescala.ScalaPrinter +import scala.collection.mutable.{ Map => MutableMap } +import scala.reflect.runtime.universe +import invariant.engine._ +import invariant.factories._ +import invariant.util._ +import invariant.util.ExpressionTransformer._ +import invariant.structure._ +import invariant.structure.FunctionUtils._ +import Stats._ +import leon.evaluators._ +import EvaluationResults._ + +import Util._ +import PredicateUtil._ +import SolverUtil._ + +class UniversalQuantificationSolver(ctx: InferenceContext, program: Program, + funs: Seq[FunDef], ctrTracker: ConstraintTracker, + minimizer: Option[(Expr, Model) => Model]) { + + import NLTemplateSolver._ + + //flags controlling debugging + val debugUnflattening = false + val debugIncrementalVC = false + val trackCompressedVCCTime = false + + val printCounterExample = false + val dumpInstantiatedVC = false + + val reporter = ctx.reporter + val timeout = ctx.vcTimeout + val leonctx = ctx.leonContext + + //flag controlling behavior + val disableCegis = true + private val useIncrementalSolvingForVCs = true + private val usePortfolio = false // portfolio has a bug in incremental solving + + val defaultEval = new DefaultEvaluator(leonctx, program) // an evaluator for extracting models + val existSolver = new ExistentialQuantificationSolver(ctx, program, ctrTracker, defaultEval) + + val solverFactory = + if (usePortfolio) { + if (useIncrementalSolvingForVCs) + throw new IllegalArgumentException("Cannot perform incremental solving with portfolio solvers!") + SolverFactory(() => new PortfolioSolver(leonctx, Seq(new SMTLIBCVC4Solver(leonctx, program), + new SMTLIBZ3Solver(leonctx, program))) with TimeoutSolver) + } else + SolverFactory.uninterpreted(leonctx, program) + + def splitVC(fd: FunDef) = { + val (paramPart, rest, modCons) = + time { ctrTracker.getVC(fd).toUnflatExpr } { + t => Stats.updateCounterTime(t, "UnflatTime", "VC-refinement") + } + if (ctx.usereals) { + (IntLiteralToReal(paramPart), IntLiteralToReal(rest), modCons) + } else (paramPart, rest, modCons) + } + + case class FunData(modelCons: (Model, DefaultEvaluator) => FlatModel, paramParts: Expr, simpleParts: Expr) + val funInfos = funs.map { fd => + val (paramPart, rest, modelCons) = splitVC(fd) + if (hasReals(rest) && hasInts(rest)) + throw new IllegalStateException("Non-param Part has both integers and reals: " + rest) + if (debugIncrementalVC) { + assert(getTemplateVars(rest).isEmpty) + println("For function: " + fd.id) + println("Param part: " + paramPart) + } + (fd -> FunData(modelCons, paramPart, rest)) + }.toMap + + var funSolvers = initializeSolvers + def initializeSolvers = + if (!ctx.abort) { // this is required to ensure that solvers are not created after interrupts + funInfos.map { + case (fd, FunData(_, _, rest)) => + val vcSolver = solverFactory.getNewSolver() + vcSolver.assertCnstr(rest) + (fd -> vcSolver) + }.toMap + } else Map[FunDef, Solver with TimeoutSolver]() + + def free = { + if (useIncrementalSolvingForVCs) + funSolvers.foreach(entry => entry._2.free) + } + + /** + * State for minimization + */ + class MinimizationInfo { + var minStarted = false + var lastCorrectModel: Option[Model] = None + var minStartTime: Long = 0 // for stats + + def started = minStarted + def reset() = { + minStarted = false + lastCorrectModel = None + } + def updateProgress(model: Model) { + lastCorrectModel = Some(model) + if (!minStarted) { + minStarted = true + minStartTime = System.currentTimeMillis() + } + } + def complete { + reset() + /*val mintime = (System.currentTimeMillis() - minStartTime) + Stats.updateCounterTime(mintime, "minimization-time", "procs") + Stats.updateCumTime(mintime, "Total-Min-Time")*/ + } + def getLastCorrectModel = lastCorrectModel + } + + /** + * State for recording diffcult paths + */ + class DifficultPaths { + var paths = MutableMap[FunDef, Seq[Expr]]() + + def addPath(fd: FunDef, cePath: Expr): Unit = { + if (paths.contains(fd)) { + paths.update(fd, cePath +: paths(fd)) + } else { + paths += (fd -> Seq(cePath)) + } + } + def get(fd: FunDef) = paths.get(fd) + def hasPath(fd: FunDef) = paths.contains(fd) + def pathsToExpr(fd: FunDef) = Not(createOr(paths(fd))) + def size = paths.values.map(_.size).sum + } + + abstract class RefineRes + case class UnsolvableVC() extends RefineRes + case class NoSolution() extends RefineRes + case class CorrectSolution() extends RefineRes + case class NewSolution(tempModel: Model) extends RefineRes + + class ModelRefiner(tempModel: Model) { + val tempVarMap: Map[Expr, Expr] = tempModel.map { case (k, v) => (k.toVariable -> v) }.toMap + val seenPaths = new DifficultPaths() + private var callsInPaths = Set[Call]() + + def callsEncountered = callsInPaths + + def nextCandidate(conflicts: Seq[FunDef]): RefineRes = { + var newConflicts = Seq[FunDef]() + var blockedPaths = false + val newctrsOpt = conflicts.foldLeft(Some(Seq()): Option[Seq[Expr]]) { + case (None, _) => None + case _ if (ctx.abort) => None + case (Some(acc), fd) => + val disabledPaths = + if (seenPaths.hasPath(fd)) { + blockedPaths = true + seenPaths.pathsToExpr(fd) + } else tru + checkVCSAT(fd, tempModel, disabledPaths) match { + case (None, _) => None // VC cannot be decided + case (Some(false), _) => Some(acc) // VC is unsat + case (Some(true), univModel) => // VC is sat + newConflicts :+= fd + if (verbose) reporter.info("Function: " + fd.id + "--Found candidate invariant is not a real invariant! ") + if (printCounterExample) { + reporter.info("Model: " + univModel) + } + // generate constraints for making preventing the model + val (existCtr, linearpath, calls) = existSolver.generateCtrsForUNSAT(fd, univModel, tempModel) + if (existCtr == tru) throw new IllegalStateException("Cannot find a counter-example path!!") + callsInPaths ++= calls + //instantiate the disjunct + val cePath = simplifyArithmetic(TemplateInstantiator.instantiate( + createAnd(linearpath.map(_.template)), tempVarMap)) + //some sanity checks + if (variablesOf(cePath).exists(TemplateIdFactory.IsTemplateIdentifier _)) + throw new IllegalStateException("Found template identifier in counter-example disjunct: " + cePath) + seenPaths.addPath(fd, cePath) + Some(acc :+ existCtr) + } + } + newctrsOpt match { + case None => // give up, the VC cannot be decided + UnsolvableVC() + case Some(newctrs) if (newctrs.isEmpty) => + if (!blockedPaths) { //yes, hurray,found an inductive invariant + CorrectSolution() + } else { + //give up, only hard paths remaining + reporter.info("- Exhausted all easy paths !!") + reporter.info("- Number of remaining hard paths: " + seenPaths.size) + NoSolution() //TODO: what to unroll here ? + } + case Some(newctrs) => + existSolver.solveConstraints(newctrs, tempModel) match { + case (None, _) => + //here we have timed out while solving the non-linear constraints + if (verbose) + reporter.info("NLsolver timed-out on the disjunct... blocking this disjunct...") + Stats.updateCumStats(1, "retries") + nextCandidate(newConflicts) + case (Some(false), _) => // template not solvable, need more unrollings here + NoSolution() + case (Some(true), nextModel) => + NewSolution(nextModel) + } + } + } + def nextCandidate: RefineRes = nextCandidate(funs) + } + + /** + * @param foundModel a call-back that will be invoked every time a new model is found + */ + def solveUNSAT(initModel: Model, foundModel: Model => Unit): (Option[Model], Option[Set[Call]]) = { + val minInfo = new MinimizationInfo() + var sat: Option[Boolean] = Some(true) + var tempModel = initModel + var callsInPaths = Set[Call]() + var minimized = false + while (sat == Some(true) && !ctx.abort) { + Stats.updateCounter(1, "disjuncts") + if (verbose) { + reporter.info("Candidate invariants") + TemplateInstantiator.getAllInvariants(tempModel, ctrTracker.getFuncs).foreach( + entry => reporter.info(entry._1.id + "-->" + entry._2)) + } + val modRefiner = new ModelRefiner(tempModel) + sat = modRefiner.nextCandidate match { + case CorrectSolution() if (minimizer.isDefined && !minimized) => + minInfo.updateProgress(tempModel) + val minModel = minimizer.get(existSolver.getSolvedCtrs, tempModel) + minimized = true + if (minModel.toMap == tempModel.toMap) { + minInfo.complete + Some(false) + } else { + tempModel = minModel + Some(true) + } + case CorrectSolution() => // minimization has completed or is not applicable + minInfo.complete + Some(false) + case NewSolution(newModel) => + foundModel(newModel) + minimized = false + tempModel = newModel + Some(true) + case NoSolution() => // here template is unsolvable or only hard paths remain + None + case UnsolvableVC() if minInfo.started => + tempModel = minInfo.getLastCorrectModel.get + Some(false) + case UnsolvableVC() if !ctx.abort => + if (verbose) { + reporter.info("VC solving failed!...retrying with a bigger model...") + } + existSolver.solveConstraints(retryStrategy(tempModel), tempModel) match { + case (Some(true), newModel) => + foundModel(newModel) + tempModel = newModel + funSolvers = initializeSolvers // reinitialize all VC solvers as they all timed out + Some(true) + case _ => // give up, no other bigger invariant exist or existential solving timed out! + None + } + case _ => None + } + callsInPaths ++= modRefiner.callsEncountered + } + sat match { + case _ if ctx.abort => (None, None) + case None => (None, Some(callsInPaths)) //cannot solve template, more unrollings + case _ => (Some(tempModel), None) // template solved + } + } + + /** + * Strategy: try to find a value for templates that is bigger than the current value + */ + import RealValuedExprEvaluator._ + val rtwo = FractionalLiteral(2, 1) + def retryStrategy(tempModel: Model): Seq[Expr] = { + tempModel.map { + case (id, z @ FractionalLiteral(n, _)) if n == 0 => GreaterThan(id.toVariable, z) + case (id, fl: FractionalLiteral) => GreaterThan(id.toVariable, evaluate(RealTimes(rtwo, fl))) + }.toSeq + } + + protected def instantiateTemplate(e: Expr, tempVarMap: Map[Expr, Expr]): Expr = { + if (ctx.usereals) replace(tempVarMap, e) + else + simplifyArithmetic(TemplateInstantiator.instantiate(e, tempVarMap)) + } + + /** + * Checks if the VC of fd is unsat + */ + def checkVCSAT(fd: FunDef, tempModel: Model, disabledPaths: Expr): (Option[Boolean], LazyModel) = { + val tempIdMap = tempModel.toMap + val tempVarMap: Map[Expr, Expr] = tempIdMap.map { case (k, v) => k.toVariable -> v }.toMap + val funData = funInfos(fd) + val (solver, instExpr, modelCons) = + if (useIncrementalSolvingForVCs) { + val instParamPart = instantiateTemplate(funData.paramParts, tempVarMap) + (funSolvers(fd), And(instParamPart, disabledPaths), funData.modelCons) + } else { + val FunData(modCons, paramPart, rest) = funData + val instPart = instantiateTemplate(paramPart, tempVarMap) + (solverFactory.getNewSolver(), createAnd(Seq(rest, instPart, disabledPaths)), modCons) + } + //For debugging + if (dumpInstantiatedVC) { + val fullExpr = if (useIncrementalSolvingForVCs) And(funData.simpleParts, instExpr) else instExpr + ExpressionTransformer.PrintWithIndentation("vcInst", fullExpr) + } + // sanity check + if (hasMixedIntReals(instExpr)) + throw new IllegalStateException("Instantiated VC of " + fd.id + " contains mixed integer/reals: " + instExpr) + //reporter.info("checking VC inst ...") + solver.setTimeout(timeout * 1000) + val (res, packedModel) = + time { + if (useIncrementalSolvingForVCs) { + solver.push + solver.assertCnstr(instExpr) + val solRes = solver.check match { + case _ if ctx.abort => + (None, Model.empty) + case r @ Some(true) => + (r, solver.getModel) + case r => (r, Model.empty) + } + if (solRes._1.isDefined) // invoking pop() otherwise will throw an exception + solver.pop() + solRes + } else + SimpleSolverAPI(SolverFactory(() => solver)).solveSAT(instExpr) + } { vccTime => + if (verbose) reporter.info("checked VC inst... in " + vccTime / 1000.0 + "s") + updateCounterTime(vccTime, "VC-check-time", "disjuncts") + updateCumTime(vccTime, "TotalVCCTime") + } + if (debugUnflattening) { + /*ctrTracker.getVC(fd).checkUnflattening(tempVarMap, + SimpleSolverAPI(SolverFactory(() => solverFactory.getNewSolver())), + defaultEval)*/ + verifyModel(funData.simpleParts, packedModel, SimpleSolverAPI(SolverFactory(() => solverFactory.getNewSolver()))) + //val unflatPath = ctrTracker.getVC(fd).pickSatFromUnflatFormula(funData.simpleParts, packedModel, defaultEval) + } + //for statistics + if (trackCompressedVCCTime) { + val compressedVC = + unflatten(simplifyArithmetic(instantiateTemplate(ctrTracker.getVC(fd).eliminateBlockers, tempVarMap))) + Stats.updateCounterStats(atomNum(compressedVC), "Compressed-VC-size", "disjuncts") + time { + SimpleSolverAPI(SolverFactory(() => solverFactory.getNewSolver())).solveSAT(compressedVC) + } { compTime => + Stats.updateCumTime(compTime, "TotalCompressVCCTime") + reporter.info("checked compressed VC... in " + compTime / 1000.0 + "s") + } + } + (res, modelCons(packedModel, defaultEval)) + } + + // cegis code, now not used + //val (cres, cctr, cmodel) = solveWithCegis(tempIds.toSet, createOr(newConfDisjuncts), inputCtr, Some(model)) + // def solveWithCegis(tempIds: Set[Identifier], expr: Expr, precond: Expr, initModel: Option[Model]): (Option[Boolean], Expr, Model) = { + // val cegisSolver = new CegisCore(ctx, program, timeout.toInt, NLTemplateSolver.this) + // val (res, ctr, model) = cegisSolver.solve(tempIds, expr, precond, solveAsInt = false, initModel) + // if (res.isEmpty) + // reporter.info("cegis timed-out on the disjunct...") + // (res, ctr, model) + // } + +} diff --git a/src/main/scala/leon/invariant/util/CallGraph.scala b/src/main/scala/leon/invariant/util/CallGraph.scala index 4559a90572d39b4dce2fe9c393a50ade299c1a8f..9c8d1359b0f833d613aded2b517cd467e80ba759 100644 --- a/src/main/scala/leon/invariant/util/CallGraph.scala +++ b/src/main/scala/leon/invariant/util/CallGraph.scala @@ -47,10 +47,10 @@ class CallGraph { } /** - * Checks if the src transitively calls the procedure proc + * Checks if the src transitively calls the procedure proc. + * Note: We cannot say that src calls itself even though source is reachable from itself in the callgraph */ def transitivelyCalls(src: FunDef, proc: FunDef): Boolean = { - //important: We cannot say that src calls it self even though source is reachable from itself in the callgraph graph.BFSReach(src, proc, excludeSrc = true) } @@ -59,38 +59,13 @@ class CallGraph { } /** - * sorting functions in ascending topological order + * Sorting functions in reverse topological order. + * For functions within an SCC, we preserve the initial order + * given as input */ - def topologicalOrder: Seq[FunDef] = { - - def insert(index: Int, l: Seq[FunDef], fd: FunDef): Seq[FunDef] = { - var i = 0 - var head = Seq[FunDef]() - l.foreach((elem) => { - if (i == index) - head :+= fd - head :+= elem - i += 1 - }) - head - } - - var funcList = Seq[FunDef]() - graph.getNodes.toList.foreach((f) => { - var inserted = false - var index = 0 - for (i <- funcList.indices) { - if (!inserted && this.transitivelyCalls(funcList(i), f)) { - index = i - inserted = true - } - } - if (!inserted) - funcList :+= f - else funcList = insert(index, funcList, f) - }) - - funcList + def reverseTopologicalOrder(initOrder: Seq[FunDef]): Seq[FunDef] = { + val orderMap = initOrder.zipWithIndex.toMap + graph.sccs.flatMap{scc => scc.sortWith((f1, f2) => orderMap(f1) <= orderMap(f2)) } } override def toString: String = { @@ -108,9 +83,8 @@ object CallGraphUtil { onlyBody: Boolean = false, withTemplates: Boolean = false, calleesFun: Expr => Set[FunDef] = getCallees): CallGraph = { - val cg = new CallGraph() - functionsWOFields(prog.definedFunctions).foreach((fd) => { + functionsWOFields(prog.definedFunctions).foreach{fd => cg.addFunction(fd) if (fd.hasBody) { var funExpr = fd.body.get @@ -126,7 +100,7 @@ object CallGraphUtil { //introduce a new edge for every callee calleesFun(funExpr).foreach(cg.addEdgeIfNotPresent(fd, _)) } - }) + } cg } diff --git a/src/main/scala/leon/invariant/util/ExpressionTransformer.scala b/src/main/scala/leon/invariant/util/ExpressionTransformer.scala index abda26d430a851bbdaec38f338cb4d5aea5340fb..5de31ddac39d0529ea1e715a747581df831ae198 100644 --- a/src/main/scala/leon/invariant/util/ExpressionTransformer.scala +++ b/src/main/scala/leon/invariant/util/ExpressionTransformer.scala @@ -20,13 +20,6 @@ import TVarFactory._ */ object ExpressionTransformer { - val zero = InfiniteIntegerLiteral(0) - val one = InfiniteIntegerLiteral(1) - val mone = InfiniteIntegerLiteral(-1) - val tru = BooleanLiteral(true) - val fls = BooleanLiteral(false) - val bone = BigInt(1) - // identifier for temporaries that are generated during flattening of terms other than functions val flatContext = newContext // temporaries used in the function flattening @@ -42,29 +35,28 @@ object ExpressionTransformer { * @param insideFunction when set to true indicates that the newConjuncts (second argument) * should not conjoined to the And(..) / Or(..) expressions found because they * may be called inside a function. + * TODO: remove this function altogether and treat 'and' and 'or's as functions. */ def conjoinWithinClause(e: Expr, transformer: (Expr, Boolean) => (Expr, Set[Expr]), insideFunction: Boolean): (Expr, Set[Expr]) = { e match { - case And(args) if !insideFunction => { - val newargs = args.map((arg) => { + case And(args) if !insideFunction => + val newargs = args.map{arg => val (nexp, ncjs) = transformer(arg, false) createAnd(nexp +: ncjs.toSeq) - }) + } (createAnd(newargs), Set()) - } - case Or(args) if !insideFunction => { - val newargs = args.map((arg) => { + case Or(args) if !insideFunction => + val newargs = args.map{arg => val (nexp, ncjs) = transformer(arg, false) createAnd(nexp +: ncjs.toSeq) - }) + } (createOr(newargs), Set()) - } case t: Terminal => (t, Set()) - case n @ Operator(args, op) => { + case n @ Operator(args, op) => var ncjs = Set[Expr]() val newargs = args.map((arg) => { val (nexp, js) = transformer(arg, true) @@ -72,7 +64,6 @@ object ExpressionTransformer { nexp }) (op(newargs), ncjs) - } case _ => throw new IllegalStateException("Impossible event: expr did not match any case: " + e) } } @@ -86,14 +77,14 @@ object ExpressionTransformer { def transform(e: Expr, insideFunction: Boolean): (Expr, Set[Expr]) = { e match { // Handle asserts here. Return flattened body as the result - case as @ Assert(pred, _, body) => { + case as @ Assert(pred, _, body) => val freshvar = createFlatTemp("asrtres", e.getType).toVariable val newexpr = Equals(freshvar, body) val resset = transform(newexpr, insideFunction) (freshvar, resset._2 + resset._1) - } + //handles division by constant - case Division(lhs, rhs @ InfiniteIntegerLiteral(v)) => { + case Division(lhs, rhs @ InfiniteIntegerLiteral(v)) => //this models floor and not integer division val quo = createTemp("q", IntegerType, langContext).toVariable var possibs = Seq[Expr]() @@ -106,9 +97,9 @@ object ExpressionTransformer { //println("newexpr: "+newexpr) val resset = transform(newexpr, true) (quo, resset._2 + resset._1) - } + //handles division by variables - case Division(lhs, rhs) => { + case Division(lhs, rhs) => //this models floor and not integer division val quo = createTemp("q", IntegerType, langContext).toVariable val rem = createTemp("r", IntegerType, langContext).toVariable @@ -118,30 +109,40 @@ object ExpressionTransformer { val newexpr = createAnd(Seq(divsem, LessEquals(zero, rem), LessEquals(rem, Minus(rhs, one)))) val resset = transform(newexpr, true) (quo, resset._2 + resset._1) - } - case err @ Error(_, msg) => { + + case err @ Error(_, msg) => //replace this by a fresh variable of the error type (createTemp("err", err.getType, langContext).toVariable, Set[Expr]()) - } - case Equals(lhs, rhs) => { + + case Equals(lhs, rhs) => val (nexp1, ncjs1) = transform(lhs, true) val (nexp2, ncjs2) = transform(rhs, true) (Equals(nexp1, nexp2), ncjs1 ++ ncjs2) - } - case IfExpr(cond, thn, elze) => { + + case IfExpr(cond, thn, elze) if insideFunction => val freshvar = createTemp("ifres", e.getType, langContext).toVariable - val newexpr = Or(And(cond, Equals(freshvar, thn)), And(Not(cond), Equals(freshvar, elze))) - val resset = transform(newexpr, insideFunction) - (freshvar, resset._2 + resset._1) - } - case Let(binder, value, body) => { + val (ncond, condConjs) = transform(cond, true) + val (nthen, thenConjs) = transform(Equals(freshvar, thn), false) + val (nelze, elzeConjs) = transform(Equals(freshvar, elze), false) + val conjs = condConjs + IfExpr(cond, + createAnd(nthen +: thenConjs.toSeq), createAnd(nelze +: elzeConjs.toSeq)) + (freshvar, conjs) + + case IfExpr(cond, thn, elze) => // here, we are at the top, and hence can avoid creating freshids + val (ncond, condConjs) = transform(cond, true) + val (nthen, thenConjs) = transform(thn, false) + val (nelze, elzeConjs) = transform(elze, false) + (IfExpr(cond, + createAnd(nthen +: thenConjs.toSeq), createAnd(nelze +: elzeConjs.toSeq)), condConjs) + + case Let(binder, value, body) => //TODO: do we have to consider reuse of let variables ? val (resbody, bodycjs) = transform(body, true) val (resvalue, valuecjs) = transform(value, true) (resbody, (valuecjs + Equals(binder.toVariable, resvalue)) ++ bodycjs) - } + //the value is a tuple in the following case - case LetTuple(binders, value, body) => { + /*case LetTuple(binders, value, body) => //TODO: do we have to consider reuse of let variables ? val (resbody, bodycjs) = transform(body, true) val (resvalue, valuecjs) = transform(value, true) @@ -170,11 +171,9 @@ object ExpressionTransformer { (cjs ++ cjs2) } } - (resbody, (valuecjs ++ newConjuncts) ++ bodycjs) - } - case _ => { - conjoinWithinClause(e, transform, false) - } + (resbody, (valuecjs ++ newConjuncts) ++ bodycjs)*/ + + case _ => conjoinWithinClause(e, transform, insideFunction) } } val (nexp, ncjs) = transform(inexpr, false) @@ -184,6 +183,21 @@ object ExpressionTransformer { res } + def isAtom(e: Expr): Boolean = e match { + case _: And | _: Or | _: IfExpr => false + case _ => true + } + + def isADTTheory(e: Expr) = e match { + case _: CaseClassSelector | _: CaseClass | _: TupleSelect | _: Tuple | _: IsInstanceOf => true + case _ => false + } + + def isSetTheory(e: Expr) = e match { + case _: SetUnion | _: ElementOfSet | _: SubsetOf | _: FiniteSet => true + case _ => false + } + /** * Requires: The expression has to be in NNF form and without if-then-else and let constructs * Assumed that that given expression has boolean type @@ -204,57 +218,22 @@ object ExpressionTransformer { e match { case fi @ FunctionInvocation(fd, args) => val (newargs, newConjuncts) = flattenArgs(args, true) - val newfi = FunctionInvocation(fd, newargs) val freshResVar = Variable(createTemp("r", fi.getType, funFlatContext)) - val res = (freshResVar, newConjuncts + Equals(freshResVar, newfi)) - res - - case inst @ IsInstanceOf(e1, cd) => - //replace e by a variable - val (newargs, newcjs) = flattenArgs(Seq(e1), true) - var newConjuncts = newcjs - val freshArg = newargs(0) - val newInst = IsInstanceOf(freshArg, cd) - val freshResVar = Variable(createFlatTemp("ci", inst.getType)) - newConjuncts += Equals(freshResVar, newInst) - (freshResVar, newConjuncts) - - case cs @ CaseClassSelector(cd, e1, sel) => - val (newargs, newcjs) = flattenArgs(Seq(e1), true) - var newConjuncts = newcjs - val freshArg = newargs(0) - val newCS = CaseClassSelector(cd, freshArg, sel) - val freshResVar = Variable(createFlatTemp("cs", cs.getType)) - //val freshResVar = Variable(createTemp("cs", cs.getType, funFlatContext)) // we cannot flatten these as they will converted to cons - newConjuncts += Equals(freshResVar, newCS) - (freshResVar, newConjuncts) - - case ts @ TupleSelect(e1, index) => - val (newargs, newcjs) = flattenArgs(Seq(e1), true) - var newConjuncts = newcjs - val freshArg = newargs(0) - val newTS = TupleSelect(freshArg, index) - val freshResVar = Variable(createFlatTemp("ts", ts.getType)) - //val freshResVar = Variable(createTemp("ts", ts.getType, funFlatContext)) - newConjuncts += Equals(freshResVar, newTS) - (freshResVar, newConjuncts) - - case cc @ CaseClass(cd, args) => + (freshResVar, newConjuncts + Equals(freshResVar, FunctionInvocation(fd, newargs))) + + case adte if isADTTheory(adte) => + val Operator(args, op) = adte + val freshName = adte match { + case _: IsInstanceOf => "ci" + case _: CaseClassSelector => "cs" + case _: CaseClass => "cc" + case _: TupleSelect => "ts" + case _: Tuple => "tp" + } + val freshVar = Variable(createFlatTemp(freshName, adte.getType)) val (newargs, newcjs) = flattenArgs(args, true) - var newConjuncts = newcjs - val newCC = CaseClass(cd, newargs) - val freshResVar = Variable(createFlatTemp("cc", cc.getType)) - newConjuncts += Equals(freshResVar, newCC) - (freshResVar, newConjuncts) + (freshVar, newcjs + Equals(freshVar, op(newargs))) - case tp @ Tuple(args) => { - val (newargs, newcjs) = flattenArgs(args, true) - var newConjuncts = newcjs - val newTP = Tuple(newargs) - val freshResVar = Variable(createFlatTemp("tp", tp.getType)) - newConjuncts += Equals(freshResVar, newTP) - (freshResVar, newConjuncts) - } case SetUnion(_, _) | ElementOfSet(_, _) | SubsetOf(_, _) => val Operator(args, op) = e val (Seq(a1, a2), newcjs) = flattenArgs(args, true) @@ -269,6 +248,26 @@ object ExpressionTransformer { val freshResVar = Variable(createFlatTemp("fset", fs.getType)) (freshResVar, newcjs + Equals(freshResVar, newexpr)) + case And(args) if insideFunction => + val (nargs, cjs) = flattenArithmeticCtrs(args) + (And(nargs), cjs) + + case Or(args) if insideFunction => + val (nargs, cjs) = flattenArithmeticCtrs(args) + (Or(nargs), cjs) + + case IfExpr(cond, thn, elze) => // make condition of if-then-elze an atom + val (nthen, thenConjs) = flattenFunc(thn, false) + val (nelze, elzeConjs) = flattenFunc(elze, false) + val (ncond, condConjs) = flattenFunc(cond, true) match { + case r@(nc, _) if isAtom(nc) && getTemplateIds(nc).isEmpty => r + case (nc, conjs) => + val condvar = createFlatTemp("cond", cond.getType).toVariable + (condvar, conjs + Equals(condvar, nc)) + } + (IfExpr(ncond, createAnd(nthen +: thenConjs.toSeq), + createAnd(nelze +: elzeConjs.toSeq)), condConjs) + case _ => conjoinWithinClause(e, flattenFunc, insideFunction) } } @@ -292,100 +291,85 @@ object ExpressionTransformer { } (newargs, newConjuncts) } + + def flattenArithmeticCtrs(args: Seq[Expr]) = { + val (flatArgs, cjs) = flattenArgs(args, true) + var ncjs = Set[Expr]() + val nargs = flatArgs.map { + case farg if isArithmeticRelation(farg) != Some(false) => + // 'farg' is a possibly arithmetic relation. + val argvar = createFlatTemp("ar", farg.getType).toVariable + ncjs += Equals(argvar, farg) + argvar + case farg => farg + } + (nargs, cjs ++ ncjs) + } + val (nexp, ncjs) = flattenFunc(inExpr, false) if (ncjs.nonEmpty) { createAnd(nexp +: ncjs.toSeq) } else nexp } - def testHelp(e: Expr) = { - e match { - case Operator(args, op) => - args.foreach { arg => - if (arg.getType == Untyped) { - println(s"$arg is untyped! ") - arg match { - case CaseClassSelector(cct, cl, fld) => - println("cl type: " + cl.getType + " cct: " + cct) - case _ => - } - } - } - case _ => - } + /** + * note: we consider even type parameters as ADT type + */ + def adtType(e: Expr) = { + val tpe = e.getType + tpe.isInstanceOf[ClassType] || tpe.isInstanceOf[TupleType] || tpe.isInstanceOf[TypeParameter] } /** * The following procedure converts the formula into negated normal form by pushing all not's inside. - * It also handles disequality constraints. + * It will not convert boolean equalities or inequalities to disjunctions for performance. * Assumption: * (a) the formula does not have match constructs + * (b) all lets have been pulled to the top * Some important features. * (a) For a strict inequality with real variables/constants, the following produces a strict inequality * (b) Strict inequalities with only integer variables/constants are reduced to non-strict inequalities */ - def TransformNot(expr: Expr, retainNEQ: Boolean = false): Expr = { // retainIff : Boolean = false - def nnf(inExpr: Expr): Expr = { - if (inExpr.getType != BooleanType) inExpr - else { - inExpr match { - case Not(Not(e1)) => nnf(e1) - case e @ Not(t: Terminal) => e - case e @ Not(FunctionInvocation(_, _)) => e - case Not(And(args)) => createOr(args.map(arg => nnf(Not(arg)))) - case Not(Or(args)) => createAnd(args.map(arg => nnf(Not(arg)))) - case Not(e @ Operator(Seq(e1, e2), op)) => { - //matches integer binary relation or a boolean equality - if (e1.getType == BooleanType || e1.getType == Int32Type || e1.getType == RealType || e1.getType == IntegerType) { - e match { - case e: Equals => { - if (e1.getType == BooleanType && e2.getType == BooleanType) { - Or(And(nnf(e1), nnf(Not(e2))), And(nnf(e2), nnf(Not(e1)))) - } else { - if (retainNEQ) Not(Equals(e1, e2)) - else Or(nnf(LessThan(e1, e2)), nnf(GreaterThan(e1, e2))) - } - } - case e: LessThan => GreaterEquals(nnf(e1), nnf(e2)) - case e: LessEquals => GreaterThan(nnf(e1), nnf(e2)) - case e: GreaterThan => LessEquals(nnf(e1), nnf(e2)) - case e: GreaterEquals => LessThan(nnf(e1), nnf(e2)) - case e: Implies => And(nnf(e1), nnf(Not(e2))) - case _ => throw new IllegalStateException("Unknown binary operation: " + e) - } - } else { - //in this case e is a binary operation over ADTs - e match { - // TODO: is this a bug ? - case ninst @ Not(IsInstanceOf(e1, cd)) => Not(IsInstanceOf(nnf(e1), cd)) - case SubsetOf(_, _) | ElementOfSet(_, _) | SetUnion(_, _) | FiniteSet(_, _) => - Not(e) - case e: Equals => Not(Equals(nnf(e1), nnf(e2))) - case _ => throw new IllegalStateException("Unknown operation on algebraic data types: " + e) - } - } - } - case e @ Equals(lhs, SubsetOf(_, _) | ElementOfSet(_, _) | SetUnion(_, _) | FiniteSet(_, _)) => - // all are set operations - e - case e @ Equals(lhs, IsInstanceOf(_, _) | CaseClassSelector(_, _, _) | TupleSelect(_, _) | FunctionInvocation(_, _)) => - //all case where rhs could use an ADT tree e.g. instanceOF, tupleSelect, fieldSelect, function invocation - e - case Implies(lhs, rhs) => nnf(Or(Not(lhs), rhs)) - case Equals(lhs, rhs) if (lhs.getType == BooleanType && rhs.getType == BooleanType) => { - nnf(And(Implies(lhs, rhs), Implies(rhs, lhs))) + def toNNF(inExpr: Expr, retainNEQ: Boolean = false): Expr = { + def nnf(expr: Expr): Expr = { +// /println("Invoking nnf on: "+expr) + expr match { + //case e if e.getType != BooleanType => e + case Not(Not(e1)) => nnf(e1) + case e @ Not(t: Terminal) => e + case Not(FunctionInvocation(tfd, args)) => Not(FunctionInvocation(tfd, args map nnf)) + case Not(And(args)) => createOr(args.map(arg => nnf(Not(arg)))) + case Not(Or(args)) => createAnd(args.map(arg => nnf(Not(arg)))) + case Not(Let(i, v, e)) => Let(i, nnf(v), nnf(Not(e))) + case Not(IfExpr(cond, thn, elze)) => IfExpr(nnf(cond), nnf(Not(thn)), nnf(Not(elze))) + case Not(e @ Operator(Seq(e1, e2), op)) => // Not of binary operator ? + e match { + case _: LessThan => GreaterEquals(e1, e2) + case _: LessEquals => GreaterThan(e1, e2) + case _: GreaterThan => LessEquals(e1, e2) + case _: GreaterEquals => LessThan(e1, e2) + case _: Implies => And(nnf(e1), nnf(Not(e2))) + case _: SubsetOf | _: ElementOfSet | _: SetUnion | _: FiniteSet => Not(e) // set ops + // handle equalities (which is shared by theories) + case _: Equals if e1.getType == BooleanType => Not(Equals(nnf(e1), nnf(e2))) + case _: Equals if adtType(e1) || e1.getType.isInstanceOf[SetType] => Not(e) // adt or set equality + case _: Equals if TypeUtil.isNumericType(e1.getType) => + if (retainNEQ) Not(Equals(e1, e2)) + else Or(nnf(LessThan(e1, e2)), nnf(GreaterThan(e1, e2))) + case _ => throw new IllegalStateException(s"Unknown binary operation: $e arg types: ${e1.getType},${e2.getType}") } - case Not(IfExpr(cond, thn, elze)) => IfExpr(nnf(cond), nnf(Not(thn)), nnf(Not(elze))) - case Not(Let(i, v, e)) => Let(i, nnf(v), nnf(Not(e))) - case t: Terminal => t - case n @ Operator(args, op) => op(args.map(nnf(_))) - - case _ => throw new IllegalStateException("Impossible event: expr did not match any case: " + inExpr) - } + case Implies(lhs, rhs) => nnf(Or(Not(lhs), rhs)) + case Equals(lhs, rhs @ (_: SubsetOf | _: ElementOfSet | _: IsInstanceOf | _: TupleSelect | _: CaseClassSelector)) => + Equals(nnf(lhs), rhs) + case Equals(lhs, FunctionInvocation(tfd, args)) => + Equals(nnf(lhs), FunctionInvocation(tfd, args map nnf)) + case Equals(lhs, rhs) if lhs.getType == BooleanType => Equals(nnf(lhs), nnf(rhs)) + case t: Terminal => t + case n @ Operator(args, op) => op(args map nnf) + case _ => throw new IllegalStateException("Impossible event: expr did not match any case: " + inExpr) } } - val nnfvc = nnf(expr) - nnfvc + nnf(inExpr) } /** @@ -393,22 +377,19 @@ object ExpressionTransformer { * This is supposed to be a semantic preserving transformation */ def pullAndOrs(expr: Expr): Expr = { - simplePostTransform { - case Or(args) => { + case Or(args) => val newArgs = args.foldLeft(Seq[Expr]())((acc, arg) => arg match { case Or(inArgs) => acc ++ inArgs case _ => acc :+ arg }) createOr(newArgs) - } - case And(args) => { + case And(args) => val newArgs = args.foldLeft(Seq[Expr]())((acc, arg) => arg match { case And(inArgs) => acc ++ inArgs case _ => acc :+ arg }) createAnd(newArgs) - } case e => e }(expr) } @@ -417,17 +398,12 @@ object ExpressionTransformer { * Normalizes the expressions */ def normalizeExpr(expr: Expr, multOp: (Expr, Expr) => Expr): Expr = { - //reduce the language before applying flatten function //println("Normalizing " + ScalaPrinter(expr) + "\n") - val redex = reduceLangBlocks(expr, multOp) - //println("Redex: " + ScalaPrinter(redex) + "\n") - val nnfExpr = TransformNot(redex) - //println("NNFexpr: " + ScalaPrinter(nnfExpr) + "\n") - //flatten all function calls - val flatExpr = FlattenFunction(nnfExpr) - //println("Flatexpr: " + ScalaPrinter(flatExpr) + "\n") - //perform additional simplification - val simpExpr = pullAndOrs(TransformNot(flatExpr)) + val redex = reduceLangBlocks(toNNF(matchToIfThenElse(expr)), multOp) + //println("After reducing lang blocks: " + ScalaPrinter(redex) + "\n") + val flatExpr = FlattenFunction(redex) + val simpExpr = pullAndOrs(flatExpr) + //println("After Normalizing: " + ScalaPrinter(flatExpr) + "\n") simpExpr } @@ -462,8 +438,8 @@ object ExpressionTransformer { // specially handle boolean function to prevent unnecessary simplifications case Or(args) => Or(args map rec) case And(args) => And(args map rec) - case Not(arg) => Not(rec(arg)) - case Operator(args, op) => op(args map rec) + case IfExpr(cond, th, elze) => IfExpr(rec(cond), rec(th), rec(elze)) + case e => e // we should not recurse in other operations, note: Not(equals) should not be considered } val newe = rec(ine) val closure = (e: Expr) => replaceFromIDs(idMap, e) @@ -559,7 +535,10 @@ object ExpressionTransformer { case _ => true } - def PrintWithIndentation(wr: PrintWriter, expr: Expr): Unit = { + def PrintWithIndentation(filePrefix: String, expr: Expr): Unit = { + + val filename = filePrefix + FileCountGUID.getID + ".txt" + val wr = new PrintWriter(new File(filename)) def uniOP(e: Expr, seen: Int): Boolean = e match { case And(args) => { @@ -605,6 +584,8 @@ object ExpressionTransformer { } } printRec(expr, 0) + wr.close() + println("Printed to file: " + filename) } /** diff --git a/src/main/scala/leon/invariant/util/LetTupleSimplification.scala b/src/main/scala/leon/invariant/util/LetTupleSimplification.scala index 641f6045e004df79be1a30fccf24b385c490a384..d560f5bc2719797679f64ec99e3d9806b12227b4 100644 --- a/src/main/scala/leon/invariant/util/LetTupleSimplification.scala +++ b/src/main/scala/leon/invariant/util/LetTupleSimplification.scala @@ -335,8 +335,7 @@ object LetTupleSimplification { } // println(s"E : $e After Pulling lets to top : \n $transe") transe - } - //val res = pullLetToTop(matchToIfThenElse(ine)) + } val res = pullLetToTop(ine) /*if(debug) println(s"InE : $ine After Pulling lets to top : \n ${ScalaPrinter.apply(res)}")*/ diff --git a/src/main/scala/leon/invariant/util/SolverUtil.scala b/src/main/scala/leon/invariant/util/SolverUtil.scala index 9d798ac020a953e5f138bc7fd35494c30ce12358..a9ff8b0f19cfb4733243b28ce59d6379ff372727 100644 --- a/src/main/scala/leon/invariant/util/SolverUtil.scala +++ b/src/main/scala/leon/invariant/util/SolverUtil.scala @@ -12,6 +12,9 @@ import leon.invariant.templateSolvers.ExtendedUFSolver import java.io._ import Util._ import PredicateUtil._ +import evaluators._ +import EvaluationResults._ +import purescala.Extractors._ object SolverUtil { @@ -24,6 +27,15 @@ object SolverUtil { }) } + def completeWithRefModel(currModel: Model, refModel: Model) = { + new Model(refModel.toMap.map { + case (id, _) if currModel.isDefinedAt(id) => + (id -> currModel(id)) + case (id, v) => + (id -> v) + }.toMap) + } + def toZ3SMTLIB(expr: Expr, filename: String, theory: String, ctx: LeonContext, pgm: Program, useBitvectors: Boolean = false, @@ -38,6 +50,14 @@ object SolverUtil { writer.close() } + def verifyModel(e: Expr, model: Model, solver: SimpleSolverAPI) = { + solver.solveSAT(And(e, modelToExpr(model))) match { + case (Some(false), _) => + throw new IllegalStateException("Model doesn't staisfy formula!") + case _ => + } + } + /** * A helper function that can be used to hardcode an invariant and see if it unsatifies the paths */ diff --git a/src/main/scala/leon/invariant/util/TreeUtil.scala b/src/main/scala/leon/invariant/util/TreeUtil.scala index 5949082d339c198d57a7d17255ce2db5f5368498..b5650b7b2e1a65461664eeb6365f4f3b7fbf2379 100644 --- a/src/main/scala/leon/invariant/util/TreeUtil.scala +++ b/src/main/scala/leon/invariant/util/TreeUtil.scala @@ -121,7 +121,7 @@ object ProgramUtil { * will be removed */ def assignTemplateAndCojoinPost(funToTmpl: Map[FunDef, Expr], prog: Program, - funToPost: Map[FunDef, Expr] = Map(), uniqueIdDisplay: Boolean = true): Program = { + funToPost: Map[FunDef, Expr] = Map(), uniqueIdDisplay: Boolean = false): Program = { val funMap = functionsWOFields(prog.definedFunctions).foldLeft(Map[FunDef, FunDef]()) { case (accMap, fd) if fd.isTheoryOperation => @@ -236,13 +236,23 @@ object ProgramUtil { } def translateExprToProgram(ine: Expr, currProg: Program, newProg: Program): Expr = { + var funCache = Map[String, Option[FunDef]]() + def funInNewprog(fn: String) = + funCache.get(fn) match { + case None => + val fd = functionByFullName(fn, newProg) + funCache += (fn -> fd) + fd + case Some(fd) => fd + } simplePostTransform { case FunctionInvocation(TypedFunDef(fd, tps), args) => - functionByName(fullName(fd)(currProg), newProg) match { + val fname = fullName(fd)(currProg) + funInNewprog(fname) match { case Some(nfd) => FunctionInvocation(TypedFunDef(nfd, tps), args) case _ => - throw new IllegalStateException(s"Cannot find translation for ${fd.id.name}") + throw new IllegalStateException(s"Cannot find translation for ${fname}") } case e => e }(ine) @@ -286,27 +296,37 @@ object PredicateUtil { (e => e, base) } + def letStarUnapplyWithSimplify(e: Expr): (Expr => Expr, Expr) = { + val (letCons, letBody) = letStarUnapply(e) + (letCons andThen simplifyLets, letBody) + } + /** * Checks if the input expression has only template variables as free variables */ def isTemplateExpr(expr: Expr): Boolean = { var foundVar = false - simplePostTransform { - case e @ Variable(id) => { + postTraversal { + case e @ Variable(id) => if (!TemplateIdFactory.IsTemplateIdentifier(id)) - foundVar = true - e - } - case e @ ResultVariable(_) => { - foundVar = true - e - } - case e => e + foundVar = true + case e @ ResultVariable(_) => + foundVar = true + case e => }(expr) - !foundVar } + def isArithmeticRelation(e: Expr) = { + e match { + case Equals(l, r) => + if (l.getType == Untyped) None + else Some(TypeUtil.isNumericType(l.getType)) + case _: LessThan | _: LessEquals | _: GreaterThan | _: GreaterEquals => Some(true) + case _ => Some(false) + } + } + def getTemplateIds(expr: Expr) = { variablesOf(expr).filter(TemplateIdFactory.IsTemplateIdentifier) } @@ -352,20 +372,15 @@ object PredicateUtil { hasInts(expr) && hasReals(expr) } - def atomNum(e: Expr): Int = { - var count: Int = 0 - simplePostTransform { - case e @ And(args) => { - count += args.size - e - } - case e @ Or(args) => { - count += args.size - e - } - case e => e - }(e) - count + /** + * Assuming a flattenned formula + */ + def atomNum(e: Expr): Int = e match { + case And(args) => (args map atomNum).sum + case Or(args) => (args map atomNum).sum + case IfExpr(c, th, el) => atomNum(c) + atomNum(th) + atomNum(el) + case Not(arg) => atomNum(arg) + case e => 1 } def numUIFADT(e: Expr): Int = { @@ -448,8 +463,9 @@ object PredicateUtil { * Computes the set of variables that are shared across disjunctions. * This may return bound variables as well */ - def sharedIds(e: Expr): Set[Identifier] = e match { - case Or(args) => + def sharedIds(ine: Expr): Set[Identifier] = { + + def sharedOfDisjointExprs(args: Seq[Expr]) = { var uniqueVars = Set[Identifier]() var sharedVars = Set[Identifier]() args.foreach { arg => @@ -458,9 +474,17 @@ object PredicateUtil { sharedVars ++= newShared uniqueVars = (uniqueVars ++ candUniques) -- newShared } - sharedVars ++ (args flatMap sharedIds) - case Variable(_) => Set() - case Operator(args, op) => - (args flatMap sharedIds).toSet + sharedVars ++ (args flatMap rec) + } + def rec(e: Expr): Set[Identifier] = + e match { + case Or(args) => sharedOfDisjointExprs(args) + case IfExpr(c, th, el) => + rec(c) ++ sharedOfDisjointExprs(Seq(th, el)) + case Variable(_) => Set() + case Operator(args, op) => + (args flatMap rec).toSet + } + rec(ine) } } diff --git a/src/main/scala/leon/invariant/util/TypeUtil.scala b/src/main/scala/leon/invariant/util/TypeUtil.scala index bb1a70e326c43e362641cbd042d6f7e7d8862149..6d7c59e8068a3fc0ec2ed99bbebb926c6aa9fa41 100644 --- a/src/main/scala/leon/invariant/util/TypeUtil.scala +++ b/src/main/scala/leon/invariant/util/TypeUtil.scala @@ -49,4 +49,10 @@ object TypeUtil { throw new IllegalStateException("BitVector types not supported yet!") case _ => false } + + def rootType(t: TypeTree): Option[AbstractClassType] = t match { + case absT: AbstractClassType => Some(absT) + case ct: CaseClassType => ct.parent + case _ => None + } } \ No newline at end of file diff --git a/src/main/scala/leon/invariant/util/UnflatHelper.scala b/src/main/scala/leon/invariant/util/UnflatHelper.scala index 16ca818eda7092e22d62669dfb4b304d230dd8ea..c32d2208478553eddd822b2018c0ca3866fe79ee 100644 --- a/src/main/scala/leon/invariant/util/UnflatHelper.scala +++ b/src/main/scala/leon/invariant/util/UnflatHelper.scala @@ -33,7 +33,7 @@ class SimpleLazyModel(m: Model) extends LazyModel { * Expands a given model into a model with mappings for identifiers introduced during flattening. * Note: this class cannot be accessed in parallel. */ -class FlatModel(freeVars: Set[Identifier], flatIdMap: Map[Identifier, Expr], initModel: Model, eval: DefaultEvaluator) extends LazyModel { +class FlatModel(freeVars: Set[Identifier], flatIdMap: Map[Identifier, Expr], initModel: Model, eval: DefaultEvaluator) extends LazyModel { var idModel = initModel.toMap override def get(iden: Identifier) = { @@ -55,7 +55,8 @@ class FlatModel(freeVars: Set[Identifier], flatIdMap: Map[Identifier, Expr], ini idModel += (id -> v) Some(v) case _ => - throw new IllegalStateException(s"Evaluation Falied for $id -> $rhs") + None + //throw new IllegalStateException(s"Evaluation Falied for $id -> $rhs") } } else if (freeVars(id)) { // here, `id` either belongs to values of the flatIdMap, or to flate or was lost in unflattening @@ -72,13 +73,26 @@ class FlatModel(freeVars: Set[Identifier], flatIdMap: Map[Identifier, Expr], ini } } +object UnflatHelper { + def evaluate(e: Expr, m: LazyModel, eval: DefaultEvaluator): Expr = { + val varsMap = variablesOf(e).collect { + case v if m.isDefinedAt(v) => (v -> m(v)) + }.toMap + eval.eval(e, varsMap) match { + case Successful(v) => v + case _ => + throw new IllegalStateException(s"Evaluation Falied for $e") + } + } +} + /** * A class that can used to compress a flattened expression * and also expand the compressed models to the flat forms */ class UnflatHelper(ine: Expr, excludeIds: Set[Identifier], eval: DefaultEvaluator) { - val (unflate, flatIdMap) = unflattenWithMap(ine, excludeIds, includeFuns = false) + val (unflate, flatIdMap) = unflattenWithMap(ine, excludeIds, includeFuns = false) val invars = variablesOf(ine) def getModel(m: Model) = new FlatModel(invars, flatIdMap, m, eval) diff --git a/src/main/scala/leon/invariant/util/Util.scala b/src/main/scala/leon/invariant/util/Util.scala index a792f1586479d5989773cdeee7e322eaa66983a6..c660b065062ca99c5f4e87d18a259b7d1abbbeb7 100644 --- a/src/main/scala/leon/invariant/util/Util.scala +++ b/src/main/scala/leon/invariant/util/Util.scala @@ -6,6 +6,9 @@ import purescala.Types._ import purescala.PrettyPrintable import purescala.PrinterContext import purescala.PrinterHelpers._ +import purescala.Definitions._ +import purescala.Common._ +import purescala.ExprOps._ object FileCountGUID { var fileCount = 0 @@ -38,6 +41,8 @@ object Util { val zero = InfiniteIntegerLiteral(0) val one = InfiniteIntegerLiteral(1) + val mone = InfiniteIntegerLiteral(-1) + val bone = BigInt(1) val tru = BooleanLiteral(true) val fls = BooleanLiteral(false) @@ -62,4 +67,33 @@ object Util { else product } + + /** + * Transitively close the substitution map from identifiers to expressions. + * Note: the map is required to be acyclic. + */ + def substClosure(initMap: Map[Identifier, Expr]): Map[Identifier, Expr] = { + if (initMap.isEmpty) initMap + else { + var stables = Seq[(Identifier, Expr)]() + var unstables = initMap.toSeq + var changed = true + while (changed) { + changed = false + var foundStable = false + unstables = unstables.flatMap { + case (k, v) if variablesOf(v).intersect(initMap.keySet).isEmpty => + foundStable = true + stables +:= (k -> v) + Seq() + case (k, v) => + changed = true + Seq((k -> replaceFromIDs(initMap, v))) + } + if (!foundStable) + throw new IllegalStateException(s"No stable entry was found in the map! The map is possibly cyclic: $initMap") + } + stables.toMap + } + } } diff --git a/src/main/scala/leon/laziness/LazinessEliminationPhase.scala b/src/main/scala/leon/laziness/LazinessEliminationPhase.scala index 439dc5f8b7ed81492581df370e41554276e14d0b..f5a10ad83e77b6acdb743b3804dfbec00687d59f 100644 --- a/src/main/scala/leon/laziness/LazinessEliminationPhase.scala +++ b/src/main/scala/leon/laziness/LazinessEliminationPhase.scala @@ -38,6 +38,8 @@ import PredicateUtil._ import leon.invariant.engine._ import LazyVerificationPhase._ import utils._ +import java.io. +_ /** * TODO: Function names are assumed to be small case. Fix this!! */ @@ -119,9 +121,18 @@ object LazinessEliminationPhase extends TransformationPhase { } // check specifications (to be moved to a different phase) if (!skipResourceVerification) - checkInstrumentationSpecs(instProg, checkCtx) + checkInstrumentationSpecs(instProg, checkCtx, + checkCtx.findOption(LazinessEliminationPhase.optUseOrb).getOrElse(false)) // dump stats - dumpStats() + if (ctx.findOption(SharedOptions.optBenchmark).getOrElse(false)) { + val modid = prog.units.find(_.isMainUnit).get.id + val filename = modid + "-stats.txt" + val pw = new PrintWriter(filename) + Stats.dumpStats(pw) + SpecificStats.dumpOutputs(pw) + ctx.reporter.info("Stats dumped to file: " + filename) + pw.close() + } instProg } diff --git a/src/main/scala/leon/laziness/LazinessUtil.scala b/src/main/scala/leon/laziness/LazinessUtil.scala index e5546a7eba12b8de25a0d9255f26615c501b37a8..39c7505c534d7615ee9536f4dd10a1b14546b168 100644 --- a/src/main/scala/leon/laziness/LazinessUtil.scala +++ b/src/main/scala/leon/laziness/LazinessUtil.scala @@ -171,13 +171,7 @@ object LazinessUtil { case ctype @ CaseClassType(_, Seq(innerType)) if isLazyType(ctype) || isMemType(ctype) => Some(innerType) case _ => None - } - - def rootType(t: TypeTree): Option[AbstractClassType] = t match { - case absT: AbstractClassType => Some(absT) - case ct: ClassType => ct.parent - case _ => None - } + } def opNameToCCName(name: String) = { name.capitalize + "@" diff --git a/src/main/scala/leon/laziness/LazyVerificationPhase.scala b/src/main/scala/leon/laziness/LazyVerificationPhase.scala index e17548465055c89688308006a2beeb2b2e5d9ed2..d9b4c18e1f55de015e7009819946bfc8cbc700b7 100644 --- a/src/main/scala/leon/laziness/LazyVerificationPhase.scala +++ b/src/main/scala/leon/laziness/LazyVerificationPhase.scala @@ -40,6 +40,7 @@ import leon.invariant.engine._ object LazyVerificationPhase { val debugInstVCs = false + val debugInferProgram = true def removeInstrumentationSpecs(p: Program): Program = { def hasInstVar(e: Expr) = { @@ -74,34 +75,17 @@ object LazyVerificationPhase { solverOptions.options ++ userOptions.options) } - // cumulative stats - var totalTime = 0L - var totalVCs = 0 - var solvedWithZ3 = 0 - var solvedWithCVC4 = 0 - var z3Time = 0L - var cvc4Time = 0L - def collectCumulativeStats(rep: VerificationReport) { - totalTime += rep.totalTime - totalVCs += rep.totalConditions - val (withz3, withcvc) = rep.vrs.partition{ + Stats.updateCumTime(rep.totalTime, "Total-Verification-Time") + Stats.updateCumStats(rep.totalConditions, "Total-VCs-Generated") + val (withz3, withcvc) = rep.vrs.partition { case (vc, vr) => vr.solvedWith.map(s => s.name.contains("smt-z3")).get } - solvedWithZ3 += withz3.size - solvedWithCVC4 += withcvc.size - z3Time += withz3.map(_._2.timeMs.getOrElse(0L)).sum - cvc4Time += withcvc.map(_._2.timeMs.getOrElse(0L)).sum - } - - def dumpStats() { - println("totalTime: "+f"${totalTime/1000d}%-3.3f") - println("totalVCs: "+totalVCs) - println("solvedWithZ3: "+ solvedWithZ3) - println("solvedWithCVC4: "+ solvedWithCVC4) - println("z3Time: "+f"${z3Time/1000d}%-3.3f") - println("cvc4Time: "+f"${cvc4Time/1000d}%-3.3f") + Stats.updateCounter(withz3.size, "Z3SolvedVCs") + Stats.updateCounter(withcvc.size, "CVC4SolvedVCs") + Stats.updateCounterStats(withz3.map(_._2.timeMs.getOrElse(0L)).sum, "Z3-Time", "Z3SolvedVCs") + Stats.updateCounterStats(withcvc.map(_._2.timeMs.getOrElse(0L)).sum, "CVC4-Time", "CVC4SolvedVCs") } def checkSpecifications(prog: Program, checkCtx: LeonContext) { @@ -110,52 +94,40 @@ object LazyVerificationPhase { if (fd.annotations.contains("axiom")) fd.addFlag(Annotation("library", Seq())) } - // val functions = Seq() //Seq("--functions=rotate") - // val solverOptions = if (debugSolvers) Seq("--debug=solver") else Seq() - // val unfoldFactor = 3 , - // "--unfoldFactor="+unfoldFactor) ++ solverOptions ++ functions - //val solverOptions = Main.processOptions(Seq("--solvers=smt-cvc4,smt-z3", "--assumepre") val report = VerificationPhase.apply(checkCtx, prog) // collect stats collectCumulativeStats(report) println(report.summaryString) - /*ctx.reporter.whenDebug(leon.utils.DebugSectionTimers) { debug => - ctx.timers.outputTable(debug) - }*/ } - def checkInstrumentationSpecs(p: Program, checkCtx: LeonContext) = { - - val useOrb = checkCtx.findOption(LazinessEliminationPhase.optUseOrb).getOrElse(false) + def checkInstrumentationSpecs(p: Program, checkCtx: LeonContext, useOrb: Boolean) = { p.definedFunctions.foreach { fd => if (fd.annotations.contains("axiom")) fd.addFlag(Annotation("library", Seq())) } val funsToCheck = p.definedFunctions.filter(shouldGenerateVC) - if (useOrb) { - // create an inference context - val inferOpts = Main.processOptions(Seq("--disableInfer", "--assumepreInf", "--minbounds","--solvers=smt-cvc4")) - val ctxForInf = LeonContext(checkCtx.reporter, checkCtx.interruptManager, - inferOpts.options ++ checkCtx.options) - val inferctx = new InferenceContext(p, ctxForInf) - val vcSolver = (funDef: FunDef, prog: Program) => new VCSolver(inferctx, prog, funDef) - prettyPrintProgramToFile(inferctx.inferProgram, checkCtx, "-inferProg", true) - (new InferenceEngine(inferctx)).analyseProgram(inferctx.inferProgram, funsToCheck, vcSolver, None) - } else { - val vcs = funsToCheck.map { fd => - val (ants, post, tmpl) = createVC(fd) - if (tmpl.isDefined) - throw new IllegalStateException("Postcondition has holes! Run with --useOrb option") - val vc = implies(ants, post) - if (debugInstVCs) - println(s"VC for function ${fd.id} : " + vc) - VC(vc, fd, VCKinds.Postcondition) + val rep = + if (useOrb) { + // create an inference context + val inferOpts = Main.processOptions(Seq("--disableInfer", "--assumepreInf", "--minbounds", "--solvers=smt-cvc4")) + val ctxForInf = LeonContext(checkCtx.reporter, checkCtx.interruptManager, + inferOpts.options ++ checkCtx.options) + val inferctx = new InferenceContext(p, ctxForInf) + val vcSolver = (funDef: FunDef, prog: Program) => new VCSolver(inferctx, prog, funDef) + + if (debugInferProgram) + prettyPrintProgramToFile(inferctx.inferProgram, checkCtx, "-inferProg", true) + + val results = (new InferenceEngine(inferctx)).analyseProgram(inferctx.inferProgram, + funsToCheck.map(InstUtil.userFunctionName), vcSolver, None) + new InferenceReport(results.map { case (fd, ic) => (fd -> List[VC](ic)) }, inferctx.inferProgram)(inferctx) + } else { + val rep = checkVCs(funsToCheck.map(vcForFun), checkCtx, p) + // record some stats + collectCumulativeStats(rep) + rep } - val rep = checkVCs(vcs, checkCtx, p) - // record some stats - collectCumulativeStats(rep) - println("Resource Verification Results: \n" + rep.summaryString) - } + println("Resource Verification Results: \n" + rep.summaryString) } def accessesSecondRes(e: Expr, resid: Identifier): Boolean = @@ -178,7 +150,17 @@ object LazyVerificationPhase { * Moreover, we can add other specs as assumptions since (A => B) ^ ((A ^ B) => C) => A => B ^ C * checks if the expression uses res._2 which corresponds to instvars after instrumentation */ - def createVC(fd: FunDef) = { + def vcForFun(fd: FunDef) = { + val (body, ants, post, tmpl) = collectAntsPostTmpl(fd) + if (tmpl.isDefined) + throw new IllegalStateException("Postcondition has holes! Run with --useOrb option") + val vc = implies(And(ants, body), post) + if (debugInstVCs) + println(s"VC for function ${fd.id} : " + vc) + VC(vc, fd, VCKinds.Postcondition) + } + + def collectAntsPostTmpl(fd: FunDef) = { val Lambda(Seq(resdef), _) = fd.postcondition.get val (pbody, tmpl) = (fd.getPostWoTemplate, fd.template) val (instPost, assumptions) = pbody match { @@ -186,7 +168,7 @@ object LazyVerificationPhase { val (instSpecs, rest) = args.partition(accessesSecondRes(_, resdef.id)) (createAnd(instSpecs), createAnd(rest)) case l: Let => - val (letsCons, letsBody) = letStarUnapply(l) + val (letsCons, letsBody) = letStarUnapplyWithSimplify(l) letsBody match { case And(args) => val (instSpecs, rest) = args.partition(accessesSecondRes(_, resdef.id)) @@ -196,8 +178,10 @@ object LazyVerificationPhase { } case e => (e, Util.tru) } - val ants = createAnd(Seq(fd.precOrTrue, assumptions, Equals(resdef.id.toVariable, fd.body.get))) - (ants, instPost, tmpl) + val ants = + if (fd.usePost) createAnd(Seq(fd.precOrTrue, assumptions)) + else fd.precOrTrue + (Equals(resdef.id.toVariable, fd.body.get), ants, instPost, tmpl) } def checkVCs(vcs: List[VC], checkCtx: LeonContext, p: Program) = { @@ -223,10 +207,15 @@ object LazyVerificationPhase { class VCSolver(ctx: InferenceContext, p: Program, rootFd: FunDef) extends UnfoldingTemplateSolver(ctx, p, rootFd) { - override def constructVC(fd: FunDef): (Expr, Expr) = { - val (ants, post, tmpl) = createVC(rootFd) + override def constructVC(fd: FunDef): (Expr, Expr, Expr) = { + val (body, ants, post, tmpl) = collectAntsPostTmpl(rootFd) val conseq = matchToIfThenElse(createAnd(Seq(post, tmpl.getOrElse(Util.tru)))) - (matchToIfThenElse(ants), conseq) + //println(s"body: $body ants: $ants conseq: $conseq") + (matchToIfThenElse(body), matchToIfThenElse(ants), conseq) + } + + override def verifyVC(newprog: Program, newroot: FunDef) = { + solveUsingLeon(contextForChecks(ctx.leonContext), newprog, vcForFun(newroot)) } } } diff --git a/src/main/scala/leon/transformations/InstrumentationUtil.scala b/src/main/scala/leon/transformations/InstrumentationUtil.scala index 6d0a97b577c2bccd5733f5196346d10a73d8db0e..8e4c10451080f233aa1ac1bff4ae56dafb32d616 100644 --- a/src/main/scala/leon/transformations/InstrumentationUtil.scala +++ b/src/main/scala/leon/transformations/InstrumentationUtil.scala @@ -124,4 +124,18 @@ object InstUtil { val newres = FreshIdentifier(resvar.id.name, resvar.getType).toVariable replace(getInstVariableMap(fd) + (TupleSelect(resvar, 1) -> newres), e) } + + /** + * Checks if the given expression is a resource bound of the given function. + */ + def isResourceBoundOf(fd: FunDef)(e: Expr) = { + val instExprs = InstTypes.map(getInstExpr(fd, _)).collect { + case Some(inste) => inste + }.toSet + !instExprs.isEmpty && isArithmeticRelation(e).get && + exists { + case sub: TupleSelect => instExprs(sub) + case _ => false + }(e) + } } diff --git a/src/main/scala/leon/transformations/NonlinearityEliminationPhase.scala b/src/main/scala/leon/transformations/NonlinearityEliminationPhase.scala index ecd0bca4b1dbf343be42f8c1e43b5e099de4b66f..6d740de6aff8b077d271377740e86c7fa5d327ab 100644 --- a/src/main/scala/leon/transformations/NonlinearityEliminationPhase.scala +++ b/src/main/scala/leon/transformations/NonlinearityEliminationPhase.scala @@ -173,15 +173,17 @@ class NonlinearityEliminator(skipAxioms: Boolean, domain: TypeTree) { } else None fd.flags.foreach(newfd.addFlag(_)) - }) - - val newprog = copyProgram(program, (defs: Seq[Definition]) => { + }) + val transProg = copyProgram(program, (defs: Seq[Definition]) => { defs.map { case fd: FunDef => newFundefs(fd) case d => d - } ++ (if (addMult) Seq(multFun, pivMultFun) else Seq()) + } }) - + val newprog = + if (addMult) + addDefs(transProg, Seq(multFun, pivMultFun), transProg.units.find(_.isMainUnit).get.definedFunctions.last) + else transProg if (debugNLElim) println("After Nonlinearity Elimination: \n" + ScalaPrinter.apply(newprog)) diff --git a/src/test/scala/leon/regression/orb/OrbInstrumentationTestSuite.scala b/src/test/scala/leon/regression/orb/OrbInstrumentationTestSuite.scala index 692951f2eb37100a2782605a06701e28d9905334..82288376e4b484ebd04ae125110ffa8f2c967157 100644 --- a/src/test/scala/leon/regression/orb/OrbInstrumentationTestSuite.scala +++ b/src/test/scala/leon/regression/orb/OrbInstrumentationTestSuite.scala @@ -37,7 +37,7 @@ class OrbInstrumentationTestSuite extends LeonRegressionSuite { // check properties. val (ctx3, instProg) = processPipe.run(ctx2, program) val sizeFun = instProg.definedFunctions.find(_.id.name.startsWith("size")) - if(!sizeFun.isDefined || !sizeFun.get.returnType.isInstanceOf[TupleType]) + if (!sizeFun.isDefined || !sizeFun.get.returnType.isInstanceOf[TupleType]) fail("Error in instrumentation") } diff --git a/src/test/scala/leon/regression/orb/OrbRegressionSuite.scala b/src/test/scala/leon/regression/orb/OrbRegressionSuite.scala index ebde5cc5d0e1d93638b0e4698306196827f8eaa5..00c01cf27fa7db43963028fda40bd9f7c3b8fa17 100644 --- a/src/test/scala/leon/regression/orb/OrbRegressionSuite.scala +++ b/src/test/scala/leon/regression/orb/OrbRegressionSuite.scala @@ -19,7 +19,7 @@ class OrbRegressionSuite extends LeonRegressionSuite { } private def testInference(f: File, bound: Option[Int] = None) { - val ctx = createLeonContext("--inferInv","--solvers=smt-z3") + val ctx = createLeonContext("--inferInv", "--vcTimeout=3", "--solvers=smt-z3") val beginPipe = leon.frontends.scalac.ExtractionPhase andThen new leon.utils.PreprocessingPhase val (ctx2, program) = beginPipe.run(ctx, f.getAbsolutePath :: Nil) diff --git a/src/test/scala/leon/test/helpers/ExpressionsDSL.scala b/src/test/scala/leon/test/helpers/ExpressionsDSL.scala index a7c9692e151ca0b9813a41cd431981f7e28bd5b4..715a874e6d7d29b471ee923dff682fc19b26881f 100644 --- a/src/test/scala/leon/test/helpers/ExpressionsDSL.scala +++ b/src/test/scala/leon/test/helpers/ExpressionsDSL.scala @@ -76,5 +76,4 @@ trait ExpressionsDSL { val tfd = funDef(name).typed(Seq()) FunctionInvocation(tfd, args.toSeq) } - } diff --git a/src/test/scala/leon/unit/orb/OrbUnitTestSuite.scala b/src/test/scala/leon/unit/orb/OrbUnitTestSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..32229100b782b154a360843d2a4d3dd6ffa4f210 --- /dev/null +++ b/src/test/scala/leon/unit/orb/OrbUnitTestSuite.scala @@ -0,0 +1,82 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon.unit.orb + +import leon._ +import leon.test._ +import leon.purescala.Expressions._ +import leon.purescala.Types._ +import leon.purescala.Common._ +import leon.purescala.ExprOps._ +import leon.invariant.util.LetTupleSimplification +import scala.math.BigInt.int2bigInt +import leon.purescala.Definitions._ +import leon.invariant.engine._ +import leon.transformations._ +import java.io.File +import leon.purescala.Types.TupleType +import leon.utils._ +import invariant.structure.LinearConstraintUtil._ +import leon.invariant.util.ExpressionTransformer._ +import invariant.structure._ +import invariant.util._ +import ProgramUtil._ +import Util.zero + +class OrbUnitTestSuite extends LeonTestSuite { + val a = FreshIdentifier("a", IntegerType).toVariable + val b = FreshIdentifier("b", IntegerType).toVariable + val c = FreshIdentifier("c", IntegerType).toVariable + val d = FreshIdentifier("d", IntegerType).toVariable + val l42 = InfiniteIntegerLiteral(42) + val l43 = InfiniteIntegerLiteral(43) + val mtwo = InfiniteIntegerLiteral(-2) + + test("Pull lets to top with tuples and tuple select") { ctx => + val in = TupleSelect(Tuple(Seq(a, b)), 1) + val exp = in + val out = LetTupleSimplification.removeLetsFromLetValues(in) + assert(out === exp) + } + + test("TestElimination") {ctx => + val exprs = Seq(Equals(a, b), Equals(c, Plus(a, b)), GreaterEquals(Plus(c, d), zero)) + println("Exprs: "+exprs) + val elimVars = Set(a, b, c).map(_.id) + val ctrs = exprs map ConstraintUtil.createConstriant + val nctrs = apply1PRuleOnDisjunct(ctrs.collect{ case c: LinearConstraint => c }, elimVars, None) + //println("Constraints after elimination: "+nctrs) + assert(nctrs.size == 1) + } + + test("TestElimination2") {ctx => + val exprs = Seq(Equals(zero, Plus(a, b)), Equals(a, zero), GreaterEquals(Plus(b, c), zero)) + println("Exprs: "+exprs) + val elimVars = Set(a, b).map(_.id) + val ctrs = exprs map ConstraintUtil.createConstriant + val nctrs = apply1PRuleOnDisjunct(ctrs.collect{ case c: LinearConstraint => c }, elimVars, None) + //println("Constraints after elimination: "+nctrs) + assert(nctrs.size == 1) + } + + def createLeonContext(opts: String*): LeonContext = { + val reporter = new TestSilentReporter + Main.processOptions(opts.toList).copy(reporter = reporter, interruptManager = new InterruptManager(reporter)) + } + + def scalaExprToTree(scalaExpr: String): Expr = { + val ctx = createLeonContext() + val testFilename = toTempFile(s""" + import leon.annotation._ + object Test { def test() = { $scalaExpr } }""") + val beginPipe = leon.frontends.scalac.ExtractionPhase andThen + new leon.utils.PreprocessingPhase + val (_, program) = beginPipe.run(ctx, testFilename) + //println("Program: "+program) + functionByName("test", program).get.body.get + } + + def toTempFile(content: String): List[String] = { + TemporaryInputPhase.run(createLeonContext(), (List(content), Nil))._2 + } +} diff --git a/src/test/scala/leon/unit/orb/SimplifyLetsSuite.scala b/src/test/scala/leon/unit/orb/SimplifyLetsSuite.scala deleted file mode 100644 index b267e5d0f9e53cfd8d878c8910ec0ca49ea21b3f..0000000000000000000000000000000000000000 --- a/src/test/scala/leon/unit/orb/SimplifyLetsSuite.scala +++ /dev/null @@ -1,26 +0,0 @@ -/* Copyright 2009-2015 EPFL, Lausanne */ - -package leon.unit.orb - -import leon.test._ -import leon.purescala.Expressions._ -import leon.purescala.Types._ -import leon.purescala.Common._ -import leon.purescala.ExprOps._ -import leon.invariant.util.LetTupleSimplification -import scala.math.BigInt.int2bigInt - -class SimplifyLetsSuite extends LeonTestSuite { - val a = FreshIdentifier("a", IntegerType) - val b = FreshIdentifier("b", IntegerType) - val c = FreshIdentifier("c", IntegerType) - val l42 = InfiniteIntegerLiteral(42) - val l43 = InfiniteIntegerLiteral(43) - - test("Pull lets to top with tuples and tuple select") { ctx => - val in = TupleSelect(Tuple(Seq(a.toVariable, b.toVariable)), 1) - val exp = in - val out = LetTupleSimplification.removeLetsFromLetValues(in) - assert(out === exp) - } -} diff --git a/testcases/lazy-datastructures/conc/ConcTrees.scala b/testcases/lazy-datastructures/conc/ConcTrees.scala index ce3c54ecf401ca1d03ea9a55f347103910914e23..6bfb6642dd35e88d9ad7e853cf69b21498534dc4 100644 --- a/testcases/lazy-datastructures/conc/ConcTrees.scala +++ b/testcases/lazy-datastructures/conc/ConcTrees.scala @@ -5,9 +5,14 @@ import leon.collection._ import leon.lang._ import ListSpecs._ import leon.annotation._ +import leon.invariant._ +/** + * For better performance use --disableInfer on this benchmark + */ object ConcTrees { + @inline def max(x: BigInt, y: BigInt): BigInt = if (x >= y) x else y def abs(x: BigInt): BigInt = if (x < 0) -x else x @@ -19,11 +24,11 @@ object ConcTrees { def isLeaf: Boolean = { this match { - case Empty() => true + case Empty() => true case Single(_) => true - case _ => false + case _ => false } - } + } def valid: Boolean = { concInv && balanced @@ -35,7 +40,7 @@ object ConcTrees { */ def concInv: Boolean = this match { case CC(l, r) => - !l.isEmpty && !r.isEmpty && + !l.isEmpty && !r.isEmpty && l.concInv && r.concInv case _ => true } @@ -51,7 +56,7 @@ object ConcTrees { val level: BigInt = { (this match { - case Empty() => 0 + case Empty() => 0 case Single(x) => 0 case CC(l, r) => 1 + max(l.level, r.level) @@ -60,7 +65,7 @@ object ConcTrees { val size: BigInt = { (this match { - case Empty() => 0 + case Empty() => 0 case Single(x) => 1 case CC(l, r) => l.size + r.size @@ -69,10 +74,10 @@ object ConcTrees { def toList: List[T] = { this match { - case Empty() => Nil[T]() + case Empty() => Nil[T]() case Single(x) => Cons(x, Nil[T]()) case CC(l, r) => - l.toList ++ r.toList // note: left elements precede the right elements in the list + l.toList ++ r.toList // note: left elements precede the right elements in the list } } ensuring (res => res.size == this.size) } @@ -86,25 +91,63 @@ object ConcTrees { override def toString = s"Chunk(${array.mkString("", ", ", "")}; $size; $k)" }*/ - @library - def lookup[T](xs: Conc[T], i: BigInt): (T, BigInt) = { + @invisibleBody + def concatNonEmpty[T](xs: Conc[T], ys: Conc[T]): Conc[T] = { + require(xs.valid && ys.valid && !xs.isEmpty && !ys.isEmpty) + val diff = ys.level - xs.level + if (diff >= -1 && diff <= 1) CC(xs, ys) + else if (diff < -1) { // ys is smaller than xs + xs match { + case CC(l, r) if (l.level >= r.level) => + CC(l, concatNonEmpty(r, ys)) + case CC(l, r) => + r match { + case CC(rl, rr) => + val nrr = concatNonEmpty(rr, ys) + if (nrr.level == xs.level - 3) + CC(l, CC(rl, nrr)) + else + CC(CC(l, rl), nrr) + } + } + } else ys match { + case CC(l, r) if (r.level >= l.level) => + CC(concatNonEmpty(xs, l), r) + case CC(l, r) => + l match { + case CC(ll, lr) => + val nll = concatNonEmpty(xs, ll) + if (nll.level == ys.level - 3) + CC(CC(nll, lr), r) + else + CC(nll, CC(lr, r)) + } + } + } ensuring (res => + appendAssocInst(xs, ys) && // instantiation of an axiom + res.level <= max(xs.level, ys.level) + 1 && // height invariants + res.level >= max(xs.level, ys.level) && + res.balanced && res.concInv && //this is should not be needed. But, seems necessary for leon + res.valid && // tree invariant is preserved + res.toList == xs.toList ++ ys.toList && // correctness + time <= 39 * abs(xs.level - ys.level) + 9) // time bounds + + @invisibleBody + def lookup[T](xs: Conc[T], i: BigInt): T = { require(xs.valid && !xs.isEmpty && i >= 0 && i < xs.size) xs match { - case Single(x) => (x, 0) + case Single(x) => x case CC(l, r) => - if (i < l.size) { - val (res, t) = lookup(l, i) - (res, t + 1) - } else { - val (res, t) = lookup(r, i - l.size) - (res, t + 1) - } + if (i < l.size) lookup(l, i) + else lookup(r, i - l.size) } - } ensuring (res => res._2 <= xs.level && // lookup time is linear in the height - res._1 == xs.toList(i) && // correctness - instAppendIndexAxiom(xs, i)) // an auxiliary axiom instantiation that required for the proof + } ensuring (res => + // axiom instantiation + instAppendIndexAxiom(xs, i) && + res == xs.toList(i) && // correctness + time <= ? * xs.level + ?) // lookup time is linear in the height - @library + @invisibleBody def instAppendIndexAxiom[T](xs: Conc[T], i: BigInt): Boolean = { require(0 <= i && i < xs.size) xs match { @@ -114,27 +157,22 @@ object ConcTrees { } }.holds - @library - def update[T](xs: Conc[T], i: BigInt, y: T): (Conc[T], BigInt) = { + @invisibleBody + def update[T](xs: Conc[T], i: BigInt, y: T): Conc[T] = { require(xs.valid && !xs.isEmpty && i >= 0 && i < xs.size) xs match { - case Single(x) => (Single(y), 0) + case Single(x) => Single(y) case CC(l, r) => - if (i < l.size) { - val (nl, t) = update(l, i, y) - (CC(nl, r), t + 1) - } else { - val (nr, t) = update(r, i - l.size, y) - (CC(l, nr), t + 1) - } + if (i < l.size) CC(update(l, i, y), r) + else CC(l, update(r, i - l.size, y)) } - } ensuring (res => res._1.level == xs.level && // heights of the input and output trees are equal - res._1.valid && // tree invariants are preserved - res._2 <= xs.level && // update time is linear in the height of the tree - res._1.toList == xs.toList.updated(i, y) && // correctness - instAppendUpdateAxiom(xs, i, y)) // an auxiliary axiom instantiation + } ensuring (res => instAppendUpdateAxiom(xs, i, y) && // an auxiliary axiom instantiation + res.level == xs.level && // heights of the input and output trees are equal + res.valid && // tree invariants are preserved + res.toList == xs.toList.updated(i, y) && // correctness + time <= ? * xs.level + ?) // update time is linear in the height of the tree - @library + @invisibleBody def instAppendUpdateAxiom[T](xs: Conc[T], i: BigInt, y: T): Boolean = { require(i >= 0 && i < xs.size) xs match { @@ -144,104 +182,11 @@ object ConcTrees { } }.holds - /** - * A generic concat that applies to general concTrees - */ - /*def concat[T](xs: Conc[T], ys: Conc[T]): (Conc[T], BigInt) = { - require(xs.valid && ys.valid) - val (nxs, t1) = normalize(xs) - val (nys, t2) = normalize(ys) - val (res, t3) = concatNormalized(nxs, nys) - (res, t1 + t2 + t3) - }*/ - - /** - * This concat applies only to normalized trees. - * This prevents concat from being recursive - */ - @library - def concatNormalized[T](xs: Conc[T], ys: Conc[T]): (Conc[T], BigInt) = { - require(xs.valid && ys.valid) - (xs, ys) match { - case (xs, Empty()) => (xs, 0) - case (Empty(), ys) => (ys, 0) - case _ => - concatNonEmpty(xs, ys) - } - } ensuring (res => res._1.valid && // tree invariants - res._1.level <= max(xs.level, ys.level) + 1 && // height invariants - res._1.level >= max(xs.level, ys.level) && - (res._1.toList == xs.toList ++ ys.toList) // correctness - ) - - @library - def concatNonEmpty[T](xs: Conc[T], ys: Conc[T]): (Conc[T], BigInt) = { - require(xs.valid && ys.valid && - !xs.isEmpty && !ys.isEmpty) - - val diff = ys.level - xs.level - if (diff >= -1 && diff <= 1) - (CC(xs, ys), 0) - else if (diff < -1) { - // ys is smaller than xs - xs match { - case CC(l, r) => - if (l.level >= r.level) { - val (nr, t) = concatNonEmpty(r, ys) - (CC(l, nr), t + 1) - } else { - r match { - case CC(rl, rr) => - val (nrr, t) = concatNonEmpty(rr, ys) - if (nrr.level == xs.level - 3) { - val nl = l - val nr = CC(rl, nrr) - (CC(nl, nr), t + 1) - } else { - val nl = CC(l, rl) - val nr = nrr - (CC(nl, nr), t + 1) - } - } - } - } - } else { - ys match { - case CC(l, r) => - if (r.level >= l.level) { - val (nl, t) = concatNonEmpty(xs, l) - (CC(nl, r), t + 1) - } else { - l match { - case CC(ll, lr) => - val (nll, t) = concatNonEmpty(xs, ll) - if (nll.level == ys.level - 3) { - val nl = CC(nll, lr) - val nr = r - (CC(nl, nr), t + 1) - } else { - val nl = nll - val nr = CC(lr, r) - (CC(nl, nr), t + 1) - } - } - } - } - } - } ensuring (res => res._2 <= abs(xs.level - ys.level) && // time bound - res._1.level <= max(xs.level, ys.level) + 1 && // height invariants - res._1.level >= max(xs.level, ys.level) && - res._1.balanced && res._1.concInv && //this is should not be needed. But, seems necessary for leon - res._1.valid && // tree invariant is preserved - res._1.toList == xs.toList ++ ys.toList && // correctness - appendAssocInst(xs, ys) // instantiation of an axiom - ) - - @library + @invisibleBody def appendAssocInst[T](xs: Conc[T], ys: Conc[T]): Boolean = { (xs match { case CC(l, r) => - appendAssoc(l.toList, r.toList, ys.toList) && //instantiation of associativity of concatenation + appendAssoc(l.toList, r.toList, ys.toList) && //instantiation of associativity of concatenation (r match { case CC(rl, rr) => appendAssoc(rl.toList, rr.toList, ys.toList) && @@ -263,38 +208,61 @@ object ConcTrees { }) }.holds - @library - def insert[T](xs: Conc[T], i: BigInt, y: T): (Conc[T], BigInt) = { - require(xs.valid && i >= 0 && i <= xs.size) //note the precondition + /** + * A generic concat that applies to general concTrees + */ + // @invisibleBody + // def concat[T](xs: Conc[T], ys: Conc[T]): Conc[T] = { + // require(xs.valid && ys.valid) + // concatNormalized(normalize(xs), normalize(ys)) + // } + + /** + * This concat applies only to normalized trees. + * This prevents concat from being recursive + */ + @invisibleBody + def concatNormalized[T](xs: Conc[T], ys: Conc[T]): Conc[T] = { + require(xs.valid && ys.valid) + (xs, ys) match { + case (xs, Empty()) => xs + case (Empty(), ys) => ys + case _ => concatNonEmpty(xs, ys) + } + } ensuring (res => res.valid && // tree invariants + res.level <= max(xs.level, ys.level) + 1 && // height invariants + res.level >= max(xs.level, ys.level) && + (res.toList == xs.toList ++ ys.toList) && // correctness + time <= ? * abs(xs.level - ys.level) + ?) + + @invisibleBody + def insert[T](xs: Conc[T], i: BigInt, y: T): Conc[T] = { + //xs.valid && + require(xs.concInv && xs.balanced && i >= 0 && i <= xs.size) //note the precondition xs match { - case Empty() => (Single(y), 0) + case Empty() => Single(y) case Single(x) => - if (i == 0) - (CC(Single(y), xs), 0) - else - (CC(xs, Single(y)), 0) + if (i == 0) CC(Single(y), xs) + else CC(xs, Single(y)) case CC(l, r) if i < l.size => - val (nl, t) = insert(l, i, y) - val (res, t1) = concatNonEmpty(nl, r) - (res, t + t1 + 1) + concatNonEmpty(insert(l, i, y), r) case CC(l, r) => - val (nr, t) = insert(r, i - l.size, y) - val (res, t1) = concatNonEmpty(l, nr) - (res, t + t1 + 1) + concatNonEmpty(l, insert(r, i - l.size, y)) } - } ensuring (res => res._1.valid && // tree invariants - res._1.level - xs.level <= 1 && res._1.level >= xs.level && // height of the output tree is at most 1 greater than that of the input tree - res._2 <= 3 * xs.level && // time is linear in the height of the tree - res._1.toList == insertAtIndex(xs.toList, i, y) && // correctness - insertAppendAxiomInst(xs, i, y) // instantiation of an axiom - ) + } ensuring (res => + insertAppendAxiomInst(xs, i, y) && // instantiation of an axiom + res.valid && // tree invariants + res.level - xs.level <= 1 && res.level >= xs.level && // height of the output tree is at most 1 greater than that of the input tree + res.toList == insertAtIndex(xs.toList, i, y) && // correctness + time <= ? * xs.level + ? // time is linear in the height of the tree + ) /** * Using a different version of insert than of the library * because the library implementation in unnecessarily complicated. * TODO: update the code to use the library instead ? */ - @library + @invisibleBody def insertAtIndex[T](l: List[T], i: BigInt, y: T): List[T] = { require(0 <= i && i <= l.size) l match { @@ -308,11 +276,11 @@ object ConcTrees { } // A lemma about `append` and `insertAtIndex` - @library + @invisibleBody def appendInsertIndex[T](l1: List[T], l2: List[T], i: BigInt, y: T): Boolean = { require(0 <= i && i <= l1.size + l2.size) (l1 match { - case Nil() => true + case Nil() => true case Cons(x, xs) => if (i == 0) true else appendInsertIndex[T](xs, l2, i - 1, y) }) && // lemma @@ -321,17 +289,17 @@ object ConcTrees { else l1 ++ insertAtIndex(l2, (i - l1.size), y))) }.holds - @library + @invisibleBody def insertAppendAxiomInst[T](xs: Conc[T], i: BigInt, y: T): Boolean = { require(i >= 0 && i <= xs.size) xs match { case CC(l, r) => appendInsertIndex(l.toList, r.toList, i, y) - case _ => true + case _ => true } }.holds //TODO: why with instrumentation we are not able prove the running time here ? (performance bug ?) - @library + /*@library def split[T](xs: Conc[T], n: BigInt): (Conc[T], Conc[T], BigInt) = { require(xs.valid) xs match { @@ -356,11 +324,11 @@ object ConcTrees { (l, r, BigInt(0)) } } - } ensuring (res => res._1.valid && res._2.valid && // tree invariants are preserved + } ensuring (res => res._1.valid && res._2.valid && // tree invariants are preserved xs.level >= res._1.level && xs.level >= res._2.level && // height bounds of the resulting tree res._3 <= xs.level + res._1.level + res._2.level && // time is linear in height res._1.toList == xs.toList.take(n) && res._2.toList == xs.toList.drop(n) && // correctness - instSplitAxiom(xs, n) // instantiation of an axiom + instSplitAxiom(xs, n) // instantiation of an axiom ) @library @@ -370,5 +338,5 @@ object ConcTrees { appendTakeDrop(l.toList, r.toList, n) case _ => true } - }.holds + }.holds*/ } diff --git a/testcases/lazy-datastructures/withOrb/BottomUpMegeSort.scala b/testcases/lazy-datastructures/withOrb/BottomUpMegeSort.scala index bf9387fd80f40f8c3a9605a521f1a986d4d02515..10a0047cf2813906ad084be22253b656d6fed512 100644 --- a/testcases/lazy-datastructures/withOrb/BottomUpMegeSort.scala +++ b/testcases/lazy-datastructures/withOrb/BottomUpMegeSort.scala @@ -45,6 +45,7 @@ object BottomUpMergeSort { } case class SCons(x: BigInt, tail: $[IStream]) extends IStream case class SNil() extends IStream + @inline def ssize(l: $[IStream]): BigInt = (l*).size /** @@ -81,6 +82,7 @@ object BottomUpMergeSort { * on each pair. * Takes time linear in the size of the input list */ + @invisibleBody def pairs(l: LList): LList = { require(l.valid) l match { @@ -98,6 +100,7 @@ object BottomUpMergeSort { * Create a linearized tree of merges e.g. merge(merge(2, 1), merge(17, 19)). * Takes time linear in the size of the input list. */ + @invisibleBody def constructMergeTree(l: LList): LList = { require(l.valid) l match { @@ -120,6 +123,8 @@ object BottomUpMergeSort { * Note: the sorted stream of integers may by recursively constructed using merge. * Takes time linear in the size of the streams (non-trivial to prove due to cascading of lazy calls) */ + @invisibleBody + @usePost def merge(a: $[IStream], b: $[IStream]): IStream = { require(((a*) != SNil() || b.isEvaluated) && // if one of the arguments is Nil then the other is evaluated ((b*) != SNil() || a.isEvaluated) && @@ -142,6 +147,7 @@ object BottomUpMergeSort { /** * Converts a list of integers to a list of streams of integers */ + @invisibleBody def IListToLList(l: IList): LList = { l match { case INil() => LNil() diff --git a/testcases/lazy-datastructures/withOrb/Deque.scala b/testcases/lazy-datastructures/withOrb/Deque.scala index 870eed380bc729186c5eeafd718cd4938f16c2d7..6493eb8c8b2bef8b3955c31c71c8ffe4eedfb6f2 100644 --- a/testcases/lazy-datastructures/withOrb/Deque.scala +++ b/testcases/lazy-datastructures/withOrb/Deque.scala @@ -191,6 +191,7 @@ object RealTimeDeque { * A function that takes streams where the size of front and rear streams violate * the balance invariant, and restores the balance. */ + @invisibleBody def createQueue[T](f: $[Stream[T]], lenf: BigInt, sf: $[Stream[T]], r: $[Stream[T]], lenr: BigInt, sr: $[Stream[T]]): Queue[T] = { require(firstUneval(f) == firstUneval(sf) && @@ -219,12 +220,16 @@ object RealTimeDeque { } ensuring(res => res.valid && time <= ?) + + @invisibleBody + def funeEqual[T](s1: $[Stream[T]], s2: $[Stream[T]]) = firstUneval(s1) == firstUneval(s2) + /** * Forces the schedules, and ensures that `firstUneval` equality is preserved */ + @invisibleBody def force[T](tar: $[Stream[T]], htar: $[Stream[T]], other: $[Stream[T]], hother: $[Stream[T]]): $[Stream[T]] = { - require(firstUneval(tar) == firstUneval(htar) && - firstUneval(other) == firstUneval(hother)) + require(funeEqual(tar, htar) && funeEqual(other, hother)) tar.value match { case SCons(_, tail) => tail case _ => tar @@ -246,12 +251,13 @@ object RealTimeDeque { /** * Forces the schedules in the queue twice and ensures the `firstUneval` property. */ + @invisibleBody def forceTwice[T](q: Queue[T]): ($[Stream[T]], $[Stream[T]]) = { require(q.valid) val nsf = force(force(q.sf, q.f, q.r, q.sr), q.f, q.r, q.sr) // forces q.sf twice val nsr = force(force(q.sr, q.r, q.f, nsf), q.r, q.f, nsf) // forces q.sr twice (nsf, nsr) - } + } ensuring(time <= ?) // the following properties are ensured, but need not be stated /*ensuring (res => { val nsf = res._1 diff --git a/testcases/lazy-datastructures/withOrb/LazyNumericalRep.scala b/testcases/lazy-datastructures/withOrb/LazyNumericalRep.scala new file mode 100644 index 0000000000000000000000000000000000000000..47a42bfa1d6d891a843db9767ac94562bc343c0e --- /dev/null +++ b/testcases/lazy-datastructures/withOrb/LazyNumericalRep.scala @@ -0,0 +1,479 @@ +package orb + +import leon.lazyeval._ +import leon.lang._ +import leon.annotation._ +import leon.instrumentation._ +import leon.lazyeval.$._ +import leon.invariant._ + +object DigitObject { + sealed abstract class Digit + case class Zero() extends Digit + case class One() extends Digit +} + +import DigitObject._ +object LazyNumericalRep { + + sealed abstract class NumStream { + val isSpine: Boolean = this match { + case Spine(_, _, _) => true + case _ => false + } + val isTip = !isSpine + } + + case class Tip() extends NumStream + case class Spine(head: Digit, createdWithSuspension: Bool, rear: $[NumStream]) extends NumStream + + sealed abstract class Bool + case class True() extends Bool + case class False() extends Bool + + /** + * Checks whether there is a zero before an unevaluated closure + */ + def zeroPrecedeLazy[T](q: $[NumStream]): Boolean = { + if (q.isEvaluated) { + q* match { + case Spine(Zero(), _, rear) => + true // here we have seen a zero + case Spine(_, _, rear) => + zeroPrecedeLazy(rear) //here we have not seen a zero + case Tip() => true + } + } else false + } + + /** + * Checks whether there is a zero before a given suffix + */ + def zeroPrecedeSuf[T](q: $[NumStream], suf: $[NumStream]): Boolean = { + if (q != suf) { + q* match { + case Spine(Zero(), _, rear) => true + case Spine(_, _, rear) => + zeroPrecedeSuf(rear, suf) + case Tip() => false + } + } else false + } + + /** + * Everything until suf is evaluated. This + * also asserts that suf should be a suffix of the list + */ + def concreteUntil[T](l: $[NumStream], suf: $[NumStream]): Boolean = { + if (l != suf) { + l.isEvaluated && (l* match { + case Spine(_, cws, tail) => + concreteUntil(tail, suf) + case _ => + false + }) + } else true + } + + def isConcrete[T](l: $[NumStream]): Boolean = { + l.isEvaluated && (l* match { + case Spine(_, _, tail) => + isConcrete(tail) + case _ => true + }) + } + + sealed abstract class Scheds + case class Cons(h: $[NumStream], tail: Scheds) extends Scheds + case class Nil() extends Scheds + + def schedulesProperty[T](q: $[NumStream], schs: Scheds): Boolean = { + schs match { + case Cons(head, tail) => + head* match { + case Spine(Zero(), _, _) => // head starts with zero + head.isSuspension(incLazy _) && + concreteUntil(q, head) && + schedulesProperty(pushUntilCarry(head), tail) + case _ => + false + } + case Nil() => + isConcrete(q) + } + } + + @invisibleBody + def strongSchedsProp[T](q: $[NumStream], schs: Scheds) = { + q.isEvaluated && { + schs match { + case Cons(head, tail) => + zeroPrecedeSuf(q, head) // zeroPrecedeSuf holds initially + case Nil() => true + } + } && + schedulesProperty(q, schs) + } + + /** + * Note: if 'q' has a suspension then it would have a carry. + */ + @invisibleBody + def pushUntilCarry[T](q: $[NumStream]): $[NumStream] = { + q* match { + case Spine(Zero(), _, rear) => // if we push a carry and get back 0 then there is a new carry + pushUntilCarry(rear) + case Spine(_, _, rear) => // if we push a carry and get back 1 then there the carry has been fully pushed + rear + case Tip() => + q + } + } + + case class Number(digs: $[NumStream], schedule: Scheds) { + val valid = strongSchedsProp(digs, schedule) + } + + @invisibleBody + def inc(xs: $[NumStream]): NumStream = { + require(zeroPrecedeLazy(xs)) + xs.value match { + case Tip() => + Spine(One(), False(), xs) + case s @ Spine(Zero(), _, rear) => + Spine(One(), False(), rear) + case s @ Spine(_, _, _) => + incLazy(xs) + } + } ensuring (_ => time <= ?) + + @invisibleBody + @invstate + def incLazy(xs: $[NumStream]): NumStream = { + require(zeroPrecedeLazy(xs) && + (xs* match { + case Spine(h, _, _) => h != Zero() // xs doesn't start with a zero + case _ => false + })) + xs.value match { + case Spine(head, _, rear) => // here, rear is guaranteed to be evaluated by 'zeroPrecedeLazy' invariant + val carry = One() + rear.value match { + case s @ Spine(Zero(), _, srear) => + val tail: NumStream = Spine(carry, False(), srear) + Spine(Zero(), False(), tail) + + case s @ Spine(_, _, _) => + Spine(Zero(), True(), $(incLazy(rear))) + + case t @ Tip() => + val y: NumStream = Spine(carry, False(), rear) + Spine(Zero(), False(), y) + } + } + } ensuring { res => + (res match { + case Spine(Zero(), _, rear) => + (!isConcrete(xs) || isConcrete(pushUntilCarry(rear))) && + { + val _ = rear.value // this is necessary to assert properties on the state in the recursive invocation (and note this cannot go first) + rear.isEvaluated // this is a tautology + } + case _ => + false + }) && + time <= ? + } + + /** + * Lemma: + * forall suf. suf*.head != Zero() ^ zeroPredsSuf(xs, suf) ^ concUntil(xs.tail.tail, suf) => concUntil(push(rear), suf) + */ + @invisibleBody + @invstate + def incLazyLemma[T](xs: $[NumStream], suf: $[NumStream]): Boolean = { + require(zeroPrecedeSuf(xs, suf) && + (xs* match { + case Spine(h, _, _) => h != Zero() + case _ => false + }) && + (suf* match { + case Spine(Zero(), _, _) => + concreteUntil(xs, suf) + case _ => false + })) + // induction scheme + (xs* match { + case Spine(head, _, rear) => + rear* match { + case s @ Spine(h, _, _) => + if (h != Zero()) + incLazyLemma(rear, suf) + else true + case _ => true + } + }) && + // instantiate the lemma that implies zeroPrecedeLazy + (if (zeroPredSufConcreteUntilLemma(xs, suf)) { + // property + (incLazy(xs) match { + case Spine(Zero(), _, rear) => + concreteUntil(pushUntilCarry(rear), suf) + }) + } else false) + } holds + + @invisibleBody + def incNum[T](w: Number) = { + require(w.valid && + // instantiate the lemma that implies zeroPrecedeLazy + (w.schedule match { + case Cons(h, _) => + zeroPredSufConcreteUntilLemma(w.digs, h) + case _ => + concreteZeroPredLemma(w.digs) + })) + val nq = inc(w.digs) + val nsched = nq match { + case Spine(Zero(), createdWithSusp, rear) => + if (createdWithSusp == True()) + Cons(rear, w.schedule) // this is the only case where we create a new lazy closure + else + w.schedule + case _ => + w.schedule + } + val lq: $[NumStream] = nq + (lq, nsched) + } ensuring { res => + // lemma instantiations + (w.schedule match { + case Cons(head, tail) => + w.digs* match { + case Spine(h, _, _) => + if (h != Zero()) + incLazyLemma(w.digs, head) + else true + case _ => true + } + case _ => true + }) && + schedulesProperty(res._1, res._2) && + time <= ? + } + + @invisibleBody + def Pay[T](q: $[NumStream], scheds: Scheds): Scheds = { + require(schedulesProperty(q, scheds) && q.isEvaluated) + scheds match { + case c @ Cons(head, rest) => + head.value match { + case Spine(Zero(), createdWithSusp, rear) => + if (createdWithSusp == True()) + Cons(rear, rest) + else + rest + } + case Nil() => scheds + } + } ensuring { res => + { + val in = inState[NumStream] + val out = outState[NumStream] + // instantiations for proving the scheds property + (scheds match { + case Cons(head, rest) => + concUntilExtenLemma(q, head, in, out) && + (head* match { + case Spine(Zero(), _, rear) => + res match { + case Cons(rhead, rtail) => + schedMonotone(in, out, rtail, pushUntilCarry(rhead)) && + concUntilMonotone(rear, rhead, in, out) && + concUntilCompose(q, rear, rhead) + case _ => + concreteMonotone(in, out, rear) && + concUntilConcreteExten(q, rear) + } + }) + case _ => true + }) && + // instantiations for zeroPrecedeSuf property + (scheds match { + case Cons(head, rest) => + (concreteUntilIsSuffix(q, head) withState in) && + (res match { + case Cons(rhead, rtail) => + concreteUntilIsSuffix(pushUntilCarry(head), rhead) && + suffixZeroLemma(q, head, rhead) && + zeroPrecedeSuf(q, rhead) + case _ => + true + }) + case _ => + true + }) + } && // properties + schedulesProperty(q, res) && + time <= ? + } + + /** + * Pushing an element to the left of the queue preserves the data-structure invariants + */ + @invisibleBody + def incAndPay[T](w: Number) = { + require(w.valid) + + val (q, scheds) = incNum(w) + val nscheds = Pay(q, scheds) + Number(q, nscheds) + + } ensuring { res => res.valid && time <= ? } + + // monotonicity lemmas + def schedMonotone[T](st1: Set[$[NumStream]], st2: Set[$[NumStream]], scheds: Scheds, l: $[NumStream]): Boolean = { + require(st1.subsetOf(st2) && + (schedulesProperty(l, scheds) withState st1)) // here the input state is fixed as 'st1' + //induction scheme + (scheds match { + case Cons(head, tail) => + head* match { + case Spine(_, _, rear) => + concUntilMonotone(l, head, st1, st2) && + schedMonotone(st1, st2, tail, pushUntilCarry(head)) + case _ => true + } + case Nil() => + concreteMonotone(st1, st2, l) + }) && (schedulesProperty(l, scheds) withState st2) //property + } holds + + @traceInduct + def concreteMonotone[T](st1: Set[$[NumStream]], st2: Set[$[NumStream]], l: $[NumStream]): Boolean = { + ((isConcrete(l) withState st1) && st1.subsetOf(st2)) ==> (isConcrete(l) withState st2) + } holds + + @traceInduct + def concUntilMonotone[T](q: $[NumStream], suf: $[NumStream], st1: Set[$[NumStream]], st2: Set[$[NumStream]]): Boolean = { + ((concreteUntil(q, suf) withState st1) && st1.subsetOf(st2)) ==> (concreteUntil(q, suf) withState st2) + } holds + + // suffix predicates and their properties (this should be generalizable) + + def suffix[T](q: $[NumStream], suf: $[NumStream]): Boolean = { + if (q == suf) true + else { + q* match { + case Spine(_, _, rear) => + suffix(rear, suf) + case Tip() => false + } + } + } + + def properSuffix[T](l: $[NumStream], suf: $[NumStream]): Boolean = { + l* match { + case Spine(_, _, rear) => + suffix(rear, suf) + case _ => false + } + } ensuring (res => !res || (suffixDisequality(l, suf) && suf != l)) + + /** + * suf(q, suf) ==> suf(q.rear, suf.rear) + */ + @traceInduct + def suffixTrans[T](q: $[NumStream], suf: $[NumStream]): Boolean = { + suffix(q, suf) ==> ((q*, suf*) match { + case (Spine(_, _, rear), Spine(_, _, sufRear)) => + // 'sufRear' should be a suffix of 'rear1' + suffix(rear, sufRear) + case _ => true + }) + }.holds + + /** + * properSuf(l, suf) ==> l != suf + */ + def suffixDisequality[T](l: $[NumStream], suf: $[NumStream]): Boolean = { + require(properSuffix(l, suf)) + suffixTrans(l, suf) && // lemma instantiation + ((l*, suf*) match { // induction scheme + case (Spine(_, _, rear), Spine(_, _, sufRear)) => + // 'sufRear' should be a suffix of 'rear1' + suffixDisequality(rear, sufRear) + case _ => true + }) && l != suf // property + }.holds + + @traceInduct + def suffixCompose[T](q: $[NumStream], suf1: $[NumStream], suf2: $[NumStream]): Boolean = { + (suffix(q, suf1) && properSuffix(suf1, suf2)) ==> properSuffix(q, suf2) + } holds + + // properties of 'concUntil' + + @traceInduct + def concreteUntilIsSuffix[T](l: $[NumStream], suf: $[NumStream]): Boolean = { + concreteUntil(l, suf) ==> suffix(l, suf) + }.holds + + // properties that extend `concUntil` to larger portions of the queue + + @traceInduct + def concUntilExtenLemma[T](q: $[NumStream], suf: $[NumStream], st1: Set[$[NumStream]], st2: Set[$[NumStream]]): Boolean = { + ((concreteUntil(q, suf) withState st1) && st2 == st1 ++ Set(suf)) ==> + (suf* match { + case Spine(_, _, rear) => + concreteUntil(q, rear) withState st2 + case _ => true + }) + } holds + + @traceInduct + def concUntilConcreteExten[T](q: $[NumStream], suf: $[NumStream]): Boolean = { + (concreteUntil(q, suf) && isConcrete(suf)) ==> isConcrete(q) + } holds + + @traceInduct + def concUntilCompose[T](q: $[NumStream], suf1: $[NumStream], suf2: $[NumStream]): Boolean = { + (concreteUntil(q, suf1) && concreteUntil(suf1, suf2)) ==> concreteUntil(q, suf2) + } holds + + // properties that relate `concUntil`, `concrete`, `zeroPrecedeSuf` with `zeroPrecedeLazy` + // - these are used in preconditions to derive the `zeroPrecedeLazy` property + + @invisibleBody + @traceInduct + def zeroPredSufConcreteUntilLemma[T](q: $[NumStream], suf: $[NumStream]): Boolean = { + (zeroPrecedeSuf(q, suf) && concreteUntil(q, suf)) ==> zeroPrecedeLazy(q) + } holds + + @invisibleBody + @traceInduct + def concreteZeroPredLemma[T](q: $[NumStream]): Boolean = { + isConcrete(q) ==> zeroPrecedeLazy(q) + } holds + + // properties relating `suffix` an `zeroPrecedeSuf` + + def suffixZeroLemma[T](q: $[NumStream], suf: $[NumStream], suf2: $[NumStream]): Boolean = { + require(suf* match { + case Spine(Zero(), _, _) => + suffix(q, suf) && properSuffix(suf, suf2) + case _ => false + }) + suffixCompose(q, suf, suf2) && ( + // induction scheme + if (q != suf) { + q* match { + case Spine(_, _, tail) => + suffixZeroLemma(tail, suf, suf2) + case _ => + true + } + } else true) && + zeroPrecedeSuf(q, suf2) // property + }.holds +} diff --git a/testcases/lazy-datastructures/withOrb/PackratParsing.scala b/testcases/lazy-datastructures/withOrb/PackratParsing.scala new file mode 100644 index 0000000000000000000000000000000000000000..b9377da8ff6b0f3cb987dc32880b5216a90e9574 --- /dev/null +++ b/testcases/lazy-datastructures/withOrb/PackratParsing.scala @@ -0,0 +1,182 @@ +package orb + +import leon.lazyeval._ +import leon.lazyeval.Mem._ +import leon.lang._ +import leon.annotation._ +import leon.instrumentation._ +import leon.invariant._ + +/** + * The packrat parser that uses the Expressions grammar presented in Bran Ford ICFP'02 paper. + * The implementation is almost exactly as it was presented in the paper, but + * here indices are passed around between parse functions, instead of strings. + */ +object PackratParsing { + + sealed abstract class Terminal + case class Open() extends Terminal + case class Close() extends Terminal + case class Plus() extends Terminal + case class Times() extends Terminal + case class Digit() extends Terminal + + /** + * A mutable array of tokens returned by the lexer + */ + @ignore + var string = Array[Terminal]() + + /** + * looking up the ith token + */ + @extern + def lookup(i: BigInt): Terminal = { + string(i.toInt) + } ensuring(_ => time <= 1) + + sealed abstract class Result { + /** + * Checks if the index in the result (if any) is + * smaller than `i` + */ + @inline + def smallerIndex(i: BigInt) = this match { + case Parsed(m) => m < i + case _ => true + } + } + case class Parsed(rest: BigInt) extends Result + case class NoParse() extends Result + + @invisibleBody + @memoize + @invstate + def pAdd(i: BigInt): Result = { + require(depsEval(i) && + pMul(i).isCached && pPrim(i).isCached && + resEval(i, pMul(i))) // lemma inst + + // Rule 1: Add <- Mul + Add + pMul(i) match { + case Parsed(j) => + if (j > 0 && lookup(j) == Plus()) { + pAdd(j - 1) match { + case Parsed(rem) => + Parsed(rem) + case _ => + pMul(i) // Rule2: Add <- Mul + } + } else pMul(i) + case _ => + pMul(i) + } + } ensuring (res => res.smallerIndex(i) && time <= ?) + + @invisibleBody + @memoize + @invstate + def pMul(i: BigInt): Result = { + require(depsEval(i) && pPrim(i).isCached && + resEval(i, pPrim(i)) // lemma inst + ) + // Rule 1: Mul <- Prim * Mul + pPrim(i) match { + case Parsed(j) => + if (j > 0 && lookup(j) == Plus()) { + pMul(j - 1) match { + case Parsed(rem) => + Parsed(rem) + case _ => + pPrim(i) // Rule2: Mul <- Prim + } + } else pPrim(i) + case _ => + pPrim(i) + } + } ensuring (res => res.smallerIndex(i) && time <= ?) + + @invisibleBody + @memoize + @invstate + def pPrim(i: BigInt): Result = { + require(depsEval(i)) + val char = lookup(i) + if (char == Digit()) { + if (i > 0) + Parsed(i - 1) // Rule1: Prim <- Digit + else + Parsed(-1) // here, we can use -1 to convery that the suffix is empty + } else if (char == Open() && i > 0) { + pAdd(i - 1) match { // Rule 2: pPrim <- ( Add ) + case Parsed(rem) => + Parsed(rem) + case _ => + NoParse() + } + } else NoParse() + } ensuring (res => res.smallerIndex(i) && time <= ?) + + //@inline + def depsEval(i: BigInt) = i == 0 || (i > 0 && allEval(i-1)) + + def allEval(i: BigInt): Boolean = { + require(i >= 0) + (pPrim(i).isCached && pMul(i).isCached && pAdd(i).isCached) &&( + if (i == 0) true + else allEval(i - 1)) + } + + @traceInduct + def evalMono(i: BigInt, st1: Set[Mem[Result]], st2: Set[Mem[Result]]) = { + require(i >= 0) + (st1.subsetOf(st2) && (allEval(i) withState st1)) ==> (allEval(i) withState st2) + } holds + + @traceInduct + def depsLem(x: BigInt, y: BigInt) = { + require(x >= 0 && y >= 0) + (x <= y && allEval(y)) ==> allEval(x) + } holds + + /** + * Instantiates the lemma `depsLem` on the result index (if any) + */ + //@inline + def resEval(i: BigInt, res: Result) = { + (res match { + case Parsed(j) => + if (j >= 0 && i > 1) depsLem(j, i - 1) + else true + case _ => true + }) + } + + @invisibleBody + def invoke(i: BigInt): (Result, Result, Result) = { + require(i == 0 || (i > 0 && allEval(i-1))) + (pPrim(i), pMul(i), pAdd(i)) + } ensuring (res => { + val in = Mem.inState[Result] + val out = Mem.outState[Result] + (if(i >0) evalMono(i-1, in, out) else true) && + allEval(i) && + time <= ? + }) + + /** + * Parsing a string of length 'n+1'. + * Word is represented as an array indexed by 'n'. We only pass around the index. + * The 'lookup' function will return a character of the array. + */ + @invisibleBody + def parse(n: BigInt): Result = { + require(n >= 0) + if(n == 0) invoke(n)._3 + else { + val tailres = parse(n-1) // we parse the prefixes ending at 0, 1, 2, 3, ..., n + invoke(n)._3 + } + } ensuring(_ => allEval(n) && + time <= ? * n + ?) +} diff --git a/testcases/lazy-datastructures/withOrb/SortingnConcat-orb.scala b/testcases/lazy-datastructures/withOrb/SortingnConcat-orb.scala index 94acdf9caaae8dd289262048d9d3bf7e3ad69c9f..78eff9bbe6c7f0a2c0ffa7cdba3ac385c4f11649 100644 --- a/testcases/lazy-datastructures/withOrb/SortingnConcat-orb.scala +++ b/testcases/lazy-datastructures/withOrb/SortingnConcat-orb.scala @@ -71,8 +71,8 @@ object SortingnConcat { } } ensuring (_ => time <= ? * ssize(l) + ?) - /* Orb can prove this - * def kthMin(l: $[LList], k: BigInt): BigInt = { + // Orb can prove this + def kthMin(l: $[LList], k: BigInt): BigInt = { require(k >= 1) l.value match { case SCons(x, xs) => @@ -81,5 +81,5 @@ object SortingnConcat { kthMin(xs, k - 1) case SNil() => BigInt(0) // None[BigInt] } - } ensuring (_ => time <= 15 * k * ssize(l) + 20 * k + 20)*/ + } ensuring (_ => time <= ? * k * ssize(l) + ? * k + ?) } diff --git a/testcases/lazy-datastructures/withOrb/WeightedScheduling.scala b/testcases/lazy-datastructures/withOrb/WeightedScheduling.scala new file mode 100644 index 0000000000000000000000000000000000000000..74281ac66511eca27c394a62e4a7e5d88e3a3964 --- /dev/null +++ b/testcases/lazy-datastructures/withOrb/WeightedScheduling.scala @@ -0,0 +1,109 @@ +package orb + +import leon.lazyeval._ +import leon.lazyeval.Mem._ +import leon.lang._ +import leon.annotation._ +import leon.instrumentation._ +import leon.invariant._ + +object WeightedSched { + sealed abstract class IList { + def size: BigInt = { + this match { + case Cons(_, tail) => 1 + tail.size + case Nil() => BigInt(0) + } + } ensuring(_ >= 0) + } + case class Cons(x: BigInt, tail: IList) extends IList + case class Nil() extends IList + + /** + * array of jobs + * (a) each job has a start time, finish time, and weight + * (b) Jobs are sorted in ascending order of finish times + */ + @ignore + var jobs = Array[(BigInt, BigInt, BigInt)]() + + /** + * A precomputed mapping from each job i to the previous job j it is compatible with. + */ + @ignore + var p = Array[Int]() + + @extern + def jobInfo(i: BigInt) = { + jobs(i.toInt) + } ensuring(_ => time <= 1) + + @extern + def prevCompatibleJob(i: BigInt) = { + BigInt(p(i.toInt)) + } ensuring(res => res >=0 && res < i && time <= 1) + + @inline + def max(x: BigInt, y: BigInt) = if (x >= y) x else y + + def depsEval(i: BigInt) = i == 0 || (i > 0 && allEval(i-1)) + + def allEval(i: BigInt): Boolean = { + require(i >= 0) + sched(i).isCached && + (if (i == 0) true + else allEval(i - 1)) + } + + @traceInduct + def evalMono(i: BigInt, st1: Set[Mem[BigInt]], st2: Set[Mem[BigInt]]) = { + require(i >= 0) + (st1.subsetOf(st2) && (allEval(i) withState st1)) ==> (allEval(i) withState st2) + } holds + + @traceInduct + def evalLem(x: BigInt, y: BigInt) = { + require(x >= 0 && y >= 0) + (x <= y && allEval(y)) ==> allEval(x) + } holds + + @invisibleBody + @invstate + @memoize + def sched(jobIndex: BigInt): BigInt = { + require(depsEval(jobIndex) && + (jobIndex == 0 || evalLem(prevCompatibleJob(jobIndex), jobIndex-1))) + val (st, fn, w) = jobInfo(jobIndex) + if(jobIndex == 0) w + else { + // we may either include the head job or not: + // if we include the head job, we have to skip every job that overlaps with it + val tailValue = sched(jobIndex - 1) + val prevCompatVal = sched(prevCompatibleJob(jobIndex)) + max(w + prevCompatVal, tailValue) + } + } ensuring(_ => time <= ?) + + @invisibleBody + def invoke(jobIndex: BigInt) = { + require(depsEval(jobIndex)) + sched(jobIndex) + } ensuring (res => { + val in = Mem.inState[BigInt] + val out = Mem.outState[BigInt] + (jobIndex == 0 || evalMono(jobIndex-1, in, out)) && + time <= ? + }) + + @invisibleBody + def schedBU(jobi: BigInt): IList = { + require(jobi >= 0) + if(jobi == 0) { + Cons(invoke(jobi), Nil()) + } else { + val tailRes = schedBU(jobi-1) + Cons(invoke(jobi), tailRes) + } + } ensuring(_ => allEval(jobi) && + time <= ? * (jobi + 1)) +} diff --git a/testcases/orb-testcases/timing/AVLTree.scala b/testcases/orb-testcases/timing/AVLTree.scala index d34787eba6ab5c416bffd10cd5d50a72addc9448..2dad4712c0933969d49963ef9ecc94af093d632d 100644 --- a/testcases/orb-testcases/timing/AVLTree.scala +++ b/testcases/orb-testcases/timing/AVLTree.scala @@ -1,36 +1,30 @@ import leon.invariant._ import leon.instrumentation._ import leon.math._ +import leon.annotation._ -/** - * created by manos and modified by ravi. - * BST property cannot be verified - */ -object AVLTree { +object AVLTree { sealed abstract class Tree case class Leaf() extends Tree - case class Node(left : Tree, value : BigInt, right: Tree, rank : BigInt) extends Tree + case class Node(left: Tree, value: BigInt, right: Tree, rank: BigInt) extends Tree sealed abstract class OptionInt case class None() extends OptionInt case class Some(i: BigInt) extends OptionInt - //def min(i1:BigInt, i2:BigInt) : BigInt = if (i1<=i2) i1 else i2 - //def max(i1:BigInt, i2:BigInt) : BigInt = if (i1>=i2) i1 else i2 - - /*def twopower(x: BigInt) : BigInt = { + /*def expGoldenRatio(x: BigInt) : BigInt = { //require(x >= 0) if(x < 1) 1 else 3/2 * twopower(x - 1) } ensuring(res => res >= 1 template((a) => a <= 0))*/ - def rank(t: Tree) : BigInt = { + def rank(t: Tree): BigInt = { t match { - case Leaf() => 0 - case Node(_,_,_,rk) => rk + case Leaf() => 0 + case Node(_, _, _, rk) => rk } - } //ensuring(res => res >= 0) + } def height(t: Tree): BigInt = { t match { @@ -38,7 +32,7 @@ object AVLTree { case Node(l, x, r, _) => { val hl = height(l) val hr = height(r) - max(hl,hr) + 1 + max(hl, hr) + 1 } } } @@ -46,21 +40,21 @@ object AVLTree { def size(t: Tree): BigInt = { //require(isAVL(t)) (t match { - case Leaf() => 0 - case Node(l, _, r,_) => size(l) + 1 + size(r) + case Leaf() => 0 + case Node(l, _, r, _) => size(l) + 1 + size(r) }) } - //ensuring (res => true template((a,b) => height(t) <= a*res + b)) + //ensuring (_ => height(t) <= ? * res + ?) - def rankHeight(t: Tree) : Boolean = t match { - case Leaf() => true - case Node(l,_,r,rk) => rankHeight(l) && rankHeight(r) && rk == height(t) + def rankHeight(t: Tree): Boolean = t match { + case Leaf() => true + case Node(l, _, r, rk) => rankHeight(l) && rankHeight(r) && rk == height(t) } - def balanceFactor(t : Tree) : BigInt = { - t match{ - case Leaf() => 0 + def balanceFactor(t: Tree): BigInt = { + t match { + case Leaf() => 0 case Node(l, _, r, _) => rank(l) - rank(r) } } @@ -72,48 +66,39 @@ object AVLTree { } }*/ - def unbalancedInsert(t: Tree, e : BigInt) : Tree = { + def unbalancedInsert(t: Tree, e: BigInt): Tree = { t match { case Leaf() => Node(Leaf(), e, Leaf(), 1) - case Node(l,v,r,h) => - if (e == v) t - else if (e < v){ - val newl = avlInsert(l,e) + case Node(l, v, r, h) => + if (e == v) t + else if (e < v) { + val newl = avlInsert(l, e) Node(newl, v, r, max(rank(newl), rank(r)) + 1) - } - else { - val newr = avlInsert(r,e) + } else { + val newr = avlInsert(r, e) Node(l, v, newr, max(rank(l), rank(newr)) + 1) } } } - def avlInsert(t: Tree, e : BigInt) : Tree = { - - balance(unbalancedInsert(t,e)) - - } ensuring(res => tmpl((a,b) => time <= a*height(t) + b)) - //ensuring(res => time <= 276*height(t) + 38) - //minbound: ensuring(res => time <= 138*height(t) + 19) + def avlInsert(t: Tree, e: BigInt): Tree = { + balance(unbalancedInsert(t, e)) + } ensuring (_ => time <= ? * height(t) + ?) def deletemax(t: Tree): (Tree, OptionInt) = { - t match { case Node(Leaf(), v, Leaf(), _) => (Leaf(), Some(v)) case Node(l, v, Leaf(), _) => { - val (newl, opt) = deletemax(l) - opt match { - case None() => (t, None()) - case Some(lmax) => { - val newt = balance(Node(newl, lmax, Leaf(), rank(newl) + 1)) - (newt, Some(v)) - } + deletemax(l) match { + case (_, None()) => (t, None()) + case (newl, Some(lmax)) => + (balance(Node(newl, lmax, Leaf(), rank(newl) + 1)), Some(v)) } } case Node(_, _, r, _) => deletemax(r) - case _ => (t, None()) + case _ => (t, None()) } - } ensuring(res => tmpl((a,b) => time <= a*height(t) + b)) + } ensuring (res => time <= ? * height(t) + ?) def unbalancedDelete(t: Tree, e: BigInt): Tree = { t match { @@ -121,14 +106,12 @@ object AVLTree { case Node(l, v, r, h) => if (e == v) { if (l == Leaf()) r - else if(r == Leaf()) l + else if (r == Leaf()) l else { - val (newl, opt) = deletemax(l) - opt match { - case None() => t - case Some(newe) => { + deletemax(l) match { + case (_, None()) => t + case (newl, Some(newe)) => Node(newl, newe, r, max(rank(newl), rank(r)) + 1) - } } } } else if (e < v) { @@ -142,54 +125,48 @@ object AVLTree { } def avlDelete(t: Tree, e: BigInt): Tree = { - balance(unbalancedDelete(t, e)) + } ensuring (res => tmpl((a, b) => time <= a * height(t) + b)) - } ensuring(res => tmpl((a,b) => time <= a*height(t) + b)) - - def balance(t:Tree) : Tree = { + @invisibleBody + def balance(t: Tree): Tree = { t match { case Leaf() => Leaf() // impossible... case Node(l, v, r, h) => val bfactor = balanceFactor(t) // at this point, the tree is unbalanced - if(bfactor > 1 ) { // left-heavy + if (bfactor > 1) { // left-heavy val newL = if (balanceFactor(l) < 0) { // l is right heavy rotateLeft(l) - } - else l - rotateRight(Node(newL,v,r, max(rank(newL), rank(r)) + 1)) - } - else if(bfactor < -1) { + } else l + rotateRight(Node(newL, v, r, max(rank(newL), rank(r)) + 1)) + } else if (bfactor < -1) { val newR = if (balanceFactor(r) > 0) { // r is left heavy rotateRight(r) - } - else r - rotateLeft(Node(l,v,newR, max(rank(newR), rank(l)) + 1)) + } else r + rotateLeft(Node(l, v, newR, max(rank(newR), rank(l)) + 1)) } else t - } - } + } + } ensuring (_ => time <= ?) - def rotateRight(t:Tree) = { + def rotateRight(t: Tree) = { t match { - case Node(Node(ll, vl, rl, _),v,r, _) => - - val hr = max(rank(rl),rank(r)) + 1 - Node(ll, vl, Node(rl,v,r,hr), max(rank(ll),hr) + 1) - + case Node(Node(ll, vl, rl, _), v, r, _) => + val hr = max(rank(rl), rank(r)) + 1 + Node(ll, vl, Node(rl, v, r, hr), max(rank(ll), hr) + 1) case _ => t // this should not happen - } } - + } + } - def rotateLeft(t:Tree) = { + def rotateLeft(t: Tree) = { t match { - case Node(l, v, Node(lr,vr,rr,_), _) => - - val hl = max(rank(l),rank(lr)) + 1 - Node(Node(l,v,lr,hl), vr, rr, max(hl, rank(rr)) + 1) + case Node(l, v, Node(lr, vr, rr, _), _) => + val hl = max(rank(l), rank(lr)) + 1 + Node(Node(l, v, lr, hl), vr, rr, max(hl, rank(rr)) + 1) case _ => t // this should not happen - } } + } + } } diff --git a/testcases/orb-testcases/timing/InsertionSort.scala b/testcases/orb-testcases/timing/InsertionSort.scala index 8fd79a2e89f60441fd522584fae4197079f9294e..8c0f029798e29fc8834c0dd598f61e323f307f67 100644 --- a/testcases/orb-testcases/timing/InsertionSort.scala +++ b/testcases/orb-testcases/timing/InsertionSort.scala @@ -15,12 +15,11 @@ object InsertionSort { l match { case Cons(x,xs) => if (x <= e) Cons(x,sortedIns(e, xs)) else Cons(e, l) case _ => Cons(e,Nil()) - } - } ensuring(res => size(res) == size(l) + 1 && tmpl((a,b) => time <= a*size(l) +b && depth <= a*size(l) +b)) + } + } ensuring(res => size(res) == size(l) + 1 && time <= ? * size(l) + ? && depth <= ? * size(l) + ?) def sort(l: List): List = (l match { case Cons(x,xs) => sortedIns(x, sort(xs)) case _ => Nil() - - }) ensuring(res => size(res) == size(l) && tmpl((a,b) => time <= a*(size(l)*size(l)) +b && rec <= a*size(l) + b)) + }) ensuring(res => size(res) == size(l) && time <= ? * (size(l)*size(l)) + ? && rec <= ? * size(l) + ?) } diff --git a/testcases/orb-testcases/timing/SpeedBenchmarks.scala b/testcases/orb-testcases/timing/SpeedBenchmarks.scala index a7349ab260eeec44f80222b7893e7cc16ea08b08..4728ae6bc6f5a5089d0ee046887e7f85ed495686 100644 --- a/testcases/orb-testcases/timing/SpeedBenchmarks.scala +++ b/testcases/orb-testcases/timing/SpeedBenchmarks.scala @@ -1,6 +1,7 @@ import leon.invariant._ import leon.instrumentation._ + object SpeedBenchmarks { sealed abstract class List case class Cons(head: BigInt, tail: List) extends List @@ -71,7 +72,7 @@ object SpeedBenchmarks { //Fig. 2 of Speed POPL'09 def Dis1(x : BigInt, y : BigInt, n: BigInt, m: BigInt) : BigInt = { - if(x >= n) 0 + if(x >= n) BigInt(0) else { if(y < m) Dis1(x, y+1, n, m) else Dis1(x+1, y, n, m) @@ -80,7 +81,7 @@ object SpeedBenchmarks { //Fig. 2 of Speed POPL'09 def Dis2(x : BigInt, z : BigInt, n: BigInt) : BigInt = { - if(x >= n) 0 + if(x >= n) BigInt(0) else { if(z > x) Dis2(x+1, z, n) else Dis2(x, z+1, n) @@ -90,7 +91,7 @@ object SpeedBenchmarks { //Pg. 138, Speed POPL'09 def Dis3(x : BigInt, b : Boolean, t: BigInt, n: BigInt) : BigInt = { require((b && t == 1) || (!b && t == -1)) - if(x > n || x < 0) 0 + if(x > n || x < 0) BigInt(0) else { if(b) Dis3(x+t, b, t, n) else Dis3(x-t, b, t, n) @@ -99,7 +100,7 @@ object SpeedBenchmarks { //Pg. 138, Speed POPL'09 def Dis4(x : BigInt, b : Boolean, t: BigInt, n: BigInt) : BigInt = { - if(x > n || x < 0) 0 + if(x > n || x < 0) BigInt(0) else { if(b) Dis4(x+t, b, t, n) else Dis4(x-t, b, t, n)