package leon package synthesis package rules import purescala.Common._ import purescala.Trees._ import purescala.Extractors._ import purescala.TreeOps._ import purescala.TreeNormalizations.linearArithmeticForm import purescala.TypeTrees._ import purescala.Definitions._ import LinearEquations.elimVariable import leon.synthesis.Algebra.lcm case object IntegerInequalities extends Rule("Integer Inequalities") { def instantiateOn(sctx: SynthesisContext, problem: Problem): Traversable[RuleInstantiation] = { val TopLevelAnds(exprs) = problem.phi //assume that we only have inequalities var lhsSides: List[Expr] = Nil var exprNotUsed: List[Expr] = Nil //normalized all inequalities to LessEquals(t, 0) exprs.foreach{ case LessThan(a, b) => lhsSides ::= Plus(Minus(a, b), IntLiteral(1)) case LessEquals(a, b) => lhsSides ::= Minus(a, b) case GreaterThan(a, b) => lhsSides ::= Plus(Minus(b, a), IntLiteral(1)) case GreaterEquals(a, b) => lhsSides ::= Minus(b, a) case e => exprNotUsed ::= e } val ineqVars = lhsSides.foldLeft(Set[Identifier]())((acc, lhs) => acc ++ variablesOf(lhs)) val nonIneqVars = exprNotUsed.foldLeft(Set[Identifier]())((acc, x) => acc ++ variablesOf(x)) val candidateVars = ineqVars.intersect(problem.xs.toSet).filterNot(nonIneqVars.contains(_)) if (candidateVars.isEmpty) { Nil } else { val processedVar: Identifier = candidateVars.map(v => { val normalizedLhs: List[List[Expr]] = lhsSides.map(linearArithmeticForm(_, Array(v)).toList) if(normalizedLhs.isEmpty) (v, 0) else (v, lcm(normalizedLhs.map{ case List(t, IntLiteral(i)) => if(i == 0) 1 else i.abs case _ => sys.error("shouldn't happen") })) }).toList.sortWith((t1, t2) => t1._2 <= t2._2).head._1 val otherVars: List[Identifier] = problem.xs.filterNot(_ == processedVar) val normalizedLhs: List[List[Expr]] = lhsSides.map(linearArithmeticForm(_, Array(processedVar)).toList) var upperBounds: List[(Expr, Int)] = Nil // (t, c) means c*x <= t var lowerBounds: List[(Expr, Int)] = Nil // (t, c) means t <= c*x normalizedLhs.foreach{ case List(t, IntLiteral(i)) => if(i > 0) upperBounds ::= (expandAndSimplifyArithmetic(UMinus(t)), i) else if(i < 0) lowerBounds ::= (expandAndSimplifyArithmetic(t), -i) else exprNotUsed ::= LessEquals(t, IntLiteral(0)) //TODO: make sure that these are added as preconditions case err => sys.error("unexpected from normal form: " + err) } val L = if(upperBounds.isEmpty && lowerBounds.isEmpty) -1 else lcm((upperBounds ::: lowerBounds).map(_._2)) //optimization when coef = 1 and when ub - lb is a constant greater than LCM upperBounds = upperBounds.filterNot{case (ub, uc) => if(uc == 1) { exprNotUsed ++= lowerBounds.map{case (lb, lc) => LessEquals(lb, Times(IntLiteral(lc), ub))} true } else false } lowerBounds = lowerBounds.filterNot{case (lb, lc) => if(lc == 1) { exprNotUsed ++= upperBounds.map{case (ub, uc) => LessEquals(Times(IntLiteral(uc), lb), ub)} true } else false } upperBounds = upperBounds.filterNot{case (ub, uc) => { lowerBounds.forall{case (lb, lc) => { expandAndSimplifyArithmetic(Minus(ub, lb)) match { case IntLiteral(n) => L - 1 <= n case _ => false }}} }} //define max function val maxVarDecls: Seq[VarDecl] = lowerBounds.map(_ => VarDecl(FreshIdentifier("b"), Int32Type)) val maxFun = new FunDef(FreshIdentifier("max"), Int32Type, maxVarDecls) def maxRec(bounds: List[Expr]): Expr = bounds match { case (x1 :: x2 :: xs) => { val v = FreshIdentifier("m").setType(Int32Type) Let(v, IfExpr(LessThan(x1, x2), x2, x1), maxRec(Variable(v) :: xs)) } case (x :: Nil) => x case Nil => sys.error("cannot build a max expression with no argument") } if(!lowerBounds.isEmpty) maxFun.body = Some(maxRec(maxVarDecls.map(vd => Variable(vd.id)).toList)) def max(xs: Seq[Expr]): Expr = FunctionInvocation(maxFun, xs) //define min function val minVarDecls: Seq[VarDecl] = upperBounds.map(_ => VarDecl(FreshIdentifier("b"), Int32Type)) val minFun = new FunDef(FreshIdentifier("min"), Int32Type, minVarDecls) def minRec(bounds: List[Expr]): Expr = bounds match { case (x1 :: x2 :: xs) => { val v = FreshIdentifier("m").setType(Int32Type) Let(v, IfExpr(LessThan(x1, x2), x1, x2), minRec(Variable(v) :: xs)) } case (x :: Nil) => x case Nil => sys.error("cannot build a min expression with no argument") } if(!upperBounds.isEmpty) minFun.body = Some(minRec(minVarDecls.map(vd => Variable(vd.id)).toList)) def min(xs: Seq[Expr]): Expr = FunctionInvocation(minFun, xs) val floorFun = new FunDef(FreshIdentifier("floorDiv"), Int32Type, Seq( VarDecl(FreshIdentifier("x"), Int32Type), VarDecl(FreshIdentifier("x"), Int32Type))) val ceilingFun = new FunDef(FreshIdentifier("ceilingDiv"), Int32Type, Seq( VarDecl(FreshIdentifier("x"), Int32Type), VarDecl(FreshIdentifier("x"), Int32Type))) ceilingFun.body = Some(IntLiteral(0)) def floorDiv(x: Expr, y: Expr): Expr = FunctionInvocation(floorFun, Seq(x, y)) def ceilingDiv(x: Expr, y: Expr): Expr = FunctionInvocation(ceilingFun, Seq(x, y)) val witness: Expr = if(upperBounds.isEmpty) { if(lowerBounds.size > 1) max(lowerBounds.map{case (b, c) => ceilingDiv(b, IntLiteral(c))}) else ceilingDiv(lowerBounds.head._1, IntLiteral(lowerBounds.head._2)) } else { if(upperBounds.size > 1) min(upperBounds.map{case (b, c) => floorDiv(b, IntLiteral(c))}) else floorDiv(upperBounds.head._1, IntLiteral(upperBounds.head._2)) } if(otherVars.isEmpty) { //here we can simply evaluate the precondition and return a witness val constraints: List[Expr] = for((ub, uc) <- upperBounds; (lb, lc) <- lowerBounds) yield LessEquals(ceilingDiv(lb, IntLiteral(lc)), floorDiv(ub, IntLiteral(uc))) val pre = And(exprNotUsed ++ constraints) List(RuleInstantiation.immediateSuccess(problem, this, Solution(pre, Set(), Tuple(Seq(witness))))) } else { val involvedVariables = (upperBounds++lowerBounds).foldLeft(Set[Identifier]())((acc, t) => { acc ++ variablesOf(t._1) }).intersect(problem.xs.toSet) //output variables involved in the bounds of the process variables var newPre: Expr = BooleanLiteral(true) if(involvedVariables.isEmpty) { newPre = And( for((ub, uc) <- upperBounds; (lb, lc) <- lowerBounds) yield LessEquals(ceilingDiv(lb, IntLiteral(lc)), floorDiv(ub, IntLiteral(uc))) ) lowerBounds = Nil upperBounds = Nil } val remainderIds: List[Identifier] = upperBounds.map(_ => FreshIdentifier("k", true).setType(Int32Type)) val quotientIds: List[Identifier] = lowerBounds.map(_ => FreshIdentifier("l", true).setType(Int32Type)) val newUpperBounds: List[Expr] = upperBounds.map{case (bound, coef) => Times(IntLiteral(L/coef), bound)} val newLowerBounds: List[Expr] = lowerBounds.map{case (bound, coef) => Times(IntLiteral(L/coef), bound)} val subProblemFormula = expandAndSimplifyArithmetic(And( newUpperBounds.zip(remainderIds).zip(quotientIds).flatMap{ case ((b, k), l) => Equals(b, Plus(Times(IntLiteral(L), Variable(l)), Variable(k))) :: newLowerBounds.map(lbound => LessEquals(Variable(k), Minus(b, lbound))) } ++ exprNotUsed)) val subProblemxs: List[Identifier] = quotientIds ++ otherVars val subProblem = Problem(problem.as ++ remainderIds, problem.pc, subProblemFormula, subProblemxs) def onSuccess(sols: List[Solution]): Option[Solution] = sols match { case List(Solution(pre, defs, term)) => { if(remainderIds.isEmpty) { Some(Solution(And(newPre, pre), defs, LetTuple(subProblemxs, term, Let(processedVar, witness, Tuple(problem.xs.map(Variable(_))))))) } else if(remainderIds.size > 1) { sys.error("TODO") } else { val k = remainderIds.head val loopCounter = Variable(FreshIdentifier("i", true).setType(Int32Type)) val concretePre = replace(Map(Variable(k) -> loopCounter), pre) val concreteTerm = replace(Map(Variable(k) -> loopCounter), term) val returnType = TupleType(problem.xs.map(_.getType)) val funDef = new FunDef(FreshIdentifier("rec", true), returnType, Seq(VarDecl(loopCounter.id, Int32Type))) val funBody = expandAndSimplifyArithmetic(IfExpr( LessThan(loopCounter, IntLiteral(0)), Error("No solution exists"), IfExpr( concretePre, LetTuple(subProblemxs, concreteTerm, Let(processedVar, witness, Tuple(problem.xs.map(Variable(_)))) ), FunctionInvocation(funDef, Seq(Minus(loopCounter, IntLiteral(1)))) ) )) funDef.body = Some(funBody) Some(Solution(And(newPre, pre), defs + funDef, FunctionInvocation(funDef, Seq(IntLiteral(L-1))))) } } case _ => None } List(RuleInstantiation.immediateDecomp(problem, this, List(subProblem), onSuccess, this.name)) } } } }