From 51e2f35ca6125dbb75e044e6cd9a2df92be28f33 Mon Sep 17 00:00:00 2001
From: Etienne Kneuss <ekneuss@gmail.com>
Date: Mon, 4 May 2015 18:25:37 +0200
Subject: [PATCH] Unify and centralize use of solvers

---
 .../codegen/runtime/ChooseEntryPoint.scala    |  6 ++--
 src/main/scala/leon/repair/Repairman.scala    |  4 ++-
 .../scala/leon/solvers/SolverFactory.scala    | 20 ++++++++++---
 .../leon/synthesis/SynthesisContext.scala     | 28 ++++---------------
 .../leon/synthesis/SynthesisSettings.scala    |  1 -
 .../leon/synthesis/rules/CEGISLike.scala      |  6 ++--
 .../scala/leon/termination/Processor.scala    |  8 ++++--
 src/main/scala/leon/utils/Simplifiers.scala   | 11 ++++----
 .../leon/test/synthesis/SynthesisSuite.scala  | 13 +++++----
 9 files changed, 49 insertions(+), 48 deletions(-)

diff --git a/src/main/scala/leon/codegen/runtime/ChooseEntryPoint.scala b/src/main/scala/leon/codegen/runtime/ChooseEntryPoint.scala
index b74b46e89..a79fe4a2d 100644
--- a/src/main/scala/leon/codegen/runtime/ChooseEntryPoint.scala
+++ b/src/main/scala/leon/codegen/runtime/ChooseEntryPoint.scala
@@ -7,12 +7,12 @@ import utils._
 import purescala.Expressions._
 import purescala.ExprOps.valuateWithModel
 import purescala.Constructors._
-import solvers.TimeoutSolver
-import solvers.z3._
+import solvers.SolverFactory
 
 import java.util.WeakHashMap
 import java.lang.ref.WeakReference
 import scala.collection.mutable.{HashMap => MutableMap}
+import scala.concurrent.duration._
 
 import codegen.CompilationUnit
 
@@ -79,7 +79,7 @@ object ChooseEntryPoint {
     } else {
       val tStart = System.currentTimeMillis
 
-      val solver = (new FairZ3Solver(ctx, program) with TimeoutSolver).setTimeout(10000L)
+      val solver = SolverFactory.default(ctx, program).withTimeout(10.second).getNewSolver()
 
       val inputsMap = (p.as zip inputs).map {
         case (id, v) =>
diff --git a/src/main/scala/leon/repair/Repairman.scala b/src/main/scala/leon/repair/Repairman.scala
index bb7c2ef9a..5fbfc60a7 100644
--- a/src/main/scala/leon/repair/Repairman.scala
+++ b/src/main/scala/leon/repair/Repairman.scala
@@ -25,6 +25,8 @@ import rules._
 import graph.DotGenerator
 import leon.utils.ASCIIHelpers.title
 
+import scala.concurrent.duration._
+
 class Repairman(ctx: LeonContext, initProgram: Program, fd: FunDef, verifTimeoutMs: Option[Long], repairTimeoutMs: Option[Long]) {
   val reporter = ctx.reporter
 
@@ -469,7 +471,7 @@ class Repairman(ctx: LeonContext, initProgram: Program, fd: FunDef, verifTimeout
       None
     } else {
       val diff = and(p.pc, not(Equals(s1, s2)))
-      val solver = (new FairZ3Solver(ctx, program) with TimeoutSolver).setTimeout(1000)
+      val solver = SolverFactory.default(ctx, program).withTimeout(1.second).getNewSolver()
 
       solver.assertCnstr(diff)
       solver.check match {
diff --git a/src/main/scala/leon/solvers/SolverFactory.scala b/src/main/scala/leon/solvers/SolverFactory.scala
index 4e53b8b9e..2b838b68b 100644
--- a/src/main/scala/leon/solvers/SolverFactory.scala
+++ b/src/main/scala/leon/solvers/SolverFactory.scala
@@ -3,6 +3,10 @@
 package leon
 package solvers
 
+import combinators._
+import z3._
+import smtlib._
+
 import purescala.Definitions._
 import scala.reflect.runtime.universe._
 
@@ -40,10 +44,6 @@ object SolverFactory {
   }
 
   def getFromName(ctx: LeonContext, program: Program)(names: String*): SolverFactory[TimeoutSolver] = {
-    import combinators._
-    import z3._
-    import smtlib._
-
 
     def getSolver(name: String): SolverFactory[TimeoutSolver] = name match {
       case "fairz3" =>
@@ -93,4 +93,16 @@ object SolverFactory {
 
   }
 
+  // Solver qualifiers that get used internally:
+
+  // Fast solver used by simplifiactions, to discharge simple tautologies
+  def uninterpreted(ctx: LeonContext, program: Program): SolverFactory[TimeoutSolver] = {
+    SolverFactory(() => new SMTLIBSolver(ctx, program) with SMTLIBZ3Target with TimeoutSolver)
+  }
+
+  // Full featured solver used by default
+  def default(ctx: LeonContext, program: Program): SolverFactory[TimeoutSolver] = {
+    getFromName(ctx, program)("fairz3")
+  }
+
 }
diff --git a/src/main/scala/leon/synthesis/SynthesisContext.scala b/src/main/scala/leon/synthesis/SynthesisContext.scala
index d9b9a6039..920901e4f 100644
--- a/src/main/scala/leon/synthesis/SynthesisContext.scala
+++ b/src/main/scala/leon/synthesis/SynthesisContext.scala
@@ -22,36 +22,20 @@ case class SynthesisContext(
 
   val rules = settings.rules
 
-  val allSolvers: Map[String, SolverFactory[SynthesisContext.SynthesisSolver]] = Map(
-    "fairz3" -> SolverFactory(() => new FairZ3Solver(context, program) with TimeoutAssumptionSolver),
-    "enum"   -> SolverFactory(() => new EnumerationSolver(context, program) with TimeoutAssumptionSolver)
-  )
+  val solverFactory = SolverFactory.getFromSettings(context, program)
 
-  val solversToUse = allSolvers.filterKeys(settings.selectedSolvers)
-
-  val solverFactory: SolverFactory[SynthesisContext.SynthesisSolver] = solversToUse.values.toSeq match {
-    case Seq() =>
-      reporter.fatalError("No solver selected. Aborting")
-    case Seq(value) =>
-      value
-    case more =>
-      SolverFactory( () => new PortfolioSolverSynth(context, more) with TimeoutAssumptionSolver )
-  }
-
-  def newSolver: SynthesisContext.SynthesisSolver = {
+  def newSolver = {
     solverFactory.getNewSolver()
   }
 
-  def newFastSolver: SynthesisContext.SynthesisSolver = {
-    new UninterpretedZ3Solver(context, program) with TimeoutAssumptionSolver
-  }
-
-  val fastSolverFactory = SolverFactory(() => newFastSolver)
+  val fastSolverFactory = SolverFactory.uninterpreted(context, program)
 
+  def newFastSolver = {
+    fastSolverFactory.getNewSolver()
+  }
 }
 
 object SynthesisContext {
-  type SynthesisSolver = TimeoutAssumptionSolver with IncrementalSolver
 
   def fromSynthesizer(synth: Synthesizer) = {
     SynthesisContext(
diff --git a/src/main/scala/leon/synthesis/SynthesisSettings.scala b/src/main/scala/leon/synthesis/SynthesisSettings.scala
index 9a52e58d6..9d227bdb6 100644
--- a/src/main/scala/leon/synthesis/SynthesisSettings.scala
+++ b/src/main/scala/leon/synthesis/SynthesisSettings.scala
@@ -12,7 +12,6 @@ case class SynthesisSettings(
   rules: Seq[Rule]                    = Rules.all,
   manualSearch: Option[String]        = None,
   searchBound: Option[Int]            = None,
-  selectedSolvers: Set[String]        = Set("fairz3"),
   functions: Option[Set[String]]      = None,
   functionsToIgnore: Set[FunDef]      = Set(),
   
diff --git a/src/main/scala/leon/synthesis/rules/CEGISLike.scala b/src/main/scala/leon/synthesis/rules/CEGISLike.scala
index fce0363fb..8c5a88bf0 100644
--- a/src/main/scala/leon/synthesis/rules/CEGISLike.scala
+++ b/src/main/scala/leon/synthesis/rules/CEGISLike.scala
@@ -392,7 +392,7 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) {
             val cnstr = and(p.pc, letTuple(p.xs, sol, Not(p.phi)))
             //println("Solving for: "+cnstr)
 
-            val solver = (new FairZ3Solver(ctx, prog) with TimeoutSolver).setTimeout(cexSolverTo)
+            val solver = SolverFactory.default(ctx, prog).withTimeout(cexSolverTo).getNewSolver()
             try {
               solver.assertCnstr(cnstr)
               solver.check match {
@@ -658,7 +658,7 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) {
       }
 
       def solveForTentativeProgram(): Option[Option[Set[Identifier]]] = {
-        val solver  = (new FairZ3Solver(ctx, programCTree) with TimeoutSolver).setTimeout(exSolverTo)
+        val solver = SolverFactory.default(ctx, programCTree).withTimeout(exSolverTo).getNewSolver()
         val cnstr = FunctionInvocation(phiFd.typed, phiFd.params.map(_.id.toVariable))
         //debugCExpr(cTree)
 
@@ -735,7 +735,7 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) {
       }
 
       def solveForCounterExample(bs: Set[Identifier]): Option[Option[Seq[Expr]]] = {
-        val solver = (new FairZ3Solver(ctx, programCTree) with TimeoutSolver).setTimeout(cexSolverTo)
+        val solver = SolverFactory.default(ctx, programCTree).withTimeout(cexSolverTo).getNewSolver()
         val cnstr = FunctionInvocation(phiFd.typed, phiFd.params.map(_.id.toVariable))
 
         val fixedBs = finiteArray(bsOrdered.map(b => BooleanLiteral(bs(b))), None, BooleanType)
diff --git a/src/main/scala/leon/termination/Processor.scala b/src/main/scala/leon/termination/Processor.scala
index 87637e316..b6c6972f9 100644
--- a/src/main/scala/leon/termination/Processor.scala
+++ b/src/main/scala/leon/termination/Processor.scala
@@ -7,6 +7,8 @@ import purescala.Expressions._
 import purescala.Common._
 import purescala.Definitions._
 
+import scala.concurrent.duration._
+
 import leon.solvers._
 import leon.solvers.z3._
 
@@ -35,15 +37,15 @@ trait Solvable extends Processor {
 
   val checker : TerminationChecker with Strengthener with StructuralSize
 
-  private val solver: SolverFactory[Solver] = SolverFactory(() => {
+  private val solver: SolverFactory[Solver] = {
     val program     : Program     = checker.program
     val context     : LeonContext = checker.context
     val sizeModule  : ModuleDef   = ModuleDef(FreshIdentifier("$size"), checker.defs.toSeq, false)
     val sizeUnit    : UnitDef     = UnitDef(FreshIdentifier("$size"),Seq(sizeModule)) 
     val newProgram  : Program     = program.copy( units = sizeUnit :: program.units)
 
-    (new FairZ3Solver(context, newProgram) with TimeoutAssumptionSolver).setTimeout(500L)
-  })
+    SolverFactory.default(context, newProgram).withTimeout(500.millisecond)
+  }
 
   type Solution = (Option[Boolean], Map[Identifier, Expr])
 
diff --git a/src/main/scala/leon/utils/Simplifiers.scala b/src/main/scala/leon/utils/Simplifiers.scala
index 2426539ec..82fa58a72 100644
--- a/src/main/scala/leon/utils/Simplifiers.scala
+++ b/src/main/scala/leon/utils/Simplifiers.scala
@@ -7,18 +7,17 @@ import purescala.Definitions._
 import purescala.Expressions._
 import purescala.ExprOps._
 import purescala.ScopeSimplifier
-import solvers.z3.UninterpretedZ3Solver
 import solvers._
 
 object Simplifiers {
   
   def bestEffort(ctx: LeonContext, p: Program)(e: Expr): Expr = {
-    val uninterpretedZ3 = SolverFactory(() => new UninterpretedZ3Solver(ctx, p))
+    val solver = SolverFactory.uninterpreted(ctx, p)
 
     val simplifiers = List[Expr => Expr](
-      simplifyTautologies(uninterpretedZ3)(_),
+      simplifyTautologies(solver)(_),
       simplifyLets,
-      simplifyPaths(uninterpretedZ3)(_),
+      simplifyPaths(solver)(_),
       simplifyArithmetic,
       evalGround(ctx, p),
       normalizeExpression
@@ -38,10 +37,10 @@ object Simplifiers {
   }
 
   def namePreservingBestEffort(ctx: LeonContext, p: Program)(e: Expr): Expr = {
-    val uninterpretedZ3 = SolverFactory(() => new UninterpretedZ3Solver(ctx, p))
+    val solver = SolverFactory.uninterpreted(ctx, p)
 
     val simplifiers = List[Expr => Expr](
-      simplifyTautologies(uninterpretedZ3)(_),
+      simplifyTautologies(solver)(_),
       simplifyArithmetic,
       evalGround(ctx, p),
       normalizeExpression
diff --git a/src/test/scala/leon/test/synthesis/SynthesisSuite.scala b/src/test/scala/leon/test/synthesis/SynthesisSuite.scala
index c043ac0fa..b8ef5681f 100644
--- a/src/test/scala/leon/test/synthesis/SynthesisSuite.scala
+++ b/src/test/scala/leon/test/synthesis/SynthesisSuite.scala
@@ -81,9 +81,9 @@ class SynthesisSuite extends LeonTestSuite {
 
   }
 
-  def forProgram(title: String, opts: SynthesisSettings = SynthesisSettings())(content: String)(strats: PartialFunction[String, SynStrat]) {
+  def forProgram(title: String, opts: Seq[LeonOption[Any]] = Nil)(content: String)(strats: PartialFunction[String, SynStrat]) {
       test(f"Synthesizing ${nextInt()}%3d: [$title]") {
-        val ctx = testContext
+        val ctx = testContext.copy(options = opts ++ testContext.options)
 
         val pipeline = leon.utils.TemporaryInputPhase andThen leon.frontends.scalac.ExtractionPhase andThen PreprocessingPhase andThen SynthesisProblemExtractionPhase
 
@@ -92,9 +92,12 @@ class SynthesisSuite extends LeonTestSuite {
         for ((f,cis) <- results; ci <- cis) {
           info(f"${ci.fd.id.toString}%-20s")
 
-          val sctx = SynthesisContext(ctx, opts, ci.fd, program)
+          val sctx = SynthesisContext(ctx,
+                                      SynthesisSettings(),
+                                      ci.fd,
+                                      program)
 
-          val p    = ci.problem
+          val p      = ci.problem
 
           if (strats.isDefinedAt(f.id.name)) {
             val search = new TestSearch(ctx, ci, p, strats(f.id.name))
@@ -109,7 +112,7 @@ class SynthesisSuite extends LeonTestSuite {
       }
   }
 
-  forProgram("Ground Enum", SynthesisSettings(selectedSolvers = Set("enum")))(
+  forProgram("Ground Enum", Seq(LeonOption(SharedOptions.optSelectedSolvers)(Set("enum"))))(
     """
 import leon.annotation._
 import leon.lang._
-- 
GitLab