From aab7b7f36f61f37f8cd699cea4f8b0311f276b13 Mon Sep 17 00:00:00 2001
From: Etienne Kneuss <colder@php.net>
Date: Tue, 25 Feb 2014 13:29:36 +0100
Subject: [PATCH] EnumeratingSolver / PortfolioSolver

Use a datagen-based solver to find simple counter-examples. Note that
this solver returns Unknown most of the time, so it is best to combine
it with a full-fledged solver.

PortfolioSolver allows us to combine solvers and have them run in
parallel. The first result (!= Unknown) is used. Solvers can be selected
for verification using the --solvers option.
---
 .../scala/leon/datagen/DataGenerator.scala    |  15 +-
 .../scala/leon/datagen/NaiveDataGen.scala     |   1 +
 .../scala/leon/datagen/VanuatooDataGen.scala  |  26 ++--
 .../leon/solvers/EnumerationSolver.scala      |  73 +++++++++
 .../solvers/combinators/PortfolioSolver.scala |  84 +++++++++++
 .../scala/leon/synthesis/Synthesizer.scala    |   2 +-
 .../leon/verification/AnalysisPhase.scala     | 138 ++++++++++--------
 .../verification/VerificationContext.scala    |   2 +-
 .../test/solvers/EnumerationSolverTests.scala |  44 ++++++
 9 files changed, 309 insertions(+), 76 deletions(-)
 create mode 100644 src/main/scala/leon/solvers/EnumerationSolver.scala
 create mode 100644 src/main/scala/leon/solvers/combinators/PortfolioSolver.scala
 create mode 100644 src/test/scala/leon/test/solvers/EnumerationSolverTests.scala

diff --git a/src/main/scala/leon/datagen/DataGenerator.scala b/src/main/scala/leon/datagen/DataGenerator.scala
index c15f07a2b..810083dd1 100644
--- a/src/main/scala/leon/datagen/DataGenerator.scala
+++ b/src/main/scala/leon/datagen/DataGenerator.scala
@@ -5,7 +5,20 @@ package datagen
 
 import purescala.Trees._
 import purescala.Common._
+import utils._
 
-trait DataGenerator {
+import java.util.concurrent.atomic.AtomicBoolean
+
+trait DataGenerator extends Interruptible {
   def generateFor(ins: Seq[Identifier], satisfying: Expr, maxValid: Int, maxEnumerated: Int): Iterator[Seq[Expr]];
+
+  protected val interrupted: AtomicBoolean = new AtomicBoolean(false)
+
+  def interrupt(): Unit = {
+    interrupted.set(true)
+  }
+
+  def recoverInterrupt(): Unit = {
+    interrupted.set(false)
+  }
 }
diff --git a/src/main/scala/leon/datagen/NaiveDataGen.scala b/src/main/scala/leon/datagen/NaiveDataGen.scala
index 5ec8577bd..5bd33cb85 100644
--- a/src/main/scala/leon/datagen/NaiveDataGen.scala
+++ b/src/main/scala/leon/datagen/NaiveDataGen.scala
@@ -118,6 +118,7 @@ class NaiveDataGen(ctx: LeonContext, p: Program, evaluator: Evaluator, _bounds :
 
       naryProduct(ins.map(id => generate(id.getType, bounds)))
         .take(maxEnumerated)
+        .takeWhile(s => !interrupted.get)
         .filter{s => evalFun(s) == sat }
         .take(maxValid)
         .iterator
diff --git a/src/main/scala/leon/datagen/VanuatooDataGen.scala b/src/main/scala/leon/datagen/VanuatooDataGen.scala
index f934ac4bd..e94415131 100644
--- a/src/main/scala/leon/datagen/VanuatooDataGen.scala
+++ b/src/main/scala/leon/datagen/VanuatooDataGen.scala
@@ -87,11 +87,12 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator {
 
       unit.jvmClassToLeonClass(cc.getClass.getName) match {
         case Some(ccd: CaseClassDef) =>
+          val cct = CaseClassType(ccd, ct.tps)
           val c = ct match {
             case act : AbstractClassType =>
-              getConstructorFor(CaseClassType(ccd, ct.tps), act)
+              getConstructorFor(cct, act)
             case cct : CaseClassType =>
-              getConstructors(CaseClassType(ccd, ct.tps))(0)
+              getConstructors(cct)(0)
           }
 
           val fields = cc.productElements()
@@ -99,7 +100,7 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator {
           val elems = for (i <- 0 until fields.length) yield {
             if (((r >> i) & 1) == 1) {
               // has been read
-              valueToPattern(fields(i), ct.fieldsTypes(i))
+              valueToPattern(fields(i), cct.fieldsTypes(i))
             } else {
               (AnyPattern[Expr, TypeTree](), false)
             }
@@ -158,6 +159,9 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator {
 
           (EvaluationResults.Successful(result), if (!pattern._2) Some(pattern._1) else None)
         } catch {
+          case e : ClassCastException  =>
+            (EvaluationResults.RuntimeError(e.getMessage), None)
+
           case e : ArithmeticException =>
             (EvaluationResults.RuntimeError(e.getMessage), None)
 
@@ -228,7 +232,7 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator {
 
 
       def computeNext(): Option[Seq[Expr]] = {
-        while(total < maxEnumerated && found < maxValid && it.hasNext) {
+        while(total < maxEnumerated && found < maxValid && it.hasNext && !interrupted.get) {
           val model = it.next.asInstanceOf[Tuple]
 
           if (model eq null) {
@@ -250,10 +254,10 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator {
             }
 
             if (!failed) {
-              println("Got model:")
-              for ((i, v) <- (ins zip model.exprs)) {
-                println(" - "+i+" -> "+v)
-              }
+              //println("Got model:")
+              //for ((i, v) <- (ins zip model.exprs)) {
+              //  println(" - "+i+" -> "+v)
+              //}
 
               found += 1
 
@@ -264,9 +268,9 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator {
               return Some(model.exprs);
             }
 
-            if (total % 1000 == 0) {
-              println("... "+total+" ...")
-            }
+            //if (total % 1000 == 0) {
+            //  println("... "+total+" ...")
+            //}
           }
         }
         None
diff --git a/src/main/scala/leon/solvers/EnumerationSolver.scala b/src/main/scala/leon/solvers/EnumerationSolver.scala
new file mode 100644
index 000000000..4f5d71659
--- /dev/null
+++ b/src/main/scala/leon/solvers/EnumerationSolver.scala
@@ -0,0 +1,73 @@
+/* Copyright 2009-2013 EPFL, Lausanne */
+
+package leon
+package solvers
+
+import utils._
+import purescala.Common._
+import purescala.Definitions._
+import purescala.Trees._
+import purescala.Extractors._
+import purescala.TreeOps._
+import purescala.TypeTrees._
+
+import datagen._
+
+
+class EnumerationSolver(val context: LeonContext, val program: Program) extends Solver with Interruptible {
+  def name = "Enum"
+
+  val maxTried = 10000;
+
+  var datagen: DataGenerator = _
+
+  var freeVars    = List[Identifier]()
+  var constraints = List[Expr]()
+
+  def assertCnstr(expression: Expr): Unit = {
+    constraints ::= expression
+
+    val newFreeVars = (variablesOf(expression) -- freeVars).toList
+    freeVars = freeVars ::: newFreeVars
+  }
+
+  private var modelMap = Map[Identifier, Expr]()
+
+  def check: Option[Boolean] = {
+    try { 
+      val muteContext = context.copy(reporter = new DefaultReporter(context.settings))
+      datagen = new VanuatooDataGen(muteContext, program)
+
+      modelMap = Map()
+
+      val it = datagen.generateFor(freeVars, And(constraints.reverse), 1, maxTried)
+
+      if (it.hasNext) {
+        val model = it.next
+        modelMap = (freeVars zip model).toMap
+        Some(true)
+      } else {
+        None
+      }
+    } catch {
+      case e: codegen.CompilationException =>
+        None
+    }
+  }
+
+  def getModel: Map[Identifier, Expr] = {
+    modelMap
+  }
+
+  def free() = {
+    constraints = Nil
+  }
+
+  def interrupt(): Unit = {
+    Option(datagen).foreach(_.interrupt)
+  }
+
+  def recoverInterrupt(): Unit = {
+    Option(datagen).foreach(_.recoverInterrupt)
+  }
+}
diff --git a/src/main/scala/leon/solvers/combinators/PortfolioSolver.scala b/src/main/scala/leon/solvers/combinators/PortfolioSolver.scala
new file mode 100644
index 000000000..1c75bf3b8
--- /dev/null
+++ b/src/main/scala/leon/solvers/combinators/PortfolioSolver.scala
@@ -0,0 +1,84 @@
+/* Copyright 2009-2013 EPFL, Lausanne */
+
+package leon
+package solvers
+package combinators
+
+import purescala.Common._
+import purescala.Definitions._
+import purescala.Trees._
+import purescala.TreeOps._
+import purescala.TypeTrees._
+
+import utils.Interruptible
+import scala.concurrent._
+import scala.concurrent.duration._
+
+import scala.collection.mutable.{Map=>MutableMap}
+
+import ExecutionContext.Implicits.global
+
+class PortfolioSolver(val context: LeonContext, solvers: Seq[SolverFactory[Solver with Interruptible]])
+        extends Solver with Interruptible {
+
+  val name = "Pfolio"
+
+  var constraints = List[Expr]()
+
+  def assertCnstr(expression: Expr): Unit = {
+    constraints ::= expression
+  }
+
+  private var modelMap = Map[Identifier, Expr]()
+  private var solversInsts = Seq[Solver with Interruptible]()
+
+  def check: Option[Boolean] = {
+    modelMap = Map()
+
+    // create fresh solvers
+    solversInsts = solvers.map(_.getNewSolver)
+
+    // assert
+    solversInsts.foreach { s =>
+      s.assertCnstr(And(constraints.reverse))
+    }
+
+    // solving
+    val fs = solversInsts.map { s =>
+      Future {
+        (s, s.check, s.getModel)
+      }
+    }
+
+    val result = Future.find(fs)(_._2.isDefined)
+
+    val res = Await.result(result, 10.days) match {
+      case Some((s, r, m)) =>
+        modelMap = m
+        solversInsts.foreach(_.interrupt)
+        r
+      case None =>
+        None
+    }
+
+    solversInsts.foreach(_.free)
+
+    res
+  }
+
+  def getModel: Map[Identifier, Expr] = {
+    modelMap
+  }
+
+  def free() = {
+    constraints = Nil
+  }
+
+  def interrupt(): Unit = {
+    solversInsts.foreach(_.interrupt())
+  }
+
+  def recoverInterrupt(): Unit = {
+    solversInsts.foreach(_.recoverInterrupt())
+  }
+}
diff --git a/src/main/scala/leon/synthesis/Synthesizer.scala b/src/main/scala/leon/synthesis/Synthesizer.scala
index 766b4a42e..f33071f35 100644
--- a/src/main/scala/leon/synthesis/Synthesizer.scala
+++ b/src/main/scala/leon/synthesis/Synthesizer.scala
@@ -83,7 +83,7 @@ class Synthesizer(val context : LeonContext,
 
     val solverf = SolverFactory(() => (new FairZ3Solver(context, npr) with TimeoutSolver).setTimeout(timeoutMs))
 
-    val vctx = VerificationContext(context, npr, Seq(solverf), context.reporter)
+    val vctx = VerificationContext(context, npr, solverf, context.reporter)
     val vcs = generateVerificationConditions(vctx, fds.map(_.id.name))
     val vcreport = checkVerificationConditions(vctx, vcs)
 
diff --git a/src/main/scala/leon/verification/AnalysisPhase.scala b/src/main/scala/leon/verification/AnalysisPhase.scala
index a0bf5fe60..25ec22efc 100644
--- a/src/main/scala/leon/verification/AnalysisPhase.scala
+++ b/src/main/scala/leon/verification/AnalysisPhase.scala
@@ -23,6 +23,7 @@ object AnalysisPhase extends LeonPhase[Program,VerificationReport] {
 
   override val definedOptions : Set[LeonOptionDef] = Set(
     LeonValueOptionDef("functions", "--functions=f1:f2", "Limit verification to f1,f2,..."),
+    LeonValueOptionDef("solvers",   "--solvers=s1,s2",   "Use solvers s1 and s2 (fairz3,enum)", default = Some("fairz3")),
     LeonValueOptionDef("timeout",   "--timeout=T",       "Timeout after T seconds when trying to prove a verification condition.")
   )
 
@@ -83,7 +84,7 @@ object AnalysisPhase extends LeonPhase[Program,VerificationReport] {
 
   def checkVerificationConditions(vctx: VerificationContext, vcs: Map[FunDef, List[VerificationCondition]]) : VerificationReport = {
     import vctx.reporter
-    import vctx.solvers
+    import vctx.solverFactory
     import vctx.program
 
     val interruptManager = vctx.context.interruptManager
@@ -100,60 +101,54 @@ object AnalysisPhase extends LeonPhase[Program,VerificationReport] {
       reporter.debug("Verification condition (" + vcInfo.kind + ") for ==== " + funDef.id + " ====")
       reporter.debug(simplifyLets(vc).asString(vctx.context))
 
-      // try all solvers until one returns a meaningful answer
-      solvers.find(sf => {
-        val s = sf.getNewSolver
-        try {
-          reporter.debug("Trying with solver: " + s.name)
-          val t1 = System.nanoTime
-          s.assertCnstr(Not(vc))
-
-          val satResult = s.check
-          val counterexample: Map[Identifier, Expr] = if (satResult == Some(true)) s.getModel else Map()
-          val solverResult = satResult.map(!_)
-
-          val t2 = System.nanoTime
-          val dt = ((t2 - t1) / 1000000) / 1000.0
-
-          solverResult match {
-            case _ if interruptManager.isInterrupted() =>
-              reporter.info("=== CANCELLED ===")
-              vcInfo.time = Some(dt)
-              false
-
-            case None =>
-              vcInfo.time = Some(dt)
-              false
-
-            case Some(true) =>
-              reporter.info("==== VALID ====")
-
-              vcInfo.hasValue = true
-              vcInfo.value = Some(true)
-              vcInfo.solvedWith = Some(s)
-              vcInfo.time = Some(dt)
-              true
-
-            case Some(false) =>
-              reporter.error("Found counter-example : ")
-              reporter.error(counterexample.toSeq.sortBy(_._1.name).map(p => p._1 + " -> " + p._2).mkString("\n"))
-              reporter.error("==== INVALID ====")
-              vcInfo.hasValue = true
-              vcInfo.value = Some(false)
-              vcInfo.solvedWith = Some(s)
-              vcInfo.counterExample = Some(counterexample)
-              vcInfo.time = Some(dt)
-              true
-          }
-        } finally {
-          s.free()
-        }}) match {
-          case None => {
+      val s = solverFactory.getNewSolver
+      try {
+        reporter.debug("Trying with solver: " + s.name)
+        val t1 = System.nanoTime
+        s.assertCnstr(Not(vc))
+
+        val satResult = s.check
+        val counterexample: Map[Identifier, Expr] = if (satResult == Some(true)) s.getModel else Map()
+        val solverResult = satResult.map(!_)
+
+        val t2 = System.nanoTime
+        val dt = ((t2 - t1) / 1000000) / 1000.0
+
+        solverResult match {
+          case _ if interruptManager.isInterrupted() =>
+            reporter.info("=== CANCELLED ===")
+            vcInfo.time = Some(dt)
+            false
+
+          case None =>
             vcInfo.hasValue = true
             reporter.warning("==== UNKNOWN ====")
-          }
-          case _ =>
+            vcInfo.time = Some(dt)
+            false
+
+          case Some(true) =>
+            reporter.info("==== VALID ====")
+
+            vcInfo.hasValue = true
+            vcInfo.value = Some(true)
+            vcInfo.solvedWith = Some(s)
+            vcInfo.time = Some(dt)
+            true
+
+          case Some(false) =>
+            reporter.error("Found counter-example : ")
+            reporter.error(counterexample.toSeq.sortBy(_._1.name).map(p => p._1 + " -> " + p._2).mkString("\n"))
+            reporter.error("==== INVALID ====")
+            vcInfo.hasValue = true
+            vcInfo.value = Some(false)
+            vcInfo.solvedWith = Some(s)
+            vcInfo.counterExample = Some(counterexample)
+            vcInfo.time = Some(dt)
+            true
         }
+      } finally {
+        s.free()
+      }
     }
 
     val report = new VerificationReport(vcs)
@@ -163,33 +158,52 @@ object AnalysisPhase extends LeonPhase[Program,VerificationReport] {
   def run(ctx: LeonContext)(program: Program) : VerificationReport = {
     var functionsToAnalyse   = Set[String]()
     var timeout: Option[Int] = None
+    var selectedSolvers      = Set[String]("fairz3")
+
+    val allSolvers = Map(
+      "fairz3" -> SolverFactory(() => new FairZ3Solver(ctx, program) with TimeoutSolver),
+      "enum"   -> SolverFactory(() => new EnumerationSolver(ctx, program) with TimeoutSolver)
+    )
+
+    val reporter = ctx.reporter
 
     for(opt <- ctx.options) opt match {
       case LeonValueOption("functions", ListValue(fs)) =>
         functionsToAnalyse = Set() ++ fs
 
+      case LeonValueOption("solvers", ListValue(ss)) =>
+        val unknownSolvers = ss.toSet -- allSolvers.keySet
+        if (unknownSolvers.nonEmpty) {
+          reporter.error("Unknown solver(s): "+unknownSolvers.mkString(", ")+" (Available: "+allSolvers.keys.mkString(", ")+")")
+        }
+        selectedSolvers = Set() ++ ss
+
       case v @ LeonValueOption("timeout", _) =>
         timeout = v.asInt(ctx)
 
       case _ =>
     }
 
-    val reporter = ctx.reporter
+    // Solvers selection and validation
+    val solversToUse = allSolvers.filterKeys(selectedSolvers)
 
-    val baseFactories = Seq(
-      SolverFactory(() => new FairZ3Solver(ctx, program) with TimeoutSolver)
-    )
+    val entrySolver = if (solversToUse.isEmpty) {
+      reporter.fatalError("No solver selected. Aborting")
+    } else if (solversToUse.size == 1) {
+      solversToUse.values.head
+    } else {
+      SolverFactory( () => new PortfolioSolver(ctx, solversToUse.values.toSeq) with TimeoutSolver)
+    }
 
-    val solverFactories = timeout match {
+
+    val mainSolver = timeout match {
       case Some(sec) =>
-        baseFactories.map { sf =>
-          new TimeoutSolverFactory(sf, sec*1000L)
-        }
+        new TimeoutSolverFactory(entrySolver, sec*1000L)
       case None =>
-        baseFactories
+        entrySolver
     }
 
-    val vctx = VerificationContext(ctx, program, solverFactories, reporter)
+    val vctx = VerificationContext(ctx, program, mainSolver, reporter)
 
     reporter.debug("Running verification condition generation...")
     val vcs = generateVerificationConditions(vctx, functionsToAnalyse)
diff --git a/src/main/scala/leon/verification/VerificationContext.scala b/src/main/scala/leon/verification/VerificationContext.scala
index 84692c3a7..50d3012ea 100644
--- a/src/main/scala/leon/verification/VerificationContext.scala
+++ b/src/main/scala/leon/verification/VerificationContext.scala
@@ -11,6 +11,6 @@ import java.util.concurrent.atomic.AtomicBoolean
 case class VerificationContext (
   context: LeonContext,
   program: Program,
-  solvers: Seq[SolverFactory[Solver]],
+  solverFactory: SolverFactory[Solver],
   reporter: Reporter
 )
diff --git a/src/test/scala/leon/test/solvers/EnumerationSolverTests.scala b/src/test/scala/leon/test/solvers/EnumerationSolverTests.scala
new file mode 100644
index 000000000..44ceb91e1
--- /dev/null
+++ b/src/test/scala/leon/test/solvers/EnumerationSolverTests.scala
@@ -0,0 +1,44 @@
+/* Copyright 2009-2013 EPFL, Lausanne */
+
+package leon.test
+package solvers
+
+import leon._
+import leon.utils.Interruptible
+import leon.solvers._
+import leon.solvers.combinators._
+import leon.purescala.Common._
+import leon.purescala.Definitions._
+import leon.purescala.Trees._
+import leon.purescala.TypeTrees._
+
+class EnumerationSolverTests extends LeonTestSuite {
+  private def check(sf: SolverFactory[Solver], e: Expr): Option[Boolean] = {
+    val s = sf.getNewSolver
+    s.assertCnstr(e)
+    s.check
+  }
+
+  private def getSolver = {
+    SolverFactory(() => new EnumerationSolver(testContext, Program.empty))
+  }
+
+  test("EnumerationSolver 1 (true)") {
+    val sf = getSolver
+    assert(check(sf, BooleanLiteral(true)) === Some(true))
+  }
+
+  test("EnumerationSolver 2 (x == 1)") {
+    val sf = getSolver
+    val x = Variable(FreshIdentifier("x").setType(Int32Type))
+    val o = IntLiteral(1)
+    assert(check(sf, Equals(x, o)) === Some(true))
+  }
+
+  test("EnumerationSolver 3 (Limited range for ints)") {
+    val sf = getSolver
+    val x = Variable(FreshIdentifier("x").setType(Int32Type))
+    val o = IntLiteral(42)
+    assert(check(sf, Equals(x, o)) === None)
+  }
+}
-- 
GitLab