From 4dd4ad694cc13d42aa3a84994783ca849bdd051d Mon Sep 17 00:00:00 2001
From: Nicolas Voirol <voirol.nicolas@gmail.com>
Date: Mon, 7 Dec 2015 12:40:04 +0100
Subject: [PATCH] Optimization for && and || short-circuiting

---
 .../templates/QuantificationManager.scala     | 47 +++++++++------
 .../solvers/templates/TemplateGenerator.scala | 58 ++++++++++++++++++-
 .../invalid/PropositionalLogic.scala          |  2 +-
 .../PureScalaVerificationSuite.scala          |  5 --
 4 files changed, 86 insertions(+), 26 deletions(-)

diff --git a/src/main/scala/leon/solvers/templates/QuantificationManager.scala b/src/main/scala/leon/solvers/templates/QuantificationManager.scala
index c1adbf3f1..9bc278f09 100644
--- a/src/main/scala/leon/solvers/templates/QuantificationManager.scala
+++ b/src/main/scala/leon/solvers/templates/QuantificationManager.scala
@@ -442,22 +442,24 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage
         val lambdaSubstMap = lambdas map (lambda => lambda.ids._2 -> encoder.encodeId(lambda.ids._1))
         val substMap = subst.mapValues(Matcher.argValue) ++ baseSubstMap ++ lambdaSubstMap ++ instanceSubst(enablers)
 
-        instantiation ++= Template.instantiate(encoder, QuantificationManager.this,
-          clauses, blockers, applications, Seq.empty, Map.empty[T, Set[Matcher[T]]], lambdas, substMap)
-
-        val msubst = subst.collect { case (c, Right(m)) => c -> m }
-        val substituter = encoder.substitute(substMap)
-
-        for ((b,ms) <- allMatchers; m <- ms) {
-          val sb = enablers ++ (if (b == start) Set.empty else Set(substituter(b)))
-          val sm = m.substitute(substituter, matcherSubst = msubst)
-
-          if (matchers(m)) {
-            handled += sb -> sm
-          } else if (transMatchers(m) && isStrict) {
-            instantiation ++= instCtx.instantiate(sb, sm)(quantifications.toSeq : _*)
-          } else {
-            ignored += sb -> sm
+        if (!skip(substMap)) {
+          instantiation ++= Template.instantiate(encoder, QuantificationManager.this,
+            clauses, blockers, applications, Seq.empty, Map.empty[T, Set[Matcher[T]]], lambdas, substMap)
+
+          val msubst = subst.collect { case (c, Right(m)) => c -> m }
+          val substituter = encoder.substitute(substMap)
+
+          for ((b,ms) <- allMatchers; m <- ms) {
+            val sb = enablers ++ (if (b == start) Set.empty else Set(substituter(b)))
+            val sm = m.substitute(substituter, matcherSubst = msubst)
+
+            if (matchers(m)) {
+              handled += sb -> sm
+            } else if (transMatchers(m) && isStrict) {
+              instantiation ++= instCtx.instantiate(sb, sm)(quantifications.toSeq : _*)
+            } else {
+              ignored += sb -> sm
+            }
           }
         }
       }
@@ -466,6 +468,8 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage
     }
 
     protected def instanceSubst(enablers: Set[T]): Map[T, T]
+
+    protected def skip(subst: Map[T, T]): Boolean = false
   }
 
   private class Quantification (
@@ -515,7 +519,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage
       }
   }
 
-  private class Axiom (
+  private class LambdaAxiom (
     val start: T,
     val blocker: T,
     val guardVar: T,
@@ -536,6 +540,13 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage
       val guardT = if (optEnabler.isDefined) encoder.mkAnd(start, optEnabler.get) else start
       Map(guardVar -> guardT, blocker -> newBlocker)
     }
+
+    override protected def skip(subst: Map[T, T]): Boolean = {
+      val substituter = encoder.substitute(subst)
+      allMatchers.forall { case (b, ms) =>
+        ms.forall(m => matchers(m) || instCtx(Set(substituter(b)) -> m.substitute(substituter)))
+      }
+    }
   }
 
   private def extractQuorums(
@@ -630,7 +641,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage
     var instantiation = Instantiation.empty[T]
 
     for (matchers <- matchQuorums) {
-      val axiom = new Axiom(start, blocker, guardVar, quantified,
+      val axiom = new LambdaAxiom(start, blocker, guardVar, quantified,
         matchers, allMatchers, condVars, exprVars, condTree,
         clauses, blockers, applications, lambdas
       )
diff --git a/src/main/scala/leon/solvers/templates/TemplateGenerator.scala b/src/main/scala/leon/solvers/templates/TemplateGenerator.scala
index 5c5098b77..665242433 100644
--- a/src/main/scala/leon/solvers/templates/TemplateGenerator.scala
+++ b/src/main/scala/leon/solvers/templates/TemplateGenerator.scala
@@ -311,10 +311,64 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T],
           }
 
         case a @ And(parts) =>
-          liftToIfExpr(pathVar, parts, andJoin, (a,b) => IfExpr(a, b, BooleanLiteral(false)))
+          val partitions = groupWhile(parts)(!requireDecomposition(_))
+          partitions.map(andJoin) match {
+            case Seq(e) => e
+            case seq =>
+              val newExpr : Identifier = FreshIdentifier("e", BooleanType, true)
+              storeExpr(newExpr)
+
+              def recAnd(pathVar: Identifier, partitions: Seq[Expr]): Unit = partitions match {
+                case x :: Nil if !requireDecomposition(x) =>
+                  storeGuarded(pathVar, Equals(Variable(newExpr), x))
+
+                case x :: xs =>
+                  val newBool : Identifier = FreshIdentifier("b", BooleanType, true)
+                  storeCond(pathVar, newBool)
+
+                  val xrec = rec(pathVar, x)
+                  storeGuarded(pathVar, Equals(Variable(newBool), xrec))
+                  storeGuarded(pathVar, Implies(Not(Variable(newBool)), Not(Variable(newExpr))))
+
+                  recAnd(newBool, xs)
+
+                case Nil =>
+                  storeGuarded(pathVar, Variable(newExpr))
+              }
+
+              recAnd(pathVar, seq)
+              Variable(newExpr)
+          }
 
         case o @ Or(parts) =>
-          liftToIfExpr(pathVar, parts, orJoin, (a,b) => IfExpr(a, BooleanLiteral(true), b))
+          val partitions = groupWhile(parts)(!requireDecomposition(_))
+          partitions.map(orJoin) match {
+            case Seq(e) => e
+            case seq =>
+              val newExpr : Identifier = FreshIdentifier("e", BooleanType, true)
+              storeExpr(newExpr)
+
+              def recOr(pathVar: Identifier, partitions: Seq[Expr]): Unit = partitions match {
+                case x :: Nil if !requireDecomposition(x) =>
+                  storeGuarded(pathVar, Equals(Variable(newExpr), x))
+
+                case x :: xs =>
+                  val newBool : Identifier = FreshIdentifier("b", BooleanType, true)
+                  storeCond(pathVar, newBool)
+
+                  val xrec = rec(pathVar, x)
+                  storeGuarded(pathVar, Equals(Not(Variable(newBool)), xrec))
+                  storeGuarded(pathVar, Implies(Not(Variable(newBool)), Variable(newExpr)))
+
+                  recOr(newBool, xs)
+
+                case Nil =>
+                  storeGuarded(pathVar, Not(Variable(newExpr)))
+              }
+
+              recOr(pathVar, seq)
+              Variable(newExpr)
+          }
 
         case i @ IfExpr(cond, thenn, elze) => {
           if(!requireDecomposition(i)) {
diff --git a/src/test/resources/regression/verification/purescala/invalid/PropositionalLogic.scala b/src/test/resources/regression/verification/purescala/invalid/PropositionalLogic.scala
index ea73b834b..a8927f360 100644
--- a/src/test/resources/regression/verification/purescala/invalid/PropositionalLogic.scala
+++ b/src/test/resources/regression/verification/purescala/invalid/PropositionalLogic.scala
@@ -67,7 +67,7 @@ object PropositionalLogic {
   //   nnf(simplify(f)) == simplify(nnf(f))
   // }.holds
 
-  //@induct
+  @induct
   def simplifyBreaksNNF(f: Formula) : Boolean = {
     require(isNNF(f))
     isNNF(simplify(f))
diff --git a/src/test/scala/leon/regression/verification/purescala/PureScalaVerificationSuite.scala b/src/test/scala/leon/regression/verification/purescala/PureScalaVerificationSuite.scala
index 629e2c677..974615e64 100644
--- a/src/test/scala/leon/regression/verification/purescala/PureScalaVerificationSuite.scala
+++ b/src/test/scala/leon/regression/verification/purescala/PureScalaVerificationSuite.scala
@@ -66,11 +66,6 @@ class PureScalaValidSuiteCVC4 extends PureScalaValidSuite {
 }
 
 class PureScalaInvalidSuite extends PureScalaVerificationSuite {
-  override val ignored = Seq(
-    "verification/purescala/invalid/PropositionalLogic.scala",
-    "verification/purescala/invalid/InductiveQuantification.scala"
-  )
-
   override def testAll() = testInvalid()
   val optionVariants = opts
 }
-- 
GitLab