From ff401032147f1dce3084410f80cbbd8dfaf9df16 Mon Sep 17 00:00:00 2001 From: Etienne Kneuss <ekneuss@gmail.com> Date: Tue, 20 Nov 2012 04:31:17 +0100 Subject: [PATCH] Fix the fixes --- src/main/scala/leon/aographs/Graph.scala | 57 ++++-- src/main/scala/leon/synthesis/Cost.scala | 11 +- src/main/scala/leon/synthesis/Problem.scala | 11 + src/main/scala/leon/synthesis/Rules.scala | 2 +- src/main/scala/leon/synthesis/Solution.scala | 2 +- .../scala/leon/synthesis/SynthesisPhase.scala | 6 +- .../scala/leon/synthesis/Synthesizer.scala | 71 +++---- .../scala/leon/synthesis/rules/Cegis.scala | 188 +++++++++--------- 8 files changed, 193 insertions(+), 155 deletions(-) diff --git a/src/main/scala/leon/aographs/Graph.scala b/src/main/scala/leon/aographs/Graph.scala index 9f986e67e..e535b71ad 100644 --- a/src/main/scala/leon/aographs/Graph.scala +++ b/src/main/scala/leon/aographs/Graph.scala @@ -32,7 +32,7 @@ trait AOSolution { class AndOrGraph[AT <: AOAndTask[S], OT <: AOOrTask[S], S <: AOSolution](val root: OT) { type C = AOCost - var tree: Tree = RootNode + var tree: OrTree = RootNode trait Tree { val task : AOTask[S] @@ -49,11 +49,13 @@ class AndOrGraph[AT <: AOAndTask[S], OT <: AOOrTask[S], S <: AOSolution](val roo abstract class OrTree extends Tree { override val task: OT + + def isUnsolvable: Boolean = false } trait Leaf extends Tree { - val minCost: C = task.cost + def minCost: C = task.cost def isSolved: Boolean = false } @@ -110,13 +112,22 @@ class AndOrGraph[AT <: AOAndTask[S], OT <: AOOrTask[S], S <: AOSolution](val roo } - def notifyParent(s: S) { - parent.notifySolution(this, s) + def notifyParent(sol: S) { + if (parent ne null) { + parent.notifySolution(this, sol) + } } } object RootNode extends OrLeaf(null, root) { + override def expandWith(succ: List[AT]) { + val n = OrNode(null, Map(), root) + n.alternatives = succ.map(t => t -> AndLeaf(n, t)).toMap + n.minAlternative = n.computeMin + + tree = n + } } case class AndLeaf(parent: OrNode, task: AT) extends AndTree with Leaf { @@ -155,6 +166,8 @@ class AndOrGraph[AT <: AOAndTask[S], OT <: AOOrTask[S], S <: AOSolution](val roo n.minCost = n.computeCost alternatives += l.task -> n + + minAlternative = computeMin } def notifySolution(sub: AndTree, sol: S) { @@ -163,14 +176,22 @@ class AndOrGraph[AT <: AOAndTask[S], OT <: AOOrTask[S], S <: AOSolution](val roo solution = Some(sol) minAlternative = sub - parent.notifySolution(this, solution.get) + notifyParent(solution.get) case None => solution = Some(sol) minAlternative = sub - parent.notifySolution(this, solution.get) + notifyParent(solution.get) } } + + def notifyParent(sol: S) { + if (parent ne null) { + parent.notifySolution(this, sol) + } + } + + override def isUnsolvable: Boolean = alternatives.isEmpty } case class OrLeaf(parent: AndNode, task: OT) extends OrTree with Leaf { @@ -193,7 +214,7 @@ abstract class AndOrGraphSearch[AT <: AOAndTask[S], OT <: AOOrTask[S], S <: AOSo res = Some(l) case an: g.AndNode => - c = an.subProblems.values.minBy(_.minCost) + c = (an.subProblems -- an.subSolutions.keySet).values.minBy(_.minCost) case on: g.OrNode => c = on.minAlternative @@ -203,22 +224,21 @@ abstract class AndOrGraphSearch[AT <: AOAndTask[S], OT <: AOOrTask[S], S <: AOSo res } - abstract class ExpandResult - case class ExpandedAnd(sub: List[OT]) extends ExpandResult - case class ExpandedOr(sub: List[AT]) extends ExpandResult - case class ExpandSuccess(sol: S) extends ExpandResult - case object ExpandFailure extends ExpandResult + abstract class ExpandResult[T <: AOTask[S]] + case class Expanded[T <: AOTask[S]](sub: List[T]) extends ExpandResult[T] + case class ExpandSuccess[T <: AOTask[S]](sol: S) extends ExpandResult[T] + case class ExpandFailure[T <: AOTask[S]]() extends ExpandResult[T] var continue = true def search = { - while (!g.tree.isSolved && continue) { + while (!g.tree.isSolved && continue && !g.tree.isUnsolvable) { nextLeaf match { case Some(l) => l match { case al: g.AndLeaf => - processLeaf(al) match { - case ExpandedAnd(ls) => + processAndLeaf(al.task) match { + case Expanded(ls) => al.expandWith(ls) case r @ ExpandSuccess(sol) => al.parent.notifySolution(al, sol) @@ -226,8 +246,8 @@ abstract class AndOrGraphSearch[AT <: AOAndTask[S], OT <: AOOrTask[S], S <: AOSo al.parent.unsolvable(al) } case ol: g.OrLeaf => - processLeaf(ol) match { - case ExpandedOr(ls) => + processOrLeaf(ol.task) match { + case Expanded(ls) => ol.expandWith(ls) case r @ ExpandSuccess(sol) => ol.parent.notifySolution(ol, sol) @@ -241,5 +261,6 @@ abstract class AndOrGraphSearch[AT <: AOAndTask[S], OT <: AOOrTask[S], S <: AOSo } } - def processLeaf(l: g.Leaf): ExpandResult + def processAndLeaf(l: AT): ExpandResult[OT] + def processOrLeaf(l: OT): ExpandResult[AT] } diff --git a/src/main/scala/leon/synthesis/Cost.scala b/src/main/scala/leon/synthesis/Cost.scala index bd479a489..c25e382ed 100644 --- a/src/main/scala/leon/synthesis/Cost.scala +++ b/src/main/scala/leon/synthesis/Cost.scala @@ -9,12 +9,19 @@ import aographs.AOCost case class SolutionCost(s: Solution) extends AOCost { val value = { val chooses = collectChooses(s.toExpr) - val chooseCost = chooses.foldLeft(0)((i, c) => i + (1000 * math.pow(2, c.vars.size).toInt + formulaSize(c.pred))) + val chooseCost = chooses.foldLeft(0)((i, c) => i + ProblemCost(Problem.fromChoose(c)).value) formulaSize(s.toExpr) + chooseCost } } case class ProblemCost(p: Problem) extends AOCost { - val value = math.pow(2, p.xs.size).toInt + formulaSize(p.phi) + val value = math.pow(2, p.xs.size).toInt + formulaSize(p.phi)*1000 +} + +case class RuleApplicationCost(rule: Rule, app: RuleApplication) extends AOCost { + val subSols = (1 to app.subProblemsCount).map {i => Solution.simplest }.toList + val simpleSol = app.onSuccess(subSols) + + val value = SolutionCost(simpleSol).value*1000 + 1000-rule.priority } diff --git a/src/main/scala/leon/synthesis/Problem.scala b/src/main/scala/leon/synthesis/Problem.scala index 2f85a916f..89b452873 100644 --- a/src/main/scala/leon/synthesis/Problem.scala +++ b/src/main/scala/leon/synthesis/Problem.scala @@ -2,6 +2,7 @@ package leon package synthesis import leon.purescala.Trees._ +import leon.purescala.TreeOps._ import leon.purescala.Common._ // Defines a synthesis triple of the form: @@ -11,3 +12,13 @@ case class Problem(as: List[Identifier], c: Expr, phi: Expr, xs: List[Identifier val complexity: ProblemComplexity = ProblemComplexity(this) } + +object Problem { + def fromChoose(ch: Choose): Problem = { + val xs = ch.vars + val phi = ch.pred + val as = (variablesOf(phi)--xs).toList + + Problem(as, BooleanLiteral(true), phi, xs) + } +} diff --git a/src/main/scala/leon/synthesis/Rules.scala b/src/main/scala/leon/synthesis/Rules.scala index 8fac72bce..a9a7f1735 100644 --- a/src/main/scala/leon/synthesis/Rules.scala +++ b/src/main/scala/leon/synthesis/Rules.scala @@ -20,7 +20,7 @@ object Rules { new EqualitySplit(_), new CEGIS(_), new Assert(_), - new ADTSplit(_), +// new ADTSplit(_), new IntegerEquation(_), new IntegerInequalities(_) ) diff --git a/src/main/scala/leon/synthesis/Solution.scala b/src/main/scala/leon/synthesis/Solution.scala index 255d7fd10..a81c01036 100644 --- a/src/main/scala/leon/synthesis/Solution.scala +++ b/src/main/scala/leon/synthesis/Solution.scala @@ -11,7 +11,7 @@ import aographs._ class Solution(val pre: Expr, val defs: Set[FunDef], val term: Expr) extends AOSolution { override def toString = "⟨ "+pre+" | "+defs.mkString(" ")+" "+term+" ⟩" - val cost: AOCost = null + val cost: AOCost = SolutionCost(this) def toExpr = { val result = if (pre == BooleanLiteral(true)) { diff --git a/src/main/scala/leon/synthesis/SynthesisPhase.scala b/src/main/scala/leon/synthesis/SynthesisPhase.scala index 78103a3b1..dff4aea5b 100644 --- a/src/main/scala/leon/synthesis/SynthesisPhase.scala +++ b/src/main/scala/leon/synthesis/SynthesisPhase.scala @@ -61,11 +61,7 @@ object SynthesisPhase extends LeonPhase[Program, Program] { def actOnChoose(f: FunDef)(e: Expr, a: Expr): Expr = e match { case ch @ Choose(vars, pred) => - val xs = vars - val as = (variablesOf(pred)--xs).toList - val phi = pred - - val problem = Problem(as, BooleanLiteral(true), phi, xs) + val problem = Problem.fromChoose(ch) val synth = new Synthesizer(ctx.reporter, mainSolver, problem, Rules.all ++ Heuristics.all, genTrees, filterFun.map(_.toSet), firstOnly, timeoutMs) val sol = synth.synthesize() diff --git a/src/main/scala/leon/synthesis/Synthesizer.scala b/src/main/scala/leon/synthesis/Synthesizer.scala index ecda64562..423e2e81f 100644 --- a/src/main/scala/leon/synthesis/Synthesizer.scala +++ b/src/main/scala/leon/synthesis/Synthesizer.scala @@ -59,59 +59,54 @@ class Synthesizer(val reporter: Reporter, import aographs._ - abstract class Task extends AOTask[Solution] - case class TaskRunRule(problem: Problem, rule: Rule, app: RuleApplication) extends Task { - val subSols = (1 to app.subProblemsCount).map {i => Solution.simplest }.toList - val simpleSol = app.onSuccess(subSols) - - def cost = SolutionCost(simpleSol) + case class TaskRunRule(problem: Problem, rule: Rule, app: RuleApplication) extends AOAndTask[Solution] { + def cost = RuleApplicationCost(rule, app) def composeSolution(sols: List[Solution]): Solution = { app.onSuccess(sols) } + + override def toString = rule.name+" ON "+problem } - case class TaskTryRules(p: Problem) extends Task { + case class TaskTryRules(p: Problem) extends AOOrTask[Solution] { def cost = ProblemCost(p) - def composeSolution(sols: List[Solution]): Solution = { - sys.error("Should not be called") - } + override def toString = " Splitting "+problem } - class AOSearch(problem: Problem, rules: Set[Rule]) extends AndOrGraphSearch[Task, Solution](new AndOrGraph(TaskTryRules(problem))) { - def processLeaf(l: g.Leaf) = { - l.task match { - case t: TaskTryRules => - val sub = rules.flatMap ( r => r.attemptToApplyOn(t.p).alternatives.map(TaskRunRule(t.p, r, _)) ) + class AOSearch(problem: Problem, rules: Set[Rule]) extends AndOrGraphSearch[TaskRunRule, TaskTryRules, Solution](new AndOrGraph(TaskTryRules(problem))) { - if (!sub.isEmpty) { - Expanded(sub.toList) - } else { - ExpandFailure - } + def processAndLeaf(t: TaskRunRule) = { + val prefix = "[%-20s] ".format(Option(t.rule).getOrElse("?")) - case t: TaskRunRule => - val prefix = "[%-20s] ".format(Option(t.rule).getOrElse("?")) + t.app.apply() match { + case RuleSuccess(sol) => + info(prefix+"Got: "+t.problem) + info(prefix+"Solved with: "+sol) + ExpandSuccess(sol) + case RuleDecomposed(sub, onSuccess) => info(prefix+"Got: "+t.problem) - t.app.apply() match { - case RuleSuccess(sol) => - info(prefix+"Solved with: "+sol) - - ExpandSuccess(sol) - case RuleDecomposed(sub, onSuccess) => - info(prefix+"Got: "+t.problem) - info(prefix+"Decomposed into:") - for(p <- sub) { - info(prefix+" - "+p) - } - - Expanded(sub.map(TaskTryRules(_))) - - case RuleApplicationImpossible => - ExpandFailure + info(prefix+"Decomposed into:") + for(p <- sub) { + info(prefix+" - "+p) } + + Expanded(sub.map(TaskTryRules(_))) + + case RuleApplicationImpossible => + ExpandFailure() + } + } + + def processOrLeaf(t: TaskTryRules) = { + val sub = rules.flatMap ( r => r.attemptToApplyOn(t.p).alternatives.map(TaskRunRule(t.p, r, _)) ) + + if (!sub.isEmpty) { + Expanded(sub.toList) + } else { + ExpandFailure() } } } diff --git a/src/main/scala/leon/synthesis/rules/Cegis.scala b/src/main/scala/leon/synthesis/rules/Cegis.scala index 50b944532..f6453f9fd 100644 --- a/src/main/scala/leon/synthesis/rules/Cegis.scala +++ b/src/main/scala/leon/synthesis/rules/Cegis.scala @@ -120,100 +120,108 @@ class CEGIS(synth: Synthesizer) extends Rule("CEGIS", synth, 150) { var unrolings = 0 val maxUnrolings = 3 var predicates: Seq[Expr] = Seq() - do { - //println("="*80) - //println("Was: "+lastF.entireFormula) - //println("Now Trying : "+currentF.entireFormula) - - val tpe = TupleType(p.xs.map(_.getType)) - val bss = currentF.bss - - var continue = true - - while (result.isEmpty && continue && synth.continue) { - val basePhi = currentF.entireFormula - val constrainedPhi = And(basePhi +: predicates) - //println("-"*80) - //println("To satisfy: "+constrainedPhi) - synth.solver.solveSAT(constrainedPhi) match { - case (Some(true), satModel) => - //println("Found candidate!: "+satModel.filterKeys(bss)) - - //println("Corresponding program: "+simplifyTautologies(synth.solver)(valuateWithModelIn(currentF.program, bss, satModel))) - val fixedBss = And(bss.map(b => Equals(Variable(b), satModel(b))).toSeq) - //println("Phi with fixed sat bss: "+fixedBss) - - val counterPhi = And(Seq(currentF.pathcond, fixedBss, currentF.program, Not(currentF.phi))) - //println("Formula to validate: "+counterPhi) - - synth.solver.solveSAT(counterPhi) match { - case (Some(true), invalidModel) => - val fixedAss = And(ass.map(a => Equals(Variable(a), invalidModel(a))).toSeq) - - val mustBeUnsat = And(currentF.pathcond :: currentF.program :: fixedAss :: currentF.phi :: Nil) - - val bssAssumptions: Set[Expr] = bss.toSet.map { b: Identifier => satModel(b) match { - case BooleanLiteral(true) => Variable(b) - case BooleanLiteral(false) => Not(Variable(b)) - }} - - val unsatCore = synth.solver.solveSATWithCores(mustBeUnsat, bssAssumptions) match { - case ((Some(false), _, core)) => - //println("Formula: "+mustBeUnsat) - //println("Core: "+core) - //println(synth.solver.solveSAT(And(mustBeUnsat +: bssAssumptions.toSeq))) - //println("maxcore: "+bssAssumptions) - if (core.isEmpty) { - synth.reporter.warning("Got empty core, must be unsat without assumptions!") - Set() - } else { - core - } - case _ => - bssAssumptions - } - - // Found as such as the xs break, refine predicates - //println("Found counter EX: "+invalidModel) - if (unsatCore.isEmpty) { + try { + do { + //println("="*80) + //println("Was: "+lastF.entireFormula) + //println("Now Trying : "+currentF.entireFormula) + + val tpe = TupleType(p.xs.map(_.getType)) + val bss = currentF.bss + + var continue = true + + while (result.isEmpty && continue && synth.continue) { + val basePhi = currentF.entireFormula + val constrainedPhi = And(basePhi +: predicates) + //println("-"*80) + //println("To satisfy: "+constrainedPhi) + synth.solver.solveSAT(constrainedPhi) match { + case (Some(true), satModel) => + //println("Found candidate!: "+satModel.filterKeys(bss)) + + //println("Corresponding program: "+simplifyTautologies(synth.solver)(valuateWithModelIn(currentF.program, bss, satModel))) + val fixedBss = And(bss.map(b => Equals(Variable(b), satModel(b))).toSeq) + //println("Phi with fixed sat bss: "+fixedBss) + + val counterPhi = And(Seq(currentF.pathcond, fixedBss, currentF.program, Not(currentF.phi))) + //println("Formula to validate: "+counterPhi) + + synth.solver.solveSAT(counterPhi) match { + case (Some(true), invalidModel) => + val fixedAss = And(ass.map(a => Equals(Variable(a), invalidModel(a))).toSeq) + + val mustBeUnsat = And(currentF.pathcond :: currentF.program :: fixedAss :: currentF.phi :: Nil) + + val bssAssumptions: Set[Expr] = bss.toSet.map { b: Identifier => satModel(b) match { + case BooleanLiteral(true) => Variable(b) + case BooleanLiteral(false) => Not(Variable(b)) + }} + + val unsatCore = synth.solver.solveSATWithCores(mustBeUnsat, bssAssumptions) match { + case ((Some(false), _, core)) => + //println("Formula: "+mustBeUnsat) + //println("Core: "+core) + //println(synth.solver.solveSAT(And(mustBeUnsat +: bssAssumptions.toSeq))) + //println("maxcore: "+bssAssumptions) + if (core.isEmpty) { + synth.reporter.warning("Got empty core, must be unsat without assumptions!") + Set() + } else { + core + } + case _ => + bssAssumptions + } + + // Found as such as the xs break, refine predicates + //println("Found counter EX: "+invalidModel) + if (unsatCore.isEmpty) { + continue = false + } else { + predicates = Not(And(unsatCore.toSeq)) +: predicates + } + + case (Some(false), _) => + //println("Sat model: "+satModel.toSeq.sortBy(_._1.toString).map{ case (id, v) => id+" -> "+v }.mkString(", ")) + var mapping = currentF.mappings.filterKeys(satModel.mapValues(_ == BooleanLiteral(true))).values.toMap + + //println("Mapping: "+mapping) + + // Resolve mapping + for ((c, e) <- mapping) { + mapping += c -> substAll(mapping, e) + } + + result = Some(RuleSuccess(Solution(BooleanLiteral(true), Set(), Tuple(p.xs.map(valuateWithModel(mapping))).setType(tpe)))) + + case _ => + synth.reporter.warning("Solver returned 'UNKNOWN' in a CEGIS iteration.") continue = false - } else { - predicates = Not(And(unsatCore.toSeq)) +: predicates - } - - case (Some(false), _) => - //println("Sat model: "+satModel.toSeq.sortBy(_._1.toString).map{ case (id, v) => id+" -> "+v }.mkString(", ")) - var mapping = currentF.mappings.filterKeys(satModel.mapValues(_ == BooleanLiteral(true))).values.toMap - - //println("Mapping: "+mapping) - - // Resolve mapping - for ((c, e) <- mapping) { - mapping += c -> substAll(mapping, e) - } - - result = Some(RuleSuccess(Solution(BooleanLiteral(true), Set(), Tuple(p.xs.map(valuateWithModel(mapping))).setType(tpe)))) - - case _ => - synth.reporter.warning("Solver returned 'UNKNOWN' in a CEGIS iteration.") - continue = false - } - - case (Some(false), _) => - //println("%%%% UNSAT") - continue = false - case _ => - //println("%%%% WOOPS") - continue = false + } + + case (Some(false), _) => + //println("%%%% UNSAT") + continue = false + case _ => + //println("%%%% WOOPS") + continue = false + } } - } - lastF = currentF - currentF = currentF.unroll - unrolings += 1 - } while(unrolings < maxUnrolings && lastF != currentF && result.isEmpty && synth.continue) + lastF = currentF + currentF = currentF.unroll + unrolings += 1 + } while(unrolings < maxUnrolings && lastF != currentF && result.isEmpty && synth.continue) + + result.getOrElse(RuleApplicationImpossible) + + } catch { + case e: Throwable => + synth.reporter.warning("CEGIS crashed: "+e.getMessage) + RuleApplicationImpossible + } - result.getOrElse(RuleApplicationImpossible) } } -- GitLab