diff --git a/library/annotation/package.scala b/library/annotation/package.scala index 63bf384f24187a5ede39dcbdf39e90400803fccf..00ae0743c3977c65376ba1d66a9fb476c24b5397 100644 --- a/library/annotation/package.scala +++ b/library/annotation/package.scala @@ -15,5 +15,8 @@ package object annotation { class extern extends StaticAnnotation @ignore class inline extends StaticAnnotation -} - + @ignore + class monotonic extends StaticAnnotation + @ignore + class compose extends StaticAnnotation +} \ No newline at end of file diff --git a/library/instrumentation/package.scala b/library/instrumentation/package.scala new file mode 100644 index 0000000000000000000000000000000000000000..a724eadd725cae95034354c5bd64990f8ff18ba3 --- /dev/null +++ b/library/instrumentation/package.scala @@ -0,0 +1,24 @@ +/* Copyright 2009-2013 EPFL, Lausanne */ + +package leon + +import leon.annotation._ +import leon.lang._ +import scala.language.implicitConversions + +package object instrumentation { + @library + def time: BigInt = 0 + + @library + def stack: BigInt = 0 + + @library + def rec: BigInt = 0 + + @library + def depth: BigInt = 0 + + @library + def tpr: BigInt = 0 +} diff --git a/library/invariant/package.scala b/library/invariant/package.scala new file mode 100644 index 0000000000000000000000000000000000000000..4382462bdbce3e062d246aefa682db21df557624 --- /dev/null +++ b/library/invariant/package.scala @@ -0,0 +1,26 @@ +/* Copyright 2009-2013 EPFL, Lausanne */ + +package leon + +import leon.annotation._ +import leon.lang._ +import scala.language.implicitConversions + +package object invariant { + @library + def tmpl(templateFunc: BigInt => Boolean): Boolean = true + @library + def tmpl(templateFunc: (BigInt, BigInt) => Boolean): Boolean = true + @library + def tmpl(templateFunc: (BigInt, BigInt, BigInt) => Boolean): Boolean = true + @library + def tmpl(templateFunc: (BigInt, BigInt, BigInt, BigInt) => Boolean): Boolean = true + @library + def tmpl(templateFunc: (BigInt, BigInt, BigInt, BigInt, BigInt) => Boolean): Boolean = true + + @library + def ? : BigInt = 0 + + @library + def ?(id: BigInt) = id +} diff --git a/library/lang/synthesis/package.scala b/library/lang/synthesis/package.scala index 7084d9c1408c2fc18525173e5c9b22595d0cf22d..cb56992797715423aa1b0c45f94946778cb5780e 100644 --- a/library/lang/synthesis/package.scala +++ b/library/lang/synthesis/package.scala @@ -20,7 +20,7 @@ package object synthesis { @ignore def choose[A, B, C, D](predicate: (A, B, C, D) => Boolean): (A, B, C, D) = noImpl @ignore - def choose[A, B, C, D, E](predicate: (A, B, C, D, E) => Boolean): (A, B, C, D, E) = noImpl + def choose[A, B, C, D, E](predicate: (A, B, C, D, E) => Boolean): (A, B, C, D, E) = noImpl @ignore def ???[T]: T = noImpl diff --git a/library/par/package.scala b/library/par/package.scala index b1bb657d086c9984121cc32432e188670b24daa1..5842210fce915df2c02a0f69b6000703c8865af9 100644 --- a/library/par/package.scala +++ b/library/par/package.scala @@ -8,7 +8,7 @@ import leon.lang.synthesis.choose package object par { - // @library + @library @inline def parallel[A,B](x: => A, y: => B) : (A,B) = { (x,y) diff --git a/src/main/java/leon/codegen/runtime/Rational.java b/src/main/java/leon/codegen/runtime/Rational.java new file mode 100644 index 0000000000000000000000000000000000000000..fc409434b05216591069a94d79af6000c75e62d2 --- /dev/null +++ b/src/main/java/leon/codegen/runtime/Rational.java @@ -0,0 +1,127 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon.codegen.runtime; + +import java.math.BigInteger; + +public final class Rational { + + private final BigInteger _num; + private final BigInteger _denom; + + /** + * class invariant: the fractions are always normalized + * + * @param num + * numerator + * @param denom + * denominator + */ + public Rational(BigInteger num, BigInteger denom) { + BigInteger modNum = num.abs(); + BigInteger modDenom = denom.abs(); + BigInteger divisor = modNum.gcd(modDenom); + BigInteger simpNum = num.divide(divisor); + BigInteger simpDenom = denom.divide(divisor); + if (isLTZ(simpDenom)) { + _num = simpNum.negate(); + _denom = simpDenom.negate(); + } else { + _num = simpNum; + _denom = simpDenom; + } + } + + public Rational(String num, String denom) { + this(new BigInteger(num), new BigInteger(denom)); + } + + public BigInteger numerator() { + return _num; + } + + public BigInteger denominator() { + return _denom; + } + + public boolean isZero(BigInteger bi) { + return bi.signum() == 0; + } + + public boolean isLEZ(BigInteger bi) { + return bi.signum() != 1; + } + + public boolean isLTZ(BigInteger bi) { + return (bi.signum() == -1); + } + + public boolean isGEZ(BigInteger bi) { + return (bi.signum() != -1); + } + + public boolean isGTZ(BigInteger bi) { + return (bi.signum() == 1); + } + + public Rational add(Rational that) { + return new Rational(_num.multiply(that._denom).add( + that._num.multiply(_denom)), _denom.multiply(that._denom)); + } + + public Rational sub(Rational that) { + return new Rational(_num.multiply(that._denom).subtract( + that._num.multiply(_denom)), _denom.multiply(that._denom)); + } + + public Rational mult(Rational that) { + return new Rational(_num.multiply(that._num), + _denom.multiply(that._denom)); + } + + public Rational div(Rational that) { + return new Rational(_num.multiply(that._denom), + _denom.multiply(that._num)); + } + + public Rational neg() { + return new Rational(_num.negate(), _denom); + } + + public boolean lessThan(Rational that) { + return isLTZ(this.sub(that)._num); + } + + public boolean lessEquals(Rational that) { + return isLEZ(this.sub(that)._num); + } + + public boolean greaterThan(Rational that) { + return isGTZ(this.sub(that)._num); + } + + public boolean greaterEquals(Rational that) { + return isGEZ(this.sub(that)._num); + } + + @Override + public boolean equals(Object that) { + if (that == this) + return true; + if (!(that instanceof Rational)) + return false; + + Rational other = (Rational) that; + return isZero(this.sub(other)._num); + } + + @Override + public String toString() { + return _num.toString() + "/" + _denom.toString(); + } + + @Override + public int hashCode() { + return _num.hashCode() ^ _denom.hashCode(); + } +} diff --git a/src/main/scala/leon/Main.scala b/src/main/scala/leon/Main.scala index 6e2af3812045f1b7dbc613dabea2162585d30cda..34fc1c98c52a9b8d4d165e5e9d5d3d21d0e20e57 100644 --- a/src/main/scala/leon/Main.scala +++ b/src/main/scala/leon/Main.scala @@ -24,8 +24,9 @@ object Main { repair.RepairPhase, evaluators.EvaluationPhase, solvers.isabelle.AdaptationPhase, - solvers.isabelle.IsabellePhase - ) + solvers.isabelle.IsabellePhase, + transformations.InstrumentationPhase, + invariant.engine.InferInvariantsPhase) } // Add whatever you need here. @@ -49,10 +50,11 @@ object Main { val optNoop = LeonFlagOptionDef("noop", "No operation performed, just output program", false) val optVerify = LeonFlagOptionDef("verify", "Verify function contracts", false) val optHelp = LeonFlagOptionDef("help", "Show help message", false) + val optInstrument = LeonFlagOptionDef("instrument", "Instrument the code for inferring time/depth/stack bounds", false) + val optInferInv = LeonFlagOptionDef("inferInv", "Infer invariants from (instrumented) the code", false) override val definedOptions: Set[LeonOptionDef[Any]] = - Set(optTermination, optRepair, optSynthesis, optIsabelle, optNoop, optHelp, optEval, optVerify) - + Set(optTermination, optRepair, optSynthesis, optIsabelle, optNoop, optHelp, optEval, optVerify, optInstrument, optInferInv) } lazy val allOptions: Set[LeonOptionDef[Any]] = allComponents.flatMap(_.definedOptions) @@ -70,8 +72,8 @@ object Main { reporter.info(opt.helpString) } reporter.info("") - - reporter.title("Additional options, by component:") + + reporter.info("Additional options, by component:") for (c <- (allComponents - MainComponent - SharedOptions).toSeq.sortBy(_.name) if c.definedOptions.nonEmpty) { reporter.info("") @@ -149,6 +151,8 @@ object Main { import evaluators.EvaluationPhase import solvers.isabelle.IsabellePhase import MainComponent._ + import invariant.engine.InferInvariantsPhase + import transformations.InstrumentationPhase val helpF = ctx.findOptionOrDefault(optHelp) val noopF = ctx.findOptionOrDefault(optNoop) @@ -159,6 +163,8 @@ object Main { val terminationF = ctx.findOptionOrDefault(optTermination) val verifyF = ctx.findOptionOrDefault(optVerify) val evalF = ctx.findOption(optEval).isDefined + val inferInvF = ctx.findOptionOrDefault(optInferInv) + val instrumentF = ctx.findOptionOrDefault(optInstrument) val analysisF = verifyF && terminationF if (helpF) { @@ -179,7 +185,10 @@ object Main { else if (terminationF) TerminationPhase else if (isabelleF) IsabellePhase else if (evalF) EvaluationPhase + else if (inferInvF) InstrumentationPhase andThen InferInvariantsPhase + else if (instrumentF) InstrumentationPhase andThen FileOutputPhase else analysis + } pipeBegin andThen @@ -242,7 +251,7 @@ object Main { case (vReport: verification.VerificationReport, tReport: termination.TerminationReport) => ctx.reporter.info(vReport.summaryString) ctx.reporter.info(tReport.summaryString) - + case report: verification.VerificationReport => ctx.reporter.info(report.summaryString) diff --git a/src/main/scala/leon/codegen/CodeGeneration.scala b/src/main/scala/leon/codegen/CodeGeneration.scala index cd2881c82b05bfd9b5e7f27ebd8fa163813f328d..3ea170a05732901467b31f939d87406433719895 100644 --- a/src/main/scala/leon/codegen/CodeGeneration.scala +++ b/src/main/scala/leon/codegen/CodeGeneration.scala @@ -24,7 +24,7 @@ trait CodeGeneration { /** A class providing information about the status of parameters in the function that is being currently compiled. * vars is a mapping from local variables/ parameters to the offset of the respective JVM local register - * isStatic signifies if the current method is static (a function, in Leon terms) + * isStatic signifies if the current method is static (a function, in Leon terms) */ class Locals private[codegen] ( vars : Map[Identifier, Int], @@ -67,6 +67,7 @@ trait CodeGeneration { private[codegen] val MapClass = "leon/codegen/runtime/Map" private[codegen] val BigIntClass = "leon/codegen/runtime/BigInt" private[codegen] val RealClass = "leon/codegen/runtime/Real" + private[codegen] val RationalClass = "leon/codegen/runtime/Rational" private[codegen] val CaseClassClass = "leon/codegen/runtime/CaseClass" private[codegen] val LambdaClass = "leon/codegen/runtime/Lambda" private[codegen] val ErrorClass = "leon/codegen/runtime/LeonCodeGenRuntimeException" @@ -120,7 +121,7 @@ trait CodeGeneration { "L" + BigIntClass + ";" case RealType => - "L" + RealClass + ";" + "L" + RationalClass + ";" case _ : FunctionType => "L" + LambdaClass + ";" @@ -141,12 +142,12 @@ trait CodeGeneration { case CharType => s"L$BoxedCharClass;" case other => typeToJVM(other) } - + /** * Compiles a function/method definition. * @param funDef The function definition to be compiled * @param owner The module/class that contains `funDef` - */ + */ def compileFunDef(funDef: FunDef, owner: Definition) { val isStatic = owner.isInstanceOf[ModuleDef] @@ -194,7 +195,7 @@ trait CodeGeneration { } val bodyWithPost = funDef.postcondition match { - case Some(post) if params.checkContracts => + case Some(post) if params.checkContracts => Ensuring(bodyWithPre, post).toAssert case _ => bodyWithPre } @@ -569,10 +570,11 @@ trait CodeGeneration { ch << Ldc(v.toString) ch << InvokeSpecial(BigIntClass, constructorName, "(Ljava/lang/String;)V") - case RealLiteral(v) => - ch << New(RealClass) << DUP - ch << Ldc(v.toString) - ch << InvokeSpecial(RealClass, constructorName, "(Ljava/lang/String;)V") + case FractionalLiteral(n, d) => + ch << New(RationalClass) << DUP + ch << Ldc(n.toString) + ch << Ldc(d.toString) + ch << InvokeSpecial(RationalClass, constructorName, "(Ljava/lang/String;Ljava/lang/String;)V") // Case classes case CaseClass(cct, as) => @@ -718,7 +720,7 @@ trait CodeGeneration { val (className, fieldName, _) = leonFunDefToJVMInfo(tfd.fd).getOrElse { throw CompilationException("Unknown method : " + tfd.id) } - + if (requireMonitor) { load(monitorID, ch) ch << InvokeVirtual(MonitorClass, "onInvoke", "()V") @@ -748,7 +750,7 @@ trait CodeGeneration { // ch << ALoad(locals.monitorIndex) //} - // No dynamic dispatching/overriding in Leon, + // No dynamic dispatching/overriding in Leon, // so no need to take care of own vs. "super" methods ch << InvokeVirtual(SetClass, "getElements", s"()L$JavaIteratorClass;") @@ -818,17 +820,17 @@ trait CodeGeneration { val (className, fieldName, _) = leonFunDefToJVMInfo(tfd.fd).getOrElse { throw CompilationException("Unknown method : " + tfd.id) } - + if (requireMonitor) { load(monitorID, ch) ch << InvokeVirtual(MonitorClass, "onInvoke", "()V") } - // Load receiver - mkExpr(rec,ch) - + // Load receiver + mkExpr(rec,ch) + // Get field ch << GetField(className, fieldName, typeToJVM(tfd.fd.returnType)) - + // unbox field (tfd.fd.returnType, tfd.returnType) match { case (TypeParameter(_), tpe) => @@ -843,13 +845,13 @@ trait CodeGeneration { throw CompilationException("Unknown method : " + tfd.id) } - // Receiver of the method call + // Receiver of the method call mkExpr(rec,ch) if (requireMonitor) { load(monitorID, ch) } - + for((a, vd) <- as zip tfd.fd.params) { vd.getType match { case TypeParameter(_) => @@ -860,7 +862,7 @@ trait CodeGeneration { } // No interfaces in Leon, so no need to use InvokeInterface - ch << InvokeVirtual(className, methodName, sig) + ch << InvokeVirtual(className, methodName, sig) (tfd.fd.returnType, tfd.returnType) match { case (TypeParameter(_), tpe) => @@ -924,26 +926,26 @@ trait CodeGeneration { case RealPlus(l, r) => mkExpr(l, ch) mkExpr(r, ch) - ch << InvokeVirtual(RealClass, "add", s"(L$RealClass;)L$RealClass;") + ch << InvokeVirtual(RationalClass, "add", s"(L$RationalClass;)L$RationalClass;") case RealMinus(l, r) => mkExpr(l, ch) mkExpr(r, ch) - ch << InvokeVirtual(RealClass, "sub", s"(L$RealClass;)L$RealClass;") + ch << InvokeVirtual(RationalClass, "sub", s"(L$RationalClass;)L$RationalClass;") case RealTimes(l, r) => mkExpr(l, ch) mkExpr(r, ch) - ch << InvokeVirtual(RealClass, "mult", s"(L$RealClass;)L$RealClass;") + ch << InvokeVirtual(RationalClass, "mult", s"(L$RationalClass;)L$RationalClass;") case RealDivision(l, r) => mkExpr(l, ch) mkExpr(r, ch) - ch << InvokeVirtual(RealClass, "div", s"(L$RealClass;)L$RealClass;") + ch << InvokeVirtual(RationalClass, "div", s"(L$RationalClass;)L$RationalClass;") case RealUMinus(e) => mkExpr(e, ch) - ch << InvokeVirtual(RealClass, "neg", s"()L$RealClass;") + ch << InvokeVirtual(RationalClass, "neg", s"()L$RationalClass;") //BV arithmetic @@ -1036,7 +1038,7 @@ trait CodeGeneration { case ArrayType(BooleanType) => ch << NewArray.primitive("T_BOOLEAN"); BASTORE case ArrayType(other) => ch << NewArray(typeToJVM(other)); AASTORE case other => throw CompilationException(s"Cannot compile finite array expression whose type is $other.") - } + } //srcArrary and targetArray is on the stack ch << DUP_X1 //insert targetArray under srcArray ch << Ldc(0) << SWAP //srcArray, 0, targetArray @@ -1065,10 +1067,10 @@ trait CodeGeneration { for (i <- 0 until l) { val v = elems.get(i).orElse(default).getOrElse { throw CompilationException(s"No valuation for key '$i' in array") - } + } ch << DUP << Ldc(i) - mkExpr(v, ch) + mkExpr(v, ch) ch << storeInstr } @@ -1104,22 +1106,22 @@ trait CodeGeneration { val id = runtime.GenericValues.register(gv) ch << Ldc(id) ch << InvokeStatic(GenericValuesClass, "get", s"(I)L$ObjectClass;") - + case nt @ NoTree( tp@ValueType() ) => mkExpr(simplestValue(tp), ch) - + case NoTree(_) => ch << ACONST_NULL - + case This(ct) => ch << ALoad(0) - - case p : Passes => + + case p : Passes => mkExpr(matchToIfThenElse(p.asConstraint), ch) - case m : MatchExpr => + case m : MatchExpr => mkExpr(matchToIfThenElse(m), ch) - + case b if b.getType == BooleanType && canDelegateToMkBranch => val fl = ch.getFreshLabel("boolfalse") val al = ch.getFreshLabel("boolafter") @@ -1127,7 +1129,7 @@ trait CodeGeneration { mkBranch(b, al, fl, ch, canDelegateToMkExpr = false) ch << Label(fl) << POP << Ldc(0) << Label(al) - case _ => throw CompilationException("Unsupported expr " + e + " : " + e.getClass) + case _ => throw CompilationException("Unsupported expr " + e + " : " + e.getClass) } } @@ -1200,7 +1202,7 @@ trait CodeGeneration { ch << CheckCast(BigIntClass) case RealType => - ch << CheckCast(RealClass) + ch << CheckCast(RationalClass) case tt : TupleType => ch << CheckCast(TupleClass) @@ -1214,7 +1216,7 @@ trait CodeGeneration { case ft : FunctionType => ch << CheckCast(LambdaClass) - case tp : TypeParameter => + case tp : TypeParameter => case tp : ArrayType => ch << CheckCast(BoxedArrayClass) << InvokeVirtual(BoxedArrayClass, "arrayValue", s"()${typeToJVM(tp)}") @@ -1243,7 +1245,7 @@ trait CodeGeneration { val fl = ch.getFreshLabel("ornext") mkBranch(es.head, thenn, fl, ch) ch << Label(fl) - mkBranch(orJoin(es.tail), thenn, elze, ch) + mkBranch(orJoin(es.tail), thenn, elze, ch) case Implies(l, r) => mkBranch(or(not(l), r), thenn, elze, ch) @@ -1272,12 +1274,12 @@ trait CodeGeneration { mkExpr(r, ch) l.getType match { case Int32Type | CharType => - ch << If_ICmpLt(thenn) << Goto(elze) + ch << If_ICmpLt(thenn) << Goto(elze) case IntegerType => ch << InvokeVirtual(BigIntClass, "lessThan", s"(L$BigIntClass;)Z") ch << IfEq(elze) << Goto(thenn) case RealType => - ch << InvokeVirtual(RealClass, "lessThan", s"(L$RealClass;)Z") + ch << InvokeVirtual(RationalClass, "lessThan", s"(L$RationalClass;)Z") ch << IfEq(elze) << Goto(thenn) } @@ -1286,12 +1288,12 @@ trait CodeGeneration { mkExpr(r, ch) l.getType match { case Int32Type | CharType => - ch << If_ICmpGt(thenn) << Goto(elze) + ch << If_ICmpGt(thenn) << Goto(elze) case IntegerType => ch << InvokeVirtual(BigIntClass, "greaterThan", s"(L$BigIntClass;)Z") ch << IfEq(elze) << Goto(thenn) case RealType => - ch << InvokeVirtual(RealClass, "greaterThan", s"(L$RealClass;)Z") + ch << InvokeVirtual(RationalClass, "greaterThan", s"(L$RationalClass;)Z") ch << IfEq(elze) << Goto(thenn) } @@ -1300,12 +1302,12 @@ trait CodeGeneration { mkExpr(r, ch) l.getType match { case Int32Type | CharType => - ch << If_ICmpLe(thenn) << Goto(elze) + ch << If_ICmpLe(thenn) << Goto(elze) case IntegerType => ch << InvokeVirtual(BigIntClass, "lessEquals", s"(L$BigIntClass;)Z") ch << IfEq(elze) << Goto(thenn) case RealType => - ch << InvokeVirtual(RealClass, "lessEquals", s"(L$RealClass;)Z") + ch << InvokeVirtual(RationalClass, "lessEquals", s"(L$RationalClass;)Z") ch << IfEq(elze) << Goto(thenn) } @@ -1314,16 +1316,16 @@ trait CodeGeneration { mkExpr(r, ch) l.getType match { case Int32Type | CharType => - ch << If_ICmpGe(thenn) << Goto(elze) + ch << If_ICmpGe(thenn) << Goto(elze) case IntegerType => ch << InvokeVirtual(BigIntClass, "greaterEquals", s"(L$BigIntClass;)Z") ch << IfEq(elze) << Goto(thenn) case RealType => - ch << InvokeVirtual(RealClass, "greaterEquals", s"(L$RealClass;)Z") + ch << InvokeVirtual(RationalClass, "greaterEquals", s"(L$RationalClass;)Z") ch << IfEq(elze) << Goto(thenn) } - - case IfExpr(c, t, e) => + + case IfExpr(c, t, e) => val innerThen = ch.getFreshLabel("then") val innerElse = ch.getFreshLabel("else") mkBranch(c, innerThen, innerElse, ch) @@ -1340,7 +1342,7 @@ trait CodeGeneration { mkExpr(other, ch, canDelegateToMkBranch = false) ch << IfEq(elze) << Goto(thenn) - case other => throw CompilationException("Unsupported branching expr. : " + other) + case other => throw CompilationException("Unsupported branching expr. : " + other) } } @@ -1366,38 +1368,38 @@ trait CodeGeneration { } /** Compiles a lazy field. - * + * * To define a lazy field, we have to add an accessor method and an underlying field. * The accessor method has the name of the original (Scala) lazy field and can be public. - * The underlying field has a different name, is private, and is of a boxed type - * to support null value (to signify uninitialized). - * + * The underlying field has a different name, is private, and is of a boxed type + * to support null value (to signify uninitialized). + * * @param lzy The lazy field to be compiled * @param owner The module/class containing `lzy` */ - def compileLazyField(lzy: FunDef, owner: Definition) { + def compileLazyField(lzy: FunDef, owner: Definition) { ctx.reporter.internalAssertion(lzy.canBeLazyField, s"Trying to compile non-lazy ${lzy.id.name} as a lazy field") - + val (_, accessorName, _ ) = leonFunDefToJVMInfo(lzy).get val cf = classes(owner) val cName = defToJVMName(owner) - + val isStatic = owner.isInstanceOf[ModuleDef] - + // Name of the underlying field val underlyingName = underlyingField(accessorName) // Underlying field is of boxed type val underlyingType = typeToJVMBoxed(lzy.returnType) - + // Underlying field. It is of a boxed type val fh = cf.addField(underlyingType,underlyingName) fh.setFlags( if (isStatic) {( - FIELD_ACC_STATIC | + FIELD_ACC_STATIC | FIELD_ACC_PRIVATE ).asInstanceOf[U2] } else { FIELD_ACC_PRIVATE }) // FIXME private etc? - + // accessor method locally { val parameters = if (requireMonitor) { @@ -1413,11 +1415,11 @@ trait CodeGeneration { METHOD_ACC_PUBLIC ).asInstanceOf[U2] } else { METHOD_ACC_PUBLIC - }) + }) val ch = accM.codeHandler val body = lzy.body.getOrElse(throw CompilationException("Lazy field without body?")) val initLabel = ch.getFreshLabel("isInitialized") - + if (requireMonitor) { load(monitorID, ch)(newLocs) ch << InvokeVirtual(MonitorClass, "onInvoke", "()V") @@ -1429,13 +1431,13 @@ trait CodeGeneration { ch << ALoad(0) << GetField(cName, underlyingName, underlyingType) // if (lzy == null) } // oldValue - ch << DUP << IfNonNull(initLabel) + ch << DUP << IfNonNull(initLabel) // null ch << POP - // - mkBoxedExpr(body,ch)(newLocs) // lzy = <expr> + // + mkBoxedExpr(body,ch)(newLocs) // lzy = <expr> ch << DUP - // newValue, newValue + // newValue, newValue if (isStatic) { ch << PutStatic(cName, underlyingName, underlyingType) //newValue @@ -1446,7 +1448,7 @@ trait CodeGeneration { ch << PutField (cName, underlyingName, underlyingType) //newValue } - ch << Label(initLabel) // return lzy + ch << Label(initLabel) // return lzy //newValue lzy.returnType match { case ValueType() => @@ -1456,14 +1458,14 @@ trait CodeGeneration { case _ => ch << ARETURN } - ch.freeze + ch.freeze } } - + /** Compile the (strict) field `field` which is owned by class `owner` */ def compileStrictField(field : FunDef, owner : Definition) = { - ctx.reporter.internalAssertion(field.canBeStrictField, + ctx.reporter.internalAssertion(field.canBeStrictField, s"Trying to compile ${field.id.name} as a strict field") val (_, fieldName, _) = leonFunDefToJVMInfo(field).get @@ -1471,7 +1473,7 @@ trait CodeGeneration { val fh = cf.addField(typeToJVM(field.returnType),fieldName) fh.setFlags( owner match { case _ : ModuleDef => ( - FIELD_ACC_STATIC | + FIELD_ACC_STATIC | FIELD_ACC_PUBLIC | // FIXME FIELD_ACC_FINAL ).asInstanceOf[U2] @@ -1481,16 +1483,16 @@ trait CodeGeneration { ).asInstanceOf[U2] }) } - + /** Initializes a lazy field to null * @param ch the codehandler to add the initializing code to - * @param className the name of the class in which the field is initialized + * @param className the name of the class in which the field is initialized * @param lzy the lazy field to be initialized * @param isStatic true if this is a static field */ def initLazyField(ch: CodeHandler, className: String, lzy: FunDef, isStatic: Boolean)(implicit locals: Locals) = { val (_, name, _) = leonFunDefToJVMInfo(lzy).get - val underlyingName = underlyingField(name) + val underlyingName = underlyingField(name) val jvmType = typeToJVMBoxed(lzy.returnType) if (isStatic){ ch << ACONST_NULL << PutStatic(className, underlyingName, jvmType) @@ -1498,10 +1500,10 @@ trait CodeGeneration { ch << ALoad(0) << ACONST_NULL << PutField(className, underlyingName, jvmType) } } - + /** Initializes a (strict) field * @param ch the codehandler to add the initializing code to - * @param className the name of the class in which the field is initialized + * @param className the name of the class in which the field is initialized * @param field the field to be initialized * @param isStatic true if this is a static field */ @@ -1519,8 +1521,8 @@ trait CodeGeneration { } } - def compileAbstractClassDef(acd : AbstractClassDef) { - + def compileAbstractClassDef(acd : AbstractClassDef) { + val cName = defToJVMName(acd) val cf = classes(acd) @@ -1533,30 +1535,30 @@ trait CodeGeneration { cf.addInterface(CaseClassClass) - // add special monitor for method invocations + // add special monitor for method invocations if (params.doInstrument) { val fh = cf.addField("I", instrumentedField) fh.setFlags(FIELD_ACC_PUBLIC) } - + val (fields, methods) = acd.methods partition { _.canBeField } val (strictFields, lazyFields) = fields partition { _.canBeStrictField } - + // Compile methods for (method <- methods) { compileFunDef(method,acd) } - + // Compile lazy fields for (lzy <- lazyFields) { compileLazyField(lzy, acd) } - + // Compile strict fields for (field <- strictFields) { compileStrictField(field, acd) } - + // definition of the constructor locally { val constrParams = if (requireMonitor) { @@ -1586,7 +1588,7 @@ trait CodeGeneration { // Call constructor of java.lang.Object cch << InvokeSpecial(ObjectClass, constructorName, "()V") } - + // Initialize special monitor field if (params.doInstrument) { cch << ALoad(0) @@ -1637,9 +1639,9 @@ trait CodeGeneration { val cName = defToJVMName(ccd) val pName = ccd.parent.map(parent => defToJVMName(parent.classDef)) - // An instantiation of ccd with its own type parameters + // An instantiation of ccd with its own type parameters val cct = CaseClassType(ccd, ccd.tparams.map(_.tp)) - + val cf = classes(ccd) cf.setFlags(( @@ -1662,25 +1664,25 @@ trait CodeGeneration { case (id, jvmt) => (id, (cName, id.name, jvmt)) }.toMap) - locally { + locally { val (fields, methods) = ccd.methods partition { _.canBeField } val (strictFields, lazyFields) = fields partition { _.canBeStrictField } - + // Compile methods for (method <- methods) { compileFunDef(method,ccd) } - + // Compile lazy fields for (lzy <- lazyFields) { - compileLazyField(lzy, ccd) + compileLazyField(lzy, ccd) } - + // Compile strict fields for (field <- strictFields) { - compileStrictField(field, ccd) + compileStrictField(field, ccd) } - + // definition of the constructor if(!params.doInstrument && !requireMonitor && ccd.fields.isEmpty && !ccd.methods.exists(_.canBeField)) { cf.addDefaultConstructor @@ -1692,14 +1694,14 @@ trait CodeGeneration { FIELD_ACC_FINAL ).asInstanceOf[U2]) } - + if (params.doInstrument) { val fh = cf.addField("I", instrumentedField) fh.setFlags(FIELD_ACC_PUBLIC) } - + val cch = cf.addConstructor(constructorArgs.map(_._2) : _*).codeHandler - + if (params.doInstrument) { cch << ALoad(0) cch << Ldc(0) @@ -1831,7 +1833,7 @@ trait CodeGeneration { ech << InvokeVirtual(ObjectClass, "equals", s"(L$ObjectClass;)Z") << IfEq(notEq) } } - } + } ech << Ldc(1) << IRETURN << Label(notEq) << Ldc(0) << IRETURN ech.freeze @@ -1859,7 +1861,7 @@ trait CodeGeneration { hch << ALoad(0) << InvokeVirtual(cName, "productName", "()Ljava/lang/String;") hch << InvokeVirtual("java/lang/String", "hashCode", "()I") hch << InvokeStatic(HashingClass, "seqHash", s"([L$ObjectClass;I)I") << DUP - hch << ALoad(0) << SWAP << PutField(cName, hashFieldName, "I") + hch << ALoad(0) << SWAP << PutField(cName, hashFieldName, "I") hch << IRETURN hch.freeze diff --git a/src/main/scala/leon/codegen/CompilationUnit.scala b/src/main/scala/leon/codegen/CompilationUnit.scala index c4551d5a455dcf5963bfb987025bbefe10d77e25..a8c513a0cba6354b625446eb575fafeef34037b0 100644 --- a/src/main/scala/leon/codegen/CompilationUnit.scala +++ b/src/main/scala/leon/codegen/CompilationUnit.scala @@ -75,9 +75,9 @@ class CompilationUnit(val ctx: LeonContext, // Returns className, methodName, methodSignature private[this] var funDefInfo = Map[FunDef, (String, String, String)]() - + /** - * Returns (cn, mn, sig) where + * Returns (cn, mn, sig) where * cn is the module name * mn is the safe method name * sig is the method signature @@ -164,8 +164,8 @@ class CompilationUnit(val ctx: LeonContext, case InfiniteIntegerLiteral(v) => new runtime.BigInt(v.toString) - case RealLiteral(v) => - new runtime.Real(v.toString) + case FractionalLiteral(n, d) => + new runtime.Rational(n.toString, d.toString) case GenericValue(tp, id) => e @@ -218,7 +218,7 @@ class CompilationUnit(val ctx: LeonContext, //case _ => // compileExpression(e, Seq()).evalToJVM(Seq(),monitor) } - + /** Translates JVM objects back to Leon values of the appropriate type */ def jvmToValue(e: AnyRef, tpe: TypeTree): Expr = (e, tpe) match { case (i: Integer, Int32Type) => @@ -227,8 +227,10 @@ class CompilationUnit(val ctx: LeonContext, case (c: runtime.BigInt, IntegerType) => InfiniteIntegerLiteral(BigInt(c.underlying)) - case (c: runtime.Real, RealType) => - RealLiteral(BigDecimal(c.underlying)) + case (c: runtime.Rational, RealType) => + val num = BigInt(c.numerator()) + val denom = BigInt(c.denominator()) + FractionalLiteral(num, denom) case (b: java.lang.Boolean, BooleanType) => BooleanLiteral(b.booleanValue) @@ -257,7 +259,7 @@ class CompilationUnit(val ctx: LeonContext, case (tpl: runtime.Tuple, tpe) => val stpe = unwrapTupleType(tpe, tpl.getArity) - val elems = stpe.zipWithIndex.map { case (tpe, i) => + val elems = stpe.zipWithIndex.map { case (tpe, i) => jvmToValue(tpl.get(i), tpe) } tupleWrap(elems) @@ -379,7 +381,7 @@ class CompilationUnit(val ctx: LeonContext, val (fields, functions) = module.definedFunctions partition { _.canBeField } val (strictFields, lazyFields) = fields partition { _.canBeStrictField } - + // Compile methods for (function <- functions) { compileFunDef(function,module) @@ -394,17 +396,17 @@ class CompilationUnit(val ctx: LeonContext, for (field <- strictFields) { compileStrictField(field, module) } - + // Constructor cf.addDefaultConstructor val cName = defToJVMName(module) // Add class initializer method - locally{ + locally{ val mh = cf.addMethod("V", "<clinit>") mh.setFlags(( - METHOD_ACC_STATIC | + METHOD_ACC_STATIC | METHOD_ACC_PUBLIC ).asInstanceOf[U2]) @@ -412,9 +414,9 @@ class CompilationUnit(val ctx: LeonContext, /* * FIXME : * Dirty hack to make this compatible with monitoring of method invocations. - * Because we don't have access to the monitor object here, we initialize a new one - * that will get lost when this method returns, so we can't hope to count - * method invocations here :( + * Because we don't have access to the monitor object here, we initialize a new one + * that will get lost when this method returns, so we can't hope to count + * method invocations here :( */ val locals = NoLocals.withVar(monitorID -> ch.getFreshVar) ch << New(MonitorClass) << DUP @@ -443,7 +445,7 @@ class CompilationUnit(val ctx: LeonContext, defToModuleOrClass += meth -> cls } } - + for (m <- u.modules) { defineClass(m) for(funDef <- m.definedFunctions) { @@ -453,14 +455,14 @@ class CompilationUnit(val ctx: LeonContext, } } - /** Compiles the program. + /** Compiles the program. * * Uses information provided by [[init]]. */ def compile() { // Compile everything for (u <- program.units) { - + for { ch <- u.classHierarchies c <- ch diff --git a/src/main/scala/leon/codegen/CompiledExpression.scala b/src/main/scala/leon/codegen/CompiledExpression.scala index ad012bb74f4606bfeb53924c97015a2c3ce54fe5..a9d1eda0c5e36e19a6b6c12f99a617b480866f10 100644 --- a/src/main/scala/leon/codegen/CompiledExpression.scala +++ b/src/main/scala/leon/codegen/CompiledExpression.scala @@ -59,4 +59,4 @@ class CompiledExpression(unit: CompilationUnit, cf: ClassFile, expression: Expr, case ite : InvocationTargetException => throw ite.getCause } } -} +} diff --git a/src/main/scala/leon/evaluators/CodeGenEvaluator.scala b/src/main/scala/leon/evaluators/CodeGenEvaluator.scala index 4cc67c3fa143d7d3260aa68185b9139576f308f6..36cd9da0c7c35cfdc9d674162c3279d461afe813 100644 --- a/src/main/scala/leon/evaluators/CodeGenEvaluator.scala +++ b/src/main/scala/leon/evaluators/CodeGenEvaluator.scala @@ -22,7 +22,7 @@ class CodeGenEvaluator(ctx: LeonContext, val unit : CompilationUnit) extends Eva def eval(expression: Expr, model: solvers.Model) : EvaluationResult = { val toPairs = model.toSeq - compile(expression, toPairs.map(_._1)).map { e => + compile(expression, toPairs.map(_._1)).map { e => ctx.timers.evaluators.codegen.runtime.start() val res = e(model) ctx.timers.evaluators.codegen.runtime.stop() @@ -57,8 +57,8 @@ class CodeGenEvaluator(ctx: LeonContext, val unit : CompilationUnit) extends Eva EvaluationResults.EvaluatorError(e.getMessage) case e : java.lang.ExceptionInInitializerError => - EvaluationResults.RuntimeError(e.getException.getMessage) - + EvaluationResults.RuntimeError(e.getException.getMessage) + case so : java.lang.StackOverflowError => EvaluationResults.RuntimeError("Stack overflow") diff --git a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala index 2091b0e6eed6d28d25d4c7c158eae4dd32586caf..bd2a594ec873a079e58be27852d1c3e223f55130 100644 --- a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala +++ b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala @@ -12,10 +12,10 @@ import purescala.TypeOps.isSubtypeOf import purescala.Constructors._ import purescala.Extractors._ import purescala.Quantification._ - import solvers.{Model, HenkinModel} import solvers.SolverFactory import synthesis.ConvertHoles.convertHoles +import leon.purescala.ExprOps abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int) extends Evaluator(ctx, prog) { val name = "evaluator" @@ -122,11 +122,11 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int if ( exists{ case Hole(_,_) => true case _ => false - }(en)) + }(en)) e(convertHoles(en, ctx)) else e(en.toAssert) - + case Error(tpe, desc) => throw RuntimeError("Error reached in evaluation: " + desc) @@ -147,7 +147,7 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int val nil = CaseClass(CaseClassType(program.library.Nil.get, Seq(tp)), Seq()) def mkCons(h: Expr, t: Expr) = CaseClass(CaseClassType(cons, Seq(tp)), Seq(h,t)) els.foldRight(nil)(mkCons) - + case FunctionInvocation(tfd, args) => if (gctx.stepsLeft < 0) { throw RuntimeError("Exceeded number of allocated methods calls ("+gctx.maxSteps+")") @@ -177,7 +177,7 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int val callResult = e(body)(frame, gctx) tfd.postcondition match { - case Some(post) => + case Some(post) => e(application(post, Seq(callResult)))(frame, gctx) match { case BooleanLiteral(true) => case BooleanLiteral(false) => throw RuntimeError("Postcondition violation for " + tfd.id.asString + " reached in evaluation.") @@ -237,7 +237,7 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int if (isSubtypeOf(le.getType, ct)) { le } else { - throw RuntimeError("Cast error: cannot cast "+le.asString+" to "+ct.asString) + throw RuntimeError("Cast error: cannot cast "+le.asString+" to "+ct.asString) } case IsInstanceOf(expr, ct) => @@ -263,17 +263,15 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int case (le,re) => throw EvalError(typeErrorMsg(le, IntegerType)) } - case RealPlus(l,r) => + case RealPlus(l, r) => (e(l), e(r)) match { - case (RealLiteral(i1), RealLiteral(i2)) => RealLiteral(i1 + i2) - case (le,re) => throw EvalError(typeErrorMsg(le, RealType)) + case (FractionalLiteral(ln, ld), FractionalLiteral(rn, rd)) => + normalizeFraction(FractionalLiteral((ln * rd + rn * ld), (ld * rd))) + case (le, re) => throw EvalError(typeErrorMsg(le, RealType)) } case RealMinus(l,r) => - (e(l), e(r)) match { - case (RealLiteral(i1), RealLiteral(i2)) => RealLiteral(i1 - i2) - case (le,re) => throw EvalError(typeErrorMsg(le, RealType)) - } + e(RealPlus(l, RealUMinus(r))) case BVPlus(l,r) => (e(l), e(r)) match { @@ -301,7 +299,7 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int case RealUMinus(ex) => e(ex) match { - case RealLiteral(i) => RealLiteral(-i) + case FractionalLiteral(n, d) => FractionalLiteral(-n, d) case re => throw EvalError(typeErrorMsg(re, RealType)) } @@ -333,12 +331,12 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int } case Modulo(l,r) => (e(l), e(r)) match { - case (InfiniteIntegerLiteral(i1), InfiniteIntegerLiteral(i2)) => + case (InfiniteIntegerLiteral(i1), InfiniteIntegerLiteral(i2)) => if(i2 < 0) InfiniteIntegerLiteral(i1 mod (-i2)) - else if(i2 != BigInt(0)) - InfiniteIntegerLiteral(i1 mod i2) - else + else if(i2 != BigInt(0)) + InfiniteIntegerLiteral(i1 mod i2) + else throw RuntimeError("Modulo of division by 0.") case (le,re) => throw EvalError(typeErrorMsg(le, IntegerType)) } @@ -358,21 +356,24 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int case BVRemainder(l,r) => (e(l), e(r)) match { - case (IntLiteral(i1), IntLiteral(i2)) => + case (IntLiteral(i1), IntLiteral(i2)) => if(i2 != 0) IntLiteral(i1 % i2) else throw RuntimeError("Remainder of division by 0.") case (le,re) => throw EvalError(typeErrorMsg(le, Int32Type)) } case RealTimes(l,r) => (e(l), e(r)) match { - case (RealLiteral(i1), RealLiteral(i2)) => RealLiteral(i1 * i2) + case (FractionalLiteral(ln, ld), FractionalLiteral(rn, rd)) => + normalizeFraction(FractionalLiteral((ln * rn), (ld * rd))) case (le,re) => throw EvalError(typeErrorMsg(le, RealType)) } case RealDivision(l,r) => (e(l), e(r)) match { - case (RealLiteral(i1), RealLiteral(i2)) => - if (i2 != 0) RealLiteral(i1 / i2) else throw RuntimeError("Division by 0.") + case (FractionalLiteral(ln, ld), FractionalLiteral(rn, rd)) => + if (rn != 0) + normalizeFraction(FractionalLiteral((ln * rd), (ld * rn))) + else throw RuntimeError("Division by 0.") case (le,re) => throw EvalError(typeErrorMsg(le, RealType)) } @@ -417,7 +418,9 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int (e(l), e(r)) match { case (IntLiteral(i1), IntLiteral(i2)) => BooleanLiteral(i1 < i2) case (InfiniteIntegerLiteral(i1), InfiniteIntegerLiteral(i2)) => BooleanLiteral(i1 < i2) - case (RealLiteral(r1), RealLiteral(r2)) => BooleanLiteral(r1 < r2) + case (a @ FractionalLiteral(_, _), b @ FractionalLiteral(_, _)) => + val FractionalLiteral(n, _) = e(RealMinus(a, b)) + BooleanLiteral(n < 0) case (CharLiteral(c1), CharLiteral(c2)) => BooleanLiteral(c1 < c2) case (le,re) => throw EvalError(typeErrorMsg(le, Int32Type)) } @@ -426,7 +429,9 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int (e(l), e(r)) match { case (IntLiteral(i1), IntLiteral(i2)) => BooleanLiteral(i1 > i2) case (InfiniteIntegerLiteral(i1), InfiniteIntegerLiteral(i2)) => BooleanLiteral(i1 > i2) - case (RealLiteral(r1), RealLiteral(r2)) => BooleanLiteral(r1 > r2) + case (a @ FractionalLiteral(_, _), b @ FractionalLiteral(_, _)) => + val FractionalLiteral(n, _) = e(RealMinus(a, b)) + BooleanLiteral(n > 0) case (CharLiteral(c1), CharLiteral(c2)) => BooleanLiteral(c1 > c2) case (le,re) => throw EvalError(typeErrorMsg(le, Int32Type)) } @@ -435,7 +440,9 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int (e(l), e(r)) match { case (IntLiteral(i1), IntLiteral(i2)) => BooleanLiteral(i1 <= i2) case (InfiniteIntegerLiteral(i1), InfiniteIntegerLiteral(i2)) => BooleanLiteral(i1 <= i2) - case (RealLiteral(r1), RealLiteral(r2)) => BooleanLiteral(r1 <= r2) + case (a @ FractionalLiteral(_, _), b @ FractionalLiteral(_, _)) => + val FractionalLiteral(n, _) = e(RealMinus(a, b)) + BooleanLiteral(n <= 0) case (CharLiteral(c1), CharLiteral(c2)) => BooleanLiteral(c1 <= c2) case (le,re) => throw EvalError(typeErrorMsg(le, Int32Type)) } @@ -444,14 +451,16 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int (e(l), e(r)) match { case (IntLiteral(i1), IntLiteral(i2)) => BooleanLiteral(i1 >= i2) case (InfiniteIntegerLiteral(i1), InfiniteIntegerLiteral(i2)) => BooleanLiteral(i1 >= i2) - case (RealLiteral(r1), RealLiteral(r2)) => BooleanLiteral(r1 >= r2) + case (a @ FractionalLiteral(_, _), b @ FractionalLiteral(_, _)) => + val FractionalLiteral(n, _) = e(RealMinus(a, b)) + BooleanLiteral(n >= 0) case (CharLiteral(c1), CharLiteral(c2)) => BooleanLiteral(c1 >= c2) case (le,re) => throw EvalError(typeErrorMsg(le, Int32Type)) } case SetUnion(s1,s2) => (e(s1), e(s2)) match { - case (f@FiniteSet(els1, _),FiniteSet(els2, _)) => + case (f@FiniteSet(els1, _),FiniteSet(els2, _)) => val SetType(tpe) = f.getType FiniteSet(els1 ++ els2, tpe) case (le,re) => throw EvalError(typeErrorMsg(le, s1.getType)) @@ -492,7 +501,7 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int case _ => throw EvalError(typeErrorMsg(sr, SetType(Untyped))) } - case f @ FiniteSet(els, base) => + case f @ FiniteSet(els, base) => FiniteSet(els.map(e), base) case l @ Lambda(_, _) => @@ -588,12 +597,12 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int case f @ FiniteArray(elems, default, length) => val ArrayType(tp) = f.getType finiteArray( - elems.map(el => (el._1, e(el._2))), + elems.map(el => (el._1, e(el._2))), default.map{ d => (e(d), e(length)) }, tp ) - case f @ FiniteMap(ss, kT, vT) => + case f @ FiniteMap(ss, kT, vT) => FiniteMap(ss.map{ case (k, v) => (e(k), e(v)) }.distinct, kT, vT) case g @ MapApply(m,k) => (e(m), e(k)) match { @@ -620,7 +629,7 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int case gv: GenericValue => gv - case p : Passes => + case p : Passes => e(p.asConstraint) case choose: Choose => @@ -684,6 +693,7 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int throw RuntimeError("MatchError: "+rscrut.asString+" did not match any of the cases") } + case fl : FractionalLiteral => normalizeFraction(fl) case l : Literal[_] => l case other => diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala index 025aa4b9839eb0891899101130f73508c7b23ba9..fbbfd6e3d503c8e8004e1fc0118bb6153087d85b 100644 --- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala +++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala @@ -139,7 +139,7 @@ trait CodeExtraction extends ASTExtractors { //This is a bit misleading, if an expr is not mapped then it has no owner, if it is mapped to None it means //that it can have any owner - private var owners: Map[Identifier, Option[FunDef]] = Map() + private var owners: Map[Identifier, Option[FunDef]] = Map() // This one never fails, on error, it returns Untyped def leonType(tpt: Type)(implicit dctx: DefContext, pos: Position): LeonType = { @@ -379,7 +379,7 @@ trait CodeExtraction extends ASTExtractors { val allSels = sels map { prefix :+ _.name.toString } // Make a different import for each selector at the end of the chain - allSels flatMap { selectors => + allSels flatMap { selectors => assert(selectors.nonEmpty) val (thePath, isWild) = selectors.last match { case "_" => (selectors.dropRight(1), true) @@ -448,8 +448,8 @@ trait CodeExtraction extends ASTExtractors { private var methodToClass = Map[FunDef, LeonClassDef]() /** - * For the function in $defs with name $owner, find its parameter with index $index, - * and registers $fd as the default value function for this parameter. + * For the function in $defs with name $owner, find its parameter with index $index, + * and registers $fd as the default value function for this parameter. */ private def registerDefaultMethod( defs : List[Tree], @@ -548,7 +548,7 @@ trait CodeExtraction extends ASTExtractors { //println(s"Body of $sym") - // We collect the methods and fields + // We collect the methods and fields for (d <- tmpl.body) d match { case EmptyTree => // ignore @@ -575,7 +575,7 @@ trait CodeExtraction extends ASTExtractors { cd.registerMethod(fd) val matcher: PartialFunction[Tree, Symbol] = { - case ExFunctionDef(ownerSym, _ ,_ ,_, _) if ownerSym.name.toString == owner => ownerSym + case ExFunctionDef(ownerSym, _ ,_ ,_, _) if ownerSym.name.toString == owner => ownerSym } registerDefaultMethod(tmpl.body, matcher, index, fd ) @@ -765,7 +765,7 @@ trait CodeExtraction extends ASTExtractors { // Find defining function for params with default value for ((s,vd) <- params zip funDef.params) { - vd.defaultValue = paramsToDefaultValues.get(s.symbol) + vd.defaultValue = paramsToDefaultValues.get(s.symbol) } val newVars = for ((s, vd) <- params zip funDef.params) yield { @@ -775,7 +775,7 @@ trait CodeExtraction extends ASTExtractors { val fctx = dctx.withNewVars(newVars).copy(isExtern = funDef.annotations("extern")) // If this is a lazy field definition, drop the assignment/ accessing - val body = + val body = if (funDef.flags.contains(IsField(true))) { body0 match { case Block(List(Assign(_, realBody)),_ ) => realBody case _ => outOfSubsetError(body0, "Wrong form of lazy accessor") @@ -938,7 +938,7 @@ trait CodeExtraction extends ASTExtractors { val recGuard = extractTree(cd.guard)(ndctx) if(isXLang(recGuard)) { - outOfSubsetError(cd.guard.pos, "Guard expression must be pure") + outOfSubsetError(cd.guard.pos, "Guard expression must be pure") } GuardedCase(recPattern, recGuard, recBody).setPos(cd.pos) @@ -1110,7 +1110,7 @@ trait CodeExtraction extends ASTExtractors { LetDef(funDefWithBody, restTree) // FIXME case ExDefaultValueFunction - + /** * XLang Extractors */ @@ -1218,7 +1218,7 @@ trait CodeExtraction extends ASTExtractors { rec match { case IntLiteral(n) => InfiniteIntegerLiteral(BigInt(n)) - case _ => + case _ => outOfSubsetError(tr, "Conversion from Int to BigInt") } @@ -1227,7 +1227,7 @@ trait CodeExtraction extends ASTExtractors { val rd = extractTree(d) (rn, rd) match { case (InfiniteIntegerLiteral(n), InfiniteIntegerLiteral(d)) => - RealLiteral(BigDecimal(n) / BigDecimal(d)) + FractionalLiteral(n, d) case _ => outOfSubsetError(tr, "Real not build from literals") } @@ -1235,7 +1235,7 @@ trait CodeExtraction extends ASTExtractors { val rn = extractTree(n) rn match { case InfiniteIntegerLiteral(n) => - RealLiteral(BigDecimal(n)) + FractionalLiteral(n, 1) case _ => outOfSubsetError(tr, "Real not build from literals") } @@ -1482,7 +1482,7 @@ trait CodeExtraction extends ASTExtractors { case str @ ExStringLiteral(s) => val chars = s.toList.map(CharLiteral) - + val consChar = CaseClassType(libraryCaseClass(str.pos, "leon.collection.Cons"), Seq(CharType)) val nilChar = CaseClassType(libraryCaseClass(str.pos, "leon.collection.Nil"), Seq(CharType)) @@ -1810,11 +1810,11 @@ trait CodeExtraction extends ASTExtractors { case RefinedType(parents, defs) if defs.isEmpty => /** - * For cases like if(a) e1 else e2 where + * For cases like if(a) e1 else e2 where * e1 <: C1, * e2 <: C2, * with C1,C2 <: C - * + * * Scala might infer a type for C such as: Product with Serializable with C * we generalize to the first known type, e.g. C. */ diff --git a/src/main/scala/leon/invariant/engine/CompositionalTemplateSolver.scala b/src/main/scala/leon/invariant/engine/CompositionalTemplateSolver.scala new file mode 100644 index 0000000000000000000000000000000000000000..4c443b3f6e08b7cf381b661a1725862156c1bba6 --- /dev/null +++ b/src/main/scala/leon/invariant/engine/CompositionalTemplateSolver.scala @@ -0,0 +1,235 @@ +package leon +package invariant.engine + +import purescala.Common._ +import purescala.Definitions._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.Extractors._ +import purescala.Types._ +import invariant.templateSolvers._ +import invariant.util.Util._ +import transformations._ +import invariant.structure.FunctionUtils._ +import transformations.InstUtil._ +import leon.invariant.structure.Formula +import leon.invariant.structure.Call +import leon.invariant.util.RealToInt +import leon.invariant.util.OrderedMultiMap +import leon.invariant.util.ExpressionTransformer +import leon.invariant.factories.TemplateSolverFactory +import leon.invariant.util.Minimizer +import leon.solvers.Model + +class CompositionalTimeBoundSolver(ctx: InferenceContext, rootFd: FunDef) + extends FunctionTemplateSolver { + + val printIntermediatePrograms = false + val debugDecreaseConstraints = false + val debugComposition = false + val reporter = ctx.reporter + + def inferTemplate(instProg: Program) = { + (new UnfoldingTemplateSolver(ctx.copy(program = instProg), findRoot(instProg)))() + } + + def findRoot(prog: Program) = { + functionByName(rootFd.id.name, prog).get + } + + def apply() = { + // Check if all the three templates have different template variable sets + val (Some(tprTmpl), Some(recTmpl), Some(timeTmpl), othersTmpls) = extractSeparateTemplates(rootFd) + val tmplIds = (Seq(tprTmpl, recTmpl, timeTmpl) ++ othersTmpls) flatMap getTemplateIds + if (tmplIds.toSet.size < tmplIds.size) + throw new IllegalStateException("Templates for tpr, rec, time as well as all other templates " + + " taken together should not have the any common template variables for compositional analysis") + + val origProg = ctx.program + // add only rec templates for all functions + val funToRecTmpl = origProg.definedFunctions.collect { + case fd if fd.hasTemplate && fd == rootFd => + fd -> recTmpl + case fd if fd.hasTemplate => + fd -> fd.getTemplate + }.toMap + val recProg = assignTemplateAndCojoinPost(funToRecTmpl, origProg) + + // add only tpr template for all functions + val funToNonRecTmpl = origProg.definedFunctions.collect { + case fd if fd.hasTemplate && fd == rootFd => + fd -> tprTmpl + case fd if fd.hasTemplate => + fd -> fd.getTemplate + }.toMap + val tprProg = assignTemplateAndCojoinPost(funToNonRecTmpl, origProg) + + if (printIntermediatePrograms) { + reporter.info("RecProg:\n" + recProg) + reporter.info("TRPProg: \n" + tprProg) + } + val recInfRes = inferTemplate(recProg) + val tprInfRes = inferTPRTemplate(tprProg) + + (recInfRes, tprInfRes) match { + case (Some(InferResult(true, Some(recModel), _)), + Some(InferResult(true, Some(tprModel), _))) => + // create a new program by omitting the templates of the root function + val funToTmpl = origProg.definedFunctions.collect { + case fd if fd.hasTemplate && fd != rootFd => + (fd -> fd.getTemplate) + }.toMap + val compProg = assignTemplateAndCojoinPost(funToTmpl, origProg) + val compFunDef = findRoot(compProg) + val nctx = ctx.copy(program = compProg) + + // construct the instantiated tpr bound and check if it monotonically decreases + val Operator(Seq(_, tprFun), _) = tprTmpl + val tprFunInst = (new RealToInt()).mapRealToInt( + replace(tprModel.map { case (k, v) => (k.toVariable -> v) }.toMap, tprFun)) + // TODO: this would fail on non-integers, handle these by approximating to the next bigger integer + + // Upper bound on time time <= recFun * tprFun + tprFun + val (_, multFun) = MultFuncs.getMultFuncs(if (ctx.usereals) RealType else IntegerType) + val Operator(Seq(_, recFun), _) = recTmpl + val recFunInst = (new RealToInt()).mapRealToInt( + replace(recModel.map { case (k, v) => (k.toVariable -> v) }.toMap, recFun)) + + val timeUpperBound = ExpressionTransformer.normalizeMultiplication( + Plus(FunctionInvocation(TypedFunDef(multFun, Seq()), + Seq(recFunInst, tprFunInst)), tprFunInst), ctx.multOp) + // res = body + val plainBody = Equals(getResId(rootFd).get.toVariable, matchToIfThenElse(rootFd.body.get)) + val bodyExpr = if (rootFd.hasPrecondition) { + And(matchToIfThenElse(rootFd.precondition.get), plainBody) + } else plainBody + + val Operator(Seq(timeInstExpr, _), _) = timeTmpl + val compositionAnt = And(Seq(LessEquals(timeInstExpr, timeUpperBound), bodyExpr)) + val prototypeVC = And(compositionAnt, Not(timeTmpl)) + + // map the old functions in the vc using the new functions + val substMap = origProg.definedFunctions.collect { + case fd => + (fd -> functionByName(fd.id.name, compProg).get) + }.toMap + val vcExpr = mapFunctionsInExpr(substMap)(prototypeVC) + + if (printIntermediatePrograms) reporter.info("Comp prog: " + compProg) + if (debugComposition) reporter.info("Compositional VC: " + vcExpr) + + val recTempSolver = new UnfoldingTemplateSolver(nctx, compFunDef) { + val minFunc = { + val mizer = new Minimizer(ctx) + Some(mizer.minimizeBounds(mizer.computeCompositionLevel(timeTmpl)) _) + } + override lazy val templateSolver = + TemplateSolverFactory.createTemplateSolver(ctx, constTracker, rootFd, minFunc) + override def instantiateModel(model: Model, funcs: Seq[FunDef]) = { + funcs.collect { + case `compFunDef` => + compFunDef -> timeTmpl + case fd if fd.hasTemplate => + fd -> fd.getTemplate + }.toMap + } + } + recTempSolver.solveParametricVC(vcExpr) match { + case Some(InferResult(true, Some(timeModel),timeInferredFuncs)) => + val inferredFuns = (recInfRes.get.inferredFuncs ++ tprInfRes.get.inferredFuncs ++ timeInferredFuncs).distinct + Some(InferResult(true, Some(recModel ++ tprModel.toMap ++ timeModel.toMap), + inferredFuns.map(ifd => functionByName(ifd.id.name, origProg).get).distinct)) + case res @ _ => + res + } + case _ => + reporter.info("Could not infer bounds on rec and(or) tpr. Cannot precced with composition.") + None + } + } + + def combineMapsUsingConjunction(maps: List[Map[FunDef, Expr]]) = { + val combMap = new OrderedMultiMap[FunDef, Expr] + maps.foreach { + _.foreach { + case (k, v) => + val origFun = functionByName(k.id.name, ctx.program).get + combMap.addBinding(origFun, v) + } + } + combMap.foldLeft(Map[FunDef, Expr]()) { + case (acc, (k, vs)) if vs.size == 1 => acc + (k -> vs(0)) + case (acc, (k, vs)) => acc + (k -> And(vs.toSeq)) + } + } + + def extractSeparateTemplates(funDef: FunDef): (Option[Expr], Option[Expr], Option[Expr], Seq[Expr]) = { + if (!funDef.hasTemplate) (None, None, None, Seq[Expr]()) + else { + val template = ExpressionTransformer.pullAndOrs(And(funDef.getTemplate, + funDef.getPostWoTemplate)) // note that some bounds can occur in post and not in tmpl + def extractTmplConjuncts(tmpl: Expr): Seq[Expr] = { + tmpl match { + case And(seqExprs) => + seqExprs + case _ => + throw new IllegalStateException("Compositional reasoning requires templates to be conjunctions!" + tmpl) + } + } + val tmplConjuncts = extractTmplConjuncts(template) + val tupleSelectToInst = InstUtil.getInstMap(funDef) + var tprTmpl: Option[Expr] = None + var timeTmpl: Option[Expr] = None + var recTmpl: Option[Expr] = None + var othersTmpls: Seq[Expr] = Seq[Expr]() + tmplConjuncts.foreach(conj => { + conj match { + case Operator(Seq(lhs, _), _) if (tupleSelectToInst.contains(lhs)) => + tupleSelectToInst(lhs) match { + case n if n == TPR.name => + tprTmpl = Some(conj) + case n if n == Time.name => + timeTmpl = Some(conj) + case n if n == Rec.name => + recTmpl = Some(conj) + case _ => + othersTmpls = othersTmpls :+ conj + } + case _ => + othersTmpls = othersTmpls :+ conj + } + }) + (tprTmpl, recTmpl, timeTmpl, othersTmpls) + } + } + + def inferTPRTemplate(tprProg: Program) = { + val tempSolver = new UnfoldingTemplateSolver(ctx.copy(program = tprProg), findRoot(tprProg)) { + override def constructVC(rootFd: FunDef): (Expr, Expr) = { + val body = Equals(getResId(rootFd).get.toVariable, matchToIfThenElse(rootFd.body.get)) + val preExpr = + if (rootFd.hasPrecondition) + matchToIfThenElse(rootFd.precondition.get) + else tru + val tprTmpl = rootFd.getTemplate + val postWithTemplate = matchToIfThenElse(And(rootFd.getPostWoTemplate, tprTmpl)) + // generate constraints characterizing decrease of the tpr function with recursive calls + val Operator(Seq(_, tprFun), op) = tprTmpl + val bodyFormula = new Formula(rootFd, ExpressionTransformer.normalizeExpr(body, ctx.multOp), ctx) + val constraints = bodyFormula.disjunctsInFormula.flatMap { + case (guard, ctrs) => + ctrs.collect { + case call @ Call(_, FunctionInvocation(TypedFunDef(`rootFd`, _), _)) => //direct recursive call ? + Implies(guard, LessEquals(replace(formalToActual(call), tprFun), tprFun)) + } + } + if (debugDecreaseConstraints) + reporter.info("Decrease constraints: " + createAnd(constraints.toSeq)) + + val fullPost = createAnd(postWithTemplate +: constraints.toSeq) + (And(preExpr, bodyFormula.toExpr), fullPost) + } + } + tempSolver() + } +} \ No newline at end of file diff --git a/src/main/scala/leon/invariant/engine/ConstraintTracker.scala b/src/main/scala/leon/invariant/engine/ConstraintTracker.scala new file mode 100644 index 0000000000000000000000000000000000000000..e50b50bb8b84886bff46c54641f55050fc5f5e52 --- /dev/null +++ b/src/main/scala/leon/invariant/engine/ConstraintTracker.scala @@ -0,0 +1,45 @@ +package leon +package invariant.engine + +import z3.scala._ +import purescala._ +import purescala.Common._ +import purescala.Definitions._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.Extractors._ +import purescala.Types._ +import evaluators._ +import java.io._ + +import invariant.factories._ +import invariant.util._ +import invariant.structure._ + +class ConstraintTracker(ctx : InferenceContext, rootFun : FunDef/*, temFactory: TemplateFactory*/) { + + //a mapping from functions to its VCs represented as a CNF formula + protected var funcVCs = Map[FunDef,Formula]() + + val vcRefiner = new RefinementEngine(ctx, this/*, temFactory*/) + val specInstantiator = new SpecInstantiator(ctx, this/*, temFactory*/) + + def getFuncs : Seq[FunDef] = funcVCs.keys.toSeq + def hasVC(fdef: FunDef) = funcVCs.contains(fdef) + def getVC(fd: FunDef) : Formula = funcVCs(fd) + + def addVC(fd: FunDef, vc: Expr) = { + funcVCs += (fd -> new Formula(fd, vc, ctx)) + } + + def initialize = { + //assume specifications + specInstantiator.instantiate + } + + def refineVCs(toUnrollCalls: Option[Set[Call]]) : Set[Call] = { + val unrolledCalls = vcRefiner.refineAbstraction(toUnrollCalls) + specInstantiator.instantiate + unrolledCalls + } +} diff --git a/src/main/scala/leon/invariant/engine/InferInvariantsPhase.scala b/src/main/scala/leon/invariant/engine/InferInvariantsPhase.scala new file mode 100644 index 0000000000000000000000000000000000000000..30ca3fcdebedca181e58ea6929930d3b823aad9e --- /dev/null +++ b/src/main/scala/leon/invariant/engine/InferInvariantsPhase.scala @@ -0,0 +1,165 @@ +package leon +package invariant.engine + +import purescala.Common._ +import purescala.Definitions._ +import purescala.ExprOps._ +import purescala.Expressions._ +import purescala.Extractors._ +import purescala.Types._ +import verification.VerificationReport +import invariant.templateSolvers._ +import invariant.factories._ +import invariant.util._ +import invariant.structure.FunctionUtils._ +import invariant.structure._ +import transformations._ +import verification._ +import verification.VCKinds +import leon.purescala.ScalaPrinter + +/** + * @author ravi + * This phase performs automatic invariant inference. + * TODO: should time be implicitly made positive + */ +object InferInvariantsPhase extends LeonPhase[Program, InferenceReport] { + val name = "InferInv" + val description = "Invariant Inference" + + val optWholeProgram = LeonFlagOptionDef("wholeprogram", "Perform an non-modular whole program analysis", false) + val optFunctionUnroll = LeonFlagOptionDef("fullunroll", "Unroll all calls in every unroll step", false) + val optWithMult = LeonFlagOptionDef("withmult", "Multiplication is not converted to a recursive function in VCs", false) + val optUseReals = LeonFlagOptionDef("usereals", "Interpret the input program as a real program", false) + val optMinBounds = LeonFlagOptionDef("minbounds", "tighten time bounds", false) + val optInferTemp = LeonFlagOptionDef("inferTemp", "Infer templates by enumeration", false) + val optCegis = LeonFlagOptionDef("cegis", "use cegis instead of farkas", false) + val optStatsSuffix = LeonStringOptionDef("stats-suffix", "the suffix of the statistics file", "", "s") + val optTimeout = LeonLongOptionDef("timeout", "Timeout after T seconds when trying to prove a verification condition.", 20, "s") + val optDisableInfer = LeonFlagOptionDef("disableInfer", "Disable automatic inference of auxiliary invariants", false) + + override val definedOptions: Set[LeonOptionDef[Any]] = + Set(optWholeProgram, optFunctionUnroll, optWithMult, optUseReals, + optMinBounds, optInferTemp, optCegis, optStatsSuffix, optTimeout, + optDisableInfer) + + //TODO provide options for analyzing only selected functions + def run(ctx: LeonContext)(prog: Program): InferenceReport = { + + //control printing of statistics + val dumpStats = true + var timeout: Int = 15 + + //defualt true flags + var modularlyAnalyze = true + var targettedUnroll = true + + //default false flags + var tightBounds = false + var withmult = false + var inferTemp = false + var enumerationRelation: (Expr, Expr) => Expr = LessEquals + var useCegis = false + //var maxCegisBound = 200 //maximum bound for the constants in cegis + var maxCegisBound = 1000000000 + var statsSuff = "-stats" + FileCountGUID.getID + var usereals = false + var autoInference = true + + for (opt <- ctx.options) (opt.optionDef.name, opt.value) match { + case ("wholeprogram", true) => { + //do not do a modular analysis + modularlyAnalyze = false + } + + case ("fullunroll", true) => { + //do not do a modular analysis + targettedUnroll = false + } + + case ("minbounds", true) => { + tightBounds = true + } + + case ("withmult", true) => { + withmult = true + } + + case ("usereals", true) => { + usereals = true + } + + case ("disableInfer", true) => + autoInference = false + + case ("inferTemp", true) => { + inferTemp = true + var foundStrongest = false + //go over all post-conditions and pick the strongest relation + prog.definedFunctions.foreach((fd) => { + if (!foundStrongest && fd.hasPostcondition) { + val cond = fd.postcondition.get + simplePostTransform((e) => e match { + case Equals(_, _) => { + enumerationRelation = Equals.apply _ + foundStrongest = true + e + } + case _ => e + })(cond) + } + }) + } + + case ("cegis", true) => { + useCegis = true + } + + case ("timeout", timeOut: Int) => + timeout = timeOut + + case ("stats-suffix", suffix: String) => { + statsSuff = suffix + } + + case _ => + } + + val funToTmpl = prog.definedFunctions.collect { + case fd if fd.hasTemplate => + fd -> fd.getTemplate + }.toMap + val qMarksRemovedProg = Util.assignTemplateAndCojoinPost(funToTmpl, prog, Map()) + + val newprog = if (usereals) { + (new IntToRealProgram())(qMarksRemovedProg) + } else qMarksRemovedProg + val nlelim = new NonlinearityEliminator(withmult, if (usereals) RealType else IntegerType) + val finalprog = nlelim(newprog) + + val toVerifyPost = validateAndCollectNotValidated(qMarksRemovedProg, ctx, timeout) + //populate the inference context and invoke inferenceEngine + val inferctx = new InferenceContext(finalprog, toVerifyPost, ctx, + //multiplication operation + (e1, e2) => FunctionInvocation(TypedFunDef(nlelim.multFun, nlelim.multFun.tparams.map(_.tp)), Seq(e1, e2)), + enumerationRelation = LessEquals, modularlyAnalyze, targettedUnroll, autoInference, + dumpStats, tightBounds, withmult, usereals, inferTemp, useCegis, timeout, maxCegisBound, statsSuff) + (new InferenceEngine(inferctx)).run() + } + + def createLeonContext(ctx: LeonContext, opts: String*): LeonContext = { + Main.processOptions(opts.toList).copy(reporter = ctx.reporter, + interruptManager = ctx.interruptManager, files = ctx.files, timers = ctx.timers) + } + + def validateAndCollectNotValidated(prog: Program, ctx: LeonContext, timeout: Int): Set[String] = { + val verifyPipe = AnalysisPhase + val ctxWithTO = createLeonContext(ctx, "--timeout=" + timeout) + (verifyPipe.run(ctxWithTO)(prog)).results.collect{ + case (VC(_, fd, VCKinds.Postcondition), Some(vcRes)) if vcRes.isInconclusive => + fd.id.name + case (VC(_, fd, vcKind), Some(vcRes)) if vcRes.isInvalid => + throw new IllegalStateException("Invalid" + vcKind + " for function " + fd.id.name) + }.toSet + } +} \ No newline at end of file diff --git a/src/main/scala/leon/invariant/engine/InferenceContext.scala b/src/main/scala/leon/invariant/engine/InferenceContext.scala new file mode 100644 index 0000000000000000000000000000000000000000..243cd54801a40ca9d4c2a60dfcb2d24ac3fdf802 --- /dev/null +++ b/src/main/scala/leon/invariant/engine/InferenceContext.scala @@ -0,0 +1,31 @@ +package leon +package invariant.engine + +import purescala.Definitions._ +import purescala.Expressions._ +import purescala._ + +/** + * @author ravi + */ +case class InferenceContext( + val program : Program, + val toVerifyPostFor: Set[String], + val leonContext : LeonContext, + val multOp: (Expr,Expr) => Expr, + val enumerationRelation : (Expr,Expr) => Expr, + val modularlyAnalyze : Boolean, + val targettedUnroll : Boolean, + val autoInference : Boolean, + val dumpStats : Boolean , + val tightBounds : Boolean, + val withmult : Boolean, + val usereals : Boolean, + val inferTemp : Boolean, + val useCegis : Boolean, + val timeout: Int, //in secs + val maxCegisBound : Int, + val statsSuffix : String) { + + val reporter = leonContext.reporter +} diff --git a/src/main/scala/leon/invariant/engine/InferenceEngine.scala b/src/main/scala/leon/invariant/engine/InferenceEngine.scala new file mode 100644 index 0000000000000000000000000000000000000000..067eac087c5267cbc9c7c8eb7a62c90bfac3465f --- /dev/null +++ b/src/main/scala/leon/invariant/engine/InferenceEngine.scala @@ -0,0 +1,179 @@ +package leon +package invariant.engine + +import z3.scala._ +import purescala.Common._ +import purescala.Definitions._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.Extractors._ +import purescala.Types._ +import solvers._ +import java.io._ +import verification.VerificationReport +import verification.VC +import scala.util.control.Breaks._ +import invariant.templateSolvers._ +import invariant.factories._ +import invariant.util._ +import invariant.util.Util._ +import invariant.structure._ +import invariant.structure.FunctionUtils._ +import leon.invariant.factories.TemplateFactory + +/** + * @author ravi + * This phase performs automatic invariant inference. + * TODO: should time be implicitly made positive + */ +class InferenceEngine(ctx: InferenceContext) { + + def run(): InferenceReport = { + val reporter = ctx.reporter + val program = ctx.program + reporter.info("Running Inference Engine...") + + //register a shutdownhook + if (ctx.dumpStats) { + sys.ShutdownHookThread({ dumpStats(ctx.statsSuffix) }) + } + val t1 = System.currentTimeMillis() + //compute functions to analyze by sorting based on topological order (this is an ascending topological order) + val callgraph = CallGraphUtil.constructCallGraph(program, withTemplates = true) + val functionsToAnalyze = if (ctx.modularlyAnalyze) { + callgraph.topologicalOrder + } else { + callgraph.topologicalOrder.reverse + } + //reporter.info("Analysis Order: " + functionsToAnalyze.map(_.id)) + var results: Map[FunDef, InferenceCondition] = null + if (!ctx.useCegis) { + results = analyseProgram(functionsToAnalyze) + //println("Inferrence did not succeeded for functions: "+functionsToAnalyze.filterNot(succeededFuncs.contains _).map(_.id)) + } else { + var remFuncs = functionsToAnalyze + var b = 200 + var maxCegisBound = 200 + breakable { + while (b <= maxCegisBound) { + Stats.updateCumStats(1, "CegisBoundsTried") + val succeededFuncs = analyseProgram(remFuncs) + remFuncs = remFuncs.filterNot(succeededFuncs.contains _) + if (remFuncs.isEmpty) break; + b += 5 //increase bounds in steps of 5 + } + //println("Inferrence did not succeeded for functions: " + remFuncs.map(_.id)) + } + } + val t2 = System.currentTimeMillis() + Stats.updateCumTime(t2 - t1, "TotalTime") + //dump stats + if (ctx.dumpStats) { + reporter.info("- Dumping statistics") + dumpStats(ctx.statsSuffix) + } + new InferenceReport(program, results.map(pair => { + val (fd, ic) = pair + (fd -> List[VC](ic)) + }))(ctx) + } + + def dumpStats(statsSuffix: String) = { + //pick the module id. + val modid = ctx.program.modules.last.id + val pw = new PrintWriter(modid + statsSuffix + ".txt") + Stats.dumpStats(pw) + SpecificStats.dumpOutputs(pw) + if (ctx.tightBounds) { + SpecificStats.dumpMinimizationStats(pw) + } + } + + /** + * Returns map from analyzed functions to their inference conditions. + * TODO: use function names in inference conditions, so that + * we an get rid of dependence on origFd in many places. + */ + def analyseProgram(functionsToAnalyze: Seq[FunDef]): Map[FunDef, InferenceCondition] = { + val reporter = ctx.reporter + val funToTmpl = + if (ctx.autoInference) { + //A template generator that generates templates for the functions (here we are generating templates by enumeration) + val tempFactory = new TemplateFactory(Some(new TemplateEnumerator(ctx)), ctx.program, ctx.reporter) + ctx.program.definedFunctions.map(fd => fd -> getOrCreateTemplateForFun(fd)).toMap + } else + ctx.program.definedFunctions.collect { case fd if fd.hasTemplate => fd -> fd.getTemplate }.toMap + val progWithTemplates = assignTemplateAndCojoinPost(funToTmpl, ctx.program) + var analyzedSet = Map[FunDef, InferenceCondition]() + functionsToAnalyze.filterNot((fd) => { + (fd.annotations contains "verified") || + (fd.annotations contains "library") || + (fd.annotations contains "theoryop") + }).foldLeft(progWithTemplates) { (prog, origFun) => + + val funDef = functionByName(origFun.id.name, prog).get + reporter.info("- considering function " + funDef.id.name + "...") + + //skip the function if it has been analyzed + if (!analyzedSet.contains(origFun)) { + if (funDef.hasBody && funDef.hasPostcondition) { + val currCtx = ctx.copy(program = prog) + // for stats + Stats.updateCounter(1, "procs") + val solver = + if (funDef.annotations.contains("compose")) //compositional inference ? + new CompositionalTimeBoundSolver(currCtx, funDef) + else + new UnfoldingTemplateSolver(currCtx, funDef) + val t1 = System.currentTimeMillis() + val infRes = solver() + val funcTime = (System.currentTimeMillis() - t1) / 1000.0 + infRes match { + case Some(InferResult(true, model, inferredFuns)) => + val funsWithTemplates = inferredFuns.filter { fd => + val origFd = Util.functionByName(fd.id.name, ctx.program).get + !analyzedSet.contains(origFd) && origFd.hasTemplate + } + // create a inference condition for reporting + var first = true + funsWithTemplates.foreach { fd => + val origFd = Util.functionByName(fd.id.name, ctx.program).get + val inv = TemplateInstantiator.getAllInvariants(model.get, + Map(origFd -> origFd.getTemplate)) + // record the inferred invariants + val ic = new InferenceCondition(Some(inv(origFd)), fd) + ic.time = if (first) Some(funcTime) else Some(0.0) + // update analyzed set + analyzedSet += (origFd -> ic) + first = false + } + val invs = TemplateInstantiator.getAllInvariants(model.get, + funsWithTemplates.collect { + case fd if fd.hasTemplate => fd -> fd.getTemplate + }.toMap) + if (ctx.modularlyAnalyze) { + // create a new program that has the inferred templates + val funToTmpl = prog.definedFunctions.collect { + case fd if !invs.contains(fd) && fd.hasTemplate => + fd -> fd.getTemplate + }.toMap + assignTemplateAndCojoinPost(funToTmpl, prog, invs) + } else + prog + case _ => + reporter.info("- Exhausted all templates, cannot infer invariants") + val ic = new InferenceCondition(None, origFun) + ic.time = Some(funcTime) + analyzedSet += (origFun -> ic) + prog + } + } else { + //nothing needs to be done here + reporter.info("Function does not have a body or postcondition") + prog + } + } else prog + } + analyzedSet + } +} diff --git a/src/main/scala/leon/invariant/engine/InferenceReport.scala b/src/main/scala/leon/invariant/engine/InferenceReport.scala new file mode 100644 index 0000000000000000000000000000000000000000..21756faac1fb1deae5bd626070f517a935905291 --- /dev/null +++ b/src/main/scala/leon/invariant/engine/InferenceReport.scala @@ -0,0 +1,84 @@ +/* Copyright 2009-2013 EPFL, Lausanne */ + +package leon +package invariant.engine + +import purescala.Definitions.FunDef +import verification._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.Definitions._ +import purescala.Common._ +import invariant.templateSolvers._ +import invariant.factories._ +import invariant.util._ +import invariant.structure._ +import leon.transformations.InstUtil +import leon.purescala.PrettyPrinter + + +class InferenceCondition(val invariant: Option[Expr], funDef: FunDef) + extends VC(BooleanLiteral(true), funDef, null) { + + var time : Option[Double] = None + + def status: String = invariant match { + case None => "unknown" + case Some(inv) => { + val prettyInv = simplifyArithmetic(InstUtil.replaceInstruVars(Util.multToTimes(inv),fd)) + PrettyPrinter(prettyInv) + } + } +} + +class InferenceReport(p: Program, fvcs: Map[FunDef, List[VC]])(implicit ctx: InferenceContext) + extends VerificationReport(p : Program, Map()) { + + import scala.math.Ordering.Implicits._ + val conditions : Seq[InferenceCondition] = + fvcs.flatMap(_._2.map(_.asInstanceOf[InferenceCondition])).toSeq.sortBy(vc => vc.fd.id.name) + + private def infoSep(size: Int) : String = "â•Ÿ" + ("┄" * size) + "â•¢\n" + private def infoFooter(size: Int) : String = "â•š" + ("â•" * size) + "â•" + private def infoHeader(size: Int) : String = ". ┌─────────â”\n" + + "â•”â•â•¡ Summary â•ž" + ("â•" * (size - 12)) + "â•—\n" + + "â•‘ └─────────┘" + (" " * (size - 12)) + "â•‘" + + private def infoLine(str: String, size: Int) : String = { + "â•‘ "+ str + (" " * (size - str.size - 2)) + " â•‘" + } + + private def fit(str : String, maxLength : Int) : String = { + if(str.length <= maxLength) { + str + } else { + str.substring(0, maxLength - 1) + "…" + } + } + + private def funName(fd: FunDef) = InstUtil.userFunctionName(fd) + + override def summaryString : String = if(conditions.size > 0) { + val maxTempSize = (conditions.map(_.status.size).max + 3) + val outputStrs = conditions.map(vc => { + val timeStr = vc.time.map(t => "%-3.3f".format(t)).getOrElse("") + "%-15s %s %-4s".format(fit(funName(vc.fd), 15), vc.status + (" "*(maxTempSize-vc.status.size)), timeStr) + }) + val summaryStr = { + val totalTime = conditions.foldLeft(0.0)((a, ic) => a + ic.time.getOrElse(0.0)) + val inferredConds = conditions.count((ic) => ic.invariant.isDefined) + "total: %-4d inferred: %-4d unknown: %-4d time: %-3.3f".format( + conditions.size, inferredConds, conditions.size - inferredConds, totalTime) + } + val entrySize = (outputStrs :+ summaryStr).map(_.size).max + 2 + + infoHeader(entrySize) + + outputStrs.map(str => infoLine(str, entrySize)).mkString("\n", "\n", "\n") + + infoSep(entrySize) + + infoLine(summaryStr, entrySize) + "\n" + + infoFooter(entrySize) + + } else { + "No user provided templates were solved." + } +} \ No newline at end of file diff --git a/src/main/scala/leon/invariant/engine/RefinementEngine.scala b/src/main/scala/leon/invariant/engine/RefinementEngine.scala new file mode 100644 index 0000000000000000000000000000000000000000..cce8a3cb45a27693824944ed32c190921691cd2c --- /dev/null +++ b/src/main/scala/leon/invariant/engine/RefinementEngine.scala @@ -0,0 +1,200 @@ +package leon +package invariant.engine + +import purescala.Common._ +import purescala.Definitions._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.Extractors._ +import purescala.Types._ +import java.io._ + +import invariant.templateSolvers._ +import invariant.factories._ +import invariant.util._ +import invariant.util.Util._ +import invariant.structure._ +import FunctionUtils._ + +//TODO: the parts of the code that collect the new head functions is ugly and has many side-effects. Fix this. +//TODO: there is a better way to compute heads, which is to consider all guards not previous seen +class RefinementEngine(ctx: InferenceContext, ctrTracker: ConstraintTracker) { + + val tru = BooleanLiteral(true) + val reporter = ctx.reporter + val prog = ctx.program + val cg = CallGraphUtil.constructCallGraph(prog) + + //this count indicates the number of times we unroll a recursive call + private val MAX_UNROLLS = 2 + + //debugging flags + private val dumpInlinedSummary = false + + //print flags + val verbose = false + + //the guards of disjuncts that were already processed + private var exploredGuards = Set[Variable]() + + //a set of calls that have not been unrolled (these are potential unroll candidates) + //However, these calls except those given by the unspecdCalls have been assumed specifications + private var headCalls = Map[FunDef, Set[Call]]() + def getHeads(fd: FunDef) = if (headCalls.contains(fd)) headCalls(fd) else Set() + def resetHeads(fd: FunDef, heads: Set[Call]) = { + if (headCalls.contains(fd)) { + headCalls -= fd + headCalls += (fd -> heads) + } else { + headCalls += (fd -> heads) + } + } + + /** + * This procedure refines the existing abstraction. + * Currently, the refinement happens by unrolling the head functions. + */ + def refineAbstraction(toRefineCalls: Option[Set[Call]]): Set[Call] = { + + ctrTracker.getFuncs.flatMap((fd) => { + val formula = ctrTracker.getVC(fd) + val disjuncts = formula.disjunctsInFormula + val newguards = formula.disjunctsInFormula.keySet.diff(exploredGuards) + exploredGuards ++= newguards + + val newheads = newguards.flatMap(g => disjuncts(g).collect { case c: Call => c }) + val allheads = getHeads(fd) ++ newheads + + //unroll each call in the head pointers and in toRefineCalls + val callsToProcess = if (toRefineCalls.isDefined) { + + //pick only those calls that have been least unrolled + val relevCalls = allheads.intersect(toRefineCalls.get) + var minCalls = Set[Call]() + var minUnrollings = MAX_UNROLLS + relevCalls.foreach((call) => { + val calldata = formula.callData(call) + val recInvokes = calldata.parents.count(_ == call.fi.tfd.fd) + if (recInvokes < minUnrollings) { + minUnrollings = recInvokes + minCalls = Set(call) + } else if (recInvokes == minUnrollings) { + minCalls += call + } + }) + minCalls + + } else allheads + + if (verbose) + reporter.info("Unrolling: " + callsToProcess.size + "/" + allheads.size) + + val unrolls = callsToProcess.foldLeft(Set[Call]())((acc, call) => { + + val calldata = formula.callData(call) + val recInvokes = calldata.parents.count(_ == call.fi.tfd.fd) + //if the call is not a recursive call, unroll it unconditionally + if (recInvokes == 0) { + unrollCall(call, formula) + acc + call + } else { + //if the call is recursive, unroll iff the number of times the recursive function occurs in the context is < MAX-UNROLL + if (recInvokes < MAX_UNROLLS) { + unrollCall(call, formula) + acc + call + } else { + //otherwise, do not unroll the call + acc + } + } + //TODO: are there better ways of unrolling ?? + }) + + //update the head functions + resetHeads(fd, allheads.diff(callsToProcess)) + unrolls + }).toSet + } + + def shouldCreateVC(recFun: FunDef): Boolean = { + if (ctrTracker.hasVC(recFun)) false + else { + //need not create vcs for theory operations + !recFun.isTheoryOperation && recFun.hasTemplate && + !recFun.annotations.contains("library") + } + } + + /** + * Returns a set of unrolled calls and a set of new head functions + * here we unroll the methods in the current abstraction by one step. + * This procedure has side-effects on 'headCalls' and 'callDataMap' + */ + def unrollCall(call: Call, formula: Formula) = { + val fi = call.fi + if (fi.tfd.fd.hasBody) { + + //freshen the body and the post + val isRecursive = cg.isRecursive(fi.tfd.fd) + if (isRecursive) { + val recFun = fi.tfd.fd + val recFunTyped = fi.tfd + + //check if we need to create a constraint tree for the call's target + if (shouldCreateVC(recFun)) { + + //create a new verification condition for this recursive function + reporter.info("Creating VC for " + recFun.id) + val freshBody = freshenLocals(matchToIfThenElse(recFun.body.get)) + val resvar = if (recFun.hasPostcondition) { + //create a new result variable here for the same reason as freshening the locals, + //which is to avoid variable capturing during unrolling + val origRes = getResId(recFun).get + Variable(FreshIdentifier(origRes.name, origRes.getType, true)) + } else { + //create a new resvar + Variable(FreshIdentifier("res", recFun.returnType, true)) + } + val plainBody = Equals(resvar, freshBody) + val bodyExpr = if (recFun.hasPrecondition) { + And(matchToIfThenElse(recFun.precondition.get), plainBody) + } else plainBody + + //note: here we are only adding the template as the postcondition + val idmap = Util.formalToActual(Call(resvar, FunctionInvocation(recFunTyped, recFun.params.map(_.toVariable)))) + val postTemp = replace(idmap, recFun.getTemplate) + val vcExpr = ExpressionTransformer.normalizeExpr(And(bodyExpr, Not(postTemp)), ctx.multOp) + ctrTracker.addVC(recFun, vcExpr) + } + + //Here, unroll the call into the caller tree + if (verbose) reporter.info("Unrolling " + Equals(call.retexpr, call.fi)) + inilineCall(call, formula) + } else { + //here we are unrolling a function without template + if (verbose) reporter.info("Unfolding " + Equals(call.retexpr, call.fi)) + inilineCall(call, formula) + } + } else Set() + } + + def inilineCall(call: Call, formula: Formula) = { + //here inline the body and conjoin it with the guard + val callee = call.fi.tfd.fd + + //Important: make sure we use a fresh body expression here + val freshBody = freshenLocals(matchToIfThenElse(callee.body.get)) + val calleeSummary = + Equals(Util.getFunctionReturnVariable(callee), freshBody) + val argmap1 = Util.formalToActual(call) + val inlinedSummary = ExpressionTransformer.normalizeExpr(replace(argmap1, calleeSummary), ctx.multOp) + + if (this.dumpInlinedSummary) + println("Inlined Summary: " + inlinedSummary) + + //conjoin the summary with the disjunct corresponding to the 'guard' + //note: the parents of the summary are the parents of the call plus the callee function + val calldata = formula.callData(call) + formula.conjoinWithDisjunct(calldata.guard, inlinedSummary, (callee +: calldata.parents)) + } +} diff --git a/src/main/scala/leon/invariant/engine/SpecInstatiator.scala b/src/main/scala/leon/invariant/engine/SpecInstatiator.scala new file mode 100644 index 0000000000000000000000000000000000000000..9961df29357d1de52c98fe135f106993ab1e388a --- /dev/null +++ b/src/main/scala/leon/invariant/engine/SpecInstatiator.scala @@ -0,0 +1,270 @@ +package leon +package invariant.engine +import z3.scala._ +import purescala.Common._ +import purescala.Definitions._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.Extractors._ +import purescala.Types._ +import java.io._ +import leon.invariant.templateSolvers.ExtendedUFSolver +import invariant._ +import scala.util.control.Breaks._ +import solvers._ +import scala.concurrent._ +import scala.concurrent.duration._ + +import invariant.templateSolvers._ +import invariant.factories._ +import invariant.util._ +import invariant.util.Util._ +import invariant.structure._ +import FunctionUtils._ + +class SpecInstantiator(ctx: InferenceContext, ctrTracker: ConstraintTracker) { + + val verbose = false + + protected val disableAxioms = false + protected val debugAxiomInstantiation = false + + val tru = BooleanLiteral(true) + val axiomFactory = new AxiomFactory(ctx) //handles instantiation of axiomatic specification + val program = ctx.program + + //the guards of the set of calls that were already processed + protected var exploredGuards = Set[Variable]() + + def instantiate() = { + val funcs = ctrTracker.getFuncs + + funcs.foreach((fd) => { + val formula = ctrTracker.getVC(fd) + val disjuncts = formula.disjunctsInFormula + val newguards = disjuncts.keySet.diff(exploredGuards) + exploredGuards ++= newguards + + val newcalls = newguards.flatMap(g => disjuncts(g).collect { case c: Call => c }) + instantiateSpecs(formula, newcalls, funcs.toSet) + + if (!disableAxioms) { + //remove all multiplication if "withmult" is specified + val relavantCalls = if (ctx.withmult) { + newcalls.filter(call => !Util.isMultFunctions(call.fi.tfd.fd)) + } else newcalls + instantiateAxioms(formula, relavantCalls) + } + }) + } + + /** + * This function refines the formula by assuming the specifications/templates for calls in the formula + * Here, we assume (pre => post ^ template) for each call (templates only for calls with VC) + * Important: adding templates for 'newcalls' of the previous iterations is empirically more effective + */ + //a set of calls for which templates or specifications have not been assumed + private var untemplatedCalls = Map[FunDef, Set[Call]]() + def getUntempCalls(fd: FunDef) = if (untemplatedCalls.contains(fd)) untemplatedCalls(fd) else Set() + def resetUntempCalls(fd: FunDef, calls: Set[Call]) = { + if (untemplatedCalls.contains(fd)) { + untemplatedCalls -= fd + untemplatedCalls += (fd -> calls) + } else { + untemplatedCalls += (fd -> calls) + } + } + + def instantiateSpecs(formula: Formula, calls: Set[Call], funcsWithVC: Set[FunDef]) = { + + //assume specifications + calls.foreach((call) => { + //first get the spec for the call if it exists + val spec = specForCall(call) + if (spec.isDefined && spec.get != tru) { + val cdata = formula.callData(call) + formula.conjoinWithDisjunct(cdata.guard, spec.get, cdata.parents) + } + }) + + //try to assume templates for all the current un-templated calls + var newUntemplatedCalls = Set[Call]() + getUntempCalls(formula.fd).foreach((call) => { + //first get the template for the call if one needs to be added + if (funcsWithVC.contains(call.fi.tfd.fd)) { + templateForCall(call) match { + case Some(temp) => + val cdata = formula.callData(call) + formula.conjoinWithDisjunct(cdata.guard, temp, cdata.parents) + case _ => + ; // here there is no template for the call + } + } else { + newUntemplatedCalls += call + } + }) + resetUntempCalls(formula.fd, newUntemplatedCalls ++ calls) + } + + def specForCall(call: Call): Option[Expr] = { + val argmap = Util.formalToActual(call) + val callee = call.fi.tfd.fd + if (callee.hasPostcondition) { + //get the postcondition without templates + val post = callee.getPostWoTemplate + val freshPost = freshenLocals(matchToIfThenElse(post)) + + val spec = if (callee.hasPrecondition) { + val freshPre = freshenLocals(matchToIfThenElse(callee.precondition.get)) + Implies(freshPre, freshPost) + } else { + freshPost + } + val inlinedSpec = ExpressionTransformer.normalizeExpr(replace(argmap, spec), ctx.multOp) + Some(inlinedSpec) + } else { + None + } + } + + def templateForCall(call: Call): Option[Expr] = { + val callee = call.fi.tfd.fd + if (callee.hasTemplate) { + val argmap = Util.formalToActual(call) + val tempExpr = replace(argmap, callee.getTemplate) + val template = if (callee.hasPrecondition) { + val freshPre = replace(argmap, freshenLocals(matchToIfThenElse(callee.precondition.get))) + Implies(freshPre, tempExpr) + } else { + tempExpr + } + //flatten functions + //TODO: should we freshen locals here ?? + Some(ExpressionTransformer.normalizeExpr(template, ctx.multOp)) + } else None + } + + //axiomatic specification + protected var axiomRoots = Map[Seq[Call], Variable]() //a mapping from axioms keys (a sequence of calls) to the guards + def instantiateAxioms(formula: Formula, calls: Set[Call]) = { + + val debugSolver = if (this.debugAxiomInstantiation) { + val sol = new ExtendedUFSolver(ctx.leonContext, program) + sol.assertCnstr(formula.toExpr) + Some(sol) + } else None + + val inst1 = instantiateUnaryAxioms(formula, calls) + val inst2 = instantiateBinaryAxioms(formula, calls) + val axiomInsts = inst1 ++ inst2 + + Stats.updateCounterStats(Util.atomNum(Util.createAnd(axiomInsts)), "AxiomBlowup", "VC-refinement") + if(verbose) ctx.reporter.info("Number of axiom instances: " + axiomInsts.size) + + if (this.debugAxiomInstantiation) { + println("Instantianting axioms over: " + calls) + println("Instantiated Axioms: ") + axiomInsts.foreach((ainst) => { + println(ainst) + debugSolver.get.assertCnstr(ainst) + val res = debugSolver.get.check + res match { + case Some(false) => + println("adding axiom made formula unsat!!") + case _ => ; + } + }) + debugSolver.get.free + } + } + + //this code is similar to assuming specifications + def instantiateUnaryAxioms(formula: Formula, calls: Set[Call]) = { + val axioms = calls.collect { + case call @ _ if axiomFactory.hasUnaryAxiom(call) => { + val (ant, conseq) = axiomFactory.unaryAxiom(call) + val axiomInst = Implies(ant, conseq) + val nnfAxiom = ExpressionTransformer.normalizeExpr(axiomInst, ctx.multOp) + val cdata = formula.callData(call) + formula.conjoinWithDisjunct(cdata.guard, nnfAxiom, cdata.parents) + axiomInst + } + } + axioms.toSeq + } + + /** + * Here, we assume that axioms do not introduce calls. + * If this does not hold, 'guards' have to be used while instantiating axioms so as + * to compute correct verification conditions. + * TODO: Use least common ancestor etc. to avoid axiomatizing calls along different disjuncts + * TODO: can we avoid axioms like (a <= b ^ x<=y => p <= q), (x <= y ^ a<=b => p <= q), ... + * TODO: can we have axiomatic specifications relating two different functions ? + */ + protected var binaryAxiomCalls = Map[FunDef, Set[Call]]() //calls with axioms so far seen + def getBinaxCalls(fd: FunDef) = if (binaryAxiomCalls.contains(fd)) binaryAxiomCalls(fd) else Set[Call]() + def appendBinaxCalls(fd: FunDef, calls: Set[Call]) = { + if (binaryAxiomCalls.contains(fd)) { + val oldcalls = binaryAxiomCalls(fd) + binaryAxiomCalls -= fd + binaryAxiomCalls += (fd -> (oldcalls ++ calls)) + } else { + binaryAxiomCalls += (fd -> calls) + } + } + + def instantiateBinaryAxioms(formula: Formula, calls: Set[Call]) = { + + val newCallsWithAxioms = calls.filter(axiomFactory.hasBinaryAxiom _) + + def isInstantiable(call1: Call, call2: Call): Boolean = { + //important: check if the two calls refer to the same function + (call1.fi.tfd.id == call2.fi.tfd.id) && (call1 != call2) + } + + val product = Util.cross[Call, Call](newCallsWithAxioms, getBinaxCalls(formula.fd), Some(isInstantiable)).flatMap( + p => Seq((p._1, p._2), (p._2, p._1))) ++ + Util.cross[Call, Call](newCallsWithAxioms, newCallsWithAxioms, Some(isInstantiable)).map(p => (p._1, p._2)) + + //ctx.reporter.info("# of pairs with axioms: "+product.size) + //Stats.updateCumStats(product.size, "Call-pairs-with-axioms") + + val addedAxioms = product.flatMap(pair => { + //union the parents of the two calls + val cdata1 = formula.callData(pair._1) + val cdata2 = formula.callData(pair._2) + val parents = cdata1.parents ++ cdata2.parents + val axiomInsts = axiomFactory.binaryAxiom(pair._1, pair._2) + + axiomInsts.foldLeft(Seq[Expr]())((acc, inst) => { + val (ant, conseq) = inst + val axiom = Implies(ant, conseq) + val nnfAxiom = ExpressionTransformer.normalizeExpr(axiom, ctx.multOp) + val (axroot, _) = formula.conjoinWithRoot(nnfAxiom, parents) + //important: here we need to update the axiom roots + axiomRoots += (Seq(pair._1, pair._2) -> axroot) + acc :+ axiom + }) + }) + appendBinaxCalls(formula.fd, newCallsWithAxioms) + addedAxioms + } + + /** + * Note: taking a formula as input may not be necessary. We can store it as a part of the state + * TODO: can we use transitivity here to optimize ? + */ + def axiomsForCalls(formula: Formula, calls: Set[Call], model: Model): Seq[Constraint] = { + //note: unary axioms need not be instantiated + //consider only binary axioms + (for (x <- calls; y <- calls) yield (x, y)).foldLeft(Seq[Constraint]())((acc, pair) => { + val (c1, c2) = pair + if (c1 != c2) { + val axRoot = axiomRoots.get(Seq(c1, c2)) + if (axRoot.isDefined) + acc ++ formula.pickSatDisjunct(axRoot.get, model) + else acc + } else acc + }) + } +} \ No newline at end of file diff --git a/src/main/scala/leon/invariant/engine/TemplateEnumerator.scala b/src/main/scala/leon/invariant/engine/TemplateEnumerator.scala new file mode 100644 index 0000000000000000000000000000000000000000..17bfdee4b4d761f3d974abac06cc637166f46e98 --- /dev/null +++ b/src/main/scala/leon/invariant/engine/TemplateEnumerator.scala @@ -0,0 +1,195 @@ +package leon +package invariant.engine + +import z3.scala._ +import purescala.Common._ +import purescala.Definitions._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.Extractors._ +import purescala.Types._ +import scala.collection.mutable.{ Set => MutableSet } +import java.io._ +import scala.collection.mutable.{ Set => MutableSet } +import scala.collection.mutable.{Set => MutableSet} + +import invariant.templateSolvers._ +import invariant.factories._ +import invariant.util._ +import invariant.structure._ + +/** + * An enumeration based template generator. + * Enumerates all numerical terms in some order (this enumeration is incomplete for termination). + * TODO: Feature: + * (a) allow template functions and functions with template variables ? + * (b) should we unroll algebraic data types ? + * + * The following function may potentially have complexity O(n^i) where 'n' is the number of functions + * and 'i' is the increment step + * TODO: optimize the running and also reduce the size of the input templates + * + * For now this is incomplete + */ +class TemplateEnumerator(ctx: InferenceContext) extends TemplateGenerator { + val prog = ctx.program + val reporter = ctx.reporter + + //create a call graph for the program + //Caution: this call-graph could be modified later while call the 'getNextTemplate' method + private val callGraph = { + val cg = CallGraphUtil.constructCallGraph(prog) + cg + } + + private var tempEnumMap = Map[FunDef, FunctionTemplateEnumerator]() + + def getNextTemplate(fd : FunDef) : Expr = { + if(tempEnumMap.contains(fd)) tempEnumMap(fd).getNextTemplate() + else { + val enumerator = new FunctionTemplateEnumerator(fd, prog, ctx.enumerationRelation, callGraph, reporter) + tempEnumMap += (fd -> enumerator) + enumerator.getNextTemplate() + } + } +} + +/** + * This class manages templates for the given function + * 'op' is a side-effects parameter + * Caution: The methods of this class has side-effects on the 'callGraph' parameter + */ +class FunctionTemplateEnumerator(rootFun: FunDef, prog: Program, op: (Expr,Expr) => Expr, + callGraph : CallGraph, reporter: Reporter) { + private val MAX_INCREMENTS = 2 + private val zero = InfiniteIntegerLiteral(0) + //using default op as <= or == (manually adjusted) + //private val op = LessEquals + //LessThan + //LessEquals + //Equals.apply _ + private var currTemp: Expr = null + private var incrStep: Int = 0 + private var typeTermMap = Map[TypeTree, MutableSet[Expr]]() + private var ttCurrent = Map[TypeTree, MutableSet[Expr]]() + + //get all functions that are not the current function. + //the value of the current function is given by res and its body + //itself characterizes how it is defined recursively w.r.t itself. + //Need to also avoid mutual recursion as it may lead to proving of invalid facts + private val fds = prog.definedFunctions.filter(_ != rootFun) + + def getNextTemplate(): Expr = { + //println("Getting next template for function: "+fd.id) + + if (incrStep == MAX_INCREMENTS){ + //exhausted the templates, so return + op(currTemp, zero) + } + else { + + incrStep += 1 + + var newTerms = Map[TypeTree, MutableSet[Expr]]() + if (currTemp == null) { + //initialize + //add all the arguments and results of fd to 'typeTermMap' + rootFun.params.foreach((vardecl) => { + val tpe = vardecl.getType + val v = vardecl.id.toVariable + if (newTerms.contains(tpe)) { + newTerms(tpe).add(v) + } else { + newTerms += (tpe -> MutableSet(v)) + } + }) + + val resVar = Util.getFunctionReturnVariable(rootFun) + if (newTerms.contains(rootFun.returnType)) { + newTerms(rootFun.returnType).add(resVar) + } else { + newTerms += (rootFun.returnType -> MutableSet(resVar)) + } + + //also 'assignCurrTemp' to a template variable + currTemp = TemplateIdFactory.freshTemplateVar() + } else { + + //apply the user-defined functions to the compatible terms in typeTermMap + //Important: Make sure that the recursive calls are not introduced in the templates + //TODO: this is a hack to prevent infinite recursion in specification. However, it is not clear if this will prevent inferrence of + //any legitimate specifications (however this can be modified). + fds.foreach((fun) => { + //Check if adding a call from 'rootFun' to 'fun' creates a mutual recursion by checking if + //'fun' transitively calls 'rootFun' + if (fun != rootFun && !callGraph.transitivelyCalls(fun, rootFun)) { + + //check if every argument has at least one satisfying assignment? + if (fun.params.filter((vardecl) => !ttCurrent.contains(vardecl.getType)).isEmpty) { + + //here compute all the combinations + val newcalls = generateFunctionCalls(fun) + if (newTerms.contains(fun.returnType)) { + newTerms(fun.returnType) ++= newcalls + } else { + var muset = MutableSet[Expr]() + muset ++= newcalls + newTerms += (fun.returnType -> muset) + } + } + } + + }) + + } + //add all the newly generated expression to the typeTermMap + ttCurrent = newTerms + typeTermMap ++= newTerms + + //statistics + reporter.info("- Number of new terms enumerated: " + newTerms.size) + + //return all the integer valued terms of newTerms + //++ newTerms.getOrElse(Int32Type, Seq[Expr]()) (for now not handling int 32 terms) + val numericTerms = (newTerms.getOrElse(RealType, Seq[Expr]()) ++ newTerms.getOrElse(IntegerType, Seq[Expr]())).toSeq + if(!numericTerms.isEmpty) { + //create a linear combination of intTerms + val newTemp = numericTerms.foldLeft(null: Expr)((acc, t: Expr) => { + val summand = Times(t, TemplateIdFactory.freshTemplateVar() : Expr) + if (acc == null) summand + else + Plus(summand, acc) + }) + //add newTemp to currTemp + currTemp = Plus(newTemp, currTemp) + + //get all the calls in the 'newTemp' and add edges from 'rootFun' to the callees to the call-graph + val callees = CallGraphUtil.getCallees(newTemp) + callees.foreach(callGraph.addEdgeIfNotPresent(rootFun, _)) + } + op(currTemp, zero) + } + } + + /** + * Generate a set of function calls of fun using the terms in ttCurrent + */ + def generateFunctionCalls(fun: FunDef): Set[Expr] = { + /** + * To be called with argIndex of zero and an empty argList + */ + def genFunctionCallsRecur(argIndex: Int, argList: Seq[Expr]): Set[Expr] = { + if (argIndex == fun.params.size) { + //create a call using argList + //TODO: how should we handle generics + Set(FunctionInvocation(TypedFunDef(fun, fun.tparams.map(_.tp)), argList)) + } else { + val arg = fun.params(argIndex) + val tpe = arg.getType + ttCurrent(tpe).foldLeft(Set[Expr]())((acc, term) => acc ++ genFunctionCallsRecur(argIndex + 1, argList :+ term)) + } + } + + genFunctionCallsRecur(0, Seq()) + } +} \ No newline at end of file diff --git a/src/main/scala/leon/invariant/engine/UnfoldingTemplateSolver.scala b/src/main/scala/leon/invariant/engine/UnfoldingTemplateSolver.scala new file mode 100644 index 0000000000000000000000000000000000000000..a933be56a28cde78336b336c829a3504d6fc66c2 --- /dev/null +++ b/src/main/scala/leon/invariant/engine/UnfoldingTemplateSolver.scala @@ -0,0 +1,259 @@ +package leon +package invariant.engine + +import z3.scala._ +import purescala.Common._ +import purescala.Definitions._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.Extractors._ +import purescala.Types._ +import solvers._ +import solvers.z3.FairZ3Solver +import java.io._ +import purescala.ScalaPrinter +import verification._ +import scala.reflect.runtime.universe +import invariant.templateSolvers._ +import invariant.factories._ +import invariant.util._ +import invariant.util.Util._ +import invariant.structure._ +import transformations._ +import FunctionUtils._ +import leon.invariant.templateSolvers.ExtendedUFSolver + +/** + * @author ravi + * This phase performs automatic invariant inference. + * TODO: Do we need to also assert that time is >= 0 + */ +case class InferResult(res: Boolean, model: Option[Model], inferredFuncs: List[FunDef]) { +} + +trait FunctionTemplateSolver { + def apply() : Option[InferResult] +} + +class UnfoldingTemplateSolver(ctx: InferenceContext, rootFd: FunDef) extends FunctionTemplateSolver { + + val reporter = ctx.reporter + val program = ctx.program + val debugVCs = false + + lazy val constTracker = new ConstraintTracker(ctx, rootFd) + lazy val templateSolver = TemplateSolverFactory.createTemplateSolver(ctx, constTracker, rootFd) + + def constructVC(funDef: FunDef): (Expr, Expr) = { + val body = funDef.body.get + val Lambda(Seq(ValDef(resid, _)), _) = funDef.postcondition.get + val resvar = resid.toVariable + + val simpBody = matchToIfThenElse(body) + val plainBody = Equals(resvar, simpBody) + val bodyExpr = if (funDef.hasPrecondition) { + And(matchToIfThenElse(funDef.precondition.get), plainBody) + } else plainBody + + val fullPost = matchToIfThenElse(if (funDef.hasTemplate) + if (ctx.toVerifyPostFor.contains(funDef.id.name)) + And(funDef.getPostWoTemplate, funDef.getTemplate) + else + funDef.getTemplate + else + if (ctx.toVerifyPostFor.contains(funDef.id.name)) + funDef.getPostWoTemplate + else + BooleanLiteral(true)) + + (bodyExpr, fullPost) + } + + def solveParametricVC(vc: Expr) = { + val vcExpr = ExpressionTransformer.normalizeExpr(vc, ctx.multOp) + //for debugging + if (debugVCs) reporter.info("flattened VC: " + ScalaPrinter(vcExpr)) + + // initialize the constraint tracker + constTracker.addVC(rootFd, vcExpr) + + var refinementStep: Int = 0 + var toRefineCalls: Option[Set[Call]] = None + var infRes: Option[InferResult] = None + do { + Stats.updateCounter(1, "VC-refinement") + /* uncomment if we want to bound refinements + * if (refinementStep >= 5) + throw new IllegalStateException("Done 4 refinements")*/ + val refined = + if (refinementStep >= 1) { + reporter.info("- More unrollings for invariant inference") + + val toUnrollCalls = if (ctx.targettedUnroll) toRefineCalls else None + val unrolledCalls = constTracker.refineVCs(toUnrollCalls) + if (unrolledCalls.isEmpty) { + reporter.info("- Cannot do more unrollings, reached unroll bound") + false + } else true + } else { + constTracker.initialize + true + } + refinementStep += 1 + infRes = + if (!refined) + Some(InferResult(false, None, List())) + else { + //solve for the templates in this unroll step + templateSolver.solveTemplates() match { + case (Some(model), callsInPath) => + toRefineCalls = callsInPath + //Validate the model here + instantiateAndValidateModel(model, constTracker.getFuncs.toSeq) + Some(InferResult(true, Some(model), + constTracker.getFuncs.toList)) + case (None, callsInPath) => + toRefineCalls = callsInPath + //here, we do not know if the template is solvable or not, we need to do more unrollings. + None + } + } + } while (!infRes.isDefined) + infRes + } + + def apply() = { + //create a body and post of the function + val (bodyExpr, fullPost) = constructVC(rootFd) + if (fullPost == tru) + Some(InferResult(true, Some(Model.empty), List())) + else + solveParametricVC(And(bodyExpr, Not(fullPost))) + } + + def instantiateModel(model: Model, funcs: Seq[FunDef]) = { + funcs.collect { + case fd if fd.hasTemplate => + fd -> fd.getTemplate + }.toMap + } + + def instantiateAndValidateModel(model: Model, funcs: Seq[FunDef]) = { + val templates = instantiateModel(model, funcs) + val sols = TemplateInstantiator.getAllInvariants(model, templates) + + var output = "Invariants for Function: " + rootFd.id + "\n" + sols foreach { + case (fd, inv) => + val simpInv = simplifyArithmetic(InstUtil.replaceInstruVars(multToTimes(inv), fd)) + reporter.info("- Found inductive invariant: " + fd.id + " --> " + ScalaPrinter(simpInv)) + output += fd.id + " --> " + simpInv + "\n" + } + SpecificStats.addOutput(output) + + reporter.info("- Verifying Invariants... ") + val verifierRes = verifyInvariant(sols) + val finalRes = verifierRes._1 match { + case Some(false) => + reporter.info("- Invariant verified") + sols + case Some(true) => + reporter.error("- Invalid invariant, model: " + verifierRes._2) + throw new IllegalStateException("") + case _ => + //the solver timed out here + reporter.error("- Unable to prove or disprove invariant, the invariant is probably true") + sols + } + finalRes + } + + /** + * This function creates a new program with each function postcondition strengthened by + * the inferred postcondition + */ + def verifyInvariant(newposts: Map[FunDef, Expr]): (Option[Boolean], Model) = { + //create a fundef for each function in the program + //note: mult functions are also copied + val newFundefs = program.definedFunctions.collect { + case fd @ _ => { //if !Util.isMultFunctions(fd) + val newfd = new FunDef(FreshIdentifier(fd.id.name, Untyped, false), + fd.tparams, fd.returnType, fd.params) + (fd, newfd) + } + }.toMap + //note: we are not replacing "mult" function by "Times" + val replaceFun = (e: Expr) => e match { + case fi @ FunctionInvocation(tfd1, args) if newFundefs.contains(tfd1.fd) => + FunctionInvocation(TypedFunDef(newFundefs(tfd1.fd), tfd1.tps), args) + case _ => e + } + //create a body, pre, post for each newfundef + newFundefs.foreach((entry) => { + val (fd, newfd) = entry + //add a new precondition + newfd.precondition = + if (fd.precondition.isDefined) + Some(simplePostTransform(replaceFun)(fd.precondition.get)) + else None + + //add a new body + newfd.body = if (fd.hasBody) + Some(simplePostTransform(replaceFun)(fd.body.get)) + else None + + //add a new postcondition + val newpost = if (newposts.contains(fd)) { + val inv = newposts(fd) + if (fd.postcondition.isDefined) { + val Lambda(resultBinder, _) = fd.postcondition.get + Some(Lambda(resultBinder, And(fd.getPostWoTemplate, inv))) + } else { + //replace #res in the invariant by a new result variable + val resvar = FreshIdentifier("res", fd.returnType, true) + // FIXME: Is this correct (ResultVariable(fd.returnType) -> resvar.toVariable)) + val ninv = replace(Map(ResultVariable(fd.returnType) -> resvar.toVariable), inv) + Some(Lambda(Seq(ValDef(resvar, Some(fd.returnType))), ninv)) + } + } else if(fd.postcondition.isDefined) { + val Lambda(resultBinder, _) = fd.postcondition.get + Some(Lambda(resultBinder, fd.getPostWoTemplate)) + } else None + + newfd.postcondition = if (newpost.isDefined) { + val Lambda(resultBinder, pexpr) = newpost.get + // Some((resvar, simplePostTransform(replaceFun)(pexpr))) + Some(Lambda(resultBinder, simplePostTransform(replaceFun)(pexpr))) + } else None + newfd.addFlags(fd.flags) + }) + + val augmentedProg = Util.copyProgram(program, (defs: Seq[Definition]) => defs.collect { + case fd: FunDef if (newFundefs.contains(fd)) => newFundefs(fd) + case d if (!d.isInstanceOf[FunDef]) => d + }) + //convert the program back to an integer program if necessary + val (newprog, newroot) = if (ctx.usereals) { + val realToIntconverter = new RealToIntProgram() + val intProg = realToIntconverter(augmentedProg) + (intProg, realToIntconverter.mappedFun(newFundefs(rootFd))) + } else { + (augmentedProg, newFundefs(rootFd)) + } + //println("New Root: "+newroot) + import leon.solvers.z3._ + val dummySolFactory = new leon.solvers.SolverFactory[ExtendedUFSolver] { + def getNewSolver() = new ExtendedUFSolver(ctx.leonContext, program) + } + val vericontext = VerificationContext(ctx.leonContext, newprog, dummySolFactory, reporter) + val defaultTactic = new DefaultTactic(vericontext) + val vc = defaultTactic.generatePostconditions(newroot)(0) + + val verifyTimeout = 5 + val fairZ3 = new SimpleSolverAPI( + new TimeoutSolverFactory(SolverFactory(() => new FairZ3Solver(ctx.leonContext, newprog) with TimeoutSolver), + verifyTimeout * 1000)) + val sat = fairZ3.solveSAT(Not(vc.condition)) + sat + } +} diff --git a/src/main/scala/leon/invariant/factories/AxiomFactory.scala b/src/main/scala/leon/invariant/factories/AxiomFactory.scala new file mode 100644 index 0000000000000000000000000000000000000000..27c7e44c708e2d3126166afc7417ddfab858bc1d --- /dev/null +++ b/src/main/scala/leon/invariant/factories/AxiomFactory.scala @@ -0,0 +1,100 @@ +package leon +package invariant.factories + +import z3.scala._ +import purescala.Common._ +import purescala.Definitions._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.Extractors._ +import purescala.Types._ +import java.io._ +import leon.invariant._ +import scala.util.control.Breaks._ +import scala.concurrent._ +import scala.concurrent.duration._ + +import invariant.engine._ +import invariant.util._ +import invariant.structure._ +import FunctionUtils._ + +class AxiomFactory(ctx : InferenceContext) { + + val tru = BooleanLiteral(true) + + //Add more axioms here, if necessary + def hasUnaryAxiom(call: Call) : Boolean = { + //important: here we need to avoid applying commutativity on the calls produced by axioms instantiation + call.fi.tfd.fd.isCommutative + } + + def hasBinaryAxiom(call: Call) : Boolean = { + val callee = call.fi.tfd.fd + (callee.isMonotonic || callee.isDistributive) + } + + def unaryAxiom(call: Call) : (Expr,Expr) = { + val callee = call.fi.tfd.fd + val tfd = call.fi.tfd + + if(callee.isCommutative) { + //note: commutativity is defined only for binary operations + val Seq(a1, a2) = call.fi.args + val newret = TVarFactory.createTemp("cm").toVariable + val newfi = FunctionInvocation(tfd,Seq(a2,a1)) + val newcall = Call(newret,newfi) + (tru, And(newcall.toExpr, Equals(newret, call.retexpr))) + } else + throw new IllegalStateException("Call does not have unary axiom: "+call) + } + + def binaryAxiom(call1: Call, call2: Call): Seq[(Expr,Expr)] = { + + if (call1.fi.tfd.id != call2.fi.tfd.id) + throw new IllegalStateException("Instantiating binary axiom on calls to different functions: " + call1 + "," + call2) + + if (!hasBinaryAxiom(call1)) + throw new IllegalStateException("Call does not have binary axiom: " + call1) + + val callee = call1.fi.tfd.fd + //monotonicity + var axioms = if (callee.isMonotonic) { + Seq(monotonizeCalls(call1, call2)) + } else Seq() + + //distributivity + axioms ++= (if (callee.isDistributive) { + //println("Applying distributivity on: "+(call1,call2)) + Seq(undistributeCalls(call1, call2)) + } else Seq()) + + axioms + } + + def monotonizeCalls(call1: Call, call2: Call): (Expr,Expr) = { + val ants = (call1.fi.args zip call2.fi.args).foldLeft(Seq[Expr]())((acc, pair) => { + val lesse = LessEquals(pair._1, pair._2) + lesse +: acc + }) + val conseq = LessEquals(call1.retexpr, call2.retexpr) + (Util.createAnd(ants), conseq) + } + + //this is applicable only to binary operations + def undistributeCalls(call1: Call, call2: Call): (Expr,Expr) = { + val fd = call1.fi.tfd.fd + val tfd = call1.fi.tfd + + val Seq(a1,b1) = call1.fi.args + val Seq(a2,b2) = call2.fi.args + val r1 = call1.retexpr + val r2 = call2.retexpr + + val dret1 = TVarFactory.createTemp("dt", IntegerType).toVariable + val dret2 = TVarFactory.createTemp("dt", IntegerType).toVariable + val dcall1 = Call(dret1, FunctionInvocation(tfd,Seq(a2,Plus(b1,b2)))) + val dcall2 = Call(dret2, FunctionInvocation(tfd,Seq(Plus(a1,a2),b2))) + (LessEquals(b1,b2), And(LessEquals(Plus(r1,r2),dret2), dcall2.toExpr)) + } +} diff --git a/src/main/scala/leon/invariant/factories/TemplateFactory.scala b/src/main/scala/leon/invariant/factories/TemplateFactory.scala new file mode 100644 index 0000000000000000000000000000000000000000..68080a79e8106decdeb6c809d0601c10be637e68 --- /dev/null +++ b/src/main/scala/leon/invariant/factories/TemplateFactory.scala @@ -0,0 +1,154 @@ +package leon +package invariant.factories + +import z3.scala._ +import purescala.Common._ +import purescala.Definitions._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.Extractors._ +import purescala.Types._ +import java.io._ +import scala.collection.mutable.{ Map => MutableMap } +import invariant._ +import scala.collection.mutable.{Map => MutableMap} + +import invariant.engine._ +import invariant.util._ +import invariant.structure._ +import FunctionUtils._ + +object TemplateIdFactory { + //a set of template ids + private var ids = Set[Identifier]() + + def getTemplateIds : Set[Identifier] = ids + + def freshIdentifier(name : String = "", idType: TypeTree = RealType) : Identifier = { + val idname = if(name.isEmpty()) "a?" + else name + "?" + val freshid = FreshIdentifier(idname, idType, true) + ids += freshid + freshid + } + + def copyIdentifier(id: Identifier) : Identifier = { + val freshid = FreshIdentifier(id.name, RealType, false) + ids += freshid + freshid + } + + /** + * Template variables have real type + */ + def IsTemplateIdentifier(id : Identifier) : Boolean = { + ids.contains(id) + } + + def IsTemplateVar(v : Variable) : Boolean = { + IsTemplateIdentifier(v.id) + } + + def freshTemplateVar(name : String= "") : Variable = { + Variable(freshIdentifier(name)) + } +} + +trait TemplateGenerator { + def getNextTemplate(fd : FunDef): Expr +} + +/** + * Templates are expressions with template variables. + * The program variables that can be free in the templates are only the arguments and + * the result variable. + * Note: the program logic depends on the mutability here. + */ +class TemplateFactory(tempGen : Option[TemplateGenerator], prog: Program, reporter : Reporter) { + + //a mapping from function definition to the template + private var templateMap = { + //initialize the template map with predefined user maps + var muMap = MutableMap[FunDef, Expr]() + Util.functionsWOFields(prog.definedFunctions).foreach { fd => + val tmpl = fd.template + if (tmpl.isDefined) { + muMap.update(fd, tmpl.get) + } + } + muMap + } + + def setTemplate(fd:FunDef, tempExpr :Expr) = { + templateMap += (fd -> tempExpr) + } + + /** + * This is the default template generator. + * + */ + def getDefaultTemplate(fd : FunDef): Expr = { + + //just consider all the arguments, return values that are integers + val baseTerms = fd.params.filter((vardecl) => Util.isNumericType(vardecl.getType)).map(_.toVariable) ++ + (if(Util.isNumericType(fd.returnType)) Seq(Util.getFunctionReturnVariable(fd)) + else Seq()) + + val lhs = baseTerms.foldLeft(TemplateIdFactory.freshTemplateVar() : Expr)((acc, t)=> { + Plus(Times(TemplateIdFactory.freshTemplateVar(),t),acc) + }) + val tempExpr = LessEquals(lhs,InfiniteIntegerLiteral(0)) + tempExpr + } + + + /** + * Constructs a template using a mapping from the formals to actuals. + * Uses default template if a template does not exist for the function and no template generator is provided. + * Otherwise, use the provided template generator + */ + var refinementSet = Set[FunDef]() + def constructTemplate(argmap: Map[Expr,Expr], fd: FunDef): Expr = { + + //initialize the template for the function + if (!templateMap.contains(fd)) { + if(!tempGen.isDefined) templateMap += (fd -> getDefaultTemplate(fd)) + else { + templateMap += (fd -> tempGen.get.getNextTemplate(fd)) + refinementSet += fd + //for information + reporter.info("- Template generated for function "+fd.id+" : "+templateMap(fd)) + } + } + replace(argmap,templateMap(fd)) + } + + + /** + * Refines the templates of the functions that were assigned templates using the template generator. + */ + def refineTemplates(): Boolean = { + + if(tempGen.isDefined) { + var modifiedTemplate = false + refinementSet.foreach((fd) => { + val oldTemp = templateMap(fd) + val newTemp = tempGen.get.getNextTemplate(fd) + + if (oldTemp != newTemp) { + modifiedTemplate = true + templateMap.update(fd, newTemp) + reporter.info("- New template for function " + fd.id + " : " + newTemp) + } + }) + modifiedTemplate + } else false + } + + def getTemplate(fd : FunDef) : Option[Expr] = { + templateMap.get(fd) + } + + def getFunctionsWithTemplate : Seq[FunDef] = templateMap.keys.toSeq + +} \ No newline at end of file diff --git a/src/main/scala/leon/invariant/factories/TemplateInstantiator.scala b/src/main/scala/leon/invariant/factories/TemplateInstantiator.scala new file mode 100644 index 0000000000000000000000000000000000000000..5944f68bd9faa11ddfa5d26d8d057a086aea080c --- /dev/null +++ b/src/main/scala/leon/invariant/factories/TemplateInstantiator.scala @@ -0,0 +1,134 @@ +package leon +package invariant.factories + +import z3.scala._ +import purescala._ +import purescala.Common._ +import purescala.Definitions._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.Extractors._ +import purescala.Types._ +import java.io._ +import invariant.engine._ +import invariant.util._ +import invariant.structure._ +import leon.solvers.Model +import leon.invariant.util.RealValuedExprEvaluator + +object TemplateInstantiator { + /** + * Computes the invariant for all the procedures given a mapping for the + * template variables. + * (Undone) If the mapping does not have a value for an id, then the id is bound to the simplest value + */ + def getAllInvariants(model: Model, templates :Map[FunDef, Expr]): Map[FunDef, Expr] = { + val invs = templates.map((pair) => { + val (fd, t) = pair + //flatten the template + val freevars = variablesOf(t) + val template = ExpressionTransformer.FlattenFunction(t) + + val tempvars = Util.getTemplateVars(template) + val tempVarMap: Map[Expr, Expr] = tempvars.map((v) => { + (v, model(v.id)) + }).toMap + + val instTemplate = instantiate(template, tempVarMap) + //now unflatten it + val comprTemp = ExpressionTransformer.unFlatten(instTemplate, freevars) + (fd, comprTemp) + }) + invs + } + + /** + * Instantiates templated subexpressions of the given expression (expr) using the given mapping for the template variables. + * The instantiation also takes care of converting the rational coefficients to integer coefficients. + */ + def instantiate(expr: Expr, tempVarMap: Map[Expr, Expr]): Expr = { + //do a simple post transform and replace the template vars by their values + val inv = simplePostTransform((tempExpr: Expr) => tempExpr match { + case e @ Operator(Seq(lhs, rhs), op) if ((e.isInstanceOf[Equals] || e.isInstanceOf[LessThan] + || e.isInstanceOf[LessEquals] || e.isInstanceOf[GreaterThan] + || e.isInstanceOf[GreaterEquals]) + && + !Util.getTemplateVars(tempExpr).isEmpty) => { + + //println("Template Expression: "+tempExpr) + val linearTemp = LinearConstraintUtil.exprToTemplate(tempExpr) + // println("MODEL\n" + tempVarMap) + instantiateTemplate(linearTemp, tempVarMap) + } + case _ => tempExpr + })(expr) + inv + } + + + def validateLiteral(e : Expr) = e match { + case FractionalLiteral(num, denom) => { + if (denom == 0) + throw new IllegalStateException("Denominator is zero !! " +e) + if (denom < 0) + throw new IllegalStateException("Denominator is negative: " + denom) + true + } + case IntLiteral(_) => true + case InfiniteIntegerLiteral(_) => true + case _ => throw new IllegalStateException("Not a real literal: " + e) + } + + def instantiateTemplate(linearTemp: LinearTemplate, tempVarMap: Map[Expr, Expr]): Expr = { + val bigone = BigInt(1) + val coeffMap = linearTemp.coeffTemplate.map((entry) => { + val (term, coeffTemp) = entry + val coeffE = replace(tempVarMap, coeffTemp) + val coeff = RealValuedExprEvaluator.evaluate(coeffE) + + validateLiteral(coeff) + + (term -> coeff) + }) + val const = if (linearTemp.constTemplate.isDefined){ + val constE = replace(tempVarMap, linearTemp.constTemplate.get) + val constV = RealValuedExprEvaluator.evaluate(constE) + + validateLiteral(constV) + Some(constV) + } + else None + + val realValues: Seq[Expr] = coeffMap.values.toSeq ++ { if (const.isDefined) Seq(const.get) else Seq() } + //the coefficients could be fractions ,so collect all the denominators + val getDenom = (t: Expr) => t match { + case FractionalLiteral(num, denum) => denum + case _ => bigone + } + + val denoms = realValues.foldLeft(Set[BigInt]())((acc, entry) => { acc + getDenom(entry) }) + + //compute the LCM of the denominators + val gcd = denoms.foldLeft(bigone)((acc, d) => acc.gcd(d)) + val lcm = denoms.foldLeft(BigInt(1))((acc, d) => { + val product = (acc * d) + if(product % gcd == 0) + product/ gcd + else product + }) + + //scale the numerator by lcm + val scaleNum = (t: Expr) => t match { + case FractionalLiteral(num, denum) => + InfiniteIntegerLiteral(num * (lcm / denum)) + case InfiniteIntegerLiteral(n) => + InfiniteIntegerLiteral(n * lcm) + case _ => throw new IllegalStateException("Coefficient not assigned to any value") + } + val intCoeffMap = coeffMap.map((entry) => (entry._1, scaleNum(entry._2))) + val intConst = if (const.isDefined) Some(scaleNum(const.get)) else None + + val linearCtr = new LinearConstraint(linearTemp.op, intCoeffMap, intConst) + linearCtr.toExpr + } +} \ No newline at end of file diff --git a/src/main/scala/leon/invariant/factories/TemplateSolverFactory.scala b/src/main/scala/leon/invariant/factories/TemplateSolverFactory.scala new file mode 100644 index 0000000000000000000000000000000000000000..25abefa22e51479c6d56fffd2c9e8a91a8b0bfec --- /dev/null +++ b/src/main/scala/leon/invariant/factories/TemplateSolverFactory.scala @@ -0,0 +1,44 @@ +package leon +package invariant.factories + +import purescala.Common._ +import purescala.Definitions._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.Extractors._ +import purescala.Types._ +import invariant._ +import invariant.engine._ +import invariant.util._ +import invariant.structure._ +import FunctionUtils._ +import templateSolvers._ +import leon.solvers.Model + +object TemplateSolverFactory { + + def createTemplateSolver(ctx: InferenceContext, ctrack: ConstraintTracker, rootFun: FunDef, + // options to solvers + minopt: Option[(Expr, Model) => Model] = None, + bound: Option[Int] = None): TemplateSolver = { + if (ctx.useCegis) { + // TODO: find a better way to specify CEGIS total time bound + new CegisSolver(ctx, rootFun, ctrack, 10000, bound) + } else { + val minimizer = if (ctx.tightBounds && rootFun.hasTemplate) { + if (minopt.isDefined) + minopt + else { + //TODO: need to assert that the templates are resource templates + Some((new Minimizer(ctx)).tightenTimeBounds(rootFun.getTemplate) _) + } + } else + None + if (ctx.withmult) { + new NLTemplateSolverWithMult(ctx, rootFun, ctrack, minimizer) + } else { + new NLTemplateSolver(ctx, rootFun, ctrack, minimizer) + } + } + } +} \ No newline at end of file diff --git a/src/main/scala/leon/invariant/structure/Constraint.scala b/src/main/scala/leon/invariant/structure/Constraint.scala new file mode 100644 index 0000000000000000000000000000000000000000..2a6e32ce40f8882ec4e240578bd1a60b35eaee0e --- /dev/null +++ b/src/main/scala/leon/invariant/structure/Constraint.scala @@ -0,0 +1,272 @@ +package leon +package invariant.structure + +import z3.scala._ +import purescala._ +import purescala.Common._ +import purescala.Definitions._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.Extractors._ +import purescala.Types._ +import solvers.{ Solver, TimeoutSolver } +import solvers.z3.FairZ3Solver +import scala.collection.mutable.{ Set => MutableSet } +import scala.collection.mutable.{ Map => MutableMap } +import evaluators._ +import java.io._ +import solvers.z3.UninterpretedZ3Solver +import invariant.util._ + +trait Constraint { + def toExpr: Expr +} +/** + * Class representing linear templates which is a constraint of the form + * a1*v1 + a2*v2 + .. + an*vn + a0 <= 0 or = 0 or < 0 where ai's are unknown coefficients + * which could be any arbitrary expression with template variables as free variables + * and vi's are variables. + * Note: we need atleast one coefficient or one constant to be defined. + * Otherwise a NPE will be thrown (in the computation of 'template') + */ +class LinearTemplate(oper: Seq[Expr] => Expr, + coeffTemp: Map[Expr, Expr], + constTemp: Option[Expr]) extends Constraint { + + val zero = InfiniteIntegerLiteral(0) + + val op = { + oper + } + val coeffTemplate = { + //assert if the coefficients are templated expressions + assert(coeffTemp.values.foldLeft(true)((acc, e) => { + acc && Util.isTemplateExpr(e) + })) + + //print the template mapping + /*println("Coeff Mapping: "+coeffTemp) + println("Constant: "+constTemplate)*/ + coeffTemp + } + + val constTemplate = { + assert(constTemp match { + case None => true + case Some(e) => Util.isTemplateExpr(e) + }) + constTemp + } + + val template = { + //construct the expression corresponding to the template here + var lhs = coeffTemp.foldLeft(null: Expr)((acc, entry) => { + val (term, coeff) = entry + val minterm = Times(coeff, term) + if (acc == null) minterm else Plus(acc, minterm) + }) + lhs = if (constTemp.isDefined) { + if (lhs == null) constTemp.get + else Plus(lhs, constTemp.get) + } else lhs + val expr = oper(Seq(lhs, zero)) + //assert(expr.isInstanceOf[Equals] || expr.isInstanceOf[LessThan] || expr.isInstanceOf[GreaterThan] + // || expr.isInstanceOf[LessEquals] || expr.isInstanceOf[GreaterEquals]) + expr + } + + def templateVars: Set[Variable] = { + Util.getTemplateVars(template) + } + + def coeffEntryToString(coeffEntry: (Expr, Expr)): String = { + val (e, i) = coeffEntry + i match { + case InfiniteIntegerLiteral(x) if (x == 1) => e.toString + case InfiniteIntegerLiteral(x) if (x == -1) => "-" + e.toString + case InfiniteIntegerLiteral(v) => v + e.toString + case IntLiteral(1) => e.toString + case IntLiteral(-1) => "-" + e.toString + case IntLiteral(v) => v + e.toString + case _ => i + " * " + e.toString + } + } + + override def toExpr: Expr = template + + override def toString(): String = { + + val coeffStr = if (coeffTemplate.isEmpty) "" + else { + val (head :: tail) = coeffTemplate.toList + tail.foldLeft(coeffEntryToString(head))((str, pair) => { + + val termStr = coeffEntryToString(pair) + (str + " + " + termStr) + }) + } + val constStr = if (constTemplate.isDefined) constTemplate.get.toString else "" + val str = if (!coeffStr.isEmpty() && !constStr.isEmpty()) coeffStr + " + " + constStr + else coeffStr + constStr + str + (template match { + case t: Equals => " = " + case t: LessThan => " < " + case t: GreaterThan => " > " + case t: LessEquals => " <= " + case t: GreaterEquals => " >= " + }) + "0" + } + + override def hashCode(): Int = { + template.hashCode() + } + + override def equals(obj: Any): Boolean = obj match { + case lit: LinearTemplate => { + if (!lit.template.equals(this.template)) { + //println(lit.template + " and " + this.template+ " are not equal ") + false + } else true + } + case _ => false + } +} + +/** + * class representing a linear constraint. This is a linear template wherein the coefficients are constants + */ +class LinearConstraint(opr: Seq[Expr] => Expr, cMap: Map[Expr, Expr], constant: Option[Expr]) + extends LinearTemplate(opr, cMap, constant) { + + val coeffMap = { + //assert if the coefficients are only constant expressions + assert(cMap.values.foldLeft(true)((acc, e) => { + acc && variablesOf(e).isEmpty + })) + + //TODO: here we should try to simplify the constant expressions + cMap + } + + val const = constant.map((c) => { + //check if constant does not have any variables + assert(variablesOf(c).isEmpty) + c + }) +} + +case class BoolConstraint(e: Expr) extends Constraint { + val expr = { + assert(e match { + case Variable(_) => true + case Not(Variable(_)) => true + case t: BooleanLiteral => true + case Not(t: BooleanLiteral) => true + case _ => false + }) + e + } + + override def toString(): String = { + expr.toString + } + + def toExpr: Expr = expr +} + +object ADTConstraint { + + def apply(e: Expr): ADTConstraint = e match { + + //is this a tuple or case class select ? + // case Equals(Variable(_), CaseClassSelector(_, _, _)) | Iff(Variable(_), CaseClassSelector(_, _, _)) => { + case Equals(Variable(_), CaseClassSelector(_, _, _)) => { + val ccExpr = ExpressionTransformer.classSelToCons(e) + new ADTConstraint(ccExpr, Some(ccExpr)) + } + // case Equals(Variable(_),TupleSelect(_,_)) | Iff(Variable(_),TupleSelect(_,_)) => { + case Equals(Variable(_), TupleSelect(_, _)) => { + val tpExpr = ExpressionTransformer.tupleSelToCons(e) + new ADTConstraint(tpExpr, Some(tpExpr)) + } + //is this a tuple or case class def ? + case Equals(Variable(_), CaseClass(_, _)) | Equals(Variable(_), Tuple(_)) => { + new ADTConstraint(e, Some(e)) + } + //is this an instanceOf ? + case Equals(v @ Variable(_), ci @ IsInstanceOf(_, _)) => { + new ADTConstraint(e, None, Some(e)) + } + // considering asInstanceOf as equalities + case Equals(lhs @ Variable(_), ci @ AsInstanceOf(rhs @ Variable(_), _)) => { + val eq = Equals(lhs, rhs) + new ADTConstraint(eq, None, None, Some(eq)) + } + //equals and disequalities betweeen variables + case Equals(lhs @ Variable(_), rhs @ Variable(_)) if (lhs.getType != Int32Type && lhs.getType != RealType && lhs.getType != IntegerType) => { + new ADTConstraint(e, None, None, Some(e)) + } + case Not(Equals(lhs @ Variable(_), rhs @ Variable(_))) if (lhs.getType != Int32Type && lhs.getType != RealType && lhs.getType != IntegerType) => { + new ADTConstraint(e, None, None, Some(e)) + } + case _ => { + throw new IllegalStateException("Expression not an ADT constraint: " + e) + } + } +} + +class ADTConstraint(val expr: Expr, + val cons: Option[Expr] = None, + val inst: Option[Expr] = None, + val comp: Option[Expr] = None) extends Constraint { + + override def toString(): String = { + expr.toString + } + + override def toExpr = expr +} + +case class Call(retexpr: Expr, fi: FunctionInvocation) extends Constraint { + val expr = Equals(retexpr, fi) + + override def toExpr = expr +} + +object ConstraintUtil { + + def createConstriant(ie: Expr): Constraint = { + ie match { + case Variable(_) | Not(Variable(_)) | BooleanLiteral(_) | Not(BooleanLiteral(_)) => BoolConstraint(ie) + case Equals(v @ Variable(_), fi @ FunctionInvocation(_, _)) => Call(v, fi) + case Equals(Variable(_), CaseClassSelector(_, _, _)) + | Equals(Variable(_), CaseClass(_, _)) + | Equals(Variable(_), TupleSelect(_, _)) + | Equals(Variable(_), Tuple(_)) + | Equals(Variable(_), IsInstanceOf(_, _)) => { + + ADTConstraint(ie) + } + case Equals(lhs, rhs) if (lhs.getType != Int32Type && lhs.getType != RealType && lhs.getType != IntegerType) => { + //println("ADT constraint: "+ie) + ADTConstraint(ie) + } + case Not(Equals(lhs, rhs)) if (lhs.getType != Int32Type && lhs.getType != RealType && lhs.getType != IntegerType) => { + ADTConstraint(ie) + } + case _ => { + val simpe = simplifyArithmetic(ie) + simpe match { + case b: BooleanLiteral => BoolConstraint(b) + case _ => { + val template = LinearConstraintUtil.exprToTemplate(ie) + LinearConstraintUtil.evaluate(template) match { + case Some(v) => BoolConstraint(BooleanLiteral(v)) + case _ => template + } + } + } + } + } + } +} diff --git a/src/main/scala/leon/invariant/structure/Formula.scala b/src/main/scala/leon/invariant/structure/Formula.scala new file mode 100644 index 0000000000000000000000000000000000000000..0fedbc7057b6172b43c405caa6a976a24541c81a --- /dev/null +++ b/src/main/scala/leon/invariant/structure/Formula.scala @@ -0,0 +1,281 @@ +package leon +package invariant.structure + +import z3.scala._ +import purescala._ +import purescala.Common._ +import purescala.Definitions._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.Extractors._ +import purescala.Types._ +import solvers.{ Solver, TimeoutSolver } +import solvers.z3.FairZ3Solver +import java.io._ +import solvers.z3._ +import invariant.engine._ +import invariant.util._ +import leon.solvers.Model + +/** + * Data associated with a call + */ +class CallData(val guard : Variable, val parents: List[FunDef]) { +} + +/** + * Representation of an expression as a set of implications. + * 'initexpr' is required to be in negation normal form and And/Ors have been pulled up + * TODO: optimize the representation so that we use fewer guards. + */ +class Formula(val fd: FunDef, initexpr: Expr, ctx: InferenceContext) { + + val fls = BooleanLiteral(false) + val tru = BooleanLiteral(true) + val useImplies = false + + val combiningOp = if(useImplies) Implies.apply _ else Equals.apply _ + protected var disjuncts = Map[Variable, Seq[Constraint]]() //a mapping from guards to conjunction of atoms + protected var conjuncts = Map[Variable, Expr]() //a mapping from guards to disjunction of atoms + private var callDataMap = Map[Call, CallData]() //a mapping from a 'call' to the 'guard' guarding the call plus the list of transitive callers of 'call' + + val firstRoot : Variable = addConstraints(initexpr, List(fd))._1 + protected var roots : Seq[Variable] = Seq(firstRoot) //a list of roots, the formula is a conjunction of formula of each root + + def disjunctsInFormula = disjuncts + + def callData(call: Call) : CallData = callDataMap(call) + + //return the root variable and the sequence of disjunct guards added + //(which includes the root variable incase it respresents a disjunct) + def addConstraints(ine: Expr, callParents : List[FunDef]) : (Variable, Seq[Variable]) = { + + var newDisjGuards = Seq[Variable]() + + def getCtrsFromExprs(guard: Variable, exprs: Seq[Expr]) : Seq[Constraint] = { + var break = false + exprs.foldLeft(Seq[Constraint]())((acc, e) => { + if (break) acc + else { + val ctr = ConstraintUtil.createConstriant(e) + ctr match { + case BoolConstraint(BooleanLiteral(true)) => acc + case BoolConstraint(BooleanLiteral(false)) => { + break = true + Seq(ctr) + } + case call@Call(_,_) => { + + if(callParents.isEmpty) + throw new IllegalArgumentException("Parent not specified for call: "+ctr) + else { + callDataMap += (call -> new CallData(guard, callParents)) + } + acc :+ call + } + case _ => acc :+ ctr + } + } + }) + } + + val f1 = simplePostTransform((e: Expr) => e match { + case Or(args) => { + val newargs = args.map(arg => arg match { + case v: Variable if (disjuncts.contains(v)) => arg + case v: Variable if (conjuncts.contains(v)) => throw new IllegalStateException("or gaurd inside conjunct: "+e+" or-guard: "+v) + case _ => { + val atoms = arg match { + case And(atms) => atms + case _ => Seq(arg) + } + val g = TVarFactory.createTemp("b", BooleanType).toVariable + newDisjGuards :+= g + //println("atoms: "+atoms) + val ctrs = getCtrsFromExprs(g, atoms) + disjuncts += (g -> ctrs) + g + } + }) + //create a temporary for Or + val gor = TVarFactory.createTemp("b", BooleanType).toVariable + val newor = Util.createOr(newargs) + conjuncts += (gor -> newor) + gor + } + case And(args) => { + val newargs = args.map(arg => if (Util.getTemplateVars(e).isEmpty) { + arg + } else { + //if the expression has template variables then we separate it using guards + val g = TVarFactory.createTemp("b", BooleanType).toVariable + newDisjGuards :+= g + val ctrs = getCtrsFromExprs(g, Seq(arg)) + disjuncts += (g -> ctrs) + g + }) + Util.createAnd(newargs) + } + case _ => e + })(ExpressionTransformer.simplify(simplifyArithmetic( + //TODO: this is a hack as of now. Fix this. + //Note: it is necessary to convert real literals to integers since the linear constraint cannot handle real literals + if(ctx.usereals) ExpressionTransformer.FractionalLiteralToInt(ine) + else ine + ))) + + val rootvar = f1 match { + case v: Variable if(conjuncts.contains(v)) => v + case v: Variable if(disjuncts.contains(v)) => throw new IllegalStateException("f1 is a disjunct guard: "+v) + case _ => { + val atoms = f1 match { + case And(atms) => atms + case _ => Seq(f1) + } + val g = TVarFactory.createTemp("b", BooleanType).toVariable + val ctrs = getCtrsFromExprs(g, atoms) + newDisjGuards :+= g + disjuncts += (g -> ctrs) + g + } + } + (rootvar, newDisjGuards) + } + + //'satGuard' is required to a guard variable + def pickSatDisjunct(startGaurd : Variable, model: Model): Seq[Constraint] = { + + def traverseOrs(gd: Variable, model: Model): Seq[Variable] = { + val e @ Or(guards) = conjuncts(gd) + //pick one guard that is true + val guard = guards.collectFirst { case g @ Variable(id) if (model(id) == tru) => g } + if (!guard.isDefined) + throw new IllegalStateException("No satisfiable guard found: " + e) + guard.get +: traverseAnds(guard.get, model) + } + + def traverseAnds(gd: Variable, model: Model): Seq[Variable] = { + val ctrs = disjuncts(gd) + val guards = ctrs.collect { + case BoolConstraint(v @ Variable(_)) if (conjuncts.contains(v) || disjuncts.contains(v)) => v + } + if (guards.isEmpty) Seq() + else { + guards.foldLeft(Seq[Variable]())((acc, g) => { + if (model(g.id) != tru) + throw new IllegalStateException("Not a satisfiable guard: " + g) + + if (conjuncts.contains(g)) + acc ++ traverseOrs(g, model) + else { + acc ++ (g +: traverseAnds(g, model)) + } + }) + } + } + //if startGuard is unsat return empty + if (model(startGaurd.id) == fls) Seq() + else { + val satGuards = if (conjuncts.contains(startGaurd)) traverseOrs(startGaurd, model) + else (startGaurd +: traverseAnds(startGaurd, model)) + satGuards.flatMap(g => disjuncts(g)) + } + } + + /** + * 'neweexpr' is required to be in negation normal form and And/Ors have been pulled up + */ + def conjoinWithDisjunct(guard: Variable, newexpr: Expr, callParents: List[FunDef]) : (Variable, Seq[Variable]) = { + val (exprRoot, newGaurds) = addConstraints(newexpr, callParents) + //add 'newguard' in conjunction with 'disjuncts(guard)' + val ctrs = disjuncts(guard) + disjuncts -= guard + disjuncts += (guard -> (BoolConstraint(exprRoot) +: ctrs)) + (exprRoot, newGaurds) + } + + def conjoinWithRoot(newexpr: Expr, callParents: List[FunDef]): (Variable, Seq[Variable]) = { + val (exprRoot, newGaurds) = addConstraints(newexpr, callParents) + roots :+= exprRoot + (exprRoot, newGaurds) + } + + /** + * The first return value is param part and the second one is the + * non-parametric part + */ + def splitParamPart : (Expr, Expr) = { + var paramPart = Seq[Expr]() + var rest = Seq[Expr]() + disjuncts.foreach(entry => { + val (g,ctrs) = entry + val ctrExpr = combiningOp(g,Util.createAnd(ctrs.map(_.toExpr))) + if(Util.getTemplateVars(ctrExpr).isEmpty) + rest :+= ctrExpr + else + paramPart :+= ctrExpr + + }) + val conjs = conjuncts.map((entry) => combiningOp(entry._1, entry._2)).toSeq ++ roots + (Util.createAnd(paramPart), Util.createAnd(rest ++ conjs ++ roots)) + } + + def toExpr : Expr={ + val disjs = disjuncts.map((entry) => { + val (g,ctrs) = entry + combiningOp(g, Util.createAnd(ctrs.map(_.toExpr))) + }).toSeq + val conjs = conjuncts.map((entry) => combiningOp(entry._1, entry._2)).toSeq + Util.createAnd(disjs ++ conjs ++ roots) + } + + //unpack the disjunct and conjuncts by removing all guards + def unpackedExpr : Expr = { + //replace all conjunct guards in disjuncts by their mapping + val disjs : Map[Expr,Expr] = disjuncts.map((entry) => { + val (g,ctrs) = entry + val newctrs = ctrs.map(_ match { + case BoolConstraint(g@Variable(_)) if conjuncts.contains(g) => conjuncts(g) + case ctr@_ => ctr.toExpr + }) + (g, Util.createAnd(newctrs)) + }) + val rootexprs = roots.map(_ match { + case g@Variable(_) if conjuncts.contains(g) => conjuncts(g) + case e@_ => e + }) + //replace every guard in the 'disjs' by its disjunct. DO this as long as every guard is replaced in every disjunct + var unpackedDisjs = disjs + var replacedGuard = true + //var removeGuards = Seq[Variable]() + while(replacedGuard) { + replacedGuard = false + + val newDisjs = unpackedDisjs.map(entry => { + val (g,d) = entry + val guards = variablesOf(d).collect{ case id@_ if disjuncts.contains(id.toVariable) => id.toVariable } + if (guards.isEmpty) entry + else { + /*println("Disunct: "+d) + println("guard replaced: "+guards)*/ + replacedGuard = true + //removeGuards ++= guards + (g, replace(unpackedDisjs, d)) + } + }) + unpackedDisjs = newDisjs + } + //replace all the 'guards' in root using 'unpackedDisjs' + replace(unpackedDisjs, Util.createAnd(rootexprs)) + } + + override def toString : String = { + val disjStrs = disjuncts.map((entry) => { + val (g,ctrs) = entry + simplifyArithmetic(combiningOp(g, Util.createAnd(ctrs.map(_.toExpr)))).toString + }).toSeq + val conjStrs = conjuncts.map((entry) => combiningOp(entry._1, entry._2).toString).toSeq + val rootStrs = roots.map(_.toString) + (disjStrs ++ conjStrs ++ rootStrs).foldLeft("")((acc,str) => acc + "\n" + str) + } +} diff --git a/src/main/scala/leon/invariant/structure/FunctionUtils.scala b/src/main/scala/leon/invariant/structure/FunctionUtils.scala new file mode 100644 index 0000000000000000000000000000000000000000..9bdc6692c3440698d7db8feaf2d214aa144ebf02 --- /dev/null +++ b/src/main/scala/leon/invariant/structure/FunctionUtils.scala @@ -0,0 +1,158 @@ +package leon +package invariant.structure + +import purescala.Common._ +import purescala.Definitions._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.Extractors._ +import purescala.Types._ +import invariant.factories._ +import invariant.util._ +import Util._ +import scala.language.implicitConversions + +/** + * Some utiliy methods for functions. + * This also does caching to improve performance. + */ +object FunctionUtils { + + class FunctionInfo(fd: FunDef) { + //flags + lazy val isTheoryOperation = fd.annotations.contains("theoryop") + lazy val isMonotonic = fd.annotations.contains("monotonic") + lazy val isCommutative = fd.annotations.contains("commutative") + lazy val isDistributive = fd.annotations.contains("distributive") + lazy val compose = fd.annotations.contains("compose") + + //the template function + lazy val tmplFunctionName = "tmpl" + /** + * checks if the function name is 'tmpl' and there is only one argument + * if not, type checker would anyway throw an error if leon.invariant._ is included + */ + def isTemplateInvocation(finv: Expr) = { + finv match { + case FunctionInvocation(funInv, args) => + (funInv.id.name == "tmpl" && funInv.returnType == BooleanType && + args.size == 1 && args(0).isInstanceOf[Lambda]) + case _ => + false + } + } + + def isQMark(e: Expr) = e match { + case FunctionInvocation(TypedFunDef(fd, Seq()), args) => + (fd.id.name == "?" && fd.returnType == IntegerType && + args.size <= 1) + case _ => false + } + + def extractTemplateFromLambda(tempLambda: Lambda): Expr = { + val Lambda(vdefs, body) = tempLambda + val vars = vdefs.map(_.id.toVariable) + val tempVars = vars.map { // reuse template variables if possible + case v if TemplateIdFactory.IsTemplateIdentifier(v.id) => v + case v => + TemplateIdFactory.freshIdentifier(v.id.name).toVariable + } + val repmap = (vars zip tempVars).toMap[Expr, Expr] + replace(repmap, body) + } + + def tmplFunction(paramTypes: Seq[TypeTree]) = { + val lambdaType = FunctionType(paramTypes, BooleanType) + val paramid = FreshIdentifier("lamb", lambdaType) + new FunDef(FreshIdentifier("tmpl", BooleanType), Seq(), BooleanType, Seq(ValDef(paramid))) + } + + /** + * Repackages '?' mark expression into tmpl functions + */ + def qmarksToTmplFunction(ine: Expr) = { + var tempIds = Seq[Identifier]() + var indexToId = Map[BigInt, Identifier]() + val lambBody = simplePostTransform { + case q @ FunctionInvocation(_, Seq()) if isQMark(q) => // question mark with zero args + val freshid = TemplateIdFactory.freshIdentifier("q") + tempIds :+= freshid + freshid.toVariable + + case q @ FunctionInvocation(_, Seq(InfiniteIntegerLiteral(index))) if isQMark(q) => //question mark with one arg + indexToId.getOrElse(index, { + val freshid = TemplateIdFactory.freshIdentifier("q" + index) + tempIds :+= freshid + indexToId += (index -> freshid) + freshid + }).toVariable + + case other => other + }(ine) + FunctionInvocation(TypedFunDef(tmplFunction(tempIds.map(_.getType)), Seq()), + Seq(Lambda(tempIds.map(id => ValDef(id)), lambBody))) + } + + /** + * Does not support mixing of tmpl exprs and '?'. + * Need to check that tmpl functions are not nested. + */ + lazy val (postWoTemplate, templateExpr) = { + if (fd.postcondition.isDefined) { + val Lambda(_, postBody) = fd.postcondition.get + // collect all terms with question marks and convert them to a template + val postWoQmarks = postBody match { + case And(args) if args.exists(exists(isQMark) _) => + val (tempExprs, otherPreds) = args.partition { + case a if exists(isQMark)(a) => true + case _ => false + } + //println(s"Otherpreds: $otherPreds ${qmarksToTmplFunction(Util.createAnd(tempExprs))}") + Util.createAnd(otherPreds :+ qmarksToTmplFunction(Util.createAnd(tempExprs))) + case pb if exists(isQMark)(pb) => + qmarksToTmplFunction(pb) + case other => other + } + //the 'body' could be a template or 'And(pred, template)' + postWoQmarks match { + case finv @ FunctionInvocation(_, args) if isTemplateInvocation(finv) => + (None, Some(finv)) + case And(args) if args.exists(isTemplateInvocation) => + val (tempFuns, otherPreds) = args.partition { + case a if isTemplateInvocation(a) => true + case _ => false + } + if (tempFuns.size > 1) { + throw new IllegalStateException("Multiple template functions used in the postcondition: " + postBody) + } else { + val rest = if (otherPreds.size <= 1) otherPreds(0) else And(otherPreds) + (Some(rest), Some(tempFuns(0).asInstanceOf[FunctionInvocation])) + } + case pb => + (Some(pb), None) + } + } else { + (None, None) + } + } + + lazy val template = templateExpr map (finv => extractTemplateFromLambda(finv.args(0).asInstanceOf[Lambda])) + + def hasTemplate: Boolean = templateExpr.isDefined + def getPostWoTemplate = postWoTemplate match { + case None => tru + case Some(expr) => expr + } + def getTemplate = template.get + } + + // a cache for function infos + private var functionInfos = Map[FunDef, FunctionInfo]() + implicit def funDefToFunctionInfo(fd: FunDef): FunctionInfo = { + functionInfos.getOrElse(fd, { + val info = new FunctionInfo(fd) + functionInfos += (fd -> info) + info + }) + } +} \ No newline at end of file diff --git a/src/main/scala/leon/invariant/structure/LinearConstraintUtil.scala b/src/main/scala/leon/invariant/structure/LinearConstraintUtil.scala new file mode 100644 index 0000000000000000000000000000000000000000..102c2a088446102f6d0f9a1403cd106009fdfb11 --- /dev/null +++ b/src/main/scala/leon/invariant/structure/LinearConstraintUtil.scala @@ -0,0 +1,489 @@ +package leon +package invariant.structure + +import purescala._ +import purescala.Common._ +import purescala.Definitions._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.Extractors._ +import purescala.Types._ +import scala.collection.mutable.{ Set => MutableSet } +import scala.collection.mutable.{ Map => MutableMap } +import java.io._ +import invariant.util._ +import BigInt._ +import Constructors._ + +class NotImplementedException(message: String) extends RuntimeException(message) { + +} + +//a collections of utility methods that manipulate the templates +object LinearConstraintUtil { + val zero = InfiniteIntegerLiteral(0) + val one = InfiniteIntegerLiteral(1) + val mone = InfiniteIntegerLiteral(-1) + val tru = BooleanLiteral(true) + val fls = BooleanLiteral(false) + + //some utility methods + def getFIs(ctr: LinearConstraint): Set[FunctionInvocation] = { + val fis = ctr.coeffMap.keys.collect((e) => e match { + case fi: FunctionInvocation => fi + }) + fis.toSet + } + + def evaluate(lt: LinearTemplate): Option[Boolean] = lt match { + case lc: LinearConstraint if (lc.coeffMap.size == 0) => + ExpressionTransformer.simplify(lt.toExpr) match { + case BooleanLiteral(v) => Some(v) + case _ => None + } + case _ => None + } + + /** + * the expression 'Expr' is required to be a linear atomic predicate (or a template), + * if not, an exception would be thrown. + * For now some of the constructs are not handled. + * The function returns a linear template or a linear constraint depending + * on whether the expression has template variables or not + */ + def exprToTemplate(expr: Expr): LinearTemplate = { + + //println("Expr: "+expr) + //these are the result values + var coeffMap = MutableMap[Expr, Expr]() + var constant: Option[Expr] = None + var isTemplate : Boolean = false + + def addCoefficient(term: Expr, coeff: Expr) = { + if (coeffMap.contains(term)) { + val value = coeffMap(term) + val newcoeff = simplifyArithmetic(Plus(value, coeff)) + + //if newcoeff becomes zero remove it from the coeffMap + if(newcoeff == zero) { + coeffMap.remove(term) + } else{ + coeffMap.update(term, newcoeff) + } + } else coeffMap += (term -> simplifyArithmetic(coeff)) + + if (!variablesOf(coeff).isEmpty) { + isTemplate = true + } + } + + def addConstant(coeff: Expr) ={ + if (constant.isDefined) { + val value = constant.get + constant = Some(simplifyArithmetic(Plus(value, coeff))) + } else + constant = Some(simplifyArithmetic(coeff)) + + if (!variablesOf(coeff).isEmpty) { + isTemplate = true + } + } + + //recurse into plus and get all minterms + def getMinTerms(lexpr: Expr): Seq[Expr] = lexpr match { + case Plus(e1, e2) => getMinTerms(e1) ++ getMinTerms(e2) + case _ => Seq(lexpr) + } + + val linearExpr = MakeLinear(expr) + //the top most operator should be a relation + val Operator(Seq(lhs, InfiniteIntegerLiteral(x)), op) = linearExpr + /*if (lhs.isInstanceOf[InfiniteIntegerLiteral]) + throw new IllegalStateException("relation on two integers, not in canonical form: " + linearExpr)*/ + + val minterms = getMinTerms(lhs) + + //handle each minterm + minterms.foreach((minterm: Expr) => minterm match { + case _ if (Util.isTemplateExpr(minterm)) => { + addConstant(minterm) + } + case Times(e1, e2) => { + e2 match { + case Variable(_) => ; + case ResultVariable(_) => ; + case FunctionInvocation(_, _) => ; + case _ => throw new IllegalStateException("Multiplicand not a constraint variable: " + e2) + } + e1 match { + //case c @ InfiniteIntegerLiteral(_) => addCoefficient(e2, c) + case _ if (Util.isTemplateExpr(e1)) => { + addCoefficient(e2, e1) + } + case _ => throw new IllegalStateException("Coefficient not a constant or template expression: " + e1) + } + } + case Variable(_) => { + //here the coefficient is 1 + addCoefficient(minterm, one) + } + case ResultVariable(_) => { + addCoefficient(minterm, one) + } + case _ => throw new IllegalStateException("Unhandled min term: " + minterm) + }) + + if(coeffMap.isEmpty && constant.isEmpty) { + //here the generated template the constant term is zero. + new LinearConstraint(op, Map.empty, Some(zero)) + } else if(isTemplate) { + new LinearTemplate(op, coeffMap.toMap, constant) + } else{ + new LinearConstraint(op, coeffMap.toMap,constant) + } + } + + /** + * This method may have to do all sorts of transformation to make the expressions linear constraints. + * This assumes that the input expression is an atomic predicate (i.e, without and, or and nots) + * This is subjected to constant modification. + */ + def MakeLinear(atom: Expr): Expr = { + + //pushes the minus inside the arithmetic terms + //we assume that inExpr is in linear form + def PushMinus(inExpr: Expr): Expr = { + inExpr match { + case IntLiteral(v) => IntLiteral(-v) + case InfiniteIntegerLiteral(v) => InfiniteIntegerLiteral(-v) + case t: Terminal => Times(mone, t) + case fi @ FunctionInvocation(fdef, args) => Times(mone, fi) + case UMinus(e1) => e1 + case RealUMinus(e1) => e1 + case Minus(e1, e2) => Plus(PushMinus(e1), e2) + case RealMinus(e1, e2) => Plus(PushMinus(e1), e2) + case Plus(e1, e2) => Plus(PushMinus(e1), PushMinus(e2)) + case RealPlus(e1, e2) => Plus(PushMinus(e1), PushMinus(e2)) + case Times(e1, e2) => { + //here push the minus in to the coefficient which is the first argument + Times(PushMinus(e1), e2) + } + case RealTimes(e1, e2) => Times(PushMinus(e1), e2) + case _ => throw new NotImplementedException("PushMinus -- Operators not yet handled: " + inExpr) + } + } + + //we assume that ine is in linear form + def PushTimes(mul: Expr, ine: Expr): Expr = { + ine match { + case t: Terminal => Times(mul, t) + case fi @ FunctionInvocation(fdef, args) => Times(mul, fi) + case Plus(e1, e2) => Plus(PushTimes(mul, e1), PushTimes(mul, e2)) + case RealPlus(e1, e2) => Plus(PushTimes(mul, e1), PushTimes(mul, e2)) + case Times(e1, e2) => { + //here push the times into the coefficient which should be the first expression + Times(PushTimes(mul, e1), e2) + } + case RealTimes(e1, e2) => Times(PushTimes(mul, e1), e2) + case _ => throw new NotImplementedException("PushTimes -- Operators not yet handled: " + ine) + } + } + + //collect all the constants in addition and simplify them + //we assume that ine is in linear form and also that all constants are integers + def simplifyConsts(ine: Expr): (Option[Expr], BigInt) = { + ine match { + case IntLiteral(v) => (None, v) + case InfiniteIntegerLiteral(v) => (None, v) + case Plus(e1, e2) => { + val (r1, c1) = simplifyConsts(e1) + val (r2, c2) = simplifyConsts(e2) + + val newe = (r1, r2) match { + case (None, None) => None + case (Some(t), None) => Some(t) + case (None, Some(t)) => Some(t) + case (Some(t1), Some(t2)) => Some(Plus(t1, t2)) + } + (newe, c1 + c2) + } + case _ => (Some(ine), 0) + } + } + + def mkLinearRecur(inExpr: Expr): Expr = { + inExpr match { + case e @ Operator(Seq(e1, e2), op) + if ((e.isInstanceOf[Equals] || e.isInstanceOf[LessThan] + || e.isInstanceOf[LessEquals] || e.isInstanceOf[GreaterThan] + || e.isInstanceOf[GreaterEquals])) => { + + //check if the expression has real valued sub-expressions + val isReal = Util.hasReals(e1) || Util.hasReals(e2) + //doing something else ... ? + // println("[DEBUG] Expr 1 " + e1 + " of type " + e1.getType + " and Expr 2 " + e2 + " of type" + e2.getType) + val (newe, newop) = e match { + case t: Equals => (Minus(e1, e2), Equals) + case t: LessEquals => (Minus(e1, e2), LessEquals) + case t: GreaterEquals => (Minus(e2, e1), LessEquals) + case t: LessThan => { + if (isReal) + (Minus(e1, e2), LessThan) + else + (Plus(Minus(e1, e2), one), LessEquals) + } + case t: GreaterThan => { + if(isReal) + (Minus(e2,e1),LessThan) + else + (Plus(Minus(e2, e1), one), LessEquals) + } + } + val r = mkLinearRecur(newe) + //simplify the resulting constants + val (r2, const) = simplifyConsts(r) + val finale = if (r2.isDefined) { + if (const != 0) Plus(r2.get, InfiniteIntegerLiteral(const)) + else r2.get + } else InfiniteIntegerLiteral(const) + //println(r + " simplifies to "+finale) + newop(finale, zero) + } + case Minus(e1, e2) => Plus(mkLinearRecur(e1), PushMinus(mkLinearRecur(e2))) + case RealMinus(e1, e2) => RealPlus(mkLinearRecur(e1), PushMinus(mkLinearRecur(e2))) + case UMinus(e1) => PushMinus(mkLinearRecur(e1)) + case RealUMinus(e1) => PushMinus(mkLinearRecur(e1)) + case Times(_, _) | RealTimes(_, _) => { + val Operator(Seq(e1, e2), op) = inExpr + val (r1, r2) = (mkLinearRecur(e1), mkLinearRecur(e2)) + if(Util.isTemplateExpr(r1)) { + PushTimes(r1, r2) + } else if(Util.isTemplateExpr(r2)){ + PushTimes(r2, r1) + } else + throw new IllegalStateException("Expression not linear: " + Times(r1, r2)) + } + case Plus(e1, e2) => Plus(mkLinearRecur(e1), mkLinearRecur(e2)) + case RealPlus(e1, e2) => RealPlus(mkLinearRecur(e1), mkLinearRecur(e2)) + case t: Terminal => t + case fi: FunctionInvocation => fi + case _ => throw new IllegalStateException("Expression not linear: " + inExpr) + } + } + val rese = mkLinearRecur(atom) + rese + } + + /** + * Replaces an expression by another expression in the terms of the given linear constraint. + */ + def replaceInCtr(replaceMap: Map[Expr, Expr], lc: LinearConstraint): Option[LinearConstraint] = { + + //println("Replacing in "+lc+" repMap: "+replaceMap) + val newexpr = ExpressionTransformer.simplify(simplifyArithmetic(replace(replaceMap, lc.toExpr))) + //println("new expression: "+newexpr) + if (newexpr == tru) None + else if(newexpr == fls) throw new IllegalStateException("!!Constraint reduced to false during elimination: " + lc) + else { + val res = exprToTemplate(newexpr) + //check if res is true or false + evaluate(res) match { + case Some(false) => throw new IllegalStateException("!!Constraint reduced to false during elimination: " + lc) + case Some(true) => None //constraint reduced to true + case _ => + val resctr = res.asInstanceOf[LinearConstraint] + Some(resctr) + } + } + } + + /** + * Eliminates the specified variables from a conjunction of linear constraints (a disjunct) (that is satisfiable) + * We assume that the disjunct is in nnf form + * + * debugger is a function used for debugging + */ + val debugElimination = false + def apply1PRuleOnDisjunct(linearCtrs: Seq[LinearConstraint], elimVars: Set[Identifier], + debugger: Option[(Seq[LinearConstraint] => Unit)]): Seq[LinearConstraint] = { + //eliminate one variable at a time + //each iteration produces a new set of linear constraints + elimVars.foldLeft(linearCtrs)((acc, elimVar) => { + val newdisj = apply1PRuleOnDisjunct(acc, elimVar) + + if(debugElimination) { + if(debugger.isDefined) { + debugger.get(newdisj) + } + } + + newdisj + }) + } + + def apply1PRuleOnDisjunct(linearCtrs: Seq[LinearConstraint], elimVar: Identifier): Seq[LinearConstraint] = { + + if(debugElimination) + println("Trying to eliminate: "+elimVar) + + //collect all relevant constraints + val emptySeq = Seq[LinearConstraint]() + val (relCtrs, rest) = linearCtrs.foldLeft((emptySeq,emptySeq))((acc,lc) => { + if(variablesOf(lc.toExpr).contains(elimVar)) { + (lc +: acc._1,acc._2) + } else { + (acc._1,lc +: acc._2) + } + }) + + //now consider each constraint look for (a) equality involving the elimVar or (b) check if all bounds are lower + //or (c) if all bounds are upper. + var elimExpr : Option[Expr] = None + var bestExpr = false + var elimCtr : Option[LinearConstraint] = None + var allUpperBounds : Boolean = true + var allLowerBounds : Boolean = true + var foundEquality : Boolean = false + var skippingEquality : Boolean = false + + relCtrs.foreach((lc) => { + //check for an equality + if (lc.toExpr.isInstanceOf[Equals] && lc.coeffMap.contains(elimVar.toVariable)) { + foundEquality = true + + //here, sometimes we replace an existing expression with a better one if available + if (!elimExpr.isDefined || shouldReplace(elimExpr.get, lc, elimVar)) { + //if the coeffcient of elimVar is +ve the the sign of the coeff of every other term should be changed + val InfiniteIntegerLiteral(elimCoeff) = lc.coeffMap(elimVar.toVariable) + //make sure the value of the coefficient is 1 or -1 + //TODO: handle cases wherein the coefficient is not 1 or -1 + if (elimCoeff == 1 || elimCoeff == -1) { + val changeSign = if (elimCoeff > 0) true else false + + val startval = if (lc.const.isDefined) { + val InfiniteIntegerLiteral(cval) = lc.const.get + val newconst = if (changeSign) -cval else cval + InfiniteIntegerLiteral(newconst) + + } else zero + + val substExpr = lc.coeffMap.foldLeft(startval: Expr)((acc, summand) => { + val (term, InfiniteIntegerLiteral(coeff)) = summand + if (term != elimVar.toVariable) { + + val newcoeff = if (changeSign) -coeff else coeff + val newsummand = if (newcoeff == 1) term else Times(term, InfiniteIntegerLiteral(newcoeff)) + if (acc == zero) newsummand + else Plus(acc, newsummand) + + } else acc + }) + + elimExpr = Some(simplifyArithmetic(substExpr)) + elimCtr = Some(lc) + + if (debugElimination) { + println("Using ctr: " + lc + " found mapping: " + elimVar + " --> " + substExpr) + } + } else { + skippingEquality = true + } + } + } else if ((lc.toExpr.isInstanceOf[LessEquals] || lc.toExpr.isInstanceOf[LessThan]) + && lc.coeffMap.contains(elimVar.toVariable)) { + + val InfiniteIntegerLiteral(elimCoeff) = lc.coeffMap(elimVar.toVariable) + if (elimCoeff > 0) { + //here, we have found an upper bound + allLowerBounds = false + } else { + //here, we have found a lower bound + allUpperBounds = false + } + } else { + //here, we assume that the operators are normalized to Equals, LessThan and LessEquals + throw new IllegalStateException("LinearConstraint not in expeceted form : " + lc.toExpr) + } + }) + + val newctrs = if (elimExpr.isDefined) { + + val elimMap = Map[Expr, Expr](elimVar.toVariable -> elimExpr.get) + var repCtrs = Seq[LinearConstraint]() + relCtrs.foreach((ctr) => { + if (ctr != elimCtr.get) { + //replace 'elimVar' by 'elimExpr' in ctr + val repCtr = this.replaceInCtr(elimMap, ctr) + if (repCtr.isDefined) + repCtrs +:= repCtr.get + } + }) + repCtrs + + } else if (!foundEquality && (allLowerBounds || allUpperBounds)) { + //here, drop all relCtrs. None of them are important + Seq() + } else { + //for stats + if(skippingEquality) { + Stats.updateCumStats(1,"SkippedVar") + } + //cannot eliminate the variable + relCtrs + } + val resctrs = (newctrs ++ rest) + //println("After eliminating: "+elimVar+" : "+resctrs) + resctrs + } + + def sizeExpr(ine: Expr): Int = { + val simpe = simplifyArithmetic(ine) + var size = 0 + simplePostTransform((e: Expr) => { + size += 1 + e + })(simpe) + size + } + + def sizeCtr(ctr : LinearConstraint) : Int = { + val coeffSize = ctr.coeffMap.foldLeft(0)((acc, pair) => { + val (term, coeff) = pair + if(coeff == one) acc + 1 + else acc + sizeExpr(coeff) + 2 + }) + if(ctr.const.isDefined) coeffSize + 1 + else coeffSize + } + + def shouldReplace(currExpr : Expr, candidateCtr : LinearConstraint, elimVar: Identifier) : Boolean = { + if(!currExpr.isInstanceOf[InfiniteIntegerLiteral]) { + //is the candidate a constant + if(candidateCtr.coeffMap.size == 1) true + else{ + //computing the size of currExpr + if(sizeExpr(currExpr) > (sizeCtr(candidateCtr) - 1)) true + else false + } + } else false + } + + //remove transitive axioms + + /** + * Checks if the expression is linear i.e, + * is only conjuntion and disjunction of linear atomic predicates + */ + def isLinear(e: Expr) : Boolean = { + e match { + case And(args) => args forall isLinear + case Or(args) => args forall isLinear + case Not(arg) => isLinear(arg) + case Implies(e1, e2) => isLinear(e1) && isLinear(e2) + case t : Terminal => true + case atom => + exprToTemplate(atom).isInstanceOf[LinearConstraint] + } + } +} + diff --git a/src/main/scala/leon/invariant/templateSolvers/CegisSolver.scala b/src/main/scala/leon/invariant/templateSolvers/CegisSolver.scala new file mode 100644 index 0000000000000000000000000000000000000000..305bd98e34383a572c20618c1e429722f0c91341 --- /dev/null +++ b/src/main/scala/leon/invariant/templateSolvers/CegisSolver.scala @@ -0,0 +1,408 @@ +package leon +package invariant.templateSolvers +import z3.scala._ +import purescala.Common._ +import purescala.Definitions._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.Extractors._ +import purescala.Types._ +import java.io._ +import scala.util.control.Breaks._ +import solvers._ +import solvers.z3._ +import invariant.engine._ +import invariant.factories._ +import invariant.util._ +import invariant.structure._ +import invariant.structure.FunctionUtils._ +import leon.invariant.util.RealValuedExprEvaluator._ + +class CegisSolver(ctx: InferenceContext, + rootFun: FunDef, + ctrTracker: ConstraintTracker, + timeout: Int, + bound: Option[Int] = None) extends TemplateSolver(ctx, rootFun, ctrTracker) { + + override def solve(tempIds: Set[Identifier], funcVCs: Map[FunDef, Expr]): (Option[Model], Option[Set[Call]]) = { + + val initCtr = if (bound.isDefined) { + //use a predefined bound on the template variables + Util.createAnd(tempIds.map((id) => { + val idvar = id.toVariable + And(Implies(LessThan(idvar, realzero), GreaterEquals(idvar, InfiniteIntegerLiteral(-bound.get))), + Implies(GreaterEquals(idvar, realzero), LessEquals(idvar, InfiniteIntegerLiteral(bound.get)))) + }).toSeq) + + } else tru + + val funcs = funcVCs.keys + val formula = Util.createOr(funcs.map(funcVCs.apply _).toSeq) + + //using reals with bounds does not converge and also results in overflow + val (res, _, model) = (new CegisCore(ctx, timeout, this)).solve(tempIds, formula, initCtr, solveAsInt = true) + res match { + case Some(true) => (Some(model), None) + case Some(false) => (None, None) //no solution exists + case _ => //timed out + throw new IllegalStateException("Timeout!!") + } + } +} + +class CegisCore(ctx: InferenceContext, + timeout: Int, + cegisSolver: TemplateSolver) { + + val fls = BooleanLiteral(false) + val tru = BooleanLiteral(true) + val zero = InfiniteIntegerLiteral(0) + val timeoutMillis = timeout.toLong * 1000 + val dumpCandidateInvs = true + val minimizeSum = false + val program = ctx.program + val context = ctx.leonContext + val reporter = context.reporter + + /** + * Finds a model for the template variables in the 'formula' so that 'formula' is falsified + * subject to the constraints on the template variables given by the 'envCtrs' + * + * The parameter solveAsInt when set to true will convert the template constraints + * to integer constraints and solve. This should be enabled when bounds are used to constrain the variables + */ + def solve(tempIds: Set[Identifier], formula: Expr, initCtr: Expr, solveAsInt: Boolean, + initModel: Option[Model] = None): (Option[Boolean], Expr, Model) = { + + //start a timer + val startTime = System.currentTimeMillis() + + //for some sanity checks + var oldModels = Set[Expr]() + def addModel(m: Model) = { + val mexpr = Util.modelToExpr(m) + if (oldModels.contains(mexpr)) + throw new IllegalStateException("repeating model !!:" + m) + else oldModels += mexpr + } + + //add the initial model + val simplestModel = if (initModel.isDefined) initModel.get else { + new Model(tempIds.map((id) => (id -> simplestValue(id.getType))).toMap) + } + addModel(simplestModel) + + val tempVarSum = if (minimizeSum) { + //compute the sum of the tempIds + val rootTempIds = Util.getTemplateVars(cegisSolver.rootFun.getTemplate) + if (rootTempIds.size >= 1) { + rootTempIds.tail.foldLeft(rootTempIds.head.asInstanceOf[Expr])((acc, tvar) => Plus(acc, tvar)) + } else zero + } else zero + + //convert initCtr to a real-constraint + val initRealCtr = ExpressionTransformer.IntLiteralToReal(initCtr) + if (Util.hasInts(initRealCtr)) + throw new IllegalStateException("Initial constraints have integer terms: " + initRealCtr) + + def cegisRec(model: Model, prevctr: Expr): (Option[Boolean], Expr, Model) = { + + val elapsedTime = (System.currentTimeMillis() - startTime) + if (elapsedTime >= timeoutMillis - 100) { + //if we have timed out return the present set of constrains and the current model we have + (None, prevctr, model) + } else { + + //println("elapsedTime: "+elapsedTime / 1000+" timeout: "+timeout) + Stats.updateCounter(1, "CegisIters") + + if (dumpCandidateInvs) { + reporter.info("Candidate invariants") + val candInvs = cegisSolver.getAllInvariants(model) + candInvs.foreach((entry) => println(entry._1.id + "-->" + entry._2)) + } + val tempVarMap: Map[Expr, Expr] = model.map((elem) => (elem._1.toVariable, elem._2)).toMap + val instFormula = simplifyArithmetic(TemplateInstantiator.instantiate(formula, tempVarMap)) + + //sanity checks + val spuriousTempIds = variablesOf(instFormula).intersect(TemplateIdFactory.getTemplateIds) + if (!spuriousTempIds.isEmpty) + throw new IllegalStateException("Found a template variable in instFormula: " + spuriousTempIds) + if (Util.hasReals(instFormula)) + throw new IllegalStateException("Reals in instFormula: " + instFormula) + + //println("solving instantiated vcs...") + val t1 = System.currentTimeMillis() + val solver1 = new ExtendedUFSolver(context, program) + solver1.assertCnstr(instFormula) + val res = solver1.check + val t2 = System.currentTimeMillis() + println("1: " + (if (res.isDefined) "solved" else "timedout") + "... in " + (t2 - t1) / 1000.0 + "s") + + res match { + case Some(true) => { + //simplify the tempctrs, evaluate every atom that does not involve a template variable + //this should get rid of all functions + val satctrs = + simplePreTransform((e) => e match { + //is 'e' free of template variables ? + case _ if (variablesOf(e).filter(TemplateIdFactory.IsTemplateIdentifier _).isEmpty) => { + //evaluate the term + val value = solver1.evalExpr(e) + if (value.isDefined) value.get + else throw new IllegalStateException("Cannot evaluate expression: " + e) + } + case _ => e + })(Not(formula)) + solver1.free() + + //sanity checks + val spuriousProgIds = variablesOf(satctrs).filterNot(TemplateIdFactory.IsTemplateIdentifier _) + if (!spuriousProgIds.isEmpty) + throw new IllegalStateException("Found a progam variable in tempctrs: " + spuriousProgIds) + + val tempctrs = if (!solveAsInt) ExpressionTransformer.IntLiteralToReal(satctrs) else satctrs + val newctr = And(tempctrs, prevctr) + //println("Newctr: " +newctr) + + if (ctx.dumpStats) { + Stats.updateCounterStats(Util.atomNum(newctr), "CegisTemplateCtrs", "CegisIters") + } + + //println("solving template constraints...") + val t3 = System.currentTimeMillis() + val elapsedTime = (t3 - startTime) + val solver2 = SimpleSolverAPI(new TimeoutSolverFactory(SolverFactory(() => new ExtendedUFSolver(context, program) with TimeoutSolver), + timeoutMillis - elapsedTime)) + + val (res1, newModel) = if (solveAsInt) { + //convert templates to integers and solve. Finally, re-convert integer models for templates to real models + val rti = new RealToInt() + val intctr = rti.mapRealToInt(And(newctr, initRealCtr)) + val intObjective = rti.mapRealToInt(tempVarSum) + val (res1, intModel) = if (minimizeSum) { + minimizeIntegers(intctr, intObjective) + } else { + solver2.solveSAT(intctr) + } + (res1, rti.unmapModel(intModel)) + } else { + + /*if(InvariantUtil.hasInts(tempctrs)) + throw new IllegalStateException("Template constraints have integer terms: " + tempctrs)*/ + if (minimizeSum) { + minimizeReals(And(newctr, initRealCtr), tempVarSum) + } else { + solver2.solveSAT(And(newctr, initRealCtr)) + } + } + + val t4 = System.currentTimeMillis() + println("2: " + (if (res1.isDefined) "solved" else "timed out") + "... in " + (t4 - t3) / 1000.0 + "s") + + if (res1.isDefined) { + if (res1.get == false) { + //there exists no solution for templates + (Some(false), newctr, Model.empty) + + } else { + //this is for sanity check + addModel(newModel) + //generate more constraints + cegisRec(newModel, newctr) + } + } else { + //we have timed out + (None, prevctr, model) + } + } + case Some(false) => { + solver1.free() + //found a model for disabling the formula + (Some(true), prevctr, model) + } case _ => { + solver1.free() + throw new IllegalStateException("Cannot solve instFormula: " + instFormula) + } + } + } + } + //note: initRealCtr is used inside 'cegisRec' + cegisRec(simplestModel, tru) + } + + /** + * Performs minimization + */ + val MaxIter = 16 //note we may not be able to represent anything beyond 2^16 + val MaxInt = Int.MaxValue + val sqrtMaxInt = 45000 + val half = FractionalLiteral(1, 2) + val two = FractionalLiteral(2, 1) + val rzero = FractionalLiteral(0, 1) + val mone = FractionalLiteral(-1, 1) + val debugMinimization = false + + def minimizeReals(inputCtr: Expr, objective: Expr): (Option[Boolean], Model) = { + //val t1 = System.currentTimeMillis() + val sol = SimpleSolverAPI(new TimeoutSolverFactory(SolverFactory(() => new ExtendedUFSolver(context, program) with TimeoutSolver), timeoutMillis)) + val (res, model1) = sol.solveSAT(inputCtr) + res match { + case Some(true) => { + //do a binary search on sequentially on each of these tempvars + println("minimizing " + objective + " ...") + val idMap: Map[Expr, Expr] = variablesOf(objective).map(id => (id.toVariable -> model1(id))).toMap + var upperBound: FractionalLiteral = evaluate(replace(idMap, objective)) + var lowerBound: Option[FractionalLiteral] = None + var currentModel = model1 + var continue = true + var iter = 0 + do { + iter += 1 + //here we perform some sanity checks to prevent overflow + if (!boundSanityChecks(upperBound, lowerBound)) { + continue = false + } else { + if (lowerBound.isDefined && evaluateRealPredicate(GreaterEquals(lowerBound.get, upperBound))) { + continue = false + } else { + + val currval = if (lowerBound.isDefined) { + val midval = evaluate(Times(half, Plus(upperBound, lowerBound.get))) + floor(midval) + + } else { + val rlit @ FractionalLiteral(n, d) = upperBound + if (isGEZ(rlit)) { + if (n == 0) { + //make the upper bound negative + mone + } else { + floor(evaluate(Times(half, upperBound))) + } + } else floor(evaluate(Times(two, upperBound))) + + } + val boundCtr = LessEquals(objective, currval) + //val t1 = System.currentTimeMillis() + val solver2 = SimpleSolverAPI(new TimeoutSolverFactory(SolverFactory(() => new ExtendedUFSolver(context, program) with TimeoutSolver), timeoutMillis)) + val (res, newModel) = sol.solveSAT(And(inputCtr, boundCtr)) + //val t2 = System.currentTimeMillis() + //println((if (res.isDefined) "solved" else "timed out") + "... in " + (t2 - t1) / 1000.0 + "s") + res match { + case Some(true) => { + //here we have a new upper bound + currentModel = newModel + val idMap: Map[Expr, Expr] = variablesOf(objective).map(id => (id.toVariable -> newModel(id))).toMap + val value = evaluate(replace(idMap, objective)) + upperBound = value + if (this.debugMinimization) + reporter.info("Found new upper bound: " + upperBound) + } + case _ => { + //here we have a new lower bound : currval + lowerBound = Some(currval) + if (this.debugMinimization) + reporter.info("Found new lower bound: " + currval) + } + } + } + } + } while (continue && iter < MaxIter) + //here, we found a best-effort minimum + reporter.info("Minimization complete...") + (Some(true), currentModel) + } + case _ => (res, model1) + } + } + + def boundSanityChecks(ub: FractionalLiteral, lb: Option[FractionalLiteral]): Boolean = { + val FractionalLiteral(n, d) = ub + if (n <= (MaxInt / 2)) { + if (lb.isDefined) { + val FractionalLiteral(n2, _) = lb.get + (n2 <= sqrtMaxInt && d <= sqrtMaxInt) + } else { + (d <= (MaxInt / 2)) + } + } else false + } + + def minimizeIntegers(inputCtr: Expr, objective: Expr): (Option[Boolean], Model) = { + //val t1 = System.currentTimeMillis() + val sol = SimpleSolverAPI(new TimeoutSolverFactory(SolverFactory(() => new ExtendedUFSolver(context, program) with TimeoutSolver), timeoutMillis)) + val (res, model1) = sol.solveSAT(inputCtr) + res match { + case Some(true) => { + //do a binary search on sequentially on each of these tempvars + reporter.info("minimizing " + objective + " ...") + val idMap: Map[Expr, Expr] = variablesOf(objective).map(id => (id.toVariable -> model1(id))).toMap + var upperBound = simplifyArithmetic(replace(idMap, objective)).asInstanceOf[InfiniteIntegerLiteral].value + var lowerBound: Option[BigInt] = None + var currentModel = model1 + var continue = true + var iter = 0 + do { + iter += 1 + if (lowerBound.isDefined && lowerBound.get >= upperBound - 1) { + continue = false + } else { + + val currval = if (lowerBound.isDefined) { + val sum = (upperBound + lowerBound.get) + floorDiv(sum, 2) + } else { + if (upperBound >= 0) { + if (upperBound == 0) { + //make the upper bound negative + BigInt(-1) + } else { + floorDiv(upperBound, 2) + } + } else 2 * upperBound + } + val boundCtr = LessEquals(objective, InfiniteIntegerLiteral(currval)) + //val t1 = System.currentTimeMillis() + val solver2 = SimpleSolverAPI(new TimeoutSolverFactory(SolverFactory(() => new ExtendedUFSolver(context, program) with TimeoutSolver), timeoutMillis)) + val (res, newModel) = sol.solveSAT(And(inputCtr, boundCtr)) + //val t2 = System.currentTimeMillis() + //println((if (res.isDefined) "solved" else "timed out") + "... in " + (t2 - t1) / 1000.0 + "s") + res match { + case Some(true) => { + //here we have a new upper bound + currentModel = newModel + val idMap: Map[Expr, Expr] = variablesOf(objective).map(id => (id.toVariable -> newModel(id))).toMap + val value = simplifyArithmetic(replace(idMap, objective)).asInstanceOf[InfiniteIntegerLiteral].value + upperBound = value + if (this.debugMinimization) + reporter.info("Found new upper bound: " + upperBound) + } + case _ => { + //here we have a new lower bound : currval + lowerBound = Some(currval) + if (this.debugMinimization) + reporter.info("Found new lower bound: " + currval) + } + } + } + } while (continue && iter < MaxIter) + //here, we found a best-effort minimum + reporter.info("Minimization complete...") + (Some(true), currentModel) + } + case _ => (res, model1) + } + } + + def floorDiv(did: BigInt, div: BigInt): BigInt = { + if (div <= 0) throw new IllegalStateException("Invalid divisor") + if (did < 0) { + if (did % div != 0) did / div - 1 + else did / div + } else { + did / div + } + } + +} \ No newline at end of file diff --git a/src/main/scala/leon/invariant/templateSolvers/ExtendedUFSolver.scala b/src/main/scala/leon/invariant/templateSolvers/ExtendedUFSolver.scala new file mode 100644 index 0000000000000000000000000000000000000000..8bc93564c35076430ce1b96a74029fc7ff7f494b --- /dev/null +++ b/src/main/scala/leon/invariant/templateSolvers/ExtendedUFSolver.scala @@ -0,0 +1,79 @@ +/* Copyright 2009-2013 EPFL, Lausanne */ + +package leon +package invariant.templateSolvers + +import z3.scala._ +import leon.solvers._ +import purescala.Common._ +import purescala.Definitions._ +import purescala.Expressions._ +import purescala.Extractors._ +import purescala.ExprOps._ +import purescala.Types._ +import leon.LeonContext +import leon.solvers.z3.UninterpretedZ3Solver + +/** + * A uninterpreted solver extended with additional functionalities. + * TODO: need to handle bit vectors + */ +class ExtendedUFSolver(context : LeonContext, program: Program) + extends UninterpretedZ3Solver(context, program) { + + override val name = "Z3-eu" + override val description = "Extended UF-ADT Z3 Solver" + + /** + * This uses z3 methods to evaluate the model + */ + def evalExpr(expr: Expr): Option[Expr] = { + val ast = toZ3Formula(expr) + val model = solver.getModel + val res = model.eval(ast, true) + if (res.isDefined) + Some(fromZ3Formula(model, res.get, null)) + else None + } + + def getAssertions : Expr = { + val assers = solver.getAssertions.map((ast) => fromZ3Formula(null, ast, null)) + And(assers) + } + + /** + * Uses z3 to convert a formula to SMTLIB. + */ + def ctrsToString(logic: String, unsatcore: Boolean = false): String = { + z3.setAstPrintMode(Z3Context.AstPrintMode.Z3_PRINT_SMTLIB2_COMPLIANT) + var seenHeaders = Set[String]() + var headers = Seq[String]() + var asserts = Seq[String]() + solver.getAssertions().toSeq.foreach((asser) => { + val str = z3.benchmarkToSMTLIBString("benchmark", logic, "unknown", "", Seq(), asser) + //remove from the string the headers and also redeclaration of template variables + //split based on newline to get a list of strings + val strs = str.split("\n") + val newstrs = strs.filter((line) => !seenHeaders.contains(line)) + var newHeaders = Seq[String]() + newstrs.foreach((line) => { + if (line == "; benchmark") newHeaders :+= line + else if (line.startsWith("(set")) newHeaders :+= line + else if (line.startsWith("(declare")) newHeaders :+= line + else if(line.startsWith("(check-sat)")) {} //do nothing + else asserts :+= line + }) + headers ++= newHeaders + seenHeaders ++= newHeaders + }) + val initstr = if (unsatcore) { + "(set-option :produce-unsat-cores true)" + } else "" + val smtstr = headers.foldLeft(initstr)((acc, hdr) => acc + "\n" + hdr) + "\n" + + asserts.foldLeft("")((acc, asrt) => acc + "\n" + asrt) + "\n" + + "(check-sat)" + "\n" + + (if (!unsatcore) "(get-model)" + else "(get-unsat-core)") + smtstr + } +} diff --git a/src/main/scala/leon/invariant/templateSolvers/FarkasLemmaSolver.scala b/src/main/scala/leon/invariant/templateSolvers/FarkasLemmaSolver.scala new file mode 100644 index 0000000000000000000000000000000000000000..13e8727a4f204ce04a2dae1634057f1723385b80 --- /dev/null +++ b/src/main/scala/leon/invariant/templateSolvers/FarkasLemmaSolver.scala @@ -0,0 +1,332 @@ +package leon +package invariant.templateSolvers + +import z3.scala._ +import purescala._ +import purescala.Common._ +import purescala.Definitions._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.Extractors._ +import purescala.Types._ +import java.io._ +import solvers.SimpleSolverAPI +import scala.collection.mutable.{ Map => MutableMap } +import invariant.engine._ +import invariant.factories._ +import invariant.util.Util._ +import invariant.util._ +import invariant.structure._ +import leon.solvers.TimeoutSolver +import leon.solvers.SolverFactory +import leon.solvers.TimeoutSolverFactory +import leon.solvers.Model +import leon.invariant.util.RealValuedExprEvaluator._ + +class FarkasLemmaSolver(ctx: InferenceContext) { + + //debug flags + val verbose = true + val verifyModel = false + val dumpNLCtrsAsSMTLIB = false + val dumpNLCtrs = false + val debugNLCtrs = false + + // functionality flags + val solveAsBitvectors = false + val bvsize = 5 + val useIncrementalSolvingForNLctrs = false //note: NLsat doesn't support incremental solving. It starts from sratch even in incremental solving. + + val leonctx = ctx.leonContext + val program = ctx.program + val reporter = ctx.reporter + val timeout = ctx.timeout + /** + * This procedure produces a set of constraints that need to be satisfiable for the + * conjunction ants and conseqs to be false + * antsSimple - antecedents without template variables + * antsTemp - antecedents with template variables + * Similarly for conseqsSimple and conseqsTemp + * + * Let A,A' and C,C' denote the simple and templated portions of the antecedent and the consequent respectively. + * We need to check \exists a, \forall x, A[x] ^ A'[x,a] ^ C[x] ^ C'[x,a] = false + * + */ + def constraintsForUnsat(linearCtrs: Seq[LinearConstraint], temps: Seq[LinearTemplate]): Expr = { + + //for debugging + /*println("#" * 20) + println(allAnts + " ^ " + allConseqs) + println("#" * 20)*/ + this.applyFarkasLemma(linearCtrs ++ temps, Seq(), true) + } + + /** + * This procedure produces a set of constraints that need to be satisfiable for the implication to hold + * antsSimple - antecedents without template variables + * antsTemp - antecedents with template variables + * Similarly for conseqsSimple and conseqsTemp + * + * Let A,A' and C,C' denote the simple and templated portions of the antecedent and the consequent respectively. + * We need to check \exists a, \forall x, A[x] ^ A'[x,a] => C[x] ^ C'[x,a] + * + */ + def constraintsForImplication(antsSimple: Seq[LinearConstraint], antsTemp: Seq[LinearTemplate], + conseqsSimple: Seq[LinearConstraint], conseqsTemp: Seq[LinearTemplate], + uisolver: SimpleSolverAPI): Expr = { + + val allAnts = antsSimple ++ antsTemp + val allConseqs = conseqsSimple ++ conseqsTemp + //for debugging + println("#" * 20) + println(allAnts + " => " + allConseqs) + println("#" * 20) + + //Optimization 1: Check if ants are unsat (already handled) + val pathVC = createAnd(antsSimple.map(_.toExpr).toSeq ++ conseqsSimple.map(_.toExpr).toSeq) + val notPathVC = And(createAnd(antsSimple.map(_.toExpr).toSeq), Not(createAnd(conseqsSimple.map(_.toExpr).toSeq))) + val (satVC, _) = uisolver.solveSAT(pathVC) + val (satNVC, _) = uisolver.solveSAT(notPathVC) + + //Optimization 2: use the unsatisfiability of VC and not VC to simplify the constraint generation + //(a) if A => C is false and A' is true then the entire formula is unsat + //(b) if A => C is false and A' is not true then we need to ensure A^A' is unsat (i.e, disable Ant) + //(c) if A => C is true (i.e, valid) then it suffices to ensure A^A' => C' is valid + //(d) if A => C is neither true nor false then we cannot do any simplification + //TODO: Food for thought: + //(a) can we do any simplification for case (d) with the model + //(b) could the linearity in the disabled case be exploited + val (ants, conseqs, disableFlag) = (satVC, satNVC) match { + case (Some(false), _) if (antsTemp.isEmpty) => (Seq(), Seq(), false) + case (Some(false), _) => (allAnts, Seq(), true) //here only disable the antecedents + case (_, Some(false)) => (allAnts, conseqsTemp, false) //here we need to only check the inductiveness of the templates + case _ => (allAnts, allConseqs, false) + } + if (ants.isEmpty) { + BooleanLiteral(false) + } else { + this.applyFarkasLemma(ants, conseqs, disableFlag) + } + } + + /** + * This procedure uses Farka's lemma to generate a set of non-linear constraints for the input implication. + * Note that these non-linear constraints are in real arithmetic. + * TODO: Correctness issue: need to handle strict inequalities in consequent + * Do we really need the consequent ?? + */ + def applyFarkasLemma(ants: Seq[LinearTemplate], conseqs: Seq[LinearTemplate], disableAnts: Boolean): Expr = { + + //compute the set of all constraint variables in ants + val antCVars = ants.foldLeft(Set[Expr]())((acc, ant) => acc ++ ant.coeffTemplate.keySet) + + //the creates constraints for a single consequent + def createCtrs(conseq: Option[LinearTemplate]): Expr = { + //create a set of identifiers one for each ants + val lambdas = ants.map((ant) => (ant -> Variable(FreshIdentifier("l", RealType, true)))).toMap + val lambda0 = Variable(FreshIdentifier("l", RealType, true)) + + //add a bunch of constraints on lambdas + var strictCtrLambdas = Seq[Variable]() + val lambdaCtrs = (ants.collect((ant) => ant.template match { + case t: LessEquals => GreaterEquals(lambdas(ant), zero) + case t: LessThan => { + val l = lambdas(ant) + strictCtrLambdas :+= l + GreaterEquals(l, zero) + } + }).toSeq :+ GreaterEquals(lambda0, zero)) + + //add the constraints on constant terms + val sumConst = ants.foldLeft(UMinus(lambda0): Expr)((acc, ant) => ant.constTemplate match { + case Some(d) => Plus(acc, Times(lambdas(ant), d)) + case None => acc + }) + + val cvars = antCVars ++ (if (conseq.isDefined) conseq.get.coeffTemplate.keys else Seq()) + //initialize enabled and disabled parts + var enabledPart: Expr = if (conseq.isDefined) { + conseq.get.constTemplate match { + case Some(d) => Equals(d, sumConst) + case None => Equals(zero, sumConst) + } + } else null + //the disabled part handles strict inequalities as well using Motzkin's transposition + var disabledPart: Expr = + if (strictCtrLambdas.isEmpty) Equals(one, sumConst) + else Or(Equals(one, sumConst), + And(Equals(zero, sumConst), createOr(strictCtrLambdas.map(GreaterThan(_, zero))))) + + for (cvar <- cvars) { + //compute the linear combination of all the coeffs of antCVars + //println("Processing cvar: "+cvar) + var sumCoeff: Expr = zero + for (ant <- ants) { + //handle coefficients here + if (ant.coeffTemplate.contains(cvar)) { + val addend = Times(lambdas(ant), ant.coeffTemplate.get(cvar).get) + if (sumCoeff == zero) + sumCoeff = addend + else + sumCoeff = Plus(sumCoeff, addend) + } + } + //println("sum coeff: "+sumCoeff) + //make the sum equal to the coeff. of cvar in conseq + if (conseq.isDefined) { + enabledPart = And(enabledPart, + (if (conseq.get.coeffTemplate.contains(cvar)) + Equals(conseq.get.coeffTemplate.get(cvar).get, sumCoeff) + else Equals(zero, sumCoeff))) + } + + disabledPart = And(disabledPart, Equals(zero, sumCoeff)) + } //end of cvars loop + + //the final constraint is a conjunction of lambda constraints and disjunction of enabled and disabled parts + if (disableAnts) And(createAnd(lambdaCtrs), disabledPart) + else { + //And(And(lambdaCtrs), enabledPart) + And(createAnd(lambdaCtrs), Or(enabledPart, disabledPart)) + } + } + + val ctrs = if (disableAnts) { + //here conseqs are empty + createCtrs(None) + } else { + val Seq(head, tail @ _*) = conseqs + val nonLinearCtrs = tail.foldLeft(createCtrs(Some(head)))((acc, conseq) => And(acc, createCtrs(Some(conseq)))) + nonLinearCtrs + } + ExpressionTransformer.IntLiteralToReal(ctrs) + } + + def solveFarkasConstraints(nlctrs: Expr): (Option[Boolean], Model) = { + + // factor out common nonlinear terms and create an equiv-satisfiable constraint + def reduceCommonNLTerms(ctrs: Expr) = { + var nlUsage = new CounterMap[Expr]() + postTraversal{ + case t: Times => nlUsage.inc(t) + case e => ; + }(ctrs) + val repMap = nlUsage.collect{ + case (k, v) if v > 1 => + (k -> FreshIdentifier("t", RealType, true).toVariable) + }.toMap + createAnd(replace(repMap, ctrs) +: repMap.map { + case (k, v) => Equals(v, k) + }.toSeq) + } + + // try eliminate nonlinearity to whatever extent possible + var elimMap = Map[Identifier, (Identifier, Identifier)]() // maps the fresh identifiers to the product of the identifiers they represent. + def reduceNonlinearity(farkasctrs: Expr): Expr = { + var varCounts = new CounterMap[Identifier]() + // collect # of uses of each variable + postTraversal { + case Variable(id) => varCounts.inc(id) + case _ => ; + }(farkasctrs) + var adnlCtrs = Seq[Expr]() + val simpCtrs = simplePostTransform { + case Times(vlb @ Variable(lb), va @ Variable(a)) if (varCounts(lb) == 1 || varCounts(a) == 1) => // is lb or a used only once ? + // stats + Stats.updateCumStats(1, "NonlinearMultEliminated") + val freshid = FreshIdentifier(lb.name + a.name, RealType, true) + val freshvar = freshid.toVariable + elimMap += (freshid -> (lb, a)) + if (varCounts(lb) == 1) + // va = 0 ==> freshvar = 0 + adnlCtrs :+= Implies(Equals(va, realzero), Equals(freshvar, realzero)) + else // here varCounts(a) == 1 + adnlCtrs :+= Implies(Equals(vlb, realzero), Equals(freshvar, realzero)) + freshvar + case e => + e + }(farkasctrs) + createAnd(simpCtrs +: adnlCtrs) + } + val simpctrs = (reduceCommonNLTerms _ andThen + reduceNonlinearity)(nlctrs) + + //for debugging nonlinear constraints + if (this.debugNLCtrs && Util.hasInts(simpctrs)) { + throw new IllegalStateException("Nonlinear constraints have integers: " + simpctrs) + } + if (verbose && LinearConstraintUtil.isLinear(simpctrs)) { + reporter.info("Constraints reduced to linear !") + } + if (this.dumpNLCtrs) { + reporter.info("InputCtrs: " + nlctrs) + reporter.info("SimpCtrs: " + simpctrs) + if (this.dumpNLCtrsAsSMTLIB) { + val filename = ctx.program.modules.last.id + "-nlctr" + FileCountGUID.getID + ".smt2" + if (Util.atomNum(simpctrs) >= 5) { + if (solveAsBitvectors) + Util.toZ3SMTLIB(simpctrs, filename, "QF_BV", leonctx, program, useBitvectors = true, bitvecSize = bvsize) + else + Util.toZ3SMTLIB(simpctrs, filename, "QF_NRA", leonctx, program) + reporter.info("NLctrs dumped to: " + filename) + } + } + } + + // solve the resulting constraints using solver + val innerSolver = if (solveAsBitvectors) { + throw new IllegalStateException("Not supported now. Will be in the future!") + //new ExtendedUFSolver(leonctx, program, useBitvectors = true, bitvecSize = bvsize) with TimeoutSolver + } else { + new ExtendedUFSolver(leonctx, program) with TimeoutSolver + } + val solver = SimpleSolverAPI(new TimeoutSolverFactory(SolverFactory(() => innerSolver), timeout * 1000)) + if (verbose) reporter.info("solving...") + val t1 = System.currentTimeMillis() + val (res, model) = solver.solveSAT(simpctrs) + val t2 = System.currentTimeMillis() + if (verbose) reporter.info((if (res.isDefined) "solved" else "timed out") + "... in " + (t2 - t1) / 1000.0 + "s") + Stats.updateCounterTime((t2 - t1), "NL-solving-time", "disjuncts") + + res match { + case Some(true) => + // construct assignments for the variables that were removed during nonlinearity reduction + def divide(dividend: Expr, divisor: Expr) = { + divisor match { + case `realzero` => + assert(dividend == realzero) + // here result can be anything. So make it zero + realzero + case _ => + val res = evaluate(Division(dividend, divisor)) + res + } + } + val newassignments = elimMap.flatMap { + case (k, (v1, v2)) => + val kval = evaluate(model(k)) + if (model.isDefinedAt(v1) && model.isDefinedAt(v2)) + throw new IllegalStateException( + s"Variables $v1 and $v2 in an eliminated nonlinearity have models") + else if (model.isDefinedAt(v1)) { + val v2val = divide(kval, evaluate(model(v1))) + Seq((v2 -> v2val)) + } else if (model.isDefinedAt(v2)) + Seq((v1 -> divide(kval, evaluate(model(v2))))) + else + // here v1 * v2 = k. Therefore make v1 = k and v2 = 1 + Seq((v1 -> kval), (v2 -> FractionalLiteral(1, 1))) + } + val fullmodel = model ++ newassignments + if (this.verifyModel) { + //println("Fullmodel: "+fullmodel) + assert(evaluateRealFormula(replace( + fullmodel.map { case (k, v) => (k.toVariable, v) }.toMap, + nlctrs))) + } + (res, fullmodel) + case _ => + (res, model) + } + } +} \ No newline at end of file diff --git a/src/main/scala/leon/invariant/templateSolvers/NLTemplateSolver.scala b/src/main/scala/leon/invariant/templateSolvers/NLTemplateSolver.scala new file mode 100644 index 0000000000000000000000000000000000000000..2e01e4682beec850227aaa6663619ea6ed4cb371 --- /dev/null +++ b/src/main/scala/leon/invariant/templateSolvers/NLTemplateSolver.scala @@ -0,0 +1,717 @@ +package leon +package invariant.templateSolvers + +import z3.scala._ +import purescala.Common._ +import purescala.Definitions._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.Extractors._ +import purescala.Types._ +import evaluators._ +import scala.collection.mutable.{ Map => MutableMap } +import java.io._ +import solvers._ +import solvers.z3._ +import scala.util.control.Breaks._ +import purescala.ScalaPrinter +import scala.collection.mutable.{ Map => MutableMap } +import scala.reflect.runtime.universe +import invariant.engine._ +import invariant.factories._ +import invariant.util._ +import invariant.util.ExpressionTransformer._ +import invariant.structure._ +import invariant.structure.FunctionUtils._ +import leon.invariant.util.RealValuedExprEvaluator._ + +class NLTemplateSolver(ctx: InferenceContext, rootFun: FunDef, ctrTracker: ConstraintTracker, + minimizer: Option[(Expr, Model) => Model]) + extends TemplateSolver(ctx, rootFun, ctrTracker) { + + //flags controlling debugging + val debugIncrementalVC = false + val debugElimination = false + val debugChooseDisjunct = false + val debugTheoryReduction = false + val debugAxioms = false + val verifyInvariant = false + val debugReducedFormula = false + val trackUnpackedVCCTime = false + + //print flags + val verbose = false + val printCounterExample = false + val printPathToConsole = false + val dumpPathAsSMTLIB = false + val printCallConstriants = false + val dumpInstantiatedVC = false + + private val program = ctx.program + private val timeout = ctx.timeout + private val leonctx = ctx.leonContext + + //flag controlling behavior + private val farkasSolver = new FarkasLemmaSolver(ctx) + private val startFromEarlierModel = true + private val disableCegis = true + private val useIncrementalSolvingForVCs = true + + //this is private mutable state used by initialized during every call to 'solve' and used by 'solveUNSAT' + protected var funcVCs = Map[FunDef, Expr]() + //TODO: can incremental solving be trusted ? There were problems earlier. + protected var vcSolvers = Map[FunDef, ExtendedUFSolver]() + protected var paramParts = Map[FunDef, Expr]() + private var lastFoundModel: Option[Model] = None + + //for miscellaneous things + val trackNumericalDisjuncts = false + var numericalDisjuncts = List[Expr]() + + protected def splitVC(fd: FunDef): (Expr, Expr) = { + ctrTracker.getVC(fd).splitParamPart + } + + def initVCSolvers { + funcVCs.keys.foreach(fd => { + val (paramPart, rest) = if (ctx.usereals) { + val (pp, r) = splitVC(fd) + (IntLiteralToReal(pp), IntLiteralToReal(r)) + } else + splitVC(fd) + + if (Util.hasReals(rest) && Util.hasInts(rest)) + throw new IllegalStateException("Non-param Part has both integers and reals: " + rest) + + val vcSolver = new ExtendedUFSolver(leonctx, program) + vcSolver.assertCnstr(rest) + + if (debugIncrementalVC) { + assert(Util.getTemplateVars(rest).isEmpty) + println("For function: " + fd.id) + println("Param part: " + paramPart) + /*vcSolver.check match { + case Some(false) => throw new IllegalStateException("Non param-part is unsat "+rest) + case _ => ; + }*/ + } + vcSolvers += (fd -> vcSolver) + paramParts += (fd -> paramPart) + }) + } + + def freeVCSolvers { + vcSolvers.foreach(entry => entry._2.free) + } + + /** + * This function computes invariants belonging to the given templates incrementally. + * The result is a mapping from function definitions to the corresponding invariants. + */ + override def solve(tempIds: Set[Identifier], funcVCs: Map[FunDef, Expr]): (Option[Model], Option[Set[Call]]) = { + //initialize vcs of functions + this.funcVCs = funcVCs + if (useIncrementalSolvingForVCs) { + initVCSolvers + } + val initModel = if (this.startFromEarlierModel && lastFoundModel.isDefined) { + val candModel = lastFoundModel.get + new Model(tempIds.map(id => + (id -> candModel.getOrElse(id, simplestValue(id.getType)))).toMap) + } else { + new Model(tempIds.map((id) => + (id -> simplestValue(id.getType))).toMap) + } + val sol = solveUNSAT(initModel, tru, Seq(), Set()) + + if (useIncrementalSolvingForVCs) { + freeVCSolvers + } + //set lowerbound map + //TODO: find a way to record lower bound stats + /*if (ctx.tightBounds) + SpecificStats.addLowerBoundStats(rootFun, minimizer.lowerBoundMap, "")*/ + //miscellaneous stuff + if (trackNumericalDisjuncts) { + this.numericalDisjuncts = List[Expr]() + } + sol + } + + //state for minimization + var minStarted = false + var minStartTime: Long = 0 + var minimized = false + + def minimizationInProgress { + if (!minStarted) { + minStarted = true + minStartTime = System.currentTimeMillis() + } + } + + def minimizationCompleted { + minStarted = false + val mintime = (System.currentTimeMillis() - minStartTime) + /*Stats.updateCounterTime(mintime, "minimization-time", "procs") + Stats.updateCumTime(mintime, "Total-Min-Time")*/ + } + + def solveUNSAT(model: Model, inputCtr: Expr, solvedDisjs: Seq[Expr], seenCalls: Set[Call]): (Option[Model], Option[Set[Call]]) = { + + if (verbose) { + reporter.info("Candidate invariants") + val candInvs = getAllInvariants(model) + candInvs.foreach((entry) => reporter.info(entry._1.id + "-->" + entry._2)) + } + + if (this.startFromEarlierModel) this.lastFoundModel = Some(model) + + val (res, newCtr, newModel, newdisjs, newcalls) = invalidateSATDisjunct(inputCtr, model) + res match { + case None => { + //here, we cannot proceed and have to return unknown + //However, we can return the calls that need to be unrolled + (None, Some(seenCalls ++ newcalls)) + } + case Some(false) => { + //here, the vcs are unsatisfiable when instantiated with the invariant + if (minimizer.isDefined) { + //for stats + minimizationInProgress + if (minimized) { + minimizationCompleted + (Some(model), None) + } else { + val minModel = minimizer.get(inputCtr, model) + minimized = true + if (minModel == model) { + minimizationCompleted + (Some(model), None) + } else { + solveUNSAT(minModel, inputCtr, solvedDisjs, seenCalls) + } + } + } else { + (Some(model), None) + } + } + case Some(true) => { + //here, we have found a new candidate invariant. Hence, the above process needs to be repeated + minimized = false + solveUNSAT(newModel, newCtr, solvedDisjs ++ newdisjs, seenCalls ++ newcalls) + } + } + } + + //TODO: this code does too much imperative update. + //TODO: use guards to block a path and not use the path itself + def invalidateSATDisjunct(inputCtr: Expr, model: Model): (Option[Boolean], Expr, Model, Seq[Expr], Set[Call]) = { + + val tempIds = model.map(_._1) + val tempVarMap: Map[Expr, Expr] = model.map((elem) => (elem._1.toVariable, elem._2)).toMap + val inputSize = Util.atomNum(inputCtr) + + var disjsSolvedInIter = Seq[Expr]() + var callsInPaths = Set[Call]() + var conflictingFuns = funcVCs.keySet + //mapping from the functions to the counter-example paths that were seen + var seenPaths = MutableMap[FunDef, Seq[Expr]]() + def updateSeenPaths(fd: FunDef, cePath: Expr): Unit = { + if (seenPaths.contains(fd)) { + seenPaths.update(fd, cePath +: seenPaths(fd)) + } else { + seenPaths += (fd -> Seq(cePath)) + } + } + + def invalidateDisjRecr(prevCtr: Expr): (Option[Boolean], Expr, Model) = { + + Stats.updateCounter(1, "disjuncts") + + var blockedCEs = false + var confFunctions = Set[FunDef]() + var confDisjuncts = Seq[Expr]() + + val newctrs = conflictingFuns.foldLeft(Seq[Expr]())((acc, fd) => { + + val disableCounterExs = if (seenPaths.contains(fd)) { + blockedCEs = true + Not(Util.createOr(seenPaths(fd))) + } else tru + val (data, ctrsForFun) = getUNSATConstraints(fd, model, disableCounterExs) + val (disjunct, callsInPath) = data + if (ctrsForFun == tru) acc + else { + confFunctions += fd + confDisjuncts :+= disjunct + callsInPaths ++= callsInPath + //instantiate the disjunct + val cePath = simplifyArithmetic(TemplateInstantiator.instantiate(disjunct, tempVarMap)) + + //some sanity checks + if (variablesOf(cePath).exists(TemplateIdFactory.IsTemplateIdentifier _)) + throw new IllegalStateException("Found template identifier in counter-example disjunct: " + cePath) + + updateSeenPaths(fd, cePath) + acc :+ ctrsForFun + } + }) + //update conflicting functions + conflictingFuns = confFunctions + if (newctrs.isEmpty) { + + if (!blockedCEs) { + //yes, hurray,found an inductive invariant + (Some(false), prevCtr, model) + } else { + //give up, only hard paths remaining + reporter.info("- Exhausted all easy paths !!") + reporter.info("- Number of remaining hard paths: " + seenPaths.values.foldLeft(0)((acc, elem) => acc + elem.size)) + //TODO: what to unroll here ? + (None, tru, Model.empty) + } + } else { + + //check that the new constraints does not have any reals + val newPart = Util.createAnd(newctrs) + val newSize = Util.atomNum(newPart) + Stats.updateCounterStats((newSize + inputSize), "NLsize", "disjuncts") + if (verbose) + reporter.info("# of atomic predicates: " + newSize + " + " + inputSize) + + /*if (this.debugIncremental) + solverWithCtr.assertCnstr(newPart)*/ + + //here we need to solve for the newctrs + inputCtrs + val combCtr = And(prevCtr, newPart) + val (res, newModel) = farkasSolver.solveFarkasConstraints(combCtr) + + res match { + case None => { + //here we have timed out while solving the non-linear constraints + if (verbose) + if (!this.disableCegis) + reporter.info("NLsolver timed-out on the disjunct... starting cegis phase...") + else + reporter.info("NLsolver timed-out on the disjunct... blocking this disjunct...") + + if (!this.disableCegis) { + val (cres, cctr, cmodel) = solveWithCegis(tempIds.toSet, Util.createOr(confDisjuncts), inputCtr, Some(model)) + cres match { + case Some(true) => { + disjsSolvedInIter ++= confDisjuncts + (Some(true), And(inputCtr, cctr), cmodel) + } + case Some(false) => { + disjsSolvedInIter ++= confDisjuncts + //here also return the calls that needs to be unrolled + (None, fls, Model.empty) + } + case _ => { + if (verbose) reporter.info("retrying...") + Stats.updateCumStats(1, "retries") + //disable this disjunct and retry but, use the inputCtrs + the constraints generated by cegis from the next iteration + invalidateDisjRecr(And(inputCtr, cctr)) + } + } + } else { + if (verbose) reporter.info("retrying...") + Stats.updateCumStats(1, "retries") + invalidateDisjRecr(inputCtr) + } + } + case Some(false) => { + //reporter.info("- Number of explored paths (of the DAG) in this unroll step: " + exploredPaths) + disjsSolvedInIter ++= confDisjuncts + (None, fls, Model.empty) + } + case Some(true) => { + disjsSolvedInIter ++= confDisjuncts + //new model may not have mappings for all the template variables, hence, use the mappings from earlier models + val compModel = new Model(tempIds.map((id) => { + if (newModel.isDefinedAt(id)) + (id -> newModel(id)) + else + (id -> model(id)) + }).toMap) + (Some(true), combCtr, compModel) + } + } + } + } + val (res, newctr, newmodel) = invalidateDisjRecr(inputCtr) + (res, newctr, newmodel, disjsSolvedInIter, callsInPaths) + } + + def solveWithCegis(tempIds: Set[Identifier], expr: Expr, precond: Expr, initModel: Option[Model]): (Option[Boolean], Expr, Model) = { + + val cegisSolver = new CegisCore(ctx, timeout, this) + val (res, ctr, model) = cegisSolver.solve(tempIds, expr, precond, solveAsInt = false, initModel) + if (!res.isDefined) + reporter.info("cegis timed-out on the disjunct...") + (res, ctr, model) + } + + protected def instantiateTemplate(e: Expr, tempVarMap: Map[Expr, Expr]): Expr = { + if (ctx.usereals) replace(tempVarMap, e) + else + simplifyArithmetic(TemplateInstantiator.instantiate(e, tempVarMap)) + } + + /** + * Constructs a quantifier-free non-linear constraint for unsatisfiability + */ + def getUNSATConstraints(fd: FunDef, inModel: Model, disableCounterExs: Expr): ((Expr, Set[Call]), Expr) = { + + val tempVarMap: Map[Expr, Expr] = inModel.map((elem) => (elem._1.toVariable, elem._2)).toMap + val innerSolver = if (this.useIncrementalSolvingForVCs) vcSolvers(fd) + else new ExtendedUFSolver(leonctx, program) + val instExpr = if (this.useIncrementalSolvingForVCs) { + val instParamPart = instantiateTemplate(this.paramParts(fd), tempVarMap) + And(instParamPart, disableCounterExs) + } else { + val instVC = instantiateTemplate(funcVCs(fd), tempVarMap) + And(instVC, disableCounterExs) + } + //For debugging + if (this.dumpInstantiatedVC) { + // println("Plain vc: "+funcVCs(fd)) + val wr = new PrintWriter(new File("formula-dump.txt")) + val fullExpr = if (this.useIncrementalSolvingForVCs) { + And(innerSolver.getAssertions, instExpr) + } else + instExpr + // println("Instantiated VC of " + fd.id + " is: " + fullExpr) + wr.println("Function name: " + fd.id) + wr.println("Formula expr: ") + ExpressionTransformer.PrintWithIndentation(wr, fullExpr) + wr.flush() + wr.close() + } + //throw an exception if the candidate expression has reals + if (Util.hasMixedIntReals(instExpr)) { + //variablesOf(instExpr).foreach(id => println("Id: "+id+" type: "+id.getType)) + throw new IllegalStateException("Instantiated VC of " + fd.id + " contains mixed integer/reals: " + instExpr) + } + + //reporter.info("checking VC inst ...") + var t1 = System.currentTimeMillis() + val (res, model) = if (this.useIncrementalSolvingForVCs) { + innerSolver.push + innerSolver.assertCnstr(instExpr) + //dump the inst VC as SMTLIB + /*val filename = "vc" + FileCountGUID.getID + ".smt2" + Util.toZ3SMTLIB(innerSolver.getAssertions, filename, "", leonctx, program) + val writer = new PrintWriter(filename) + writer.println(innerSolver.ctrsToString("")) + writer.close() + println("vc dumped to: " + filename)*/ + + val solRes = innerSolver.check + innerSolver.pop() + solRes match { + case Some(true) => (solRes, innerSolver.getModel) + case _ => (solRes, Model.empty) + } + } else { + val solver = SimpleSolverAPI(SolverFactory(() => innerSolver)) + solver.solveSAT(instExpr) + } + val vccTime = (System.currentTimeMillis() - t1) + + if (verbose) reporter.info("checked VC inst... in " + vccTime / 1000.0 + "s") + Stats.updateCounterTime(vccTime, "VC-check-time", "disjuncts") + Stats.updateCumTime(vccTime, "TotalVCCTime") + + //for debugging + if (this.trackUnpackedVCCTime) { + val upVCinst = simplifyArithmetic(TemplateInstantiator.instantiate(ctrTracker.getVC(fd).unpackedExpr, tempVarMap)) + Stats.updateCounterStats(Util.atomNum(upVCinst), "UP-VC-size", "disjuncts") + + t1 = System.currentTimeMillis() + val (res2, _) = SimpleSolverAPI(SolverFactory(() => new ExtendedUFSolver(leonctx, program))).solveSAT(upVCinst) + val unpackedTime = System.currentTimeMillis() - t1 + if (res != res2) { + throw new IllegalStateException("Unpacked VC produces different result: " + upVCinst) + } + Stats.updateCumTime(unpackedTime, "TotalUPVCCTime") + reporter.info("checked UP-VC inst... in " + unpackedTime / 1000.0 + "s") + } + + t1 = System.currentTimeMillis() + res match { + case None => { + throw new IllegalStateException("cannot check the satisfiability of " + funcVCs(fd)) + } + case Some(false) => { + //do not generate any constraints + ((fls, Set()), tru) + } + case Some(true) => { + //For debugging purposes. + if (verbose) reporter.info("Function: " + fd.id + "--Found candidate invariant is not a real invariant! ") + if (this.printCounterExample) { + reporter.info("Model: " + model) + } + + //get the disjuncts that are satisfied + val (data, newctr) = generateCtrsFromDisjunct(fd, model) + if (newctr == tru) + throw new IllegalStateException("Cannot find a counter-example path!!") + + val t2 = System.currentTimeMillis() + Stats.updateCounterTime((t2 - t1), "Disj-choosing-time", "disjuncts") + Stats.updateCumTime((t2 - t1), "Total-Choose-Time") + + (data, newctr) + } + } + } + + val evaluator = new DefaultEvaluator(leonctx, program) //as of now used only for debugging + //a helper method + //TODO: this should also handle reals + protected def doesSatisfyModel(expr: Expr, model: Model): Boolean = { + evaluator.eval(expr, model).result match { + case Some(BooleanLiteral(true)) => true + case _ => false + } + } + + /** + * Evaluator for a predicate that is a simple equality/inequality between two variables + */ + protected def predEval(model: Model): (Expr => Boolean) = { + if (ctx.usereals) realEval(model) + else intEval(model) + } + + protected def intEval(model: Model): (Expr => Boolean) = { + def modelVal(id: Identifier): BigInt = { + val InfiniteIntegerLiteral(v) = model(id) + v + } + def eval: (Expr => Boolean) = e => e match { + case And(args) => args.forall(eval) + // case Iff(Variable(id1), Variable(id2)) => model(id1) == model(id2) + case Equals(Variable(id1), Variable(id2)) => model(id1) == model(id2) //note: ADTs can also be compared for equality + case LessEquals(Variable(id1), Variable(id2)) => modelVal(id1) <= modelVal(id2) + case GreaterEquals(Variable(id1), Variable(id2)) => modelVal(id1) >= modelVal(id2) + case GreaterThan(Variable(id1), Variable(id2)) => modelVal(id1) > modelVal(id2) + case LessThan(Variable(id1), Variable(id2)) => modelVal(id1) < modelVal(id2) + case _ => throw new IllegalStateException("Predicate not handled: " + e) + } + eval + } + + protected def realEval(model: Model): (Expr => Boolean) = { + def modelVal(id: Identifier): FractionalLiteral = { + //println("Identifier: "+id) + model(id).asInstanceOf[FractionalLiteral] + } + (e: Expr) => e match { + case Equals(Variable(id1), Variable(id2)) => model(id1) == model(id2) //note: ADTs can also be compared for equality + case Operator(Seq(Variable(id1), Variable(id2)), op) if (e.isInstanceOf[LessThan] + || e.isInstanceOf[LessEquals] || e.isInstanceOf[GreaterThan] + || e.isInstanceOf[GreaterEquals]) => { + evaluateRealPredicate(op(Seq(modelVal(id1), modelVal(id2)))) + } + case _ => throw new IllegalStateException("Predicate not handled: " + e) + } + } + + /** + * This solver does not use any theories other than UF/ADT. It assumes that other theories are axiomatized in the VC. + * This method can overloaded by the subclasses. + */ + protected def axiomsForTheory(formula: Formula, calls: Set[Call], model: Model): Seq[Constraint] = Seq() + + protected def generateCtrsFromDisjunct(fd: FunDef, model: Model): ((Expr, Set[Call]), Expr) = { + + val formula = ctrTracker.getVC(fd) + //this picks the satisfiable disjunct of the VC modulo axioms + val satCtrs = formula.pickSatDisjunct(formula.firstRoot, model) + //for debugging + if (this.debugChooseDisjunct || this.printPathToConsole || this.dumpPathAsSMTLIB || this.verifyInvariant) { + val pathctrs = satCtrs.map(_.toExpr) + val plainFormula = Util.createAnd(pathctrs) + val pathcond = simplifyArithmetic(plainFormula) + + if (this.debugChooseDisjunct) { + satCtrs.filter(_.isInstanceOf[LinearConstraint]).map(_.toExpr).foreach((ctr) => { + if (!doesSatisfyModel(ctr, model)) + throw new IllegalStateException("Path ctr not satisfied by model: " + ctr) + }) + } + + if (this.verifyInvariant) { + println("checking invariant for path...") + val sat = Util.checkInvariant(pathcond, leonctx, program) + } + + if (this.printPathToConsole) { + //val simpcond = ExpressionTransformer.unFlatten(pathcond, variablesOf(pathcond).filterNot(TVarFactory.isTemporary _)) + val simpcond = pathcond + println("Full-path: " + ScalaPrinter(simpcond)) + val filename = "full-path-" + FileCountGUID.getID + ".txt" + val wr = new PrintWriter(new File(filename)) + ExpressionTransformer.PrintWithIndentation(wr, simpcond) + println("Printed to file: " + filename) + wr.flush() + wr.close() + } + + if (this.dumpPathAsSMTLIB) { + val filename = "pathcond" + FileCountGUID.getID + ".smt2" + Util.toZ3SMTLIB(pathcond, filename, "QF_NIA", leonctx, program) + println("Path dumped to: " + filename) + } + } + + var calls = Set[Call]() + var cons = Set[Expr]() + satCtrs.foreach(ctr => ctr match { + case t: Call => calls += t + case t: ADTConstraint if (t.cons.isDefined) => cons += t.cons.get + case _ => ; + }) + val callExprs = calls.map(_.toExpr) + + var t1 = System.currentTimeMillis() + val axiomCtrs = ctrTracker.specInstantiator.axiomsForCalls(formula, calls, model) + var t2 = System.currentTimeMillis() + Stats.updateCumTime((t2 - t1), "Total-AxiomChoose-Time") + + //here, handle theory operations by reducing them to axioms. + //Note: uninterpreted calls/ADTs are handled below as they are more general. Here, we handle + //other theory axioms like: multiplication, sets, arrays, maps etc. + t1 = System.currentTimeMillis() + val theoryCtrs = axiomsForTheory(formula, calls, model) + t2 = System.currentTimeMillis() + Stats.updateCumTime((t2 - t1), "Total-TheoryAxiomatization-Time") + + //Finally, eliminate UF/ADT + t1 = System.currentTimeMillis() + val callCtrs = (new UFADTEliminator(leonctx, program)).constraintsForCalls((callExprs ++ cons), + predEval(model)).map(ConstraintUtil.createConstriant _) + t2 = System.currentTimeMillis() + Stats.updateCumTime((t2 - t1), "Total-ElimUF-Time") + + //exclude guards, separate calls and cons from the rest + var lnctrs = Set[LinearConstraint]() + var temps = Set[LinearTemplate]() + (satCtrs ++ callCtrs ++ axiomCtrs ++ theoryCtrs).foreach(ctr => ctr match { + case t: LinearConstraint => lnctrs += t + case t: LinearTemplate => temps += t + case _ => ; + }) + + if (this.debugChooseDisjunct) { + lnctrs.map(_.toExpr).foreach((ctr) => { + if (!doesSatisfyModel(ctr, model)) + throw new IllegalStateException("Ctr not satisfied by model: " + ctr) + }) + } + + if (this.debugTheoryReduction) { + val simpPathCond = Util.createAnd((lnctrs ++ temps).map(_.template).toSeq) + if (this.verifyInvariant) { + println("checking invariant for simp-path...") + Util.checkInvariant(simpPathCond, leonctx, program) + } + } + + if (this.trackNumericalDisjuncts) { + numericalDisjuncts :+= Util.createAnd((lnctrs ++ temps).map(_.template).toSeq) + } + + val (data, nlctr) = processNumCtrs(lnctrs.toSeq, temps.toSeq) + ((data, calls), nlctr) + } + + /** + * Endpoint of the pipeline. Invokes the Farkas Lemma constraint generation. + */ + def processNumCtrs(lnctrs: Seq[LinearConstraint], temps: Seq[LinearTemplate]): (Expr, Expr) = { + //here we are invalidating A^~(B) + if (temps.isEmpty) { + //here ants ^ conseq is sat (otherwise we wouldn't reach here) and there is no way to falsify this path + (Util.createAnd(lnctrs.map(_.toExpr)), fls) + } else { + + if (this.debugElimination) { + //println("Path Constraints (before elim): "+(lnctrs ++ temps)) + if (this.verifyInvariant) { + println("checking invariant for disjunct before elimination...") + Util.checkInvariant(Util.createAnd((lnctrs ++ temps).map(_.template)), leonctx, program) + } + } + //compute variables to be eliminated + val t1 = System.currentTimeMillis() + val ctrVars = lnctrs.foldLeft(Set[Identifier]())((acc, lc) => acc ++ variablesOf(lc.toExpr)) + val tempVars = temps.foldLeft(Set[Identifier]())((acc, lt) => acc ++ variablesOf(lt.template)) + val elimVars = ctrVars.diff(tempVars) + + val debugger = if (debugElimination && verifyInvariant) { + Some((ctrs: Seq[LinearConstraint]) => { + //println("checking disjunct before elimination...") + //println("ctrs: "+ctrs) + val debugRes = Util.checkInvariant(Util.createAnd((ctrs ++ temps).map(_.template)), leonctx, program) + }) + } else None + val elimLnctrs = LinearConstraintUtil.apply1PRuleOnDisjunct(lnctrs, elimVars, debugger) + val t2 = System.currentTimeMillis() + + if (this.debugElimination) { + println("Path constriants (after elimination): " + elimLnctrs) + if (this.verifyInvariant) { + println("checking invariant for disjunct after elimination...") + Util.checkInvariant(Util.createAnd((elimLnctrs ++ temps).map(_.template)), leonctx, program) + } + } + //for stats + if (ctx.dumpStats) { + var elimCtrCount = 0 + var elimCtrs = Seq[LinearConstraint]() + var elimRems = Set[Identifier]() + elimLnctrs.foreach((lc) => { + val evars = variablesOf(lc.toExpr).intersect(elimVars) + if (!evars.isEmpty) { + elimCtrs :+= lc + elimCtrCount += 1 + elimRems ++= evars + } + }) + Stats.updateCounterStats((elimVars.size - elimRems.size), "Eliminated-Vars", "disjuncts") + Stats.updateCounterStats((lnctrs.size - elimLnctrs.size), "Eliminated-Atoms", "disjuncts") + Stats.updateCounterStats(temps.size, "Param-Atoms", "disjuncts") + Stats.updateCounterStats(lnctrs.size, "NonParam-Atoms", "disjuncts") + Stats.updateCumTime((t2 - t1), "ElimTime") + } + val newLnctrs = elimLnctrs.toSet.toSeq + + //TODO:Remove transitive facts. E.g. a <= b, b <= c, a <=c can be simplified by dropping a <= c + //TODO: simplify the formulas and remove implied conjuncts if possible (note the formula is satisfiable, so there can be no inconsistencies) + //e.g, remove: a <= b if we have a = b or if a < b + //Also, enrich the rules for quantifier elimination: try z3 quantifier elimination on variables that have an equality. + + //TODO: Use the dependence chains in the formulas to identify what to assertionize + // and what can never be implied by solving for the templates + + val disjunct = Util.createAnd((newLnctrs ++ temps).map(_.template)) + val implCtrs = farkasSolver.constraintsForUnsat(newLnctrs, temps) + + //for debugging + if (this.debugReducedFormula) { + println("Final Path Constraints: " + disjunct) + if (this.verifyInvariant) { + println("checking invariant for final disjunct... ") + Util.checkInvariant(disjunct, leonctx, program) + } + } + + (disjunct, implCtrs) + } + } +} diff --git a/src/main/scala/leon/invariant/templateSolvers/NLTemplateSolverWithMult.scala b/src/main/scala/leon/invariant/templateSolvers/NLTemplateSolverWithMult.scala new file mode 100644 index 0000000000000000000000000000000000000000..10a232e3fcb4730a4e2d3026ebd5f92db2f51ee5 --- /dev/null +++ b/src/main/scala/leon/invariant/templateSolvers/NLTemplateSolverWithMult.scala @@ -0,0 +1,98 @@ +package leon +package invariant.templateSolvers +import z3.scala._ +import purescala.Common._ +import purescala.Definitions._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.Extractors._ +import purescala.Types._ +import java.io._ +import leon.invariant._ +import scala.util.control.Breaks._ +import solvers._ + +import invariant.engine._ +import invariant.factories._ +import invariant.util._ +import invariant.structure._ + +class NLTemplateSolverWithMult(ctx : InferenceContext, rootFun: FunDef, + ctrTracker: ConstraintTracker, minimizer: Option[(Expr, Model) => Model]) + extends NLTemplateSolver(ctx, rootFun, ctrTracker, minimizer) { + + val axiomFactory = new AxiomFactory(ctx) + + override def getVCForFun(fd: FunDef): Expr = { + val plainvc = ctrTracker.getVC(fd).toExpr + val nlvc = Util.multToTimes(plainvc) + nlvc + } + + override def splitVC(fd: FunDef) : (Expr,Expr) = { + val (paramPart, rest) = ctrTracker.getVC(fd).splitParamPart + (Util.multToTimes(paramPart),Util.multToTimes(rest)) + } + + override def axiomsForTheory(formula : Formula, calls: Set[Call], model: Model) : Seq[Constraint] = { + + //in the sequel we instantiate axioms for multiplication + val inst1 = unaryMultAxioms(formula, calls, predEval(model)) + val inst2 = binaryMultAxioms(formula,calls, predEval(model)) + val multCtrs = (inst1 ++ inst2).flatMap(_ match { + case And(args) => args.map(ConstraintUtil.createConstriant _) + case e => Seq(ConstraintUtil.createConstriant(e)) + }) + + Stats.updateCounterStats(multCtrs.size, "MultAxiomBlowup", "disjuncts") + ctx.reporter.info("Number of multiplication induced predicates: "+multCtrs.size) + multCtrs + } + + def chooseSATPredicate(expr: Expr, predEval: (Expr => Boolean)): Expr = { + val norme = ExpressionTransformer.normalizeExpr(expr,ctx.multOp) + val preds = norme match { + case Or(args) => args + case Operator(_, _) => Seq(norme) + case _ => throw new IllegalStateException("Not(ant) is not in expected format: " + norme) + } + //pick the first predicate that holds true + preds.collectFirst { case pred @ _ if predEval(pred) => pred }.get + } + + def isMultOp(call : Call) : Boolean = { + Util.isMultFunctions(call.fi.tfd.fd) + } + + def unaryMultAxioms(formula: Formula, calls: Set[Call], predEval: (Expr => Boolean)) : Seq[Expr] = { + val axioms = calls.flatMap { + case call@_ if (isMultOp(call) && axiomFactory.hasUnaryAxiom(call)) => { + val (ant,conseq) = axiomFactory.unaryAxiom(call) + if(predEval(ant)) + Seq(ant,conseq) + else + Seq(chooseSATPredicate(Not(ant), predEval)) + } + case _ => Seq() + } + axioms.toSeq + } + + def binaryMultAxioms(formula: Formula, calls: Set[Call], predEval: (Expr => Boolean)) : Seq[Expr] = { + + val mults = calls.filter(call => isMultOp(call) && axiomFactory.hasBinaryAxiom(call)) + val product = Util.cross(mults,mults).collect{ case (c1,c2) if c1 != c2 => (c1,c2) } + + ctx.reporter.info("Theory axioms: "+product.size) + Stats.updateCumStats(product.size, "-Total-theory-axioms") + + val newpreds = product.flatMap(pair => { + val axiomInsts = axiomFactory.binaryAxiom(pair._1, pair._2) + axiomInsts.flatMap { + case (ant,conseq) if predEval(ant) => Seq(ant,conseq) //if axiom-pre holds. + case (ant,_) => Seq(chooseSATPredicate(Not(ant), predEval)) //if axiom-pre does not hold. + } + }) + newpreds.toSeq + } +} diff --git a/src/main/scala/leon/invariant/templateSolvers/TemplateSolver.scala b/src/main/scala/leon/invariant/templateSolvers/TemplateSolver.scala new file mode 100644 index 0000000000000000000000000000000000000000..9b360c2db7d2f4442fb253e7fed58a0079b44c84 --- /dev/null +++ b/src/main/scala/leon/invariant/templateSolvers/TemplateSolver.scala @@ -0,0 +1,117 @@ +package leon +package invariant.templateSolvers + +import z3.scala._ +import purescala.Common._ +import purescala.Definitions._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.Extractors._ +import purescala.Types._ +import java.io._ +import leon.invariant._ +import scala.util.control.Breaks._ +import scala.concurrent._ +import scala.concurrent.duration._ +import invariant.engine._ +import invariant.factories._ +import invariant.util._ +import invariant.structure._ +import invariant.structure.FunctionUtils._ +import leon.solvers.Model + +abstract class TemplateSolver(ctx: InferenceContext, val rootFun: FunDef, + ctrTracker: ConstraintTracker) { + + protected val reporter = ctx.reporter + //protected val cg = CallGraphUtil.constructCallGraph(program) + + //some constants + protected val fls = BooleanLiteral(false) + protected val tru = BooleanLiteral(true) + //protected val zero = IntLiteral(0) + + private val dumpVCtoConsole = false + private val dumpVCasText = false + private val dumpVCasSMTLIB = false + + /** + * Completes a model by adding mapping to new template variables + */ + def completeModel(model: Map[Identifier, Expr], tempIds: Set[Identifier]): Map[Identifier, Expr] = { + tempIds.map((id) => { + if (!model.contains(id)) { + (id, simplestValue(id.getType)) + } else (id, model(id)) + }).toMap + } + + /** + * Computes the invariant for all the procedures given a mapping for the + * template variables. + */ + def getAllInvariants(model: Model): Map[FunDef, Expr] = { + val templates = ctrTracker.getFuncs.collect { + case fd if fd.hasTemplate => + fd -> fd.getTemplate + } + TemplateInstantiator.getAllInvariants(model, templates.toMap) + } + + protected def getVCForFun(fd: FunDef): Expr = { + ctrTracker.getVC(fd).toExpr + } + + /** + * This function computes invariants belonging to the given templates incrementally. + * The result is a mapping from function definitions to the corresponding invariants. + */ + def solveTemplates(): (Option[Model], Option[Set[Call]]) = { + //traverse each of the functions and collect the VCs + val funcs = ctrTracker.getFuncs + val funcExprs = funcs.map((fd) => { + val vc = if (ctx.usereals) + ExpressionTransformer.IntLiteralToReal(getVCForFun(fd)) + else getVCForFun(fd) + if (dumpVCtoConsole || dumpVCasText || dumpVCasSMTLIB) { + //val simpForm = simplifyArithmetic(vc) + val filename = "vc-" + FileCountGUID.getID + if (dumpVCtoConsole) { + println("Func: " + fd.id + " VC: " + vc) + } + if (dumpVCasText) { + val wr = new PrintWriter(new File(filename + ".txt")) + //ExpressionTransformer.PrintWithIndentation(wr, vcstr) + println("Printed VC of " + fd.id + " to file: " + filename) + wr.println(vc.toString) + wr.flush() + wr.close() + } + if (dumpVCasSMTLIB) { + Util.toZ3SMTLIB(vc, filename + ".smt2", "QF_LIA", ctx.leonContext, ctx.program) + println("Printed VC of " + fd.id + " to file: " + filename) + } + } + + if (ctx.dumpStats) { + Stats.updateCounterStats(Util.atomNum(vc), "VC-size", "VC-refinement") + Stats.updateCounterStats(Util.numUIFADT(vc), "UIF+ADT", "VC-refinement") + } + (fd -> vc) + }).toMap + //Assign some values for the template variables at random (actually use the simplest value for the type) + val tempIds = funcExprs.foldLeft(Set[Identifier]()) { + case (acc, (_, vc)) => + //val tempOption = if (fd.hasTemplate) Some(fd.getTemplate) else None + //if (!tempOption.isDefined) acc + //else + acc ++ Util.getTemplateIds(vc) + } + + Stats.updateCounterStats(tempIds.size, "TemplateIds", "VC-refinement") + val solution = solve(tempIds, funcExprs) + solution + } + + def solve(tempIds: Set[Identifier], funcVCs: Map[FunDef, Expr]): (Option[Model], Option[Set[Call]]) +} \ No newline at end of file diff --git a/src/main/scala/leon/invariant/templateSolvers/UFADTEliminator.scala b/src/main/scala/leon/invariant/templateSolvers/UFADTEliminator.scala new file mode 100644 index 0000000000000000000000000000000000000000..fb7519850bcebaafab437f1d270f2f9d37417dc9 --- /dev/null +++ b/src/main/scala/leon/invariant/templateSolvers/UFADTEliminator.scala @@ -0,0 +1,288 @@ +package leon +package invariant.templateSolvers +import z3.scala._ +import purescala.Common._ +import purescala.Definitions._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.Extractors._ +import purescala.Types._ +import java.io._ +import leon.invariant.util.UndirectedGraph +import scala.util.control.Breaks._ +import invariant.util._ +import leon.purescala.TypeOps + +class UFADTEliminator(ctx: LeonContext, program: Program) { + + val debugAliases = false + val makeEfficient = true //this will happen at the expense of completeness + val reporter = ctx.reporter + val verbose = false + + def collectCompatibleCalls(calls: Set[Expr]) = { + //compute the cartesian product of the calls and select the pairs having the same function symbol and also implied by the precond + val vec = calls.toArray + val size = calls.size + var j = 0 + //for stats + var tuples = 0 + var functions = 0 + var adts = 0 + val product = vec.foldLeft(Set[(Expr, Expr)]())((acc, call) => { + + //an optimization: here we can exclude calls to maxFun from axiomatization, they will be inlined anyway + /*val shouldConsider = if(InvariantUtil.isCallExpr(call)) { + val BinaryOperator(_,FunctionInvocation(calledFun,_), _) = call + if(calledFun == DepthInstPhase.maxFun) false + else true + } else true*/ + var pairs = Set[(Expr, Expr)]() + for (i <- j + 1 until size) { + val call2 = vec(i) + if (mayAlias(call, call2)) { + + call match { + case Equals(_, fin : FunctionInvocation) => functions += 1 + case Equals(_, tup : Tuple) => tuples += 1 + case _ => adts += 1 + } + if (debugAliases) + println("Aliases: " + call + "," + call2) + + pairs ++= Set((call, call2)) + + } else { + if (debugAliases) { + (call, call2) match { + case (Equals(_, t1 @ Tuple(_)), Equals(_, t2 @ Tuple(_))) => + println("No Aliases: " + t1.getType + "," + t2.getType) + case _ => println("No Aliases: " + call + "," + call2) + } + } + } + } + j += 1 + acc ++ pairs + }) + if(verbose) reporter.info("Number of compatible calls: " + product.size) + /*reporter.info("Compatible Tuples: "+tuples) + reporter.info("Compatible Functions+ADTs: "+(functions+adts))*/ + Stats.updateCounterStats(product.size, "Compatible-Calls", "disjuncts") + Stats.updateCumStats(functions, "Compatible-functioncalls") + Stats.updateCumStats(adts, "Compatible-adtcalls") + Stats.updateCumStats(tuples, "Compatible-tuples") + product + } + + /** + * Convert the theory formula into linear arithmetic formula. + * The calls could be functions calls or ADT constructor calls. + * 'predEval' is an evaluator that evaluates a predicate to a boolean value + */ + def constraintsForCalls(calls: Set[Expr], predEval: (Expr => Boolean)): Seq[Expr] = { + + //check if two calls (to functions or ADT cons) have the same value in the model + def doesAlias(call1: Expr, call2: Expr): Boolean = { + val Operator(Seq(r1 @ Variable(_), _), _) = call1 + val Operator(Seq(r2 @ Variable(_), _), _) = call2 + val resEquals = predEval(Equals(r1, r2)) + if (resEquals) { + if (Util.isCallExpr(call1)) { + val (ants, _) = axiomatizeCalls(call1, call2) + val antsHold = ants.forall(ant => { + val Operator(Seq(lvar @ Variable(_), rvar @ Variable(_)), _) = ant + //(model(lid) == model(rid)) + predEval(Equals(lvar, rvar)) + }) + antsHold + } else true + } else false + } + + def predForEquality(call1: Expr, call2: Expr): Seq[Expr] = { + + val eqs = if (Util.isCallExpr(call1)) { + val (_, rhs) = axiomatizeCalls(call1, call2) + Seq(rhs) + } else { + val (lhs, rhs) = axiomatizeADTCons(call1, call2) + lhs :+ rhs + } + //remove self equalities. + val preds = eqs.filter(_ match { + case Operator(Seq(Variable(lid), Variable(rid)), _) => { + if (lid == rid) false + else { + if (lid.getType == Int32Type || lid.getType == RealType || lid.getType == IntegerType) true + else false + } + } + case e @ _ => throw new IllegalStateException("Not an equality or Iff: " + e) + }) + preds + } + + def predForDisequality(call1: Expr, call2: Expr): Seq[Expr] = { + + val (ants, _) = if (Util.isCallExpr(call1)) { + axiomatizeCalls(call1, call2) + } else { + axiomatizeADTCons(call1, call2) + } + + if (makeEfficient && ants.exists(_ match { + case Equals(l, r) if (l.getType != RealType && l.getType != BooleanType && l.getType != IntegerType) => true + case _ => false + })) { + Seq() + } else { + var unsatIntEq: Option[Expr] = None + var unsatOtherEq: Option[Expr] = None + ants.foreach(eq => + if (!unsatOtherEq.isDefined) { + eq match { + case Equals(lhs @ Variable(_), rhs @ Variable(_)) if !predEval(Equals(lhs, rhs)) => { + if (lhs.getType != Int32Type && lhs.getType != RealType && lhs.getType != IntegerType) + unsatOtherEq = Some(eq) + else if (!unsatIntEq.isDefined) + unsatIntEq = Some(eq) + } + case _ => ; + } + }) + if (unsatOtherEq.isDefined) Seq() //need not add any constraint + else if (unsatIntEq.isDefined) { + //pick the constraint a < b or a > b that is satisfied + val Equals(lhs @ Variable(_), rhs @ Variable(_)) = unsatIntEq.get + val lLTr = LessThan(lhs, rhs) + val atom = if (predEval(lLTr)) lLTr + else GreaterThan(lhs, rhs) + /*val InfiniteIntegerLiteral(lval) = model(lid) + val InfiniteIntegerLiteral(rval) = model(rid) + val atom = if (lval < rval) LessThan(lhs, rhs) + else if (lval > rval) GreaterThan(lhs, rhs) + else throw new IllegalStateException("Models are equal!!")*/ + + /*if (ants.exists(_ match { + case Equals(l, r) if (l.getType != Int32Type && l.getType != RealType && l.getType != BooleanType && l.getType != IntegerType) => true + case _ => false + })) { + Stats.updateCumStats(1, "Diseq-blowup") + }*/ + Seq(atom) + } else throw new IllegalStateException("All arguments are equal: " + (call1, call2)) + } + } + + var eqGraph = new UndirectedGraph[Expr]() //an equality graph + var neqSet = Set[(Expr, Expr)]() + val product = collectCompatibleCalls(calls) + val newctrs = product.foldLeft(Seq[Expr]())((acc, pair) => { + val (call1, call2) = (pair._1, pair._2) + //println("Assertionizing "+call1+" , call2: "+call2) + if (!eqGraph.BFSReach(call1, call2) && !neqSet.contains((call1, call2)) && !neqSet.contains((call2, call1))) { + if (doesAlias(call1, call2)) { + eqGraph.addEdge(call1, call2) + //note: here it suffices to check for adjacency and not reachability of calls (i.e, exprs). + //This is because the transitive equalities (corresponding to rechability) are encoded by the generated equalities. + acc ++ predForEquality(call1, call2) + + } else { + neqSet ++= Set((call1, call2)) + acc ++ predForDisequality(call1, call2) + } + } else acc + }) + + //reporter.info("Number of equal calls: " + eqGraph.getEdgeCount) + newctrs + } + + /** + * This function actually checks if two non-primitive expressions could have the same value + * (when some constraints on their arguments hold). + * Remark: notice that when the expressions have ADT types, then this is basically a form of may-alias check. + * TODO: handling generic can become very trickier here. + */ + def mayAlias(e1: Expr, e2: Expr): Boolean = { + //check if call and call2 are compatible + /*(e1, e2) match { + case (Equals(_, FunctionInvocation(fd1, _)), Equals(_, FunctionInvocation(fd2, _))) if (fd1.id == fd2.id) => true + case (Equals(_, CaseClass(cd1, _)), Equals(_, CaseClass(cd2, _))) if (cd1.id == cd2.id) => true + case (Equals(_, tp1 @ Tuple(e1)), Equals(_, tp2 @ Tuple(e2))) => { + //get the types and check if the types are compatible + val TupleType(tps1) = tp1.getType + val TupleType(tps2) = tp2.getType + (tps1 zip tps2).forall(pair => { + val (t1, t2) = pair + val lub = TypeOps.leastUpperBound(t1, t2) + (lub == Some(t1) || lub == Some(t2)) + }) + } + case _ => false + }*/ + (e1, e2) match { + case (Equals(_, FunctionInvocation(fd1, _)), Equals(_, FunctionInvocation(fd2, _))) => { + (fd1.id == fd2.id && fd1.fd.tparams == fd2.fd.tparams) + } + case (Equals(_, CaseClass(cd1, _)), Equals(_, CaseClass(cd2, _))) => { + // if (cd1.id == cd2.id && cd1.tps != cd2.tps) println("Invalidated the classes " + e1 + " " + e2) + (cd1.id == cd2.id && cd1.tps == cd2.tps) + } + case (Equals(_, tp1 @ Tuple(e1)), Equals(_, tp2 @ Tuple(e2))) => { + //get the types and check if the types are compatible + val TupleType(tps1) = tp1.getType + val TupleType(tps2) = tp2.getType + (tps1 zip tps2).forall(pair => { + val (t1, t2) = pair + val lub = TypeOps.leastUpperBound(t1, t2) + (lub == Some(t1) || lub == Some(t2)) + }) + } + case _ => false + } + } + + /** + * This procedure generates constraints for the calls to be equal + */ + def axiomatizeCalls(call1: Expr, call2: Expr): (Seq[Expr], Expr) = { + val (v1, fi1, v2, fi2) = { + val Equals(r1, f1 @ FunctionInvocation(_, _)) = call1 + val Equals(r2, f2 @ FunctionInvocation(_, _)) = call2 + (r1, f1, r2, f2) + } + + val ants = (fi1.args.zip(fi2.args)).foldLeft(Seq[Expr]())((acc, pair) => { + val (arg1, arg2) = pair + acc :+ Equals(arg1, arg2) + }) + val conseq = Equals(v1, v2) + (ants, conseq) + } + + /** + * The returned pairs should be interpreted as a bidirectional implication + */ + def axiomatizeADTCons(sel1: Expr, sel2: Expr): (Seq[Expr], Expr) = { + + val (v1, args1, v2, args2) = sel1 match { + case Equals(r1 @ Variable(_), CaseClass(_, a1)) => { + val Equals(r2 @ Variable(_), CaseClass(_, a2)) = sel2 + (r1, a1, r2, a2) + } + case Equals(r1 @ Variable(_), Tuple(a1)) => { + val Equals(r2 @ Variable(_), Tuple(a2)) = sel2 + (r1, a1, r2, a2) + } + } + + val ants = (args1.zip(args2)).foldLeft(Seq[Expr]())((acc, pair) => { + val (arg1, arg2) = pair + acc :+ Equals(arg1, arg2) + }) + val conseq = Equals(v1, v2) + (ants, conseq) + } +} diff --git a/src/main/scala/leon/invariant/util/CallGraph.scala b/src/main/scala/leon/invariant/util/CallGraph.scala new file mode 100644 index 0000000000000000000000000000000000000000..73ac3b099389eb33223d83302bc74b82d3b1547f --- /dev/null +++ b/src/main/scala/leon/invariant/util/CallGraph.scala @@ -0,0 +1,139 @@ +package leon +package invariant.util + +import purescala._ +import purescala.Common._ +import purescala.Definitions._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.Extractors._ +import purescala.Types._ +import Util._ +import invariant.structure.FunctionUtils._ + +/** + * This represents a call graph of the functions in the program + */ +class CallGraph { + val graph = new DirectedGraph[FunDef]() + + def addFunction(fd: FunDef) = graph.addNode(fd) + + def addEdgeIfNotPresent(src: FunDef, callee: FunDef): Unit = { + if (!graph.containsEdge(src, callee)) + graph.addEdge(src, callee) + } + + def callees(src: FunDef): Set[FunDef] = { + graph.getSuccessors(src) + } + + def transitiveCallees(src: FunDef): Set[FunDef] = { + graph.BFSReachables(src) + } + + def isRecursive(fd: FunDef): Boolean = { + transitivelyCalls(fd, fd) + } + + /** + * Checks if the src transitively calls the procedure proc + */ + def transitivelyCalls(src: FunDef, proc: FunDef): Boolean = { + //important: We cannot say that src calls it self even though source is reachable from itself in the callgraph + graph.BFSReach(src, proc, excludeSrc = true) + } + + def calls(src: FunDef, proc: FunDef): Boolean = { + graph.containsEdge(src, proc) + } + + /** + * sorting functions in ascending topological order + */ + def topologicalOrder: Seq[FunDef] = { + + def insert(index: Int, l: Seq[FunDef], fd: FunDef): Seq[FunDef] = { + var i = 0 + var head = Seq[FunDef]() + l.foreach((elem) => { + if (i == index) + head :+= fd + head :+= elem + i += 1 + }) + head + } + + var funcList = Seq[FunDef]() + graph.getNodes.toList.foreach((f) => { + var inserted = false + var index = 0 + for (i <- 0 to funcList.length - 1) { + if (!inserted && this.transitivelyCalls(funcList(i), f)) { + index = i + inserted = true + } + } + if (!inserted) + funcList :+= f + else funcList = insert(index, funcList, f) + }) + + funcList + } + + override def toString: String = { + val procs = graph.getNodes + procs.foldLeft("")((acc, proc) => { + acc + proc.id + " --calls--> " + + graph.getSuccessors(proc).foldLeft("")((acc, succ) => acc + "," + succ.id) + "\n" + }) + } +} + +object CallGraphUtil { + + def constructCallGraph(prog: Program, onlyBody: Boolean = false, withTemplates: Boolean = false): CallGraph = { +// + // println("Constructing call graph") + val cg = new CallGraph() + functionsWOFields(prog.definedFunctions).foreach((fd) => { + if (fd.hasBody) { + // println("Adding func " + fd.id.uniqueName) + var funExpr = fd.body.get + if (!onlyBody) { + if (fd.hasPrecondition) + funExpr = Tuple(Seq(funExpr, fd.precondition.get)) + if (fd.hasPostcondition) + funExpr = Tuple(Seq(funExpr, fd.postcondition.get)) + } + if (withTemplates && fd.hasTemplate) { + funExpr = Tuple(Seq(funExpr, fd.getTemplate)) + } + + //introduce a new edge for every callee + val callees = getCallees(funExpr) + if (callees.isEmpty) + cg.addFunction(fd) + else + callees.foreach(cg.addEdgeIfNotPresent(fd, _)) + } + }) + cg + } + + def getCallees(expr: Expr): Set[FunDef] = { + var callees = Set[FunDef]() + simplePostTransform((expr) => expr match { + //note: do not consider field invocations + case FunctionInvocation(TypedFunDef(callee, _), args) + if callee.isRealFunction => { + callees += callee + expr + } + case _ => expr + })(expr) + callees + } +} \ No newline at end of file diff --git a/src/main/scala/leon/invariant/util/ExpressionTransformer.scala b/src/main/scala/leon/invariant/util/ExpressionTransformer.scala new file mode 100644 index 0000000000000000000000000000000000000000..e38bdf3aa5cc3dd7d1f591bbee464e5dd19bd1a5 --- /dev/null +++ b/src/main/scala/leon/invariant/util/ExpressionTransformer.scala @@ -0,0 +1,660 @@ +package leon +package invariant.util + +import purescala.Common._ +import purescala.Definitions._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.Extractors._ +import purescala.Types._ +import java.io._ +import java.io._ +import purescala.ScalaPrinter +import invariant.structure.Call +import invariant.structure.FunctionUtils._ +import leon.invariant.factories.TemplateIdFactory + +/** + * A collection of transformation on expressions and some utility methods. + * These operations are mostly semantic preserving (specific assumptions/requirements are specified on the operations) + */ +object ExpressionTransformer { + + val zero = InfiniteIntegerLiteral(0) + val one = InfiniteIntegerLiteral(1) + val mone = InfiniteIntegerLiteral(-1) + val tru = BooleanLiteral(true) + val fls = BooleanLiteral(false) + val bone = BigInt(1) + + /** + * This function conjoins the conjuncts created by 'transfomer' within the clauses containing Expr. + * This is meant to be used by operations that may flatten subexpression using existential quantifiers. + * @param insideFunction when set to true indicates that the newConjuncts (second argument) + * should not conjoined to the And(..) / Or(..) expressions found because they + * may be called inside a function. + */ + def conjoinWithinClause(e: Expr, transformer: (Expr, Boolean) => (Expr, Set[Expr]), + insideFunction: Boolean): (Expr, Set[Expr]) = { + e match { + case And(args) if !insideFunction => { + val newargs = args.map((arg) => { + val (nexp, ncjs) = transformer(arg, false) + Util.createAnd(nexp +: ncjs.toSeq) + }) + (Util.createAnd(newargs), Set()) + } + + case Or(args) if !insideFunction => { + val newargs = args.map((arg) => { + val (nexp, ncjs) = transformer(arg, false) + Util.createAnd(nexp +: ncjs.toSeq) + }) + (Util.createOr(newargs), Set()) + } + + case t: Terminal => (t, Set()) + + /*case BinaryOperator(e1, e2, op) => { + val (nexp1, ncjs1) = transformer(e1, true) + val (nexp2, ncjs2) = transformer(e2, true) + (op(nexp1, nexp2), ncjs1 ++ ncjs2) + } + + case u @ UnaryOperator(e1, op) => { + val (nexp, ncjs) = transformer(e1, true) + (op(nexp), ncjs) + }*/ + + case n @ Operator(args, op) => { + var ncjs = Set[Expr]() + val newargs = args.map((arg) => { + val (nexp, js) = transformer(arg, true) + ncjs ++= js + nexp + }) + (op(newargs), ncjs) + } + case _ => throw new IllegalStateException("Impossible event: expr did not match any case: " + e) + } + } + + /** + * Assumed that that given expression has boolean type + * converting if-then-else and let into a logical formula + */ + def reduceLangBlocks(inexpr: Expr, multop: (Expr, Expr) => Expr) = { + + def transform(e: Expr, insideFunction: Boolean): (Expr, Set[Expr]) = { + e match { + // Handle asserts here. Return flattened body as the result + case as @ Assert(pred, _, body) => { + val freshvar = TVarFactory.createTemp("asrtres", e.getType).toVariable + val newexpr = Equals(freshvar, body) + val resset = transform(newexpr, insideFunction) + (freshvar, resset._2 + resset._1) + } + //handles division by constant + case Division(lhs, rhs @ InfiniteIntegerLiteral(v)) => { + //this models floor and not integer division + val quo = TVarFactory.createTemp("q", IntegerType).toVariable + var possibs = Seq[Expr]() + for (i <- v - 1 to 0 by -1) { + if (i == 0) possibs :+= Equals(lhs, Times(rhs, quo)) + else possibs :+= Equals(lhs, Plus(Times(rhs, quo), InfiniteIntegerLiteral(i))) + } + //compute the disjunction of all possibs + val newexpr = Or(possibs) + //println("newexpr: "+newexpr) + val resset = transform(newexpr, true) + (quo, resset._2 + resset._1) + } + //handles division by variables + case Division(lhs, rhs) => { + //this models floor and not integer division + val quo = TVarFactory.createTemp("q", IntegerType).toVariable + val rem = TVarFactory.createTemp("r", IntegerType).toVariable + val mult = multop(quo, rhs) + val divsem = Equals(lhs, Plus(mult, rem)) + //TODO: here, we have to use |rhs| + val newexpr = Util.createAnd(Seq(divsem, LessEquals(zero, rem), LessEquals(rem, Minus(rhs, one)))) + val resset = transform(newexpr, true) + (quo, resset._2 + resset._1) + } + case err @ Error(_, msg) => { + //replace this by a fresh variable of the error type + (TVarFactory.createTemp("err", err.getType).toVariable, Set[Expr]()) + } + case Equals(lhs, rhs) => { + val (nexp1, ncjs1) = transform(lhs, true) + val (nexp2, ncjs2) = transform(rhs, true) + (Equals(nexp1, nexp2), ncjs1 ++ ncjs2) + } + case IfExpr(cond, thn, elze) => { + val freshvar = TVarFactory.createTemp("ifres", e.getType).toVariable + val newexpr = Or(And(cond, Equals(freshvar, thn)), And(Not(cond), Equals(freshvar, elze))) + val resset = transform(newexpr, insideFunction) + (freshvar, resset._2 + resset._1) + } + case Let(binder, value, body) => { + //TODO: do we have to consider reuse of let variables ? + val (resbody, bodycjs) = transform(body, true) + val (resvalue, valuecjs) = transform(value, true) + + (resbody, (valuecjs + Equals(binder.toVariable, resvalue)) ++ bodycjs) + } + //the value is a tuple in the following case + case LetTuple(binders, value, body) => { + //TODO: do we have to consider reuse of let variables ? + val (resbody, bodycjs) = transform(body, true) + val (resvalue, valuecjs) = transform(value, true) + + //here we optimize the case where resvalue itself has tuples + val newConjuncts = resvalue match { + case Tuple(args) => { + binders.zip(args).map((elem) => { + val (bind, arg) = elem + Equals(bind.toVariable, arg) + }) + } + case _ => { + //may it is better to assign resvalue to a temporary variable (if it is not already a variable) + val (resvalue2, cjs) = resvalue match { + case t: Terminal => (t, Seq()) + case _ => { + val freshres = TVarFactory.createTemp("tres", resvalue.getType).toVariable + (freshres, Seq(Equals(freshres, resvalue))) + } + } + var i = 0 + val cjs2 = binders.map((bind) => { + i += 1 + Equals(bind.toVariable, TupleSelect(resvalue2, i)) + }) + (cjs ++ cjs2) + } + } + + (resbody, (valuecjs ++ newConjuncts) ++ bodycjs) + } + case _ => { + conjoinWithinClause(e, transform, false) + } + } + } + val (nexp, ncjs) = transform(inexpr, false) + val res = if (!ncjs.isEmpty) { + Util.createAnd(nexp +: ncjs.toSeq) + } else nexp + res + } + + /** + * Requires: The expression has to be in NNF form and without if-then-else and let constructs + * Assumed that that given expression has boolean type + * (a) the function replaces every function call by a variable and creates a new equality + * (b) it also replaces arguments that are not variables by fresh variables and creates + * a new equality mapping the fresh variable to the argument expression + */ + def FlattenFunction(inExpr: Expr): Expr = { + + /** + * First return value is the new expression. The second return value is the + * set of new conjuncts + * @param insideFunction when set to true indicates that the newConjuncts (second argument) + * should not conjoined to the And(..) / Or(..) expressions found because they + * may be called inside a function. + */ + def flattenFunc(e: Expr, insideFunction: Boolean): (Expr, Set[Expr]) = { + e match { + case fi @ FunctionInvocation(fd, args) => { + //now also flatten the args. The following is slightly tricky + val (newargs, newConjuncts) = flattenArgs(args, true) + //create a new equality in UIFs + val newfi = FunctionInvocation(fd, newargs) + //create a new variable to represent the function + val freshResVar = Variable(TVarFactory.createTemp("r", fi.getType)) + val res = (freshResVar, newConjuncts + Equals(freshResVar, newfi)) + res + } + case inst @ IsInstanceOf(e1, cd) => { + //replace e by a variable + val (newargs, newcjs) = flattenArgs(Seq(e1), true) + var newConjuncts = newcjs + + val freshArg = newargs(0) + val newInst = IsInstanceOf(freshArg, cd) + val freshResVar = Variable(TVarFactory.createTemp("ci", inst.getType)) + newConjuncts += Equals(freshResVar, newInst) + (freshResVar, newConjuncts) + } + case cs @ CaseClassSelector(cd, e1, sel) => { + val (newargs, newcjs) = flattenArgs(Seq(e1), true) + var newConjuncts = newcjs + + val freshArg = newargs(0) + val newCS = CaseClassSelector(cd, freshArg, sel) + val freshResVar = Variable(TVarFactory.createTemp("cs", cs.getType)) + newConjuncts += Equals(freshResVar, newCS) + + (freshResVar, newConjuncts) + } + case ts @ TupleSelect(e1, index) => { + val (newargs, newcjs) = flattenArgs(Seq(e1), true) + var newConjuncts = newcjs + + val freshArg = newargs(0) + val newTS = TupleSelect(freshArg, index) + val freshResVar = Variable(TVarFactory.createTemp("ts", ts.getType)) + newConjuncts += Equals(freshResVar, newTS) + + (freshResVar, newConjuncts) + } + case cc @ CaseClass(cd, args) => { + + val (newargs, newcjs) = flattenArgs(args, true) + var newConjuncts = newcjs + + val newCC = CaseClass(cd, newargs) + val freshResVar = Variable(TVarFactory.createTemp("cc", cc.getType)) + newConjuncts += Equals(freshResVar, newCC) + + (freshResVar, newConjuncts) + } + case tp @ Tuple(args) => { + val (newargs, newcjs) = flattenArgs(args, true) + var newConjuncts = newcjs + + val newTP = Tuple(newargs) + val freshResVar = Variable(TVarFactory.createTemp("tp", tp.getType)) + // if(freshResVar.id.toString == "tp6"){ + // println("Creating temporary tp6 type: "+tp.getType+" expr: "+tp) + // throw new IllegalStateException("") + // } + newConjuncts += Equals(freshResVar, newTP) + + (freshResVar, newConjuncts) + } + case _ => conjoinWithinClause(e, flattenFunc, insideFunction) + } + } + + def flattenArgs(args: Seq[Expr], insideFunction: Boolean): (Seq[Expr], Set[Expr]) = { + var newConjuncts = Set[Expr]() + val newargs = args.map((arg) => + arg match { + case v: Variable => v + case r: ResultVariable => r + case _ => { + val (nexpr, ncjs) = flattenFunc(arg, insideFunction) + + newConjuncts ++= ncjs + + nexpr match { + case v: Variable => v + case r: ResultVariable => r + case _ => { + val freshArgVar = Variable(TVarFactory.createTemp("arg", arg.getType)) + newConjuncts += Equals(freshArgVar, nexpr) + freshArgVar + } + } + } + }) + (newargs, newConjuncts) + } + + val (nexp, ncjs) = flattenFunc(inExpr, false) + if (!ncjs.isEmpty) { + Util.createAnd(nexp +: ncjs.toSeq) + } else nexp + } + + /** + * The following procedure converts the formula into negated normal form by pushing all not's inside. + * It also handles disequality constraints. + * Assumption: + * (a) the formula does not have match constructs + * Some important features. + * (a) For a strict inequality with real variables/constants, the following produces a strict inequality + * (b) Strict inequalities with only integer variables/constants are reduced to non-strict inequalities + */ + def TransformNot(expr: Expr, retainNEQ: Boolean = false): Expr = { // retainIff : Boolean = false + def nnf(inExpr: Expr): Expr = { + + if (inExpr.getType != BooleanType) inExpr + else inExpr match { + case Not(Not(e1)) => nnf(e1) + case e @ Not(t: Terminal) => e + case e @ Not(FunctionInvocation(_, _)) => e + case Not(And(args)) => Util.createOr(args.map(arg => nnf(Not(arg)))) + case Not(Or(args)) => Util.createAnd(args.map(arg => nnf(Not(arg)))) + case Not(e @ Operator(Seq(e1, e2), op)) => { + //matches integer binary relation or a boolean equality + if (e1.getType == BooleanType || e1.getType == Int32Type || e1.getType == RealType || e1.getType == IntegerType) { + e match { + case e: Equals => { + if (e1.getType == BooleanType && e2.getType == BooleanType) { + Or(And(nnf(e1), nnf(Not(e2))), And(nnf(e2), nnf(Not(e1)))) + } else { + if (retainNEQ) Not(Equals(e1, e2)) + else Or(nnf(LessThan(e1, e2)), nnf(GreaterThan(e1, e2))) + } + } + case e: LessThan => GreaterEquals(nnf(e1), nnf(e2)) + case e: LessEquals => GreaterThan(nnf(e1), nnf(e2)) + case e: GreaterThan => LessEquals(nnf(e1), nnf(e2)) + case e: GreaterEquals => LessThan(nnf(e1), nnf(e2)) + case e: Implies => And(nnf(e1), nnf(Not(e2))) + case _ => throw new IllegalStateException("Unknown binary operation: " + e) + } + } else { + //in this case e is a binary operation over ADTs + e match { + case ninst @ Not(IsInstanceOf(e1, cd)) => Not(IsInstanceOf(nnf(e1), cd)) + case e: Equals => Not(Equals(nnf(e1), nnf(e2))) + case _ => throw new IllegalStateException("Unknown operation on algebraic data types: " + e) + } + } + } + case Implies(lhs, rhs) => nnf(Or(Not(lhs), rhs)) + case e @ Equals(lhs, IsInstanceOf(_, _) | CaseClassSelector(_, _, _) | TupleSelect(_, _) | FunctionInvocation(_, _)) => + //all case where rhs could use an ADT tree e.g. instanceOF, tupleSelect, fieldSelect, function invocation + e + case Equals(lhs, rhs) if (lhs.getType == BooleanType && rhs.getType == BooleanType) => { + nnf(And(Implies(lhs, rhs), Implies(rhs, lhs))) + } + case Not(IfExpr(cond, thn, elze)) => IfExpr(nnf(cond), nnf(Not(thn)), nnf(Not(elze))) + case Not(Let(i, v, e)) => Let(i, nnf(v), nnf(Not(e))) + //note that Not(LetTuple) is not possible + case t: Terminal => t + /*case u @ UnaryOperator(e1, op) => op(nnf(e1)) + case b @ BinaryOperator(e1, e2, op) => op(nnf(e1), nnf(e2))*/ + case n @ Operator(args, op) => op(args.map(nnf(_))) + + case _ => throw new IllegalStateException("Impossible event: expr did not match any case: " + inExpr) + } + } + val nnfvc = nnf(expr) + nnfvc + } + + /** + * Eliminates redundant nesting of ORs and ANDs. + * This is supposed to be a semantic preserving transformation + */ + def pullAndOrs(expr: Expr): Expr = { + + simplePostTransform((e: Expr) => e match { + case Or(args) => { + val newArgs = args.foldLeft(Seq[Expr]())((acc, arg) => arg match { + case Or(inArgs) => acc ++ inArgs + case _ => acc :+ arg + }) + Util.createOr(newArgs) + } + case And(args) => { + val newArgs = args.foldLeft(Seq[Expr]())((acc, arg) => arg match { + case And(inArgs) => acc ++ inArgs + case _ => acc :+ arg + }) + Util.createAnd(newArgs) + } + case _ => e + })(expr) + } + + def classSelToCons(e: Expr): Expr = { + val (r, cd, ccvar, ccfld) = e match { + case Equals(r0 @ Variable(_), CaseClassSelector(cd0, ccvar0, ccfld0)) => (r0, cd0, ccvar0, ccfld0) + case _ => throw new IllegalStateException("Not a case-class-selector call") + } + //convert this to a cons by creating dummy variables + val args = cd.fields.map((fld) => { + if (fld.id == ccfld) r + else { + //create a dummy identifier there + TVarFactory.createDummy(fld.getType).toVariable + } + }) + Equals(ccvar, CaseClass(cd, args)) + } + + def tupleSelToCons(e: Expr): Expr = { + val (r, tpvar, index) = e match { + case Equals(r0 @ Variable(_), TupleSelect(tpvar0, index0)) => (r0, tpvar0, index0) + // case Iff(r0 @ Variable(_), TupleSelect(tpvar0, index0)) => (r0, tpvar0, index0) + case _ => throw new IllegalStateException("Not a tuple-selector call") + } + //convert this to a Tuple by creating dummy variables + val tupleType = tpvar.getType.asInstanceOf[TupleType] + val args = (1 until tupleType.dimension + 1).map((i) => { + if (i == index) r + else { + //create a dummy identifier there (note that here we have to use i-1) + TVarFactory.createDummy(tupleType.bases(i - 1)).toVariable + } + }) + Equals(tpvar, Tuple(args)) + } + + /** + * Normalizes the expressions + */ + def normalizeExpr(expr: Expr, multOp: (Expr, Expr) => Expr): Expr = { + //reduce the language before applying flatten function + // println("Normalizing " + ScalaPrinter(expr) + "\n") + val redex = reduceLangBlocks(expr, multOp) + // println("Redex: "+ScalaPrinter(redex) + "\n") + val nnfExpr = TransformNot(redex) + // println("NNFexpr: "+ScalaPrinter(nnfExpr) + "\n") + //flatten all function calls + val flatExpr = FlattenFunction(nnfExpr) + // println("Flatexpr: "+ScalaPrinter(flatExpr) + "\n") + //perform additional simplification + val simpExpr = pullAndOrs(TransformNot(flatExpr)) + simpExpr + } + + /** + * This is the inverse operation of flattening, this is mostly + * used to produce a readable formula. + * Freevars is a set of identifiers that are program variables + * This assumes that temporary identifiers (which are not freevars) are not reused across clauses. + */ + def unFlatten(ine: Expr, freevars: Set[Identifier]): Expr = { + var tempMap = Map[Expr, Expr]() + val newinst = simplePostTransform((e: Expr) => e match { + case Equals(v @ Variable(id), rhs @ _) if !freevars.contains(id) => + if (tempMap.contains(v)) e + else { + tempMap += (v -> rhs) + tru + } + case _ => e + })(ine) + val closure = (e: Expr) => replace(tempMap, e) + Util.fix(closure)(newinst) + } + + /** + * convert all integer constants to real constants + */ + def IntLiteralToReal(inexpr: Expr): Expr = { + val transformer = (e: Expr) => e match { + case InfiniteIntegerLiteral(v) => FractionalLiteral(v, 1) + case IntLiteral(v) => FractionalLiteral(v, 1) + case _ => e + } + simplePostTransform(transformer)(inexpr) + } + + /** + * convert all real constants to integers + */ + def FractionalLiteralToInt(inexpr: Expr): Expr = { + val transformer = (e: Expr) => e match { + case FractionalLiteral(v, `bone`) => InfiniteIntegerLiteral(v) + case FractionalLiteral(_, _) => throw new IllegalStateException("cannot convert real literal to integer: " + e) + case _ => e + } + simplePostTransform(transformer)(inexpr) + } + + /** + * A hacky way to implement subexpression check. + * TODO: fix this + */ + def isSubExpr(key: Expr, expr: Expr): Boolean = { + + var found = false + simplePostTransform((e: Expr) => e match { + case _ if (e == key) => + found = true; e + case _ => e + })(expr) + found + } + + /** + * Some simplification rules (keep adding more and more rules) + */ + def simplify(expr: Expr): Expr = { + + //Note: some simplification are already performed by the class constructors (see Tree.scala) + simplePostTransform((e: Expr) => e match { + case Equals(lhs, rhs) if (lhs == rhs) => tru + case LessEquals(lhs, rhs) if (lhs == rhs) => tru + case GreaterEquals(lhs, rhs) if (lhs == rhs) => tru + case LessThan(lhs, rhs) if (lhs == rhs) => fls + case GreaterThan(lhs, rhs) if (lhs == rhs) => fls + case UMinus(InfiniteIntegerLiteral(v)) => InfiniteIntegerLiteral(-v) + case Equals(InfiniteIntegerLiteral(v1), InfiniteIntegerLiteral(v2)) => BooleanLiteral(v1 == v2) + case LessEquals(InfiniteIntegerLiteral(v1), InfiniteIntegerLiteral(v2)) => BooleanLiteral(v1 <= v2) + case LessThan(InfiniteIntegerLiteral(v1), InfiniteIntegerLiteral(v2)) => BooleanLiteral(v1 < v2) + case GreaterEquals(InfiniteIntegerLiteral(v1), InfiniteIntegerLiteral(v2)) => BooleanLiteral(v1 >= v2) + case GreaterThan(InfiniteIntegerLiteral(v1), InfiniteIntegerLiteral(v2)) => BooleanLiteral(v1 > v2) + case _ => e + })(expr) + } + + /** + * Input expression is assumed to be in nnf form + * Note: (a) Not(Equals()) and Not(Variable) is allowed + */ + def isDisjunct(e: Expr): Boolean = e match { + case And(args) => args.foldLeft(true)((acc, arg) => acc && isDisjunct(arg)) + case Not(Equals(_, _)) | Not(Variable(_)) => true + case Or(_) | Implies(_, _) | Not(_) | Equals(_, _) => false + case _ => true + } + + /** + * assuming that the expression is in nnf form + * Note: (a) Not(Equals()) and Not(Variable) is allowed + */ + def isConjunct(e: Expr): Boolean = e match { + case Or(args) => args.foldLeft(true)((acc, arg) => acc && isConjunct(arg)) + case Not(Equals(_, _)) | Not(Variable(_)) => true + case And(_) | Implies(_, _) | Not(_) | Equals(_, _) => false + case _ => true + } + + def PrintWithIndentation(wr: PrintWriter, expr: Expr): Unit = { + + def uniOP(e: Expr, seen: Int): Boolean = e match { + case And(args) => { + //have we seen an or ? + if (seen == 2) false + else args.foldLeft(true)((acc, arg) => acc && uniOP(arg, 1)) + } + case Or(args) => { + //have we seen an And ? + if (seen == 1) false + else args.foldLeft(true)((acc, arg) => acc && uniOP(arg, 2)) + } + case t: Terminal => true + /*case u @ UnaryOperator(e1, op) => uniOP(e1, seen) + case b @ BinaryOperator(e1, e2, op) => uniOP(e1, seen) && uniOP(e2, seen)*/ + case n @ Operator(args, op) => args.foldLeft(true)((acc, arg) => acc && uniOP(arg, seen)) + } + + def printRec(e: Expr, indent: Int): Unit = { + if (uniOP(e, 0)) wr.println(ScalaPrinter(e)) + else { + wr.write("\n" + " " * indent + "(\n") + e match { + case And(args) => { + var start = true + args.map((arg) => { + wr.print(" " * (indent + 1)) + if (!start) wr.print("^") + printRec(arg, indent + 1) + start = false + }) + } + case Or(args) => { + var start = true + args.map((arg) => { + wr.print(" " * (indent + 1)) + if (!start) wr.print("v") + printRec(arg, indent + 1) + start = false + }) + } + case _ => throw new IllegalStateException("how can this happen ? " + e) + } + wr.write(" " * indent + ")\n") + } + } + printRec(expr, 0) + } + + /** + * Converts to sum of products form by distributing + * multiplication over addition + */ + def normalizeMultiplication(e: Expr, multop: (Expr, Expr) => Expr): Expr = { + + def isConstantOrTemplateVar(e: Expr) = { + e match { + case l: Literal[_] => true + case Variable(id) if TemplateIdFactory.IsTemplateIdentifier(id) => true + case _ => false + } + } + + def distribute(e: Expr): Expr = { + simplePreTransform(_ match { + case e @ FunctionInvocation(TypedFunDef(fd, _), Seq(e1, e2)) if Util.isMultFunctions(fd) => + val newe = (e1, e2) match { + case (Plus(sum1, sum2), _) => + // distribute e2 over e1 + Plus(multop(sum1, e2), multop(sum2, e2)) + case (_, Plus(sum1, sum2)) => + // distribute e1 over e2 + Plus(multop(e1, sum1), multop(e1, sum2)) + case (Times(arg1, arg2), _) => + // pull the constants out of multiplication (note: times is used when one of the arguments is a literal or template id + if (isConstantOrTemplateVar(arg1)) { + Times(arg1, multop(arg2, e2)) + } else + Times(arg2, multop(arg1, e2)) // here using commutativity axiom + case (_, Times(arg1, arg2)) => + if (isConstantOrTemplateVar(arg1)) + Times(arg1, multop(e1, arg2)) + else + Times(arg2, multop(e1, arg1)) + case _ if isConstantOrTemplateVar(e1) || isConstantOrTemplateVar(e2) => + // here one of the operands is a literal or template var, so convert mult to times and continue + Times(e1, e2) + case _ => + e + } + newe + case other => other + })(e) + } + distribute(e) + } +} diff --git a/src/main/scala/leon/invariant/util/Graph.scala b/src/main/scala/leon/invariant/util/Graph.scala new file mode 100644 index 0000000000000000000000000000000000000000..22f4e83d12bb96407f37cb548b71482c38376914 --- /dev/null +++ b/src/main/scala/leon/invariant/util/Graph.scala @@ -0,0 +1,170 @@ +package leon +package invariant.util + +class DirectedGraph[T] { + + var adjlist = scala.collection.mutable.Map[T, Set[T]]() + var edgeCount: Int = 0 + + def addNode(n: T) { + if (!adjlist.contains(n)) { + adjlist.update(n, Set()) + } + } + + def addEdge(src: T, dest: T): Unit = { + val newset = if (adjlist.contains(src)) adjlist(src) + dest + else Set(dest) + + //this has some side-effects + adjlist.update(src, newset) + + edgeCount += 1 + } + + def BFSReach(src: T, dest: T, excludeSrc: Boolean = false): Boolean = { + var queue = List[T]() + var visited = Set[T]() + visited += src + + //TODO: is there a better (and efficient) way to implement BFS without using side-effects + def BFSReachRecur(cur: T): Boolean = { + var found: Boolean = false + if (adjlist.contains(cur)) { + adjlist(cur).foreach((fi) => { + if (fi == dest) found = true + else if (!visited.contains(fi)) { + visited += fi + queue ::= fi + } + }) + } + if (found) true + else if (queue.isEmpty) false + else { + val (head :: tail) = queue + queue = tail + BFSReachRecur(head) + } + } + + if (!excludeSrc && src == dest) true + else BFSReachRecur(src) + } + + def BFSReachables(src: T): Set[T] = { + var queue = List[T]() + var visited = Set[T]() + visited += src + + def BFSReachRecur(cur: T): Unit = { + if (adjlist.contains(cur)) { + adjlist(cur).foreach((neigh) => { + if (!visited.contains(neigh)) { + visited += neigh + queue ::= neigh + } + }) + } + if (!queue.isEmpty) { + val (head :: tail) = queue + queue = tail + BFSReachRecur(head) + } + } + + BFSReachRecur(src) + visited + } + + def containsEdge(src: T, dest: T): Boolean = { + if (adjlist.contains(src)) { + adjlist(src).contains(dest) + } else false + } + + def getEdgeCount: Int = edgeCount + def getNodes: Set[T] = adjlist.keySet.toSet + def getSuccessors(src: T): Set[T] = adjlist(src) + + /** + * Change this to the verified component + */ + def sccs: List[List[T]] = { + + type Component = List[T] + + case class State(count: Int, + visited: Map[T, Boolean], + dfNumber: Map[T, Int], + lowlinks: Map[T, Int], + stack: List[T], + components: List[Component]) + + def search(vertex: T, state: State): State = { + val newState = state.copy(visited = state.visited.updated(vertex, true), + dfNumber = state.dfNumber.updated(vertex, state.count), + count = state.count + 1, + lowlinks = state.lowlinks.updated(vertex, state.count), + stack = vertex :: state.stack) + + def processVertex(st: State, w: T): State = { + if (!st.visited(w)) { + val st1 = search(w, st) + val min = Math.min(st1.lowlinks(w), st1.lowlinks(vertex)) + st1.copy(lowlinks = st1.lowlinks.updated(vertex, min)) + } else { + if ((st.dfNumber(w) < st.dfNumber(vertex)) && st.stack.contains(w)) { + val min = Math.min(st.dfNumber(w), st.lowlinks(vertex)) + st.copy(lowlinks = st.lowlinks.updated(vertex, min)) + } else st + } + } + + val strslt = getSuccessors(vertex).foldLeft(newState)(processVertex) + + if (strslt.lowlinks(vertex) == strslt.dfNumber(vertex)) { + + val index = strslt.stack.indexOf(vertex) + val (comp, rest) = strslt.stack.splitAt(index + 1) + strslt.copy(stack = rest, + components = strslt.components :+ comp) + } else strslt + } + + val initial = State( + count = 1, + visited = getNodes.map { (_, false) }.toMap, + dfNumber = Map(), + lowlinks = Map(), + stack = Nil, + components = Nil) + + var state = initial + while (state.visited.exists(_._2 == false)) { + state.visited.find(_._2 == false).foreach { tuple => + val (vertex, _) = tuple + state = search(vertex, state) + } + } + state.components + } + +} + +class UndirectedGraph[T] extends DirectedGraph[T] { + + override def addEdge(src: T, dest: T): Unit = { + val newset1 = if (adjlist.contains(src)) adjlist(src) + dest + else Set(dest) + + val newset2 = if (adjlist.contains(dest)) adjlist(dest) + src + else Set(src) + + //this has some side-effects + adjlist.update(src, newset1) + adjlist.update(dest, newset2) + + edgeCount += 1 + } +} diff --git a/src/main/scala/leon/invariant/util/LetTupleSimplifications.scala b/src/main/scala/leon/invariant/util/LetTupleSimplifications.scala new file mode 100644 index 0000000000000000000000000000000000000000..966121756fd62e0ef2d517f8ad99153b161858fb --- /dev/null +++ b/src/main/scala/leon/invariant/util/LetTupleSimplifications.scala @@ -0,0 +1,465 @@ +package leon +package invariant.util + +import purescala.Common._ +import purescala.Definitions._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.Extractors._ +import purescala.Types._ +import java.io._ +import java.io._ +import purescala.ScalaPrinter +import leon.utils._ + +import invariant.structure.Call +import invariant.structure.FunctionUtils._ +import leon.transformations.InstUtil._ + +/** + * A collection of transformation on expressions and some utility methods. + * These operations are mostly semantic preserving (specific assumptions/requirements are specified on the operations) + */ +object LetTupleSimplification { + + val zero = InfiniteIntegerLiteral(0) + val one = InfiniteIntegerLiteral(1) + val mone = InfiniteIntegerLiteral(-1) + val tru = BooleanLiteral(true) + val fls = BooleanLiteral(false) + val bone = BigInt(1) + + def letSanityChecks(ine: Expr) = { + simplePostTransform(_ match { + case letExpr @ Let(binderId, letValue, body) + if (binderId.getType != letValue.getType) => + throw new IllegalStateException("Binder and value type mismatch: "+ + s"(${binderId.getType},${letValue.getType})") + case e => e + })(ine) + } + + /** + * This function simplifies lets of the form <Var> = <TupleType Expr> by replacing + * uses of the <Var>._i by the approppriate expression in the let body or by + * introducing a new let <Var'> = <Var>._i and using <Var'> in place of <Var>._i + * in the original let body. + * Caution: this function may not be idempotent. + */ + def simplifyTuples(ine: Expr): Expr = { + + var processedLetBinders = Set[Identifier]() + def recSimplify(e: Expr, replaceMap: Map[Expr, Expr]): Expr = { + + //println("Before: "+e) + val transe = e match { + case letExpr @ Let(binderId, letValue, body) if !processedLetBinders(binderId) => + processedLetBinders += binderId + // transform the 'letValue' with the current map + val nvalue = recSimplify(letValue, replaceMap) + // enrich the map if letValue is of tuple type + nvalue.getType match { + case TupleType(argTypes) => + var freshBinders = Set[Identifier]() + def freshBinder(typ: TypeTree) = { + val freshid = TVarFactory.createTemp(binderId.name, typ) + freshBinders += freshid + freshid.toVariable + } + val newmap: Map[Expr, Expr] = nvalue match { + case Tuple(args) => // this is an optimization for the case where nvalue is a tuple + args.zipWithIndex.map { + case (t: Terminal, index) => + (TupleSelect(binderId.toVariable, index + 1) -> t) + case (_, index) => + (TupleSelect(binderId.toVariable, index + 1) -> freshBinder(argTypes(index))) + }.toMap + case _ => + argTypes.zipWithIndex.map { + case (argtype, index) => + (TupleSelect(binderId.toVariable, index + 1) -> freshBinder(argtype)) + }.toMap + } + // transform the body using the new map + old map + val nbody = recSimplify(body, replaceMap ++ newmap) + val bodyFreevars = variablesOf(nbody) + // create a sequence of lets for the freshBinders + val nletBody = newmap.foldLeft(nbody) { + case (acc, (k, Variable(id))) if freshBinders(id) && bodyFreevars(id) => + // here, the 'id' is a newly created binder and is also used in the transformed body + Let(id, k, acc) + case (acc, _) => + acc + } + Let(binderId, nvalue, nletBody) + case _ => + // no simplification can be done in this step + Let(binderId, nvalue, recSimplify(body, replaceMap)) + } + case ts @ TupleSelect(_, _) if replaceMap.contains(ts) => + postMap(replaceMap.lift, true)(e) //perform recursive replacements to handle nested tuple selects + //replaceMap(ts) //replace tuple-selects in the map with the new identifier + + case t: Terminal => t + + /*case UnaryOperator(sube, op) => + op(recSimplify(sube, replaceMap)) + + case BinaryOperator(e1, e2, op) => + op(recSimplify(e1, replaceMap), recSimplify(e2, replaceMap))*/ + + case Operator(subes, op) => + op(subes.map(recSimplify(_, replaceMap))) + } + //println("After: "+e) + transe + } + fixpoint((e: Expr) => simplifyArithmetic(recSimplify(e, Map())))(ine) + } + + // sanity checks + def checkTupleSelectInsideMax(e: Expr): Boolean = { + //exists( predicate: Expr => Expr) (e) + var error = false + def helper(e: Expr): Unit = { + e match { + case FunctionInvocation(tfd, args) if (tfd.fd == maxFun) => { + + val Seq(arg1: Expr, arg2: Expr) = args + (arg1, arg2) match { + case (_: TupleSelect, _) => error = true + case (_, _: TupleSelect) => error = true + case _ => { ; } + } + } + + case _ => { ; } + } + } + + postTraversal(helper)(e) + error + } + + def simplifyMax(ine: Expr): Expr = { + val debugMaxSimplify = false + //computes a lower bound value, assuming that every sub-term used in the term is positive + //Note: this is applicable only to expressions involving depth + def positiveTermLowerBound(e: Expr): Int = e match { + case IntLiteral(v) => v + case Plus(l, r) => positiveTermLowerBound(l) + positiveTermLowerBound(r) + case FunctionInvocation(tfd, args) if (tfd.fd == maxFun) => { + val Seq(arg1, arg2) = args + val lb1 = positiveTermLowerBound(arg1) + val lb2 = positiveTermLowerBound(arg2) + if (lb1 >= lb2) lb1 else lb2 + } + case _ => 0 //other case are not handled as they do not appear + } + + //checks if 'sub' is subsumed by 'e' i.e, 'e' will always take a value + // greater than or equal to 'sub'. + //Assuming that every sub-term used in the term is positive + def subsumedBy(sub: Expr, e: Expr): Boolean = e match { + case _ if (sub == e) => true + case Plus(l, r) => subsumedBy(sub, l) || subsumedBy(sub, r) + case FunctionInvocation(tfd, args) if (tfd.fd == maxFun) => + val Seq(l, r) = args + subsumedBy(sub, l) || subsumedBy(sub, r) + case _ => false + } + + // in the sequel, we are using the fact that 'depth' is positive and + // 'ine' contains only 'depth' variables + val simpe = simplePostTransform((e: Expr) => e match { + case FunctionInvocation(tfd, args) if (tfd.fd == maxFun) => { + if (debugMaxSimplify) { + println("Simplifying: " + e) + } + val newargs: Seq[Expr] = args.map(simplifyArithmetic) + val Seq(arg1: Expr, arg2: Expr) = newargs + val simpval = if (!Util.hasCalls(arg1) && !Util.hasCalls(arg2)) { + import invariant.structure.LinearConstraintUtil._ + val lt = exprToTemplate(LessEquals(Minus(arg1, arg2), InfiniteIntegerLiteral(0))) + //now, check if all the variables in 'lt' have only positive coefficients + val allPositive = lt.coeffTemplate.forall(entry => entry match { + case (k, IntLiteral(v)) if (v >= 0) => true + case _ => false + }) && (lt.constTemplate match { + case None => true + case Some(IntLiteral(v)) if (v >= 0) => true + case _ => false + }) + if (allPositive) arg1 + else { + val allNegative = lt.coeffTemplate.forall(entry => entry match { + case (k, IntLiteral(v)) if (v <= 0) => true + case _ => false + }) && (lt.constTemplate match { + case None => true + case Some(IntLiteral(v)) if (v <= 0) => true + case _ => false + }) + if (allNegative) arg2 + else FunctionInvocation(tfd, newargs) //here we cannot do any simplification. + } + + } else { + (arg1, arg2) match { + case (IntLiteral(v), r) if (v <= positiveTermLowerBound(r)) => r + case (l, IntLiteral(v)) if (v <= positiveTermLowerBound(l)) => l + case (l, r) if subsumedBy(l, r) => r + case (l, r) if subsumedBy(r, l) => l + case _ => FunctionInvocation(tfd, newargs) + } + } + if (debugMaxSimplify) { + println("Simplified value: " + simpval) + } + simpval + } + // case FunctionInvocation(tfd, args) if(tfd.fd.id.name == "max") => { + // throw new IllegalStateException("Found just max in expression " + e + "\n") + // } + case _ => e + })(ine) + simpe + } + + def inlineMax(ine: Expr): Expr = { + //inline 'max' operations here + simplePostTransform((e: Expr) => e match { + case FunctionInvocation(tfd, args) if (tfd.fd == maxFun) => + val Seq(arg1, arg2) = args + val bindWithLet = (value: Expr, body: (Expr with Terminal) => Expr) => { + value match { + case t: Terminal => body(t) + case Let(id, v, b: Terminal) => + //here we can use 'b' in 'body' + Let(id, v, body(b)) + case _ => + val mt = TVarFactory.createTemp("mt", value.getType) + Let(mt, value, body(mt.toVariable)) + } + } + bindWithLet(arg1, a1 => bindWithLet(arg2, a2 => + IfExpr(GreaterEquals(a1, a2), a1, a2))) + case _ => e + })(ine) + } + + def removeLetsFromLetValues(ine: Expr): Expr = { + + /** + * Navigates through the sequence of lets in 'e' + * and replaces its 'let' free part by subst. + * Assuming that 'e' has only lets at the top and no nested lets in the value + */ + def replaceLetBody(e: Expr, subst: Expr => Expr): Expr = e match { + case Let(binder, letv, letb) => + Let(binder, letv, replaceLetBody(letb, subst)) + case _ => + subst(e) + } + + // the function removes the lets from the let values + // by pulling them out + def pullLetToTop(e: Expr): Expr = { + val transe = e match { + //note: do not pull let's out of ensuring or requires + case Ensuring(body, pred) => + Ensuring(pullLetToTop(body), pred) + case Require(pre, body) => + Require(pre, pullLetToTop(body)) + + case letExpr @ Let(binder, letValue, body) => + // transform the 'letValue' with the current map + pullLetToTop(letValue) match { + case sublet @ Let(binder2, subvalue, subbody) => + // transforming "let v = (let v1 = e1 in e2) in e3" + // to "let v1 = e1 in (let v = e2 in e3)" + // here, subvalue is free of lets, but subbody may again be a let + val newbody = replaceLetBody(subbody, Let(binder, _, pullLetToTop(body))) + Let(binder2, subvalue, newbody) + case nval => + // here, there is no let in the value + Let(binder, nval, pullLetToTop(body)) + } + case t: Terminal => t + case Operator(Seq(sube), op) => + replaceLetBody(pullLetToTop(sube), e => op(Seq(e))) + + case Operator(Seq(e1, e2), op) => + replaceLetBody(pullLetToTop(e1), te1 => + replaceLetBody(pullLetToTop(e2), te2 => op(Seq(te1, te2)))) + + //don't pull things out of if-then-else and match (don't know why this is a problem) + case IfExpr(c, th, elze) => + IfExpr(pullLetToTop(c), pullLetToTop(th), pullLetToTop(elze)) + + case Operator(Seq(), op) => + op(Seq()) + + case Operator(subes, op) => + // transform all the sub-expressions + val nsubes = subes map pullLetToTop + //collects all the lets and makes the bodies a tuple + var i = -1 + val transLet = nsubes.tail.foldLeft(nsubes.head) { + case (acc, nsube) => + i += 1 + replaceLetBody(acc, e1 => + replaceLetBody(nsube, e2 => e1 match { + case _ if i == 0 => + Tuple(Seq(e1, e2)) + case Tuple(args) => + Tuple(args :+ e2) + })) + } + replaceLetBody(transLet, (e: Expr) => e match { + case Tuple(args) => + op(args) + case _ => op(Seq(e)) //here, there was only one argument + }) + } + transe + } + val res = pullLetToTop(matchToIfThenElse(ine)) + // println("After Pulling lets to top : \n" + ScalaPrinter.apply(res)) + res + } + + def simplifyLetsAndLetsWithTuples(ine: Expr) = { + + def simplerLet(t: Expr): Option[Expr] = { + val res = t match { + case letExpr @ Let(i, t: Terminal, b) => + Some(replace(Map(Variable(i) -> t), b)) + + // check if the let can be completely removed + case letExpr @ Let(i, e, b) => { + val occurrences = count { + case Variable(x) if x == i => 1 + case _ => 0 + }(b) + + if (occurrences == 0) { + Some(b) + } else if (occurrences == 1) { + Some(replace(Map(Variable(i) -> e), b)) + } else { + //TODO: we can also remove zero occurrences and compress the tuples + // this may be necessary when instrumentations are combined. + letExpr match { + case letExpr @ Let(binder, lval @ Tuple(subes), b) => + def occurrences(index: Int) = { + val res = count { + case TupleSelect(sel, i) if sel == binder.toVariable && i == index => 1 + case _ => 0 + }(b) + res + } + val repmap: Map[Expr, Expr] = subes.zipWithIndex.collect { + case (sube, i) if occurrences(i + 1) == 1 => + (TupleSelect(binder.toVariable, i + 1) -> sube) + }.toMap + Some(Let(binder, lval, replace(repmap, b))) + //note: here, we cannot remove the let, + //if it is not used it will be removed in the next iteration + + case _ => None + } + } + } + + case _ => None + } + res + } + + val transforms = removeLetsFromLetValues _ andThen fixpoint(postMap(simplerLet)) _ andThen simplifyArithmetic + transforms(ine) + } + + /* + This function tries to simplify a part of the expression tree consisting of the same operation. + The operatoin needs to be associative and commutative for this simplification to work . + Arguments: + op: An implementation of the opertaion to be simplified + getLeaves: Gets all the operands from the AST (if the argument is not of + the form currently being simplified, this is required to return an empty set) + identity: The identity element for the operation + makeTree: Makes an AST from the operands + */ + def simplifyConstantsGeneral(e: Expr, op: (BigInt, BigInt) => BigInt, + getLeaves: (Expr, Boolean) => Seq[Expr], identity: BigInt, + makeTree: (Expr, Expr) => Expr): Expr = { + + val allLeaves = getLeaves(e, true) + // Here the expression is not of the form we are currently simplifying + if (allLeaves.size == 0) e + else { + // fold constants here + val allConstantsOpped = allLeaves.foldLeft(identity)((acc, e) => e match { + case InfiniteIntegerLiteral(x) => op(acc, x) + case _ => acc + }) + + val allNonConstants = allLeaves.filter((e) => e match { + case _: InfiniteIntegerLiteral => false + case _ => true + }) + + // Reconstruct the expressin tree with the non-constants and the result of constant evaluation above + if (allConstantsOpped != identity) { + allNonConstants.foldLeft(InfiniteIntegerLiteral(allConstantsOpped): Expr)((acc: Expr, currExpr) => makeTree(acc, currExpr)) + } + else { + if (allNonConstants.size == 0) InfiniteIntegerLiteral(identity) + else { + allNonConstants.tail.foldLeft(allNonConstants.head)((acc: Expr, currExpr) => makeTree(acc, currExpr)) + } + } + } + } + + //Use the above function to simplify additions and maximums interleaved + def simplifyAdditionsAndMax(e: Expr): Expr = { + def getAllSummands(e: Expr, isTopLevel: Boolean): Seq[Expr] = { + e match { + case Plus(e1, e2) => { + getAllSummands(e1, false) ++ getAllSummands(e2, false) + } + case _ => if (isTopLevel) Seq[Expr]() else Seq[Expr](e) + } + } + + def getAllMaximands(e: Expr, isTopLevel: Boolean): Seq[Expr] = { + e match { + case FunctionInvocation(tfd, args) if (tfd.fd == maxFun) => { + args.foldLeft(Seq[Expr]())((accSet, e) => accSet ++ getAllMaximands(e, false)) + } + case _ => if (isTopLevel) Seq[Expr]() else Seq[Expr](e) + } + } + + simplePostTransform(e => { + val plusSimplifiedExpr = + simplifyConstantsGeneral(e, _ + _, getAllSummands, 0, ((e1, e2) => Plus(e1, e2))) + + // Maximum simplification assumes all arguments to max + // are non-negative (and hence 0 is the identity) + val maxSimplifiedExpr = + simplifyConstantsGeneral(plusSimplifiedExpr, + ((a: BigInt, b: BigInt) => if (a > b) a else b), + getAllMaximands, + 0, + ((e1, e2) => { + val typedMaxFun = TypedFunDef(maxFun, Seq()) + FunctionInvocation(typedMaxFun, Seq(e1, e2)) + })) + + maxSimplifiedExpr + })(e) + } +} \ No newline at end of file diff --git a/src/main/scala/leon/invariant/util/Minimizer.scala b/src/main/scala/leon/invariant/util/Minimizer.scala new file mode 100644 index 0000000000000000000000000000000000000000..7e44eaaa78990f0ad9d97ac01e7bb90c569883d6 --- /dev/null +++ b/src/main/scala/leon/invariant/util/Minimizer.scala @@ -0,0 +1,200 @@ +package leon +package invariant.util +import z3.scala._ +import purescala.Common._ +import purescala.Definitions._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.Extractors._ +import purescala.Types._ +import solvers._ +import solvers.z3._ +import leon.invariant._ +import scala.util.control.Breaks._ +import invariant.engine.InferenceContext +import invariant.factories._ +import leon.invariant.templateSolvers.ExtendedUFSolver +import leon.invariant.util.RealValuedExprEvaluator._ + +class Minimizer(ctx: InferenceContext) { + + val verbose = false + val debugMinimization = false + /** + * Here we are assuming that that initModel is a model for ctrs + * TODO: make sure that the template for rootFun is the time template + */ + val MaxIter = 16 //note we may not be able to represent anything beyond 2^16 + /*val MaxInt = Int.MaxValue + val sqrtMaxInt = 45000 //this is a number that is close a sqrt of 2^31 +*/ val half = FractionalLiteral(1, 2) + val two = FractionalLiteral(2, 1) + val rzero = FractionalLiteral(0, 1) + val mone = FractionalLiteral(-1, 1) + + private val program = ctx.program + private val leonctx = ctx.leonContext + val reporter = leonctx.reporter + + //for statistics and output + //store the lowerbounds for each template variables in the template of the rootFun provided it is a time template + var lowerBoundMap = Map[Variable, FractionalLiteral]() + def updateLowerBound(tvar: Variable, rval: FractionalLiteral) = { + //record the lower bound if it exist + if (lowerBoundMap.contains(tvar)) { + lowerBoundMap -= tvar + } + lowerBoundMap += (tvar -> rval) + } + + def tightenTimeBounds(timeTemplate: Expr)(inputCtr: Expr, initModel: Model) = { + //the order in which the template variables are minimized is based on the level of nesting of the terms + minimizeBounds(computeCompositionLevel(timeTemplate))(inputCtr, initModel) + } + + def minimizeBounds(nestMap: Map[Variable, Int])(inputCtr: Expr, initModel: Model): Model = { + val orderedTempVars = nestMap.toSeq.sortWith((a, b) => a._2 >= b._2).map(_._1) + //do a binary search sequentially on each of these tempvars + val solver = SimpleSolverAPI( + new TimeoutSolverFactory(SolverFactory(() => + new ExtendedUFSolver(leonctx, program) with TimeoutSolver), ctx.timeout * 1000)) + + reporter.info("minimizing...") + var currentModel = initModel + orderedTempVars.foldLeft(inputCtr: Expr)((acc, tvar) => { + var upperBound = if (currentModel.isDefinedAt(tvar.id)) { + currentModel(tvar.id).asInstanceOf[FractionalLiteral] + } else { + initModel(tvar.id).asInstanceOf[FractionalLiteral] + } + //note: the lower bound is an integer by construction (and is by default zero) + var lowerBound: FractionalLiteral = + if (tvar == orderedTempVars(0) && lowerBoundMap.contains(tvar)) + lowerBoundMap(tvar) + else realzero + //a helper method + def updateState(nmodel: Model) = { + upperBound = nmodel(tvar.id).asInstanceOf[FractionalLiteral] + currentModel = nmodel + if (this.debugMinimization) { + reporter.info("Found new upper bound: " + upperBound) + //reporter.info("Model: "+currentModel) + } + } + + if (this.debugMinimization) + reporter.info(s"Minimizing variable: $tvar Initial Bounds: [$upperBound,$lowerBound]") + //TODO: use incremental solving of z3 when it is supported in nlsat + var continue = true + var iter = 0 + do { + iter += 1 + if (continue) { + //we make sure that curr val is an integer + val currval = floor(evaluate(Times(half, Plus(upperBound, lowerBound)))) + //check if the lowerbound, if it exists, is < currval + if (evaluateRealPredicate(GreaterEquals(lowerBound, currval))) + continue = false + else { + val boundCtr = And(LessEquals(tvar, currval), GreaterEquals(tvar, lowerBound)) + //val t1 = System.currentTimeMillis() + val (res, newModel) = solver.solveSAT(And(acc, boundCtr)) + //val t2 = System.currentTimeMillis() + //println((if (res.isDefined) "solved" else "timed out") + "... in " + (t2 - t1) / 1000.0 + "s") + res match { + case Some(true) => + updateState(newModel) + case _ => + //here we have a new lower bound: currval + lowerBound = currval + if (this.debugMinimization) + reporter.info("Found new lower bound: " + currval) + } + } + } + } while (continue && iter < MaxIter) + //this is the last ditch effort to make the upper bound constant smaller. + //check if the floor of the upper-bound is a solution + val currval @ FractionalLiteral(n, d) = + if (currentModel.isDefinedAt(tvar.id)) { + currentModel(tvar.id).asInstanceOf[FractionalLiteral] + } else { + initModel(tvar.id).asInstanceOf[FractionalLiteral] + } + if (d != 1) { + val (res, newModel) = solver.solveSAT(And(acc, Equals(tvar, floor(currval)))) + if (res == Some(true)) + updateState(newModel) + } + //here, we found a best-effort minimum + if (lowerBound != realzero) { + updateLowerBound(tvar, lowerBound) + } + And(acc, Equals(tvar, currval)) + }) + new Model(initModel.map { + case (id, e) => + if (currentModel.isDefinedAt(id)) + (id -> currentModel(id)) + else + (id -> initModel(id)) + }.toMap) + } + + def checkBoundingInteger(tvar: Variable, rl: FractionalLiteral, nlctr: Expr, solver: SimpleSolverAPI): Option[Model] = { + val nl @ FractionalLiteral(n, d) = normalizeFraction(rl) + if (d != 1) { + val flval = floor(nl) + val (res, newModel) = solver.solveSAT(And(nlctr, Equals(tvar, flval))) + res match { + case Some(true) => Some(newModel) + case _ => None + } + } else None + } + + /** + * The following code is little tricky + */ + def computeCompositionLevel(template: Expr): Map[Variable, Int] = { + var nestMap = Map[Variable, Int]() + + def updateMax(v: Variable, level: Int) = { + if (verbose) reporter.info("Nesting level: " + v + "-->" + level) + if (nestMap.contains(v)) { + if (nestMap(v) < level) { + nestMap -= v + nestMap += (v -> level) + } + } else + nestMap += (v -> level) + } + + def functionNesting(e: Expr): Int = { + e match { + + case Times(e1, v @ Variable(id)) if (TemplateIdFactory.IsTemplateIdentifier(id)) => { + val nestLevel = functionNesting(e1) + updateMax(v, nestLevel) + nestLevel + } + case Times(v @ Variable(id), e2) if (TemplateIdFactory.IsTemplateIdentifier(id)) => { + val nestLevel = functionNesting(e2) + updateMax(v, nestLevel) + nestLevel + } + case v @ Variable(id) if (TemplateIdFactory.IsTemplateIdentifier(id)) => { + updateMax(v, 0) + 0 + } + case FunctionInvocation(_, args) => 1 + args.foldLeft(0)((acc, arg) => acc + functionNesting(arg)) + case t: Terminal => 0 + /*case UnaryOperator(arg, _) => functionNesting(arg) + case BinaryOperator(a1, a2, _) => functionNesting(a1) + functionNesting(a2)*/ + case Operator(args, _) => args.foldLeft(0)((acc, arg) => acc + functionNesting(arg)) + } + } + functionNesting(template) + nestMap + } +} \ No newline at end of file diff --git a/src/main/scala/leon/invariant/util/RealExprEvaluator.scala b/src/main/scala/leon/invariant/util/RealExprEvaluator.scala new file mode 100644 index 0000000000000000000000000000000000000000..d591de24533d0f26fdfaae35386a87d1b936b617 --- /dev/null +++ b/src/main/scala/leon/invariant/util/RealExprEvaluator.scala @@ -0,0 +1,100 @@ +package leon +package invariant.util + +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.Extractors._ +import scala.math.BigInt.int2bigInt + +object RealValuedExprEvaluator { + + /** + * Requires that the input expression is ground + */ + def evaluate(expr: Expr): FractionalLiteral = { + plainEvaluate(expr) + } + + def plainEvaluate(expr: Expr): FractionalLiteral = expr match { + + case UMinus(e) => { + val FractionalLiteral(num, denom) = plainEvaluate(e) + FractionalLiteral(-num, denom) + } + case Minus(lhs, rhs) => { + plainEvaluate(Plus(lhs, UMinus(rhs))) + } + case Plus(_, _) | RealPlus(_, _) => { + val Operator(Seq(lhs, rhs), op) = expr + val FractionalLiteral(lnum, ldenom) = plainEvaluate(lhs) + val FractionalLiteral(rnum, rdenom) = plainEvaluate(rhs) + normalizeFraction(FractionalLiteral((lnum * rdenom + rnum * ldenom), (ldenom * rdenom))) + } + case Times(_, _) | RealTimes(_, _) => { + val Operator(Seq(lhs, rhs), op) = expr + val FractionalLiteral(lnum, ldenom) = plainEvaluate(lhs) + val FractionalLiteral(rnum, rdenom) = plainEvaluate(rhs) + normalizeFraction(FractionalLiteral((lnum * rnum), (ldenom * rdenom))) + } + case Division(_, _) | RealDivision(_, _) => { + val Operator(Seq(lhs, rhs), op) = expr + val FractionalLiteral(lnum, ldenom) = plainEvaluate(lhs) + val FractionalLiteral(rnum, rdenom) = plainEvaluate(rhs) + plainEvaluate(Times(FractionalLiteral(lnum, ldenom), FractionalLiteral(rdenom, rnum))) + } + case il @ InfiniteIntegerLiteral(v) => FractionalLiteral(v, 1) + case rl @ FractionalLiteral(_, _) => normalizeFraction(rl) + case _ => throw new IllegalStateException("Not an evaluatable expression: " + expr) + } + + def evaluateRealPredicate(expr: Expr): Boolean = expr match { + case Equals(a @ FractionalLiteral(_, _), b @ FractionalLiteral(_, _)) => isEQZ(evaluate(Minus(a, b))) + case LessEquals(a @ FractionalLiteral(_, _), b @ FractionalLiteral(_, _)) => isLEZ(evaluate(Minus(a, b))) + case LessThan(a @ FractionalLiteral(_, _), b @ FractionalLiteral(_, _)) => isLTZ(evaluate(Minus(a, b))) + case GreaterEquals(a @ FractionalLiteral(_, _), b @ FractionalLiteral(_, _)) => isGEZ(evaluate(Minus(a, b))) + case GreaterThan(a @ FractionalLiteral(n1, d1), b @ FractionalLiteral(n2, d2)) => isGTZ(evaluate(Minus(a, b))) + } + + def isEQZ(rlit: FractionalLiteral): Boolean = { + val FractionalLiteral(n, d) = rlit + if (d == 0) throw new IllegalStateException("denominator zero") + (n == 0) + } + + def isLEZ(rlit: FractionalLiteral): Boolean = { + val FractionalLiteral(n, d) = rlit + if (d == 0) throw new IllegalStateException("denominator zero") + if (d < 0) throw new IllegalStateException("denominator negative: " + d) + (n <= 0) + } + + def isLTZ(rlit: FractionalLiteral): Boolean = { + val FractionalLiteral(n, d) = rlit + if (d == 0) throw new IllegalStateException("denominator zero") + if (d < 0) throw new IllegalStateException("denominator negative: " + d) + (n < 0) + } + + def isGEZ(rlit: FractionalLiteral): Boolean = { + val FractionalLiteral(n, d) = rlit + if (d == 0) throw new IllegalStateException("denominator zero") + if (d < 0) throw new IllegalStateException("denominator negative: " + d) + (n >= 0) + } + + def isGTZ(rlit: FractionalLiteral): Boolean = { + val FractionalLiteral(n, d) = rlit + if (d == 0) throw new IllegalStateException("denominator zero") + if (d < 0) throw new IllegalStateException("denominator negative: " + d) + (n > 0) + } + + def evaluateRealFormula(expr: Expr): Boolean = expr match { + case And(args) => args forall evaluateRealFormula + case Or(args) => args exists evaluateRealFormula + case Not(arg) => !evaluateRealFormula(arg) + case BooleanLiteral(b) => b + case Operator(args, op) => + evaluateRealPredicate(op(args map evaluate)) + } +} diff --git a/src/main/scala/leon/invariant/util/Stats.scala b/src/main/scala/leon/invariant/util/Stats.scala new file mode 100644 index 0000000000000000000000000000000000000000..76a1f12ee4eede0ba631969ebb7bf3076bb3bf08 --- /dev/null +++ b/src/main/scala/leon/invariant/util/Stats.scala @@ -0,0 +1,130 @@ +package leon +package invariant.util + +import purescala.Common._ +import purescala.Definitions._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.Extractors._ +import purescala.Types._ +import scala.collection.mutable.{ Map => MutableMap } +import java.io._ +import leon.invariant._ +import java.io._ +import scala.collection.mutable.{Map => MutableMap} + + +/** + * A generic statistics object that provides: + * (a) Temporal variables that change over time. We track the total sum and max of the values the variable takes over time + * (b) Counters that are incremented over time. Variables can be associated with counters. + * We track the averages value of a variable over time w.r.t to the counters with which it is associated. + */ +object Stats { + val keystats = MutableMap[String, (Long, Long)]() + val counterMap = MutableMap[String, Seq[String]]() + var cumKeys = Seq[String]() + var timekeys = Set[String]() //this may be inner, outer or cumkey + + private def updateStats(newval: Long, key: String, cname: Option[String]) = { + val (cum, max) = keystats.getOrElse(key, { + val init = (0: Long, 0: Long) + keystats += (key -> (0, 0)) + + if (cname.isDefined) { + val presentKeys = counterMap(cname.get) + counterMap.update(cname.get, presentKeys :+ key) + } else { + cumKeys :+= key + } + init + }) + val newcum = cum + newval + val newmax = if (max < newval) newval else max + keystats.update(key, (newcum, newmax)) + } + //a special method for adding times + private def updateTimeStats(newval: Long, key: String, cname: Option[String]) = { + if (!timekeys.contains(key)) + timekeys += key + updateStats(newval, key, cname) + } + + def updateCumStats(newval: Long, key: String) = updateStats(newval, key, None) + def updateCumTime(newval: Long, key: String) = updateTimeStats(newval, key, None) + def updateCounter(incr: Long, key: String) = { + if (!counterMap.contains(key)) { + counterMap.update(key, Seq()) + } + //counters are considered as cumulative stats + updateStats(incr, key, None) + } + def updateCounterStats(newval: Long, key: String, cname: String) = updateStats(newval, key, Some(cname)) + def updateCounterTime(newval: Long, key: String, cname: String) = updateTimeStats(newval, key, Some(cname)) + + private def getCum(key: String): Long = keystats(key)._1 + private def getMax(key: String): Long = keystats(key)._2 + + def dumpStats(pr: PrintWriter) = { + //Print cumulative stats + cumKeys.foreach(key => { + if (timekeys.contains(key)) { + pr.println(key + ": " + (getCum(key).toDouble / 1000.0) + "s") + } else + pr.println(key + ": " + getCum(key)) + }) + + //dump the averages and maximum of all stats associated with counters + counterMap.keys.foreach((ckey) => { + pr.println("### Statistics for counter: " + ckey + " ####") + val counterval = getCum(ckey) + val assocKeys = counterMap(ckey) + assocKeys.foreach((key) => { + if (timekeys.contains(key)) { + pr.println("Avg." + key + ": " + (getCum(key).toDouble / (counterval * 1000.0)) + "s") + pr.println("Max." + key + ": " + (getMax(key).toDouble / 1000.0) + "s") + } else { + pr.println("Avg." + key + ": " + (getCum(key).toDouble / counterval)) + pr.println("Max." + key + ": " + getMax(key)) + } + }) + }) + } +} + +/** + * Statistics specific for this application + */ +object SpecificStats { + + var output: String = "" + def addOutput(out: String) = { + output += out + "\n" + } + def dumpOutputs(pr: PrintWriter) { + pr.println("########## Outputs ############") + pr.println(output) + pr.flush() + } + + //minimization stats + var lowerBounds = Map[FunDef, Map[Variable, FractionalLiteral]]() + var lowerBoundsOutput = Map[FunDef, String]() + def addLowerBoundStats(fd: FunDef, lbMap: Map[Variable, FractionalLiteral], out: String) = { + lowerBounds += (fd -> lbMap) + lowerBoundsOutput += (fd -> out) + } + def dumpMinimizationStats(pr: PrintWriter) { + pr.println("########## Lower Bounds ############") + lowerBounds.foreach((pair) => { + val (fd, lbMap) = pair + pr.print(fd.id + ": \t") + lbMap.foreach((entry) => { + pr.print("(" + entry._1 + "->" + entry._2 + "), ") + }) + pr.print("\t Test results: " + lowerBoundsOutput(fd)) + pr.println() + }) + pr.flush() + } +} \ No newline at end of file diff --git a/src/main/scala/leon/invariant/util/Util.scala b/src/main/scala/leon/invariant/util/Util.scala new file mode 100644 index 0000000000000000000000000000000000000000..fdd70bf9a5bb21155d7269279d39be0e0559d5a1 --- /dev/null +++ b/src/main/scala/leon/invariant/util/Util.scala @@ -0,0 +1,717 @@ +package leon +package invariant.util + +import purescala.Common._ +import purescala.Definitions._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.Extractors._ +import purescala.Types._ +import scala.collection.mutable.{ Set => MutableSet, Map => MutableMap } +import scala.collection.immutable.Stack +import java.io._ +import leon.invariant._ +import java.io._ +import solvers.z3._ +import solvers._ +import invariant.engine._ +import invariant.factories._ +import invariant.structure._ +import leon.purescala.PrettyPrintable +import leon.purescala.PrinterContext +import purescala.PrinterHelpers._ +import FunctionUtils._ +import leon.invariant.templateSolvers.ExtendedUFSolver +import scala.annotation.tailrec + +object FileCountGUID { + var fileCount = 0 + def getID: Int = { + var oldcnt = fileCount + fileCount += 1 + oldcnt + } +} + +//three valued logic +object TVL { + abstract class Value + object FALSE extends Value + object TRUE extends Value + object MAYBE extends Value +} + +//this is used as a place holder for result +case class ResultVariable(tpe: TypeTree) extends Expr with Terminal with PrettyPrintable { + val getType = tpe + override def toString: String = "#res" + + def printWith(implicit pctx: PrinterContext) { + p"#res" + } +} + +//this used to refer to the time steps of a procedure +case class TimeVariable() extends Expr with Terminal with PrettyPrintable { + val getType = IntegerType + override def toString: String = "#time" + def printWith(implicit pctx: PrinterContext) { + p"#time" + } +} + +//this used to refer to the depth of a procedure +case class DepthVariable() extends Expr with Terminal with PrettyPrintable { + val getType = IntegerType + override def toString: String = "#depth" + def printWith(implicit pctx: PrinterContext) { + p"#time" + } +} + +object TVarFactory { + + val temporaries = MutableSet[Identifier]() + //these are dummy identifiers used in 'CaseClassSelector' conversion + val dummyIds = MutableSet[Identifier]() + + def createTemp(name: String, tpe: TypeTree = Untyped): Identifier = { + val freshid = FreshIdentifier(name, tpe, true) + temporaries.add(freshid) + freshid + } + + def createDummy(tpe: TypeTree): Identifier = { + val freshid = FreshIdentifier("dy", tpe, true) + dummyIds.add(freshid) + freshid + } + + def isTemporary(id: Identifier): Boolean = temporaries.contains(id) + def isDummy(id: Identifier): Boolean = dummyIds.contains(id) +} + +object Util { + + val zero = InfiniteIntegerLiteral(0) + val one = InfiniteIntegerLiteral(1) + val tru = BooleanLiteral(true) + val fls = BooleanLiteral(false) + + /** + * Here, we exclude empty units that do not have any modules and empty + * modules that do not have any definitions + */ + def copyProgram(prog: Program, mapdefs: (Seq[Definition] => Seq[Definition])): Program = { + prog.copy(units = prog.units.collect { + case unit if (!unit.defs.isEmpty) => unit.copy(defs = unit.defs.collect { + case module : ModuleDef if (!module.defs.isEmpty) => + module.copy(defs = mapdefs(module.defs)) + case other => other + }) + }) + } + + def createTemplateFun(plainTemp: Expr): FunctionInvocation = { + val tmpl = Lambda(getTemplateIds(plainTemp).toSeq.map(id => ValDef(id)), plainTemp) + val tmplFd = new FunDef(FreshIdentifier("tmpl", FunctionType(Seq(tmpl.getType), BooleanType), false), + Seq(), BooleanType, Seq(ValDef(FreshIdentifier("arg", tmpl.getType), + Some(tmpl.getType)))) + tmplFd.body = Some(BooleanLiteral(true)) + FunctionInvocation(TypedFunDef(tmplFd, Seq()), Seq(tmpl)) + } + + /** + * This is the default template generator. + * Note: we are not creating template for libraries. + */ + def getOrCreateTemplateForFun(fd: FunDef): Expr = { + val plainTemp = if (fd.hasTemplate) fd.getTemplate + else if (fd.annotations.contains("library")) BooleanLiteral(true) + else { + //just consider all the arguments, return values that are integers + val baseTerms = fd.params.filter((vardecl) => isNumericType(vardecl.getType)).map(_.toVariable) ++ + (if (isNumericType(fd.returnType)) Seq(Util.getFunctionReturnVariable(fd)) + else Seq()) + val lhs = baseTerms.foldLeft(TemplateIdFactory.freshTemplateVar(): Expr)((acc, t) => { + Plus(Times(TemplateIdFactory.freshTemplateVar(), t), acc) + }) + val tempExpr = LessEquals(lhs, InfiniteIntegerLiteral(0)) + tempExpr + } + plainTemp + } + + def mapFunctionsInExpr(funmap: Map[FunDef, FunDef])(ine: Expr): Expr = { + simplePostTransform((e: Expr) => e match { + case FunctionInvocation(tfd, args) if funmap.contains(tfd.fd) => + FunctionInvocation(TypedFunDef(funmap(tfd.fd), tfd.tps), args) + case _ => e + })(ine) + } + + def assignTemplateAndCojoinPost(funToTmpl: Map[FunDef, Expr], prog: Program, funToPost: Map[FunDef, Expr] = Map()): Program = { + + val funMap = Util.functionsWOFields(prog.definedFunctions).foldLeft(Map[FunDef, FunDef]()) { + case (accMap, fd) if fd.isTheoryOperation => + accMap + (fd -> fd) + case (accMap, fd) => { + val freshId = FreshIdentifier(fd.id.name, fd.returnType, true) + val newfd = new FunDef(freshId, fd.tparams, fd.returnType, fd.params) + accMap.updated(fd, newfd) + } + } + + // FIXME: This with createAnd (which performs simplifications) gives an error during composition. + val mapExpr = mapFunctionsInExpr(funMap) _ + for ((from, to) <- funMap) { + to.fullBody = if (!funToTmpl.contains(from)) { + mapExpr { + from.fullBody match { + case Ensuring(b, post) => + Ensuring(b, + Lambda(Seq(ValDef(Util.getResId(from).get)), + createAnd(Seq(from.getPostWoTemplate, funToPost.getOrElse(from, tru))))) + case fb => + fb + } + } + } else { + val newTmpl = createTemplateFun(funToTmpl(from)) + mapExpr { + from.fullBody match { + case Require(pre, body) => + val toPost = + Lambda(Seq(ValDef(FreshIdentifier("res", from.returnType))), + createAnd(Seq(newTmpl, funToPost.getOrElse(from, tru)))) + Ensuring(Require(pre, body), toPost) + + case Ensuring(Require(pre, body), post) => + Ensuring(Require(pre, body), + Lambda(Seq(ValDef(Util.getResId(from).get)), + createAnd(Seq(from.getPostWoTemplate, newTmpl, funToPost.getOrElse(from, tru))))) + + case Ensuring(body, post) => + Ensuring(body, + Lambda(Seq(ValDef(Util.getResId(from).get)), + createAnd(Seq(from.getPostWoTemplate, newTmpl, funToPost.getOrElse(from, tru))))) + + case body => + val toPost = + Lambda(Seq(ValDef(FreshIdentifier("res", from.returnType))), + createAnd(Seq(newTmpl, funToPost.getOrElse(from, tru)))) + Ensuring(body, toPost) + } + } + } + //copy annotations + from.flags.foreach(to.addFlag(_)) + } + val newprog = Util.copyProgram(prog, (defs: Seq[Definition]) => defs.map { + case fd: FunDef if funMap.contains(fd) => + funMap(fd) + case d => d + }) + newprog + } + + def functionByName(nm: String, prog: Program) = { + prog.definedFunctions.find(fd => fd.id.name == nm) + } + + def functionsWOFields(fds: Seq[FunDef]): Seq[FunDef] = { + fds.filter(_.isRealFunction) + } + + def isNumericExpr(expr: Expr): Boolean = { + expr.getType == IntegerType || + expr.getType == RealType + } + + def getFunctionReturnVariable(fd: FunDef) = { + if (fd.hasPostcondition) getResId(fd).get.toVariable + else ResultVariable(fd.returnType) /*FreshIdentifier("res", fd.returnType).toVariable*/ + } + + //compute the formal to the actual argument mapping + def formalToActual(call: Call): Map[Expr, Expr] = { + val fd = call.fi.tfd.fd + val resvar = getFunctionReturnVariable(fd) + val argmap: Map[Expr, Expr] = Map(resvar -> call.retexpr) ++ fd.params.map(_.id.toVariable).zip(call.fi.args) + argmap + } + + /** + * Checks if the input expression has only template variables as free variables + */ + def isTemplateExpr(expr: Expr): Boolean = { + var foundVar = false + simplePostTransform((e: Expr) => e match { + case Variable(id) => { + if (!TemplateIdFactory.IsTemplateIdentifier(id)) + foundVar = true + e + } + case ResultVariable(_) => { + foundVar = true + e + } + case _ => e + })(expr) + + !foundVar + } + + def getTemplateIds(expr: Expr) = { + variablesOf(expr).filter(TemplateIdFactory.IsTemplateIdentifier) + } + + def getTemplateVars(expr: Expr): Set[Variable] = { + /*var tempVars = Set[Variable]() + postTraversal(e => e match { + case t @ Variable(id) => + if (TemplateIdFactory.IsTemplateIdentifier(id)) + tempVars += t + case _ => + })(expr) + tempVars*/ + getTemplateIds(expr).map(_.toVariable) + } + + /** + * Checks if the expression has real valued sub-expressions. + */ + def hasReals(expr: Expr): Boolean = { + var foundReal = false + simplePostTransform((e: Expr) => e match { + case _ => { + if (e.getType == RealType) + foundReal = true; + e + } + })(expr) + foundReal + } + + /** + * Checks if the expression has real valued sub-expressions. + * Note: important, <, <=, > etc have default int type. + * However, they can also be applied over real arguments + * So check only if all terminals are real + */ + def hasInts(expr: Expr): Boolean = { + var foundInt = false + simplePostTransform((e: Expr) => e match { + case e: Terminal if (e.getType == Int32Type || e.getType == IntegerType) => { + foundInt = true; + e + } + case _ => e + })(expr) + foundInt + } + + def hasMixedIntReals(expr: Expr): Boolean = { + hasInts(expr) && hasReals(expr) + } + + def fix[A](f: (A) => A)(a: A): A = { + val na = f(a) + if (a == na) a else fix(f)(na) + } + + def atomNum(e: Expr): Int = { + var count: Int = 0 + simplePostTransform((e: Expr) => e match { + case And(args) => { + count += args.size + e + } + case Or(args) => { + count += args.size + e + } + case _ => e + })(e) + count + } + + def numUIFADT(e: Expr): Int = { + var count: Int = 0 + simplePostTransform((e: Expr) => e match { + case FunctionInvocation(_, _) | CaseClass(_, _) | Tuple(_) => { + count += 1 + e + } + case _ => e + })(e) + count + } + + def hasCalls(e: Expr) = numUIFADT(e) >= 1 + + def getCallExprs(ine: Expr): Set[Expr] = { + var calls = Set[Expr]() + simplePostTransform((e: Expr) => e match { + case call @ _ if Util.isCallExpr(e) => { + calls += e + call + } + case _ => e + })(ine) + calls + } + + def isCallExpr(e: Expr): Boolean = e match { + case Equals(Variable(_), FunctionInvocation(_, _)) => true + // case Iff(Variable(_),FunctionInvocation(_,_)) => true + case _ => false + } + + def isADTConstructor(e: Expr): Boolean = e match { + case Equals(Variable(_), CaseClass(_, _)) => true + case Equals(Variable(_), Tuple(_)) => true + case _ => false + } + + def modelToExpr(model: Model): Expr = { + model.foldLeft(tru: Expr)((acc, elem) => { + val (k, v) = elem + val eq = Equals(k.toVariable, v) + if (acc == tru) eq + else And(acc, eq) + }) + } + + def gcd(x: Int, y: Int): Int = { + if (x == 0) y + else gcd(y % x, x) + } + + def toZ3SMTLIB(expr: Expr, filename: String, + theory: String, ctx: LeonContext, pgm: Program, + useBitvectors: Boolean = false, + bitvecSize: Int = 32) = { + //create new solver, assert constraints and print + val printSol = new ExtendedUFSolver(ctx, pgm) + printSol.assertCnstr(expr) + val writer = new PrintWriter(filename) + writer.println(printSol.ctrsToString(theory)) + printSol.free() + writer.flush() + writer.close() + } + + /** + * A helper function that can be used to hardcode an invariant and see if it unsatifies the paths + */ + def checkInvariant(expr: Expr, ctx: LeonContext, prog: Program): Option[Boolean] = { + val idmap: Map[Expr, Expr] = variablesOf(expr).collect { + case id @ _ if (id.name.toString == "a?") => id.toVariable -> InfiniteIntegerLiteral(6) + case id @ _ if (id.name.toString == "c?") => id.toVariable -> InfiniteIntegerLiteral(2) + }.toMap + //println("found ids: " + idmap.keys) + if (!idmap.keys.isEmpty) { + val newpathcond = replace(idmap, expr) + //check if this is solvable + val solver = SimpleSolverAPI(SolverFactory(() => new ExtendedUFSolver(ctx, prog))) + solver.solveSAT(newpathcond)._1 match { + case Some(true) => { + println("Path satisfiable for a?,c? -->6,2 ") + Some(true) + } + case _ => { + println("Path unsat for a?,c? --> 6,2") + Some(false) + } + } + } else None + } + + def collectUNSATCores(ine: Expr, ctx: LeonContext, prog: Program): Expr = { + var controlVars = Map[Variable, Expr]() + var newEqs = Map[Expr, Expr]() + val solver = new ExtendedUFSolver(ctx, prog) + val newe = simplePostTransform((e: Expr) => e match { + case And(_) | Or(_) => { + val v = TVarFactory.createTemp("a", BooleanType).toVariable + newEqs += (v -> e) + val newe = Equals(v, e) + + //create new variable and add it in disjunction + val cvar = FreshIdentifier("ctrl", BooleanType, true).toVariable + controlVars += (cvar -> newe) + solver.assertCnstr(Or(newe, cvar)) + v + } + case _ => e + })(ine) + //create new variable and add it in disjunction + val cvar = FreshIdentifier("ctrl", BooleanType, true).toVariable + controlVars += (cvar -> newe) + solver.assertCnstr(Or(newe, cvar)) + + val res = solver.checkAssumptions(controlVars.keySet.map(Not.apply _)) + println("Result: " + res) + val coreExprs = solver.getUnsatCore + val simpcores = coreExprs.foldLeft(Seq[Expr]())((acc, coreExp) => { + val Not(cvar @ Variable(_)) = coreExp + val newexp = controlVars(cvar) + //println("newexp: "+newexp) + newexp match { + // case Iff(v@Variable(_),rhs) if(newEqs.contains(v)) => acc + case Equals(v @ Variable(_), rhs) if (v.getType == BooleanType && rhs.getType == BooleanType && newEqs.contains(v)) => acc + case _ => { + acc :+ newexp + } + } + }) + val cores = Util.fix((e: Expr) => replace(newEqs, e))(Util.createAnd(simpcores.toSeq)) + + solver.free + //cores + ExpressionTransformer.unFlatten(cores, + variablesOf(ine).filterNot(TVarFactory.isTemporary _)) + } + + def isMultFunctions(fd: FunDef) = { + (fd.id.name == "mult" || fd.id.name == "pmult") && + fd.isTheoryOperation + } + //replaces occurrences of mult by Times + def multToTimes(ine: Expr): Expr = { + simplePostTransform((e: Expr) => e match { + case FunctionInvocation(TypedFunDef(fd, _), args) if isMultFunctions(fd) => { + Times(args(0), args(1)) + } + case _ => e + })(ine) + } + + /** + * A cross product with an optional filter + */ + def cross[U, V](a: Set[U], b: Set[V], selector: Option[(U, V) => Boolean] = None): Set[(U, V)] = { + + val product = (for (x <- a; y <- b) yield (x, y)) + if (selector.isDefined) + product.filter(pair => selector.get(pair._1, pair._2)) + else + product + } + + def getResId(funDef: FunDef): Option[Identifier] = { + funDef.fullBody match { + case Ensuring(_, post) => { + post match { + case Lambda(Seq(ValDef(fromRes, _)), _) => Some(fromRes) + case _ => + throw new IllegalStateException("Postcondition with multiple return values!") + } + } + case _ => None + } + } + + def createAnd(exprs: Seq[Expr]): Expr = { + val newExprs = exprs.filterNot(conj => conj == tru) + newExprs match { + case Seq() => tru + case Seq(e) => e + case _ => And(newExprs) + } + } + + def createOr(exprs: Seq[Expr]): Expr = { + val newExprs = exprs.filterNot(disj => disj == fls) + newExprs match { + case Seq() => fls + case Seq(e) => e + case _ => Or(newExprs) + } + } + + def isNumericType(t: TypeTree) = t match { + case IntegerType | RealType => true + case _ => false + } + + //tests if the solver uses nlsat + def usesNLSat(solver: AbstractZ3Solver) = { + //check for nlsat + val x = FreshIdentifier("x", RealType).toVariable + val testExpr = Equals(Times(x, x), FractionalLiteral(2, 1)) + solver.assertCnstr(testExpr) + solver.check match { + case Some(true) => true + case _ => false + } + } +} + +/** + * maps all real valued variables and literals to new integer variables/literals and + * performs the reverse mapping + * Note: this should preserve the template identifier property + */ +class RealToInt { + + val bone = BigInt(1) + var realToIntId = Map[Identifier, Identifier]() + var intToRealId = Map[Identifier, Identifier]() + + def mapRealToInt(inexpr: Expr): Expr = { + val transformer = (e: Expr) => e match { + case FractionalLiteral(num, `bone`) => InfiniteIntegerLiteral(num) + case FractionalLiteral(_, _) => throw new IllegalStateException("Real literal with non-unit denominator") + case v @ Variable(realId) if (v.getType == RealType) => { + val newId = realToIntId.getOrElse(realId, { + //note: the fresh identifier has to be a template identifier if the original one is a template identifier + val freshId = if (TemplateIdFactory.IsTemplateIdentifier(realId)) + TemplateIdFactory.freshIdentifier(realId.name, IntegerType) + else + FreshIdentifier(realId.name, IntegerType, true) + + realToIntId += (realId -> freshId) + intToRealId += (freshId -> realId) + freshId + }) + Variable(newId) + } + case _ => e + } + simplePostTransform(transformer)(inexpr) + } + + def unmapModel(model: Model): Model = { + new Model(model.map(pair => { + val (key, value) = if (intToRealId.contains(pair._1)) { + (intToRealId(pair._1), + pair._2 match { + case InfiniteIntegerLiteral(v) => FractionalLiteral(v.toInt, 1) + case _ => pair._2 + }) + } else pair + (key -> value) + }).toMap) + } + + def mapModel(model: Model): Model = { + new Model(model.collect { + case (k, FractionalLiteral(n, bone)) => + (realToIntId(k), InfiniteIntegerLiteral(n)) + case (k, v) => + if (realToIntId.contains(k)) { + (realToIntId(k), v) + } else { + (k, v) + } + }.toMap) + } +} + +class MultiMap[A, B] extends scala.collection.mutable.HashMap[A, scala.collection.mutable.Set[B]] with scala.collection.mutable.MultiMap[A, B] { + /** + * Creates a new map and does not change the existing map + */ + def append(that: MultiMap[A, B]): MultiMap[A, B] = { + val newmap = new MultiMap[A, B]() + this.foreach { case (k, vset) => newmap += (k -> vset) } + that.foreach { + case (k, vset) => vset.foreach(v => newmap.addBinding(k, v)) + } + newmap + } +} + +/** + * A multimap that allows duplicate entries + */ +class OrderedMultiMap[A, B] extends scala.collection.mutable.HashMap[A, scala.collection.mutable.ListBuffer[B]] { + + def addBinding(key: A, value: B): this.type = { + get(key) match { + case None => + val list = new scala.collection.mutable.ListBuffer[B]() + list += value + this(key) = list + case Some(list) => + list += value + } + this + } + + /** + * Creates a new map and does not change the existing map + */ + def append(that: OrderedMultiMap[A, B]): OrderedMultiMap[A, B] = { + val newmap = new OrderedMultiMap[A, B]() + this.foreach { case (k, vlist) => newmap += (k -> vlist) } + that.foreach { + case (k, vlist) => vlist.foreach(v => newmap.addBinding(k, v)) + } + newmap + } + + /** + * Make the value of every key distinct + */ + def distinct: OrderedMultiMap[A, B] = { + val newmap = new OrderedMultiMap[A, B]() + this.foreach { case (k, vlist) => newmap += (k -> vlist.distinct) } + newmap + } +} + +/** + * Implements a mapping from Seq[A] to B where Seq[A] + * is stored as a Trie + */ +final class TrieMap[A, B] { + var childrenMap = Map[A, TrieMap[A, B]]() + var dataMap = Map[A, B]() + + @tailrec def addBinding(key: Seq[A], value: B) { + key match { + case Seq() => + throw new IllegalStateException("Key is empty!!") + case Seq(x) => + //add the value to the dataMap + if (dataMap.contains(x)) + throw new IllegalStateException("A mapping for key already exists: " + x + " --> " + dataMap(x)) + else + dataMap += (x -> value) + case head +: tail => //here, tail has at least one element + //check if we have an entry for seq(0) if yes go to the children, if not create one + val child = childrenMap.getOrElse(head, { + val ch = new TrieMap[A, B]() + childrenMap += (head -> ch) + ch + }) + child.addBinding(tail, value) + } + } + + @tailrec def lookup(key: Seq[A]): Option[B] = { + key match { + case Seq() => + throw new IllegalStateException("Key is empty!!") + case Seq(x) => + dataMap.get(x) + case head +: tail => //here, tail has at least one element + childrenMap.get(head) match { + case Some(child) => + child.lookup(tail) + case _ => None + } + } + } +} + +class CounterMap[T] extends scala.collection.mutable.HashMap[T, Int] { + def inc(v: T) = { + if (this.contains(v)) + this(v) += 1 + else this += (v -> 1) + } +} \ No newline at end of file diff --git a/src/main/scala/leon/purescala/Constructors.scala b/src/main/scala/leon/purescala/Constructors.scala index 7747ec680948859ea6099c595578705ad33f27ed..913e6efbd22bfe8479be9e2b5ea067227bd418e8 100644 --- a/src/main/scala/leon/purescala/Constructors.scala +++ b/src/main/scala/leon/purescala/Constructors.scala @@ -19,7 +19,7 @@ import Types._ * */ object Constructors { - /** If `isTuple`, the whole expression is returned. This is to avoid a situation like + /** If `isTuple`, the whole expression is returned. This is to avoid a situation like * `tupleSelect(tupleWrap(Seq(Tuple(x,y))),1) -> x`, which is not expected. * Instead, * `tupleSelect(tupleWrap(Seq(Tuple(x,y))),1) -> Tuple(x,y)`. @@ -71,10 +71,10 @@ object Constructors { */ def tupleWrap(es: Seq[Expr]): Expr = es match { case Seq() => UnitLiteral() - case Seq(elem) => elem + case Seq(elem) => elem case more => Tuple(more) } - + /** Wraps the sequence of patterns as a tuple. If the sequence contains a single pattern, it is returned instead. * If the sequence is empty, [[purescala.Expressions.LiteralPattern `LiteralPattern`]]`(None, `[[purescala.Expressions.UnitLiteral `UnitLiteral`]]`())` is returned. * @see [[purescala.Expressions.TuplePattern]] @@ -85,9 +85,9 @@ object Constructors { case Seq(elem) => elem case more => TuplePattern(None, more) } - + /** Wraps the sequence of types as a tuple. If the sequence contains a single type, it is returned instead. - * If the sequence is empty, the [[purescala.Types.UnitType UnitType]] is returned. + * If the sequence is empty, the [[purescala.Types.UnitType UnitType]] is returned. * @see [[purescala.Types.TupleType]] */ def tupleTypeWrap(tps : Seq[TypeTree]) = tps match { @@ -101,12 +101,12 @@ object Constructors { * @see [[purescala.Expressions.FunctionInvocation]] */ def functionInvocation(fd : FunDef, args : Seq[Expr]) = { - + require(fd.params.length == args.length, "Invoking function with incorrect number of arguments") - + val formalType = tupleTypeWrap(fd.params map { _.getType }) val actualType = tupleTypeWrap(args map { _.getType }) - + canBeSubtypeOf(actualType, typeParamsOf(formalType).toSeq, formalType) match { case Some(tmap) => FunctionInvocation(fd.typed(fd.tparams map { tpd => tmap.getOrElse(tpd.tp, tpd.tp) }), args) @@ -168,13 +168,13 @@ object Constructors { val filtered = filterCases(scrutinee.getType, None, cases) if (filtered.nonEmpty) MatchExpr(scrutinee, filtered) - else + else Error( cases.headOption.map{ _.rhs.getType }.getOrElse(Untyped), "No case matches the scrutinee" ) - } - + } + /** $encodingof `&&`-expressions with arbitrary number of operands, and simplified. * @see [[purescala.Expressions.And And]] */ @@ -251,7 +251,7 @@ object Constructors { */ def finiteArray(els: Seq[Expr]): Expr = { require(els.nonEmpty) - finiteArray(els, None, Untyped) // Untyped is not correct, but will not be used anyway + finiteArray(els, None, Untyped) // Untyped is not correct, but will not be used anyway } /** $encodingof Simplified `Array[...](...)` (array length and default element defined at run-time) with type information * @see [[purescala.Constructors#finiteArray(els:Map* finiteArray]] @@ -320,11 +320,11 @@ object Constructors { case (_, InfiniteIntegerLiteral(bi)) if bi == 0 => lhs case (IntLiteral(0), _) => rhs case (_, IntLiteral(0)) => lhs - case (RealLiteral(d), _) if d == 0 => rhs - case (_, RealLiteral(d)) if d == 0 => lhs - case (IsTyped(_, IntegerType), IsTyped(_, IntegerType)) => Plus(lhs, rhs) + case (FractionalLiteral(n, d), _) if n == 0 => rhs + case (_, FractionalLiteral(n, d)) if n == 0 => lhs case (IsTyped(_, Int32Type), IsTyped(_, Int32Type)) => BVPlus(lhs, rhs) case (IsTyped(_, RealType), IsTyped(_, RealType)) => RealPlus(lhs, rhs) + case _ => Plus(lhs, rhs) } /** $encodingof simplified `... - ...` (minus). @@ -337,9 +337,9 @@ object Constructors { case (_, IntLiteral(0)) => lhs case (InfiniteIntegerLiteral(bi), _) if bi == 0 => UMinus(rhs) case (IntLiteral(0), _) => BVUMinus(rhs) - case (IsTyped(_, IntegerType), IsTyped(_, IntegerType)) => Minus(lhs, rhs) case (IsTyped(_, Int32Type), IsTyped(_, Int32Type)) => BVMinus(lhs, rhs) case (IsTyped(_, RealType), IsTyped(_, RealType)) => RealMinus(lhs, rhs) + case _ => Minus(lhs, rhs) } /** $encodingof simplified `... * ...` (times). @@ -356,9 +356,9 @@ object Constructors { case (_, IntLiteral(1)) => lhs case (IntLiteral(0), _) => IntLiteral(0) case (_, IntLiteral(0)) => IntLiteral(0) - case (IsTyped(_, IntegerType), IsTyped(_, IntegerType)) => Times(lhs, rhs) case (IsTyped(_, Int32Type), IsTyped(_, Int32Type)) => BVTimes(lhs, rhs) case (IsTyped(_, RealType), IsTyped(_, RealType)) => RealTimes(lhs, rhs) + case _ => Times(lhs, rhs) } /** $encodingof expr.asInstanceOf[tpe], returns `expr` it it already is of type `tpe`. */ @@ -366,7 +366,6 @@ object Constructors { if (isSubtypeOf(expr.getType, tpe)) { expr } else { - //println(s"$expr:${expr.getType} is not a subtype of $tpe") AsInstanceOf(expr, tpe) } } diff --git a/src/main/scala/leon/purescala/ExprOps.scala b/src/main/scala/leon/purescala/ExprOps.scala index 1d98f57b87c86d91e61bee3e68c0894feac6780f..d8aadc52a499218e1f71e9e1eac3d204c3feb1ab 100644 --- a/src/main/scala/leon/purescala/ExprOps.scala +++ b/src/main/scala/leon/purescala/ExprOps.scala @@ -14,7 +14,7 @@ import solvers._ /** Provides functions to manipulate [[purescala.Expressions]]. * - * This object provides a few generic operations on Leon expressions, + * This object provides a few generic operations on Leon expressions, * as well as some common operations. * * The generic operations lets you apply operations on a whole tree @@ -47,7 +47,7 @@ object ExprOps { * A right tree fold applies the input function to the subnodes first (from left * to right), and combine the results along with the current node value. * - * @param f a function that takes the current node and the seq + * @param f a function that takes the current node and the seq * of results form the subtrees. * @param e The Expr on which to apply the fold. * @return The expression after applying `f` on all subtrees. @@ -115,10 +115,10 @@ object ExprOps { * Takes a partial function of replacements and substitute * '''before''' recursing down the trees. * - * Supports two modes : - * + * Supports two modes : + * * - If applyRec is false (default), will only substitute once on each level. - * + * * e.g. * {{{ * Add(a, Minus(b, c)) with replacements: Minus(b,c) -> d, b -> e, d -> f @@ -127,18 +127,18 @@ object ExprOps { * {{{ * Add(a, d) // And not Add(a, f) because it only substitute once for each level. * }}} - * + * * - If applyRec is true, it will substitute multiple times on each level: - * + * * e.g. * {{{ * Add(a, Minus(b, c)) with replacements: Minus(b,c) -> d, b -> e, d -> f * }}} * will yield: * {{{ - * Add(a, f) + * Add(a, f) * }}} - * + * * @note The mode with applyRec true can diverge if f is not well formed */ def preMap(f: Expr => Option[Expr], applyRec : Boolean = false)(e: Expr): Expr = { @@ -146,7 +146,7 @@ object ExprOps { val newV = if (applyRec) { // Apply f as long as it returns Some() - fixpoint { e : Expr => f(e) getOrElse e } (e) + fixpoint { e : Expr => f(e) getOrElse e } (e) } else { f(e) getOrElse e } @@ -166,7 +166,7 @@ object ExprOps { * Takes a partial function of replacements. * Substitutes '''after''' recursing down the trees. * - * Supports two modes : + * Supports two modes : * * - If applyRec is false (default), will only substitute once on each level. * e.g. @@ -177,7 +177,7 @@ object ExprOps { * {{{ * Add(a, Minus(e, c)) * }}} - * + * * - If applyRec is true, it will substitute multiple times on each level: * e.g. * {{{ @@ -185,7 +185,7 @@ object ExprOps { * }}} * will yield: * {{{ - * Add(a, f) + * Add(a, f) * }}} * * @note The mode with applyRec true can diverge if f is not well formed (i.e. not convergent) @@ -205,7 +205,7 @@ object ExprOps { if (applyRec) { // Apply f as long as it returns Some() - fixpoint { e : Expr => f(e) getOrElse e } (newV) + fixpoint { e : Expr => f(e) getOrElse e } (newV) } else { f(newV) getOrElse newV } @@ -222,7 +222,7 @@ object ExprOps { * @param pre a function applied on a node before doing a recursion in the children * @param post a function applied to the node built from the recursive application to all children - * @param combiner a function to combine the resulting values from all children with + * @param combiner a function to combine the resulting values from all children with the current node * @param init the initial value * @param expr the expression on which to apply the transform @@ -270,7 +270,7 @@ object ExprOps { def collect[T](matcher: Expr => Set[T])(e: Expr): Set[T] = { foldRight[Set[T]]({ (e, subs) => matcher(e) ++ subs.flatten } )(e) } - + def collectPreorder[T](matcher: Expr => Seq[T])(e: Expr): Seq[T] = { foldRight[Seq[T]]({ (e, subs) => matcher(e) ++ subs.flatten } )(e) } @@ -341,15 +341,15 @@ object ExprOps { case _ => Set() }(expr) } - + /** Returns functions in directly nested LetDefs */ def directlyNestedFunDefs(e: Expr): Set[FunDef] = { - foldRight[Set[FunDef]]{ + foldRight[Set[FunDef]]{ case (LetDef(fd,bd), _) => Set(fd) case (_, subs) => subs.flatten.toSet }(e) } - + /** Computes the negation of a boolean formula, with some simplifications. */ def negate(expr: Expr) : Expr = { //require(expr.getType == BooleanType) @@ -388,13 +388,13 @@ object ExprOps { def freshenLocals(expr: Expr) : Expr = { def freshenCase(cse: MatchCase) : MatchCase = { val allBinders: Set[Identifier] = cse.pattern.binders - val subMap: Map[Identifier,Identifier] = + val subMap: Map[Identifier,Identifier] = Map(allBinders.map(i => (i, FreshIdentifier(i.name, i.getType, true))).toSeq : _*) val subVarMap: Map[Expr,Expr] = subMap.map(kv => Variable(kv._1) -> Variable(kv._2)) - + MatchCase( replacePatternBinders(cse.pattern, subMap), - cse.optGuard map { replace(subVarMap, _)}, + cse.optGuard map { replace(subVarMap, _)}, replace(subVarMap,cse.rhs) ) } @@ -418,7 +418,7 @@ object ExprOps { def depth(e: Expr): Int = { foldRight[Int]{ (e, sub) => 1 + (0 +: sub).max }(e) } - + /** Applies the function to the I/O constraint and simplifies the resulting constraint */ def applyAsMatches(p : Passes, f : Expr => Expr) = { f(p.asConstraint) match { @@ -494,7 +494,7 @@ object ExprOps { val grouped : Map[TypeTree, Seq[Identifier]] = allVars.groupBy(_.getType) val subst = grouped.foldLeft(Map.empty[Identifier, Identifier]) { case (subst, (tpe, ids)) => val currentVars = typedIds(tpe) - + val freshCount = ids.size - currentVars.size val typedVars = if (freshCount > 0) { val allIds = currentVars ++ List.range(0, freshCount).map(_ => FreshIdentifier("x", tpe, true)) @@ -538,7 +538,7 @@ object ExprOps { import evaluators._ val eval = new DefaultEvaluator(ctx, program) - + def rec(e: Expr): Option[Expr] = e match { case l: Terminal => None case e if isGround(e) => eval.eval(e) match { @@ -583,7 +583,7 @@ object ExprOps { case letTuple @ LetTuple(ids, Tuple(exprs), body) if isDeterministic(body) => var newBody = body - val (remIds, remExprs) = (ids zip exprs).filter { + val (remIds, remExprs) = (ids zip exprs).filter { case (id, value: Terminal) => newBody = replace(Map(Variable(id) -> value), newBody) //we replace, so we drop old @@ -603,7 +603,7 @@ object ExprOps { true } }.unzip - + Some(Constructors.letTuple(remIds, tupleWrap(remExprs), newBody)) case l @ LetTuple(ids, tExpr: Terminal, body) if isDeterministic(body) => @@ -805,14 +805,14 @@ object ExprOps { * case m @ MyCaseClass(t: B, (_, 7)) => * }}} * will yield the following condition before simplification (to give some flavour) - * + * * {{{and(IsInstanceOf(MyCaseClass, i), and(Equals(m, i), InstanceOfClass(B, i.t), equals(i.k.arity, 2), equals(i.k._2, 7))) }}} - * + * * Pretty-printed, this would be: * {{{ * i.instanceOf[MyCaseClass] && m == i && i.t.instanceOf[B] && i.k.instanceOf[Tuple2] && i.k._2 == 7 * }}} - * + * * @see [[purescala.Expressions.Pattern]] */ def conditionForPattern(in: Expr, pattern: Pattern, includeBinders: Boolean = false): Expr = { @@ -937,9 +937,9 @@ object ExprOps { } /** For each case in the [[purescala.Expressions.MatchExpr MatchExpr]], concatenates the path condition with the newly induced conditions. - * + * * Each case holds the conditions on other previous cases as negative. - * + * * @see [[purescala.ExprOps#conditionForPattern conditionForPattern]] * @see [[purescala.ExprOps#mapForPattern mapForPattern]] */ @@ -951,7 +951,7 @@ object ExprOps { val g = c.optGuard getOrElse BooleanLiteral(true) val cond = conditionForPattern(scrut, c.pattern, includeBinders = true) val localCond = pcSoFar :+ cond :+ g - + // These contain no binders defined in this MatchCase val condSafe = conditionForPattern(scrut, c.pattern) val gSafe = replaceFromIDs(mapForPattern(scrut, c.pattern),g) @@ -984,7 +984,7 @@ object ExprOps { def passesPathConditions(p : Passes, pathCond: List[Expr]) : Seq[List[Expr]] = { matchExprCaseConditions(MatchExpr(p.in, p.cases), pathCond) } - + /** * Returns a pattern from an expression, and a guard if any. */ @@ -995,7 +995,7 @@ object ExprOps { case Tuple(subs) => TuplePattern(None, subs map rec) case l : Literal[_] => LiteralPattern(None, l) case Variable(i) => WildcardPattern(Some(i)) - case other => + case other => val id = FreshIdentifier("other", other.getType, true) guard = and(guard, Equals(Variable(id), other)) WildcardPattern(Some(id)) @@ -1003,9 +1003,9 @@ object ExprOps { (rec(e), guard) } - /** + /** * Takes a pattern and returns an expression that corresponds to it. - * Also returns a sequence of `Identifier -> Expr` pairs which + * Also returns a sequence of `Identifier -> Expr` pairs which * represent the bindings for intermediate binders (from outermost to innermost) */ def patternToExpression(p: Pattern, expectedType: TypeTree): (Expr, Seq[(Identifier, Expr)]) = { @@ -1036,12 +1036,12 @@ object ExprOps { } case TuplePattern(b, subs) => val TupleType(subTypes) = expectedType - val e = Tuple(subs zip subTypes map { + val e = Tuple(subs zip subTypes map { case (sub, subType) => rec(sub, subType) }) addBinding(b, e) e - case CaseClassPattern(b, cct, subs) => + case CaseClassPattern(b, cct, subs) => val e = CaseClass(cct, subs zip cct.fieldsTypes map { case (sub,tp) => rec(sub,tp) }) addBinding(b, e) e @@ -1070,6 +1070,7 @@ object ExprOps { /** Returns simplest value of a given type */ def simplestValue(tpe: TypeTree) : Expr = tpe match { case Int32Type => IntLiteral(0) + case RealType => FractionalLiteral(0, 1) case IntegerType => InfiniteIntegerLiteral(0) case CharType => CharLiteral('a') case BooleanType => BooleanLiteral(false) @@ -1172,7 +1173,7 @@ object ExprOps { def pre(e : Expr) = e match { case LetDef(fd, expr) if fd.hasPrecondition => - val pre = fd.precondition.get + val pre = fd.precondition.get solver.solveVALID(pre) match { case Some(true) => @@ -1188,7 +1189,7 @@ object ExprOps { e - case IfExpr(cond, thenn, elze) => + case IfExpr(cond, thenn, elze) => try { solver.solveVALID(cond) match { case Some(true) => thenn @@ -1200,7 +1201,7 @@ object ExprOps { } } catch { // let's give up when the solver crashes - case _ : Exception => e + case _ : Exception => e } case _ => e @@ -1297,7 +1298,7 @@ object ExprOps { preTraversal{ case Choose(_) => return false case Hole(_, _) => return false - //@EK FIXME: do we need it? + //@EK FIXME: do we need it? //case Error(_, _) => return false case _ => }(e) @@ -1310,7 +1311,7 @@ object ExprOps { } /** Substitute (free) variables in an expression with values form a model. - * + * * Complete with simplest values in case of incomplete model. */ def valuateWithModelIn(expr: Expr, vars: Set[Identifier], model: Model): Expr = { @@ -1364,7 +1365,7 @@ object ExprOps { //btw, I know those are not the most general rules, but they lead to good optimizations :) case Plus(UMinus(Plus(e1, e2)), e3) if e1 == e3 => UMinus(e2) case Plus(UMinus(Plus(e1, e2)), e3) if e2 == e3 => UMinus(e1) - case Minus(e1, e2) if e1 == e2 => InfiniteIntegerLiteral(0) + case Minus(e1, e2) if e1 == e2 => InfiniteIntegerLiteral(0) case Minus(Plus(e1, e2), Plus(e3, e4)) if e1 == e4 && e2 == e3 => InfiniteIntegerLiteral(0) case Minus(Plus(e1, e2), Plus(Plus(e3, e4), e5)) if e1 == e4 && e2 == e3 => UMinus(e5) @@ -1375,6 +1376,41 @@ object ExprOps { fixpoint(simplePostTransform(simplify0))(expr) } + /** + * Some helper methods for FractionalLiterals + */ + def normalizeFraction(fl: FractionalLiteral) = { + val FractionalLiteral(num, denom) = fl + val modNum = if (num < 0) -num else num + val modDenom = if (denom < 0) -denom else denom + val divisor = modNum.gcd(modDenom) + val simpNum = num / divisor + val simpDenom = denom / divisor + if (simpDenom < 0) + FractionalLiteral(-simpNum, -simpDenom) + else + FractionalLiteral(simpNum, simpDenom) + } + + val realzero = FractionalLiteral(0, 1) + def floor(fl: FractionalLiteral): FractionalLiteral = { + val FractionalLiteral(n, d) = normalizeFraction(fl) + if (d == 0) throw new IllegalStateException("denominator zero") + if (n == 0) realzero + else if (n > 0) { + //perform integer division + FractionalLiteral(n / d, 1) + } else { + //here the number is negative + if (n % d == 0) + FractionalLiteral(n / d, 1) + else { + //perform integer division and subtract 1 + FractionalLiteral(n / d - 1, 1) + } + } + } + /** Checks whether a predicate is inductive on a certain identfier. * * isInductive(foo(a, b), a) where a: List will check whether @@ -1390,7 +1426,7 @@ object ExprOps { val isType = IsInstanceOf(Variable(on), cct) - val recSelectors = cct.fields.collect { + val recSelectors = cct.fields.collect { case vd if vd.getType == on.getType => vd.id } @@ -1512,11 +1548,11 @@ object ExprOps { g && e && h } - + } import synthesis.Witnesses.Terminating - + val res = (t1, t2) match { case (Variable(i1), Variable(i2)) => idHomo(i1, i2) @@ -1531,7 +1567,7 @@ object ExprOps { case (MatchExpr(s1, cs1), MatchExpr(s2, cs2)) => cs1.size == cs2.size && isHomo(s1, s2) && casesMatch(cs1,cs2) - + case (Passes(in1, out1, cs1), Passes(in2, out2, cs2)) => cs1.size == cs2.size && isHomo(in1,in2) && isHomo(out1,out2) && casesMatch(cs1,cs2) @@ -1539,7 +1575,7 @@ object ExprOps { // TODO: Check type params fdHomo(tfd1.fd, tfd2.fd) && (args1 zip args2).forall{ case (a1, a2) => isHomo(a1, a2) } - + case (Terminating(tfd1, args1), Terminating(tfd2, args2)) => // TODO: Check type params fdHomo(tfd1.fd, tfd2.fd) && @@ -1596,7 +1632,7 @@ object ExprOps { * * TODO: We ignore type parameters here, we might want to make sure it's * valid. What's Leon's semantics w.r.t. erasure? - */ + */ def areExaustive(pss: Seq[(TypeTree, Seq[Pattern])]): Boolean = pss.forall { case (tpe, ps) => tpe match { @@ -1653,24 +1689,24 @@ object ExprOps { } } - case BooleanType => - // make sure ps contains either + case BooleanType => + // make sure ps contains either // - Wildcard or - // - both true and false + // - both true and false (ps exists { _.isInstanceOf[WildcardPattern] }) || { var found = Set[Boolean]() - ps foreach { + ps foreach { case LiteralPattern(_, BooleanLiteral(b)) => found += b case _ => () } (found contains true) && (found contains false) } - case UnitType => + case UnitType => // Anything matches () ps.nonEmpty - case Int32Type => + case Int32Type => // Can't possibly pattern match against all Ints one by one ps exists (_.isInstanceOf[WildcardPattern]) @@ -1702,7 +1738,7 @@ object ExprOps { * } * }}} * becomes - * {{{ + * {{{ * def foo(a, b) { * if (..) { foo(b, a) } else { .. } * } @@ -1751,7 +1787,7 @@ object ExprOps { case (Some(oe), Some(ie)) => val res = FreshIdentifier("res", fdOuter.returnType, true) Some(Lambda(Seq(ValDef(res)), and( - application(oe, Seq(Variable(res))), + application(oe, Seq(Variable(res))), application(simplePreTransform(pre)(ie), Seq(Variable(res))) ))) } @@ -1791,12 +1827,12 @@ object ExprOps { * Body manipulation * ================= */ - + /** Replaces the precondition of an existing [[Expressions.Expr]] with a new one. - * + * * If no precondition is provided, removes any existing precondition. * Else, wraps the expression with a [[Expressions.Require]] clause referring to the new precondition. - * + * * @param expr The current expression * @param pred An optional precondition. Setting it to None removes any precondition. * @see [[Expressions.Ensuring]] @@ -1813,10 +1849,10 @@ object ExprOps { } /** Replaces the postcondition of an existing [[Expressions.Expr]] with a new one. - * + * * If no postcondition is provided, removes any existing postcondition. * Else, wraps the expression with a [[Expressions.Ensuring]] clause referring to the new postcondition. - * + * * @param expr The current expression * @param oie An optional postcondition. Setting it to None removes any postcondition. * @see [[Expressions.Ensuring]] @@ -1830,7 +1866,7 @@ object ExprOps { } /** Adds a body to a specification - * + * * @param expr The specification expression [[Expressions.Ensuring]] or [[Expressions.Require]]. If none of these, the argument is discarded. * @param body An option of [[Expressions.Expr]] possibly containing an expression body. * @return The post/pre condition with the body. If no body is provided, returns [[Expressions.NoTree]] @@ -1845,10 +1881,10 @@ object ExprOps { } /** Extracts the body without its specification - * + * * [[Expressions.Expr]] trees contain its specifications as part of certain nodes. * This function helps extracting only the body part of an expression - * + * * @return An option type with the resulting expression if not [[Expressions.NoTree]] * @see [[Expressions.Ensuring]] * @see [[Expressions.Require]] @@ -1875,7 +1911,7 @@ object ExprOps { /** Returns a tuple of precondition, the raw body and the postcondition of an expression */ def breakDownSpecs(e : Expr) = (preconditionOf(e), withoutSpec(e), postconditionOf(e)) - + def preTraversalWithParent(f: (Expr, Option[Tree]) => Unit, initParent: Option[Tree] = None)(e: Expr): Unit = { val rec = preTraversalWithParent(f, Some(e)) _ @@ -1952,7 +1988,7 @@ object ExprOps { */ def liftClosures(e: Expr): (Set[FunDef], Expr) = { var fds: Map[FunDef, FunDef] = Map() - + import synthesis.Witnesses.Terminating val res1 = preMap({ case LetDef(fd, b) => @@ -1968,7 +2004,7 @@ object ExprOps { } else { None } - + case Terminating(tfd, args) => if (fds contains tfd.fd) { Some(Terminating(fds(tfd.fd).typed(tfd.tps), args)) diff --git a/src/main/scala/leon/purescala/Expressions.scala b/src/main/scala/leon/purescala/Expressions.scala index c807ef81601c672ffb335013eae264724116b9d6..bfe78acb38b7d33fb25cb44046f0bdfed8016bcc 100644 --- a/src/main/scala/leon/purescala/Expressions.scala +++ b/src/main/scala/leon/purescala/Expressions.scala @@ -10,7 +10,7 @@ import Extractors._ import Constructors._ import ExprOps.replaceFromIDs -/** Expression definitions for Pure Scala. +/** Expression definitions for Pure Scala. * * If you are looking for things such as function or class definitions, * please have a look in [[purescala.Definitions]]. @@ -25,7 +25,7 @@ import ExprOps.replaceFromIDs * optimization opportunities. Unless you need exact control on the structure * of the trees, you should use constructors in [[purescala.Constructors]], that * simplify the trees they produce. - * + * * @define encodingof Encoding of * @define noteBitvector (32-bit vector) * @define noteReal (Real) @@ -77,7 +77,7 @@ object Expressions { } /** Precondition of an [[Expressions.Expr]]. Corresponds to the Leon keyword *require* - * + * * @param pred The precondition formula inside ``require(...)`` * @param body The body following the ``require(...)`` */ @@ -90,7 +90,7 @@ object Expressions { } /** Postcondition of an [[Expressions.Expr]]. Corresponds to the Leon keyword *ensuring* - * + * * @param body The body of the expression. It can contain at most one [[Expressions.Require]] sub-expression. * @param pred The predicate to satisfy. It should be a function whose argument's type can handle the type of the body */ @@ -114,7 +114,7 @@ object Expressions { } /** Local assertions with customizable error message - * + * * @param pred The predicate, first argument of `assert(..., ...)` * @param error An optional error string to display if the assert fails. Second argument of `assert(..., ...)` * @param body The expression following `assert(..., ...)` @@ -156,7 +156,7 @@ object Expressions { } /** $encodingof `def ... = ...; ...` (local function definition) - * + * * @param fd The function definition. * @param body The body of the expression after the function */ @@ -172,7 +172,7 @@ object Expressions { * Both [[Expressions.MethodInvocation]] and [[Expressions.This]] get removed by phase [[MethodLifting]]. * Methods become functions, [[Expressions.This]] becomes first argument, * and [[Expressions.MethodInvocation]] becomes [[Expressions.FunctionInvocation]]. - * + * * @param rec The expression evaluating to an object * @param cd The class definition typing `rec` * @param tfd The typed function definition of the method @@ -203,7 +203,7 @@ object Expressions { /* Higher-order Functions */ - + /** $encodingof `callee(args...)`, where [[callee]] is an expression of a function type (not a method) */ case class Application(callee: Expr, args: Seq[Expr]) extends Expr { val getType = callee.getType match { @@ -251,12 +251,12 @@ object Expressions { } /** $encodingof `... match { ... }` - * + * * '''cases''' should be nonempty. If you are not sure about this, you should use * [[purescala.Constructors#matchExpr purescala's constructor matchExpr]] - * + * * @param scrutinee Expression to the left of the '''match''' keyword - * @param cases A sequence of cases to match `scrutinee` against + * @param cases A sequence of cases to match `scrutinee` against */ case class MatchExpr(scrutinee: Expr, cases: Seq[MatchCase]) extends Expr { require(cases.nonEmpty) @@ -264,7 +264,7 @@ object Expressions { } /** $encodingof `case pattern [if optGuard] => rhs` - * + * * @param pattern The pattern just to the right of the '''case''' keyword * @param optGuard An optional if-condition just to the left of the `=>` * @param rhs The expression to the right of `=>` @@ -275,7 +275,7 @@ object Expressions { } /** $encodingof a pattern after a '''case''' keyword. - * + * * @see [[Expressions.MatchCase]] */ sealed abstract class Pattern extends Tree { @@ -301,7 +301,7 @@ object Expressions { /** Pattern encoding `case _ => `, or `case binder => ` if identifier [[binder]] is present */ case class WildcardPattern(binder: Option[Identifier]) extends Pattern { // c @ _ val subPatterns = Seq() - } + } /** Pattern encoding `case binder @ ct(subPatterns...) =>` * * If [[binder]] is empty, consider a wildcard `_` in its place. @@ -319,7 +319,7 @@ object Expressions { * If [[binder]] is empty, consider a wildcard `_` in its place. */ case class LiteralPattern[+T](binder: Option[Identifier], lit : Literal[T]) extends Pattern { - val subPatterns = Seq() + val subPatterns = Seq() } /** A custom pattern defined through an object's `unapply` function */ @@ -362,11 +362,11 @@ object Expressions { /** Symbolic I/O examples as a match/case. * $encodingof `out == (in match { cases; case _ => out })` - * + * * [[cases]] should be nonempty. If you are not sure about this, you should use * [[purescala.Constructors#passes purescala's constructor passes]] - * - * @param in + * + * @param in * @param out * @param cases */ @@ -402,8 +402,9 @@ object Expressions { case class InfiniteIntegerLiteral(value: BigInt) extends Literal[BigInt] { val getType = IntegerType } - /** $encodingof a real number literal */ - case class RealLiteral(value: BigDecimal) extends Literal[BigDecimal] { + /** $encodingof a fraction literal */ + case class FractionalLiteral(numerator: BigInt, denominator: BigInt) extends Literal[(BigInt, BigInt)] { + val value = (numerator, denominator) val getType = RealType } /** $encodingof a boolean literal '''true''' or '''false''' */ @@ -500,7 +501,7 @@ object Expressions { } /** $encodingof `... || ...` - * + * * [[exprs]] must contain at least two elements; if you are not sure about this, * you should use [[purescala.Constructors#or purescala's constructor or]] or * [[purescala.Constructors#orJoin purescala's constructor orJoin]] @@ -574,7 +575,7 @@ object Expressions { } } /** $encodingof `... / ...` - * + * * Division and Remainder follows Java/Scala semantics. Division corresponds * to / operator on BigInt and Remainder corresponds to %. Note that in * Java/Scala % is called remainder and the "mod" operator (Modulo in Leon) is also @@ -591,7 +592,7 @@ object Expressions { } } /** $encodingof `... % ...` (can return negative numbers) - * + * * @see [[Expressions.Division]] */ case class Remainder(lhs: Expr, rhs: Expr) extends Expr { @@ -601,7 +602,7 @@ object Expressions { } } /** $encodingof `... mod ...` (cannot return negative numbers) - * + * * @see [[Expressions.Division]] */ case class Modulo(lhs: Expr, rhs: Expr) extends Expr { @@ -615,11 +616,11 @@ object Expressions { val getType = BooleanType } /** $encodingof `... > ...`*/ - case class GreaterThan(lhs: Expr, rhs: Expr) extends Expr { + case class GreaterThan(lhs: Expr, rhs: Expr) extends Expr { val getType = BooleanType } /** $encodingof `... <= ...`*/ - case class LessEquals(lhs: Expr, rhs: Expr) extends Expr { + case class LessEquals(lhs: Expr, rhs: Expr) extends Expr { val getType = BooleanType } /** $encodingof `... >= ...`*/ @@ -635,32 +636,32 @@ object Expressions { val getType = Int32Type } /** $encodingof `... - ...` $noteBitvector*/ - case class BVMinus(lhs: Expr, rhs: Expr) extends Expr { + case class BVMinus(lhs: Expr, rhs: Expr) extends Expr { require(lhs.getType == Int32Type && rhs.getType == Int32Type) val getType = Int32Type } /** $encodingof `- ...` $noteBitvector*/ - case class BVUMinus(expr: Expr) extends Expr { + case class BVUMinus(expr: Expr) extends Expr { require(expr.getType == Int32Type) val getType = Int32Type } /** $encodingof `... * ...` $noteBitvector*/ - case class BVTimes(lhs: Expr, rhs: Expr) extends Expr { + case class BVTimes(lhs: Expr, rhs: Expr) extends Expr { require(lhs.getType == Int32Type && rhs.getType == Int32Type) val getType = Int32Type } /** $encodingof `... / ...` $noteBitvector*/ - case class BVDivision(lhs: Expr, rhs: Expr) extends Expr { + case class BVDivision(lhs: Expr, rhs: Expr) extends Expr { require(lhs.getType == Int32Type && rhs.getType == Int32Type) val getType = Int32Type } /** $encodingof `... % ...` $noteBitvector*/ - case class BVRemainder(lhs: Expr, rhs: Expr) extends Expr { + case class BVRemainder(lhs: Expr, rhs: Expr) extends Expr { require(lhs.getType == Int32Type && rhs.getType == Int32Type) val getType = Int32Type } /** $encodingof `! ...` $noteBitvector */ - case class BVNot(expr: Expr) extends Expr { + case class BVNot(expr: Expr) extends Expr { val getType = Int32Type } /** $encodingof `... & ...` $noteBitvector */ @@ -696,22 +697,22 @@ object Expressions { val getType = RealType } /** $encodingof `... - ...` $noteReal */ - case class RealMinus(lhs: Expr, rhs: Expr) extends Expr { + case class RealMinus(lhs: Expr, rhs: Expr) extends Expr { require(lhs.getType == RealType && rhs.getType == RealType) val getType = RealType } /** $encodingof `- ...` $noteReal */ - case class RealUMinus(expr: Expr) extends Expr { + case class RealUMinus(expr: Expr) extends Expr { require(expr.getType == RealType) val getType = RealType } /** $encodingof `... * ...` $noteReal */ - case class RealTimes(lhs: Expr, rhs: Expr) extends Expr { + case class RealTimes(lhs: Expr, rhs: Expr) extends Expr { require(lhs.getType == RealType && rhs.getType == RealType) val getType = RealType } /** $encodingof `... / ...` $noteReal */ - case class RealDivision(lhs: Expr, rhs: Expr) extends Expr { + case class RealDivision(lhs: Expr, rhs: Expr) extends Expr { require(lhs.getType == RealType && rhs.getType == RealType) val getType = RealType } @@ -720,11 +721,11 @@ object Expressions { /* Tuple operations */ /** $encodingof `(..., ....)` (tuple) - * + * * [[exprs]] should always contain at least 2 elements. * If you are not sure about this requirement, you should use * [[purescala.Constructors#tupleWrap purescala's constructor tupleWrap]] - * + * * @param exprs The expressions in the tuple */ case class Tuple (exprs: Seq[Expr]) extends Expr { @@ -733,7 +734,7 @@ object Expressions { } /** $encodingof `(tuple)._i` - * + * * Index is 1-based, first element of tuple is 1. * If you are not sure that [[tuple]] is indeed of a TupleType, * you should use [[purescala.Constructors$.tupleSelect(t:leon\.purescala\.Expressions\.Expr,index:Int,isTuple:Boolean):leon\.purescala\.Expressions\.Expr* purescala's constructor tupleSelect]] @@ -839,7 +840,7 @@ object Expressions { } /** $encodingof Array(elems...) with predetermined elements - * @param elems The map from the position to the elements. + * @param elems The map from the position to the elements. * @param defaultLength An optional pair where the first element is the default value * and the second is the size of the array. Set this for big arrays * with a default value (as genereted with `Array.fill` in Scala). diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala index 57618655f1e3f793030f9cc37f920bf628870a2f..2fccfbdaf8afd779677392e5bb94a3a42816b64e 100644 --- a/src/main/scala/leon/purescala/PrettyPrinter.scala +++ b/src/main/scala/leon/purescala/PrettyPrinter.scala @@ -160,7 +160,9 @@ class PrettyPrinter(opts: PrinterOptions, case Equals(l,r) => optP { p"$l == $r" } case IntLiteral(v) => p"$v" case InfiniteIntegerLiteral(v) => p"$v" - case RealLiteral(d) => p"$d" + case FractionalLiteral(n, d) => + if (d == 1) p"$n" + else p"$n/$d" case CharLiteral(v) => p"$v" case BooleanLiteral(v) => p"$v" case UnitLiteral() => p"()" @@ -286,7 +288,7 @@ class PrettyPrinter(opts: PrinterOptions, val orderedElements = es.toSeq.sortWith((e1, e2) => e1._1 < e2._1).map(el => el._2) p"Array($orderedElements)" } else if(length < 10) { - val elems = (0 until length).map(i => + val elems = (0 until length).map(i => es.find(el => el._1 == i).map(el => el._2).getOrElse(d.get) ) p"Array($elems)" @@ -418,7 +420,7 @@ class PrettyPrinter(opts: PrinterOptions, |${nary(defs,"\n\n")} |""" - case Import(path, isWild) => + case Import(path, isWild) => if (isWild) { p"import ${nary(path,".")}._" } else { diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala index 17faa873772df58a8382a3b3b7ab02f29c0cb63c..eee03153e680124194b3f7443d6f5111be1962e3 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala @@ -363,7 +363,7 @@ trait SMTLIBTarget extends Interruptible { case InfiniteIntegerLiteral(i) => if (i >= 0) Ints.NumeralLit(i) else Ints.Neg(Ints.NumeralLit(-i)) case IntLiteral(i) => FixedSizeBitVectors.BitVectorLit(Hexadecimal.fromInt(i)) - case RealLiteral(d) => if (d >= 0) Reals.DecimalLit(d) else Reals.Neg(Reals.DecimalLit(-d)) + case FractionalLiteral(n, d) => Reals.Div(Reals.NumeralLit(n), Reals.NumeralLit(d)) case CharLiteral(c) => FixedSizeBitVectors.BitVectorLit(Hexadecimal.fromInt(c.toInt)) case BooleanLiteral(v) => Core.BoolConst(v) case Let(b,d,e) => @@ -642,10 +642,14 @@ trait SMTLIBTarget extends Interruptible { IntLiteral(hexa.toInt) case (SDecimal(d), Some(RealType)) => - RealLiteral(d) + // converting bigdecimal to a fraction + val scale = d.scale + val num = BigInt(d.bigDecimal.scaleByPowerOfTen(scale).toBigInteger()) + val denom = BigInt(new java.math.BigDecimal(1).scaleByPowerOfTen(-scale).toBigInteger()) + FractionalLiteral(num, denom) case (SNumeral(n), Some(RealType)) => - RealLiteral(BigDecimal(n)) + FractionalLiteral(n, 1) case (FunctionApplication(SimpleSymbol(SSymbol("ite")), Seq(cond, thenn, elze)), t) => IfExpr( diff --git a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala index 44df7eb678380b27c93439d19eb65b3626202a80..9c6979ca1992728b4aaefbaa0d9d3465a34558bb 100644 --- a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala @@ -80,7 +80,7 @@ trait AbstractZ3Solver extends Solver { z3.mkFreshFuncDecl(gv.tp.id.uniqueName+"#"+gv.id+"!val", Seq(), typeToSort(gv.tp)) } } - + // ADT Manager protected val adtManager = new ADTManager(context) @@ -272,7 +272,7 @@ trait AbstractZ3Solver extends Solver { } def rec(ex: Expr): Z3AST = ex match { - + // TODO: Leave that as a specialization? case LetTuple(ids, e, b) => { z3Vars = z3Vars ++ ids.zipWithIndex.map { case (id, ix) => @@ -283,7 +283,7 @@ trait AbstractZ3Solver extends Solver { z3Vars = z3Vars -- ids rb } - + case p @ Passes(_, _, _) => rec(p.asConstraint) @@ -326,7 +326,7 @@ trait AbstractZ3Solver extends Solver { case Not(e) => z3.mkNot(rec(e)) case IntLiteral(v) => z3.mkInt(v, typeToSort(Int32Type)) case InfiniteIntegerLiteral(v) => z3.mkNumeral(v.toString, typeToSort(IntegerType)) - case RealLiteral(v) => z3.mkNumeral(v.toString, typeToSort(RealType)) + case FractionalLiteral(n, d) => z3.mkNumeral(s"$n / $d", typeToSort(RealType)) case CharLiteral(c) => z3.mkInt(c, typeToSort(CharType)) case BooleanLiteral(v) => if (v) z3.mkTrue() else z3.mkFalse() case Equals(l, r) => z3.mkEq(rec( l ), rec( r ) ) @@ -381,6 +381,7 @@ trait AbstractZ3Solver extends Solver { case RealType => z3.mkLE(rec(l), rec(r)) case Int32Type => z3.mkBVSle(rec(l), rec(r)) case CharType => z3.mkBVSle(rec(l), rec(r)) + //case _ => throw new IllegalStateException(s"l: $l, Left type: ${l.getType} Expr: $ex") } case GreaterThan(l, r) => l.getType match { case IntegerType => z3.mkGT(rec(l), rec(r)) @@ -583,9 +584,7 @@ trait AbstractZ3Solver extends Solver { } } } - case Z3NumeralRealAST(num: BigInt, den: BigInt) => { - RealLiteral(BigDecimal(num) / BigDecimal(den)) - } + case Z3NumeralRealAST(n: BigInt, d: BigInt) => FractionalLiteral(n, d) case Z3AppAST(decl, args) => val argsSize = args.size if(argsSize == 0 && (variables containsB t)) { diff --git a/src/main/scala/leon/transformations/DepthInstPhase.scala b/src/main/scala/leon/transformations/DepthInstPhase.scala new file mode 100644 index 0000000000000000000000000000000000000000..a505929d0dee541803d921fb4d0047c869bc28f1 --- /dev/null +++ b/src/main/scala/leon/transformations/DepthInstPhase.scala @@ -0,0 +1,104 @@ +package leon +package transformations + +import purescala.Common._ +import purescala.Definitions._ +import purescala.Extractors._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.Types._ +import leon.utils._ +import invariant.util.Util._ + +object DepthCostModel { + val typedMaxFun = TypedFunDef(InstUtil.maxFun, Seq()) + + def costOf(e: Expr): Int = + e match { + case FunctionInvocation(fd, args) => 1 + case t: Terminal => 0 + case _ => 1 + } + + def costOfExpr(e: Expr) = InfiniteIntegerLiteral(costOf(e)) +} + +class DepthInstrumenter(p: Program, si: SerialInstrumenter) extends Instrumenter(p, si) { + import DepthCostModel._ + + def inst = Depth + + def functionsToInstrument(): Map[FunDef, List[Instrumentation]] = { + //find all functions transitively called from rootFuncs (here ignore functions called via pre/post conditions) + val instFunSet = getRootFuncs().foldLeft(Set[FunDef]())((acc, fd) => acc ++ cg.transitiveCallees(fd)) + instFunSet.map(x => (x, List(Depth))).toMap + } + + def additionalfunctionsToAdd(): Seq[FunDef] = Seq()// - max functions are inlined, so they need not be added + + def instrumentMatchCase(me: MatchExpr, mc: MatchCase, + caseExprCost: Expr, scrutineeCost: Expr): Expr = { + val costMatch = costOfExpr(me) + def totalCostOfMatchPatterns(me: MatchExpr, mc: MatchCase): BigInt = 0 + combineDepthIds(costMatch, List(caseExprCost, scrutineeCost)) + } + + def instrument(e: Expr, subInsts: Seq[Expr], funInvResVar: Option[Variable] = None)(implicit fd: FunDef, letIdMap: Map[Identifier, Identifier]): Expr = { + val costOfParent = costOfExpr(e) + e match { + case Variable(id) if letIdMap.contains(id) => + // add the cost of instrumentation here + Plus(costOfParent, si.selectInst(fd)(letIdMap(id).toVariable, inst)) + + case t: Terminal => costOfParent + case FunctionInvocation(tfd, args) => + val depthvar = subInsts.last + val remSubInsts = subInsts.slice(0, subInsts.length - 1) + val costofOp = { + costOfParent match { + case InfiniteIntegerLiteral(x) if (x == 0) => depthvar + case _ => Plus(costOfParent, depthvar) + } + } + combineDepthIds(costofOp, remSubInsts) + case e : Let => + //in this case, ignore the depth of the value, it will included if the bounded variable is + // used in the body + combineDepthIds(costOfParent, subInsts.tail) + case _ => + val costofOp = costOfParent + combineDepthIds(costofOp, subInsts) + } + } + + def instrumentIfThenElseExpr(e: IfExpr, condInst: Option[Expr], + thenInst: Option[Expr], elzeInst: Option[Expr]): (Expr, Expr) = { + + val cinst = condInst.toList + val tinst = thenInst.toList + val einst = elzeInst.toList + + (combineDepthIds(zero, cinst ++ tinst), combineDepthIds(zero, cinst ++ einst)) + } + + def combineDepthIds(costofOp: Expr, subeInsts: Seq[Expr]): Expr = { + if (subeInsts.size == 0) costofOp + else if (subeInsts.size == 1) Plus(costofOp, subeInsts(0)) + else { + //optimization: remove duplicates from 'subeInsts' as 'max' is an idempotent operation + val head +: tail = subeInsts.distinct + val summand = tail.foldLeft(head: Expr)((acc, id) => { + (acc, id) match { + case (InfiniteIntegerLiteral(x), _) if (x == 0) => id + case (_, InfiniteIntegerLiteral(x)) if (x == 0) => acc + case _ => + FunctionInvocation(typedMaxFun, Seq(acc, id)) + } + }) + costofOp match { + case InfiniteIntegerLiteral(x) if (x == 0) => summand + case _ => Plus(costofOp, summand) + } + } + } +} diff --git a/src/main/scala/leon/transformations/InstProgSimplifier.scala b/src/main/scala/leon/transformations/InstProgSimplifier.scala new file mode 100644 index 0000000000000000000000000000000000000000..037a444354fe1ab91566c701d009c29eee2ad2b6 --- /dev/null +++ b/src/main/scala/leon/transformations/InstProgSimplifier.scala @@ -0,0 +1,87 @@ +package leon +package transformations + +import purescala.Common._ +import purescala.Definitions._ +import purescala.Extractors._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.Types._ +import leon.purescala.ScalaPrinter +import leon.utils._ +import invariant.util.Util._ +import invariant.util.ExpressionTransformer._ +import invariant.structure.FunctionUtils._ +import invariant.util.LetTupleSimplification._ + +/** + * A simplifier phase that eliminates tuples that are not needed + * from function bodies, and also performs other simplifications. + * Note: performing simplifications during instrumentation + * will affect the validity of the information stored in function info. + */ +object ProgramSimplifier { + val debugSimplify = false + + def mapProgram(funMap: Map[FunDef, FunDef]): Map[FunDef, FunDef] = { + + def mapExpr(ine: Expr): Expr = { + val replaced = simplePostTransform((e: Expr) => e match { + case FunctionInvocation(tfd, args) if funMap.contains(tfd.fd) => + FunctionInvocation(TypedFunDef(funMap(tfd.fd), tfd.tps), args) + case _ => e + })(ine) + + // One might want to add the maximum function to the program in the stack + // and depth instrumentation phases if inlineMax is removed from here + val allSimplifications = + simplifyTuples _ andThen + simplifyMax _ andThen + simplifyLetsAndLetsWithTuples _ andThen + simplifyAdditionsAndMax _ andThen + inlineMax _ + + allSimplifications(replaced) + } + + for ((from, to) <- funMap) { + to.fullBody = mapExpr(from.fullBody) + //copy annotations + from.flags.foreach(to.addFlag(_)) + } + funMap + } + + def createNewFunDefs(program: Program): Map[FunDef, FunDef] = { + val allFuncs = functionsWOFields(program.definedFunctions) + + allFuncs.foldLeft(Map[FunDef, FunDef]()) { + case (accMap, fd) if fd.isTheoryOperation => + accMap + (fd -> fd) + case (accMap, fd) => { + //here we need not augment the return types + val freshId = FreshIdentifier(fd.id.name, fd.returnType) + val newfd = new FunDef(freshId, fd.tparams, fd.returnType, fd.params) + accMap.updated(fd, newfd) + } + } + } + + def createNewProg(mappedFuncs: Map[FunDef, FunDef], prog: Program): Program = { + val newprog = copyProgram(prog, (defs: Seq[Definition]) => defs.map { + case fd: FunDef if mappedFuncs.contains(fd) => + mappedFuncs(fd) + case d => d + }) + + if (debugSimplify) + println("After Simplifications: \n" + ScalaPrinter.apply(newprog)) + newprog + } + + def apply(program: Program): Program = { + val newFuncs = createNewFunDefs(program) + val mappedFuncs = mapProgram(newFuncs) + createNewProg(mappedFuncs, program) + } +} \ No newline at end of file diff --git a/src/main/scala/leon/transformations/InstrumentationUtil.scala b/src/main/scala/leon/transformations/InstrumentationUtil.scala new file mode 100644 index 0000000000000000000000000000000000000000..fa35cc3f56cfe1dd37f2b013337373e619a5b086 --- /dev/null +++ b/src/main/scala/leon/transformations/InstrumentationUtil.scala @@ -0,0 +1,110 @@ +package leon +package transformations + +import purescala.Common._ +import purescala.Definitions._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.Extractors._ +import purescala.Types._ +import leon.utils.Library + +sealed abstract class Instrumentation { + val getType: TypeTree + val name: String + def isInstVariable(e: Expr): Boolean = { + e match { + case FunctionInvocation(TypedFunDef(fd, _), _) if (fd.id.name == name && fd.annotations("library")) => + true + case _ => false + } + } + override def toString = name +} + +object Time extends Instrumentation { + override val getType = IntegerType + override val name = "time" +} +object Depth extends Instrumentation { + override val getType = IntegerType + override val name = "depth" +} +object Rec extends Instrumentation { + override val getType = IntegerType + override val name = "rec" +} + +/** + * time per recursive step. + */ +object TPR extends Instrumentation { + override val getType = IntegerType + override val name = "tpr" +} + +object Stack extends Instrumentation { + override val getType = IntegerType + override val name = "stack" +} +//add more instrumentation variables + +object InstUtil { + + val maxFun = { + val xid = FreshIdentifier("x", IntegerType) + val yid = FreshIdentifier("y", IntegerType) + val varx = xid.toVariable + val vary = yid.toVariable + val args = Seq(xid, yid) + val maxType = FunctionType(Seq(IntegerType, IntegerType), IntegerType) + val mfd = new FunDef(FreshIdentifier("max", maxType, false), Seq(), IntegerType, + args.map((arg) => ValDef(arg, Some(arg.getType)))) + + val cond = GreaterEquals(varx, vary) + mfd.body = Some(IfExpr(cond, varx, vary)) + mfd.addFlag(Annotation("theoryop", Seq())) + mfd + } + + def userFunctionName(fd: FunDef) = fd.id.name.split("-")(0) + + def getInstMap(fd: FunDef) = { + val resvar = invariant.util.Util.getResId(fd).get.toVariable // note: every instrumented function has a postcondition + val insts = fd.id.name.split("-").tail // split the name of the function w.r.t '-' + (insts.zipWithIndex).foldLeft(Map[Expr, String]()) { + case (acc, (instName, i)) => + acc + (TupleSelect(resvar, i + 2) -> instName) + } + } + + def getInstExpr(fd: FunDef, inst: Instrumentation) = { + val resvar = invariant.util.Util.getResId(fd).get.toVariable // note: every instrumented function has a postcondition + val insts = fd.id.name.split("-").tail // split the name of the function w.r.t '-' + val index = insts.indexOf(inst.name) + if (index >= 0) + Some(TupleSelect(resvar, index + 2)) + else None + } + + def getInstVariableMap(fd: FunDef) = { + getInstMap(fd).map { + case (ts, instName) => + (ts -> Variable(FreshIdentifier(instName, IntegerType))) + } + } + + def isInstrumented(fd: FunDef, instType: Instrumentation) = { + fd.id.name.split("-").contains(instType.toString) + } + + def resultExprForInstVariable(fd: FunDef, instType: Instrumentation) = { + getInstVariableMap(fd).collectFirst { + case (k, Variable(id)) if (id.name == instType.toString) => k + } + } + + def replaceInstruVars(e: Expr, fd: FunDef): Expr = { + replace(getInstVariableMap(fd), e) + } +} diff --git a/src/main/scala/leon/transformations/IntToRealProgram.scala b/src/main/scala/leon/transformations/IntToRealProgram.scala new file mode 100644 index 0000000000000000000000000000000000000000..d24b4c87aba7356ba091e43286a9e9ffc2a64130 --- /dev/null +++ b/src/main/scala/leon/transformations/IntToRealProgram.scala @@ -0,0 +1,234 @@ +package leon +package transformations + +import purescala.Common._ +import purescala.Definitions._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.Extractors._ +import purescala.Types._ +import leon.purescala.ScalaPrinter + +import invariant.factories._ +import invariant.util.Util._ +import invariant.structure._ + +abstract class ProgramTypeTransformer { + protected var defmap = Map[ClassDef, ClassDef]() + protected var idmap = Map[Identifier, Identifier]() + protected var newFundefs = Map[FunDef, FunDef]() + + def mapField(cdef: CaseClassDef, fieldId: Identifier): Identifier = { + (cdef.fieldsIds.collectFirst { + case fid @ _ if (fid.name == fieldId.name) => fid + }).get + } + + def mapClass[T <: ClassDef](cdef: T): T = { + if (defmap.contains(cdef)) { + defmap(cdef).asInstanceOf[T] + } else { + cdef match { + case ccdef: CaseClassDef => + val newparent = if (ccdef.hasParent) { + val absType = ccdef.parent.get + Some(AbstractClassType(mapClass(absType.classDef), absType.tps)) + } else None + val newclassDef = ccdef.copy(id = FreshIdentifier(ccdef.id.name, ccdef.id.getType, true), parent = newparent) + + //important: register a child if a parent was newly created. + if (newparent.isDefined) + newparent.get.classDef.registerChild(newclassDef) + + defmap += (ccdef -> newclassDef) + newclassDef.setFields(ccdef.fields.map(mapDecl)) + newclassDef.asInstanceOf[T] + + case acdef: AbstractClassDef => + val newparent = if (acdef.hasParent) { + val absType = acdef.parent.get + Some(AbstractClassType(mapClass(absType.classDef), absType.tps)) + } else None + val newClassDef = acdef.copy(id = FreshIdentifier(acdef.id.name, acdef.id.getType, true), parent = newparent) + defmap += (acdef -> newClassDef) + newClassDef.asInstanceOf[T] + } + } + } + + def mapId(id: Identifier): Identifier = { + val newtype = mapType(id.getType) + val newId = idmap.getOrElse(id, { + //important need to preserve distinction between template variables and ordinary variables + val freshId = if (TemplateIdFactory.IsTemplateIdentifier(id)) TemplateIdFactory.copyIdentifier(id) + else FreshIdentifier(id.name, newtype, true) + idmap += (id -> freshId) + freshId + }) + newId + } + + def mapDecl(decl: ValDef): ValDef = { + val newtpe = mapType(decl.getType) + new ValDef(mapId(decl.id), Some(newtpe)) + } + + def mapType(tpe: TypeTree): TypeTree = { + tpe match { + case t @ RealType => mapNumericType(t) + case t @ IntegerType => mapNumericType(t) + case AbstractClassType(adef, tps) => AbstractClassType(mapClass(adef), tps) + case CaseClassType(cdef, tps) => CaseClassType(mapClass(cdef), tps) + case TupleType(bases) => TupleType(bases.map(mapType)) + case _ => tpe + } + } + + def mapNumericType(tpe: TypeTree): TypeTree + + def mapLiteral(lit: Literal[_]): Literal[_] + + def transform(program: Program): Program = { + //create a new fundef for each function in the program + //Unlike functions, classes are created lazily as required. + newFundefs = program.definedFunctions.map((fd) => { + val newFunType = FunctionType(fd.tparams.map((currParam) => currParam.tp), fd.returnType) + val newfd = new FunDef(FreshIdentifier(fd.id.name, newFunType, true), fd.tparams, + mapType(fd.returnType), fd.params.map(mapDecl)) + (fd, newfd) + }).toMap + + /** + * Here, we assume that tuple-select and case-class-select have been reduced + */ + def transformExpr(e: Expr): Expr = e match { + case l: Literal[_] => mapLiteral(l) + case v @ Variable(inId) => mapId(inId).toVariable + case FunctionInvocation(TypedFunDef(intfd, tps), args) => FunctionInvocation(TypedFunDef(newFundefs(intfd), tps), args.map(transformExpr)) + case CaseClass(CaseClassType(classDef, tps), args) => CaseClass(CaseClassType(mapClass(classDef), tps), args.map(transformExpr)) + case IsInstanceOf(expr, CaseClassType(classDef, tps)) => IsInstanceOf(transformExpr(expr), CaseClassType(mapClass(classDef), tps)) + case CaseClassSelector(CaseClassType(classDef, tps), expr, fieldId) => { + val newtype = CaseClassType(mapClass(classDef), tps) + CaseClassSelector(newtype, transformExpr(expr), mapField(newtype.classDef, fieldId)) + } + //need to handle 'let' and 'letTuple' specially + case Let(binder, value, body) => Let(mapId(binder), transformExpr(value), transformExpr(body)) + case t: Terminal => t + /*case UnaryOperator(arg, op) => op(transformExpr(arg)) + case BinaryOperator(arg1, arg2, op) => op(transformExpr(arg1), transformExpr(arg2))*/ + case Operator(args, op) => op(args.map(transformExpr)) + } + + //create a body, pre, post for each newfundef + newFundefs.foreach((entry) => { + val (fd, newfd) = entry + + //add a new precondition + newfd.precondition = + if (fd.precondition.isDefined) + Some(transformExpr(fd.precondition.get)) + else None + + //add a new body + newfd.body = if (fd.hasBody) { + //replace variables by constants if possible + val simpBody = matchToIfThenElse(fd.body.get) + Some(transformExpr(simpBody)) + } else Some(NoTree(fd.returnType)) + + // FIXME + //add a new postcondition + newfd.fullBody = if (fd.postcondition.isDefined && newfd.body.isDefined) { + val Lambda(Seq(ValDef(resid, _)), pexpr) = fd.postcondition.get + val tempRes = mapId(resid).toVariable + Ensuring(newfd.body.get, Lambda(Seq(ValDef(tempRes.id, Some(tempRes.getType))), transformExpr(pexpr))) + // Some(mapId(resid), transformExpr(pexpr)) + } else NoTree(fd.returnType) + + fd.flags.foreach(newfd.addFlag(_)) + }) + + val newprog = copyProgram(program, (defs: Seq[Definition]) => defs.map { + case fd: FunDef => newFundefs(fd) + case cd: ClassDef => mapClass(cd) + case d @ _ => throw new IllegalStateException("Unknown Definition: " + d) + }) + newprog + } +} + +class IntToRealProgram extends ProgramTypeTransformer { + + private var realToIntId = Map[Identifier, Identifier]() + + def mapNumericType(tpe: TypeTree) = { + require(isNumericType(tpe)) + tpe match { + case IntegerType => RealType + case _ => tpe + } + } + + def mapLiteral(lit: Literal[_]): Literal[_] = lit match { + case IntLiteral(v) => FractionalLiteral(v, 1) + case _ => lit + } + + def apply(program: Program): Program = { + + val newprog = transform(program) + //reverse the map + realToIntId = idmap.map(entry => (entry._2 -> entry._1)) + //println("After Real Program Conversion: \n" + ScalaPrinter.apply(newprog)) + //print all the templates + /*newprog.definedFunctions.foreach((fd) => { + val funinfo = FunctionInfoFactory.getFunctionInfo(fd) + if (funinfo.isDefined && funinfo.get.hasTemplate) + println("Function: " + fd.id + " template --> " + funinfo.get.getTemplate) + })*/ + newprog + } + + /** + * Assuming that the model maps only variables + */ + def unmapModel(model: Map[Identifier, Expr]): Map[Identifier, Expr] = { + model.map((pair) => { + val (key, value) = if (realToIntId.contains(pair._1)) { + (realToIntId(pair._1), pair._2) + } else pair + (key -> value) + }) + } +} + +class RealToIntProgram extends ProgramTypeTransformer { + val debugIntToReal = false + val bone = BigInt(1) + + def mapNumericType(tpe: TypeTree) = { + require(isNumericType(tpe)) + tpe match { + case RealType => IntegerType + case _ => tpe + } + } + + def mapLiteral(lit: Literal[_]): Literal[_] = lit match { + case FractionalLiteral(v, `bone`) => InfiniteIntegerLiteral(v) + case FractionalLiteral(_, _) => throw new IllegalStateException("Cannot convert real to integer: " + lit) + case _ => lit + } + + def apply(program: Program): Program = { + + val newprog = transform(program) + + if (debugIntToReal) + println("Program to Verify: \n" + ScalaPrinter.apply(newprog)) + + newprog + } + + def mappedFun(fd: FunDef): FunDef = newFundefs(fd) +} \ No newline at end of file diff --git a/src/main/scala/leon/transformations/NonRecursiveTimePhase.scala b/src/main/scala/leon/transformations/NonRecursiveTimePhase.scala new file mode 100644 index 0000000000000000000000000000000000000000..f234c8afeed166e10c629cf3008567a031158f60 --- /dev/null +++ b/src/main/scala/leon/transformations/NonRecursiveTimePhase.scala @@ -0,0 +1,120 @@ +package leon +package transformations + +import purescala.Common._ +import purescala.Definitions._ +import purescala.Extractors._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.Types._ +import leon.utils._ +import leon.invariant.util.Util._ + +import scala.collection.mutable.{Map => MutableMap} + +object tprCostModel { + def costOf(e: Expr): Int = e match { + case FunctionInvocation(fd, args) => 1 + case t: Terminal => 0 + case _ => 1 + } + def costOfExpr(e: Expr) = InfiniteIntegerLiteral(costOf(e)) +} + +class TPRInstrumenter(p: Program, si: SerialInstrumenter) extends Instrumenter(p, si) { + import tprCostModel._ + + def inst = TPR + + val sccs = cg.graph.sccs.flatMap { scc => + scc.map(fd => (fd -> scc.toSet)) + }.toMap + + //find all functions transitively called from rootFuncs (here ignore functions called via pre/post conditions) + val tprFuncs = getRootFuncs() + val timeFuncs = tprFuncs.foldLeft(Set[FunDef]())((acc, fd) => acc ++ cg.transitiveCallees(fd)) + + def functionsToInstrument(): Map[FunDef, List[Instrumentation]] = { + var emap = MutableMap[FunDef,List[Instrumentation]]() + def update(fd: FunDef, inst: Instrumentation) { + if (emap.contains(fd)) + emap(fd) :+= inst + else emap.update(fd, List(inst)) + } + tprFuncs.map(fd => update(fd, TPR)) + timeFuncs.map(fd => update(fd, Time)) + emap.toMap + } + + def additionalfunctionsToAdd() = Seq() + + def instrumentMatchCase( + me: MatchExpr, + mc: MatchCase, + caseExprCost: Expr, + scrutineeCost: Expr): Expr = { + val costMatch = costOfExpr(me) + + def totalCostOfMatchPatterns(me: MatchExpr, mc: MatchCase): BigInt = { + def patCostRecur(pattern: Pattern, innerPat: Boolean, countLeafs: Boolean): Int = { + pattern match { + case InstanceOfPattern(_, _) => { + if (innerPat) 2 else 1 + } + case WildcardPattern(None) => 0 + case WildcardPattern(Some(id)) => { + if (countLeafs && innerPat) 1 + else 0 + } + case CaseClassPattern(_, _, subPatterns) => { + (if (innerPat) 2 else 1) + subPatterns.foldLeft(0)((acc, subPat) => + acc + patCostRecur(subPat, true, countLeafs)) + } + case TuplePattern(_, subPatterns) => { + (if (innerPat) 2 else 1) + subPatterns.foldLeft(0)((acc, subPat) => + acc + patCostRecur(subPat, true, countLeafs)) + } + case LiteralPattern(_, _) => if (innerPat) 2 else 1 + case _ => + throw new NotImplementedError(s"Pattern $pattern not handled yet!") + } + } + me.cases.take(me.cases.indexOf(mc)).foldLeft(0)( + (acc, currCase) => acc + patCostRecur(currCase.pattern, false, false)) + + patCostRecur(mc.pattern, false, true) + } + Plus(costMatch, Plus( + Plus(InfiniteIntegerLiteral(totalCostOfMatchPatterns(me, mc)), + caseExprCost), + scrutineeCost)) + } + + def instrument(e: Expr, subInsts: Seq[Expr], funInvResVar: Option[Variable] = None) + (implicit fd: FunDef, letIdMap: Map[Identifier,Identifier]): Expr = e match { + case t: Terminal => costOfExpr(t) + case FunctionInvocation(tfd, args) => { + val remSubInsts = if (tprFuncs.contains(tfd.fd)) + subInsts.slice(0, subInsts.length - 1) + else subInsts + if (sccs(fd)(tfd.fd)) { + remSubInsts.foldLeft(costOfExpr(e) : Expr)( + (acc: Expr, subeTime: Expr) => Plus(subeTime, acc)) + } + else { + val allSubInsts = remSubInsts :+ si.selectInst(tfd.fd)(funInvResVar.get, Time) + allSubInsts.foldLeft(costOfExpr(e) : Expr)( + (acc: Expr, subeTime: Expr) => Plus(subeTime, acc)) + } + } + case _ => + subInsts.foldLeft(costOfExpr(e) : Expr)( + (acc: Expr, subeTime: Expr) => Plus(subeTime, acc)) + } + + def instrumentIfThenElseExpr(e: IfExpr, condInst: Option[Expr], + thenInst: Option[Expr], elzeInst: Option[Expr]): (Expr, Expr) = { + val costIf = costOfExpr(e) + (Plus(costIf, Plus(condInst.get, thenInst.get)), + Plus(costIf, Plus(condInst.get, elzeInst.get))) + } +} \ No newline at end of file diff --git a/src/main/scala/leon/transformations/NonlinearityEliminationPhase.scala b/src/main/scala/leon/transformations/NonlinearityEliminationPhase.scala new file mode 100644 index 0000000000000000000000000000000000000000..2f41a5ccb55aef436d5a69289ab473cc392e1fda --- /dev/null +++ b/src/main/scala/leon/transformations/NonlinearityEliminationPhase.scala @@ -0,0 +1,188 @@ +package leon +package transformations + +import invariant.factories._ +import invariant.util.Util._ +import invariant.structure.FunctionUtils._ + +import purescala.ScalaPrinter +import purescala.Common._ +import purescala.Definitions._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.Extractors._ +import purescala.Types._ + +object MultFuncs { + def getMultFuncs(domain: TypeTree): (FunDef, FunDef) = { + //a recursive function that represents multiplication of two positive arguments + val pivMultFun = { + val xid = FreshIdentifier("x", domain) + val yid = FreshIdentifier("y", domain) + val varx = xid.toVariable + val vary = yid.toVariable + val args = Seq(xid, yid) + val funcType = FunctionType(Seq(domain, domain), domain) + val mfd = new FunDef(FreshIdentifier("pmult", funcType, false), Seq(), domain, + args.map((arg) => ValDef(arg, Some(arg.getType)))) + val tmfd = TypedFunDef(mfd, Seq()) + + //define a body (a) using mult(x,y) = if(x == 0 || y ==0) 0 else mult(x-1,y) + y + val cond = Or(Equals(varx, zero), Equals(vary, zero)) + val xminus1 = Minus(varx, one) + val yminus1 = Minus(vary, one) + val elze = Plus(FunctionInvocation(tmfd, Seq(xminus1, vary)), vary) + mfd.body = Some(IfExpr(cond, zero, elze)) + + //add postcondition + val resvar = FreshIdentifier("res", domain).toVariable + val post0 = GreaterEquals(resvar, zero) + + //define alternate definitions of multiplication as postconditions + //(a) res = !(x==0 || y==0) => mult(x,y-1) + x + val guard = Not(cond) + val defn2 = Equals(resvar, Plus(FunctionInvocation(tmfd, Seq(varx, yminus1)), varx)) + val post1 = Implies(guard, defn2) + + // mfd.postcondition = Some((resvar.id, And(Seq(post0, post1)))) + mfd.fullBody = Ensuring(mfd.body.get, Lambda(Seq(ValDef(resvar.id, Some(resvar.getType))), And(Seq(post0, post1)))) + //set function properties (for now, only monotonicity) + mfd.addFlags(Set(Annotation("theoryop", Seq()), Annotation("monotonic", Seq()))) //"distributive" ? + mfd + } + + //a function that represents multiplication, this transitively calls pmult + val multFun = { + val xid = FreshIdentifier("x", domain) + val yid = FreshIdentifier("y", domain) + val args = Seq(xid, yid) + val funcType = FunctionType(Seq(domain, domain), domain) + val fd = new FunDef(FreshIdentifier("mult", funcType, false), Seq(), domain, args.map((arg) => ValDef(arg, Some(arg.getType)))) + val tpivMultFun = TypedFunDef(pivMultFun, Seq()) + + //the body is defined as mult(x,y) = val px = if(x < 0) -x else x; + //val py = if(y<0) -y else y; val r = pmult(px,py); + //if(x < 0 && y < 0 || x >= 0 && y >= 0) r else -r + val varx = xid.toVariable + val vary = yid.toVariable + val modx = IfExpr(LessThan(varx, zero), UMinus(varx), varx) + val mody = IfExpr(LessThan(vary, zero), UMinus(vary), vary) + val px = FreshIdentifier("px", domain, false) + val py = FreshIdentifier("py", domain, false) + val call = Let(px, modx, Let(py, mody, FunctionInvocation(tpivMultFun, Seq(px, py).map(_.toVariable)))) + val bothPive = And(GreaterEquals(varx, zero), GreaterEquals(vary, zero)) + val bothNive = And(LessThan(varx, zero), LessThan(vary, zero)) + val res = FreshIdentifier("r", domain, false) + val body = Let(res, call, IfExpr(Or(bothPive, bothNive), res.toVariable, UMinus(res.toVariable))) + fd.body = Some(body) + //set function properties + fd.addFlags(Set(Annotation("theoryop", Seq()), Annotation("monotonic", Seq()))) + fd + } + + (pivMultFun, multFun) + } +} + +class NonlinearityEliminator(skipAxioms: Boolean, domain: TypeTree) { + import MultFuncs._ + require(isNumericType(domain)) + + val debugNLElim = false + + val one = InfiniteIntegerLiteral(1) + val zero = InfiniteIntegerLiteral(0) + + val (pivMultFun, multFun) = getMultFuncs(domain) + + //TOOD: note associativity property of multiplication is not taken into account + def apply(program: Program): Program = { + + //create a fundef for each function in the program + val newFundefs = program.definedFunctions.map((fd) => { + val newFunType = FunctionType(fd.tparams.map((currParam) => currParam.tp), fd.returnType) + val newfd = new FunDef(FreshIdentifier(fd.id.name, newFunType, false), fd.tparams, fd.returnType, fd.params) + (fd, newfd) + }).toMap + + //note, handling templates variables is slightly tricky as we need to preserve a*x as it is + val tmult = TypedFunDef(multFun, Seq()) + var addMult = false + def replaceFun(ine: Expr, allowedVars: Set[Identifier] = Set()): Expr = { + simplePostTransform(e => e match { + case fi @ FunctionInvocation(tfd1, args) if newFundefs.contains(tfd1.fd) => + FunctionInvocation(TypedFunDef(newFundefs(tfd1.fd), tfd1.tps), args) + + case Times(Variable(id), e2) if (allowedVars.contains(id)) => e + case Times(e1, Variable(id)) if (allowedVars.contains(id)) => e + + case Times(e1, e2) if (!e1.isInstanceOf[Literal[_]] && !e2.isInstanceOf[Literal[_]]) => { + //replace times by a mult function + addMult = true + FunctionInvocation(tmult, Seq(e1, e2)) + } + //note: include mult function if division operation is encountered + //division is handled during verification condition generation. + case Division(_, _) => { + addMult = true + e + } + case _ => e + })(ine) + } + + //create a body, pre, post for each newfundef + newFundefs.foreach((entry) => { + val (fd, newfd) = entry + + //add a new precondition + newfd.precondition = + if (fd.precondition.isDefined) + Some(replaceFun(fd.precondition.get)) + else None + + //add a new body + newfd.body = if (fd.hasBody) { + //replace variables by constants if possible + val simpBody = simplifyLets(fd.body.get) + Some(replaceFun(simpBody)) + } else None + + + //add a new postcondition + newfd.postcondition = if (fd.postcondition.isDefined) { + //we need to handle template and postWoTemplate specially + val Lambda(resultBinders, _) = fd.postcondition.get + val tmplExpr = fd.templateExpr + val newpost = if (fd.hasTemplate) { + val FunctionInvocation(tmpfd, Seq(Lambda(tmpvars, tmpbody))) = tmplExpr.get + val newtmp = FunctionInvocation(tmpfd, Seq(Lambda(tmpvars, + replaceFun(tmpbody, tmpvars.map(_.id).toSet)))) + fd.postWoTemplate match { + case None => + newtmp + case Some(postExpr) => + And(replaceFun(postExpr), newtmp) + } + } else + replaceFun(fd.getPostWoTemplate) + + Some(Lambda(resultBinders, newpost)) + } else None + + fd.flags.foreach(newfd.addFlag(_)) + }) + + val newprog = copyProgram(program, (defs: Seq[Definition]) => { + defs.map { + case fd: FunDef => newFundefs(fd) + case d => d + } ++ (if (addMult) Seq(multFun, pivMultFun) else Seq()) + }) + + if (debugNLElim) + println("After Nonlinearity Elimination: \n" + ScalaPrinter.apply(newprog)) + + newprog + } +} diff --git a/src/main/scala/leon/transformations/RecursionCountPhase.scala b/src/main/scala/leon/transformations/RecursionCountPhase.scala new file mode 100644 index 0000000000000000000000000000000000000000..73472e272b154aa57d174c467c2d0f4e250c2e77 --- /dev/null +++ b/src/main/scala/leon/transformations/RecursionCountPhase.scala @@ -0,0 +1,71 @@ +package leon +package transformations + +import purescala.Common._ +import purescala.Definitions._ +import purescala.Extractors._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.Types._ +import leon.utils._ +import invariant.util.Util._ + +class RecursionCountInstrumenter(p: Program, si: SerialInstrumenter) extends Instrumenter(p, si) { + + def inst = Rec + + val sccs = cg.graph.sccs.flatMap { scc => + scc.map(fd => (fd -> scc.toSet)) + }.toMap + + /** + * Instrument only those functions that are in the same sccs of the root functions + */ + def functionsToInstrument(): Map[FunDef, List[Instrumentation]] = { + val instFunSet = getRootFuncs().flatMap(sccs.apply _) + instFunSet.map(x => (x, List(Rec))).toMap + } + + override def additionalfunctionsToAdd(): Seq[FunDef] = Seq.empty[FunDef] + + def addSubInstsIfNonZero(subInsts: Seq[Expr], init: Expr): Expr = { + subInsts.foldLeft(init) { + case (acc, subinst) if subinst != zero => + if (acc == zero) subinst + else Plus(acc, subinst) + } + } + + def instrumentMatchCase(me: MatchExpr, + mc: MatchCase, + caseExprCost: Expr, + scrutineeCost: Expr): Expr = { + Plus(caseExprCost, scrutineeCost) + } + + def instrument(e: Expr, subInsts: Seq[Expr], funInvResVar: Option[Variable] = None) + (implicit fd: FunDef, leIdtMap: Map[Identifier,Identifier]): Expr = e match { + case FunctionInvocation(TypedFunDef(callee, _), _) if sccs(fd)(callee) => + //this is a recursive call + //Note that the last element of subInsts is the instExpr of the invoked function + addSubInstsIfNonZero(subInsts, one) + case FunctionInvocation(TypedFunDef(callee, _), _) if si.funcInsts.contains(callee) && si.funcInsts(callee).contains(this.inst) => + //this is not a recursive call, so do not consider the cost of the callee + //Note that the last element of subInsts is the instExpr of the invoked function + addSubInstsIfNonZero(subInsts.take(subInsts.size - 1), zero) + case _ => + //add the cost of every sub-expression + addSubInstsIfNonZero(subInsts, zero) + } + + def instrumentIfThenElseExpr(e: IfExpr, condInst: Option[Expr], thenInst: Option[Expr], + elzeInst: Option[Expr]): (Expr, Expr) = { + + val cinst = condInst.toList + val tinst = thenInst.toList + val einst = elzeInst.toList + + (addSubInstsIfNonZero(cinst ++ tinst, zero), + addSubInstsIfNonZero(cinst ++ einst, zero)) + } +} \ No newline at end of file diff --git a/src/main/scala/leon/transformations/SerialInstrumentationPhase.scala b/src/main/scala/leon/transformations/SerialInstrumentationPhase.scala new file mode 100644 index 0000000000000000000000000000000000000000..f6862377d979ad1aeb8345494b5a89113bdc8bd0 --- /dev/null +++ b/src/main/scala/leon/transformations/SerialInstrumentationPhase.scala @@ -0,0 +1,489 @@ +package leon +package transformations + +import purescala.Common._ +import purescala.Definitions._ +import purescala.Extractors._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.Types._ +import leon.purescala.ScalaPrinter +import leon.utils._ +import invariant.util._ +import invariant.util.CallGraphUtil +import invariant.structure.FunctionUtils._ +import scala.collection.mutable.{Map => MutableMap} + +/** + * An instrumentation phase that performs a sequence of instrumentations + */ + +object InstrumentationPhase extends TransformationPhase { + val name = "Instrumentation Phase" + val description = "Instruments the program for all counters needed" + + def apply(ctx: LeonContext, program: Program): Program = { + val instprog = new SerialInstrumenter(ctx, program) + instprog.apply + } +} + +class SerialInstrumenter(ctx: LeonContext, program: Program) { + val debugInstrumentation = false + + val instToInstrumenter: Map[Instrumentation, Instrumenter] = + Map(Time -> new TimeInstrumenter(program, this), + Depth -> new DepthInstrumenter(program, this), + Rec -> new RecursionCountInstrumenter(program, this), + Stack -> new StackSpaceInstrumenter(program, this), + TPR -> new TPRInstrumenter(program, this)) + + // a map from functions to the list of instrumentations to be performed for the function + lazy val funcInsts = { + var emap = MutableMap[FunDef,List[Instrumentation]]() + def update(fd: FunDef, inst: Instrumentation) { + if (emap.contains(fd)) + emap(fd) = (emap(fd) :+ inst).distinct + else emap.update(fd, List(inst)) + } + instToInstrumenter.values.foreach{ m => + m.functionsToInstrument.foreach({ case (fd, instsToPerform) => + instsToPerform.foreach(instToPerform => update(fd, instToPerform)) }) + } + emap.toMap + } + lazy val instFuncs = funcInsts.keySet //should we exclude theory operations ? + + def instrumenters(fd: FunDef) = funcInsts(fd) map instToInstrumenter.apply _ + def instTypes(fd: FunDef) = funcInsts(fd).map(_.getType) + /** + * Index of the instrumentation 'inst' in result tuple that would be created. + * The return value will be >= 2 as the actual result value would be at index 1 + */ + def instIndex(fd: FunDef)(ins: Instrumentation) = funcInsts(fd).indexOf(ins) + 2 + def selectInst(fd: FunDef)(e: Expr, ins: Instrumentation) = TupleSelect(e, instIndex(fd)(ins)) + + def apply: Program = { + + //create new functions. Augment the return type of a function iff the postcondition uses + //the instrumentation variable or if the function is transitively called from such a function + //note: need not instrument fields + val funMap = Util.functionsWOFields(program.definedFunctions).foldLeft(Map[FunDef, FunDef]()) { + case (accMap, fd: FunDef) if fd.isTheoryOperation => + accMap + (fd -> fd) + case (accMap, fd) => { + if (instFuncs.contains(fd)) { + val newRetType = TupleType(fd.returnType +: instTypes(fd)) + // let the names of the function encode the kind of instrumentations performed + val freshId = FreshIdentifier(fd.id.name + "-" + funcInsts(fd).map(_.name).mkString("-"), newRetType) + val newfd = new FunDef(freshId, fd.tparams, newRetType, fd.params) + accMap + (fd -> newfd) + } else { + //here we need not augment the return types but do need to create a new copy + val freshId = FreshIdentifier(fd.id.name, fd.returnType) + val newfd = new FunDef(freshId, fd.tparams, fd.returnType, fd.params) + accMap + (fd -> newfd) + } + } + } + + def mapExpr(ine: Expr): Expr = { + simplePostTransform((e: Expr) => e match { + case FunctionInvocation(tfd, args) if funMap.contains(tfd.fd) => + if (instFuncs.contains(tfd.fd)) + TupleSelect(FunctionInvocation(TypedFunDef(funMap(tfd.fd), tfd.tps), args), 1) + else + FunctionInvocation(TypedFunDef(funMap(tfd.fd), tfd.tps), args) + case _ => e + })(ine) + } + + def mapBody(body: Expr, from: FunDef, to: FunDef) = { + val res = if (instFuncs.contains(from)) { + (new ExprInstrumenter(funMap)(from)(body)) + } else + mapExpr(body) + res + } + + def mapPost(pred: Expr, from: FunDef, to: FunDef) = { + pred match { + case Lambda(Seq(ValDef(fromRes, _)), postCond) if (instFuncs.contains(from)) => + val toResId = FreshIdentifier(fromRes.name, to.returnType, true) + val newpost = postMap((e: Expr) => e match { + case Variable(`fromRes`) => + Some(TupleSelect(toResId.toVariable, 1)) + + case _ if funcInsts(from).exists(_.isInstVariable(e)) => + val inst = funcInsts(from).find(_.isInstVariable(e)).get + Some(TupleSelect(toResId.toVariable, instIndex(from)(inst))) + + case _ => + None + })(postCond) + Lambda(Seq(ValDef(toResId)), mapExpr(newpost)) + case _ => + mapExpr(pred) + } + } + + // Map the bodies and preconditions + for ((from, to) <- funMap) { + //copy annotations + from.flags.foreach(to.addFlag(_)) + to.fullBody = from.fullBody match { + case Require(pre, body) => + //here 'from' does not have a postcondition but 'to' will always have a postcondition + val toPost = + Lambda(Seq(ValDef(FreshIdentifier("res", to.returnType))), BooleanLiteral(true)) + val bodyPre = + Require(mapExpr(pre), mapBody(body, from, to)) + Ensuring(bodyPre, toPost) + + case Ensuring(Require(pre, body), post) => + Ensuring(Require(mapExpr(pre), mapBody(body, from, to)), + mapPost(post, from, to)) + + case Ensuring(body, post) => + Ensuring(mapBody(body, from, to), mapPost(post, from, to)) + + case body => + val toPost = + Lambda(Seq(ValDef(FreshIdentifier("res", to.returnType))), BooleanLiteral(true)) + Ensuring(mapBody(body, from, to), toPost) + } + } + + val additionalFuncs = funMap.flatMap{ case (k, _) => + if (instFuncs(k)) + instrumenters(k).flatMap(_.additionalfunctionsToAdd) + else List() + }.toList.distinct + + val newprog = Util.copyProgram(program, (defs: Seq[Definition]) => + defs.map { + case fd: FunDef if funMap.contains(fd) => + funMap(fd) + case d => d + } ++ additionalFuncs) + if (debugInstrumentation) + println("After Instrumentation: \n" + ScalaPrinter.apply(newprog)) + + ProgramSimplifier(newprog) + } + + class ExprInstrumenter(funMap: Map[FunDef, FunDef])(implicit currFun: FunDef) { + val retainMatches = true + + val insts = funcInsts(currFun) + val instrumenters = SerialInstrumenter.this.instrumenters(currFun) + val instIndex = SerialInstrumenter.this.instIndex(currFun) _ + val selectInst = SerialInstrumenter.this.selectInst(currFun) _ + val instTypes = SerialInstrumenter.this.instTypes(currFun) + + // Should be called only if 'expr' has to be instrumented + // Returned Expr is always an expr of type tuple (Expr, Int) + def tupleify(e: Expr, subs: Seq[Expr], recons: Seq[Expr] => Expr)(implicit letIdMap: Map[Identifier, Identifier]): Expr = { + // When called for: + // Op(n1,n2,n3) + // e = Op(n1,n2,n3) + // subs = Seq(n1,n2,n3) + // recons = { Seq(newn1,newn2,newn3) => Op(newn1, newn2, newn3) } + // + // This transformation should return, informally: + // + // LetTuple((e1,t1), transform(n1), + // LetTuple((e2,t2), transform(n2), + // ... + // Tuple(recons(e1, e2, ...), t1 + t2 + ... costOfExpr(Op) + // ... + // ) + // ) + // + // You will have to handle FunctionInvocation specially here! + tupleifyRecur(e, subs, recons, List(), Map()) + } + + def tupleifyRecur(e: Expr, subs: Seq[Expr], recons: Seq[Expr] => Expr, subeVals: List[Expr], + subeInsts: Map[Instrumentation, List[Expr]])(implicit letIdMap: Map[Identifier, Identifier]): Expr = { + //note: subs.size should be zero if e is a terminal + if (subs.size == 0) { + e match { + case v @ Variable(id) => + val valPart = if (letIdMap.contains(id)) { + TupleSelect(letIdMap(id).toVariable, 1) //this takes care of replacement + } else v + val instPart = instrumenters map (_.instrument(v, Seq())) + Tuple(valPart +: instPart) + + case t: Terminal => + val instPart = instrumenters map (_.instrument(t, Seq())) + val finalRes = Tuple(t +: instPart) + finalRes + + case f @ FunctionInvocation(tfd, args) if tfd.fd.isRealFunction => + val newfd = funMap(tfd.fd) + val newFunInv = FunctionInvocation(TypedFunDef(newfd, tfd.tps), subeVals) + //create a variables to store the result of function invocation + if (instFuncs(tfd.fd)) { + //this function is also instrumented + val resvar = Variable(FreshIdentifier("e", newfd.returnType, true)) + val valexpr = TupleSelect(resvar, 1) + val instexprs = instrumenters.map { m => + val calleeInst = if (funcInsts(tfd.fd).contains(m.inst)) { + List(SerialInstrumenter.this.selectInst(tfd.fd)(resvar, m.inst)) + } else List() + //Note we need to ensure that the last element of list is the instval of the finv + m.instrument(e, subeInsts.getOrElse(m.inst, List()) ++ calleeInst, Some(resvar)) + } + Let(resvar.id, newFunInv, Tuple(valexpr +: instexprs)) + } else { + val resvar = Variable(FreshIdentifier("e", tfd.fd.returnType, true)) + val instexprs = instrumenters.map { m => + m.instrument(e, subeInsts.getOrElse(m.inst, List())) + } + Let(resvar.id, newFunInv, Tuple(resvar +: instexprs)) + } + + // This case will be taken if the function invocation is actually a val (lazy or otherwise) in the class + case f @ FunctionInvocation(tfd, args) => + val resvar = Variable(FreshIdentifier("e", tfd.fd.returnType, true)) + val instPart = instrumenters map (_.instrument(f, Seq())) + val finalRes = Tuple(f +: instPart) + finalRes + + case _ => + val exprPart = recons(subeVals) + val instexprs = instrumenters.zipWithIndex.map { + case (menter, i) => menter.instrument(e, subeInsts.getOrElse(menter.inst, List())) + } + Tuple(exprPart +: instexprs) + } + } else { + val currExp = subs.head + val resvar = Variable(FreshIdentifier("e", TupleType(currExp.getType +: instTypes), true)) + val eval = TupleSelect(resvar, 1) + val instMap = insts.map { inst => + (inst -> (subeInsts.getOrElse(inst, List()) :+ selectInst(resvar, inst))) + }.toMap + //process the remaining arguments + val recRes = tupleifyRecur(e, subs.tail, recons, subeVals :+ eval, instMap) + //transform the current expression + val newCurrExpr = transform(currExp) + Let(resvar.id, newCurrExpr, recRes) + } + } + + /** + * TODO: need to handle new expression trees + * Match statements without guards are now instrumented directly + */ + def transform(e: Expr)(implicit letIdMap: Map[Identifier, Identifier]): Expr = e match { + // Assume that none of the matchcases has a guard. It has already been converted into an if then else + case me @ MatchExpr(scrutinee, matchCases) => + val containsGuard = matchCases.exists(currCase => currCase.optGuard.isDefined) + if (containsGuard) { + def rewritePM(me: MatchExpr): Option[Expr] = { + val MatchExpr(scrut, cases) = me + val condsAndRhs = for (cse <- cases) yield { + val map = mapForPattern(scrut, cse.pattern) + val patCond = conditionForPattern(scrut, cse.pattern, includeBinders = false) + val realCond = cse.optGuard match { + case Some(g) => And(patCond, replaceFromIDs(map, g)) + case None => patCond + } + val newRhs = replaceFromIDs(map, cse.rhs) + (realCond, newRhs) + } + val bigIte = condsAndRhs.foldRight[Expr]( + Error(me.getType, "Match is non-exhaustive").copiedFrom(me))((p1, ex) => { + if (p1._1 == BooleanLiteral(true)) { + p1._2 + } else { + IfExpr(p1._1, p1._2, ex) + } + }) + Some(bigIte) + } + transform(rewritePM(me).get) + } else { + val instScrutinee = + Variable(FreshIdentifier("scr", TupleType(scrutinee.getType +: instTypes), true)) + + def transformMatchCaseList(mCases: Seq[MatchCase]): Seq[MatchCase] = { + def transformMatchCase(mCase: MatchCase) = { + val MatchCase(pattern, guard, expr) = mCase + val newExpr = { + val exprVal = + Variable(FreshIdentifier("expr", TupleType(expr.getType +: instTypes), true)) + val newExpr = transform(expr) + val instExprs = instrumenters map { m => + m.instrumentMatchCase(me, mCase, selectInst(exprVal, m.inst), + selectInst(instScrutinee, m.inst)) + } + val letBody = Tuple(TupleSelect(exprVal, 1) +: instExprs) + Let(exprVal.id, newExpr, letBody) + } + MatchCase(pattern, guard, newExpr) + } + if (mCases.length == 0) Seq[MatchCase]() + else { + transformMatchCase(mCases.head) +: transformMatchCaseList(mCases.tail) + } + } + val matchExpr = MatchExpr(TupleSelect(instScrutinee, 1), + transformMatchCaseList(matchCases)) + Let(instScrutinee.id, transform(scrutinee), matchExpr) + } + + case Let(i, v, b) => { + val (ni, nv) = { + val ir = Variable(FreshIdentifier("ir", TupleType(v.getType +: instTypes), true)) + val transv = transform(v) + (ir, transv) + } + val r = Variable(FreshIdentifier("r", TupleType(b.getType +: instTypes), true)) + val transformedBody = transform(b)(letIdMap + (i -> ni.id)) + val instexprs = instrumenters map { m => + m.instrument(e, List(selectInst(ni, m.inst), selectInst(r, m.inst))) + } + Let(ni.id, nv, + Let(r.id, transformedBody, Tuple(TupleSelect(r, 1) +: instexprs))) + } + + case ife @ IfExpr(cond, th, elze) => { + val (nifCons, condInsts) = { + val rescond = Variable(FreshIdentifier("c", TupleType(cond.getType +: instTypes), true)) + val condInstPart = insts.map { inst => (inst -> selectInst(rescond, inst)) }.toMap + val recons = (e1: Expr, e2: Expr) => { + Let(rescond.id, transform(cond), IfExpr(TupleSelect(rescond, 1), e1, e2)) + } + (recons, condInstPart) + } + val (nthenCons, thenInsts) = { + val resthen = Variable(FreshIdentifier("th", TupleType(th.getType +: instTypes), true)) + val thInstPart = insts.map { inst => (inst -> selectInst(resthen, inst)) }.toMap + val recons = (theninsts: List[Expr]) => { + Let(resthen.id, transform(th), Tuple(TupleSelect(resthen, 1) +: theninsts)) + } + (recons, thInstPart) + } + val (nelseCons, elseInsts) = { + val reselse = Variable(FreshIdentifier("el", TupleType(elze.getType +: instTypes), true)) + val elInstPart = insts.map { inst => (inst -> selectInst(reselse, inst)) }.toMap + val recons = (einsts: List[Expr]) => { + Let(reselse.id, transform(elze), Tuple(TupleSelect(reselse, 1) +: einsts)) + } + (recons, elInstPart) + } + val (finalThInsts, finalElInsts) = instrumenters.foldLeft((List[Expr](), List[Expr]())) { + case ((thinsts, elinsts), menter) => + val inst = menter.inst + val (thinst, elinst) = menter.instrumentIfThenElseExpr(ife, + Some(condInsts(inst)), Some(thenInsts(inst)), Some(elseInsts(inst))) + (thinsts :+ thinst, elinsts :+ elinst) + } + val nthen = nthenCons(finalThInsts) + val nelse = nelseCons(finalElInsts) + nifCons(nthen, nelse) + } + + // For all other operations, we go through a common tupleifier. + case n @ Operator(ss, recons) => + tupleify(e, ss, recons) + +/* case b @ BinaryOperator(s1, s2, recons) => + tupleify(e, Seq(s1, s2), { case Seq(s1, s2) => recons(s1, s2) }) + + case u @ UnaryOperator(s, recons) => + tupleify(e, Seq(s), { case Seq(s) => recons(s) }) +*/ + case t: Terminal => + tupleify(e, Seq(), { case Seq() => t }) + } + + def apply(e: Expr): Expr = { + // Apply transformations + val newe = + if (retainMatches) e + else matchToIfThenElse(liftExprInMatch(e)) + val transformed = transform(newe)(Map()) + val bodyId = FreshIdentifier("bd", transformed.getType, true) + val instExprs = instrumenters map { m => + m.instrumentBody(newe, + selectInst(bodyId.toVariable, m.inst)) + + } + Let(bodyId, transformed, + Tuple(TupleSelect(bodyId.toVariable, 1) +: instExprs)) + } + + def liftExprInMatch(ine: Expr): Expr = { + def helper(e: Expr): Expr = { + e match { + case MatchExpr(strut, cases) => strut match { + case t: Terminal => e + case _ => { + val freshid = FreshIdentifier("m", strut.getType, true) + Let(freshid, strut, MatchExpr(freshid.toVariable, cases)) + } + } + case _ => e + } + } + + if (retainMatches) helper(ine) + else simplePostTransform(helper)(ine) + } + } +} + +/** + * Implements procedures for a specific instrumentation + */ +abstract class Instrumenter(program: Program, si: SerialInstrumenter) { + + def inst: Instrumentation + + protected val cg = CallGraphUtil.constructCallGraph(program, onlyBody = true) + + def functionsToInstrument(): Map[FunDef, List[Instrumentation]] + + def additionalfunctionsToAdd(): Seq[FunDef] + + def instrumentBody(bodyExpr: Expr, instExpr: Expr)(implicit fd: FunDef): Expr = instExpr + + def getRootFuncs(prog: Program = program): Set[FunDef] = { + prog.definedFunctions.filter { fd => + (fd.hasPostcondition && exists(inst.isInstVariable)(fd.postcondition.get)) + }.toSet + } + + /** + * Given an expression to be instrumented + * and the instrumentation of each of its subexpressions, + * computes an instrumentation for the procedure. + * The sub-expressions correspond to expressions returned + * by Expression Extractors. + * fd is the function containing the expression `e` + */ + def instrument(e: Expr, subInsts: Seq[Expr], funInvResVar: Option[Variable] = None) + (implicit fd: FunDef, letIdMap: Map[Identifier, Identifier]): Expr + + /** + * Instrument procedure specialized for if-then-else + */ + def instrumentIfThenElseExpr(e: IfExpr, condInst: Option[Expr], + thenInst: Option[Expr], elzeInst: Option[Expr]): (Expr, Expr) + + /** + * This function is expected to combine the cost of the scrutinee, + * the pattern matching and the expression. + * The cost model for pattern matching is left to the user. + * As matches with guards are converted to ifThenElse statements, + * the user may want to make sure that the cost model for pattern + * matching across match statements and ifThenElse statements is consistent + */ + def instrumentMatchCase(me: MatchExpr, mc: MatchCase, + caseExprCost: Expr, scrutineeCost: Expr): Expr +} \ No newline at end of file diff --git a/src/main/scala/leon/transformations/StackSpacePhase.scala b/src/main/scala/leon/transformations/StackSpacePhase.scala new file mode 100644 index 0000000000000000000000000000000000000000..f4edb7ee1ae3fa53fc86a16b0f2b7c8c5146284b --- /dev/null +++ b/src/main/scala/leon/transformations/StackSpacePhase.scala @@ -0,0 +1,336 @@ +package leon +package transformations + +import purescala.Common._ +import purescala.Definitions._ +import purescala.Extractors._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.Types._ +import leon.utils._ + +class StackSpaceInstrumenter(p: Program, si: SerialInstrumenter) extends Instrumenter(p, si) { + val typedMaxFun = TypedFunDef(InstUtil.maxFun, Seq()) + val optimiseTailCalls = true + + def inst = Stack + + def functionsToInstrument(): Map[FunDef, List[Instrumentation]] = { + // find all functions transitively called from rootFuncs (here ignore functions called via pre/post conditions) + val instFunSet = getRootFuncs().foldLeft(Set[FunDef]())((acc, fd) => acc ++ cg.transitiveCallees(fd)) + instFunSet.map(x => (x, List(Stack))).toMap + } + + def additionalfunctionsToAdd(): Seq[FunDef] = Seq() //Seq(InstUtil.maxFun) - max functions are inlined, so they need not be added + + def addSubInstsIfNonZero(subInsts: Seq[Expr], init: Expr): Expr = { + subInsts.foldLeft(init)((acc: Expr, subeTime: Expr) => { + (subeTime, acc) match { + case (InfiniteIntegerLiteral(x), _) if (x == 0) => acc + case (_, InfiniteIntegerLiteral(x)) if (x == 0) => subeTime + case _ => FunctionInvocation(typedMaxFun, Seq(acc, subeTime)) + } + }) + } + + // Check if a given function call is a tail recursive call + def isTailCall(call: FunctionInvocation, fd: FunDef): Boolean = { + if (fd.body.isDefined) { + def helper(e: Expr): Boolean = { + e match { + case FunctionInvocation(_,_) if (e == call) => true + case Let(binder, value, body) => helper(body) + case LetDef(_,body) => helper(body) + case IfExpr(_,thenExpr, elseExpr) => helper(thenExpr) || helper(elseExpr) + case MatchExpr(_, mCases) => { + mCases.exists(currCase => helper(currCase.rhs)) + } + case _ => false + } + } + helper(fd.body.get) + } + else false + } + + def instrumentMatchCase(me: MatchExpr, mc: MatchCase, + caseExprCost: Expr, scrutineeCost: Expr): Expr = { + + def costOfMatchPattern(me: MatchExpr, mc: MatchCase): Expr = { + val costOfMatchPattern = 1 + InfiniteIntegerLiteral(costOfMatchPattern) + } + + addSubInstsIfNonZero(Seq(costOfMatchPattern(me, mc), caseExprCost, scrutineeCost), InfiniteIntegerLiteral(0)) + } + + def instrument(e: Expr, subInsts: Seq[Expr], funInvResVar: Option[Variable] = None) + (implicit fd: FunDef, letIdMap: Map[Identifier,Identifier]): Expr = { + + e match { + case t: Terminal => InfiniteIntegerLiteral(0) + case FunctionInvocation(callFd, args) => { + // Need to extimate the size of the activation frame of this function. + // #Args + + // #LocalVals + + // #Temporaries created (assume tree-like evaluation of expressions. This will the maximum + // number of temporaries allocated. Also because we assume all the + // temporaries are allocated on the stack and not used only from registers) + + val numTemps = + if (callFd.body.isDefined) { + val (temp, stack) = estimateTemporaries(callFd.body.get) + temp + stack + } else 0 + val retVar = subInsts.last + val remSubInsts = subInsts.slice(0, subInsts.length - 1) + val totalInvocationCost = { + // model scala's tail recursion optimization here + if ((isTailCall(FunctionInvocation(callFd, args), fd) && fd.id == callFd.id) && optimiseTailCalls) + InfiniteIntegerLiteral(0) + else + retVar + } + val subeTimesExpr = addSubInstsIfNonZero(remSubInsts, InfiniteIntegerLiteral(0)) + + subeTimesExpr match { + case InfiniteIntegerLiteral(x) if (x == 0) => totalInvocationCost + case _ => + addSubInstsIfNonZero(remSubInsts :+ totalInvocationCost, InfiniteIntegerLiteral(0)) + } + } + case _ => addSubInstsIfNonZero(subInsts, InfiniteIntegerLiteral(0)) + } + } + + override def instrumentBody(bodyExpr: Expr, instExpr: Expr)(implicit fd: FunDef): Expr = { + val minActivationRecSize = 2 + val (temps, stack) = estimateTemporaries(bodyExpr) + //println(temps + " " + stack) + Plus(instExpr, InfiniteIntegerLiteral(temps + stack + fd.params.length + + 1 /*object ref*/ + + 1 /*return variable before jumping*/ + + minActivationRecSize /*Sometimes for some reason, there are holes in local vars*/)) + } + + def instrumentIfThenElseExpr(e: IfExpr, condInst: Option[Expr], thenInst: Option[Expr], + elzeInst: Option[Expr]): (Expr, Expr) = { + import invariant.util.Util._ + + val cinst = condInst.toList + val tinst = thenInst.toList + val einst = elzeInst.toList + + (addSubInstsIfNonZero(cinst ++ tinst, zero), + addSubInstsIfNonZero(cinst ++ einst, zero)) + } + + /* Tries to estimate the depth of the operand stack and the temporaries + (excluding arguments) needed by the bytecode. As the JVM might perform + some optimizations when actually executing the bytecode, what we compute + here is an upper bound on the memory needed to evaluate the expression + */ + // (temporaries, stackSize) + def estimateTemporaries(e: Expr): (Int, Int) = { + e match { + /* Like vals */ + case Let(binder: Identifier, value: Expr, body: Expr) => { + // One for the val created + Temps in expr on RHS of initilisation + Rem. body + val (valTemp, valStack) = estimateTemporaries(value) + val (bodyTemp, bodyStack) = estimateTemporaries(body) + (1 + valTemp + bodyTemp, Math.max(valStack, bodyStack)) + } + + case LetDef(fd: FunDef, body: Expr) => { + // The function definition does not take up stack space. Goes into the constant pool + estimateTemporaries(body) + } + + case FunctionInvocation(tfd: TypedFunDef, args: Seq[Expr]) => { + // One for the object reference. + stack for computing arguments and also the + // fact that the arguments go into the stack + val (temp, stack) = + args.foldLeft(((0, 0), 0))((t: ((Int, Int),Int), currExpr) => { + t match { + case (acc: (Int, Int), currExprNum: Int) => + val (seTemp, seStack) = estimateTemporaries(currExpr) + ((acc._1 + seTemp, Math.max(acc._2, currExprNum + seStack)), 1 + currExprNum) + } + })._1 + + (temp + 1 /*possibly if the ret val needs to be stored for future use*/, stack + 1) + } + + case MethodInvocation(rec: Expr, cd: ClassDef, tfd: TypedFunDef, args: Seq[Expr]) => { + val (recTemp, recStack) = estimateTemporaries(rec) + val (temp, stack) = + args.foldLeft(((recTemp, Math.max(args.length, recStack)), 0))((t: ((Int, Int),Int), currExpr) => { + t match { + case (acc: (Int, Int), currExprNum: Int) => + val (seTemp, seStack) = estimateTemporaries(currExpr) + ((acc._1 + seTemp, Math.max(acc._2, currExprNum + seStack)), 1 + currExprNum) + } + })._1 + + (temp + 1 /*possibly if the ret val needs to be stored for future use*/, stack + 1) + } + + case Application(caller: Expr, args: Seq[Expr]) => { + val (callerTemp, callerStack) = estimateTemporaries(caller) + args.foldLeft(((callerTemp, Math.max(args.length, callerStack)), 0))((t: ((Int, Int),Int), currExpr) => { + t match { + case (acc: (Int, Int), currExprNum: Int) => + val (seTemp, seStack) = estimateTemporaries(currExpr) + ((acc._1 + seTemp, Math.max(acc._2, currExprNum + seStack)), 1 + currExprNum) + } + })._1 + } + + case IfExpr(cond: Expr, thenn: Expr, elze: Expr) => { + val (condTemp, condStack) = estimateTemporaries(cond) + val (thennTemp, thennStack) = estimateTemporaries(thenn) + val (elzeTemp, elzeStack) = estimateTemporaries(elze) + + (condTemp + thennTemp + elzeTemp, + Math.max(condStack, Math.max(thennStack, elzeStack))) + } + + case Tuple (exprs: Seq[Expr]) => { + val (temp, stack) = + exprs.foldLeft(((0, 0), 0))((t: ((Int, Int),Int), currExpr) => { + t match { + case (acc: (Int, Int), currExprNum: Int) => + val (seTemp, seStack) = estimateTemporaries(currExpr) + ((acc._1 + seTemp, Math.max(acc._2, currExprNum + seStack)), 1 + currExprNum) + } + })._1 + + (temp, stack + 2) + } + + case MatchExpr(scrutinee: Expr, cases: Seq[MatchCase]) => { + + // FIXME + def estimateTemporariesMatchPattern(pattern: Pattern): (Int, Int) = { + pattern match { + case InstanceOfPattern(binder: Option[Identifier], ct: ClassType) => { // c: Class + (0,1) + } + + case WildcardPattern(binder: Option[Identifier]) => { // c @ _ + (if (binder.isDefined) 1 else 0, 0) + } + + case CaseClassPattern(binder: Option[Identifier], ct: CaseClassType, subPatterns: Seq[Pattern]) => { + val (temp, stack) = + subPatterns.foldLeft((1 /* create a new var for matching */, 1))((t: (Int, Int), currPattern) => { + t match { + case acc: (Int, Int) => { + val (patTemp, patStack) = estimateTemporariesMatchPattern(currPattern) + (acc._1 + patTemp, Math.max(acc._2, patStack)) + } + } + }) + + (temp, stack) + } + + case TuplePattern(binder: Option[Identifier], subPatterns: Seq[Pattern]) => { + val (temp, stack) = + subPatterns.foldLeft((1 /* create a new var for matching */, 1))((t: (Int, Int), currPattern) => { + t match { + case acc: (Int, Int) => { + val (patTemp, patStack) = estimateTemporariesMatchPattern(currPattern) + (acc._1 + patTemp, Math.max(acc._2, patStack)) + } + } + }) + + (temp, stack) + } + + case LiteralPattern(binder, lit) => { + (0,2) + } + case _ => + throw new NotImplementedError(s"Pattern $pattern not handled yet!") + } + } + + val (scrTemp, scrStack) = estimateTemporaries(scrutinee) + + val res = cases.foldLeft(((scrTemp + 1 /* create a new var for matching */, Math.max(scrStack, 3 /*MatchError*/))))((t: (Int, Int), currCase: MatchCase) => { + t match { + case acc: (Int, Int) => + val (patTemp, patStack) = estimateTemporariesMatchPattern(currCase.pattern) + val (rhsTemp, rhsStack) = estimateTemporaries(currCase.rhs) + val (guardTemp, guardStack) = + if (currCase.optGuard.isDefined) estimateTemporaries(currCase.optGuard.get) else (0,0) + + (patTemp + rhsTemp + guardTemp + acc._1, + Math.max(acc._2, Math.max(patStack, Math.max(guardStack, rhsStack)))) + } + }) + res + } + + /* Propositional logic */ + case Implies(lhs: Expr, rhs: Expr) => { + val (lhsTemp, lhsStack)= estimateTemporaries(lhs) + val (rhsTemp, rhsStack)= estimateTemporaries(rhs) + (rhsTemp + lhsTemp, 1 + Math.max(lhsStack, rhsStack)) + } + + case Not(expr: Expr) => estimateTemporaries(expr) + + case Equals(lhs: Expr, rhs: Expr) => { + val (lhsTemp, lhsStack)= estimateTemporaries(lhs) + val (rhsTemp, rhsStack)= estimateTemporaries(rhs) + (rhsTemp + lhsTemp + + // If object ref, check for non nullity + 1, + //(if (!(lhs.getType == IntegerType && rhs.getType == IntegerType)) 1 else 0), + 1 + Math.max(lhsStack, rhsStack)) + } + + case CaseClass(ct: CaseClassType, args: Seq[Expr]) => { + val (temp, stack) = + args.foldLeft(((0, 0), 0))((t: ((Int, Int),Int), currExpr) => { + t match { + case (acc: (Int, Int), currExprNum: Int) => + val (seTemp, seStack) = estimateTemporaries(currExpr) + ((acc._1 + seTemp, Math.max(acc._2, currExprNum + seStack)), 1 + currExprNum) + } + })._1 + + (temp, stack + 2) + } + + case _: Literal[_] => (0, 1) + + case Variable(id: Identifier) => (0, 1) + + case Lambda(args: Seq[ValDef], body: Expr) => (0, 0) + + case TupleSelect(tuple: Expr, index: Int) => (0, 1) + + /*case BinaryOperator(s1,s2,_) => { + val (s1Temp, s1Stack)= estimateTemporaries(s1) + val (s2Temp, s2Stack)= estimateTemporaries(s2) + (s1Temp + s2Temp, Math.max(s1Stack, 1 + s2Stack)) + }*/ + + case Operator(exprs, _) => { + exprs.foldLeft(((0, 0), 0))((t: ((Int, Int),Int), currExpr) => { + t match { + case (acc: (Int, Int), currExprNum: Int) => + val (seTemp, seStack) = estimateTemporaries(currExpr) + ((acc._1 + seTemp, Math.max(acc._2, currExprNum + seStack)), 1 + currExprNum) + } + })._1 + } + + case _ => (0, 0) + } + } +} diff --git a/src/main/scala/leon/transformations/TimeStepsPhase.scala b/src/main/scala/leon/transformations/TimeStepsPhase.scala new file mode 100644 index 0000000000000000000000000000000000000000..4d8613649d3fd4ac7928b68cd8be7c6c3203fb99 --- /dev/null +++ b/src/main/scala/leon/transformations/TimeStepsPhase.scala @@ -0,0 +1,94 @@ +package leon +package transformations + +import purescala.Common._ +import purescala.Definitions._ +import purescala.Extractors._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.Types._ +import leon.utils._ +import leon.invariant.util.Util._ + +object timeCostModel { + def costOf(e: Expr): Int = e match { + case FunctionInvocation(fd, args) => 1 + case t: Terminal => 0 + case _ => 1 + } + + def costOfExpr(e: Expr) = InfiniteIntegerLiteral(costOf(e)) +} + +class TimeInstrumenter(p: Program, si: SerialInstrumenter) extends Instrumenter(p, si) { + import timeCostModel._ + + def inst = Time + + def functionsToInstrument(): Map[FunDef, List[Instrumentation]] = { + //find all functions transitively called from rootFuncs (here ignore functions called via pre/post conditions) + val instFunSet = getRootFuncs().foldLeft(Set[FunDef]())((acc, fd) => acc ++ cg.transitiveCallees(fd)) + instFunSet.map(x => (x, List(Time))).toMap + } + + def additionalfunctionsToAdd() = Seq() + + def instrumentMatchCase( + me: MatchExpr, + mc: MatchCase, + caseExprCost: Expr, + scrutineeCost: Expr): Expr = { + val costMatch = costOfExpr(me) + + def totalCostOfMatchPatterns(me: MatchExpr, mc: MatchCase): BigInt = { + + def patCostRecur(pattern: Pattern, innerPat: Boolean, countLeafs: Boolean): Int = { + pattern match { + case InstanceOfPattern(_, _) => { + if (innerPat) 2 else 1 + } + case WildcardPattern(None) => 0 + case WildcardPattern(Some(id)) => { + if (countLeafs && innerPat) 1 + else 0 + } + case CaseClassPattern(_, _, subPatterns) => { + (if (innerPat) 2 else 1) + subPatterns.foldLeft(0)((acc, subPat) => + acc + patCostRecur(subPat, true, countLeafs)) + } + case TuplePattern(_, subPatterns) => { + (if (innerPat) 2 else 1) + subPatterns.foldLeft(0)((acc, subPat) => + acc + patCostRecur(subPat, true, countLeafs)) + } + case LiteralPattern(_, _) => if (innerPat) 2 else 1 + case _ => + throw new NotImplementedError(s"Pattern $pattern not handled yet!") + } + } + + me.cases.take(me.cases.indexOf(mc)).foldLeft(0)( + (acc, currCase) => acc + patCostRecur(currCase.pattern, false, false)) + + patCostRecur(mc.pattern, false, true) + } + + Plus(costMatch, Plus( + Plus(InfiniteIntegerLiteral(totalCostOfMatchPatterns(me, mc)), + caseExprCost), + scrutineeCost)) + } + + def instrument(e: Expr, subInsts: Seq[Expr], funInvResVar: Option[Variable] = None) + (implicit fd: FunDef, letIdMap: Map[Identifier,Identifier]): Expr = e match { + case t: Terminal => costOfExpr(t) + case _ => + subInsts.foldLeft(costOfExpr(e) : Expr)( + (acc: Expr, subeTime: Expr) => Plus(subeTime, acc)) + } + + def instrumentIfThenElseExpr(e: IfExpr, condInst: Option[Expr], + thenInst: Option[Expr], elzeInst: Option[Expr]): (Expr, Expr) = { + val costIf = costOfExpr(e) + (Plus(costIf, Plus(condInst.get, thenInst.get)), + Plus(costIf, Plus(condInst.get, elzeInst.get))) + } +} \ No newline at end of file diff --git a/src/main/scala/leon/verification/InjectAsserts.scala b/src/main/scala/leon/verification/InjectAsserts.scala index 8244da31a6d9b8b97e559e7e2ca0b172ce4f7dcf..ffa8d13a0f2322a53270a52b3d4a2e82ef75cc8c 100644 --- a/src/main/scala/leon/verification/InjectAsserts.scala +++ b/src/main/scala/leon/verification/InjectAsserts.scala @@ -21,7 +21,7 @@ object InjectAsserts extends LeonPhase[Program, Program] { } pgm.definedFunctions.foreach(fd => { - fd.body = fd.body.map(postMap { + fd.body = fd.body.map(postMap { case e @ ArraySelect(a, i) => Some(Assert(indexUpTo(i, ArrayLength(a)), Some("Array index out of range"), e).setPos(e)) case e @ ArrayUpdated(a, i, v) => @@ -69,14 +69,14 @@ object InjectAsserts extends LeonPhase[Program, Program] { ).setPos(e)) case e @ RealDivision(_, d) => - Some(Assert(Not(Equals(d, RealLiteral(0))), + Some(Assert(Not(Equals(d, FractionalLiteral(0, 1))), Some("Division by zero"), e ).setPos(e)) case _ => None - }) + }) }) pgm diff --git a/src/test/resources/regression/orb/combined/InsertionSort.scala b/src/test/resources/regression/orb/combined/InsertionSort.scala new file mode 100644 index 0000000000000000000000000000000000000000..8fd79a2e89f60441fd522584fae4197079f9294e --- /dev/null +++ b/src/test/resources/regression/orb/combined/InsertionSort.scala @@ -0,0 +1,26 @@ +import leon.invariant._ +import leon.instrumentation._ + +object InsertionSort { + sealed abstract class List + case class Cons(head: BigInt, tail:List) extends List + case class Nil() extends List + + def size(l : List) : BigInt = (l match { + case Cons(_, xs) => 1 + size(xs) + case _ => 0 + }) + + def sortedIns(e: BigInt, l: List): List = { + l match { + case Cons(x,xs) => if (x <= e) Cons(x,sortedIns(e, xs)) else Cons(e, l) + case _ => Cons(e,Nil()) + } + } ensuring(res => size(res) == size(l) + 1 && tmpl((a,b) => time <= a*size(l) +b && depth <= a*size(l) +b)) + + def sort(l: List): List = (l match { + case Cons(x,xs) => sortedIns(x, sort(xs)) + case _ => Nil() + + }) ensuring(res => size(res) == size(l) && tmpl((a,b) => time <= a*(size(l)*size(l)) +b && rec <= a*size(l) + b)) +} diff --git a/src/test/resources/regression/orb/depth/BinaryTrie.scala b/src/test/resources/regression/orb/depth/BinaryTrie.scala new file mode 100755 index 0000000000000000000000000000000000000000..562ab19ea8913551bafd45e429ec28c98cbee064 --- /dev/null +++ b/src/test/resources/regression/orb/depth/BinaryTrie.scala @@ -0,0 +1,121 @@ +import leon.instrumentation._ +import leon.invariant._ + +import scala.collection.immutable.Set + +object ParallelBinaryTrie { + sealed abstract class Tree + case class Leaf() extends Tree + case class Node(nvalue: BigInt, left: Tree, right: Tree) extends Tree + + sealed abstract class IList + case class Cons(head: BigInt, tail: IList) extends IList + case class Nil() extends IList + + def listSize(l: IList): BigInt = (l match { + case Nil() => 0 + case Cons(x, xs) => 1 + listSize(xs) + }) + + def height(t: Tree): BigInt = { + t match { + case Leaf() => 0 + case Node(x, l, r) => { + val hl = height(l) + val hr = height(r) + if (hl > hr) hl + 1 else hr + 1 + } + } + } + + def find(inp: IList, t: Tree): Tree = { + inp match { + case Nil() => t + case Cons(x, Nil()) => t + case Cons(x, xs @ Cons(y, _)) => { + t match { + case Leaf() => t + case Node(v, l, r) => { + if (y > 0) find(xs, l) else find(xs, r) + } + } + } + case _ => t + } + } ensuring (res => true && tmpl ((a, c) => depth <= a * listSize(inp) + c)) + + def insert(inp: IList, t: Tree): Tree = { + t match { + case Leaf() => { + inp match { + case Nil() => t + case Cons(x, xs) => { + val newch = insert(xs, Leaf()) + newch match { + case Leaf() => Node(x, Leaf(), Leaf()) + case Node(y, _, _) => if (y > 0) Node(x, newch, Leaf()) else Node(y, Leaf(), newch) + } + } + } + + } + case Node(v, l, r) => { + inp match { + case Nil() => t + case Cons(x, Nil()) => t + case Cons(x, xs @ Cons(y, _)) => { + val ch = if (y > 0) l else r + if (y > 0) + Node(v, insert(xs, ch), r) + else + Node(v, l, insert(xs, ch)) + } + case _ => t + } + } + } + } ensuring (res => true && tmpl ((a, c) => depth <= a * listSize(inp) + c)) + + def create(inp: IList): Tree = { + insert(inp, Leaf()) + } ensuring (res => true && tmpl ((a, c) => depth <= a * listSize(inp) + c)) + + def delete(inp: IList, t: Tree): Tree = { + t match { + case Leaf() => { + inp match { + case Nil() => Leaf() + case Cons(x ,xs) => { + //the input is not in the tree, so do nothing + Leaf() + } + } + } + case Node(v, l, r) => { + inp match { + case Nil() => { + //the tree has extensions of the input list so do nothing + t + } + case Cons(x, Nil()) => { + //if "l" and "r" are nil, remove the node + if(l == Leaf() && r == Leaf()) Leaf() + else t + } + case Cons(x ,xs@Cons(y, _)) => { + val ch = if(y > 0) l else r + val newch = delete(xs, ch) + if(newch == Leaf() && ((y > 0 && r == Leaf()) || (y <= 0 && l == Leaf()))) Leaf() + else { + if(y > 0) + Node(v, newch, r) + else + Node(v, l, newch) + } + } + case _ => t + } + } + } + } ensuring (res => true && tmpl ((a, c) => depth <= a * listSize(inp) + c)) +} diff --git a/src/test/resources/regression/orb/depth/Folds.scala b/src/test/resources/regression/orb/depth/Folds.scala new file mode 100755 index 0000000000000000000000000000000000000000..305446be53b27c976080968b881119ce06de11e6 --- /dev/null +++ b/src/test/resources/regression/orb/depth/Folds.scala @@ -0,0 +1,82 @@ +import leon.instrumentation._ +import leon.invariant._ + + +object TreeMaps { + + sealed abstract class Tree + case class Node(left: Tree, value: BigInt, right: Tree) extends Tree + case class Leaf() extends Tree + + def height(t: Tree): BigInt = { + t match { + case Leaf() => 0 + case Node(l, x, r) => { + val hl = height(l) + val hr = height(r) + if (hl > hr) hl + 1 else hr + 1 + } + } + } + + def parallelSearch(elem : BigInt, t : Tree) : Boolean = { + t match { + case Leaf() => false + case Node(l, x, r) => + if(x == elem) true + else { + val r1 = parallelSearch(elem, r) + val r2 = parallelSearch(elem, l) + if(r1 || r2) true + else false + } + } + } ensuring(res => true && tmpl((a,b) => depth <= a*height(t) + b)) + + + def squareMap(t : Tree) : Tree = { + t match { + case Leaf() => t + case Node(l, x, r) => + val nl = squareMap(l) + val nr = squareMap(r) + Node(nl, x*x, nr) + } + } ensuring (res => true && tmpl((a,b) => depth <= a*height(t) + b)) + + sealed abstract class List + case class Cons(head: BigInt, tail: List) extends List + case class Nil() extends List + + def size(l: List): BigInt = (l match { + case Nil() => 0 + case Cons(_, t) => 1 + size(t) + }) + + def fact(n : BigInt) : BigInt = { + require(n >= 0) + + if(n == 1 || n == 0) BigInt(1) + else n * fact(n-1) + + } ensuring(res => tmpl((a,b) => depth <= a*n + b)) + + def descending(l: List, k: BigInt) : Boolean = { + l match { + case Nil() => true + case Cons(x, t) => x > 0 && x <= k && descending(t, x-1) + } + } + + def factMap(l: List, k: BigInt): List = { + require(descending(l, k) && k >= 0) + + l match { + case Nil() => Nil() + case Cons(x, t) => { + val f = fact(x) + Cons(f, factMap(t, x-1)) + } + + }} ensuring(res => true && tmpl((a,b) => depth <= a*k + b)) +} \ No newline at end of file diff --git a/src/test/resources/regression/orb/depth/ListOperations.scala b/src/test/resources/regression/orb/depth/ListOperations.scala new file mode 100644 index 0000000000000000000000000000000000000000..7d77ba83435f1e5d72951f1c8e0cd72ad8adbb7f --- /dev/null +++ b/src/test/resources/regression/orb/depth/ListOperations.scala @@ -0,0 +1,44 @@ +import leon.instrumentation._ +import leon.invariant._ +import leon.annotation._ + +object ListOperations { + sealed abstract class List + case class Cons(head: BigInt, tail: List) extends List + case class Nil() extends List + + def size(l: List): BigInt = (l match { + case Nil() => 0 + case Cons(_, t) => 1 + size(t) + }) + + def append(l1: List, l2: List): List = (l1 match { + case Nil() => l2 + case Cons(x, xs) => Cons(x, append(xs, l2)) + + }) ensuring (res => size(l1) + size(l2) == size(res) && tmpl((a,b) => depth <= a*size(l1) + b)) + + def reverseRec(l1: List, l2: List): List = (l1 match { + case Nil() => l2 + case Cons(x, xs) => reverseRec(xs, Cons(x, l2)) + + }) ensuring (res => size(l1) + size(l2) == size(res) && tmpl((a,b) => depth <= a*size(l1) + b)) + + def reverse(l: List): List = { + reverseRec(l, Nil()) + + } ensuring (res => size(l) == size(res) && tmpl((a,b) => depth <= a*size(l) + b)) + + def remove(elem: BigInt, l: List): List = { + l match { + case Nil() => Nil() + case Cons(hd, tl) => if (hd == elem) remove(elem, tl) else Cons(hd, remove(elem, tl)) + } + } ensuring (res => size(l) >= size(res) && tmpl((a,b) => depth <= a*size(l) + b)) + + def contains(list: List, elem: BigInt): Boolean = (list match { + case Nil() => false + case Cons(x, xs) => x == elem || contains(xs, elem) + + }) ensuring (res => true && tmpl((a,b) => depth <= a*size(list) + b)) +} diff --git a/src/test/resources/regression/orb/depth/PropLogicDepth.scala b/src/test/resources/regression/orb/depth/PropLogicDepth.scala new file mode 100644 index 0000000000000000000000000000000000000000..881cd61a4b31ad2b40ab22b84651f4362e00878c --- /dev/null +++ b/src/test/resources/regression/orb/depth/PropLogicDepth.scala @@ -0,0 +1,112 @@ +import scala.collection.immutable.Set +import leon.instrumentation._ +import leon.invariant._ +import leon.annotation._ + +object PropLogicDepth { + + sealed abstract class Formula + case class And(lhs: Formula, rhs: Formula) extends Formula + case class Or(lhs: Formula, rhs: Formula) extends Formula + case class Implies(lhs: Formula, rhs: Formula) extends Formula + case class Not(f: Formula) extends Formula + case class Literal(id: BigInt) extends Formula + case class True() extends Formula + case class False() extends Formula + + def max(x: BigInt,y: BigInt) = if (x >= y) x else y + + def nestingDepth(f: Formula) : BigInt = (f match { + case And(lhs, rhs) => max(nestingDepth(lhs),nestingDepth(rhs)) + 1 + case Or(lhs, rhs) => max(nestingDepth(lhs),nestingDepth(rhs)) + 1 + case Implies(lhs, rhs) => max(nestingDepth(lhs),nestingDepth(rhs)) + 1 + case Not(f) => nestingDepth(f) + 1 + case _ => 1 + }) + + def removeImplies(f: Formula): Formula = (f match { + case And(lhs, rhs) => And(removeImplies(lhs), removeImplies(rhs)) + case Or(lhs, rhs) => Or(removeImplies(lhs), removeImplies(rhs)) + case Implies(lhs, rhs) => Or(Not(removeImplies(lhs)),removeImplies(rhs)) + case Not(f) => Not(removeImplies(f)) + case _ => f + + }) ensuring((res) => true && tmpl((a,b) => depth <= a*nestingDepth(f) + b)) + + def nnf(formula: Formula): Formula = (formula match { + case And(lhs, rhs) => And(nnf(lhs), nnf(rhs)) + case Or(lhs, rhs) => Or(nnf(lhs), nnf(rhs)) + case Implies(lhs, rhs) => Implies(nnf(lhs), nnf(rhs)) + case Not(And(lhs, rhs)) => Or(nnf(Not(lhs)), nnf(Not(rhs))) + case Not(Or(lhs, rhs)) => And(nnf(Not(lhs)), nnf(Not(rhs))) + case Not(Implies(lhs, rhs)) => And(nnf(lhs), nnf(Not(rhs))) + case Not(Not(f)) => nnf(f) + case Not(Literal(_)) => formula + case Literal(_) => formula + case Not(True()) => False() + case Not(False()) => True() + case _ => formula + }) ensuring((res) => true && tmpl((a,b) => depth <= a*nestingDepth(formula) + b)) + + def isNNF(f: Formula): Boolean = { f match { + case And(lhs, rhs) => isNNF(lhs) && isNNF(rhs) + case Or(lhs, rhs) => isNNF(lhs) && isNNF(rhs) + case Implies(lhs, rhs) => false + case Not(Literal(_)) => true + case Not(_) => false + case _ => true + }} ensuring((res) => true && tmpl((a,b) => depth <= a*nestingDepth(f) + b)) + + def simplify(f: Formula): Formula = (f match { + case And(lhs, rhs) => { + val sl = simplify(lhs) + val sr = simplify(rhs) + + //if lhs or rhs is false, return false + //if lhs is true return rhs + //if rhs is true return lhs + (sl,sr) match { + case (False(), _) => False() + case (_, False()) => False() + case (True(), _) => sr + case (_, True()) => sl + case _ => And(sl, sr) + } + } + case Or(lhs, rhs) => { + val sl = simplify(lhs) + val sr = simplify(rhs) + + //if lhs or rhs is true, return true + //if lhs is false return rhs + //if rhs is false return lhs + (sl,sr) match { + case (True(), _) => True() + case (_, True()) => True() + case (False(), _) => sr + case (_, False()) => sl + case _ => Or(sl, sr) + } + } + case Implies(lhs, rhs) => { + val sl = simplify(lhs) + val sr = simplify(rhs) + + //if lhs is false return true + //if rhs is true return true + //if lhs is true return rhs + //if rhs is false return Not(rhs) + (sl,sr) match { + case (False(), _) => True() + case (_, True()) => True() + case (True(), _) => sr + case (_, False()) => Not(sl) + case _ => Implies(sl, sr) + } + } + case Not(True()) => False() + case Not(False()) => True() + case _ => f + + }) ensuring((res) => true && tmpl((a,b) => depth <= a*nestingDepth(f) + b)) +} \ No newline at end of file diff --git a/src/test/resources/regression/orb/numerical/ConcatVariationsAbs.scala b/src/test/resources/regression/orb/numerical/ConcatVariationsAbs.scala new file mode 100644 index 0000000000000000000000000000000000000000..bff880ab30ac1495c7ab6c706265376f1fafe7ad --- /dev/null +++ b/src/test/resources/regression/orb/numerical/ConcatVariationsAbs.scala @@ -0,0 +1,43 @@ +import leon.invariant._ + +object ConcatVariationsAbs { + def genL(n: BigInt): BigInt = { + require(n >= 0) + if (n == 0) + BigInt(2) + else + 4 + genL(n - 1) + } ensuring (res => tmpl((a, b) => res <= a * n + b)) + + def append(l1: BigInt, l2: BigInt): BigInt = { + require(l1 >= 0 && l2 >= 0) + if (l1 == 0) + BigInt(3) + else + append(l1 - 1, l2 + 1) + 5 + } ensuring (res => tmpl((a, b) => res <= a * l1 + b)) + + def f_good(m: BigInt, n: BigInt): BigInt = { + require(0 <= m && 0 <= n) + if (m == 0) BigInt(2) + else { + val t1 = genL(n) + val t2 = f_good(m - 1, n) + val t3 = append(n, n * (m - 1)) + (t1 + t2 + t3 + 6) + } + + } ensuring (res => tmpl((a, b, c, d) => res <= a * (n * m) + b * n + c * m + d)) + + def f_worst(m: BigInt, n: BigInt): BigInt = { + require(0 <= m && 0 <= n) + if (m == 0) BigInt(2) + else { + val t1 = genL(n) + val t2 = f_worst(m - 1, n) + val t3 = append(n * (m - 1), n) + (t1 + t2 + t3 + 6) + } + + } ensuring (res => tmpl((a, c, d, e, f) => res <= a * ((n * m) * m) + c * (n * m) + d * n + e * m + f)) +} diff --git a/src/test/resources/regression/orb/numerical/QueueAbs.scala b/src/test/resources/regression/orb/numerical/QueueAbs.scala new file mode 100644 index 0000000000000000000000000000000000000000..a7aee3a93d4ca9ee058b2f0695f001d29fdf3acc --- /dev/null +++ b/src/test/resources/regression/orb/numerical/QueueAbs.scala @@ -0,0 +1,70 @@ +import leon.invariant._ + +object AmortizedQueue { + def concat(l1: BigInt, l2: BigInt): BigInt = { + require(l1 >= 0 && l2 >= 0) + if (l1 == 0) + BigInt(3) + else + concat(l1 - 1, l2 + 1) + 5 + } ensuring (res => tmpl((a, b) => res <= a * l1 + b)) + + def reverseRec(l1: BigInt, l2: BigInt): BigInt = { + require(l1 >= 0 && l2 >= 0) + if (l1 == 0) + BigInt(3) + else { + reverseRec(l1 - 1, l2 + 1) + 6 + } + } ensuring (res => tmpl((a, b) => res <= a * l1 + b)) + + def reverse(l: BigInt): BigInt = { + require(l >= 0) + reverseRec(l, 0) + 1 + } ensuring (res => tmpl((a, b) => res <= a * l + b)) + + def create(front: BigInt, rear: BigInt): BigInt = { + require(front >= 0 && rear >= 0) + if (rear <= front) + BigInt(4) + else { + val t1 = reverse(rear) + val t2 = concat(front, rear) + t1 + t2 + 7 + } + } + + def enqueue(q: BigInt, front: BigInt, rear: BigInt): BigInt = { + require(q == front + rear && q >= 0 && front >= 0 && rear >= 0) + create(front, rear) + 5 + } ensuring (res => tmpl((a, b) => res <= a * q + b)) + + def dequeue(q: BigInt, front: BigInt, rear: BigInt): BigInt = { + require(q == front + rear && q >= 1 && front >= rear && rear >= 0) + if (front >= 1) { + create(front - 1, rear) + 4 + } else { + //since front should be greater than rear, here rear should be 0 as well + BigInt(5) + } + } ensuring (res => tmpl((a, b) => res <= a * q + b)) + + def removeLast(l: BigInt): BigInt = { + require(l >= 1) + if (l == 1) { + BigInt(4) + } else { + removeLast(l - 1) + 6 + } + } ensuring (res => tmpl((a, b) => res <= a * l + b)) + + def pop(q: BigInt, front: BigInt, rear: BigInt): BigInt = { + require(q == front + rear && q >= 1 && front >= rear && rear >= 0) + if (rear >= 1) { + BigInt(3) + } else { + val t1 = removeLast(front) + t1 + 5 + } + } ensuring (res => tmpl((a, b) => res <= a * q + b)) +} diff --git a/src/test/resources/regression/orb/numerical/SimpleLoop.scala b/src/test/resources/regression/orb/numerical/SimpleLoop.scala new file mode 100755 index 0000000000000000000000000000000000000000..6a2cdb3d9958f5ad41783d3723b1c08ef7c0ba16 --- /dev/null +++ b/src/test/resources/regression/orb/numerical/SimpleLoop.scala @@ -0,0 +1,9 @@ +object SimpleLoop +{ + def s(x: BigInt) : BigInt = { + if(x < 0) + BigInt(0) + else + s(x-1) + 1 + } ensuring(res => res != -1) +} \ No newline at end of file diff --git a/src/test/resources/regression/orb/numerical/see-saw.scala b/src/test/resources/regression/orb/numerical/see-saw.scala new file mode 100644 index 0000000000000000000000000000000000000000..894a8caed2a298a22a89d8bb6925cc03f022407c --- /dev/null +++ b/src/test/resources/regression/orb/numerical/see-saw.scala @@ -0,0 +1,15 @@ +object SeeSaw { + def s(x: BigInt, y: BigInt, z: BigInt): BigInt = { + require(y >= 0) + + if (x >= 100) { + y + } else if (x <= z) { //some condition + s(x + 1, y + 2, z) + } else if (x <= z + 9) { //some condition + s(x + 1, y + 3, z) + } else { + s(x + 2, y + 1, z) + } + } ensuring (res => (100 - x <= 2 * res)) +} \ No newline at end of file diff --git a/src/test/resources/regression/orb/stack/BinaryTrie.scala b/src/test/resources/regression/orb/stack/BinaryTrie.scala new file mode 100644 index 0000000000000000000000000000000000000000..f2dfd876cdbc51c969f4afc7d1548810be111c02 --- /dev/null +++ b/src/test/resources/regression/orb/stack/BinaryTrie.scala @@ -0,0 +1,120 @@ +import leon.invariant._ +import leon.instrumentation._ +//import scala.collection.immutable.Set + +object BinaryTrie { + sealed abstract class Tree + case class Leaf() extends Tree + case class Node(nvalue: BigInt, left: Tree, right: Tree) extends Tree + + sealed abstract class IList + case class Cons(head: BigInt, tail: IList) extends IList + case class Nil() extends IList + + def listSize(l: IList): BigInt = (l match { + case Nil() => 0 + case Cons(x, xs) => 1 + listSize(xs) + }) + + def height(t: Tree): BigInt = { + t match { + case Leaf() => 0 + case Node(x, l, r) => { + val hl = height(l) + val hr = height(r) + if (hl > hr) hl + 1 else hr + 1 + } + } + } + + def find(inp: IList, t: Tree): Tree = { + inp match { + case Nil() => t + case Cons(x, Nil()) => t + case Cons(x, xs @ Cons(y, _)) => { + t match { + case Leaf() => t + case Node(v, l, r) => { + if (y > 0) find(xs, l) else find(xs, r) + } + } + } + case _ => t + } + } ensuring (_ => stack <= ? * listSize(inp) + ?) + + def insert(inp: IList, t: Tree): Tree = { + t match { + case Leaf() => { + inp match { + case Nil() => t + case Cons(x, xs) => { + val newch = insert(xs, Leaf()) + newch match { + case Leaf() => Node(x, Leaf(), Leaf()) + case Node(y, _, _) => if (y > 0) Node(x, newch, Leaf()) else Node(y, Leaf(), newch) + } + } + } + + } + case Node(v, l, r) => { + inp match { + case Nil() => t + case Cons(x, Nil()) => t + case Cons(x, xs @ Cons(y, _)) => { + val ch = if (y > 0) l else r + if (y > 0) + Node(v, insert(xs, ch), r) + else + Node(v, l, insert(xs, ch)) + } + case _ => t + } + } + } + } ensuring (_ => stack <= ? * listSize(inp) + ?) + + def create(inp: IList): Tree = { + insert(inp, Leaf()) + } ensuring (_ => stack <= ? * listSize(inp) + ?) + + def delete(inp: IList, t: Tree): Tree = { + t match { + case Leaf() => { + inp match { + case Nil() => Leaf() + case Cons(x ,xs) => { + //the input is not in the tree, so do nothing + Leaf() + } + } + } + case Node(v, l, r) => { + inp match { + case Nil() => { + //the tree has extensions of the input list so do nothing + t + } + case Cons(x, Nil()) => { + //if "l" and "r" are nil, remove the node + if(l == Leaf() && r == Leaf()) Leaf() + else t + } + case Cons(x ,xs@Cons(y, _)) => { + val ch = if(y > 0) l else r + val newch = delete(xs, ch) + if(newch == Leaf() && ((y > 0 && r == Leaf()) || (y <= 0 && l == Leaf()))) Leaf() + else { + if(y > 0) + Node(v, newch, r) + else + Node(v, l, newch) + } + } + case _ => t + } + } + } + } ensuring (_ => stack <= ? * listSize(inp) + ?) +} diff --git a/src/test/resources/regression/orb/stack/ListOperations.scala b/src/test/resources/regression/orb/stack/ListOperations.scala new file mode 100644 index 0000000000000000000000000000000000000000..3ae4a2d1705051b185c444c7342bf9b180fc4086 --- /dev/null +++ b/src/test/resources/regression/orb/stack/ListOperations.scala @@ -0,0 +1,35 @@ +import leon.invariant._ +import leon.instrumentation._ + +object ListOperations { + sealed abstract class List + case class Cons(head: BigInt, tail: List) extends List + case class Nil() extends List + + def size(l: List): BigInt = (l match { + case Nil() => 0 + case Cons(_, t) => 1 + size(t) + }) + + def reverseRec(l1: List, l2: List): List = (l1 match { + case Nil() => l2 + case Cons(x, xs) => reverseRec(xs, Cons(x, l2)) + + }) ensuring (res => size(l1) + size(l2) == size(res) && tmpl((a,b) => stack <= a*size(l1) + b)) + + def contains(list: List, elem: BigInt): Boolean = (list match { + case Nil() => false + case Cons(x, xs) => x == elem || contains(xs, elem) + + }) ensuring (res => tmpl((a,b) => stack <= a*size(list) + b)) + + def distinct(l: List): List = ( + l match { + case Nil() => Nil() + case Cons(x, xs) => { + val newl = distinct(xs) + if (contains(newl, x)) newl + else Cons(x, newl) + } + }) ensuring (res => size(l) >= size(res) && tmpl((a,b) => stack <= a*size(l) + b)) +} diff --git a/src/test/resources/regression/orb/stack/SpeedBenchmarks.scala b/src/test/resources/regression/orb/stack/SpeedBenchmarks.scala new file mode 100644 index 0000000000000000000000000000000000000000..c1c59d592b2b0b59cfe79b780aaca479fcbb222d --- /dev/null +++ b/src/test/resources/regression/orb/stack/SpeedBenchmarks.scala @@ -0,0 +1,75 @@ +import leon.invariant._ +import leon.instrumentation._ +import leon.math._ + +object SpeedBenchmarks { + sealed abstract class List + case class Cons(head: BigInt, tail: List) extends List + case class Nil() extends List + + def size(l: List): BigInt = (l match { + case Nil() => 0 + case Cons(_, t) => 1 + size(t) + }) + + sealed abstract class StringBuffer + case class Chunk(str: List, next: StringBuffer) extends StringBuffer + case class Empty() extends StringBuffer + + def length(sb: StringBuffer) : BigInt = sb match { + case Chunk(_, next) => 1 + length(next) + case _ => 0 + } + + def sizeBound(sb: StringBuffer, k: BigInt) : Boolean ={ + sb match { + case Chunk(str, next) => size(str) <= k && sizeBound(next, k) + case _ => 0 <= k + } + } + + def sizeBuffer(sb: StringBuffer): BigInt = { + sb match { + case Chunk(str, next) => size(str) + sizeBuffer(sb) + case Empty() => 0 + } + } + + /** + * Fig. 1 of SPEED, POPL'09: The functional version of the implementation. + * Equality check of two string buffers + */ + def Equals(str1: List, str2: List, s1: StringBuffer, s2: StringBuffer, k: BigInt) : Boolean = { + require(sizeBound(s1, k) && sizeBound(s2, k) && size(str1) <= k && size(str2) <= k && k >= 0) + + (str1, str2) match { + case (Cons(h1,t1), Cons(h2,t2)) => { + if(h1 != h2) false + else Equals(t1,t2, s1,s2, k) + } + case (Cons(_,_), Nil()) => { + //load from s2 + s2 match { + case Chunk(str, next) => Equals(str1, str, s1, next, k) + case Empty() => false + } + } + case (Nil(), Cons(_,_)) => { + //load from s1 + s1 match { + case Chunk(str, next) => Equals(str, str2, next, s2, k) + case Empty() => false + } + } + case _ =>{ + //load from both + (s1,s2) match { + case (Chunk(nstr1, next1),Chunk(nstr2, next2)) => Equals(nstr1, nstr2, next1, next2, k) + case (Empty(),Chunk(nstr2, next2)) => Equals(str1, nstr2, s1, next2, k) + case (Chunk(nstr1, next1), Empty()) => Equals(nstr1, str2, next1, s2, k) + case _ => true + } + } + } + } ensuring(res => tmpl((a,b,c,d,e) => stack <= a*max(sizeBuffer(s1), sizeBuffer(s2)) + c*(k+1) + e)) +} diff --git a/src/test/resources/regression/orb/timing/BinaryTrie.scala b/src/test/resources/regression/orb/timing/BinaryTrie.scala new file mode 100644 index 0000000000000000000000000000000000000000..a1de6ee0e13bb53383b1bba6548e9e0fa449a166 --- /dev/null +++ b/src/test/resources/regression/orb/timing/BinaryTrie.scala @@ -0,0 +1,119 @@ +import leon.invariant._ +import leon.instrumentation._ + +object BinaryTrie { + sealed abstract class Tree + case class Leaf() extends Tree + case class Node(nvalue: BigInt, left: Tree, right: Tree) extends Tree + + sealed abstract class IList + case class Cons(head: BigInt, tail: IList) extends IList + case class Nil() extends IList + + def listSize(l: IList): BigInt = (l match { + case Nil() => 0 + case Cons(x, xs) => 1 + listSize(xs) + }) + + def height(t: Tree): BigInt = { + t match { + case Leaf() => 0 + case Node(x, l, r) => { + val hl = height(l) + val hr = height(r) + if (hl > hr) hl + 1 else hr + 1 + } + } + } + + def find(inp: IList, t: Tree): Tree = { + inp match { + case Nil() => t + case Cons(x, Nil()) => t + case Cons(x, xs @ Cons(y, _)) => { + t match { + case Leaf() => t + case Node(v, l, r) => { + if (y > 0) find(xs, l) else find(xs, r) + } + } + } + case _ => t + } + } ensuring (_ => time <= ? * listSize(inp) + ?) + + def insert(inp: IList, t: Tree): Tree = { + t match { + case Leaf() => { + inp match { + case Nil() => t + case Cons(x, xs) => { + val newch = insert(xs, Leaf()) + newch match { + case Leaf() => Node(x, Leaf(), Leaf()) + case Node(y, _, _) => if (y > 0) Node(x, newch, Leaf()) else Node(y, Leaf(), newch) + } + } + } + + } + case Node(v, l, r) => { + inp match { + case Nil() => t + case Cons(x, Nil()) => t + case Cons(x, xs @ Cons(y, _)) => { + val ch = if (y > 0) l else r + if (y > 0) + Node(v, insert(xs, ch), r) + else + Node(v, l, insert(xs, ch)) + } + case _ => t + } + } + } + } ensuring (_ => time <= ? * listSize(inp) + ?) + + def create(inp: IList): Tree = { + insert(inp, Leaf()) + } ensuring (res => true && tmpl((a, c) => time <= a * listSize(inp) + c)) + + def delete(inp: IList, t: Tree): Tree = { + t match { + case Leaf() => { + inp match { + case Nil() => Leaf() + case Cons(x ,xs) => { + //the input is not in the tree, so do nothing + Leaf() + } + } + } + case Node(v, l, r) => { + inp match { + case Nil() => { + //the tree has extensions of the input list so do nothing + t + } + case Cons(x, Nil()) => { + //if "l" and "r" are nil, remove the node + if(l == Leaf() && r == Leaf()) Leaf() + else t + } + case Cons(x ,xs@Cons(y, _)) => { + val ch = if(y > 0) l else r + val newch = delete(xs, ch) + if(newch == Leaf() && ((y > 0 && r == Leaf()) || (y <= 0 && l == Leaf()))) Leaf() + else { + if(y > 0) + Node(v, newch, r) + else + Node(v, l, newch) + } + } + case _ => t + } + } + } + } ensuring (_ => time <= ? * listSize(inp) + ?) +} diff --git a/src/test/resources/regression/orb/timing/BinomialHeap.scala b/src/test/resources/regression/orb/timing/BinomialHeap.scala new file mode 100644 index 0000000000000000000000000000000000000000..81b990d41323f353098f5cd02feddf3e25ec9264 --- /dev/null +++ b/src/test/resources/regression/orb/timing/BinomialHeap.scala @@ -0,0 +1,181 @@ +import leon.invariant._ +import leon.instrumentation._ + +object BinomialHeap { + //sealed abstract class TreeNode + case class TreeNode(rank: BigInt, elem: Element, children: BinomialHeap) + case class Element(n: BigInt) + + sealed abstract class BinomialHeap + case class ConsHeap(head: TreeNode, tail: BinomialHeap) extends BinomialHeap + case class NilHeap() extends BinomialHeap + + sealed abstract class List + case class NodeL(head: BinomialHeap, tail: List) extends List + case class NilL() extends List + + sealed abstract class OptionalTree + case class Some(t : TreeNode) extends OptionalTree + case class None() extends OptionalTree + + /* Lower or Equal than for Element structure */ + private def leq(a: Element, b: Element) : Boolean = { + a match { + case Element(a1) => { + b match { + case Element(a2) => { + if(a1 <= a2) true + else false + } + } + } + } + } + + /* isEmpty function of the Binomial Heap */ + def isEmpty(t: BinomialHeap) = t match { + case ConsHeap(_,_) => false + case _ => true + } + + /* Helper function to determine rank of a TreeNode */ + def rank(t: TreeNode) : BigInt = t.rank /*t match { + case TreeNode(r, _, _) => r + }*/ + + /* Helper function to get the root element of a TreeNode */ + def root(t: TreeNode) : Element = t.elem /*t match { + case TreeNode(_, e, _) => e + }*/ + + /* Linking trees of equal ranks depending on the root element */ + def link(t1: TreeNode, t2: TreeNode): TreeNode = { + if (leq(t1.elem, t2.elem)) { + TreeNode(t1.rank + 1, t1.elem, ConsHeap(t2, t1.children)) + } else { + TreeNode(t1.rank + 1, t2.elem, ConsHeap(t1, t2.children)) + } + } + + def treeNum(h: BinomialHeap) : BigInt = { + h match { + case ConsHeap(head, tail) => 1 + treeNum(tail) + case _ => 0 + } + } + + /* Insert a tree into a binomial heap. The tree should be correct in relation to the heap */ + def insTree(t: TreeNode, h: BinomialHeap) : BinomialHeap = { + h match { + case ConsHeap(head, tail) => { + if (rank(t) < rank(head)) { + ConsHeap(t, h) + } else if (rank(t) > rank(head)) { + ConsHeap(head, insTree(t,tail)) + } else { + insTree(link(t,head), tail) + } + } + case _ => ConsHeap(t, NilHeap()) + } + } ensuring(_ => time <= ? * treeNum(h) + ?) + + /* Merge two heaps together */ + def merge(h1: BinomialHeap, h2: BinomialHeap): BinomialHeap = { + h1 match { + case ConsHeap(head1, tail1) => { + h2 match { + case ConsHeap(head2, tail2) => { + if (rank(head1) < rank(head2)) { + ConsHeap(head1, merge(tail1, h2)) + } else if (rank(head2) < rank(head1)) { + ConsHeap(head2, merge(h1, tail2)) + } else { + mergeWithCarry(link(head1, head2), tail1, tail2) + } + } + case _ => h1 + } + } + case _ => h2 + } + } ensuring(_ => time <= ? * treeNum(h1) + ? * treeNum(h2) + ?) + + def mergeWithCarry(t: TreeNode, h1: BinomialHeap, h2: BinomialHeap): BinomialHeap = { + h1 match { + case ConsHeap(head1, tail1) => { + h2 match { + case ConsHeap(head2, tail2) => { + if (rank(head1) < rank(head2)) { + + if (rank(t) < rank(head1)) + ConsHeap(t, ConsHeap(head1, merge(tail1, h2))) + else + mergeWithCarry(link(t, head1), tail1, h2) + + } else if (rank(head2) < rank(head1)) { + + if (rank(t) < rank(head2)) + ConsHeap(t, ConsHeap(head2, merge(h1, tail2))) + else + mergeWithCarry(link(t, head2), h1, tail2) + + } else { + ConsHeap(t, mergeWithCarry(link(head1, head2), tail1, tail2)) + } + } + case _ => { + insTree(t, h1) + } + } + } + case _ => insTree(t, h2) + } + } ensuring (_ => time <= ? * treeNum(h1) + ? * treeNum(h2) + ?) + + //Auxiliary helper function to simplefy findMin and deleteMin + def removeMinTree(h: BinomialHeap): (OptionalTree, BinomialHeap) = { + h match { + case ConsHeap(head, NilHeap()) => (Some(head), NilHeap()) + case ConsHeap(head1, tail1) => { + val (opthead2, tail2) = removeMinTree(tail1) + opthead2 match { + case Some(head2) => + if (leq(root(head1), root(head2))) { + (Some(head1), tail1) + } else { + (Some(head2), ConsHeap(head1, tail2)) + } + case _ => (Some(head1), tail1) + } + } + case _ => (None(), NilHeap()) + } + } ensuring (res => treeNum(res._2) <= treeNum(h) && time <= ? * treeNum(h) + ?) + + /*def findMin(h: BinomialHeap) : Element = { + val (opt, _) = removeMinTree(h) + opt match { + case Some(TreeNode(_,e,ts1)) => e + case _ => Element(-1) + } + } ensuring(res => true && tmpl((a,b) => time <= a*treeNum(h) + b))*/ + + def minTreeChildren(h: BinomialHeap) : BigInt = { + val (min, _) = removeMinTree(h) + min match { + case Some(TreeNode(_,_,ch)) => treeNum(ch) + case _ => 0 + } + } + + // Discard the minimum element of the extracted min tree and put its children back into the heap + def deleteMin(h: BinomialHeap) : BinomialHeap = { + val (min, ts2) = removeMinTree(h) + min match { + case Some(TreeNode(_,_,ts1)) => merge(ts1, ts2) + case _ => h + } + } ensuring(_ => time <= ? * minTreeChildren(h) + ? * treeNum(h) + ?) + +} diff --git a/src/test/resources/regression/orb/timing/ConcatVariations.scala b/src/test/resources/regression/orb/timing/ConcatVariations.scala new file mode 100644 index 0000000000000000000000000000000000000000..a94fb418a48db5b27ab88cfbbe84a59233f394b0 --- /dev/null +++ b/src/test/resources/regression/orb/timing/ConcatVariations.scala @@ -0,0 +1,42 @@ +import leon.invariant._ +import leon.instrumentation._ + + +object ConcatVariations { + sealed abstract class List + case class Cons(head: BigInt, tail: List) extends List + case class Nil() extends List + + def size(l: List): BigInt = (l match { + case Nil() => 0 + case Cons(_, t) => 1 + size(t) + }) + + def genL(n: BigInt): List = { + require(n >= 0) + if (n == 0) Nil() + else + Cons(n, genL(n - 1)) + } ensuring (res => size(res) == n && tmpl((a,b) => time <= a*n + b)) + + def append(l1: List, l2: List): List = (l1 match { + case Nil() => l2 + case Cons(x, xs) => Cons(x, append(xs, l2)) + + }) ensuring (res => size(l1) + size(l2) == size(res) && tmpl((a,b) => time <= a*size(l1) + b)) + + def f_good(m: BigInt, n: BigInt): List = { + require(0 <= m && 0 <= n) + if (m == 0) Nil() + else append(genL(n), f_good(m - 1, n)) + + } ensuring(res => size(res) == n*m && tmpl((a,b,c,d) => time <= a*(n*m) + b*n + c*m +d)) + + def f_worst(m: BigInt, n: BigInt): List = { + require(0 <= m && 0 <= n) + + if (m == 0) Nil() + else append(f_worst(m - 1, n), genL(n)) + + } ensuring(res => size(res) == n*m && tmpl((a,c,d,e,f) => time <= a*((n*m)*m)+c*(n*m)+d*n+e*m+f)) +} diff --git a/src/test/resources/regression/orb/timing/ListOperations.scala b/src/test/resources/regression/orb/timing/ListOperations.scala new file mode 100644 index 0000000000000000000000000000000000000000..13b031b4da2cb89e5e2557af107a5757fef1e7f4 --- /dev/null +++ b/src/test/resources/regression/orb/timing/ListOperations.scala @@ -0,0 +1,40 @@ +import leon.invariant._ +import leon.instrumentation._ + +object ListOperations { + sealed abstract class List + case class Cons(head: BigInt, tail: List) extends List + case class Nil() extends List + + def size(l: List): BigInt = (l match { + case Nil() => 0 + case Cons(_, t) => 1 + size(t) + }) + + def append(l1: List, l2: List): List = (l1 match { + case Nil() => l2 + case Cons(x, xs) => Cons(x, append(xs, l2)) + }) ensuring (res => size(l1) + size(l2) == size(res) && tmpl((a,b) => time <= a*size(l1) + b)) + + def remove(elem: BigInt, l: List): List = { + l match { + case Nil() => Nil() + case Cons(hd, tl) => if (hd == elem) remove(elem, tl) else Cons(hd, remove(elem, tl)) + } + } ensuring (res => size(l) >= size(res) && tmpl((a,b) => time <= a*size(l) + b)) + + def contains(list: List, elem: BigInt): Boolean = (list match { + case Nil() => false + case Cons(x, xs) => x == elem || contains(xs, elem) + }) ensuring (res => true && tmpl((a,b) => time <= a*size(list) + b)) + + def distinct(l: List): List = ( + l match { + case Nil() => Nil() + case Cons(x, xs) => { + val newl = distinct(xs) + if (contains(newl, x)) newl + else Cons(x, newl) + } + }) ensuring (res => size(l) >= size(res) && tmpl((a,b) => time <= a*(size(l)*size(l)) + b)) +} diff --git a/src/test/resources/regression/orb/timing/PropositionalLogic.scala b/src/test/resources/regression/orb/timing/PropositionalLogic.scala new file mode 100644 index 0000000000000000000000000000000000000000..22dfdcdec06cef0760222140d669d11ae134a658 --- /dev/null +++ b/src/test/resources/regression/orb/timing/PropositionalLogic.scala @@ -0,0 +1,115 @@ +import scala.collection.immutable.Set +import leon.invariant._ +import leon.instrumentation._ + +object PropositionalLogic { + + sealed abstract class Formula + case class And(lhs: Formula, rhs: Formula) extends Formula + case class Or(lhs: Formula, rhs: Formula) extends Formula + case class Implies(lhs: Formula, rhs: Formula) extends Formula + case class Not(f: Formula) extends Formula + case class Literal(id: BigInt) extends Formula + case class True() extends Formula + case class False() extends Formula + + case class Pair(f: Formula, b: Boolean) + + sealed abstract class List + case class Cons(x: Pair, xs: List) extends List + case class Nil() extends List + + def size(f : Formula) : BigInt = (f match { + case And(lhs, rhs) => size(lhs) + size(rhs) + 1 + case Or(lhs, rhs) => size(lhs) + size(rhs) + 1 + case Implies(lhs, rhs) => size(lhs) + size(rhs) + 1 + case Not(f) => size(f) + 1 + case _ => 1 + }) + + def removeImplies(f: Formula): Formula = (f match { + case And(lhs, rhs) => And(removeImplies(lhs), removeImplies(rhs)) + case Or(lhs, rhs) => Or(removeImplies(lhs), removeImplies(rhs)) + case Implies(lhs, rhs) => Or(Not(removeImplies(lhs)),removeImplies(rhs)) + case Not(f) => Not(removeImplies(f)) + case _ => f + + }) ensuring(_ => time <= ? * size(f) + ?) + + def nnf(formula: Formula): Formula = (formula match { + case And(lhs, rhs) => And(nnf(lhs), nnf(rhs)) + case Or(lhs, rhs) => Or(nnf(lhs), nnf(rhs)) + case Implies(lhs, rhs) => Implies(nnf(lhs), nnf(rhs)) + case Not(And(lhs, rhs)) => Or(nnf(Not(lhs)), nnf(Not(rhs))) + case Not(Or(lhs, rhs)) => And(nnf(Not(lhs)), nnf(Not(rhs))) + case Not(Implies(lhs, rhs)) => And(nnf(lhs), nnf(Not(rhs))) + case Not(Not(f)) => nnf(f) + case Not(Literal(_)) => formula + case Literal(_) => formula + case Not(True()) => False() + case Not(False()) => True() + case _ => formula + }) ensuring(_ => time <= ? * size(formula) + ?) + + def isNNF(f: Formula): Boolean = { f match { + case And(lhs, rhs) => isNNF(lhs) && isNNF(rhs) + case Or(lhs, rhs) => isNNF(lhs) && isNNF(rhs) + case Implies(lhs, rhs) => false + case Not(Literal(_)) => true + case Not(_) => false + case _ => true + }} ensuring(_ => time <= ? * size(f) + ?) + + def simplify(f: Formula): Formula = (f match { + case And(lhs, rhs) => { + val sl = simplify(lhs) + val sr = simplify(rhs) + + //if lhs or rhs is false, return false + //if lhs is true return rhs + //if rhs is true return lhs + (sl,sr) match { + case (False(), _) => False() + case (_, False()) => False() + case (True(), _) => sr + case (_, True()) => sl + case _ => And(sl, sr) + } + } + case Or(lhs, rhs) => { + val sl = simplify(lhs) + val sr = simplify(rhs) + + //if lhs or rhs is true, return true + //if lhs is false return rhs + //if rhs is false return lhs + (sl,sr) match { + case (True(), _) => True() + case (_, True()) => True() + case (False(), _) => sr + case (_, False()) => sl + case _ => Or(sl, sr) + } + } + case Implies(lhs, rhs) => { + val sl = simplify(lhs) + val sr = simplify(rhs) + + //if lhs is false return true + //if rhs is true return true + //if lhs is true return rhs + //if rhs is false return Not(rhs) + (sl,sr) match { + case (False(), _) => True() + case (_, True()) => True() + case (True(), _) => sr + case (_, False()) => Not(sl) + case _ => Implies(sl, sr) + } + } + case Not(True()) => False() + case Not(False()) => True() + case _ => f + + }) ensuring(_ => time <= ? *size(f) + ?) +} diff --git a/src/test/resources/regression/orb/timing/SimpleMap.scala b/src/test/resources/regression/orb/timing/SimpleMap.scala new file mode 100644 index 0000000000000000000000000000000000000000..86e2c7b7b23d2b262379c34ecc78f6b5736086ca --- /dev/null +++ b/src/test/resources/regression/orb/timing/SimpleMap.scala @@ -0,0 +1,25 @@ +import leon.instrumentation._ +import leon.invariant._ + +object SimpleMap { + sealed abstract class List + case class Cons(head: (BigInt, BigInt), tail: List) extends List + case class Nil() extends List + + def size(l : List) : BigInt = (l match { + case Cons(_, xs) => 1 + size(xs) + case _ => 0 + }) + + def insert(l: List, key: BigInt, value: BigInt): List = { + Cons((key, value), l) + } ensuring(res => tmpl((a) => time <= a)) + + def getOrElse(l: List, key: BigInt, elseValue: BigInt): BigInt = { + l match { + case Nil() => elseValue + case Cons((currKey, currValue), _) if (currKey == key) => currValue + case Cons(_, tail) => getOrElse(tail, key, elseValue) + } + } ensuring(res => tmpl((a, b) => time <= a*size(l) + b)) +} \ No newline at end of file diff --git a/src/test/scala/leon/regression/orb/OrbInstrumentationTestSuite.scala b/src/test/scala/leon/regression/orb/OrbInstrumentationTestSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..b6a7d60f0076038c093abf8d6bec7cefe6f44e41 --- /dev/null +++ b/src/test/scala/leon/regression/orb/OrbInstrumentationTestSuite.scala @@ -0,0 +1,48 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon.regression.orb +import leon.test._ +import leon._ +import leon.purescala.Definitions._ +import leon.invariant.engine._ +import leon.transformations._ +import java.io.File +import leon.purescala.Types.TupleType + +class OrbInstrumentationTestSuite extends LeonRegressionSuite { + + test("TestInstrumentation") { + val ctx = createLeonContext("--inferInv", "--minbounds", "--timeout=" + 10) + val testFilename = toTempFile( + """ + import leon.annotation._ + import leon.invariant._ + import leon.instrumentation._ + + object Test { + sealed abstract class List + case class Cons(tail: List) extends List + case class Nil() extends List + + // proved with unrolling=0 + def size(l: List) : BigInt = (l match { + case Nil() => BigInt(0) + case Cons(t) => 1 + size(t) + }) ensuring(res => tmpl(a => time <= a)) + }""") + val beginPipe = leon.frontends.scalac.ExtractionPhase andThen + new leon.utils.PreprocessingPhase + val program = beginPipe.run(ctx)(testFilename) + val processPipe = InstrumentationPhase + // check properties. + val instProg = processPipe.run(ctx)(program) + val sizeFun = instProg.definedFunctions.find(_.id.name.startsWith("size")) + if(!sizeFun.isDefined || !sizeFun.get.returnType.isInstanceOf[TupleType]) + fail("Error in instrumentation") + } + + def toTempFile(content: String): List[String] = { + val pipeline = leon.utils.TemporaryInputPhase + pipeline.run(createLeonContext())((List(content), Nil)) + } +} diff --git a/src/test/scala/leon/regression/orb/OrbRegressionSuite.scala b/src/test/scala/leon/regression/orb/OrbRegressionSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..5d6f7c0e728104475e2e29b5a92196f2ee6aca49 --- /dev/null +++ b/src/test/scala/leon/regression/orb/OrbRegressionSuite.scala @@ -0,0 +1,63 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon.regression.orb +import leon.test._ + +import leon._ +import leon.purescala.Definitions._ +import leon.invariant.engine._ +import leon.transformations._ + +import java.io.File + +class OrbRegressionSuite extends LeonRegressionSuite { + private def forEachFileIn(path: String)(block: File => Unit) { + val fs = filesInResourceDir(path, _.endsWith(".scala")) + for (f <- fs) { + block(f) + } + } + + private def testInference(f: File, bound: Int) { + + val ctx = createLeonContext("--inferInv", "--minbounds", "--timeout="+bound) + val beginPipe = leon.frontends.scalac.ExtractionPhase andThen + new leon.utils.PreprocessingPhase + val program = beginPipe.run(ctx)(f.getAbsolutePath :: Nil) + val processPipe = InstrumentationPhase andThen InferInvariantsPhase + val report = processPipe.run(ctx)(program) + val fails = report.conditions.filterNot(_.invariant.isDefined) + if (!fails.isEmpty) + fail(s"Inference failed for functions ${fails.map(_.fd).mkString("\n")}") + } + + forEachFileIn("regression/orb/timing") { f => + test("Timing: " + f.getName) { + testInference(f, 50) + } + } + + forEachFileIn("regression/orb/stack/") { f => + test("Stack: " + f.getName) { + testInference(f, 50) + } + } + + forEachFileIn("regression/orb//depth") { f => + test("Depth: " + f.getName) { + testInference(f, 50) + } + } + + forEachFileIn("regression/orb/numerical") { f => + test("Numerical: " + f.getName) { + testInference(f, 50) + } + } + + forEachFileIn("regression/orb/combined/") { f => + test("Multiple Instrumentations: " + f.getName) { + testInference(f, 50) + } + } +} diff --git a/src/test/scala/leon/test/helpers/ExpressionsDSL.scala b/src/test/scala/leon/test/helpers/ExpressionsDSL.scala index 08ab658815eb5cd51ef5b821261d5484ba878112..74cd84c2c1702d1e078ca72785b4722ae002de5c 100644 --- a/src/test/scala/leon/test/helpers/ExpressionsDSL.scala +++ b/src/test/scala/leon/test/helpers/ExpressionsDSL.scala @@ -16,7 +16,7 @@ trait ExpressionsDSL { def bi(x: Int) = InfiniteIntegerLiteral(x) def b(x: Boolean) = BooleanLiteral(x) def i(x: Int) = IntLiteral(x) - def r(x: Double) = RealLiteral(BigDecimal(x)) + def r(n: BigInt, d: BigInt) = FractionalLiteral(n, d) val a = FreshIdentifier("a", Int32Type).toVariable val b = FreshIdentifier("b", Int32Type).toVariable diff --git a/src/test/scala/leon/unit/evaluators/EvaluatorSuite.scala b/src/test/scala/leon/unit/evaluators/EvaluatorSuite.scala index a18a0dca273ce1a53e47ab8cfc67cc93327348ea..4d97487f6d85a50e39efd8277d64c99958a63df4 100644 --- a/src/test/scala/leon/unit/evaluators/EvaluatorSuite.scala +++ b/src/test/scala/leon/unit/evaluators/EvaluatorSuite.scala @@ -48,9 +48,9 @@ class EvaluatorSuite extends LeonTestSuite with helpers.ExpressionsDSL { eval(e, UnitLiteral()) === UnitLiteral() eval(e, InfiniteIntegerLiteral(0)) === InfiniteIntegerLiteral(0) eval(e, InfiniteIntegerLiteral(42)) === InfiniteIntegerLiteral(42) - eval(e, RealLiteral(0)) === RealLiteral(0) - eval(e, RealLiteral(42)) === RealLiteral(42) - eval(e, RealLiteral(13.255)) === RealLiteral(13.255) + eval(e, FractionalLiteral(0 ,1)) === FractionalLiteral(0 ,1) + eval(e, FractionalLiteral(42 ,1)) === FractionalLiteral(42, 1) + eval(e, FractionalLiteral(26, 3)) === FractionalLiteral(26, 3) } } @@ -172,32 +172,29 @@ class EvaluatorSuite extends LeonTestSuite with helpers.ExpressionsDSL { } test("Real Arightmetic") { implicit fix => - for(e <- allEvaluators) { - eval(e, RealPlus(RealLiteral(3), RealLiteral(5))) === RealLiteral(8) - eval(e, RealMinus(RealLiteral(7), RealLiteral(2))) === RealLiteral(5) - eval(e, RealUMinus(RealLiteral(7))) === RealLiteral(-7) - eval(e, RealTimes(RealLiteral(2), RealLiteral(3))) === RealLiteral(6) - eval(e, RealPlus(RealLiteral(2.5), RealLiteral(3.5))) === RealLiteral(6) + for (e <- allEvaluators) { + eval(e, RealPlus(FractionalLiteral(2, 3), FractionalLiteral(1, 3))) === FractionalLiteral(1, 1) + eval(e, RealMinus(FractionalLiteral(1, 1), FractionalLiteral(1, 4))) === FractionalLiteral(3, 4) + eval(e, RealUMinus(FractionalLiteral(7, 1))) === FractionalLiteral(-7, 1) + eval(e, RealTimes(FractionalLiteral(2, 3), FractionalLiteral(1, 3))) === FractionalLiteral(2, 9) } } test("Real Comparisons") { implicit fix => for(e <- allEvaluators) { - eval(e, GreaterEquals(RealLiteral(7), RealLiteral(4))) === BooleanLiteral(true) - eval(e, GreaterEquals(RealLiteral(7), RealLiteral(7))) === BooleanLiteral(true) - eval(e, GreaterEquals(RealLiteral(4), RealLiteral(7))) === BooleanLiteral(false) + eval(e, GreaterEquals(FractionalLiteral(7, 1), FractionalLiteral(4, 2))) === BooleanLiteral(true) + eval(e, GreaterEquals(FractionalLiteral(7, 2), FractionalLiteral(49, 13))) === BooleanLiteral(false) - eval(e, GreaterThan(RealLiteral(7), RealLiteral(4))) === BooleanLiteral(true) - eval(e, GreaterThan(RealLiteral(7), RealLiteral(7))) === BooleanLiteral(false) - eval(e, GreaterThan(RealLiteral(4), RealLiteral(7))) === BooleanLiteral(false) + eval(e, GreaterThan(FractionalLiteral(49, 13), FractionalLiteral(7, 2))) === BooleanLiteral(true) + eval(e, GreaterThan(FractionalLiteral(49, 14), FractionalLiteral(7, 2))) === BooleanLiteral(false) + eval(e, GreaterThan(FractionalLiteral(4, 2), FractionalLiteral(7, 1))) === BooleanLiteral(false) - eval(e, LessEquals(RealLiteral(7), RealLiteral(4))) === BooleanLiteral(false) - eval(e, LessEquals(RealLiteral(7), RealLiteral(7))) === BooleanLiteral(true) - eval(e, LessEquals(RealLiteral(4), RealLiteral(7))) === BooleanLiteral(true) + eval(e, LessEquals(FractionalLiteral(7, 1), FractionalLiteral(4, 2))) === BooleanLiteral(false) + eval(e, LessEquals(FractionalLiteral(7, 2), FractionalLiteral(49, 13))) === BooleanLiteral(true) - eval(e, LessThan(RealLiteral(7), RealLiteral(4))) === BooleanLiteral(false) - eval(e, LessThan(RealLiteral(7), RealLiteral(7))) === BooleanLiteral(false) - eval(e, LessThan(RealLiteral(4), RealLiteral(7))) === BooleanLiteral(true) + eval(e, LessThan(FractionalLiteral(49, 13), FractionalLiteral(7, 2))) === BooleanLiteral(false) + eval(e, LessThan(FractionalLiteral(49, 14), FractionalLiteral(7, 2))) === BooleanLiteral(false) + eval(e, LessThan(FractionalLiteral(4, 2), FractionalLiteral(7, 1))) === BooleanLiteral(true) } } @@ -266,7 +263,7 @@ class EvaluatorSuite extends LeonTestSuite with helpers.ExpressionsDSL { test("Array Default Value") { implicit fix => for (e <- allEvaluators) { val id = FreshIdentifier("id", Int32Type) - eqArray(eval(e, finiteArray(Map[Int, Expr](), Some(Variable(id), IntLiteral(7)), Int32Type), Map(id -> IntLiteral(27))).res, + eqArray(eval(e, finiteArray(Map[Int, Expr](), Some(Variable(id), IntLiteral(7)), Int32Type), Map(id -> IntLiteral(27))).res, finiteArray(Map[Int, Expr](), Some(IntLiteral(27), IntLiteral(7)), Int32Type)) } } diff --git a/testcases/orb-testcases/amortized/BigNums.scala b/testcases/orb-testcases/amortized/BigNums.scala new file mode 100644 index 0000000000000000000000000000000000000000..883d1f45ab8752668070699c840edd7a3de160b9 --- /dev/null +++ b/testcases/orb-testcases/amortized/BigNums.scala @@ -0,0 +1,50 @@ +import leon.invariant._ +import leon.instrumentation._ + +object BigNums { + sealed abstract class BigNum + case class Cons(head: BigInt, tail: BigNum) extends BigNum + case class Nil() extends BigNum + + def incrTime(l: BigNum) : BigInt = { + l match { + case Nil() => 1 + case Cons(x, tail) => + if(x == 0) 1 + else 1 + incrTime(tail) + } + } + + def potentialIncr(l: BigNum) : BigInt = { + l match { + case Nil() => 0 + case Cons(x, tail) => + if(x == 0) potentialIncr(tail) + else 1 + potentialIncr(tail) + } + } + + def increment(l: BigNum) : BigNum = { + l match { + case Nil() => Cons(1,l) + case Cons(x, tail) => + if(x == 0) Cons(1, tail) + else Cons(0, increment(tail)) + } + } ensuring (res => time <= ? * incrTime(l) + ? && incrTime(l) + potentialIncr(res) - potentialIncr(l) <= ?) + + /** + * Nop is the number of operations + */ + def incrUntil(nop: BigInt, l: BigNum) : BigNum = { + if(nop == 0) l + else { + incrUntil(nop-1, increment(l)) + } + } ensuring (res => time <= ? * nop + ? * potentialIncr(l) + ?) + + def count(nop: BigInt) : BigNum = { + incrUntil(nop, Nil()) + } ensuring (res => time <= ? * nop + ?) + +} diff --git a/testcases/orb-testcases/depth/AVLTree.scala b/testcases/orb-testcases/depth/AVLTree.scala new file mode 100644 index 0000000000000000000000000000000000000000..b9a7cf520f71ec94cbbc682bfc2635fae25a15c9 --- /dev/null +++ b/testcases/orb-testcases/depth/AVLTree.scala @@ -0,0 +1,193 @@ +import leon.instrumentation._ +import leon.invariant._ + + +/** + * created by manos and modified by ravi. + * BST property cannot be verified + */ +object AVLTree { + sealed abstract class Tree + case class Leaf() extends Tree + case class Node(left : Tree, value : BigInt, right: Tree, rank : BigInt) extends Tree + + sealed abstract class OptionBigInt + case class None() extends OptionBigInt + case class Some(i: BigInt) extends OptionBigInt + + def min(i1:BigInt, i2:BigInt) : BigInt = if (i1<=i2) i1 else i2 + def max(i1:BigInt, i2:BigInt) : BigInt = if (i1>=i2) i1 else i2 + + /*def twopower(x: BigInt) : BigInt = { + //require(x >= 0) + if(x < 1) 1 + else + 3/2 * twopower(x - 1) + } ensuring(res => res >= 1 && tmpl((a) => a <= 0))*/ + + def rank(t: Tree) : BigInt = { + t match { + case Leaf() => 0 + case Node(_,_,_,rk) => rk + } + } + + def height(t: Tree): BigInt = { + t match { + case Leaf() => 0 + case Node(l, x, r, _) => { + val hl = height(l) + val hr = height(r) + max(hl,hr) + 1 + } + } + } + + def size(t: Tree): BigInt = { + //require(isAVL(t)) + (t match { + case Leaf() => 0 + case Node(l, _, r,_) => size(l) + 1 + size(r) + }) + + } ensuring (res => true && tmpl((a,b) => height(t) <= a*res + b)) + + def rankHeight(t: Tree) : Boolean = t match { + case Leaf() => true + case Node(l,_,r,rk) => rankHeight(l) && rankHeight(r) && rk == height(t) + } + + def balanceFactor(t : Tree) : BigInt = { + t match{ + case Leaf() => 0 + case Node(l, _, r, _) => rank(l) - rank(r) + } + } + + /*def isAVL(t:Tree) : Boolean = { + t match { + case Leaf() => true + case Node(l,_,r,rk) => isAVL(l) && isAVL(r) && balanceFactor(t) >= -1 && balanceFactor(t) <= 1 && rankHeight(t) //isBST(t) && + } + }*/ + + def unbalancedInsert(t: Tree, e : BigInt) : Tree = { + t match { + case Leaf() => Node(Leaf(), e, Leaf(), 1) + case Node(l,v,r,h) => + if (e == v) t + else if (e < v){ + val newl = avlInsert(l,e) + Node(newl, v, r, max(rank(newl), rank(r)) + 1) + } + else { + val newr = avlInsert(r,e) + Node(l, v, newr, max(rank(l), rank(newr)) + 1) + } + } + } + + def avlInsert(t: Tree, e : BigInt) : Tree = { + + balance(unbalancedInsert(t,e)) + + } ensuring(res => true && tmpl((a,b) => depth <= a*height(t) + b)) + //minbound: ensuring(res => time <= 138*height(t) + 19) + + def deleteMax(t: Tree): (Tree, OptionBigInt) = { + + t match { + case Node(Leaf(), v, Leaf(), _) => (Leaf(), Some(v)) + case Node(l, v, Leaf(), _) => { + val (newl, opt) = deleteMax(l) + opt match { + case None() => (t, None()) + case Some(lmax) => { + val newt = balance(Node(newl, lmax, Leaf(), rank(newl) + 1)) + (newt, Some(v)) + } + } + } + case Node(_, _, r, _) => deleteMax(r) + case _ => (t, None()) + } + } ensuring(res => true && tmpl((a,b) => depth <= a*height(t) + b)) + + def unbalancedDelete(t: Tree, e: BigInt): Tree = { + t match { + case Leaf() => Leaf() //not found case + case Node(l, v, r, h) => + if (e == v) { + if (l == Leaf()) r + else if(r == Leaf()) l + else { + val (newl, opt) = deleteMax(l) + opt match { + case None() => t + case Some(newe) => { + Node(newl, newe, r, max(rank(newl), rank(r)) + 1) + } + } + } + } else if (e < v) { + val newl = avlDelete(l, e) + Node(newl, v, r, max(rank(newl), rank(r)) + 1) + } else { + val newr = avlDelete(r, e) + Node(l, v, newr, max(rank(l), rank(newr)) + 1) + } + } + } + + def avlDelete(t: Tree, e: BigInt): Tree = { + + balance(unbalancedDelete(t, e)) + + } ensuring(res => true && tmpl((a,b) => depth <= a*height(t) + b)) + + def balance(t:Tree) : Tree = { + t match { + case Leaf() => Leaf() // impossible... + case Node(l, v, r, h) => + val bfactor = balanceFactor(t) + // at this poBigInt, the tree is unbalanced + if(bfactor > 1 ) { // left-heavy + val newL = + if (balanceFactor(l) < 0) { // l is right heavy + rotateLeft(l) + } + else l + rotateRight(Node(newL,v,r, max(rank(newL), rank(r)) + 1)) + } + else if(bfactor < -1) { + val newR = + if (balanceFactor(r) > 0) { // r is left heavy + rotateRight(r) + } + else r + rotateLeft(Node(l,v,newR, max(rank(newR), rank(l)) + 1)) + } else t + } + } + + def rotateRight(t:Tree) = { + t match { + case Node(Node(ll, vl, rl, _),v,r, _) => + + val hr = max(rank(rl),rank(r)) + 1 + Node(ll, vl, Node(rl,v,r,hr), max(rank(ll),hr) + 1) + + case _ => t // this should not happen + } } + + + def rotateLeft(t:Tree) = { + t match { + case Node(l, v, Node(lr,vr,rr,_), _) => + + val hl = max(rank(l),rank(lr)) + 1 + Node(Node(l,v,lr,hl), vr, rr, max(hl, rank(rr)) + 1) + case _ => t // this should not happen + } } +} + diff --git a/testcases/orb-testcases/depth/AmortizedQueue.scala b/testcases/orb-testcases/depth/AmortizedQueue.scala new file mode 100644 index 0000000000000000000000000000000000000000..1678d8980ba741be46f464b321bc2c05c6575c80 --- /dev/null +++ b/testcases/orb-testcases/depth/AmortizedQueue.scala @@ -0,0 +1,86 @@ +import leon.instrumentation._ +import leon.invariant._ + +object AmortizedQueue { + sealed abstract class List + case class Cons(head : BigInt, tail : List) extends List + case class Nil() extends List + + case class Queue(front : List, rear : List) + + def size(list : List) : BigInt = (list match { + case Nil() => 0 + case Cons(_, xs) => 1 + size(xs) + }) + + def sizeList(list : List) : BigInt = (list match { + case Nil() => 0 + case Cons(_, xs) => 1 + sizeList(xs) + }) ensuring(res => res >= 0 && tmpl((a,b) => depth <= a*size(list) + b)) + + def qsize(q : Queue) : BigInt = size(q.front) + size(q.rear) + + def asList(q : Queue) : List = concat(q.front, reverse(q.rear)) + + def concat(l1 : List, l2 : List) : List = (l1 match { + case Nil() => l2 + case Cons(x,xs) => Cons(x, concat(xs, l2)) + + }) ensuring (res => size(res) == size(l1) + size(l2) && tmpl((a,b,c) => depth <= a*size(l1) + b)) + + def isAmortized(q : Queue) : Boolean = sizeList(q.front) >= sizeList(q.rear) + + def isEmpty(queue : Queue) : Boolean = queue match { + case Queue(Nil(), Nil()) => true + case _ => false + } + + def reverseRec(l1: List, l2: List): List = (l1 match { + case Nil() => l2 + case Cons(x, xs) => reverseRec(xs, Cons(x, l2)) + + }) ensuring (res => size(l1) + size(l2) == size(res) && tmpl((a,b) => depth <= a*size(l1) + b)) + + def reverse(l: List): List = { + reverseRec(l, Nil()) + } ensuring (res => size(l) == size(res) && tmpl((a,b) => depth <= a*size(l) + b)) + + def amortizedQueue(front : List, rear : List) : Queue = { + if (sizeList(rear) <= sizeList(front)) + Queue(front, rear) + else + Queue(concat(front, reverse(rear)), Nil()) + } + + def enqueue(q : Queue, elem : BigInt) : Queue = ({ + + amortizedQueue(q.front, Cons(elem, q.rear)) + + }) ensuring(res => true && tmpl((a,b) => depth <= a*qsize(q) + b)) + + def dequeue(q : Queue) : Queue = { + require(isAmortized(q) && !isEmpty(q)) + q match { + case Queue(Cons(f, fs), rear) => amortizedQueue(fs, rear) + case _ => Queue(Nil(),Nil()) + } + } ensuring(res => true && tmpl((a,b) => depth <= a*qsize(q) + b)) + + def removeLast(l : List) : List = { + require(l != Nil()) + l match { + case Cons(x,Nil()) => Nil() + case Cons(x,xs) => Cons(x, removeLast(xs)) + case _ => Nil() + } + } ensuring(res => size(res) <= size(l) && tmpl((a,b) => depth <= a*size(l) + b)) + + def pop(q : Queue) : Queue = { + require(isAmortized(q) && !isEmpty(q)) + q match { + case Queue(front, Cons(r,rs)) => Queue(front, rs) + case Queue(front, rear) => Queue(removeLast(front), rear) + case _ => Queue(Nil(),Nil()) + } + } ensuring(res => true && tmpl((a,b) => depth <= a*size(q.front) + b)) +} diff --git a/testcases/orb-testcases/depth/BinaryTrie.scala b/testcases/orb-testcases/depth/BinaryTrie.scala new file mode 100755 index 0000000000000000000000000000000000000000..562ab19ea8913551bafd45e429ec28c98cbee064 --- /dev/null +++ b/testcases/orb-testcases/depth/BinaryTrie.scala @@ -0,0 +1,121 @@ +import leon.instrumentation._ +import leon.invariant._ + +import scala.collection.immutable.Set + +object ParallelBinaryTrie { + sealed abstract class Tree + case class Leaf() extends Tree + case class Node(nvalue: BigInt, left: Tree, right: Tree) extends Tree + + sealed abstract class IList + case class Cons(head: BigInt, tail: IList) extends IList + case class Nil() extends IList + + def listSize(l: IList): BigInt = (l match { + case Nil() => 0 + case Cons(x, xs) => 1 + listSize(xs) + }) + + def height(t: Tree): BigInt = { + t match { + case Leaf() => 0 + case Node(x, l, r) => { + val hl = height(l) + val hr = height(r) + if (hl > hr) hl + 1 else hr + 1 + } + } + } + + def find(inp: IList, t: Tree): Tree = { + inp match { + case Nil() => t + case Cons(x, Nil()) => t + case Cons(x, xs @ Cons(y, _)) => { + t match { + case Leaf() => t + case Node(v, l, r) => { + if (y > 0) find(xs, l) else find(xs, r) + } + } + } + case _ => t + } + } ensuring (res => true && tmpl ((a, c) => depth <= a * listSize(inp) + c)) + + def insert(inp: IList, t: Tree): Tree = { + t match { + case Leaf() => { + inp match { + case Nil() => t + case Cons(x, xs) => { + val newch = insert(xs, Leaf()) + newch match { + case Leaf() => Node(x, Leaf(), Leaf()) + case Node(y, _, _) => if (y > 0) Node(x, newch, Leaf()) else Node(y, Leaf(), newch) + } + } + } + + } + case Node(v, l, r) => { + inp match { + case Nil() => t + case Cons(x, Nil()) => t + case Cons(x, xs @ Cons(y, _)) => { + val ch = if (y > 0) l else r + if (y > 0) + Node(v, insert(xs, ch), r) + else + Node(v, l, insert(xs, ch)) + } + case _ => t + } + } + } + } ensuring (res => true && tmpl ((a, c) => depth <= a * listSize(inp) + c)) + + def create(inp: IList): Tree = { + insert(inp, Leaf()) + } ensuring (res => true && tmpl ((a, c) => depth <= a * listSize(inp) + c)) + + def delete(inp: IList, t: Tree): Tree = { + t match { + case Leaf() => { + inp match { + case Nil() => Leaf() + case Cons(x ,xs) => { + //the input is not in the tree, so do nothing + Leaf() + } + } + } + case Node(v, l, r) => { + inp match { + case Nil() => { + //the tree has extensions of the input list so do nothing + t + } + case Cons(x, Nil()) => { + //if "l" and "r" are nil, remove the node + if(l == Leaf() && r == Leaf()) Leaf() + else t + } + case Cons(x ,xs@Cons(y, _)) => { + val ch = if(y > 0) l else r + val newch = delete(xs, ch) + if(newch == Leaf() && ((y > 0 && r == Leaf()) || (y <= 0 && l == Leaf()))) Leaf() + else { + if(y > 0) + Node(v, newch, r) + else + Node(v, l, newch) + } + } + case _ => t + } + } + } + } ensuring (res => true && tmpl ((a, c) => depth <= a * listSize(inp) + c)) +} diff --git a/testcases/orb-testcases/depth/BinomialHeap.scala b/testcases/orb-testcases/depth/BinomialHeap.scala new file mode 100644 index 0000000000000000000000000000000000000000..7dd9e613cd8964348fd63f1df22d22bab6d6c7f3 --- /dev/null +++ b/testcases/orb-testcases/depth/BinomialHeap.scala @@ -0,0 +1,199 @@ +/** + * @author Ravi + **/ +import leon.instrumentation._ +import leon.invariant._ + + +object BinomialHeap { + sealed abstract class BinomialTree + case class Node(rank: BigInt, elem: Element, children: BinomialHeap) extends BinomialTree + + sealed abstract class ElementAbs + case class Element(n: BigInt) extends ElementAbs + + sealed abstract class BinomialHeap + case class ConsHeap(head: BinomialTree, tail: BinomialHeap) extends BinomialHeap + case class NilHeap() extends BinomialHeap + + sealed abstract class List + case class NodeL(head: BinomialHeap, tail: List) extends List + case class NilL() extends List + + sealed abstract class OptionalTree + case class Some(t : BinomialTree) extends OptionalTree + case class None() extends OptionalTree + + /* Lower or Equal than for Element structure */ + private def leq(a: Element, b: Element) : Boolean = { + a match { + case Element(a1) => { + b match { + case Element(a2) => { + if(a1 <= a2) true + else false + } + } + } + } + } + + /* isEmpty function of the Binomial Heap */ + def isEmpty(t: BinomialHeap) = t match { + case ConsHeap(_,_) => false + case NilHeap() => true + } + + /* Helper function to determine rank of a BinomialTree */ + def rank(t: BinomialTree) : BigInt = t match { + case Node(r, _, _) => r + } + + /* Helper function to get the root element of a BinomialTree */ + def root(t: BinomialTree) : Element = t match { + case Node(_, e, _) => e + } + + /* Linking trees of equal ranks depending on the root element */ + def link(t1: BinomialTree, t2: BinomialTree) : BinomialTree = { + t1 match { + case Node(r, x1, c1) => { + t2 match { + case Node(_, x2, c2) => { + if (leq(x1,x2)) { + Node(r+1, x1, ConsHeap(t2, c1)) + } else { + Node(r+1, x2, ConsHeap(t1, c2)) + } + } + } + } + } + } + + def treeNum(h: BinomialHeap) : BigInt = { + h match { + case ConsHeap(head, tail) => 1 + treeNum(tail) + case NilHeap() => 0 + } + } + + /* Insert a tree into a binomial heap. The tree should be correct in relation to the heap */ + def insTree(t: BinomialTree, h: BinomialHeap) : BinomialHeap = { + h match { + case ConsHeap(head, tail) => { + if (rank(t) < rank(head)) { + ConsHeap(t, h) + } else if (rank(t) > rank(head)) { + ConsHeap(head, insTree(t,tail)) + } else { + insTree(link(t,head), tail) + } + } + case NilHeap() => ConsHeap(t, NilHeap()) + } + } ensuring(res => true && tmpl((a,b) => depth <= a*treeNum(h) + b)) + + /* Merge two heaps together */ + def merge(h1: BinomialHeap, h2: BinomialHeap): BinomialHeap = { + h1 match { + case ConsHeap(head1, tail1) => { + h2 match { + case ConsHeap(head2, tail2) => { + if (rank(head1) < rank(head2)) { + ConsHeap(head1, merge(tail1, h2)) + } else if (rank(head2) < rank(head1)) { + ConsHeap(head2, merge(h1, tail2)) + } else { + mergeWithCarry(link(head1, head2), tail1, tail2) + } + } + case NilHeap() => h1 + } + } + case NilHeap() => h2 + } + } ensuring(res => true && tmpl((a,b,c) => depth <= a*treeNum(h1) + b*treeNum(h2) + c)) + + def mergeWithCarry(t: BinomialTree, h1: BinomialHeap, h2: BinomialHeap): BinomialHeap = { + t match { + case Node(r, _, _) => { + h1 match { + case ConsHeap(head1, tail1) => { + h2 match { + case ConsHeap(head2, tail2) => { + if (rank(head1) < rank(head2)) { + + if (rank(t) < rank(head1)) + ConsHeap(t, ConsHeap(head1, merge(tail1, h2))) + else + mergeWithCarry(link(t, head1), tail1, h2) + + } else if (rank(head2) < rank(head1)) { + + if (rank(t) < rank(head2)) + ConsHeap(t, ConsHeap(head2, merge(h1, tail2))) + else + mergeWithCarry(link(t, head2), h1, tail2) + + } else { + ConsHeap(t, mergeWithCarry(link(head1, head2), tail1, tail2)) + } + } + case NilHeap() => { + insTree(t, h1) + } + } + } + case NilHeap() => insTree(t, h2) + } + } + } + } ensuring (res => true && tmpl ((d, e, f) => depth <= d * treeNum(h1) + e * treeNum(h2) + f)) + + //Auxiliary helper function to simplify findMin and deleteMin + def removeMinTree(h: BinomialHeap): (OptionalTree, BinomialHeap) = { + h match { + case ConsHeap(head, NilHeap()) => (Some(head), NilHeap()) + case ConsHeap(head1, tail1) => { + val (opthead2, tail2) = removeMinTree(tail1) + opthead2 match { + case None() => (Some(head1), tail1) + case Some(head2) => + if (leq(root(head1), root(head2))) { + (Some(head1), tail1) + } else { + (Some(head2), ConsHeap(head1, tail2)) + } + } + } + case _ => (None(), NilHeap()) + } + } ensuring (res => treeNum(res._2) <= treeNum(h) && tmpl ((a, b) => depth <= a*treeNum(h) + b)) + + /*def findMin(h: BinomialHeap) : Element = { + val (opt, _) = removeMinTree(h) + opt match { + case Some(Node(_,e,ts1)) => e + case _ => Element(-1) + } + } ensuring(res => true && tmpl((a,b) => time <= a*treeNum(h) + b))*/ + + def minTreeChildren(h: BinomialHeap) : BigInt = { + val (min, _) = removeMinTree(h) + min match { + case None() => 0 + case Some(Node(_,_,ch)) => treeNum(ch) + } + } + + // Discard the minimum element of the extracted min tree and put its children back into the heap + def deleteMin(h: BinomialHeap) : BinomialHeap = { + val (min, ts2) = removeMinTree(h) + min match { + case Some(Node(_,_,ts1)) => merge(ts1, ts2) + case _ => h + } + } ensuring(res => true && tmpl((a,b,c) => depth <= a*minTreeChildren(h) + b*treeNum(h) + c)) + +} diff --git a/testcases/orb-testcases/depth/ConcatVariations.scala b/testcases/orb-testcases/depth/ConcatVariations.scala new file mode 100644 index 0000000000000000000000000000000000000000..aea7080d6421e03901f6e1c58db1c5e0c6ed80d9 --- /dev/null +++ b/testcases/orb-testcases/depth/ConcatVariations.scala @@ -0,0 +1,42 @@ +import leon.instrumentation._ +import leon.invariant._ + + +object ConcatVariations { + sealed abstract class List + case class Cons(head: BigInt, tail: List) extends List + case class Nil() extends List + + def size(l: List): BigInt = (l match { + case Nil() => 0 + case Cons(_, t) => 1 + size(t) + }) + + def genL(n: BigInt): List = { + require(n >= 0) + if (n == 0) Nil() + else + Cons(n, genL(n - 1)) + } ensuring (res => size(res) == n && tmpl((a,b) => depth <= a*n + b)) + + def append(l1: List, l2: List): List = (l1 match { + case Nil() => l2 + case Cons(x, xs) => Cons(x, append(xs, l2)) + + }) ensuring (res => size(l1) + size(l2) == size(res) && tmpl((a,b) => depth <= a*size(l1) + b)) + + def f_good(m: BigInt, n: BigInt): List = { + require(0 <= m && 0 <= n) + if (m == 0) Nil() + else append(genL(n), f_good(m - 1, n)) + + } ensuring(res => size(res) == n*m && tmpl((a,b,c,d) => depth <= a*(n*m) + b*n + c*m +d)) + + def f_worst(m: BigInt, n: BigInt): List = { + require(0 <= m && 0 <= n) + + if (m == 0) Nil() + else append(f_worst(m - 1, n), genL(n)) + + } ensuring(res => size(res) == n*m && tmpl((a,c,d,e,f) => depth <= a*((n*m)*m)+c*(n*m)+d*n+e*m+f)) +} diff --git a/testcases/orb-testcases/depth/Folds.scala b/testcases/orb-testcases/depth/Folds.scala new file mode 100755 index 0000000000000000000000000000000000000000..305446be53b27c976080968b881119ce06de11e6 --- /dev/null +++ b/testcases/orb-testcases/depth/Folds.scala @@ -0,0 +1,82 @@ +import leon.instrumentation._ +import leon.invariant._ + + +object TreeMaps { + + sealed abstract class Tree + case class Node(left: Tree, value: BigInt, right: Tree) extends Tree + case class Leaf() extends Tree + + def height(t: Tree): BigInt = { + t match { + case Leaf() => 0 + case Node(l, x, r) => { + val hl = height(l) + val hr = height(r) + if (hl > hr) hl + 1 else hr + 1 + } + } + } + + def parallelSearch(elem : BigInt, t : Tree) : Boolean = { + t match { + case Leaf() => false + case Node(l, x, r) => + if(x == elem) true + else { + val r1 = parallelSearch(elem, r) + val r2 = parallelSearch(elem, l) + if(r1 || r2) true + else false + } + } + } ensuring(res => true && tmpl((a,b) => depth <= a*height(t) + b)) + + + def squareMap(t : Tree) : Tree = { + t match { + case Leaf() => t + case Node(l, x, r) => + val nl = squareMap(l) + val nr = squareMap(r) + Node(nl, x*x, nr) + } + } ensuring (res => true && tmpl((a,b) => depth <= a*height(t) + b)) + + sealed abstract class List + case class Cons(head: BigInt, tail: List) extends List + case class Nil() extends List + + def size(l: List): BigInt = (l match { + case Nil() => 0 + case Cons(_, t) => 1 + size(t) + }) + + def fact(n : BigInt) : BigInt = { + require(n >= 0) + + if(n == 1 || n == 0) BigInt(1) + else n * fact(n-1) + + } ensuring(res => tmpl((a,b) => depth <= a*n + b)) + + def descending(l: List, k: BigInt) : Boolean = { + l match { + case Nil() => true + case Cons(x, t) => x > 0 && x <= k && descending(t, x-1) + } + } + + def factMap(l: List, k: BigInt): List = { + require(descending(l, k) && k >= 0) + + l match { + case Nil() => Nil() + case Cons(x, t) => { + val f = fact(x) + Cons(f, factMap(t, x-1)) + } + + }} ensuring(res => true && tmpl((a,b) => depth <= a*k + b)) +} \ No newline at end of file diff --git a/testcases/orb-testcases/depth/ForElimParallel.scala b/testcases/orb-testcases/depth/ForElimParallel.scala new file mode 100644 index 0000000000000000000000000000000000000000..a065af8f3383e2c283241758111bca380dbe41fe --- /dev/null +++ b/testcases/orb-testcases/depth/ForElimParallel.scala @@ -0,0 +1,119 @@ +import leon.instrumentation._ +import leon.invariant._ + + +object ForElimination { + + sealed abstract class Tree + case class Node(left: Tree, value: Statement, right: Tree) extends Tree + case class Leaf() extends Tree + + sealed abstract class Statement + case class Print(msg: BigInt, varID: BigInt) extends Statement + case class Assign(varID: BigInt, expr: Expression) extends Statement + case class Skip() extends Statement + case class Block(body: Tree) extends Statement + case class IfThenElse(expr: Expression, then: Statement, elze: Statement) extends Statement + case class While(expr: Expression, body: Statement) extends Statement + case class For(init: Statement, expr: Expression, step: Statement, body: Statement) extends Statement + + sealed abstract class Expression + case class Var(varID: BigInt) extends Expression + case class BigIntLiteral(value: BigInt) extends Expression + case class Plus(lhs: Expression, rhs: Expression) extends Expression + case class Minus(lhs: Expression, rhs: Expression) extends Expression + case class Times(lhs: Expression, rhs: Expression) extends Expression + case class Division(lhs: Expression, rhs: Expression) extends Expression + case class Equals(lhs: Expression, rhs: Expression) extends Expression + case class LessThan(lhs: Expression, rhs: Expression) extends Expression + case class And(lhs: Expression, rhs: Expression) extends Expression + case class Or(lhs: Expression, rhs: Expression) extends Expression + case class Not(expr: Expression) extends Expression + + /*def sizeStat(st: Statement) : BigInt = st match { + case Block(l) => sizeList(l) + 1 + case IfThenElse(c,th,el) => sizeStat(th) + sizeStat(el) + 1 + case While(c,b) => sizeStat(b) + 1 + case For(init,cond,step,body) => sizeStat(init) + sizeStat(step) + sizeStat(body) + case other => 1 + } + + def sizeTree(l: List) : BigInt = l match { + case Node(l,x,r) => sizeTree(l) + sizeTree(r) + sizeStat(x) + case Nil() => 0 + }*/ + + def max(x: BigInt, y: BigInt) = if(x >= y) x else y + + def depthStat(st: Statement) : BigInt = st match { + case Block(t) => depthTree(t) + 1 + case IfThenElse(c,th,el) => max(depthStat(th),depthStat(el)) + 1 + case While(c,b) => depthStat(b) + 1 + case For(init,cond,step,body) => max(max(depthStat(init),depthStat(step)),depthStat(body)) + case other => 1 + } + + def depthTree(t: Tree) : BigInt = t match { + case Node(l,x,r) => max(max(depthTree(l),depthTree(r)),depthStat(x)) + 1 + case Leaf() => 0 + } + + /*def isForFree(stat: Statement): Boolean = (stat match { + case Block(body) => isForFreeTree(body) + case IfThenElse(_, then, elze) => isForFree(then) && isForFree(elze) + case While(_, body) => isForFree(body) + case For(_,_,_,_) => false + case _ => true + }) ensuring(res => true && tmpl((a,b) => depth <= a*depthStat(stat) + b)) + + def isForFreeTree(t: Tree): Boolean = (t match { + case Leaf() => true + case Node(l, x, r) => isForFree(x) && isForFreeTree(l) && isForFreeTree(r) + }) ensuring(res => true && tmpl((a,b) => depth <= a*depthTree(t) + b))*/ + + /*def forLoopsWellFormedTree(t: Tree): Boolean = (t match { + case Leaf() => true + case Node(l, x, r) => forLoopsWellFormed(x) && forLoopsWellFormedTree(l) && forLoopsWellFormedTree(r) + }) ensuring(res => true && tmpl((a,b) => depth <= a*depthTree(t) + b)) + + def forLoopsWellFormed(stat: Statement): Boolean = (stat match { + case Block(body) => forLoopsWellFormedTree(body) + case IfThenElse(_, then, elze) => forLoopsWellFormed(then) && forLoopsWellFormed(elze) + case While(_, body) => forLoopsWellFormed(body) + case For(init, _, step, body) => isForFree(init) && isForFree(step) && forLoopsWellFormed(body) + case _ => true + }) ensuring(res => true && tmpl((a,b) => depth <= a*depthStat(stat) + b))*/ + + def eliminateWhileLoopsTree(t: Tree): Tree = { + t match { + case Leaf() => Leaf() + case Node(l,x,r) => Node(eliminateWhileLoopsTree(l), eliminateWhileLoops(x), eliminateWhileLoopsTree(r)) + } + } ensuring(res => true && tmpl((a,b) => depth <= a*depthTree(t) + b)) + + def eliminateWhileLoops(stat: Statement): Statement = (stat match { + case Block(body) => Block(eliminateWhileLoopsTree(body)) + case IfThenElse(expr, then, elze) => IfThenElse(expr, eliminateWhileLoops(then), eliminateWhileLoops(elze)) + case While(expr, body) => For(Skip(), expr, Skip(), eliminateWhileLoops(body)) + case For(init, expr, step, body) => For(eliminateWhileLoops(init), expr, eliminateWhileLoops(step), eliminateWhileLoops(body)) + case other => other + }) ensuring(res => true && tmpl((a,b) => depth <= a*depthStat(stat) + b)) + + /*def eliminateForLoopsTree(t: Tree): Tree = { + t match { + case Leaf() => Leaf() + case Node(l,x,r) => Node(eliminateForLoopsTree(l), eliminateForLoops(x), eliminateForLoopsTree(r)) + } + } ensuring(res => true && tmpl((a,b) => depth <= a*depthTree(t) + b)) + + def eliminateForLoops(stat: Statement): Statement = { + stat match { + case Block(body) => Block(eliminateForLoopsTree(body)) + case IfThenElse(expr, then, elze) => IfThenElse(expr, eliminateForLoops(then), eliminateForLoops(elze)) + case While(expr, body) => While(expr, eliminateForLoops(body)) + case For(init, expr, step, body) => Block(Node(Leaf(),eliminateForLoops(init),Node(Leaf(), + While(expr, Block(Node(Leaf(),eliminateForLoops(body), Node(Leaf(),eliminateForLoops(step), Leaf())))),Leaf()))) + case other => other + } + } ensuring(res => true && tmpl((a,b) => depth <= a*depthStat(stat) + b))*/ +} diff --git a/testcases/orb-testcases/depth/ForElimination.scala b/testcases/orb-testcases/depth/ForElimination.scala new file mode 100644 index 0000000000000000000000000000000000000000..b1c5ba98f1a7da5c626dd61e55a34b0ed97515c7 --- /dev/null +++ b/testcases/orb-testcases/depth/ForElimination.scala @@ -0,0 +1,103 @@ +import leon.instrumentation._ +import leon.invariant._ + + +object ForElimination { + + sealed abstract class List + case class Nil() extends List + case class Cons(head: Statement, tail: List) extends List + + sealed abstract class Statement + case class Print(msg: BigInt, varID: BigInt) extends Statement + case class Assign(varID: BigInt, expr: Expression) extends Statement + case class Skip() extends Statement + case class Block(body: List) extends Statement + case class IfThenElse(expr: Expression, thenExpr: Statement, elseExpr: Statement) extends Statement + case class While(expr: Expression, body: Statement) extends Statement + case class For(init: Statement, expr: Expression, step: Statement, body: Statement) extends Statement + + sealed abstract class Expression + case class Var(varID: BigInt) extends Expression + case class BigIntLiteral(value: BigInt) extends Expression + case class Plus(lhs: Expression, rhs: Expression) extends Expression + case class Minus(lhs: Expression, rhs: Expression) extends Expression + case class Times(lhs: Expression, rhs: Expression) extends Expression + case class Division(lhs: Expression, rhs: Expression) extends Expression + case class Equals(lhs: Expression, rhs: Expression) extends Expression + case class LessThan(lhs: Expression, rhs: Expression) extends Expression + case class And(lhs: Expression, rhs: Expression) extends Expression + case class Or(lhs: Expression, rhs: Expression) extends Expression + case class Not(expr: Expression) extends Expression + + def sizeStat(st: Statement) : BigInt = st match { + case Block(l) => sizeList(l) + 1 + case IfThenElse(c,th,el) => sizeStat(th) + sizeStat(el) + 1 + case While(c,b) => sizeStat(b) + 1 + case For(init,cond,step,body) => sizeStat(init) + sizeStat(step) + sizeStat(body) + case other => 1 + } + + def sizeList(l: List) : BigInt = l match { + case Cons(h,t) => sizeStat(h) + sizeList(t) + case Nil() => 0 + } + + def isForFree(stat: Statement): Boolean = (stat match { + case Block(body) => isForFreeList(body) + case IfThenElse(_, thenExpr, elseExpr) => isForFree(thenExpr) && isForFree(elseExpr) + case While(_, body) => isForFree(body) + case For(_,_,_,_) => false + case _ => true + }) ensuring(res => true && tmpl((a,b) => depth <= a*sizeStat(stat) + b)) + + def isForFreeList(l: List): Boolean = (l match { + case Nil() => true + case Cons(x, xs) => isForFree(x) && isForFreeList(xs) + }) ensuring(res => true && tmpl((a,b) => depth <= a*sizeList(l) + b)) + + def forLoopsWellFormedList(l: List): Boolean = (l match { + case Nil() => true + case Cons(x, xs) => forLoopsWellFormed(x) && forLoopsWellFormedList(xs) + }) ensuring(res => true && tmpl((a,b) => depth <= a*sizeList(l) + b)) + + def forLoopsWellFormed(stat: Statement): Boolean = (stat match { + case Block(body) => forLoopsWellFormedList(body) + case IfThenElse(_, thenExpr, elseExpr) => forLoopsWellFormed(thenExpr) && forLoopsWellFormed(elseExpr) + case While(_, body) => forLoopsWellFormed(body) + case For(init, _, step, body) => isForFree(init) && isForFree(step) && forLoopsWellFormed(body) + case _ => true + }) ensuring(res => true && tmpl((a,b) => depth <= a*sizeStat(stat) + b)) + + def eliminateWhileLoopsList(l: List): List = { + l match { + case Nil() => Nil() + case Cons(x, xs) => Cons(eliminateWhileLoops(x), eliminateWhileLoopsList(xs)) + } + } ensuring(res => true && tmpl((a,b) => depth <= a*sizeList(l) + b)) + + def eliminateWhileLoops(stat: Statement): Statement = (stat match { + case Block(body) => Block(eliminateWhileLoopsList(body)) + case IfThenElse(expr, thenExpr, elseExpr) => IfThenElse(expr, eliminateWhileLoops(thenExpr), eliminateWhileLoops(elseExpr)) + case While(expr, body) => For(Skip(), expr, Skip(), eliminateWhileLoops(body)) + case For(init, expr, step, body) => For(eliminateWhileLoops(init), expr, eliminateWhileLoops(step), eliminateWhileLoops(body)) + case other => other + }) ensuring(res => true && tmpl((a,b) => depth <= a*sizeStat(stat) + b)) + + def eliminateForLoopsList(l: List): List = { + l match { + case Nil() => Nil() + case Cons(x, xs) => Cons(eliminateForLoops(x), eliminateForLoopsList(xs)) + } + } ensuring(res => true && tmpl((a,b) => depth <= a*sizeList(l) + b)) + + def eliminateForLoops(stat: Statement): Statement = { + stat match { + case Block(body) => Block(eliminateForLoopsList(body)) + case IfThenElse(expr, thenExpr, elseExpr) => IfThenElse(expr, eliminateForLoops(thenExpr), eliminateForLoops(elseExpr)) + case While(expr, body) => While(expr, eliminateForLoops(body)) + case For(init, expr, step, body) => Block(Cons(eliminateForLoops(init), Cons(While(expr, Block(Cons(eliminateForLoops(body), Cons(eliminateForLoops(step), Nil())))), Nil()))) + case other => other + } + } ensuring(res => true && tmpl((a,b) => depth <= a*sizeStat(stat) + b)) +} diff --git a/testcases/orb-testcases/depth/InsertionSort.scala b/testcases/orb-testcases/depth/InsertionSort.scala new file mode 100644 index 0000000000000000000000000000000000000000..e52adc62f4a119393e4a3560af1d20aae7b9c34c --- /dev/null +++ b/testcases/orb-testcases/depth/InsertionSort.scala @@ -0,0 +1,28 @@ +import scala.collection.immutable.Set +import leon.instrumentation._ +import leon.invariant._ + + +object InsertionSort { + sealed abstract class List + case class Cons(head:BigInt,tail:List) extends List + case class Nil() extends List + + def size(l : List) : BigInt = (l match { + case Nil() => 0 + case Cons(_, xs) => 1 + size(xs) + }) + + def sortedIns(e: BigInt, l: List): List = { + l match { + case Nil() => Cons(e,Nil()) + case Cons(x,xs) => if (x <= e) Cons(x,sortedIns(e, xs)) else Cons(e, l) + } + } ensuring(res => size(res) == size(l) + 1 && tmpl((a,b) => depth <= a*size(l) +b)) + + def sort(l: List): List = (l match { + case Nil() => Nil() + case Cons(x,xs) => sortedIns(x, sort(xs)) + + }) ensuring(res => size(res) == size(l) && tmpl((a,b) => depth <= a*(size(l)*size(l)) +b)) +} diff --git a/testcases/orb-testcases/depth/LeftistHeap.scala b/testcases/orb-testcases/depth/LeftistHeap.scala new file mode 100644 index 0000000000000000000000000000000000000000..06a59d2d6745821130d2da5a7e6453425c573b81 --- /dev/null +++ b/testcases/orb-testcases/depth/LeftistHeap.scala @@ -0,0 +1,66 @@ +import leon.instrumentation._ +import leon.invariant._ +import leon.annotation._ + +object LeftistHeap { + sealed abstract class Heap + case class Leaf() extends Heap + case class Node(rk : BigInt, value: BigInt, left: Heap, right: Heap) extends Heap + + private def rightHeight(h: Heap) : BigInt = h match { + case Leaf() => 0 + case Node(_,_,_,r) => rightHeight(r) + 1 + } + + private def rank(h: Heap) : BigInt = h match { + case Leaf() => 0 + case Node(rk,_,_,_) => rk + } + + private def hasLeftistProperty(h: Heap) : Boolean = (h match { + case Leaf() => true + case Node(_,_,l,r) => hasLeftistProperty(l) && hasLeftistProperty(r) && rightHeight(l) >= rightHeight(r) && (rank(h) == rightHeight(h)) + }) + + def leftRightHeight(h: Heap) : BigInt = {h match { + case Leaf() => 0 + case Node(_,_,l,r) => rightHeight(l) + }} + + def removeMax(h: Heap) : Heap = { + require(hasLeftistProperty(h)) + h match { + case Node(_,_,l,r) => merge(l, r) + case l @ Leaf() => l + } + } ensuring(res => true && tmpl((a,b) => depth <= a*leftRightHeight(h) + b)) + + private def merge(h1: Heap, h2: Heap) : Heap = { + require(hasLeftistProperty(h1) && hasLeftistProperty(h2)) + h1 match { + case Leaf() => h2 + case Node(_, v1, l1, r1) => h2 match { + case Leaf() => h1 + case Node(_, v2, l2, r2) => + if(v1 > v2) + makeT(v1, l1, merge(r1, h2)) + else + makeT(v2, l2, merge(h1, r2)) + } + } + } ensuring(res => true && tmpl((a,b,c) => depth <= a*rightHeight(h1) + b*rightHeight(h2) + c)) + + private def makeT(value: BigInt, left: Heap, right: Heap) : Heap = { + if(rank(left) >= rank(right)) + Node(rank(right) + 1, value, left, right) + else + Node(rank(left) + 1, value, right, left) + } + + def insert(element: BigInt, heap: Heap) : Heap = { + require(hasLeftistProperty(heap)) + + merge(Node(1, element, Leaf(), Leaf()), heap) + + } ensuring(res => true && tmpl((a,b,c) => depth <= a*rightHeight(heap) + c)) +} diff --git a/testcases/orb-testcases/depth/ListOperations.scala b/testcases/orb-testcases/depth/ListOperations.scala new file mode 100644 index 0000000000000000000000000000000000000000..9eb3c790707d630b32f42b8d707b367c8622a1be --- /dev/null +++ b/testcases/orb-testcases/depth/ListOperations.scala @@ -0,0 +1,61 @@ +import leon.instrumentation._ +import leon.invariant._ +import leon.annotation._ + +object ListOperations { + sealed abstract class List + case class Cons(head: BigInt, tail: List) extends List + case class Nil() extends List + + def size(l: List): BigInt = (l match { + case Nil() => 0 + case Cons(_, t) => 1 + size(t) + }) + + def append(l1: List, l2: List): List = (l1 match { + case Nil() => l2 + case Cons(x, xs) => Cons(x, append(xs, l2)) + + }) ensuring (res => size(l1) + size(l2) == size(res) && tmpl((a,b) => depth <= a*size(l1) + b)) + + def reverseRec(l1: List, l2: List): List = (l1 match { + case Nil() => l2 + case Cons(x, xs) => reverseRec(xs, Cons(x, l2)) + + }) ensuring (res => size(l1) + size(l2) == size(res) && tmpl((a,b) => depth <= a*size(l1) + b)) + + def reverse(l: List): List = { + reverseRec(l, Nil()) + + } ensuring (res => size(l) == size(res) && tmpl((a,b) => depth <= a*size(l) + b)) + + def reverse2(l: List): List = { + l match { + case Nil() => l + case Cons(hd, tl) => append(reverse2(tl), Cons(hd, Nil())) + } + } ensuring (res => size(res) == size(l) && tmpl((a,b) => depth <= a*(size(l)*size(l)) + b)) + + def remove(elem: BigInt, l: List): List = { + l match { + case Nil() => Nil() + case Cons(hd, tl) => if (hd == elem) remove(elem, tl) else Cons(hd, remove(elem, tl)) + } + } ensuring (res => size(l) >= size(res) && tmpl((a,b) => depth <= a*size(l) + b)) + + def contains(list: List, elem: BigInt): Boolean = (list match { + case Nil() => false + case Cons(x, xs) => x == elem || contains(xs, elem) + + }) ensuring (res => true && tmpl((a,b) => depth <= a*size(list) + b)) + + def distinct(l: List): List = ( + l match { + case Nil() => Nil() + case Cons(x, xs) => { + val newl = distinct(xs) + if (contains(newl, x)) newl + else Cons(x, newl) + } + }) ensuring (res => size(l) >= size(res) && tmpl((a,b) => depth <= a*(size(l)*size(l)) + b)) +} diff --git a/testcases/orb-testcases/depth/MergeSort.scala b/testcases/orb-testcases/depth/MergeSort.scala new file mode 100644 index 0000000000000000000000000000000000000000..de9fea5347a6ddf27679ae69b7b559a4ca981aa2 --- /dev/null +++ b/testcases/orb-testcases/depth/MergeSort.scala @@ -0,0 +1,56 @@ +import leon.instrumentation._ +import leon.invariant._ +import leon.annotation._ + +object MergeSort { + sealed abstract class List + case class Cons(head:BigInt,tail:List) extends List + case class Nil() extends List + + def size(list:List): BigInt = {list match { + case Nil() => BigInt(0) + case Cons(x,xs) => 1 + size(xs) + }} ensuring(res => res >= 0) + + def length(l:List): BigInt = { + l match { + case Nil() => BigInt(0) + case Cons(x,xs) => 1 + length(xs) + } + } ensuring(res => res >= 0 && res == size(l) && tmpl((a,b) => depth <= a*size(l) + b)) + + def split(l:List,n:BigInt): (List,List) = { + require(n >= 0 && n <= size(l)) + if (n <= 0) (Nil(),l) + else + l match { + case Nil() => (Nil(),l) + case Cons(x,xs) => { + val (fst,snd) = split(xs, n-1) + (Cons(x,fst), snd) + } + } + } ensuring(res => size(res._2) == size(l) - n && size(res._1) == n && tmpl((a,b) => depth <= a*n +b)) + + def merge(aList:List, bList:List):List = (bList match { + case Nil() => aList + case Cons(x,xs) => + aList match { + case Nil() => bList + case Cons(y,ys) => + if (y < x) + Cons(y,merge(ys, bList)) + else + Cons(x,merge(aList, xs)) + } + }) ensuring(res => size(aList)+size(bList) == size(res) && tmpl((a,b,c) => depth <= a*size(aList) + b*size(bList) + c)) + + def mergeSort(list:List):List = (list match { + case Nil() => list + case Cons(x,Nil()) => list + case _ => + val (fst,snd) = split(list,length(list)/2) + merge(mergeSort(fst), mergeSort(snd)) + + }) ensuring(res => size(res) == size(list) && tmpl((a,b) => depth <= a*size(list) + b)) +} diff --git a/testcases/orb-testcases/depth/PrimesParallel.scala b/testcases/orb-testcases/depth/PrimesParallel.scala new file mode 100644 index 0000000000000000000000000000000000000000..98724e97289a565d00f2f96b18754849dcc448e0 --- /dev/null +++ b/testcases/orb-testcases/depth/PrimesParallel.scala @@ -0,0 +1,78 @@ +import leon.instrumentation._ +import leon.invariant._ +import leon.annotation._ + +object PrimesParallel { + sealed abstract class List + case class Cons(head: BigInt, tail: List) extends List + case class Nil() extends List + + def size(l: List): BigInt = (l match { + case Nil() => 0 + case Cons(_, t) => 1 + size(t) + }) + + //a program that removes from a list, all multiples of a number 'i' upto 'n' + //the depth of this program is again 1 +// def removeMultiples(l: List, i: BigInt, n: BigInt, incr: BigInt): (List, BigInt) = { +// require(i >= 0 && incr >= 1 && i <= n) +// l match { +// case Nil() => (Nil(), 0) +// case Cons(x, t) => { +// if (x < i) { +// val (r,d) = removeMultiples(t, i, n, incr) +// (Cons(x, r), max(d, 2)) +// +// } else if (x > i) { +// val ni = i + incr +// if (ni > n) (l, 2) +// else { +// val (r,d) = removeMultiples(l, ni, n, incr) +// (r, max(d, 2)) +// } +// +// +// } else { +// val ni = i + incr +// if (ni > n) (t, 2) +// else{ +// val (r,d) = removeMultiples(l, ni, n, incr) +// (r, max(d, 2)) +// } +// } +// } +// } +// } //ensuring (res => true && tmpl ((a) => res._2 <= a)) + //ensuring (res => true && tmpl ((a,b) => time <= a*(size(l) + n - i) + b)) + + //another function with constant depth +// def createList(i: BigInt, n: BigInt) : (List, BigInt) = { +// require(i <= n) +// if(n == i) (Nil(), 0) +// else { +// val (r, d) = createList(i+1, n) +// (Cons(i, r), max(d, 2)) +// } +// } //ensuring(res => true && tmpl((a) => res._2 <= a)) + //ensuring(res => true && tmpl((a,b) => time <= a*(n-i) + b)) + +// def removeNonPrimes(currval: BigInt, l: List, n: BigInt, sqrtn: BigInt): (List, BigInt) = { +// require(currval <= sqrtn && sqrtn <= n && currval >= 1) +// +// val (r,d) = removeMultiples(l, currval, n, currval) +// if(currval == sqrtn) { +// (r, d + 2) +// } else { +// val (res, t) = removeNonPrimes(currval + 1, r, n, sqrtn) +// (res, t + 2) +// } +// } //ensuring(res => true && tmpl((a,b) => res._2 <= a*(sqrtn - currval) + b)) + +// def simplePrimes(n: BigInt, sqrtn : BigInt) : (List, BigInt) = { +// require(sqrtn >= 2 && sqrtn <= n) +// +// val (l, d1) = createList(2, n) +// val (resl, t2) = removeNonPrimes(2, l, n, sqrtn) +// (resl, d1 + t2 + 3) +// } //ensuring(res => true && tmpl((a,b) => res._2 <= a*sqrtn + b)) +} diff --git a/testcases/orb-testcases/depth/PropLogicDepth.scala b/testcases/orb-testcases/depth/PropLogicDepth.scala new file mode 100644 index 0000000000000000000000000000000000000000..881cd61a4b31ad2b40ab22b84651f4362e00878c --- /dev/null +++ b/testcases/orb-testcases/depth/PropLogicDepth.scala @@ -0,0 +1,112 @@ +import scala.collection.immutable.Set +import leon.instrumentation._ +import leon.invariant._ +import leon.annotation._ + +object PropLogicDepth { + + sealed abstract class Formula + case class And(lhs: Formula, rhs: Formula) extends Formula + case class Or(lhs: Formula, rhs: Formula) extends Formula + case class Implies(lhs: Formula, rhs: Formula) extends Formula + case class Not(f: Formula) extends Formula + case class Literal(id: BigInt) extends Formula + case class True() extends Formula + case class False() extends Formula + + def max(x: BigInt,y: BigInt) = if (x >= y) x else y + + def nestingDepth(f: Formula) : BigInt = (f match { + case And(lhs, rhs) => max(nestingDepth(lhs),nestingDepth(rhs)) + 1 + case Or(lhs, rhs) => max(nestingDepth(lhs),nestingDepth(rhs)) + 1 + case Implies(lhs, rhs) => max(nestingDepth(lhs),nestingDepth(rhs)) + 1 + case Not(f) => nestingDepth(f) + 1 + case _ => 1 + }) + + def removeImplies(f: Formula): Formula = (f match { + case And(lhs, rhs) => And(removeImplies(lhs), removeImplies(rhs)) + case Or(lhs, rhs) => Or(removeImplies(lhs), removeImplies(rhs)) + case Implies(lhs, rhs) => Or(Not(removeImplies(lhs)),removeImplies(rhs)) + case Not(f) => Not(removeImplies(f)) + case _ => f + + }) ensuring((res) => true && tmpl((a,b) => depth <= a*nestingDepth(f) + b)) + + def nnf(formula: Formula): Formula = (formula match { + case And(lhs, rhs) => And(nnf(lhs), nnf(rhs)) + case Or(lhs, rhs) => Or(nnf(lhs), nnf(rhs)) + case Implies(lhs, rhs) => Implies(nnf(lhs), nnf(rhs)) + case Not(And(lhs, rhs)) => Or(nnf(Not(lhs)), nnf(Not(rhs))) + case Not(Or(lhs, rhs)) => And(nnf(Not(lhs)), nnf(Not(rhs))) + case Not(Implies(lhs, rhs)) => And(nnf(lhs), nnf(Not(rhs))) + case Not(Not(f)) => nnf(f) + case Not(Literal(_)) => formula + case Literal(_) => formula + case Not(True()) => False() + case Not(False()) => True() + case _ => formula + }) ensuring((res) => true && tmpl((a,b) => depth <= a*nestingDepth(formula) + b)) + + def isNNF(f: Formula): Boolean = { f match { + case And(lhs, rhs) => isNNF(lhs) && isNNF(rhs) + case Or(lhs, rhs) => isNNF(lhs) && isNNF(rhs) + case Implies(lhs, rhs) => false + case Not(Literal(_)) => true + case Not(_) => false + case _ => true + }} ensuring((res) => true && tmpl((a,b) => depth <= a*nestingDepth(f) + b)) + + def simplify(f: Formula): Formula = (f match { + case And(lhs, rhs) => { + val sl = simplify(lhs) + val sr = simplify(rhs) + + //if lhs or rhs is false, return false + //if lhs is true return rhs + //if rhs is true return lhs + (sl,sr) match { + case (False(), _) => False() + case (_, False()) => False() + case (True(), _) => sr + case (_, True()) => sl + case _ => And(sl, sr) + } + } + case Or(lhs, rhs) => { + val sl = simplify(lhs) + val sr = simplify(rhs) + + //if lhs or rhs is true, return true + //if lhs is false return rhs + //if rhs is false return lhs + (sl,sr) match { + case (True(), _) => True() + case (_, True()) => True() + case (False(), _) => sr + case (_, False()) => sl + case _ => Or(sl, sr) + } + } + case Implies(lhs, rhs) => { + val sl = simplify(lhs) + val sr = simplify(rhs) + + //if lhs is false return true + //if rhs is true return true + //if lhs is true return rhs + //if rhs is false return Not(rhs) + (sl,sr) match { + case (False(), _) => True() + case (_, True()) => True() + case (True(), _) => sr + case (_, False()) => Not(sl) + case _ => Implies(sl, sr) + } + } + case Not(True()) => False() + case Not(False()) => True() + case _ => f + + }) ensuring((res) => true && tmpl((a,b) => depth <= a*nestingDepth(f) + b)) +} \ No newline at end of file diff --git a/testcases/orb-testcases/depth/QSortDepth.scala b/testcases/orb-testcases/depth/QSortDepth.scala new file mode 100644 index 0000000000000000000000000000000000000000..cd8f16627836a3d23067dafaac9bcac627140d6b --- /dev/null +++ b/testcases/orb-testcases/depth/QSortDepth.scala @@ -0,0 +1,41 @@ +import leon.instrumentation._ +import leon.invariant._ +import leon.annotation._ + +object QSortDepth { + sealed abstract class List + case class Cons(head:BigInt,tail:List) extends List + case class Nil() extends List + + def size(l:List): BigInt = {l match { + case Nil() => 0 + case Cons(x,xs) => 1 + size(xs) + }} + + case class Triple(fst:List,snd:List, trd: List) + + def append(aList:List,bList:List): List = {aList match { + case Nil() => bList + case Cons(x, xs) => Cons(x,append(xs,bList)) + }} ensuring(res => size(res) == size(aList) + size(bList) && tmpl((a,b) => depth <= a*size(aList) +b)) + + def partition(n:BigInt,l:List) : Triple = (l match { + case Nil() => Triple(Nil(), Nil(), Nil()) + case Cons(x,xs) => { + val t = partition(n,xs) + if (n < x) Triple(t.fst, t.snd, Cons(x,t.trd)) + else if(n == x) Triple(t.fst, Cons(x,t.snd), t.trd) + else Triple(Cons(x,t.fst), t.snd, t.trd) + } + }) ensuring(res => (size(l) == size(res.fst) + size(res.snd) + size(res.trd)) && tmpl((a,b) => depth <= a*size(l) +b)) + + def quickSort(l:List): List = (l match { + case Nil() => Nil() + case Cons(x,Nil()) => l + case Cons(x,xs) => { + val t = partition(x, xs) + append(append(quickSort(t.fst), Cons(x, t.snd)), quickSort(t.trd)) + } + case _ => l + }) ensuring(res => size(res) == size(l) && tmpl((a,b,c) => depth <= a*(size(l)*size(l)) + b*size(l) + c)) +} diff --git a/testcases/orb-testcases/depth/RedBlackTree.scala b/testcases/orb-testcases/depth/RedBlackTree.scala new file mode 100755 index 0000000000000000000000000000000000000000..9bde5ebdf9f5f59c133498602beb7ae2c5a9c0c4 --- /dev/null +++ b/testcases/orb-testcases/depth/RedBlackTree.scala @@ -0,0 +1,113 @@ +import leon.instrumentation._ +import leon.invariant._ +import scala.collection.immutable.Set +import leon.annotation._ + +object RedBlackTree { + sealed abstract class Color + case class Red() extends Color + case class Black() extends Color + + sealed abstract class Tree + case class Empty() extends Tree + case class Node(color: Color, left: Tree, value: BigInt, right: Tree) extends Tree + + def twopower(x: BigInt) : BigInt = { + require(x >= 0) + if(x < 1) 1 + else + 2* twopower(x - 1) + } + + def size(t: Tree): BigInt = { + require(blackBalanced(t)) + (t match { + case Empty() => 0 + case Node(_, l, v, r) => size(l) + 1 + size(r) + }) + } //ensuring (res => true && tmpl((a,b) => twopower(blackHeight(t)) <= a*res + b)) + + def blackHeight(t : Tree) : BigInt = { + t match { + case Empty() => 0 + case Node(Black(), l, _, _) => blackHeight(l) + 1 + case Node(Red(), l, _, _) => blackHeight(l) + } + } + + //We consider leaves to be black by definition + def isBlack(t: Tree) : Boolean = t match { + case Empty() => true + case Node(Black(),_,_,_) => true + case _ => false + } + + def redNodesHaveBlackChildren(t: Tree) : Boolean = t match { + case Empty() => true + case Node(Black(), l, _, r) => redNodesHaveBlackChildren(l) && redNodesHaveBlackChildren(r) + case Node(Red(), l, _, r) => isBlack(l) && isBlack(r) && redNodesHaveBlackChildren(l) && redNodesHaveBlackChildren(r) + case _ => false + } + + def redDescHaveBlackChildren(t: Tree) : Boolean = t match { + case Node(_,l,_,r) => redNodesHaveBlackChildren(l) && redNodesHaveBlackChildren(r) + case _ => true + } + + def blackBalanced(t : Tree) : Boolean = t match { + case Node(_,l,_,r) => blackBalanced(l) && blackBalanced(r) && blackHeight(l) == blackHeight(r) + case _ => true + } + + // <<insert element x into the tree t>> + def ins(x: BigInt, t: Tree): Tree = { + require(redNodesHaveBlackChildren(t) && blackBalanced(t)) + + t match { + case Empty() => Node(Red(),Empty(),x,Empty()) + case Node(c,a,y,b) => + if(x < y) { + val t1 = ins(x, a) + balance(c, t1, y, b) + } + else if (x == y){ + Node(c,a,y,b) + } + else{ + val t1 = ins(x, b) + balance(c,a,y,t1) + } + } + } ensuring(res => true && tmpl((a,b) => depth <= a*blackHeight(t) + b)) + + def makeBlack(n: Tree): Tree = { + n match { + case Node(Red(),l,v,r) => Node(Black(),l,v,r) + case _ => n + } + } + + def add(x: BigInt, t: Tree): Tree = { + require(redNodesHaveBlackChildren(t) && blackBalanced(t) ) + val t1 = ins(x, t) + makeBlack(t1) + + } ensuring(res => true && tmpl((a,b) => depth <= a*blackHeight(t) + b)) + + def balance(co: Color, l: Tree, x: BigInt, r: Tree): Tree = { + Node(co,l,x,r) + match { + case Node(Black(),Node(Red(),Node(Red(),a,xV,b),yV,c),zV,d) => + Node(Red(),Node(Black(),a,xV,b),yV,Node(Black(),c,zV,d)) + case Node(Black(),Node(Red(),a,xV,Node(Red(),b,yV,c)),zV,d) => + Node(Red(),Node(Black(),a,xV,b),yV,Node(Black(),c,zV,d)) + case Node(Black(),a,xV,Node(Red(),Node(Red(),b,yV,c),zV,d)) => + Node(Red(),Node(Black(),a,xV,b),yV,Node(Black(),c,zV,d)) + case Node(Black(),a,xV,Node(Red(),b,yV,Node(Red(),c,zV,d))) => + Node(Red(),Node(Black(),a,xV,b),yV,Node(Black(),c,zV,d)) + case _ => Node(co,l,x,r) + } + } + + +} diff --git a/testcases/orb-testcases/depth/SpeedBenchmarks.scala b/testcases/orb-testcases/depth/SpeedBenchmarks.scala new file mode 100644 index 0000000000000000000000000000000000000000..c8679d68de7065fd03de68f762233c998a6c5fdb --- /dev/null +++ b/testcases/orb-testcases/depth/SpeedBenchmarks.scala @@ -0,0 +1,109 @@ +import leon.instrumentation._ +import leon.invariant._ +import leon.annotation._ + +object SpeedBenchmarks { + sealed abstract class List + case class Cons(head: BigInt, tail: List) extends List + case class Nil() extends List + + def size(l: List): BigInt = (l match { + case Nil() => 0 + case Cons(_, t) => 1 + size(t) + }) + + sealed abstract class StringBuffer + case class Chunk(str: List, next: StringBuffer) extends StringBuffer + case class Empty() extends StringBuffer + + def length(sb: StringBuffer) : BigInt = sb match { + case Chunk(_, next) => 1 + length(next) + case _ => 0 + } + + def sizeBound(sb: StringBuffer, k: BigInt) : Boolean ={ + sb match { + case Chunk(str, next) => size(str) <= k && sizeBound(next, k) + case _ => 0 <= k + } + } + + /** + * Fig. 1 of SPEED, POPL'09: The functional version of the implementation. + * Equality check of two string buffers + */ + def Equals(str1: List, str2: List, s1: StringBuffer, s2: StringBuffer, k: BigInt) : Boolean = { + require(sizeBound(s1, k) && sizeBound(s2, k) && size(str1) <= k && size(str2) <= k && k >= 0) + + (str1, str2) match { + case (Cons(h1,t1), Cons(h2,t2)) => { + + if(h1 != h2) false + else Equals(t1,t2, s1,s2, k) + } + case (Cons(_,_), Nil()) => { + //load from s2 + s2 match { + case Chunk(str, next) => Equals(str1, str, s1, next, k) + case Empty() => false + } + } + case (Nil(), Cons(_,_)) => { + //load from s1 + s1 match { + case Chunk(str, next) => Equals(str, str2, next, s2, k) + case Empty() => false + } + } + case _ =>{ + //load from both + (s1,s2) match { + case (Chunk(nstr1, next1),Chunk(nstr2, next2)) => Equals(nstr1, nstr2, next1, next2, k) + case (Empty(),Chunk(nstr2, next2)) => Equals(str1, nstr2, s1, next2, k) + case (Chunk(nstr1, next1), Empty()) => Equals(nstr1, str2, next1, s2, k) + case _ => true + } + } + } + } ensuring(res => true && tmpl((a,b,c,d,e) => depth <= a*((k+1)*(length(s1) + length(s2))) + b*size(str1) + e)) + + def max(x: BigInt, y: BigInt) : BigInt = if(x >= y) x else y + + //Fig. 2 of Speed POPL'09 + def Dis1(x : BigInt, y : BigInt, n: BigInt, m: BigInt) : BigInt = { + if(x >= n) 0 + else { + if(y < m) Dis1(x, y+1, n, m) + else Dis1(x+1, y, n, m) + } + } ensuring(res => true && tmpl((a,b,c) => depth <= a*max(0,n-x) + b*max(0,m-y) + c)) + + //Fig. 2 of Speed POPL'09 + def Dis2(x : BigInt, z : BigInt, n: BigInt) : BigInt = { + if(x >= n) 0 + else { + if(z > x) Dis2(x+1, z, n) + else Dis2(x, z+1, n) + } + } ensuring(res => true && tmpl((a,b,c) => depth <= a*max(0,n-x) + b*max(0,n-z) + c)) + + //Pg. 138, Speed POPL'09 + def Dis3(x : BigInt, b : Boolean, t: BigInt, n: BigInt) : BigInt = { + require((b && t == 1) || (!b && t == -1)) + if(x > n || x < 0) 0 + else { + if(b) Dis3(x+t, b, t, n) + else Dis3(x-t, b, t, n) + } + } ensuring(res => true && tmpl((a,c) => depth <= a*max(0,(n-x)) + c)) + + //Pg. 138, Speed POPL'09 + def Dis4(x : BigInt, b : Boolean, t: BigInt, n: BigInt) : BigInt = { + if(x > n || x < 0) 0 + else { + if(b) Dis4(x+t, b, t, n) + else Dis4(x-t, b, t, n) + } + } ensuring(res => true && tmpl((a,c,d,e) => (((b && t >= 0) || (!b && t < 0)) && depth <= a*max(0,(n-x)) + c) + || (((!b && t >= 0) || (b && t < 0)) && depth <= d*max(0,x) + e))) +} diff --git a/testcases/orb-testcases/depth/TreeOperations.scala b/testcases/orb-testcases/depth/TreeOperations.scala new file mode 100755 index 0000000000000000000000000000000000000000..d6199f16345d0e39ba3239949d8ac0486d3345e5 --- /dev/null +++ b/testcases/orb-testcases/depth/TreeOperations.scala @@ -0,0 +1,93 @@ +import leon.instrumentation._ +import leon.invariant._ +import leon.annotation._ + +object TreeOperations { + + sealed abstract class List + case class Cons(head: BigInt, tail: List) extends List + case class Nil() extends List + + sealed abstract class Tree + case class Node(left: Tree, value: BigInt, right: Tree) extends Tree + case class Leaf() extends Tree + + def listSize(l: List): BigInt = (l match { + case Nil() => 0 + case Cons(_, t) => 1 + listSize(t) + }) + + def size(t: Tree): BigInt = { + t match { + case Leaf() => 0 + case Node(l, x, r) => { + size(l) + size(r) + 1 + } + } + } + + def height(t: Tree): BigInt = { + t match { + case Leaf() => 0 + case Node(l, x, r) => { + val hl = height(l) + val hr = height(r) + if (hl > hr) hl + 1 else hr + 1 + } + } + } + + def insert(elem: BigInt, t: Tree): Tree = { + t match { + case Leaf() => Node(Leaf(), elem, Leaf()) + case Node(l, x, r) => if (x <= elem) Node(l, x, insert(elem, r)) + else Node(insert(elem, l), x, r) + } + } ensuring (res => height(res) <= height(t) + 1 && tmpl((a,b) => depth <= a*height(t) + b)) + + def addAll(l: List, t: Tree): Tree = { + l match { + case Nil() => t + case Cons(x, xs) =>{ + val newt = insert(x, t) + addAll(xs, newt) + } + } + } ensuring(res => true && tmpl((a,b,c) => depth <= a*(listSize(l) * (height(t) + listSize(l))) + b*listSize(l) + c)) + + def remove(elem: BigInt, t: Tree): Tree = { + t match { + case Leaf() => Leaf() + case Node(l, x, r) => { + + if (x < elem) Node(l, x, remove(elem, r)) + else if (x > elem) Node(remove(elem, l), x, r) + else { + t match { + case Node(Leaf(), x, Leaf()) => Leaf() + case Node(Leaf(), x, Node(_, rx, _)) => Node(Leaf(), rx, remove(rx, r)) + case Node(Node(_, lx, _), x, r) => Node(remove(lx, l), lx, r) + case _ => Leaf() + } + } + } + } + } ensuring (res => height(res) <= height(t) && tmpl ((a, b, c) => depth <= a*height(t) + b)) + + def removeAll(l: List, t: Tree): Tree = { + l match { + case Nil() => t + case Cons(x, xs) => removeAll(xs, remove(x, t)) + } + } ensuring(res => true && tmpl((a,b,c) => depth <= a*(listSize(l) * height(t)) + b*listSize(l) + c)) + + def contains(elem : BigInt, t : Tree) : Boolean = { + t match { + case Leaf() => false + case Node(l, x, r) => + if(x == elem) true + else if (x < elem) contains(elem, r) + else contains(elem, l) + } + } ensuring (res => true && tmpl((a,b) => depth <= a*height(t) + b)) +} \ No newline at end of file diff --git a/testcases/orb-testcases/numerical/ConcatVariationsAbs.scala b/testcases/orb-testcases/numerical/ConcatVariationsAbs.scala new file mode 100644 index 0000000000000000000000000000000000000000..bff880ab30ac1495c7ab6c706265376f1fafe7ad --- /dev/null +++ b/testcases/orb-testcases/numerical/ConcatVariationsAbs.scala @@ -0,0 +1,43 @@ +import leon.invariant._ + +object ConcatVariationsAbs { + def genL(n: BigInt): BigInt = { + require(n >= 0) + if (n == 0) + BigInt(2) + else + 4 + genL(n - 1) + } ensuring (res => tmpl((a, b) => res <= a * n + b)) + + def append(l1: BigInt, l2: BigInt): BigInt = { + require(l1 >= 0 && l2 >= 0) + if (l1 == 0) + BigInt(3) + else + append(l1 - 1, l2 + 1) + 5 + } ensuring (res => tmpl((a, b) => res <= a * l1 + b)) + + def f_good(m: BigInt, n: BigInt): BigInt = { + require(0 <= m && 0 <= n) + if (m == 0) BigInt(2) + else { + val t1 = genL(n) + val t2 = f_good(m - 1, n) + val t3 = append(n, n * (m - 1)) + (t1 + t2 + t3 + 6) + } + + } ensuring (res => tmpl((a, b, c, d) => res <= a * (n * m) + b * n + c * m + d)) + + def f_worst(m: BigInt, n: BigInt): BigInt = { + require(0 <= m && 0 <= n) + if (m == 0) BigInt(2) + else { + val t1 = genL(n) + val t2 = f_worst(m - 1, n) + val t3 = append(n * (m - 1), n) + (t1 + t2 + t3 + 6) + } + + } ensuring (res => tmpl((a, c, d, e, f) => res <= a * ((n * m) * m) + c * (n * m) + d * n + e * m + f)) +} diff --git a/testcases/orb-testcases/numerical/ListAppendAbs.scala b/testcases/orb-testcases/numerical/ListAppendAbs.scala new file mode 100755 index 0000000000000000000000000000000000000000..63038cc74064b928534a0e53ac0baf0184c09d78 --- /dev/null +++ b/testcases/orb-testcases/numerical/ListAppendAbs.scala @@ -0,0 +1,20 @@ +import leon.invariant._ + +object ListAppendAbs +{ + def app(x: BigInt) : BigInt = { + require(x >=0) + + app0(x,1) + + } ensuring(res => res == x + 1) + + def app0(a: BigInt, b: BigInt) : BigInt = { + require(a >=0 && b >=0) + + if(a <= 0) + b + else + app0(a-1,b+1) + } ensuring(res => tmpl((p, q, r) => (p*res + q*a + r*b == 0 && q != 0))) +} diff --git a/testcases/orb-testcases/numerical/LogarithmTest.scala b/testcases/orb-testcases/numerical/LogarithmTest.scala new file mode 100755 index 0000000000000000000000000000000000000000..00fdc552dcb7b0d691139f07578fed5d654d0272 --- /dev/null +++ b/testcases/orb-testcases/numerical/LogarithmTest.scala @@ -0,0 +1,30 @@ +import leon.invariant._ +import leon.annotation._ + +object LogarithmTest { + + @monotonic + def log(x: BigInt) : BigInt = { + require(x >= 0) + if(x <= 1) BigInt(0) + else { + 1 + log(x/2) + } + } ensuring(_ >= 0) + + def binarySearchAbs(x: BigInt, min: BigInt, max: BigInt): BigInt = { + require(max - min >= 0) + if (max - min <= 1) BigInt(2) + else { + val mid = (min + max) / 2 + if (x < mid) { + binarySearchAbs(x, min, mid) + 5 + } else if (x > mid) { + binarySearchAbs(x, mid + 1, max) + 7 + } else + BigInt(8) + } + } ensuring(res => tmpl((a,b) => res <= a*log(max - min) + b)) + //ensuring(res => tmpl((a,b) => res <= 7*log(max - min) + 2)) + // +} diff --git a/testcases/orb-testcases/numerical/QueueAbs.scala b/testcases/orb-testcases/numerical/QueueAbs.scala new file mode 100644 index 0000000000000000000000000000000000000000..a7aee3a93d4ca9ee058b2f0695f001d29fdf3acc --- /dev/null +++ b/testcases/orb-testcases/numerical/QueueAbs.scala @@ -0,0 +1,70 @@ +import leon.invariant._ + +object AmortizedQueue { + def concat(l1: BigInt, l2: BigInt): BigInt = { + require(l1 >= 0 && l2 >= 0) + if (l1 == 0) + BigInt(3) + else + concat(l1 - 1, l2 + 1) + 5 + } ensuring (res => tmpl((a, b) => res <= a * l1 + b)) + + def reverseRec(l1: BigInt, l2: BigInt): BigInt = { + require(l1 >= 0 && l2 >= 0) + if (l1 == 0) + BigInt(3) + else { + reverseRec(l1 - 1, l2 + 1) + 6 + } + } ensuring (res => tmpl((a, b) => res <= a * l1 + b)) + + def reverse(l: BigInt): BigInt = { + require(l >= 0) + reverseRec(l, 0) + 1 + } ensuring (res => tmpl((a, b) => res <= a * l + b)) + + def create(front: BigInt, rear: BigInt): BigInt = { + require(front >= 0 && rear >= 0) + if (rear <= front) + BigInt(4) + else { + val t1 = reverse(rear) + val t2 = concat(front, rear) + t1 + t2 + 7 + } + } + + def enqueue(q: BigInt, front: BigInt, rear: BigInt): BigInt = { + require(q == front + rear && q >= 0 && front >= 0 && rear >= 0) + create(front, rear) + 5 + } ensuring (res => tmpl((a, b) => res <= a * q + b)) + + def dequeue(q: BigInt, front: BigInt, rear: BigInt): BigInt = { + require(q == front + rear && q >= 1 && front >= rear && rear >= 0) + if (front >= 1) { + create(front - 1, rear) + 4 + } else { + //since front should be greater than rear, here rear should be 0 as well + BigInt(5) + } + } ensuring (res => tmpl((a, b) => res <= a * q + b)) + + def removeLast(l: BigInt): BigInt = { + require(l >= 1) + if (l == 1) { + BigInt(4) + } else { + removeLast(l - 1) + 6 + } + } ensuring (res => tmpl((a, b) => res <= a * l + b)) + + def pop(q: BigInt, front: BigInt, rear: BigInt): BigInt = { + require(q == front + rear && q >= 1 && front >= rear && rear >= 0) + if (rear >= 1) { + BigInt(3) + } else { + val t1 = removeLast(front) + t1 + 5 + } + } ensuring (res => tmpl((a, b) => res <= a * q + b)) +} diff --git a/testcases/orb-testcases/numerical/SimpleInterProc.scala b/testcases/orb-testcases/numerical/SimpleInterProc.scala new file mode 100755 index 0000000000000000000000000000000000000000..b2ac527b54cefa3d1dadb063ac58a770dcc76bf9 --- /dev/null +++ b/testcases/orb-testcases/numerical/SimpleInterProc.scala @@ -0,0 +1,16 @@ +object SimpleInterProc +{ + def s(x: BigInt) : BigInt = { + if(x < 0) + makePositive(x) + else + s(x-1) + 1 + } ensuring(res => res != -1) + + def makePositive(y : BigInt) : BigInt = { + 2*negate(y) + } + def negate(c : BigInt) : BigInt={ + -c + } +} \ No newline at end of file diff --git a/testcases/orb-testcases/numerical/SimpleLoop.scala b/testcases/orb-testcases/numerical/SimpleLoop.scala new file mode 100755 index 0000000000000000000000000000000000000000..6a2cdb3d9958f5ad41783d3723b1c08ef7c0ba16 --- /dev/null +++ b/testcases/orb-testcases/numerical/SimpleLoop.scala @@ -0,0 +1,9 @@ +object SimpleLoop +{ + def s(x: BigInt) : BigInt = { + if(x < 0) + BigInt(0) + else + s(x-1) + 1 + } ensuring(res => res != -1) +} \ No newline at end of file diff --git a/testcases/orb-testcases/numerical/see-saw.scala b/testcases/orb-testcases/numerical/see-saw.scala new file mode 100644 index 0000000000000000000000000000000000000000..894a8caed2a298a22a89d8bb6925cc03f022407c --- /dev/null +++ b/testcases/orb-testcases/numerical/see-saw.scala @@ -0,0 +1,15 @@ +object SeeSaw { + def s(x: BigInt, y: BigInt, z: BigInt): BigInt = { + require(y >= 0) + + if (x >= 100) { + y + } else if (x <= z) { //some condition + s(x + 1, y + 2, z) + } else if (x <= z + 9) { //some condition + s(x + 1, y + 3, z) + } else { + s(x + 2, y + 1, z) + } + } ensuring (res => (100 - x <= 2 * res)) +} \ No newline at end of file diff --git a/testcases/orb-testcases/stack/BinaryTrie.scala b/testcases/orb-testcases/stack/BinaryTrie.scala new file mode 100644 index 0000000000000000000000000000000000000000..f2dfd876cdbc51c969f4afc7d1548810be111c02 --- /dev/null +++ b/testcases/orb-testcases/stack/BinaryTrie.scala @@ -0,0 +1,120 @@ +import leon.invariant._ +import leon.instrumentation._ +//import scala.collection.immutable.Set + +object BinaryTrie { + sealed abstract class Tree + case class Leaf() extends Tree + case class Node(nvalue: BigInt, left: Tree, right: Tree) extends Tree + + sealed abstract class IList + case class Cons(head: BigInt, tail: IList) extends IList + case class Nil() extends IList + + def listSize(l: IList): BigInt = (l match { + case Nil() => 0 + case Cons(x, xs) => 1 + listSize(xs) + }) + + def height(t: Tree): BigInt = { + t match { + case Leaf() => 0 + case Node(x, l, r) => { + val hl = height(l) + val hr = height(r) + if (hl > hr) hl + 1 else hr + 1 + } + } + } + + def find(inp: IList, t: Tree): Tree = { + inp match { + case Nil() => t + case Cons(x, Nil()) => t + case Cons(x, xs @ Cons(y, _)) => { + t match { + case Leaf() => t + case Node(v, l, r) => { + if (y > 0) find(xs, l) else find(xs, r) + } + } + } + case _ => t + } + } ensuring (_ => stack <= ? * listSize(inp) + ?) + + def insert(inp: IList, t: Tree): Tree = { + t match { + case Leaf() => { + inp match { + case Nil() => t + case Cons(x, xs) => { + val newch = insert(xs, Leaf()) + newch match { + case Leaf() => Node(x, Leaf(), Leaf()) + case Node(y, _, _) => if (y > 0) Node(x, newch, Leaf()) else Node(y, Leaf(), newch) + } + } + } + + } + case Node(v, l, r) => { + inp match { + case Nil() => t + case Cons(x, Nil()) => t + case Cons(x, xs @ Cons(y, _)) => { + val ch = if (y > 0) l else r + if (y > 0) + Node(v, insert(xs, ch), r) + else + Node(v, l, insert(xs, ch)) + } + case _ => t + } + } + } + } ensuring (_ => stack <= ? * listSize(inp) + ?) + + def create(inp: IList): Tree = { + insert(inp, Leaf()) + } ensuring (_ => stack <= ? * listSize(inp) + ?) + + def delete(inp: IList, t: Tree): Tree = { + t match { + case Leaf() => { + inp match { + case Nil() => Leaf() + case Cons(x ,xs) => { + //the input is not in the tree, so do nothing + Leaf() + } + } + } + case Node(v, l, r) => { + inp match { + case Nil() => { + //the tree has extensions of the input list so do nothing + t + } + case Cons(x, Nil()) => { + //if "l" and "r" are nil, remove the node + if(l == Leaf() && r == Leaf()) Leaf() + else t + } + case Cons(x ,xs@Cons(y, _)) => { + val ch = if(y > 0) l else r + val newch = delete(xs, ch) + if(newch == Leaf() && ((y > 0 && r == Leaf()) || (y <= 0 && l == Leaf()))) Leaf() + else { + if(y > 0) + Node(v, newch, r) + else + Node(v, l, newch) + } + } + case _ => t + } + } + } + } ensuring (_ => stack <= ? * listSize(inp) + ?) +} diff --git a/testcases/orb-testcases/stack/BinomialHeap.scala b/testcases/orb-testcases/stack/BinomialHeap.scala new file mode 100644 index 0000000000000000000000000000000000000000..e18b9b1438b9c66925750dd86a2367f76f5866a0 --- /dev/null +++ b/testcases/orb-testcases/stack/BinomialHeap.scala @@ -0,0 +1,207 @@ +import leon.invariant._ +import leon.instrumentation._ + +object BinomialHeap { + //sealed abstract class TreeNode + case class TreeNode(rank: BigInt, elem: Element, children: BinomialHeap) + case class Element(n: BigInt) + + sealed abstract class BinomialHeap + case class ConsHeap(head: TreeNode, tail: BinomialHeap) extends BinomialHeap + case class NilHeap() extends BinomialHeap + + sealed abstract class List + case class NodeL(head: BinomialHeap, tail: List) extends List + case class NilL() extends List + + sealed abstract class OptionalTree + case class Some(t : TreeNode) extends OptionalTree + case class None() extends OptionalTree + + /* Lower or Equal than for Element structure */ + private def leq(a: Element, b: Element) : Boolean = { + a match { + case Element(a1) => { + b match { + case Element(a2) => { + if(a1 <= a2) true + else false + } + } + } + } + } + + /* isEmpty function of the Binomial Heap */ + def isEmpty(t: BinomialHeap) = t match { + case ConsHeap(_,_) => false + case _ => true + } + + /* Helper function to determine rank of a TreeNode */ + def rank(t: TreeNode) : BigInt = t.rank /*t match { + case TreeNode(r, _, _) => r + }*/ + + /* Helper function to get the root element of a TreeNode */ + def root(t: TreeNode) : Element = t.elem /*t match { + case TreeNode(_, e, _) => e + }*/ + + /* Linking trees of equal ranks depending on the root element */ + def link(t1: TreeNode, t2: TreeNode): TreeNode = { + if (leq(t1.elem, t2.elem)) { + TreeNode(t1.rank + 1, t1.elem, ConsHeap(t2, t1.children)) + } else { + TreeNode(t1.rank + 1, t2.elem, ConsHeap(t1, t2.children)) + } + } + + def treeNum(h: BinomialHeap) : BigInt = { + h match { + case ConsHeap(head, tail) => 1 + treeNum(tail) + case _ => 0 + } + } + + /* Insert a tree into a binomial heap. The tree should be correct in relation to the heap */ + def insTree(t: TreeNode, h: BinomialHeap) : BinomialHeap = { + h match { + case ConsHeap(head, tail) => { + if (rank(t) < rank(head)) { + ConsHeap(t, h) + } else if (rank(t) > rank(head)) { + ConsHeap(head, insTree(t,tail)) + } else { + insTree(link(t,head), tail) + } + } + case _ => ConsHeap(t, NilHeap()) + } + } ensuring(res => tmpl((a,b) => stack <= a*treeNum(h) + b)) + + /* Merge two heaps together */ + def merge(h1: BinomialHeap, h2: BinomialHeap): BinomialHeap = { + h1 match { + case ConsHeap(head1, tail1) => { + h2 match { + case ConsHeap(head2, tail2) => { + if (rank(head1) < rank(head2)) { + ConsHeap(head1, merge(tail1, h2)) + } else if (rank(head2) < rank(head1)) { + ConsHeap(head2, merge(h1, tail2)) + } else { + mergeWithCarry(link(head1, head2), tail1, tail2) + } + } + case _ => h1 + } + } + case _ => h2 + } + } ensuring(res => tmpl((a,b,c) => stack <= a*treeNum(h1) + b*treeNum(h2) + c)) + + def mergeWithCarry(t: TreeNode, h1: BinomialHeap, h2: BinomialHeap): BinomialHeap = { + h1 match { + case ConsHeap(head1, tail1) => { + h2 match { + case ConsHeap(head2, tail2) => { + if (rank(head1) < rank(head2)) { + + if (rank(t) < rank(head1)) + ConsHeap(t, ConsHeap(head1, merge(tail1, h2))) + else + mergeWithCarry(link(t, head1), tail1, h2) + + } else if (rank(head2) < rank(head1)) { + + if (rank(t) < rank(head2)) + ConsHeap(t, ConsHeap(head2, merge(h1, tail2))) + else + mergeWithCarry(link(t, head2), h1, tail2) + + } else { + ConsHeap(t, mergeWithCarry(link(head1, head2), tail1, tail2)) + } + } + case _ => { + insTree(t, h1) + } + } + } + case _ => insTree(t, h2) + } + } ensuring (res => tmpl((d, e, f) => stack <= d * treeNum(h1) + e * treeNum(h2) + f)) + + //Auxiliary helper function to simplefy findMin and deleteMin + def removeMinTree(h: BinomialHeap): (OptionalTree, BinomialHeap) = { + h match { + case ConsHeap(head, NilHeap()) => (Some(head), NilHeap()) + case ConsHeap(head1, tail1) => { + val (opthead2, tail2) = removeMinTree(tail1) + opthead2 match { + case Some(head2) => + if (leq(root(head1), root(head2))) { + (Some(head1), tail1) + } else { + (Some(head2), ConsHeap(head1, tail2)) + } + case _ => (Some(head1), tail1) + } + } + case _ => (None(), NilHeap()) + } + } ensuring (res => treeNum(res._2) <= treeNum(h) && tmpl((a, b) => stack <= a * treeNum(h) + b)) + + /*def findMin(h: BinomialHeap) : Element = { + val (opt, _) = removeMinTree(h) + opt match { + case Some(TreeNode(_,e,ts1)) => e + case _ => Element(-1) + } + } ensuring(res => true && tmpl((a,b) => time <= a*treeNum(h) + b))*/ + + def minTreeChildren(h: BinomialHeap) : BigInt = { + val (min, _) = removeMinTree(h) + min match { + case Some(TreeNode(_,_,ch)) => treeNum(ch) + case _ => 0 + } + } + + // Discard the minimum element of the extracted min tree and put its children back into the heap + def deleteMin(h: BinomialHeap) : BinomialHeap = { + val (min, ts2) = removeMinTree(h) + min match { + case Some(TreeNode(_,_,ts1)) => merge(ts1, ts2) + case _ => h + } + } ensuring(res => tmpl((a,b,c) => stack <= a*minTreeChildren(h) + b*treeNum(h) + c)) + + /*def heapSize(h: BinomialHeap) : BigInt = { + h match { + NilHeap() => 0 + ConsHeap(head, tail) => + treeSize(head) + heapSize(tail) + } + } + + def treeSize(tree: TreeNode) : BigInt = { + val (_, _, children) = tree + heapSize(children) + 1 + } + + @monotonic + def twopower(x: BigInt) : BigInt = { + require(x >= 0) + if(x < 1) 1 + else + 2* twopower(x - 1) + } + + def sizeProperty(tree: TreeNode) : BigInt = { + val (r, _, _) = tree + treeSize(tree) == twopower(r) + }*/ + +} diff --git a/testcases/orb-testcases/stack/ListOperations.scala b/testcases/orb-testcases/stack/ListOperations.scala new file mode 100644 index 0000000000000000000000000000000000000000..3ae4a2d1705051b185c444c7342bf9b180fc4086 --- /dev/null +++ b/testcases/orb-testcases/stack/ListOperations.scala @@ -0,0 +1,35 @@ +import leon.invariant._ +import leon.instrumentation._ + +object ListOperations { + sealed abstract class List + case class Cons(head: BigInt, tail: List) extends List + case class Nil() extends List + + def size(l: List): BigInt = (l match { + case Nil() => 0 + case Cons(_, t) => 1 + size(t) + }) + + def reverseRec(l1: List, l2: List): List = (l1 match { + case Nil() => l2 + case Cons(x, xs) => reverseRec(xs, Cons(x, l2)) + + }) ensuring (res => size(l1) + size(l2) == size(res) && tmpl((a,b) => stack <= a*size(l1) + b)) + + def contains(list: List, elem: BigInt): Boolean = (list match { + case Nil() => false + case Cons(x, xs) => x == elem || contains(xs, elem) + + }) ensuring (res => tmpl((a,b) => stack <= a*size(list) + b)) + + def distinct(l: List): List = ( + l match { + case Nil() => Nil() + case Cons(x, xs) => { + val newl = distinct(xs) + if (contains(newl, x)) newl + else Cons(x, newl) + } + }) ensuring (res => size(l) >= size(res) && tmpl((a,b) => stack <= a*size(l) + b)) +} diff --git a/testcases/orb-testcases/stack/MergeSort.scala b/testcases/orb-testcases/stack/MergeSort.scala new file mode 100644 index 0000000000000000000000000000000000000000..7a104689f06714c51e042291d7660d8b46c92541 --- /dev/null +++ b/testcases/orb-testcases/stack/MergeSort.scala @@ -0,0 +1,63 @@ +import leon.invariant._ +import leon.instrumentation._ + +import leon.annotation._ + +object MergeSort { + sealed abstract class List + case class Cons(head: BigInt, tail: List) extends List + case class Nil() extends List + + def size(list: List): BigInt = (list match { + case Nil() => 0 + case Cons(x,xs) => 1 + size(xs) + }) //ensuring(res => true && tmpl((a) => res >= 0)) + + def length(l: List): BigInt = { + l match { + case Nil() => BigInt(0) + case Cons(x,xs) => 1 + length(xs) + } + } ensuring(res => res == size(l) && tmpl((a,b) => stack <= a*size(l) + b)) + + def split(l: List, n: BigInt): (List, List) = { + require(n >= 0 && n <= size(l)) + if (n <= 0) (Nil(),l) + else + l match { + case Nil() => (Nil(),l) + case Cons(x,xs) => { + if(n == 1) (Cons(x,Nil()), xs) + else { + val (fst,snd) = split(xs, n-1) + (Cons(x,fst), snd) + } + } + } + } ensuring(res => size(res._2) == size(l) - n && size(res._1) == n && size(res._2) + size(res._1) == size(l) && tmpl((a,b) => stack <= a*size(l) +b)) + + def merge(aList: List, bList: List): List = (bList match { + case Nil() => aList + case Cons(x,xs) => + aList match { + case Nil() => bList + case Cons(y,ys) => + if (y < x) + Cons(y,merge(ys, bList)) + else + Cons(x,merge(aList, xs)) + } + }) ensuring(res => size(aList) + size(bList) == size(res) && tmpl((a,b,c) => stack <= a*size(aList) + b*size(bList) + c)) + + def mergeSort(list: List): List = { + list match { + case Cons(x, Nil()) => list + case Cons(_, Cons(_, _)) => + val lby2 = length(list) / 2 + val (fst, snd) = split(list, lby2) + merge(mergeSort(fst), mergeSort(snd)) + + case _ => list + } + } ensuring(res => size(res) == size(list) && tmpl((a,b) => stack <= a*size(list) + b)) +} diff --git a/testcases/orb-testcases/stack/QuickSort.scala b/testcases/orb-testcases/stack/QuickSort.scala new file mode 100644 index 0000000000000000000000000000000000000000..20150475348464c5d8c35caaa9a865be0049eace --- /dev/null +++ b/testcases/orb-testcases/stack/QuickSort.scala @@ -0,0 +1,41 @@ +import leon.invariant._ +import leon.instrumentation._ + +object QuickSort { + sealed abstract class List + case class Cons(head:BigInt,tail:List) extends List + case class Nil() extends List + + def size(l:List): BigInt = {l match { + case Nil() => 0 + case Cons(x,xs) => 1 + size(xs) + }} + + case class Triple(fst:List,snd:List, trd: List) + + def append(aList:List,bList:List): List = {aList match { + case Nil() => bList + case Cons(x, xs) => Cons(x,append(xs,bList)) + }} ensuring(res => size(res) == size(aList) + size(bList) && tmpl((a,b) => stack <= a*size(aList) +b)) + + def partition(n:BigInt,l:List) : Triple = (l match { + case Nil() => Triple(Nil(), Nil(), Nil()) + case Cons(x,xs) => { + val t = partition(n,xs) + if (n < x) Triple(t.fst, t.snd, Cons(x,t.trd)) + else if(n == x) Triple(t.fst, Cons(x,t.snd), t.trd) + else Triple(Cons(x,t.fst), t.snd, t.trd) + } + }) ensuring(res => (size(l) == size(res.fst) + size(res.snd) + size(res.trd)) && tmpl((a,b) => stack <= a*size(l) +b)) + + def quickSort(l:List): List = (l match { + case Nil() => Nil() + case Cons(x,Nil()) => l + case Cons(x,xs) => { + val t = partition(x, xs) + append(append(quickSort(t.fst), Cons(x, t.snd)), quickSort(t.trd)) + } + case _ => l + }) ensuring(res => size(l) == size(res) && tmpl((a,b,c,d) => stack <= a*size(l) + d)) +} + diff --git a/testcases/orb-testcases/stack/RedBlackTree.scala b/testcases/orb-testcases/stack/RedBlackTree.scala new file mode 100644 index 0000000000000000000000000000000000000000..44f84a9b3e90e98a362348c8862892e384851344 --- /dev/null +++ b/testcases/orb-testcases/stack/RedBlackTree.scala @@ -0,0 +1,110 @@ +import leon.invariant._ +import leon.instrumentation._ +import scala.collection.immutable.Set + +object RedBlackTree { + sealed abstract class Color + case class Red() extends Color + case class Black() extends Color + + sealed abstract class Tree + case class Empty() extends Tree + case class Node(color: Color, left: Tree, value: BigInt, right: Tree) extends Tree + + def twopower(x: BigInt) : BigInt = { + require(x >= 0) + if(x < 1) 1 + else + 2* twopower(x - 1) + } + + def size(t: Tree): BigInt = { + require(blackBalanced(t)) + (t match { + case Empty() => BigInt(0) + case Node(_, l, v, r) => size(l) + 1 + size(r) + }) + } ensuring (res => tmpl((a,b) => twopower(blackHeight(t)) <= a*res + b)) + + def blackHeight(t : Tree) : BigInt = { + t match { + case Node(Black(), l, _, _) => blackHeight(l) + 1 + case Node(Red(), l, _, _) => blackHeight(l) + case _ => 0 + } + } + + //We consider leaves to be black by definition + def isBlack(t: Tree) : Boolean = t match { + case Empty() => true + case Node(Black(),_,_,_) => true + case _ => false + } + + def redNodesHaveBlackChildren(t: Tree) : Boolean = t match { + case Empty() => true + case Node(Black(), l, _, r) => redNodesHaveBlackChildren(l) && redNodesHaveBlackChildren(r) + case Node(Red(), l, _, r) => isBlack(l) && isBlack(r) && redNodesHaveBlackChildren(l) && redNodesHaveBlackChildren(r) + case _ => false + } + + def redDescHaveBlackChildren(t: Tree) : Boolean = t match { + case Node(_,l,_,r) => redNodesHaveBlackChildren(l) && redNodesHaveBlackChildren(r) + case _ => true + } + + def blackBalanced(t : Tree) : Boolean = t match { + case Node(_,l,_,r) => blackBalanced(l) && blackBalanced(r) && blackHeight(l) == blackHeight(r) + case _ => true + } + + // <<insert element x BigInto the tree t>> + def ins(x: BigInt, t: Tree): Tree = { + require(redNodesHaveBlackChildren(t) && blackBalanced(t)) + + t match { + case Empty() => Node(Red(),Empty(),x,Empty()) + case Node(c,a,y,b) => + if(x < y) { + val t1 = ins(x, a) + balance(c, t1, y, b) + } + else if (x == y){ + Node(c,a,y,b) + } + else{ + val t1 = ins(x, b) + balance(c,a,y,t1) + } + } + } ensuring(res => tmpl((a,b) => stack <= a*blackHeight(t) + b)) + + def makeBlack(n: Tree): Tree = { + n match { + case Node(Red(),l,v,r) => Node(Black(),l,v,r) + case _ => n + } + } + + def add(x: BigInt, t: Tree): Tree = { + require(redNodesHaveBlackChildren(t) && blackBalanced(t) ) + val t1 = ins(x, t) + makeBlack(t1) + + } ensuring(res => tmpl((a,b) => stack <= a*blackHeight(t) + b)) + + def balance(co: Color, l: Tree, x: BigInt, r: Tree): Tree = { + Node(co,l,x,r) + match { + case Node(Black(),Node(Red(),Node(Red(),a,xV,b),yV,c),zV,d) => + Node(Red(),Node(Black(),a,xV,b),yV,Node(Black(),c,zV,d)) + case Node(Black(),Node(Red(),a,xV,Node(Red(),b,yV,c)),zV,d) => + Node(Red(),Node(Black(),a,xV,b),yV,Node(Black(),c,zV,d)) + case Node(Black(),a,xV,Node(Red(),Node(Red(),b,yV,c),zV,d)) => + Node(Red(),Node(Black(),a,xV,b),yV,Node(Black(),c,zV,d)) + case Node(Black(),a,xV,Node(Red(),b,yV,Node(Red(),c,zV,d))) => + Node(Red(),Node(Black(),a,xV,b),yV,Node(Black(),c,zV,d)) + case _ => Node(co,l,x,r) + } + } +} diff --git a/testcases/orb-testcases/stack/SpeedBenchmarks.scala b/testcases/orb-testcases/stack/SpeedBenchmarks.scala new file mode 100644 index 0000000000000000000000000000000000000000..c1c59d592b2b0b59cfe79b780aaca479fcbb222d --- /dev/null +++ b/testcases/orb-testcases/stack/SpeedBenchmarks.scala @@ -0,0 +1,75 @@ +import leon.invariant._ +import leon.instrumentation._ +import leon.math._ + +object SpeedBenchmarks { + sealed abstract class List + case class Cons(head: BigInt, tail: List) extends List + case class Nil() extends List + + def size(l: List): BigInt = (l match { + case Nil() => 0 + case Cons(_, t) => 1 + size(t) + }) + + sealed abstract class StringBuffer + case class Chunk(str: List, next: StringBuffer) extends StringBuffer + case class Empty() extends StringBuffer + + def length(sb: StringBuffer) : BigInt = sb match { + case Chunk(_, next) => 1 + length(next) + case _ => 0 + } + + def sizeBound(sb: StringBuffer, k: BigInt) : Boolean ={ + sb match { + case Chunk(str, next) => size(str) <= k && sizeBound(next, k) + case _ => 0 <= k + } + } + + def sizeBuffer(sb: StringBuffer): BigInt = { + sb match { + case Chunk(str, next) => size(str) + sizeBuffer(sb) + case Empty() => 0 + } + } + + /** + * Fig. 1 of SPEED, POPL'09: The functional version of the implementation. + * Equality check of two string buffers + */ + def Equals(str1: List, str2: List, s1: StringBuffer, s2: StringBuffer, k: BigInt) : Boolean = { + require(sizeBound(s1, k) && sizeBound(s2, k) && size(str1) <= k && size(str2) <= k && k >= 0) + + (str1, str2) match { + case (Cons(h1,t1), Cons(h2,t2)) => { + if(h1 != h2) false + else Equals(t1,t2, s1,s2, k) + } + case (Cons(_,_), Nil()) => { + //load from s2 + s2 match { + case Chunk(str, next) => Equals(str1, str, s1, next, k) + case Empty() => false + } + } + case (Nil(), Cons(_,_)) => { + //load from s1 + s1 match { + case Chunk(str, next) => Equals(str, str2, next, s2, k) + case Empty() => false + } + } + case _ =>{ + //load from both + (s1,s2) match { + case (Chunk(nstr1, next1),Chunk(nstr2, next2)) => Equals(nstr1, nstr2, next1, next2, k) + case (Empty(),Chunk(nstr2, next2)) => Equals(str1, nstr2, s1, next2, k) + case (Chunk(nstr1, next1), Empty()) => Equals(nstr1, str2, next1, s2, k) + case _ => true + } + } + } + } ensuring(res => tmpl((a,b,c,d,e) => stack <= a*max(sizeBuffer(s1), sizeBuffer(s2)) + c*(k+1) + e)) +} diff --git a/testcases/orb-testcases/stack/TreeOperations.scala b/testcases/orb-testcases/stack/TreeOperations.scala new file mode 100644 index 0000000000000000000000000000000000000000..123031d7d0cb523f164f69aabb5b8f0d4aef4ca9 --- /dev/null +++ b/testcases/orb-testcases/stack/TreeOperations.scala @@ -0,0 +1,93 @@ +import leon.invariant._ +import leon.instrumentation._ + + +object TreeOperations { + + sealed abstract class List + case class Cons(head: BigInt, tail: List) extends List + case class Nil() extends List + + sealed abstract class Tree + case class Node(left: Tree, value: BigInt, right: Tree) extends Tree + case class Leaf() extends Tree + + def listSize(l: List): BigInt = (l match { + case Nil() => 0 + case Cons(_, t) => 1 + listSize(t) + }) + + def size(t: Tree): BigInt = { + t match { + case Leaf() => 0 + case Node(l, x, r) => { + size(l) + size(r) + 1 + } + } + } + + def height(t: Tree): BigInt = { + t match { + case Leaf() => 0 + case Node(l, x, r) => { + val hl = height(l) + val hr = height(r) + if (hl > hr) hl + 1 else hr + 1 + } + } + } + + def insert(elem: BigInt, t: Tree): Tree = { + t match { + case Leaf() => Node(Leaf(), elem, Leaf()) + case Node(l, x, r) => if (x <= elem) Node(l, x, insert(elem, r)) + else Node(insert(elem, l), x, r) + } + } ensuring (res => height(res) <= height(t) + 1 && tmpl((a,b) => stack <= a*height(t) + b)) + + def addAll(l: List, t: Tree): Tree = { + l match { + case Nil() => t + case Cons(x, xs) =>{ + val newt = insert(x, t) + addAll(xs, newt) + } + } + } ensuring(res => tmpl((a,b,c) => stack <= a*(listSize(l) * (height(t) + listSize(l))) + b*listSize(l) + c)) + + def remove(elem: BigInt, t: Tree): Tree = { + t match { + case Leaf() => Leaf() + case Node(l, x, r) => { + + if (x < elem) Node(l, x, remove(elem, r)) + else if (x > elem) Node(remove(elem, l), x, r) + else { + t match { + case Node(Leaf(), x, Leaf()) => Leaf() + case Node(Leaf(), x, Node(_, rx, _)) => Node(Leaf(), rx, remove(rx, r)) + case Node(Node(_, lx, _), x, r) => Node(remove(lx, l), lx, r) + case _ => Leaf() + } + } + } + } + } ensuring (res => height(res) <= height(t) && tmpl ((a, b, c) => stack <= a*height(t) + b)) + + def removeAll(l: List, t: Tree): Tree = { + l match { + case Nil() => t + case Cons(x, xs) => removeAll(xs, remove(x, t)) + } + } ensuring(res => tmpl((a,b,c) => stack <= a*(listSize(l) * height(t)) + b*listSize(l) + c)) + + def contains(elem : BigInt, t : Tree) : Boolean = { + t match { + case Leaf() => false + case Node(l, x, r) => + if(x == elem) true + else if (x < elem) contains(elem, r) + else contains(elem, l) + } + } ensuring (res => tmpl((a,b) => stack <= a*height(t) + b)) +} \ No newline at end of file diff --git a/testcases/orb-testcases/timing/AVLTree.scala b/testcases/orb-testcases/timing/AVLTree.scala new file mode 100644 index 0000000000000000000000000000000000000000..d34787eba6ab5c416bffd10cd5d50a72addc9448 --- /dev/null +++ b/testcases/orb-testcases/timing/AVLTree.scala @@ -0,0 +1,195 @@ +import leon.invariant._ +import leon.instrumentation._ +import leon.math._ + +/** + * created by manos and modified by ravi. + * BST property cannot be verified + */ +object AVLTree { + sealed abstract class Tree + case class Leaf() extends Tree + case class Node(left : Tree, value : BigInt, right: Tree, rank : BigInt) extends Tree + + sealed abstract class OptionInt + case class None() extends OptionInt + case class Some(i: BigInt) extends OptionInt + + //def min(i1:BigInt, i2:BigInt) : BigInt = if (i1<=i2) i1 else i2 + //def max(i1:BigInt, i2:BigInt) : BigInt = if (i1>=i2) i1 else i2 + + /*def twopower(x: BigInt) : BigInt = { + //require(x >= 0) + if(x < 1) 1 + else + 3/2 * twopower(x - 1) + } ensuring(res => res >= 1 template((a) => a <= 0))*/ + + def rank(t: Tree) : BigInt = { + t match { + case Leaf() => 0 + case Node(_,_,_,rk) => rk + } + } //ensuring(res => res >= 0) + + def height(t: Tree): BigInt = { + t match { + case Leaf() => 0 + case Node(l, x, r, _) => { + val hl = height(l) + val hr = height(r) + max(hl,hr) + 1 + } + } + } + + def size(t: Tree): BigInt = { + //require(isAVL(t)) + (t match { + case Leaf() => 0 + case Node(l, _, r,_) => size(l) + 1 + size(r) + }) + + } + //ensuring (res => true template((a,b) => height(t) <= a*res + b)) + + def rankHeight(t: Tree) : Boolean = t match { + case Leaf() => true + case Node(l,_,r,rk) => rankHeight(l) && rankHeight(r) && rk == height(t) + } + + def balanceFactor(t : Tree) : BigInt = { + t match{ + case Leaf() => 0 + case Node(l, _, r, _) => rank(l) - rank(r) + } + } + + /*def isAVL(t:Tree) : Boolean = { + t match { + case Leaf() => true + case Node(l,_,r,rk) => isAVL(l) && isAVL(r) && balanceFactor(t) >= -1 && balanceFactor(t) <= 1 && rankHeight(t) //isBST(t) && + } + }*/ + + def unbalancedInsert(t: Tree, e : BigInt) : Tree = { + t match { + case Leaf() => Node(Leaf(), e, Leaf(), 1) + case Node(l,v,r,h) => + if (e == v) t + else if (e < v){ + val newl = avlInsert(l,e) + Node(newl, v, r, max(rank(newl), rank(r)) + 1) + } + else { + val newr = avlInsert(r,e) + Node(l, v, newr, max(rank(l), rank(newr)) + 1) + } + } + } + + def avlInsert(t: Tree, e : BigInt) : Tree = { + + balance(unbalancedInsert(t,e)) + + } ensuring(res => tmpl((a,b) => time <= a*height(t) + b)) + //ensuring(res => time <= 276*height(t) + 38) + //minbound: ensuring(res => time <= 138*height(t) + 19) + + def deletemax(t: Tree): (Tree, OptionInt) = { + + t match { + case Node(Leaf(), v, Leaf(), _) => (Leaf(), Some(v)) + case Node(l, v, Leaf(), _) => { + val (newl, opt) = deletemax(l) + opt match { + case None() => (t, None()) + case Some(lmax) => { + val newt = balance(Node(newl, lmax, Leaf(), rank(newl) + 1)) + (newt, Some(v)) + } + } + } + case Node(_, _, r, _) => deletemax(r) + case _ => (t, None()) + } + } ensuring(res => tmpl((a,b) => time <= a*height(t) + b)) + + def unbalancedDelete(t: Tree, e: BigInt): Tree = { + t match { + case Leaf() => Leaf() //not found case + case Node(l, v, r, h) => + if (e == v) { + if (l == Leaf()) r + else if(r == Leaf()) l + else { + val (newl, opt) = deletemax(l) + opt match { + case None() => t + case Some(newe) => { + Node(newl, newe, r, max(rank(newl), rank(r)) + 1) + } + } + } + } else if (e < v) { + val newl = avlDelete(l, e) + Node(newl, v, r, max(rank(newl), rank(r)) + 1) + } else { + val newr = avlDelete(r, e) + Node(l, v, newr, max(rank(l), rank(newr)) + 1) + } + } + } + + def avlDelete(t: Tree, e: BigInt): Tree = { + + balance(unbalancedDelete(t, e)) + + } ensuring(res => tmpl((a,b) => time <= a*height(t) + b)) + + def balance(t:Tree) : Tree = { + t match { + case Leaf() => Leaf() // impossible... + case Node(l, v, r, h) => + val bfactor = balanceFactor(t) + // at this point, the tree is unbalanced + if(bfactor > 1 ) { // left-heavy + val newL = + if (balanceFactor(l) < 0) { // l is right heavy + rotateLeft(l) + } + else l + rotateRight(Node(newL,v,r, max(rank(newL), rank(r)) + 1)) + } + else if(bfactor < -1) { + val newR = + if (balanceFactor(r) > 0) { // r is left heavy + rotateRight(r) + } + else r + rotateLeft(Node(l,v,newR, max(rank(newR), rank(l)) + 1)) + } else t + } + } + + def rotateRight(t:Tree) = { + t match { + case Node(Node(ll, vl, rl, _),v,r, _) => + + val hr = max(rank(rl),rank(r)) + 1 + Node(ll, vl, Node(rl,v,r,hr), max(rank(ll),hr) + 1) + + case _ => t // this should not happen + } } + + + def rotateLeft(t:Tree) = { + t match { + case Node(l, v, Node(lr,vr,rr,_), _) => + + val hl = max(rank(l),rank(lr)) + 1 + Node(Node(l,v,lr,hl), vr, rr, max(hl, rank(rr)) + 1) + case _ => t // this should not happen + } } +} + diff --git a/testcases/orb-testcases/timing/AmortizedQueue.scala b/testcases/orb-testcases/timing/AmortizedQueue.scala new file mode 100644 index 0000000000000000000000000000000000000000..23b5c35726d4e40f152e5d80878a96952266cbfb --- /dev/null +++ b/testcases/orb-testcases/timing/AmortizedQueue.scala @@ -0,0 +1,88 @@ +import leon.invariant._ +import leon.instrumentation._ + +object AmortizedQueue { + sealed abstract class List + case class Cons(head : BigInt, tail : List) extends List + case class Nil() extends List + + case class Queue(front : List, rear : List) + + def size(list : List) : BigInt = (list match { + case Nil() => 0 + case Cons(_, xs) => 1 + size(xs) + }) + + def sizeList(list: List): BigInt = { + list match { + case Nil() => BigInt(0) + case Cons(_, xs) => 1 + sizeList(xs) + } + }ensuring((res : BigInt) => res >= 0 && tmpl((a, b) => time <= a * size(list) + b)) + + def qsize(q : Queue) : BigInt = size(q.front) + size(q.rear) + + def asList(q : Queue) : List = concat(q.front, reverse(q.rear)) + + def concat(l1 : List, l2 : List) : List = (l1 match { + case Nil() => l2 + case Cons(x,xs) => Cons(x, concat(xs, l2)) + + }) ensuring (res => size(res) == size(l1) + size(l2) && tmpl((a,b,c) => time <= a*size(l1) + b)) + + def isAmortized(q : Queue) : Boolean = sizeList(q.front) >= sizeList(q.rear) + + def isEmpty(queue : Queue) : Boolean = queue match { + case Queue(Nil(), Nil()) => true + case _ => false + } + + def reverseRec(l1: List, l2: List): List = (l1 match { + case Nil() => l2 + case Cons(x, xs) => reverseRec(xs, Cons(x, l2)) + + }) ensuring (res => size(l1) + size(l2) == size(res) && tmpl((a,b) => time <= a*size(l1) + b)) + + def reverse(l: List): List = { + reverseRec(l, Nil()) + } ensuring (res => size(l) == size(res) && tmpl((a,b) => time <= a*size(l) + b)) + + def amortizedQueue(front : List, rear : List) : Queue = { + if (sizeList(rear) <= sizeList(front)) + Queue(front, rear) + else + Queue(concat(front, reverse(rear)), Nil()) + } + + def enqueue(q : Queue, elem : BigInt) : Queue = ({ + + amortizedQueue(q.front, Cons(elem, q.rear)) + + }) ensuring(res => true && tmpl((a,b) => time <= a*qsize(q) + b)) + + def dequeue(q : Queue) : Queue = { + require(isAmortized(q) && !isEmpty(q)) + q match { + case Queue(Cons(f, fs), rear) => amortizedQueue(fs, rear) + case _ => Queue(Nil(),Nil()) + } + } ensuring(res => true && tmpl((a,b) => time <= a*qsize(q) + b)) + + def removeLast(l : List) : List = { + require(l != Nil()) + l match { + case Cons(x,Nil()) => Nil() + case Cons(x,xs) => Cons(x, removeLast(xs)) + case _ => Nil() + } + } ensuring(res => size(res) <= size(l) && tmpl((a,b) => time <= a*size(l) + b)) + + def pop(q : Queue) : Queue = { + require(isAmortized(q) && !isEmpty(q)) + q match { + case Queue(front, Cons(r,rs)) => Queue(front, rs) + case Queue(front, rear) => Queue(removeLast(front), rear) + case _ => Queue(Nil(),Nil()) + } + } ensuring(res => true && tmpl((a,b) => time <= a*size(q.front) + b)) +} diff --git a/testcases/orb-testcases/timing/BinaryTrie.scala b/testcases/orb-testcases/timing/BinaryTrie.scala new file mode 100644 index 0000000000000000000000000000000000000000..a1de6ee0e13bb53383b1bba6548e9e0fa449a166 --- /dev/null +++ b/testcases/orb-testcases/timing/BinaryTrie.scala @@ -0,0 +1,119 @@ +import leon.invariant._ +import leon.instrumentation._ + +object BinaryTrie { + sealed abstract class Tree + case class Leaf() extends Tree + case class Node(nvalue: BigInt, left: Tree, right: Tree) extends Tree + + sealed abstract class IList + case class Cons(head: BigInt, tail: IList) extends IList + case class Nil() extends IList + + def listSize(l: IList): BigInt = (l match { + case Nil() => 0 + case Cons(x, xs) => 1 + listSize(xs) + }) + + def height(t: Tree): BigInt = { + t match { + case Leaf() => 0 + case Node(x, l, r) => { + val hl = height(l) + val hr = height(r) + if (hl > hr) hl + 1 else hr + 1 + } + } + } + + def find(inp: IList, t: Tree): Tree = { + inp match { + case Nil() => t + case Cons(x, Nil()) => t + case Cons(x, xs @ Cons(y, _)) => { + t match { + case Leaf() => t + case Node(v, l, r) => { + if (y > 0) find(xs, l) else find(xs, r) + } + } + } + case _ => t + } + } ensuring (_ => time <= ? * listSize(inp) + ?) + + def insert(inp: IList, t: Tree): Tree = { + t match { + case Leaf() => { + inp match { + case Nil() => t + case Cons(x, xs) => { + val newch = insert(xs, Leaf()) + newch match { + case Leaf() => Node(x, Leaf(), Leaf()) + case Node(y, _, _) => if (y > 0) Node(x, newch, Leaf()) else Node(y, Leaf(), newch) + } + } + } + + } + case Node(v, l, r) => { + inp match { + case Nil() => t + case Cons(x, Nil()) => t + case Cons(x, xs @ Cons(y, _)) => { + val ch = if (y > 0) l else r + if (y > 0) + Node(v, insert(xs, ch), r) + else + Node(v, l, insert(xs, ch)) + } + case _ => t + } + } + } + } ensuring (_ => time <= ? * listSize(inp) + ?) + + def create(inp: IList): Tree = { + insert(inp, Leaf()) + } ensuring (res => true && tmpl((a, c) => time <= a * listSize(inp) + c)) + + def delete(inp: IList, t: Tree): Tree = { + t match { + case Leaf() => { + inp match { + case Nil() => Leaf() + case Cons(x ,xs) => { + //the input is not in the tree, so do nothing + Leaf() + } + } + } + case Node(v, l, r) => { + inp match { + case Nil() => { + //the tree has extensions of the input list so do nothing + t + } + case Cons(x, Nil()) => { + //if "l" and "r" are nil, remove the node + if(l == Leaf() && r == Leaf()) Leaf() + else t + } + case Cons(x ,xs@Cons(y, _)) => { + val ch = if(y > 0) l else r + val newch = delete(xs, ch) + if(newch == Leaf() && ((y > 0 && r == Leaf()) || (y <= 0 && l == Leaf()))) Leaf() + else { + if(y > 0) + Node(v, newch, r) + else + Node(v, l, newch) + } + } + case _ => t + } + } + } + } ensuring (_ => time <= ? * listSize(inp) + ?) +} diff --git a/testcases/orb-testcases/timing/BinomialHeap.scala b/testcases/orb-testcases/timing/BinomialHeap.scala new file mode 100644 index 0000000000000000000000000000000000000000..81b990d41323f353098f5cd02feddf3e25ec9264 --- /dev/null +++ b/testcases/orb-testcases/timing/BinomialHeap.scala @@ -0,0 +1,181 @@ +import leon.invariant._ +import leon.instrumentation._ + +object BinomialHeap { + //sealed abstract class TreeNode + case class TreeNode(rank: BigInt, elem: Element, children: BinomialHeap) + case class Element(n: BigInt) + + sealed abstract class BinomialHeap + case class ConsHeap(head: TreeNode, tail: BinomialHeap) extends BinomialHeap + case class NilHeap() extends BinomialHeap + + sealed abstract class List + case class NodeL(head: BinomialHeap, tail: List) extends List + case class NilL() extends List + + sealed abstract class OptionalTree + case class Some(t : TreeNode) extends OptionalTree + case class None() extends OptionalTree + + /* Lower or Equal than for Element structure */ + private def leq(a: Element, b: Element) : Boolean = { + a match { + case Element(a1) => { + b match { + case Element(a2) => { + if(a1 <= a2) true + else false + } + } + } + } + } + + /* isEmpty function of the Binomial Heap */ + def isEmpty(t: BinomialHeap) = t match { + case ConsHeap(_,_) => false + case _ => true + } + + /* Helper function to determine rank of a TreeNode */ + def rank(t: TreeNode) : BigInt = t.rank /*t match { + case TreeNode(r, _, _) => r + }*/ + + /* Helper function to get the root element of a TreeNode */ + def root(t: TreeNode) : Element = t.elem /*t match { + case TreeNode(_, e, _) => e + }*/ + + /* Linking trees of equal ranks depending on the root element */ + def link(t1: TreeNode, t2: TreeNode): TreeNode = { + if (leq(t1.elem, t2.elem)) { + TreeNode(t1.rank + 1, t1.elem, ConsHeap(t2, t1.children)) + } else { + TreeNode(t1.rank + 1, t2.elem, ConsHeap(t1, t2.children)) + } + } + + def treeNum(h: BinomialHeap) : BigInt = { + h match { + case ConsHeap(head, tail) => 1 + treeNum(tail) + case _ => 0 + } + } + + /* Insert a tree into a binomial heap. The tree should be correct in relation to the heap */ + def insTree(t: TreeNode, h: BinomialHeap) : BinomialHeap = { + h match { + case ConsHeap(head, tail) => { + if (rank(t) < rank(head)) { + ConsHeap(t, h) + } else if (rank(t) > rank(head)) { + ConsHeap(head, insTree(t,tail)) + } else { + insTree(link(t,head), tail) + } + } + case _ => ConsHeap(t, NilHeap()) + } + } ensuring(_ => time <= ? * treeNum(h) + ?) + + /* Merge two heaps together */ + def merge(h1: BinomialHeap, h2: BinomialHeap): BinomialHeap = { + h1 match { + case ConsHeap(head1, tail1) => { + h2 match { + case ConsHeap(head2, tail2) => { + if (rank(head1) < rank(head2)) { + ConsHeap(head1, merge(tail1, h2)) + } else if (rank(head2) < rank(head1)) { + ConsHeap(head2, merge(h1, tail2)) + } else { + mergeWithCarry(link(head1, head2), tail1, tail2) + } + } + case _ => h1 + } + } + case _ => h2 + } + } ensuring(_ => time <= ? * treeNum(h1) + ? * treeNum(h2) + ?) + + def mergeWithCarry(t: TreeNode, h1: BinomialHeap, h2: BinomialHeap): BinomialHeap = { + h1 match { + case ConsHeap(head1, tail1) => { + h2 match { + case ConsHeap(head2, tail2) => { + if (rank(head1) < rank(head2)) { + + if (rank(t) < rank(head1)) + ConsHeap(t, ConsHeap(head1, merge(tail1, h2))) + else + mergeWithCarry(link(t, head1), tail1, h2) + + } else if (rank(head2) < rank(head1)) { + + if (rank(t) < rank(head2)) + ConsHeap(t, ConsHeap(head2, merge(h1, tail2))) + else + mergeWithCarry(link(t, head2), h1, tail2) + + } else { + ConsHeap(t, mergeWithCarry(link(head1, head2), tail1, tail2)) + } + } + case _ => { + insTree(t, h1) + } + } + } + case _ => insTree(t, h2) + } + } ensuring (_ => time <= ? * treeNum(h1) + ? * treeNum(h2) + ?) + + //Auxiliary helper function to simplefy findMin and deleteMin + def removeMinTree(h: BinomialHeap): (OptionalTree, BinomialHeap) = { + h match { + case ConsHeap(head, NilHeap()) => (Some(head), NilHeap()) + case ConsHeap(head1, tail1) => { + val (opthead2, tail2) = removeMinTree(tail1) + opthead2 match { + case Some(head2) => + if (leq(root(head1), root(head2))) { + (Some(head1), tail1) + } else { + (Some(head2), ConsHeap(head1, tail2)) + } + case _ => (Some(head1), tail1) + } + } + case _ => (None(), NilHeap()) + } + } ensuring (res => treeNum(res._2) <= treeNum(h) && time <= ? * treeNum(h) + ?) + + /*def findMin(h: BinomialHeap) : Element = { + val (opt, _) = removeMinTree(h) + opt match { + case Some(TreeNode(_,e,ts1)) => e + case _ => Element(-1) + } + } ensuring(res => true && tmpl((a,b) => time <= a*treeNum(h) + b))*/ + + def minTreeChildren(h: BinomialHeap) : BigInt = { + val (min, _) = removeMinTree(h) + min match { + case Some(TreeNode(_,_,ch)) => treeNum(ch) + case _ => 0 + } + } + + // Discard the minimum element of the extracted min tree and put its children back into the heap + def deleteMin(h: BinomialHeap) : BinomialHeap = { + val (min, ts2) = removeMinTree(h) + min match { + case Some(TreeNode(_,_,ts1)) => merge(ts1, ts2) + case _ => h + } + } ensuring(_ => time <= ? * minTreeChildren(h) + ? * treeNum(h) + ?) + +} diff --git a/testcases/orb-testcases/timing/ConcTrees.scala b/testcases/orb-testcases/timing/ConcTrees.scala new file mode 100644 index 0000000000000000000000000000000000000000..9145789294ea99cce0323261b6505d67e7492648 --- /dev/null +++ b/testcases/orb-testcases/timing/ConcTrees.scala @@ -0,0 +1,536 @@ +package conctrees + +import leon.instrumentation._ +import leon.collection._ +import leon.lang._ +import ListSpecs._ +import leon.annotation._ +import leon.invariant._ + +object ConcTrees { + + def max(x: BigInt, y: BigInt): BigInt = if (x >= y) x else y + def abs(x: BigInt): BigInt = if (x < 0) -x else x + + sealed abstract class Conc[T] { + + def isEmpty: Boolean = { + this == Empty[T]() + } + + def isLeaf: Boolean = { + this match { + case Empty() => true + case Single(_) => true + case _ => false + } + } + + def isNormalized: Boolean = this match { + case Append(_, _) => false + case _ => true + } + + def valid: Boolean = { + concInv && balanced && appendInv + } + + /** + * (a) left and right trees of conc node should be non-empty + * (b) they cannot be append nodes + */ + def concInv: Boolean = this match { + case CC(l, r) => + !l.isEmpty && !r.isEmpty && + l.isNormalized && r.isNormalized && + l.concInv && r.concInv + case _ => true + } + + def balanced: Boolean = { + this match { + case CC(l, r) => + l.level - r.level >= -1 && l.level - r.level <= 1 && + l.balanced && r.balanced + case _ => true + } + } + + /** + * (a) Right subtree of an append node is not an append node + * (b) If the tree is of the form a@Append(b@Append(_,_),_) then + * a.right.level < b.right.level + * (c) left and right are not empty + */ + def appendInv: Boolean = this match { + case Append(l, r) => + !l.isEmpty && !r.isEmpty && + l.valid && r.valid && + r.isNormalized && + (l match { + case Append(_, lr) => + lr.level > r.level + case _ => + l.level > r.level + }) + case _ => true + } + + val level: BigInt = { + (this match { + case Empty() => 0 + case Single(x) => 0 + case CC(l, r) => + 1 + max(l.level, r.level) + case Append(l, r) => + 1 + max(l.level, r.level) + }): BigInt + } ensuring (_ >= 0) + + val size: BigInt = { + (this match { + case Empty() => 0 + case Single(x) => 1 + case CC(l, r) => + l.size + r.size + case Append(l, r) => + l.size + r.size + }): BigInt + } ensuring (_ >= 0) + + def toList: List[T] = { + this match { + case Empty() => Nil[T]() + case Single(x) => Cons(x, Nil[T]()) + case CC(l, r) => + l.toList ++ r.toList // note: left elements precede the right elements in the list + case Append(l, r) => + l.toList ++ r.toList + } + } ensuring (res => res.size == this.size) + } + + case class Empty[T]() extends Conc[T] + case class Single[T](x: T) extends Conc[T] + case class CC[T](left: Conc[T], right: Conc[T]) extends Conc[T] + case class Append[T](left: Conc[T], right: Conc[T]) extends Conc[T] + + /*class Chunk(val array: Array[T], val size: Int, val k: Int) extends Leaf[T] { + def level = 0 + override def toString = s"Chunk(${array.mkString("", ", ", "")}; $size; $k)" + }*/ + + def lookup[T](xs: Conc[T], i: BigInt): T = { + require(xs.valid && !xs.isEmpty && i >= 0 && i < xs.size) + xs match { + case Single(x) => x + case CC(l, r) => + if (i < l.size) { + lookup(l, i) + } else { + lookup(r, i - l.size) + } + case Append(l, r) => + if (i < l.size) { + lookup(l, i) + } else { + lookup(r, i - l.size) + } + } + } ensuring (res => tmpl((a,b) => time <= a*xs.level + b) && // lookup time is linear in the height + res == xs.toList(i) && // correctness + instAppendIndexAxiom(xs, i)) // an auxiliary axiom instantiation that is required for the proof + + // @library + def instAppendIndexAxiom[T](xs: Conc[T], i: BigInt): Boolean = { + require(0 <= i && i < xs.size) + xs match { + case CC(l, r) => + appendIndex(l.toList, r.toList, i) + case Append(l, r) => + appendIndex(l.toList, r.toList, i) + case _ => true + } + }.holds + + @library + def update[T](xs: Conc[T], i: BigInt, y: T): Conc[T] = { + require(xs.valid && !xs.isEmpty && i >= 0 && i < xs.size) + xs match { + case Single(x) => Single(y) + case CC(l, r) => + if (i < l.size) + CC(update(l, i, y), r) + else + CC(l, update(r, i - l.size, y)) + case Append(l, r) => + if (i < l.size) + Append(update(l, i, y), r) + else + Append(l, update(r, i - l.size, y)) + } + } ensuring (res => res.level == xs.level && // heights of the input and output trees are equal + res.valid && // tree invariants are preserved + tmpl((a,b) => time <= a*xs.level + b) && // update time is linear in the height of the tree + res.toList == xs.toList.updated(i, y) && // correctness + numTrees(res) == numTrees(xs) && //auxiliary property that preserves the potential function + instAppendUpdateAxiom(xs, i, y)) // an auxiliary axiom instantiation + + @library + def instAppendUpdateAxiom[T](xs: Conc[T], i: BigInt, y: T): Boolean = { + require(i >= 0 && i < xs.size) + xs match { + case CC(l, r) => + appendUpdate(l.toList, r.toList, i, y) + case Append(l, r) => + appendUpdate(l.toList, r.toList, i, y) + case _ => true + } + }.holds + + /** + * A generic concat that applies to general concTrees + */ + @library + def concat[T](xs: Conc[T], ys: Conc[T]): Conc[T] = { + require(xs.valid && ys.valid) + concatNormalized(normalize(xs), normalize(ys)) + } + + /** + * This concat applies only to normalized trees. + * This prevents concat from being recursive + */ + @library + def concatNormalized[T](xs: Conc[T], ys: Conc[T]): Conc[T] = { + require(xs.valid && ys.valid && + xs.isNormalized && ys.isNormalized) + (xs, ys) match { + case (xs, Empty()) => xs + case (Empty(), ys) => ys + case _ => + concatNonEmpty(xs, ys) + } + } ensuring (res => res.valid && // tree invariants + res.level <= max(xs.level, ys.level) + 1 && // height invariants + res.level >= max(xs.level, ys.level) && + (res.toList == xs.toList ++ ys.toList) && // correctness + res.isNormalized //auxiliary properties + ) + + //@library + def concatNonEmpty[T](xs: Conc[T], ys: Conc[T]): Conc[T] = { + require(xs.valid && ys.valid && + xs.isNormalized && ys.isNormalized && + !xs.isEmpty && !ys.isEmpty) + + val diff = ys.level - xs.level + if (diff >= -1 && diff <= 1) + CC(xs, ys) + else if (diff < -1) { + // ys is smaller than xs + xs match { + case CC(l, r) => + if (l.level >= r.level) + CC(l, concatNonEmpty(r, ys)) + else { + r match { + case CC(rl, rr) => + val nrr = concatNonEmpty(rr, ys) + if (nrr.level == xs.level - 3) + CC(l, CC(rl, nrr)) + else + CC(CC(l, rl), nrr) + } + } + } + } else { + ys match { + case CC(l, r) => + if (r.level >= l.level) + CC(concatNonEmpty(xs, l), r) + else { + l match { + case CC(ll, lr) => + val nll = concatNonEmpty(xs, ll) + if (nll.level == ys.level - 3) + CC(CC(nll, lr), r) + else + CC(nll, CC(lr, r)) + } + } + } + } + } ensuring (res => tmpl((a,b) => time <= a*abs(xs.level - ys.level) + b) && // time bound + res.level <= max(xs.level, ys.level) + 1 && // height invariants + res.level >= max(xs.level, ys.level) && + res.balanced && res.appendInv && res.concInv && //this is should not be needed. But, seems necessary for leon + res.valid && // tree invariant is preserved + res.toList == xs.toList ++ ys.toList && // correctness + res.isNormalized && // auxiliary properties + appendAssocInst(xs, ys) // instantiation of an axiom + ) + + @library + def appendAssocInst[T](xs: Conc[T], ys: Conc[T]): Boolean = { + (xs match { + case CC(l, r) => + appendAssoc(l.toList, r.toList, ys.toList) && //instantiation of associativity of concatenation + (r match { + case CC(rl, rr) => + appendAssoc(rl.toList, rr.toList, ys.toList) && + appendAssoc(l.toList, rl.toList, rr.toList ++ ys.toList) + case _ => true + }) + case _ => true + }) && + (ys match { + case CC(l, r) => + appendAssoc(xs.toList, l.toList, r.toList) && + (l match { + case CC(ll, lr) => + appendAssoc(xs.toList, ll.toList, lr.toList) && + appendAssoc(xs.toList ++ ll.toList, lr.toList, r.toList) + case _ => true + }) + case _ => true + }) + }.holds + + @library + def insert[T](xs: Conc[T], i: BigInt, y: T): Conc[T] = { + require(xs.valid && i >= 0 && i <= xs.size && + xs.isNormalized) //note the precondition + xs match { + case Empty() => Single(y) + case Single(x) => + if (i == 0) + CC(Single(y), xs) + else + CC(xs, Single(y)) + case CC(l, r) if i < l.size => + concatNonEmpty(insert(l, i, y), r) + case CC(l, r) => + concatNonEmpty(l, insert(r, i - l.size, y)) + } + } ensuring (res => res.valid && res.isNormalized && // tree invariants + res.level - xs.level <= 1 && res.level >= xs.level && // height of the output tree is at most 1 greater than that of the input tree + tmpl((a,b) => time <= a*xs.level + b) && // time is linear in the height of the tree + res.toList == xs.toList.insertAt(i, y) && // correctness + insertAppendAxiomInst(xs, i, y) // instantiation of an axiom + ) + + @library + def insertAppendAxiomInst[T](xs: Conc[T], i: BigInt, y: T): Boolean = { + require(i >= 0 && i <= xs.size) + xs match { + case CC(l, r) => appendInsert(l.toList, r.toList, i, y) + case _ => true + } + }.holds + + //TODO: why with instrumentation we are not able prove the running time here ? (performance bug ?) + @library + def split[T](xs: Conc[T], n: BigInt): (Conc[T], Conc[T]) = { + require(xs.valid && xs.isNormalized) + xs match { + case Empty() => + (Empty(), Empty()) + case s @ Single(x) => + if (n <= 0) { //a minor fix + (Empty(), s) + } else { + (s, Empty()) + } + case CC(l, r) => + if (n < l.size) { + val (ll, lr) = split(l, n) + (ll, concatNormalized(lr, r)) + } else if (n > l.size) { + val (rl, rr) = split(r, n - l.size) + (concatNormalized(l, rl), rr) + } else { + (l, r) + } + } + } ensuring (res => res._1.valid && res._2.valid && // tree invariants are preserved + res._1.isNormalized && res._2.isNormalized && + xs.level >= res._1.level && xs.level >= res._2.level && // height bounds of the resulting tree + tmpl((a,b,c,d) => time <= a*xs.level + b*res._1.level + c*res._2.level + d) && // time is linear in height + res._1.toList == xs.toList.take(n) && res._2.toList == xs.toList.drop(n) && // correctness + instSplitAxiom(xs, n) // instantiation of an axiom + ) + + @library + def instSplitAxiom[T](xs: Conc[T], n: BigInt): Boolean = { + xs match { + case CC(l, r) => + appendTakeDrop(l.toList, r.toList, n) + case _ => true + } + }.holds + + @library + def append[T](xs: Conc[T], x: T): Conc[T] = { + require(xs.valid) + val ys = Single[T](x) + xs match { + case xs @ Append(_, _) => + appendPriv(xs, ys) + case CC(_, _) => + Append(xs, ys) //creating an append node + case Empty() => + ys + case Single(_) => + CC(xs, ys) + } + } ensuring (res => res.valid && //conctree invariants + res.toList == xs.toList ++ Cons(x, Nil[T]()) && //correctness + res.level <= xs.level + 1 && + tmpl((a,b,c) => time <= a*numTrees(xs) - b*numTrees(res) + c) //time bound (worst case) + ) + + /** + * This is a private method and is not exposed to the + * clients of conc trees + */ + @library + def appendPriv[T](xs: Append[T], ys: Conc[T]): Conc[T] = { + require(xs.valid && ys.valid && + !ys.isEmpty && ys.isNormalized && + xs.right.level >= ys.level) + + if (xs.right.level > ys.level) + Append(xs, ys) + else { + val zs = CC(xs.right, ys) + xs.left match { + case l @ Append(_, _) => + appendPriv(l, zs) + case l if l.level <= zs.level => //note: here < is not possible + CC(l, zs) + case l => + Append(l, zs) + } + } + } ensuring (res => res.valid && //conc tree invariants + res.toList == xs.toList ++ ys.toList && //correctness invariants + res.level <= xs.level + 1 && + tmpl((a,b,c) => time <= a*numTrees(xs) - b*numTrees(res) + c) && //time bound (worst case) + appendAssocInst2(xs, ys)) + + @library + def appendAssocInst2[T](xs: Conc[T], ys: Conc[T]): Boolean = { + xs match { + case CC(l, r) => + appendAssoc(l.toList, r.toList, ys.toList) + case Append(l, r) => + appendAssoc(l.toList, r.toList, ys.toList) + case _ => true + } + }.holds + + @library + def numTrees[T](t: Conc[T]): BigInt = { + t match { + case Append(l, r) => numTrees(l) + 1 + case _ => BigInt(1) + } + } ensuring (res => res >= 0) + + @library + def normalize[T](t: Conc[T]): Conc[T] = { + require(t.valid) + t match { + case Append(l @ Append(_, _), r) => + wrap(l, r) + case Append(l, r) => + concatNormalized(l, r) + case _ => t + } + } ensuring (res => res.valid && + res.isNormalized && + res.toList == t.toList && //correctness + res.size == t.size && res.level <= t.level && //normalize preserves level and size + tmpl((a,b) => time <= a*t.level + b) //time bound (a little over approximate) + ) + + @library + def wrap[T](xs: Append[T], ys: Conc[T]): Conc[T] = { + require(xs.valid && ys.valid && ys.isNormalized && + xs.right.level >= ys.level) + val nr = concatNormalized(xs.right, ys) + xs.left match { + case l @ Append(_, _) => + wrap(l, nr) + case l => + concatNormalized(l, nr) + } + } ensuring (res => res.valid && + res.isNormalized && + res.toList == xs.toList ++ ys.toList && //correctness + res.size == xs.size + ys.size && //other auxiliary properties + res.level <= xs.level && + tmpl((a,b,c) => time <= a*xs.level - b*ys.level + c) && //time bound + appendAssocInst2(xs, ys)) //some lemma instantiations + + /** + * A class that represents an operation on a concTree. + * opid - an integer that denotes the function that has to be performed e.g. append, insert, update ... + * opid <= 0 => the operation is lookup + * opid == 1 => the operation is update + * opid == 2 => the operation is insert + * opid == 3 => the operation is split + * opid >= 4 => the operation is append + * index, x - denote the arguments the function given by opid + */ + case class Operation[T](opid: BigInt, /*argument to the operations*/ index: BigInt /*for lookup, update, insert, split*/ , + x: T /*for update, insert, append*/ ) + + /** + * Proving amortized running time of 'Append' when used ephimerally. + * ops- a arbitrary sequence of operations, + * noaps - number of append operations in the list + */ + def performOperations[T](xs: Conc[T], ops: List[Operation[T]], noaps: BigInt): Conc[T] = { + require(xs.valid && noaps >= 0) + ops match { + case Cons(Operation(id, i, _), tail) if id <= 0 => + //we need to perform a lookup operation, but do the operation only if + //preconditions hold + // val _ = if (0 <= i && i < xs.size) + // lookup(xs, i) + // else BigInt(0) + performOperations(xs, tail, noaps) //returns the time taken by appends in the remaining operations + + case Cons(Operation(id, i, x), tail) if id == 1 => + val newt = if (0 <= i && i < xs.size) + update(xs, i, x) + else xs + //note that only the return value is used by the subsequent operations (emphimeral use) + performOperations(newt, tail, noaps) + + case Cons(Operation(id, i, x), tail) if id == 2 => + val newt = if (0 <= i && i <= xs.size) + insert(normalize(xs), i, x) + else xs + performOperations(newt, tail, noaps) + + case Cons(Operation(id, n, _), tail) if id == 3 => + //use the larger tree to perform the remaining operations + val (newl, newr) = split(normalize(xs), n) + val newt = if (newl.size >= newr.size) newl else newr + performOperations(newt, tail, noaps) + + case Cons(Operation(id, _, x), tail) if noaps > 0 => + //here, we need to perform append operation + val newt = append(xs, x) + val r = performOperations(newt, tail, noaps - 1) + r //time taken by this append and those that follow it + + case _ => + xs + } + } ensuring (res => tmpl((a, b) => time <= a*noaps + b*numTrees(xs))) +//res._2 <= noaps + 2*nops*(xs.level + res._1.level)+ numTrees(xs) +} diff --git a/testcases/orb-testcases/timing/ConcatVariations.scala b/testcases/orb-testcases/timing/ConcatVariations.scala new file mode 100644 index 0000000000000000000000000000000000000000..a94fb418a48db5b27ab88cfbbe84a59233f394b0 --- /dev/null +++ b/testcases/orb-testcases/timing/ConcatVariations.scala @@ -0,0 +1,42 @@ +import leon.invariant._ +import leon.instrumentation._ + + +object ConcatVariations { + sealed abstract class List + case class Cons(head: BigInt, tail: List) extends List + case class Nil() extends List + + def size(l: List): BigInt = (l match { + case Nil() => 0 + case Cons(_, t) => 1 + size(t) + }) + + def genL(n: BigInt): List = { + require(n >= 0) + if (n == 0) Nil() + else + Cons(n, genL(n - 1)) + } ensuring (res => size(res) == n && tmpl((a,b) => time <= a*n + b)) + + def append(l1: List, l2: List): List = (l1 match { + case Nil() => l2 + case Cons(x, xs) => Cons(x, append(xs, l2)) + + }) ensuring (res => size(l1) + size(l2) == size(res) && tmpl((a,b) => time <= a*size(l1) + b)) + + def f_good(m: BigInt, n: BigInt): List = { + require(0 <= m && 0 <= n) + if (m == 0) Nil() + else append(genL(n), f_good(m - 1, n)) + + } ensuring(res => size(res) == n*m && tmpl((a,b,c,d) => time <= a*(n*m) + b*n + c*m +d)) + + def f_worst(m: BigInt, n: BigInt): List = { + require(0 <= m && 0 <= n) + + if (m == 0) Nil() + else append(f_worst(m - 1, n), genL(n)) + + } ensuring(res => size(res) == n*m && tmpl((a,c,d,e,f) => time <= a*((n*m)*m)+c*(n*m)+d*n+e*m+f)) +} diff --git a/testcases/orb-testcases/timing/ConstantPropagation.scala b/testcases/orb-testcases/timing/ConstantPropagation.scala new file mode 100644 index 0000000000000000000000000000000000000000..76af84834063ba67788d05dee977da962e79f372 --- /dev/null +++ b/testcases/orb-testcases/timing/ConstantPropagation.scala @@ -0,0 +1,290 @@ +import leon.lang._ +import leon.annotation._ +import leon.collection._ +import leon._ +import leon.invariant._ +import leon.instrumentation._ + +object IntLattice { + abstract class Element + case class Bot() extends Element + case class Top() extends Element + case class BigIntVal(x: BigInt) extends Element + + def height: BigInt = { + /** + * A number that depends on the lattice definition. + * In simplest case it has height 3 (_|_ (bot) <= BigInts <= T (top)) + */ + 3 + } + + def join(oldVal: Element, newVal: Element) = (oldVal, newVal) match { + case (Bot(), any) => any // bot is the identity for join + case (any, Bot()) => any + case (Top(), _) => Top() // top joined with anything is top + case (_, Top()) => Top() + case (BigIntVal(x), BigIntVal(y)) if (x == y) => BigIntVal(y) + case _ => + //here old and new vals are different BigIntegers + Top() + } +} + +object LatticeOps { + import IntLattice._ + + def add(a: Element, b: Element): Element = { + (a, b) match { + case (Bot(), _) => Bot() + case (_, Bot()) => Bot() + case (Top(), _) => Top() + case (_, Top()) => Top() + case (BigIntVal(x), BigIntVal(y)) => BigIntVal(x + y) + } + } + + def multiply(a: Element, b: Element): Element = { + (a, b) match { + case (_, BigIntVal(x)) if x == 0 => BigIntVal(0) + case (BigIntVal(x), _) if x == 0 => BigIntVal(0) + case (Bot(), _) => Bot() + case (_, Bot()) => Bot() + case (Top(), _) => Top() + case (_, Top()) => Top() + case (BigIntVal(x), BigIntVal(y)) => BigIntVal(x * y) + } + } +} + +object ConstantPropagation { + import IntLattice._ + import LatticeOps._ + + abstract class Expr + case class Times(lhs: Expr, rhs: Expr) extends Expr + case class Plus(lhs: Expr, rhs: Expr) extends Expr + case class BigIntLiteral(v: BigInt) extends Expr + case class FunctionCall(calleeId: BigInt, args: List[Expr]) extends Expr + case class IfThenElse(cond: Expr, thenExpr: Expr, elseExpr: Expr) extends Expr + case class Identifier(id: BigInt) extends Expr + + /** + * Definition of a function AST + */ + case class Function(id: BigInt, params: List[Expr], body: Expr) + + /** + * Assuming that the functions are ordered from callee to + * caller and there is no mutual recursion + */ + case class Program(funcs: List[Function]) + + def size(l: List[Function]): BigInt = { + l match { + case Cons(_, t) => 1 + size(t) + case Nil() => 0 + } + } + + def sizeExprList(exprs: List[Expr]): BigInt = { + exprs match { + case Nil() => 0 + case Cons(currExpr, otherExprs) => sizeExpr(currExpr) + sizeExprList(otherExprs) + } + } + + def sizeExpr(e: Expr): BigInt = { + e match { + case Plus(l, r) => 1 + sizeExpr(l) + sizeExpr(r) + case Times(l, r) => 1 + sizeExpr(l) + sizeExpr(r) + case FunctionCall(c, args) => { + 1 + sizeExprList(args) + } + case IfThenElse(c, th, el) => + 1 + sizeExpr(c) + sizeExpr(th) + sizeExpr(el) + case _ => 1 + } + } + + def sizeFuncList(funcs: List[Function]): BigInt = { + funcs match { + case Nil() => 0 + case Cons(currFunc, otherFuncs) => + 1 + sizeExpr(currFunc.body) + sizeFuncList(otherFuncs) + } + } + + def initToBot(l: List[Function]): List[(BigInt /*function id*/ , Element)] = { + l match { + case Nil() => Nil[(BigInt /*function id*/ , Element)]() + case Cons(fun, tail) => Cons((fun.id, Bot()), initToBot(tail)) + } + } ensuring (_ => time <= ? * size(l) + ?) + + def foldConstants(p: Program): Program = { + val initVals = initToBot(p.funcs) + val fvals = computeSummaries(p, initToBot(p.funcs), height) + val newfuns = transformFuns(p.funcs, fvals) + Program(newfuns) + } ensuring(_ => time <= ? * (sizeFuncList(p.funcs)*height) + ? * height + ? * size(p.funcs) + ?) + + /** + * The initVals is the initial values for the + * values of the functions + */ + def computeSummaries(p: Program, initVals: List[(BigInt /*function id*/ , Element)], noIters: BigInt): List[(BigInt /*function id*/ , Element)] = { + require(noIters >= 0) + if (noIters <= 0) { + initVals + } else + computeSummaries(p, analyzeFuns(p.funcs, initVals, initVals), noIters - 1) + } ensuring(_ => time <= ? * (sizeFuncList(p.funcs)*noIters) + ? * noIters + ?) + + /** + * Initial fvals and oldVals are the same + * but as the function progresses, fvals will only have the olds values + * of the functions that are yet to be processed, whereas oldVals will remain the same. + */ + def analyzeFuns(funcs: List[Function], fvals: List[(BigInt, Element)], oldVals: List[(BigInt, Element)]): List[(BigInt, Element)] = { + (funcs, fvals) match { + case (Cons(f, otherFuns), Cons((fid, fval), otherVals)) => + val newval = analyzeFunction(f, oldVals) + val approxVal = join(fval, newval) //creates an approximation of newVal to ensure convergence + Cons((fid, approxVal), analyzeFuns (otherFuns, otherVals, oldVals)) + case _ => + Nil[(BigInt, Element)]() //this also handles precondition violations e.g. lists aren't of same size etc. + } + } ensuring (_ => time <= ? * sizeFuncList(funcs) + ?) + + @library + def getFunctionVal(funcId: BigInt, funcVals: List[(BigInt, Element)]): Element = { + funcVals match { + case Nil() => Bot() + case Cons((currFuncId, currFuncVal), otherFuncVals) if (currFuncId == funcId) => currFuncVal + case Cons(_, otherFuncVals) => + getFunctionVal(funcId, otherFuncVals) + } + } ensuring (_ => time <= 1) + + + def analyzeExprList(l: List[Expr], funcVals: List[(BigInt, Element)]): List[Element] = { + l match { + case Nil() => Nil[Element]() + case Cons(expr, otherExprs) => Cons(analyzeExpr(expr, funcVals), analyzeExprList(otherExprs, funcVals)) + } + } ensuring (_ => time <= ? * sizeExprList(l) + ?) + + /** + * Returns the value of the expression when "Abstractly Interpreted" + * using the lattice. + */ + def analyzeExpr(e: Expr, funcVals: List[(BigInt, Element)]): Element = { + e match { + case Times(lhs: Expr, rhs: Expr) => { + val lval = analyzeExpr(lhs, funcVals) + val rval = analyzeExpr(rhs, funcVals) + multiply(lval, rval) + } + case Plus(lhs: Expr, rhs: Expr) => { + val lval = analyzeExpr(lhs, funcVals) + val rval = analyzeExpr(rhs, funcVals) + add(lval, rval) + } + case FunctionCall(id, args: List[Expr]) => { + getFunctionVal(id, funcVals) + } + case IfThenElse(c, th, el) => { + //analyze then and else branches and join their values + //TODO: this can be made more precise e.g. if 'c' is + //a non-zero value it can only execute the then branch. + val v1 = analyzeExpr(th, funcVals) + val v2 = analyzeExpr(el, funcVals) + join(v1, v2) + } + case lit @ BigIntLiteral(v) => + BigIntVal(v) + + case Identifier(_) => Bot() + } + } ensuring (_ => time <= ? * sizeExpr(e) + ?) + + + def analyzeFunction(f: Function, oldVals: List[(BigInt, Element)]): Element = { + // traverse the body of the function and simplify constants + // for function calls assume the value given by oldVals + // also for if-then-else statments, take a join of the values along if and else branches + // assume that bot op any = bot and top op any = top (but this can be made more precise). + analyzeExpr(f.body, oldVals) + } ensuring (_ => time <= ? * sizeExpr(f.body) + ?) + + + def transformExprList(l: List[Expr], funcVals: List[(BigInt, Element)]): List[Expr] = { + l match { + case Nil() => Nil[Expr]() + case Cons(expr, otherExprs) => Cons(transformExpr(expr, funcVals), + transformExprList(otherExprs, funcVals)) + } + } ensuring (_ => time <= ? * sizeExprList(l) + ?) + + /** + * Returns the folded expression + */ + def transformExpr(e: Expr, funcVals: List[(BigInt, Element)]): Expr = { + e match { + case Times(lhs: Expr, rhs: Expr) => { + val foldedLHS = transformExpr(lhs, funcVals) + val foldedRHS = transformExpr(rhs, funcVals) + (foldedLHS, foldedRHS) match { + case (BigIntLiteral(x), BigIntLiteral(y)) => + BigIntLiteral(x * y) + case _ => + Times(foldedLHS, foldedRHS) + } + } + case Plus(lhs: Expr, rhs: Expr) => { + val foldedLHS = transformExpr(lhs, funcVals) + val foldedRHS = transformExpr(rhs, funcVals) + (foldedLHS, foldedRHS) match { + case (BigIntLiteral(x), BigIntLiteral(y)) => + BigIntLiteral(x + y) + case _ => + Plus(foldedLHS, foldedRHS) + } + } + case FunctionCall(calleeid, args: List[Expr]) => { + getFunctionVal(calleeid, funcVals) match { + case BigIntVal(x) => + BigIntLiteral(x) + case _ => + val foldedArgs = transformExprList(args, funcVals) + FunctionCall(calleeid, foldedArgs) + } + } + case IfThenElse(c, th, el) => { + val foldedCond = transformExpr(c, funcVals) + val foldedTh = transformExpr(th, funcVals) + val foldedEl = transformExpr(el, funcVals) + foldedCond match { + case BigIntLiteral(x) => { + if (x != 0) foldedTh + else foldedEl + } + case _ => IfThenElse(foldedCond, foldedTh, foldedEl) + } + } + case _ => e + } + } ensuring (_ => time <= ? * sizeExpr(e) + ?) + + + def transformFuns(funcs: List[Function], fvals: List[(BigInt, Element)]): List[Function] = { + funcs match { + case Cons(f, otherFuns) => + val newfun = Function(f.id, f.params, transformExpr(f.body, fvals)) + Cons(newfun, transformFuns(otherFuns, fvals)) + case _ => + Nil[Function]() + } + } ensuring (_ => time <= ? * sizeFuncList(funcs) + ?) +} \ No newline at end of file diff --git a/testcases/orb-testcases/timing/Folds.scala b/testcases/orb-testcases/timing/Folds.scala new file mode 100644 index 0000000000000000000000000000000000000000..2dac7e26684a0764429f0a14b8098160bcd0a0e4 --- /dev/null +++ b/testcases/orb-testcases/timing/Folds.scala @@ -0,0 +1,77 @@ +import leon.invariant._ +import leon.instrumentation._ + +object TreeMaps { + + sealed abstract class Tree + case class Node(left: Tree, value: BigInt, right: Tree) extends Tree + case class Leaf() extends Tree + + def size(t: Tree): BigInt = { + t match { + case Leaf() => 0 + case Node(l, x, r) => size(l) + size(r) + 1 + } + } + + def parallelSearch(elem : BigInt, t : Tree) : Boolean = { + t match { + case Node(l, x, r) => + if(x == elem) true + else { + val r1 = parallelSearch(elem, r) + val r2 = parallelSearch(elem, l) + if(r1 || r2) true + else false + } + case Leaf() => false + } + } ensuring(res => true && tmpl((a,b) => time <= a*size(t) + b)) + + + def squareMap(t : Tree) : Tree = { + t match { + case Node(l, x, r) => + val nl = squareMap(l) + val nr = squareMap(r) + Node(nl, x*x, nr) + case _ => t + } + } ensuring (res => true && tmpl((a,b) => time <= a*size(t) + b)) + + sealed abstract class List + case class Cons(head: BigInt, tail: List) extends List + case class Nil() extends List + + def size(l: List): BigInt = (l match { + case Nil() => 0 + case Cons(_, t) => 1 + size(t) + }) + + def fact(n : BigInt) : BigInt = { + require(n >= 0) + + if(n == 1 || n == 0) BigInt(1) + else n * fact(n-1) + + } ensuring(res => true && tmpl((a,b) => time <= a*n + b)) + + def descending(l: List, k: BigInt) : Boolean = { + l match { + case Nil() => true + case Cons(x, t) => x > 0 && x <= k && descending(t, x-1) + } + } + + def factMap(l: List, k: BigInt): List = { + require(descending(l, k) && k >= 0) + + l match { + case Nil() => Nil() + case Cons(x, t) => { + val f = fact(x) + Cons(f, factMap(t, x-1)) + } + + }} ensuring(res => true && tmpl((a,b) => time <= a*(k*k) + b)) +} \ No newline at end of file diff --git a/testcases/orb-testcases/timing/ForElimination.scala b/testcases/orb-testcases/timing/ForElimination.scala new file mode 100644 index 0000000000000000000000000000000000000000..c76da1507651ca301816ef6fa26c6e468cf6df3d --- /dev/null +++ b/testcases/orb-testcases/timing/ForElimination.scala @@ -0,0 +1,102 @@ +import leon.invariant._ +import leon.instrumentation._ + +object ForElimination { + + sealed abstract class List + case class Nil() extends List + case class Cons(head: Statement, tail: List) extends List + + sealed abstract class Statement + case class Print(msg: BigInt, varID: BigInt) extends Statement + case class Assign(varID: BigInt, expr: Expression) extends Statement + case class Skip() extends Statement + case class Block(body: List) extends Statement + case class IfThenElse(expr: Expression, thenExpr: Statement, elseExpr: Statement) extends Statement + case class While(expr: Expression, body: Statement) extends Statement + case class For(init: Statement, expr: Expression, step: Statement, body: Statement) extends Statement + + sealed abstract class Expression + case class Var(varID: BigInt) extends Expression + case class IntLiteral(value: BigInt) extends Expression + case class Plus(lhs: Expression, rhs: Expression) extends Expression + case class Minus(lhs: Expression, rhs: Expression) extends Expression + case class Times(lhs: Expression, rhs: Expression) extends Expression + case class Division(lhs: Expression, rhs: Expression) extends Expression + case class Equals(lhs: Expression, rhs: Expression) extends Expression + case class LessThan(lhs: Expression, rhs: Expression) extends Expression + case class And(lhs: Expression, rhs: Expression) extends Expression + case class Or(lhs: Expression, rhs: Expression) extends Expression + case class Not(expr: Expression) extends Expression + + def sizeStat(st: Statement) : BigInt = st match { + case Block(l) => sizeList(l) + 1 + case IfThenElse(c,th,el) => sizeStat(th) + sizeStat(el) + 1 + case While(c,b) => sizeStat(b) + 1 + case For(init,cond,step,body) => sizeStat(init) + sizeStat(step) + sizeStat(body) + case other => 1 + } + + def sizeList(l: List) : BigInt = l match { + case Cons(h,t) => sizeStat(h) + sizeList(t) + case Nil() => 0 + } + + def isForFree(stat: Statement): Boolean = (stat match { + case Block(body) => isForFreeList(body) + case IfThenElse(_, thenExpr, elseExpr) => isForFree(thenExpr) && isForFree(elseExpr) + case While(_, body) => isForFree(body) + case For(_,_,_,_) => false + case _ => true + }) ensuring(res => true && tmpl((a,b) => time <= a*sizeStat(stat) + b)) + + def isForFreeList(l: List): Boolean = (l match { + case Nil() => true + case Cons(x, xs) => isForFree(x) && isForFreeList(xs) + }) ensuring(res => true && tmpl((a,b) => time <= a*sizeList(l) + b)) + + def forLoopsWellFormedList(l: List): Boolean = (l match { + case Nil() => true + case Cons(x, xs) => forLoopsWellFormed(x) && forLoopsWellFormedList(xs) + }) ensuring(res => true && tmpl((a,b) => time <= a*sizeList(l) + b)) + + def forLoopsWellFormed(stat: Statement): Boolean = (stat match { + case Block(body) => forLoopsWellFormedList(body) + case IfThenElse(_, thenExpr, elseExpr) => forLoopsWellFormed(thenExpr) && forLoopsWellFormed(elseExpr) + case While(_, body) => forLoopsWellFormed(body) + case For(init, _, step, body) => isForFree(init) && isForFree(step) && forLoopsWellFormed(body) + case _ => true + }) ensuring(res => true && tmpl((a,b) => time <= a*sizeStat(stat) + b)) + + def eliminateWhileLoopsList(l: List): List = { + l match { + case Nil() => Nil() + case Cons(x, xs) => Cons(eliminateWhileLoops(x), eliminateWhileLoopsList(xs)) + } + } ensuring(res => true && tmpl((a,b) => time <= a*sizeList(l) + b)) + + def eliminateWhileLoops(stat: Statement): Statement = (stat match { + case Block(body) => Block(eliminateWhileLoopsList(body)) + case IfThenElse(expr, thenExpr, elseExpr) => IfThenElse(expr, eliminateWhileLoops(thenExpr), eliminateWhileLoops(elseExpr)) + case While(expr, body) => For(Skip(), expr, Skip(), eliminateWhileLoops(body)) + case For(init, expr, step, body) => For(eliminateWhileLoops(init), expr, eliminateWhileLoops(step), eliminateWhileLoops(body)) + case other => other + }) ensuring(res => true && tmpl((a,b) => time <= a*sizeStat(stat) + b)) + + def eliminateForLoopsList(l: List): List = { + l match { + case Nil() => Nil() + case Cons(x, xs) => Cons(eliminateForLoops(x), eliminateForLoopsList(xs)) + } + } ensuring(res => true && tmpl((a,b) => time <= a*sizeList(l) + b)) + + def eliminateForLoops(stat: Statement): Statement = { + stat match { + case Block(body) => Block(eliminateForLoopsList(body)) + case IfThenElse(expr, thenExpr, elseExpr) => IfThenElse(expr, eliminateForLoops(thenExpr), eliminateForLoops(elseExpr)) + case While(expr, body) => While(expr, eliminateForLoops(body)) + case For(init, expr, step, body) => Block(Cons(eliminateForLoops(init), Cons(While(expr, Block(Cons(eliminateForLoops(body), Cons(eliminateForLoops(step), Nil())))), Nil()))) + case other => other + } + } ensuring(res => true && tmpl((a,b) => time <= a*sizeStat(stat) + b)) +} diff --git a/testcases/orb-testcases/timing/InsertionSort.scala b/testcases/orb-testcases/timing/InsertionSort.scala new file mode 100644 index 0000000000000000000000000000000000000000..8fd79a2e89f60441fd522584fae4197079f9294e --- /dev/null +++ b/testcases/orb-testcases/timing/InsertionSort.scala @@ -0,0 +1,26 @@ +import leon.invariant._ +import leon.instrumentation._ + +object InsertionSort { + sealed abstract class List + case class Cons(head: BigInt, tail:List) extends List + case class Nil() extends List + + def size(l : List) : BigInt = (l match { + case Cons(_, xs) => 1 + size(xs) + case _ => 0 + }) + + def sortedIns(e: BigInt, l: List): List = { + l match { + case Cons(x,xs) => if (x <= e) Cons(x,sortedIns(e, xs)) else Cons(e, l) + case _ => Cons(e,Nil()) + } + } ensuring(res => size(res) == size(l) + 1 && tmpl((a,b) => time <= a*size(l) +b && depth <= a*size(l) +b)) + + def sort(l: List): List = (l match { + case Cons(x,xs) => sortedIns(x, sort(xs)) + case _ => Nil() + + }) ensuring(res => size(res) == size(l) && tmpl((a,b) => time <= a*(size(l)*size(l)) +b && rec <= a*size(l) + b)) +} diff --git a/testcases/orb-testcases/timing/LeftistHeap.scala b/testcases/orb-testcases/timing/LeftistHeap.scala new file mode 100644 index 0000000000000000000000000000000000000000..2d3cd389ab13e437b6f2458eafc7b841bc005f4f --- /dev/null +++ b/testcases/orb-testcases/timing/LeftistHeap.scala @@ -0,0 +1,82 @@ +import leon.invariant._ +import leon.instrumentation._ +import leon.annotation._ + +object LeftistHeap { + sealed abstract class Heap + case class Leaf() extends Heap + case class Node(rk : BigInt, value: BigInt, left: Heap, right: Heap) extends Heap + + private def rightHeight(h: Heap) : BigInt = h match { + case Leaf() => 0 + case Node(_,_,_,r) => rightHeight(r) + 1 + } + + private def rank(h: Heap) : BigInt = h match { + case Leaf() => 0 + case Node(rk,_,_,_) => rk + } + + private def hasLeftistProperty(h: Heap) : Boolean = (h match { + case Leaf() => true + case Node(_,_,l,r) => hasLeftistProperty(l) && hasLeftistProperty(r) && rightHeight(l) >= rightHeight(r) && (rank(h) == rightHeight(h)) + }) + + @monotonic + def twopower(x: BigInt) : BigInt = { + require(x >= 0) + if(x < 1) 1 + else + 2* twopower(x - 1) + } + + def size(t: Heap): BigInt = { + require(hasLeftistProperty(t)) + (t match { + case Leaf() => BigInt(0) + case Node(_,v, l, r) => size(l) + 1 + size(r) + }) + } ensuring (res => true && tmpl((a,b) => twopower(rightHeight(t)) <= a*res + b)) + + def leftRightHeight(h: Heap) : BigInt = {h match { + case Leaf() => 0 + case Node(_,_,l,r) => rightHeight(l) + }} + + def removeMax(h: Heap) : Heap = { + require(hasLeftistProperty(h)) + h match { + case Node(_,_,l,r) => merge(l, r) + case l @ Leaf() => l + } + } ensuring(res => true && tmpl((a,b) => time <= a*leftRightHeight(h) + b)) + + private def merge(h1: Heap, h2: Heap) : Heap = { + require(hasLeftistProperty(h1) && hasLeftistProperty(h2)) + h1 match { + case Leaf() => h2 + case Node(_, v1, l1, r1) => h2 match { + case Leaf() => h1 + case Node(_, v2, l2, r2) => + if(v1 > v2) + makeT(v1, l1, merge(r1, h2)) + else + makeT(v2, l2, merge(h1, r2)) + } + } + } ensuring(res => true && tmpl((a,b,c) => time <= a*rightHeight(h1) + b*rightHeight(h2) + c)) + + private def makeT(value: BigInt, left: Heap, right: Heap) : Heap = { + if(rank(left) >= rank(right)) + Node(rank(right) + 1, value, left, right) + else + Node(rank(left) + 1, value, right, left) + } + + def insert(element: BigInt, heap: Heap) : Heap = { + require(hasLeftistProperty(heap)) + + merge(Node(1, element, Leaf(), Leaf()), heap) + + } ensuring(res => true && tmpl((a,b,c) => time <= a*rightHeight(heap) + c)) +} diff --git a/testcases/orb-testcases/timing/ListOperations.scala b/testcases/orb-testcases/timing/ListOperations.scala new file mode 100644 index 0000000000000000000000000000000000000000..444029464261b21ec3a4dc434c31bed4904101b8 --- /dev/null +++ b/testcases/orb-testcases/timing/ListOperations.scala @@ -0,0 +1,61 @@ +import leon.invariant._ +import leon.instrumentation._ + +object ListOperations { + sealed abstract class List + case class Cons(head: BigInt, tail: List) extends List + case class Nil() extends List + + def size(l: List): BigInt = (l match { + case Nil() => 0 + case Cons(_, t) => 1 + size(t) + }) + + def append(l1: List, l2: List): List = (l1 match { + case Nil() => l2 + case Cons(x, xs) => Cons(x, append(xs, l2)) + + }) ensuring (res => size(l1) + size(l2) == size(res) && tmpl((a,b) => time <= a*size(l1) + b)) + + def reverseRec(l1: List, l2: List): List = (l1 match { + case Nil() => l2 + case Cons(x, xs) => reverseRec(xs, Cons(x, l2)) + + }) ensuring (res => size(l1) + size(l2) == size(res) && tmpl((a,b) => time <= a*size(l1) + b)) + //ensuring (res => size(l1) + size(l2) == size(res) && time <= 4*size(l1) + 1) + + def reverse(l: List): List = { + reverseRec(l, Nil()) + + } ensuring (res => size(l) == size(res) && tmpl((a,b) => time <= a*size(l) + b)) + + def reverse2(l: List): List = { + l match { + case Nil() => l + case Cons(hd, tl) => append(reverse2(tl), Cons(hd, Nil())) + } + } ensuring (res => size(res) == size(l) && tmpl((a,b) => time <= a*(size(l)*size(l)) + b)) + + def remove(elem: BigInt, l: List): List = { + l match { + case Nil() => Nil() + case Cons(hd, tl) => if (hd == elem) remove(elem, tl) else Cons(hd, remove(elem, tl)) + } + } ensuring (res => size(l) >= size(res) && tmpl((a,b) => time <= a*size(l) + b)) + + def contains(list: List, elem: BigInt): Boolean = (list match { + case Nil() => false + case Cons(x, xs) => x == elem || contains(xs, elem) + + }) ensuring (res => true && tmpl((a,b) => time <= a*size(list) + b)) + + def distinct(l: List): List = ( + l match { + case Nil() => Nil() + case Cons(x, xs) => { + val newl = distinct(xs) + if (contains(newl, x)) newl + else Cons(x, newl) + } + }) ensuring (res => size(l) >= size(res) && tmpl((a,b) => time <= a*(size(l)*size(l)) + b)) +} diff --git a/testcases/orb-testcases/timing/MergeSort.scala b/testcases/orb-testcases/timing/MergeSort.scala new file mode 100644 index 0000000000000000000000000000000000000000..aa900aada57f70fb4ef7e63b8cbeefd82abe9457 --- /dev/null +++ b/testcases/orb-testcases/timing/MergeSort.scala @@ -0,0 +1,76 @@ +import leon.invariant._ +import leon.instrumentation._ + +import leon.annotation._ + +object MergeSort { + sealed abstract class List + case class Cons(head:BigInt,tail:List) extends List + case class Nil() extends List + + //case class Pair(fst:List,snd:List) + + @monotonic + def log(x: BigInt) : BigInt = { + require(x >= 0) + if(x <= 1) 0 + else { + val k = x/2 + 1 + log(x - k) + } + } ensuring(res => true && tmpl((a) => res >= 0)) + + def size(list:List): BigInt = {list match { + case Nil() => 0 + case Cons(x,xs) => 1 + size(xs) + }} ensuring(res => true && tmpl((a) => res >= 0)) + + def length(l:List): BigInt = { + l match { + case Nil() => 0 + case Cons(x,xs) => 1 + length(xs) + } + } ensuring(res => res == size(l) && tmpl((a,b) => time <= a*size(l) + b)) + + def split(l:List,n:BigInt): (List,List) = { + require(n >= 0 && n <= size(l)) + if (n <= 0) (Nil(),l) + else + l match { + case Nil() => (Nil(),l) + case Cons(x,xs) => { + if(n == 1) (Cons(x,Nil()), xs) + else { + val (fst,snd) = split(xs, n-1) + (Cons(x,fst), snd) + } + } + } + } ensuring(res => size(res._2) == size(l) - n && size(res._1) == n && tmpl((a,b) => time <= a*n +b)) + + def merge(aList:List, bList:List):List = (bList match { + case Nil() => aList + case Cons(x,xs) => + aList match { + case Nil() => bList + case Cons(y,ys) => + if (y < x) + Cons(y,merge(ys, bList)) + else + Cons(x,merge(aList, xs)) + } + }) ensuring(res => size(aList)+size(bList) == size(res) && tmpl((a,b,c) => time <= a*size(aList) + b*size(bList) + c)) + + def mergeSort(list:List):List = { + list match { + case Cons(x,Nil()) => list + case Cons(_,Cons(_,_)) => + val lby2 = length(list)/2 + val (fst,snd) = split(list,lby2) + //merge(mergeSort(fst,l), mergeSort(snd,len - l)) + merge(mergeSort(fst),mergeSort(snd)) + + case _ => list + + }} ensuring(res => true && tmpl((a,b) => time <= a*(size(list)*log(size(list))) + b)) +} diff --git a/testcases/orb-testcases/timing/PropositionalLogic.scala b/testcases/orb-testcases/timing/PropositionalLogic.scala new file mode 100644 index 0000000000000000000000000000000000000000..22dfdcdec06cef0760222140d669d11ae134a658 --- /dev/null +++ b/testcases/orb-testcases/timing/PropositionalLogic.scala @@ -0,0 +1,115 @@ +import scala.collection.immutable.Set +import leon.invariant._ +import leon.instrumentation._ + +object PropositionalLogic { + + sealed abstract class Formula + case class And(lhs: Formula, rhs: Formula) extends Formula + case class Or(lhs: Formula, rhs: Formula) extends Formula + case class Implies(lhs: Formula, rhs: Formula) extends Formula + case class Not(f: Formula) extends Formula + case class Literal(id: BigInt) extends Formula + case class True() extends Formula + case class False() extends Formula + + case class Pair(f: Formula, b: Boolean) + + sealed abstract class List + case class Cons(x: Pair, xs: List) extends List + case class Nil() extends List + + def size(f : Formula) : BigInt = (f match { + case And(lhs, rhs) => size(lhs) + size(rhs) + 1 + case Or(lhs, rhs) => size(lhs) + size(rhs) + 1 + case Implies(lhs, rhs) => size(lhs) + size(rhs) + 1 + case Not(f) => size(f) + 1 + case _ => 1 + }) + + def removeImplies(f: Formula): Formula = (f match { + case And(lhs, rhs) => And(removeImplies(lhs), removeImplies(rhs)) + case Or(lhs, rhs) => Or(removeImplies(lhs), removeImplies(rhs)) + case Implies(lhs, rhs) => Or(Not(removeImplies(lhs)),removeImplies(rhs)) + case Not(f) => Not(removeImplies(f)) + case _ => f + + }) ensuring(_ => time <= ? * size(f) + ?) + + def nnf(formula: Formula): Formula = (formula match { + case And(lhs, rhs) => And(nnf(lhs), nnf(rhs)) + case Or(lhs, rhs) => Or(nnf(lhs), nnf(rhs)) + case Implies(lhs, rhs) => Implies(nnf(lhs), nnf(rhs)) + case Not(And(lhs, rhs)) => Or(nnf(Not(lhs)), nnf(Not(rhs))) + case Not(Or(lhs, rhs)) => And(nnf(Not(lhs)), nnf(Not(rhs))) + case Not(Implies(lhs, rhs)) => And(nnf(lhs), nnf(Not(rhs))) + case Not(Not(f)) => nnf(f) + case Not(Literal(_)) => formula + case Literal(_) => formula + case Not(True()) => False() + case Not(False()) => True() + case _ => formula + }) ensuring(_ => time <= ? * size(formula) + ?) + + def isNNF(f: Formula): Boolean = { f match { + case And(lhs, rhs) => isNNF(lhs) && isNNF(rhs) + case Or(lhs, rhs) => isNNF(lhs) && isNNF(rhs) + case Implies(lhs, rhs) => false + case Not(Literal(_)) => true + case Not(_) => false + case _ => true + }} ensuring(_ => time <= ? * size(f) + ?) + + def simplify(f: Formula): Formula = (f match { + case And(lhs, rhs) => { + val sl = simplify(lhs) + val sr = simplify(rhs) + + //if lhs or rhs is false, return false + //if lhs is true return rhs + //if rhs is true return lhs + (sl,sr) match { + case (False(), _) => False() + case (_, False()) => False() + case (True(), _) => sr + case (_, True()) => sl + case _ => And(sl, sr) + } + } + case Or(lhs, rhs) => { + val sl = simplify(lhs) + val sr = simplify(rhs) + + //if lhs or rhs is true, return true + //if lhs is false return rhs + //if rhs is false return lhs + (sl,sr) match { + case (True(), _) => True() + case (_, True()) => True() + case (False(), _) => sr + case (_, False()) => sl + case _ => Or(sl, sr) + } + } + case Implies(lhs, rhs) => { + val sl = simplify(lhs) + val sr = simplify(rhs) + + //if lhs is false return true + //if rhs is true return true + //if lhs is true return rhs + //if rhs is false return Not(rhs) + (sl,sr) match { + case (False(), _) => True() + case (_, True()) => True() + case (True(), _) => sr + case (_, False()) => Not(sl) + case _ => Implies(sl, sr) + } + } + case Not(True()) => False() + case Not(False()) => True() + case _ => f + + }) ensuring(_ => time <= ? *size(f) + ?) +} diff --git a/testcases/orb-testcases/timing/QuickSort.scala b/testcases/orb-testcases/timing/QuickSort.scala new file mode 100644 index 0000000000000000000000000000000000000000..2f4101dc8636ff9a3dada89e73a4f11c1e908a90 --- /dev/null +++ b/testcases/orb-testcases/timing/QuickSort.scala @@ -0,0 +1,43 @@ +import leon.invariant._ +import leon.instrumentation._ + +object QuickSort { + sealed abstract class List + case class Cons(head:BigInt,tail:List) extends List + case class Nil() extends List + + def size(l:List): BigInt = {l match { + case Nil() => 0 + case Cons(x,xs) => 1 + size(xs) + }} + + case class Triple(fst:List,snd:List, trd: List) + + def append(aList:List,bList:List): List = {aList match { + case Nil() => bList + case Cons(x, xs) => Cons(x,append(xs,bList)) + }} ensuring(res => size(res) == size(aList) + size(bList) && tmpl((a,b) => time <= a*size(aList) +b)) + + def partition(n:BigInt,l:List) : Triple = (l match { + case Nil() => Triple(Nil(), Nil(), Nil()) + case Cons(x,xs) => { + val t = partition(n,xs) + if (n < x) Triple(t.fst, t.snd, Cons(x,t.trd)) + else if(n == x) Triple(t.fst, Cons(x,t.snd), t.trd) + else Triple(Cons(x,t.fst), t.snd, t.trd) + } + }) ensuring(res => (size(l) == size(res.fst) + size(res.snd) + size(res.trd)) && tmpl((a,b) => time <= a*size(l) +b)) + + //Unable to prove n^2 upper bound :-( + def quickSort(l:List): List = (l match { + case Nil() => Nil() + case Cons(x,Nil()) => l + case Cons(x,xs) => { + val t = partition(x, xs) + append(append(quickSort(t.fst), Cons(x, t.snd)), quickSort(t.trd)) + } + case _ => l + }) + //ensuring(res => size(l) == size(res) && tmpl((a,b,c,d) => time <= a*(size(l)*size(l)) + c*size(l) + d)) +} + diff --git a/testcases/orb-testcases/timing/RedBlackTree.scala b/testcases/orb-testcases/timing/RedBlackTree.scala new file mode 100644 index 0000000000000000000000000000000000000000..7ad0a1088ba09eac9195d9808d8ef6d647b84130 --- /dev/null +++ b/testcases/orb-testcases/timing/RedBlackTree.scala @@ -0,0 +1,112 @@ +import leon.invariant._ +import leon.instrumentation._ +import scala.collection.immutable.Set + +object RedBlackTree { + sealed abstract class Color + case class Red() extends Color + case class Black() extends Color + + sealed abstract class Tree + case class Empty() extends Tree + case class Node(color: Color, left: Tree, value: BigInt, right: Tree) extends Tree + + def twopower(x: BigInt) : BigInt = { + require(x >= 0) + if(x < 1) 1 + else + 2* twopower(x - 1) + } + + def size(t: Tree): BigInt = { + require(blackBalanced(t)) + (t match { + case Empty() => BigInt(0) + case Node(_, l, v, r) => size(l) + 1 + size(r) + }) + } ensuring (res => tmpl((a,b) => twopower(blackHeight(t)) <= a*res + b)) + + def blackHeight(t : Tree) : BigInt = { + t match { + case Node(Black(), l, _, _) => blackHeight(l) + 1 + case Node(Red(), l, _, _) => blackHeight(l) + case _ => 0 + } + } + + //We consider leaves to be black by definition + def isBlack(t: Tree) : Boolean = t match { + case Empty() => true + case Node(Black(),_,_,_) => true + case _ => false + } + + def redNodesHaveBlackChildren(t: Tree) : Boolean = t match { + case Empty() => true + case Node(Black(), l, _, r) => redNodesHaveBlackChildren(l) && redNodesHaveBlackChildren(r) + case Node(Red(), l, _, r) => isBlack(l) && isBlack(r) && redNodesHaveBlackChildren(l) && redNodesHaveBlackChildren(r) + case _ => false + } + + def redDescHaveBlackChildren(t: Tree) : Boolean = t match { + case Node(_,l,_,r) => redNodesHaveBlackChildren(l) && redNodesHaveBlackChildren(r) + case _ => true + } + + def blackBalanced(t : Tree) : Boolean = t match { + case Node(_,l,_,r) => blackBalanced(l) && blackBalanced(r) && blackHeight(l) == blackHeight(r) + case _ => true + } + + // <<insert element x BigInto the tree t>> + def ins(x: BigInt, t: Tree): Tree = { + require(redNodesHaveBlackChildren(t) && blackBalanced(t)) + + t match { + case Empty() => Node(Red(),Empty(),x,Empty()) + case Node(c,a,y,b) => + if(x < y) { + val t1 = ins(x, a) + balance(c, t1, y, b) + } + else if (x == y){ + Node(c,a,y,b) + } + else{ + val t1 = ins(x, b) + balance(c,a,y,t1) + } + } + } ensuring(res => tmpl((a,b) => time <= a*blackHeight(t) + b)) + + def makeBlack(n: Tree): Tree = { + n match { + case Node(Red(),l,v,r) => Node(Black(),l,v,r) + case _ => n + } + } + + def add(x: BigInt, t: Tree): Tree = { + require(redNodesHaveBlackChildren(t) && blackBalanced(t) ) + val t1 = ins(x, t) + makeBlack(t1) + + } ensuring(res => tmpl((a,b) => time <= a*blackHeight(t) + b)) + + def balance(co: Color, l: Tree, x: BigInt, r: Tree): Tree = { + Node(co,l,x,r) + match { + case Node(Black(),Node(Red(),Node(Red(),a,xV,b),yV,c),zV,d) => + Node(Red(),Node(Black(),a,xV,b),yV,Node(Black(),c,zV,d)) + case Node(Black(),Node(Red(),a,xV,Node(Red(),b,yV,c)),zV,d) => + Node(Red(),Node(Black(),a,xV,b),yV,Node(Black(),c,zV,d)) + case Node(Black(),a,xV,Node(Red(),Node(Red(),b,yV,c),zV,d)) => + Node(Red(),Node(Black(),a,xV,b),yV,Node(Black(),c,zV,d)) + case Node(Black(),a,xV,Node(Red(),b,yV,Node(Red(),c,zV,d))) => + Node(Red(),Node(Black(),a,xV,b),yV,Node(Black(),c,zV,d)) + case _ => Node(co,l,x,r) + } + } + + +} diff --git a/testcases/orb-testcases/timing/SortingCombined.scala b/testcases/orb-testcases/timing/SortingCombined.scala new file mode 100644 index 0000000000000000000000000000000000000000..914d997aff4768b4ca5d5eb33cc46d2007477ee7 --- /dev/null +++ b/testcases/orb-testcases/timing/SortingCombined.scala @@ -0,0 +1,116 @@ +import leon.invariant._ +import leon.instrumentation._ +import leon.annotation._ + +object Sort { + sealed abstract class List + case class Cons(head:BigInt,tail:List) extends List + case class Nil() extends List + + //case class Pair(fst:List,snd:List) + + // @monotonic + def log(x: BigInt) : BigInt = { + //require(x >= 0) + if(x <= 1) 0 + else 1 + log(x/2) + } //ensuring(res=> true && tmpl((b) => res >= b)) + + def size(list:List): BigInt = {list match { + case Nil() => 0 + case Cons(x,xs) => 1 + size(xs) + }} + + def length(l:List): BigInt = { + l match { + case Nil() => 0 + case Cons(x,xs) => 1 + length(xs) + } + } ensuring(res => res == size(l) && tmpl((a,b) => time <= a*size(l) + b)) + + def split(l:List,n:BigInt): (List,List) = { + require(n >= 0 && n <= size(l)) + if (n <= 0) (Nil(),l) + else + l match { + case Nil() => (Nil(),l) + case Cons(x,xs) => { + if(n == 1) (Cons(x,Nil()), xs) + else { + val (fst,snd) = split(xs, n-1) + (Cons(x,fst), snd) + } + } + } + } ensuring(res => size(res._2) == size(l) - n && size(res._1) == n && tmpl((a,b) => time <= a*n +b)) + + def merge(aList:List, bList:List):List = (bList match { + case Nil() => aList + case Cons(x,xs) => + aList match { + case Nil() => bList + case Cons(y,ys) => + if (y < x) + Cons(y,merge(ys, bList)) + else + Cons(x,merge(aList, xs)) + } + }) ensuring(res => size(aList)+size(bList) == size(res) && tmpl((a,b,c) => time <= a*size(aList) + b*size(bList) + c)) + + def mergeSort(list:List, len: BigInt):List = { + require(len == size(list)) + + list match { + case Cons(x,Nil()) => list + case Cons(_,Cons(_,_)) => + val l = len/2 + val (fst,snd) = split(list,l) + merge(mergeSort(fst,l), mergeSort(snd,len - l)) + + case _ => list + + }} //ensuring(res => size(res) == size(list) && tmpl((a,b,c) => time <= a*(size(list)*size(list)) + c)) + //&& tmpl((a,b) => time <= a*size(list) + b)) + //ensuring(res => true && tmpl((a,b) => time <= a*(size(list)*log(size(list))) + b)) + case class Triple(fst:List,snd:List, trd: List) + + def append(aList:List,bList:List): List = {aList match { + case Nil() => bList + case Cons(x, xs) => Cons(x,append(xs,bList)) + }} ensuring(res => size(res) == size(aList) + size(bList) && tmpl((a,b) => time <= a*size(aList) +b)) + + def partition(n:BigInt,l:List) : Triple = (l match { + case Nil() => Triple(Nil(), Nil(), Nil()) + case Cons(x,xs) => { + val t = partition(n,xs) + if (n < x) Triple(t.fst, t.snd, Cons(x,t.trd)) + else if(n == x) Triple(t.fst, Cons(x,t.snd), t.trd) + else Triple(Cons(x,t.fst), t.snd, t.trd) + } + }) ensuring(res => (size(l) == size(res.fst) + size(res.snd) + size(res.trd)) && tmpl((a,b) => time <= a*size(l) +b)) + + //Unable to prove n^2 upper bound :-( + def quickSort(l:List): List = (l match { + case Nil() => Nil() + case Cons(x,Nil()) => l + case Cons(x,xs) => { + val t = partition(x, xs) + append(append(quickSort(t.fst), Cons(x, t.snd)), quickSort(t.trd)) + } + case _ => l + }) + + def sortedIns(e: BigInt, l: List): List = { + l match { + case Nil() => Cons(e,Nil()) + case Cons(x,xs) => if (x <= e) Cons(x,sortedIns(e, xs)) else Cons(e, l) + } + } ensuring(res => size(res) == size(l) + 1 && tmpl((a,b) => time <= a*size(l) +b)) + + def sort(l: List): List = (l match { + case Nil() => Nil() + case Cons(x,xs) => sortedIns(x, sort(xs)) + + }) ensuring(res => size(res) == size(l) && tmpl((a,b) => time <= a*(size(l)*size(l)) +b)) + +} diff --git a/testcases/orb-testcases/timing/SpeedBenchmarks.scala b/testcases/orb-testcases/timing/SpeedBenchmarks.scala new file mode 100644 index 0000000000000000000000000000000000000000..a7349ab260eeec44f80222b7893e7cc16ea08b08 --- /dev/null +++ b/testcases/orb-testcases/timing/SpeedBenchmarks.scala @@ -0,0 +1,109 @@ +import leon.invariant._ +import leon.instrumentation._ + +object SpeedBenchmarks { + sealed abstract class List + case class Cons(head: BigInt, tail: List) extends List + case class Nil() extends List + + def size(l: List): BigInt = (l match { + case Nil() => 0 + case Cons(_, t) => 1 + size(t) + }) + + sealed abstract class StringBuffer + case class Chunk(str: List, next: StringBuffer) extends StringBuffer + case class Empty() extends StringBuffer + + def length(sb: StringBuffer) : BigInt = sb match { + case Chunk(_, next) => 1 + length(next) + case _ => 0 + } + + def sizeBound(sb: StringBuffer, k: BigInt) : Boolean ={ + sb match { + case Chunk(str, next) => size(str) <= k && sizeBound(next, k) + case _ => 0 <= k + } + } + + /** + * Fig. 1 of SPEED, POPL'09: The functional version of the implementation. + * Equality check of two string buffers + */ + def Equals(str1: List, str2: List, s1: StringBuffer, s2: StringBuffer, k: BigInt) : Boolean = { + require(sizeBound(s1, k) && sizeBound(s2, k) && size(str1) <= k && size(str2) <= k && k >= 0) + + (str1, str2) match { + case (Cons(h1,t1), Cons(h2,t2)) => { + + if(h1 != h2) false + else Equals(t1,t2, s1,s2, k) + } + case (Cons(_,_), Nil()) => { + //load from s2 + s2 match { + case Chunk(str, next) => Equals(str1, str, s1, next, k) + case Empty() => false + } + } + case (Nil(), Cons(_,_)) => { + //load from s1 + s1 match { + case Chunk(str, next) => Equals(str, str2, next, s2, k) + case Empty() => false + } + } + case _ =>{ + //load from both + (s1,s2) match { + case (Chunk(nstr1, next1),Chunk(nstr2, next2)) => Equals(nstr1, nstr2, next1, next2, k) + case (Empty(),Chunk(nstr2, next2)) => Equals(str1, nstr2, s1, next2, k) + case (Chunk(nstr1, next1), Empty()) => Equals(nstr1, str2, next1, s2, k) + case _ => true + } + } + } + } ensuring(res => true && tmpl((a,b,c,d,e) => time <= a*((k+1)*(length(s1) + length(s2))) + b*size(str1) + e)) + //ensuring(res => true && tmpl((a,b,c,d,e) => time <= a*(k*(length(s1) + length(s2))) + b*size(str1) + c*length(s1) + d*length(s2) + e)) + + def max(x: BigInt, y: BigInt) : BigInt = if(x >= y) x else y + + //Fig. 2 of Speed POPL'09 + def Dis1(x : BigInt, y : BigInt, n: BigInt, m: BigInt) : BigInt = { + if(x >= n) 0 + else { + if(y < m) Dis1(x, y+1, n, m) + else Dis1(x+1, y, n, m) + } + } ensuring(res => true && tmpl((a,b,c) => time <= a*max(0,n-x) + b*max(0,m-y) + c)) + + //Fig. 2 of Speed POPL'09 + def Dis2(x : BigInt, z : BigInt, n: BigInt) : BigInt = { + if(x >= n) 0 + else { + if(z > x) Dis2(x+1, z, n) + else Dis2(x, z+1, n) + } + } ensuring(res => true && tmpl((a,b,c) => time <= a*max(0,n-x) + b*max(0,n-z) + c)) + + //Pg. 138, Speed POPL'09 + def Dis3(x : BigInt, b : Boolean, t: BigInt, n: BigInt) : BigInt = { + require((b && t == 1) || (!b && t == -1)) + if(x > n || x < 0) 0 + else { + if(b) Dis3(x+t, b, t, n) + else Dis3(x-t, b, t, n) + } + } ensuring(res => true && tmpl((a,c) => time <= a*max(0,(n-x)) + c)) + + //Pg. 138, Speed POPL'09 + def Dis4(x : BigInt, b : Boolean, t: BigInt, n: BigInt) : BigInt = { + if(x > n || x < 0) 0 + else { + if(b) Dis4(x+t, b, t, n) + else Dis4(x-t, b, t, n) + } + } ensuring(res => true && tmpl((a,c,d,e) => (((b && t >= 0) || (!b && t < 0)) && time <= a*max(0,(n-x)) + c) + || (((!b && t >= 0) || (b && t < 0)) && time <= d*max(0,x) + e))) +} diff --git a/testcases/orb-testcases/timing/TreeOperations.scala b/testcases/orb-testcases/timing/TreeOperations.scala new file mode 100644 index 0000000000000000000000000000000000000000..2e73dcd77ddfa630ba3801514ec96482fabfdb1d --- /dev/null +++ b/testcases/orb-testcases/timing/TreeOperations.scala @@ -0,0 +1,93 @@ +import leon.invariant._ +import leon.instrumentation._ + + +object TreeOperations { + + sealed abstract class List + case class Cons(head: BigInt, tail: List) extends List + case class Nil() extends List + + sealed abstract class Tree + case class Node(left: Tree, value: BigInt, right: Tree) extends Tree + case class Leaf() extends Tree + + def listSize(l: List): BigInt = (l match { + case Nil() => 0 + case Cons(_, t) => 1 + listSize(t) + }) + + def size(t: Tree): BigInt = { + t match { + case Leaf() => 0 + case Node(l, x, r) => { + size(l) + size(r) + 1 + } + } + } + + def height(t: Tree): BigInt = { + t match { + case Leaf() => 0 + case Node(l, x, r) => { + val hl = height(l) + val hr = height(r) + if (hl > hr) hl + 1 else hr + 1 + } + } + } + + def insert(elem: BigInt, t: Tree): Tree = { + t match { + case Leaf() => Node(Leaf(), elem, Leaf()) + case Node(l, x, r) => if (x <= elem) Node(l, x, insert(elem, r)) + else Node(insert(elem, l), x, r) + } + } ensuring (res => height(res) <= height(t) + 1 && tmpl((a,b) => time <= a*height(t) + b)) + + def addAll(l: List, t: Tree): Tree = { + l match { + case Nil() => t + case Cons(x, xs) =>{ + val newt = insert(x, t) + addAll(xs, newt) + } + } + } ensuring(res => tmpl((a,b,c) => time <= a*(listSize(l) * (height(t) + listSize(l))) + b*listSize(l) + c)) + + def remove(elem: BigInt, t: Tree): Tree = { + t match { + case Leaf() => Leaf() + case Node(l, x, r) => { + + if (x < elem) Node(l, x, remove(elem, r)) + else if (x > elem) Node(remove(elem, l), x, r) + else { + t match { + case Node(Leaf(), x, Leaf()) => Leaf() + case Node(Leaf(), x, Node(_, rx, _)) => Node(Leaf(), rx, remove(rx, r)) + case Node(Node(_, lx, _), x, r) => Node(remove(lx, l), lx, r) + case _ => Leaf() + } + } + } + } + } ensuring (res => height(res) <= height(t) && tmpl ((a, b, c) => time <= a*height(t) + b)) + + def removeAll(l: List, t: Tree): Tree = { + l match { + case Nil() => t + case Cons(x, xs) => removeAll(xs, remove(x, t)) + } + } ensuring(res => tmpl((a,b,c) => time <= a*(listSize(l) * height(t)) + b*listSize(l) + c)) + + def contains(elem : BigInt, t : Tree) : Boolean = { + t match { + case Leaf() => false + case Node(l, x, r) => + if(x == elem) true + else if (x < elem) contains(elem, r) + else contains(elem, l) + } + } ensuring (res => tmpl((a,b) => time <= a*height(t) + b)) +} \ No newline at end of file