From b5edabade30cda324771e5059583f415603ae765 Mon Sep 17 00:00:00 2001
From: Nicolas Voirol <voirol.nicolas@gmail.com>
Date: Mon, 9 Feb 2015 14:22:22 +0100
Subject: [PATCH] Added some tests and guarded reference equality for lambdas

---
 .../solvers/templates/LambdaManager.scala     | 33 ++++++-
 .../solvers/templates/TemplateGenerator.scala |  2 +-
 .../leon/solvers/templates/Templates.scala    | 87 +++++++++++++++----
 .../purescala/invalid/BraunTree.scala         | 39 +++++++++
 .../purescala/valid/FlatMap.scala.BAK         | 69 +++++++++++++++
 .../verification/purescala/valid/Lists5.scala | 32 +++++++
 .../purescala/valid/Monads1.scala             | 24 +++++
 .../verification/purescala/valid/Sets1.scala  | 12 +--
 .../verification/purescala/valid/Sets2.scala  | 23 +++++
 .../verification/purescala/valid/Trees1.scala | 23 +++++
 10 files changed, 318 insertions(+), 26 deletions(-)
 create mode 100644 src/test/resources/regression/verification/purescala/invalid/BraunTree.scala
 create mode 100644 src/test/resources/regression/verification/purescala/valid/FlatMap.scala.BAK
 create mode 100644 src/test/resources/regression/verification/purescala/valid/Lists5.scala
 create mode 100644 src/test/resources/regression/verification/purescala/valid/Monads1.scala
 create mode 100644 src/test/resources/regression/verification/purescala/valid/Sets2.scala
 create mode 100644 src/test/resources/regression/verification/purescala/valid/Trees1.scala

diff --git a/src/main/scala/leon/solvers/templates/LambdaManager.scala b/src/main/scala/leon/solvers/templates/LambdaManager.scala
index 930e2a691..cfecfacda 100644
--- a/src/main/scala/leon/solvers/templates/LambdaManager.scala
+++ b/src/main/scala/leon/solvers/templates/LambdaManager.scala
@@ -38,11 +38,19 @@ class LambdaManager[T](encoder: TemplateEncoder[T]) {
     freeLambdasStack = map :: freeLambdasStack.tail
   }
 
+  private type StructuralMap = Map[Lambda, List[(T, LambdaTemplate[T])]]
+  private var structuralLambdasStack : List[StructuralMap] = List(Map.empty.withDefaultValue(List.empty))
+  private def structuralLambdas : StructuralMap = structuralLambdasStack.head
+  private def structuralLambdas_=(map: StructuralMap) : Unit = {
+    structuralLambdasStack = map :: structuralLambdasStack.tail
+  }
+
   def push(): Unit = {
     byIDStack = byID :: byIDStack
     byTypeStack = byType :: byTypeStack
     applicationsStack = applications :: applicationsStack
     freeLambdasStack = freeLambdas :: freeLambdasStack
+    structuralLambdasStack = structuralLambdas :: structuralLambdasStack
   }
 
   def pop(lvl: Int): Unit = {
@@ -50,6 +58,7 @@ class LambdaManager[T](encoder: TemplateEncoder[T]) {
     byTypeStack = byTypeStack.drop(lvl)
     applicationsStack = applicationsStack.drop(lvl)
     freeLambdasStack = freeLambdasStack.drop(lvl)
+    structuralLambdasStack = structuralLambdasStack.drop(lvl)
   }
 
   def registerFree(lambdas: Seq[(TypeTree, T)]): Unit = {
@@ -70,8 +79,10 @@ class LambdaManager[T](encoder: TemplateEncoder[T]) {
     }
 
     for (lambda @ (idT, template) <- lambdas) {
-      // make sure concrete lambdas can't be equal to free lambdas
-      clauses ++= freeLambdas(template.tpe).map(pIdT => encoder.mkNot(encoder.mkEquals(pIdT, idT)))
+      // get all lambda references...
+      val lambdaRefs = freeLambdas(template.tpe) ++ byType(template.tpe).map(_._1)
+      // ... and make sure the new lambda isn't equal to one of them!
+      clauses ++= lambdaRefs.map(pIdT => encoder.mkNot(encoder.mkEquals(pIdT, idT)))
 
       byID += idT -> template
       byType += template.tpe -> (byType(template.tpe) + (idT -> template))
@@ -106,5 +117,23 @@ class LambdaManager[T](encoder: TemplateEncoder[T]) {
     (clauses, callBlockers, appBlockers)
   }
 
+  def equalityClauses(template: LambdaTemplate[T], idT: T, substMap: Map[T,T]): Seq[T] = {
+    val key : Lambda = template.key
+    val t : LambdaTemplate[T] = template.substitute(substMap)
+
+    val newClauses = structuralLambdas(key).map { case (thatIdT, that) =>
+      val equals = encoder.mkEquals(idT, thatIdT)
+      if (t.dependencies.isEmpty) {
+        equals
+      } else {
+        encoder.mkImplies(t.contextEquality(that), equals)
+      }
+    }
+
+    structuralLambdas += key -> (structuralLambdas(key) :+ (idT -> t))
+
+    newClauses
+  }
+
 }
 
diff --git a/src/main/scala/leon/solvers/templates/TemplateGenerator.scala b/src/main/scala/leon/solvers/templates/TemplateGenerator.scala
index 75fd757a6..a69ec41e9 100644
--- a/src/main/scala/leon/solvers/templates/TemplateGenerator.scala
+++ b/src/main/scala/leon/solvers/templates/TemplateGenerator.scala
@@ -258,7 +258,7 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T]) {
           val (lambdaConds, lambdaExprs, lambdaGuarded, lambdaTemplates) = mkClauses(pathVar, clause, clauseSubst)
 
           val ids: (Identifier, T) = lid -> encoder.encodeId(lid)
-          val dependencies: Set[T] = variablesOf(l).map(localSubst)
+          val dependencies: Map[Identifier, T] = variablesOf(l).map(id => id -> localSubst(id)).toMap
           val template = LambdaTemplate(ids, encoder, lambdaManager, pathVar -> encodedCond(pathVar), idArgs zip trArgs, lambdaConds, lambdaExprs, lambdaGuarded, lambdaTemplates, localSubst, dependencies, l)
           storeLambda(ids._2, template)
 
diff --git a/src/main/scala/leon/solvers/templates/Templates.scala b/src/main/scala/leon/solvers/templates/Templates.scala
index f0c808a99..fb73c599d 100644
--- a/src/main/scala/leon/solvers/templates/Templates.scala
+++ b/src/main/scala/leon/solvers/templates/Templates.scala
@@ -33,7 +33,6 @@ trait Template[T] { self =>
   val lambdas : Map[T, LambdaTemplate[T]]
 
   private var substCache : Map[Seq[T],Map[T,T]] = Map.empty
-  private var lambdaCache : Map[(T, Map[T,T]), T] = Map.empty
 
   def instantiate(aVar: T, args: Seq[T]): (Seq[T], Map[T, Set[TemplateCallInfo[T]]], Map[(T, App[T]), Set[TemplateAppInfo[T]]]) = {
 
@@ -48,20 +47,10 @@ trait Template[T] { self =>
 
     val (lambdaSubstMap, lambdaClauses) = lambdas.foldLeft((Map.empty[T,T], Seq.empty[T])) {
       case ((subst, clauses), (idT, lambda)) =>
-        val closureMap = lambda.dependencies.map(idT => idT -> baseSubstMap(idT)).toMap
-        val key : (T, Map[T,T]) = idT -> closureMap
-
         val newIdT = encoder.encodeId(lambda.id)
-        val prevIdT = lambdaCache.get(key) match {
-          case Some(id) =>
-            Some(id)
-          case None =>
-            lambdaCache += key -> newIdT
-            None
-        }
+        val eqClauses = lambdaManager.equalityClauses(lambda, newIdT, baseSubstMap)
 
-        val newClause = prevIdT.map(id => encoder.mkEquals(newIdT, id))
-        (subst + (idT -> newIdT), clauses ++ newClause)
+        (subst + (idT -> newIdT), clauses ++ eqClauses)
     }
 
     val substMap : Map[T,T] = baseSubstMap ++ lambdaSubstMap + (start -> aVar)
@@ -238,6 +227,41 @@ class FunctionTemplate[T] private(
 
 object LambdaTemplate {
 
+  private var typedIds : Map[TypeTree, List[Identifier]] = Map.empty.withDefaultValue(List.empty)
+
+  private def templateKey[T](lambda: LambdaTemplate[T]): Lambda = {
+
+    def closureIds(expr: Expr): Seq[Identifier] = {
+      val vars = variablesOf(expr)
+      val allVars : Seq[Identifier] = foldRight[Seq[Identifier]] {
+        (expr, idSeqs) => idSeqs.foldLeft(expr match {
+          case Variable(id) => Seq(id)
+          case _ => Seq.empty[Identifier]
+        })((acc, seq) => acc ++ seq)
+      } (expr)
+
+      allVars.filter(vars(_)).distinct
+    }
+
+    val grouped : Map[TypeTree, Seq[Identifier]] = closureIds(lambda.lambda).groupBy(_.getType)
+    val subst : Map[Identifier, Identifier] = grouped.foldLeft(Map.empty[Identifier,Identifier]) { case (subst, (tpe, ids)) =>
+      val currentVars = typedIds(tpe)
+
+      val freshCount = ids.size - currentVars.size
+      val typedVars = if (freshCount > 0) {
+        val allIds = currentVars ++ List.range(0, freshCount).map(_ => FreshIdentifier("x", true).setType(tpe))
+        typedIds += tpe -> allIds
+        allIds
+      } else {
+        currentVars
+      }
+
+      subst ++ (ids zip typedVars)
+    }
+
+    replaceFromIDs(subst.mapValues(_.toVariable), lambda.lambda).asInstanceOf[Lambda]
+  }
+
   def apply[T](
     ids: (Identifier, T),
     encoder: TemplateEncoder[T],
@@ -249,7 +273,7 @@ object LambdaTemplate {
     guardedExprs: Map[Identifier, Seq[Expr]],
     lambdas: Map[T, LambdaTemplate[T]],
     baseSubstMap: Map[Identifier, T],
-    dependencies: Set[T],
+    dependencies: Map[Identifier, T],
     lambda: Lambda
   ) : LambdaTemplate[T] = {
 
@@ -294,8 +318,8 @@ class LambdaTemplate[T] private (
   val blockers: Map[T, Set[TemplateCallInfo[T]]],
   val applications: Map[T, Set[App[T]]],
   val lambdas: Map[T, LambdaTemplate[T]],
-  val dependencies: Set[T],
-  val lambda: Lambda,
+  private[templates] val dependencies: Map[Identifier, T],
+  private val lambda: Lambda,
   stringRepr: () => String) extends Template[T] {
 
   val tpe = id.getType
@@ -317,7 +341,7 @@ class LambdaTemplate[T] private (
 
     val newLambdas = lambdas.map { case (idT, template) => idT -> template.substitute(substMap) }
 
-    val newDependencies = dependencies.map(substituter)
+    val newDependencies = dependencies.map(p => p._1 -> substituter(p._2))
 
     new LambdaTemplate[T](
       id,
@@ -339,4 +363,33 @@ class LambdaTemplate[T] private (
 
   private lazy val str : String = stringRepr()
   override def toString : String = str
+
+  def contextEquality(that: LambdaTemplate[T]) : T = {
+    assert(key == that.key, "Can't generate equality clause for lambdas that don't share structure")
+    assert(dependencies.nonEmpty, "No closures implies obvious equality")
+
+    def rec(e1: Expr, e2: Expr): Seq[T] = (e1,e2) match {
+      case (Variable(id1), Variable(id2)) =>
+        if (dependencies.isDefinedAt(id1)) {
+          Seq(encoder.mkEquals(dependencies(id1), that.dependencies(id2)))
+        } else {
+          Seq.empty
+        }
+
+      case (NAryOperator(es1, _), NAryOperator(es2, _)) =>
+        (es1 zip es2).flatMap(p => rec(p._1, p._2))
+
+      case (BinaryOperator(e11, e12, _), BinaryOperator(e21, e22, _)) =>
+        rec(e11, e21) ++ rec(e12, e22)
+
+      case (UnaryOperator(ue1, _), UnaryOperator(ue2, _)) =>
+        rec(ue1, ue2)
+
+      case _ => Seq.empty
+    }
+
+    encoder.mkAnd(rec(lambda, that.lambda) : _*)
+  }
+
+  def key : Lambda = LambdaTemplate.templateKey(this)
 }
diff --git a/src/test/resources/regression/verification/purescala/invalid/BraunTree.scala b/src/test/resources/regression/verification/purescala/invalid/BraunTree.scala
new file mode 100644
index 000000000..e1de4180f
--- /dev/null
+++ b/src/test/resources/regression/verification/purescala/invalid/BraunTree.scala
@@ -0,0 +1,39 @@
+import leon.lang._
+
+object BraunTree {
+  abstract class Tree
+  case class Node(value: Int, left: Tree, right: Tree) extends Tree
+  case class Leaf() extends Tree
+
+  def insert(tree: Tree, x: Int): Tree = {
+    require(isBraun(tree))
+    tree match {
+      case Node(value, left, right) =>
+        Node(value, insert(left, x), right)
+      case Leaf() => Node(x, Leaf(), Leaf())
+    }
+  } ensuring { res => isBraun(res) }
+
+  def height(tree: Tree): Int = {
+    tree match {
+      case Node(value, left, right) =>
+        val l = height(left)
+        val r = height(right)
+        val max = if (l > r) l else r
+        1 + max
+      case Leaf() => 0
+    }
+  }
+
+  def isBraun(tree: Tree): Boolean = {
+    tree match {
+      case Node(value, left, right) =>
+        isBraun(left) && isBraun(right) && {
+          val l = height(left)
+          val r = height(right)
+          l == r || l == r + 1
+        }
+      case Leaf() => true
+    }
+  }
+}
diff --git a/src/test/resources/regression/verification/purescala/valid/FlatMap.scala.BAK b/src/test/resources/regression/verification/purescala/valid/FlatMap.scala.BAK
new file mode 100644
index 000000000..fc177bd1f
--- /dev/null
+++ b/src/test/resources/regression/verification/purescala/valid/FlatMap.scala.BAK
@@ -0,0 +1,69 @@
+import leon.lang._
+import leon.collection._
+
+object Lists5 {
+
+  def append[T](l1: List[T], l2: List[T]): List[T] = l1 match {
+    case Cons(head, tail) => Cons(head, append(tail, l2))
+    case Nil() => l2
+  }
+
+  def associative_append_lemma[T](l1: List[T], l2: List[T], l3: List[T]): Boolean = {
+    append(append(l1, l2), l3) == append(l1, append(l2, l3))
+  }
+
+  def associative_append_lemma_induct[T](l1: List[T], l2: List[T], l3: List[T]): Boolean = {
+    l1 match {
+      case Nil() => associative_append_lemma(l1, l2, l3)
+      case Cons(head, tail) => associative_append_lemma(l1, l2, l3) && associative_append_lemma_induct(tail, l2, l3)
+    }
+  }.holds
+
+  def flatMap[T,U](list: List[T], f: T => List[U]): List[U] = list match {
+    case Cons(head, tail) => append(f(head), flatMap(tail, f))
+    case Nil() => Nil()
+  }
+
+  def associative_lemma[T,U,V](list: List[T], f: T => List[U], g: U => List[V]): Boolean = {
+    flatMap(flatMap(list, f), g) == flatMap(list, (x: T) => flatMap(f(x), g))
+  }
+
+  def associative_lemma_helper[T,U,V](fprev: List[U], fcur: List[U], fgprev: List[V], fgcur: List[V], tail: List[T], f: T => List[U], g: U => List[V]): Boolean = {
+    val left = fcur match {
+      case Cons(fhead, ftail) =>
+        flatMap(append(fprev, Cons(fhead, append(ftail, flatMap(tail, f)))), g)
+      case Nil() =>
+        flatMap(append(fprev, flatMap(tail, f)), g)
+    }
+
+    val right = fgcur match {
+      case Cons(fghead, fgtail) =>
+        append(fgprev, Cons(fghead, append(fgtail, flatMap(tail, (x: T) => flatMap(f(x), g)))))
+      case Nil() =>
+        append(fgprev, flatMap(tail, (x: T) => flatMap(f(x), g)))
+    }
+        
+    left == right && (fgcur match {
+      case Cons(fghead, fgtail) => associative_lemma_helper(fprev, fcur, append(fgprev, Cons(fghead, Nil())), fgtail, tail, f, g)
+      case Nil() => fcur match {
+        case Cons(fhead, ftail) => associative_lemma_helper(append(fprev, Cons(fhead, Nil())), ftail, fgprev, fgcur, tail, f, g)
+        case Nil() => tail match {
+          case Nil() => true
+          case Cons(head, tail) => f(head) match {
+            case fl @ Cons(fh, ft) => associative_lemma_helper(Nil(), fl, Nil(), flatMap(fl, g), tail, f, g)
+            case Nil() => associative_lemma_helper(Nil(), Nil(), Nil(), Nil(), tail, f, g)
+          }
+        }
+      }
+    })
+  }
+
+  def associative_lemma_induct[T,U,V](list: List[T], f: T => List[U], g: U => List[V]): Boolean = {
+    list match {
+      case Nil() => associative_lemma(list, f, g)
+      case Cons(head, tail) =>
+        associative_lemma(list, f, g) && associative_lemma_induct(tail, f, g) && associative_lemma_helper(Nil(), Nil(), Nil(), Nil(), list, f, g)
+    }
+  }.holds
+
+}
diff --git a/src/test/resources/regression/verification/purescala/valid/Lists5.scala b/src/test/resources/regression/verification/purescala/valid/Lists5.scala
new file mode 100644
index 000000000..51ff810c6
--- /dev/null
+++ b/src/test/resources/regression/verification/purescala/valid/Lists5.scala
@@ -0,0 +1,32 @@
+import leon.lang._
+import leon.collection._
+
+object Lists5 {
+  abstract class Option[T]
+  case class Some[T](value: T) extends Option[T]
+  case class None[T]() extends Option[T]
+
+  def find[T](f: T => Boolean, list: List[T]): Option[T] = list match {
+    case Cons(head, tail) => if (f(head)) Some(head) else find(f, tail)
+    case Nil() => None()
+  }
+
+  def exists[T](f: T => Boolean, list: List[T]): Boolean = list match {
+    case Cons(head, tail) => f(head) || exists(f, tail)
+    case Nil() => false
+  }
+
+  def equivalence_lemma[T](f: T => Boolean, list: List[T]): Boolean = {
+    find(f, list) match {
+      case Some(e) => exists(f, list)
+      case None() => !exists(f, list)
+    }
+  }
+
+  def equivalence_lemma_induct[T](f: T => Boolean, list: List[T]): Boolean = {
+    list match {
+      case Cons(head, tail) => equivalence_lemma(f, list) && equivalence_lemma_induct(f, tail)
+      case Nil() => equivalence_lemma(f, list)
+    }
+  }.holds
+}
diff --git a/src/test/resources/regression/verification/purescala/valid/Monads1.scala b/src/test/resources/regression/verification/purescala/valid/Monads1.scala
new file mode 100644
index 000000000..da9cd6cec
--- /dev/null
+++ b/src/test/resources/regression/verification/purescala/valid/Monads1.scala
@@ -0,0 +1,24 @@
+import leon.lang._
+
+object Monads1 {
+  abstract class Try[T]
+  case class Success[T](value: T) extends Try[T]
+  case class Failure[T](error: Int) extends Try[T]
+
+  def flatMap[T,U](t: Try[T], f: T => Try[U]): Try[U] = t match {
+    case Success(value) => f(value)
+    case fail @ Failure(error) => Failure(error)
+  }
+
+  def associative_law[T,U,V](t: Try[T], f: T => Try[U], g: U => Try[V]): Boolean = {
+    flatMap(flatMap(t, f), g) == flatMap(t, (x: T) => flatMap(f(x), g))
+  }.holds
+
+  def left_unit_law[T,U](x: T, f: T => Try[U]): Boolean = {
+    flatMap(Success(x), f) == f(x)
+  }.holds
+
+  def right_unit_law[T,U](t: Try[T]): Boolean = {
+    flatMap(t, (x: T) => Success(x)) == t
+  }.holds
+}
diff --git a/src/test/resources/regression/verification/purescala/valid/Sets1.scala b/src/test/resources/regression/verification/purescala/valid/Sets1.scala
index e46ba25d0..4cfca4890 100644
--- a/src/test/resources/regression/verification/purescala/valid/Sets1.scala
+++ b/src/test/resources/regression/verification/purescala/valid/Sets1.scala
@@ -9,15 +9,15 @@ object Sets1 {
 
   def intersection(s1: Int => Boolean, s2: Int => Boolean): Int => Boolean = x => s1(x) && s2(x)
 
-  def associativity(sa1: Int => Boolean, sa2: Int => Boolean, sa3: Int => Boolean, x: Int): Boolean = {
-    val u1 = union(union(sa1, sa2), sa3)
-    val u2 = union(sa1, union(sa2, sa3))
+  def de_morgan_1(s1: Int => Boolean, s2: Int => Boolean, x: Int): Boolean = {
+    val u1 = union(s1, s2)
+    val u2 = complement(intersection(complement(s1), complement(s2)))
     u1(x) == u2(x)
   }.holds
 
-  def lemma(s1: Int => Boolean, s2: Int => Boolean, x: Int): Boolean = {
-    val u1 = union(s1, s2)
-    val u2 = complement(intersection(complement(s1), complement(s2)))
+  def de_morgan_2(s1: Int => Boolean, s2: Int => Boolean, x: Int): Boolean = {
+    val u1 = intersection(s1, s2)
+    val u2 = complement(union(complement(s1), complement(s2)))
     u1(x) == u2(x)
   }.holds
 }
diff --git a/src/test/resources/regression/verification/purescala/valid/Sets2.scala b/src/test/resources/regression/verification/purescala/valid/Sets2.scala
new file mode 100644
index 000000000..f25449028
--- /dev/null
+++ b/src/test/resources/regression/verification/purescala/valid/Sets2.scala
@@ -0,0 +1,23 @@
+import leon.lang._
+
+object Sets1 {
+  def set(i: Int): Int => Boolean = x => x == i
+
+  def complement(s: Int => Boolean): Int => Boolean = x => !s(x)
+
+  def union(s1: Int => Boolean, s2: Int => Boolean): Int => Boolean = x => s1(x) || s2(x)
+
+  def intersection(s1: Int => Boolean, s2: Int => Boolean): Int => Boolean = x => s1(x) && s2(x)
+
+  def union_associativity(sa1: Int => Boolean, sa2: Int => Boolean, sa3: Int => Boolean, x: Int): Boolean = {
+    val u1 = union(union(sa1, sa2), sa3)
+    val u2 = union(sa1, union(sa2, sa3))
+    u1(x) == u2(x)
+  }.holds
+
+  def intersection_associativity(sa1: Int => Boolean, sa2: Int => Boolean, sa3: Int => Boolean, x: Int): Boolean = {
+    val u1 = intersection(intersection(sa1, sa2), sa3)
+    val u2 = intersection(sa1, intersection(sa2, sa3))
+    u1(x) == u2(x)
+  }.holds
+}
diff --git a/src/test/resources/regression/verification/purescala/valid/Trees1.scala b/src/test/resources/regression/verification/purescala/valid/Trees1.scala
new file mode 100644
index 000000000..4c01ae06e
--- /dev/null
+++ b/src/test/resources/regression/verification/purescala/valid/Trees1.scala
@@ -0,0 +1,23 @@
+import leon.lang._
+
+object Trees1 {
+  abstract class Tree[T]
+  case class Node[T](left: Tree[T], right: Tree[T]) extends Tree[T]
+  case class Leaf[T](value: T) extends Tree[T]
+
+  def map[T,U](tree: Tree[T], f: T => U): Tree[U] = tree match {
+    case Node(left, right) => Node(map(left, f), map(right, f))
+    case Leaf(value) => Leaf(f(value))
+  }
+
+  def associative_lemma[T,U,V](tree: Tree[T], f: T => U, g: U => V): Boolean = {
+    map(map(tree, f), g) == map(tree, (x: T) => g(f(x)))
+  }
+
+  def associative_lemma_induct[T,U,V](tree: Tree[T], f: T => U, g: U => V): Boolean = {
+    tree match {
+      case Node(left, right) => associative_lemma_induct(left, f, g) && associative_lemma_induct(right, f, g) && associative_lemma(tree, f, g)
+      case Leaf(value) => associative_lemma(tree, f, g)
+    }
+  }.holds
+}
-- 
GitLab