From 827d8f01f8975810027b8f90fd92d53641b6e186 Mon Sep 17 00:00:00 2001
From: Nicolas Voirol <voirol.nicolas@gmail.com>
Date: Tue, 9 Feb 2016 11:34:05 +0100
Subject: [PATCH] Added generics to codegen evaluator

---
 .../scala/leon/codegen/runtime/Monitor.scala   | 18 ++++++++++++------
 .../scala/leon/evaluators/DualEvaluator.scala  | 10 +++++-----
 .../leon/evaluators/RecursiveEvaluator.scala   |  2 ++
 src/main/scala/leon/purescala/TypeOps.scala    |  2 +-
 .../solvers/combinators/UnrollingSolver.scala  | 16 +++++++++++-----
 .../templates/QuantificationManager.scala      | 12 ++++++++----
 6 files changed, 39 insertions(+), 21 deletions(-)

diff --git a/src/main/scala/leon/codegen/runtime/Monitor.scala b/src/main/scala/leon/codegen/runtime/Monitor.scala
index ac88c35be..64d80eaa8 100644
--- a/src/main/scala/leon/codegen/runtime/Monitor.scala
+++ b/src/main/scala/leon/codegen/runtime/Monitor.scala
@@ -115,15 +115,21 @@ class StdMonitor(unit: CompilationUnit, invocationsMax: Int, bodies: ScalaMap[Id
       val newTypes = tps.toSeq.map(unit.runtimeIdToTypeMap(_))
       val tpMap = (tparams.map(TypeParameterDef(_)) zip newTypes).toMap
 
-      val vars = (variablesOf(p.pc) ++ variablesOf(p.phi)).toSeq.sortBy(_.uniqueName)
-      val newVars = vars.map(id => FreshIdentifier(id.name, instantiateType(id.getType, tpMap), true))
+      val newXs = p.xs.map { id =>
+        val newTpe = instantiateType(id.getType, tpMap)
+        if (id.getType == newTpe) id else FreshIdentifier(id.name, newTpe, true)
+      }
+
+      val newAs = p.as.map { id =>
+        val newTpe = instantiateType(id.getType, tpMap)
+        if (id.getType == newTpe) id else FreshIdentifier(id.name, newTpe, true)
+      }
 
-      val args = p.as.map(id => FreshIdentifier(id.name, instantiateType(id.getType, tpMap), true))
-      val inputsMap = (args zip inputs).map {
+      val inputsMap = (newAs zip inputs).map {
         case (id, v) => Equals(Variable(id), unit.jvmToValue(v, id.getType))
       }
 
-      val expr = instantiateType(and(p.pc, p.phi), tpMap, (vars zip newVars).toMap)
+      val expr = instantiateType(and(p.pc, p.phi), tpMap, (p.as zip newAs).toMap ++ (p.xs zip newXs))
       solver.assertCnstr(andJoin(expr +: inputsMap))
 
       try {
@@ -133,7 +139,7 @@ class StdMonitor(unit: CompilationUnit, invocationsMax: Int, bodies: ScalaMap[Id
 
             val valModel = valuateWithModel(model) _
 
-            val res = p.xs.map(valModel)
+            val res = newXs.map(valModel)
             val leonRes = tupleWrap(res) 
 
             val total = System.currentTimeMillis-tStart
diff --git a/src/main/scala/leon/evaluators/DualEvaluator.scala b/src/main/scala/leon/evaluators/DualEvaluator.scala
index e387e4f0c..6fc5b856d 100644
--- a/src/main/scala/leon/evaluators/DualEvaluator.scala
+++ b/src/main/scala/leon/evaluators/DualEvaluator.scala
@@ -13,14 +13,12 @@ import codegen.runtime.{StdMonitor, Monitor}
 
 class DualEvaluator(ctx: LeonContext, prog: Program, params: CodeGenParams)
   extends RecursiveEvaluator(ctx, prog, params.maxFunctionInvocations)
-  with HasDefaultGlobalContext
-{
+  with HasDefaultGlobalContext {
 
   type RC = DualRecContext
   def initRC(mappings: Map[Identifier, Expr]): RC = DualRecContext(mappings)
   implicit val debugSection = utils.DebugSectionEvaluation
 
-
   val unit = new CompilationUnit(ctx, prog, params)
 
   var monitor: Monitor = new StdMonitor(unit, params.maxFunctionInvocations, Map())
@@ -39,7 +37,10 @@ class DualEvaluator(ctx: LeonContext, prog: Program, params: CodeGenParams)
 
     val (className, methodName, _) = unit.leonFunDefToJVMInfo(tfd.fd).get
 
-    val allArgs = if (params.requireMonitor) monitor +: args else args
+    val allArgs =
+      (if (params.requireMonitor) Seq(monitor) else Seq()) ++
+      (if (tfd.fd.tparams.nonEmpty) Seq(tfd.tps.map(unit.registerType(_)).toArray) else Seq()) ++
+      args
 
     ctx.reporter.debug(s"Calling $className.$methodName(${args.mkString(",")})")
 
@@ -126,7 +127,6 @@ class DualEvaluator(ctx: LeonContext, prog: Program, params: CodeGenParams)
     }
   }
 
-
   override def eval(ex: Expr, model: solvers.Model) = {
     monitor = unit.getMonitor(model, params.maxFunctionInvocations)
     super.eval(ex, model)
diff --git a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala
index 22aa09a4d..e7ff08ba6 100644
--- a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala
+++ b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala
@@ -567,6 +567,8 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int
               ctx.reporter.debug("Verification took "+total+"ms")
               ctx.reporter.debug("Finished forall evaluation with: "+res)
 
+              println(fargs.map(_.id),replaceFromIDs(mapping, body))
+              println(res)
               frlCache += (f, context) -> res
               res
             case _ =>
diff --git a/src/main/scala/leon/purescala/TypeOps.scala b/src/main/scala/leon/purescala/TypeOps.scala
index 6e694843c..3310a8a8b 100644
--- a/src/main/scala/leon/purescala/TypeOps.scala
+++ b/src/main/scala/leon/purescala/TypeOps.scala
@@ -219,7 +219,7 @@ object TypeOps extends { val Deconstructor = NAryType } with SubTreeOps[TypeTree
       id
     }
   }
-  
+
   def instantiateType(id: Identifier, tps: Map[TypeParameterDef, TypeTree]): Identifier = {
     freshId(id, typeParamSubst(tps map { case (tpd, tp) => tpd.tp -> tp })(id.getType))
   }
diff --git a/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala b/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala
index 10447122b..6d7af0ba4 100644
--- a/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala
+++ b/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala
@@ -263,6 +263,7 @@ trait AbstractUnrollingSolver[T]
 
   private def getTotalModel: Model = {
     val wrapped = solverGetModel
+    println(wrapped)
 
     val typeInsts = templateGenerator.manager.typeInstantiations
     val partialInsts = templateGenerator.manager.partialInstantiations
@@ -310,12 +311,17 @@ trait AbstractUnrollingSolver[T]
               if (mapping.isDefinedAt(conds)) mapping else mapping + (conds -> result)
             }
 
-          val rest :+ ((_, dflt)) = filteredConds.toSeq.sortBy(_._1.size)
-          val body = rest.foldLeft(dflt) { case (elze, (conds, res)) =>
-            if (conds.isEmpty) elze else IfExpr(andJoin(conds), res, elze)
-          }
+          if (filteredConds.isEmpty) {
+            // TODO: warning??
+            value
+          } else {
+            val rest :+ ((_, dflt)) = filteredConds.toSeq.sortBy(_._1.size)
+            val body = rest.foldLeft(dflt) { case (elze, (conds, res)) =>
+              if (conds.isEmpty) elze else IfExpr(andJoin(conds), res, elze)
+            }
 
-          Lambda(params.map(ValDef(_)), body)
+            Lambda(params.map(ValDef(_)), body)
+          }
 
         case _ => value
       })
diff --git a/src/main/scala/leon/solvers/templates/QuantificationManager.scala b/src/main/scala/leon/solvers/templates/QuantificationManager.scala
index ea5369e8c..4765060fd 100644
--- a/src/main/scala/leon/solvers/templates/QuantificationManager.scala
+++ b/src/main/scala/leon/solvers/templates/QuantificationManager.scala
@@ -422,14 +422,19 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage
     private def extractSubst(mapping: Set[(Set[T], Matcher[T], Matcher[T])]): (Set[T], Map[T,Arg[T]], Boolean) = {
       var constraints: Set[T] = Set.empty
       var eqConstraints: Set[(T, T)] = Set.empty
-      var matcherEqs: List[(T, T)] = Nil
       var subst: Map[T, Arg[T]] = Map.empty
 
+      var matcherEqs: Set[(T, T)] = Set.empty
+      def strictnessCnstr(qarg: Arg[T], arg: Arg[T]): Unit = (qarg, arg) match {
+        case (Right(qam), Right(am)) => (qam.args zip am.args).foreach(p => strictnessCnstr(p._1, p._2))
+        case _ => matcherEqs += qarg.encoded -> arg.encoded
+      }
+
       for {
         (bs, qm @ Matcher(qcaller, _, qargs, _), m @ Matcher(caller, _, args, _)) <- mapping
-        _ = constraints ++= bs
-        _ = matcherEqs :+= qm.encoded -> m.encoded
+        _ = constraints ++= bs + encoder.mkEquals(qcaller, caller)
         (qarg, arg) <- (qargs zip args)
+        _ = strictnessCnstr(qarg, arg)
       } qarg match {
         case Left(quant) if subst.isDefinedAt(quant) =>
           eqConstraints += (quant -> arg.encoded)
@@ -438,7 +443,6 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage
         case Right(qam) =>
           val argVal = arg.encoded
           eqConstraints += (qam.encoded -> argVal)
-          matcherEqs :+= qam.encoded -> argVal
       }
 
       val substituter = encoder.substitute(subst.mapValues(_.encoded))
-- 
GitLab