From efd59ede41648b64c146bcd7f14a2f3a25d74458 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ali=20Sinan=20K=C3=B6ksal?= <alisinan@gmail.com>
Date: Wed, 6 Jul 2011 21:29:56 +0000
Subject: [PATCH] various fixes to translation of case classes with set/map
 fields

---
 src/cp/CodeGeneration.scala                   |  15 +-
 src/purescala/AbstractZ3Solver.scala          |   6 +-
 src/purescala/Evaluator.scala                 |  12 +-
 src/purescala/FairZ3Solver.scala              | 221 ++++++++++--------
 src/purescala/Trees.scala                     |   2 +-
 src/purescala/Z3ModelReconstruction.scala     |  56 ++---
 src/purescala/Z3Solver.scala                  |   2 +-
 .../z3plugins/instantiator/Instantiator.scala |   2 +-
 8 files changed, 162 insertions(+), 154 deletions(-)

diff --git a/src/cp/CodeGeneration.scala b/src/cp/CodeGeneration.scala
index 05e1e594d..fdfc52b6f 100644
--- a/src/cp/CodeGeneration.scala
+++ b/src/cp/CodeGeneration.scala
@@ -148,16 +148,19 @@ trait CodeGeneration {
             New(TypeTree(scalaSym.tpe), List({
               (ccd.fields.zipWithIndex map {
                 case (VarDecl(id, tpe), idx) =>
-                  val typeArg = tpe match {
-                    case purescala.TypeTrees.BooleanType => definitions.BooleanClass
-                    case purescala.TypeTrees.Int32Type => definitions.IntClass
-                    case c : purescala.TypeTrees.ClassType => reverseClassesToClasses(c.classDef)
-                    case _ => scala.sys.error("Cannot generate method using type : " + tpe)
+                  def typeArg(t: purescala.TypeTrees.TypeTree): Type = t match {
+                    case purescala.TypeTrees.BooleanType => definitions.BooleanClass.tpe
+                    case purescala.TypeTrees.Int32Type => definitions.IntClass.tpe
+                    case c : purescala.TypeTrees.ClassType => reverseClassesToClasses(c.classDef).tpe
+                    case purescala.TypeTrees.SetType(dt) => {
+                      typeRef(NoPrefix, setClass, List(typeArg(dt)))
+                    }
+                    case _ => scala.sys.error("Cannot generate method using type : " + t)
                   }
                   Apply(
                     TypeApply(
                       Ident(castMethodSym), 
-                      List(TypeTree(typeArg.tpe))
+                      List(TypeTree(typeArg(tpe)))
                     ), 
                     List(
                        // cast hack to make typer happy :(
diff --git a/src/purescala/AbstractZ3Solver.scala b/src/purescala/AbstractZ3Solver.scala
index b3830bb19..771dc934d 100644
--- a/src/purescala/AbstractZ3Solver.scala
+++ b/src/purescala/AbstractZ3Solver.scala
@@ -39,11 +39,11 @@ trait AbstractZ3Solver {
   protected[purescala] var funDomainSelectors: Map[TypeTree, Seq[Z3FuncDecl]]
 
   protected[purescala] var exprToZ3Id : Map[Expr,Z3AST]
-  protected[purescala] def fromZ3Formula(tree : Z3AST) : Expr
+  protected[purescala] def fromZ3Formula(model: Z3Model, tree : Z3AST, expectedType: Option[TypeTree]) : Expr
 
-  protected[purescala] def softFromZ3Formula(tree : Z3AST) : Option[Expr] = {
+  protected[purescala] def softFromZ3Formula(model: Z3Model, tree : Z3AST, expectedType: TypeTree) : Option[Expr] = {
     try {
-      Some(fromZ3Formula(tree))
+      Some(fromZ3Formula(model, tree, Some(expectedType)))
     } catch {
       case e: CantTranslateException => None
     }
diff --git a/src/purescala/Evaluator.scala b/src/purescala/Evaluator.scala
index ebaea193e..c4b61eaa0 100644
--- a/src/purescala/Evaluator.scala
+++ b/src/purescala/Evaluator.scala
@@ -34,9 +34,9 @@ object Evaluator {
 
     def rec(ctx: EvaluationContext, expr: Expr) : Expr = if(left <= 0) {
       throw InfiniteComputationEx()
-    } else {
-      // println("Step on : " + expr)
-      // println(ctx)
+    } else { val ret = {
+      println("Step on : " + expr)
+      println(ctx)
       left -= 1
       expr match {
         case Variable(id) => {
@@ -127,9 +127,9 @@ object Evaluator {
           val rv = rec(ctx,re)
 
           (lv,rv) match {
-            case (FiniteSet(el1),FiniteSet(el2)) => BooleanLiteral(el1.toSet == el2.toSet)
+            case (FiniteSet(el1),FiniteSet(el2)) => println("(el1, el2): " + (el1, el2)); BooleanLiteral(el1.toSet == el2.toSet)
             case (FiniteMap(el1),FiniteMap(el2)) => BooleanLiteral(el1.toSet == el2.toSet)
-            case _ => BooleanLiteral(lv == rv)
+            case _ => println("just equals check "); BooleanLiteral(lv == rv)
           }
         }
         case CaseClass(cd, args) => CaseClass(cd, args.map(rec(ctx,_)))
@@ -278,7 +278,7 @@ object Evaluator {
           throw RuntimeErrorEx("unhandled case in Evaluator") 
         }
       }
-    }
+    }; println("ret: " + ret); ret }
 
     evaluator match {
       case Some(evalFun) =>
diff --git a/src/purescala/FairZ3Solver.scala b/src/purescala/FairZ3Solver.scala
index ca3d66603..9bad71102 100644
--- a/src/purescala/FairZ3Solver.scala
+++ b/src/purescala/FairZ3Solver.scala
@@ -653,7 +653,7 @@ class FairZ3Solver(reporter: Reporter) extends Solver(reporter) with AbstractZ3S
                 blockingSet
               } else {
                 // reporter.info(" - Will only unroll literals from core")
-                core.map(ast => fromZ3Formula(ast) match {
+                core.map(ast => fromZ3Formula(m, ast, Some(BooleanType)) match {
                   case n @ Not(Variable(_)) => n
                   case v @ Variable(_) => v
                   case _ => scala.sys.error("Impossible element extracted from core: " + ast)
@@ -1007,107 +1007,136 @@ class FairZ3Solver(reporter: Reporter) extends Solver(reporter) with AbstractZ3S
     }
   }
 
-  protected[purescala] def fromZ3Formula(tree : Z3AST) : Expr = {
-    def rec(t: Z3AST) : Expr = z3.getASTKind(t) match {
-      case Z3AppAST(decl, args) => {
-        val argsSize = args.size
-        if(argsSize == 0 && z3IdToExpr.isDefinedAt(t)) {
-          val toRet = z3IdToExpr(t)
-          // println("Map says I should replace " + t + " by " + toRet)
-          toRet
-        } else if(isKnownDecl(decl)) {
-          val fd = functionDeclToDef(decl)
-          assert(fd.args.size == argsSize)
-          FunctionInvocation(fd, args.map(rec(_)))
-        } else if(argsSize == 1 && reverseADTTesters.isDefinedAt(decl)) {
-          CaseClassInstanceOf(reverseADTTesters(decl), rec(args(0)))
-        } else if(argsSize == 1 && reverseADTFieldSelectors.isDefinedAt(decl)) {
-          val (ccd, fid) = reverseADTFieldSelectors(decl)
-          CaseClassSelector(ccd, rec(args(0)), fid)
-        } else if(reverseADTConstructors.isDefinedAt(decl)) {
-          val ccd = reverseADTConstructors(decl)
-          assert(argsSize == ccd.fields.size)
-          CaseClass(ccd, args.map(rec(_)))
-        } else {
-          import Z3DeclKind._
-          val rargs = args.map(rec(_))
-          z3.getDeclKind(decl) match {
-            case OpTrue => BooleanLiteral(true)
-            case OpFalse => BooleanLiteral(false)
-            case OpEq => Equals(rargs(0), rargs(1))
-            case OpITE => {
-              assert(argsSize == 3)
-              val r0 = rargs(0)
-              val r1 = rargs(1)
-              val r2 = rargs(2)
-              try {
-                IfExpr(r0, r1, r2).setType(leastUpperBound(r1.getType, r2.getType))
-              } catch {
-                case e => {
-                  println("I was asking for lub because of this.")
-                  println(t)
-                  println("which was translated as")
-                  println(IfExpr(r0,r1,r2))
-                  throw e
-                }
-              }
-            }
-            case OpAnd => And(rargs)
-            case OpOr => Or(rargs)
-            case OpIff => Iff(rargs(0), rargs(1))
-            case OpXor => Not(Iff(rargs(0), rargs(1)))
-            case OpNot => Not(rargs(0))
-            case OpImplies => Implies(rargs(0), rargs(1))
-            case OpLE => LessEquals(rargs(0), rargs(1))
-            case OpGE => GreaterEquals(rargs(0), rargs(1))
-            case OpLT => LessThan(rargs(0), rargs(1))
-            case OpGT => GreaterThan(rargs(0), rargs(1))
-            case OpAdd => {
-              assert(argsSize == 2)
-              Plus(rargs(0), rargs(1))
-            }
-            case OpSub => {
-              assert(argsSize == 2)
-              Minus(rargs(0), rargs(1))
-            }
-            case OpUMinus => UMinus(rargs(0))
-            case OpMul => {
-              assert(argsSize == 2)
-              Times(rargs(0), rargs(1))
-            }
-            case OpDiv => {
-              assert(argsSize == 2)
-              Division(rargs(0), rargs(1))
-            }
-            case OpIDiv => {
-              assert(argsSize == 2)
-              Division(rargs(0), rargs(1))
+  protected[purescala] def fromZ3Formula(model: Z3Model, tree : Z3AST, expectedType: Option[TypeTree] = None) : Expr = {
+    def rec(t: Z3AST, expType: Option[TypeTree] = None) : Expr = expType match {
+      case Some(MapType(kt,vt)) => 
+        model.getArrayValue(t) match {
+          case None => throw new CantTranslateException(t)
+          case Some((map, elseValue)) => 
+            val singletons = map.map(e => (e, z3.getASTKind(e._2))).collect {
+              case ((index, value), Z3AppAST(someCons, arg :: Nil)) if someCons == mapRangeSomeConstructors(vt) => SingletonMap(rec(index, Some(kt)), rec(arg, Some(vt)))
             }
-            // case OpAsArray => {
-            //   assert(argsSize == 0)
-            //   throw new Exception()
-            // }
-            case other => {
-              System.err.println("Don't know what to do with this declKind : " + other)
-              System.err.println("The arguments are : " + args)
-              throw new CantTranslateException(t)
+            (if (singletons.isEmpty) EmptyMap(kt, vt) else FiniteMap(singletons.toSeq)).setType(expType.get)
+        }
+      case funType @ Some(FunctionType(fts, tt)) =>
+        model.getArrayValue(t) match {
+          case None => throw new CantTranslateException(t)
+          case Some((es, ev)) =>
+            val entries: Seq[(Seq[Expr], Expr)] = es.toSeq.map(e => (e, z3.getASTKind(e._1))).collect {
+              case ((key, value), Z3AppAST(cons, args)) if cons == funDomainConstructors(funType.get) => ((args zip fts) map (p => rec(p._1, Some(p._2))), rec(value, Some(tt)))
             }
+            val elseValue = rec(ev, Some(tt))
+            AnonymousFunction(entries, elseValue).setType(expType.get)
+        }
+      case Some(SetType(dt)) => 
+        model.getSetValue(t) match {
+          case None => throw new CantTranslateException(t)
+          case Some(set) => {
+            val elems = set.map(e => rec(e, Some(dt)))
+            (if (elems.isEmpty) EmptySet(dt) else FiniteSet(elems.toSeq)).setType(expType.get)
           }
         }
-      }
+      case other => 
+        z3.getASTKind(t) match {
+          case Z3AppAST(decl, args) => {
+            val argsSize = args.size
+            if(argsSize == 0 && z3IdToExpr.isDefinedAt(t)) {
+              val toRet = z3IdToExpr(t)
+              // println("Map says I should replace " + t + " by " + toRet)
+              toRet
+            } else if(isKnownDecl(decl)) {
+              val fd = functionDeclToDef(decl)
+              assert(fd.args.size == argsSize)
+              FunctionInvocation(fd, (args zip fd.args).map(p => rec(p._1,Some(p._2.tpe))))
+            } else if(argsSize == 1 && reverseADTTesters.isDefinedAt(decl)) {
+              CaseClassInstanceOf(reverseADTTesters(decl), rec(args(0)))
+            } else if(argsSize == 1 && reverseADTFieldSelectors.isDefinedAt(decl)) {
+              val (ccd, fid) = reverseADTFieldSelectors(decl)
+              CaseClassSelector(ccd, rec(args(0)), fid)
+            } else if(reverseADTConstructors.isDefinedAt(decl)) {
+              val ccd = reverseADTConstructors(decl)
+              assert(argsSize == ccd.fields.size)
+              CaseClass(ccd, (args zip ccd.fields).map(p => rec(p._1, Some(p._2.tpe))))
+            } else {
+              import Z3DeclKind._
+              val rargs = args.map(rec(_))
+              z3.getDeclKind(decl) match {
+                case OpTrue => BooleanLiteral(true)
+                case OpFalse => BooleanLiteral(false)
+                case OpEq => Equals(rargs(0), rargs(1))
+                case OpITE => {
+                  assert(argsSize == 3)
+                  val r0 = rargs(0)
+                  val r1 = rargs(1)
+                  val r2 = rargs(2)
+                  try {
+                    IfExpr(r0, r1, r2).setType(leastUpperBound(r1.getType, r2.getType))
+                  } catch {
+                    case e => {
+                      println("I was asking for lub because of this.")
+                      println(t)
+                      println("which was translated as")
+                      println(IfExpr(r0,r1,r2))
+                      throw e
+                    }
+                  }
+                }
+                case OpAnd => And(rargs)
+                case OpOr => Or(rargs)
+                case OpIff => Iff(rargs(0), rargs(1))
+                case OpXor => Not(Iff(rargs(0), rargs(1)))
+                case OpNot => Not(rargs(0))
+                case OpImplies => Implies(rargs(0), rargs(1))
+                case OpLE => LessEquals(rargs(0), rargs(1))
+                case OpGE => GreaterEquals(rargs(0), rargs(1))
+                case OpLT => LessThan(rargs(0), rargs(1))
+                case OpGT => GreaterThan(rargs(0), rargs(1))
+                case OpAdd => {
+                  assert(argsSize == 2)
+                  Plus(rargs(0), rargs(1))
+                }
+                case OpSub => {
+                  assert(argsSize == 2)
+                  Minus(rargs(0), rargs(1))
+                }
+                case OpUMinus => UMinus(rargs(0))
+                case OpMul => {
+                  assert(argsSize == 2)
+                  Times(rargs(0), rargs(1))
+                }
+                case OpDiv => {
+                  assert(argsSize == 2)
+                  Division(rargs(0), rargs(1))
+                }
+                case OpIDiv => {
+                  assert(argsSize == 2)
+                  Division(rargs(0), rargs(1))
+                }
+                case OpAsArray => {
+                  assert(argsSize == 0)
+                  throw new Exception("encountered OpAsArray")
+                }
+                case other => {
+                  System.err.println("Don't know what to do with this declKind : " + other)
+                  System.err.println("The arguments are : " + args)
+                  throw new CantTranslateException(t)
+                }
+              }
+            }
+          }
 
-      case Z3NumeralAST(Some(v)) => IntLiteral(v)
-      case other @ _ => {
-        System.err.println("Don't know what this is " + other) 
-        System.err.println("REVERSE FUNCTION MAP:")
-        System.err.println(reverseFunctionMap.toSeq.mkString("\n"))
-        System.err.println("REVERSE CONS MAP:")
-        System.err.println(reverseADTConstructors.toSeq.mkString("\n"))
-        throw new CantTranslateException(t)
-      }
+          case Z3NumeralAST(Some(v)) => IntLiteral(v)
+          case other @ _ => {
+            System.err.println("Don't know what this is " + other) 
+            System.err.println("REVERSE FUNCTION MAP:")
+            System.err.println(reverseFunctionMap.toSeq.mkString("\n"))
+            System.err.println("REVERSE CONS MAP:")
+            System.err.println(reverseADTConstructors.toSeq.mkString("\n"))
+            throw new CantTranslateException(t)
+          }
+        }
     }
-
-    rec(tree)
+    rec(tree, expectedType)
   }
 
   // This remembers everything that was unrolled, which literal is blocking
diff --git a/src/purescala/Trees.scala b/src/purescala/Trees.scala
index 2230f9a6f..338b5d24a 100644
--- a/src/purescala/Trees.scala
+++ b/src/purescala/Trees.scala
@@ -1215,7 +1215,7 @@ object Trees {
     case CaseClassType(ccd) =>
       val fields = ccd.fields
       CaseClass(ccd, fields.map(f => simplestValue(f.getType)))
-    case SetType(baseType) => FiniteSet(Nil).setType(tpe)
+    case SetType(baseType) => EmptySet(baseType).setType(tpe)
     case MapType(fromType, toType) => EmptyMap(fromType, toType).setType(tpe)
     case FunctionType(fromTypes, toType) => AnonymousFunction(Seq.empty, simplestValue(toType)).setType(tpe)
     case _ => throw new Exception("I can't choose simplest value for type " + tpe)
diff --git a/src/purescala/Z3ModelReconstruction.scala b/src/purescala/Z3ModelReconstruction.scala
index 66b478d31..59fea83e1 100644
--- a/src/purescala/Z3ModelReconstruction.scala
+++ b/src/purescala/Z3ModelReconstruction.scala
@@ -21,52 +21,28 @@ trait Z3ModelReconstruction {
     if(exprToZ3Id.isDefinedAt(id.toVariable)) {
       val z3ID : Z3AST = exprToZ3Id(id.toVariable)
 
-      def rec(ast: Z3AST, expTpe: TypeTree): Option[Expr] = expTpe match {
-        case BooleanType => model.evalAs[Boolean](ast).map(BooleanLiteral(_))
-        case Int32Type => model.evalAs[Int](ast).map(IntLiteral(_))
-        case MapType(kt,vt) => model.eval(ast) match {
+      expectedType match {
+        case BooleanType => model.evalAs[Boolean](z3ID).map(BooleanLiteral(_))
+        case Int32Type => model.evalAs[Int](z3ID).map(IntLiteral(_))
+        case other => model.eval(z3ID) match {
           case None => None
-          case Some(t) => model.getArrayValue(t) match {
-            case None => None
-            case Some((map, elseValue)) => 
-              val singletons = map.map(e => (e, z3.getASTKind(e._2))).collect {
-                case ((index, value), Z3AppAST(someCons, arg :: Nil)) if someCons == mapRangeSomeConstructors(vt) => SingletonMap(rec(index, kt).get, rec(arg, vt).get)
-              }
-              (if (singletons.isEmpty) Some(EmptyMap(kt, vt)) else Some(FiniteMap(singletons.toSeq))).map(_.setType(expTpe))
-          }
-        }
-        case funType @ FunctionType(fts, tt) => model.eval(ast) match {
-          case None => None
-          case Some(t) => model.getArrayValue(t) match {
-            case None => None
-            case Some((es, ev)) =>
-              val entries: Seq[(Seq[Expr], Expr)] = es.toSeq.map(e => (e, z3.getASTKind(e._1))).collect {
-                case ((key, value), Z3AppAST(cons, args)) if cons == funDomainConstructors(funType) => ((args zip fts) map (p => rec(p._1, p._2).get), rec(value, tt).get)
-              }
-              val elseValue = rec(ev, tt).get
-              Some(AnonymousFunction(entries, elseValue).setType(expTpe))
-          }
-        }
-        case SetType(dt) => model.eval(ast) match {
-          case None => None
-          case Some(t) => model.getSetValue(t) match {
-            case None => None
-            case Some(set) => {
-              val elems = set.map(e => rec(e, dt).get)
-              (if (elems.isEmpty) Some(EmptySet(dt)) else Some(FiniteSet(elems.toSeq))).map(_.setType(expTpe))
-            }
-          }
-        }
-        case other => model.eval(ast) match {
-          case None => None
-          case Some(t) => softFromZ3Formula(t)
+          case Some(t) => softFromZ3Formula(model, t, expectedType)
         }
       }
-
-      rec(z3ID, expectedType)
     } else None
   }
 
+  // def modelValue(model: Z3Model, id: Identifier, tpe: TypeTree = null) : Option[Expr] = {
+  //   val expectedType = if(tpe == null) id.getType else tpe
+  //   
+  //   if(exprToZ3Id.isDefinedAt(id.toVariable)) {
+  //     val z3ID : Z3AST = exprToZ3Id(id.toVariable)
+
+
+  //     rec(z3ID, expectedType)
+  //   } else None
+  // }
+
   def modelToMap(model: Z3Model, ids: Iterable[Identifier]) : Map[Identifier,Expr] = {
     var asMap = Map.empty[Identifier,Expr]
 
diff --git a/src/purescala/Z3Solver.scala b/src/purescala/Z3Solver.scala
index 98cb563b4..23406bf71 100644
--- a/src/purescala/Z3Solver.scala
+++ b/src/purescala/Z3Solver.scala
@@ -690,7 +690,7 @@ class Z3Solver(reporter: Reporter) extends Solver(reporter) with AbstractZ3Solve
     }
   }
 
-  protected[purescala] def fromZ3Formula(tree : Z3AST) : Expr = {
+  protected[purescala] def fromZ3Formula(model: Z3Model, tree : Z3AST, expectedType: Option[TypeTree] = None) : Expr = {
     def rec(t: Z3AST) : Expr = z3.getASTKind(t) match {
       case Z3AppAST(decl, args) => {
         val argsSize = args.size
diff --git a/src/purescala/z3plugins/instantiator/Instantiator.scala b/src/purescala/z3plugins/instantiator/Instantiator.scala
index 29e46e00a..33641d97c 100644
--- a/src/purescala/z3plugins/instantiator/Instantiator.scala
+++ b/src/purescala/z3plugins/instantiator/Instantiator.scala
@@ -88,7 +88,7 @@ class Instantiator(val z3Solver: Z3Solver) extends Z3Theory(z3Solver.z3, "Instan
       seen += ast
     }
 
-    val aps = fromZ3Formula(ast)
+    val aps = fromZ3Formula(null,ast)
     val fis : Set[FunctionInvocation] = if(allFunctions) {
       functionCallsOf(aps)
     } else {
-- 
GitLab