Skip to content
Snippets Groups Projects
Commit e0dd0a63 authored by Etienne Kneuss's avatar Etienne Kneuss
Browse files

Add unification rules (Trivial, Symbol Clash, Decomposition, Occurs Check)

parent 147d699d
No related branches found
No related tags found
No related merge requests found
......@@ -7,6 +7,8 @@ import purescala.TypeTrees._
object Rules {
def all(synth: Synthesizer) = List(
new Unification.DecompTrivialClash(synth),
new Unification.OccursCheck(synth),
new OnePoint(synth),
new Ground(synth),
new CaseSplit(synth),
......@@ -22,6 +24,11 @@ abstract class Rule(val name: String, val synth: Synthesizer) {
def subst(what: Tuple2[Identifier, Expr], in: Expr): Expr = replace(Map(Variable(what._1) -> what._2), in)
val forward: List[Solution] => Solution = {
case List(s) => s
case _ => Solution.none
}
override def toString = name
}
......@@ -151,12 +158,7 @@ class UnusedInput(synth: Synthesizer) extends Rule("UnusedInput", synth) {
if (!unused.isEmpty) {
val sub = p.copy(as = p.as.filterNot(unused))
val onSuccess: List[Solution] => Solution = {
case List(s) => s
case _ => Solution.none
}
List(task.decompose(this, List(sub), onSuccess, 300))
List(task.decompose(this, List(sub), forward, 300))
} else {
Nil
}
......@@ -173,14 +175,9 @@ class UnconstrainedOutput(synth: Synthesizer) extends Rule("Unconstr.Output", sy
val onSuccess: List[Solution] => Solution = {
case List(s) =>
synth.solveSAT(And(unconstr.map(id => Equals(Variable(id), Variable(id))).toSeq)) match {
case (Some(true), model) =>
Solution(s.pre, LetTuple(sub.xs, s.term, Tuple(p.xs.map(id => if (unconstr(id)) model(id) else Variable(id)))))
case _ =>
Solution.none
}
case _ => Solution.none
Solution(s.pre, LetTuple(sub.xs, s.term, Tuple(p.xs.map(id => if (unconstr(id)) simplestValue(Variable(id)) else Variable(id)))))
case _ =>
Solution.none
}
List(task.decompose(this, List(sub), onSuccess, 300))
......@@ -191,3 +188,58 @@ class UnconstrainedOutput(synth: Synthesizer) extends Rule("Unconstr.Output", sy
}
}
object Unification {
class DecompTrivialClash(synth: Synthesizer) extends Rule("Unif Dec./Clash/Triv.", synth) {
def isApplicable(task: Task): List[DecomposedTask] = {
val p = task.problem
val TopLevelAnds(exprs) = p.phi
val (toRemove, toAdd) = exprs.collect {
case eq @ Equals(cc1 @ CaseClass(cd1, args1), cc2 @ CaseClass(cd2, args2)) =>
if (cc1 == cc2) {
(eq, List(BooleanLiteral(true)))
} else if (cd1 == cd2) {
(eq, (args1 zip args2).map((Equals(_, _)).tupled))
} else {
(eq, List(BooleanLiteral(false)))
}
}.unzip
if (!toRemove.isEmpty) {
val sub = p.copy(phi = And((exprs.toSet -- toRemove ++ toAdd.flatten).toSeq))
List(task.decompose(this, List(sub), forward, 100))
} else {
Nil
}
}
}
class OccursCheck(synth: Synthesizer) extends Rule("Unif OccursCheck", synth) {
def isApplicable(task: Task): List[DecomposedTask] = {
val p = task.problem
val TopLevelAnds(exprs) = p.phi
val isImpossible = exprs.exists {
case eq @ Equals(cc : CaseClass, Variable(id)) if variablesOf(cc) contains id =>
true
case eq @ Equals(Variable(id), cc : CaseClass) if variablesOf(cc) contains id =>
true
case _ =>
false
}
if (isImpossible) {
val tpe = TupleType(p.xs.map(_.getType))
List(task.solveUsing(this, Solution(BooleanLiteral(false), Error(p.phi+" is UNSAT!").setType(tpe), 200), 200))
} else {
Nil
}
}
}
}
......@@ -15,17 +15,4 @@ object ChooseTest {
def z0(a: Int): (Int, Int, List) = choose{ (x1: Int, x2: Int, x3: List) => x1 > a && x2 > a }
sealed abstract class List
case class Nil() extends List
case class Cons(head : Int, tail : List) extends List
def size(lst : List) : Int = (lst match {
case Nil() => 0
case Cons(_, xs) => 1 + size(xs)
}) ensuring(_ >= 0)
def k0() : List = choose {
(l : List) => size(l) == 1
}
}
import leon.Utils._
object SimpleSynthesis {
def c1(x: Int): Int = choose { (y: Int) => y > x }
}
import leon.Utils._
object UnificationSynthesis {
def u1(a1: Int): Int = choose { (x1: Int) => Cons(x1, Nil()) == Cons(a1, Nil()) }
def u2(a1: Int): Int = choose { (x1: Int) => Cons(x1, Nil()) == Cons(x1, Cons(2, Nil())) }
def u3(a1: Int): List = choose { (xs: List) => Cons(a1, xs) == xs }
sealed abstract class List
case class Nil() extends List
case class Cons(head : Int, tail : List) extends List
def size(lst : List) : Int = (lst match {
case Nil() => 0
case Cons(_, xs) => 1 + size(xs)
}) ensuring(_ >= 0)
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment