Skip to content
Snippets Groups Projects
Commit f5fb158f authored by Etienne Kneuss's avatar Etienne Kneuss
Browse files

Implement the concept of Normalizing rules

Normalizing rules are rules that:
1) always help synthesis
2) are commutative
3) should be applied as early as possible

Here we apply normalizing rules explicitly before all other rules, and
in a deterministic order. This should dramatically reduce the search
space in cases where such rules apply.

Note that rules that are said to be normalizing should never fail once
instantiated.
parent d2b25e38
No related branches found
No related tags found
No related merge requests found
...@@ -7,7 +7,7 @@ import purescala.TypeTrees.TupleType ...@@ -7,7 +7,7 @@ import purescala.TypeTrees.TupleType
import heuristics._ import heuristics._
object Heuristics { object Heuristics {
def all = Set[Rule]( def all = List[Rule](
IntInduction, IntInduction,
InnerCaseSplit, InnerCaseSplit,
//new OptimisticInjection(_), //new OptimisticInjection(_),
......
...@@ -8,7 +8,7 @@ import solvers.TrivialSolver ...@@ -8,7 +8,7 @@ import solvers.TrivialSolver
class ParallelSearch(synth: Synthesizer, class ParallelSearch(synth: Synthesizer,
problem: Problem, problem: Problem,
rules: Set[Rule], rules: Seq[Rule],
costModel: CostModel, costModel: CostModel,
nWorkers: Int) extends AndOrGraphParallelSearch[SynthesisContext, TaskRunRule, TaskTryRules, Solution](new AndOrGraph(TaskTryRules(problem), SearchCostModel(costModel)), nWorkers) { nWorkers: Int) extends AndOrGraphParallelSearch[SynthesisContext, TaskRunRule, TaskTryRules, Solution](new AndOrGraph(TaskTryRules(problem), SearchCostModel(costModel)), nWorkers) {
...@@ -69,14 +69,22 @@ class ParallelSearch(synth: Synthesizer, ...@@ -69,14 +69,22 @@ class ParallelSearch(synth: Synthesizer,
} }
def expandOrTask(ref: ActorRef, sctx: SynthesisContext)(t: TaskTryRules) = { def expandOrTask(ref: ActorRef, sctx: SynthesisContext)(t: TaskTryRules) = {
val sub = rules.flatMap { r => val (normRules, otherRules) = rules.partition(_.isInstanceOf[NormalizingRule])
r.instantiateOn(sctx, t.p).map(TaskRunRule(_))
} val normApplications = normRules.flatMap(_.instantiateOn(sctx, t.p))
if (!sub.isEmpty) { if (!normApplications.isEmpty) {
Expanded(sub.toList) Expanded(List(TaskRunRule(normApplications.head)))
} else { } else {
ExpandFailure() val sub = otherRules.flatMap { r =>
r.instantiateOn(sctx, t.p).map(TaskRunRule(_))
}
if (!sub.isEmpty) {
Expanded(sub.toList)
} else {
ExpandFailure()
}
} }
} }
} }
...@@ -8,7 +8,7 @@ import purescala.TreeOps._ ...@@ -8,7 +8,7 @@ import purescala.TreeOps._
import rules._ import rules._
object Rules { object Rules {
def all = Set[Rule]( def all = List[Rule](
Unification.DecompTrivialClash, Unification.DecompTrivialClash,
Unification.OccursCheck, // probably useless Unification.OccursCheck, // probably useless
Disunification.Decomp, Disunification.Decomp,
...@@ -89,3 +89,10 @@ abstract class Rule(val name: String) { ...@@ -89,3 +89,10 @@ abstract class Rule(val name: String) {
override def toString = "R: "+name override def toString = "R: "+name
} }
// Note: Rules that extend NormalizingRule should all be commutative, The will
// be applied before others in a deterministic order and their application
// should never fail!
abstract class NormalizingRule(name: String) extends Rule(name) {
override def toString = "N: "+name
}
...@@ -32,7 +32,7 @@ case class SearchCostModel(cm: CostModel) extends AOCostModel[TaskRunRule, TaskT ...@@ -32,7 +32,7 @@ case class SearchCostModel(cm: CostModel) extends AOCostModel[TaskRunRule, TaskT
class SimpleSearch(synth: Synthesizer, class SimpleSearch(synth: Synthesizer,
problem: Problem, problem: Problem,
rules: Set[Rule], rules: Seq[Rule],
costModel: CostModel) extends AndOrGraphSearch[TaskRunRule, TaskTryRules, Solution](new AndOrGraph(TaskTryRules(problem), SearchCostModel(costModel))) { costModel: CostModel) extends AndOrGraphSearch[TaskRunRule, TaskTryRules, Solution](new AndOrGraph(TaskTryRules(problem), SearchCostModel(costModel))) {
import synth.reporter._ import synth.reporter._
...@@ -64,12 +64,22 @@ class SimpleSearch(synth: Synthesizer, ...@@ -64,12 +64,22 @@ class SimpleSearch(synth: Synthesizer,
} }
def expandOrTask(t: TaskTryRules): ExpandResult[TaskRunRule] = { def expandOrTask(t: TaskTryRules): ExpandResult[TaskRunRule] = {
val sub = rules.flatMap ( r => r.instantiateOn(sctx, t.p).map(TaskRunRule(_)) ) val (normRules, otherRules) = rules.partition(_.isInstanceOf[NormalizingRule])
if (!sub.isEmpty) { val normApplications = normRules.flatMap(_.instantiateOn(sctx, t.p))
Expanded(sub.toList)
if (!normApplications.isEmpty) {
Expanded(List(TaskRunRule(normApplications.head)))
} else { } else {
ExpandFailure() val sub = otherRules.flatMap { r =>
r.instantiateOn(sctx, t.p).map(TaskRunRule(_))
}
if (!sub.isEmpty) {
Expanded(sub.toList)
} else {
ExpandFailure()
}
} }
} }
......
...@@ -22,7 +22,7 @@ class Synthesizer(val context : LeonContext, ...@@ -22,7 +22,7 @@ class Synthesizer(val context : LeonContext,
val solver: Solver, val solver: Solver,
val program: Program, val program: Program,
val problem: Problem, val problem: Problem,
val rules: Set[Rule], val rules: Seq[Rule],
val options: SynthesizerOptions) { val options: SynthesizerOptions) {
protected[synthesis] val reporter = context.reporter protected[synthesis] val reporter = context.reporter
......
...@@ -6,7 +6,7 @@ import purescala.Trees._ ...@@ -6,7 +6,7 @@ import purescala.Trees._
import purescala.TreeOps._ import purescala.TreeOps._
import purescala.Extractors._ import purescala.Extractors._
case object Assert extends Rule("Assert") { case object Assert extends NormalizingRule("Assert") {
def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = {
p.phi match { p.phi match {
case TopLevelAnds(exprs) => case TopLevelAnds(exprs) =>
......
...@@ -6,7 +6,7 @@ import purescala.Trees._ ...@@ -6,7 +6,7 @@ import purescala.Trees._
import purescala.TreeOps._ import purescala.TreeOps._
import purescala.Extractors._ import purescala.Extractors._
case object OnePoint extends Rule("One-point") { case object OnePoint extends NormalizingRule("One-point") {
def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = {
val TopLevelAnds(exprs) = p.phi val TopLevelAnds(exprs) = p.phi
......
...@@ -6,7 +6,7 @@ import purescala.Trees._ ...@@ -6,7 +6,7 @@ import purescala.Trees._
import purescala.TreeOps._ import purescala.TreeOps._
import purescala.Extractors._ import purescala.Extractors._
case object UnconstrainedOutput extends Rule("Unconstr.Output") { case object UnconstrainedOutput extends NormalizingRule("Unconstr.Output") {
def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = {
val unconstr = p.xs.toSet -- variablesOf(p.phi) val unconstr = p.xs.toSet -- variablesOf(p.phi)
......
...@@ -6,7 +6,7 @@ import purescala.Trees._ ...@@ -6,7 +6,7 @@ import purescala.Trees._
import purescala.TreeOps._ import purescala.TreeOps._
import purescala.Extractors._ import purescala.Extractors._
case object UnusedInput extends Rule("UnusedInput") { case object UnusedInput extends NormalizingRule("UnusedInput") {
def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = {
val unused = p.as.toSet -- variablesOf(p.phi) -- variablesOf(p.pc) val unused = p.as.toSet -- variablesOf(p.phi) -- variablesOf(p.pc)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment