From 197ab1416fb44770026e1a6aab906c1eb7eed49e Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Mika=C3=ABl=20Mayer?= <mikael.mayer@epfl.ch>
Date: Thu, 11 Feb 2016 15:40:40 +0100
Subject: [PATCH] Simplified the converter a bit Ensures that no conversion
 occurs if only fundefs in the library use strings. Forward and backward
 conversion for embedded evaluators in FairZ3

---
 src/main/scala/leon/purescala/DefOps.scala    |   2 +-
 .../combinators/Z3StringCapableSolver.scala   | 171 +++++++++++++-----
 .../leon/solvers/z3/Z3StringConversion.scala  |  15 ++
 3 files changed, 146 insertions(+), 42 deletions(-)

diff --git a/src/main/scala/leon/purescala/DefOps.scala b/src/main/scala/leon/purescala/DefOps.scala
index b5e5f1921..2eeb5a4d4 100644
--- a/src/main/scala/leon/purescala/DefOps.scala
+++ b/src/main/scala/leon/purescala/DefOps.scala
@@ -315,7 +315,7 @@ object DefOps {
       )
     })
     for(fd <- newP.definedFunctions) {
-      if(ExprOps.exists{ case FunctionInvocation(TypedFunDef(fd, targs), fargs) => fdMapCache(fd) != None case _ => false }(fd.fullBody)) {
+      if(ExprOps.exists{ case FunctionInvocation(TypedFunDef(fd, targs), fargs) => fdMapCache.getOrElse(fd, None) != None case _ => false }(fd.fullBody)) {
         fd.fullBody = replaceFunCalls(fd.fullBody, fdMap, fiMapF)
       }
     }
diff --git a/src/main/scala/leon/solvers/combinators/Z3StringCapableSolver.scala b/src/main/scala/leon/solvers/combinators/Z3StringCapableSolver.scala
index 64d871ac6..d8657e9d4 100644
--- a/src/main/scala/leon/solvers/combinators/Z3StringCapableSolver.scala
+++ b/src/main/scala/leon/solvers/combinators/Z3StringCapableSolver.scala
@@ -24,12 +24,13 @@ import leon.utils.Bijection
 import leon.solvers.z3.StringEcoSystem
 
 object Z3StringCapableSolver {
-  def convert(p: Program): ((Program, Map[FunDef, FunDef]), Z3StringConversion) = {
+  def convert(p: Program): (Program, Option[Z3StringConversion]) = {
     val converter = new Z3StringConversion(p)
-    import converter._
     import converter.Forward._
     var globalFdMap = Map[FunDef, (Map[Identifier, Identifier], FunDef)]()
-    val (new_program, fdMap) = DefOps.replaceFunDefs(converter.getProgram)((fd: FunDef) => {
+    var hasStrings = false
+    val program_with_strings = converter.getProgram
+    val (new_program, fdMap) = DefOps.replaceFunDefs(program_with_strings)((fd: FunDef) => {
       globalFdMap.get(fd).map(_._2).orElse(
           if( fd.body.map(exists(e => TypeOps.exists{ _== StringType }(e.getType))).getOrElse(false) ||
               fd.paramIds.exists(id => TypeOps.exists(_ == StringType)(id.getType))) {
@@ -40,26 +41,44 @@ object Z3StringCapableSolver {
                 fd.params.map(vd => ValDef(idMap(vd.id))),
                 convertType(fd.returnType))
             globalFdMap += fd -> ((idMap, newFd))
+            hasStrings = hasStrings || (program_with_strings.library.escape.get != fd)
             Some(newFd)
           } else None
       )
     })
-    converter.globalFdMap ++= globalFdMap.view.map(kv => (kv._1, kv._2._2))
-    for((fd, (idMap, newFd)) <- globalFdMap) {
-      implicit val idVarMap = idMap.mapValues(id => Variable(id))
-      newFd.fullBody = convertExpr(newFd.fullBody)
+    if(!hasStrings) {
+      (p, None)
+    } else {
+      converter.globalFdMap ++= globalFdMap.view.map(kv => (kv._1, kv._2._2))
+      for((fd, (idMap, newFd)) <- globalFdMap) {
+        implicit val idVarMap = idMap.mapValues(id => Variable(id))
+        newFd.fullBody = convertExpr(newFd.fullBody)
+      }
+      (new_program, Some(converter))
     }
-    ((new_program, fdMap), converter)
   }
 }
 
-abstract class Z3StringCapableSolver[+TUnderlying <: Solver](val context: LeonContext, val program: Program, val underlyingConstructor: Program => TUnderlying)
+abstract class Z3StringCapableSolver[+TUnderlying <: Solver](val context: LeonContext, val program: Program,
+    val underlyingConstructor: (Program, Option[Z3StringConversion]) => TUnderlying)
 extends Solver {
-  protected val ((new_program, mappings), converter) = Z3StringCapableSolver.convert(program)
+  protected val (new_program, someConverter) = Z3StringCapableSolver.convert(program)
 
-  val underlying = underlyingConstructor(new_program)
+  val underlying = underlyingConstructor(new_program, someConverter)
   
-  def getModel: leon.solvers.Model = underlying.getModel
+  def getModel: leon.solvers.Model = {
+    val model = underlying.getModel
+    someConverter match {
+      case None => model
+      case Some(converter) =>
+        val ids = model.ids.toSeq
+        val exprs = ids.map(model.apply)
+        import converter.Backward._
+        val original_ids = ids.map(convertId)
+        val original_exprs = exprs.map{ case e => convertExpr(e)(Map()) }
+        new Model(original_ids.zip(original_exprs).toMap)
+    }
+  }
   
   // Members declared in leon.utils.Interruptible
   def interrupt(): Unit = underlying.interrupt()
@@ -67,15 +86,21 @@ extends Solver {
 
   // Members declared in leon.solvers.Solver
   def assertCnstr(expression: Expr): Unit = {
-    //println("Asserting " + expression)
-    val expression2 = DefOps.replaceFunCalls(expression, mappings.withDefault { x => x }.apply _)
-    import converter.Forward._
-    val newExpression = convertExpr(expression2)(Map())
-    underlying.assertCnstr(newExpression)
+    someConverter match {
+      case None => underlying.assertCnstr(expression)
+      case Some(converter) =>
+        import converter.Forward._
+        val newExpression = convertExpr(expression)(Map())
+        underlying.assertCnstr(newExpression)
+    }
   }
   def getUnsatCore: Set[Expr] = {
-    import converter.Backward._
-    underlying.getUnsatCore map (e => convertExpr(e)(Map()))
+    someConverter match {
+      case None => underlying.getUnsatCore
+      case Some(converter) =>
+        import converter.Backward._
+        underlying.getUnsatCore map (e => convertExpr(e)(Map()))
+    }
   }
   def check: Option[Boolean] = underlying.check
   def free(): Unit = underlying.free()
@@ -102,38 +127,97 @@ trait Z3StringQuantificationSolver[TUnderlying <: QuantificationSolver] extends
   // Members declared in leon.solvers.QuantificationSolver
   override def getModel: leon.solvers.HenkinModel = {
     val model = underlying.getModel
-    val ids = model.ids.toSeq
-    val exprs = ids.map(model.apply)
-    import converter.Backward._
-    val original_ids = ids.map(convertId)
-    val original_exprs = exprs.map{ case e => convertExpr(e)(Map()) }
-    
-    val new_domain = new HenkinDomains(
-        model.doms.lambdas.map(kv =>
-          (convertExpr(kv._1)(Map()).asInstanceOf[Lambda],
-           kv._2.map(e => e.map(e => convertExpr(e)(Map()))))).toMap,
-        model.doms.tpes.map(kv =>
-          (convertType(kv._1),
-           kv._2.map(e => e.map(e => convertExpr(e)(Map()))))).toMap
-        )
-    
-    new HenkinModel(original_ids.zip(original_exprs).toMap, new_domain)
+    someConverter match {
+      case None => model
+      case Some(converter) =>
+        val ids = model.ids.toSeq
+        val exprs = ids.map(model.apply)
+        import converter.Backward._
+        val original_ids = ids.map(convertId)
+        val original_exprs = exprs.map{ case e => convertExpr(e)(Map()) }
+        
+        val new_domain = new HenkinDomains(
+            model.doms.lambdas.map(kv =>
+              (convertExpr(kv._1)(Map()).asInstanceOf[Lambda],
+               kv._2.map(e => e.map(e => convertExpr(e)(Map()))))).toMap,
+            model.doms.tpes.map(kv =>
+              (convertType(kv._1),
+               kv._2.map(e => e.map(e => convertExpr(e)(Map()))))).toMap
+            )
+        
+        new HenkinModel(original_ids.zip(original_exprs).toMap, new_domain)
+    }
+  }
+}
+
+trait EvaluatorCheckConverter extends DeterministicEvaluator {
+  def converter: Z3StringConversion
+  abstract override def check(expression: Expr, model: solvers.Model) : CheckResult = {
+    val c = converter
+    import c.Backward._  // Because the evaluator is going to be called by the underlying solver, but it will use the original program
+    super.check(convertExpr(expression)(Map()), convertModel(model))
+  }
+}
+
+class ConvertibleCodeGenEvaluator(context: LeonContext, originalProgram: Program, val converter: Z3StringConversion)
+    extends CodeGenEvaluator(context, originalProgram) with EvaluatorCheckConverter {
+  override def compile(expression: Expr, args: Seq[Identifier]) : Option[solvers.Model=>EvaluationResult] = {
+    import converter._ 
+    super.compile(Backward.convertExpr(expression)(Map()), args.map(Backward.convertId))
+    .map(evaluator => (m: Model) => Forward.convertResult(evaluator(Backward.convertModel(m)))
+    )
+  }
+}
+
+class ConvertibleDefaultEvaluator(context: LeonContext, originalProgram: Program, val converter: Z3StringConversion) extends DefaultEvaluator(context, originalProgram) with EvaluatorCheckConverter {
+  override def eval(ex: Expr, model: Model): EvaluationResults.Result[Expr] = {
+    import converter._
+    Forward.convertResult(super.eval(Backward.convertExpr(ex)(Map()), Backward.convertModel(model)))
   }
 }
 
+
+class FairZ3SolverWithBackwardEvaluator(context: LeonContext, program: Program,
+    originalProgram: Program, someConverter: Option[Z3StringConversion]) extends FairZ3Solver(context, program) {
+  override lazy val evaluator: DeterministicEvaluator = { // We evaluate expressions using the original evaluator
+    someConverter match {
+      case Some(converter) =>
+        if (useCodeGen) {
+          new ConvertibleCodeGenEvaluator(context, originalProgram, converter)
+        } else {
+          new ConvertibleDefaultEvaluator(context, originalProgram, converter)
+        }
+      case None =>
+        if (useCodeGen) {
+          new CodeGenEvaluator(context, program)
+        } else {
+          new DefaultEvaluator(context, program)
+        }
+    }
+  }
+}
+
+
 class Z3StringFairZ3Solver(context: LeonContext, program: Program)
-  extends Z3StringCapableSolver(context, program, (program: Program) => new z3.FairZ3Solver(context, program)) 
+  extends Z3StringCapableSolver(context, program,
+      (prgm: Program, someConverter: Option[Z3StringConversion]) =>
+        new FairZ3SolverWithBackwardEvaluator(context, prgm, program, someConverter)) 
   with Z3StringEvaluatingSolver[FairZ3Solver]
   with Z3StringQuantificationSolver[FairZ3Solver] {
      // Members declared in leon.solvers.z3.AbstractZ3Solver
     protected[leon] val z3cfg: _root_.z3.scala.Z3Config = underlying.z3cfg
     override def checkAssumptions(assumptions: Set[Expr]): Option[Boolean] = {
-      underlying.checkAssumptions(assumptions map (e => converter.Forward.convertExpr(e)(Map())))
+      someConverter match {
+        case None => underlying.checkAssumptions(assumptions)
+        case Some(converter) =>
+          underlying.checkAssumptions(assumptions map (e => converter.Forward.convertExpr(e)(Map())))
+      }
     }
 }
 
 class Z3StringUnrollingSolver(context: LeonContext, program: Program, underlyingSolverConstructor: Program => Solver)
-  extends Z3StringCapableSolver(context, program, (program: Program) => new UnrollingSolver(context, program, underlyingSolverConstructor(program)))
+  extends Z3StringCapableSolver(context, program, (program: Program, converter: Option[Z3StringConversion]) =>
+    new UnrollingSolver(context, program, underlyingSolverConstructor(program)))
   with Z3StringNaiveAssumptionSolver[UnrollingSolver]
   with Z3StringEvaluatingSolver[UnrollingSolver]
   with Z3StringQuantificationSolver[UnrollingSolver] {
@@ -141,9 +225,14 @@ class Z3StringUnrollingSolver(context: LeonContext, program: Program, underlying
 }
 
 class Z3StringSMTLIBZ3QuantifiedSolver(context: LeonContext, program: Program)
-  extends Z3StringCapableSolver(context, program, (program: Program) => new smtlib.SMTLIBZ3QuantifiedSolver(context, program)) {
-     def checkAssumptions(assumptions: Set[Expr]): Option[Boolean] = {
-      underlying.checkAssumptions(assumptions map (e => converter.Forward.convertExpr(e)(Map())))
+  extends Z3StringCapableSolver(context, program, (program: Program, converter: Option[Z3StringConversion]) =>
+    new smtlib.SMTLIBZ3QuantifiedSolver(context, program)) {
+     override def checkAssumptions(assumptions: Set[Expr]): Option[Boolean] = {
+      someConverter match {
+        case None => underlying.checkAssumptions(assumptions)
+        case Some(converter) =>
+          underlying.checkAssumptions(assumptions map (e => converter.Forward.convertExpr(e)(Map())))
+      }
     }
 }
 
diff --git a/src/main/scala/leon/solvers/z3/Z3StringConversion.scala b/src/main/scala/leon/solvers/z3/Z3StringConversion.scala
index 5dbf6bef4..2a2b6e596 100644
--- a/src/main/scala/leon/solvers/z3/Z3StringConversion.scala
+++ b/src/main/scala/leon/solvers/z3/Z3StringConversion.scala
@@ -11,6 +11,7 @@ import leon.utils.Bijection
 import leon.purescala.DefOps
 import leon.purescala.TypeOps
 import leon.purescala.Extractors.Operator
+import leon.evaluators.EvaluationResults
 
 object StringEcoSystem {
   private def withIdentifier[T](name: String, tpe: TypeTree = Untyped)(f: Identifier => T): T = {
@@ -237,6 +238,20 @@ trait Z3StringConverters  { self: Z3StringConversion =>
         case e => e
       })
     }
+    
+    def convertModel(model: Model): Model = {
+      new Model(model.ids.map{i =>
+        val id = convertId(i)
+        id -> convertExpr(model(i))(Map())
+      }.toMap)
+    }
+    
+    def convertResult(result: EvaluationResults.Result[Expr]) = {
+      result match {
+        case EvaluationResults.Successful(e) => EvaluationResults.Successful(convertExpr(e)(Map()))
+        case result => result
+      }
+    }
   }
   
   object Forward extends BidirectionalConverters {
-- 
GitLab