diff --git a/src/main/scala/leon/synthesis/Rules.scala b/src/main/scala/leon/synthesis/Rules.scala index 986feb80be957fa9ef3e8525b915f49feac04882..5afcb0f5a78b3e7dca07aaadf590ddcd41371bb2 100644 --- a/src/main/scala/leon/synthesis/Rules.scala +++ b/src/main/scala/leon/synthesis/Rules.scala @@ -7,6 +7,8 @@ import purescala.TypeTrees._ object Rules { def all(synth: Synthesizer) = List( + new Unification.DecompTrivialClash(synth), + new Unification.OccursCheck(synth), new OnePoint(synth), new Ground(synth), new CaseSplit(synth), @@ -22,6 +24,11 @@ abstract class Rule(val name: String, val synth: Synthesizer) { def subst(what: Tuple2[Identifier, Expr], in: Expr): Expr = replace(Map(Variable(what._1) -> what._2), in) + val forward: List[Solution] => Solution = { + case List(s) => s + case _ => Solution.none + } + override def toString = name } @@ -151,12 +158,7 @@ class UnusedInput(synth: Synthesizer) extends Rule("UnusedInput", synth) { if (!unused.isEmpty) { val sub = p.copy(as = p.as.filterNot(unused)) - val onSuccess: List[Solution] => Solution = { - case List(s) => s - case _ => Solution.none - } - - List(task.decompose(this, List(sub), onSuccess, 300)) + List(task.decompose(this, List(sub), forward, 300)) } else { Nil } @@ -173,14 +175,9 @@ class UnconstrainedOutput(synth: Synthesizer) extends Rule("Unconstr.Output", sy val onSuccess: List[Solution] => Solution = { case List(s) => - synth.solveSAT(And(unconstr.map(id => Equals(Variable(id), Variable(id))).toSeq)) match { - case (Some(true), model) => - Solution(s.pre, LetTuple(sub.xs, s.term, Tuple(p.xs.map(id => if (unconstr(id)) model(id) else Variable(id))))) - case _ => - Solution.none - } - - case _ => Solution.none + Solution(s.pre, LetTuple(sub.xs, s.term, Tuple(p.xs.map(id => if (unconstr(id)) simplestValue(Variable(id)) else Variable(id))))) + case _ => + Solution.none } List(task.decompose(this, List(sub), onSuccess, 300)) @@ -191,3 +188,58 @@ class UnconstrainedOutput(synth: Synthesizer) extends Rule("Unconstr.Output", sy } } + +object Unification { + class DecompTrivialClash(synth: Synthesizer) extends Rule("Unif Dec./Clash/Triv.", synth) { + def isApplicable(task: Task): List[DecomposedTask] = { + val p = task.problem + + val TopLevelAnds(exprs) = p.phi + + val (toRemove, toAdd) = exprs.collect { + case eq @ Equals(cc1 @ CaseClass(cd1, args1), cc2 @ CaseClass(cd2, args2)) => + if (cc1 == cc2) { + (eq, List(BooleanLiteral(true))) + } else if (cd1 == cd2) { + (eq, (args1 zip args2).map((Equals(_, _)).tupled)) + } else { + (eq, List(BooleanLiteral(false))) + } + }.unzip + + if (!toRemove.isEmpty) { + val sub = p.copy(phi = And((exprs.toSet -- toRemove ++ toAdd.flatten).toSeq)) + + + List(task.decompose(this, List(sub), forward, 100)) + } else { + Nil + } + } + } + + class OccursCheck(synth: Synthesizer) extends Rule("Unif OccursCheck", synth) { + def isApplicable(task: Task): List[DecomposedTask] = { + val p = task.problem + + val TopLevelAnds(exprs) = p.phi + + val isImpossible = exprs.exists { + case eq @ Equals(cc : CaseClass, Variable(id)) if variablesOf(cc) contains id => + true + case eq @ Equals(Variable(id), cc : CaseClass) if variablesOf(cc) contains id => + true + case _ => + false + } + + if (isImpossible) { + val tpe = TupleType(p.xs.map(_.getType)) + + List(task.solveUsing(this, Solution(BooleanLiteral(false), Error(p.phi+" is UNSAT!").setType(tpe), 200), 200)) + } else { + Nil + } + } + } +} diff --git a/testcases/Choose.scala b/testcases/Choose.scala index bf9ef06fa146a5afffca2acc62fafddc42d7cea3..229d07286fcfa8442d60b0e9612efba5193b3282 100644 --- a/testcases/Choose.scala +++ b/testcases/Choose.scala @@ -15,17 +15,4 @@ object ChooseTest { def z0(a: Int): (Int, Int, List) = choose{ (x1: Int, x2: Int, x3: List) => x1 > a && x2 > a } - - sealed abstract class List - case class Nil() extends List - case class Cons(head : Int, tail : List) extends List - - def size(lst : List) : Int = (lst match { - case Nil() => 0 - case Cons(_, xs) => 1 + size(xs) - }) ensuring(_ >= 0) - - def k0() : List = choose { - (l : List) => size(l) == 1 - } } diff --git a/testcases/synthesis/Simple.scala b/testcases/synthesis/Simple.scala new file mode 100644 index 0000000000000000000000000000000000000000..6c5fdba0f82d0fdaf83c75026028208a940497b3 --- /dev/null +++ b/testcases/synthesis/Simple.scala @@ -0,0 +1,7 @@ +import leon.Utils._ + +object SimpleSynthesis { + + def c1(x: Int): Int = choose { (y: Int) => y > x } + +} diff --git a/testcases/synthesis/Unification.scala b/testcases/synthesis/Unification.scala new file mode 100644 index 0000000000000000000000000000000000000000..8e4a5c6145b1992bca76f2012f30b1ec3d328ea0 --- /dev/null +++ b/testcases/synthesis/Unification.scala @@ -0,0 +1,17 @@ +import leon.Utils._ + +object UnificationSynthesis { + + def u1(a1: Int): Int = choose { (x1: Int) => Cons(x1, Nil()) == Cons(a1, Nil()) } + def u2(a1: Int): Int = choose { (x1: Int) => Cons(x1, Nil()) == Cons(x1, Cons(2, Nil())) } + def u3(a1: Int): List = choose { (xs: List) => Cons(a1, xs) == xs } + + sealed abstract class List + case class Nil() extends List + case class Cons(head : Int, tail : List) extends List + + def size(lst : List) : Int = (lst match { + case Nil() => 0 + case Cons(_, xs) => 1 + size(xs) + }) ensuring(_ >= 0) +}