diff --git a/src/main/scala/leon/synthesis/rules/IntegerEquation.scala b/src/main/scala/leon/synthesis/rules/IntegerEquation.scala index 07a60c02b34e4f15c3bd3d44d044f7f5ab7cef96..741b31bd458858948cde07aa2f96c662568d6673 100644 --- a/src/main/scala/leon/synthesis/rules/IntegerEquation.scala +++ b/src/main/scala/leon/synthesis/rules/IntegerEquation.scala @@ -11,7 +11,7 @@ import purescala.Definitions._ import LinearEquations.elimVariable class IntegerEquation(synth: Synthesizer) extends Rule("Integer Equation", synth, 300) { - def attemptToApplyOn(problem: Problem): RuleResult = { + def attemptToApplyOn(problem: Problem): RuleResult = if(problem.xs.isEmpty) RuleInapplicable else { val TopLevelAnds(exprs) = problem.phi diff --git a/src/main/scala/leon/synthesis/rules/IntegerInequalities.scala b/src/main/scala/leon/synthesis/rules/IntegerInequalities.scala index 3249c928942f946c9337d370ebc9f6c0f3164d4c..5aa459614154cd4ec4e2daa4c5f6fbdb5ca47eec 100644 --- a/src/main/scala/leon/synthesis/rules/IntegerInequalities.scala +++ b/src/main/scala/leon/synthesis/rules/IntegerInequalities.scala @@ -34,14 +34,16 @@ class IntegerInequalities(synth: Synthesizer) extends Rule("Integer Inequalities val processedVar = candidateVars.head val otherVars: List[Identifier] = problem.xs.filterNot(_ == processedVar) + println("lhsSides: " + lhsSides) val normalizedLhs: List[List[Expr]] = lhsSides.map(ArithmeticNormalization(_, Array(processedVar)).toList) + println("normalized: " + normalizedLhs.mkString("\n")) var upperBounds: List[(Expr, Int)] = Nil // (t, c) means c*x <= t var lowerBounds: List[(Expr, Int)] = Nil // (t, c) means t <= c*x normalizedLhs.foreach{ case List(t, IntLiteral(i)) => if(i > 0) upperBounds ::= (simplifyArithmetic(UMinus(t)), i) - else if(i < 0) lowerBounds ::= (simplify(t), -i) - else /*if (i == 0)*/ exprNotUsed ::= LessEquals(t, IntLiteral(0)) + else if(i < 0) lowerBounds ::= (simplifyArithmetic(t), -i) + else exprNotUsed ::= LessEquals(t, IntLiteral(0)) //TODO: make sure that these are added as preconditions case err => sys.error("unexpected from normal form: " + err) } @@ -92,9 +94,10 @@ class IntegerInequalities(synth: Synthesizer) extends Rule("Integer Inequalities } if(otherVars.isEmpty) { //here we can simply evaluate the precondition and return a witness - val pre = And( - for((ub, uc) <- upperBounds; (lb, lc) <- lowerBounds) - yield LessEquals(ceilingDiv(lb, IntLiteral(lc)), floorDiv(ub, IntLiteral(uc)))) + + val constraints: List[Expr] = for((ub, uc) <- upperBounds; (lb, lc) <- lowerBounds) + yield LessEquals(ceilingDiv(lb, IntLiteral(lc)), floorDiv(ub, IntLiteral(uc))) + val pre = And(exprNotUsed ++ constraints) RuleFastSuccess(Solution(pre, Set(), witness)) } else { val L = lcm((upperBounds ::: lowerBounds).map(_._2)) @@ -104,7 +107,7 @@ class IntegerInequalities(synth: Synthesizer) extends Rule("Integer Inequalities val remainderIds: List[Identifier] = newUpperBounds.map(_ => FreshIdentifier("k", true).setType(Int32Type)) val quotientIds: List[Identifier] = newUpperBounds.map(_ => FreshIdentifier("l", true).setType(Int32Type)) - val subProblemFormula = simplify(And( + val subProblemFormula = simplifyArithmetic(And( newUpperBounds.zip(remainderIds).zip(quotientIds).flatMap{ case ((b, k), l) => Equals(b, Plus(Times(IntLiteral(L), Variable(l)), Variable(k))) :: newLowerBounds.map(lbound => LessEquals(Variable(k), Minus(b, lbound))) @@ -140,7 +143,7 @@ class IntegerInequalities(synth: Synthesizer) extends Rule("Integer Inequalities val concreteTerm = replace(Map(Variable(k) -> loopCounter), term) val returnType = TupleType(problem.xs.map(_.getType)) val funDef = new FunDef(FreshIdentifier("rec", true), returnType, Seq(VarDecl(loopCounter.id, Int32Type))) - val funBody = IfExpr( + val funBody = simplifyArithmetic(IfExpr( LessThan(loopCounter, IntLiteral(0)), Error("No solution exists"), IfExpr( @@ -151,7 +154,7 @@ class IntegerInequalities(synth: Synthesizer) extends Rule("Integer Inequalities ), FunctionInvocation(funDef, Seq(Minus(loopCounter, IntLiteral(1)))) ) - ) + )) funDef.body = Some(funBody) Solution(pre, defs + funDef, FunctionInvocation(funDef, Seq(IntLiteral(L-1))))