Skip to content
Snippets Groups Projects
Commit 4340b6ed authored by Etienne Kneuss's avatar Etienne Kneuss
Browse files

Merge branch 'master' of laragit.epfl.ch:projects/leon-2.0

Conflicts:
	src/main/scala/leon/synthesis/Rules.scala
parents 0b2eb621 f947b2d3
No related branches found
No related tags found
No related merge requests found
......@@ -10,6 +10,8 @@ import leon.purescala.Common._
object ArithmeticNormalization {
case class NonLinearExpressionException(msg: String) extends Exception
//assume the function is an arithmetic expression, not a relation
//return a normal form where the [t a1 ... an] where
//expr = t + a1*x1 + ... + an*xn and xs = [x1 ... xn]
......@@ -19,7 +21,7 @@ object ArithmeticNormalization {
case Times(e1, e2) => containsId(e1, id) || containsId(e2, id)
case IntLiteral(_) => false
case Variable(id2) => id == id2
case _ => sys.error("unexpected format: " + e)
case err => throw NonLinearExpressionException("unexpected in containsId: " + err)
}
def group(es: Seq[Expr], id: Identifier): Expr = {
......@@ -53,7 +55,7 @@ object ArithmeticNormalization {
def rec(e: Expr): Unit = e match {
case IntLiteral(i) => coef = coef*i
case Variable(id2) => if(id.isEmpty) id = Some(id2) else sys.error("multiple variables")
case Variable(id2) => if(id.isEmpty) id = Some(id2) else throw NonLinearExpressionException("multiple variable")
case Times(e1, e2) => rec(e1); rec(e2)
}
......@@ -71,10 +73,11 @@ object ArithmeticNormalization {
def expand(expr: Expr): Seq[Expr] = expr match {
case Plus(es1, es2) => expand(es1) ++ expand(es2)
case Minus(e1, e2) => expand(e1) ++ expand(e2).map(Times(IntLiteral(-1), _): Expr)
case UMinus(e) => expand(e).map(Times(IntLiteral(-1), _): Expr)
case Times(es1, es2) => multiply(expand(es1), expand(es2))
case v@Variable(_) => Seq(v)
case n@IntLiteral(_) => Seq(n)
case _ => sys.error("Unexpected")
case err => throw NonLinearExpressionException("unexpected in expand: " + err)
}
//simple, local simplifications
......
......@@ -14,13 +14,20 @@ object IntegerSynthesis {
// case (None, f) => (f, xs.map(id => (Variable(id))))
// case (Some(eq), f) => {
// val vars: Set[Identifier] = variablesOf(eq)
// val eqxs: List[Identifier] = xs.toSet.intersect(vars).toList
// val ys: Set[Identifier] = xs.toSet.difference(vars).toList
// val eqas: Set[Identifier] = as.intersect(vars)
// val eqxs: List[Identifier] = xs.toSet.intersect(vars).toList
// val ys: Set[Identifier] = xs.toSet.diff(vars)
// val normalizedEq: List[Expr] = ArithmeticNormalization(Minus(eq.left, eq.right), eqxs.toArray).toList
// val (eqPre, eqWitness, eqFreshVars) = elimVariable(eqas, normalizedEq)
// val eqSubstMap: Map[Expr, Expr] = eqxs.zip(eqWitness).map{case (id, e) => (Variable(id), e)}.toMap
// val freshFormula = simplify(replace(eqSubstMap, f))
// (eqPre, freshFormula)
// /*
// val (recPre, recSubst) = apply(as, ys ++ eqFreshVars, freshFormula)
// val freshPre = simplify(replace(
......@@ -29,7 +36,8 @@ object IntegerSynthesis {
// }.toMap,
// eqPre))
// (And(freshPre, recPre),
// (And(freshPre, recPre), recSubst)
// */
// }
//}
......
package leon.synthesis
import leon.purescala.Trees._
import leon.purescala.TypeTrees._
import leon.purescala.Common._
import leon.Evaluator
......@@ -9,18 +10,26 @@ object LinearEquations {
//eliminate one variable from normalizedEquation t + a1*x1 + ... + an*xn = 0
//return a mapping for each of the n variables in (pre, map, freshVars)
def elimVariable(as: Set[Identifier], normalizedEquation: List[Expr]): (Expr, List[Expr], List[Identifier]) = {
println("elim in normalized: " + normalizedEquation)
require(normalizedEquation.tail.forall{case IntLiteral(i) if i != 0 => true case _ => false})
val t: Expr = normalizedEquation.head
val coefsVars: List[Int] = normalizedEquation.tail.map{case IntLiteral(i) => i}
val orderedParams: Array[Identifier] = as.toArray
val coefsParams: List[Int] = ArithmeticNormalization(t, orderedParams).map{case IntLiteral(i) => i}.toList
val d: Int = GCD.gcd((coefsParams.tail ++ coefsVars).toSeq)
if(d > 1)
elimVariable(as, normalizedEquation.map(Division(_, IntLiteral(d))))
else {
val coefsParams0: List[Int] = ArithmeticNormalization(t, orderedParams).map{case IntLiteral(i) => i}.toList
val coefsParams: List[Int] = if(coefsParams0.head == 0) coefsParams0.tail else coefsParams0
val d: Int = GCD.gcd((coefsParams ++ coefsVars).toSeq)
if(coefsVars.size == 1) {
val coef = coefsVars.head
(Equals(Modulo(t, IntLiteral(coef)), IntLiteral(0)), List(UMinus(Division(t, IntLiteral(coef)))), List())
} else if(d > 1) {
val newCoefsParams: List[Expr] = coefsParams.map(i => IntLiteral(i/d) : Expr)
val newT = newCoefsParams.zip(IntLiteral(1)::orderedParams.map(Variable(_)).toList).foldLeft[Expr](IntLiteral(0))((acc, p) => Plus(acc, Times(p._1, p._2)))
elimVariable(as, newT :: normalizedEquation.tail.map{case IntLiteral(i) => IntLiteral(i/d) : Expr})
} else {
val basis: Array[Array[Int]] = linearSet(as, normalizedEquation.tail.map{case IntLiteral(i) => i}.toArray)
val (pre, sol) = particularSolution(as, normalizedEquation)
val freshVars: Array[Identifier] = basis(0).map(_ => FreshIdentifier("v", true))
val freshVars: Array[Identifier] = basis(0).map(_ => FreshIdentifier("v", true).setType(Int32Type))
val tbasis = basis.transpose
assert(freshVars.size == tbasis.size)
......
......@@ -8,6 +8,8 @@ import purescala.Extractors._
import purescala.TreeOps._
import purescala.TypeTrees._
import purescala.Definitions._
import LinearEquations.elimVariable
import ArithmeticNormalization.simplify
object Rules {
def all = Set[Synthesizer => Rule](
......@@ -22,7 +24,8 @@ object Rules {
new OptimisticGround(_),
new EqualitySplit(_),
new CEGIS(_),
new Assert(_)
new Assert(_),
new IntegerEquation(_)
)
}
......@@ -532,3 +535,52 @@ class EqualitySplit(synth: Synthesizer) extends Rule("Eq. Split.", synth, 10) {
}
}
}
class IntegerEquation(synth: Synthesizer) extends Rule("Integer Equation", synth, 300) {
def applyOn(task: Task): RuleResult = {
val p = task.problem
val TopLevelAnds(exprs) = p.phi
val xs = p.xs
val as = p.as
val formula = p.phi
val (eqs, others) = exprs.partition(_.isInstanceOf[Equals])
if (!eqs.isEmpty) {
val (eq@Equals(_,_), rest) = (eqs.head, eqs.tail)
val allOthers = rest ++ others
val vars: Set[Identifier] = variablesOf(eq)
val eqas: Set[Identifier] = as.toSet.intersect(vars)
val eqxs: List[Identifier] = xs.toSet.intersect(vars).toList
val ys: Set[Identifier] = xs.toSet.diff(vars)
val normalizedEq: List[Expr] = ArithmeticNormalization(Minus(eq.left, eq.right), eqxs.toArray).toList
val (eqPre, eqWitness, eqFreshVars) = elimVariable(eqas, normalizedEq)
val eqSubstMap: Map[Expr, Expr] = eqxs.zip(eqWitness).map{case (id, e) => (Variable(id), simplify(e))}.toMap
val freshFormula = simplify(replace(eqSubstMap, And(allOthers)))
(eqPre, freshFormula)
val newProblem = Problem(as, And(eqPre, p.c), freshFormula, eqFreshVars)
val onSuccess: List[Solution] => Solution = {
case List(Solution(pre, defs, term)) =>
if (eqFreshVars.isEmpty) {
Solution(pre, defs, replace(eqSubstMap, Tuple(eqxs.map(Variable(_)))))
} else {
Solution(pre, defs, LetTuple(eqFreshVars, term, replace(eqSubstMap, Tuple(eqxs.map(Variable(_))))))
}
case _ => Solution.none
}
RuleStep(List(newProblem), onSuccess)
} else {
RuleInapplicable
}
}
}
......@@ -60,6 +60,11 @@ class ArithmeticNormalizationSuite extends FunSuite {
val e2 = Times(Plus(x, Times(i(2), y)), Plus(Plus(x, y), i(1)))
checkSameExpr(toSum(expand(e2)), e2, xs)
val e3 = Minus(Plus(x, Times(i(2), y)), Plus(Plus(x, y), i(1)))
checkSameExpr(toSum(expand(e3)), e3, xs)
val e4 = UMinus(Plus(x, Times(i(2), y)))
checkSameExpr(toSum(expand(e4)), e4, xs)
}
test("apply") {
......
......@@ -48,7 +48,10 @@ class GCDSuite extends FunSuite {
assert(gcd(10,4) === 2)
assert(gcd(12,8) === 4)
assert(gcd(23,41) === 1)
assert(gcd(0,41) === 41)
assert(gcd(4,0) === 4)
assert(gcd(-4,0) === 4)
assert(gcd(-23,41) === 1)
assert(gcd(23,-41) === 1)
assert(gcd(-23,-41) === 1)
......@@ -75,9 +78,12 @@ class GCDSuite extends FunSuite {
assert(gcd(23,41,11) === 1)
assert(gcd(2,4,8,12,16,4) === 2)
assert(gcd(2,4,8,11,16,4) === 1)
assert(gcd(6,3,8, 0) === 1)
assert(gcd(2,12, 0,16,4) === 2)
assert(gcd(-12,8,6) === 2)
assert(gcd(23,-41,11) === 1)
assert(gcd(23,-41, 0,11) === 1)
assert(gcd(2,4,-8,-12,16,4) === 2)
assert(gcd(-2,-4,-8,-11,-16,-4) === 1)
}
......@@ -103,11 +109,14 @@ class GCDSuite extends FunSuite {
assert(gcd(Seq(23,41,11)) === 1)
assert(gcd(Seq(2,4,8,12,16,4)) === 2)
assert(gcd(Seq(2,4,8,11,16,4)) === 1)
assert(gcd(Seq(6,3,8, 0)) === 1)
assert(gcd(Seq(2,12, 0,16,4)) === 2)
assert(gcd(Seq(-1)) === 1)
assert(gcd(Seq(-7)) === 7)
assert(gcd(Seq(-12,8,6)) === 2)
assert(gcd(Seq(23,-41,11)) === 1)
assert(gcd(Seq(23,-41, 0,11)) === 1)
assert(gcd(Seq(2,4,-8,-12,16,4)) === 2)
assert(gcd(Seq(-2,-4,-8,-11,-16,-4)) === 1)
}
......
......@@ -199,6 +199,15 @@ class LinearEquationsSuite extends FunSuite {
val c2 = List(IntLiteral(1), IntLiteral(-1))
val (pre2, wit2, f2) = elimVariable(Set(), t2::c2)
val t3 = Minus(Times(IntLiteral(2), a), IntLiteral(3))
val c3 = List(IntLiteral(2))
val (pre3, wit3, f3) = elimVariable(Set(aId), t3::c3)
val t4 = Times(IntLiteral(2), a)
val c4 = List(IntLiteral(2), IntLiteral(4))
val (pre4, wit4, f4) = elimVariable(Set(aId), t4::c4)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment