diff --git a/src/purescala/Analysis.scala b/src/purescala/Analysis.scala index a28fe8b899240843f8c5d1d5333ecd6a45dd0ffe..45c6597d16c6de399faa6f415493118396274237 100644 --- a/src/purescala/Analysis.scala +++ b/src/purescala/Analysis.scala @@ -72,8 +72,8 @@ class Analysis(val program: Program) { reporter.info(vc) // reporter.info("Negated:") // reporter.info(negate(vc)) - // reporter.info("Negated, expanded:") - // reporter.info(expandLets(negate(vc))) + reporter.info("Negated, expanded:") + reporter.info(expandLets(negate(vc))) // try all solvers until one returns a meaningful answer solverExtensions.find(se => { @@ -105,24 +105,62 @@ class Analysis(val program: Program) { if(post.isEmpty) { BooleanLiteral(true) } else { - if(prec.isEmpty) - replace(Map(ResultVariable() -> body), post.get) - else - Implies(prec.get, replace(Map(ResultVariable() -> body), post.get)) + val resFresh = FreshIdentifier("result", true).setType(body.getType) + val bodyAndPost = Let(resFresh, body, replace(Map(ResultVariable() -> Variable(resFresh)), post.get)) + val beforeInlining = if(prec.isEmpty) { + bodyAndPost + } else { + Implies(prec.get, bodyAndPost) + } + + val (newExpr, sideExprs) = inlineFunctionsAndContracts(beforeInlining) + + if(sideExprs.isEmpty) { + newExpr + } else { + Implies(And(sideExprs), newExpr) + } } } - def rewritePatternMatching(expr: Expr) : Expr = { - def isPMExpr(e: Expr) : Boolean = e match { - case MatchExpr(_, _) => true - case _ => false + def inlineFunctionsAndContracts(expr: Expr) : (Expr, Seq[Expr]) = { + var extras : List[Expr] = Nil + + val isFunCall: Function[Expr,Boolean] = _.isInstanceOf[FunctionInvocation] + def applyToCall(e: Expr) : Expr = e match { + case f @ FunctionInvocation(fd, args) => { + val fArgsAsVars: List[Variable] = fd.args.map(_.toVariable).toList + val fParamsAsLetVars: List[Identifier] = fd.args.map(a => FreshIdentifier("arg", true).setType(a.tpe)).toList + + def mkBigLet(ex: Expr) : Expr = (fParamsAsLetVars zip args).foldRight(ex)((iap, e) => { + Let(iap._1, iap._2, e) + }) + + val substMap = Map[Expr,Expr]((fArgsAsVars zip fParamsAsLetVars.map(Variable(_))) : _*) + if(fd.hasPostcondition) { + val newVar = Variable(FreshIdentifier("call", true)).setType(fd.returnType) + extras = mkBigLet(replace(substMap + (ResultVariable() -> newVar), fd.postcondition.get)) :: extras + newVar + } else if(fd.hasImplementation && !program.isRecursive(fd)) { // means we can inline at least one level... + mkBigLet(replace(substMap, fd.body.get)) + } else { // we can't do much for calls to recursive functions or to functions with no bodies + f + } + } + case o => o } + (searchAndApply(isFunCall, applyToCall, expr), extras.reverse) + } + + def rewritePatternMatching(expr: Expr) : Expr = { + def isPMExpr(e: Expr) : Boolean = e.isInstanceOf[MatchExpr] + def rewritePM(e: Expr) : Expr = e match { case m @ MatchExpr(_, _) => m case _ => e } - replace(isPMExpr, rewritePM, expr) + searchAndApply(isPMExpr, rewritePM, expr) } } diff --git a/src/purescala/Common.scala b/src/purescala/Common.scala index ddb2129e10913f3b875af0a5d8c85d7b7ce4b1a6..1d5365c6b315b1b2be1859f8bae2a2da56f4cc46 100644 --- a/src/purescala/Common.scala +++ b/src/purescala/Common.scala @@ -4,7 +4,7 @@ object Common { import TypeTrees.Typed // the type is left blank (NoType) for Identifiers that are not variables - class Identifier private[Common](val name: String, val id: Int) extends Typed { + class Identifier private[Common](val name: String, val id: Int, alwaysShowUniqueID: Boolean = false) extends Typed { override def equals(other: Any): Boolean = { if(other == null || !other.isInstanceOf[Identifier]) false @@ -18,6 +18,8 @@ object Common { if(purescala.Settings.showIDs) { // angle brackets: name + "\u3008" + id + "\u3009" name + "[" + id + "]" + } else if(alwaysShowUniqueID) { + name + id } else { name } @@ -34,6 +36,6 @@ object Common { } object FreshIdentifier { - def apply(name: String) : Identifier = new Identifier(name, UniqueCounter.next) + def apply(name: String, alwaysShowUniqueID: Boolean = false) : Identifier = new Identifier(name, UniqueCounter.next, alwaysShowUniqueID) } } diff --git a/src/purescala/Definitions.scala b/src/purescala/Definitions.scala index 410903e8f76c582cc52166b0dc3c9d6e2d94f08f..590b4842757ebe747359d88bdf4751dca66fca55 100644 --- a/src/purescala/Definitions.scala +++ b/src/purescala/Definitions.scala @@ -14,6 +14,8 @@ object Definitions { case class VarDecl(id: Identifier, tpe: TypeTree) extends Typed { override def getType = tpe override def setType(tt: TypeTree) = scala.Predef.error("Can't set type of VarDecl.") + + def toVariable : Variable = Variable(id).setType(tpe) } type VarDecls = Seq[VarDecl] @@ -21,30 +23,95 @@ object Definitions { /** A wrapper for a program. For now a program is simply a single object. The * name is meaningless and we just use the package name as id. */ case class Program(id: Identifier, mainObject: ObjectDef) extends Definition { - lazy val callGraph: Map[FunDef,Seq[FunDef]] = computeCallGraph + def definedFunctions = mainObject.definedFunctions + def definedClasses = mainObject.definedClasses + def classHierarchyRoots = mainObject.classHierarchyRoots + def callGraph = mainObject.callGraph + def calls(f1: FunDef, f2: FunDef) = mainObject.calls(f1, f2) + def callers(f1: FunDef) = mainObject.callers(f1) + def callees(f1: FunDef) = mainObject.callees(f1) + def transitiveCallGraph = mainObject.transitiveCallGraph + def transitivelyCalls(f1: FunDef, f2: FunDef) = mainObject.transitivelyCalls(f1, f2) + def transitiveCallers(f1: FunDef) = mainObject.transitiveCallers(f1) + def transitiveCallees(f1: FunDef) = mainObject.transitiveCallees(f1) + def isRecursive(f1: FunDef) = mainObject.isRecursive(f1) + } + + /** Objects work as containers for class definitions, functions (def's) and + * val's. */ + case class ObjectDef(id: Identifier, defs : Seq[Definition], invariants: Seq[Expr]) extends Definition { + lazy val definedFunctions : Seq[FunDef] = defs.filter(_.isInstanceOf[FunDef]).map(_.asInstanceOf[FunDef]) + + lazy val definedClasses : Seq[ClassTypeDef] = defs.filter(_.isInstanceOf[ClassTypeDef]).map(_.asInstanceOf[ClassTypeDef]) + + lazy val classHierarchyRoots : Seq[ClassTypeDef] = defs.filter(_.isInstanceOf[ClassTypeDef]).map(_.asInstanceOf[ClassTypeDef]).filter(!_.hasParent) + + lazy val (callGraph, callers, callees) = { + var resSet: Set[(FunDef,FunDef)] = + new scala.collection.immutable.HashSet[(FunDef,FunDef)]() + + def isFunCall(e: Expr) : Boolean = e.isInstanceOf[FunctionInvocation] + def applyToFunCall(f1: FunDef)(e: Expr) : Expr = e match { + case f @ FunctionInvocation(f2, _) => { resSet = resSet + ((f1,f2)); f } + case o => o + } + + for(funDef <- definedFunctions) { + funDef.precondition.map(searchAndApply(isFunCall, applyToFunCall(funDef), _)) + funDef.body.map(searchAndApply(isFunCall, applyToFunCall(funDef), _)) + funDef.postcondition.map(searchAndApply(isFunCall, applyToFunCall(funDef), _)) + } + + var callers: Map[FunDef,Set[FunDef]] = + new scala.collection.immutable.HashMap[FunDef,Set[FunDef]] + var callees: Map[FunDef,Set[FunDef]] = + new scala.collection.immutable.HashMap[FunDef,Set[FunDef]] - def computeCallGraph: Map[FunDef,Seq[FunDef]] = Map.empty + for(funDef <- definedFunctions) { + val clrs = resSet.filter(_._2 == funDef).map(_._1) + val cles = resSet.filter(_._1 == funDef).map(_._2) + callers = callers + (funDef -> clrs) + callees = callees + (funDef -> cles) + } - // checks whether f2 can be called from f1 - def calls(f1: FunDef, f2: FunDef) : Boolean = { - false + (resSet, callers, callees) } - def transitivelyCalls(f1: FunDef, f2: FunDef) : Boolean = { - false + // checks whether f1's body, pre or post contain calls to f2 + def calls(f1: FunDef, f2: FunDef) : Boolean = callGraph((f1,f2)) + + lazy val (transitiveCallGraph, transitiveCallers, transitiveCallees) = { + var resSet : Set[(FunDef,FunDef)] = callGraph + var change = true + + while(change) { + change = false + for(f1 <- definedFunctions; f2 <- callers(f1); f3 <- callees(f1)) { + if(!resSet(f2,f3)) { + change = true + resSet = resSet + ((f2,f3)) + } + } + } + + var tCallers: Map[FunDef,Set[FunDef]] = + new scala.collection.immutable.HashMap[FunDef,Set[FunDef]] + var tCallees: Map[FunDef,Set[FunDef]] = + new scala.collection.immutable.HashMap[FunDef,Set[FunDef]] + + for(funDef <- definedFunctions) { + val clrs = resSet.filter(_._2 == funDef).map(_._1) + val cles = resSet.filter(_._1 == funDef).map(_._2) + tCallers = tCallers + (funDef -> clrs) + tCallees = tCallees + (funDef -> cles) + } + + (resSet, tCallers, tCallees) } - } - /** Objects work as containers for class definitions, functions (def's) and - * val's. */ - case class ObjectDef(id: Identifier, defs : Seq[Definition], invariants: Seq[Expr]) extends Definition { - // Watch out ! Use only when object is completely built ! - lazy val getDefinedClasses = computeDefinedClasses - // ...this one can be used safely at anytime. - def computeDefinedClasses : Seq[ClassTypeDef] = defs.filter(_.isInstanceOf[ClassTypeDef]).map(_.asInstanceOf[ClassTypeDef]) + def transitivelyCalls(f1: FunDef, f2: FunDef) : Boolean = transitiveCallGraph((f1,f2)) - lazy val getClassHierarchyRoots = computeClassHierarchyRoots - def computeClassHierarchyRoots : Seq[ClassTypeDef] = defs.filter(_.isInstanceOf[ClassTypeDef]).map(_.asInstanceOf[ClassTypeDef]).filter(!_.hasParent) + def isRecursive(f: FunDef) = transitivelyCalls(f, f) } /** Useful because case classes and classes are somewhat unified in some @@ -57,7 +124,6 @@ object Definitions { def setParent(parent: AbstractClassDef) : self.type def hasParent: Boolean = parent.isDefined val isAbstract: Boolean - // val fields: VarDecls } /** Will be used at some point as a common ground for case classes (which diff --git a/src/purescala/Trees.scala b/src/purescala/Trees.scala index 5c51c6666a3b78ffbda2658b62f01a7c66912540..8ae49b07048a4d8eabb57bbc087419d0a78c93d0 100644 --- a/src/purescala/Trees.scala +++ b/src/purescala/Trees.scala @@ -262,20 +262,23 @@ object Trees { // Warning ! This may loop forever if the substitutions are not // well-formed! def replace(substs: Map[Expr,Expr], expr: Expr) : Expr = { - replace(substs.isDefinedAt(_), substs(_), expr) + searchAndApply(substs.isDefinedAt(_), substs(_), expr) } // the replacement map should be understood as follows: // - on each subexpression, checkFun checks whether it should be replaced // - repFun is applied is checkFun succeeded - def replace(checkFun: Expr=>Boolean, repFun: Expr=>Expr, expr: Expr) : Expr = { - def rec(ex: Expr) : Expr = ex match { - case _ if (checkFun(ex)) => { + def searchAndApply(checkFun: Expr=>Boolean, repFun: Expr=>Expr, expr: Expr) : Expr = { + def rec(ex: Expr, skip: Expr = null) : Expr = ex match { + case _ if (ex != skip && checkFun(ex)) => { val newExpr = repFun(ex) if(newExpr.getType == NoType) { Settings.reporter.warning("REPLACING IN EXPRESSION WITH AN UNTYPED TREE ! " + ex + " --to--> " + newExpr) } - rec(newExpr) + if(ex == newExpr) + rec(ex, ex) + else + rec(newExpr) } case l @ Let(i,e,b) => Let(i, rec(e), rec(b)).setType(l.getType) case f @ FunctionInvocation(fd, args) => FunctionInvocation(fd, args.map(rec(_))).setType(f.getType) diff --git a/src/purescala/Z3Solver.scala b/src/purescala/Z3Solver.scala index 6fc9e4881aefcd8751ba19f19c7fbd56d104c58e..5f8dfc26ca77c17395c003439a547427b3f0d3c1 100644 --- a/src/purescala/Z3Solver.scala +++ b/src/purescala/Z3Solver.scala @@ -30,6 +30,9 @@ class Z3Solver(reporter: Reporter) extends Solver(reporter) { z3 = new Z3Context(z3cfg) } prepareSorts + + println(prog.callGraph.map(p => (p._1.id.name, p._2.id.name).toString)) + println(prog.transitiveCallGraph.map(p => (p._1.id.name, p._2.id.name).toString)) } private var intSort : Z3Sort = null @@ -47,7 +50,7 @@ class Z3Solver(reporter: Reporter) extends Solver(reporter) { intSort = z3.mkIntSort boolSort = z3.mkBoolSort - val roots = program.mainObject.getClassHierarchyRoots + val roots = program.classHierarchyRoots val indexMap: Map[ClassTypeDef,Int] = Map(roots.zipWithIndex: _*) //println("indexMap: " + indexMap) diff --git a/testcases/Account.scala b/testcases/Account.scala index a3a0e863b58fd03f678c7c6cabc704eeae82f289..66c5e352a8b0f3ad2394a84dc3dfca6e33c1b5c1 100644 --- a/testcases/Account.scala +++ b/testcases/Account.scala @@ -2,9 +2,12 @@ object Account { sealed abstract class AccLike case class Acc(checking : Int, savings : Int) extends AccLike + def sameTotal(a1 : Acc, a2 : Acc) : Boolean = { + a1.checking + a1.savings == a2.checking + a2.savings + } + def toSavings(x : Int, a : Acc) : Acc = { - require (a.checking > x) + require (x >= 0 && a.checking >= x) Acc(a.checking - x, a.savings + x) - // a match { case Acc(c0,s0) => Acc(c0 - x, s0 + x) } - } ensuring (_.checking >= 0) + } ensuring (res => (res.checking >= 0 && sameTotal(a, res))) }