diff --git a/src/main/scala/leon/synthesis/ArithmeticNormalization.scala b/src/main/scala/leon/synthesis/ArithmeticNormalization.scala index aacfa20512fee4ad095587cc73dac929f4bc6580..7e916afedee3d7bbdca77a18124f4dec23acda93 100644 --- a/src/main/scala/leon/synthesis/ArithmeticNormalization.scala +++ b/src/main/scala/leon/synthesis/ArithmeticNormalization.scala @@ -10,6 +10,8 @@ import leon.purescala.Common._ object ArithmeticNormalization { + case class NonLinearExpressionException(msg: String) extends Exception + //assume the function is an arithmetic expression, not a relation //return a normal form where the [t a1 ... an] where //expr = t + a1*x1 + ... + an*xn and xs = [x1 ... xn] @@ -19,7 +21,7 @@ object ArithmeticNormalization { case Times(e1, e2) => containsId(e1, id) || containsId(e2, id) case IntLiteral(_) => false case Variable(id2) => id == id2 - case _ => sys.error("unexpected format: " + e) + case err => throw NonLinearExpressionException("unexpected in containsId: " + err) } def group(es: Seq[Expr], id: Identifier): Expr = { @@ -53,7 +55,7 @@ object ArithmeticNormalization { def rec(e: Expr): Unit = e match { case IntLiteral(i) => coef = coef*i - case Variable(id2) => if(id.isEmpty) id = Some(id2) else sys.error("multiple variables") + case Variable(id2) => if(id.isEmpty) id = Some(id2) else throw NonLinearExpressionException("multiple variable") case Times(e1, e2) => rec(e1); rec(e2) } @@ -71,10 +73,11 @@ object ArithmeticNormalization { def expand(expr: Expr): Seq[Expr] = expr match { case Plus(es1, es2) => expand(es1) ++ expand(es2) case Minus(e1, e2) => expand(e1) ++ expand(e2).map(Times(IntLiteral(-1), _): Expr) + case UMinus(e) => expand(e).map(Times(IntLiteral(-1), _): Expr) case Times(es1, es2) => multiply(expand(es1), expand(es2)) case v@Variable(_) => Seq(v) case n@IntLiteral(_) => Seq(n) - case _ => sys.error("Unexpected") + case err => throw NonLinearExpressionException("unexpected in expand: " + err) } //simple, local simplifications diff --git a/src/main/scala/leon/synthesis/IntegerSynthesis.scala b/src/main/scala/leon/synthesis/IntegerSynthesis.scala index 5ed2abfa38feb0fc8f95f62581dd9a15e58f81a2..a1fac69fcf1d0165884a39e4775edd709eb62735 100644 --- a/src/main/scala/leon/synthesis/IntegerSynthesis.scala +++ b/src/main/scala/leon/synthesis/IntegerSynthesis.scala @@ -14,13 +14,20 @@ object IntegerSynthesis { // case (None, f) => (f, xs.map(id => (Variable(id)))) // case (Some(eq), f) => { // val vars: Set[Identifier] = variablesOf(eq) - // val eqxs: List[Identifier] = xs.toSet.intersect(vars).toList - // val ys: Set[Identifier] = xs.toSet.difference(vars).toList // val eqas: Set[Identifier] = as.intersect(vars) + + // val eqxs: List[Identifier] = xs.toSet.intersect(vars).toList + // val ys: Set[Identifier] = xs.toSet.diff(vars) + // val normalizedEq: List[Expr] = ArithmeticNormalization(Minus(eq.left, eq.right), eqxs.toArray).toList // val (eqPre, eqWitness, eqFreshVars) = elimVariable(eqas, normalizedEq) + // val eqSubstMap: Map[Expr, Expr] = eqxs.zip(eqWitness).map{case (id, e) => (Variable(id), e)}.toMap // val freshFormula = simplify(replace(eqSubstMap, f)) + // (eqPre, freshFormula) + + + // /* // val (recPre, recSubst) = apply(as, ys ++ eqFreshVars, freshFormula) // val freshPre = simplify(replace( @@ -29,7 +36,8 @@ object IntegerSynthesis { // }.toMap, // eqPre)) - // (And(freshPre, recPre), + // (And(freshPre, recPre), recSubst) + // */ // } //} diff --git a/src/main/scala/leon/synthesis/LinearEquations.scala b/src/main/scala/leon/synthesis/LinearEquations.scala index 6774c99f9814b0aff5e26bb23a6cfaf0e1940f73..a3ebe75da74da3b9db54207889b66736889a9d3d 100644 --- a/src/main/scala/leon/synthesis/LinearEquations.scala +++ b/src/main/scala/leon/synthesis/LinearEquations.scala @@ -1,6 +1,7 @@ package leon.synthesis import leon.purescala.Trees._ +import leon.purescala.TypeTrees._ import leon.purescala.Common._ import leon.Evaluator @@ -9,18 +10,26 @@ object LinearEquations { //eliminate one variable from normalizedEquation t + a1*x1 + ... + an*xn = 0 //return a mapping for each of the n variables in (pre, map, freshVars) def elimVariable(as: Set[Identifier], normalizedEquation: List[Expr]): (Expr, List[Expr], List[Identifier]) = { + println("elim in normalized: " + normalizedEquation) + require(normalizedEquation.tail.forall{case IntLiteral(i) if i != 0 => true case _ => false}) val t: Expr = normalizedEquation.head val coefsVars: List[Int] = normalizedEquation.tail.map{case IntLiteral(i) => i} val orderedParams: Array[Identifier] = as.toArray - val coefsParams: List[Int] = ArithmeticNormalization(t, orderedParams).map{case IntLiteral(i) => i}.toList - val d: Int = GCD.gcd((coefsParams.tail ++ coefsVars).toSeq) - - if(d > 1) - elimVariable(as, normalizedEquation.map(Division(_, IntLiteral(d)))) - else { + val coefsParams0: List[Int] = ArithmeticNormalization(t, orderedParams).map{case IntLiteral(i) => i}.toList + val coefsParams: List[Int] = if(coefsParams0.head == 0) coefsParams0.tail else coefsParams0 + val d: Int = GCD.gcd((coefsParams ++ coefsVars).toSeq) + + if(coefsVars.size == 1) { + val coef = coefsVars.head + (Equals(Modulo(t, IntLiteral(coef)), IntLiteral(0)), List(UMinus(Division(t, IntLiteral(coef)))), List()) + } else if(d > 1) { + val newCoefsParams: List[Expr] = coefsParams.map(i => IntLiteral(i/d) : Expr) + val newT = newCoefsParams.zip(IntLiteral(1)::orderedParams.map(Variable(_)).toList).foldLeft[Expr](IntLiteral(0))((acc, p) => Plus(acc, Times(p._1, p._2))) + elimVariable(as, newT :: normalizedEquation.tail.map{case IntLiteral(i) => IntLiteral(i/d) : Expr}) + } else { val basis: Array[Array[Int]] = linearSet(as, normalizedEquation.tail.map{case IntLiteral(i) => i}.toArray) val (pre, sol) = particularSolution(as, normalizedEquation) - val freshVars: Array[Identifier] = basis(0).map(_ => FreshIdentifier("v", true)) + val freshVars: Array[Identifier] = basis(0).map(_ => FreshIdentifier("v", true).setType(Int32Type)) val tbasis = basis.transpose assert(freshVars.size == tbasis.size) diff --git a/src/main/scala/leon/synthesis/Rules.scala b/src/main/scala/leon/synthesis/Rules.scala index 2ea3725c00547a38a6c77caaebd7cf20d214ab21..4bc36af3dfa0c638f2086500af95a205047635c3 100644 --- a/src/main/scala/leon/synthesis/Rules.scala +++ b/src/main/scala/leon/synthesis/Rules.scala @@ -8,6 +8,8 @@ import purescala.Extractors._ import purescala.TreeOps._ import purescala.TypeTrees._ import purescala.Definitions._ +import LinearEquations.elimVariable +import ArithmeticNormalization.simplify object Rules { def all = Set[Synthesizer => Rule]( @@ -22,7 +24,8 @@ object Rules { new OptimisticGround(_), new EqualitySplit(_), new CEGIS(_), - new Assert(_) + new Assert(_), + new IntegerEquation(_) ) } @@ -532,3 +535,52 @@ class EqualitySplit(synth: Synthesizer) extends Rule("Eq. Split.", synth, 10) { } } } + +class IntegerEquation(synth: Synthesizer) extends Rule("Integer Equation", synth, 300) { + def applyOn(task: Task): RuleResult = { + + val p = task.problem + + val TopLevelAnds(exprs) = p.phi + val xs = p.xs + val as = p.as + val formula = p.phi + + val (eqs, others) = exprs.partition(_.isInstanceOf[Equals]) + + if (!eqs.isEmpty) { + + val (eq@Equals(_,_), rest) = (eqs.head, eqs.tail) + val allOthers = rest ++ others + + val vars: Set[Identifier] = variablesOf(eq) + val eqas: Set[Identifier] = as.toSet.intersect(vars) + + val eqxs: List[Identifier] = xs.toSet.intersect(vars).toList + val ys: Set[Identifier] = xs.toSet.diff(vars) + + val normalizedEq: List[Expr] = ArithmeticNormalization(Minus(eq.left, eq.right), eqxs.toArray).toList + val (eqPre, eqWitness, eqFreshVars) = elimVariable(eqas, normalizedEq) + + val eqSubstMap: Map[Expr, Expr] = eqxs.zip(eqWitness).map{case (id, e) => (Variable(id), simplify(e))}.toMap + val freshFormula = simplify(replace(eqSubstMap, And(allOthers))) + (eqPre, freshFormula) + + val newProblem = Problem(as, And(eqPre, p.c), freshFormula, eqFreshVars) + + val onSuccess: List[Solution] => Solution = { + case List(Solution(pre, defs, term)) => + if (eqFreshVars.isEmpty) { + Solution(pre, defs, replace(eqSubstMap, Tuple(eqxs.map(Variable(_))))) + } else { + Solution(pre, defs, LetTuple(eqFreshVars, term, replace(eqSubstMap, Tuple(eqxs.map(Variable(_)))))) + } + case _ => Solution.none + } + + RuleStep(List(newProblem), onSuccess) + } else { + RuleInapplicable + } + } +} diff --git a/src/test/scala/leon/test/synthesis/ArithmeticNormalizationSuite.scala b/src/test/scala/leon/test/synthesis/ArithmeticNormalizationSuite.scala index a2dc9ec686766709b9cf3a1bc7be87edc65fa744..992e668f7c70ce82dd372f5fb0bb36c721fcfcd1 100644 --- a/src/test/scala/leon/test/synthesis/ArithmeticNormalizationSuite.scala +++ b/src/test/scala/leon/test/synthesis/ArithmeticNormalizationSuite.scala @@ -60,6 +60,11 @@ class ArithmeticNormalizationSuite extends FunSuite { val e2 = Times(Plus(x, Times(i(2), y)), Plus(Plus(x, y), i(1))) checkSameExpr(toSum(expand(e2)), e2, xs) + val e3 = Minus(Plus(x, Times(i(2), y)), Plus(Plus(x, y), i(1))) + checkSameExpr(toSum(expand(e3)), e3, xs) + + val e4 = UMinus(Plus(x, Times(i(2), y))) + checkSameExpr(toSum(expand(e4)), e4, xs) } test("apply") { diff --git a/src/test/scala/leon/test/synthesis/GCDSuite.scala b/src/test/scala/leon/test/synthesis/GCDSuite.scala index 81404f3a520f374b62dd70ca1c0adbba78f55ec0..0df607c446ec9efca9975495f73b3f8e906c7990 100644 --- a/src/test/scala/leon/test/synthesis/GCDSuite.scala +++ b/src/test/scala/leon/test/synthesis/GCDSuite.scala @@ -48,7 +48,10 @@ class GCDSuite extends FunSuite { assert(gcd(10,4) === 2) assert(gcd(12,8) === 4) assert(gcd(23,41) === 1) + assert(gcd(0,41) === 41) + assert(gcd(4,0) === 4) + assert(gcd(-4,0) === 4) assert(gcd(-23,41) === 1) assert(gcd(23,-41) === 1) assert(gcd(-23,-41) === 1) @@ -75,9 +78,12 @@ class GCDSuite extends FunSuite { assert(gcd(23,41,11) === 1) assert(gcd(2,4,8,12,16,4) === 2) assert(gcd(2,4,8,11,16,4) === 1) + assert(gcd(6,3,8, 0) === 1) + assert(gcd(2,12, 0,16,4) === 2) assert(gcd(-12,8,6) === 2) assert(gcd(23,-41,11) === 1) + assert(gcd(23,-41, 0,11) === 1) assert(gcd(2,4,-8,-12,16,4) === 2) assert(gcd(-2,-4,-8,-11,-16,-4) === 1) } @@ -103,11 +109,14 @@ class GCDSuite extends FunSuite { assert(gcd(Seq(23,41,11)) === 1) assert(gcd(Seq(2,4,8,12,16,4)) === 2) assert(gcd(Seq(2,4,8,11,16,4)) === 1) + assert(gcd(Seq(6,3,8, 0)) === 1) + assert(gcd(Seq(2,12, 0,16,4)) === 2) assert(gcd(Seq(-1)) === 1) assert(gcd(Seq(-7)) === 7) assert(gcd(Seq(-12,8,6)) === 2) assert(gcd(Seq(23,-41,11)) === 1) + assert(gcd(Seq(23,-41, 0,11)) === 1) assert(gcd(Seq(2,4,-8,-12,16,4)) === 2) assert(gcd(Seq(-2,-4,-8,-11,-16,-4)) === 1) } diff --git a/src/test/scala/leon/test/synthesis/LinearEquationsSuite.scala b/src/test/scala/leon/test/synthesis/LinearEquationsSuite.scala index fe507bea381069479939414339c490d6f3a7942d..f84c3d3b8934d141ef1a397438b93394d9e5486b 100644 --- a/src/test/scala/leon/test/synthesis/LinearEquationsSuite.scala +++ b/src/test/scala/leon/test/synthesis/LinearEquationsSuite.scala @@ -199,6 +199,15 @@ class LinearEquationsSuite extends FunSuite { val c2 = List(IntLiteral(1), IntLiteral(-1)) val (pre2, wit2, f2) = elimVariable(Set(), t2::c2) + + val t3 = Minus(Times(IntLiteral(2), a), IntLiteral(3)) + val c3 = List(IntLiteral(2)) + val (pre3, wit3, f3) = elimVariable(Set(aId), t3::c3) + + val t4 = Times(IntLiteral(2), a) + val c4 = List(IntLiteral(2), IntLiteral(4)) + val (pre4, wit4, f4) = elimVariable(Set(aId), t4::c4) + } }