From fdc3b193844c30c94777fc19c5886e4c030d4bda Mon Sep 17 00:00:00 2001
From: Robin Steiger <robin.steiger@epfl.ch>
Date: Mon, 12 Jul 2010 18:49:56 +0000
Subject: [PATCH] Two non-trivial bug fixes in pure scala (in
 pulloutAndKeepLets and inlineFunctionsAndContracts). InsertSort.sort can now
 be verified.

---
 src/purescala/Analysis.scala | 12 ++++++++++
 src/purescala/Trees.scala    | 46 +++++++++++++++++++++++++++++-------
 2 files changed, 50 insertions(+), 8 deletions(-)

diff --git a/src/purescala/Analysis.scala b/src/purescala/Analysis.scala
index 0519d109d..11b438939 100644
--- a/src/purescala/Analysis.scala
+++ b/src/purescala/Analysis.scala
@@ -166,11 +166,23 @@ object Analysis {
         val substMap = Map[Expr,Expr]((fArgsAsVars zip fParamsAsLetVarVars) : _*)
         if(fd.hasPostcondition) {
           val newVar = Variable(FreshIdentifier("call", true)).setType(fd.returnType)
+          /* START CHANGE */
+          //  Code before
+          /*
           extras = mkBigLet(And(
             replace(substMap + (ResultVariable() -> newVar), fd.postcondition.get),
             Equals(newVar, FunctionInvocation(fd, fParamsAsLetVarVars).setType(fd.returnType))
           )) :: extras
           Some(newVar)
+          */
+          
+          // Fixed code ?!?
+          extras = And(
+            replace(substMap + (ResultVariable() -> newVar), fd.postcondition.get),
+            Equals(newVar, FunctionInvocation(fd, fParamsAsLetVarVars).setType(fd.returnType))
+          ) :: extras
+          Some(mkBigLet(newVar))
+          /* END CHANGE */
         } else if(fd.hasImplementation && !program.isRecursive(fd)) { // means we can inline at least one level...
           Some(mkBigLet(replace(substMap, fd.body.get)))
         } else { // we can't do much for calls to recursive functions or to functions with no bodies
diff --git a/src/purescala/Trees.scala b/src/purescala/Trees.scala
index 94bf03bf7..d4354853f 100644
--- a/src/purescala/Trees.scala
+++ b/src/purescala/Trees.scala
@@ -125,7 +125,10 @@ object Trees {
 
   /* For all types that don't have their own XXXEquals */
   object Equals {
-    def apply(l : Expr, r : Expr) : Equals = new Equals(l,r)
+    def apply(l : Expr, r : Expr) : Expr = (l.getType, r.getType) match {
+      case (BooleanType, BooleanType) => Iff(l, r)
+      case _ => new Equals(l, r)
+    }
     def unapply(e : Equals) : Option[(Expr,Expr)] = if (e == null) None else Some((e.left, e.right))
   }
 
@@ -272,9 +275,9 @@ object Trees {
 
   object BinaryOperator {
     def unapply(expr: Expr) : Option[(Expr,Expr,(Expr,Expr)=>Expr)] = expr match {
-      case Equals(t1,t2) => Some((t1,t2,Equals(_,_)))
+      case Equals(t1,t2) => Some((t1,t2,Equals.apply))
       case Iff(t1,t2) => Some((t1,t2,Iff))
-      case Implies(t1,t2) => Some((t1,t2, ((e1,e2) => Implies(e1,e2))))
+      case Implies(t1,t2) => Some((t1,t2,Implies.apply))
       case Plus(t1,t2) => Some((t1,t2,Plus))
       case Minus(t1,t2) => Some((t1,t2,Minus))
       case Times(t1,t2) => Some((t1,t2,Times))
@@ -308,8 +311,8 @@ object Trees {
     def unapply(expr: Expr) : Option[(Seq[Expr],(Seq[Expr])=>Expr)] = expr match {
       case FunctionInvocation(fd, args) => Some((args, FunctionInvocation(fd, _)))
       case CaseClass(cd, args) => Some((args, CaseClass(cd, _)))
-      case And(args) => Some((args, And(_)))
-      case Or(args) => Some((args, Or(_)))
+      case And(args) => Some((args, And.apply))
+      case Or(args) => Some((args, Or.apply))
       case FiniteSet(args) => Some((args, FiniteSet))
       case FiniteMultiset(args) => Some((args, FiniteMultiset))
       case _ => None
@@ -321,20 +324,21 @@ object Trees {
     case Not(e) => e
     case Iff(e1,e2) => Iff(negate(e1),e2)
     case Implies(e1,e2) => And(e1, negate(e2))
-    case Or(exs) => And(exs.map(negate(_)))
-    case And(exs) => Or(exs.map(negate(_)))
+    case Or(exs) => And(exs map negate)
+    case And(exs) => Or(exs map negate)
     case LessThan(e1,e2) => GreaterEquals(e1,e2)
     case LessEquals(e1,e2) => GreaterThan(e1,e2)
     case GreaterThan(e1,e2) => LessEquals(e1,e2)
     case GreaterEquals(e1,e2) => LessThan(e1,e2)
     case i @ IfExpr(c,e1,e2) => IfExpr(c, negate(e1), negate(e2)).setType(i.getType)
+    case BooleanLiteral(b) => BooleanLiteral(!b)
     case _ => Not(expr)
   }
  
   // Warning ! This may loop forever if the substitutions are not
   // well-formed!
   def replace(substs: Map[Expr,Expr], expr: Expr) : Expr = {
-    searchAndReplace(substs.get(_))(expr)
+    searchAndReplace(substs.get)(expr)
   }
 
   def searchAndReplace(subst: Expr=>Option[Expr], recursive: Boolean=true)(expr: Expr) : Expr = {
@@ -447,6 +451,9 @@ object Trees {
     rebuildLets(storedLets, noLets)
   }
 
+  /* START CHANGE */
+  // Previous code (keep this if nested lets can only appear in the body)
+  /*
   def pulloutAndKeepLets(expr: Expr) : (Seq[(Identifier,Expr)], Expr) = {
     var storedLets: List[(Identifier,Expr)] = Nil
 
@@ -463,6 +470,29 @@ object Trees {
     val noLets = searchAndReplace(killLet)(expr)
     (storedLets, noLets)
   }
+  */
+  
+  // new code (keep this if nested lets can appear in the value part, too)
+  def pulloutAndKeepLets(expr: Expr) : (List[(Identifier,Expr)], Expr) = {
+    var storedLets: List[(Identifier,Expr)] = Nil
+
+    def storeLet(t: Expr) : Option[Expr] = t match {
+      case l @ Let(i, e, b) =>
+        // Easy fix, but breaks define-before-use order !!
+        //val noLets = searchAndReplace(storeLet)(e)
+        //storedLets ::= i -> noLets
+        
+        // Better fix, but please check
+        val (stored, value) = pulloutAndKeepLets(e)
+        storedLets :::= stored
+        storedLets ::= i -> value
+        Some(b)
+      case _ => None
+    }
+    val noLets = searchAndReplace(storeLet)(expr)
+    (storedLets, noLets)
+  }
+  /* END CHANGE */
 
   def rebuildLets(lets: Seq[(Identifier,Expr)], expr: Expr) : Expr = {
     lets.foldLeft(expr)((e,p) => Let(p._1, p._2, e))
-- 
GitLab