From 1915ae9a3969aa856873a233828e7b717a37e8ad Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?R=C3=A9gis=20Blanc?= <regwblanc@gmail.com>
Date: Fri, 16 Nov 2012 19:51:17 +0100
Subject: [PATCH] some refactoring of inequality solver

---
 .../synthesis/rules/IntegerEquation.scala     |   6 +-
 .../synthesis/rules/IntegerInequality.scala   | 266 +++++++++---------
 2 files changed, 141 insertions(+), 131 deletions(-)

diff --git a/src/main/scala/leon/synthesis/rules/IntegerEquation.scala b/src/main/scala/leon/synthesis/rules/IntegerEquation.scala
index a3f5f9e6f..2bc884b42 100644
--- a/src/main/scala/leon/synthesis/rules/IntegerEquation.scala
+++ b/src/main/scala/leon/synthesis/rules/IntegerEquation.scala
@@ -65,7 +65,11 @@ class IntegerEquation(synth: Synthesizer) extends Rule("Integer Equation", synth
           RuleStep(List(newProblem), onSuccess)
 
         } else {
-          val (eqPre, eqWitness, freshxs) = elimVariable(eqas, normalizedEq)
+          val (eqPre0, eqWitness, freshxs) = elimVariable(eqas, normalizedEq)
+          val eqPre = eqPre0 match {
+            case Equals(Modulo(_, IntLiteral(1)), _) => BooleanLiteral(true)
+            case c => c
+          }
 
           val eqSubstMap: Map[Expr, Expr] = neqxs.zip(eqWitness).map{case (id, e) => (Variable(id), simplify(e))}.toMap
           val freshFormula = simplify(replace(eqSubstMap, And(allOthers)))
diff --git a/src/main/scala/leon/synthesis/rules/IntegerInequality.scala b/src/main/scala/leon/synthesis/rules/IntegerInequality.scala
index 807479103..63a6dd65e 100644
--- a/src/main/scala/leon/synthesis/rules/IntegerInequality.scala
+++ b/src/main/scala/leon/synthesis/rules/IntegerInequality.scala
@@ -13,149 +13,155 @@ import ArithmeticNormalization.simplify
 
 class IntegerInequality(synth: Synthesizer) extends Rule("Integer Inequality", synth, 300) {
   def applyOn(task: Task): RuleResult = {
-
     val problem = task.problem
     val TopLevelAnds(exprs) = problem.phi
 
     //assume that we only have inequalities
-    var nonIneq = false
-
+    var lhsSides: List[Expr] = Nil
+    var exprNotUsed: List[Expr] = Nil
     //normalized all inequalities to LessEquals(t, 0)
-    val lhsSides: List[Expr] = exprs.map{
-      case LessThan(a, b) => Plus(Minus(a, b), IntLiteral(1))
-      case LessEquals(a, b) => Minus(a, b)
-      case GreaterThan(a, b) => Plus(Minus(b, a), IntLiteral(1))
-      case GreaterEquals(a, b) => Minus(b, a)
-      case _ => {nonIneq = true; null}
-    }.toList
+    exprs.foreach{
+      case LessThan(a, b) => lhsSides ::= Plus(Minus(a, b), IntLiteral(1))
+      case LessEquals(a, b) => lhsSides ::= Minus(a, b)
+      case GreaterThan(a, b) => lhsSides ::= Plus(Minus(b, a), IntLiteral(1))
+      case GreaterEquals(a, b) => lhsSides ::= Minus(b, a)
+      case e => exprNotUsed ::= e
+    }
 
-    if(nonIneq) RuleInapplicable else {
-      var processedVar: Option[Identifier] = None
-      for(e <- lhsSides if processedVar == None) {
-        val vars = variablesOf(e).intersect(problem.xs.toSet)
-        if(!vars.isEmpty)
-          processedVar = Some(vars.head)
+    val ineqVars = lhsSides.foldLeft(Set[Identifier]())((acc, lhs) => acc ++ variablesOf(lhs))
+    val nonIneqVars = exprNotUsed.foldLeft(Set[Identifier]())((acc, x) => acc ++ variablesOf(x))
+    val candidateVars = ineqVars.intersect(problem.xs.toSet).filterNot(nonIneqVars.contains(_))
+    if(candidateVars.isEmpty) RuleInapplicable else {
+      val processedVar = candidateVars.head
+      val otherVars: List[Identifier] = problem.xs.filterNot(_ == processedVar)
+
+      val normalizedLhs: List[List[Expr]] = lhsSides.map(ArithmeticNormalization(_, Array(processedVar)).toList)
+      var upperBounds: List[(Expr, Int)] = Nil // (t, c) means c*x <= t
+      var lowerBounds: List[(Expr, Int)] = Nil // (t, c) means t <= c*x
+      normalizedLhs.foreach{
+        case List(t, IntLiteral(i)) => 
+          if(i > 0) upperBounds ::= (simplify(UMinus(t)), i)
+          else if(i < 0) lowerBounds ::= (simplify(t), -i)
+          else /*if (i == 0)*/ exprNotUsed ::= LessEquals(t, IntLiteral(0))
+        case err => sys.error("unexpected from normal form: " + err)
       }
-      processedVar match {
-        case None => RuleInapplicable 
-        case Some(processedVar) => {
-          println("processed Var: " + processedVar)
-          val normalizedLhs: List[List[Expr]] = lhsSides.map(ArithmeticNormalization(_, Array(processedVar)).toList)
-          var upperBounds: List[(Expr, Int)] = Nil // (t, c) means c*x <= t
-          var lowerBounds: List[(Expr, Int)] = Nil // (t, c) means t <= c*x
-          normalizedLhs.foreach{
-            case List(t, IntLiteral(i)) => 
-              if(i > 0) upperBounds ::= (simplify(UMinus(t)), i)
-              else if(i < 0) lowerBounds ::= (simplify(t), -i)
-            case err => sys.error("unexpected from normal form: " + err)
-          }
-
-          val otherVars: List[Identifier] = problem.xs.filterNot(_ == processedVar)
-
-          println("otherVars: " + otherVars)
-          if(otherVars.isEmpty) { //here we can simply evaluate the precondition and return a bound
-            val witness = if(upperBounds.isEmpty) 
-                Division(lowerBounds.head._1, IntLiteral(lowerBounds.head._2))
-              else
-                Division(upperBounds.head._1, IntLiteral(upperBounds.head._2))
-            val pre = if(lowerBounds.isEmpty || upperBounds.isEmpty) BooleanLiteral(true) else sys.error("TODO")
-            RuleSuccess(Solution(pre, Set(), witness))
-          } else {
-            val L = GCD.lcm((upperBounds ::: lowerBounds).map(_._2))
-            val newUpperBounds: List[Expr] = upperBounds.map{case (bound, coef) => Times(IntLiteral(L/coef), bound)}
-            val newLowerBounds: List[Expr] = lowerBounds.map{case (bound, coef) => Times(IntLiteral(L/coef), bound)}
-
-            val remainderIds: List[Identifier] = newUpperBounds.map(_ => FreshIdentifier("k", true).setType(Int32Type))
-            val quotientIds: List[Identifier] = newUpperBounds.map(_ => FreshIdentifier("l", true).setType(Int32Type))
-
-            val subProblemFormula = simplify(And(
-              newUpperBounds.zip(remainderIds).zip(quotientIds).flatMap{
-                case ((b, k), l) => Equals(b, Plus(Times(IntLiteral(L), Variable(l)), Variable(k))) :: 
-                                    newLowerBounds.map(lbound => LessEquals(Variable(k), Minus(b, lbound)))
-              }))
-
-            val subProblem = Problem(problem.as ++ remainderIds, problem.c, subProblemFormula, otherVars ++ quotientIds) 
-
-            def onSuccess(sols: List[Solution]): Solution = sols match {
-              case List(Solution(pre, defs, term)) => {
-
-                val maxVarDecls: Seq[VarDecl] = lowerBounds.map(_ => VarDecl(FreshIdentifier("b"), Int32Type))
-                val maxFun = new FunDef(FreshIdentifier("max"), Int32Type, maxVarDecls)
-                def maxRec(bounds: List[Expr]): Expr = bounds match {
-                  case (x1 :: x2 :: xs) => {
-                    val v = FreshIdentifier("m").setType(Int32Type)
-                    Let(v, IfExpr(LessThan(x1, x2), x2, x1), maxRec(Variable(v) :: xs))
-                  }
-                  case (x :: Nil) => x
-                  case Nil => sys.error("cannot build a max expression with no argument")
-                }
-                if(!lowerBounds.isEmpty)
-                  maxFun.body = Some(maxRec(maxVarDecls.map(vd => Variable(vd.id)).toList))
-
-                val minVarDecls: Seq[VarDecl] = upperBounds.map(_ => VarDecl(FreshIdentifier("b"), Int32Type))
-                val minFun = new FunDef(FreshIdentifier("min"), Int32Type, minVarDecls)
-                def minRec(bounds: List[Expr]): Expr = bounds match {
-                  case (x1 :: x2 :: xs) => {
-                    val v = FreshIdentifier("m").setType(Int32Type)
-                    Let(v, IfExpr(LessThan(x1, x2), x1, x2), minRec(Variable(v) :: xs))
-                  }
-                  case (x :: Nil) => x
-                  case Nil => sys.error("cannot build a min expression with no argument")
-                }
-                if(!upperBounds.isEmpty)
-                  minFun.body = Some(minRec(minVarDecls.map(vd => Variable(vd.id)).toList))
 
+      //define max function
+      val maxVarDecls: Seq[VarDecl] = lowerBounds.map(_ => VarDecl(FreshIdentifier("b"), Int32Type))
+      val maxFun = new FunDef(FreshIdentifier("max"), Int32Type, maxVarDecls)
+      def maxRec(bounds: List[Expr]): Expr = bounds match {
+        case (x1 :: x2 :: xs) => {
+          val v = FreshIdentifier("m").setType(Int32Type)
+          Let(v, IfExpr(LessThan(x1, x2), x2, x1), maxRec(Variable(v) :: xs))
+        }
+        case (x :: Nil) => x
+        case Nil => sys.error("cannot build a max expression with no argument")
+      }
+      if(!lowerBounds.isEmpty)
+        maxFun.body = Some(maxRec(maxVarDecls.map(vd => Variable(vd.id)).toList))
+      def max(xs: Seq[Expr]): Expr = FunctionInvocation(maxFun, xs)
+      //define min function
+      val minVarDecls: Seq[VarDecl] = upperBounds.map(_ => VarDecl(FreshIdentifier("b"), Int32Type))
+      val minFun = new FunDef(FreshIdentifier("min"), Int32Type, minVarDecls)
+      def minRec(bounds: List[Expr]): Expr = bounds match {
+        case (x1 :: x2 :: xs) => {
+          val v = FreshIdentifier("m").setType(Int32Type)
+          Let(v, IfExpr(LessThan(x1, x2), x1, x2), minRec(Variable(v) :: xs))
+        }
+        case (x :: Nil) => x
+        case Nil => sys.error("cannot build a min expression with no argument")
+      }
+      if(!upperBounds.isEmpty)
+        minFun.body = Some(minRec(minVarDecls.map(vd => Variable(vd.id)).toList))
+      def min(xs: Seq[Expr]): Expr = FunctionInvocation(minFun, xs)
+      val floorFun = new FunDef(FreshIdentifier("floorDiv"), Int32Type, Seq(
+                                  VarDecl(FreshIdentifier("x"), Int32Type),
+                                  VarDecl(FreshIdentifier("x"), Int32Type)))
+      val ceilingFun = new FunDef(FreshIdentifier("ceilingDiv"), Int32Type, Seq(
+                                  VarDecl(FreshIdentifier("x"), Int32Type),
+                                  VarDecl(FreshIdentifier("x"), Int32Type)))
+      def floor(x: Expr, y: Expr): Expr = FunctionInvocation(floorFun, Seq(x, y))
+      def ceiling(x: Expr, y: Expr): Expr = FunctionInvocation(ceilingFun, Seq(x, y))
+
+      val witness: Expr = if(upperBounds.isEmpty) {
+        if(lowerBounds.size > 1) max(lowerBounds.map{case (b, c) => ceiling(b, IntLiteral(c))})
+        else ceiling(lowerBounds.head._1, IntLiteral(lowerBounds.head._2))
+      } else {
+        if(upperBounds.size > 1) min(upperBounds.map{case (b, c) => floor(b, IntLiteral(c))})
+        else ceiling(upperBounds.head._1, IntLiteral(upperBounds.head._2))
+      }
 
-                if(newUpperBounds.isEmpty) {
-                  Solution(pre, defs,
-                    LetTuple(otherVars++quotientIds, term,
-                      Let(processedVar, 
-                          FunctionInvocation(maxFun, lowerBounds.map(t => Division(t._1, IntLiteral(t._2)))),
-                          Tuple(problem.xs.map(Variable(_))))))
-                } else if(newLowerBounds.isEmpty) {
-                  Solution(pre, defs,
-                    LetTuple(otherVars++quotientIds, term,
-                      Let(processedVar, 
+      if(otherVars.isEmpty) { //here we can simply evaluate the precondition and return a witness
+        val pre = if(lowerBounds.isEmpty || upperBounds.isEmpty) BooleanLiteral(true) else sys.error("TODO")
+        RuleSuccess(Solution(pre, Set(), witness))
+      } else {
+        val L = GCD.lcm((upperBounds ::: lowerBounds).map(_._2))
+        val newUpperBounds: List[Expr] = upperBounds.map{case (bound, coef) => Times(IntLiteral(L/coef), bound)}
+        val newLowerBounds: List[Expr] = lowerBounds.map{case (bound, coef) => Times(IntLiteral(L/coef), bound)}
+
+        val remainderIds: List[Identifier] = newUpperBounds.map(_ => FreshIdentifier("k", true).setType(Int32Type))
+        val quotientIds: List[Identifier] = newUpperBounds.map(_ => FreshIdentifier("l", true).setType(Int32Type))
+
+        val subProblemFormula = simplify(And(
+          newUpperBounds.zip(remainderIds).zip(quotientIds).flatMap{
+            case ((b, k), l) => Equals(b, Plus(Times(IntLiteral(L), Variable(l)), Variable(k))) :: 
+                                newLowerBounds.map(lbound => LessEquals(Variable(k), Minus(b, lbound)))
+          } ++ exprNotUsed))
+        val subProblemxs: List[Identifier] = otherVars ++ quotientIds
+        val subProblem = Problem(problem.as ++ remainderIds, problem.c, subProblemFormula, subProblemxs)
+
+        def onSuccess(sols: List[Solution]): Solution = sols match {
+          case List(Solution(pre, defs, term)) => {
+            val involvedVariables = (upperBounds++lowerBounds).foldLeft(Set[Identifier]())((acc, t) => {
+              acc ++ variablesOf(t._1)
+            }).intersect(problem.xs.toSet) //output variables involved in the bounds of the process variables
+            if(involvedVariables.isEmpty) { //here we can just evaluate the lower and upper bound
+              val newPre = And(
+                for((ub, uc) <- upperBounds; (lb, lc) <- lowerBounds) 
+                  yield LessEquals(ceilingDiv(lb, IntLiteral(lc)), floorDiv(ub, IntLiteral(uc))))
+              Solution(And(newPre, pre), defs,
+                LetTuple(subProblemxs, term,
+                  Let(processedVar, witness,
+                    Tuple(problem.xs.map(Variable(_))))))
+            } else if(upperBounds.isEmpty || lowerBounds.isEmpty) {
+                Solution(pre, defs,
+                  LetTuple(otherVars++quotientIds, term,
+                    Let(processedVar, witness,
+                      Tuple(problem.xs.map(Variable(_))))))
+            } else if(upperBounds.size > 1)
+              Solution.none
+            else {
+              val k = remainderIds.head
+              
+              val loopCounter = Variable(FreshIdentifier("i").setType(Int32Type))
+              val concretePre = replace(Map(Variable(k) -> loopCounter), pre)
+              val returnType = TupleType(problem.xs.map(_.getType))
+              val funDef = new FunDef(FreshIdentifier("rec", true), returnType, Seq(VarDecl(loopCounter.id, Int32Type)))
+              val funBody = IfExpr(
+                LessThan(loopCounter, IntLiteral(0)),
+                Error("No solution exists"),
+                IfExpr(
+                  concretePre,
+                  LetTuple(otherVars++quotientIds, term,
+                    Let(processedVar, 
+                        if(newUpperBounds.isEmpty)
+                          FunctionInvocation(maxFun, lowerBounds.map(t => Division(t._1, IntLiteral(t._2))))
+                        else
                           FunctionInvocation(minFun, upperBounds.map(t => Division(t._1, IntLiteral(t._2)))),
-                          Tuple(problem.xs.map(Variable(_))))))
-                } else if(newUpperBounds.size > 1)
-                  Solution.none
-                else {
-                  val k = remainderIds.head
-                  
-                  val loopCounter = Variable(FreshIdentifier("i").setType(Int32Type))
-                  val concretePre = replace(Map(Variable(k) -> loopCounter), pre)
-                  val returnType = TupleType(problem.xs.map(_.getType))
-                  val funDef = new FunDef(FreshIdentifier("rec", true), returnType, Seq(VarDecl(loopCounter.id, Int32Type)))
-                  val funBody = IfExpr(
-                    LessThan(loopCounter, IntLiteral(0)),
-                    Error("No solution exists"),
-                    IfExpr(
-                      concretePre,
-                      LetTuple(otherVars++quotientIds, term,
-                        Let(processedVar, 
-                            if(newUpperBounds.isEmpty)
-                              FunctionInvocation(maxFun, lowerBounds.map(t => Division(t._1, IntLiteral(t._2))))
-                            else
-                              FunctionInvocation(minFun, upperBounds.map(t => Division(t._1, IntLiteral(t._2)))),
-                            Tuple(problem.xs.map(Variable(_))))
-                      ),
-                      FunctionInvocation(funDef, Seq(Minus(loopCounter, IntLiteral(1))))
-                    )
-                  )
-                  funDef.body = Some(funBody)
-
-                  println("generated code: " + funDef)
-
-                  Solution(pre, defs + funDef, FunctionInvocation(funDef, Seq(IntLiteral(L-1))))
-                }
-              }
-              case _ => Solution.none
+                        Tuple(problem.xs.map(Variable(_))))
+                  ),
+                  FunctionInvocation(funDef, Seq(Minus(loopCounter, IntLiteral(1))))
+                )
+              )
+              funDef.body = Some(funBody)
+
+              Solution(pre, defs + funDef, FunctionInvocation(funDef, Seq(IntLiteral(L-1))))
             }
-
-            RuleStep(List(subProblem), onSuccess)
           }
+          case _ => Solution.none
         }
+
+        RuleStep(List(subProblem), onSuccess)
       }
     }
   }
-- 
GitLab