Skip to content
Snippets Groups Projects
Commit 51e2f35c authored by Etienne Kneuss's avatar Etienne Kneuss
Browse files

Unify and centralize use of solvers

parent 353fe9f5
No related branches found
No related tags found
No related merge requests found
......@@ -7,12 +7,12 @@ import utils._
import purescala.Expressions._
import purescala.ExprOps.valuateWithModel
import purescala.Constructors._
import solvers.TimeoutSolver
import solvers.z3._
import solvers.SolverFactory
import java.util.WeakHashMap
import java.lang.ref.WeakReference
import scala.collection.mutable.{HashMap => MutableMap}
import scala.concurrent.duration._
import codegen.CompilationUnit
......@@ -79,7 +79,7 @@ object ChooseEntryPoint {
} else {
val tStart = System.currentTimeMillis
val solver = (new FairZ3Solver(ctx, program) with TimeoutSolver).setTimeout(10000L)
val solver = SolverFactory.default(ctx, program).withTimeout(10.second).getNewSolver()
val inputsMap = (p.as zip inputs).map {
case (id, v) =>
......
......@@ -25,6 +25,8 @@ import rules._
import graph.DotGenerator
import leon.utils.ASCIIHelpers.title
import scala.concurrent.duration._
class Repairman(ctx: LeonContext, initProgram: Program, fd: FunDef, verifTimeoutMs: Option[Long], repairTimeoutMs: Option[Long]) {
val reporter = ctx.reporter
......@@ -469,7 +471,7 @@ class Repairman(ctx: LeonContext, initProgram: Program, fd: FunDef, verifTimeout
None
} else {
val diff = and(p.pc, not(Equals(s1, s2)))
val solver = (new FairZ3Solver(ctx, program) with TimeoutSolver).setTimeout(1000)
val solver = SolverFactory.default(ctx, program).withTimeout(1.second).getNewSolver()
solver.assertCnstr(diff)
solver.check match {
......
......@@ -3,6 +3,10 @@
package leon
package solvers
import combinators._
import z3._
import smtlib._
import purescala.Definitions._
import scala.reflect.runtime.universe._
......@@ -40,10 +44,6 @@ object SolverFactory {
}
def getFromName(ctx: LeonContext, program: Program)(names: String*): SolverFactory[TimeoutSolver] = {
import combinators._
import z3._
import smtlib._
def getSolver(name: String): SolverFactory[TimeoutSolver] = name match {
case "fairz3" =>
......@@ -93,4 +93,16 @@ object SolverFactory {
}
// Solver qualifiers that get used internally:
// Fast solver used by simplifiactions, to discharge simple tautologies
def uninterpreted(ctx: LeonContext, program: Program): SolverFactory[TimeoutSolver] = {
SolverFactory(() => new SMTLIBSolver(ctx, program) with SMTLIBZ3Target with TimeoutSolver)
}
// Full featured solver used by default
def default(ctx: LeonContext, program: Program): SolverFactory[TimeoutSolver] = {
getFromName(ctx, program)("fairz3")
}
}
......@@ -22,36 +22,20 @@ case class SynthesisContext(
val rules = settings.rules
val allSolvers: Map[String, SolverFactory[SynthesisContext.SynthesisSolver]] = Map(
"fairz3" -> SolverFactory(() => new FairZ3Solver(context, program) with TimeoutAssumptionSolver),
"enum" -> SolverFactory(() => new EnumerationSolver(context, program) with TimeoutAssumptionSolver)
)
val solverFactory = SolverFactory.getFromSettings(context, program)
val solversToUse = allSolvers.filterKeys(settings.selectedSolvers)
val solverFactory: SolverFactory[SynthesisContext.SynthesisSolver] = solversToUse.values.toSeq match {
case Seq() =>
reporter.fatalError("No solver selected. Aborting")
case Seq(value) =>
value
case more =>
SolverFactory( () => new PortfolioSolverSynth(context, more) with TimeoutAssumptionSolver )
}
def newSolver: SynthesisContext.SynthesisSolver = {
def newSolver = {
solverFactory.getNewSolver()
}
def newFastSolver: SynthesisContext.SynthesisSolver = {
new UninterpretedZ3Solver(context, program) with TimeoutAssumptionSolver
}
val fastSolverFactory = SolverFactory(() => newFastSolver)
val fastSolverFactory = SolverFactory.uninterpreted(context, program)
def newFastSolver = {
fastSolverFactory.getNewSolver()
}
}
object SynthesisContext {
type SynthesisSolver = TimeoutAssumptionSolver with IncrementalSolver
def fromSynthesizer(synth: Synthesizer) = {
SynthesisContext(
......
......@@ -12,7 +12,6 @@ case class SynthesisSettings(
rules: Seq[Rule] = Rules.all,
manualSearch: Option[String] = None,
searchBound: Option[Int] = None,
selectedSolvers: Set[String] = Set("fairz3"),
functions: Option[Set[String]] = None,
functionsToIgnore: Set[FunDef] = Set(),
......
......@@ -392,7 +392,7 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) {
val cnstr = and(p.pc, letTuple(p.xs, sol, Not(p.phi)))
//println("Solving for: "+cnstr)
val solver = (new FairZ3Solver(ctx, prog) with TimeoutSolver).setTimeout(cexSolverTo)
val solver = SolverFactory.default(ctx, prog).withTimeout(cexSolverTo).getNewSolver()
try {
solver.assertCnstr(cnstr)
solver.check match {
......@@ -658,7 +658,7 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) {
}
def solveForTentativeProgram(): Option[Option[Set[Identifier]]] = {
val solver = (new FairZ3Solver(ctx, programCTree) with TimeoutSolver).setTimeout(exSolverTo)
val solver = SolverFactory.default(ctx, programCTree).withTimeout(exSolverTo).getNewSolver()
val cnstr = FunctionInvocation(phiFd.typed, phiFd.params.map(_.id.toVariable))
//debugCExpr(cTree)
......@@ -735,7 +735,7 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) {
}
def solveForCounterExample(bs: Set[Identifier]): Option[Option[Seq[Expr]]] = {
val solver = (new FairZ3Solver(ctx, programCTree) with TimeoutSolver).setTimeout(cexSolverTo)
val solver = SolverFactory.default(ctx, programCTree).withTimeout(cexSolverTo).getNewSolver()
val cnstr = FunctionInvocation(phiFd.typed, phiFd.params.map(_.id.toVariable))
val fixedBs = finiteArray(bsOrdered.map(b => BooleanLiteral(bs(b))), None, BooleanType)
......
......@@ -7,6 +7,8 @@ import purescala.Expressions._
import purescala.Common._
import purescala.Definitions._
import scala.concurrent.duration._
import leon.solvers._
import leon.solvers.z3._
......@@ -35,15 +37,15 @@ trait Solvable extends Processor {
val checker : TerminationChecker with Strengthener with StructuralSize
private val solver: SolverFactory[Solver] = SolverFactory(() => {
private val solver: SolverFactory[Solver] = {
val program : Program = checker.program
val context : LeonContext = checker.context
val sizeModule : ModuleDef = ModuleDef(FreshIdentifier("$size"), checker.defs.toSeq, false)
val sizeUnit : UnitDef = UnitDef(FreshIdentifier("$size"),Seq(sizeModule))
val newProgram : Program = program.copy( units = sizeUnit :: program.units)
(new FairZ3Solver(context, newProgram) with TimeoutAssumptionSolver).setTimeout(500L)
})
SolverFactory.default(context, newProgram).withTimeout(500.millisecond)
}
type Solution = (Option[Boolean], Map[Identifier, Expr])
......
......@@ -7,18 +7,17 @@ import purescala.Definitions._
import purescala.Expressions._
import purescala.ExprOps._
import purescala.ScopeSimplifier
import solvers.z3.UninterpretedZ3Solver
import solvers._
object Simplifiers {
def bestEffort(ctx: LeonContext, p: Program)(e: Expr): Expr = {
val uninterpretedZ3 = SolverFactory(() => new UninterpretedZ3Solver(ctx, p))
val solver = SolverFactory.uninterpreted(ctx, p)
val simplifiers = List[Expr => Expr](
simplifyTautologies(uninterpretedZ3)(_),
simplifyTautologies(solver)(_),
simplifyLets,
simplifyPaths(uninterpretedZ3)(_),
simplifyPaths(solver)(_),
simplifyArithmetic,
evalGround(ctx, p),
normalizeExpression
......@@ -38,10 +37,10 @@ object Simplifiers {
}
def namePreservingBestEffort(ctx: LeonContext, p: Program)(e: Expr): Expr = {
val uninterpretedZ3 = SolverFactory(() => new UninterpretedZ3Solver(ctx, p))
val solver = SolverFactory.uninterpreted(ctx, p)
val simplifiers = List[Expr => Expr](
simplifyTautologies(uninterpretedZ3)(_),
simplifyTautologies(solver)(_),
simplifyArithmetic,
evalGround(ctx, p),
normalizeExpression
......
......@@ -81,9 +81,9 @@ class SynthesisSuite extends LeonTestSuite {
}
def forProgram(title: String, opts: SynthesisSettings = SynthesisSettings())(content: String)(strats: PartialFunction[String, SynStrat]) {
def forProgram(title: String, opts: Seq[LeonOption[Any]] = Nil)(content: String)(strats: PartialFunction[String, SynStrat]) {
test(f"Synthesizing ${nextInt()}%3d: [$title]") {
val ctx = testContext
val ctx = testContext.copy(options = opts ++ testContext.options)
val pipeline = leon.utils.TemporaryInputPhase andThen leon.frontends.scalac.ExtractionPhase andThen PreprocessingPhase andThen SynthesisProblemExtractionPhase
......@@ -92,9 +92,12 @@ class SynthesisSuite extends LeonTestSuite {
for ((f,cis) <- results; ci <- cis) {
info(f"${ci.fd.id.toString}%-20s")
val sctx = SynthesisContext(ctx, opts, ci.fd, program)
val sctx = SynthesisContext(ctx,
SynthesisSettings(),
ci.fd,
program)
val p = ci.problem
val p = ci.problem
if (strats.isDefinedAt(f.id.name)) {
val search = new TestSearch(ctx, ci, p, strats(f.id.name))
......@@ -109,7 +112,7 @@ class SynthesisSuite extends LeonTestSuite {
}
}
forProgram("Ground Enum", SynthesisSettings(selectedSolvers = Set("enum")))(
forProgram("Ground Enum", Seq(LeonOption(SharedOptions.optSelectedSolvers)(Set("enum"))))(
"""
import leon.annotation._
import leon.lang._
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment