From 7bc504cdaa483dd8e81d2e76659e36b128d356e7 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ali=20Sinan=20K=C3=B6ksal?= <alisinan@gmail.com>
Date: Tue, 24 May 2011 09:05:33 +0000
Subject: [PATCH] Better creation of VCs for map accesses (introduce error
 checks only when creating the VC).

---
 demo/AssociativeListReloaded.scala |  5 +++--
 src/funcheck/CodeExtraction.scala  |  4 +---
 src/purescala/DefaultTactic.scala  |  4 ++--
 src/purescala/FairZ3Solver.scala   |  8 ++++----
 src/purescala/Trees.scala          | 25 +++++++++++++++++++++++++
 5 files changed, 35 insertions(+), 11 deletions(-)

diff --git a/demo/AssociativeListReloaded.scala b/demo/AssociativeListReloaded.scala
index 927de5636..a8ea94c72 100644
--- a/demo/AssociativeListReloaded.scala
+++ b/demo/AssociativeListReloaded.scala
@@ -59,8 +59,9 @@ object AssociativeList {
   } holds
 
   def weird(m : Map[Int,Int], k : Int, v : Int) : Boolean = {
-    !(m(k) == v) || m.isDefinedAt(k)
-  } holds
+    m(k) == v && !m.isDefinedAt(k)
+    // m.isDefinedAt(k) || !(m(k) == v) 
+  } ensuring(res => !res)
 
   // def prop0(l : List, m : Map[Int,Int]) : Boolean = {
   //   size(l) > 4 && asMap(l) == m
diff --git a/src/funcheck/CodeExtraction.scala b/src/funcheck/CodeExtraction.scala
index 49174d469..a7c24795a 100644
--- a/src/funcheck/CodeExtraction.scala
+++ b/src/funcheck/CodeExtraction.scala
@@ -550,9 +550,7 @@ trait CodeExtraction extends Extractors {
             throw ImpureCodeEncounteredException(tree)
           }
         }
-        val mg = MapGet(rm, rf).setType(tpe)
-        val ida = MapIsDefinedAt(rm, rf)
-        IfExpr(ida, mg, Error("key not found for map access").setType(tpe)).setType(tpe)
+        MapGet(rm, rf).setType(tpe)
       }
       case ExMapIsDefinedAt(m,k) => {
         val rm = rec(m)
diff --git a/src/purescala/DefaultTactic.scala b/src/purescala/DefaultTactic.scala
index e4797d1be..ee452eb35 100644
--- a/src/purescala/DefaultTactic.scala
+++ b/src/purescala/DefaultTactic.scala
@@ -157,7 +157,7 @@ class DefaultTactic(reporter: Reporter) extends Tactic(reporter) {
 
     def generateMapAccessChecks(function: FunDef) : Seq[VerificationCondition] = {
       val toRet = if (function.hasBody) {
-        val cleanBody = matchToIfThenElse(function.body.get)
+        val cleanBody = mapGetWithChecks(matchToIfThenElse(function.body.get))
 
         val allPathConds = collectWithPathCondition((t => t match {
           case Error("key not found for map access") => true
@@ -165,7 +165,7 @@ class DefaultTactic(reporter: Reporter) extends Tactic(reporter) {
         }), cleanBody)
 
         def withPrecIfDefined(conds: Seq[Expr]) : Expr = if (function.hasPrecondition) {
-          Not(And(matchToIfThenElse(function.precondition.get), And(conds)))
+          Not(And(mapGetWithChecks(matchToIfThenElse(function.precondition.get)), And(conds)))
         } else {
           Not(And(conds))
         }
diff --git a/src/purescala/FairZ3Solver.scala b/src/purescala/FairZ3Solver.scala
index c1a61b698..a47375f2d 100644
--- a/src/purescala/FairZ3Solver.scala
+++ b/src/purescala/FairZ3Solver.scala
@@ -504,7 +504,7 @@ class FairZ3Solver(val reporter: Reporter) extends Solver(reporter) with Abstrac
           }
         }
         case (Some(true), m) => { // SAT
-          // println("the model is")
+          // println("Model returned by Z3:")
           // println(m)
           validatingStopwatch.start
           val (trueModel, model) = validateAndDeleteModel(m, toCheckAgainstModels, varsInVC, evaluator)
@@ -786,9 +786,9 @@ class FairZ3Solver(val reporter: Reporter) extends Solver(reporter) with Abstrac
         case v@Variable(id) => z3Vars.get(id) match {
           case Some(ast) => ast
           case None => {
-            if (id.isLetBinder) {
-              scala.Predef.error("Error in formula being translated to Z3: identifier " + id + " seems to have escaped its let-definition")
-            }
+            // if (id.isLetBinder) {
+            //   scala.Predef.error("Error in formula being translated to Z3: identifier " + id + " seems to have escaped its let-definition")
+            // }
             val newAST = z3.mkFreshConst(id.uniqueName/*name*/, typeToSort(v.getType))
             z3Vars = z3Vars + (id -> newAST)
             exprToZ3Id += (v -> newAST)
diff --git a/src/purescala/Trees.scala b/src/purescala/Trees.scala
index a24b4c5c7..d66a25d11 100644
--- a/src/purescala/Trees.scala
+++ b/src/purescala/Trees.scala
@@ -1110,6 +1110,31 @@ object Trees {
     searchAndReplaceDFS(rewritePM)(expr)
   }
 
+  private var mapGetConverterCache = new scala.collection.mutable.HashMap[Expr,Expr]()
+  /** Rewrites all map accesses with additional error conditions. */
+  def mapGetWithChecks(expr: Expr) : Expr = {
+    val toRet = if (mapGetConverterCache.isDefinedAt(expr)) {
+      matchConverterCache(expr)
+    } else {
+      val converted = convertMapGet(expr)
+      mapGetConverterCache(expr) = converted
+      converted
+    }
+
+    toRet
+  }
+
+  private def convertMapGet(expr: Expr) : Expr = {
+    def rewriteMapGet(e: Expr) : Option[Expr] = e match {
+      case mg @ MapGet(m,k) => 
+        val ida = MapIsDefinedAt(m, k)
+        Some(IfExpr(ida, mg, Error("key not found for map access").setType(mg.getType)).setType(mg.getType))
+      case _ => None
+    }
+
+    searchAndReplaceDFS(rewriteMapGet)(expr)
+  }
+
   // prec: expression does not contain match expressions
   def measureADTChildrenDepth(expression: Expr) : Int = {
     import scala.math.max
-- 
GitLab