From c9f0f7de007a9d14f88bffab74b327542c896521 Mon Sep 17 00:00:00 2001
From: Etienne Kneuss <ekneuss@gmail.com>
Date: Tue, 13 Nov 2012 16:45:33 +0100
Subject: [PATCH] Move cegis and optimistic ground to rules

---
 .../scala/leon/synthesis/Heuristics.scala     | 219 +---------------
 src/main/scala/leon/synthesis/Rules.scala     | 236 +++++++++++++++++-
 .../scala/leon/synthesis/Synthesizer.scala    |   6 +-
 src/main/scala/leon/synthesis/Task.scala      |   6 +-
 4 files changed, 246 insertions(+), 221 deletions(-)

diff --git a/src/main/scala/leon/synthesis/Heuristics.scala b/src/main/scala/leon/synthesis/Heuristics.scala
index 8cd507344..97f1cb96e 100644
--- a/src/main/scala/leon/synthesis/Heuristics.scala
+++ b/src/main/scala/leon/synthesis/Heuristics.scala
@@ -10,9 +10,7 @@ import purescala.Definitions._
 
 object Heuristics {
   def all = Set[Synthesizer => Rule](
-    new OptimisticGround(_),
-    //new IntInduction(_),
-    new CEGIS(_),
+    new IntInduction(_),
     new OptimisticInjection(_)
   )
 }
@@ -23,62 +21,6 @@ trait Heuristic {
   override def toString = "H: "+name
 }
 
-class OptimisticGround(synth: Synthesizer) extends Rule("Optimistic Ground", synth, 90) with Heuristic {
-  def applyOn(task: Task): RuleResult = {
-    val p = task.problem
-
-    if (!p.as.isEmpty && !p.xs.isEmpty) {
-      val xss = p.xs.toSet
-      val ass = p.as.toSet
-
-      val tpe = TupleType(p.xs.map(_.getType))
-
-      var i = 0;
-      var maxTries = 3;
-
-      var result: Option[RuleResult]   = None
-      var predicates: Seq[Expr]        = Seq()
-
-      while (result.isEmpty && i < maxTries) {
-        val phi = And(p.phi +: predicates)
-        synth.solver.solveSAT(phi) match {
-          case (Some(true), satModel) =>
-            val satXsModel = satModel.filterKeys(xss) 
-
-            val newPhi = valuateWithModelIn(phi, xss, satModel)
-
-            synth.solver.solveSAT(Not(newPhi)) match {
-              case (Some(true), invalidModel) =>
-                // Found as such as the xs break, refine predicates
-                predicates = valuateWithModelIn(phi, ass, invalidModel) +: predicates
-
-              case (Some(false), _) =>
-                result = Some(RuleSuccess(Solution(BooleanLiteral(true), Tuple(p.xs.map(valuateWithModel(satModel))).setType(tpe))))
-
-              case _ =>
-                result = Some(RuleInapplicable)
-            }
-
-          case (Some(false), _) =>
-            if (predicates.isEmpty) {
-              result = Some(RuleSuccess(Solution(BooleanLiteral(false), Error(p.phi+" is UNSAT!").setType(tpe))))
-            } else {
-              result = Some(RuleInapplicable)
-            }
-          case _ =>
-            result = Some(RuleInapplicable)
-        }
-
-        i += 1 
-      }
-
-      result.getOrElse(RuleInapplicable)
-    } else {
-      RuleInapplicable
-    }
-  }
-}
-
 
 class IntInduction(synth: Synthesizer) extends Rule("Int Induction", synth, 80) with Heuristic {
   def applyOn(task: Task): RuleResult = {
@@ -105,7 +47,7 @@ class IntInduction(synth: Synthesizer) extends Rule("Int Induction", synth, 80)
         val onSuccess: List[Solution] => Solution = {
           case List(base, gt, lt) =>
             val newFun = new FunDef(FreshIdentifier("rec", true), tpe, Seq(VarDecl(inductOn, inductOn.getType)))
-            newFun.body = Some( 
+            newFun.body = Some(
               IfExpr(Equals(Variable(inductOn), IntLiteral(0)),
                 base.toExpr,
               IfExpr(GreaterThan(Variable(inductOn), IntLiteral(0)),
@@ -197,160 +139,3 @@ class SelectiveInlining(synth: Synthesizer) extends Rule("Sel. Inlining", synth,
   }
 }
 
-class CEGIS(synth: Synthesizer) extends Rule("CEGIS", synth, 50) with Heuristic {
-  def applyOn(task: Task): RuleResult = {
-    val p = task.problem
-
-    case class Generator(tpe: TypeTree, altBuilder: () => List[(Expr, Set[Identifier])]);
-
-    var generators = Map[TypeTree, Generator]()
-    def getGenerator(t: TypeTree): Generator = generators.get(t) match {
-      case Some(g) => g
-      case None =>
-        val alternatives: () => List[(Expr, Set[Identifier])] = t match {
-          case BooleanType =>
-            { () => List((BooleanLiteral(true), Set()), (BooleanLiteral(false), Set())) }
-
-          case Int32Type =>
-            { () => List((IntLiteral(0), Set()), (IntLiteral(1), Set())) }
-
-          case TupleType(tps) =>
-            { () =>
-              val ids = tps.map(t => FreshIdentifier("t", true).setType(t))
-              List((Tuple(ids.map(Variable(_))), ids.toSet))
-            }
-
-          case CaseClassType(cd) =>
-            { () =>
-              val ids = cd.fieldsIds.map(i => FreshIdentifier("c", true).setType(i.getType))
-              List((CaseClass(cd, ids.map(Variable(_))), ids.toSet))
-            }
-
-          case AbstractClassType(cd) =>
-            { () =>
-              val alts: Seq[(Expr, Set[Identifier])] = cd.knownDescendents.flatMap(i => i match {
-                  case acd: AbstractClassDef =>
-                    synth.reporter.error("Unnexpected abstract class in descendants!")
-                    None
-                  case cd: CaseClassDef =>
-                    val ids = cd.fieldsIds.map(i => FreshIdentifier("c", true).setType(i.getType))
-                    Some((CaseClass(cd, ids.map(Variable(_))), ids.toSet))
-              })
-              alts.toList
-            }
-
-          case _ =>
-            synth.reporter.error("Can't construct generator. Unsupported type: "+t+"["+t.getClass+"]");
-            { () => Nil }
-        }
-        val g = Generator(t, alternatives)
-        generators += t -> g
-        g
-    }
-
-    def inputAlternatives(t: TypeTree): List[(Expr, Set[Identifier])] = {
-      p.as.filter(a => isSubtypeOf(a.getType, t)).map(id => (Variable(id) : Expr, Set[Identifier]()))
-    }
-
-    case class TentativeFormula(phi: Expr,
-                                program: Expr,
-                                mappings: Map[Identifier, (Identifier, Expr)],
-                                recTerms: Map[Identifier, Set[Identifier]]) {
-      def unroll: TentativeFormula = {
-        var newProgram  = List[Expr]()
-        var newRecTerms = Map[Identifier, Set[Identifier]]()
-        var newMappings = Map[Identifier, (Identifier, Expr)]()
-
-        for ((_, recIds) <- recTerms; recId <- recIds) {
-          val gen  = getGenerator(recId.getType)
-          val alts = gen.altBuilder() ::: inputAlternatives(recId.getType)
-
-          val altsWithBranches = alts.map(alt => FreshIdentifier("b", true).setType(BooleanType) -> alt)
-
-          val pre = Or(altsWithBranches.map{ case (id, _) => Variable(id) }) // b1 OR b2
-          val cases = for((bid, (ex, rec)) <- altsWithBranches.toList) yield { // b1 => E(gen1, gen2)     [b1 -> {gen1, gen2}]
-            if (!rec.isEmpty) {
-              newRecTerms += bid -> rec
-            } else {
-              newMappings += bid -> (recId -> ex)
-            }
-
-            Implies(Variable(bid), Equals(Variable(recId), ex))
-          }
-
-          newProgram = newProgram ::: pre :: cases
-        }
-
-        TentativeFormula(phi, And(program :: newProgram), mappings ++ newMappings, newRecTerms)
-      }
-
-      def bounds = recTerms.keySet.map(id => Not(Variable(id))).toList
-      def bss = mappings.keySet
-
-      def entireFormula = And(phi :: program :: bounds)
-    }
-
-    var result: Option[RuleResult]   = None
-
-    var ass = p.as.toSet
-    var xss = p.xs.toSet
-
-    var lastF     = TentativeFormula(And(p.c, p.phi), BooleanLiteral(true), Map(), Map() ++ p.xs.map(x => x -> Set(x)))
-    var currentF  = lastF.unroll
-    var unrolings = 0
-    val maxUnrolings = 2
-    do {
-      println("Was: "+lastF.entireFormula)
-      println("Now Trying : "+currentF.entireFormula)
-
-      val tpe = TupleType(p.xs.map(_.getType))
-      val bss = currentF.bss
-
-      var predicates: Seq[Expr]        = Seq()
-      var continue = true
-
-      while (result.isEmpty && continue) {
-        val basePhi = currentF.entireFormula
-        val constrainedPhi = And(basePhi +: predicates)
-        println("-"*80)
-        println("To satisfy: "+constrainedPhi)
-        synth.solver.solveSAT(constrainedPhi) match {
-          case (Some(true), satModel) =>
-            println("Found candidate!: "+satModel.filterKeys(bss))
-
-            val fixedBss = And(bss.map(b => Equals(Variable(b), satModel(b))).toSeq)
-            println("Phi with fixed sat bss: "+fixedBss)
-
-            val counterPhi = Implies(And(fixedBss, currentF.program), currentF.phi)
-            println("Formula to validate: "+counterPhi)
-
-            synth.solver.solveSAT(Not(counterPhi)) match {
-              case (Some(true), invalidModel) =>
-                // Found as such as the xs break, refine predicates
-                println("Found counter EX: "+invalidModel)
-                predicates = Not(And(bss.map(b => Equals(Variable(b), satModel(b))).toSeq)) +: predicates
-                println("Let's avoid this case: "+bss.map(b => Equals(Variable(b), satModel(b))).mkString(" "))
-
-              case (Some(false), _) =>
-                val mapping = currentF.mappings.filterKeys(satModel.mapValues(_ == BooleanLiteral(true))).values.toMap
-                println("Mapping: "+mapping)
-                result = Some(RuleSuccess(Solution(BooleanLiteral(true), Tuple(p.xs.map(valuateWithModel(mapping))).setType(tpe))))
-
-              case _ =>
-            }
-
-          case (Some(false), _) =>
-            continue = false
-          case _ =>
-            continue = false
-        }
-      }
-
-      lastF = currentF
-      currentF = currentF.unroll
-      unrolings += 1
-    } while(unrolings < maxUnrolings && lastF != currentF && result.isEmpty)
-
-    result.getOrElse(RuleInapplicable)
-  }
-}
diff --git a/src/main/scala/leon/synthesis/Rules.scala b/src/main/scala/leon/synthesis/Rules.scala
index fc711b1e7..563ea4166 100644
--- a/src/main/scala/leon/synthesis/Rules.scala
+++ b/src/main/scala/leon/synthesis/Rules.scala
@@ -7,6 +7,7 @@ import purescala.Trees._
 import purescala.Extractors._
 import purescala.TreeOps._
 import purescala.TypeTrees._
+import purescala.Definitions._
 
 object Rules {
   def all = Set[Synthesizer => Rule](
@@ -18,6 +19,8 @@ object Rules {
     new CaseSplit(_),
     new UnusedInput(_),
     new UnconstrainedOutput(_),
+    new OptimisticGround(_),
+    new CEGIS(_),
     new Assert(_)
   )
 }
@@ -160,7 +163,7 @@ class Assert(synth: Synthesizer) extends Rule("Assert", synth, 200) {
 class UnusedInput(synth: Synthesizer) extends Rule("UnusedInput", synth, 100) {
   def applyOn(task: Task): RuleResult = {
     val p = task.problem
-    val unused = p.as.toSet -- variablesOf(p.phi)
+    val unused = p.as.toSet -- variablesOf(p.phi) -- variablesOf(p.c)
 
     if (!unused.isEmpty) {
       val sub = p.copy(as = p.as.filterNot(unused))
@@ -278,3 +281,234 @@ class ADTDual(synth: Synthesizer) extends Rule("ADTDual", synth, 200) {
   }
 }
 
+class CEGIS(synth: Synthesizer) extends Rule("CEGIS", synth, 50) {
+  def applyOn(task: Task): RuleResult = {
+    val p = task.problem
+
+    case class Generator(tpe: TypeTree, altBuilder: () => List[(Expr, Set[Identifier])]);
+
+    var generators = Map[TypeTree, Generator]()
+    def getGenerator(t: TypeTree): Generator = generators.get(t) match {
+      case Some(g) => g
+      case None =>
+        val alternatives: () => List[(Expr, Set[Identifier])] = t match {
+          case BooleanType =>
+            { () => List((BooleanLiteral(true), Set()), (BooleanLiteral(false), Set())) }
+
+          case Int32Type =>
+            { () => List((IntLiteral(0), Set()), (IntLiteral(1), Set())) }
+
+          case TupleType(tps) =>
+            { () =>
+              val ids = tps.map(t => FreshIdentifier("t", true).setType(t))
+              List((Tuple(ids.map(Variable(_))), ids.toSet))
+            }
+
+          case CaseClassType(cd) =>
+            { () =>
+              val ids = cd.fieldsIds.map(i => FreshIdentifier("c", true).setType(i.getType))
+              List((CaseClass(cd, ids.map(Variable(_))), ids.toSet))
+            }
+
+          case AbstractClassType(cd) =>
+            { () =>
+              val alts: Seq[(Expr, Set[Identifier])] = cd.knownDescendents.flatMap(i => i match {
+                  case acd: AbstractClassDef =>
+                    synth.reporter.error("Unnexpected abstract class in descendants!")
+                    None
+                  case cd: CaseClassDef =>
+                    val ids = cd.fieldsIds.map(i => FreshIdentifier("c", true).setType(i.getType))
+                    Some((CaseClass(cd, ids.map(Variable(_))), ids.toSet))
+              })
+              alts.toList
+            }
+
+          case _ =>
+            synth.reporter.error("Can't construct generator. Unsupported type: "+t+"["+t.getClass+"]");
+            { () => Nil }
+        }
+        val g = Generator(t, alternatives)
+        generators += t -> g
+        g
+    }
+
+    def inputAlternatives(t: TypeTree): List[(Expr, Set[Identifier])] = {
+      p.as.filter(a => isSubtypeOf(a.getType, t)).map(id => (Variable(id) : Expr, Set[Identifier]()))
+    }
+
+    case class TentativeFormula(phi: Expr,
+                                program: Expr,
+                                mappings: Map[Identifier, (Identifier, Expr)],
+                                recTerms: Map[Identifier, Set[Identifier]]) {
+      def unroll: TentativeFormula = {
+        var newProgram  = List[Expr]()
+        var newRecTerms = Map[Identifier, Set[Identifier]]()
+        var newMappings = Map[Identifier, (Identifier, Expr)]()
+
+        for ((_, recIds) <- recTerms; recId <- recIds) {
+          val gen  = getGenerator(recId.getType)
+          val alts = gen.altBuilder() ::: inputAlternatives(recId.getType)
+
+          val altsWithBranches = alts.map(alt => FreshIdentifier("b", true).setType(BooleanType) -> alt)
+
+          val bvs = altsWithBranches.map(alt => Variable(alt._1))
+          val distinct = if (bvs.size > 1) {
+            (for (i <- (1 to bvs.size-1); j <- 0 to i-1) yield {
+              Or(Not(bvs(i)), Not(bvs(j)))
+            }).toList
+          } else {
+            List(BooleanLiteral(true))
+          }
+          val pre = And(Or(bvs) :: distinct) // (b1 OR b2) AND (Not(b1) OR Not(b2))
+          val cases = for((bid, (ex, rec)) <- altsWithBranches.toList) yield { // b1 => E(gen1, gen2)     [b1 -> {gen1, gen2}]
+            if (!rec.isEmpty) {
+              newRecTerms += bid -> rec
+            }
+            newMappings += bid -> (recId -> ex)
+
+            Implies(Variable(bid), Equals(Variable(recId), ex))
+          }
+
+          newProgram = newProgram ::: pre :: cases
+        }
+
+        TentativeFormula(phi, And(program :: newProgram), mappings ++ newMappings, newRecTerms)
+      }
+
+      def bounds = recTerms.keySet.map(id => Not(Variable(id))).toList
+      def bss = mappings.keySet
+
+      def entireFormula = And(phi :: program :: bounds)
+    }
+
+    var result: Option[RuleResult]   = None
+
+    var ass = p.as.toSet
+    var xss = p.xs.toSet
+
+    var lastF     = TentativeFormula(Implies(p.c, p.phi), BooleanLiteral(true), Map(), Map() ++ p.xs.map(x => x -> Set(x)))
+    var currentF  = lastF.unroll
+    var unrolings = 0
+    val maxUnrolings = 2
+    do {
+      //println("Was: "+lastF.entireFormula)
+      //println("Now Trying : "+currentF.entireFormula)
+
+      val tpe = TupleType(p.xs.map(_.getType))
+      val bss = currentF.bss
+
+      var predicates: Seq[Expr]        = Seq()
+      var continue = true
+
+      while (result.isEmpty && continue) {
+        val basePhi = currentF.entireFormula
+        val constrainedPhi = And(basePhi +: predicates)
+        //println("-"*80)
+        //println("To satisfy: "+constrainedPhi)
+        synth.solver.solveSAT(constrainedPhi) match {
+          case (Some(true), satModel) =>
+            //println("Found candidate!: "+satModel.filterKeys(bss))
+
+            //println("Corresponding program: "+simplifyTautologies(synth.solver)(valuateWithModelIn(currentF.program, bss, satModel)))
+            val fixedBss = And(bss.map(b => Equals(Variable(b), satModel(b))).toSeq)
+            //println("Phi with fixed sat bss: "+fixedBss)
+
+            val counterPhi = Implies(And(fixedBss, currentF.program), currentF.phi)
+            //println("Formula to validate: "+counterPhi)
+
+            synth.solver.solveSAT(Not(counterPhi)) match {
+              case (Some(true), invalidModel) =>
+                // Found as such as the xs break, refine predicates
+                //println("Found counter EX: "+invalidModel)
+                predicates = Not(And(bss.map(b => Equals(Variable(b), satModel(b))).toSeq)) +: predicates
+                //println("Let's avoid this case: "+bss.map(b => Equals(Variable(b), satModel(b))).mkString(" "))
+
+              case (Some(false), _) =>
+                //println("Sat model: "+satModel.toSeq.sortBy(_._1.toString).map{ case (id, v) => id+" -> "+v }.mkString(", "))
+                var mapping = currentF.mappings.filterKeys(satModel.mapValues(_ == BooleanLiteral(true))).values.toMap
+
+                //println("Mapping: "+mapping)
+
+                // Resolve mapping
+                for ((c, e) <- mapping) {
+                  mapping += c -> substAll(mapping, e)
+                }
+
+                result = Some(RuleSuccess(Solution(BooleanLiteral(true), Tuple(p.xs.map(valuateWithModel(mapping))).setType(tpe))))
+
+              case _ =>
+            }
+
+          case (Some(false), _) =>
+            //println("%%%% UNSAT")
+            continue = false
+          case _ =>
+            //println("%%%% WOOPS")
+            continue = false
+        }
+      }
+
+      lastF = currentF
+      currentF = currentF.unroll
+      unrolings += 1
+    } while(unrolings < maxUnrolings && lastF != currentF && result.isEmpty)
+
+    result.getOrElse(RuleInapplicable)
+  }
+}
+
+class OptimisticGround(synth: Synthesizer) extends Rule("Optimistic Ground", synth, 90) {
+  def applyOn(task: Task): RuleResult = {
+    val p = task.problem
+
+    if (!p.as.isEmpty && !p.xs.isEmpty) {
+      val xss = p.xs.toSet
+      val ass = p.as.toSet
+
+      val tpe = TupleType(p.xs.map(_.getType))
+
+      var i = 0;
+      var maxTries = 3;
+
+      var result: Option[RuleResult]   = None
+      var predicates: Seq[Expr]        = Seq()
+
+      while (result.isEmpty && i < maxTries) {
+        val phi = And(p.phi +: predicates)
+        synth.solver.solveSAT(phi) match {
+          case (Some(true), satModel) =>
+            val satXsModel = satModel.filterKeys(xss) 
+
+            val newPhi = valuateWithModelIn(phi, xss, satModel)
+
+            synth.solver.solveSAT(Not(newPhi)) match {
+              case (Some(true), invalidModel) =>
+                // Found as such as the xs break, refine predicates
+                predicates = valuateWithModelIn(phi, ass, invalidModel) +: predicates
+
+              case (Some(false), _) =>
+                result = Some(RuleSuccess(Solution(BooleanLiteral(true), Tuple(p.xs.map(valuateWithModel(satModel))).setType(tpe))))
+
+              case _ =>
+                result = Some(RuleInapplicable)
+            }
+
+          case (Some(false), _) =>
+            if (predicates.isEmpty) {
+              result = Some(RuleSuccess(Solution(BooleanLiteral(false), Error(p.phi+" is UNSAT!").setType(tpe))))
+            } else {
+              result = Some(RuleInapplicable)
+            }
+          case _ =>
+            result = Some(RuleInapplicable)
+        }
+
+        i += 1 
+      }
+
+      result.getOrElse(RuleInapplicable)
+    } else {
+      RuleInapplicable
+    }
+  }
+}
diff --git a/src/main/scala/leon/synthesis/Synthesizer.scala b/src/main/scala/leon/synthesis/Synthesizer.scala
index ba8ec79c4..a811fe6b7 100644
--- a/src/main/scala/leon/synthesis/Synthesizer.scala
+++ b/src/main/scala/leon/synthesis/Synthesizer.scala
@@ -42,9 +42,11 @@ class Synthesizer(val reporter: Reporter,
 
     val ts = System.currentTimeMillis
 
+    def currentDurationMs = System.currentTimeMillis-ts
+
     def timeoutExpired(): Boolean = {
       timeoutMs match {
-        case Some(t) if (System.currentTimeMillis-ts)/1000 > t => true
+        case Some(t) if currentDurationMs/1000 > t => true
         case _ => false
       }
     }
@@ -64,6 +66,8 @@ class Synthesizer(val reporter: Reporter,
       }
     }
 
+    info("Finished in "+currentDurationMs+"ms")
+
     if (generateDerivationTrees) {
       val deriv = new DerivationTree(rootTask)
       deriv.toDotFile("derivation"+derivationCounter+".dot")
diff --git a/src/main/scala/leon/synthesis/Task.scala b/src/main/scala/leon/synthesis/Task.scala
index ed3bd980d..7f312deef 100644
--- a/src/main/scala/leon/synthesis/Task.scala
+++ b/src/main/scala/leon/synthesis/Task.scala
@@ -121,8 +121,10 @@ class RootTask(synth: Synthesizer, problem: Problem) extends Task(synth, null, p
   }
 
   override def partlySolvedBy(t: Task, s: Solution) {
-    solution = Some(s)
-    solver   = Some(t)
+    if (isBetterSolutionThan(s, solution)) {
+      solution = Some(s)
+      solver   = Some(t)
+    }
   }
 
   override def unsolvedBy(t: Task) {
-- 
GitLab