diff --git a/src/main/scala/leon/synthesis/Rules.scala b/src/main/scala/leon/synthesis/Rules.scala index d4d7bab0180076fe387217d04d8a0496e98f43df..7d36b09775395376a38b6584ac404a8701e59af2 100644 --- a/src/main/scala/leon/synthesis/Rules.scala +++ b/src/main/scala/leon/synthesis/Rules.scala @@ -20,6 +20,7 @@ object Rules { new EqualitySplit(_), new CEGIS(_), new Assert(_), + new ADTSplit(_), new IntegerEquation(_) ) } diff --git a/src/main/scala/leon/synthesis/rules/ADTSplit.scala b/src/main/scala/leon/synthesis/rules/ADTSplit.scala new file mode 100644 index 0000000000000000000000000000000000000000..41258337e1c0818a7a9066a4e6714981c937e110 --- /dev/null +++ b/src/main/scala/leon/synthesis/rules/ADTSplit.scala @@ -0,0 +1,80 @@ +package leon +package synthesis +package rules + +import purescala.Trees._ +import purescala.Common._ +import purescala.TypeTrees._ +import purescala.TreeOps._ +import purescala.Extractors._ +import purescala.Definitions._ + +class ADTSplit(synth: Synthesizer) extends Rule("ADT Split.", synth, 90) { + def applyOn(task: Task): RuleResult = { + val p = task.problem + + val candidate = p.as.collect { + case IsTyped(id, AbstractClassType(cd)) => + + val optCases = for (dcd <- cd.knownDescendents) yield dcd match { + case ccd : CaseClassDef => + val toVal = Implies(p.c, CaseClassInstanceOf(ccd, Variable(id))) + + val isImplied = synth.solver.solveSAT(Not(toVal)) match { + case (Some(false), _) => true + case _ => false + } + + if (!isImplied) { + Some(ccd) + } else { + None + } + case _ => + None + } + + val cases = optCases.flatten + + if (!cases.isEmpty) { + Some((id, cases)) + } else { + None + } + } + + + candidate.find(_.isDefined) match { + case Some(Some((id, cases))) => + val oas = p.as.filter(_ != id) + + val subInfo = for(ccd <- cases) yield { + val subId = FreshIdentifier(ccd.id.name, true).setType(CaseClassType(ccd)) + val subPre = CaseClassInstanceOf(ccd, Variable(id)) + val subPhi = subst(id -> Variable(subId), p.phi) + val subProblem = Problem(subId :: oas, And(p.c, subPre), subPhi, p.xs) + val subPattern = CaseClassPattern(Some(subId), ccd, ccd.fieldsIds.map(id => WildcardPattern(None))) + + (subProblem, subPre, subPattern) + } + + + val onSuccess: List[Solution] => Solution = { + case sols => + var globalPre = List[Expr]() + + val cases = for ((sol, (problem, pre, pattern)) <- (sols zip subInfo)) yield { + globalPre ::= And(pre, sol.pre) + + SimpleCase(pattern, sol.term) + } + + Solution(Or(globalPre), sols.flatMap(_.defs).toSet, MatchExpr(Variable(id), cases)) + } + + HeuristicStep(synth, p, subInfo.map(_._1).toList, onSuccess) + case _ => + RuleInapplicable + } + } +} diff --git a/src/main/scala/leon/synthesis/rules/EqualitySplit.scala b/src/main/scala/leon/synthesis/rules/EqualitySplit.scala index 0b3cc6a345d04981b61a0b75298648a78542c28a..04d4f8bf70128a6c6aa8a95a1c43fa35ac9433a9 100644 --- a/src/main/scala/leon/synthesis/rules/EqualitySplit.scala +++ b/src/main/scala/leon/synthesis/rules/EqualitySplit.scala @@ -21,14 +21,18 @@ class EqualitySplit(synth: Synthesizer) extends Rule("Eq. Split.", synth, 90) { case _ => false } - val toValNE = Implies(p.c, Not(Equals(Variable(a1), Variable(a2)))) + if (!impliesEQ) { + val toValNE = Implies(p.c, Not(Equals(Variable(a1), Variable(a2)))) - val impliesNE = synth.solver.solveSAT(Not(toValNE)) match { - case (Some(false), _) => true - case _ => false - } + val impliesNE = synth.solver.solveSAT(Not(toValNE)) match { + case (Some(false), _) => true + case _ => false + } - !impliesNE && !impliesEQ + !impliesNE + } else { + false + } case _ => false }