From c9f0f7de007a9d14f88bffab74b327542c896521 Mon Sep 17 00:00:00 2001 From: Etienne Kneuss <ekneuss@gmail.com> Date: Tue, 13 Nov 2012 16:45:33 +0100 Subject: [PATCH] Move cegis and optimistic ground to rules --- .../scala/leon/synthesis/Heuristics.scala | 219 +--------------- src/main/scala/leon/synthesis/Rules.scala | 236 +++++++++++++++++- .../scala/leon/synthesis/Synthesizer.scala | 6 +- src/main/scala/leon/synthesis/Task.scala | 6 +- 4 files changed, 246 insertions(+), 221 deletions(-) diff --git a/src/main/scala/leon/synthesis/Heuristics.scala b/src/main/scala/leon/synthesis/Heuristics.scala index 8cd507344..97f1cb96e 100644 --- a/src/main/scala/leon/synthesis/Heuristics.scala +++ b/src/main/scala/leon/synthesis/Heuristics.scala @@ -10,9 +10,7 @@ import purescala.Definitions._ object Heuristics { def all = Set[Synthesizer => Rule]( - new OptimisticGround(_), - //new IntInduction(_), - new CEGIS(_), + new IntInduction(_), new OptimisticInjection(_) ) } @@ -23,62 +21,6 @@ trait Heuristic { override def toString = "H: "+name } -class OptimisticGround(synth: Synthesizer) extends Rule("Optimistic Ground", synth, 90) with Heuristic { - def applyOn(task: Task): RuleResult = { - val p = task.problem - - if (!p.as.isEmpty && !p.xs.isEmpty) { - val xss = p.xs.toSet - val ass = p.as.toSet - - val tpe = TupleType(p.xs.map(_.getType)) - - var i = 0; - var maxTries = 3; - - var result: Option[RuleResult] = None - var predicates: Seq[Expr] = Seq() - - while (result.isEmpty && i < maxTries) { - val phi = And(p.phi +: predicates) - synth.solver.solveSAT(phi) match { - case (Some(true), satModel) => - val satXsModel = satModel.filterKeys(xss) - - val newPhi = valuateWithModelIn(phi, xss, satModel) - - synth.solver.solveSAT(Not(newPhi)) match { - case (Some(true), invalidModel) => - // Found as such as the xs break, refine predicates - predicates = valuateWithModelIn(phi, ass, invalidModel) +: predicates - - case (Some(false), _) => - result = Some(RuleSuccess(Solution(BooleanLiteral(true), Tuple(p.xs.map(valuateWithModel(satModel))).setType(tpe)))) - - case _ => - result = Some(RuleInapplicable) - } - - case (Some(false), _) => - if (predicates.isEmpty) { - result = Some(RuleSuccess(Solution(BooleanLiteral(false), Error(p.phi+" is UNSAT!").setType(tpe)))) - } else { - result = Some(RuleInapplicable) - } - case _ => - result = Some(RuleInapplicable) - } - - i += 1 - } - - result.getOrElse(RuleInapplicable) - } else { - RuleInapplicable - } - } -} - class IntInduction(synth: Synthesizer) extends Rule("Int Induction", synth, 80) with Heuristic { def applyOn(task: Task): RuleResult = { @@ -105,7 +47,7 @@ class IntInduction(synth: Synthesizer) extends Rule("Int Induction", synth, 80) val onSuccess: List[Solution] => Solution = { case List(base, gt, lt) => val newFun = new FunDef(FreshIdentifier("rec", true), tpe, Seq(VarDecl(inductOn, inductOn.getType))) - newFun.body = Some( + newFun.body = Some( IfExpr(Equals(Variable(inductOn), IntLiteral(0)), base.toExpr, IfExpr(GreaterThan(Variable(inductOn), IntLiteral(0)), @@ -197,160 +139,3 @@ class SelectiveInlining(synth: Synthesizer) extends Rule("Sel. Inlining", synth, } } -class CEGIS(synth: Synthesizer) extends Rule("CEGIS", synth, 50) with Heuristic { - def applyOn(task: Task): RuleResult = { - val p = task.problem - - case class Generator(tpe: TypeTree, altBuilder: () => List[(Expr, Set[Identifier])]); - - var generators = Map[TypeTree, Generator]() - def getGenerator(t: TypeTree): Generator = generators.get(t) match { - case Some(g) => g - case None => - val alternatives: () => List[(Expr, Set[Identifier])] = t match { - case BooleanType => - { () => List((BooleanLiteral(true), Set()), (BooleanLiteral(false), Set())) } - - case Int32Type => - { () => List((IntLiteral(0), Set()), (IntLiteral(1), Set())) } - - case TupleType(tps) => - { () => - val ids = tps.map(t => FreshIdentifier("t", true).setType(t)) - List((Tuple(ids.map(Variable(_))), ids.toSet)) - } - - case CaseClassType(cd) => - { () => - val ids = cd.fieldsIds.map(i => FreshIdentifier("c", true).setType(i.getType)) - List((CaseClass(cd, ids.map(Variable(_))), ids.toSet)) - } - - case AbstractClassType(cd) => - { () => - val alts: Seq[(Expr, Set[Identifier])] = cd.knownDescendents.flatMap(i => i match { - case acd: AbstractClassDef => - synth.reporter.error("Unnexpected abstract class in descendants!") - None - case cd: CaseClassDef => - val ids = cd.fieldsIds.map(i => FreshIdentifier("c", true).setType(i.getType)) - Some((CaseClass(cd, ids.map(Variable(_))), ids.toSet)) - }) - alts.toList - } - - case _ => - synth.reporter.error("Can't construct generator. Unsupported type: "+t+"["+t.getClass+"]"); - { () => Nil } - } - val g = Generator(t, alternatives) - generators += t -> g - g - } - - def inputAlternatives(t: TypeTree): List[(Expr, Set[Identifier])] = { - p.as.filter(a => isSubtypeOf(a.getType, t)).map(id => (Variable(id) : Expr, Set[Identifier]())) - } - - case class TentativeFormula(phi: Expr, - program: Expr, - mappings: Map[Identifier, (Identifier, Expr)], - recTerms: Map[Identifier, Set[Identifier]]) { - def unroll: TentativeFormula = { - var newProgram = List[Expr]() - var newRecTerms = Map[Identifier, Set[Identifier]]() - var newMappings = Map[Identifier, (Identifier, Expr)]() - - for ((_, recIds) <- recTerms; recId <- recIds) { - val gen = getGenerator(recId.getType) - val alts = gen.altBuilder() ::: inputAlternatives(recId.getType) - - val altsWithBranches = alts.map(alt => FreshIdentifier("b", true).setType(BooleanType) -> alt) - - val pre = Or(altsWithBranches.map{ case (id, _) => Variable(id) }) // b1 OR b2 - val cases = for((bid, (ex, rec)) <- altsWithBranches.toList) yield { // b1 => E(gen1, gen2) [b1 -> {gen1, gen2}] - if (!rec.isEmpty) { - newRecTerms += bid -> rec - } else { - newMappings += bid -> (recId -> ex) - } - - Implies(Variable(bid), Equals(Variable(recId), ex)) - } - - newProgram = newProgram ::: pre :: cases - } - - TentativeFormula(phi, And(program :: newProgram), mappings ++ newMappings, newRecTerms) - } - - def bounds = recTerms.keySet.map(id => Not(Variable(id))).toList - def bss = mappings.keySet - - def entireFormula = And(phi :: program :: bounds) - } - - var result: Option[RuleResult] = None - - var ass = p.as.toSet - var xss = p.xs.toSet - - var lastF = TentativeFormula(And(p.c, p.phi), BooleanLiteral(true), Map(), Map() ++ p.xs.map(x => x -> Set(x))) - var currentF = lastF.unroll - var unrolings = 0 - val maxUnrolings = 2 - do { - println("Was: "+lastF.entireFormula) - println("Now Trying : "+currentF.entireFormula) - - val tpe = TupleType(p.xs.map(_.getType)) - val bss = currentF.bss - - var predicates: Seq[Expr] = Seq() - var continue = true - - while (result.isEmpty && continue) { - val basePhi = currentF.entireFormula - val constrainedPhi = And(basePhi +: predicates) - println("-"*80) - println("To satisfy: "+constrainedPhi) - synth.solver.solveSAT(constrainedPhi) match { - case (Some(true), satModel) => - println("Found candidate!: "+satModel.filterKeys(bss)) - - val fixedBss = And(bss.map(b => Equals(Variable(b), satModel(b))).toSeq) - println("Phi with fixed sat bss: "+fixedBss) - - val counterPhi = Implies(And(fixedBss, currentF.program), currentF.phi) - println("Formula to validate: "+counterPhi) - - synth.solver.solveSAT(Not(counterPhi)) match { - case (Some(true), invalidModel) => - // Found as such as the xs break, refine predicates - println("Found counter EX: "+invalidModel) - predicates = Not(And(bss.map(b => Equals(Variable(b), satModel(b))).toSeq)) +: predicates - println("Let's avoid this case: "+bss.map(b => Equals(Variable(b), satModel(b))).mkString(" ")) - - case (Some(false), _) => - val mapping = currentF.mappings.filterKeys(satModel.mapValues(_ == BooleanLiteral(true))).values.toMap - println("Mapping: "+mapping) - result = Some(RuleSuccess(Solution(BooleanLiteral(true), Tuple(p.xs.map(valuateWithModel(mapping))).setType(tpe)))) - - case _ => - } - - case (Some(false), _) => - continue = false - case _ => - continue = false - } - } - - lastF = currentF - currentF = currentF.unroll - unrolings += 1 - } while(unrolings < maxUnrolings && lastF != currentF && result.isEmpty) - - result.getOrElse(RuleInapplicable) - } -} diff --git a/src/main/scala/leon/synthesis/Rules.scala b/src/main/scala/leon/synthesis/Rules.scala index fc711b1e7..563ea4166 100644 --- a/src/main/scala/leon/synthesis/Rules.scala +++ b/src/main/scala/leon/synthesis/Rules.scala @@ -7,6 +7,7 @@ import purescala.Trees._ import purescala.Extractors._ import purescala.TreeOps._ import purescala.TypeTrees._ +import purescala.Definitions._ object Rules { def all = Set[Synthesizer => Rule]( @@ -18,6 +19,8 @@ object Rules { new CaseSplit(_), new UnusedInput(_), new UnconstrainedOutput(_), + new OptimisticGround(_), + new CEGIS(_), new Assert(_) ) } @@ -160,7 +163,7 @@ class Assert(synth: Synthesizer) extends Rule("Assert", synth, 200) { class UnusedInput(synth: Synthesizer) extends Rule("UnusedInput", synth, 100) { def applyOn(task: Task): RuleResult = { val p = task.problem - val unused = p.as.toSet -- variablesOf(p.phi) + val unused = p.as.toSet -- variablesOf(p.phi) -- variablesOf(p.c) if (!unused.isEmpty) { val sub = p.copy(as = p.as.filterNot(unused)) @@ -278,3 +281,234 @@ class ADTDual(synth: Synthesizer) extends Rule("ADTDual", synth, 200) { } } +class CEGIS(synth: Synthesizer) extends Rule("CEGIS", synth, 50) { + def applyOn(task: Task): RuleResult = { + val p = task.problem + + case class Generator(tpe: TypeTree, altBuilder: () => List[(Expr, Set[Identifier])]); + + var generators = Map[TypeTree, Generator]() + def getGenerator(t: TypeTree): Generator = generators.get(t) match { + case Some(g) => g + case None => + val alternatives: () => List[(Expr, Set[Identifier])] = t match { + case BooleanType => + { () => List((BooleanLiteral(true), Set()), (BooleanLiteral(false), Set())) } + + case Int32Type => + { () => List((IntLiteral(0), Set()), (IntLiteral(1), Set())) } + + case TupleType(tps) => + { () => + val ids = tps.map(t => FreshIdentifier("t", true).setType(t)) + List((Tuple(ids.map(Variable(_))), ids.toSet)) + } + + case CaseClassType(cd) => + { () => + val ids = cd.fieldsIds.map(i => FreshIdentifier("c", true).setType(i.getType)) + List((CaseClass(cd, ids.map(Variable(_))), ids.toSet)) + } + + case AbstractClassType(cd) => + { () => + val alts: Seq[(Expr, Set[Identifier])] = cd.knownDescendents.flatMap(i => i match { + case acd: AbstractClassDef => + synth.reporter.error("Unnexpected abstract class in descendants!") + None + case cd: CaseClassDef => + val ids = cd.fieldsIds.map(i => FreshIdentifier("c", true).setType(i.getType)) + Some((CaseClass(cd, ids.map(Variable(_))), ids.toSet)) + }) + alts.toList + } + + case _ => + synth.reporter.error("Can't construct generator. Unsupported type: "+t+"["+t.getClass+"]"); + { () => Nil } + } + val g = Generator(t, alternatives) + generators += t -> g + g + } + + def inputAlternatives(t: TypeTree): List[(Expr, Set[Identifier])] = { + p.as.filter(a => isSubtypeOf(a.getType, t)).map(id => (Variable(id) : Expr, Set[Identifier]())) + } + + case class TentativeFormula(phi: Expr, + program: Expr, + mappings: Map[Identifier, (Identifier, Expr)], + recTerms: Map[Identifier, Set[Identifier]]) { + def unroll: TentativeFormula = { + var newProgram = List[Expr]() + var newRecTerms = Map[Identifier, Set[Identifier]]() + var newMappings = Map[Identifier, (Identifier, Expr)]() + + for ((_, recIds) <- recTerms; recId <- recIds) { + val gen = getGenerator(recId.getType) + val alts = gen.altBuilder() ::: inputAlternatives(recId.getType) + + val altsWithBranches = alts.map(alt => FreshIdentifier("b", true).setType(BooleanType) -> alt) + + val bvs = altsWithBranches.map(alt => Variable(alt._1)) + val distinct = if (bvs.size > 1) { + (for (i <- (1 to bvs.size-1); j <- 0 to i-1) yield { + Or(Not(bvs(i)), Not(bvs(j))) + }).toList + } else { + List(BooleanLiteral(true)) + } + val pre = And(Or(bvs) :: distinct) // (b1 OR b2) AND (Not(b1) OR Not(b2)) + val cases = for((bid, (ex, rec)) <- altsWithBranches.toList) yield { // b1 => E(gen1, gen2) [b1 -> {gen1, gen2}] + if (!rec.isEmpty) { + newRecTerms += bid -> rec + } + newMappings += bid -> (recId -> ex) + + Implies(Variable(bid), Equals(Variable(recId), ex)) + } + + newProgram = newProgram ::: pre :: cases + } + + TentativeFormula(phi, And(program :: newProgram), mappings ++ newMappings, newRecTerms) + } + + def bounds = recTerms.keySet.map(id => Not(Variable(id))).toList + def bss = mappings.keySet + + def entireFormula = And(phi :: program :: bounds) + } + + var result: Option[RuleResult] = None + + var ass = p.as.toSet + var xss = p.xs.toSet + + var lastF = TentativeFormula(Implies(p.c, p.phi), BooleanLiteral(true), Map(), Map() ++ p.xs.map(x => x -> Set(x))) + var currentF = lastF.unroll + var unrolings = 0 + val maxUnrolings = 2 + do { + //println("Was: "+lastF.entireFormula) + //println("Now Trying : "+currentF.entireFormula) + + val tpe = TupleType(p.xs.map(_.getType)) + val bss = currentF.bss + + var predicates: Seq[Expr] = Seq() + var continue = true + + while (result.isEmpty && continue) { + val basePhi = currentF.entireFormula + val constrainedPhi = And(basePhi +: predicates) + //println("-"*80) + //println("To satisfy: "+constrainedPhi) + synth.solver.solveSAT(constrainedPhi) match { + case (Some(true), satModel) => + //println("Found candidate!: "+satModel.filterKeys(bss)) + + //println("Corresponding program: "+simplifyTautologies(synth.solver)(valuateWithModelIn(currentF.program, bss, satModel))) + val fixedBss = And(bss.map(b => Equals(Variable(b), satModel(b))).toSeq) + //println("Phi with fixed sat bss: "+fixedBss) + + val counterPhi = Implies(And(fixedBss, currentF.program), currentF.phi) + //println("Formula to validate: "+counterPhi) + + synth.solver.solveSAT(Not(counterPhi)) match { + case (Some(true), invalidModel) => + // Found as such as the xs break, refine predicates + //println("Found counter EX: "+invalidModel) + predicates = Not(And(bss.map(b => Equals(Variable(b), satModel(b))).toSeq)) +: predicates + //println("Let's avoid this case: "+bss.map(b => Equals(Variable(b), satModel(b))).mkString(" ")) + + case (Some(false), _) => + //println("Sat model: "+satModel.toSeq.sortBy(_._1.toString).map{ case (id, v) => id+" -> "+v }.mkString(", ")) + var mapping = currentF.mappings.filterKeys(satModel.mapValues(_ == BooleanLiteral(true))).values.toMap + + //println("Mapping: "+mapping) + + // Resolve mapping + for ((c, e) <- mapping) { + mapping += c -> substAll(mapping, e) + } + + result = Some(RuleSuccess(Solution(BooleanLiteral(true), Tuple(p.xs.map(valuateWithModel(mapping))).setType(tpe)))) + + case _ => + } + + case (Some(false), _) => + //println("%%%% UNSAT") + continue = false + case _ => + //println("%%%% WOOPS") + continue = false + } + } + + lastF = currentF + currentF = currentF.unroll + unrolings += 1 + } while(unrolings < maxUnrolings && lastF != currentF && result.isEmpty) + + result.getOrElse(RuleInapplicable) + } +} + +class OptimisticGround(synth: Synthesizer) extends Rule("Optimistic Ground", synth, 90) { + def applyOn(task: Task): RuleResult = { + val p = task.problem + + if (!p.as.isEmpty && !p.xs.isEmpty) { + val xss = p.xs.toSet + val ass = p.as.toSet + + val tpe = TupleType(p.xs.map(_.getType)) + + var i = 0; + var maxTries = 3; + + var result: Option[RuleResult] = None + var predicates: Seq[Expr] = Seq() + + while (result.isEmpty && i < maxTries) { + val phi = And(p.phi +: predicates) + synth.solver.solveSAT(phi) match { + case (Some(true), satModel) => + val satXsModel = satModel.filterKeys(xss) + + val newPhi = valuateWithModelIn(phi, xss, satModel) + + synth.solver.solveSAT(Not(newPhi)) match { + case (Some(true), invalidModel) => + // Found as such as the xs break, refine predicates + predicates = valuateWithModelIn(phi, ass, invalidModel) +: predicates + + case (Some(false), _) => + result = Some(RuleSuccess(Solution(BooleanLiteral(true), Tuple(p.xs.map(valuateWithModel(satModel))).setType(tpe)))) + + case _ => + result = Some(RuleInapplicable) + } + + case (Some(false), _) => + if (predicates.isEmpty) { + result = Some(RuleSuccess(Solution(BooleanLiteral(false), Error(p.phi+" is UNSAT!").setType(tpe)))) + } else { + result = Some(RuleInapplicable) + } + case _ => + result = Some(RuleInapplicable) + } + + i += 1 + } + + result.getOrElse(RuleInapplicable) + } else { + RuleInapplicable + } + } +} diff --git a/src/main/scala/leon/synthesis/Synthesizer.scala b/src/main/scala/leon/synthesis/Synthesizer.scala index ba8ec79c4..a811fe6b7 100644 --- a/src/main/scala/leon/synthesis/Synthesizer.scala +++ b/src/main/scala/leon/synthesis/Synthesizer.scala @@ -42,9 +42,11 @@ class Synthesizer(val reporter: Reporter, val ts = System.currentTimeMillis + def currentDurationMs = System.currentTimeMillis-ts + def timeoutExpired(): Boolean = { timeoutMs match { - case Some(t) if (System.currentTimeMillis-ts)/1000 > t => true + case Some(t) if currentDurationMs/1000 > t => true case _ => false } } @@ -64,6 +66,8 @@ class Synthesizer(val reporter: Reporter, } } + info("Finished in "+currentDurationMs+"ms") + if (generateDerivationTrees) { val deriv = new DerivationTree(rootTask) deriv.toDotFile("derivation"+derivationCounter+".dot") diff --git a/src/main/scala/leon/synthesis/Task.scala b/src/main/scala/leon/synthesis/Task.scala index ed3bd980d..7f312deef 100644 --- a/src/main/scala/leon/synthesis/Task.scala +++ b/src/main/scala/leon/synthesis/Task.scala @@ -121,8 +121,10 @@ class RootTask(synth: Synthesizer, problem: Problem) extends Task(synth, null, p } override def partlySolvedBy(t: Task, s: Solution) { - solution = Some(s) - solver = Some(t) + if (isBetterSolutionThan(s, solution)) { + solution = Some(s) + solver = Some(t) + } } override def unsolvedBy(t: Task) { -- GitLab