From c08f1602ab1ba00988d76782d93b4fed0b4a453b Mon Sep 17 00:00:00 2001
From: Ravi <ravi.kandhadai@epfl.ch>
Date: Mon, 25 Jan 2016 00:31:09 +0100
Subject: [PATCH] Fixing a bug

---
 .../invariant/engine/SpecInstantiator.scala   |  32 +--
 .../leon/invariant/structure/Constraint.scala |  46 +++--
 .../structure/LinearConstraintUtil.scala      |   4 +-
 .../util/ExpressionTransformer.scala          | 195 ++++++++++--------
 .../laziness/LazinessEliminationPhase.scala   |   7 +-
 .../scala/leon/purescala/Expressions.scala    |   9 +-
 6 files changed, 160 insertions(+), 133 deletions(-)

diff --git a/src/main/scala/leon/invariant/engine/SpecInstantiator.scala b/src/main/scala/leon/invariant/engine/SpecInstantiator.scala
index 16403fbf4..88d7cdb9b 100644
--- a/src/main/scala/leon/invariant/engine/SpecInstantiator.scala
+++ b/src/main/scala/leon/invariant/engine/SpecInstantiator.scala
@@ -109,23 +109,28 @@ class SpecInstantiator(ctx: InferenceContext, program: Program, ctrTracker: Cons
     resetUntempCalls(formula.fd, newUntemplatedCalls ++ calls)
   }
 
+  import leon.purescala.TypeOps._
   def specForCall(call: Call): Option[Expr] = {
     val argmap = formalToActual(call)
-    val callee = call.fi.tfd.fd
+    val tfd = call.fi.tfd
+    val callee = tfd.fd
     if (callee.hasPostcondition) {
+      // instantiate the post
+      val tparamMap = (callee.tparams zip tfd.tps).toMap
+      val trans = freshenLocals _ andThen (e => instantiateType(e, tparamMap, Map())) andThen matchToIfThenElse _
       //get the postcondition without templates
-      val post = callee.getPostWoTemplate
-      val freshPost = freshenLocals(matchToIfThenElse(post))
-      val spec = if (callee.hasPrecondition) {
-        val freshPre = freshenLocals(matchToIfThenElse(callee.precondition.get))
+      val rawpost = trans(callee.getPostWoTemplate)
+      val rawspec = if (callee.hasPrecondition) {
+        val pre = trans(callee.precondition.get)
         if (ctx.assumepre)
-          And(freshPre, freshPost)
+          And(pre, rawpost)
         else
-          Implies(freshPre, freshPost)
+          Implies(pre, rawpost)
       } else {
-        freshPost
+        rawpost
       }
-      val inlinedSpec = ExpressionTransformer.normalizeExpr(replace(argmap, spec), ctx.multOp)
+      val spec = replace(argmap, rawspec)
+      val inlinedSpec = ExpressionTransformer.normalizeExpr(spec, ctx.multOp)
       Some(inlinedSpec)
     } else {
       None
@@ -133,12 +138,15 @@ class SpecInstantiator(ctx: InferenceContext, program: Program, ctrTracker: Cons
   }
 
   def templateForCall(call: Call): Option[Expr] = {
-    val callee = call.fi.tfd.fd
+    val tfd = call.fi.tfd
+    val callee = tfd.fd
     if (callee.hasTemplate) {
       val argmap = formalToActual(call)
-      val tempExpr = replace(argmap, callee.getTemplate)
+      val tparamMap = (callee.tparams zip tfd.tps).toMap
+      val tempExpr = replace(argmap, instantiateType(callee.getTemplate, tparamMap, Map()))
       val template = if (callee.hasPrecondition) {
-        val freshPre = replace(argmap, freshenLocals(matchToIfThenElse(callee.precondition.get)))
+        val pre = replace(argmap, instantiateType(callee.precondition.get, tparamMap, Map()))
+        val freshPre =  freshenLocals(matchToIfThenElse(pre))
         if (ctx.assumepre)
           And(freshPre, tempExpr)
         else
diff --git a/src/main/scala/leon/invariant/structure/Constraint.scala b/src/main/scala/leon/invariant/structure/Constraint.scala
index 2a12f6949..9644ff5c4 100644
--- a/src/main/scala/leon/invariant/structure/Constraint.scala
+++ b/src/main/scala/leon/invariant/structure/Constraint.scala
@@ -251,13 +251,15 @@ case class Call(retexpr: Expr, fi: FunctionInvocation) extends Constraint {
 }
 
 object SetConstraint {
+  def isSetOp(e: Expr) =
+    e match {
+      case SetUnion(_, _) | FiniteSet(_, _) | ElementOfSet(_, _) | SubsetOf(_, _) | Variable(_) =>
+        true
+      case _ => false
+    }
+
   def setConstraintOfBase(e: Expr) = e match {
-    case Equals(lhs@Variable(_), rhs) if lhs.getType.isInstanceOf[SetType] =>
-      rhs match {
-        case SetUnion(_, _) | FiniteSet(_, _) | ElementOfSet(_, _) | SubsetOf(_, _) | Variable(_) =>
-          true
-        case _ => false
-      }
+    case Equals(Variable(_), rhs) if isSetOp(rhs) => true
     case _ => false
   }
 
@@ -271,22 +273,22 @@ object SetConstraint {
 }
 
 case class SetConstraint(expr: Expr) extends Constraint {
-  var union = false
-  var newset = false
-  var equal = false
-  var elemof = false
-  var subset = false
-  // TODO: add more operations here
-  expr match {
-    case Equals(Variable(_), rhs) =>
-      rhs match {
-        case SetUnion(_, _) => union = true
-        case FiniteSet(_, _) => newset = true
-        case ElementOfSet(_, _) => elemof = true
-        case SubsetOf(_, _) => subset = true
-        case Variable(_) => equal = true
-      }
-  }
+//  var union = false
+//  var newset = false
+//  var equal = false
+//  var elemof = false
+//  var subset = false
+//  // TODO: add more operations here
+//  expr match {
+//    case Equals(Variable(_), rhs) =>
+//      rhs match {
+//        case SetUnion(_, _) => union = true
+//        case FiniteSet(_, _) => newset = true
+//        case ElementOfSet(_, _) => elemof = true
+//        case SubsetOf(_, _) => subset = true
+//        case Variable(_) => equal = true
+//      }
+//  }
   override def toString(): String = {
     expr.toString
   }
diff --git a/src/main/scala/leon/invariant/structure/LinearConstraintUtil.scala b/src/main/scala/leon/invariant/structure/LinearConstraintUtil.scala
index 5ed3316a4..a69ecb45d 100644
--- a/src/main/scala/leon/invariant/structure/LinearConstraintUtil.scala
+++ b/src/main/scala/leon/invariant/structure/LinearConstraintUtil.scala
@@ -258,7 +258,9 @@ object LinearConstraintUtil {
             throw new IllegalStateException("Expression not linear: " + Times(r1, r2))
         }
         case Plus(e1, e2) => Plus(mkLinearRecur(e1), mkLinearRecur(e2))
-        case RealPlus(e1, e2) => RealPlus(mkLinearRecur(e1), mkLinearRecur(e2))
+        case rp@RealPlus(e1, e2) =>
+          println(s"Expr: $rp arg1: $e1 arg2: $e2")
+          RealPlus(mkLinearRecur(e1), mkLinearRecur(e2))
         case t: Terminal => t
         case fi: FunctionInvocation => fi
         case _ => throw new IllegalStateException("Expression not linear: " + inExpr)
diff --git a/src/main/scala/leon/invariant/util/ExpressionTransformer.scala b/src/main/scala/leon/invariant/util/ExpressionTransformer.scala
index 58131d249..c2909ced2 100644
--- a/src/main/scala/leon/invariant/util/ExpressionTransformer.scala
+++ b/src/main/scala/leon/invariant/util/ExpressionTransformer.scala
@@ -35,48 +35,37 @@ object ExpressionTransformer {
    */
   def conjoinWithinClause(e: Expr, transformer: (Expr, Boolean) => (Expr, Set[Expr]),
     insideFunction: Boolean): (Expr, Set[Expr]) = {
-      e match {
-        case And(args) if !insideFunction => {
-          val newargs = args.map((arg) => {
-            val (nexp, ncjs) = transformer(arg, false)
-            createAnd(nexp +: ncjs.toSeq)
-          })
-          (createAnd(newargs), Set())
-        }
-
-        case Or(args) if !insideFunction => {
-          val newargs = args.map((arg) => {
-            val (nexp, ncjs) = transformer(arg, false)
-            createAnd(nexp +: ncjs.toSeq)
-          })
-          (createOr(newargs), Set())
-        }
+    e match {
+      case And(args) if !insideFunction => {
+        val newargs = args.map((arg) => {
+          val (nexp, ncjs) = transformer(arg, false)
+          createAnd(nexp +: ncjs.toSeq)
+        })
+        (createAnd(newargs), Set())
+      }
 
-        case t: Terminal => (t, Set())
+      case Or(args) if !insideFunction => {
+        val newargs = args.map((arg) => {
+          val (nexp, ncjs) = transformer(arg, false)
+          createAnd(nexp +: ncjs.toSeq)
+        })
+        (createOr(newargs), Set())
+      }
 
-        /*case BinaryOperator(e1, e2, op) => {
-          val (nexp1, ncjs1) = transformer(e1, true)
-          val (nexp2, ncjs2) = transformer(e2, true)
-          (op(nexp1, nexp2), ncjs1 ++ ncjs2)
-        }
+      case t: Terminal => (t, Set())
 
-        case u @ UnaryOperator(e1, op) => {
-          val (nexp, ncjs) = transformer(e1, true)
-          (op(nexp), ncjs)
-        }*/
-
-        case n @ Operator(args, op) => {
-          var ncjs = Set[Expr]()
-          val newargs = args.map((arg) => {
-            val (nexp, js) = transformer(arg, true)
-            ncjs ++= js
-            nexp
-          })
-          (op(newargs), ncjs)
-        }
-        case _ => throw new IllegalStateException("Impossible event: expr did not match any case: " + e)
+      case n @ Operator(args, op) => {
+        var ncjs = Set[Expr]()
+        val newargs = args.map((arg) => {
+          val (nexp, js) = transformer(arg, true)
+          ncjs ++= js
+          nexp
+        })
+        (op(newargs), ncjs)
       }
+      case _ => throw new IllegalStateException("Impossible event: expr did not match any case: " + e)
     }
+  }
 
   /**
    * Assumed that that given expression has boolean type
@@ -274,14 +263,14 @@ object ExpressionTransformer {
 
           (freshResVar, newConjuncts)
         }
-        case SetUnion(_, _) | ElementOfSet(_, _) | SubsetOf(_, _)  =>
+        case SetUnion(_, _) | ElementOfSet(_, _) | SubsetOf(_, _) =>
           val Operator(args, op) = e
           val (Seq(a1, a2), newcjs) = flattenArgs(args, true)
           val newexpr = op(Seq(a1, a2))
           val freshResVar = Variable(TVarFactory.createTemp("set", e.getType))
           (freshResVar, newcjs + Equals(freshResVar, newexpr))
 
-        case fs@FiniteSet(es, typ) =>
+        case fs @ FiniteSet(es, typ) =>
           val args = es.toSeq
           val (nargs, newcjs) = flattenArgs(args, true)
           val newexpr = FiniteSet(nargs.toSet, typ)
@@ -322,6 +311,23 @@ object ExpressionTransformer {
     } else nexp
   }
 
+  def testHelp(e: Expr) = {
+    e match {
+      case Operator(args, op) =>
+        args.foreach { arg =>
+          if (arg.getType == Untyped) {
+            println(s"$arg is untyped! ")
+            arg match {
+              case CaseClassSelector(cct, cl, fld) =>
+                println("cl type: " + cl.getType + " cct: " + cct)
+              case _ =>
+            }
+          }
+        }
+      case _ =>
+    }
+  }
+
   /**
    * The following procedure converts the formula into negated normal form by pushing all not's inside.
    * It also handles disequality constraints.
@@ -333,58 +339,65 @@ object ExpressionTransformer {
    */
   def TransformNot(expr: Expr, retainNEQ: Boolean = false): Expr = { // retainIff : Boolean = false
     def nnf(inExpr: Expr): Expr = {
-
+      if(inExpr.getType == Untyped){
+        testHelp(inExpr)
+        println(s"Warning: $inExpr is untyped")
+      }
       if (inExpr.getType != BooleanType) inExpr
-      else inExpr match {
-        case Not(Not(e1)) => nnf(e1)
-        case e @ Not(t: Terminal) => e
-        case e @ Not(FunctionInvocation(_, _)) => e
-        case Not(And(args)) => createOr(args.map(arg => nnf(Not(arg))))
-        case Not(Or(args)) => createAnd(args.map(arg => nnf(Not(arg))))
-        case Not(e @ Operator(Seq(e1, e2), op)) => {
-        	//matches integer binary relation or a boolean equality
-          if (e1.getType == BooleanType || e1.getType == Int32Type || e1.getType == RealType || e1.getType == IntegerType) {
-            e match {
-              case e: Equals => {
-                if (e1.getType == BooleanType && e2.getType == BooleanType) {
-                  Or(And(nnf(e1), nnf(Not(e2))), And(nnf(e2), nnf(Not(e1))))
-                } else {
-                  if (retainNEQ) Not(Equals(e1, e2))
-                  else Or(nnf(LessThan(e1, e2)), nnf(GreaterThan(e1, e2)))
+      else {
+        inExpr match {
+          case Not(Not(e1)) => nnf(e1)
+          case e @ Not(t: Terminal) => e
+          case e @ Not(FunctionInvocation(_, _)) => e
+          case Not(And(args)) => createOr(args.map(arg => nnf(Not(arg))))
+          case Not(Or(args)) => createAnd(args.map(arg => nnf(Not(arg))))
+          case Not(e @ Operator(Seq(e1, e2), op)) => {
+            //matches integer binary relation or a boolean equality
+            if (e1.getType == BooleanType || e1.getType == Int32Type || e1.getType == RealType || e1.getType == IntegerType) {
+              e match {
+                case e: Equals => {
+                  if (e1.getType == BooleanType && e2.getType == BooleanType) {
+                    Or(And(nnf(e1), nnf(Not(e2))), And(nnf(e2), nnf(Not(e1))))
+                  } else {
+                    if (retainNEQ) Not(Equals(e1, e2))
+                    else Or(nnf(LessThan(e1, e2)), nnf(GreaterThan(e1, e2)))
+                  }
                 }
+                case e: LessThan => GreaterEquals(nnf(e1), nnf(e2))
+                case e: LessEquals => GreaterThan(nnf(e1), nnf(e2))
+                case e: GreaterThan => LessEquals(nnf(e1), nnf(e2))
+                case e: GreaterEquals => LessThan(nnf(e1), nnf(e2))
+                case e: Implies => And(nnf(e1), nnf(Not(e2)))
+                case _ => throw new IllegalStateException("Unknown binary operation: " + e)
+              }
+            } else {
+              //in this case e is a binary operation over ADTs
+              e match {
+                case ninst @ Not(IsInstanceOf(e1, cd)) => Not(IsInstanceOf(nnf(e1), cd))
+                case Not(SubsetOf(_, _)) | Not(ElementOfSet(_, _)) | Not(SetUnion(_, _)) | Not(FiniteSet(_, _)) =>
+                  e
+                case e: Equals => Not(Equals(nnf(e1), nnf(e2)))
+                case _ => throw new IllegalStateException("Unknown operation on algebraic data types: " + e)
               }
-              case e: LessThan => GreaterEquals(nnf(e1), nnf(e2))
-              case e: LessEquals => GreaterThan(nnf(e1), nnf(e2))
-              case e: GreaterThan => LessEquals(nnf(e1), nnf(e2))
-              case e: GreaterEquals => LessThan(nnf(e1), nnf(e2))
-              case e: Implies => And(nnf(e1), nnf(Not(e2)))
-              case _ => throw new IllegalStateException("Unknown binary operation: " + e)
-            }
-          } else {
-            //in this case e is a binary operation over ADTs
-            e match {
-              case ninst @ Not(IsInstanceOf(e1, cd)) => Not(IsInstanceOf(nnf(e1), cd))
-              case e: Equals => Not(Equals(nnf(e1), nnf(e2)))
-              case _ => throw new IllegalStateException("Unknown operation on algebraic data types: " + e)
             }
           }
+          case e @ Equals(lhs, SubsetOf(_, _) | ElementOfSet(_, _) | SetUnion(_, _) | FiniteSet(_, _)) =>
+            // all are set operations
+            e
+          case e @ Equals(lhs, IsInstanceOf(_, _) | CaseClassSelector(_, _, _) | TupleSelect(_, _) | FunctionInvocation(_, _)) =>
+            //all case where rhs could use an ADT tree e.g. instanceOF, tupleSelect, fieldSelect, function invocation
+            e
+          case Implies(lhs, rhs) => nnf(Or(Not(lhs), rhs))
+          case Equals(lhs, rhs) if (lhs.getType == BooleanType && rhs.getType == BooleanType) => {
+            nnf(And(Implies(lhs, rhs), Implies(rhs, lhs)))
+          }
+          case Not(IfExpr(cond, thn, elze)) => IfExpr(nnf(cond), nnf(Not(thn)), nnf(Not(elze)))
+          case Not(Let(i, v, e)) => Let(i, nnf(v), nnf(Not(e)))
+          case t: Terminal => t
+          case n @ Operator(args, op) => op(args.map(nnf(_)))
+
+          case _ => throw new IllegalStateException("Impossible event: expr did not match any case: " + inExpr)
         }
-        case Implies(lhs, rhs) => nnf(Or(Not(lhs), rhs))
-        case e @ Equals(lhs, IsInstanceOf(_, _) | CaseClassSelector(_, _, _) | TupleSelect(_, _) | FunctionInvocation(_, _)) =>
-          //all case where rhs could use an ADT tree e.g. instanceOF, tupleSelect, fieldSelect, function invocation
-          e
-        case Equals(lhs, rhs) if (lhs.getType == BooleanType && rhs.getType == BooleanType) => {
-          nnf(And(Implies(lhs, rhs), Implies(rhs, lhs)))
-        }
-        case Not(IfExpr(cond, thn, elze)) => IfExpr(nnf(cond), nnf(Not(thn)), nnf(Not(elze)))
-        case Not(Let(i, v, e)) => Let(i, nnf(v), nnf(Not(e)))
-        //note that Not(LetTuple) is not possible
-        case t: Terminal => t
-        /*case u @ UnaryOperator(e1, op) => op(nnf(e1))
-        case b @ BinaryOperator(e1, e2, op) => op(nnf(e1), nnf(e2))*/
-        case n @ Operator(args, op) => op(args.map(nnf(_)))
-
-        case _ => throw new IllegalStateException("Impossible event: expr did not match any case: " + inExpr)
       }
     }
     val nnfvc = nnf(expr)
@@ -455,14 +468,14 @@ object ExpressionTransformer {
    */
   def normalizeExpr(expr: Expr, multOp: (Expr, Expr) => Expr): Expr = {
     //reduce the language before applying flatten function
-    // println("Normalizing " + ScalaPrinter(expr) + "\n")
+    //println("Normalizing " + ScalaPrinter(expr) + "\n")
     val redex = reduceLangBlocks(expr, multOp)
-    // println("Redex: "+ScalaPrinter(redex) + "\n")
+    //println("Redex: " + ScalaPrinter(redex) + "\n")
     val nnfExpr = TransformNot(redex)
-    // println("NNFexpr: "+ScalaPrinter(nnfExpr) + "\n")
+    //println("NNFexpr: " + ScalaPrinter(nnfExpr) + "\n")
     //flatten all function calls
     val flatExpr = FlattenFunction(nnfExpr)
-    // println("Flatexpr: "+ScalaPrinter(flatExpr) + "\n")
+    println("Flatexpr: " + ScalaPrinter(flatExpr) + "\n")
     //perform additional simplification
     val simpExpr = pullAndOrs(TransformNot(flatExpr))
     simpExpr
@@ -477,7 +490,7 @@ object ExpressionTransformer {
   def unFlatten(ine: Expr, freevars: Set[Identifier]): Expr = {
     var tempMap = Map[Expr, Expr]()
     val newinst = simplePostTransform {
-      case e@Equals(v@Variable(id), rhs@_) if !freevars.contains(id) =>
+      case e @ Equals(v @ Variable(id), rhs @ _) if !freevars.contains(id) =>
         if (tempMap.contains(v)) e
         else {
           tempMap += (v -> rhs)
@@ -638,7 +651,7 @@ object ExpressionTransformer {
 
     def distribute(e: Expr): Expr = {
       simplePreTransform {
-        case e@FunctionInvocation(TypedFunDef(fd, _), Seq(e1, e2)) if isMultFunctions(fd) =>
+        case e @ FunctionInvocation(TypedFunDef(fd, _), Seq(e1, e2)) if isMultFunctions(fd) =>
           val newe = (e1, e2) match {
             case (Plus(sum1, sum2), _) =>
               // distribute e2 over e1
diff --git a/src/main/scala/leon/laziness/LazinessEliminationPhase.scala b/src/main/scala/leon/laziness/LazinessEliminationPhase.scala
index e121fafac..25a3baf8f 100644
--- a/src/main/scala/leon/laziness/LazinessEliminationPhase.scala
+++ b/src/main/scala/leon/laziness/LazinessEliminationPhase.scala
@@ -47,7 +47,7 @@ object LazinessEliminationPhase extends TransformationPhase {
   val dumpProgWOInstSpecs = true
   val dumpInstrumentedProgram = true
   val debugSolvers = false
-  val skipStateVerification = false
+  val skipStateVerification = true
   val skipResourceVerification = false
   val debugInstVCs = false
 
@@ -111,7 +111,7 @@ object LazinessEliminationPhase extends TransformationPhase {
     val instProg = instrumenter.apply
     if (dumpInstrumentedProgram) {
       //println("After instrumentation: \n" + ScalaPrinter.apply(instProg))
-      prettyPrintProgramToFile(instProg, ctx, "-withinst")
+      prettyPrintProgramToFile(instProg, ctx, "-withinst", uniqueIds = true)
     }
     // check specifications (to be moved to a different phase)
     if (!skipResourceVerification)
@@ -344,7 +344,8 @@ object LazinessEliminationPhase extends TransformationPhase {
         inferOpts.options ++ checkCtx.options)
       val inferctx = new InferenceContext(p,  ctxForInf)
       val vcSolver = (funDef: FunDef, prog: Program) => new VCSolver(inferctx, prog, funDef)
-      (new InferenceEngine(inferctx)).analyseProgram(p, funsToCheck, vcSolver, None)
+      prettyPrintProgramToFile(inferctx.inferProgram, checkCtx, "-inferProg", true)
+      (new InferenceEngine(inferctx)).analyseProgram(inferctx.inferProgram, funsToCheck, vcSolver, None)
     } else {
       val vcs = funsToCheck.map { fd =>
         val (ants, post, tmpl) = createVC(fd)
diff --git a/src/main/scala/leon/purescala/Expressions.scala b/src/main/scala/leon/purescala/Expressions.scala
index 5042413fc..e4191ff73 100644
--- a/src/main/scala/leon/purescala/Expressions.scala
+++ b/src/main/scala/leon/purescala/Expressions.scala
@@ -495,10 +495,11 @@ object Expressions {
     */
   case class And(exprs: Seq[Expr]) extends Expr {
     require(exprs.size >= 2)
-    val getType = {
-      if (exprs forall (_.getType == BooleanType)) BooleanType
-      else Untyped
-    }
+//    val getType = {
+//      if (exprs forall (_.getType == BooleanType)) BooleanType
+//      else Untyped
+//    }
+    val getType = BooleanType
   }
 
   object And {
-- 
GitLab