From f751527f2a080ff0355eab97692437360ea6a1dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9gis=20Blanc?= <regwblanc@gmail.com> Date: Wed, 14 Nov 2012 02:15:10 +0000 Subject: [PATCH] integrate equation synthesis as a rule --- .../synthesis/ArithmeticNormalization.scala | 3 +- .../leon/synthesis/IntegerSynthesis.scala | 14 +++-- .../leon/synthesis/LinearEquations.scala | 22 +++++--- src/main/scala/leon/synthesis/Rules.scala | 54 ++++++++++++++++++- .../ArithmeticNormalizationSuite.scala | 5 ++ .../test/synthesis/LinearEquationsSuite.scala | 9 ++++ 6 files changed, 95 insertions(+), 12 deletions(-) diff --git a/src/main/scala/leon/synthesis/ArithmeticNormalization.scala b/src/main/scala/leon/synthesis/ArithmeticNormalization.scala index aacfa2051..ba22b8f0d 100644 --- a/src/main/scala/leon/synthesis/ArithmeticNormalization.scala +++ b/src/main/scala/leon/synthesis/ArithmeticNormalization.scala @@ -71,10 +71,11 @@ object ArithmeticNormalization { def expand(expr: Expr): Seq[Expr] = expr match { case Plus(es1, es2) => expand(es1) ++ expand(es2) case Minus(e1, e2) => expand(e1) ++ expand(e2).map(Times(IntLiteral(-1), _): Expr) + case UMinus(e) => expand(e).map(Times(IntLiteral(-1), _): Expr) case Times(es1, es2) => multiply(expand(es1), expand(es2)) case v@Variable(_) => Seq(v) case n@IntLiteral(_) => Seq(n) - case _ => sys.error("Unexpected") + case err => sys.error("Unexpected in expand: " + err) } //simple, local simplifications diff --git a/src/main/scala/leon/synthesis/IntegerSynthesis.scala b/src/main/scala/leon/synthesis/IntegerSynthesis.scala index 5ed2abfa3..a1fac69fc 100644 --- a/src/main/scala/leon/synthesis/IntegerSynthesis.scala +++ b/src/main/scala/leon/synthesis/IntegerSynthesis.scala @@ -14,13 +14,20 @@ object IntegerSynthesis { // case (None, f) => (f, xs.map(id => (Variable(id)))) // case (Some(eq), f) => { // val vars: Set[Identifier] = variablesOf(eq) - // val eqxs: List[Identifier] = xs.toSet.intersect(vars).toList - // val ys: Set[Identifier] = xs.toSet.difference(vars).toList // val eqas: Set[Identifier] = as.intersect(vars) + + // val eqxs: List[Identifier] = xs.toSet.intersect(vars).toList + // val ys: Set[Identifier] = xs.toSet.diff(vars) + // val normalizedEq: List[Expr] = ArithmeticNormalization(Minus(eq.left, eq.right), eqxs.toArray).toList // val (eqPre, eqWitness, eqFreshVars) = elimVariable(eqas, normalizedEq) + // val eqSubstMap: Map[Expr, Expr] = eqxs.zip(eqWitness).map{case (id, e) => (Variable(id), e)}.toMap // val freshFormula = simplify(replace(eqSubstMap, f)) + // (eqPre, freshFormula) + + + // /* // val (recPre, recSubst) = apply(as, ys ++ eqFreshVars, freshFormula) // val freshPre = simplify(replace( @@ -29,7 +36,8 @@ object IntegerSynthesis { // }.toMap, // eqPre)) - // (And(freshPre, recPre), + // (And(freshPre, recPre), recSubst) + // */ // } //} diff --git a/src/main/scala/leon/synthesis/LinearEquations.scala b/src/main/scala/leon/synthesis/LinearEquations.scala index 6774c99f9..2f7ed20de 100644 --- a/src/main/scala/leon/synthesis/LinearEquations.scala +++ b/src/main/scala/leon/synthesis/LinearEquations.scala @@ -1,6 +1,7 @@ package leon.synthesis import leon.purescala.Trees._ +import leon.purescala.TypeTrees._ import leon.purescala.Common._ import leon.Evaluator @@ -9,18 +10,25 @@ object LinearEquations { //eliminate one variable from normalizedEquation t + a1*x1 + ... + an*xn = 0 //return a mapping for each of the n variables in (pre, map, freshVars) def elimVariable(as: Set[Identifier], normalizedEquation: List[Expr]): (Expr, List[Expr], List[Identifier]) = { + println("elim in normalized: " + normalizedEquation) val t: Expr = normalizedEquation.head val coefsVars: List[Int] = normalizedEquation.tail.map{case IntLiteral(i) => i} val orderedParams: Array[Identifier] = as.toArray - val coefsParams: List[Int] = ArithmeticNormalization(t, orderedParams).map{case IntLiteral(i) => i}.toList - val d: Int = GCD.gcd((coefsParams.tail ++ coefsVars).toSeq) - - if(d > 1) - elimVariable(as, normalizedEquation.map(Division(_, IntLiteral(d)))) - else { + val coefsParams0: List[Int] = ArithmeticNormalization(t, orderedParams).map{case IntLiteral(i) => i}.toList + val coefsParams: List[Int] = if(coefsParams0.head == 0) coefsParams0.tail else coefsParams0 + val d: Int = GCD.gcd((coefsParams ++ coefsVars).toSeq) + + if(coefsVars.size == 1) { + val coef = coefsVars.head + (Equals(Modulo(t, IntLiteral(coef)), IntLiteral(0)), List(UMinus(Division(t, IntLiteral(coef)))), List()) + } else if(d > 1) { + val newCoefsParams: List[Expr] = coefsParams.map(i => IntLiteral(i/d) : Expr) + val newT = newCoefsParams.zip(IntLiteral(1)::orderedParams.map(Variable(_)).toList).foldLeft[Expr](IntLiteral(0))((acc, p) => Plus(acc, Times(p._1, p._2))) + elimVariable(as, newT :: normalizedEquation.tail.map{case IntLiteral(i) => IntLiteral(i/d) : Expr}) + } else { val basis: Array[Array[Int]] = linearSet(as, normalizedEquation.tail.map{case IntLiteral(i) => i}.toArray) val (pre, sol) = particularSolution(as, normalizedEquation) - val freshVars: Array[Identifier] = basis(0).map(_ => FreshIdentifier("v", true)) + val freshVars: Array[Identifier] = basis(0).map(_ => FreshIdentifier("v", true).setType(Int32Type)) val tbasis = basis.transpose assert(freshVars.size == tbasis.size) diff --git a/src/main/scala/leon/synthesis/Rules.scala b/src/main/scala/leon/synthesis/Rules.scala index b79ade73c..519a8b3cc 100644 --- a/src/main/scala/leon/synthesis/Rules.scala +++ b/src/main/scala/leon/synthesis/Rules.scala @@ -8,6 +8,8 @@ import purescala.Extractors._ import purescala.TreeOps._ import purescala.TypeTrees._ import purescala.Definitions._ +import LinearEquations.elimVariable +import ArithmeticNormalization.simplify object Rules { def all = Set[Synthesizer => Rule]( @@ -21,7 +23,8 @@ object Rules { new UnconstrainedOutput(_), new OptimisticGround(_), new CEGIS(_), - new Assert(_) + new Assert(_), + new IntegerEquation(_) ) } @@ -512,3 +515,52 @@ class OptimisticGround(synth: Synthesizer) extends Rule("Optimistic Ground", syn } } } + +class IntegerEquation(synth: Synthesizer) extends Rule("Integer Equation", synth, 300) { + def applyOn(task: Task): RuleResult = { + + val p = task.problem + + val TopLevelAnds(exprs) = p.phi + val xs = p.xs + val as = p.as + val formula = p.phi + + val (eqs, others) = exprs.partition(_.isInstanceOf[Equals]) + + if (!eqs.isEmpty) { + + val (eq@Equals(_,_), rest) = (eqs.head, eqs.tail) + val allOthers = rest ++ others + + val vars: Set[Identifier] = variablesOf(eq) + val eqas: Set[Identifier] = as.toSet.intersect(vars) + + val eqxs: List[Identifier] = xs.toSet.intersect(vars).toList + val ys: Set[Identifier] = xs.toSet.diff(vars) + + val normalizedEq: List[Expr] = ArithmeticNormalization(Minus(eq.left, eq.right), eqxs.toArray).toList + val (eqPre, eqWitness, eqFreshVars) = elimVariable(eqas, normalizedEq) + + val eqSubstMap: Map[Expr, Expr] = eqxs.zip(eqWitness).map{case (id, e) => (Variable(id), simplify(e))}.toMap + val freshFormula = simplify(replace(eqSubstMap, And(allOthers))) + (eqPre, freshFormula) + + val newProblem = Problem(as, And(eqPre, p.c), freshFormula, eqFreshVars) + + val onSuccess: List[Solution] => Solution = { + case List(Solution(pre, defs, term)) => + if (eqFreshVars.isEmpty) { + Solution(pre, defs, replace(eqSubstMap, Tuple(eqxs.map(Variable(_))))) + } else { + Solution(pre, defs, LetTuple(eqFreshVars, term, replace(eqSubstMap, Tuple(eqxs.map(Variable(_)))))) + } + case _ => Solution.none + } + + RuleStep(List(newProblem), onSuccess) + } else { + RuleInapplicable + } + } +} diff --git a/src/test/scala/leon/test/synthesis/ArithmeticNormalizationSuite.scala b/src/test/scala/leon/test/synthesis/ArithmeticNormalizationSuite.scala index a2dc9ec68..992e668f7 100644 --- a/src/test/scala/leon/test/synthesis/ArithmeticNormalizationSuite.scala +++ b/src/test/scala/leon/test/synthesis/ArithmeticNormalizationSuite.scala @@ -60,6 +60,11 @@ class ArithmeticNormalizationSuite extends FunSuite { val e2 = Times(Plus(x, Times(i(2), y)), Plus(Plus(x, y), i(1))) checkSameExpr(toSum(expand(e2)), e2, xs) + val e3 = Minus(Plus(x, Times(i(2), y)), Plus(Plus(x, y), i(1))) + checkSameExpr(toSum(expand(e3)), e3, xs) + + val e4 = UMinus(Plus(x, Times(i(2), y))) + checkSameExpr(toSum(expand(e4)), e4, xs) } test("apply") { diff --git a/src/test/scala/leon/test/synthesis/LinearEquationsSuite.scala b/src/test/scala/leon/test/synthesis/LinearEquationsSuite.scala index fe507bea3..f84c3d3b8 100644 --- a/src/test/scala/leon/test/synthesis/LinearEquationsSuite.scala +++ b/src/test/scala/leon/test/synthesis/LinearEquationsSuite.scala @@ -199,6 +199,15 @@ class LinearEquationsSuite extends FunSuite { val c2 = List(IntLiteral(1), IntLiteral(-1)) val (pre2, wit2, f2) = elimVariable(Set(), t2::c2) + + val t3 = Minus(Times(IntLiteral(2), a), IntLiteral(3)) + val c3 = List(IntLiteral(2)) + val (pre3, wit3, f3) = elimVariable(Set(aId), t3::c3) + + val t4 = Times(IntLiteral(2), a) + val c4 = List(IntLiteral(2), IntLiteral(4)) + val (pre4, wit4, f4) = elimVariable(Set(aId), t4::c4) + } } -- GitLab