Skip to content
Snippets Groups Projects
LinearConstraintUtil.scala 17.66 KiB
package leon
package invariant.structure

import purescala._
import purescala.Common._
import purescala.Expressions._
import purescala.ExprOps._
import purescala.Extractors._
import scala.collection.mutable.{ Map => MutableMap }
import invariant.util._
import BigInt._
import PredicateUtil._

class NotImplementedException(message: String) extends RuntimeException(message)

//a collections of utility methods that manipulate the templates
object LinearConstraintUtil {
  val zero = InfiniteIntegerLiteral(0)
  val one = InfiniteIntegerLiteral(1)
  val mone = InfiniteIntegerLiteral(-1)
  val tru = BooleanLiteral(true)
  val fls = BooleanLiteral(false)

  //some utility methods
  def getFIs(ctr: LinearConstraint): Set[FunctionInvocation] = {
    val fis = ctr.coeffMap.keys.collect {
      case fi: FunctionInvocation => fi
    }
    fis.toSet
  }

  def evaluate(lt: LinearTemplate): Option[Boolean] = lt match {
    case lc: LinearConstraint if lc.coeffMap.isEmpty =>
      ExpressionTransformer.simplify(lt.toExpr) match {
        case BooleanLiteral(v) => Some(v)
        case _ => None
      }
    case _ => None
  }

   /**
   * the expression 'Expr' is required to be a linear atomic predicate (or a template),
   * if not, an exception would be thrown.
   * For now some of the constructs are not handled.
   * The function returns a linear template or a linear constraint depending
   * on whether the expression has template variables or not
   */
  def exprToTemplate(expr: Expr): LinearTemplate = {

    //println("Expr: "+expr)
    //these are the result values
    var coeffMap = MutableMap[Expr, Expr]()
    var constant: Option[Expr] = None
    var isTemplate : Boolean = false

    def addCoefficient(term: Expr, coeff: Expr) = {
      if (coeffMap.contains(term)) {
        val value = coeffMap(term)
        val newcoeff = simplifyArithmetic(Plus(value, coeff))

        //if newcoeff becomes zero remove it from the coeffMap
        if(newcoeff == zero) {
          coeffMap.remove(term)
        } else{
          coeffMap.update(term, newcoeff)
        }
      } else coeffMap += (term -> simplifyArithmetic(coeff))

      if (variablesOf(coeff).nonEmpty) {
        isTemplate = true
      }
    }

    def addConstant(coeff: Expr) ={
      if (constant.isDefined) {
        val value = constant.get
        constant = Some(simplifyArithmetic(Plus(value, coeff)))
      } else
        constant = Some(simplifyArithmetic(coeff))

      if (variablesOf(coeff).nonEmpty) {
        isTemplate = true
      }
    }

    //recurse into plus and get all minterms
    def getMinTerms(lexpr: Expr): Seq[Expr] = lexpr match {
      case Plus(e1, e2) => getMinTerms(e1) ++ getMinTerms(e2)
      case _ => Seq(lexpr)
    }

    val linearExpr = MakeLinear(expr)
    //the top most operator should be a relation
    val Operator(Seq(lhs, InfiniteIntegerLiteral(x)), op) = linearExpr
    /*if (lhs.isInstanceOf[InfiniteIntegerLiteral])
      throw new IllegalStateException("relation on two integers, not in canonical form: " + linearExpr)*/

    val minterms =  getMinTerms(lhs)

    //handle each minterm
    minterms.foreach((minterm: Expr) => minterm match {
      case _ if (isTemplateExpr(minterm)) => {
        addConstant(minterm)
      }
      case Times(e1, e2) => {
        e2 match {
          case Variable(_) => ;
          case ResultVariable(_) => ;
          case FunctionInvocation(_, _) => ;
          case _ => throw new IllegalStateException("Multiplicand not a constraint variable: " + e2)
        }
        e1 match {
          //case c @ InfiniteIntegerLiteral(_) => addCoefficient(e2, c)
          case _ if (isTemplateExpr(e1)) => {
            addCoefficient(e2, e1)
          }
          case _ => throw new IllegalStateException("Coefficient not a constant or template expression: " + e1)
        }
      }
      case Variable(_) => {
        //here the coefficient is 1
        addCoefficient(minterm, one)
      }
      case ResultVariable(_) => {
        addCoefficient(minterm, one)
      }
      case _ => throw new IllegalStateException("Unhandled min term: " + minterm)
    })

    if(coeffMap.isEmpty && constant.isEmpty) {
      //here the generated template the constant term is zero.
      new LinearConstraint(op, Map.empty, Some(zero))
    } else if(isTemplate) {
      new LinearTemplate(op, coeffMap.toMap, constant)
    } else{
      new LinearConstraint(op, coeffMap.toMap,constant)
    }
  }

  /**
   * This method may have to do all sorts of transformation to make the expressions linear constraints.
   * This assumes that the input expression is an atomic predicate (i.e, without and, or and nots)
   * This is subjected to constant modification.
   */
  def MakeLinear(atom: Expr): Expr = {

    //pushes the minus inside the arithmetic terms
    //we assume that inExpr is in linear form
    def PushMinus(inExpr: Expr): Expr = {
      inExpr match {
        case IntLiteral(v) => IntLiteral(-v)
        case InfiniteIntegerLiteral(v) => InfiniteIntegerLiteral(-v)
        case t: Terminal => Times(mone, t)
        case fi @ FunctionInvocation(fdef, args) => Times(mone, fi)
        case UMinus(e1) => e1
        case RealUMinus(e1) => e1
        case Minus(e1, e2) => Plus(PushMinus(e1), e2)
        case RealMinus(e1, e2) => Plus(PushMinus(e1), e2)
        case Plus(e1, e2) => Plus(PushMinus(e1), PushMinus(e2))
        case RealPlus(e1, e2) => Plus(PushMinus(e1), PushMinus(e2))
        case Times(e1, e2) => {
          //here push the minus in to the coefficient which is the first argument
          Times(PushMinus(e1), e2)
        }
        case RealTimes(e1, e2) => Times(PushMinus(e1), e2)
        case _ => throw new NotImplementedException("PushMinus -- Operators not yet handled: " + inExpr)
      }
    }

    import leon.purescala.Types._
    //we assume that ine is in linear form
    def PushTimes(mul: Expr, ine: Expr): Expr = {
      val isReal = ine.getType == RealType && mul.getType == RealType
      val timesCons =
        if(isReal) RealTimes
        else Times
      ine match {
        case t: Terminal => timesCons(mul, t)
        case fi @ FunctionInvocation(fdef, ars) => timesCons(mul, fi)
        case Plus(e1, e2) => Plus(PushTimes(mul, e1), PushTimes(mul, e2))
        case RealPlus(e1, e2) =>
          val r1 = PushTimes(mul, e1)
          val r2 = PushTimes(mul, e2)
          if (isReal) RealPlus(r1, r2)
          else Plus(r1, r2)
        case Times(e1, e2) => {
          //here push the times into the coefficient which should be the first expression
          Times(PushTimes(mul, e1), e2)
        }
        case RealTimes(e1, e2) =>
          val r = PushTimes(mul, e1)
          if(isReal) RealTimes(r, e2)
          else Times(r, e2)
        case _ => throw new NotImplementedException("PushTimes -- Operators not yet handled: " + ine)
      }
    }

    //collect all the constants in addition and simplify them
    //we assume that ine is in linear form and also that all constants are integers
    def simplifyConsts(ine: Expr): (Option[Expr], BigInt) = {
      ine match {
        case IntLiteral(v) => (None, v)
        case InfiniteIntegerLiteral(v) => (None, v)
        case Plus(e1, e2) => {
          val (r1, c1) = simplifyConsts(e1)
          val (r2, c2) = simplifyConsts(e2)

          val newe = (r1, r2) match {
            case (None, None) => None
            case (Some(t), None) => Some(t)
            case (None, Some(t)) => Some(t)
            case (Some(t1), Some(t2)) => Some(Plus(t1, t2))
          }
          (newe, c1 + c2)
        }
        case _ => (Some(ine), 0)
      }
    }

    def mkLinearRecur(inExpr: Expr): Expr = {
      //println("inExpr: "+inExpr + " tpe: "+inExpr.getType)
      val res = inExpr match {
        case e @ Operator(Seq(e1, e2), op)
        if ((e.isInstanceOf[Equals] || e.isInstanceOf[LessThan]
            || e.isInstanceOf[LessEquals] || e.isInstanceOf[GreaterThan]
            || e.isInstanceOf[GreaterEquals])) => {

          //check if the expression has real valued sub-expressions
          val isReal = hasReals(e1) || hasReals(e2)
          //doing something else ... ?
    		  // println("[DEBUG] Expr 1 " + e1 + " of type " + e1.getType + " and Expr 2 " + e2 + " of type" + e2.getType)
          val (newe, newop) = e match {
            case t: Equals => (Minus(e1, e2), Equals)
            case t: LessEquals => (Minus(e1, e2), LessEquals)
            case t: GreaterEquals => (Minus(e2, e1), LessEquals)
            case t: LessThan => {
              if (isReal)
                (Minus(e1, e2), LessThan)
              else
                (Plus(Minus(e1, e2), one), LessEquals)
            }
            case t: GreaterThan => {
              if(isReal)
                 (Minus(e2,e1),LessThan)
              else
            	 (Plus(Minus(e2, e1), one), LessEquals)
            }
          }
          val r = mkLinearRecur(newe)
          //simplify the resulting constants
          val (r2, const) = simplifyConsts(r)
          val finale = if (r2.isDefined) {
            if (const != 0) Plus(r2.get, InfiniteIntegerLiteral(const))
            else r2.get
          } else InfiniteIntegerLiteral(const)
          //println(r + " simplifies to "+finale)
          newop(finale, zero)
        }
        case Minus(e1, e2) => Plus(mkLinearRecur(e1), PushMinus(mkLinearRecur(e2)))
        case RealMinus(e1, e2) => RealPlus(mkLinearRecur(e1), PushMinus(mkLinearRecur(e2)))
        case UMinus(e1) => PushMinus(mkLinearRecur(e1))
        case RealUMinus(e1) => PushMinus(mkLinearRecur(e1))
        case Times(_, _) | RealTimes(_, _) => {
          val Operator(Seq(e1, e2), op) = inExpr
          val (r1, r2) = (mkLinearRecur(e1), mkLinearRecur(e2))
          if(isTemplateExpr(r1)) {
            PushTimes(r1, r2)
          } else if(isTemplateExpr(r2)){
            PushTimes(r2, r1)
          } else
            throw new IllegalStateException("Expression not linear: " + Times(r1, r2))
        }
        case Plus(e1, e2) => Plus(mkLinearRecur(e1), mkLinearRecur(e2))
        case rp@RealPlus(e1, e2) =>
          //println(s"Expr: $rp arg1: $e1 tpe: ${e1.getType} arg2: $e2 tpe: ${e2.getType}")
          val r1 = mkLinearRecur(e1)
          val r2 = mkLinearRecur(e2)
          //println(s"Res1: $r1 tpe: ${r1.getType} Res2: $r2 tpe: ${r2.getType}")
          RealPlus(r1, r2)
        case t: Terminal => t
        case fi: FunctionInvocation => fi
        case _ => throw new IllegalStateException("Expression not linear: " + inExpr)
      }
      //println("Res: "+res+" tpe: "+res.getType)
      res
    }
    val rese = mkLinearRecur(atom)
    rese
  }

  /**
   * Replaces an expression by another expression in the terms of the given linear constraint.
   */
  def replaceInCtr(replaceMap: Map[Expr, Expr], lc: LinearConstraint): Option[LinearConstraint] = {

    //println("Replacing in "+lc+" repMap: "+replaceMap)
    val newexpr = ExpressionTransformer.simplify(simplifyArithmetic(replace(replaceMap, lc.toExpr)))
    //println("new expression: "+newexpr)
    if (newexpr == tru) None
    else if(newexpr == fls) throw new IllegalStateException("!!Constraint reduced to false during elimination: " + lc)
    else {
      val res = exprToTemplate(newexpr)
      //check if res is true or false
      evaluate(res) match {
        case Some(false) => throw new IllegalStateException("!!Constraint reduced to false during elimination: " + lc)
        case Some(true) => None //constraint reduced to true
        case _ =>
          val resctr = res.asInstanceOf[LinearConstraint]
          Some(resctr)
      }
    }
  }

    /**
   * Eliminates the specified variables from a conjunction of linear constraints (a disjunct) (that is satisfiable)
   * We assume that the disjunct is in nnf form
   *
   * debugger is a function used for debugging
   */
  val debugElimination = false
  def apply1PRuleOnDisjunct(linearCtrs: Seq[LinearConstraint], elimVars: Set[Identifier],
      debugger: Option[(Seq[LinearConstraint] => Unit)]): Seq[LinearConstraint] = {
    //eliminate one variable at a time
    //each iteration produces a new set of linear constraints
    elimVars.foldLeft(linearCtrs)((acc, elimVar) => {
      val newdisj = apply1PRuleOnDisjunct(acc, elimVar)

      if(debugElimination) {
        if(debugger.isDefined) {
          debugger.get(newdisj)
        }
      }

      newdisj
    })
  }

  def apply1PRuleOnDisjunct(linearCtrs: Seq[LinearConstraint], elimVar: Identifier): Seq[LinearConstraint] = {

    if(debugElimination)
      println("Trying to eliminate: "+elimVar)

    //collect all relevant constraints
    val emptySeq = Seq[LinearConstraint]()
    val (relCtrs, rest) = linearCtrs.foldLeft((emptySeq,emptySeq))((acc,lc) => {
      if(variablesOf(lc.toExpr).contains(elimVar)) {
        (lc +: acc._1,acc._2)
      } else {
        (acc._1,lc +: acc._2)
      }
    })

    //now consider each constraint look for (a) equality involving the elimVar or (b) check if all bounds are lower
    //or (c) if all bounds are upper.
    var elimExpr : Option[Expr] = None
    var elimCtr : Option[LinearConstraint] = None
    var allUpperBounds : Boolean = true
    var allLowerBounds : Boolean = true
    var foundEquality : Boolean = false
    var skippingEquality : Boolean = false

    relCtrs.foreach((lc) => {
      //check for an equality
      if (lc.toExpr.isInstanceOf[Equals] && lc.coeffMap.contains(elimVar.toVariable)) {
        foundEquality = true

        //here, sometimes we replace an existing expression with a better one if available
        if (elimExpr.isEmpty || shouldReplace(elimExpr.get, lc, elimVar)) {
          //if the coeffcient of elimVar is +ve the the sign of the coeff of every other term should be changed
          val InfiniteIntegerLiteral(elimCoeff) = lc.coeffMap(elimVar.toVariable)
          //make sure the value of the coefficient is 1 or  -1
          //TODO: handle cases wherein the coefficient is not 1 or -1
          if (elimCoeff == 1 || elimCoeff == -1) {
            val changeSign = if (elimCoeff > 0) true else false

            val startval = if (lc.const.isDefined) {
              val InfiniteIntegerLiteral(cval) = lc.const.get
              val newconst = if (changeSign) -cval else cval
              InfiniteIntegerLiteral(newconst)

            } else zero

            val substExpr = lc.coeffMap.foldLeft(startval: Expr)((acc, summand) => {
              val (term, InfiniteIntegerLiteral(coeff)) = summand
              if (term != elimVar.toVariable) {

                val newcoeff = if (changeSign) -coeff else coeff
                val newsummand = if (newcoeff == 1) term else Times(term, InfiniteIntegerLiteral(newcoeff))
                if (acc == zero) newsummand
                else Plus(acc, newsummand)

              } else acc
            })

            elimExpr = Some(simplifyArithmetic(substExpr))
            elimCtr = Some(lc)

            if (debugElimination) {
              println("Using ctr: " + lc + " found mapping: " + elimVar + " --> " + substExpr)
            }
          } else {
            skippingEquality = true
          }
        }
      } else if ((lc.toExpr.isInstanceOf[LessEquals] || lc.toExpr.isInstanceOf[LessThan])
        && lc.coeffMap.contains(elimVar.toVariable)) {

        val InfiniteIntegerLiteral(elimCoeff) = lc.coeffMap(elimVar.toVariable)
        if (elimCoeff > 0) {
          //here, we have found an upper bound
          allLowerBounds = false
        } else {
          //here, we have found a lower bound
          allUpperBounds = false
        }
      } else {
        //here, we assume that the operators are normalized to Equals, LessThan and LessEquals
        throw new IllegalStateException("LinearConstraint not in expeceted form : " + lc.toExpr)
      }
    })

    val newctrs = if (elimExpr.isDefined) {

      val elimMap = Map[Expr, Expr](elimVar.toVariable -> elimExpr.get)
      var repCtrs = Seq[LinearConstraint]()
      relCtrs.foreach((ctr) => {
        if (ctr != elimCtr.get) {
          //replace 'elimVar' by 'elimExpr' in ctr
          val repCtr = this.replaceInCtr(elimMap, ctr)
          if (repCtr.isDefined)
            repCtrs +:= repCtr.get
        }
      })
      repCtrs

    } else if (!foundEquality && (allLowerBounds || allUpperBounds)) {
      //here, drop all relCtrs. None of them are important
      Seq()
    } else {
      //for stats
      if(skippingEquality) {
        Stats.updateCumStats(1,"SkippedVar")
      }
      //cannot eliminate the variable
      relCtrs
    }
    val resctrs = (newctrs ++ rest)
    //println("After eliminating: "+elimVar+" : "+resctrs)
    resctrs
  }

  def sizeExpr(ine: Expr): Int = {
    val simpe = simplifyArithmetic(ine)
    var size = 0
    simplePostTransform((e: Expr) => {
      size += 1
      e
    })(simpe)
    size
  }

  def sizeCtr(ctr : LinearConstraint) : Int = {
    val coeffSize = ctr.coeffMap.foldLeft(0)((acc, pair) => {
      val (term, coeff) = pair
      if(coeff == one) acc + 1
      else acc + sizeExpr(coeff) + 2
    })
    if(ctr.const.isDefined) coeffSize + 1
    else coeffSize
  }

  def shouldReplace(currExpr : Expr, candidateCtr : LinearConstraint, elimVar: Identifier) : Boolean = {
    if(!currExpr.isInstanceOf[InfiniteIntegerLiteral]) {
      //is the candidate a constant
      if(candidateCtr.coeffMap.size == 1) true
      else{
        //computing the size of currExpr
        if(sizeExpr(currExpr) > (sizeCtr(candidateCtr) - 1)) true
        else false
      }
    } else false
  }

  //remove transitive axioms

  /**
   * Checks if the expression is linear i.e,
   * is only conjuntion and disjunction of linear atomic predicates
   */
  def isLinear(e: Expr) : Boolean = {
     e match {
       case And(args) => args forall isLinear
       case Or(args) => args forall isLinear
       case Not(arg) => isLinear(arg)
       case Implies(e1, e2) => isLinear(e1) && isLinear(e2)
       case t : Terminal => true
       case atom =>
         exprToTemplate(atom).isInstanceOf[LinearConstraint]
     }
  }
}