From 63376f70fcb5409203286fa7d0abcec5de15dda9 Mon Sep 17 00:00:00 2001
From: Nicolas Voirol <voirol.nicolas@gmail.com>
Date: Wed, 10 Feb 2016 14:15:50 +0100
Subject: [PATCH] Moar quantifier fixes

---
 .../scala/leon/codegen/CodeGeneration.scala     | 13 ++-----------
 .../scala/leon/codegen/CompilationUnit.scala    | 17 +++++++++++++++++
 .../leon/evaluators/RecursiveEvaluator.scala    | 16 +++++++++-------
 src/main/scala/leon/purescala/ExprOps.scala     |  6 +++---
 src/main/scala/leon/purescala/Extractors.scala  | 10 +++++-----
 src/main/scala/leon/purescala/TypeOps.scala     |  6 ++++++
 .../leon/solvers/smtlib/SMTLIBTarget.scala      |  8 ++++++--
 .../templates/QuantificationManager.scala       |  3 +--
 .../scala/leon/solvers/z3/FairZ3Solver.scala    |  2 --
 9 files changed, 49 insertions(+), 32 deletions(-)

diff --git a/src/main/scala/leon/codegen/CodeGeneration.scala b/src/main/scala/leon/codegen/CodeGeneration.scala
index cc85fe1a5..9030ba304 100644
--- a/src/main/scala/leon/codegen/CodeGeneration.scala
+++ b/src/main/scala/leon/codegen/CodeGeneration.scala
@@ -8,6 +8,7 @@ import purescala.Definitions._
 import purescala.Expressions._
 import purescala.ExprOps._
 import purescala.Types._
+import purescala.TypeOps._
 import purescala.Constructors._
 import purescala.Extractors._
 import purescala.Quantification._
@@ -223,16 +224,6 @@ trait CodeGeneration {
     ch.freeze
   }
 
-  private def typeParameters(expr: Expr): Seq[TypeParameter] = {
-    var tparams: Set[TypeParameter] = Set.empty
-    def extractParameters(tpe: TypeTree): Unit = tpe match {
-      case tp: TypeParameter => tparams += tp
-      case NAryType(tps, _) => tps.foreach(extractParameters)
-    }
-    preTraversal(e => extractParameters(e.getType))(expr)
-    tparams.toSeq.sortBy(_.id.uniqueName)
-  }
-
   private[codegen] val lambdaToClass = scala.collection.mutable.Map.empty[Lambda, String]
   private[codegen] val classToLambda = scala.collection.mutable.Map.empty[String, Lambda]
 
@@ -241,7 +232,7 @@ trait CodeGeneration {
     val reverseSubst = structSubst.map(p => p._2 -> p._1)
     val nl = normalized.asInstanceOf[Lambda]
 
-    val tparams: Seq[TypeParameter] = typeParameters(nl)
+    val tparams: Seq[TypeParameter] = typeParamsOf(nl).toSeq.sortBy(_.id.uniqueName)
 
     val closedVars = purescala.ExprOps.variablesOf(nl).toSeq.sortBy(_.uniqueName)
     val closuresWithoutMonitor = closedVars.map(id => id -> typeToJVM(id.getType))
diff --git a/src/main/scala/leon/codegen/CompilationUnit.scala b/src/main/scala/leon/codegen/CompilationUnit.scala
index ff3c4af10..707eae8fa 100644
--- a/src/main/scala/leon/codegen/CompilationUnit.scala
+++ b/src/main/scala/leon/codegen/CompilationUnit.scala
@@ -8,6 +8,7 @@ import purescala.Definitions._
 import purescala.Expressions._
 import purescala.ExprOps._
 import purescala.Types._
+import purescala.TypeOps._
 import purescala.Extractors._
 import purescala.Constructors._
 import utils.UniqueCounter
@@ -248,6 +249,22 @@ class CompilationUnit(val ctx: LeonContext,
       }
       l
 
+    case l @ Lambda(args, body) =>
+      val (afName, closures, tparams, consSig) = compileLambda(l)
+      val args = closures.map { case (id, _) =>
+        if (id == monitorID) monitor
+        else if (id == tpsID) typeParamsOf(l).toSeq.sortBy(_.id.uniqueName).map(registerType).toArray
+        else throw CompilationException(s"Unexpected closure $id in Lambda compilation")
+      }
+
+      val lc = loader.loadClass(afName)
+      val conss = lc.getConstructors.sortBy(_.getParameterTypes.length)
+      println(conss)
+      assert(conss.nonEmpty)
+      val lambdaConstructor = conss.last
+      println(args.toArray)
+      lambdaConstructor.newInstance(args.toArray : _*).asInstanceOf[AnyRef]
+
     case f @ IsTyped(FiniteArray(elems, default, IntLiteral(length)), ArrayType(underlying)) =>
       if (length < 0) {
         throw LeonFatalError(
diff --git a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala
index 4bbdecc1d..2ebb5ade5 100644
--- a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala
+++ b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala
@@ -501,13 +501,15 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int
       FiniteSet(els.map(e), base)
 
     case l @ Lambda(_, _) =>
-      val (nl, structSubst) = normalizeStructure(matchToIfThenElse(l))
-      val mapping = variablesOf(l).map(id => structSubst(id) -> e(Variable(id))).toMap
-      val newLambda = replaceFromIDs(mapping, nl).asInstanceOf[Lambda]
-      if (!gctx.lambdas.isDefinedAt(newLambda)) {
-        gctx.lambdas += (newLambda -> nl.asInstanceOf[Lambda])
-      }
-      newLambda
+      val mapping = variablesOf(l).map(id => id -> e(Variable(id))).toMap
+      val newLambda = replaceFromIDs(mapping, l).asInstanceOf[Lambda]
+      val (normalized, _) = normalizeStructure(matchToIfThenElse(newLambda))
+      val nl = normalized.asInstanceOf[Lambda]
+      if (!gctx.lambdas.isDefinedAt(nl)) {
+        val (norm, _) = normalizeStructure(matchToIfThenElse(l))
+        gctx.lambdas += (nl -> norm.asInstanceOf[Lambda])
+      }
+      nl
 
     case FiniteLambda(mapping, dflt, tpe) =>
       FiniteLambda(mapping.map(p => p._1.map(e) -> e(p._2)), e(dflt), tpe)
diff --git a/src/main/scala/leon/purescala/ExprOps.scala b/src/main/scala/leon/purescala/ExprOps.scala
index af5a04eb8..2672cfe7f 100644
--- a/src/main/scala/leon/purescala/ExprOps.scala
+++ b/src/main/scala/leon/purescala/ExprOps.scala
@@ -1994,8 +1994,8 @@ object ExprOps extends { val Deconstructor = Operator } with SubTreeOps[Expr] {
           Let(i, e, apply(b, args))
         case LetTuple(is, es, b) =>
           letTuple(is, es, apply(b, args))
-        case l @ Lambda(params, body) =>
-          l.withParamSubst(args, body)
+        //case l @ Lambda(params, body) =>
+        //  l.withParamSubst(args, body)
         case _ => Application(expr, args)
       }
 
@@ -2017,7 +2017,7 @@ object ExprOps extends { val Deconstructor = Operator } with SubTreeOps[Expr] {
         case Application(caller, args) =>
           val newArgs = args.map(rec(_, true))
           val newCaller = rec(caller, false)
-          extract(application(newCaller, newArgs), build)
+          extract(Application(newCaller, newArgs), build)
         case FunctionInvocation(fd, args) =>
           val newArgs = args.map(rec(_, true))
           extract(FunctionInvocation(fd, newArgs), build)
diff --git a/src/main/scala/leon/purescala/Extractors.scala b/src/main/scala/leon/purescala/Extractors.scala
index 00541f445..49e6afd3a 100644
--- a/src/main/scala/leon/purescala/Extractors.scala
+++ b/src/main/scala/leon/purescala/Extractors.scala
@@ -145,7 +145,7 @@ object Extractors {
         Some(Seq(t1, t2), (es: Seq[Expr]) => SetUnion(es(0), es(1)))
       case SetDifference(t1, t2) =>
         Some(Seq(t1, t2), (es: Seq[Expr]) => SetDifference(es(0), es(1)))
-      case mg@MapApply(t1, t2) =>
+      case mg @ MapApply(t1, t2) =>
         Some(Seq(t1, t2), (es: Seq[Expr]) => MapApply(es(0), es(1)))
       case MapUnion(t1, t2) =>
         Some(Seq(t1, t2), (es: Seq[Expr]) => MapUnion(es(0), es(1)))
@@ -165,9 +165,9 @@ object Extractors {
         Some(Seq(const, body), (es: Seq[Expr]) => Assert(es(0), oerr, es(1)))
 
       /* Other operators */
-      case fi@FunctionInvocation(fd, args) => Some((args, FunctionInvocation(fd, _)))
-      case mi@MethodInvocation(rec, cd, tfd, args) => Some((rec +: args, as => MethodInvocation(as.head, cd, tfd, as.tail)))
-      case fa@Application(caller, args) => Some(caller +: args, as => application(as.head, as.tail))
+      case fi @ FunctionInvocation(fd, args) => Some((args, FunctionInvocation(fd, _)))
+      case mi @ MethodInvocation(rec, cd, tfd, args) => Some((rec +: args, as => MethodInvocation(as.head, cd, tfd, as.tail)))
+      case fa @ Application(caller, args) => Some(caller +: args, as => Application(as.head, as.tail))
       case CaseClass(cd, args) => Some((args, CaseClass(cd, _)))
       case And(args) => Some((args, and))
       case Or(args) => Some((args, or))
@@ -197,7 +197,7 @@ object Extractors {
           val l = as.length
           nonemptyArray(as.take(l - 2), Some((as(l - 2), as(l - 1))))
         }))
-      case na@NonemptyArray(elems, None) =>
+      case na @ NonemptyArray(elems, None) =>
         val ArrayType(tpe) = na.getType
         val (indexes, elsOrdered) = elems.toSeq.unzip
 
diff --git a/src/main/scala/leon/purescala/TypeOps.scala b/src/main/scala/leon/purescala/TypeOps.scala
index 3310a8a8b..fedac58ec 100644
--- a/src/main/scala/leon/purescala/TypeOps.scala
+++ b/src/main/scala/leon/purescala/TypeOps.scala
@@ -23,6 +23,12 @@ object TypeOps extends { val Deconstructor = NAryType } with SubTreeOps[TypeTree
     })(t)
   }
 
+  def typeParamsOf(expr: Expr): Set[TypeParameter] = {
+    var tparams: Set[TypeParameter] = Set.empty
+    ExprOps.preTraversal(e => typeParamsOf(e.getType))(expr)
+    tparams
+  }
+
   def canBeSubtypeOf(
     tpe: TypeTree,
     freeParams: Seq[TypeParameter],
diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala
index d239de90d..4b8386949 100644
--- a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala
+++ b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala
@@ -20,7 +20,7 @@ import _root_.smtlib.printer.{ RecursivePrinter => SMTPrinter }
 import _root_.smtlib.parser.Commands.{
   Constructor => SMTConstructor,
   FunDef => SMTFunDef,
-  Assert => _,
+  Assert => SMTAssert,
   _
 }
 import _root_.smtlib.parser.Terms.{
@@ -533,7 +533,11 @@ trait SMTLIBTarget extends Interruptible {
 
       case gv @ GenericValue(tpe, n) =>
         genericValues.cachedB(gv) {
-          declareVariable(FreshIdentifier("gv" + n, tpe))
+          val v = declareVariable(FreshIdentifier("gv" + n, tpe))
+          for ((ogv, ov) <- genericValues.aToB if ogv.getType == tpe) {
+            emit(SMTAssert(Core.Not(Core.Equals(v, ov))))
+          }
+          v
         }
 
       /**
diff --git a/src/main/scala/leon/solvers/templates/QuantificationManager.scala b/src/main/scala/leon/solvers/templates/QuantificationManager.scala
index b60ac4eb9..84f8b707f 100644
--- a/src/main/scala/leon/solvers/templates/QuantificationManager.scala
+++ b/src/main/scala/leon/solvers/templates/QuantificationManager.scala
@@ -441,8 +441,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage
         case Left(quant) if quantified(quant) =>
           subst += quant -> arg
         case Right(qam) =>
-          val argVal = arg.encoded
-          eqConstraints += (qam.encoded -> argVal)
+          eqConstraints += (qam.encoded -> arg.encoded)
       }
 
       val substituter = encoder.substitute(subst.mapValues(_.encoded))
diff --git a/src/main/scala/leon/solvers/z3/FairZ3Solver.scala b/src/main/scala/leon/solvers/z3/FairZ3Solver.scala
index 1fb008834..ca1b929e1 100644
--- a/src/main/scala/leon/solvers/z3/FairZ3Solver.scala
+++ b/src/main/scala/leon/solvers/z3/FairZ3Solver.scala
@@ -58,14 +58,12 @@ class FairZ3Solver(val context: LeonContext, val program: Program)
     r
   }
 
-  /*
   override def solverCheckAssumptions[R](assumptions: Seq[Z3AST])(block: Option[Boolean] => R): R = {
     solver.push() // FIXME: remove when z3 bug is fixed
     val res = solver.checkAssumptions(assumptions : _*)
     solver.pop()  // FIXME: remove when z3 bug is fixed
     block(res)
   }
-  */
 
   def solverGetModel: ModelWrapper = new ModelWrapper {
     val model = solver.getModel
-- 
GitLab