diff --git a/src/main/scala/leon/synthesis/Rules.scala b/src/main/scala/leon/synthesis/Rules.scala index e289e6f838acbf266265a1b459f748e284b116b0..27eba7ac6fc8b78b22fc8b6bef17cbbb483118df 100644 --- a/src/main/scala/leon/synthesis/Rules.scala +++ b/src/main/scala/leon/synthesis/Rules.scala @@ -20,6 +20,7 @@ object Rules { UnconstrainedOutput, OptimisticGround, EqualitySplit, + InequalitySplit, CEGIS, Assert, DetupleOutput, diff --git a/src/main/scala/leon/synthesis/rules/EqualitySplit.scala b/src/main/scala/leon/synthesis/rules/EqualitySplit.scala index 205440cf3bbc85f7e427df619ba04399e4399731..125b2e6060f115f699b45bd5bb682e785c486c6b 100644 --- a/src/main/scala/leon/synthesis/rules/EqualitySplit.scala +++ b/src/main/scala/leon/synthesis/rules/EqualitySplit.scala @@ -10,7 +10,7 @@ import purescala.Extractors._ case object EqualitySplit extends Rule("Eq. Split.") { def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { - val candidates = p.as.groupBy(_.getType).map(_._2.toList).filter { + val candidates = p.as.groupBy(_.getType).mapValues(_.combinations(2).filter { case List(a1, a2) => val toValEQ = Implies(p.pc, Equals(Variable(a1), Variable(a2))) @@ -32,10 +32,10 @@ case object EqualitySplit extends Rule("Eq. Split.") { false } case _ => false - } + }).values.flatten - candidates.map(_ match { + candidates.flatMap(_ match { case List(a1, a2) => val sub1 = p.copy(pc = And(Equals(Variable(a1), Variable(a2)), p.pc)) @@ -51,6 +51,6 @@ case object EqualitySplit extends Rule("Eq. Split.") { Some(RuleInstantiation.immediateDecomp(p, this, List(sub1, sub2), onSuccess)) case _ => None - }).flatten + }) } } diff --git a/src/main/scala/leon/synthesis/rules/InequalitySplit.scala b/src/main/scala/leon/synthesis/rules/InequalitySplit.scala new file mode 100644 index 0000000000000000000000000000000000000000..d79bd66f06bfaac93de04e2288ed766d29e9c52b --- /dev/null +++ b/src/main/scala/leon/synthesis/rules/InequalitySplit.scala @@ -0,0 +1,78 @@ +package leon +package synthesis +package rules + +import purescala.Trees._ +import purescala.TypeTrees._ +import purescala.Common._ +import purescala.TypeTrees._ +import purescala.TreeOps._ +import purescala.Extractors._ + +case object InequalitySplit extends Rule("Ineq. Split.") { + def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { + val candidates = p.as.filter(_.getType == Int32Type).combinations(2).toList.filter { + case List(a1, a2) => + val toValLT = Implies(p.pc, LessThan(Variable(a1), Variable(a2))) + + val impliesLT = sctx.solver.solveSAT(Not(toValLT)) match { + case (Some(false), _) => true + case _ => false + } + + if (!impliesLT) { + val toValGT = Implies(p.pc, GreaterThan(Variable(a1), Variable(a2))) + + val impliesGT = sctx.solver.solveSAT(Not(toValGT)) match { + case (Some(false), _) => true + case _ => false + } + + if (!impliesGT) { + val toValEQ = Implies(p.pc, Equals(Variable(a1), Variable(a2))) + + val impliesEQ = sctx.solver.solveSAT(Not(toValEQ)) match { + case (Some(false), _) => true + case _ => false + } + + !impliesEQ + } else { + false + } + } else { + false + } + case _ => false + } + + + candidates.flatMap(_ match { + case List(a1, a2) => + + val subLT = p.copy(pc = And(LessThan(Variable(a1), Variable(a2)), p.pc)) + val subEQ = p.copy(pc = And(Equals(Variable(a1), Variable(a2)), p.pc)) + val subGT = p.copy(pc = And(GreaterThan(Variable(a1), Variable(a2)), p.pc)) + + val onSuccess: List[Solution] => Option[Solution] = { + case sols @ List(sLT, sEQ, sGT) => + val pre = Or(sols.map(_.pre)) + val defs = sLT.defs ++ sEQ.defs ++ sGT.defs + + val term = IfExpr(LessThan(Variable(a1), Variable(a2)), + sLT.term, + IfExpr(Equals(Variable(a1), Variable(a2)), + sEQ.term, + sGT.term)) + + Some(Solution(pre, defs, term)) + case _ => + None + } + + Some(RuleInstantiation.immediateDecomp(p, this, List(subLT, subEQ, subGT), onSuccess)) + case _ => + None + }) + } +}