diff --git a/src/main/scala/leon/frontends/scalac/ASTExtractors.scala b/src/main/scala/leon/frontends/scalac/ASTExtractors.scala index f8509be9c4d678b10a52f7736a406bb078b6bfbc..6246e4897053906e3a25e2eefebaf62dda11aa85 100644 --- a/src/main/scala/leon/frontends/scalac/ASTExtractors.scala +++ b/src/main/scala/leon/frontends/scalac/ASTExtractors.scala @@ -391,7 +391,7 @@ trait ASTExtractors { } object ExCaseClass { - def unapply(cd: ClassDef): Option[(String,Symbol,Seq[(Symbol,ValDef)], Template)] = cd match { + def unapply(cd: ClassDef): Option[(String,Symbol,Seq[(Symbol,ValDef)], Seq[(Symbol,ValDef)], Template)] = cd match { case ClassDef(_, name, tparams, impl) if isCaseClass(cd) || isImplicitClass(cd) => { val constructor: DefDef = impl.children.find { case ExConstructorDef() => true @@ -408,6 +408,18 @@ trait ASTExtractors { df.symbol } + val vars = impl.children.collect { + case vf: ValDef if vf.symbol.isPrivate && vf.symbol.isVar => vf + } + val varAccessors = impl.children.collect { + case df@DefDef(_, name, _, _, _, _) if + !df.symbol.isStable && df.symbol.isAccessor && !df.symbol.isParamAccessor && + !name.endsWith("_$eq") => df + } + val varsFinal = varAccessors.zip(vars).map(p => (p._1.symbol, p._2)) + //println("extracted vars: " + vars) + //println("extracted var accessors: " + varAccessors) + //if (symbols.size != valDefs.size) { // println(" >>>>> " + cd.name) // symbols foreach println @@ -416,7 +428,7 @@ trait ASTExtractors { val args = symbols zip valDefs - Some((name.toString, cd.symbol, args, impl)) + Some((name.toString, cd.symbol, args, varsFinal, impl)) } case _ => None } @@ -500,6 +512,36 @@ trait ASTExtractors { case _ => None } } + + object ExMutatorAccessorFunction { + def unapply(dd: DefDef): Option[(Symbol, Seq[Symbol], Seq[ValDef], Type, Tree)] = dd match { + case DefDef(_, name, tparams, vparamss, tpt, rhs) if( + vparamss.size <= 1 && name != nme.CONSTRUCTOR && + !dd.symbol.isSynthetic && dd.symbol.isAccessor && name.endsWith("_$eq") + ) => + Some((dd.symbol, tparams.map(_.symbol), vparamss.flatten, tpt.tpe, rhs)) + case _ => None + } + } + object ExMutableFieldDef { + + /** Matches a definition of a strict var field inside a class constructor */ + def unapply(vd: SymTree) : Option[(Symbol, Type, Tree)] = { + val sym = vd.symbol + vd match { + // Implemented fields + case ValDef(mods, name, tpt, rhs) if ( + !sym.isCaseAccessor && !sym.isParamAccessor && + !sym.isLazy && !sym.isSynthetic && !sym.isAccessor && sym.isVar + ) => + println("matched a var accessor field: sym is: " + sym) + println("getterIn is: " + sym.getterIn(sym.owner)) + // Since scalac uses the accessor symbol all over the place, we pass that instead: + Some( (sym.getterIn(sym.owner),tpt.tpe,rhs) ) + case _ => None + } + } + } object ExFieldDef { /** Matches a definition of a strict field inside a class constructor */ @@ -509,7 +551,7 @@ trait ASTExtractors { // Implemented fields case ValDef(mods, name, tpt, rhs) if ( !sym.isCaseAccessor && !sym.isParamAccessor && - !sym.isLazy && !sym.isSynthetic && !sym.isAccessor + !sym.isLazy && !sym.isSynthetic && !sym.isAccessor && !sym.isVar ) => // Since scalac uses the accessor symbol all over the place, we pass that instead: Some( (sym.getterIn(sym.owner),tpt.tpe,rhs) ) @@ -705,6 +747,7 @@ trait ASTExtractors { object ExAssign { def unapply(tree: Assign): Option[(Symbol,Tree)] = tree match { case Assign(id@Ident(_), rhs) => Some((id.symbol, rhs)) + //case Assign(sym@Select(This(_), v), rhs) => Some((sym.symbol, rhs)) case _ => None } } diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala index 0451b4c08603fb1b10ae91e2e38cc8902cd1649d..6ea915993cf482b62d543e8b3f1c1a2d6c4dac68 100644 --- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala +++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala @@ -12,6 +12,7 @@ import Definitions.{ ClassDef => LeonClassDef, ModuleDef => LeonModuleDef, ValDef => LeonValDef, + VarDef => LeonVarDef, Import => LeonImport, _ } @@ -133,6 +134,10 @@ trait CodeExtraction extends ASTExtractors { def withNewMutableVar(nvar: (Symbol, () => LeonExpr)) = { copy(mutableVars = mutableVars + nvar) } + + def withNewMutableVars(nvars: Traversable[(Symbol, () => LeonExpr)]) = { + copy(mutableVars = mutableVars ++ nvars) + } } private var currentFunDef: FunDef = null @@ -211,10 +216,10 @@ trait CodeExtraction extends ASTExtractors { // ignore case ExAbstractClass(o2, sym, tmpl) => - seenClasses += sym -> ((Nil, tmpl)) + seenClasses += sym -> ((Nil, Nil, tmpl)) - case ExCaseClass(o2, sym, args, tmpl) => - seenClasses += sym -> ((args, tmpl)) + case ExCaseClass(o2, sym, args, vars, tmpl) => + seenClasses += sym -> ((args, vars, tmpl)) case ExObjectDef(n, templ) => for (t <- templ.body if !t.isEmpty) t match { @@ -222,10 +227,10 @@ trait CodeExtraction extends ASTExtractors { // ignore case ExAbstractClass(_, sym, tmpl) => - seenClasses += sym -> ((Nil, tmpl)) + seenClasses += sym -> ((Nil, Nil, tmpl)) - case ExCaseClass(_, sym, args, tmpl) => - seenClasses += sym -> ((args, tmpl)) + case ExCaseClass(_, sym, args, vars, tmpl) => + seenClasses += sym -> ((args, vars, tmpl)) case _ => } @@ -245,7 +250,7 @@ trait CodeExtraction extends ASTExtractors { case t @ ExAbstractClass(o2, sym, _) => Some(getClassDef(sym, t.pos)) - case t @ ExCaseClass(o2, sym, args, _) => + case t @ ExCaseClass(o2, sym, args, vars, _) => Some(getClassDef(sym, t.pos)) case t @ ExObjectDef(n, templ) => @@ -259,7 +264,7 @@ trait CodeExtraction extends ASTExtractors { case ExAbstractClass(_, sym, _) => Some(getClassDef(sym, t.pos)) - case ExCaseClass(_, sym, _, _) => + case ExCaseClass(_, sym, _, _, _) => Some(getClassDef(sym, t.pos)) // Functions @@ -281,6 +286,10 @@ trait CodeExtraction extends ASTExtractors { case ExFieldDef(sym, _, _) => Some(defineFieldFunDef(sym, false)(DefContext())) + // var + case ExMutableFieldDef(sym, _, _) => + Some(defineFieldFunDef(sym, false)(DefContext())) + // All these are expected, but useless case ExCaseClassSyntheticJunk() | ExConstructorDef() @@ -290,6 +299,12 @@ trait CodeExtraction extends ASTExtractors { case d if (d.symbol.isImplicit && d.symbol.isSynthetic) => None + //vars are never accessed directly so we extract accessors and mutators and + //ignore bare variables + case d if d.symbol.isVar => + None + + // Everything else is unexpected case tree => println(tree) @@ -350,7 +365,7 @@ trait CodeExtraction extends ASTExtractors { case ExAbstractClass(_, sym, tpl) => extractClassMembers(sym, tpl) - case ExCaseClass(_, sym, _, tpl) => + case ExCaseClass(_, sym, _, _, tpl) => extractClassMembers(sym, tpl) case ExObjectDef(n, templ) => @@ -362,7 +377,7 @@ trait CodeExtraction extends ASTExtractors { case ExAbstractClass(_, sym, tpl) => extractClassMembers(sym, tpl) - case ExCaseClass(_, sym, _, tpl) => + case ExCaseClass(_, sym, _, _, tpl) => extractClassMembers(sym, tpl) case t => @@ -408,7 +423,7 @@ trait CodeExtraction extends ASTExtractors { } } - private var seenClasses = Map[Symbol, (Seq[(Symbol, ValDef)], Template)]() + private var seenClasses = Map[Symbol, (Seq[(Symbol, ValDef)], Seq[(Symbol, ValDef)], Template)]() private var classesToClasses = Map[Symbol, LeonClassDef]() def oracleType(pos: Position, tpe: LeonType) = { @@ -445,9 +460,9 @@ trait CodeExtraction extends ASTExtractors { case Some(cd) => cd case None => if (seenClasses contains sym) { - val (args, tmpl) = seenClasses(sym) + val (args, vars, tmpl) = seenClasses(sym) - extractClassDef(sym, args, tmpl) + extractClassDef(sym, args, vars, tmpl) } else { outOfSubsetError(pos, "Class "+sym.fullName+" not defined?") } @@ -463,6 +478,7 @@ trait CodeExtraction extends ASTExtractors { } private var isMethod = Set[Symbol]() + private var isMutator = Set[Symbol]() private var methodToClass = Map[FunDef, LeonClassDef]() private var classToInvariants = Map[Symbol, Set[Tree]]() @@ -487,7 +503,7 @@ trait CodeExtraction extends ASTExtractors { paramsToDefaultValues += (theParam -> fd) } - def extractClassDef(sym: Symbol, args: Seq[(Symbol, ValDef)], tmpl: Template): LeonClassDef = { + def extractClassDef(sym: Symbol, args: Seq[(Symbol, ValDef)], vars: Seq[(Symbol, ValDef)], tmpl: Template): LeonClassDef = { //println(s"Extracting $sym") @@ -613,6 +629,17 @@ trait CodeExtraction extends ASTExtractors { for (tp <- ccd.tparams) check(tp, Set.empty) + val varFields = vars.map(t => { + val sym = t._1 + val vd = t._2 + val tpe = leonType(vd.tpt.tpe)(defCtx, sym.pos) + val id = cachedWithOverrides(sym, Some(ccd), tpe) + val value = extractTree(vd.children(1))(DefContext()) + if (tpe != id.getType) println(tpe, id.getType) + LeonVarDef(id.setPos(vd.pos), value).setPos(vd.pos) + }) + ccd.setVarFields(varFields) + case _ => } @@ -662,7 +689,7 @@ trait CodeExtraction extends ASTExtractors { // normal fields case t @ ExFieldDef(fsym, _, _) => - //println(fsym + "matched as ExFieldDef") + println(fsym + "matched as ExFieldDef") // we will be using the accessor method of this field everywhere isMethod += fsym val fd = defineFieldFunDef(fsym, false, Some(cd))(defCtx) @@ -670,6 +697,25 @@ trait CodeExtraction extends ASTExtractors { cd.registerMethod(fd) + case t @ ExMutableFieldDef(fsym, _, _) => + println(fsym + "matched as ExMutableFieldDef") + // we will be using the accessor method of this field everywhere + //isMethod += fsym + //val fd = defineFieldFunDef(fsym, false, Some(cd))(defCtx) + //methodToClass += fd -> cd + + //cd.registerMethod(fd) + + case t@ ExMutatorAccessorFunction(fsym, _, _, _, _) => + println("FOUND mutator: " + t) + println("accessed: " + fsym.accessed) + isMutator += fsym + //val fd = defineFunDef(fsym, Some(cd))(defCtx) + + //methodToClass += fd -> cd + + //cd.registerMethod(fd) + case other => } @@ -688,7 +734,7 @@ trait CodeExtraction extends ASTExtractors { val topOfHierarchy = sym.overrideChain.last funOrFieldSymsToIds.cachedB(topOfHierarchy){ - FreshIdentifier(sym.name.toString, tpe) + FreshIdentifier(sym.name.toString.trim, tpe) //trim because sometimes Scala names end with a trailing space, looks nicer without the space } } @@ -827,6 +873,31 @@ trait CodeExtraction extends ASTExtractors { extractFunBody(fd, Seq(), body)(DefContext(tparamsMap.toMap)) } + case t @ ExMutableFieldDef(sym, _, body) => // if !sym.isSynthetic && !sym.isAccessor => + //val fd = defsToDefs(sym) + //val tparamsMap = ctparamsMap + + //if(body != EmptyTree) { + // extractFunBody(fd, Seq(), body)(DefContext(tparamsMap.toMap)) + //} + + case ExMutatorAccessorFunction(sym, tparams, params, _, body) => + //val fd = defsToDefs(sym) + + //val tparamsMap = (tparams zip fd.tparams.map(_.tp)).toMap ++ ctparamsMap + + //val classSym = ocsym.get + //val cd = classesToClasses(classSym).asInstanceOf[CaseClassDef] + //val classVarDefs = seenClasses(classSym)._2 + //val mutableFields = classVarDefs.zip(cd.varFields).map(p => (p._1._1, () => p._2.toVariable)) + + //val dctx = DefContext(tparamsMap) + //val pctx = dctx.withNewMutableVars(mutableFields) + + //if(body != EmptyTree) { + // extractFunBody(fd, params, body)(pctx) + //} + case _ => } } @@ -1234,15 +1305,17 @@ trait CodeExtraction extends ASTExtractors { LetVar(newID, valTree, restTree) } - case ExAssign(sym, rhs) => dctx.mutableVars.get(sym) match { + case a@ExAssign(sym, rhs) => { + println("extracted assign: " + sym + " = " + rhs) + dctx.mutableVars.get(sym) match { case Some(fun) => val Variable(id) = fun() val rhsTree = extractTree(rhs) Assignment(id, rhsTree) case None => - outOfSubsetError(tr, "Undeclared variable.") - } + outOfSubsetError(a, "Undeclared variable.") + }} case wh @ ExWhile(cond, body) => val condTree = extractTree(cond) @@ -1568,7 +1641,7 @@ trait CodeExtraction extends ASTExtractors { case c @ ExCall(rec, sym, tps, args) => // The object on which it is called is null if the symbol sym is a valid function in the scope and not a method. val rrec = rec match { - case t if (defsToDefs contains sym) && !isMethod(sym) => + case t if (defsToDefs contains sym) && !isMethod(sym) && !isMutator(sym) => null case _ => extractTree(rec) @@ -1605,6 +1678,18 @@ trait CodeExtraction extends ASTExtractors { caseClassSelector(cct, rec, fieldID) + //mutable variables + case (IsTyped(rec, cct: CaseClassType), name, List(e1)) if isMutator(sym) => + println("Searching for mutator: " + name) + println(cct.classDef.varFields) + val id = cct.classDef.varFields.find(_.id.name == name.dropRight(2)).get.id + FieldAssignment(rec, id, e1) + + case (IsTyped(rec, cct: CaseClassType), name, Nil) if cct.classDef.varFields.exists(_.id.name == name) => + val id = cct.classDef.varFields.find(_.id.name == name).get.id + MutableFieldAccess(cct, rec, id) + + //String methods case (IsTyped(a1, StringType), "toString", List()) => a1 @@ -1798,6 +1883,7 @@ trait CodeExtraction extends ASTExtractors { case (IsTyped(a1, CharType), "<=", List(IsTyped(a2, CharType))) => LessEquals(a1, a2) + case (a1, name, a2) => val typea1 = a1.getType val typea2 = a2.map(_.getType).mkString(",") diff --git a/src/main/scala/leon/purescala/Definitions.scala b/src/main/scala/leon/purescala/Definitions.scala index 733eaf124dbb6dd8499c81347f7fc864c23ca453..8a93e435e26a43bcada96871164dfb424b6dd58f 100644 --- a/src/main/scala/leon/purescala/Definitions.scala +++ b/src/main/scala/leon/purescala/Definitions.scala @@ -58,6 +58,17 @@ object Definitions { def toVariable : Variable = Variable(id) } + case class VarDef(id: Identifier, value: Expr) extends Definition with Typed { + self: Serializable => + + val getType = id.getType + + def subDefinitions = Seq() + + /** Transform this [[VarDef]] into a [[Expressions.Variable Variable]] */ + def toVariable : Variable = Variable(id) + } + /** A wrapper for a program. For now a program is simply a single object. */ case class Program(units: List[UnitDef]) extends Definition { val id = FreshIdentifier("program") @@ -371,6 +382,14 @@ object Definitions { _fields = fields } + private var _varFields = Seq[VarDef]() + + def varFields = _varFields + + def setVarFields(fields: Seq[VarDef]) { + _varFields = fields + } + val isAbstract = false def selectorID2Index(id: Identifier) : Int = { diff --git a/src/main/scala/leon/purescala/MethodLifting.scala b/src/main/scala/leon/purescala/MethodLifting.scala index 0e9e11d171a15e2b0b5fb77d0b3f0e65d44be38b..4f6f0f73470eb39fc8de1735466dc1f0d350c457 100644 --- a/src/main/scala/leon/purescala/MethodLifting.scala +++ b/src/main/scala/leon/purescala/MethodLifting.scala @@ -11,6 +11,7 @@ import ExprOps._ import Types._ import Constructors._ import TypeOps.instantiateType +import xlang.Expressions._ object MethodLifting extends TransformationPhase { @@ -162,6 +163,8 @@ object MethodLifting extends TransformationPhase { def thisToReceiver(e: Expr): Option[Expr] = e match { case th@This(ct) => Some(asInstOf(receiver.toVariable, ct).setPos(th)) + case a@Assignment(v, lhs) if cd.asInstanceOf[CaseClassDef].varFields.exists(vd => vd.id == v) => + Some(FieldAssignment(receiver.toVariable, v, lhs).setPos(a)) case _ => None } diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala index 8b681460826345dad6078145bb9cb01e1be73eb9..485a4c9b1827b6e6fe6af3aa8bbf834d87ff85d7 100644 --- a/src/main/scala/leon/purescala/PrettyPrinter.scala +++ b/src/main/scala/leon/purescala/PrettyPrinter.scala @@ -370,6 +370,9 @@ class PrettyPrinter(opts: PrinterOptions, p"$id : ${vd.getType}" vd.defaultValue.foreach { fd => p" = ${fd.body.get}" } + case vd @ VarDef(id, value) => + p"var $id : ${vd.getType} = $value" + case This(_) => p"this" case (tfd: TypedFunDef) => p"typed def ${tfd.id}[${tfd.tps}]" case TypeParameterDef(tp) => p"$tp" @@ -528,9 +531,10 @@ class PrettyPrinter(opts: PrinterOptions, p" extends ${par.id}${nary(tparams, ", ", "[", "]")}" } - if (ccd.methods.nonEmpty) { + if (ccd.methods.nonEmpty || ccd.varFields.nonEmpty) { p"""| { - | ${nary(ccd.methods, "\n\n")} + | ${(if(ccd.methods.nonEmpty) { nary(ccd.methods, "\n\n") } else "")} + | ${(if(ccd.varFields.nonEmpty) { nary(ccd.varFields, "\n\n") } else "")} |}""" } diff --git a/src/main/scala/leon/purescala/Types.scala b/src/main/scala/leon/purescala/Types.scala index 6f2518d549be6b1d439ff9cb08594558d02449d5..575aea838780c1583eb7f39c6ee80a96973688b1 100644 --- a/src/main/scala/leon/purescala/Types.scala +++ b/src/main/scala/leon/purescala/Types.scala @@ -98,7 +98,7 @@ object Types { assert(classDef.tparams.size == tps.size) - lazy val fields = { + def fields = { val tmap = (classDef.tparams zip tps).toMap if (tmap.isEmpty) { classDef.fields diff --git a/src/main/scala/leon/xlang/AntiAliasingPhase.scala b/src/main/scala/leon/xlang/AntiAliasingPhase.scala index 7eb391088ea64c276aa0bbf7836e4a9511aa978c..aa62d377378ec20940d8bc60e469c5769dfd0487 100644 --- a/src/main/scala/leon/xlang/AntiAliasingPhase.scala +++ b/src/main/scala/leon/xlang/AntiAliasingPhase.scala @@ -19,12 +19,17 @@ object AntiAliasingPhase extends TransformationPhase { val description = "Make aliasing explicit" override def apply(ctx: LeonContext, pgm: Program): Program = { + + updateCaseClassesWithVarFields(pgm) + println(pgm) + val fds = allFunDefs(pgm) fds.foreach(fd => checkAliasing(fd)(ctx)) var updatedFunctions: Map[FunDef, FunDef] = Map() val effects = effectsAnalysis(pgm) + println("effects: " + effects.filter(e => e._2.size > 0).map(e => (e._1.id, e._2))) //for each fun def, all the vars the the body captures. Only //mutable types. @@ -33,7 +38,7 @@ object AntiAliasingPhase extends TransformationPhase { } yield { val allFreeVars = fd.body.map(bd => variablesOf(bd)).getOrElse(Set()) val freeVars = allFreeVars -- fd.params.map(_.id) - val mutableFreeVars = freeVars.filter(id => id.getType.isInstanceOf[ArrayType]) + val mutableFreeVars = freeVars.filter(id => isMutableType(id.getType)) (fd, mutableFreeVars) }).toMap @@ -54,7 +59,7 @@ object AntiAliasingPhase extends TransformationPhase { } val res = replaceFunDefs(pgm)(fd => updatedFunctions.get(fd), (fi, fd) => None) - //println(res._1) + println(res._1) res._1 } @@ -166,12 +171,20 @@ object AntiAliasingPhase extends TransformationPhase { (None, bindings) } - case l@Let(id, IsTyped(v, ArrayType(_)), b) => { + case as@FieldAssignment(o, id, v) => { + val ro@Variable(oid) = o + if(bindings.contains(oid)) + (Some(Assignment(oid, copy(o, id, v))), bindings) + else + (None, bindings) + } + + case l@Let(id, IsTyped(v, tpe), b) if isMutableType(tpe) => { val varDecl = LetVar(id, v, b).setPos(l) (Some(varDecl), bindings + id) } - case l@LetVar(id, IsTyped(v, ArrayType(_)), b) => { + case l@LetVar(id, IsTyped(v, tpe), b) if isMutableType(tpe) => { (None, bindings + id) } @@ -253,12 +266,10 @@ object AntiAliasingPhase extends TransformationPhase { case None => effects += (fd -> Set()) case Some(body) => { - val mutableParams = fd.params.filter(vd => vd.getType match { - case ArrayType(_) => true - case _ => false - }) + val mutableParams = fd.params.filter(vd => isMutableType(vd.getType)) val mutatedParams = mutableParams.filter(vd => exists { case ArrayUpdate(Variable(a), _, _) => a == vd.id + case FieldAssignment(Variable(a), _, _) => a == vd.id case _ => false }(body)) val mutatedParamsIndices = fd.params.zipWithIndex.flatMap{ @@ -380,4 +391,61 @@ object AntiAliasingPhase extends TransformationPhase { pgm.definedFunctions.flatMap(fd => fd.body.toSet.flatMap((bd: Expr) => nestedFunDefsOf(bd)) + fd) + + + private def isMutableType(tpe: TypeTree): Boolean = + tpe.isInstanceOf[ArrayType] || tpe.isInstanceOf[ClassType] + + + private def copy(expr: Expr, id: Identifier, nv: Expr) = { + val ct@CaseClassType(ccd, _) = expr.getType + val newFields = ccd.fields.map(vd => + if(vd.id == id) + nv + else + CaseClassSelector(CaseClassType(ct.classDef, ct.tps), expr, vd.id) + ) + + CaseClass(CaseClassType(ct.classDef, ct.tps), newFields).setPos(expr.getPos) + } + + private def updateCaseClassesWithVarFields(program: Program) = { + val extras = (for { + ccd <- program.definedClasses.collect{ case (c: CaseClassDef) => c } + } yield { + (ccd, ccd.varFields.map(vd => (ValDef(vd.id), vd.value))) + }) + updateCaseClassFields(extras)(program) + } + + private def updateCaseClassFields(extras: Seq[(CaseClassDef, Seq[(ValDef, Expr)])])(program: Program) = { + + def updateBody(body: Expr): Expr = { + preMap({ + case CaseClass(ct, args) => extras.find(p => p._1 == ct.classDef).map{ + case (ccd, extraFields) => + CaseClass(CaseClassType(ccd, ct.tps), args ++ extraFields.map{ case (_, v) => v }) + } + case fa@MutableFieldAccess(cct, rec, id) => + Some(CaseClassSelector(CaseClassType(cct.classDef, cct.tps), rec, id)) + //extras.find(p => p._1 == cct.classDef).map{ + // case (ccd, extraFields) => caseClassSelector(cct, rec, id) + //} + case _ => None + })(body) + } + + extras.foreach{ case (ccd, extraFields) => ccd.setFields(ccd.fields ++ extraFields.map(_._1)) } + for { + fd <- program.definedFunctions + } { + fd.body = fd.body.map(body => updateBody(body)) + fd.precondition = fd.precondition.map(pre => updateBody(pre)) + fd.postcondition = fd.postcondition.map(post => updateBody(post)) + } + extras.foreach{ case (ccd, _) => ccd.setVarFields(Nil) } + + } + + } diff --git a/src/main/scala/leon/xlang/Expressions.scala b/src/main/scala/leon/xlang/Expressions.scala index 98214ee640bd95227c0b759113ab74d2c9555d94..f7204b3091ab12fbd5c3f9fad370145a41b43851 100644 --- a/src/main/scala/leon/xlang/Expressions.scala +++ b/src/main/scala/leon/xlang/Expressions.scala @@ -53,6 +53,29 @@ object Expressions { } } + case class MutableFieldAccess(classType: CaseClassType, obj: Expr, varId: Identifier) extends XLangExpr with Extractable with PrettyPrintable { + val getType = classType.classDef.varFields.find(_.id == varId).map(_.getType).getOrElse(Untyped) + + def extract: Option[(Seq[Expr], Seq[Expr]=>Expr)] = { + Some((Seq(obj), (es: Seq[Expr]) => MutableFieldAccess(classType, es(0), varId))) + } + + def printWith(implicit pctx: PrinterContext) { + p"${obj}.${varId}" + } + } + case class FieldAssignment(obj: Expr, varId: Identifier, expr: Expr) extends XLangExpr with Extractable with PrettyPrintable { + val getType = UnitType + + def extract: Option[(Seq[Expr], Seq[Expr]=>Expr)] = { + Some((Seq(obj, expr), (es: Seq[Expr]) => FieldAssignment(es(0), varId, es(1)))) + } + + def printWith(implicit pctx: PrinterContext) { + p"${obj}.${varId} = ${expr};" + } + } + case class While(cond: Expr, body: Expr) extends XLangExpr with Extractable with PrettyPrintable { val getType = UnitType var invariant: Option[Expr] = None diff --git a/src/test/resources/regression/verification/xlang/invalid/FunctionCaching.scala b/src/test/resources/regression/verification/xlang/invalid/FunctionCaching.scala new file mode 100644 index 0000000000000000000000000000000000000000..3043093585e3f8893774ab0d717024647faedece --- /dev/null +++ b/src/test/resources/regression/verification/xlang/invalid/FunctionCaching.scala @@ -0,0 +1,35 @@ +import leon.lang._ +import leon.collection._ + +object FunctionCaching { + + case class FunCache() { + var cached: Map[BigInt, BigInt] = Map() + } + + def fun(x: BigInt)(implicit funCache: FunCache): BigInt = { + funCache.cached.get(x) match { + case None() => + val res = 2*x + 42 + funCache.cached = funCache.cached.updated(x, res) + res + case Some(cached) => + cached + 1 + } + } + + def funWronglyCached(x: BigInt, trash: List[BigInt]): Boolean = { + implicit val cache = FunCache() + val res1 = fun(x) + multipleCalls(trash) + val res2 = fun(x) + res1 == res2 + } holds + + + def multipleCalls(args: List[BigInt])(implicit funCache: FunCache): Unit = args match { + case Nil() => () + case x::xs => fun(x); multipleCalls(xs) + } + +} diff --git a/src/test/resources/regression/verification/xlang/valid/FunctionCaching.scala b/src/test/resources/regression/verification/xlang/valid/FunctionCaching.scala new file mode 100644 index 0000000000000000000000000000000000000000..aa079f46bdadfa5c7b58942cb33a50ac5a7b7823 --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/FunctionCaching.scala @@ -0,0 +1,42 @@ +import leon.lang._ +import leon.collection._ + +object FunctionCaching { + + case class FunCache() { + var cached: Map[BigInt, BigInt] = Map() + } + + def fun(x: BigInt)(implicit funCache: FunCache): BigInt = { + funCache.cached.get(x) match { + case None() => + val res = 2*x + 42 + funCache.cached = funCache.cached.updated(x, res) + res + case Some(cached) => + cached + } + } ensuring(res => old(funCache).cached.get(x) match { + case None() => true + case Some(v) => v == res + }) + + def funProperlyCached(x: BigInt, trash: List[BigInt]): Boolean = { + implicit val cache = FunCache() + val res1 = fun(x) + multipleCalls(trash, x) + val res2 = fun(x) + res1 == res2 + } holds + + def multipleCalls(args: List[BigInt], x: BigInt)(implicit funCache: FunCache): Unit = { + require(funCache.cached.get(x).forall(_ == 2*x + 42)) + args match { + case Nil() => () + case y::ys => + fun(y) + multipleCalls(ys, x) + } + } ensuring(_ => funCache.cached.get(x).forall(_ == 2*x + 42)) + +} diff --git a/src/test/resources/regression/verification/xlang/valid/ObjectParamMutation1.scala b/src/test/resources/regression/verification/xlang/valid/ObjectParamMutation1.scala new file mode 100644 index 0000000000000000000000000000000000000000..9ab1c321c4aa9bf251cdd8df599f5e3dde4946e0 --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/ObjectParamMutation1.scala @@ -0,0 +1,20 @@ +import leon.lang._ + +object ObjectParamMutation1 { + + case class A() { + var y: Int = 10 + } + + def update(a: A): Int = { + a.y = 12 + a.y + } ensuring(res => res == 12) + + def f(): Int = { + val a = A() + update(a) + a.y + } ensuring(res => res == 12) + +} diff --git a/src/test/resources/regression/verification/xlang/valid/ObjectParamMutation2.scala b/src/test/resources/regression/verification/xlang/valid/ObjectParamMutation2.scala new file mode 100644 index 0000000000000000000000000000000000000000..ceb28d217a47e73b1e1b8b7e87ce20eb61feeba5 --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/ObjectParamMutation2.scala @@ -0,0 +1,19 @@ +import leon.lang._ + +object ObjectParamMutation2 { + + case class A() { + var y: Int = 10 + } + + def update(a: A): Unit = { + a.y = 12 + } ensuring(_ => a.y == 12) + + def f(): Int = { + val a = A() + update(a) + a.y + } ensuring(res => res == 12) + +} diff --git a/src/test/resources/regression/verification/xlang/valid/ObjectParamMutation3.scala b/src/test/resources/regression/verification/xlang/valid/ObjectParamMutation3.scala new file mode 100644 index 0000000000000000000000000000000000000000..ff0b34cdaaf085dc9f290d8bc36b004d50176857 --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/ObjectParamMutation3.scala @@ -0,0 +1,19 @@ +import leon.lang._ + +object ObjectParamMutation3 { + + case class A() { + var y: Int = 10 + } + + def update(a: A): Unit = { + a.y = a.y + 3 + } ensuring(_ => a.y == old(a).y + 3) + + def f(): Int = { + val a = A() + update(a) + a.y + } ensuring(res => res == 13) + +} diff --git a/src/test/resources/regression/verification/xlang/valid/ObjectParamMutation4.scala b/src/test/resources/regression/verification/xlang/valid/ObjectParamMutation4.scala new file mode 100644 index 0000000000000000000000000000000000000000..2c4643cdb75a03efbfd9e074de77998233c6d8c5 --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/ObjectParamMutation4.scala @@ -0,0 +1,24 @@ +import leon.lang._ + +object ObjectParamMutation4 { + + case class A() { + var y: Int = 10 + } + + def swapY(a1: A, a2: A): Unit = { + val tmp = a1.y + a1.y = a2.y + a2.y = tmp + } ensuring(_ => a1.y == old(a2).y && a2.y == old(a1).y) + + def f(): (Int, Int) = { + val a1 = A() + val a2 = A() + a1.y = 12 + a2.y = 42 + swapY(a1, a2) + (a1.y, a2.y) + } ensuring(res => res._1 == 42 && res._2 == 12) + +} diff --git a/src/test/resources/regression/verification/xlang/valid/ObjectParamMutation5.scala b/src/test/resources/regression/verification/xlang/valid/ObjectParamMutation5.scala new file mode 100644 index 0000000000000000000000000000000000000000..dc78d309492e484714c70f163ba5797a3670aba1 --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/ObjectParamMutation5.scala @@ -0,0 +1,22 @@ +import leon.lang._ + +object ObjectParamMutation5 { + + case class A() { + var x: Int = 10 + var y: Int = 13 + } + + def swap(a: A): Unit = { + val tmp = a.x + a.x = a.y + a.y = tmp + } ensuring(_ => a.x == old(a).y && a.y == old(a).x) + + def f(): A = { + val a = A() + swap(a) + a + } ensuring(res => res.x == 13 && res.y == 10) + +} diff --git a/src/test/resources/regression/verification/xlang/valid/ObjectParamMutation6.scala b/src/test/resources/regression/verification/xlang/valid/ObjectParamMutation6.scala new file mode 100644 index 0000000000000000000000000000000000000000..d4dbf63bd2a47227521b2f64f3c42fae995fc96e --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/ObjectParamMutation6.scala @@ -0,0 +1,19 @@ +import leon.lang._ + +object ObjectParamMutation6 { + + case class A() { + var x: BigInt = 0 + } + + def inc(a: A): Unit = { + a.x += 1 + } ensuring(_ => a.x == old(a).x + 1) + + def f(): BigInt = { + val a = A() + inc(a); inc(a); inc(a) + a.x + } ensuring(res => res == 3) + +} diff --git a/src/test/resources/regression/verification/xlang/valid/ObjectParamMutation7.scala b/src/test/resources/regression/verification/xlang/valid/ObjectParamMutation7.scala new file mode 100644 index 0000000000000000000000000000000000000000..90e48180d447dc091d7ffeb1a21840869cabd03d --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/ObjectParamMutation7.scala @@ -0,0 +1,24 @@ +import leon.lang._ + +object ObjectParamMutation7 { + + case class A(a: Int) { + var x: BigInt = 0 + var y: BigInt = 0 + var z: BigInt = 0 + } + + def inc(a: A): Unit = { + require(a.x >= 0 && a.y >= 0 && a.z >= 0) + a.x += 1 + a.y += 1 + a.z += 1 + } ensuring(_ => a.x == old(a).x + 1 && a.y == old(a).y + 1 && a.z == old(a).z + 1) + + def f(): A = { + val a = A(0) + inc(a); inc(a); inc(a) + a + } ensuring(res => res.x == res.y && res.y == res.z && res.z == 3) + +} diff --git a/testcases/verification/xlang/FunctionCaching.scala b/testcases/verification/xlang/FunctionCaching.scala new file mode 100644 index 0000000000000000000000000000000000000000..0a7c477a67d188790f765934aabddfc903ef957e --- /dev/null +++ b/testcases/verification/xlang/FunctionCaching.scala @@ -0,0 +1,42 @@ +import leon.lang._ +import leon.collection._ + +object FunctionCaching { + + case class Cache() { + var cached: Map[BigInt, BigInt] = Map() + //contains the set of elements where cache has been used + var cacheHit: Set[BigInt] = Set() + } + + def cachedFun(f: (BigInt) => BigInt, x: BigInt)(implicit cache: Cache) = { + cache.cached.get(x) match { + case None() => + val res = f(x) + cache.cached = cache.cached.updated(x, res) + res + case Some(cached) => + cache.cacheHit = cache.cacheHit ++ Set(x) + cached + } + } + + def funProperlyCached(x: BigInt, fun: (BigInt) => BigInt, trash: List[BigInt]): Boolean = { + implicit val cache = Cache() + val res1 = cachedFun(fun, x) + multipleCalls(trash, x, fun) + val res2 = cachedFun(fun, x) + res1 == res2 && cache.cacheHit.contains(x) + } holds + + def multipleCalls(args: List[BigInt], x: BigInt, fun: (BigInt) => BigInt)(implicit cache: Cache): Unit = { + require(cache.cached.isDefinedAt(x)) + args match { + case Nil() => () + case y::ys => + cachedFun(fun, y) + multipleCalls(ys, x, fun) + } + } ensuring(_ => old(cache).cached.get(x) == cache.cached.get(x)) + +}