From 0d28a8d63436458b04ba2f6e21c9cfb5bb7237d5 Mon Sep 17 00:00:00 2001 From: Samuel Gruetter <samuel.gruetter@epfl.ch> Date: Wed, 3 Jun 2015 16:02:11 +0200 Subject: [PATCH] relation comparator to compare bitvector arg lists lexicographically --- .../ComplexTerminationChecker.scala | 12 ++++ .../leon/termination/RelationComparator.scala | 66 ++++++++++++++----- .../leon/termination/RelationProcessor.scala | 4 +- .../leon/termination/StructuralSize.scala | 39 ++++++++--- 4 files changed, 96 insertions(+), 25 deletions(-) diff --git a/src/main/scala/leon/termination/ComplexTerminationChecker.scala b/src/main/scala/leon/termination/ComplexTerminationChecker.scala index a1fc1d85b..024792e23 100644 --- a/src/main/scala/leon/termination/ComplexTerminationChecker.scala +++ b/src/main/scala/leon/termination/ComplexTerminationChecker.scala @@ -35,9 +35,21 @@ class ComplexTerminationChecker(context: LeonContext, program: Program) extends val checker = ComplexTerminationChecker.this } + val modulesBV = + new StructuralSize + with BVRelationComparator + with ChainComparator + with Strengthener + with RelationBuilder + with ChainBuilder + { + val checker = ComplexTerminationChecker.this + } + def processors = List( new RecursionProcessor(this, modules), // RelationProcessor is the only Processor which benefits from trying a different RelationComparator + new RelationProcessor(this, modulesBV), new RelationProcessor(this, modules), new RelationProcessor(this, modulesLexicographic), new ChainProcessor(this, modules), diff --git a/src/main/scala/leon/termination/RelationComparator.scala b/src/main/scala/leon/termination/RelationComparator.scala index 65c00adb2..c147f12aa 100644 --- a/src/main/scala/leon/termination/RelationComparator.scala +++ b/src/main/scala/leon/termination/RelationComparator.scala @@ -5,13 +5,18 @@ package termination import purescala.Expressions._ import leon.purescala.Constructors._ +import leon.purescala.Types.Int32Type trait RelationComparator { self : StructuralSize => val comparisonMethod: String + + def isApplicableFor(p: Problem): Boolean + /** strictly decreasing: args1 > args2 */ def sizeDecreasing(args1: Seq[Expr], args2: Seq[Expr]): Expr + /** weakly decreasing: args1 >= args2 */ def softDecreasing(args1: Seq[Expr], args2: Seq[Expr]): Expr } @@ -21,6 +26,8 @@ trait ArgsSizeSumRelationComparator extends RelationComparator { self : Structur val comparisonMethod = "comparing sum of argument sizes" + def isApplicableFor(p: Problem): Boolean = true + def sizeDecreasing(args1: Seq[Expr], args2: Seq[Expr]): Expr = { GreaterThan(self.size(tupleWrap(args1)), self.size(tupleWrap(args2))) } @@ -35,27 +42,56 @@ trait ArgsSizeSumRelationComparator extends RelationComparator { self : Structur trait LexicographicRelationComparator extends RelationComparator { self : StructuralSize => val comparisonMethod = "comparing argument lists lexicographically" + + def isApplicableFor(p: Problem): Boolean = true + + def sizeDecreasing(s1: Seq[Expr], s2: Seq[Expr]): Expr = { + lexicographicDecreasing(s1, s2, strict = true, sizeOfOneExpr = e => self.size(e)) + } + + def softDecreasing(s1: Seq[Expr], s2: Seq[Expr]): Expr = { + lexicographicDecreasing(s1, s2, strict = false, sizeOfOneExpr = e => self.size(e)) + } + +} + +// for bitvector Ints +trait BVRelationComparator extends RelationComparator { self : StructuralSize => + + /* + Note: It might seem that comparing the sum of all Int arguments is more flexible, but on + bitvectors, this causes overflow problems, so we won't be able to prove anything! + So that's why this function is useless: - def lexicographicDecreasing(s1: Seq[Expr], s2: Seq[Expr], strict: Boolean): Expr = { - val sameSizeExprs = for ((arg1, arg2) <- (s1 zip s2)) yield Equals(self.size(arg1), self.size(arg2)) - - val greaterBecauseGreaterAtFirstDifferentPos = - orJoin(for (firstDifferent <- 0 until scala.math.min(s1.length, s2.length)) yield and( - andJoin(sameSizeExprs.take(firstDifferent)), - GreaterThan(self.size(s1(firstDifferent)), self.size(s2(firstDifferent))) - )) - - if (s1.length > s2.length || (s1.length == s2.length && !strict)) { - or(andJoin(sameSizeExprs), greaterBecauseGreaterAtFirstDifferentPos) - } else { - greaterBecauseGreaterAtFirstDifferentPos + def bvSize(args: Seq[Expr]): Expr = { + val absValues: Seq[Expr] = args.collect{ + case e if e.getType == Int32Type => FunctionInvocation(typedAbsIntFun, Seq(e)) } + absValues.foldLeft[Expr](IntLiteral(0)){ case (sum, expr) => BVPlus(sum, expr) } } + */ - def sizeDecreasing(s1: Seq[Expr], s2: Seq[Expr]) = lexicographicDecreasing(s1, s2, strict = true) + val comparisonMethod = "comparing Int arguments lexicographically" - def softDecreasing(s1: Seq[Expr], s2: Seq[Expr]) = lexicographicDecreasing(s1, s2, strict = false) + def isApplicableFor(p: Problem): Boolean = { + p.funDefs.forall(fd => fd.params.exists(valdef => valdef.getType == Int32Type)) + } + + def bvSize(e: Expr): Expr = FunctionInvocation(typedAbsIntFun, Seq(e)) + + def sizeDecreasing(s10: Seq[Expr], s20: Seq[Expr]): Expr = { + val s1 = s10.filter(_.getType == Int32Type) + val s2 = s20.filter(_.getType == Int32Type) + lexicographicDecreasing(s1, s2, strict = true, sizeOfOneExpr = bvSize) + } + + def softDecreasing(s10: Seq[Expr], s20: Seq[Expr]): Expr = { + val s1 = s10.filter(_.getType == Int32Type) + val s2 = s20.filter(_.getType == Int32Type) + lexicographicDecreasing(s1, s2, strict = false, sizeOfOneExpr = bvSize) + } } + // vim: set ts=4 sw=4 et: diff --git a/src/main/scala/leon/termination/RelationProcessor.scala b/src/main/scala/leon/termination/RelationProcessor.scala index 8169bff8c..2909d6c4f 100644 --- a/src/main/scala/leon/termination/RelationProcessor.scala +++ b/src/main/scala/leon/termination/RelationProcessor.scala @@ -18,7 +18,9 @@ class RelationProcessor( val name: String = "Relation Processor " + modules.comparisonMethod - def run(problem: Problem) = { + def run(problem: Problem): Option[Seq[Result]] = { + if (!modules.isApplicableFor(problem)) return None + reporter.debug("- Strengthening postconditions") modules.strengthenPostconditions(problem.funSet)(this) diff --git a/src/main/scala/leon/termination/StructuralSize.scala b/src/main/scala/leon/termination/StructuralSize.scala index 693f336a5..11c2b6a66 100644 --- a/src/main/scala/leon/termination/StructuralSize.scala +++ b/src/main/scala/leon/termination/StructuralSize.scala @@ -16,22 +16,25 @@ trait StructuralSize { private val sizeCache : MutableMap[TypeTree, FunDef] = MutableMap.empty - // function abs(x: BigInt): BigInt = if (x >= 0) x else -x - val typedAbsFun = makeAbsFun + // function absBigInt(x: BigInt): BigInt = if (x >= 0) x else -x + val typedAbsBigIntFun = makeAbsFun(IntegerType, "absBigInt", e => UMinus(e), InfiniteIntegerLiteral(0)) - def makeAbsFun: TypedFunDef = { - val x = FreshIdentifier("x", IntegerType, alwaysShowUniqueID = true) + // function absInt(x: Int): Int = if (x >= 0) x else -x + val typedAbsIntFun = makeAbsFun(Int32Type, "absInt", e => BVUMinus(e), IntLiteral(0)) + + def makeAbsFun(tp: TypeTree, name: String, uminus: Expr => Expr, zero: Expr): TypedFunDef = { + val x = FreshIdentifier("x", tp, alwaysShowUniqueID = true) val absFun = new FunDef( - FreshIdentifier("abs", alwaysShowUniqueID = true), + FreshIdentifier(name, alwaysShowUniqueID = true), Seq(), // no type params - IntegerType, // returns BigInt + tp, // return type Seq(ValDef(x)), DefType.MethodDef ) absFun.body = Some(IfExpr( - GreaterEquals(Variable(x), InfiniteIntegerLiteral(0)), + GreaterEquals(Variable(x), zero), Variable(x), - Minus(InfiniteIntegerLiteral(0), Variable(x)))) + uminus(Variable(x)))) TypedFunDef(absFun, Seq()) // Seq() because no generic type params } @@ -86,11 +89,29 @@ trait StructuralSize { case (_, index) => size(tupleSelect(expr, index + 1, true)) }).foldLeft[Expr](InfiniteIntegerLiteral(0))(Plus) case IntegerType => - FunctionInvocation(typedAbsFun, Seq(expr)) + FunctionInvocation(typedAbsBigIntFun, Seq(expr)) case _ => InfiniteIntegerLiteral(0) } } + def lexicographicDecreasing(s1: Seq[Expr], s2: Seq[Expr], strict: Boolean, sizeOfOneExpr: Expr => Expr): Expr = { + // Note: The Equal and GreaterThan ASTs work for both BigInt and Bitvector + + val sameSizeExprs = for ((arg1, arg2) <- (s1 zip s2)) yield Equals(sizeOfOneExpr(arg1), sizeOfOneExpr(arg2)) + + val greaterBecauseGreaterAtFirstDifferentPos = + orJoin(for (firstDifferent <- 0 until scala.math.min(s1.length, s2.length)) yield and( + andJoin(sameSizeExprs.take(firstDifferent)), + GreaterThan(sizeOfOneExpr(s1(firstDifferent)), sizeOfOneExpr(s2(firstDifferent))) + )) + + if (s1.length > s2.length || (s1.length == s2.length && !strict)) { + or(andJoin(sameSizeExprs), greaterBecauseGreaterAtFirstDifferentPos) + } else { + greaterBecauseGreaterAtFirstDifferentPos + } + } + def defs : Set[FunDef] = Set(sizeCache.values.toSeq : _*) } -- GitLab