From d1aea47602e17ad5f0d44cfcd17674c5c6ad3aa5 Mon Sep 17 00:00:00 2001
From: Nicolas Voirol <voirol.nicolas@gmail.com>
Date: Wed, 13 May 2015 19:52:47 +0200
Subject: [PATCH] Finished up call-by-name

---
 src/main/scala/leon/purescala/ExprOps.scala   |   4 +-
 .../leon/solvers/smtlib/SMTLIBTarget.scala    | 118 ++++++++++--------
 .../leon/solvers/smtlib/SMTLIBZ3Target.scala  |   4 +
 .../leon/solvers/z3/AbstractZ3Solver.scala    |  54 ++++----
 .../scala/leon/synthesis/rules/ADTDual.scala  |   4 +-
 .../scala/leon/synthesis/rules/ADTSplit.scala |   2 +-
 .../leon/verification/InductionTactic.scala   |   2 +-
 7 files changed, 98 insertions(+), 90 deletions(-)

diff --git a/src/main/scala/leon/purescala/ExprOps.scala b/src/main/scala/leon/purescala/ExprOps.scala
index 565f1d3f6..99e246512 100644
--- a/src/main/scala/leon/purescala/ExprOps.scala
+++ b/src/main/scala/leon/purescala/ExprOps.scala
@@ -1480,8 +1480,8 @@ object ExprOps {
 
           val isType = IsInstanceOf(Variable(on), cct)
 
-          val recSelectors = cct.fields.collect {
-            case vd if vd.getType == on.getType => vd.id
+          val recSelectors = (cct.classDef.fields zip cct.fieldsTypes).collect { 
+            case (vd, tpe) if tpe == on.getType => vd.id
           }
 
           if (recSelectors.isEmpty) {
diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala
index ac5e0fd4b..ed8e75fbf 100644
--- a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala
+++ b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala
@@ -125,6 +125,7 @@ trait SMTLIBTarget extends Interruptible {
 
   /* Symbol handling */
   protected object SimpleSymbol {
+    def apply(sym: SSymbol) = QualifiedIdentifier(SMTIdentifier(sym))
     def unapply(term: Term): Option[SSymbol] = term match {
       case QualifiedIdentifier(SMTIdentifier(sym, Seq()), None) => Some(sym)
       case _ => None
@@ -132,9 +133,7 @@ trait SMTLIBTarget extends Interruptible {
   }
 
   import scala.language.implicitConversions
-  protected implicit def symbolToQualifiedId(s: SSymbol): QualifiedIdentifier = {
-    QualifiedIdentifier(SMTIdentifier(s))
-  }
+  protected implicit def symbolToQualifiedId(s: SSymbol): QualifiedIdentifier = SimpleSymbol(s)
 
   protected val adtManager = new ADTManager(context)
 
@@ -666,65 +665,78 @@ trait SMTLIBTarget extends Interruptible {
 
       case (SNumeral(n), Some(ft @ FunctionType(from, to))) =>
         val dynLambda = lambdas.toB(ft)
-        letDefs.get(dynLambda) match {
-          case Some(DefineFun(SMTFunDef(a, SortedVar(dispatcher, dkind) +: args, rkind, body))) =>
-
-            object EQ {
-              def unapply(t: Term): Option[(Term, Term)] = t match {
-                case Core.Equals(e1, e2) => Some((e1, e2))
-                case FunctionApplication(f, Seq(e1, e2)) if f.toString == "=" => Some((e1, e2))
-                case _ => None
-              }
-            }
+        val DefineFun(SMTFunDef(a, SortedVar(dispatcher, dkind) +: args, rkind, body)) = letDefs(dynLambda)
 
-            object Num {
-              def unapply(t: Term): Option[BigInt] = t match {
-                case SNumeral(n) => Some(n)
-                case FunctionApplication(f, Seq(SNumeral(n))) if f.toString == "-" => Some(-n)
-                case _ => None
-              }
-            }
+        object EQ {
+          def unapply(t: Term): Option[(Term, Term)] = t match {
+            case Core.Equals(e1, e2) => Some((e1, e2))
+            case FunctionApplication(f, Seq(e1, e2)) if f.toString == "=" => Some((e1, e2))
+            case _ => None
+          }
+        }
 
-            val d = symbolToQualifiedId(dispatcher)
-            def dispatch(t: Term): Term = t match {
-              case Core.ITE(EQ(di, Num(ni)), thenn, elze) if di == d =>
-                if (ni == n) thenn else dispatch(elze)
-              case Core.ITE(Core.And(EQ(di, Num(ni)), _), thenn, elze) if di == d =>
-                if (ni == n) thenn else dispatch(elze)
-              case _ => t
-            }
+        object AND {
+          def unapply(t: Term): Option[Seq[Term]] = t match {
+            case Core.And(e1, e2) => Some(Seq(e1, e2))
+            case FunctionApplication(SimpleSymbol(SSymbol("and")), args) => Some(args)
+            case _ => None
+          }
+          def apply(ts: Seq[Term]): Term = ts match {
+            case Seq() => throw new IllegalArgumentException
+            case Seq(t) => t
+            case _ => FunctionApplication(SimpleSymbol(SSymbol("and")), ts)
+          }
+        }
 
-            def extract(t: Term): Expr = {
-              def recCond(term: Term, index: Int): Seq[Expr] = term match {
-                case Core.And(e1, e2) =>
-                  val e1s = recCond(e1, index)
-                  e1s ++ recCond(e2, index + e1s.size)
-                case EQ(e1, e2) =>
-                  recCond(e2, index)
-                case _ => Seq(fromSMT(term, from(index)))
-              }
+        object Num {
+          def unapply(t: Term): Option[BigInt] = t match {
+            case SNumeral(n) => Some(n)
+            case FunctionApplication(f, Seq(SNumeral(n))) if f.toString == "-" => Some(-n)
+            case _ => None
+          }
+        }
 
-              def recCases(term: Term, matchers: Seq[Expr]): Seq[(Seq[Expr], Expr)] = term match {
-                case Core.ITE(cond, thenn, elze) =>
-                  val cs = recCond(cond, matchers.size)
-                  recCases(thenn, matchers ++ cs) ++ recCases(elze, matchers)
-                case _ => Seq(matchers -> fromSMT(term, to))
-              }
+        val d = symbolToQualifiedId(dispatcher)
+        def dispatch(t: Term): Term = t match {
+          case Core.ITE(EQ(di, Num(ni)), thenn, elze) if di == d =>
+            if (ni == n) thenn else dispatch(elze)
+          case Core.ITE(AND(EQ(di, Num(ni)) +: rest), thenn, elze) if di == d =>
+            if (ni == n) Core.ITE(AND(rest), thenn, dispatch(elze)) else dispatch(elze)
+          case _ => t
+        }
 
-              val cases = recCases(t, Seq.empty)
-              val (default, rest) = cases.partition(_._1.isEmpty)
-              
-              assert(default.size == 1 && rest.forall(_._1.size == from.size))
-              PartialLambda(rest, Some(default.head._2), ft)
-            }
+        def extract(t: Term): Expr = {
+          def recCond(term: Term, index: Int): Seq[Expr] = term match {
+            case AND(es) =>
+              es.foldLeft(Seq.empty[Expr]) { case (seq, e) => seq ++ recCond(e, index + seq.size) }
+            case EQ(e1, e2) =>
+              recCond(e2, index)
+            case _ => Seq(fromSMT(term, from(index)))
+          }
 
-            val lambdaTerm = dispatch(body)
-            val lambda = extract(lambdaTerm)
-            lambda
+          def recCases(term: Term, matchers: Seq[Expr]): Seq[(Seq[Expr], Expr)] = term match {
+            case Core.ITE(cond, thenn, elze) =>
+              val cs = recCond(cond, matchers.size)
+              recCases(thenn, matchers ++ cs) ++ recCases(elze, matchers)
+            case AND(es) if to == BooleanType =>
+              Seq((matchers ++ recCond(term, matchers.size)) -> BooleanLiteral(true))
+            case EQ(e1, e2) if to == BooleanType =>
+              Seq((matchers ++ recCond(term, matchers.size)) -> BooleanLiteral(true))
+            case _ => Seq(matchers -> fromSMT(term, to))
+          }
 
-          case None => unsupported(InfiniteIntegerLiteral(n), "Unknown function ref")
+          val cases = recCases(t, Seq.empty)
+          val (default, rest) = cases.partition(_._1.isEmpty)
+          val leonDefault = if (default.isEmpty && to == BooleanType) BooleanLiteral(false) else default.head._2
+
+          assert(rest.forall(_._1.size == from.size))
+          PartialLambda(rest, Some(leonDefault), ft)
         }
 
+        val lambdaTerm = dispatch(body)
+        val lambda = extract(lambdaTerm)
+        lambda
+
       case (SNumeral(n), Some(RealType)) =>
         FractionalLiteral(n, 1)
 
diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala
index eac8f47f9..334f9f0ad 100644
--- a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala
+++ b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala
@@ -76,6 +76,10 @@ trait SMTLIBZ3Target extends SMTLIBTarget {
         val n = s.name.split("!").toList.last
         GenericValue(tp, n.toInt)
 
+      // XXX: (NV) Z3 doesn't seem to produce models for uninterpreted functions that
+      //      don't impact satisfiability...
+      case (SNumeral(n), Some(ft: FunctionType)) if !letDefs.isDefinedAt(lambdas.toB(ft)) =>
+        purescala.ExprOps.simplestValue(ft)
 
       case (QualifiedIdentifier(ExtendedIdentifier(SSymbol("as-array"), k: SSymbol), _), Some(tpe)) =>
         if (letDefs contains k) {
diff --git a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala
index ca2c0e714..d7a3fd8fa 100644
--- a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala
+++ b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala
@@ -257,7 +257,7 @@ trait AbstractZ3Solver extends Solver {
 
     case ft @ FunctionType(from, to) =>
       sorts.cachedB(ft) {
-        val symbol = z3.mkFreshStringSymbol("fun")
+        val symbol = z3.mkFreshStringSymbol(ft.toString)
         z3.mkUninterpretedSort(symbol)
       }
 
@@ -494,7 +494,7 @@ trait AbstractZ3Solver extends Solver {
         z3.mkApp(functionDefToDecl(tfd), args.map(rec): _*)
 
       case fa @ Application(caller, args) =>
-        val ft @ FunctionType(froms, to) = bestRealType(caller.getType)
+        val ft @ FunctionType(froms, to) = normalizeType(caller.getType)
         val funDecl = lambdas.cachedB(ft) {
           val sortSeq    = (ft +: froms).map(tpe => typeToSort(tpe))
           val returnSort = typeToSort(to)
@@ -569,25 +569,8 @@ trait AbstractZ3Solver extends Solver {
     def rec(t: Z3AST, tpe: TypeTree): Expr = {
       val kind = z3.getASTKind(t)
 
-      (kind, tpe) match {
-        case (Z3NumeralIntAST(Some(v)), ft @ FunctionType(fts, tt)) => lambdas.getB(ft) match {
-          case None => throw new IllegalArgumentException
-          case Some(decl) => model.getModelFuncInterpretations.find(_._1 == decl) match {
-            case None => throw new IllegalArgumentException
-            case Some((_, mapping, elseValue)) =>
-              val lambdaID = InfiniteIntegerLiteral(v)
-              val leonElseValue = rec(elseValue, tt)
-              PartialLambda(mapping.flatMap { case (z3Args, z3Result) =>
-                z3.getASTKind(z3Args.head) match {
-                  case Z3NumeralIntAST(Some(v)) if InfiniteIntegerLiteral(v) == lambdaID =>
-                    List((z3Args.tail zip fts).map(p => rec(p._1, p._2)) -> rec(z3Result, tt))
-                  case _ => Nil
-                }
-              }, Some(leonElseValue), ft)
-          }
-        }
-
-        case (Z3NumeralIntAST(Some(v)), _) =>
+      kind match {
+        case Z3NumeralIntAST(Some(v)) =>
           val leading = t.toString.substring(0, 2 min t.toString.length)
           if(leading == "#x") {
             _root_.smtlib.common.Hexadecimal.fromString(t.toString.substring(2)) match {
@@ -604,7 +587,7 @@ trait AbstractZ3Solver extends Solver {
             InfiniteIntegerLiteral(v)
           }
 
-        case (Z3NumeralIntAST(None), _) =>
+        case Z3NumeralIntAST(None) =>
           _root_.smtlib.common.Hexadecimal.fromString(t.toString.substring(2)) match {
               case Some(hexa) =>
                 tpe match {
@@ -615,9 +598,9 @@ trait AbstractZ3Solver extends Solver {
             case None => unsound(t, "could not translate Z3NumeralIntAST numeral")
           }
 
-        case (Z3NumeralRealAST(n: BigInt, d: BigInt), _) => FractionalLiteral(n, d)
+        case Z3NumeralRealAST(n: BigInt, d: BigInt) => FractionalLiteral(n, d)
 
-        case (Z3AppAST(decl, args), _) =>
+        case Z3AppAST(decl, args) =>
           val argsSize = args.size
           if(argsSize == 0 && (variables containsB t)) {
             variables.toA(t)
@@ -673,6 +656,22 @@ trait AbstractZ3Solver extends Solver {
                   case None => unsound(t, "invalid array AST")
                 }
 
+              case ft @ FunctionType(fts, tt) => lambdas.getB(ft) match {
+                case None => throw new IllegalArgumentException
+                case Some(decl) => model.getModelFuncInterpretations.find(_._1 == decl) match {
+                  case None => throw new IllegalArgumentException
+                  case Some((_, mapping, elseValue)) =>
+                    val leonElseValue = rec(elseValue, tt)
+                    PartialLambda(mapping.flatMap { case (z3Args, z3Result) =>
+                      if (t == z3Args.head) {
+                        List((z3Args.tail zip fts).map(p => rec(p._1, p._2)) -> rec(z3Result, tt))
+                      } else {
+                        Nil
+                      }
+                    }, Some(leonElseValue), ft)
+                }
+              }
+
               case tp: TypeParameter =>
                 val id = t.toString.split("!").last.toInt
                 GenericValue(tp, id)
@@ -695,13 +694,6 @@ trait AbstractZ3Solver extends Solver {
                     FiniteMap(elems, from, to)
                 }
 
-              case ft @ FunctionType(fts, tt) =>
-                rec(t, RawArrayType(tupleTypeWrap(fts), tt)) match {
-                  case r: RawArrayValue =>
-                    val elems = r.elems.toSeq.map { case (k, v) => unwrapTuple(k, fts.size) -> v }
-                    PartialLambda(elems, Some(r.default), ft)
-                }
-
               case tpe @ SetType(dt) =>
                 model.getSetValue(t) match {
                   case None => unsound(t, "invalid set AST")
diff --git a/src/main/scala/leon/synthesis/rules/ADTDual.scala b/src/main/scala/leon/synthesis/rules/ADTDual.scala
index 392670edf..004a88d04 100644
--- a/src/main/scala/leon/synthesis/rules/ADTDual.scala
+++ b/src/main/scala/leon/synthesis/rules/ADTDual.scala
@@ -18,10 +18,10 @@ case object ADTDual extends NormalizingRule("ADTDual") {
 
     val (toRemove, toAdd) = exprs.collect {
       case eq @ Equals(cc @ CaseClass(ct, args), e) if (variablesOf(e) -- as).isEmpty && (variablesOf(cc) & xs).nonEmpty =>
-        (eq, IsInstanceOf(e, ct) +: (ct.fields zip args).map{ case (vd, ex) => Equals(ex, caseClassSelector(ct, e, vd.id)) } )
+        (eq, IsInstanceOf(e, ct) +: (ct.classDef.fields zip args).map{ case (vd, ex) => Equals(ex, caseClassSelector(ct, e, vd.id)) } )
 
       case eq @ Equals(e, cc @ CaseClass(ct, args)) if (variablesOf(e) -- as).isEmpty && (variablesOf(cc) & xs).nonEmpty =>
-        (eq, IsInstanceOf(e, ct) +: (ct.fields zip args).map{ case (vd, ex) => Equals(ex, caseClassSelector(ct, e, vd.id)) } )
+        (eq, IsInstanceOf(e, ct) +: (ct.classDef.fields zip args).map{ case (vd, ex) => Equals(ex, caseClassSelector(ct, e, vd.id)) } )
     }.unzip
 
     if (toRemove.nonEmpty) {
diff --git a/src/main/scala/leon/synthesis/rules/ADTSplit.scala b/src/main/scala/leon/synthesis/rules/ADTSplit.scala
index 7ed086087..32848ae1c 100644
--- a/src/main/scala/leon/synthesis/rules/ADTSplit.scala
+++ b/src/main/scala/leon/synthesis/rules/ADTSplit.scala
@@ -94,7 +94,7 @@ case object ADTSplit extends Rule("ADT Split.") {
 
             val cases = for ((sol, (cct, problem, pattern)) <- sols zip subInfo) yield {
               if (sol.pre != BooleanLiteral(true)) {
-                val substs = (for ((field,arg) <- cct.fields zip problem.as ) yield {
+                val substs = (for ((field,arg) <- cct.classDef.fields zip problem.as ) yield {
                   (arg, caseClassSelector(cct, id.toVariable, field.id))
                 }).toMap
                 globalPre ::= and(IsInstanceOf(Variable(id), cct), replaceFromIDs(substs, sol.pre))
diff --git a/src/main/scala/leon/verification/InductionTactic.scala b/src/main/scala/leon/verification/InductionTactic.scala
index 65f96a090..dd437c224 100644
--- a/src/main/scala/leon/verification/InductionTactic.scala
+++ b/src/main/scala/leon/verification/InductionTactic.scala
@@ -21,7 +21,7 @@ class InductionTactic(vctx: VerificationContext) extends DefaultTactic(vctx) {
   }
 
   private def selectorsOfParentType(parentType: ClassType, cct: CaseClassType, expr: Expr): Seq[Expr] = {
-    val childrenOfSameType = cct.fields.filter(_.getType == parentType)
+    val childrenOfSameType = (cct.classDef.fields zip cct.fieldsTypes).collect { case (vd, tpe) if tpe == parentType => vd }
     for (field <- childrenOfSameType) yield {
       caseClassSelector(cct, expr, field.id)
     }
-- 
GitLab