From 446712cedfeaf261150a18e46135aeaa05d0cbb0 Mon Sep 17 00:00:00 2001
From: Manos Koukoutos <emmanouil.koukoutos@epfl.ch>
Date: Tue, 1 Mar 2016 13:34:34 +0100
Subject: [PATCH] CEGIS Improvements

More reasonable return type for solveSATWithCores
CEGIS discards tests that crash the evaluator
Try harder to find verifiable solution in validatePrograms
Increase default CEGIS size
---
 .../scala/leon/solvers/SimpleSolverAPI.scala  | 11 +--
 .../scala/leon/synthesis/SynthesisPhase.scala |  2 +-
 .../leon/synthesis/SynthesisSettings.scala    |  2 +-
 .../leon/synthesis/rules/CEGISLike.scala      | 92 ++++++++++++-------
 .../scala/leon/utils/GrowableIterable.scala   |  4 +-
 testcases/synthesis/current/run.sh            |  2 +-
 6 files changed, 70 insertions(+), 43 deletions(-)

diff --git a/src/main/scala/leon/solvers/SimpleSolverAPI.scala b/src/main/scala/leon/solvers/SimpleSolverAPI.scala
index 33f6f1336..e5e7c7dd3 100644
--- a/src/main/scala/leon/solvers/SimpleSolverAPI.scala
+++ b/src/main/scala/leon/solvers/SimpleSolverAPI.scala
@@ -3,7 +3,6 @@
 package leon
 package solvers
 
-import purescala.Common._
 import purescala.Expressions._
 
 class SimpleSolverAPI(sf: SolverFactory[Solver]) {
@@ -34,17 +33,17 @@ class SimpleSolverAPI(sf: SolverFactory[Solver]) {
     }
   }
 
-  def solveSATWithCores(expression: Expr, assumptions: Set[Expr]): (Option[Boolean], Model, Set[Expr]) = {
+  def solveSATWithCores(expression: Expr, assumptions: Set[Expr]): Option[Either[Set[Expr], Model]] = {
     val s = sf.getNewSolver()
     try {
       s.assertCnstr(expression)
       s.checkAssumptions(assumptions) match {
-        case Some(true) =>
-          (Some(true), s.getModel, Set())
         case Some(false) =>
-          (Some(false), Model.empty, s.getUnsatCore)
+          Some(Left(s.getUnsatCore))
+        case Some(true) =>
+          Some(Right(s.getModel))
         case None =>
-          (None, Model.empty, Set())
+          None
       }
     } finally {
       sf.reclaim(s)
diff --git a/src/main/scala/leon/synthesis/SynthesisPhase.scala b/src/main/scala/leon/synthesis/SynthesisPhase.scala
index da8b5f748..7c90d3e3b 100644
--- a/src/main/scala/leon/synthesis/SynthesisPhase.scala
+++ b/src/main/scala/leon/synthesis/SynthesisPhase.scala
@@ -22,7 +22,7 @@ object SynthesisPhase extends TransformationPhase {
   val optCEGISOptTimeout   = LeonFlagOptionDef("cegis:opttimeout", "Consider a time-out of CE-search as untrusted solution", true )
   val optCEGISVanuatoo     = LeonFlagOptionDef("cegis:vanuatoo",   "Generate inputs using new korat-style generator",        false)
   val optCEGISNaiveGrammar = LeonFlagOptionDef("cegis:naive",      "Use the old naive grammar for CEGIS",                    false)
-  val optCEGISMaxSize      = LeonLongOptionDef("cegis:maxsize",    "Maximum size of expressions synthesized by CEGIS", 5L, "N")
+  val optCEGISMaxSize      = LeonLongOptionDef("cegis:maxsize",    "Maximum size of expressions synthesized by CEGIS", 7L, "N")
 
   // Other rule options
   val optSpecifyRecCalls = LeonFlagOptionDef("reccalls", "Use full value as spec for introduced recursive calls", true)
diff --git a/src/main/scala/leon/synthesis/SynthesisSettings.scala b/src/main/scala/leon/synthesis/SynthesisSettings.scala
index 61dc24ece..f7951c464 100644
--- a/src/main/scala/leon/synthesis/SynthesisSettings.scala
+++ b/src/main/scala/leon/synthesis/SynthesisSettings.scala
@@ -18,6 +18,6 @@ case class SynthesisSettings(
   // Cegis related options
   cegisUseOptTimeout: Boolean = true,
   cegisUseVanuatoo  : Boolean = false,
-  cegisMaxSize: Int           = 5
+  cegisMaxSize: Int           = 7
 
 )
diff --git a/src/main/scala/leon/synthesis/rules/CEGISLike.scala b/src/main/scala/leon/synthesis/rules/CEGISLike.scala
index a12f0ed34..b3ce025da 100644
--- a/src/main/scala/leon/synthesis/rules/CEGISLike.scala
+++ b/src/main/scala/leon/synthesis/rules/CEGISLike.scala
@@ -396,7 +396,8 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) {
       }
 
       // Tests a candidate solution against an example in the correct environment
-      def testForProgram(bValues: Set[Identifier])(ex: Example): Boolean = {
+      // None -> evaluator error
+      def testForProgram(bValues: Set[Identifier])(ex: Example): Option[Boolean] = {
 
         def redundant(e: Expr): Boolean = {
           val (op1, op2) = e match {
@@ -441,7 +442,7 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) {
         // Deactivated for now, since it doesnot seem to help
         if (redundancyCheck && params.optimizations && exists(redundant)(outerSol)) {
           excludeProgram(bs, true)
-          return false
+          return Some(false)
         }
         val innerSol = outerExprToInnerExpr(outerSol)
         val cnstr = letTuple(p.xs, innerSol, innerPhi)
@@ -462,7 +463,7 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) {
 
         res match {
           case EvaluationResults.Successful(res) =>
-            res == BooleanLiteral(true)
+            Some(res == BooleanLiteral(true))
 
           case EvaluationResults.RuntimeError(err) =>
             /*if (err.contains("Empty production rule")) {
@@ -475,11 +476,11 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) {
               println()
             }*/
             hctx.reporter.debug("RE testing CE: "+err)
-            false
+            Some(false)
 
           case EvaluationResults.EvaluatorError(err) =>
             hctx.reporter.debug("Error testing CE: "+err)
-            false
+            None
         }
 
       }
@@ -508,6 +509,8 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) {
 
         var cexs = Seq[Seq[Expr]]()
 
+        var best: Option[Solution] = None
+
         for (bs <- bss.toSeq) {
           // We compute the corresponding expr and replace it in place of the C-tree
           val outerSol = getExpr(bs)
@@ -521,7 +524,7 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) {
           val eval = new DefaultEvaluator(hctx, innerProgram)
 
           if (cexs exists (cex => eval.eval(cnstr, p.as.zip(cex).toMap).result == Some(BooleanLiteral(true)))) {
-            //println(s"Program $outerSol fails!")
+            hctx.reporter.debug(s"Rejected by CEX: $outerSol")
             excludeProgram(bs, true)
             cTreeFd.fullBody = origImpl
           } else {
@@ -530,9 +533,11 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) {
             val solverf = SolverFactory.getFromSettings(hctx, innerProgram).withTimeout(cexSolverTo)
             val solver = solverf.getNewSolver()
             try {
+              hctx.reporter.debug("Sending candidate to solver...")
               solver.assertCnstr(cnstr)
               solver.check match {
                 case Some(true) =>
+                  hctx.reporter.debug(s"Proven invalid: $outerSol")
                   excludeProgram(bs, true)
                   val model = solver.getModel
                   //println("Found counter example: ")
@@ -547,14 +552,14 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) {
 
                 case Some(false) =>
                   // UNSAT, valid program
+                  hctx.reporter.debug("Found valid program!")
                   return Right(Solution(BooleanLiteral(true), Set(), outerSol, true))
 
                 case None =>
                   if (useOptTimeout) {
-                    // Interpret timeout in CE search as "the candidate is valid"
-                    hctx.reporter.info("CEGIS could not prove the validity of the resulting expression")
                     // Optimistic valid solution
-                    return Right(Solution(BooleanLiteral(true), Set(), outerSol, false))
+                    hctx.reporter.debug("Found a non-verifiable solution...")
+                    best = Some(Solution(BooleanLiteral(true), Set(), outerSol, false))
                   }
               }
             } finally {
@@ -565,7 +570,11 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) {
           }
         }
 
-        Left(cexs)
+        best.map{ sol =>
+          // Interpret timeout in CE search as "the candidate is valid"
+          hctx.reporter.info("CEGIS could not prove the validity of the resulting expression")
+          Right(sol)
+        }.getOrElse(Left(cexs))
       }
 
       def allProgramsClosed = prunedPrograms.isEmpty
@@ -802,7 +811,8 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) {
 
             def nPassing = ndProgram.prunedPrograms.size
 
-            def programsReduced() = nInitial / nPassing > testReductionRatio || nPassing <= 10
+            def programsReduced() = nPassing <= 10 || nInitial / nPassing > testReductionRatio 
+
             // This is the starting test-base
             val gi = new GrowableIterable[Example](baseExampleInputs, inputGenerator, programsReduced)
 
@@ -827,21 +837,40 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) {
             //    printer(" - ...")
             //  }
             //}
+    
+            hctx.reporter.debug("#Tests: "+baseExampleInputs.size)
+            hctx.reporter.ifDebug{ printer =>
+              for (e <- baseExampleInputs.take(10)) {
+                printer(" - "+e.asString)
+              }
+              if(baseExampleInputs.size > 10) {
+                printer(" - ...")
+              }
+            }
 
             // We further filter the set of working programs to remove those that fail on known examples
             if (hasInputExamples) {
               timers.filter.start()
               for (bs <- ndProgram.prunedPrograms if !interruptManager.isInterrupted) {
                 val examples = allInputExamples()
-                examples.find(e => !ndProgram.testForProgram(bs)(e)).foreach { e =>
-                  failedTestsStats(e) += 1
-                  hctx.reporter.debug(f" Program: ${ndProgram.getExpr(bs).asString}%-80s failed on: ${e.asString}")
-                  ndProgram.excludeProgram(bs, true)
-                }
-
-                if (ndProgram.excludedPrograms.size+1 % 1000 == 0) {
-                  hctx.reporter.debug("..."+ndProgram.excludedPrograms.size)
+                var badExamples = List[Example]()
+                var stop = false
+                for (e <- examples if !stop) {
+                  ndProgram.testForProgram(bs)(e) match {
+                    case Some(true) => // ok, passes
+                    case Some(false) =>
+                      // Program fails the test
+                      stop = true
+                      failedTestsStats(e) += 1
+                      hctx.reporter.debug(f" Program: ${ndProgram.getExpr(bs).asString}%-80s failed on: ${e.asString}")
+                      ndProgram.excludeProgram(bs, true)
+                    case None =>
+                      // Eval. error -> bad example
+                      hctx.reporter.debug(s" Test $e failed, removing...")
+                      badExamples ::= e
+                  }
                 }
+                gi --= badExamples
               }
               timers.filter.stop()
             }
@@ -855,17 +884,7 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) {
                 printer(" - ...")
               }
             }
-            hctx.reporter.debug("#Tests: "+baseExampleInputs.size)
-            hctx.reporter.ifDebug{ printer =>
-              for (e <- baseExampleInputs.take(10)) {
-                printer(" - "+e.asString)
-              }
-              if(baseExampleInputs.size > 10) {
-                printer(" - ...")
-              }
-            }
-
-            // CEGIS Loop at a given unfolding level
+              // CEGIS Loop at a given unfolding level
             while (result.isEmpty && !interruptManager.isInterrupted && !ndProgram.allProgramsClosed) {
               timers.loop.start()
               hctx.reporter.debug("Programs left: " + ndProgram.prunedPrograms.size)
@@ -891,6 +910,7 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) {
 
               if (result.isEmpty && !ndProgram.allProgramsClosed) {
                 // Phase 1: Find a candidate program that works for at least 1 input
+                hctx.reporter.debug("Looking for program that works on at least 1 input...")
                 ndProgram.solveForTentativeProgram() match {
                   case Some(Some(bs)) =>
                     hctx.reporter.debug(s"Found tentative model ${ndProgram.getExpr(bs)}, need to validate!")
@@ -906,9 +926,15 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) {
 
                         // Retest whether the newly found C-E invalidates some programs
                         ndProgram.prunedPrograms.foreach { p =>
-                          if (!ndProgram.testForProgram(p)(ce)) {
-                            failedTestsStats(ce) += 1
-                            ndProgram.excludeProgram(p, true)
+                          ndProgram.testForProgram(p)(ce) match {
+                            case Some(true) =>
+                            case Some(false) =>
+                              hctx.reporter.debug(f" Program: ${ndProgram.getExpr(p).asString}%-80s failed on: ${ce.asString}")
+                              failedTestsStats(ce) += 1
+                              ndProgram.excludeProgram(p, true)
+                            case None =>
+                              hctx.reporter.debug(s" Test $ce failed, removing...")
+                              gi -= ce
                           }
                         }
 
diff --git a/src/main/scala/leon/utils/GrowableIterable.scala b/src/main/scala/leon/utils/GrowableIterable.scala
index 072ca7119..5720f4967 100644
--- a/src/main/scala/leon/utils/GrowableIterable.scala
+++ b/src/main/scala/leon/utils/GrowableIterable.scala
@@ -15,8 +15,10 @@ class GrowableIterable[T](init: Seq[T], growth: Iterator[T], canGrow: () => Bool
     }
   }
 
-  def +=(more: T) = buffer += more
+  def += (more: T)      = buffer +=  more
   def ++=(more: Seq[T]) = buffer ++= more
+  def -= (less: T)      = buffer -=  less
+  def --=(less: Seq[T]) = buffer --= less
 
   def iterator: Iterator[T] = {
     buffer.iterator ++ cachingIterator
diff --git a/testcases/synthesis/current/run.sh b/testcases/synthesis/current/run.sh
index 522e432df..d16431ec7 100755
--- a/testcases/synthesis/current/run.sh
+++ b/testcases/synthesis/current/run.sh
@@ -1,7 +1,7 @@
 #!/bin/bash
 
 function run {
-    cmd="./leon --debug=report --timeout=30 --synthesis --cegis:maxsize=7 $1"
+    cmd="./leon --debug=report --timeout=60 --synthesis $1"
     echo "Running " $cmd
     echo "------------------------------------------------------------------------------------------------------------------"
     $cmd;
-- 
GitLab