From 02b62f4bfe251a0b062f4f0723fb736cda0b17fc Mon Sep 17 00:00:00 2001 From: ravi <ravi.kandhadai@epfl.ch> Date: Wed, 7 Oct 2015 19:10:26 +0200 Subject: [PATCH] Porting Orb to smtlib solvers --- .../engine/UnfoldingTemplateSolver.scala | 45 +-- .../templateSolvers/ExtendedUFSolver.scala | 11 +- .../templateSolvers/FarkasLemmaSolver.scala | 5 +- .../templateSolvers/NLTemplateSolver.scala | 248 +++++++-------- .../scala/leon/invariant/util/Minimizer.scala | 5 +- src/main/scala/leon/invariant/util/Util.scala | 2 +- .../leon/solvers/smtlib/SMTLIBTarget.scala | 288 ++++++++++-------- .../regression/orb/timing/BinomialHeap.scala | 181 ----------- 8 files changed, 320 insertions(+), 465 deletions(-) delete mode 100644 src/test/resources/regression/orb/timing/BinomialHeap.scala diff --git a/src/main/scala/leon/invariant/engine/UnfoldingTemplateSolver.scala b/src/main/scala/leon/invariant/engine/UnfoldingTemplateSolver.scala index a933be56a..e21641dd8 100644 --- a/src/main/scala/leon/invariant/engine/UnfoldingTemplateSolver.scala +++ b/src/main/scala/leon/invariant/engine/UnfoldingTemplateSolver.scala @@ -32,7 +32,7 @@ case class InferResult(res: Boolean, model: Option[Model], inferredFuncs: List[F } trait FunctionTemplateSolver { - def apply() : Option[InferResult] + def apply(): Option[InferResult] } class UnfoldingTemplateSolver(ctx: InferenceContext, rootFd: FunDef) extends FunctionTemplateSolver { @@ -57,14 +57,13 @@ class UnfoldingTemplateSolver(ctx: InferenceContext, rootFd: FunDef) extends Fun val fullPost = matchToIfThenElse(if (funDef.hasTemplate) if (ctx.toVerifyPostFor.contains(funDef.id.name)) - And(funDef.getPostWoTemplate, funDef.getTemplate) - else - funDef.getTemplate + And(funDef.getPostWoTemplate, funDef.getTemplate) else - if (ctx.toVerifyPostFor.contains(funDef.id.name)) - funDef.getPostWoTemplate - else - BooleanLiteral(true)) + funDef.getTemplate + else if (ctx.toVerifyPostFor.contains(funDef.id.name)) + funDef.getPostWoTemplate + else + BooleanLiteral(true)) (bodyExpr, fullPost) } @@ -178,7 +177,7 @@ class UnfoldingTemplateSolver(ctx: InferenceContext, rootFd: FunDef) extends Fun val newFundefs = program.definedFunctions.collect { case fd @ _ => { //if !Util.isMultFunctions(fd) val newfd = new FunDef(FreshIdentifier(fd.id.name, Untyped, false), - fd.tparams, fd.returnType, fd.params) + fd.tparams, fd.returnType, fd.params) (fd, newfd) } }.toMap @@ -215,7 +214,7 @@ class UnfoldingTemplateSolver(ctx: InferenceContext, rootFd: FunDef) extends Fun val ninv = replace(Map(ResultVariable(fd.returnType) -> resvar.toVariable), inv) Some(Lambda(Seq(ValDef(resvar, Some(fd.returnType))), ninv)) } - } else if(fd.postcondition.isDefined) { + } else if (fd.postcondition.isDefined) { val Lambda(resultBinder, _) = fd.postcondition.get Some(Lambda(resultBinder, fd.getPostWoTemplate)) } else None @@ -241,19 +240,29 @@ class UnfoldingTemplateSolver(ctx: InferenceContext, rootFd: FunDef) extends Fun (augmentedProg, newFundefs(rootFd)) } //println("New Root: "+newroot) - import leon.solvers.z3._ - val dummySolFactory = new leon.solvers.SolverFactory[ExtendedUFSolver] { - def getNewSolver() = new ExtendedUFSolver(ctx.leonContext, program) + import leon.solvers.smtlib.SMTLIBZ3Solver + import leon.solvers.combinators.UnrollingSolver + val dummySolFactory = new leon.solvers.SolverFactory[SMTLIBZ3Solver] { + def getNewSolver() = new SMTLIBZ3Solver(ctx.leonContext, program) } val vericontext = VerificationContext(ctx.leonContext, newprog, dummySolFactory, reporter) val defaultTactic = new DefaultTactic(vericontext) val vc = defaultTactic.generatePostconditions(newroot)(0) val verifyTimeout = 5 - val fairZ3 = new SimpleSolverAPI( - new TimeoutSolverFactory(SolverFactory(() => new FairZ3Solver(ctx.leonContext, newprog) with TimeoutSolver), - verifyTimeout * 1000)) - val sat = fairZ3.solveSAT(Not(vc.condition)) - sat + // val fairZ3 = new SimpleSolverAPI( + // new TimeoutSolverFactory(SolverFactory(() => + // new FairZ3Solver(ctx.leonContext, newprog) with TimeoutSolver), + // verifyTimeout * 1000)) + val smtUnrollZ3 = new UnrollingSolver(ctx.leonContext, program, + new SMTLIBZ3Solver(ctx.leonContext, program)) with TimeoutSolver + smtUnrollZ3.setTimeout(verifyTimeout * 1000) + smtUnrollZ3.assertVC(vc) + smtUnrollZ3.check match { + case Some(true) => + (Some(true), smtUnrollZ3.getModel) + case r => + (r, Model.empty) + } } } diff --git a/src/main/scala/leon/invariant/templateSolvers/ExtendedUFSolver.scala b/src/main/scala/leon/invariant/templateSolvers/ExtendedUFSolver.scala index 8bc93564c..b072ddc67 100644 --- a/src/main/scala/leon/invariant/templateSolvers/ExtendedUFSolver.scala +++ b/src/main/scala/leon/invariant/templateSolvers/ExtendedUFSolver.scala @@ -13,16 +13,17 @@ import purescala.ExprOps._ import purescala.Types._ import leon.LeonContext import leon.solvers.z3.UninterpretedZ3Solver +import leon.solvers.smtlib.SMTLIBZ3Solver /** * A uninterpreted solver extended with additional functionalities. * TODO: need to handle bit vectors */ -class ExtendedUFSolver(context : LeonContext, program: Program) - extends UninterpretedZ3Solver(context, program) { +class ExtendedUFSolver(context: LeonContext, program: Program) + extends UninterpretedZ3Solver(context, program) { override val name = "Z3-eu" - override val description = "Extended UF-ADT Z3 Solver" + override val description = "Extended UF-ADT Z3 Solver" /** * This uses z3 methods to evaluate the model @@ -36,7 +37,7 @@ class ExtendedUFSolver(context : LeonContext, program: Program) else None } - def getAssertions : Expr = { + def getAssertions: Expr = { val assers = solver.getAssertions.map((ast) => fromZ3Formula(null, ast, null)) And(assers) } @@ -60,7 +61,7 @@ class ExtendedUFSolver(context : LeonContext, program: Program) if (line == "; benchmark") newHeaders :+= line else if (line.startsWith("(set")) newHeaders :+= line else if (line.startsWith("(declare")) newHeaders :+= line - else if(line.startsWith("(check-sat)")) {} //do nothing + else if (line.startsWith("(check-sat)")) {} //do nothing else asserts :+= line }) headers ++= newHeaders diff --git a/src/main/scala/leon/invariant/templateSolvers/FarkasLemmaSolver.scala b/src/main/scala/leon/invariant/templateSolvers/FarkasLemmaSolver.scala index 13e8727a4..6e80c4749 100644 --- a/src/main/scala/leon/invariant/templateSolvers/FarkasLemmaSolver.scala +++ b/src/main/scala/leon/invariant/templateSolvers/FarkasLemmaSolver.scala @@ -21,6 +21,7 @@ import leon.solvers.TimeoutSolver import leon.solvers.SolverFactory import leon.solvers.TimeoutSolverFactory import leon.solvers.Model +import leon.solvers.smtlib.SMTLIBZ3Solver import leon.invariant.util.RealValuedExprEvaluator._ class FarkasLemmaSolver(ctx: InferenceContext) { @@ -278,7 +279,9 @@ class FarkasLemmaSolver(ctx: InferenceContext) { throw new IllegalStateException("Not supported now. Will be in the future!") //new ExtendedUFSolver(leonctx, program, useBitvectors = true, bitvecSize = bvsize) with TimeoutSolver } else { - new ExtendedUFSolver(leonctx, program) with TimeoutSolver + // use SMTLIBSolver to solve the constraints so that it can be timed out effectively + new SMTLIBZ3Solver(leonctx, program) with TimeoutSolver + //new ExtendedUFSolver(leonctx, program) with TimeoutSolver } val solver = SimpleSolverAPI(new TimeoutSolverFactory(SolverFactory(() => innerSolver), timeout * 1000)) if (verbose) reporter.info("solving...") diff --git a/src/main/scala/leon/invariant/templateSolvers/NLTemplateSolver.scala b/src/main/scala/leon/invariant/templateSolvers/NLTemplateSolver.scala index 2e01e4682..47901d3b9 100644 --- a/src/main/scala/leon/invariant/templateSolvers/NLTemplateSolver.scala +++ b/src/main/scala/leon/invariant/templateSolvers/NLTemplateSolver.scala @@ -12,6 +12,8 @@ import evaluators._ import scala.collection.mutable.{ Map => MutableMap } import java.io._ import solvers._ +import solvers.combinators._ +import solvers.smtlib._ import solvers.z3._ import scala.util.control.Breaks._ import purescala.ScalaPrinter @@ -56,12 +58,14 @@ class NLTemplateSolver(ctx: InferenceContext, rootFun: FunDef, ctrTracker: Const private val startFromEarlierModel = true private val disableCegis = true private val useIncrementalSolvingForVCs = true + private val useCVCToCheckVCs = false //this is private mutable state used by initialized during every call to 'solve' and used by 'solveUNSAT' protected var funcVCs = Map[FunDef, Expr]() //TODO: can incremental solving be trusted ? There were problems earlier. - protected var vcSolvers = Map[FunDef, ExtendedUFSolver]() + protected var vcSolvers = Map[FunDef, Solver with TimeoutSolver]() protected var paramParts = Map[FunDef, Expr]() + protected var simpleParts = Map[FunDef, Expr]() private var lastFoundModel: Option[Model] = None //for miscellaneous things @@ -83,7 +87,11 @@ class NLTemplateSolver(ctx: InferenceContext, rootFun: FunDef, ctrTracker: Const if (Util.hasReals(rest) && Util.hasInts(rest)) throw new IllegalStateException("Non-param Part has both integers and reals: " + rest) - val vcSolver = new ExtendedUFSolver(leonctx, program) + val vcSolver = + if (this.useCVCToCheckVCs) + new SMTLIBCVC4Solver(leonctx, program) with TimeoutSolver + else + new SMTLIBZ3Solver(leonctx, program) with TimeoutSolver vcSolver.assertCnstr(rest) if (debugIncrementalVC) { @@ -97,6 +105,7 @@ class NLTemplateSolver(ctx: InferenceContext, rootFun: FunDef, ctrTracker: Const } vcSolvers += (fd -> vcSolver) paramParts += (fd -> paramPart) + simpleParts += (fd -> rest) }) } @@ -233,111 +242,111 @@ class NLTemplateSolver(ctx: InferenceContext, rootFun: FunDef, ctrTracker: Const var confFunctions = Set[FunDef]() var confDisjuncts = Seq[Expr]() - val newctrs = conflictingFuns.foldLeft(Seq[Expr]())((acc, fd) => { - - val disableCounterExs = if (seenPaths.contains(fd)) { - blockedCEs = true - Not(Util.createOr(seenPaths(fd))) - } else tru - val (data, ctrsForFun) = getUNSATConstraints(fd, model, disableCounterExs) - val (disjunct, callsInPath) = data - if (ctrsForFun == tru) acc - else { - confFunctions += fd - confDisjuncts :+= disjunct - callsInPaths ++= callsInPath - //instantiate the disjunct - val cePath = simplifyArithmetic(TemplateInstantiator.instantiate(disjunct, tempVarMap)) - - //some sanity checks - if (variablesOf(cePath).exists(TemplateIdFactory.IsTemplateIdentifier _)) - throw new IllegalStateException("Found template identifier in counter-example disjunct: " + cePath) - - updateSeenPaths(fd, cePath) - acc :+ ctrsForFun - } - }) - //update conflicting functions - conflictingFuns = confFunctions - if (newctrs.isEmpty) { - - if (!blockedCEs) { - //yes, hurray,found an inductive invariant - (Some(false), prevCtr, model) - } else { - //give up, only hard paths remaining - reporter.info("- Exhausted all easy paths !!") - reporter.info("- Number of remaining hard paths: " + seenPaths.values.foldLeft(0)((acc, elem) => acc + elem.size)) - //TODO: what to unroll here ? + val newctrsOpt = conflictingFuns.foldLeft(Some(Seq()): Option[Seq[Expr]]) { + case (None, _) => None + case (Some(acc), fd) => + val disableCounterExs = if (seenPaths.contains(fd)) { + blockedCEs = true + Not(Util.createOr(seenPaths(fd))) + } else tru + getUNSATConstraints(fd, model, disableCounterExs) match { + case None => + None + case Some(((disjunct, callsInPath), ctrsForFun)) => + if (ctrsForFun == tru) Some(acc) + else { + confFunctions += fd + confDisjuncts :+= disjunct + callsInPaths ++= callsInPath + //instantiate the disjunct + val cePath = simplifyArithmetic(TemplateInstantiator.instantiate(disjunct, tempVarMap)) + //some sanity checks + if (variablesOf(cePath).exists(TemplateIdFactory.IsTemplateIdentifier _)) + throw new IllegalStateException("Found template identifier in counter-example disjunct: " + cePath) + updateSeenPaths(fd, cePath) + Some(acc :+ ctrsForFun) + } + } + } + newctrsOpt match { + case None => + // give up, the VC cannot be decided (None, tru, Model.empty) - } - } else { - - //check that the new constraints does not have any reals - val newPart = Util.createAnd(newctrs) - val newSize = Util.atomNum(newPart) - Stats.updateCounterStats((newSize + inputSize), "NLsize", "disjuncts") - if (verbose) - reporter.info("# of atomic predicates: " + newSize + " + " + inputSize) - - /*if (this.debugIncremental) - solverWithCtr.assertCnstr(newPart)*/ - - //here we need to solve for the newctrs + inputCtrs - val combCtr = And(prevCtr, newPart) - val (res, newModel) = farkasSolver.solveFarkasConstraints(combCtr) - - res match { - case None => { - //here we have timed out while solving the non-linear constraints + case Some(newctrs) => + //update conflicting functions + conflictingFuns = confFunctions + if (newctrs.isEmpty) { + if (!blockedCEs) { + //yes, hurray,found an inductive invariant + (Some(false), prevCtr, model) + } else { + //give up, only hard paths remaining + reporter.info("- Exhausted all easy paths !!") + reporter.info("- Number of remaining hard paths: " + seenPaths.values.foldLeft(0)((acc, elem) => acc + elem.size)) + //TODO: what to unroll here ? + (None, tru, Model.empty) + } + } else { + //check that the new constraints does not have any reals + val newPart = Util.createAnd(newctrs) + val newSize = Util.atomNum(newPart) + Stats.updateCounterStats((newSize + inputSize), "NLsize", "disjuncts") if (verbose) - if (!this.disableCegis) - reporter.info("NLsolver timed-out on the disjunct... starting cegis phase...") - else - reporter.info("NLsolver timed-out on the disjunct... blocking this disjunct...") - - if (!this.disableCegis) { - val (cres, cctr, cmodel) = solveWithCegis(tempIds.toSet, Util.createOr(confDisjuncts), inputCtr, Some(model)) - cres match { - case Some(true) => { - disjsSolvedInIter ++= confDisjuncts - (Some(true), And(inputCtr, cctr), cmodel) - } - case Some(false) => { - disjsSolvedInIter ++= confDisjuncts - //here also return the calls that needs to be unrolled - (None, fls, Model.empty) - } - case _ => { + reporter.info("# of atomic predicates: " + newSize + " + " + inputSize) + //here we need to solve for the newctrs + inputCtrs + val combCtr = And(prevCtr, newPart) + val (res, newModel) = farkasSolver.solveFarkasConstraints(combCtr) + res match { + case None => { + //here we have timed out while solving the non-linear constraints + if (verbose) + if (!this.disableCegis) + reporter.info("NLsolver timed-out on the disjunct... starting cegis phase...") + else + reporter.info("NLsolver timed-out on the disjunct... blocking this disjunct...") + if (!this.disableCegis) { + val (cres, cctr, cmodel) = solveWithCegis(tempIds.toSet, Util.createOr(confDisjuncts), inputCtr, Some(model)) + cres match { + case Some(true) => { + disjsSolvedInIter ++= confDisjuncts + (Some(true), And(inputCtr, cctr), cmodel) + } + case Some(false) => { + disjsSolvedInIter ++= confDisjuncts + //here also return the calls that needs to be unrolled + (None, fls, Model.empty) + } + case _ => { + if (verbose) reporter.info("retrying...") + Stats.updateCumStats(1, "retries") + //disable this disjunct and retry but, use the inputCtrs + the constraints generated by cegis from the next iteration + invalidateDisjRecr(And(inputCtr, cctr)) + } + } + } else { if (verbose) reporter.info("retrying...") Stats.updateCumStats(1, "retries") - //disable this disjunct and retry but, use the inputCtrs + the constraints generated by cegis from the next iteration - invalidateDisjRecr(And(inputCtr, cctr)) + invalidateDisjRecr(inputCtr) } } - } else { - if (verbose) reporter.info("retrying...") - Stats.updateCumStats(1, "retries") - invalidateDisjRecr(inputCtr) + case Some(false) => { + //reporter.info("- Number of explored paths (of the DAG) in this unroll step: " + exploredPaths) + disjsSolvedInIter ++= confDisjuncts + (None, fls, Model.empty) + } + case Some(true) => { + disjsSolvedInIter ++= confDisjuncts + //new model may not have mappings for all the template variables, hence, use the mappings from earlier models + val compModel = new Model(tempIds.map((id) => { + if (newModel.isDefinedAt(id)) + (id -> newModel(id)) + else + (id -> model(id)) + }).toMap) + (Some(true), combCtr, compModel) + } } } - case Some(false) => { - //reporter.info("- Number of explored paths (of the DAG) in this unroll step: " + exploredPaths) - disjsSolvedInIter ++= confDisjuncts - (None, fls, Model.empty) - } - case Some(true) => { - disjsSolvedInIter ++= confDisjuncts - //new model may not have mappings for all the template variables, hence, use the mappings from earlier models - val compModel = new Model(tempIds.map((id) => { - if (newModel.isDefinedAt(id)) - (id -> newModel(id)) - else - (id -> model(id)) - }).toMap) - (Some(true), combCtr, compModel) - } - } } } val (res, newctr, newmodel) = invalidateDisjRecr(inputCtr) @@ -362,11 +371,12 @@ class NLTemplateSolver(ctx: InferenceContext, rootFun: FunDef, ctrTracker: Const /** * Constructs a quantifier-free non-linear constraint for unsatisfiability */ - def getUNSATConstraints(fd: FunDef, inModel: Model, disableCounterExs: Expr): ((Expr, Set[Call]), Expr) = { + def getUNSATConstraints(fd: FunDef, inModel: Model, disableCounterExs: Expr): Option[((Expr, Set[Call]), Expr)] = { val tempVarMap: Map[Expr, Expr] = inModel.map((elem) => (elem._1.toVariable, elem._2)).toMap - val innerSolver = if (this.useIncrementalSolvingForVCs) vcSolvers(fd) - else new ExtendedUFSolver(leonctx, program) + val innerSolver = + if (this.useIncrementalSolvingForVCs) vcSolvers(fd) + else new SMTLIBZ3Solver(leonctx, program) with TimeoutSolver val instExpr = if (this.useIncrementalSolvingForVCs) { val instParamPart = instantiateTemplate(this.paramParts(fd), tempVarMap) And(instParamPart, disableCounterExs) @@ -379,7 +389,7 @@ class NLTemplateSolver(ctx: InferenceContext, rootFun: FunDef, ctrTracker: Const // println("Plain vc: "+funcVCs(fd)) val wr = new PrintWriter(new File("formula-dump.txt")) val fullExpr = if (this.useIncrementalSolvingForVCs) { - And(innerSolver.getAssertions, instExpr) + And(simpleParts(fd), instExpr) } else instExpr // println("Instantiated VC of " + fd.id + " is: " + fullExpr) @@ -389,31 +399,23 @@ class NLTemplateSolver(ctx: InferenceContext, rootFun: FunDef, ctrTracker: Const wr.flush() wr.close() } - //throw an exception if the candidate expression has reals if (Util.hasMixedIntReals(instExpr)) { - //variablesOf(instExpr).foreach(id => println("Id: "+id+" type: "+id.getType)) throw new IllegalStateException("Instantiated VC of " + fd.id + " contains mixed integer/reals: " + instExpr) } //reporter.info("checking VC inst ...") var t1 = System.currentTimeMillis() + innerSolver.setTimeout(timeout * 1000) val (res, model) = if (this.useIncrementalSolvingForVCs) { innerSolver.push innerSolver.assertCnstr(instExpr) - //dump the inst VC as SMTLIB - /*val filename = "vc" + FileCountGUID.getID + ".smt2" - Util.toZ3SMTLIB(innerSolver.getAssertions, filename, "", leonctx, program) - val writer = new PrintWriter(filename) - writer.println(innerSolver.ctrsToString("")) - writer.close() - println("vc dumped to: " + filename)*/ - - val solRes = innerSolver.check - innerSolver.pop() - solRes match { - case Some(true) => (solRes, innerSolver.getModel) - case _ => (solRes, Model.empty) + val solRes = innerSolver.check match { + case r @ Some(true) => + (r, innerSolver.getModel) + case r => (r, Model.empty) } + innerSolver.pop() + solRes } else { val solver = SimpleSolverAPI(SolverFactory(() => innerSolver)) solver.solveSAT(instExpr) @@ -430,7 +432,7 @@ class NLTemplateSolver(ctx: InferenceContext, rootFun: FunDef, ctrTracker: Const Stats.updateCounterStats(Util.atomNum(upVCinst), "UP-VC-size", "disjuncts") t1 = System.currentTimeMillis() - val (res2, _) = SimpleSolverAPI(SolverFactory(() => new ExtendedUFSolver(leonctx, program))).solveSAT(upVCinst) + val (res2, _) = SimpleSolverAPI(SolverFactory(() => new SMTLIBZ3Solver(leonctx, program))).solveSAT(upVCinst) val unpackedTime = System.currentTimeMillis() - t1 if (res != res2) { throw new IllegalStateException("Unpacked VC produces different result: " + upVCinst) @@ -442,11 +444,12 @@ class NLTemplateSolver(ctx: InferenceContext, rootFun: FunDef, ctrTracker: Const t1 = System.currentTimeMillis() res match { case None => { - throw new IllegalStateException("cannot check the satisfiability of " + funcVCs(fd)) + //throw new IllegalStateException("cannot check the satisfiability of " + funcVCs(fd)) + None } case Some(false) => { //do not generate any constraints - ((fls, Set()), tru) + Some(((fls, Set()), tru)) } case Some(true) => { //For debugging purposes. @@ -454,7 +457,6 @@ class NLTemplateSolver(ctx: InferenceContext, rootFun: FunDef, ctrTracker: Const if (this.printCounterExample) { reporter.info("Model: " + model) } - //get the disjuncts that are satisfied val (data, newctr) = generateCtrsFromDisjunct(fd, model) if (newctr == tru) @@ -464,7 +466,7 @@ class NLTemplateSolver(ctx: InferenceContext, rootFun: FunDef, ctrTracker: Const Stats.updateCounterTime((t2 - t1), "Disj-choosing-time", "disjuncts") Stats.updateCumTime((t2 - t1), "Total-Choose-Time") - (data, newctr) + Some((data, newctr)) } } } diff --git a/src/main/scala/leon/invariant/util/Minimizer.scala b/src/main/scala/leon/invariant/util/Minimizer.scala index 7e44eaaa7..78e82db2d 100644 --- a/src/main/scala/leon/invariant/util/Minimizer.scala +++ b/src/main/scala/leon/invariant/util/Minimizer.scala @@ -9,6 +9,7 @@ import purescala.Extractors._ import purescala.Types._ import solvers._ import solvers.z3._ +import solvers.smtlib.SMTLIBZ3Solver import leon.invariant._ import scala.util.control.Breaks._ import invariant.engine.InferenceContext @@ -55,9 +56,11 @@ class Minimizer(ctx: InferenceContext) { def minimizeBounds(nestMap: Map[Variable, Int])(inputCtr: Expr, initModel: Model): Model = { val orderedTempVars = nestMap.toSeq.sortWith((a, b) => a._2 >= b._2).map(_._1) //do a binary search sequentially on each of these tempvars + // note: use smtlib solvers so that they can be timedout val solver = SimpleSolverAPI( new TimeoutSolverFactory(SolverFactory(() => - new ExtendedUFSolver(leonctx, program) with TimeoutSolver), ctx.timeout * 1000)) + new SMTLIBZ3Solver(leonctx, program) with TimeoutSolver), ctx.timeout * 1000)) + //new ExtendedUFSolver(leonctx, program) with TimeoutSolver), ctx.timeout * 1000)) reporter.info("minimizing...") var currentModel = initModel diff --git a/src/main/scala/leon/invariant/util/Util.scala b/src/main/scala/leon/invariant/util/Util.scala index cc9bf2b2c..8ae7d40a2 100644 --- a/src/main/scala/leon/invariant/util/Util.scala +++ b/src/main/scala/leon/invariant/util/Util.scala @@ -150,7 +150,7 @@ object Util { })(ine) } - def assignTemplateAndCojoinPost(funToTmpl: Map[FunDef, Expr], prog: Program, + def assignTemplateAndCojoinPost(funToTmpl: Map[FunDef, Expr], prog: Program, funToPost: Map[FunDef, Expr] = Map(), uniqueIdDisplay : Boolean = true): Program = { val funMap = Util.functionsWOFields(prog.definedFunctions).foldLeft(Map[FunDef, FunDef]()) { diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala index eee03153e..952dbeadc 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala @@ -15,8 +15,8 @@ import purescala.Constructors._ import purescala.Definitions._ import _root_.smtlib.common._ -import _root_.smtlib.printer.{RecursivePrinter => SMTPrinter} -import _root_.smtlib.parser.Commands.{Constructor => SMTConstructor, FunDef => _, Assert => SMTAssert, _} +import _root_.smtlib.printer.{ RecursivePrinter => SMTPrinter } +import _root_.smtlib.parser.Commands.{ Constructor => SMTConstructor, FunDef => _, Assert => SMTAssert, _ } import _root_.smtlib.parser.Terms.{ Forall => SMTForall, Exists => SMTExists, @@ -27,7 +27,7 @@ import _root_.smtlib.parser.Terms.{ import _root_.smtlib.theories.Core.{ Equals => SMTEquals } -import _root_.smtlib.parser.CommandsResponses.{Error => ErrorResponse, _} +import _root_.smtlib.parser.CommandsResponses.{ Error => ErrorResponse, _ } import _root_.smtlib.theories._ import _root_.smtlib.interpreters.ProcessInterpreter @@ -71,18 +71,18 @@ trait SMTLIBTarget extends Interruptible { protected lazy val debugOut: Option[java.io.FileWriter] = { if (reporter.isDebugEnabled) { val file = context.files.headOption.map(_.getName).getOrElse("NA") - val n = DebugFileNumbers.next(targetName+file) + val n = DebugFileNumbers.next(targetName + file) val fileName = s"smt-sessions/$targetName-$file-$n.smt2" val javaFile = new java.io.File(fileName) javaFile.getParentFile.mkdirs() - reporter.debug(s"Outputting smt session into $fileName" ) + reporter.debug(s"Outputting smt session into $fileName") val fw = new java.io.FileWriter(javaFile, false) - fw.write("; Options: "+interpreterOps(context).mkString(" ")+"\n") + fw.write("; Options: " + interpreterOps(context).mkString(" ") + "\n") Some(fw) } else { @@ -98,7 +98,7 @@ trait SMTLIBTarget extends Interruptible { o.flush() } interpreter.eval(cmd) match { - case err@ErrorResponse(msg) if !hasError && !interrupted && !rawOut => + case err @ ErrorResponse(msg) if !hasError && !interrupted && !rawOut => reporter.warning(s"Unexpected error from $targetName solver: $msg") // Store that there was an error. Now all following check() // invocations will return None @@ -112,7 +112,7 @@ trait SMTLIBTarget extends Interruptible { def parseSuccess() = { val res = interpreter.parser.parseGenResponse if (res != Success) { - reporter.warning("Unnexpected result from "+targetName+": "+res+" expected success") + reporter.warning("Unnexpected result from " + targetName + ": " + res + " expected success") } } @@ -141,20 +141,18 @@ trait SMTLIBTarget extends Interruptible { SSymbol(id.uniqueNameDelimited("!").replace("|", "$pipe").replace("\\", "$backslash")) } - protected def freshSym(id: Identifier): SSymbol = freshSym(id.name) protected def freshSym(name: String): SSymbol = id2sym(FreshIdentifier(name)) - /* Metadata for CC, and variables */ - protected val constructors = new IncrementalBijection[TypeTree, SSymbol]() - protected val selectors = new IncrementalBijection[(TypeTree, Int), SSymbol]() - protected val testers = new IncrementalBijection[TypeTree, SSymbol]() - protected val variables = new IncrementalBijection[Identifier, SSymbol]() + protected val constructors = new IncrementalBijection[TypeTree, SSymbol]() + protected val selectors = new IncrementalBijection[(TypeTree, Int), SSymbol]() + protected val testers = new IncrementalBijection[TypeTree, SSymbol]() + protected val variables = new IncrementalBijection[Identifier, SSymbol]() protected val genericValues = new IncrementalBijection[GenericValue, SSymbol]() - protected val sorts = new IncrementalBijection[TypeTree, Sort]() - protected val functions = new IncrementalBijection[TypedFunDef, SSymbol]() - protected val errors = new IncrementalBijection[Unit, Boolean]() + protected val sorts = new IncrementalBijection[TypeTree, Sort]() + protected val functions = new IncrementalBijection[TypedFunDef, SSymbol]() + protected val errors = new IncrementalBijection[Unit, Boolean]() protected def hasError = errors.getB(()) contains true protected def addError() = errors += () -> true @@ -163,16 +161,14 @@ trait SMTLIBTarget extends Interruptible { protected def normalizeType(t: TypeTree): TypeTree = t match { case ct: ClassType => ct.root case tt: TupleType => tupleTypeWrap(tt.bases.map(normalizeType)) - case _ => t + case _ => t } protected def quantifiedTerm( quantifier: (SortedVar, Seq[SortedVar], Term) => Term, vars: Seq[Identifier], - body: Expr - )( - implicit bindings: Map[Identifier, Term] - ): Term = { + body: Expr)( + implicit bindings: Map[Identifier, Term]): Term = { if (vars.isEmpty) toSMT(body) if (vars.isEmpty) toSMT(body)(Map()) else { @@ -182,15 +178,13 @@ trait SMTLIBTarget extends Interruptible { quantifier( sortedVars.head, sortedVars.tail, - toSMT(body)(bindings ++ vars.map{ id => id -> (id2sym(id): Term)}.toMap) - ) + toSMT(body)(bindings ++ vars.map { id => id -> (id2sym(id): Term) }.toMap)) } } // Returns a quantified term where all free variables in the body have been quantified protected def quantifiedTerm(quantifier: (SortedVar, Seq[SortedVar], Term) => Term, body: Expr)( - implicit bindings: Map[Identifier, Term] - ): Term = + implicit bindings: Map[Identifier, Term]): Term = quantifiedTerm(quantifier, variablesOf(body).toSeq, body) protected def fromRawArray(r: RawArrayValue, tpe: TypeTree): Expr = tpe match { @@ -218,12 +212,12 @@ trait SMTLIBTarget extends Interruptible { val elems = r.elems.flatMap { case (k, CaseClass(leonSome, Seq(x))) => Some(k -> x) - case (k, _) => None + case (k, _) => None }.toSeq FiniteMap(elems, from, to) case other => - unsupported(other, "Unable to extract from raw array for "+tpe) + unsupported(other, "Unable to extract from raw array for " + tpe) } protected def declareUninterpretedSort(t: TypeParameter): Sort = { @@ -239,9 +233,9 @@ trait SMTLIBTarget extends Interruptible { tpe match { case BooleanType => Core.BoolSort() case IntegerType => Ints.IntSort() - case RealType => Reals.RealSort() - case Int32Type => FixedSizeBitVectors.BitVectorSort(32) - case CharType => FixedSizeBitVectors.BitVectorSort(32) + case RealType => Reals.RealSort() + case Int32Type => FixedSizeBitVectors.BitVectorSort(32) + case CharType => FixedSizeBitVectors.BitVectorSort(32) case RawArrayType(from, to) => Sort(SMTIdentifier(SSymbol("Array")), Seq(declareSort(from), declareSort(to))) @@ -273,7 +267,7 @@ trait SMTLIBTarget extends Interruptible { def toDecl(c: Constructor): SMTConstructor = { val s = id2sym(c.sym) - testers += c.tpe -> SSymbol("is-"+s.name) + testers += c.tpe -> SSymbol("is-" + s.name) constructors += c.tpe -> s SMTConstructor(s, c.fields.zipWithIndex.map { @@ -327,9 +321,8 @@ trait SMTLIBTarget extends Interruptible { val s = id2sym(id) emit(DeclareFun( s, - tfd.params.map( (p: ValDef) => declareSort(p.getType)), - declareSort(tfd.returnType) - )) + tfd.params.map((p: ValDef) => declareSort(p.getType)), + declareSort(tfd.returnType))) s } } @@ -341,7 +334,7 @@ trait SMTLIBTarget extends Interruptible { case Sort(id, Nil) => id.symbol - case Sort(id, subs) => + case Sort(id, subs) => SList((id.symbol +: subs.map(sortToSMT)).toList) } } @@ -362,20 +355,19 @@ trait SMTLIBTarget extends Interruptible { declareVariable(FreshIdentifier("Unit", UnitType)) case InfiniteIntegerLiteral(i) => if (i >= 0) Ints.NumeralLit(i) else Ints.Neg(Ints.NumeralLit(-i)) - case IntLiteral(i) => FixedSizeBitVectors.BitVectorLit(Hexadecimal.fromInt(i)) - case FractionalLiteral(n, d) => Reals.Div(Reals.NumeralLit(n), Reals.NumeralLit(d)) - case CharLiteral(c) => FixedSizeBitVectors.BitVectorLit(Hexadecimal.fromInt(c.toInt)) - case BooleanLiteral(v) => Core.BoolConst(v) - case Let(b,d,e) => - val id = id2sym(b) - val value = toSMT(d) + case IntLiteral(i) => FixedSizeBitVectors.BitVectorLit(Hexadecimal.fromInt(i)) + case FractionalLiteral(n, d) => Reals.Div(Reals.NumeralLit(n), Reals.NumeralLit(d)) + case CharLiteral(c) => FixedSizeBitVectors.BitVectorLit(Hexadecimal.fromInt(c.toInt)) + case BooleanLiteral(v) => Core.BoolConst(v) + case Let(b, d, e) => + val id = id2sym(b) + val value = toSMT(d) val newBody = toSMT(e)(bindings + (b -> id)) SMTLet( VarBinding(id, value), Seq(), - newBody - ) + newBody) case er @ Error(tpe, _) => declareVariable(FreshIdentifier("error_value", tpe)) @@ -403,8 +395,7 @@ trait SMTLIBTarget extends Interruptible { case more => val es = freshSym("e") SMTLet(VarBinding(es, toSMT(e)), Seq(), - Core.Or(oneOf.map(FunctionApplication(_, Seq(es:Term))): _*) - ) + Core.Or(oneOf.map(FunctionApplication(_, Seq(es: Term))): _*)) } case CaseClass(cct, es) => @@ -425,7 +416,7 @@ trait SMTLIBTarget extends Interruptible { case ts @ TupleSelect(t, i) => val tpe = normalizeType(t.getType) declareSort(tpe) - val selector = selectors.toB((tpe, i-1)) + val selector = selectors.toB((tpe, i - 1)) FunctionApplication(selector, Seq(toSMT(t))) case al @ ArrayLength(a) => @@ -445,7 +436,7 @@ trait SMTLIBTarget extends Interruptible { val tpe = normalizeType(a.getType) val sa = toSMT(a) - val ssize = FunctionApplication(selectors.toB((tpe, 0)), Seq(sa)) + val ssize = FunctionApplication(selectors.toB((tpe, 0)), Seq(sa)) val scontent = FunctionApplication(selectors.toB((tpe, 1)), Seq(sa)) val newcontent = ArraysEx.Store(scontent, toSMT(i), toSMT(e)) @@ -458,8 +449,7 @@ trait SMTLIBTarget extends Interruptible { var res: Term = FunctionApplication( QualifiedIdentifier(SMTIdentifier(SSymbol("const")), Some(s)), - List(toSMT(default)) - ) + List(toSMT(default))) for ((k, v) <- elems) { res = ArraysEx.Store(res, toSMT(k), toSMT(v)) } @@ -489,7 +479,6 @@ trait SMTLIBTarget extends Interruptible { case (k, v) => k -> CaseClass(library.someType(to), Seq(v)) }.toMap, CaseClass(library.noneType(to), Seq()))) - case MapApply(m, k) => val mt @ MapType(_, to) = m.getType declareSort(mt) @@ -497,8 +486,7 @@ trait SMTLIBTarget extends Interruptible { // (Some-value (select m k)) FunctionApplication( selectors.toB((library.someType(to), 0)), - Seq(ArraysEx.Select(toSMT(m), toSMT(k))) - ) + Seq(ArraysEx.Select(toSMT(m), toSMT(k)))) case MapIsDefinedAt(m, k) => val mt @ MapType(_, to) = m.getType @@ -507,26 +495,25 @@ trait SMTLIBTarget extends Interruptible { // (is-Some (select m k)) FunctionApplication( testers.toB(library.someType(to)), - Seq(ArraysEx.Select(toSMT(m), toSMT(k))) - ) + Seq(ArraysEx.Select(toSMT(m), toSMT(k)))) case MapUnion(m1, FiniteMap(elems, _, _)) => val MapType(_, t) = m1.getType - elems.foldLeft(toSMT(m1)) { case (m, (k,v)) => - ArraysEx.Store(m, toSMT(k), toSMT(CaseClass(library.someType(t), Seq(v)))) + elems.foldLeft(toSMT(m1)) { + case (m, (k, v)) => + ArraysEx.Store(m, toSMT(k), toSMT(CaseClass(library.someType(t), Seq(v)))) } - case p : Passes => + case p: Passes => toSMT(matchToIfThenElse(p.asConstraint)) - case m : MatchExpr => + case m: MatchExpr => toSMT(matchToIfThenElse(m)) - case gv @ GenericValue(tpe, n) => genericValues.cachedB(gv) { - declareVariable(FreshIdentifier("gv"+n, tpe)) + declareVariable(FreshIdentifier("gv" + n, tpe)) } /** @@ -535,18 +522,18 @@ trait SMTLIBTarget extends Interruptible { case ap @ Application(caller, args) => ArraysEx.Select(toSMT(caller), toSMT(tupleWrap(args))) - case Not(u) => Core.Not(toSMT(u)) - case UMinus(u) => Ints.Neg(toSMT(u)) - case BVUMinus(u) => FixedSizeBitVectors.Neg(toSMT(u)) - case BVNot(u) => FixedSizeBitVectors.Not(toSMT(u)) - case Assert(a,_, b) => toSMT(IfExpr(a, b, Error(b.getType, "assertion failed"))) - - case Equals(a,b) => Core.Equals(toSMT(a), toSMT(b)) - case Implies(a,b) => Core.Implies(toSMT(a), toSMT(b)) - case Plus(a,b) => Ints.Add(toSMT(a), toSMT(b)) - case Minus(a,b) => Ints.Sub(toSMT(a), toSMT(b)) - case Times(a,b) => Ints.Mul(toSMT(a), toSMT(b)) - case Division(a,b) => { + case Not(u) => Core.Not(toSMT(u)) + case UMinus(u) => Ints.Neg(toSMT(u)) + case BVUMinus(u) => FixedSizeBitVectors.Neg(toSMT(u)) + case BVNot(u) => FixedSizeBitVectors.Not(toSMT(u)) + case Assert(a, _, b) => toSMT(IfExpr(a, b, Error(b.getType, "assertion failed"))) + + case Equals(a, b) => Core.Equals(toSMT(a), toSMT(b)) + case Implies(a, b) => Core.Implies(toSMT(a), toSMT(b)) + case Plus(a, b) => Ints.Add(toSMT(a), toSMT(b)) + case Minus(a, b) => Ints.Sub(toSMT(a), toSMT(b)) + case Times(a, b) => Ints.Mul(toSMT(a), toSMT(b)) + case Division(a, b) => { val ar = toSMT(a) val br = toSMT(b) @@ -555,63 +542,62 @@ trait SMTLIBTarget extends Interruptible { Ints.Div(ar, br), Ints.Neg(Ints.Div(Ints.Neg(ar), br))) } - case Remainder(a,b) => { + case Remainder(a, b) => { val q = toSMT(Division(a, b)) Ints.Sub(toSMT(a), Ints.Mul(toSMT(b), q)) } - case Modulo(a,b) => { + case Modulo(a, b) => { Ints.Mod(toSMT(a), toSMT(b)) } - case LessThan(a,b) => a.getType match { - case Int32Type => FixedSizeBitVectors.SLessThan(toSMT(a), toSMT(b)) + case LessThan(a, b) => a.getType match { + case Int32Type => FixedSizeBitVectors.SLessThan(toSMT(a), toSMT(b)) case IntegerType => Ints.LessThan(toSMT(a), toSMT(b)) - case RealType => Reals.LessThan(toSMT(a), toSMT(b)) - case CharType => FixedSizeBitVectors.SLessThan(toSMT(a), toSMT(b)) + case RealType => Reals.LessThan(toSMT(a), toSMT(b)) + case CharType => FixedSizeBitVectors.SLessThan(toSMT(a), toSMT(b)) } - case LessEquals(a,b) => a.getType match { - case Int32Type => FixedSizeBitVectors.SLessEquals(toSMT(a), toSMT(b)) + case LessEquals(a, b) => a.getType match { + case Int32Type => FixedSizeBitVectors.SLessEquals(toSMT(a), toSMT(b)) case IntegerType => Ints.LessEquals(toSMT(a), toSMT(b)) - case RealType => Reals.LessEquals(toSMT(a), toSMT(b)) - case CharType => FixedSizeBitVectors.SLessEquals(toSMT(a), toSMT(b)) + case RealType => Reals.LessEquals(toSMT(a), toSMT(b)) + case CharType => FixedSizeBitVectors.SLessEquals(toSMT(a), toSMT(b)) } - case GreaterThan(a,b) => a.getType match { - case Int32Type => FixedSizeBitVectors.SGreaterThan(toSMT(a), toSMT(b)) + case GreaterThan(a, b) => a.getType match { + case Int32Type => FixedSizeBitVectors.SGreaterThan(toSMT(a), toSMT(b)) case IntegerType => Ints.GreaterThan(toSMT(a), toSMT(b)) - case RealType => Reals.GreaterThan(toSMT(a), toSMT(b)) - case CharType => FixedSizeBitVectors.SGreaterThan(toSMT(a), toSMT(b)) + case RealType => Reals.GreaterThan(toSMT(a), toSMT(b)) + case CharType => FixedSizeBitVectors.SGreaterThan(toSMT(a), toSMT(b)) } - case GreaterEquals(a,b) => a.getType match { - case Int32Type => FixedSizeBitVectors.SGreaterEquals(toSMT(a), toSMT(b)) + case GreaterEquals(a, b) => a.getType match { + case Int32Type => FixedSizeBitVectors.SGreaterEquals(toSMT(a), toSMT(b)) case IntegerType => Ints.GreaterEquals(toSMT(a), toSMT(b)) - case RealType => Reals.GreaterEquals(toSMT(a), toSMT(b)) - case CharType => FixedSizeBitVectors.SGreaterEquals(toSMT(a), toSMT(b)) + case RealType => Reals.GreaterEquals(toSMT(a), toSMT(b)) + case CharType => FixedSizeBitVectors.SGreaterEquals(toSMT(a), toSMT(b)) } - case BVPlus(a,b) => FixedSizeBitVectors.Add(toSMT(a), toSMT(b)) - case BVMinus(a,b) => FixedSizeBitVectors.Sub(toSMT(a), toSMT(b)) - case BVTimes(a,b) => FixedSizeBitVectors.Mul(toSMT(a), toSMT(b)) - case BVDivision(a,b) => FixedSizeBitVectors.SDiv(toSMT(a), toSMT(b)) - case BVRemainder(a,b) => FixedSizeBitVectors.SRem(toSMT(a), toSMT(b)) - case BVAnd(a,b) => FixedSizeBitVectors.And(toSMT(a), toSMT(b)) - case BVOr(a,b) => FixedSizeBitVectors.Or(toSMT(a), toSMT(b)) - case BVXOr(a,b) => FixedSizeBitVectors.XOr(toSMT(a), toSMT(b)) - case BVShiftLeft(a,b) => FixedSizeBitVectors.ShiftLeft(toSMT(a), toSMT(b)) - case BVAShiftRight(a,b) => FixedSizeBitVectors.AShiftRight(toSMT(a), toSMT(b)) - case BVLShiftRight(a,b) => FixedSizeBitVectors.LShiftRight(toSMT(a), toSMT(b)) - - case RealPlus(a,b) => Reals.Add(toSMT(a), toSMT(b)) - case RealMinus(a,b) => Reals.Sub(toSMT(a), toSMT(b)) - case RealTimes(a,b) => Reals.Mul(toSMT(a), toSMT(b)) - case RealDivision(a,b) => Reals.Div(toSMT(a), toSMT(b)) - - case And(sub) => Core.And(sub.map(toSMT): _*) - case Or(sub) => Core.Or(sub.map(toSMT): _*) + case BVPlus(a, b) => FixedSizeBitVectors.Add(toSMT(a), toSMT(b)) + case BVMinus(a, b) => FixedSizeBitVectors.Sub(toSMT(a), toSMT(b)) + case BVTimes(a, b) => FixedSizeBitVectors.Mul(toSMT(a), toSMT(b)) + case BVDivision(a, b) => FixedSizeBitVectors.SDiv(toSMT(a), toSMT(b)) + case BVRemainder(a, b) => FixedSizeBitVectors.SRem(toSMT(a), toSMT(b)) + case BVAnd(a, b) => FixedSizeBitVectors.And(toSMT(a), toSMT(b)) + case BVOr(a, b) => FixedSizeBitVectors.Or(toSMT(a), toSMT(b)) + case BVXOr(a, b) => FixedSizeBitVectors.XOr(toSMT(a), toSMT(b)) + case BVShiftLeft(a, b) => FixedSizeBitVectors.ShiftLeft(toSMT(a), toSMT(b)) + case BVAShiftRight(a, b) => FixedSizeBitVectors.AShiftRight(toSMT(a), toSMT(b)) + case BVLShiftRight(a, b) => FixedSizeBitVectors.LShiftRight(toSMT(a), toSMT(b)) + + case RealPlus(a, b) => Reals.Add(toSMT(a), toSMT(b)) + case RealMinus(a, b) => Reals.Sub(toSMT(a), toSMT(b)) + case RealTimes(a, b) => Reals.Mul(toSMT(a), toSMT(b)) + case RealDivision(a, b) => Reals.Div(toSMT(a), toSMT(b)) + + case And(sub) => Core.And(sub.map(toSMT): _*) + case Or(sub) => Core.Or(sub.map(toSMT): _*) case IfExpr(cond, thenn, elze) => Core.ITE(toSMT(cond), toSMT(thenn), toSMT(elze)) - case f@FunctionInvocation(_, sub) => + case f @ FunctionInvocation(_, sub) => if (sub.isEmpty) declareFunction(f.tfd) else { FunctionApplication( declareFunction(f.tfd), - sub.map(toSMT) - ) + sub.map(toSMT)) } case Forall(vs, bd) => quantifiedTerm(SMTForall, vs map { _.id }, bd)(Map()) @@ -621,8 +607,7 @@ trait SMTLIBTarget extends Interruptible { } /* Translate an SMTLIB term back to a Leon Expr */ - protected def fromSMT(t: Term, otpe: Option[TypeTree] = None) - (implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): Expr = { + protected def fromSMT(t: Term, otpe: Option[TypeTree] = None)(implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): Expr = { // Use as much information as there is, if there is an expected type, great, but it might not always be there (t, otpe) match { @@ -643,10 +628,19 @@ trait SMTLIBTarget extends Interruptible { case (SDecimal(d), Some(RealType)) => // converting bigdecimal to a fraction - val scale = d.scale - val num = BigInt(d.bigDecimal.scaleByPowerOfTen(scale).toBigInteger()) - val denom = BigInt(new java.math.BigDecimal(1).scaleByPowerOfTen(-scale).toBigInteger()) - FractionalLiteral(num, denom) + if (d == BigDecimal(0)) + FractionalLiteral(0, 1) + else { + d.toBigIntExact() match { + case Some(num) => + FractionalLiteral(num, 1) + case _ => + val scale = d.scale + val num = BigInt(d.bigDecimal.scaleByPowerOfTen(scale).toBigInteger()) + val denom = BigInt(new java.math.BigDecimal(1).scaleByPowerOfTen(scale).toBigInteger()) + FractionalLiteral(num, denom) + } + } case (SNumeral(n), Some(RealType)) => FractionalLiteral(n, 1) @@ -655,8 +649,7 @@ trait SMTLIBTarget extends Interruptible { IfExpr( fromSMT(cond, Some(BooleanType)), fromSMT(thenn, t), - fromSMT(elze, t) - ) + fromSMT(elze, t)) // Best-effort case case (SNumeral(n), _) => @@ -670,13 +663,13 @@ trait SMTLIBTarget extends Interruptible { }.toMap fromSMT(body, tpe)(lets ++ defsMap, letDefs) - case (SimpleSymbol(s), _) if constructors.containsB(s) => - constructors.toA(s) match { - case cct: CaseClassType => - CaseClass(cct, Nil) - case t => - unsupported(t, "woot? for a single constructor for non-case-object") - } + case (SimpleSymbol(s), _) if constructors.containsB(s) => + constructors.toA(s) match { + case cct: CaseClassType => + CaseClass(cct, Nil) + case t => + unsupported(t, "woot? for a single constructor for non-case-object") + } case (FunctionApplication(SimpleSymbol(s), List(e)), _) if testers.containsB(s) => testers.toA(s) match { @@ -700,17 +693,17 @@ trait SMTLIBTarget extends Interruptible { tupleWrap(rargs) case ArrayType(baseType) => - val IntLiteral(size) = fromSMT(args(0), Int32Type) + val IntLiteral(size) = fromSMT(args(0), Int32Type) val RawArrayValue(_, elems, default) = fromSMT(args(1), RawArrayType(Int32Type, baseType)) - if(size > 10) { + if (size > 10) { val definedElements = elems.collect { case (IntLiteral(i), value) => (i, value) } finiteArray(definedElements, Some(default, IntLiteral(size)), baseType) } else { - val entries = for (i <- 0 to size-1) yield elems.getOrElse(IntLiteral(i), default) + val entries = for (i <- 0 to size - 1) yield elems.getOrElse(IntLiteral(i), default) finiteArray(entries, None, baseType) } @@ -736,8 +729,23 @@ trait SMTLIBTarget extends Interruptible { case ("+", args) => args.map(fromSMT(_, otpe)).reduceLeft(plus _) + case ("-", List(a)) if otpe == Some(RealType) => + val aexpr = fromSMT(a, otpe) + aexpr match { + case FractionalLiteral(na, da) => + FractionalLiteral(-na, da) + case _ => + UMinus(aexpr) + } + case ("-", List(a)) => - UMinus(fromSMT(a, otpe)) + val aexpr = fromSMT(a, otpe) + aexpr match { + case InfiniteIntegerLiteral(v) => + InfiniteIntegerLiteral(-v) + case _ => + UMinus(aexpr) + } case ("-", List(a, b)) => Minus(fromSMT(a, otpe), fromSMT(b, otpe)) @@ -745,6 +753,16 @@ trait SMTLIBTarget extends Interruptible { case ("*", args) => args.map(fromSMT(_, otpe)).reduceLeft(times _) + case ("/", List(a, b)) if otpe == Some(RealType) => + val aexpr = fromSMT(a, otpe) + val bexpr = fromSMT(b, otpe) + (aexpr, bexpr) match { + case (FractionalLiteral(na, da), FractionalLiteral(nb, db)) if da == 1 && db == 1 => + FractionalLiteral(na, nb) + case _ => + Division(aexpr, bexpr) + } + case ("/", List(a, b)) => Division(fromSMT(a, otpe), fromSMT(b, otpe)) @@ -765,7 +783,7 @@ trait SMTLIBTarget extends Interruptible { Equals(ra, fromSMT(b, ra.getType)) case _ => - reporter.fatalError("Function "+app+" not handled in fromSMT: "+s) + reporter.fatalError("Function " + app + " not handled in fromSMT: " + s) } case (Core.True(), Some(BooleanType)) => BooleanLiteral(true) @@ -795,4 +813,4 @@ trait SMTLIBTarget extends Interruptible { } // Unique numbers -private [smtlib] object DebugFileNumbers extends UniqueCounter[String] +private[smtlib] object DebugFileNumbers extends UniqueCounter[String] diff --git a/src/test/resources/regression/orb/timing/BinomialHeap.scala b/src/test/resources/regression/orb/timing/BinomialHeap.scala deleted file mode 100644 index 81b990d41..000000000 --- a/src/test/resources/regression/orb/timing/BinomialHeap.scala +++ /dev/null @@ -1,181 +0,0 @@ -import leon.invariant._ -import leon.instrumentation._ - -object BinomialHeap { - //sealed abstract class TreeNode - case class TreeNode(rank: BigInt, elem: Element, children: BinomialHeap) - case class Element(n: BigInt) - - sealed abstract class BinomialHeap - case class ConsHeap(head: TreeNode, tail: BinomialHeap) extends BinomialHeap - case class NilHeap() extends BinomialHeap - - sealed abstract class List - case class NodeL(head: BinomialHeap, tail: List) extends List - case class NilL() extends List - - sealed abstract class OptionalTree - case class Some(t : TreeNode) extends OptionalTree - case class None() extends OptionalTree - - /* Lower or Equal than for Element structure */ - private def leq(a: Element, b: Element) : Boolean = { - a match { - case Element(a1) => { - b match { - case Element(a2) => { - if(a1 <= a2) true - else false - } - } - } - } - } - - /* isEmpty function of the Binomial Heap */ - def isEmpty(t: BinomialHeap) = t match { - case ConsHeap(_,_) => false - case _ => true - } - - /* Helper function to determine rank of a TreeNode */ - def rank(t: TreeNode) : BigInt = t.rank /*t match { - case TreeNode(r, _, _) => r - }*/ - - /* Helper function to get the root element of a TreeNode */ - def root(t: TreeNode) : Element = t.elem /*t match { - case TreeNode(_, e, _) => e - }*/ - - /* Linking trees of equal ranks depending on the root element */ - def link(t1: TreeNode, t2: TreeNode): TreeNode = { - if (leq(t1.elem, t2.elem)) { - TreeNode(t1.rank + 1, t1.elem, ConsHeap(t2, t1.children)) - } else { - TreeNode(t1.rank + 1, t2.elem, ConsHeap(t1, t2.children)) - } - } - - def treeNum(h: BinomialHeap) : BigInt = { - h match { - case ConsHeap(head, tail) => 1 + treeNum(tail) - case _ => 0 - } - } - - /* Insert a tree into a binomial heap. The tree should be correct in relation to the heap */ - def insTree(t: TreeNode, h: BinomialHeap) : BinomialHeap = { - h match { - case ConsHeap(head, tail) => { - if (rank(t) < rank(head)) { - ConsHeap(t, h) - } else if (rank(t) > rank(head)) { - ConsHeap(head, insTree(t,tail)) - } else { - insTree(link(t,head), tail) - } - } - case _ => ConsHeap(t, NilHeap()) - } - } ensuring(_ => time <= ? * treeNum(h) + ?) - - /* Merge two heaps together */ - def merge(h1: BinomialHeap, h2: BinomialHeap): BinomialHeap = { - h1 match { - case ConsHeap(head1, tail1) => { - h2 match { - case ConsHeap(head2, tail2) => { - if (rank(head1) < rank(head2)) { - ConsHeap(head1, merge(tail1, h2)) - } else if (rank(head2) < rank(head1)) { - ConsHeap(head2, merge(h1, tail2)) - } else { - mergeWithCarry(link(head1, head2), tail1, tail2) - } - } - case _ => h1 - } - } - case _ => h2 - } - } ensuring(_ => time <= ? * treeNum(h1) + ? * treeNum(h2) + ?) - - def mergeWithCarry(t: TreeNode, h1: BinomialHeap, h2: BinomialHeap): BinomialHeap = { - h1 match { - case ConsHeap(head1, tail1) => { - h2 match { - case ConsHeap(head2, tail2) => { - if (rank(head1) < rank(head2)) { - - if (rank(t) < rank(head1)) - ConsHeap(t, ConsHeap(head1, merge(tail1, h2))) - else - mergeWithCarry(link(t, head1), tail1, h2) - - } else if (rank(head2) < rank(head1)) { - - if (rank(t) < rank(head2)) - ConsHeap(t, ConsHeap(head2, merge(h1, tail2))) - else - mergeWithCarry(link(t, head2), h1, tail2) - - } else { - ConsHeap(t, mergeWithCarry(link(head1, head2), tail1, tail2)) - } - } - case _ => { - insTree(t, h1) - } - } - } - case _ => insTree(t, h2) - } - } ensuring (_ => time <= ? * treeNum(h1) + ? * treeNum(h2) + ?) - - //Auxiliary helper function to simplefy findMin and deleteMin - def removeMinTree(h: BinomialHeap): (OptionalTree, BinomialHeap) = { - h match { - case ConsHeap(head, NilHeap()) => (Some(head), NilHeap()) - case ConsHeap(head1, tail1) => { - val (opthead2, tail2) = removeMinTree(tail1) - opthead2 match { - case Some(head2) => - if (leq(root(head1), root(head2))) { - (Some(head1), tail1) - } else { - (Some(head2), ConsHeap(head1, tail2)) - } - case _ => (Some(head1), tail1) - } - } - case _ => (None(), NilHeap()) - } - } ensuring (res => treeNum(res._2) <= treeNum(h) && time <= ? * treeNum(h) + ?) - - /*def findMin(h: BinomialHeap) : Element = { - val (opt, _) = removeMinTree(h) - opt match { - case Some(TreeNode(_,e,ts1)) => e - case _ => Element(-1) - } - } ensuring(res => true && tmpl((a,b) => time <= a*treeNum(h) + b))*/ - - def minTreeChildren(h: BinomialHeap) : BigInt = { - val (min, _) = removeMinTree(h) - min match { - case Some(TreeNode(_,_,ch)) => treeNum(ch) - case _ => 0 - } - } - - // Discard the minimum element of the extracted min tree and put its children back into the heap - def deleteMin(h: BinomialHeap) : BinomialHeap = { - val (min, ts2) = removeMinTree(h) - min match { - case Some(TreeNode(_,_,ts1)) => merge(ts1, ts2) - case _ => h - } - } ensuring(_ => time <= ? * minTreeChildren(h) + ? * treeNum(h) + ?) - -} -- GitLab