From e4a278b40e3be339c83500a14596414f53927986 Mon Sep 17 00:00:00 2001
From: Etienne Kneuss <ekneuss@gmail.com>
Date: Fri, 11 Jan 2013 02:42:09 +0100
Subject: [PATCH] Propagate expected types to onSuccess

This allows CostModels to estimate correctly the minimal cost of a
applying a rule.

With type information on the expected types of a solution
reconstruction, the cost model can provide dummy values of the correct
type, avoiding assertion errors when composing solutions.
---
 src/main/scala/leon/purescala/TreeOps.scala   |  1 +
 src/main/scala/leon/synthesis/CostModel.scala |  2 +-
 .../scala/leon/synthesis/Heuristics.scala     |  5 +++-
 src/main/scala/leon/synthesis/Rules.scala     | 23 ++++++++-----------
 src/main/scala/leon/synthesis/Solution.scala  |  5 ++--
 5 files changed, 19 insertions(+), 17 deletions(-)

diff --git a/src/main/scala/leon/purescala/TreeOps.scala b/src/main/scala/leon/purescala/TreeOps.scala
index d0b9859e2..476f0ca85 100644
--- a/src/main/scala/leon/purescala/TreeOps.scala
+++ b/src/main/scala/leon/purescala/TreeOps.scala
@@ -854,6 +854,7 @@ object TreeOps {
       CaseClass(ccd, fields.map(f => simplestValue(f.getType)))
     case SetType(baseType) => FiniteSet(Seq()).setType(tpe)
     case MapType(fromType, toType) => FiniteMap(Seq()).setType(tpe)
+    case TupleType(tpes) => Tuple(tpes.map(simplestValue))
     case _ => throw new Exception("I can't choose simplest value for type " + tpe)
   }
 
diff --git a/src/main/scala/leon/synthesis/CostModel.scala b/src/main/scala/leon/synthesis/CostModel.scala
index 19a975ed2..7fd051d8e 100644
--- a/src/main/scala/leon/synthesis/CostModel.scala
+++ b/src/main/scala/leon/synthesis/CostModel.scala
@@ -12,7 +12,7 @@ abstract class CostModel(val name: String) {
   def problemCost(p: Problem): Cost
 
   def ruleAppCost(app: RuleInstantiation): Cost = new Cost {
-    val subSols = (1 to app.onSuccess.arity).map {i => Solution.simplest }.toList
+    val subSols = app.onSuccess.types.map {t => Solution.simplest(t) }.toList
     val simpleSol = app.onSuccess(subSols)
 
     val value = simpleSol match {
diff --git a/src/main/scala/leon/synthesis/Heuristics.scala b/src/main/scala/leon/synthesis/Heuristics.scala
index 837c08d3b..95ac39a13 100644
--- a/src/main/scala/leon/synthesis/Heuristics.scala
+++ b/src/main/scala/leon/synthesis/Heuristics.scala
@@ -2,6 +2,7 @@ package leon
 package synthesis
 
 import purescala.Trees._
+import purescala.TypeTrees.TupleType
 
 import heuristics._
 
@@ -38,7 +39,9 @@ object HeuristicInstantiation {
   }
 
   def apply(problem: Problem, rule: Rule, subProblems: List[Problem], onSuccess: List[Solution] => Option[Solution]): RuleInstantiation = {
-    val builder = new SolutionBuilder(subProblems.size) {
+    val subTypes = subProblems.map(p => TupleType(p.xs.map(_.getType)))
+
+    val builder = new SolutionBuilder(subProblems.size, subTypes) {
       def apply(sols: List[Solution]) = {
         onSuccess(sols)
       }
diff --git a/src/main/scala/leon/synthesis/Rules.scala b/src/main/scala/leon/synthesis/Rules.scala
index 548d8fb30..6f02e767a 100644
--- a/src/main/scala/leon/synthesis/Rules.scala
+++ b/src/main/scala/leon/synthesis/Rules.scala
@@ -3,6 +3,7 @@ package synthesis
 
 import purescala.Common._
 import purescala.Trees._
+import purescala.TypeTrees._
 import purescala.TreeOps._
 import rules._
 
@@ -27,11 +28,13 @@ object Rules {
   )
 }
 
-abstract class SolutionBuilder(val arity: Int) {
+abstract class SolutionBuilder(val arity: Int, val types: Seq[TypeTree]) {
   def apply(sols: List[Solution]): Option[Solution]
+
+  assert(types.size == arity)
 }
 
-class SolutionCombiner(arity: Int, f: List[Solution] => Option[Solution]) extends SolutionBuilder(arity) {
+class SolutionCombiner(arity: Int, types: Seq[TypeTree],  f: List[Solution] => Option[Solution]) extends SolutionBuilder(arity, types) {
   def apply(sols: List[Solution]) = {
     assert(sols.size == arity)
     f(sols)
@@ -39,7 +42,7 @@ class SolutionCombiner(arity: Int, f: List[Solution] => Option[Solution]) extend
 }
 
 object SolutionBuilder {
-  val none = new SolutionBuilder(0) {
+  val none = new SolutionBuilder(0, Seq()) {
     def apply(sols: List[Solution]) = None
   }
 }
@@ -48,14 +51,6 @@ abstract class RuleInstantiation(val problem: Problem, val rule: Rule, val onSuc
   def apply(sctx: SynthesisContext): RuleApplicationResult
 }
 
-//abstract class RuleApplication(val subProblemsCount: Int,
-//                               val onSuccess: List[Solution] => Solution) {
-//
-//  def apply(sctx: SynthesisContext): RuleApplicationResult
-//}
-//
-//abstract class RuleImmediateApplication extends RuleApplication(0, s => Solution.simplest)
-
 sealed abstract class RuleApplicationResult
 case class RuleSuccess(solution: Solution)    extends RuleApplicationResult
 case class RuleDecomposed(sub: List[Problem]) extends RuleApplicationResult
@@ -63,13 +58,15 @@ case object RuleApplicationImpossible         extends RuleApplicationResult
 
 object RuleInstantiation {
   def immediateDecomp(problem: Problem, rule: Rule, sub: List[Problem], onSuccess: List[Solution] => Option[Solution]) = {
-    new RuleInstantiation(problem, rule, new SolutionCombiner(sub.size, onSuccess)) {
+    val subTypes = sub.map(p => TupleType(p.xs.map(_.getType)))
+
+    new RuleInstantiation(problem, rule, new SolutionCombiner(sub.size, subTypes, onSuccess)) {
       def apply(sctx: SynthesisContext) = RuleDecomposed(sub)
     }
   }
 
   def immediateSuccess(problem: Problem, rule: Rule, solution: Solution) = {
-    new RuleInstantiation(problem, rule, new SolutionCombiner(0, ls => Some(solution))) {
+    new RuleInstantiation(problem, rule, new SolutionCombiner(0, Seq(), ls => Some(solution))) {
       def apply(sctx: SynthesisContext) = RuleSuccess(solution)
     }
   }
diff --git a/src/main/scala/leon/synthesis/Solution.scala b/src/main/scala/leon/synthesis/Solution.scala
index 07c37915b..902ec8e76 100644
--- a/src/main/scala/leon/synthesis/Solution.scala
+++ b/src/main/scala/leon/synthesis/Solution.scala
@@ -2,6 +2,7 @@ package leon
 package synthesis
 
 import leon.purescala.Trees._
+import leon.purescala.TypeTrees.TypeTree
 import leon.purescala.Definitions._
 import leon.purescala.TreeOps._
 import leon.xlang.Trees.LetDef
@@ -45,7 +46,7 @@ object Solution {
     new Solution(BooleanLiteral(true), Set(), Tuple(p.xs.map(id => simplestValue(id.getType))))
   }
 
-  def simplest: Solution = {
-    new Solution(BooleanLiteral(true), Set(), BooleanLiteral(true))
+  def simplest(t: TypeTree): Solution = {
+    new Solution(BooleanLiteral(true), Set(), simplestValue(t))
   }
 }
-- 
GitLab