Skip to content
Snippets Groups Projects
Commit f34ab79b authored by Etienne Kneuss's avatar Etienne Kneuss
Browse files

Rework complexities a bit

parent 9a44a404
No related branches found
No related tags found
No related merge requests found
package leon
package synthesis
abstract class Complexity extends Ordered[Complexity] {
def compare(that: Complexity): Int = (this.compute, that.compute) match {
case (x, y) if x < y => -1
case (x, y) if x > y => +1
case _ => 0
}
import purescala.Trees._
abstract class Complexity[T <: Complexity[T]] extends Ordered[T] {
def compare(that: T) = this.value - that.value
def compute : Double
def value : Int
}
case class TaskComplexity(p: Problem, r: Option[Rule]) extends Complexity {
def compute = {
r match {
case class TaskComplexity(t: Task) extends Complexity[TaskComplexity] {
def value= {
Option(t.rule) match {
case Some(r) =>
100*p.complexity.compute + (100-r.priority)
100*t.problem.complexity.value + (100-r.priority) + t.minSolutionCost
case None =>
0
}
}
}
object Complexity {
val zero = new Complexity {
override def compute = 0
override def toString = "0"
}
val max = new Complexity {
override def compute = 42
override def toString = "MAX"
}
case class SolutionComplexity(s: Solution) extends Complexity[SolutionComplexity] {
lazy val value = 42
}
case class ProblemComplexity(p: Problem) extends Complexity[ProblemComplexity] {
lazy val value = 42
}
......@@ -45,7 +45,7 @@ class OptimisticGround(synth: Synthesizer) extends Rule("Optimistic Ground", syn
predicates = valuateWithModelIn(phi, ass, invalidModel) +: predicates
case (Some(false), _) =>
result = Some(RuleSuccess(Solution(BooleanLiteral(true), Tuple(p.xs.map(valuateWithModel(satModel))).setType(tpe), cost)))
result = Some(RuleSuccess(Solution(BooleanLiteral(true), Tuple(p.xs.map(valuateWithModel(satModel))).setType(tpe))))
case _ =>
result = Some(RuleInapplicable)
......@@ -53,7 +53,7 @@ class OptimisticGround(synth: Synthesizer) extends Rule("Optimistic Ground", syn
case (Some(false), _) =>
if (predicates.isEmpty) {
result = Some(RuleSuccess(Solution(BooleanLiteral(false), Error(p.phi+" is UNSAT!").setType(tpe), cost)))
result = Some(RuleSuccess(Solution(BooleanLiteral(false), Error(p.phi+" is UNSAT!").setType(tpe))))
} else {
result = Some(RuleInapplicable)
}
......@@ -105,7 +105,7 @@ class IntInduction(synth: Synthesizer) extends Rule("Int Induction", synth, 8, 5
, LetTuple(postXs, FunctionInvocation(newFun, Seq(Plus(Variable(inductOn), IntLiteral(1)))), lt.toExpr)))
)
Solution(BooleanLiteral(true), LetDef(newFun, FunctionInvocation(newFun, Seq(Variable(origId)))), base.cost+gt.cost+lt.cost+cost)
Solution(BooleanLiteral(true), LetDef(newFun, FunctionInvocation(newFun, Seq(Variable(origId)))))
case _ =>
Solution.none
}
......
......@@ -9,5 +9,5 @@ import leon.purescala.Common._
case class Problem(as: List[Identifier], phi: Expr, xs: List[Identifier]) {
override def toString = "⟦ "+as.mkString(";")+" ⟨ "+phi+" ⟩ "+xs.mkString(";")+" ⟧ "
def complexity: Complexity = Complexity.max
val complexity: ProblemComplexity = ProblemComplexity(this)
}
......@@ -34,8 +34,8 @@ abstract class Rule(val name: String, val synth: Synthesizer, val priority: Prio
def subst(what: Tuple2[Identifier, Expr], in: Expr): Expr = replace(Map(Variable(what._1) -> what._2), in)
def substAll(what: Map[Identifier, Expr], in: Expr): Expr = replace(what.map(w => Variable(w._1) -> w._2), in)
def forward(cost: Cost): List[Solution] => Solution = {
case List(s) => Solution(s.pre, s.term, s.cost + cost)
val forward: List[Solution] => Solution = {
case List(s) => Solution(s.pre, s.term)
case _ => Solution.none
}
......@@ -65,11 +65,11 @@ class OnePoint(synth: Synthesizer) extends Rule("One-point", synth, 30, 0) {
val newProblem = Problem(p.as, subst(x -> e, And(others)), oxs)
val onSuccess: List[Solution] => Solution = {
case List(Solution(pre, term, c)) =>
case List(Solution(pre, term)) =>
if (oxs.isEmpty) {
Solution(pre, Tuple(e :: Nil), c + cost)
Solution(pre, Tuple(e :: Nil))
} else {
Solution(pre, LetTuple(oxs, term, subst(x -> e, Tuple(p.xs.map(Variable(_))))), c + cost)
Solution(pre, LetTuple(oxs, term, subst(x -> e, Tuple(p.xs.map(Variable(_))))))
}
case _ => Solution.none
}
......@@ -91,9 +91,9 @@ class Ground(synth: Synthesizer) extends Rule("Ground", synth, 50, 0) {
synth.solveSAT(p.phi) match {
case (Some(true), model) =>
RuleSuccess(Solution(BooleanLiteral(true), Tuple(p.xs.map(valuateWithModel(model))).setType(tpe), cost))
RuleSuccess(Solution(BooleanLiteral(true), Tuple(p.xs.map(valuateWithModel(model))).setType(tpe)))
case (Some(false), model) =>
RuleSuccess(Solution(BooleanLiteral(false), Error(p.phi+" is UNSAT!").setType(tpe), cost))
RuleSuccess(Solution(BooleanLiteral(false), Error(p.phi+" is UNSAT!").setType(tpe)))
case _ =>
RuleInapplicable
}
......@@ -112,7 +112,7 @@ class CaseSplit(synth: Synthesizer) extends Rule("Case-Split", synth, 20, 0) {
val sub2 = Problem(p.as, o2, p.xs)
val onSuccess: List[Solution] => Solution = {
case List(Solution(p1, t1, c1), Solution(p2, t2, c2)) => Solution(Or(p1, p2), IfExpr(p1, t1, t2), c1+c2+cost)
case List(Solution(p1, t1), Solution(p2, t2)) => Solution(Or(p1, p2), IfExpr(p1, t1, t2))
case _ => Solution.none
}
......@@ -135,10 +135,10 @@ class Assert(synth: Synthesizer) extends Rule("Assert", synth, 20, 0) {
if (!exprsA.isEmpty) {
if (others.isEmpty) {
RuleSuccess(Solution(And(exprsA), Tuple(p.xs.map(id => simplestValue(Variable(id)))), cost))
RuleSuccess(Solution(And(exprsA), Tuple(p.xs.map(id => simplestValue(Variable(id))))))
} else {
val onSuccess: List[Solution] => Solution = {
case List(s) => Solution(And(s.pre +: exprsA), s.term, s.cost + cost)
case List(s) => Solution(And(s.pre +: exprsA), s.term)
case _ => Solution.none
}
......@@ -163,7 +163,7 @@ class UnusedInput(synth: Synthesizer) extends Rule("UnusedInput", synth, 10, 0)
if (!unused.isEmpty) {
val sub = p.copy(as = p.as.filterNot(unused))
RuleDecomposed(List(sub), forward(cost))
RuleDecomposed(List(sub), forward)
} else {
RuleInapplicable
}
......@@ -180,7 +180,7 @@ class UnconstrainedOutput(synth: Synthesizer) extends Rule("Unconstr.Output", sy
val onSuccess: List[Solution] => Solution = {
case List(s) =>
Solution(s.pre, LetTuple(sub.xs, s.term, Tuple(p.xs.map(id => if (unconstr(id)) simplestValue(Variable(id)) else Variable(id)))), s.cost + cost)
Solution(s.pre, LetTuple(sub.xs, s.term, Tuple(p.xs.map(id => if (unconstr(id)) simplestValue(Variable(id)) else Variable(id)))))
case _ =>
Solution.none
}
......@@ -215,7 +215,7 @@ object Unification {
val sub = p.copy(phi = And((exprs.toSet -- toRemove ++ toAdd.flatten).toSeq))
RuleDecomposed(List(sub), forward(cost))
RuleDecomposed(List(sub), forward)
} else {
RuleInapplicable
}
......@@ -240,7 +240,7 @@ object Unification {
if (isImpossible) {
val tpe = TupleType(p.xs.map(_.getType))
RuleSuccess(Solution(BooleanLiteral(false), Error(p.phi+" is UNSAT!").setType(tpe), cost))
RuleSuccess(Solution(BooleanLiteral(false), Error(p.phi+" is UNSAT!").setType(tpe)))
} else {
RuleInapplicable
}
......@@ -269,7 +269,7 @@ class ADTDual(synth: Synthesizer) extends Rule("ADTDual", synth, 20, 0) {
if (!toRemove.isEmpty) {
val sub = p.copy(phi = And((exprs.toSet -- toRemove ++ toAdd.flatten).toSeq))
RuleDecomposed(List(sub), forward(cost))
RuleDecomposed(List(sub), forward)
} else {
RuleInapplicable
}
......@@ -278,7 +278,7 @@ class ADTDual(synth: Synthesizer) extends Rule("ADTDual", synth, 20, 0) {
class GiveUp(synth: Synthesizer) extends Rule("GiveUp", synth, 0, 100) {
def applyOn(task: Task): RuleResult = {
RuleSuccess(Solution.choose(task.problem, cost))
RuleSuccess(Solution.choose(task.problem))
}
}
......@@ -6,9 +6,11 @@ import leon.purescala.TreeOps.simplifyLets
// Defines a synthesis solution of the form:
// ⟨ P | T ⟩
class Solution(val pre: Expr, val term: Expr, val cost: Cost) {
class Solution(val pre: Expr, val term: Expr) {
override def toString = "⟨ "+pre+" | "+term+" ⟩"
lazy val complexity: SolutionComplexity = new SolutionComplexity(this)
def toExpr = {
if (pre == BooleanLiteral(true)) {
term
......@@ -18,20 +20,19 @@ class Solution(val pre: Expr, val term: Expr, val cost: Cost) {
IfExpr(pre, term, Error("Precondition failed").setType(term.getType))
}
}
def complexity = Complexity.zero
}
object Solution {
def choose(p: Problem, cost: Cost): Solution = new Solution(BooleanLiteral(true), Choose(p.xs, p.phi), cost)
def choose(p: Problem): Solution =
new Solution(BooleanLiteral(true), Choose(p.xs, p.phi))
def none: Solution = throw new Exception("Unexpected failure to construct solution")
def simplify(e: Expr) = simplifyLets(e)
def apply(pre: Expr, term: Expr, cost: Cost) = {
new Solution(simplify(pre), simplify(term), cost)
def apply(pre: Expr, term: Expr) = {
new Solution(simplify(pre), simplify(term))
}
def unapply(s: Solution): Option[(Expr, Expr, Cost)] = if (s eq null) None else Some((s.pre, s.term, s.cost))
def unapply(s: Solution): Option[(Expr, Expr)] = if (s eq null) None else Some((s.pre, s.term))
}
......@@ -6,9 +6,9 @@ class Task(synth: Synthesizer,
val problem: Problem,
val rule: Rule) extends Ordered[Task] {
def compare(that: Task) = -(this.complexity compare that.complexity) // sort by complexity ASC
val complexity: TaskComplexity = new TaskComplexity(this)
val complexity = new TaskComplexity(problem, Option(rule))
def compare(that: Task) = that.complexity compare this.complexity // sort by complexity ASC
var subProblems: List[Problem] = Nil
var onSuccess: List[Solution] => Solution = _
......@@ -16,14 +16,14 @@ class Task(synth: Synthesizer,
var subSolvers : Map[Problem, Task] = Map()
var solution : Option[Solution] = None
def currentComplexityFor(p: Problem): Complexity =
subSolutions.get(p) match {
case Some(s) => s.complexity
case None => Complexity.max
def isBetterSolutionThan(sol: Solution, osol: Option[Solution]): Boolean =
osol match {
case Some(s) => s.complexity > sol.complexity
case None => true
}
def partlySolvedBy(t: Task, s: Solution) {
if (s.complexity < currentComplexityFor(t.problem)) {
if (isBetterSolutionThan(s, subSolutions.get(t.problem))) {
subSolutions += t.problem -> s
subSolvers += t.problem -> t
......@@ -52,14 +52,18 @@ class Task(synth: Synthesizer,
}
}
lazy val minSolutionCost: Cost = rule.cost + parent.minSolutionCost
override def toString = "Applying "+rule+" on "+problem
}
class RootTask(synth: Synthesizer, problem: Problem) extends Task(synth, null, problem, null) {
var solver: Option[Task] = None
override lazy val minSolutionCost = 0
override def partlySolvedBy(t: Task, s: Solution) = {
if (s.complexity < solution.map(_.complexity).getOrElse(Complexity.max)) {
if (isBetterSolutionThan(s, solution)) {
solution = Some(s)
solver = Some(t)
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment