diff --git a/src/main/scala/leon/purescala/Definitions.scala b/src/main/scala/leon/purescala/Definitions.scala index 3c18e858fcfeba8fa7d02421536e1080c5999847..d39d8abf502cf075688a7839fdb3e8eb4977ae7b 100644 --- a/src/main/scala/leon/purescala/Definitions.scala +++ b/src/main/scala/leon/purescala/Definitions.scala @@ -246,6 +246,8 @@ object Definitions { fields.map(f => f.id).toSet } + def fieldsIds = fields.map(_.id) + def selectorID2Index(id: Identifier) : Int = { var i : Int = 0 var found = false diff --git a/src/main/scala/leon/synthesis/Rules.scala b/src/main/scala/leon/synthesis/Rules.scala index 5afcb0f5a78b3e7dca07aaadf590ddcd41371bb2..1c045aabb35882606b408df6a294ede5699f098e 100644 --- a/src/main/scala/leon/synthesis/Rules.scala +++ b/src/main/scala/leon/synthesis/Rules.scala @@ -9,6 +9,7 @@ object Rules { def all(synth: Synthesizer) = List( new Unification.DecompTrivialClash(synth), new Unification.OccursCheck(synth), + new ADTDual(synth), new OnePoint(synth), new Ground(synth), new CaseSplit(synth), @@ -243,3 +244,38 @@ object Unification { } } } + + +class ADTDual(synth: Synthesizer) extends Rule("ADTDual", synth) { + def isApplicable(task: Task): List[DecomposedTask] = { + val p = task.problem + + val xs = p.xs.toSet + val as = p.as.toSet + + val TopLevelAnds(exprs) = p.phi + + + val (toRemove, toAdd, toPre) = exprs.collect { + case eq @ Equals(cc @ CaseClass(cd, args), e) if (variablesOf(e) -- as).isEmpty && (variablesOf(cc) -- xs).isEmpty => + (eq, (cd.fieldsIds zip args).map{ case (id, ex) => Equals(ex, CaseClassSelector(cd, e, id)) }, CaseClassInstanceOf(cd, e) ) + case eq @ Equals(e, cc @ CaseClass(cd, args)) if (variablesOf(e) -- as).isEmpty && (variablesOf(cc) -- xs).isEmpty => + (eq, (cd.fieldsIds zip args).map{ case (id, ex) => Equals(ex, CaseClassSelector(cd, e, id)) }, CaseClassInstanceOf(cd, e) ) + }.unzip3 + + if (!toRemove.isEmpty) { + val sub = p.copy(phi = And((exprs.toSet -- toRemove ++ toAdd.flatten).toSeq)) + + val onSuccess: List[Solution] => Solution = { + case List(s) => + Solution(And(s.pre +: toPre), s.term) + case _ => + Solution.none + } + + List(task.decompose(this, List(sub), onSuccess, 80)) + } else { + Nil + } + } +} diff --git a/src/main/scala/leon/synthesis/Synthesizer.scala b/src/main/scala/leon/synthesis/Synthesizer.scala index 44a0491346f1a3f635b7887e10ec96cfbc62cc4b..2b927cad09ac10ae4134fcf4cdea0526c86a99c7 100644 --- a/src/main/scala/leon/synthesis/Synthesizer.scala +++ b/src/main/scala/leon/synthesis/Synthesizer.scala @@ -64,11 +64,11 @@ class Synthesizer(val r: Reporter, val solvers: List[Solver]) { } def onTaskSucceeded(task: Task, solution: Solution) { + info(" => Solved "+task.problem+" ⊢ "+solution) if (task.parent eq null) { info(" SUCCESS!") this.solution = Some(solution) } else { - info(" => Solved "+task.problem+" ⊢ "+solution) task.parent.subSucceeded(task.problem, solution) } } diff --git a/testcases/synthesis/Unification.scala b/testcases/synthesis/Unification.scala index b98f7acd3ac3863a33853c6f917a35937b56dbd7..c2b111352871aa0c3435774a21321418fb8c8e6e 100644 --- a/testcases/synthesis/Unification.scala +++ b/testcases/synthesis/Unification.scala @@ -12,6 +12,8 @@ object UnificationSynthesis { def u5(a1: Int): List = choose { (xs: List) => Cons(a1, Nil()) == xs } + def u6(a1: List): Int = choose { (xs: Int) => Cons(xs, Nil()) == a1 } + sealed abstract class List case class Nil() extends List case class Cons(head : Int, tail : List) extends List