diff --git a/library/annotation/package.scala b/library/annotation/package.scala index a473add8225ce35d251051ba4c1f40a220562ce2..78121f038e15d1844f15f9ff45cdd807b4660708 100644 --- a/library/annotation/package.scala +++ b/library/annotation/package.scala @@ -10,7 +10,7 @@ package object annotation { @ignore class induct extends StaticAnnotation @ignore - class traceInduct extends StaticAnnotation + class traceInduct(name: String = "") extends StaticAnnotation @ignore class ignore extends StaticAnnotation @ignore diff --git a/src/main/scala/leon/verification/TraceInductionTactic.scala b/src/main/scala/leon/verification/TraceInductionTactic.scala index 138148231874737a50f1cb54a3fdbee8cf422af4..c739f9feb9f9d66cd1d8a03f7a633ef7cf57e8a3 100644 --- a/src/main/scala/leon/verification/TraceInductionTactic.scala +++ b/src/main/scala/leon/verification/TraceInductionTactic.scala @@ -5,34 +5,136 @@ package verification import purescala.Definitions._ import purescala.Expressions._ +import purescala.Constructors._ +import purescala.ExprOps._ +import purescala.DefOps._ +import purescala.Common._ +import purescala.Types._ +import purescala.TypeOps._ +import purescala.Extractors._ +import invariant.util.PredicateUtil._ +import leon.utils._ /** * This tactic applies only to non-recursive functions. - * Inducts over the recursive calls of the first recursive procedure, in the body of `funDef` + * Inducts over the recursive calls of the first recursive procedure in the body of `funDef` */ class TraceInductionTactic(vctx: VerificationContext) extends Tactic(vctx) { - val description : String = + val description: String = "A tactic that performs induction over the recursions of a function." + + val cg = vctx.program.callGraph + val defaultTactic = new DefaultTactic(vctx) - implicit protected val ctx = vctx.context - - def generateVCs(fd: FunDef): Seq[VC] = { - generatePostconditions(fd) ++ - generatePreconditions(fd) ++ - generateCorrectnessConditions(fd) - } - - def generatePostconditions(function: FunDef): Seq[VC] - def generatePreconditions(function: FunDef): Seq[VC] - def generateCorrectnessConditions(function: FunDef): Seq[VC] - - protected def sizeLimit(s: String, limit: Int) = { - require(limit > 3) - // Crop the call to display it properly - val res = s.takeWhile(_ != '\n').take(limit) - if (res == s) { - res - } else { - res + " ..." + def generatePostconditions(function: FunDef): Seq[VC] = { + assert(!cg.isRecursive(function) && function.body.isDefined) + val inductFunname = function.extAnnotations("traceInduct") match { + case Seq(Some(arg: String)) => Some(arg) + case a => None } + // pritn debug info + if(inductFunname.isDefined) + ctx.reporter.debug("Extracting induction pattern from: "+inductFunname.get)(DebugSectionVerification) + + if (function.hasPostcondition) { + // construct post(body) + val prop = application(function.postcondition.get, Seq(function.body.get)) + val paramVars = function.paramIds.map(_.toVariable) + // extract the first recursive call when scanning `prop` AST from left to right + var funInv: Option[FunctionInvocation] = None + preTraversal { + case _ if funInv.isDefined => + // do nothing + case fi @ FunctionInvocation(tfd, args) if cg.isRecursive(tfd.fd) // function is recurisve + && args.forall(paramVars.contains) // all arguments should be parameters + && args.toSet.size == args.size => // all arguments are unique + if (inductFunname.isDefined) { + if (inductFunname.get == tfd.fd.id.name) + funInv = Some(fi) + } else + funInv = Some(fi) + case _ => + }(prop) + funInv match { + case None => + ctx.reporter.warning("Cannot discover induction pattern! Falling back to normal tactic.") + defaultTactic.generatePostconditions(function) + case Some(finv) => + // create a new function that realizes the tactic + val tactFun = new FunDef(FreshIdentifier(function.id.name + "-VCTact"), function.tparams, + function.params, BooleanType) + tactFun.precondition = function.precondition + // the body of tactFun is a conjunction of induction pattern of finv, and the property + val callee = finv.tfd.fd + val paramIndex = paramVars.zipWithIndex.toMap + val frame = finv.args.map { case v: Variable => v } + val footprint = paramVars.filterNot(frame.contains) + val indexedFootprint = footprint.map { a => paramIndex(a) -> a }.toMap + + // the returned expression will have boolean value + def inductPattern(e: Expr): Expr = { + e match { + case IfExpr(c, th, el) => + createAnd(Seq(inductPattern(c), + IfExpr(c, inductPattern(th), inductPattern(el)))) + + case MatchExpr(scr, cases) => + val scrpat = inductPattern(scr) + val casePats = cases.map{ + case MatchCase(pat, optGuard, rhs) => + val guardPat = optGuard.toSeq.map(inductPattern _) + (guardPat, MatchCase(pat, optGuard, inductPattern(rhs))) + } + val pats = scrpat +: casePats.flatMap(_._1) :+ MatchExpr(scr, casePats.map(_._2)) + createAnd(pats) + + case Let(i, v, b) => + createAnd(Seq(inductPattern(v), Let(i, v, inductPattern(b)))) + + case FunctionInvocation(tfd, args) => + val argPattern = createAnd(args.map(inductPattern)) + if (tfd.fd == callee) { // self recursive call ? + // create a tactFun invocation to mimic the recursion pattern + val indexedArgs = (args zip frame).map { + case (a, f) => paramIndex(f) -> a + }.toMap ++ indexedFootprint + val recArgs = (0 until indexedArgs.size).map(indexedArgs) + val recCall = FunctionInvocation(TypedFunDef(tactFun, tactFun.tparams.map(_.tp)), recArgs) + createAnd(Seq(argPattern, recCall)) + } else { + argPattern + } + + case Operator(args, op) => + // conjoin all the expressions and return them + createAnd(args.map(inductPattern)) + } + } + val argsMap = callee.params.map(_.id).zip(finv.args).toMap + val tparamMap = callee.tparams.zip(finv.tfd.tps).toMap + val inlinedBody = replaceFromIDs(argsMap, + instantiateType(callee.body.get, tparamMap, Map())) + val inductScheme = inductPattern(inlinedBody) + // add body, pre and post for the tactFun + tactFun.body = Some(createAnd(Seq(inductScheme, prop))) + tactFun.precondition = function.precondition + // postcondition is `holds` + val resid = FreshIdentifier("holds", BooleanType) + tactFun.postcondition = Some(Lambda(Seq(ValDef(resid)), resid.toVariable)) + + // print debug info if needed + ctx.reporter.debug("Autogenerated tactic fun: "+tactFun)(DebugSectionVerification) + + // generate vcs using the tactfun + defaultTactic.generatePostconditions(tactFun) ++ + defaultTactic.generatePreconditions(tactFun) ++ + defaultTactic.generateCorrectnessConditions(tactFun) + } + } else Seq() } + + def generatePreconditions(function: FunDef): Seq[VC] = + defaultTactic.generatePreconditions(function) + + def generateCorrectnessConditions(function: FunDef): Seq[VC] = + defaultTactic.generateCorrectnessConditions(function) } diff --git a/src/main/scala/leon/verification/VerificationPhase.scala b/src/main/scala/leon/verification/VerificationPhase.scala index ea09de821427e39cffa3ea9c304e7d1e5a910a4c..c0539203aa529605ce4852a9b10c95d011e5f4ad 100644 --- a/src/main/scala/leon/verification/VerificationPhase.scala +++ b/src/main/scala/leon/verification/VerificationPhase.scala @@ -71,12 +71,15 @@ object VerificationPhase extends SimpleLeonPhase[Program,VerificationReport] { def generateVCs(vctx: VerificationContext, toVerify: Seq[FunDef]): Seq[VC] = { val defaultTactic = new DefaultTactic(vctx) val inductionTactic = new InductionTactic(vctx) + val trInductTactic = new TraceInductionTactic(vctx) val vcs = for(funDef <- toVerify) yield { val tactic: Tactic = if (funDef.annotations.contains("induct")) { inductionTactic - } else { + } else if(funDef.annotations.contains("traceInduct")){ + trInductTactic + }else { defaultTactic } diff --git a/testcases/verification/datastructures/TraceInductTest.scala b/testcases/verification/datastructures/TraceInductTest.scala new file mode 100644 index 0000000000000000000000000000000000000000..731eb0621fe0b9a99b483388edae96da6911ecb3 --- /dev/null +++ b/testcases/verification/datastructures/TraceInductTest.scala @@ -0,0 +1,56 @@ +import leon.annotation._ +import leon.lang._ +import leon.collection._ + +object TraceInductTest { + sealed abstract class IList + case class ICons(head: BigInt, tail: IList) extends IList + case class INil() extends IList + + // proved with unrolling=0 + def size(l: IList): BigInt = (l match { + case INil() => BigInt(0) + case ICons(_, t) => 1 + size(t) + }) //ensuring(res => res >= 0) + + @traceInduct + def nonNegSize(l: IList): Boolean = { + size(l) >= 0 + } holds + + def reverse0(l1: IList, l2: IList): IList = (l1 match { + case INil() => l2 + case ICons(x, xs) => reverse0(xs, ICons(x, l2)) + }) + + def content(l: IList): Set[BigInt] = l match { + case INil() => Set.empty[BigInt] + case ICons(x, xs) => Set(x) ++ content(xs) + } + + @traceInduct("reverse0") + def revPreservesContent(l1: IList, l2: IList): Boolean = { + content(l1) ++ content(l2) == content(reverse0(l1, l2)) + } holds + + def insertAtIndex[T](l: List[T], i: BigInt, y: T): List[T] = { + require(0 <= i && i <= l.size) + l match { + case Nil() => + Cons[T](y, Nil()) + case _ if i == 0 => + Cons[T](y, l) + case Cons(x, tail) => + Cons[T](x, insertAtIndex(tail, i - 1, y)) + } + } + + // A lemma about `append` and `insertAtIndex` + @traceInduct("insertAtIndex") + def appendInsertIndex[T](l1: List[T], l2: List[T], i: BigInt, y: T): Boolean = { + require(0 <= i && i <= l1.size + l2.size) + (insertAtIndex((l1 ++ l2), i, y) == ( + if (i < l1.size) insertAtIndex(l1, i, y) ++ l2 + else l1 ++ insertAtIndex(l2, (i - l1.size), y))) + }.holds +}