From 51e2f35ca6125dbb75e044e6cd9a2df92be28f33 Mon Sep 17 00:00:00 2001 From: Etienne Kneuss <ekneuss@gmail.com> Date: Mon, 4 May 2015 18:25:37 +0200 Subject: [PATCH] Unify and centralize use of solvers --- .../codegen/runtime/ChooseEntryPoint.scala | 6 ++-- src/main/scala/leon/repair/Repairman.scala | 4 ++- .../scala/leon/solvers/SolverFactory.scala | 20 ++++++++++--- .../leon/synthesis/SynthesisContext.scala | 28 ++++--------------- .../leon/synthesis/SynthesisSettings.scala | 1 - .../leon/synthesis/rules/CEGISLike.scala | 6 ++-- .../scala/leon/termination/Processor.scala | 8 ++++-- src/main/scala/leon/utils/Simplifiers.scala | 11 ++++---- .../leon/test/synthesis/SynthesisSuite.scala | 13 +++++---- 9 files changed, 49 insertions(+), 48 deletions(-) diff --git a/src/main/scala/leon/codegen/runtime/ChooseEntryPoint.scala b/src/main/scala/leon/codegen/runtime/ChooseEntryPoint.scala index b74b46e89..a79fe4a2d 100644 --- a/src/main/scala/leon/codegen/runtime/ChooseEntryPoint.scala +++ b/src/main/scala/leon/codegen/runtime/ChooseEntryPoint.scala @@ -7,12 +7,12 @@ import utils._ import purescala.Expressions._ import purescala.ExprOps.valuateWithModel import purescala.Constructors._ -import solvers.TimeoutSolver -import solvers.z3._ +import solvers.SolverFactory import java.util.WeakHashMap import java.lang.ref.WeakReference import scala.collection.mutable.{HashMap => MutableMap} +import scala.concurrent.duration._ import codegen.CompilationUnit @@ -79,7 +79,7 @@ object ChooseEntryPoint { } else { val tStart = System.currentTimeMillis - val solver = (new FairZ3Solver(ctx, program) with TimeoutSolver).setTimeout(10000L) + val solver = SolverFactory.default(ctx, program).withTimeout(10.second).getNewSolver() val inputsMap = (p.as zip inputs).map { case (id, v) => diff --git a/src/main/scala/leon/repair/Repairman.scala b/src/main/scala/leon/repair/Repairman.scala index bb7c2ef9a..5fbfc60a7 100644 --- a/src/main/scala/leon/repair/Repairman.scala +++ b/src/main/scala/leon/repair/Repairman.scala @@ -25,6 +25,8 @@ import rules._ import graph.DotGenerator import leon.utils.ASCIIHelpers.title +import scala.concurrent.duration._ + class Repairman(ctx: LeonContext, initProgram: Program, fd: FunDef, verifTimeoutMs: Option[Long], repairTimeoutMs: Option[Long]) { val reporter = ctx.reporter @@ -469,7 +471,7 @@ class Repairman(ctx: LeonContext, initProgram: Program, fd: FunDef, verifTimeout None } else { val diff = and(p.pc, not(Equals(s1, s2))) - val solver = (new FairZ3Solver(ctx, program) with TimeoutSolver).setTimeout(1000) + val solver = SolverFactory.default(ctx, program).withTimeout(1.second).getNewSolver() solver.assertCnstr(diff) solver.check match { diff --git a/src/main/scala/leon/solvers/SolverFactory.scala b/src/main/scala/leon/solvers/SolverFactory.scala index 4e53b8b9e..2b838b68b 100644 --- a/src/main/scala/leon/solvers/SolverFactory.scala +++ b/src/main/scala/leon/solvers/SolverFactory.scala @@ -3,6 +3,10 @@ package leon package solvers +import combinators._ +import z3._ +import smtlib._ + import purescala.Definitions._ import scala.reflect.runtime.universe._ @@ -40,10 +44,6 @@ object SolverFactory { } def getFromName(ctx: LeonContext, program: Program)(names: String*): SolverFactory[TimeoutSolver] = { - import combinators._ - import z3._ - import smtlib._ - def getSolver(name: String): SolverFactory[TimeoutSolver] = name match { case "fairz3" => @@ -93,4 +93,16 @@ object SolverFactory { } + // Solver qualifiers that get used internally: + + // Fast solver used by simplifiactions, to discharge simple tautologies + def uninterpreted(ctx: LeonContext, program: Program): SolverFactory[TimeoutSolver] = { + SolverFactory(() => new SMTLIBSolver(ctx, program) with SMTLIBZ3Target with TimeoutSolver) + } + + // Full featured solver used by default + def default(ctx: LeonContext, program: Program): SolverFactory[TimeoutSolver] = { + getFromName(ctx, program)("fairz3") + } + } diff --git a/src/main/scala/leon/synthesis/SynthesisContext.scala b/src/main/scala/leon/synthesis/SynthesisContext.scala index d9b9a6039..920901e4f 100644 --- a/src/main/scala/leon/synthesis/SynthesisContext.scala +++ b/src/main/scala/leon/synthesis/SynthesisContext.scala @@ -22,36 +22,20 @@ case class SynthesisContext( val rules = settings.rules - val allSolvers: Map[String, SolverFactory[SynthesisContext.SynthesisSolver]] = Map( - "fairz3" -> SolverFactory(() => new FairZ3Solver(context, program) with TimeoutAssumptionSolver), - "enum" -> SolverFactory(() => new EnumerationSolver(context, program) with TimeoutAssumptionSolver) - ) + val solverFactory = SolverFactory.getFromSettings(context, program) - val solversToUse = allSolvers.filterKeys(settings.selectedSolvers) - - val solverFactory: SolverFactory[SynthesisContext.SynthesisSolver] = solversToUse.values.toSeq match { - case Seq() => - reporter.fatalError("No solver selected. Aborting") - case Seq(value) => - value - case more => - SolverFactory( () => new PortfolioSolverSynth(context, more) with TimeoutAssumptionSolver ) - } - - def newSolver: SynthesisContext.SynthesisSolver = { + def newSolver = { solverFactory.getNewSolver() } - def newFastSolver: SynthesisContext.SynthesisSolver = { - new UninterpretedZ3Solver(context, program) with TimeoutAssumptionSolver - } - - val fastSolverFactory = SolverFactory(() => newFastSolver) + val fastSolverFactory = SolverFactory.uninterpreted(context, program) + def newFastSolver = { + fastSolverFactory.getNewSolver() + } } object SynthesisContext { - type SynthesisSolver = TimeoutAssumptionSolver with IncrementalSolver def fromSynthesizer(synth: Synthesizer) = { SynthesisContext( diff --git a/src/main/scala/leon/synthesis/SynthesisSettings.scala b/src/main/scala/leon/synthesis/SynthesisSettings.scala index 9a52e58d6..9d227bdb6 100644 --- a/src/main/scala/leon/synthesis/SynthesisSettings.scala +++ b/src/main/scala/leon/synthesis/SynthesisSettings.scala @@ -12,7 +12,6 @@ case class SynthesisSettings( rules: Seq[Rule] = Rules.all, manualSearch: Option[String] = None, searchBound: Option[Int] = None, - selectedSolvers: Set[String] = Set("fairz3"), functions: Option[Set[String]] = None, functionsToIgnore: Set[FunDef] = Set(), diff --git a/src/main/scala/leon/synthesis/rules/CEGISLike.scala b/src/main/scala/leon/synthesis/rules/CEGISLike.scala index fce0363fb..8c5a88bf0 100644 --- a/src/main/scala/leon/synthesis/rules/CEGISLike.scala +++ b/src/main/scala/leon/synthesis/rules/CEGISLike.scala @@ -392,7 +392,7 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { val cnstr = and(p.pc, letTuple(p.xs, sol, Not(p.phi))) //println("Solving for: "+cnstr) - val solver = (new FairZ3Solver(ctx, prog) with TimeoutSolver).setTimeout(cexSolverTo) + val solver = SolverFactory.default(ctx, prog).withTimeout(cexSolverTo).getNewSolver() try { solver.assertCnstr(cnstr) solver.check match { @@ -658,7 +658,7 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { } def solveForTentativeProgram(): Option[Option[Set[Identifier]]] = { - val solver = (new FairZ3Solver(ctx, programCTree) with TimeoutSolver).setTimeout(exSolverTo) + val solver = SolverFactory.default(ctx, programCTree).withTimeout(exSolverTo).getNewSolver() val cnstr = FunctionInvocation(phiFd.typed, phiFd.params.map(_.id.toVariable)) //debugCExpr(cTree) @@ -735,7 +735,7 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { } def solveForCounterExample(bs: Set[Identifier]): Option[Option[Seq[Expr]]] = { - val solver = (new FairZ3Solver(ctx, programCTree) with TimeoutSolver).setTimeout(cexSolverTo) + val solver = SolverFactory.default(ctx, programCTree).withTimeout(cexSolverTo).getNewSolver() val cnstr = FunctionInvocation(phiFd.typed, phiFd.params.map(_.id.toVariable)) val fixedBs = finiteArray(bsOrdered.map(b => BooleanLiteral(bs(b))), None, BooleanType) diff --git a/src/main/scala/leon/termination/Processor.scala b/src/main/scala/leon/termination/Processor.scala index 87637e316..b6c6972f9 100644 --- a/src/main/scala/leon/termination/Processor.scala +++ b/src/main/scala/leon/termination/Processor.scala @@ -7,6 +7,8 @@ import purescala.Expressions._ import purescala.Common._ import purescala.Definitions._ +import scala.concurrent.duration._ + import leon.solvers._ import leon.solvers.z3._ @@ -35,15 +37,15 @@ trait Solvable extends Processor { val checker : TerminationChecker with Strengthener with StructuralSize - private val solver: SolverFactory[Solver] = SolverFactory(() => { + private val solver: SolverFactory[Solver] = { val program : Program = checker.program val context : LeonContext = checker.context val sizeModule : ModuleDef = ModuleDef(FreshIdentifier("$size"), checker.defs.toSeq, false) val sizeUnit : UnitDef = UnitDef(FreshIdentifier("$size"),Seq(sizeModule)) val newProgram : Program = program.copy( units = sizeUnit :: program.units) - (new FairZ3Solver(context, newProgram) with TimeoutAssumptionSolver).setTimeout(500L) - }) + SolverFactory.default(context, newProgram).withTimeout(500.millisecond) + } type Solution = (Option[Boolean], Map[Identifier, Expr]) diff --git a/src/main/scala/leon/utils/Simplifiers.scala b/src/main/scala/leon/utils/Simplifiers.scala index 2426539ec..82fa58a72 100644 --- a/src/main/scala/leon/utils/Simplifiers.scala +++ b/src/main/scala/leon/utils/Simplifiers.scala @@ -7,18 +7,17 @@ import purescala.Definitions._ import purescala.Expressions._ import purescala.ExprOps._ import purescala.ScopeSimplifier -import solvers.z3.UninterpretedZ3Solver import solvers._ object Simplifiers { def bestEffort(ctx: LeonContext, p: Program)(e: Expr): Expr = { - val uninterpretedZ3 = SolverFactory(() => new UninterpretedZ3Solver(ctx, p)) + val solver = SolverFactory.uninterpreted(ctx, p) val simplifiers = List[Expr => Expr]( - simplifyTautologies(uninterpretedZ3)(_), + simplifyTautologies(solver)(_), simplifyLets, - simplifyPaths(uninterpretedZ3)(_), + simplifyPaths(solver)(_), simplifyArithmetic, evalGround(ctx, p), normalizeExpression @@ -38,10 +37,10 @@ object Simplifiers { } def namePreservingBestEffort(ctx: LeonContext, p: Program)(e: Expr): Expr = { - val uninterpretedZ3 = SolverFactory(() => new UninterpretedZ3Solver(ctx, p)) + val solver = SolverFactory.uninterpreted(ctx, p) val simplifiers = List[Expr => Expr]( - simplifyTautologies(uninterpretedZ3)(_), + simplifyTautologies(solver)(_), simplifyArithmetic, evalGround(ctx, p), normalizeExpression diff --git a/src/test/scala/leon/test/synthesis/SynthesisSuite.scala b/src/test/scala/leon/test/synthesis/SynthesisSuite.scala index c043ac0fa..b8ef5681f 100644 --- a/src/test/scala/leon/test/synthesis/SynthesisSuite.scala +++ b/src/test/scala/leon/test/synthesis/SynthesisSuite.scala @@ -81,9 +81,9 @@ class SynthesisSuite extends LeonTestSuite { } - def forProgram(title: String, opts: SynthesisSettings = SynthesisSettings())(content: String)(strats: PartialFunction[String, SynStrat]) { + def forProgram(title: String, opts: Seq[LeonOption[Any]] = Nil)(content: String)(strats: PartialFunction[String, SynStrat]) { test(f"Synthesizing ${nextInt()}%3d: [$title]") { - val ctx = testContext + val ctx = testContext.copy(options = opts ++ testContext.options) val pipeline = leon.utils.TemporaryInputPhase andThen leon.frontends.scalac.ExtractionPhase andThen PreprocessingPhase andThen SynthesisProblemExtractionPhase @@ -92,9 +92,12 @@ class SynthesisSuite extends LeonTestSuite { for ((f,cis) <- results; ci <- cis) { info(f"${ci.fd.id.toString}%-20s") - val sctx = SynthesisContext(ctx, opts, ci.fd, program) + val sctx = SynthesisContext(ctx, + SynthesisSettings(), + ci.fd, + program) - val p = ci.problem + val p = ci.problem if (strats.isDefinedAt(f.id.name)) { val search = new TestSearch(ctx, ci, p, strats(f.id.name)) @@ -109,7 +112,7 @@ class SynthesisSuite extends LeonTestSuite { } } - forProgram("Ground Enum", SynthesisSettings(selectedSolvers = Set("enum")))( + forProgram("Ground Enum", Seq(LeonOption(SharedOptions.optSelectedSolvers)(Set("enum"))))( """ import leon.annotation._ import leon.lang._ -- GitLab