diff --git a/src/main/scala/leon/synthesis/Rules.scala b/src/main/scala/leon/synthesis/Rules.scala index 7d36b09775395376a38b6584ac404a8701e59af2..8db62a9149b14b45c3549a7f91f239722a7ae2ce 100644 --- a/src/main/scala/leon/synthesis/Rules.scala +++ b/src/main/scala/leon/synthesis/Rules.scala @@ -20,7 +20,7 @@ object Rules { new EqualitySplit(_), new CEGIS(_), new Assert(_), - new ADTSplit(_), +// new ADTSplit(_), new IntegerEquation(_) ) } diff --git a/src/main/scala/leon/synthesis/rules/ADTSplit.scala b/src/main/scala/leon/synthesis/rules/ADTSplit.scala index 41258337e1c0818a7a9066a4e6714981c937e110..74dadefe73ed2634e0f1c4b14f0713a7fcb6bd47 100644 --- a/src/main/scala/leon/synthesis/rules/ADTSplit.scala +++ b/src/main/scala/leon/synthesis/rules/ADTSplit.scala @@ -9,7 +9,7 @@ import purescala.TreeOps._ import purescala.Extractors._ import purescala.Definitions._ -class ADTSplit(synth: Synthesizer) extends Rule("ADT Split.", synth, 90) { +class ADTSplit(synth: Synthesizer) extends Rule("ADT Split.", synth, 70) { def applyOn(task: Task): RuleResult = { val p = task.problem @@ -50,12 +50,14 @@ class ADTSplit(synth: Synthesizer) extends Rule("ADT Split.", synth, 90) { val subInfo = for(ccd <- cases) yield { val subId = FreshIdentifier(ccd.id.name, true).setType(CaseClassType(ccd)) - val subPre = CaseClassInstanceOf(ccd, Variable(id)) + val args = ccd.fieldsIds.map(id => FreshIdentifier(id.name, true).setType(id.getType)).toList + val subPre = Equals(CaseClass(ccd, args.map(Variable(_))), Variable(subId)) + val subPhi = subst(id -> Variable(subId), p.phi) - val subProblem = Problem(subId :: oas, And(p.c, subPre), subPhi, p.xs) - val subPattern = CaseClassPattern(Some(subId), ccd, ccd.fieldsIds.map(id => WildcardPattern(None))) + val subProblem = Problem(subId :: args ::: oas, And(p.c, subPre), subPhi, p.xs) + val subPattern = CaseClassPattern(Some(subId), ccd, args.map(id => WildcardPattern(Some(id)))) - (subProblem, subPre, subPattern) + (ccd, subProblem, subPattern) } @@ -63,8 +65,8 @@ class ADTSplit(synth: Synthesizer) extends Rule("ADT Split.", synth, 90) { case sols => var globalPre = List[Expr]() - val cases = for ((sol, (problem, pre, pattern)) <- (sols zip subInfo)) yield { - globalPre ::= And(pre, sol.pre) + val cases = for ((sol, (ccd, problem, pattern)) <- (sols zip subInfo)) yield { + globalPre ::= And(CaseClassInstanceOf(ccd, Variable(id)), sol.pre) SimpleCase(pattern, sol.term) } @@ -72,7 +74,7 @@ class ADTSplit(synth: Synthesizer) extends Rule("ADT Split.", synth, 90) { Solution(Or(globalPre), sols.flatMap(_.defs).toSet, MatchExpr(Variable(id), cases)) } - HeuristicStep(synth, p, subInfo.map(_._1).toList, onSuccess) + HeuristicStep(synth, p, subInfo.map(_._2).toList, onSuccess) case _ => RuleInapplicable }