From 70f68bc7169fe4552993e5ed0b8c25163a600c03 Mon Sep 17 00:00:00 2001
From: Etienne Kneuss <ekneuss@gmail.com>
Date: Thu, 1 Nov 2012 02:02:44 +0100
Subject: [PATCH] Score solution programatically, not by hand. Allow arbitrary
 scrutinees (to check)

---
 .../scala/leon/purescala/Extractors.scala     |  1 +
 src/main/scala/leon/purescala/TreeOps.scala   | 17 ++++++-----
 src/main/scala/leon/synthesis/Rules.scala     | 30 +++++++++----------
 src/main/scala/leon/synthesis/Solution.scala  |  8 +++--
 .../scala/leon/synthesis/Synthesizer.scala    |  9 ------
 src/main/scala/leon/synthesis/Task.scala      |  9 +++---
 6 files changed, 35 insertions(+), 39 deletions(-)

diff --git a/src/main/scala/leon/purescala/Extractors.scala b/src/main/scala/leon/purescala/Extractors.scala
index 94eb5fba0..392c7af33 100644
--- a/src/main/scala/leon/purescala/Extractors.scala
+++ b/src/main/scala/leon/purescala/Extractors.scala
@@ -73,6 +73,7 @@ object Extractors {
       case ArraySelect(t1, t2) => Some((t1, t2, ArraySelect))
       case Concat(t1,t2) => Some((t1,t2,Concat))
       case ListAt(t1,t2) => Some((t1,t2,ListAt))
+      case Let(binders, e, body) => Some((e, body, (e: Expr, b: Expr) => Let(binders, e, body)))
       case LetTuple(binders, e, body) => Some((e, body, (e: Expr, b: Expr) => LetTuple(binders, e, body)))
       case wh@While(t1, t2) => Some((t1,t2, (t1, t2) => While(t1, t2).setInvariant(wh.invariant).setPosInfo(wh)))
       case ex: BinaryExtractable => ex.extract
diff --git a/src/main/scala/leon/purescala/TreeOps.scala b/src/main/scala/leon/purescala/TreeOps.scala
index 14b9ae018..dbabda723 100644
--- a/src/main/scala/leon/purescala/TreeOps.scala
+++ b/src/main/scala/leon/purescala/TreeOps.scala
@@ -1107,10 +1107,10 @@ object TreeOps {
         if (cases.forall(_.isInstanceOf[CaseClassInstanceOf])) {
           // matchingOn might initially be: a : T1, a.tail : T2, b: T2
           def selectorDepth(e: Expr): Int = e match {
-            case v: Variable =>
-              0
             case cd: CaseClassSelector =>
               1+selectorDepth(cd.caseClass)
+            case _ =>
+              0
           }
 
           var scrutSet = Set[Expr]()
@@ -1121,14 +1121,14 @@ object TreeOps {
             conditions += expr -> cd
 
             expr match {
-              case v: Variable =>
-                scrutSet += v 
               case cd: CaseClassSelector =>
                 if (!scrutSet.contains(cd.caseClass)) {
                   // we found a test looking like "a.foo.isInstanceof[..]"
                   // without a check on "a".
                   scrutSet += cd
                 }
+              case e =>
+                scrutSet += e
             }
           }
 
@@ -1137,12 +1137,13 @@ object TreeOps {
 
           def computePatternFor(cd: CaseClassDef, prefix: Expr): Pattern = {
 
-            val id = prefix match {
-              case CaseClassSelector(_, _, id) => id
-              case Variable(id) => id
+            val name = prefix match {
+              case CaseClassSelector(_, _, id) => id.name
+              case Variable(id) => id.name
+              case _ => "tmp"
             }
 
-            val binder = FreshIdentifier(id.name, true).setType(id.getType) // Is it full of women though?
+            val binder = FreshIdentifier(name, true).setType(prefix.getType) // Is it full of women though?
 
             // prefix becomes binder
             substMap += prefix -> Variable(binder)
diff --git a/src/main/scala/leon/synthesis/Rules.scala b/src/main/scala/leon/synthesis/Rules.scala
index ea1220929..16351d9b4 100644
--- a/src/main/scala/leon/synthesis/Rules.scala
+++ b/src/main/scala/leon/synthesis/Rules.scala
@@ -27,7 +27,7 @@ case class RuleSuccess(solution: Solution) extends RuleResult
 case class RuleDecomposed(subProblems: List[Problem], onSuccess: List[Solution] => Solution) extends RuleResult
 
 
-abstract class Rule(val name: String, val synth: Synthesizer) {
+abstract class Rule(val name: String, val synth: Synthesizer, val priority: Priority) {
   def applyOn(task: Task): RuleResult
 
   def subst(what: Tuple2[Identifier, Expr], in: Expr): Expr = replace(Map(Variable(what._1) -> what._2), in)
@@ -40,7 +40,7 @@ abstract class Rule(val name: String, val synth: Synthesizer) {
   override def toString = name
 }
 
-class OnePoint(synth: Synthesizer) extends Rule("One-point", synth) {
+class OnePoint(synth: Synthesizer) extends Rule("One-point", synth, 300) {
   def applyOn(task: Task): RuleResult = {
 
     val p = task.problem
@@ -63,11 +63,11 @@ class OnePoint(synth: Synthesizer) extends Rule("One-point", synth) {
       val newProblem = Problem(p.as, subst(x -> e, And(others)), oxs)
 
       val onSuccess: List[Solution] => Solution = { 
-        case List(Solution(pre, term, sc)) =>
+        case List(Solution(pre, term)) =>
           if (oxs.isEmpty) {
-            Solution(pre, Tuple(e :: Nil), sc) 
+            Solution(pre, Tuple(e :: Nil)) 
           } else {
-            Solution(pre, LetTuple(oxs, term, subst(x -> e, Tuple(p.xs.map(Variable(_))))), sc) 
+            Solution(pre, LetTuple(oxs, term, subst(x -> e, Tuple(p.xs.map(Variable(_)))))) 
           }
         case _ => Solution.none
       }
@@ -79,7 +79,7 @@ class OnePoint(synth: Synthesizer) extends Rule("One-point", synth) {
   }
 }
 
-class Ground(synth: Synthesizer) extends Rule("Ground", synth) {
+class Ground(synth: Synthesizer) extends Rule("Ground", synth, 500) {
   def applyOn(task: Task): RuleResult = {
     val p = task.problem
 
@@ -101,7 +101,7 @@ class Ground(synth: Synthesizer) extends Rule("Ground", synth) {
   }
 }
 
-class CaseSplit(synth: Synthesizer) extends Rule("Case-Split", synth) {
+class CaseSplit(synth: Synthesizer) extends Rule("Case-Split", synth, 200) {
   def applyOn(task: Task): RuleResult = {
     val p = task.problem
     p.phi match {
@@ -110,7 +110,7 @@ class CaseSplit(synth: Synthesizer) extends Rule("Case-Split", synth) {
         val sub2 = Problem(p.as, o2, p.xs)
 
         val onSuccess: List[Solution] => Solution = { 
-          case List(s1, s2) => Solution(Or(s1.pre, s2.pre), IfExpr(s1.pre, s1.term, s2.term), 100)
+          case List(s1, s2) => Solution(Or(s1.pre, s2.pre), IfExpr(s1.pre, s1.term, s2.term))
           case _ => Solution.none
         }
 
@@ -121,7 +121,7 @@ class CaseSplit(synth: Synthesizer) extends Rule("Case-Split", synth) {
   }
 }
 
-class Assert(synth: Synthesizer) extends Rule("Assert", synth) {
+class Assert(synth: Synthesizer) extends Rule("Assert", synth, 200) {
   def applyOn(task: Task): RuleResult = {
     val p = task.problem
 
@@ -136,7 +136,7 @@ class Assert(synth: Synthesizer) extends Rule("Assert", synth) {
             RuleSuccess(Solution(And(exprsA), Tuple(p.xs.map(id => simplestValue(Variable(id))))))
           } else {
             val onSuccess: List[Solution] => Solution = { 
-              case List(s) => Solution(And(s.pre +: exprsA), s.term, 150)
+              case List(s) => Solution(And(s.pre +: exprsA), s.term)
               case _ => Solution.none
             }
 
@@ -153,7 +153,7 @@ class Assert(synth: Synthesizer) extends Rule("Assert", synth) {
   }
 }
 
-class UnusedInput(synth: Synthesizer) extends Rule("UnusedInput", synth) {
+class UnusedInput(synth: Synthesizer) extends Rule("UnusedInput", synth, 500) {
   def applyOn(task: Task): RuleResult = {
     val p = task.problem
     val unused = p.as.toSet -- variablesOf(p.phi)
@@ -168,7 +168,7 @@ class UnusedInput(synth: Synthesizer) extends Rule("UnusedInput", synth) {
   }
 }
 
-class UnconstrainedOutput(synth: Synthesizer) extends Rule("Unconstr.Output", synth) {
+class UnconstrainedOutput(synth: Synthesizer) extends Rule("Unconstr.Output", synth, 500) {
   def applyOn(task: Task): RuleResult = {
     val p = task.problem
     val unconstr = p.xs.toSet -- variablesOf(p.phi)
@@ -192,7 +192,7 @@ class UnconstrainedOutput(synth: Synthesizer) extends Rule("Unconstr.Output", sy
 }
 
 object Unification {
-  class DecompTrivialClash(synth: Synthesizer) extends Rule("Unif Dec./Clash/Triv.", synth) {
+  class DecompTrivialClash(synth: Synthesizer) extends Rule("Unif Dec./Clash/Triv.", synth, 300) {
     def applyOn(task: Task): RuleResult = {
       val p = task.problem
 
@@ -220,7 +220,7 @@ object Unification {
     }
   }
 
-  class OccursCheck(synth: Synthesizer) extends Rule("Unif OccursCheck", synth) {
+  class OccursCheck(synth: Synthesizer) extends Rule("Unif OccursCheck", synth, 300) {
     def applyOn(task: Task): RuleResult = {
       val p = task.problem
 
@@ -247,7 +247,7 @@ object Unification {
 }
 
 
-class ADTDual(synth: Synthesizer) extends Rule("ADTDual", synth) {
+class ADTDual(synth: Synthesizer) extends Rule("ADTDual", synth, 300) {
   def applyOn(task: Task): RuleResult = {
     val p = task.problem
 
diff --git a/src/main/scala/leon/synthesis/Solution.scala b/src/main/scala/leon/synthesis/Solution.scala
index 25e0d5899..0dbdc5dbf 100644
--- a/src/main/scala/leon/synthesis/Solution.scala
+++ b/src/main/scala/leon/synthesis/Solution.scala
@@ -5,7 +5,7 @@ import leon.purescala.Trees._
 
 // Defines a synthesis solution of the form:
 // ⟨ P | T ⟩
-case class Solution(pre: Expr, term: Expr, score: Score = 0) {
+case class Solution(pre: Expr, term: Expr) {
   override def toString = "⟨ "+pre+" | "+term+" ⟩" 
 
   def toExpr = {
@@ -17,10 +17,14 @@ case class Solution(pre: Expr, term: Expr, score: Score = 0) {
       IfExpr(pre, term, Error("Precondition failed").setType(term.getType))
     }
   }
+
+  def score: Score = 10
 }
 
 object Solution {
-  def choose(p: Problem): Solution = Solution(BooleanLiteral(true), Choose(p.xs, p.phi), 0)
+  def choose(p: Problem): Solution = new Solution(BooleanLiteral(true), Choose(p.xs, p.phi)) {
+    override def score: Score = 0
+  }
 
   def none: Solution = throw new Exception("Unexpected failure to construct solution")
 }
diff --git a/src/main/scala/leon/synthesis/Synthesizer.scala b/src/main/scala/leon/synthesis/Synthesizer.scala
index b1c443494..1d75db8e0 100644
--- a/src/main/scala/leon/synthesis/Synthesizer.scala
+++ b/src/main/scala/leon/synthesis/Synthesizer.scala
@@ -36,9 +36,7 @@ class Synthesizer(val r: Reporter, val solvers: List[Solver], generateDerivation
     while (!workList.isEmpty && solution.isEmpty) {
       val task = workList.dequeue()
 
-      println("Running "+task+"...")
       val subtasks = task.run
-      println(subtasks)
 
       workList ++= subtasks
     }
@@ -119,11 +117,4 @@ class Synthesizer(val r: Reporter, val solvers: List[Solver], generateDerivation
 
     solutions
   }
-
-
-
-  def solutionToString(solution: Solution): String = {
-    ScalaPrinter(simplifyLets(solution.toExpr))
-  }
-
 }
diff --git a/src/main/scala/leon/synthesis/Task.scala b/src/main/scala/leon/synthesis/Task.scala
index d971b8ac0..a81ac85a6 100644
--- a/src/main/scala/leon/synthesis/Task.scala
+++ b/src/main/scala/leon/synthesis/Task.scala
@@ -33,7 +33,7 @@ class SimpleTask(synth: Synthesizer,
   }
 
   def run: List[Task] = {
-    synth.rules.map(r => new ApplyRuleTask(synth, this, problem, priority, r))
+    synth.rules.map(r => new ApplyRuleTask(synth, this, problem, r))
   }
 
   var failed = Set[Rule]()
@@ -50,8 +50,7 @@ class RootTask(synth: Synthesizer, problem: Problem) extends SimpleTask(synth, n
 class ApplyRuleTask(synth: Synthesizer,
                     override val parent: SimpleTask,
                     problem: Problem,
-                    priority: Priority,
-                    val rule: Rule) extends Task(synth, parent, problem, priority) {
+                    val rule: Rule) extends Task(synth, parent, problem, rule.priority) {
 
   var subProblems: List[Problem]            = _
   var onSuccess: List[Solution] => Solution = _
@@ -60,7 +59,7 @@ class ApplyRuleTask(synth: Synthesizer,
   def subSucceeded(p: Problem, s: Solution) {
     assert(subProblems contains p, "Problem "+p+" is unknown to me ?!?")
 
-    if (subSolutions.get(p).map(_.score).getOrElse(-1) < s.score) {
+    if (subSolutions.get(p).map(_.score).getOrElse(-1) <= s.score) {
       subSolutions += p -> s
 
       if (subSolutions.size == subProblems.size) {
@@ -90,5 +89,5 @@ class ApplyRuleTask(synth: Synthesizer,
     }
   }
 
-  override def toString = "Trying "+rule+" on "+problem
+  override def toString = "Applying "+rule+" on "+problem
 }
-- 
GitLab