package leon package synthesis import purescala.Common._ import purescala.ScalaPrinter import purescala.Trees._ import purescala.Extractors._ import purescala.TreeOps._ import purescala.TypeTrees._ import purescala.Definitions._ import LinearEquations.elimVariable import ArithmeticNormalization.simplify object Rules { def all = Set[Synthesizer => Rule]( new Unification.DecompTrivialClash(_), new Unification.OccursCheck(_), new ADTDual(_), new OnePoint(_), new Ground(_), new CaseSplit(_), new UnusedInput(_), new UnconstrainedOutput(_), new OptimisticGround(_), new EqualitySplit(_), new CEGIS(_), new Assert(_), new IntegerEquation(_) ) } sealed abstract class RuleResult case object RuleInapplicable extends RuleResult case class RuleSuccess(solution: Solution) extends RuleResult case class RuleMultiSteps(subProblems: List[Problem], steps: List[List[Solution] => List[Problem]], onSuccess: List[Solution] => (Solution, Boolean)) extends RuleResult object RuleStep { def apply(subProblems: List[Problem], onSuccess: List[Solution] => Solution) = { RuleMultiSteps(subProblems, Nil, onSuccess.andThen((_, true))) } } abstract class Rule(val name: String, val synth: Synthesizer, val priority: Priority) { def applyOn(task: Task): RuleResult def subst(what: Tuple2[Identifier, Expr], in: Expr): Expr = replace(Map(Variable(what._1) -> what._2), in) def substAll(what: Map[Identifier, Expr], in: Expr): Expr = replace(what.map(w => Variable(w._1) -> w._2), in) val forward: List[Solution] => Solution = { case List(s) => Solution(s.pre, s.defs, s.term) case _ => Solution.none } override def toString = "R: "+name } class OnePoint(synth: Synthesizer) extends Rule("One-point", synth, 300) { def applyOn(task: Task): RuleResult = { val p = task.problem val TopLevelAnds(exprs) = p.phi val candidates = exprs.collect { case eq @ Equals(Variable(x), e) if (p.xs contains x) && !(variablesOf(e) contains x) => (x, e, eq) case eq @ Equals(e, Variable(x)) if (p.xs contains x) && !(variablesOf(e) contains x) => (x, e, eq) } if (!candidates.isEmpty) { val (x, e, eq) = candidates.head val others = exprs.filter(_ != eq) val oxs = p.xs.filter(_ != x) val newProblem = Problem(p.as, p.c, subst(x -> e, And(others)), oxs) val onSuccess: List[Solution] => Solution = { case List(Solution(pre, defs, term)) => if (oxs.isEmpty) { Solution(pre, defs, Tuple(e :: Nil)) } else { Solution(pre, defs, LetTuple(oxs, term, subst(x -> e, Tuple(p.xs.map(Variable(_)))))) } case _ => Solution.none } RuleStep(List(newProblem), onSuccess) } else { RuleInapplicable } } } class Ground(synth: Synthesizer) extends Rule("Ground", synth, 500) { def applyOn(task: Task): RuleResult = { val p = task.problem if (p.as.isEmpty) { val tpe = TupleType(p.xs.map(_.getType)) synth.solver.solveSAT(p.phi) match { case (Some(true), model) => RuleSuccess(Solution(BooleanLiteral(true), Set(), Tuple(p.xs.map(valuateWithModel(model))).setType(tpe))) case (Some(false), model) => RuleSuccess(Solution(BooleanLiteral(false), Set(), Error(p.phi+" is UNSAT!").setType(tpe))) case _ => RuleInapplicable } } else { RuleInapplicable } } } class CaseSplit(synth: Synthesizer) extends Rule("Case-Split", synth, 200) { def applyOn(task: Task): RuleResult = { val p = task.problem p.phi match { case Or(Seq(o1, o2)) => val sub1 = Problem(p.as, p.c, o1, p.xs) val sub2 = Problem(p.as, p.c, o2, p.xs) val onSuccess: List[Solution] => Solution = { case List(Solution(p1, d1, t1), Solution(p2, d2, t2)) => Solution(Or(p1, p2), d1++d2, IfExpr(p1, t1, t2)) case _ => Solution.none } RuleStep(List(sub1, sub2), onSuccess) case _ => RuleInapplicable } } } class Assert(synth: Synthesizer) extends Rule("Assert", synth, 200) { def applyOn(task: Task): RuleResult = { val p = task.problem p.phi match { case TopLevelAnds(exprs) => val xsSet = p.xs.toSet val (exprsA, others) = exprs.partition(e => (variablesOf(e) & xsSet).isEmpty) if (!exprsA.isEmpty) { if (others.isEmpty) { RuleSuccess(Solution(And(exprsA), Set(), Tuple(p.xs.map(id => simplestValue(Variable(id)))))) } else { val sub = p.copy(c = And(p.c +: exprsA), phi = And(others)) RuleStep(List(sub), forward) } } else { RuleInapplicable } case _ => RuleInapplicable } } } class UnusedInput(synth: Synthesizer) extends Rule("UnusedInput", synth, 100) { def applyOn(task: Task): RuleResult = { val p = task.problem val unused = p.as.toSet -- variablesOf(p.phi) -- variablesOf(p.c) if (!unused.isEmpty) { val sub = p.copy(as = p.as.filterNot(unused)) RuleStep(List(sub), forward) } else { RuleInapplicable } } } class UnconstrainedOutput(synth: Synthesizer) extends Rule("Unconstr.Output", synth, 100) { def applyOn(task: Task): RuleResult = { val p = task.problem val unconstr = p.xs.toSet -- variablesOf(p.phi) if (!unconstr.isEmpty) { val sub = p.copy(xs = p.xs.filterNot(unconstr)) val onSuccess: List[Solution] => Solution = { case List(s) => Solution(s.pre, s.defs, LetTuple(sub.xs, s.term, Tuple(p.xs.map(id => if (unconstr(id)) simplestValue(Variable(id)) else Variable(id))))) case _ => Solution.none } RuleStep(List(sub), onSuccess) } else { RuleInapplicable } } } object Unification { class DecompTrivialClash(synth: Synthesizer) extends Rule("Unif Dec./Clash/Triv.", synth, 200) { def applyOn(task: Task): RuleResult = { val p = task.problem val TopLevelAnds(exprs) = p.phi val (toRemove, toAdd) = exprs.collect { case eq @ Equals(cc1 @ CaseClass(cd1, args1), cc2 @ CaseClass(cd2, args2)) => if (cc1 == cc2) { (eq, List(BooleanLiteral(true))) } else if (cd1 == cd2) { (eq, (args1 zip args2).map((Equals(_, _)).tupled)) } else { (eq, List(BooleanLiteral(false))) } }.unzip if (!toRemove.isEmpty) { val sub = p.copy(phi = And((exprs.toSet -- toRemove ++ toAdd.flatten).toSeq)) RuleStep(List(sub), forward) } else { RuleInapplicable } } } class OccursCheck(synth: Synthesizer) extends Rule("Unif OccursCheck", synth, 200) { def applyOn(task: Task): RuleResult = { val p = task.problem val TopLevelAnds(exprs) = p.phi val isImpossible = exprs.exists { case eq @ Equals(cc : CaseClass, Variable(id)) if variablesOf(cc) contains id => true case eq @ Equals(Variable(id), cc : CaseClass) if variablesOf(cc) contains id => true case _ => false } if (isImpossible) { val tpe = TupleType(p.xs.map(_.getType)) RuleSuccess(Solution(BooleanLiteral(false), Set(), Error(p.phi+" is UNSAT!").setType(tpe))) } else { RuleInapplicable } } } } class ADTDual(synth: Synthesizer) extends Rule("ADTDual", synth, 200) { def applyOn(task: Task): RuleResult = { val p = task.problem val xs = p.xs.toSet val as = p.as.toSet val TopLevelAnds(exprs) = p.phi val (toRemove, toAdd) = exprs.collect { case eq @ Equals(cc @ CaseClass(cd, args), e) if (variablesOf(e) -- as).isEmpty && (variablesOf(cc) -- xs).isEmpty => (eq, CaseClassInstanceOf(cd, e) +: (cd.fieldsIds zip args).map{ case (id, ex) => Equals(ex, CaseClassSelector(cd, e, id)) } ) case eq @ Equals(e, cc @ CaseClass(cd, args)) if (variablesOf(e) -- as).isEmpty && (variablesOf(cc) -- xs).isEmpty => (eq, CaseClassInstanceOf(cd, e) +: (cd.fieldsIds zip args).map{ case (id, ex) => Equals(ex, CaseClassSelector(cd, e, id)) } ) }.unzip if (!toRemove.isEmpty) { val sub = p.copy(phi = And((exprs.toSet -- toRemove ++ toAdd.flatten).toSeq)) RuleStep(List(sub), forward) } else { RuleInapplicable } } } class CEGIS(synth: Synthesizer) extends Rule("CEGIS", synth, 150) { def applyOn(task: Task): RuleResult = { val p = task.problem case class Generator(tpe: TypeTree, altBuilder: () => List[(Expr, Set[Identifier])]); var generators = Map[TypeTree, Generator]() def getGenerator(t: TypeTree): Generator = generators.get(t) match { case Some(g) => g case None => val alternatives: () => List[(Expr, Set[Identifier])] = t match { case BooleanType => { () => List((BooleanLiteral(true), Set()), (BooleanLiteral(false), Set())) } case Int32Type => { () => List((IntLiteral(0), Set()), (IntLiteral(1), Set())) } case TupleType(tps) => { () => val ids = tps.map(t => FreshIdentifier("t", true).setType(t)) List((Tuple(ids.map(Variable(_))), ids.toSet)) } case CaseClassType(cd) => { () => val ids = cd.fieldsIds.map(i => FreshIdentifier("c", true).setType(i.getType)) List((CaseClass(cd, ids.map(Variable(_))), ids.toSet)) } case AbstractClassType(cd) => { () => val alts: Seq[(Expr, Set[Identifier])] = cd.knownDescendents.flatMap(i => i match { case acd: AbstractClassDef => synth.reporter.error("Unnexpected abstract class in descendants!") None case cd: CaseClassDef => val ids = cd.fieldsIds.map(i => FreshIdentifier("c", true).setType(i.getType)) Some((CaseClass(cd, ids.map(Variable(_))), ids.toSet)) }) alts.toList } case _ => synth.reporter.error("Can't construct generator. Unsupported type: "+t+"["+t.getClass+"]"); { () => Nil } } val g = Generator(t, alternatives) generators += t -> g g } def inputAlternatives(t: TypeTree): List[(Expr, Set[Identifier])] = { p.as.filter(a => isSubtypeOf(a.getType, t)).map(id => (Variable(id) : Expr, Set[Identifier]())) } case class TentativeFormula(phi: Expr, program: Expr, mappings: Map[Identifier, (Identifier, Expr)], recTerms: Map[Identifier, Set[Identifier]]) { def unroll: TentativeFormula = { var newProgram = List[Expr]() var newRecTerms = Map[Identifier, Set[Identifier]]() var newMappings = Map[Identifier, (Identifier, Expr)]() for ((_, recIds) <- recTerms; recId <- recIds) { val gen = getGenerator(recId.getType) val alts = gen.altBuilder() ::: inputAlternatives(recId.getType) val altsWithBranches = alts.map(alt => FreshIdentifier("b", true).setType(BooleanType) -> alt) val bvs = altsWithBranches.map(alt => Variable(alt._1)) val distinct = if (bvs.size > 1) { (for (i <- (1 to bvs.size-1); j <- 0 to i-1) yield { Or(Not(bvs(i)), Not(bvs(j))) }).toList } else { List(BooleanLiteral(true)) } val pre = And(Or(bvs) :: distinct) // (b1 OR b2) AND (Not(b1) OR Not(b2)) val cases = for((bid, (ex, rec)) <- altsWithBranches.toList) yield { // b1 => E(gen1, gen2) [b1 -> {gen1, gen2}] if (!rec.isEmpty) { newRecTerms += bid -> rec } newMappings += bid -> (recId -> ex) Implies(Variable(bid), Equals(Variable(recId), ex)) } newProgram = newProgram ::: pre :: cases } TentativeFormula(phi, And(program :: newProgram), mappings ++ newMappings, newRecTerms) } def bounds = recTerms.keySet.map(id => Not(Variable(id))).toList def bss = mappings.keySet def entireFormula = And(phi :: program :: bounds) } var result: Option[RuleResult] = None var ass = p.as.toSet var xss = p.xs.toSet var lastF = TentativeFormula(Implies(p.c, p.phi), BooleanLiteral(true), Map(), Map() ++ p.xs.map(x => x -> Set(x))) var currentF = lastF.unroll var unrolings = 0 val maxUnrolings = 2 do { //println("Was: "+lastF.entireFormula) //println("Now Trying : "+currentF.entireFormula) val tpe = TupleType(p.xs.map(_.getType)) val bss = currentF.bss var predicates: Seq[Expr] = Seq() var continue = true while (result.isEmpty && continue) { val basePhi = currentF.entireFormula val constrainedPhi = And(basePhi +: predicates) //println("-"*80) //println("To satisfy: "+constrainedPhi) synth.solver.solveSAT(constrainedPhi) match { case (Some(true), satModel) => //println("Found candidate!: "+satModel.filterKeys(bss)) //println("Corresponding program: "+simplifyTautologies(synth.solver)(valuateWithModelIn(currentF.program, bss, satModel))) val fixedBss = And(bss.map(b => Equals(Variable(b), satModel(b))).toSeq) //println("Phi with fixed sat bss: "+fixedBss) val counterPhi = Implies(And(fixedBss, currentF.program), currentF.phi) //println("Formula to validate: "+counterPhi) synth.solver.solveSAT(Not(counterPhi)) match { case (Some(true), invalidModel) => // Found as such as the xs break, refine predicates //println("Found counter EX: "+invalidModel) predicates = Not(And(bss.map(b => Equals(Variable(b), satModel(b))).toSeq)) +: predicates //println("Let's avoid this case: "+bss.map(b => Equals(Variable(b), satModel(b))).mkString(" ")) case (Some(false), _) => //println("Sat model: "+satModel.toSeq.sortBy(_._1.toString).map{ case (id, v) => id+" -> "+v }.mkString(", ")) var mapping = currentF.mappings.filterKeys(satModel.mapValues(_ == BooleanLiteral(true))).values.toMap //println("Mapping: "+mapping) // Resolve mapping for ((c, e) <- mapping) { mapping += c -> substAll(mapping, e) } result = Some(RuleSuccess(Solution(BooleanLiteral(true), Set(), Tuple(p.xs.map(valuateWithModel(mapping))).setType(tpe)))) case _ => reporter.warning("Solver returned 'UNKNOWN' in a CEGIS iteration.") continue = false } case (Some(false), _) => //println("%%%% UNSAT") continue = false case _ => //println("%%%% WOOPS") continue = false } } lastF = currentF currentF = currentF.unroll unrolings += 1 } while(unrolings < maxUnrolings && lastF != currentF && result.isEmpty) result.getOrElse(RuleInapplicable) } } class OptimisticGround(synth: Synthesizer) extends Rule("Optimistic Ground", synth, 150) { def applyOn(task: Task): RuleResult = { val p = task.problem if (!p.as.isEmpty && !p.xs.isEmpty) { val xss = p.xs.toSet val ass = p.as.toSet val tpe = TupleType(p.xs.map(_.getType)) var i = 0; var maxTries = 3; var result: Option[RuleResult] = None var predicates: Seq[Expr] = Seq() while (result.isEmpty && i < maxTries) { val phi = And(p.phi +: predicates) //println("SOLVING " + phi + " ...") synth.solver.solveSAT(phi) match { case (Some(true), satModel) => val satXsModel = satModel.filterKeys(xss) val newPhi = valuateWithModelIn(phi, xss, satModel) //println("REFUTING " + Not(newPhi) + "...") synth.solver.solveSAT(Not(newPhi)) match { case (Some(true), invalidModel) => // Found as such as the xs break, refine predicates predicates = valuateWithModelIn(phi, ass, invalidModel) +: predicates case (Some(false), _) => result = Some(RuleSuccess(Solution(BooleanLiteral(true), Set(), Tuple(p.xs.map(valuateWithModel(satModel))).setType(tpe)))) case _ => result = Some(RuleInapplicable) } case (Some(false), _) => if (predicates.isEmpty) { result = Some(RuleSuccess(Solution(BooleanLiteral(false), Set(), Error(p.phi+" is UNSAT!").setType(tpe)))) } else { result = Some(RuleInapplicable) } case _ => result = Some(RuleInapplicable) } i += 1 } result.getOrElse(RuleInapplicable) } else { RuleInapplicable } } } class EqualitySplit(synth: Synthesizer) extends Rule("Eq. Split.", synth, 90) { def applyOn(task: Task): RuleResult = { val p = task.problem val TopLevelAnds(presSeq) = p.c val pres = presSeq.toSet def combinations(a1: Identifier, a2: Identifier): Set[Expr] = { val v1 = Variable(a1) val v2 = Variable(a2) Set( Equals(v1, v2), Equals(v2, v1), Not(Equals(v1, v2)), Not(Equals(v2, v1)) ) } val candidate = p.as.groupBy(_.getType).map(_._2.toList).find { case List(a1, a2) => (pres & combinations(a1, a2)).isEmpty case _ => false } candidate match { case Some(List(a1, a2)) => val sub1 = p.copy(c = And(Equals(Variable(a1), Variable(a2)), p.c)) val sub2 = p.copy(c = And(Not(Equals(Variable(a1), Variable(a2))), p.c)) val onSuccess: List[Solution] => Solution = { case List(s1, s2) => Solution(Or(s1.pre, s2.pre), s1.defs++s2.defs, IfExpr(Equals(Variable(a1), Variable(a2)), s1.term, s2.term)) case _ => Solution.none } RuleStep(List(sub1, sub2), onSuccess) case _ => RuleInapplicable } } } class IntegerEquation(synth: Synthesizer) extends Rule("Integer Equation", synth, 300) { def applyOn(task: Task): RuleResult = { val p = task.problem val TopLevelAnds(exprs) = p.phi val xs = p.xs val as = p.as val formula = p.phi val (eqs, others) = exprs.partition(_.isInstanceOf[Equals]) var candidates: Seq[Expr] = eqs var allOthers: Seq[Expr] = others var vars: Set[Identifier] = Set() var eqas: Set[Identifier] = Set() var eqxs: List[Identifier] = List() var ys: Set[Identifier] = Set() var optionNormalizedEq: Option[List[Expr]] = None while(!candidates.isEmpty && optionNormalizedEq == None) { val eq@Equals(_,_) = candidates.head candidates = candidates.tail vars = variablesOf(eq) eqas = as.toSet.intersect(vars) eqxs = xs.toSet.intersect(vars).toList ys = xs.toSet.diff(vars) try { optionNormalizedEq = Some(ArithmeticNormalization(Minus(eq.left, eq.right), eqxs.toArray).toList) } catch { case ArithmeticNormalization.NonLinearExpressionException(_) => } } optionNormalizedEq match { case None => RuleInapplicable case Some(normalizedEq0) => { val (neqxs, normalizedEq) = eqxs.zip(normalizedEq0.tail).filterNot{ case (_, IntLiteral(0)) => true case _ => false}.unzip //if(normalizedEq.size == 1) { //} else { val (eqPre, eqWitness, eqFreshVars) = elimVariable(eqas, normalizedEq) val eqSubstMap: Map[Expr, Expr] = neqxs.zip(eqWitness).map{case (id, e) => (Variable(id), simplify(e))}.toMap val freshFormula = simplify(replace(eqSubstMap, And(allOthers))) //} //(eqPre, freshFormula) val newProblem = Problem(as, And(eqPre, p.c), freshFormula, eqFreshVars) val onSuccess: List[Solution] => Solution = { case List(Solution(pre, defs, term)) => if (eqFreshVars.isEmpty) { Solution(pre, defs, replace(eqSubstMap, Tuple(neqxs.map(Variable(_))))) } else { Solution(pre, defs, LetTuple(eqFreshVars, term, replace(eqSubstMap, Tuple(neqxs.map(Variable(_)))))) } case _ => Solution.none } RuleStep(List(newProblem), onSuccess) } } } }