diff --git a/src/main/scala/leon/codegen/CodeGeneration.scala b/src/main/scala/leon/codegen/CodeGeneration.scala index 9a36ab3292caf5c34e3b5dd2b9b86744d3f70628..f3abb7513d9800035d95b13a7c231b42672ecfea 100644 --- a/src/main/scala/leon/codegen/CodeGeneration.scala +++ b/src/main/scala/leon/codegen/CodeGeneration.scala @@ -134,6 +134,12 @@ trait CodeGeneration { } ch << instr + case Assert(cond, oerr, body) => + mkExpr(IfExpr(Not(cond), Error(oerr.getOrElse("Assertion failed @"+e.getPos)), body), ch) + + case Ensuring(body, id, post) => + mkExpr(Let(id, body, Assert(post, Some("Ensuring failed"), Variable(id))), ch) + case Let(i,d,b) => mkExpr(d, ch) val slot = ch.getFreshVar diff --git a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala index 92f2b579f509faceb099c0eb6417ac1db526fa73..425cdc19d70041d31c79b40cb46e3e0c4be3fa76 100644 --- a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala +++ b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala @@ -76,6 +76,12 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program) extends Evalu val first = e(ex) e(b)(rctx.withNewVar(i, first), gctx) + case Assert(cond, oerr, body) => + e(IfExpr(Not(cond), Error(oerr.getOrElse("Assertion failed @"+expr.getPos)), body)) + + case Ensuring(body, id, post) => + e(Let(id, body, Assert(post, Some("Ensuring failed"), Variable(id)))) + case Error(desc) => throw RuntimeError("Error reached in evaluation: " + desc) diff --git a/src/main/scala/leon/frontends/scalac/ASTExtractors.scala b/src/main/scala/leon/frontends/scalac/ASTExtractors.scala index 97fbbc2ac69919c452256b57bb2152fea490a720..c018e48fa47232d62a655f93123d2952148b7a24 100644 --- a/src/main/scala/leon/frontends/scalac/ASTExtractors.scala +++ b/src/main/scala/leon/frontends/scalac/ASTExtractors.scala @@ -140,16 +140,26 @@ trait ASTExtractors { object ExRequiredExpression { /** Extracts the 'require' contract from an expression (only if it's the * first call in the block). */ - def unapply(tree: Block): Option[(Tree,Tree)] = tree match { - case Block(Apply(ExSelected("scala", "Predef", "require"), contractBody :: Nil) :: rest, body) => - if(rest.isEmpty) - Some((body,contractBody)) - else - Some((Block(rest,body),contractBody)) + def unapply(tree: Apply): Option[Tree] = tree match { + case Apply(ExSelected("scala", "Predef", "require"), contractBody :: Nil) => + Some(contractBody) case _ => None } } + object ExAssertExpression { + /** Extracts the 'assert' contract from an expression (only if it's the + * first call in the block). */ + def unapply(tree: Apply): Option[(Tree, Option[String])] = tree match { + case Apply(ExSelected("scala", "Predef", "assert"), contractBody :: Nil) => + Some((contractBody, None)) + case Apply(ExSelected("scala", "Predef", "assert"), contractBody :: (error: Literal) :: Nil) => + Some((contractBody, Some(error.value.stringValue))) + case _ => + None + } + } + object ExObjectDef { /** Matches an object with no type parameters, and regardless of its diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala index 95912df292aad6fdf4efdbd9194e4ddb6f9dcd3d..3474f6d7c4734690360003f6175566c8ee4b8abe 100644 --- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala +++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala @@ -560,32 +560,8 @@ trait CodeExtraction extends ASTExtractors { val fctx = dctx.withNewVars(newVars) - - val (body2, ensuring) = body match { - case ExEnsuredExpression(body2, resSym, contract) => - val resId = FreshIdentifier(resSym.name.toString).setType(funDef.returnType).setPos(resSym.pos) - val post = toPureScala(contract)(fctx.withNewVar(resSym -> (() => Variable(resId)))).map( r => (resId, r)) - - (body2, post) - - case t @ ExHoldsExpression(body2) => - val resId = FreshIdentifier("holds").setType(BooleanType).setPos(body.pos) - (body2, Some((resId, Variable(resId).setPos(body.pos)))) - - case _ => - (body, None) - } - - val (body3, require) = body2 match { - case ExRequiredExpression(body3, contract) => - (body3, toPureScala(contract)(fctx)) - - case _ => - (body2, None) - } - val finalBody = try { - Some(flattenBlocks(extractTree(body3)(fctx)) match { + flattenBlocks(extractTree(body)(fctx)) match { case e if e.getType.isInstanceOf[ArrayType] => getOwner(e) match { case Some(Some(fd)) if fd == funDef => @@ -595,16 +571,14 @@ trait CodeExtraction extends ASTExtractors { e case _ => - outOfSubsetError(body3, "Function cannot return an array that is not locally defined") + outOfSubsetError(body, "Function cannot return an array that is not locally defined") } case e => e - }) + } } catch { case e: ImpureCodeEncounteredException => - if (dctx.isProxy) { - // We actually expect errors, no point reporting - } else { + if (!dctx.isProxy) { e.emit() if (ctx.settings.strictCompilation) { reporter.error(funDef.getPos, "Function "+funDef.id.name+" could not be extracted. (Forgot @proxy ?)") @@ -614,30 +588,27 @@ trait CodeExtraction extends ASTExtractors { } funDef.addAnnotation("abstract") - None + NoTree(funDef.returnType) } - val finalRequire = require.filter{ e => + funDef.fullBody = finalBody; + + // Post-extraction sanity checks + + funDef.precondition.foreach { case e => if(containsLetDef(e)) { - reporter.warning(body3.pos, "Function precondtion should not contain nested function definition, ignoring.") - false - } else { - true + reporter.warning(e.getPos, "Function precondtion should not contain nested function definition, ignoring.") + funDef.precondition = None } } - val finalEnsuring = ensuring.filter{ case (id, e) => + funDef.postcondition.foreach { case (id, e) => if(containsLetDef(e)) { - reporter.warning(body3.pos, "Function postcondition should not contain nested function definition, ignoring.") - false - } else { - true + reporter.warning(e.getPos, "Function postcondition should not contain nested function definition, ignoring.") + funDef.postcondition = None } } - funDef.body = finalBody - funDef.precondition = finalRequire - funDef.postcondition = finalEnsuring funDef } @@ -730,6 +701,54 @@ trait CodeExtraction extends ASTExtractors { var rest = tmpRest val res = current match { + case ExEnsuredExpression(body, resSym, contract) => + val resId = FreshIdentifier(resSym.name.toString).setType(extractType(current)).setPos(resSym.pos) + val post = extractTree(contract)(dctx.withNewVar(resSym -> (() => Variable(resId)))) + + val b = try { + extractTree(body) + } catch { + case (e: ImpureCodeEncounteredException) if dctx.isProxy => + NoTree(toPureScalaType(current.tpe)(dctx, current.pos)) + } + + Ensuring(b, resId, post) + + case t @ ExHoldsExpression(body) => + val resId = FreshIdentifier("holds").setType(BooleanType).setPos(current.pos) + val post = Variable(resId).setPos(current.pos) + + val b = try { + extractTree(body) + } catch { + case (e: ImpureCodeEncounteredException) if dctx.isProxy => + NoTree(toPureScalaType(current.tpe)(dctx, current.pos)) + } + + Ensuring(b, resId, post) + + case ExAssertExpression(contract, oerr) => + val const = extractTree(contract) + val b = rest.map(extractTree).getOrElse(UnitLiteral()) + + rest = None + + Assert(const, oerr, b) + + case ExRequiredExpression(contract) => + val pre = extractTree(contract) + + val b = try { + rest.map(extractTree).getOrElse(UnitLiteral()) + } catch { + case (e: ImpureCodeEncounteredException) if dctx.isProxy => + NoTree(toPureScalaType(current.tpe)(dctx, current.pos)) + } + + rest = None + + Require(pre, b) + case ExArrayLiteral(tpe, args) => FiniteArray(args.map(extractTree)).setType(ArrayType(extractType(tpe)(dctx, current.pos))) diff --git a/src/main/scala/leon/purescala/Definitions.scala b/src/main/scala/leon/purescala/Definitions.scala index ac99c1146633be43e695bd66e727a608588ed3eb..54b43592e361d96f61a752e39671962b497ad84b 100644 --- a/src/main/scala/leon/purescala/Definitions.scala +++ b/src/main/scala/leon/purescala/Definitions.scala @@ -181,10 +181,21 @@ object Definitions { /** Functions (= 'methods' of objects) */ class FunDef(val id: Identifier, val tparams: Seq[TypeParameterDef], val returnType: TypeTree, val params: Seq[ValDef]) extends Definition { - var body: Option[Expr] = None - def implementation : Option[Expr] = body - var precondition: Option[Expr] = None - var postcondition: Option[(Identifier, Expr)] = None + + var fullBody: Expr = NoTree(returnType) + + def body: Option[Expr] = withoutSpec(fullBody) + def body_=(b: Option[Expr]) = fullBody = withBody(fullBody, b) + + def precondition = preconditionOf(fullBody) + def precondition_=(oe: Option[Expr]) = { + fullBody = withPrecondition(fullBody, oe) + } + + def postcondition = postconditionOf(fullBody) + def postcondition_=(op: Option[(Identifier, Expr)]) = { + fullBody = withPostcondition(fullBody, op) + } // Metadata kept here after transformations var parent: Option[FunDef] = None @@ -192,16 +203,13 @@ object Definitions { def duplicate: FunDef = { val fd = new FunDef(id, tparams, returnType, params) - fd.body = body - fd.precondition = precondition - fd.postcondition = postcondition + fd.fullBody = fullBody fd.parent = parent fd.orig = orig fd } - def hasImplementation : Boolean = body.isDefined - def hasBody = hasImplementation + def hasBody = body.isDefined def hasPrecondition : Boolean = precondition.isDefined def hasPostcondition : Boolean = postcondition.isDefined @@ -224,6 +232,13 @@ object Definitions { TypedFunDef(this, Nil) } + // Deprecated, old API + @deprecated("Use .body instead", "2.3") + def implementation : Option[Expr] = body + + @deprecated("Use .hasBody instead", "2.3") + def hasImplementation : Boolean = hasBody + } diff --git a/src/main/scala/leon/purescala/Extractors.scala b/src/main/scala/leon/purescala/Extractors.scala index fcd088388f1af4117d2e35d67b755b7c52df18af..11a2e7281ed2cbb233233a2a6e2d1db0dd597a06 100644 --- a/src/main/scala/leon/purescala/Extractors.scala +++ b/src/main/scala/leon/purescala/Extractors.scala @@ -73,6 +73,9 @@ object Extractors { case ListAt(t1,t2) => Some((t1,t2,ListAt)) case Let(binders, e, body) => Some((e, body, (e: Expr, b: Expr) => Let(binders, e, b))) case LetTuple(binders, e, body) => Some((e, body, (e: Expr, b: Expr) => LetTuple(binders, e, b))) + case Require(pre, body) => Some((pre, body, Require)) + case Ensuring(body, id, post) => Some((body, post, (b: Expr, p: Expr) => Ensuring(b, id, p))) + case Assert(const, oerr, body) => Some((const, body, (c: Expr, b: Expr) => Assert(c, oerr, b))) case (ex: BinaryExtractable) => ex.extract case _ => None } diff --git a/src/main/scala/leon/purescala/TreeOps.scala b/src/main/scala/leon/purescala/TreeOps.scala index 87bebfe6d38f88dab077c02fbb90916be416eee0..68e9404494d9cd7a29cb065a1d52cdd293148f95 100644 --- a/src/main/scala/leon/purescala/TreeOps.scala +++ b/src/main/scala/leon/purescala/TreeOps.scala @@ -1852,6 +1852,54 @@ object TreeOps { simplifyArithmetic(expr0) } + /** + * Body manipulation + * ======== + */ + + def withPrecondition(expr: Expr, pred: Option[Expr]): Expr = (pred, expr) match { + case (Some(newPre), Require(pre, b)) => Require(newPre, b) + case (Some(newPre), Ensuring(Require(pre, b), i, p)) => Ensuring(Require(newPre, b), i, p) + case (Some(newPre), Ensuring(b, i, p)) => Ensuring(Require(newPre, b), i, p) + case (Some(newPre), b) => Require(newPre, b) + case (None, Require(pre, b)) => b + case (None, Ensuring(Require(pre, b), i, p)) => Ensuring(b, i, p) + case (None, Ensuring(b, i, p)) => Ensuring(b, i, p) + case (None, b) => b + } + + def withPostcondition(expr: Expr, oie: Option[(Identifier, Expr)]) = (oie, expr) match { + case (Some((nid, npost)), Ensuring(b, id, post)) => Ensuring(b, nid, npost) + case (Some((nid, npost)), b) => Ensuring(b, nid, npost) + case (None, Ensuring(b, i, p)) => b + case (None, b) => b + } + + def withBody(expr: Expr, body: Option[Expr]) = expr match { + case Require(pre, _) => Require(pre, body.getOrElse(NoTree(expr.getType))) + case Ensuring(Require(pre, _), i, post) => Ensuring(Require(pre, body.getOrElse(NoTree(expr.getType))), i, post) + case Ensuring(_, i, post) => Ensuring(body.getOrElse(NoTree(expr.getType)), i, post) + case _ => body.getOrElse(NoTree(expr.getType)) + } + + def withoutSpec(expr: Expr) = expr match { + case Require(pre, b) => Option(b).filterNot(_.isInstanceOf[NoTree]) + case Ensuring(Require(pre, b), i, post) => Option(b).filterNot(_.isInstanceOf[NoTree]) + case Ensuring(b, i, post) => Option(b).filterNot(_.isInstanceOf[NoTree]) + case b => Option(b).filterNot(_.isInstanceOf[NoTree]) + } + + def preconditionOf(expr: Expr) = expr match { + case Require(pre, _) => Some(pre) + case Ensuring(Require(pre, _), _, _) => Some(pre) + case b => None + } + + def postconditionOf(expr: Expr) = expr match { + case Ensuring(_, i, post) => Some((i, post)) + case _ => None + } + /** * Deprecated API diff --git a/src/main/scala/leon/purescala/Trees.scala b/src/main/scala/leon/purescala/Trees.scala index b7624a842c5322fb1e482b0d0c5b849d94973d44..2a56870b90183c225f39a3a1cfa111f9e207da9a 100644 --- a/src/main/scala/leon/purescala/Trees.scala +++ b/src/main/scala/leon/purescala/Trees.scala @@ -21,11 +21,27 @@ object Trees { self: Expr => } + case class NoTree(tpe: TypeTree) extends Expr with Terminal with FixedType { + val fixedType = tpe + } + /* This describes computational errors (unmatched case, taking min of an * empty set, division by zero, etc.). It should always be typed according to * the expected type. */ case class Error(description: String) extends Expr with Terminal + case class Require(pred: Expr, body: Expr) extends Expr with FixedType { + val fixedType = body.getType + } + + case class Ensuring(body: Expr, id: Identifier, pred: Expr) extends Expr with FixedType { + val fixedType = body.getType + } + + case class Assert(pred: Expr, error: Option[String], body: Expr) extends Expr with FixedType { + val fixedType = body.getType + } + case class Choose(vars: List[Identifier], pred: Expr) extends Expr with FixedType with UnaryExtractable { assert(!vars.isEmpty) diff --git a/src/main/scala/leon/solvers/combinators/FunctionTemplate.scala b/src/main/scala/leon/solvers/combinators/FunctionTemplate.scala index 820d1b963f3f9c9bc3a43fba2db016b225e21533..85ad52f78d1870343c9d8678d2ec2176a4eddac6 100644 --- a/src/main/scala/leon/solvers/combinators/FunctionTemplate.scala +++ b/src/main/scala/leon/solvers/combinators/FunctionTemplate.scala @@ -140,6 +140,13 @@ object FunctionTemplate { def rec(pathVar : Identifier, expr : Expr) : Expr = { expr match { + case a @ Assert(cond, _, body) => + storeGuarded(pathVar, rec(pathVar, cond)) + rec(pathVar, body) + + case e @ Ensuring(body, id, post) => + rec(pathVar, Let(id, body, Assert(post, None, Variable(id)))) + case l @ Let(i, e, b) => val newExpr : Identifier = FreshIdentifier("lt", true).setType(i.getType) exprVars += newExpr diff --git a/src/main/scala/leon/solvers/z3/FunctionTemplate.scala b/src/main/scala/leon/solvers/z3/FunctionTemplate.scala index 9fe4be2cc7b46e971b439766bf639f3c285afd4a..cf993c1b9054144987674d0ac2d8578d5e98b141 100644 --- a/src/main/scala/leon/solvers/z3/FunctionTemplate.scala +++ b/src/main/scala/leon/solvers/z3/FunctionTemplate.scala @@ -148,8 +148,6 @@ class FunctionTemplate private( } object FunctionTemplate { - val splitAndOrImplies = false - def mkTemplate(solver: FairZ3Solver, tfd: TypedFunDef, isRealFunDef : Boolean = true) : FunctionTemplate = { val condVars : MutableSet[Identifier] = MutableSet.empty val exprVars : MutableSet[Identifier] = MutableSet.empty @@ -191,8 +189,22 @@ object FunctionTemplate { res } + def requireDecomposition(e: Expr) = { + exists{ + case (_: FunctionInvocation) | (_: Assert) | (_: Ensuring) | (_: Choose) => true + case _ => false + }(e) + } + def rec(pathVar : Identifier, expr : Expr) : Expr = { expr match { + case a @ Assert(cond, _, body) => + storeGuarded(pathVar, rec(pathVar, cond)) + rec(pathVar, body) + + case e @ Ensuring(body, id, post) => + rec(pathVar, Let(id, body, Assert(post, None, Variable(id)))) + case l @ Let(i, e, b) => val newExpr : Identifier = FreshIdentifier("lt", true).setType(i.getType) exprVars += newExpr @@ -221,47 +233,16 @@ object FunctionTemplate { case m : MatchExpr => sys.error("MatchExpr's should have been eliminated.") case i @ Implies(lhs, rhs) => - if (splitAndOrImplies) { - if (containsFunctionCalls(i)) { - rec(pathVar, IfExpr(lhs, rhs, BooleanLiteral(true))) - } else { - i - } - } else { - Implies(rec(pathVar, lhs), rec(pathVar, rhs)) - } + Implies(rec(pathVar, lhs), rec(pathVar, rhs)) case a @ And(parts) => - if (splitAndOrImplies) { - if (containsFunctionCalls(a)) { - val partitions = groupWhile((e: Expr) => !containsFunctionCalls(e), parts) - - val ifExpr = partitions.map(And(_)).reduceRight{ (a: Expr, b: Expr) => IfExpr(a, b, BooleanLiteral(false)) } - - rec(pathVar, ifExpr) - } else { - a - } - } else { - And(parts.map(rec(pathVar, _))) - } + And(parts.map(rec(pathVar, _))) case o @ Or(parts) => - if (splitAndOrImplies) { - if (containsFunctionCalls(o)) { - val partitions = groupWhile((e: Expr) => !containsFunctionCalls(e), parts) - - val ifExpr = partitions.map(Or(_)).reduceRight{ (a: Expr, b: Expr) => IfExpr(a, BooleanLiteral(true), b) } - rec(pathVar, ifExpr) - } else { - o - } - } else { - Or(parts.map(rec(pathVar, _))) - } + Or(parts.map(rec(pathVar, _))) case i @ IfExpr(cond, thenn, elze) => { - if(!containsFunctionCalls(i)) { + if(!requireDecomposition(i)) { i } else { val newBool1 : Identifier = FreshIdentifier("b", true).setType(BooleanType) diff --git a/src/main/scala/leon/synthesis/Rules.scala b/src/main/scala/leon/synthesis/Rules.scala index aa79a14703d0da8a278dca069b1aeed835cb79d4..85c8b60245e0f36626f52031b8498e13c3ad5e6f 100644 --- a/src/main/scala/leon/synthesis/Rules.scala +++ b/src/main/scala/leon/synthesis/Rules.scala @@ -25,7 +25,7 @@ object Rules { EqualitySplit, InequalitySplit, CEGIS, - Assert, + rules.Assert, DetupleOutput, DetupleInput, ADTSplit, diff --git a/src/main/scala/leon/termination/SimpleTerminationChecker.scala b/src/main/scala/leon/termination/SimpleTerminationChecker.scala index 1a8f6aaeff17d7d853c29559f3ff30f36f87b1e1..be6e98ae3656ec41ac8053ff430e997e6d07d1ad 100644 --- a/src/main/scala/leon/termination/SimpleTerminationChecker.scala +++ b/src/main/scala/leon/termination/SimpleTerminationChecker.scala @@ -53,7 +53,7 @@ class SimpleTerminationChecker(context: LeonContext, program: Program) extends T return NoGuarantee // This is also too confusing for me to think about now. - if (!funDef.hasImplementation) + if (!funDef.hasBody) return NoGuarantee val sccIndex = funDefToSCCIndex.getOrElse(funDef, { diff --git a/src/main/scala/leon/utils/PreprocessingPhase.scala b/src/main/scala/leon/utils/PreprocessingPhase.scala index 896e16c786d5af7daf83a39febd7a8c217728041..a7e0b46a0832042f4acc989e1dc762f5eae0ff99 100644 --- a/src/main/scala/leon/utils/PreprocessingPhase.scala +++ b/src/main/scala/leon/utils/PreprocessingPhase.scala @@ -17,7 +17,8 @@ object PreprocessingPhase extends TransformationPhase { val phases = MethodLifting andThen TypingPhase andThen - CompleteAbstractDefinitions + CompleteAbstractDefinitions andThen + InjectAsserts phases.run(ctx)(p) } diff --git a/src/main/scala/leon/verification/AnalysisPhase.scala b/src/main/scala/leon/verification/AnalysisPhase.scala index 17d0222b3c3625d41636b50654f441bb427d690b..922ca045f16cebbb078059c3c3da595927fbfe47 100644 --- a/src/main/scala/leon/verification/AnalysisPhase.scala +++ b/src/main/scala/leon/verification/AnalysisPhase.scala @@ -32,10 +32,8 @@ object AnalysisPhase extends LeonPhase[Program,VerificationReport] { import vctx.reporter import vctx.program - val defaultTactic = new DefaultTactic(reporter) - defaultTactic.setProgram(program) - val inductionTactic = new InductionTactic(reporter) - inductionTactic.setProgram(program) + val defaultTactic = new DefaultTactic(vctx) + val inductionTactic = new InductionTactic(vctx) var allVCs = Map[FunDef, List[VerificationCondition]]() @@ -64,11 +62,7 @@ object AnalysisPhase extends LeonPhase[Program,VerificationReport] { } if(funDef.body.isDefined) { - val funVCs = tactic.generatePreconditions(funDef) ++ - tactic.generatePatternMatchingExhaustivenessChecks(funDef) ++ - tactic.generatePostconditions(funDef) ++ - tactic.generateMiscCorrectnessConditions(funDef) ++ - tactic.generateArrayAccessChecks(funDef) + val funVCs = tactic.generateVCs(funDef) allVCs += funDef -> funVCs.toList } diff --git a/src/main/scala/leon/verification/DefaultTactic.scala b/src/main/scala/leon/verification/DefaultTactic.scala index f947286b838a0fca249f07f71341a8bf8f5dc43f..7e03ccebaf9c947c4206b38613474d7f1803f75c 100644 --- a/src/main/scala/leon/verification/DefaultTactic.scala +++ b/src/main/scala/leon/verification/DefaultTactic.scala @@ -11,229 +11,61 @@ import purescala.Definitions._ import scala.collection.mutable.{Map => MutableMap} -class DefaultTactic(reporter: Reporter) extends Tactic(reporter) { +class DefaultTactic(vctx: VerificationContext) extends Tactic(vctx) { val description = "Default verification condition generation approach" override val shortDescription = "default" - var _prog : Option[Program] = None - def program : Program = _prog match { - case None => throw new Exception("Program never set in DefaultTactic.") - case Some(p) => p - } - - override def setProgram(program: Program) : Unit = { - _prog = Some(program) - } - - def generatePostconditions(functionDefinition: FunDef) : Seq[VerificationCondition] = { - assert(functionDefinition.body.isDefined) - val prec = functionDefinition.precondition - val optPost = functionDefinition.postcondition - val body = matchToIfThenElse(functionDefinition.body.get) - - optPost match { - case None => - Seq() - - case Some((id, post)) => - val theExpr = { - val resFresh = FreshIdentifier("result", true).setType(body.getType) - val bodyAndPost = Let(resFresh, body, replace(Map(Variable(id) -> Variable(resFresh)), matchToIfThenElse(post))) - - val withPrec = if(prec.isEmpty) { - bodyAndPost - } else { - Implies(matchToIfThenElse(prec.get), bodyAndPost) - } - - withPrec - } - Seq(new VerificationCondition(theExpr, functionDefinition, VCKind.Postcondition, this).setPos(post)) - } - } - - def generatePreconditions(function: FunDef) : Seq[VerificationCondition] = { - val toRet = if(function.hasBody) { - val pre = matchToIfThenElse(function.body.get) - val cleanBody = expandLets(pre) - - val allPathConds = collectWithPathCondition((t => t match { - case FunctionInvocation(tfd, _) if(tfd.hasPrecondition) => true - case _ => false - }), cleanBody) + def generatePostconditions(fd: FunDef): Seq[VerificationCondition] = { + (fd.postcondition, fd.body) match { + case (Some((id, post)), Some(body)) => + val res = id.freshen + val vc = Implies(precOrTrue(fd), Let(res, safe(body), replace(Map(id.toVariable -> res.toVariable), safe(post)))) - def withPrecIfDefined(path: Seq[Expr], shouldHold: Expr) : Expr = if(function.hasPrecondition) { - Not(And(And(matchToIfThenElse(function.precondition.get) +: path), Not(shouldHold))) - } else { - Not(And(And(path), Not(shouldHold))) - } - - allPathConds.map(pc => { - val path : Seq[Expr] = pc._1 - val fi = pc._2.asInstanceOf[FunctionInvocation] - val FunctionInvocation(tfd, args) = fi - val prec : Expr = freshenLocals(matchToIfThenElse(tfd.precondition.get)) - val newLetIDs = tfd.params.map(a => FreshIdentifier("arg_" + a.id.name, true).setType(a.tpe)) - val substMap = Map[Expr,Expr]((tfd.params.map(_.toVariable) zip newLetIDs.map(Variable(_))) : _*) - val newBody : Expr = replace(substMap, prec) - val newCall : Expr = (newLetIDs zip args).foldRight(newBody)((iap, e) => Let(iap._1, iap._2, e)) - - new VerificationCondition( - withPrecIfDefined(path, newCall), - function, - VCKind.Precondition, - this.asInstanceOf[DefaultTactic]).setPos(fi) - }).toSeq - } else { - Seq.empty + Seq(new VerificationCondition(vc, fd, VCKind.Postcondition, this).setPos(post)) + case _ => + Nil } - - // println("PRECS VCs FOR " + function.id.name) - // println(toRet.toList.map(vc => vc.posInfo + " -- " + vc.condition).mkString("\n\n")) - - toRet } - - def generatePatternMatchingExhaustivenessChecks(function: FunDef) : Seq[VerificationCondition] = { - val toRet = if(function.hasBody) { - val cleanBody = matchToIfThenElse(function.body.get) - - val allPathConds = collectWithPathCondition((t => t match { - case Error("non-exhaustive match") => true - case _ => false - }), cleanBody) - def withPrecIfDefined(conds: Seq[Expr]) : Expr = if(function.hasPrecondition) { - Not(And(matchToIfThenElse(function.precondition.get), And(conds))) - } else { - Not(And(conds)) - } - - allPathConds.map(pc => - new VerificationCondition( - withPrecIfDefined(pc._1), - function,//if(function.fromLoop) function.parent.get else function, - VCKind.ExhaustiveMatch, - this.asInstanceOf[DefaultTactic]).setPos(pc._2) - ).toSeq - } else { - Seq.empty - } - - // println("MATCHING VCs FOR " + function.id.name) - // println(toRet.toList.map(vc => vc.posInfo + " -- " + vc.condition).mkString("\n\n")) - - toRet - } - - def generateMapAccessChecks(function: FunDef) : Seq[VerificationCondition] = { - val toRet = if (function.hasBody) { - val cleanBody = mapGetWithChecks(matchToIfThenElse(function.body.get)) - - val allPathConds = collectWithPathCondition((t => t match { - case Error("key not found for map access") => true - case _ => false - }), cleanBody) - - def withPrecIfDefined(conds: Seq[Expr]) : Expr = if (function.hasPrecondition) { - Not(And(mapGetWithChecks(matchToIfThenElse(function.precondition.get)), And(conds))) - } else { - Not(And(conds)) - } + def generatePreconditions(fd: FunDef): Seq[VerificationCondition] = { + fd.body match { + case Some(body) => + val calls = collectWithPC { + case c @ FunctionInvocation(tfd, _) if tfd.hasPrecondition => (c, tfd.precondition.get) + }(safe(body)) + + calls.map { + case ((fi @ FunctionInvocation(tfd, args), pre), path) => + val pre2 = replaceFromIDs((tfd.params.map(_.id) zip args).toMap, safe(pre)) + val vc = Implies(And(precOrTrue(fd), path), pre2) + + new VerificationCondition(vc, fd, VCKind.Precondition, this).setPos(fi) + } - allPathConds.map(pc => - new VerificationCondition( - withPrecIfDefined(pc._1), - function, //if(function.fromLoop) function.parent.get else function, - VCKind.MapAccess, - this.asInstanceOf[DefaultTactic]).setPos(pc._2) - ).toSeq - } else { - Seq.empty + case None => + Nil } - - toRet } - def generateArrayAccessChecks(function: FunDef) : Seq[VerificationCondition] = { - val toRet = if (function.hasBody) { - val cleanBody = matchToIfThenElse(function.body.get) - - val allPathConds = CollectorWithPaths { - case expr@ArraySelect(a, i) => (expr, a, i) - case expr@ArrayUpdated(a, i, _) => (expr, a, i) - }.traverse(cleanBody) - - val arrayAccessConditions = allPathConds.map{ - case ((expr, array, index), pathCond) => { - val length = ArrayLength(array) - val negative = LessThan(index, IntLiteral(0)) - val tooBig = GreaterEquals(index, length) - (And(pathCond, Or(negative, tooBig)), expr) + def generateCorrectnessConditions(fd: FunDef): Seq[VerificationCondition] = { + fd.body match { + case Some(body) => + val calls = collectWithPC { + case e @ Error(_) => (e, BooleanLiteral(false)) + case a @ Assert(cond, _, _) => (a, cond) + // Only triggered for inner ensurings, general postconditions are handled by generatePostconditions + case a @ Ensuring(body, id, post) => (a, Let(id, body, post)) + }(safe(body)) + + calls.map { + case ((e, errorCond), path) => + val vc = Implies(And(precOrTrue(fd), path), errorCond) + + new VerificationCondition(vc, fd, VCKind.Correctness, this).setPos(e) } - } - def withPrecIfDefined(conds: Expr) : Expr = if (function.hasPrecondition) { - Not(And(mapGetWithChecks(matchToIfThenElse(function.precondition.get)), conds)) - } else { - Not(conds) - } - - - arrayAccessConditions.map(pc => - new VerificationCondition( - withPrecIfDefined(pc._1), - function, //if(function.fromLoop) function.parent.get else function, - VCKind.ArrayAccess, - this.asInstanceOf[DefaultTactic]).setPos(pc._2) - ).toSeq - } else { - Seq.empty + case None => + Nil } - - toRet - } - - def generateMiscCorrectnessConditions(function: FunDef) : Seq[VerificationCondition] = { - generateMapAccessChecks(function) } - - def collectWithPathCondition(matcher: Expr=>Boolean, expression: Expr) : Set[(Seq[Expr],Expr)] = { - CollectorWithPaths({ - case e if matcher(e) => e - }).traverse(expression).map{ - case (e, And(es)) => (es, e) - case (e1, e2) => (Seq(e2), e1) - }.toSet - } - // prec: there should be no lets and no pattern-matching in this expression - //def collectWithPathCondition(matcher: Expr=>Boolean, expression: Expr) : Set[(Seq[Expr],Expr)] = { - // var collected : Set[(Seq[Expr],Expr)] = Set.empty - - // def rec(expr: Expr, path: List[Expr]) : Unit = { - // if(matcher(expr)) { - // collected = collected + ((path.reverse, expr)) - // } - - // expr match { - // case Let(i,e,b) => { - // rec(e, path) - // rec(b, Equals(Variable(i), e) :: path) - // } - // case IfExpr(cond, thenn, elze) => { - // rec(cond, path) - // rec(thenn, cond :: path) - // rec(elze, Not(cond) :: path) - // } - // case NAryOperator(args, _) => args.foreach(rec(_, path)) - // case BinaryOperator(t1, t2, _) => rec(t1, path); rec(t2, path) - // case UnaryOperator(t, _) => rec(t, path) - // case t : Terminal => ; - // case _ => scala.sys.error("Unhandled tree in collectWithPathCondition : " + expr) - // } - // } - - // rec(expression, Nil) - // collected - //} } diff --git a/src/main/scala/leon/verification/InductionTactic.scala b/src/main/scala/leon/verification/InductionTactic.scala index 70824c7edaa57011a80b52d39691e10d668ed527..73b0655bd1ab1c92432815c044c2ee8a63697e80 100644 --- a/src/main/scala/leon/verification/InductionTactic.scala +++ b/src/main/scala/leon/verification/InductionTactic.scala @@ -9,132 +9,83 @@ import purescala.TreeOps._ import purescala.TypeTrees._ import purescala.Definitions._ -class InductionTactic(reporter: Reporter) extends DefaultTactic(reporter) { +class InductionTactic(vctx: VerificationContext) extends DefaultTactic(vctx) { override val description = "Induction tactic for suitable functions" override val shortDescription = "induction" - private def firstAbsClassDef(args: Seq[ValDef]) : Option[(AbstractClassType, ValDef)] = { + private def firstAbsClassDef(args: Seq[ValDef]): Option[(AbstractClassType, ValDef)] = { args.map(vd => (vd.getType, vd)).collect { case (act: AbstractClassType, vd) => (act, vd) }.headOption - } + } - private def selectorsOfParentType(parentType: ClassType, cct: CaseClassType, expr: Expr) : Seq[Expr] = { + private def selectorsOfParentType(parentType: ClassType, cct: CaseClassType, expr: Expr): Seq[Expr] = { val childrenOfSameType = cct.fields.filter(_.tpe == parentType) for (field <- childrenOfSameType) yield { CaseClassSelector(cct, expr, field.id) } } - override def generatePostconditions(funDef: FunDef) : Seq[VerificationCondition] = { - assert(funDef.body.isDefined) - firstAbsClassDef(funDef.params) match { - case Some((cct, arg)) => - val prec = funDef.precondition - val optPost = funDef.postcondition - val body = matchToIfThenElse(funDef.body.get) - val argAsVar = arg.toVariable - val parentType = cct - - optPost match { - case None => - Seq.empty - case Some((pid, post)) => - for (cct <- parentType.knownCCDescendents) yield { - val selectors = selectorsOfParentType(parentType, cct, argAsVar) - // if no subtrees of parent type, assert property for base case - val resFresh = FreshIdentifier("result", true).setType(body.getType) - val bodyAndPostForArg = Let(resFresh, body, replace(Map(Variable(pid) -> Variable(resFresh)), matchToIfThenElse(post))) - val withPrec = if (prec.isEmpty) bodyAndPostForArg else Implies(matchToIfThenElse(prec.get), bodyAndPostForArg) - - val conditionForChild = - if (selectors.size == 0) - withPrec - else { - val inductiveHypothesis = (for (sel <- selectors) yield { - val resFresh = FreshIdentifier("result", true).setType(body.getType) - val bodyAndPost = Let(resFresh, replace(Map(argAsVar -> sel), body), replace(Map(Variable(pid) -> Variable(resFresh), argAsVar -> sel), matchToIfThenElse(post))) - val withPrec = if (prec.isEmpty) bodyAndPost else Implies(replace(Map(argAsVar -> sel), matchToIfThenElse(prec.get)), bodyAndPost) - withPrec - }) - Implies(And(inductiveHypothesis), withPrec) - } - new VerificationCondition(Implies(CaseClassInstanceOf(cct, argAsVar), conditionForChild), funDef, VCKind.Postcondition, this).setPos(funDef) - } + override def generatePostconditions(fd: FunDef): Seq[VerificationCondition] = { + (fd.body.map(safe), firstAbsClassDef(fd.params), fd.postcondition) match { + case (Some(b), Some((parentType, arg)), Some((id, p))) => + val post = safe(p) + val body = safe(b) + + for (cct <- parentType.knownCCDescendents) yield { + val selectors = selectorsOfParentType(parentType, cct, arg.toVariable) + + val subCases = selectors.map { sel => + val res = id.freshen + replace(Map(arg.toVariable -> sel), + Implies(precOrTrue(fd), Let(res, body, replace(Map(id.toVariable -> res.toVariable), post))) + ) + } + + val vc = Implies(And(CaseClassInstanceOf(cct, arg.toVariable), precOrTrue(fd)), Implies(And(subCases), Let(id, body, post))) + + new VerificationCondition(vc, fd, VCKind.Postcondition, this).setPos(fd) } - case None => - reporter.warning(funDef.getPos, "Could not find abstract class type argument to induct on") - super.generatePostconditions(funDef) + case (body, _, post) => + if (post.isDefined && body.isDefined) { + reporter.warning(fd.getPos, "Could not find abstract class type argument to induct on") + } + super.generatePostconditions(fd) } } - override def generatePreconditions(function: FunDef) : Seq[VerificationCondition] = { - val defaultPrec = super.generatePreconditions(function) - firstAbsClassDef(function.params) match { - case Some((cct, arg)) => { - val toRet = if(function.hasBody) { - val parentType = cct - val cleanBody = expandLets(matchToIfThenElse(function.body.get)) - - val allPathConds = collectWithPathCondition((t => t match { - case FunctionInvocation(tfd, _) if(tfd.hasPrecondition) => true - case _ => false - }), cleanBody) - - def withPrec(path: Seq[Expr], shouldHold: Expr) : Expr = if(function.hasPrecondition) { - Not(And(And(matchToIfThenElse(function.precondition.get) +: path), Not(shouldHold))) - } else { - Not(And(And(path), Not(shouldHold))) - } + override def generatePreconditions(fd: FunDef): Seq[VerificationCondition] = { + (fd.body.map(safe), firstAbsClassDef(fd.params)) match { + case (Some(b), Some((parentType, arg))) => + val body = safe(b) - val conditionsForAllPaths : Seq[Seq[VerificationCondition]] = allPathConds.map(pc => { - val path : Seq[Expr] = pc._1 - val fi = pc._2.asInstanceOf[FunctionInvocation] - val FunctionInvocation(tfd, args) = fi + val calls = collectWithPC { + case fi @ FunctionInvocation(tfd, _) if tfd.hasPrecondition => (fi, safe(tfd.precondition.get)) + }(body) + calls.flatMap { + case ((fi @ FunctionInvocation(tfd, args), pre), path) => for (cct <- parentType.knownCCDescendents) yield { - val argAsVar = arg.toVariable - val selectors = selectorsOfParentType(parentType, cct, argAsVar) - - val prec : Expr = freshenLocals(matchToIfThenElse(tfd.precondition.get)) - val newLetIDs = tfd.params.map(a => FreshIdentifier("arg_" + a.id.name, true).setType(a.tpe)) - val substMap = Map[Expr,Expr]((tfd.params.map(_.toVariable) zip newLetIDs.map(Variable(_))) : _*) - val newBody : Expr = replace(substMap, prec) - val newCall : Expr = (newLetIDs zip args).foldRight(newBody)((iap, e) => Let(iap._1, iap._2, e)) - - val toProve = withPrec(path, newCall) - - val conditionForChild = - if (selectors.isEmpty) - toProve - else { - val inductiveHypothesis = (for (sel <- selectors) yield { - val prec : Expr = freshenLocals(matchToIfThenElse(tfd.precondition.get)) - val newLetIDs = tfd.params.map(a => FreshIdentifier("arg_" + a.id.name, true).setType(a.tpe)) - val substMap = Map[Expr,Expr]((tfd.params.map(_.toVariable) zip newLetIDs.map(Variable(_))) : _*) - val newBody : Expr = replace(substMap, prec) - val newCall : Expr = (newLetIDs zip args).foldRight(newBody)((iap, e) => Let(iap._1, iap._2, e)) - - val toReplace = withPrec(path, newCall) - replace(Map(argAsVar -> sel), toReplace) - }) - Implies(And(inductiveHypothesis), toProve) - } - new VerificationCondition(Implies(CaseClassInstanceOf(cct, argAsVar), conditionForChild), function, VCKind.Precondition, this).setPos(fi) + val selectors = selectorsOfParentType(parentType, cct, arg.toVariable) + + val subCases = selectors.map { sel => + replace(Map(arg.toVariable -> sel), + Implies(precOrTrue(fd), replace((tfd.params.map(_.toVariable) zip args).toMap, pre)) + ) + } + + val vc = Implies(And(Seq(CaseClassInstanceOf(cct, arg.toVariable), precOrTrue(fd), path)), Implies(And(subCases), replace((tfd.params.map(_.toVariable) zip args).toMap, pre))) + + new VerificationCondition(vc, fd, VCKind.Precondition, this).setPos(fi) } - }).toSeq + } - conditionsForAllPaths.flatten - } else { - Seq.empty + case (body, _) => + if (body.isDefined) { + reporter.warning(fd.getPos, "Could not find abstract class type argument to induct on") } - toRet - } - case None => { - reporter.warning(function.getPos, "Induction tactic currently supports exactly one argument of abstract class type") - defaultPrec - } + super.generatePreconditions(fd) } } } diff --git a/src/main/scala/leon/verification/InjectAsserts.scala b/src/main/scala/leon/verification/InjectAsserts.scala new file mode 100644 index 0000000000000000000000000000000000000000..294a86e629df8fedf79e920eeef26061b9682fb1 --- /dev/null +++ b/src/main/scala/leon/verification/InjectAsserts.scala @@ -0,0 +1,39 @@ +/* Copyright 2009-2014 EPFL, Lausanne */ + +package leon +package utils + +import purescala.Common._ +import purescala.Trees._ +import xlang.Trees._ +import purescala.TreeOps._ +import purescala.Definitions._ + +object InjectAsserts extends LeonPhase[Program, Program] { + + val name = "Asserts" + val description = "Inject asserts for various corrected conditions (map accesses, array accesses, ..)" + + def run(ctx: LeonContext)(pgm: Program): Program = { + def indexUpTo(i: Expr, e: Expr) = { + And(GreaterEquals(i, IntLiteral(0)), LessThan(i, e)) + } + + pgm.definedFunctions.foreach(fd => { + fd.body = fd.body.map(postMap { + case e @ ArraySelect(a, i) => + Some(Assert(indexUpTo(i, ArrayLength(a)), Some("Array index out of range"), e).setPos(e)) + case e @ ArrayUpdated(a, i, _) => + Some(Assert(indexUpTo(i, ArrayLength(a)), Some("Array index out of range"), e).setPos(e)) + case e @ ArrayUpdate(a, i, _) => + Some(Assert(indexUpTo(i, ArrayLength(a)), Some("Array index out of range"), e).setPos(e)) + case e @ MapGet(m,k) => + Some(Assert(MapIsDefinedAt(m, k), Some("Map undefined at this index"), e).setPos(e)) + case _ => + None + }) + }) + + pgm + } +} diff --git a/src/main/scala/leon/verification/Tactic.scala b/src/main/scala/leon/verification/Tactic.scala index 5fe88ffad0875e578b5d34fe9fbbb16ae67faa70..7f4418023003a14bd70c1d55cfcfc01116d1461b 100644 --- a/src/main/scala/leon/verification/Tactic.scala +++ b/src/main/scala/leon/verification/Tactic.scala @@ -4,15 +4,35 @@ package leon package verification import purescala.Definitions._ +import purescala.Trees._ +import purescala.TreeOps._ -abstract class Tactic(reporter: Reporter) { +abstract class Tactic(vctx: VerificationContext) { val description : String val shortDescription : String - def setProgram(program: Program) : Unit = {} - def generatePostconditions(function: FunDef) : Seq[VerificationCondition] - def generatePreconditions(function: FunDef) : Seq[VerificationCondition] - def generatePatternMatchingExhaustivenessChecks(function: FunDef) : Seq[VerificationCondition] - def generateMiscCorrectnessConditions(function: FunDef) : Seq[VerificationCondition] - def generateArrayAccessChecks(function: FunDef) : Seq[VerificationCondition] + val program = vctx.program + val reporter = vctx.reporter + + def generateVCs(fd: FunDef): Seq[VerificationCondition] = { + generatePostconditions(fd) ++ + generatePreconditions(fd) ++ + generateCorrectnessConditions(fd) + } + + def generatePostconditions(function: FunDef): Seq[VerificationCondition] + def generatePreconditions(function: FunDef): Seq[VerificationCondition] + def generateCorrectnessConditions(function: FunDef): Seq[VerificationCondition] + + + // Helper functions + protected def safe(e: Expr): Expr = matchToIfThenElse(e) + protected def precOrTrue(fd: FunDef): Expr = fd.precondition match { + case Some(pre) => safe(pre) + case None => BooleanLiteral(true) + } + + protected def collectWithPC[T](f: PartialFunction[Expr, T])(expr: Expr): Seq[(T, Expr)] = { + CollectorWithPaths(f).traverse(expr) + } } diff --git a/src/main/scala/leon/verification/VerificationCondition.scala b/src/main/scala/leon/verification/VerificationCondition.scala index 4a23e63e18d1f2ce928ddba204a91176c9f810c3..fc905630d039d8c651443a8c384ebddc3f0f98ad 100644 --- a/src/main/scala/leon/verification/VerificationCondition.scala +++ b/src/main/scala/leon/verification/VerificationCondition.scala @@ -40,6 +40,7 @@ class VerificationCondition(val condition: Expr, val funDef: FunDef, val kind: V object VCKind extends Enumeration { val Precondition = Value("precond.") val Postcondition = Value("postcond.") + val Correctness = Value("correct.") val ExhaustiveMatch = Value("match.") val MapAccess = Value("map acc.") val ArrayAccess = Value("arr. acc.") diff --git a/src/test/resources/regression/verification/purescala/invalid/Asserts1.scala b/src/test/resources/regression/verification/purescala/invalid/Asserts1.scala new file mode 100644 index 0000000000000000000000000000000000000000..2643eff3e10a772e4bf475475b8eaa88bef2d168 --- /dev/null +++ b/src/test/resources/regression/verification/purescala/invalid/Asserts1.scala @@ -0,0 +1,32 @@ +import leon.lang._ +import leon.annotation._ +import leon._ + +object Operators { + + def foo(a: Int): Int = { + require(a > 0) + + { + val b = a + assert(b > 0, "Hey now") + b + bar(1) + } ensuring { _ < 2 } + + } ensuring { + _ > a + } + + def bar(a: Int): Int = { + require(a > 0) + + { + val b = a + assert(b > 0, "Hey now") + b + 2 + } ensuring { _ > 2 } + + } ensuring { + _ > a + } +} diff --git a/src/test/resources/regression/verification/purescala/valid/Asserts1.scala b/src/test/resources/regression/verification/purescala/valid/Asserts1.scala new file mode 100644 index 0000000000000000000000000000000000000000..abbd39d5e8f88244891e20f763f7ccd150fcf839 --- /dev/null +++ b/src/test/resources/regression/verification/purescala/valid/Asserts1.scala @@ -0,0 +1,32 @@ +import leon.lang._ +import leon.annotation._ +import leon._ + +object Operators { + + def foo(a: Int): Int = { + require(a > 0) + + { + val b = a + assert(b > 0, "Hey now") + b + bar(1) + } ensuring { _ > 2 } + + } ensuring { + _ > a + } + + def bar(a: Int): Int = { + require(a > 0) + + { + val b = a + assert(b > 0, "Hey now") + b + 2 + } ensuring { _ > 2 } + + } ensuring { + _ > a + } +}