From 726d0107b584c042aae387e6faf0bddad202b7d2 Mon Sep 17 00:00:00 2001
From: Manos Koukoutos <emmanouil.koukoutos@epfl.ch>
Date: Tue, 26 Jan 2016 16:35:36 +0100
Subject: [PATCH] Fix empty LetDef's, tidy up CEGIS, better CEGIS errors

---
 .../frontends/scalac/CodeExtraction.scala     |  2 +-
 src/main/scala/leon/purescala/ExprOps.scala   |  2 +-
 .../scala/leon/purescala/Expressions.scala    |  2 +-
 .../leon/purescala/ScopeSimplifier.scala      |  3 ++-
 src/main/scala/leon/purescala/TypeOps.scala   |  2 +-
 .../leon/solvers/SolverUnsupportedError.scala |  2 +-
 .../leon/solvers/z3/AbstractZ3Solver.scala    |  4 ++--
 .../leon/synthesis/rules/CEGISLike.scala      | 20 +++++++++++------
 .../leon/synthesis/rules/TEGISLike.scala      | 22 +++++--------------
 .../scala/leon/utils/UnitElimination.scala    |  2 +-
 .../xlang/ImperativeCodeElimination.scala     |  2 +-
 11 files changed, 30 insertions(+), 33 deletions(-)

diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala
index 3e9366400..96215b68b 100644
--- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala
+++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala
@@ -1138,7 +1138,7 @@ trait CodeExtraction extends ASTExtractors {
             case _ =>
               (Nil, restTree)
           }
-          LetDef(funDefWithBody +: other_fds, block)
+          letDef(funDefWithBody +: other_fds, block)
 
         // FIXME case ExDefaultValueFunction
 
diff --git a/src/main/scala/leon/purescala/ExprOps.scala b/src/main/scala/leon/purescala/ExprOps.scala
index 49d945ec6..ed86c1806 100644
--- a/src/main/scala/leon/purescala/ExprOps.scala
+++ b/src/main/scala/leon/purescala/ExprOps.scala
@@ -2129,7 +2129,7 @@ object ExprOps {
 
         fds ++= nfds
 
-        Some(LetDef(nfds.map(_._2), b))
+        Some(letDef(nfds.map(_._2), b))
 
       case FunctionInvocation(tfd, args) =>
         if (fds contains tfd.fd) {
diff --git a/src/main/scala/leon/purescala/Expressions.scala b/src/main/scala/leon/purescala/Expressions.scala
index bb70a6676..f6edb01a4 100644
--- a/src/main/scala/leon/purescala/Expressions.scala
+++ b/src/main/scala/leon/purescala/Expressions.scala
@@ -165,7 +165,7 @@ object Expressions {
     * @param body The body of the expression after the function
     */
   case class LetDef(fds: Seq[FunDef], body: Expr) extends Expr {
-    assert(fds.nonEmpty)
+    require(fds.nonEmpty)
     val getType = body.getType
   }
 
diff --git a/src/main/scala/leon/purescala/ScopeSimplifier.scala b/src/main/scala/leon/purescala/ScopeSimplifier.scala
index f0ff379ff..380554d8c 100644
--- a/src/main/scala/leon/purescala/ScopeSimplifier.scala
+++ b/src/main/scala/leon/purescala/ScopeSimplifier.scala
@@ -7,6 +7,7 @@ import Common._
 import Definitions._
 import Expressions._
 import Extractors._
+import Constructors.letDef
 
 class ScopeSimplifier extends Transformer {
   case class Scope(inScope: Set[Identifier] = Set(), oldToNew: Map[Identifier, Identifier] = Map(), funDefs: Map[FunDef, FunDef] = Map()) {
@@ -59,7 +60,7 @@ class ScopeSimplifier extends Transformer {
       for((newFd, fd) <- fds_mapping) {
         newFd.fullBody = rec(fd.fullBody, newScope)
       }
-      LetDef(fds_mapping.map(_._1), rec(body, newScope))
+      letDef(fds_mapping.map(_._1), rec(body, newScope))
    
     case MatchExpr(scrut, cases) =>
       val rs = rec(scrut, scope)
diff --git a/src/main/scala/leon/purescala/TypeOps.scala b/src/main/scala/leon/purescala/TypeOps.scala
index db655365c..81b8420c3 100644
--- a/src/main/scala/leon/purescala/TypeOps.scala
+++ b/src/main/scala/leon/purescala/TypeOps.scala
@@ -335,7 +335,7 @@ object TypeOps {
             }
             val newBd = srec(subCalls(bd)).copiedFrom(bd)
 
-            LetDef(newFds, newBd).copiedFrom(l)
+            letDef(newFds, newBd).copiedFrom(l)
 
           case l @ Lambda(args, body) =>
             val newArgs = args.map { arg =>
diff --git a/src/main/scala/leon/solvers/SolverUnsupportedError.scala b/src/main/scala/leon/solvers/SolverUnsupportedError.scala
index 5d519160d..2efc8ea39 100644
--- a/src/main/scala/leon/solvers/SolverUnsupportedError.scala
+++ b/src/main/scala/leon/solvers/SolverUnsupportedError.scala
@@ -7,7 +7,7 @@ import purescala.Common.Tree
 
 object SolverUnsupportedError {
   def msg(t: Tree, s: Solver, reason: Option[String]) = {
-    s" is unsupported by solver ${s.name}" + reason.map(":\n  " + _ ).getOrElse("")
+    s"(of ${t.getClass}) is unsupported by solver ${s.name}" + reason.map(":\n  " + _ ).getOrElse("")
   }
 }
 
diff --git a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala
index ac1e8855a..04496ae9c 100644
--- a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala
+++ b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala
@@ -276,7 +276,7 @@ trait AbstractZ3Solver extends Solver {
       // variable has to remain in a map etc.
       variables.aToB.collect{ case (Variable(id), p2) => id -> p2 }
     }
-     new Z3StringConversion[Z3AST] {
+    new Z3StringConversion[Z3AST] {
         def getProgram = AbstractZ3Solver.this.program
         def convertToTarget(e: Expr)(implicit bindings: Map[Identifier, Z3AST]): Z3AST = {
           rec(e)
@@ -538,7 +538,7 @@ trait AbstractZ3Solver extends Solver {
     
             rec(RawArrayValue(from, elems.map{
               case (k, v) => (k, CaseClass(library.someType(t), Seq(v)))
-            }.toMap, CaseClass(library.noneType(t), Seq())))
+            }, CaseClass(library.noneType(t), Seq())))
     
           case MapApply(m, k) =>
             val mt @ MapType(_, t) = normalizeType(m.getType)
diff --git a/src/main/scala/leon/synthesis/rules/CEGISLike.scala b/src/main/scala/leon/synthesis/rules/CEGISLike.scala
index d577f7f9f..f643a7bb0 100644
--- a/src/main/scala/leon/synthesis/rules/CEGISLike.scala
+++ b/src/main/scala/leon/synthesis/rules/CEGISLike.scala
@@ -287,7 +287,7 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) {
               case (e, alt) => IfExpr(alt._1.toVariable, exprOf(alt), e)
             }
           } else {
-            Error(c.getType, "Impossibru")
+            Error(c.getType, s"Empty production rule: $c")
           }
 
           cToFd(c).fullBody = body
@@ -408,11 +408,21 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) {
 
 
       def testForProgram(bValues: Set[Identifier])(ex: Example): Boolean = {
+
         tester(ex, bValues) match {
           case EvaluationResults.Successful(res) =>
             res == BooleanLiteral(true)
 
           case EvaluationResults.RuntimeError(err) =>
+            /*if (err.contains("Empty production rule")) {
+              println(programCTree.asString)
+              println(bValues)
+              println(ex)
+              println(this.getExpr(bValues))
+              (new Throwable).printStackTrace()
+              println(err)
+              println()
+            }*/
             sctx.reporter.debug("RE testing CE: "+err)
             false
 
@@ -431,7 +441,7 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) {
             case (b, builder, cs) =>
               builder(cs.map(getCValue))
           }.getOrElse {
-            simplestValue(c.getType)
+            Error(c.getType, "Impossible assignment of bs")
           }
         }
 
@@ -542,9 +552,6 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) {
         //println(" --- Constraints ---")
         //println(" - "+toFind.asString)
         try {
-          //TODO: WHAT THE F IS THIS?
-          //val bsOrNotBs = andJoin(bsOrdered.map(b => if (bs(b)) b.toVariable else Not(b.toVariable)))
-          //solver.assertCnstr(bsOrNotBs)
           solver.assertCnstr(toFind)
 
           for ((c, alts) <- cTree) {
@@ -813,7 +820,7 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) {
               var doFilter = true
 
               if (validateUpTo > 0) {
-                // Validate the first N programs individualy
+                // Validate the first N programs individually
                 ndProgram.validatePrograms(prunedPrograms.take(validateUpTo)) match {
                   case Left(sols) if sols.nonEmpty =>
                     doFilter = false
@@ -850,7 +857,6 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) {
                       // make sure by validating this candidate with z3
                       true
                     } else {
-                      println("testing failed ?!")
                       // One valid input failed with this candidate, we can skip
                       ndProgram.excludeProgram(bs, false)
                       false
diff --git a/src/main/scala/leon/synthesis/rules/TEGISLike.scala b/src/main/scala/leon/synthesis/rules/TEGISLike.scala
index 91084ae4f..a6e060f89 100644
--- a/src/main/scala/leon/synthesis/rules/TEGISLike.scala
+++ b/src/main/scala/leon/synthesis/rules/TEGISLike.scala
@@ -40,7 +40,7 @@ abstract class TEGISLike[T <: Typed](name: String) extends Rule(name) {
 
         val nTests = if (p.pc == BooleanLiteral(true)) 50 else 20
 
-        val useVanuatoo      = sctx.settings.cegisUseVanuatoo.getOrElse(false)
+        val useVanuatoo = sctx.settings.cegisUseVanuatoo.getOrElse(false)
 
         val inputGenerator: Iterator[Seq[Expr]] = if (useVanuatoo) {
           new VanuatooDataGen(sctx.context, sctx.program).generateFor(p.as, p.pc, nTests, 3000)
@@ -53,8 +53,6 @@ abstract class TEGISLike[T <: Typed](name: String) extends Rule(name) {
 
         val failedTestsStats = new MutableMap[Seq[Expr], Int]().withDefaultValue(0)
 
-        def hasInputExamples = gi.nonEmpty
-
         var n = 1
         def allInputExamples() = {
           if (n == 10 || n == 50 || n % 500 == 0) {
@@ -64,12 +62,10 @@ abstract class TEGISLike[T <: Typed](name: String) extends Rule(name) {
           gi.iterator
         }
 
-        var tests = p.eb.valids.map(_.ins).distinct
-
         if (gi.nonEmpty) {
 
-          val evalParams            = CodeGenParams.default.copy(maxFunctionInvocations = 2000)
-          val evaluator             = new DualEvaluator(sctx.context, sctx.program, evalParams)
+          val evalParams = CodeGenParams.default.copy(maxFunctionInvocations = 2000)
+          val evaluator  = new DualEvaluator(sctx.context, sctx.program, evalParams)
 
           val enum = new MemoizedEnumerator[T, Expr, Generator[T, Expr]](grammar.getProductions)
 
@@ -80,7 +76,6 @@ abstract class TEGISLike[T <: Typed](name: String) extends Rule(name) {
           val allExprs = enum.iterator(params.rootLabel(targetType))
 
           var candidate: Option[Expr] = None
-          var n = 1
 
           def findNext(): Option[Expr] = {
             candidate = None
@@ -111,14 +106,9 @@ abstract class TEGISLike[T <: Typed](name: String) extends Rule(name) {
             candidate
           }
 
-          def toStream: Stream[Solution] = {
-            findNext() match {
-              case Some(e) =>
-                Stream.cons(Solution(BooleanLiteral(true), Set(), e, isTrusted = false), toStream)
-              case None =>
-                Stream.empty
-            }
-          }
+          val toStream = Stream.continually(findNext()).takeWhile(_.nonEmpty).map( e =>
+            Solution(BooleanLiteral(true), Set(), e.get, isTrusted = false)
+          )
 
           RuleClosed(toStream)
         } else {
diff --git a/src/main/scala/leon/utils/UnitElimination.scala b/src/main/scala/leon/utils/UnitElimination.scala
index f4f603393..45fa8bea4 100644
--- a/src/main/scala/leon/utils/UnitElimination.scala
+++ b/src/main/scala/leon/utils/UnitElimination.scala
@@ -125,7 +125,7 @@ object UnitElimination extends TransformationPhase {
             }
           }
           
-          LetDef(newFds, rest)
+          letDef(newFds, rest)
         }
 
       case ite@IfExpr(cond, tExpr, eExpr) =>
diff --git a/src/main/scala/leon/xlang/ImperativeCodeElimination.scala b/src/main/scala/leon/xlang/ImperativeCodeElimination.scala
index 45bb36770..20048c0ba 100644
--- a/src/main/scala/leon/xlang/ImperativeCodeElimination.scala
+++ b/src/main/scala/leon/xlang/ImperativeCodeElimination.scala
@@ -218,7 +218,7 @@ object ImperativeCodeElimination extends UnitPhase[Program] {
       case LetDef(fds, b) =>
 
         if(fds.size > 1) {
-          //TODO: no support for true mutually recursion
+          //TODO: no support for true mutual recursion
           toFunction(LetDef(Seq(fds.head), LetDef(fds.tail, b)))
         } else {
 
-- 
GitLab