Skip to content
Snippets Groups Projects
Commit 1915ae9a authored by Régis Blanc's avatar Régis Blanc
Browse files

some refactoring of inequality solver

parent fcb7e959
No related branches found
No related tags found
No related merge requests found
...@@ -65,7 +65,11 @@ class IntegerEquation(synth: Synthesizer) extends Rule("Integer Equation", synth ...@@ -65,7 +65,11 @@ class IntegerEquation(synth: Synthesizer) extends Rule("Integer Equation", synth
RuleStep(List(newProblem), onSuccess) RuleStep(List(newProblem), onSuccess)
} else { } else {
val (eqPre, eqWitness, freshxs) = elimVariable(eqas, normalizedEq) val (eqPre0, eqWitness, freshxs) = elimVariable(eqas, normalizedEq)
val eqPre = eqPre0 match {
case Equals(Modulo(_, IntLiteral(1)), _) => BooleanLiteral(true)
case c => c
}
val eqSubstMap: Map[Expr, Expr] = neqxs.zip(eqWitness).map{case (id, e) => (Variable(id), simplify(e))}.toMap val eqSubstMap: Map[Expr, Expr] = neqxs.zip(eqWitness).map{case (id, e) => (Variable(id), simplify(e))}.toMap
val freshFormula = simplify(replace(eqSubstMap, And(allOthers))) val freshFormula = simplify(replace(eqSubstMap, And(allOthers)))
......
...@@ -13,149 +13,155 @@ import ArithmeticNormalization.simplify ...@@ -13,149 +13,155 @@ import ArithmeticNormalization.simplify
class IntegerInequality(synth: Synthesizer) extends Rule("Integer Inequality", synth, 300) { class IntegerInequality(synth: Synthesizer) extends Rule("Integer Inequality", synth, 300) {
def applyOn(task: Task): RuleResult = { def applyOn(task: Task): RuleResult = {
val problem = task.problem val problem = task.problem
val TopLevelAnds(exprs) = problem.phi val TopLevelAnds(exprs) = problem.phi
//assume that we only have inequalities //assume that we only have inequalities
var nonIneq = false var lhsSides: List[Expr] = Nil
var exprNotUsed: List[Expr] = Nil
//normalized all inequalities to LessEquals(t, 0) //normalized all inequalities to LessEquals(t, 0)
val lhsSides: List[Expr] = exprs.map{ exprs.foreach{
case LessThan(a, b) => Plus(Minus(a, b), IntLiteral(1)) case LessThan(a, b) => lhsSides ::= Plus(Minus(a, b), IntLiteral(1))
case LessEquals(a, b) => Minus(a, b) case LessEquals(a, b) => lhsSides ::= Minus(a, b)
case GreaterThan(a, b) => Plus(Minus(b, a), IntLiteral(1)) case GreaterThan(a, b) => lhsSides ::= Plus(Minus(b, a), IntLiteral(1))
case GreaterEquals(a, b) => Minus(b, a) case GreaterEquals(a, b) => lhsSides ::= Minus(b, a)
case _ => {nonIneq = true; null} case e => exprNotUsed ::= e
}.toList }
if(nonIneq) RuleInapplicable else { val ineqVars = lhsSides.foldLeft(Set[Identifier]())((acc, lhs) => acc ++ variablesOf(lhs))
var processedVar: Option[Identifier] = None val nonIneqVars = exprNotUsed.foldLeft(Set[Identifier]())((acc, x) => acc ++ variablesOf(x))
for(e <- lhsSides if processedVar == None) { val candidateVars = ineqVars.intersect(problem.xs.toSet).filterNot(nonIneqVars.contains(_))
val vars = variablesOf(e).intersect(problem.xs.toSet) if(candidateVars.isEmpty) RuleInapplicable else {
if(!vars.isEmpty) val processedVar = candidateVars.head
processedVar = Some(vars.head) val otherVars: List[Identifier] = problem.xs.filterNot(_ == processedVar)
val normalizedLhs: List[List[Expr]] = lhsSides.map(ArithmeticNormalization(_, Array(processedVar)).toList)
var upperBounds: List[(Expr, Int)] = Nil // (t, c) means c*x <= t
var lowerBounds: List[(Expr, Int)] = Nil // (t, c) means t <= c*x
normalizedLhs.foreach{
case List(t, IntLiteral(i)) =>
if(i > 0) upperBounds ::= (simplify(UMinus(t)), i)
else if(i < 0) lowerBounds ::= (simplify(t), -i)
else /*if (i == 0)*/ exprNotUsed ::= LessEquals(t, IntLiteral(0))
case err => sys.error("unexpected from normal form: " + err)
} }
processedVar match {
case None => RuleInapplicable
case Some(processedVar) => {
println("processed Var: " + processedVar)
val normalizedLhs: List[List[Expr]] = lhsSides.map(ArithmeticNormalization(_, Array(processedVar)).toList)
var upperBounds: List[(Expr, Int)] = Nil // (t, c) means c*x <= t
var lowerBounds: List[(Expr, Int)] = Nil // (t, c) means t <= c*x
normalizedLhs.foreach{
case List(t, IntLiteral(i)) =>
if(i > 0) upperBounds ::= (simplify(UMinus(t)), i)
else if(i < 0) lowerBounds ::= (simplify(t), -i)
case err => sys.error("unexpected from normal form: " + err)
}
val otherVars: List[Identifier] = problem.xs.filterNot(_ == processedVar)
println("otherVars: " + otherVars)
if(otherVars.isEmpty) { //here we can simply evaluate the precondition and return a bound
val witness = if(upperBounds.isEmpty)
Division(lowerBounds.head._1, IntLiteral(lowerBounds.head._2))
else
Division(upperBounds.head._1, IntLiteral(upperBounds.head._2))
val pre = if(lowerBounds.isEmpty || upperBounds.isEmpty) BooleanLiteral(true) else sys.error("TODO")
RuleSuccess(Solution(pre, Set(), witness))
} else {
val L = GCD.lcm((upperBounds ::: lowerBounds).map(_._2))
val newUpperBounds: List[Expr] = upperBounds.map{case (bound, coef) => Times(IntLiteral(L/coef), bound)}
val newLowerBounds: List[Expr] = lowerBounds.map{case (bound, coef) => Times(IntLiteral(L/coef), bound)}
val remainderIds: List[Identifier] = newUpperBounds.map(_ => FreshIdentifier("k", true).setType(Int32Type))
val quotientIds: List[Identifier] = newUpperBounds.map(_ => FreshIdentifier("l", true).setType(Int32Type))
val subProblemFormula = simplify(And(
newUpperBounds.zip(remainderIds).zip(quotientIds).flatMap{
case ((b, k), l) => Equals(b, Plus(Times(IntLiteral(L), Variable(l)), Variable(k))) ::
newLowerBounds.map(lbound => LessEquals(Variable(k), Minus(b, lbound)))
}))
val subProblem = Problem(problem.as ++ remainderIds, problem.c, subProblemFormula, otherVars ++ quotientIds)
def onSuccess(sols: List[Solution]): Solution = sols match {
case List(Solution(pre, defs, term)) => {
val maxVarDecls: Seq[VarDecl] = lowerBounds.map(_ => VarDecl(FreshIdentifier("b"), Int32Type))
val maxFun = new FunDef(FreshIdentifier("max"), Int32Type, maxVarDecls)
def maxRec(bounds: List[Expr]): Expr = bounds match {
case (x1 :: x2 :: xs) => {
val v = FreshIdentifier("m").setType(Int32Type)
Let(v, IfExpr(LessThan(x1, x2), x2, x1), maxRec(Variable(v) :: xs))
}
case (x :: Nil) => x
case Nil => sys.error("cannot build a max expression with no argument")
}
if(!lowerBounds.isEmpty)
maxFun.body = Some(maxRec(maxVarDecls.map(vd => Variable(vd.id)).toList))
val minVarDecls: Seq[VarDecl] = upperBounds.map(_ => VarDecl(FreshIdentifier("b"), Int32Type))
val minFun = new FunDef(FreshIdentifier("min"), Int32Type, minVarDecls)
def minRec(bounds: List[Expr]): Expr = bounds match {
case (x1 :: x2 :: xs) => {
val v = FreshIdentifier("m").setType(Int32Type)
Let(v, IfExpr(LessThan(x1, x2), x1, x2), minRec(Variable(v) :: xs))
}
case (x :: Nil) => x
case Nil => sys.error("cannot build a min expression with no argument")
}
if(!upperBounds.isEmpty)
minFun.body = Some(minRec(minVarDecls.map(vd => Variable(vd.id)).toList))
//define max function
val maxVarDecls: Seq[VarDecl] = lowerBounds.map(_ => VarDecl(FreshIdentifier("b"), Int32Type))
val maxFun = new FunDef(FreshIdentifier("max"), Int32Type, maxVarDecls)
def maxRec(bounds: List[Expr]): Expr = bounds match {
case (x1 :: x2 :: xs) => {
val v = FreshIdentifier("m").setType(Int32Type)
Let(v, IfExpr(LessThan(x1, x2), x2, x1), maxRec(Variable(v) :: xs))
}
case (x :: Nil) => x
case Nil => sys.error("cannot build a max expression with no argument")
}
if(!lowerBounds.isEmpty)
maxFun.body = Some(maxRec(maxVarDecls.map(vd => Variable(vd.id)).toList))
def max(xs: Seq[Expr]): Expr = FunctionInvocation(maxFun, xs)
//define min function
val minVarDecls: Seq[VarDecl] = upperBounds.map(_ => VarDecl(FreshIdentifier("b"), Int32Type))
val minFun = new FunDef(FreshIdentifier("min"), Int32Type, minVarDecls)
def minRec(bounds: List[Expr]): Expr = bounds match {
case (x1 :: x2 :: xs) => {
val v = FreshIdentifier("m").setType(Int32Type)
Let(v, IfExpr(LessThan(x1, x2), x1, x2), minRec(Variable(v) :: xs))
}
case (x :: Nil) => x
case Nil => sys.error("cannot build a min expression with no argument")
}
if(!upperBounds.isEmpty)
minFun.body = Some(minRec(minVarDecls.map(vd => Variable(vd.id)).toList))
def min(xs: Seq[Expr]): Expr = FunctionInvocation(minFun, xs)
val floorFun = new FunDef(FreshIdentifier("floorDiv"), Int32Type, Seq(
VarDecl(FreshIdentifier("x"), Int32Type),
VarDecl(FreshIdentifier("x"), Int32Type)))
val ceilingFun = new FunDef(FreshIdentifier("ceilingDiv"), Int32Type, Seq(
VarDecl(FreshIdentifier("x"), Int32Type),
VarDecl(FreshIdentifier("x"), Int32Type)))
def floor(x: Expr, y: Expr): Expr = FunctionInvocation(floorFun, Seq(x, y))
def ceiling(x: Expr, y: Expr): Expr = FunctionInvocation(ceilingFun, Seq(x, y))
val witness: Expr = if(upperBounds.isEmpty) {
if(lowerBounds.size > 1) max(lowerBounds.map{case (b, c) => ceiling(b, IntLiteral(c))})
else ceiling(lowerBounds.head._1, IntLiteral(lowerBounds.head._2))
} else {
if(upperBounds.size > 1) min(upperBounds.map{case (b, c) => floor(b, IntLiteral(c))})
else ceiling(upperBounds.head._1, IntLiteral(upperBounds.head._2))
}
if(newUpperBounds.isEmpty) { if(otherVars.isEmpty) { //here we can simply evaluate the precondition and return a witness
Solution(pre, defs, val pre = if(lowerBounds.isEmpty || upperBounds.isEmpty) BooleanLiteral(true) else sys.error("TODO")
LetTuple(otherVars++quotientIds, term, RuleSuccess(Solution(pre, Set(), witness))
Let(processedVar, } else {
FunctionInvocation(maxFun, lowerBounds.map(t => Division(t._1, IntLiteral(t._2)))), val L = GCD.lcm((upperBounds ::: lowerBounds).map(_._2))
Tuple(problem.xs.map(Variable(_)))))) val newUpperBounds: List[Expr] = upperBounds.map{case (bound, coef) => Times(IntLiteral(L/coef), bound)}
} else if(newLowerBounds.isEmpty) { val newLowerBounds: List[Expr] = lowerBounds.map{case (bound, coef) => Times(IntLiteral(L/coef), bound)}
Solution(pre, defs,
LetTuple(otherVars++quotientIds, term, val remainderIds: List[Identifier] = newUpperBounds.map(_ => FreshIdentifier("k", true).setType(Int32Type))
Let(processedVar, val quotientIds: List[Identifier] = newUpperBounds.map(_ => FreshIdentifier("l", true).setType(Int32Type))
val subProblemFormula = simplify(And(
newUpperBounds.zip(remainderIds).zip(quotientIds).flatMap{
case ((b, k), l) => Equals(b, Plus(Times(IntLiteral(L), Variable(l)), Variable(k))) ::
newLowerBounds.map(lbound => LessEquals(Variable(k), Minus(b, lbound)))
} ++ exprNotUsed))
val subProblemxs: List[Identifier] = otherVars ++ quotientIds
val subProblem = Problem(problem.as ++ remainderIds, problem.c, subProblemFormula, subProblemxs)
def onSuccess(sols: List[Solution]): Solution = sols match {
case List(Solution(pre, defs, term)) => {
val involvedVariables = (upperBounds++lowerBounds).foldLeft(Set[Identifier]())((acc, t) => {
acc ++ variablesOf(t._1)
}).intersect(problem.xs.toSet) //output variables involved in the bounds of the process variables
if(involvedVariables.isEmpty) { //here we can just evaluate the lower and upper bound
val newPre = And(
for((ub, uc) <- upperBounds; (lb, lc) <- lowerBounds)
yield LessEquals(ceilingDiv(lb, IntLiteral(lc)), floorDiv(ub, IntLiteral(uc))))
Solution(And(newPre, pre), defs,
LetTuple(subProblemxs, term,
Let(processedVar, witness,
Tuple(problem.xs.map(Variable(_))))))
} else if(upperBounds.isEmpty || lowerBounds.isEmpty) {
Solution(pre, defs,
LetTuple(otherVars++quotientIds, term,
Let(processedVar, witness,
Tuple(problem.xs.map(Variable(_))))))
} else if(upperBounds.size > 1)
Solution.none
else {
val k = remainderIds.head
val loopCounter = Variable(FreshIdentifier("i").setType(Int32Type))
val concretePre = replace(Map(Variable(k) -> loopCounter), pre)
val returnType = TupleType(problem.xs.map(_.getType))
val funDef = new FunDef(FreshIdentifier("rec", true), returnType, Seq(VarDecl(loopCounter.id, Int32Type)))
val funBody = IfExpr(
LessThan(loopCounter, IntLiteral(0)),
Error("No solution exists"),
IfExpr(
concretePre,
LetTuple(otherVars++quotientIds, term,
Let(processedVar,
if(newUpperBounds.isEmpty)
FunctionInvocation(maxFun, lowerBounds.map(t => Division(t._1, IntLiteral(t._2))))
else
FunctionInvocation(minFun, upperBounds.map(t => Division(t._1, IntLiteral(t._2)))), FunctionInvocation(minFun, upperBounds.map(t => Division(t._1, IntLiteral(t._2)))),
Tuple(problem.xs.map(Variable(_)))))) Tuple(problem.xs.map(Variable(_))))
} else if(newUpperBounds.size > 1) ),
Solution.none FunctionInvocation(funDef, Seq(Minus(loopCounter, IntLiteral(1))))
else { )
val k = remainderIds.head )
funDef.body = Some(funBody)
val loopCounter = Variable(FreshIdentifier("i").setType(Int32Type))
val concretePre = replace(Map(Variable(k) -> loopCounter), pre) Solution(pre, defs + funDef, FunctionInvocation(funDef, Seq(IntLiteral(L-1))))
val returnType = TupleType(problem.xs.map(_.getType))
val funDef = new FunDef(FreshIdentifier("rec", true), returnType, Seq(VarDecl(loopCounter.id, Int32Type)))
val funBody = IfExpr(
LessThan(loopCounter, IntLiteral(0)),
Error("No solution exists"),
IfExpr(
concretePre,
LetTuple(otherVars++quotientIds, term,
Let(processedVar,
if(newUpperBounds.isEmpty)
FunctionInvocation(maxFun, lowerBounds.map(t => Division(t._1, IntLiteral(t._2))))
else
FunctionInvocation(minFun, upperBounds.map(t => Division(t._1, IntLiteral(t._2)))),
Tuple(problem.xs.map(Variable(_))))
),
FunctionInvocation(funDef, Seq(Minus(loopCounter, IntLiteral(1))))
)
)
funDef.body = Some(funBody)
println("generated code: " + funDef)
Solution(pre, defs + funDef, FunctionInvocation(funDef, Seq(IntLiteral(L-1))))
}
}
case _ => Solution.none
} }
RuleStep(List(subProblem), onSuccess)
} }
case _ => Solution.none
} }
RuleStep(List(subProblem), onSuccess)
} }
} }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment