From 6cb4213c3827e75c192c0ec84e249393b709b696 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Mika=C3=ABl=20Mayer?= <mikael.mayer@epfl.ch>
Date: Tue, 26 Jan 2016 17:06:25 +0100
Subject: [PATCH] Added ExprOps.canBeHomomorphic which returns a mapping from
 identifiers if the expressions are homomorphic. Addded failing test of
 partial string synthesis by example and made it work at the same time.

---
 src/main/scala/leon/purescala/ExprOps.scala   | 163 ++++++++++++++++++
 .../scala/leon/synthesis/ExamplesFinder.scala |  21 ++-
 .../leon/synthesis/rules/StringRender.scala   |   2 +-
 .../solvers/StringRenderSuite.scala           | 160 +++++++++++------
 4 files changed, 280 insertions(+), 66 deletions(-)

diff --git a/src/main/scala/leon/purescala/ExprOps.scala b/src/main/scala/leon/purescala/ExprOps.scala
index e6c87698c..6899a422e 100644
--- a/src/main/scala/leon/purescala/ExprOps.scala
+++ b/src/main/scala/leon/purescala/ExprOps.scala
@@ -1534,6 +1534,169 @@ object ExprOps {
     case _ =>
       false
   }
+  
+  /** Checks whether two expressions can be homomorphic and returns the corresponding mapping */
+  def canBeHomomorphic(t1: Expr, t2: Expr): Option[Map[Identifier, Identifier]] = {
+    def mergeContexts(a: Option[Map[Identifier, Identifier]], b: =>Option[Map[Identifier, Identifier]]) = a match {
+      case Some(m) =>
+        b match {
+          case Some(n) if (m.keySet & n.keySet) forall (key => m(key) == n(key)) =>
+            Some(m ++ n)        
+          case _ =>None
+        }
+      case _ => None
+    }
+    object Same {
+      def unapply(tt: (Expr, Expr)): Option[(Expr, Expr)] = {
+        if (tt._1.getClass == tt._2.getClass) {
+          Some(tt)
+        } else {
+          None
+        }
+      }
+    }
+    implicit class AugmentedContext(c: Option[Map[Identifier, Identifier]]) {
+      def &&(other: => Option[Map[Identifier, Identifier]]) = mergeContexts(c, other)
+    }
+    implicit class AugmentedBooleant(c: Boolean) {
+      def &&(other: => Option[Map[Identifier, Identifier]]) = if(c) other else None
+    }
+    implicit class AugmentedSeq[T](c: Seq[T]) {
+      def mergeall(p: T => Option[Map[Identifier, Identifier]]) =
+        (Option(Map[Identifier, Identifier]()) /: c) {
+          case (s, c) => s && p(c)
+        }
+    }
+
+
+    def idHomo(i1: Identifier, i2: Identifier): Option[Map[Identifier, Identifier]] = {
+      Some(Map(i1 -> i2))
+    }
+
+    def fdHomo(fd1: FunDef, fd2: FunDef): Option[Map[Identifier, Identifier]] = {
+      if(fd1.params.size == fd2.params.size) {
+         val newMap = Map((
+           (fd1.id -> fd2.id) +:
+           (fd1.paramIds zip fd2.paramIds)): _*)
+         Option(newMap) && isHomo(fd1.fullBody, fd2.fullBody)
+      } else None
+    }
+
+    def isHomo(t1: Expr, t2: Expr): Option[Map[Identifier, Identifier]] = {
+      def casesMatch(cs1 : Seq[MatchCase], cs2 : Seq[MatchCase]) : Option[Map[Identifier, Identifier]] = {
+        def patternHomo(p1: Pattern, p2: Pattern): (Boolean, Map[Identifier, Identifier]) = (p1, p2) match {
+          case (InstanceOfPattern(ob1, cd1), InstanceOfPattern(ob2, cd2)) =>
+            (ob1.size == ob2.size && cd1 == cd2, Map((ob1 zip ob2).toSeq : _*))
+
+          case (WildcardPattern(ob1), WildcardPattern(ob2)) =>
+            (ob1.size == ob2.size, Map((ob1 zip ob2).toSeq : _*))
+
+          case (CaseClassPattern(ob1, ccd1, subs1), CaseClassPattern(ob2, ccd2, subs2)) =>
+            val m = Map[Identifier, Identifier]() ++ (ob1 zip ob2)
+
+            if (ob1.size == ob2.size && ccd1 == ccd2 && subs1.size == subs2.size) {
+              (subs1 zip subs2).map { case (p1, p2) => patternHomo(p1, p2) }.foldLeft((true, m)) {
+                case ((b1, m1), (b2,m2)) => (b1 && b2, m1 ++ m2)
+              }
+            } else {
+              (false, Map())
+            }
+
+          case (UnapplyPattern(ob1, fd1, subs1), UnapplyPattern(ob2, fd2, subs2)) =>
+            val m = Map[Identifier, Identifier]() ++ (ob1 zip ob2)
+
+            if (ob1.size == ob2.size && fd1 == fd2 && subs1.size == subs2.size) {
+              (subs1 zip subs2).map { case (p1, p2) => patternHomo(p1, p2) }.foldLeft((true, m)) {
+                case ((b1, m1), (b2,m2)) => (b1 && b2, m1 ++ m2)
+              }
+            } else {
+              (false, Map())
+            }
+
+          case (TuplePattern(ob1, subs1), TuplePattern(ob2, subs2)) =>
+            val m = Map[Identifier, Identifier]() ++ (ob1 zip ob2)
+
+            if (ob1.size == ob2.size && subs1.size == subs2.size) {
+              (subs1 zip subs2).map { case (p1, p2) => patternHomo(p1, p2) }.foldLeft((true, m)) {
+                case ((b1, m1), (b2,m2)) => (b1 && b2, m1 ++ m2)
+              }
+            } else {
+              (false, Map())
+            }
+
+          case (LiteralPattern(ob1, lit1), LiteralPattern(ob2,lit2)) =>
+            (ob1.size == ob2.size && lit1 == lit2, (ob1 zip ob2).toMap)
+
+          case _ =>
+            (false, Map())
+        }
+
+        (cs1 zip cs2).mergeall {
+          case (MatchCase(p1, g1, e1), MatchCase(p2, g2, e2)) =>
+            val (h, nm) = patternHomo(p1, p2)
+            val g: Option[Map[Identifier, Identifier]] = (g1, g2) match {
+              case (Some(g1), Some(g2)) => Some(nm) && isHomo(g1,g2)
+              case (None, None) => Some(Map())
+              case _ => None
+            }
+            val e = Some(nm) && isHomo(e1, e2)
+
+            h && g && e
+        }
+
+      }
+
+      import synthesis.Witnesses.Terminating
+
+      val res: Option[Map[Identifier, Identifier]] = (t1, t2) match {
+        case (Variable(i1), Variable(i2)) =>
+          idHomo(i1, i2)
+
+        case (Let(id1, v1, e1), Let(id2, v2, e2)) =>
+          isHomo(v1, v2) &&
+          isHomo(e1, e2) && Some(Map(id1 -> id2))
+
+        case (LetDef(fds1, e1), LetDef(fds2, e2)) =>
+          fds1.size == fds2.size &&
+          {
+            val zipped = fds1.zip(fds2)
+            (zipped mergeall (fds => fdHomo(fds._1, fds._2))) && Some(zipped.map(fds => fds._1.id -> fds._2.id).toMap) &&
+            isHomo(e1, e2)
+          }
+
+        case (MatchExpr(s1, cs1), MatchExpr(s2, cs2)) =>
+          cs1.size == cs2.size && casesMatch(cs1,cs2) && isHomo(s1, s2)
+
+        case (Passes(in1, out1, cs1), Passes(in2, out2, cs2)) =>
+          (cs1.size == cs2.size && casesMatch(cs1,cs2)) && isHomo(in1,in2) && isHomo(out1,out2)
+
+        case (FunctionInvocation(tfd1, args1), FunctionInvocation(tfd2, args2)) =>
+          // TODO: Check type params
+          fdHomo(tfd1.fd, tfd2.fd) &&
+          (args1 zip args2).mergeall{ case (a1, a2) => isHomo(a1, a2) }
+
+        case (Terminating(tfd1, args1), Terminating(tfd2, args2)) =>
+          // TODO: Check type params
+          fdHomo(tfd1.fd, tfd2.fd) &&
+          (args1 zip args2).mergeall{ case (a1, a2) => isHomo(a1, a2) }
+
+        // TODO: Seems a lot is missing, like Literals
+
+        case Same(Operator(es1, _), Operator(es2, _)) =>
+          (es1.size == es2.size) &&
+          (es1 zip es2).mergeall{ case (e1, e2) => isHomo(e1, e2) }
+
+        case _ =>
+          None
+      }
+
+      res
+    }
+
+    isHomo(t1,t2)
+    
+    
+  } // ensuring (res => res.isEmpty || isHomomorphic(t1, t2)(res.get))
 
   /** Checks whether two trees are homomoprhic modulo an identifier map.
     *
diff --git a/src/main/scala/leon/synthesis/ExamplesFinder.scala b/src/main/scala/leon/synthesis/ExamplesFinder.scala
index f535d2f56..c22c7e7f3 100644
--- a/src/main/scala/leon/synthesis/ExamplesFinder.scala
+++ b/src/main/scala/leon/synthesis/ExamplesFinder.scala
@@ -144,14 +144,15 @@ class ExamplesFinder(ctx0: LeonContext, program: Program) {
             case _                                 => test
           }
         }
-
-        // Check whether we can extract all ids from example
-        val results = exs.collect { case e if infos.forall(_._2.isDefinedAt(e)) || this.keepAbstractExamples =>
-          infos.map{ case (id, f) => id -> f(e) }.toMap
+        try {
+          // Check whether we can extract all ids from example
+          val results = exs.collect { case e if this.keepAbstractExamples || infos.forall(_._2.isDefinedAt(e)) =>
+            infos.map{ case (id, f) => id -> f(e) }.toMap
+          }
+          results.toSet
+        } catch {
+          case e: IDExtractionException => Set()
         }
-
-        results.toSet
-
       case _ =>
         Set()
     }(e)
@@ -272,6 +273,8 @@ class ExamplesFinder(ctx0: LeonContext, program: Program) {
     }
     consolidated
   }
+  
+  case class IDExtractionException(msg: String) extends Exception(msg)
 
   /** Extract ids in ins/outs args, and compute corresponding extractors for values map
     *
@@ -291,13 +294,13 @@ class ExamplesFinder(ctx0: LeonContext, program: Program) {
     case Tuple(vs) =>
       vs.map(extractIds).zipWithIndex.flatMap{ case (ids, i) =>
         ids.map{ case (id, e) =>
-          (id, andThen({ case Tuple(vs) => vs(i) }, e))
+          (id, andThen({ case Tuple(vs) => vs(i) case e => throw new IDExtractionException("Expected Tuple, got " + e) }, e))
         }
       }
     case CaseClass(cct, args) =>
       args.map(extractIds).zipWithIndex.flatMap { case (ids, i) =>
         ids.map{ case (id, e) =>
-          (id, andThen({ case CaseClass(cct2, vs) if cct2 == cct => vs(i) } ,e))
+          (id, andThen({ case CaseClass(cct2, vs) if cct2 == cct => vs(i) case e => throw new IDExtractionException("Expected Case class of type " + cct + ", got " + e) } ,e))
         }
       }
 
diff --git a/src/main/scala/leon/synthesis/rules/StringRender.scala b/src/main/scala/leon/synthesis/rules/StringRender.scala
index 319ec9ad3..b3172538e 100644
--- a/src/main/scala/leon/synthesis/rules/StringRender.scala
+++ b/src/main/scala/leon/synthesis/rules/StringRender.scala
@@ -134,7 +134,7 @@ case object StringRender extends Rule("StringRender") {
           case (lhs, rhs) => Some((accEqs += ((lhs, rhs))).toList)
         }
       case (OtherStringFormToken(e)::lhstail, OtherStringChunk(f)::rhstail) =>
-        if(e == f) {
+        if(ExprOps.canBeHomomorphic(e, f).nonEmpty) {
           rec(lhstail, rhstail, accEqs += ((accLeft.toList, accRight.toString)), ListBuffer[StringFormToken](), new StringBuffer)
         } else None
       case (OtherStringFormToken(e)::lhstail, Nil) =>
diff --git a/src/test/scala/leon/integration/solvers/StringRenderSuite.scala b/src/test/scala/leon/integration/solvers/StringRenderSuite.scala
index b6beaa104..66aa7f59e 100644
--- a/src/test/scala/leon/integration/solvers/StringRenderSuite.scala
+++ b/src/test/scala/leon/integration/solvers/StringRenderSuite.scala
@@ -10,6 +10,7 @@ import leon.purescala.Common.Identifier
 import leon.purescala.Expressions._
 import leon.purescala.Definitions._
 import leon.purescala.DefOps
+import leon.purescala.ExprOps
 import leon.purescala.Types._
 import leon.purescala.TypeOps
 import leon.purescala.Constructors._
@@ -206,6 +207,10 @@ class StringRenderSuite extends LeonTestSuiteWithProgram with Matchers with Scal
     |  def bListToString[T](b: BList[T], f: T => String) = ???[String] ensuring
     |  { (res: String) => (b, res) passes { case BNil() => "[]" case BCons(a, BCons(b, BCons(c, BNil()))) => "[" + f(a._1) + "-" + f(a._2) + ", " + f(b._1) + "-" + f(b._2) + ", " + f(c._1) + "-" + f(c._2) + "]" }}
     |  
+    |  case class BConfig(flags: BList[Dummy])
+    |  def bConfigToString(b: BConfig): String = ???[String] ensuring
+    |  { (res: String) => (b, res) passes { case BConfig(BNil()) => "Config" + bListToString[Dummy](BNil(), (x: Dummy) => dummyToString(x)) } }
+    |  
     |  case class Node(tag: String, l: List[Edge])
     |  case class Edge(start: Node, end: Node)
     |  
@@ -221,7 +226,6 @@ class StringRenderSuite extends LeonTestSuiteWithProgram with Matchers with Scal
   def synthesizeAndTest(functionName: String, tests: (Seq[Expr], String)*) {
     val (fd, program) = applyStringRenderOn(functionName)
     val when = new DefaultEvaluator(ctx, program)
-    val when_abstract = new AbstractEvaluator(ctx, program)
     val args = getFunctionArguments(functionName)
     for((in, out) <- tests) {
       val expr = functionInvocation(fd, in)
@@ -232,6 +236,81 @@ class StringRenderSuite extends LeonTestSuiteWithProgram with Matchers with Scal
       }
     }
   }
+  def synthesizeAndAbstractTest(functionName: String)(tests: (FunDef, Program) => Seq[(Seq[Expr], Expr)]) {
+    val (fd, program) = applyStringRenderOn(functionName)
+    val when_abstract = new AbstractEvaluator(ctx, program)
+    val args = getFunctionArguments(functionName)
+    for((in, out) <- tests(fd, program)) {
+      val expr = functionInvocation(fd, in)
+      when_abstract.eval(expr) match {
+        case EvaluationResults.Successful(value) => val m = ExprOps.canBeHomomorphic(value._1, out)
+          assert(m.nonEmpty, value._1 + " was not homomorphic with " + out)
+        case EvaluationResults.EvaluatorError(msg) => fail(/*program + "\n" + */msg)
+        case EvaluationResults.RuntimeError(msg) => fail(/*program + "\n" + */"Runtime: " + msg)
+      }
+    }
+  }
+  class TreeBuilder(program: Program) {
+    object Knot {
+      def apply(left: Expr, right: Expr): CaseClass = {
+        CaseClass(program.lookupCaseClass("StringRenderSuite.Knot").get.typed,
+            Seq(left, right))
+      }
+    }
+    object Bud {
+      def apply(s: String): CaseClass = {
+        CaseClass(program.lookupCaseClass("StringRenderSuite.Bud").get.typed,
+            Seq(StringLiteral(s)))
+      }
+    }
+  }
+  class DummyBuilder(program: Program) {
+    object Dummy {
+      def getType: TypeTree = program.lookupCaseClass("StringRenderSuite.Dummy").get.typed
+      def apply(s: String): CaseClass = {
+        CaseClass(program.lookupCaseClass("StringRenderSuite.Dummy").get.typed,
+            Seq(StringLiteral(s)))
+      }
+    }
+  }
+  
+  class BListBuilder(program: Program) {
+    object Cons {
+      def apply(types: Seq[TypeTree])(left: Expr, right: Expr): CaseClass = {
+        CaseClass(program.lookupCaseClass("StringRenderSuite.BCons").get.typed(types),
+            Seq(left, right))
+      }
+    }
+    object Nil {
+      def apply(types: Seq[TypeTree]): CaseClass = {
+        CaseClass(program.lookupCaseClass("StringRenderSuite.BNil").get.typed(types),
+            Seq())
+      }
+    }
+    object List {
+      def apply(types: Seq[TypeTree])(elems: Expr*): CaseClass = {
+        elems.toList match {
+          case collection.immutable.Nil => Nil(types)
+          case a::b => Cons(types)(a, List(types)(b: _*))
+        }
+      }
+    }
+  }
+  case class ConfigBuilder(program: Program) {
+    def apply(i: Int, b: (Int, String)): CaseClass = {
+      CaseClass(program.lookupCaseClass("StringRenderSuite.Config").get.typed,
+          Seq(InfiniteIntegerLiteral(i), tupleWrap(Seq(IntLiteral(b._1), StringLiteral(b._2)))))
+    }
+  }
+  class BConfigBuilder(program: Program) {
+    object BConfig {
+      def getType: TypeTree = program.lookupCaseClass("StringRenderSuite.BConfig").get.typed
+      def apply(s: Expr): CaseClass = {
+        CaseClass(program.lookupCaseClass("StringRenderSuite.BConfig").get.typed,
+            Seq(s))
+      }
+    }
+  }
   
   test("Literal synthesis"){ case (ctx: LeonContext, program: Program) =>
     synthesizeAndTest("literalSynthesis",
@@ -257,40 +336,20 @@ class StringRenderSuite extends LeonTestSuiteWithProgram with Matchers with Scal
         Seq(StringLiteral("\t")) -> "done...\\t")
         
   }*/
-  case class ConfigBuilder(program: Program) {
-    def apply(i: Int, b: (Int, String)): CaseClass = {
-      CaseClass(program.lookupCaseClass("StringRenderSuite.Config").get.typed,
-          Seq(InfiniteIntegerLiteral(i), tupleWrap(Seq(IntLiteral(b._1), StringLiteral(b._2)))))
-    }
-  }
-  StringRender.enforceDefaultStringMethodsIfAvailable = false
+  
   test("Case class synthesis"){ case (ctx: LeonContext, program: Program) =>
     val Config = ConfigBuilder(program)
-    
+    StringRender.enforceDefaultStringMethodsIfAvailable = false
     synthesizeAndTest("configToString",
         Seq(Config(4, (5, "foo"))) -> "4: 5 -> foo")
   }
   
   test("Out of order synthesis"){ case (ctx: LeonContext, program: Program) =>
     val Config = ConfigBuilder(program)
-    
+    StringRender.enforceDefaultStringMethodsIfAvailable = false
     synthesizeAndTest("configToString2",
         Seq(Config(4, (5, "foo"))) -> "foo: 4 -> 5")
   }
-  class TreeBuilder(program: Program) {
-    object Knot {
-      def apply(left: Expr, right: Expr): CaseClass = {
-        CaseClass(program.lookupCaseClass("StringRenderSuite.Knot").get.typed,
-            Seq(left, right))
-      }
-    }
-    object Bud {
-      def apply(s: String): CaseClass = {
-        CaseClass(program.lookupCaseClass("StringRenderSuite.Bud").get.typed,
-            Seq(StringLiteral(s)))
-      }
-    }
-  }
   
   test("Recursive case class synthesis"){ case (ctx: LeonContext, program: Program) =>
     val tb = new TreeBuilder(program)
@@ -300,30 +359,6 @@ class StringRenderSuite extends LeonTestSuiteWithProgram with Matchers with Scal
         "<<aaYbb>Y<mmYnn>>")
   }
   
-  class DummyBuilder(program: Program) {
-    object Dummy {
-      def getType: TypeTree = program.lookupCaseClass("StringRenderSuite.Dummy").get.typed
-      def apply(s: String): CaseClass = {
-        CaseClass(program.lookupCaseClass("StringRenderSuite.Dummy").get.typed,
-            Seq(StringLiteral(s)))
-      }
-    }
-  }
-  
-  class BListBuilder(program: Program) {
-    object Cons {
-      def apply(types: Seq[TypeTree])(left: Expr, right: Expr): CaseClass = {
-        CaseClass(program.lookupCaseClass("StringRenderSuite.BCons").get.typed(types),
-            Seq(left, right))
-      }
-    }
-    object Nil {
-      def apply(types: Seq[TypeTree]): CaseClass = {
-        CaseClass(program.lookupCaseClass("StringRenderSuite.BNil").get.typed(types),
-            Seq())
-      }
-    }
-  }
   test("Abstract synthesis"){ case (ctx: LeonContext, program: Program) =>
     val db = new DummyBuilder(program)
     import db._
@@ -334,20 +369,33 @@ class StringRenderSuite extends LeonTestSuiteWithProgram with Matchers with Scal
     val dummyToString = program.lookupFunDef("StringRenderSuite.dummyToString").get
     
     synthesizeAndTest("bListToString",
-        Seq(Cons(DT)(tupleWrap(Seq(Dummy("a"), Dummy("b"))),
-            Cons(DT)(tupleWrap(Seq(Dummy("c"), Dummy("d"))),
-            Nil(DT))),
+        Seq(List(DT)(
+              tupleWrap(Seq(Dummy("a"), Dummy("b"))),
+              tupleWrap(Seq(Dummy("c"), Dummy("d")))),
             Lambda(Seq(ValDef(d)), FunctionInvocation(dummyToString.typed, Seq(Variable(d)))))
             ->
         "[{a}-{b}, {c}-{d}]")
     
   }
   
-  test("Use of existing functions"){ case (ctx: LeonContext, program: Program) =>
-    
-  }
   
-  test("Pretty-printing using existing functions"){ case (ctx: LeonContext, program: Program) =>
-    // Lists of size 2
+  test("Pretty-printing using existing not yet defined functions"){ case (ctx: LeonContext, program: Program) =>
+    StringRender.enforceDefaultStringMethodsIfAvailable = true
+    synthesizeAndAbstractTest("bConfigToString"){ (fd: FunDef, program: Program) =>
+      val db = new DummyBuilder(program)
+      import db._
+      val DT = Seq(Dummy.getType)
+      val bcb = new BConfigBuilder(program)
+      import bcb._
+      val blb = new BListBuilder(program)
+      import blb._
+      val d = FreshIdentifier("d", Dummy.getType)
+      val arg = List(DT)(tupleWrap(Seq(Dummy("a"), Dummy("b"))))
+      val dummyToString = program.lookupFunDef("StringRenderSuite.dummyToString").get
+      val lambdaDummyToString = Lambda(Seq(ValDef(d)), FunctionInvocation(dummyToString.typed, Seq(Variable(d))))
+      val listDummyToString = functionInvocation(program.lookupFunDef("StringRenderSuite.bListToString").get, Seq(arg, lambdaDummyToString))
+      Seq(Seq(BConfig(arg)) ->
+      StringConcat(StringLiteral("Config"), listDummyToString))
+    }
   }
 }
\ No newline at end of file
-- 
GitLab