From a42e72ed3c213f736cdf993d15776036eb4f7903 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Mika=C3=ABl=20Mayer?= <mikael.mayer@epfl.ch>
Date: Wed, 10 Feb 2016 16:27:28 +0100
Subject: [PATCH] Added Z3 string capacities to all Z3 solvers

---
 .../leon/solvers/QuantificationSolver.scala   |  2 +-
 .../scala/leon/solvers/SolverFactory.scala    | 13 ++-
 .../combinators/Z3StringCapableSolver.scala   | 99 ++++++++++++++-----
 3 files changed, 82 insertions(+), 32 deletions(-)

diff --git a/src/main/scala/leon/solvers/QuantificationSolver.scala b/src/main/scala/leon/solvers/QuantificationSolver.scala
index fa11ab661..4f56903c5 100644
--- a/src/main/scala/leon/solvers/QuantificationSolver.scala
+++ b/src/main/scala/leon/solvers/QuantificationSolver.scala
@@ -25,7 +25,7 @@ class HenkinModelBuilder(domains: HenkinDomains)
   override def result = new HenkinModel(mapBuilder.result, domains)
 }
 
-trait QuantificationSolver {
+trait QuantificationSolver extends Solver {
   val program: Program
   def getModel: HenkinModel
 
diff --git a/src/main/scala/leon/solvers/SolverFactory.scala b/src/main/scala/leon/solvers/SolverFactory.scala
index 4e87a2888..42d7ead02 100644
--- a/src/main/scala/leon/solvers/SolverFactory.scala
+++ b/src/main/scala/leon/solvers/SolverFactory.scala
@@ -79,10 +79,12 @@ object SolverFactory {
 
   def getFromName(ctx: LeonContext, program: Program)(name: String): SolverFactory[TimeoutSolver] = name match {
     case "fairz3" =>
-      SolverFactory(() => new FairZ3Solver(ctx, program) with TimeoutSolver)
+      // Previously:      new FairZ3Solver(ctx, program) with TimeoutSolver
+      SolverFactory(() => new Z3StringFairZ3Solver(ctx, program) with TimeoutSolver)
 
     case "unrollz3" =>
-      SolverFactory(() => new UnrollingSolver(ctx, program, new UninterpretedZ3Solver(ctx, program)) with TimeoutSolver)
+      // Previously:      new UnrollingSolver(ctx, program, new UninterpretedZ3Solver(ctx, program)) with TimeoutSolver
+      SolverFactory(() => new Z3StringUnrollingSolver(ctx, program, (program: Program) => new UninterpretedZ3Solver(ctx, program)) with TimeoutSolver)
 
     case "enum"   =>
       SolverFactory(() => new EnumerationSolver(ctx, program) with TimeoutSolver)
@@ -91,11 +93,12 @@ object SolverFactory {
       SolverFactory(() => new GroundSolver(ctx, program) with TimeoutSolver)
 
     case "smt-z3" =>
-      SolverFactory(() => new Z3StringCapableSolver(ctx, program, (program: Program) =>
-                              new UnrollingSolver(ctx, program, new SMTLIBZ3Solver(ctx, program))) with TimeoutSolver)
+      // Previously:      new UnrollingSolver(ctx, program, new SMTLIBZ3Solver(ctx, program)) with TimeoutSolver
+      SolverFactory(() => new Z3StringUnrollingSolver(ctx, program, (program: Program) => new SMTLIBZ3Solver(ctx, program)) with TimeoutSolver)
 
     case "smt-z3-q" =>
-      SolverFactory(() => new SMTLIBZ3QuantifiedSolver(ctx, program) with TimeoutSolver)
+      // Previously:      new SMTLIBZ3QuantifiedSolver(ctx, program) with TimeoutSolver
+      SolverFactory(() => new Z3StringSMTLIBZ3QuantifiedSolver(ctx, program) with TimeoutSolver)
 
     case "smt-cvc4" =>
       SolverFactory(() => new UnrollingSolver(ctx, program, new SMTLIBCVC4Solver(ctx, program)) with TimeoutSolver)
diff --git a/src/main/scala/leon/solvers/combinators/Z3StringCapableSolver.scala b/src/main/scala/leon/solvers/combinators/Z3StringCapableSolver.scala
index 9648c1785..64f8161f1 100644
--- a/src/main/scala/leon/solvers/combinators/Z3StringCapableSolver.scala
+++ b/src/main/scala/leon/solvers/combinators/Z3StringCapableSolver.scala
@@ -55,25 +55,53 @@ object Z3StringCapableSolver {
   }
 }
 
-class Z3StringCapableSolver(val context: LeonContext, val program: Program, f: Program => UnrollingSolver)  extends Solver
-     with NaiveAssumptionSolver
-     with EvaluatingSolver
-     with QuantificationSolver  {
-  
-  val ((new_program, mappings), converter, idMap) = Z3StringCapableSolver.convert(program)
+abstract class Z3StringCapableSolver[+TUnderlying <: Solver](val context: LeonContext, val program: Program, val underlyingConstructor: Program => TUnderlying)
+extends Solver {
+  protected val ((new_program, mappings), converter, idMap) = Z3StringCapableSolver.convert(program)
 
   val idMapReverse = idMap.map(kv => kv._2 -> kv._1).toMap
-  val underlying = f(new_program)
-
-  // Members declared in leon.solvers.EvaluatingSolver
-  val useCodeGen: Boolean = underlying.useCodeGen
-
+  val underlying = underlyingConstructor(new_program)
+  
+  def getModel: leon.solvers.Model = underlying.getModel
+  
   // Members declared in leon.utils.Interruptible
   def interrupt(): Unit = underlying.interrupt()
   def recoverInterrupt(): Unit = underlying.recoverInterrupt()
 
+  // Members declared in leon.solvers.Solver
+  def assertCnstr(expression: leon.purescala.Expressions.Expr): Unit = {
+    val expression2 = DefOps.replaceFunCalls(expression, mappings.withDefault { x => x }.apply _)
+    import converter.Forward._
+    val newExpression = convertExpr(expression2)(idMap.mapValues(Variable))
+    underlying.assertCnstr(newExpression)
+  }
+  def getUnsatCore: Set[Expr] = {
+    import converter.Backward._
+    underlying.getUnsatCore map (e => convertExpr(e)(Map()))
+  }
+  def check: Option[Boolean] = underlying.check
+  def free(): Unit = underlying.free()
+  def pop(): Unit = underlying.pop()
+  def push(): Unit = underlying.push()
+  def reset(): Unit = underlying.reset()
+}
+
+import z3._
+
+trait Z3StringAbstractZ3Solver[TUnderlying <: Solver] extends AbstractZ3Solver { self: Z3StringCapableSolver[TUnderlying] =>
+}
+
+trait Z3StringNaiveAssumptionSolver[TUnderlying <: Solver] extends NaiveAssumptionSolver { self:  Z3StringCapableSolver[TUnderlying] =>
+}
+
+trait Z3StringEvaluatingSolver[TUnderlying <: EvaluatingSolver] extends EvaluatingSolver{ self:  Z3StringCapableSolver[TUnderlying] =>
+  // Members declared in leon.solvers.EvaluatingSolver
+  val useCodeGen: Boolean = underlying.useCodeGen
+}
+
+trait Z3StringQuantificationSolver[TUnderlying <: QuantificationSolver] extends QuantificationSolver { self:  Z3StringCapableSolver[TUnderlying] =>
   // Members declared in leon.solvers.QuantificationSolver
-  def getModel: leon.solvers.HenkinModel = {
+  override def getModel: leon.solvers.HenkinModel = {
     val model = underlying.getModel
     val ids = model.ids.toSeq
     val exprs = ids.map(model.apply)
@@ -92,18 +120,37 @@ class Z3StringCapableSolver(val context: LeonContext, val program: Program, f: P
     
     new HenkinModel(original_ids.zip(original_exprs).toMap, new_domain)
   }
+}
+
+class Z3StringFairZ3Solver(context: LeonContext, program: Program)
+  extends Z3StringCapableSolver(context, program, (program: Program) => new z3.FairZ3Solver(context, program)) 
+  with Z3StringAbstractZ3Solver[FairZ3Solver]
+  with Z3ModelReconstruction
+  with FairZ3Component
+  with Z3StringEvaluatingSolver[FairZ3Solver]
+  with Z3StringQuantificationSolver[FairZ3Solver] {
+     // Members declared in leon.solvers.z3.AbstractZ3Solver
+    protected[leon] val z3cfg: _root_.z3.scala.Z3Config = underlying.z3cfg
+    override def reset() = super[Z3StringAbstractZ3Solver].reset()
+    override def checkAssumptions(assumptions: Set[Expr]): Option[Boolean] = {
+      underlying.checkAssumptions(assumptions map (e => converter.Forward.convertExpr(e)(Map())))
+    }
+}
+
+class Z3StringUnrollingSolver(context: LeonContext, program: Program, underlyingSolverConstructor: Program => Solver)
+  extends Z3StringCapableSolver(context, program, (program: Program) => new UnrollingSolver(context, program, underlyingSolverConstructor(program)))
+  with Z3StringNaiveAssumptionSolver[UnrollingSolver]
+  with Z3StringEvaluatingSolver[UnrollingSolver]
+  with Z3StringQuantificationSolver[UnrollingSolver] {
+    def name = underlying.name
+    override def getUnsatCore = super[Z3StringNaiveAssumptionSolver].getUnsatCore
+}
+
+class Z3StringSMTLIBZ3QuantifiedSolver(context: LeonContext, program: Program)
+  extends Z3StringCapableSolver(context, program, (program: Program) => new smtlib.SMTLIBZ3QuantifiedSolver(context, program)) {
+     def name: String = underlying.name
+     def checkAssumptions(assumptions: Set[Expr]): Option[Boolean] = {
+      underlying.checkAssumptions(assumptions map (e => converter.Forward.convertExpr(e)(Map())))
+    }
+}
 
-  // Members declared in leon.solvers.Solver
-  def assertCnstr(expression: leon.purescala.Expressions.Expr): Unit = {
-    val expression2 = DefOps.replaceFunCalls(expression, mappings.withDefault { x => x }.apply _)
-    import converter.Forward._
-    val newExpression = convertExpr(expression2)(idMap.mapValues(Variable))
-    underlying.assertCnstr(newExpression)
-  }
-  def check: Option[Boolean] = underlying.check
-  def free(): Unit = underlying.free()
-  def name: String = "String" + underlying.name
-  def pop(): Unit = underlying.pop()
-  def push(): Unit = underlying.push()
-  def reset(): Unit = underlying.reset()
-}
\ No newline at end of file
-- 
GitLab