diff --git a/doc/purescala.rst b/doc/purescala.rst index f3c40b5fd57e5620fa33e2872ef51f7bfeb4583a..a7260db81416b0b4c1b6249d1b3dff5be0aa4479 100644 --- a/doc/purescala.rst +++ b/doc/purescala.rst @@ -47,11 +47,12 @@ ADT roots need to be defined as abstract, unless the ADT is defined with only on abstract class MyType +An abstract class can be extended by other abstract classes. Case Classes ************ -This abstract root can be extended by a case-class, defining several fields: +The abstract root can also be extended by a case-class, defining several fields: .. code-block:: scala @@ -101,7 +102,7 @@ Leon supports type parameters for classes and functions. Methods ------- -You can currently define methods in ADT roots: +You can define methods in classes. .. code-block:: scala @@ -113,6 +114,33 @@ You can currently define methods in ADT roots: def test(a: List[Int]) = a.contains(42) +It is possible to define abstract methods in abstract classes and implement them in case classes. +It is also possible to override methods. + +.. code-block:: scala + + abstract class A { + def x(a: Int): Int + } + + abstract class B extends A { + def x(a: Int) = { + require(a > 0) + 42 + } ensuring { _ >= 0 } + } + + case class C(c: Int) extends B { + override def x(i: Int) = { + require(i >= 0) + if (i == 0) 0 + else c + x(i-1) + } ensuring ( _ == c * i ) + } + + case class D() extends B + +It is not possible, however, to call methods of a superclass with the ``super`` keyword. Specifications -------------- diff --git a/src/main/scala/leon/codegen/CodeGeneration.scala b/src/main/scala/leon/codegen/CodeGeneration.scala index 79eb4d614ed6ff2ef3f2f5f80e71b2fc5308189f..cfda320f7b8d14e69e68e87d68884eb95b7b851a 100644 --- a/src/main/scala/leon/codegen/CodeGeneration.scala +++ b/src/main/scala/leon/codegen/CodeGeneration.scala @@ -94,7 +94,9 @@ trait CodeGeneration { case UnitType => "Z" case c : ClassType => - leonClassToJVMInfo(c.classDef).map { case (n, _) => "L" + n + ";" }.getOrElse("Unsupported class " + c.id) + leonClassToJVMInfo(c.classDef).map { case (n, _) => "L" + n + ";" }.getOrElse( + throw CompilationException("Unsupported class " + c.id) + ) case _ : TupleType => "L" + TupleClass + ";" @@ -134,7 +136,6 @@ trait CodeGeneration { * @param owner The module/class that contains $funDef */ def compileFunDef(funDef : FunDef, owner : Definition) { - val isStatic = owner.isInstanceOf[ModuleDef] val cf = classes(owner) @@ -153,7 +154,7 @@ trait CodeGeneration { mn, realParams : _* ) - m.setFlags(( + m.setFlags(( if (isStatic) METHOD_ACC_PUBLIC | METHOD_ACC_FINAL | @@ -263,7 +264,7 @@ trait CodeGeneration { } ch << InvokeSpecial(ccName, constructorName, ccApplySig) - case CaseClassInstanceOf(cct, e) => + case IsInstanceOf(cct, e) => val (ccName, _) = leonClassToJVMInfo(cct.classDef).getOrElse { throw CompilationException("Unknown class : " + cct.id) } @@ -1025,7 +1026,7 @@ trait CodeGeneration { ch << Label(innerElse) mkBranch(e, thenn, elze, ch) - case cci@CaseClassInstanceOf(cct, e) => + case cci@IsInstanceOf(cct, e) => mkExpr(cci, ch) ch << IfEq(elze) << Goto(thenn) @@ -1250,16 +1251,31 @@ trait CodeGeneration { } // definition of the constructor - if (fields.isEmpty && !params.doInstrument && !params.requireMonitor) cf.addDefaultConstructor else { + locally { val constrParams = if (params.requireMonitor) { Seq("L" + MonitorClass + ";") } else Seq() val cch = cf.addConstructor(constrParams : _*).codeHandler - // Abstract classes are hierarchy roots, so call java.lang.Object constructor + + for (lzy <- lazyFields) { initLazyField(cch, cName, lzy, false) } + for (field <- strictFields) { initStrictField(cch, cName, field, false)} + + // Call parent constructor cch << ALoad(0) - cch << InvokeSpecial("java/lang/Object", constructorName, "()V") + acd.parent match { + case Some(parent) => + val pName = defToJVMName(parent.classDef) + // Load monitor object + if (params.requireMonitor) cch << ALoad(1) + val constrSig = if (params.requireMonitor) "(L" + MonitorClass + ";)V" else "()V" + cch << InvokeSpecial(pName, constructorName, constrSig) + + case None => + // Call constructor of java.lang.Object + cch << InvokeSpecial("java/lang/Object", constructorName, "()V") + } // Initialize special monitor field if (params.doInstrument) { @@ -1267,10 +1283,7 @@ trait CodeGeneration { cch << Ldc(0) cch << PutField(cName, instrumentedField, "I") } - - for (lzy <- lazyFields) { initLazyField(cch, cName, lzy, false) } - for (field <- strictFields) { initStrictField(cch, cName, field, false)} - + cch << RETURN cch.freeze } diff --git a/src/main/scala/leon/codegen/CompilationUnit.scala b/src/main/scala/leon/codegen/CompilationUnit.scala index f42ce69b468dbd960bb0786b01220448e96a1165..85625077de933c167285f2f82c4d3107d887f4ff 100644 --- a/src/main/scala/leon/codegen/CompilationUnit.scala +++ b/src/main/scala/leon/codegen/CompilationUnit.scala @@ -36,13 +36,10 @@ class CompilationUnit(val ctx: LeonContext, val cName = defToJVMName(df) val cf = df match { - case ccd: CaseClassDef => - val pName = ccd.parent.map(parent => defToJVMName(parent.classDef)) + case cd: ClassDef => + val pName = cd.parent.map(parent => defToJVMName(parent.classDef)) new ClassFile(cName, pName) - case acd: AbstractClassDef => - new ClassFile(cName, None) - case ob: ModuleDef => new ClassFile(cName, None) @@ -393,9 +390,6 @@ class CompilationUnit(val ctx: LeonContext, ch << RETURN ch.freeze } - - - } @@ -404,9 +398,7 @@ class CompilationUnit(val ctx: LeonContext, // First define all classes/ methods/ functions for (u <- program.units) { - val (parents, children) = u.algebraicDataTypes.toSeq.unzip - - for ( cls <- parents ++ children.flatten ++ u.singleCaseClasses) { + for ( cls <- u.definedClassesOrdered ) { defineClass(cls) for (meth <- cls.methods) { defToModuleOrClass += meth -> cls @@ -427,18 +419,15 @@ class CompilationUnit(val ctx: LeonContext, // Compile everything for (u <- program.units) { - for ((parent, children) <- u.algebraicDataTypes) { - compileAbstractClassDef(parent) - - for (c <- children) { - compileCaseClassDef(c) + for (c <- u.definedClassesOrdered) { + c match { + case acd: AbstractClassDef => + compileAbstractClassDef(acd) + case ccd: CaseClassDef => + compileCaseClassDef(ccd) } } - for(single <- u.singleCaseClasses) { - compileCaseClassDef(single) - } - for (m <- u.modules) compileModule(m) } diff --git a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala index 4265fbecdb1d26e585a63fa61b2ae727b3dc96d1..0fff170048236ae86d102cc0b4d238956f5f64df 100644 --- a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala +++ b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala @@ -8,6 +8,7 @@ import purescala.Definitions._ import purescala.ExprOps._ import purescala.Expressions._ import purescala.Types._ +import purescala.TypeOps.isSubtypeOf import purescala.Constructors._ import purescala.Extractors._ @@ -221,12 +222,9 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int case CaseClass(cd, args) => CaseClass(cd, args.map(e)) - case CaseClassInstanceOf(cct, expr) => + case IsInstanceOf(ct, expr) => val le = e(expr) - BooleanLiteral(le.getType match { - case CaseClassType(cd2, _) if cd2 == cct.classDef => true - case _ => false - }) + BooleanLiteral(isSubtypeOf(le.getType, ct)) case CaseClassSelector(ct1, expr, sel) => val le = e(expr) diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala index a1ac056ed3a91c52830d16dbed7210a467c7aa5e..c14ade4c7cf6c77e4b2dc91b75b2ed0b582d12a0 100644 --- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala +++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala @@ -500,15 +500,13 @@ trait CodeExtraction extends ASTExtractors { val acd = AbstractClassDef(id, tparams, parent).setPos(sym.pos) classesToClasses += sym -> acd + parent.foreach(_.classDef.registerChildren(acd)) acd } else { val ccd = CaseClassDef(id, tparams, parent, sym.isModuleClass).setPos(sym.pos) - - parent.foreach(_.classDef.registerChildren(ccd)) - - classesToClasses += sym -> ccd + parent.foreach(_.classDef.registerChildren(ccd)) val fields = args.map { case (symbol, t) => val tpt = t.tpt @@ -538,6 +536,8 @@ trait CodeExtraction extends ASTExtractors { ccd } + + // We collect the methods and fields for (d <- tmpl.body) d match { case EmptyTree => @@ -548,10 +548,7 @@ trait CodeExtraction extends ASTExtractors { // Normal methods case t @ ExFunctionDef(fsym, _, _, _, _) => - if (parent.isDefined) { - outOfSubsetError(t, "Only hierarchy roots can define methods") - } - val fd = defineFunDef(fsym)(defCtx) + val fd = defineFunDef(fsym, Some(cd))(defCtx) isMethod += fsym methodToClass += fd -> cd @@ -574,11 +571,7 @@ trait CodeExtraction extends ASTExtractors { // Lazy fields case t @ ExLazyAccessorFunction(fsym, _, _) => - if (parent.isDefined) { - outOfSubsetError(t, "Only hierarchy roots can define lazy fields") - } - - val fd = defineFieldFunDef(fsym, true)(defCtx) + val fd = defineFieldFunDef(fsym, true, Some(cd))(defCtx) isMethod += fsym methodToClass += fd -> cd @@ -586,14 +579,10 @@ trait CodeExtraction extends ASTExtractors { cd.registerMethod(fd) // normal fields - case t @ ExFieldDef(fsym, _, _) => - if (parent.isDefined) { - outOfSubsetError(t, "Only hierarchy roots can define fields") - } - + case t @ ExFieldDef(fsym, _, _) => // we will be using the accessor method of this field everywhere val fsymAsMethod = fsym - val fd = defineFieldFunDef(fsymAsMethod, false)(defCtx) + val fd = defineFieldFunDef(fsymAsMethod, false, Some(cd))(defCtx) isMethod += fsymAsMethod methodToClass += fd -> cd @@ -606,9 +595,9 @@ trait CodeExtraction extends ASTExtractors { cd } - private var defsToDefs = Map[Symbol, FunDef]() + private var defsToDefs = Map[Symbol, FunDef]() - private def defineFunDef(sym: Symbol)(implicit dctx: DefContext): FunDef = { + private def defineFunDef(sym: Symbol, within: Option[LeonClassDef] = None)(implicit dctx: DefContext): FunDef = { // Type params of the function itself val tparams = extractTypeParams(sym.typeParams.map(_.tpe)) @@ -627,7 +616,19 @@ trait CodeExtraction extends ASTExtractors { val name = sym.name.toString - val fd = new FunDef(FreshIdentifier(name).setPos(sym.pos), tparamsDef, returnType, newParams) + val id = { + if (sym.overrideChain.length > 1) { + (for { + cd <- within + p <- cd.parent + m <- p.classDef.methods.find(_.id.name == name) + } yield m.id).getOrElse(FreshIdentifier(name)) + } else { + FreshIdentifier(name) + } + } + + val fd = new FunDef(id.setPos(sym.pos), tparamsDef, returnType, newParams) fd.setPos(sym.pos) @@ -642,7 +643,7 @@ trait CodeExtraction extends ASTExtractors { fd } - private def defineFieldFunDef(sym : Symbol, isLazy : Boolean)(implicit dctx : DefContext) : FunDef = { + private def defineFieldFunDef(sym : Symbol, isLazy : Boolean, within: Option[LeonClassDef] = None)(implicit dctx : DefContext) : FunDef = { val nctx = dctx.copy(tparams = dctx.tparams) @@ -650,7 +651,18 @@ trait CodeExtraction extends ASTExtractors { val name = sym.name.toString - val fd = new FunDef(FreshIdentifier(name).setPos(sym.pos), Seq(), returnType, Seq()) + val id = + if (sym.overrideChain.length == 1) { + FreshIdentifier(name) + } else { + ( for { + cd <- within + p <- cd.parent + m <- p.classDef.methods.find(_.id.name == name) + } yield m.id).getOrElse(FreshIdentifier(name)) + } + + val fd = new FunDef(id.setPos(sym.pos), Seq(), returnType, Seq()) fd.setPos(sym.pos) fd.addFlag(IsField(isLazy)) @@ -1337,25 +1349,22 @@ trait CodeExtraction extends ASTExtractors { val ccRec = extractTree(cc) val checkType = extractType(tt) checkType match { - case cct @ CaseClassType(ccd, tps) => { - val rootType: LeonClassDef = if(ccd.parent != None) ccd.parent.get.classDef else ccd - + case ct: ClassType => if(!ccRec.getType.isInstanceOf[ClassType]) { - outOfSubsetError(tr, "isInstanceOf can only be used with a case class") + outOfSubsetError(tr, "isInstanceOf can only be used with a class") } else { - val testedExprType = ccRec.getType.asInstanceOf[ClassType].classDef - val testedExprRootType: LeonClassDef = if(testedExprType.parent != None) testedExprType.parent.get.classDef else testedExprType + val rootType: LeonClassDef = ct.root.classDef + val testedExprType = ccRec.getType.asInstanceOf[ClassType] + val testedExprRootType: LeonClassDef = testedExprType.root.classDef if(rootType != testedExprRootType) { - outOfSubsetError(tr, "isInstanceOf can only be used with compatible case classes") + outOfSubsetError(tr, "isInstanceOf can only be used with compatible classes") } else { - CaseClassInstanceOf(cct, ccRec) + IsInstanceOf(ct, ccRec) } } - } - case _ => { - outOfSubsetError(tr, "isInstanceOf can only be used with a case class") - } + case _ => + outOfSubsetError(tr, "isInstanceOf can only be used with a class") } } diff --git a/src/main/scala/leon/purescala/Definitions.scala b/src/main/scala/leon/purescala/Definitions.scala index da9f619efbc8a4e14ac93eb32e0a01aeaa0dcbbb..ab8ab0cf9d754eea80fa9c0af724bcd7084bf76f 100644 --- a/src/main/scala/leon/purescala/Definitions.scala +++ b/src/main/scala/leon/purescala/Definitions.scala @@ -70,7 +70,6 @@ object Definitions { def definedFunctions = units.flatMap(_.definedFunctions) def definedClasses = units.flatMap(_.definedClasses) def classHierarchyRoots = units.flatMap(_.classHierarchyRoots) - def algebraicDataTypes = units.flatMap(_.algebraicDataTypes).toMap def singleCaseClasses = units.flatMap(_.singleCaseClasses) def modules = { units.flatMap(_.defs.collect { @@ -111,34 +110,6 @@ object Definitions { } } } -/* - // import pack._ - case class PackageImport(pack : PackageRef) extends Import { - val id = FreshIdentifier("import " + (pack mkString ".")) - def importedDefs(implicit pgm: Program): Seq[Definition] = for { - u <- DefOps.unitsInPackage(pgm, pack) - d <- u.subDefinitions - ret <- d match { - case m: ModuleDef if m.isPackageObject => - m.subDefinitions - case other => - Seq(other) - } - } yield ret - } - // import pack.(...).df - case class SingleImport(df : Definition) extends Import { - val id = FreshIdentifier(s"import ${df.id.toString}") - def importedDefs(implicit pgm: Program): Seq[Definition] = - List(df) - } - // import pack.(...).df._ - case class WildcardImport(df : Definition) extends Import { - val id = FreshIdentifier(s"import ${df.id.toString}._") - def importedDefs(implicit pgm: Program): Seq[Definition] = - df.subDefinitions - } - */ case class UnitDef( id: Identifier, @@ -165,11 +136,7 @@ object Definitions { definedClasses.filter(!_.hasParent) } - def algebraicDataTypes = { - definedClasses.collect { - case ccd: CaseClassDef if ccd.hasParent => ccd - }.groupBy(_.parent.get.classDef) - } + def definedClassesOrdered = classHierarchyRoots flatMap { root => root +: root.knownDescendents } def singleCaseClasses = { definedClasses.collect { @@ -239,19 +206,27 @@ object Definitions { _methods = _methods ::: List(fd) } + def unregisterMethod(id: Identifier) = { + _methods = _methods filterNot (_.id == id) + } + def clearMethods() { _methods = Nil } def methods = _methods + lazy val ancestors: Seq[ClassDef] = parent.toSeq flatMap { p => p.classDef +: p.classDef.ancestors } + + lazy val root = ancestors.lastOption.getOrElse(this) + def knownChildren: Seq[ClassDef] = _children def knownDescendents: Seq[ClassDef] = { - knownChildren ++ (knownChildren.map { + knownChildren ++ knownChildren.flatMap { case acd: AbstractClassDef => acd.knownDescendents case _ => Nil - }.foldLeft(List[ClassDef]())(_ ++ _)) + } } def knownCCDescendents: Seq[CaseClassDef] = knownDescendents.collect { diff --git a/src/main/scala/leon/purescala/ExprOps.scala b/src/main/scala/leon/purescala/ExprOps.scala index 90f9d82b671f25fdee27e292768a56c17eeac2ad..84d7f515608f06e289756282e4d14990fd8ab482 100644 --- a/src/main/scala/leon/purescala/ExprOps.scala +++ b/src/main/scala/leon/purescala/ExprOps.scala @@ -692,19 +692,24 @@ object ExprOps { pattern match { case WildcardPattern(ob) => bind(ob, in) case InstanceOfPattern(ob, ct) => - ct match { - case _: AbstractClassType => - bind(ob, in) - - case cct: CaseClassType => - and(CaseClassInstanceOf(cct, in), bind(ob, in)) + if (ct.parent.isEmpty) { + bind(ob, in) + } else { + val ccs = ct match { + case act: AbstractClassType => + act.knownCCDescendents + case cct: CaseClassType => + Seq(cct) + } + val oneOf = ccs map { IsInstanceOf(_, in) } + and(orJoin(oneOf), bind(ob, in)) } case CaseClassPattern(ob, cct, subps) => assert(cct.fields.size == subps.size) val pairs = cct.fields.map(_.id).toList zip subps.toList val subTests = pairs.map(p => rec(CaseClassSelector(cct, in, p._1), p._2)) val together = and(bind(ob, in) +: subTests :_*) - and(CaseClassInstanceOf(cct, in), together) + and(IsInstanceOf(cct, in), together) case TuplePattern(ob, subps) => { val TupleType(tpes) = in.getType @@ -720,11 +725,7 @@ object ExprOps { } def mapForPattern(in: Expr, pattern: Pattern) : Map[Identifier,Expr] = pattern match { - case WildcardPattern(None) => Map.empty - case WildcardPattern(Some(id)) => Map(id -> in) - case InstanceOfPattern(None, _) => Map.empty - case InstanceOfPattern(Some(id), _) => Map(id -> in) - case CaseClassPattern(b, ccd, subps) => { + case CaseClassPattern(b, ccd, subps) => assert(ccd.fields.size == subps.size) val pairs = ccd.fields.map(_.id).toList zip subps.toList val subMaps = pairs.map(p => mapForPattern(CaseClassSelector(ccd, in, p._1), p._2)) @@ -733,8 +734,8 @@ object ExprOps { case Some(id) => Map(id -> in) ++ together case None => together } - } - case TuplePattern(b, subps) => { + + case TuplePattern(b, subps) => val TupleType(tpes) = in.getType assert(tpes.size == subps.size) @@ -744,9 +745,12 @@ object ExprOps { case Some(id) => map + (id -> in) case None => map } - } - case LiteralPattern(None, lit) => Map() - case LiteralPattern(Some(id), lit) => Map(id -> in) + + case other => + other.binder match { + case None => Map.empty + case Some(b) => Map(b -> in) + } } /** Rewrites all pattern-matching expressions into if-then-else expressions, @@ -1263,7 +1267,7 @@ object ExprOps { case ccd: CaseClassDef => val cct = CaseClassType(ccd, tps) - val isType = CaseClassInstanceOf(cct, Variable(on)) + val isType = IsInstanceOf(cct, Variable(on)) val recSelectors = cct.fields.collect { case vd if vd.getType == on.getType => vd.id @@ -1836,260 +1840,4 @@ object ExprOps { None } - - /** - * Deprecated API - * ======== - */ - - @deprecated("Use postMap instead", "Leon 0.2.1") - def searchAndReplace(f: Expr => Option[Expr])(e: Expr) = postMap(f)(e) - - @deprecated("Use postMap instead", "Leon 0.2.1") - def searchAndReplaceDFS(f: Expr => Option[Expr])(e: Expr) = postMap(f)(e) - - @deprecated("Use exists instead", "Leon 0.2.1") - def contains(e: Expr, matcher: Expr => Boolean): Boolean = exists(matcher)(e) - - /* - * Transforms complicated Ifs into multiple nested if blocks - * It will decompose every OR clauses, and it will group AND clauses checking - * isInstanceOf toghether. - * - * if (a.isInstanceof[T1] && a.tail.isInstanceof[T2] && a.head == a2 || C) { - * T - * } else { - * E - * } - * - * Becomes: - * - * if (a.isInstanceof[T1] && a.tail.isInstanceof[T2]) { - * if (a.head == a2) { - * T - * } else { - * if(C) { - * T - * } else { - * E - * } - * } - * } else { - * if(C) { - * T - * } else { - * E - * } - * } - * - * This transformation runs immediately before patternMatchReconstruction. - * - * Notes: positions are lost. - */ - @deprecated("Mending an expression after matchToIfThenElse is unsafe", "Leon 0.2.4") - def decomposeIfs(e: Expr): Expr = { - def pre(e: Expr): Expr = e match { - case IfExpr(cond, thenn, elze) => - val TopLevelOrs(orcases) = cond - - if (orcases.exists{ case TopLevelAnds(ands) => ands.exists(_.isInstanceOf[CaseClassInstanceOf]) } ) { - if (orcases.tail.nonEmpty) { - pre(IfExpr(orcases.head, thenn, IfExpr(orJoin(orcases.tail), thenn, elze))) - } else { - val TopLevelAnds(andcases) = orcases.head - - val (andis, andnotis) = andcases.partition(_.isInstanceOf[CaseClassInstanceOf]) - - if (andis.isEmpty || andnotis.isEmpty) { - e - } else { - IfExpr(and(andis: _*), IfExpr(and(andnotis: _*), thenn, elze), elze) - } - } - } else { - e - } - case _ => - e - } - - simplePreTransform(pre)(e) - } - - /** - * Reconstructs match expressions from if-then-elses. - * - * Notes: positions are lost. - */ - @deprecated("Mending an expression after matchToIfThenElse is unsafe", "Leon 0.2.4") - def patternMatchReconstruction(e: Expr): Expr = { - def post(e: Expr): Expr = e match { - case IfExpr(cond, thenn, elze) => - val TopLevelAnds(cases) = cond - - if (cases.forall(_.isInstanceOf[CaseClassInstanceOf])) { - // matchingOn might initially be: a : T1, a.tail : T2, b: T2 - def selectorDepth(e: Expr): Int = e match { - case cd: CaseClassSelector => - 1+selectorDepth(cd.caseClass) - case _ => - 0 - } - - var scrutSet = Set[Expr]() - var conditions = Map[Expr, CaseClassType]() - - val matchingOn = cases.collect { case cc : CaseClassInstanceOf => cc } sortBy(cc => selectorDepth(cc.expr)) - for (CaseClassInstanceOf(cct, expr) <- matchingOn) { - conditions += expr -> cct - - expr match { - case cd: CaseClassSelector => - if (!scrutSet.contains(cd.caseClass)) { - // we found a test looking like "a.foo.isInstanceof[..]" - // without a check on "a". - scrutSet += cd - } - case e => - scrutSet += e - } - } - - var substMap = Map[Expr, Expr]() - - def computePatternFor(ct: CaseClassType, prefix: Expr): Pattern = { - - val name = prefix match { - case CaseClassSelector(_, _, id) => id.name - case Variable(id) => id.name - case _ => "tmp" - } - - val binder = FreshIdentifier(name, prefix.getType, true) - - // prefix becomes binder - substMap += prefix -> Variable(binder) - substMap += CaseClassInstanceOf(ct, prefix) -> BooleanLiteral(true) - - val subconds = for (f <- ct.fields) yield { - val fieldSel = CaseClassSelector(ct, prefix, f.id) - if (conditions contains fieldSel) { - computePatternFor(conditions(fieldSel), fieldSel) - } else { - val b = FreshIdentifier(f.id.name, f.getType, true) - substMap += fieldSel -> Variable(b) - WildcardPattern(Some(b)) - } - } - - CaseClassPattern(Some(binder), ct, subconds) - } - - val (scrutinees, patterns) = scrutSet.toSeq.map(s => (s, computePatternFor(conditions(s), s))).unzip - - val scrutinee = tupleWrap(scrutinees) - val pattern = tuplePatternWrap(patterns) - - // We use searchAndReplace to replace the biggest match first - // (topdown). - // So replaceing using Map(a => b, CC(a) => d) will replace - // "CC(a)" by "d" and not by "CC(b)" - val newThen = preMap(substMap.lift)(thenn) - - // Remove unused binders - val vars = variablesOf(newThen) - - def simplerBinder(oid: Option[Identifier]) = oid.filter(vars(_)) - - def simplifyPattern(p: Pattern): Pattern = p match { - case CaseClassPattern(ob, cd, subpatterns) => - CaseClassPattern(simplerBinder(ob), cd, subpatterns map simplifyPattern) - case WildcardPattern(ob) => - WildcardPattern(simplerBinder(ob)) - case TuplePattern(ob, patterns) => - TuplePattern(simplerBinder(ob), patterns map simplifyPattern) - case LiteralPattern(ob,lit) => LiteralPattern(simplerBinder(ob), lit) - case _ => - p - } - - val resCases = List( - SimpleCase(simplifyPattern(pattern), newThen), - SimpleCase(WildcardPattern(None), elze) - ) - - def mergePattern(to: Pattern, anchor: Identifier, pat: Pattern): Pattern = to match { - case CaseClassPattern(ob, cd, subs) => - if (ob == Some(anchor)) { - sys.error("WOOOT: "+to+" <<= "+pat +" on "+anchor) - pat - } else { - CaseClassPattern(ob, cd, subs.map(mergePattern(_, anchor, pat))) - } - case InstanceOfPattern(ob, cd) => - if (ob == Some(anchor)) { - sys.error("WOOOT: "+to+" <<= "+pat +" on "+anchor) - pat - } else { - InstanceOfPattern(ob, cd) - } - - case WildcardPattern(ob) => - if (ob == Some(anchor)) { - pat - } else { - WildcardPattern(ob) - } - case TuplePattern(ob,subs) => - if (ob == Some(anchor)) { - sys.error("WOOOT: "+to+" <<= "+pat +" on "+anchor) - pat - } else { - TuplePattern(ob, subs) - } - case LiteralPattern(ob, lit) => - if (ob == Some(anchor)) { - sys.error("WOOOT: "+to+" <<= "+pat +" on "+anchor) - pat - } else { - LiteralPattern(ob,lit) - } - - } - - val newCases = resCases.flatMap { - case SimpleCase(wp: WildcardPattern, m@MatchExpr(ex, cases)) if ex == scrutinee => - cases - - case c@SimpleCase(pattern, m@MatchExpr(v@Variable(id), cases)) => - if (pattern.binders(id)) { - cases.map { nc => - SimpleCase(mergePattern(pattern, id, nc.pattern), nc.rhs) - } - } else { - Seq(c) - } - case c => - Seq(c) - } - - var finalMatch = matchExpr(scrutinee, List(newCases.head)).asInstanceOf[MatchExpr] - - for (toAdd <- newCases.tail if !isMatchExhaustive(finalMatch)) { - finalMatch = matchExpr(scrutinee, finalMatch.cases :+ toAdd).asInstanceOf[MatchExpr] - } - - finalMatch - - } else { - e - } - case _ => - e - } - - simplePostTransform(post)(e) - } - - } diff --git a/src/main/scala/leon/purescala/Expressions.scala b/src/main/scala/leon/purescala/Expressions.scala index 0d2a840a30c6cc1ca921ad45687a5d52c2fa0630..960202301371a2226e034e8aa6f72e243aa5a883 100644 --- a/src/main/scala/leon/purescala/Expressions.scala +++ b/src/main/scala/leon/purescala/Expressions.scala @@ -266,7 +266,7 @@ object Expressions { val getType = ct } - case class CaseClassInstanceOf(classType: CaseClassType, expr: Expr) extends Expr { + case class IsInstanceOf(classType: ClassType, expr: Expr) extends Expr { val getType = BooleanType } diff --git a/src/main/scala/leon/purescala/Extractors.scala b/src/main/scala/leon/purescala/Extractors.scala index 7591e251d2fe76eec56ed7039234e89ae0ad8cc3..d74812fd6608841e9318d547c5bf20a3e4215f85 100644 --- a/src/main/scala/leon/purescala/Extractors.scala +++ b/src/main/scala/leon/purescala/Extractors.scala @@ -29,8 +29,8 @@ object Extractors { Some((Seq(t), (es: Seq[Expr]) => SetCardinality(es.head))) case CaseClassSelector(cd, e, sel) => Some((Seq(e), (es: Seq[Expr]) => CaseClassSelector(cd, es.head, sel))) - case CaseClassInstanceOf(cd, e) => - Some((Seq(e), (es: Seq[Expr]) => CaseClassInstanceOf(cd, es.head))) + case IsInstanceOf(cd, e) => + Some((Seq(e), (es: Seq[Expr]) => IsInstanceOf(cd, es.head))) case TupleSelect(t, i) => Some((Seq(t), (es: Seq[Expr]) => TupleSelect(es.head, i))) case ArrayLength(a) => diff --git a/src/main/scala/leon/purescala/MethodLifting.scala b/src/main/scala/leon/purescala/MethodLifting.scala index e63dad484842575fb65e290d7692cf42b821b947..cb124cc5bb0099ff656f6ec6461fb05881927493 100644 --- a/src/main/scala/leon/purescala/MethodLifting.scala +++ b/src/main/scala/leon/purescala/MethodLifting.scala @@ -9,13 +9,70 @@ import Expressions._ import Extractors._ import ExprOps._ import Types._ +import Constructors.and import TypeOps.instantiateType +import Constructors.application object MethodLifting extends TransformationPhase { val name = "Method Lifting" val description = "Translate methods into functions of the companion object" + // Takes cd and its subclasses and creates cases that together will form a composite method. + // fdId is the method id which will be searched for in the subclasses. + // cd is the hierarchy root + // A Seq of MatchCases is returned, along with a boolean that signifies if the matching is complete. + private def makeCases(cd: ClassDef, fdId: Identifier, breakDown: Expr => Expr): (Seq[MatchCase], Boolean) = cd match { + case ccd: CaseClassDef => + ccd.methods.find( _.id == fdId) match { + case None => + (List(), false) + case Some(m) => + val ct = classDefToClassType(ccd).asInstanceOf[CaseClassType] + val binder = FreshIdentifier(ccd.id.name.toLowerCase, ct, true) + val fBinders = ct.fields.map{ f => f.id -> f.id.freshen }.toMap + def subst(e: Expr): Expr = e match { + case CaseClassSelector(`ct`, This(`ct`), i) => + Variable(fBinders(i)).setPos(e) + case This(`ct`) => + Variable(binder).setPos(e) + case e => + e + } + val newE = simplePreTransform(subst)(breakDown(m.fullBody)) + val subPatts = ct.fields map (f => WildcardPattern(Some(fBinders(f.id)))) + val cse = SimpleCase(CaseClassPattern(Some(binder), ct, subPatts), newE).setPos(newE) + (List(cse), true) + } + case acd: AbstractClassDef => + val (r, c) = acd.knownChildren.map(makeCases(_, fdId, breakDown)).unzip + val recs = r.flatten + val complete = c forall (x => x) + if (complete) { + // Children define all cases completely, we don't need to add anything + (recs, true) + } else if (!acd.methods.exists( m => m.id == fdId && m.body.nonEmpty)) { + // We don't have anything to add + (recs, false) + } else { + // We have something to add + val m = acd.methods.find( m => m.id == fdId ).get + val at = classDefToClassType(acd).asInstanceOf[AbstractClassType] + val binder = FreshIdentifier(acd.id.name.toLowerCase, at, true) + def subst(e: Expr): Expr = e match { + case This(`at`) => + Variable(binder) + case e => + e + } + val newE = simplePreTransform(subst)(breakDown(m.fullBody)) + val cse = SimpleCase(InstanceOfPattern(Some(binder), at), newE).setPos(newE) + (recs :+ cse, true) + } + + } + + def apply(ctx: LeonContext, program: Program): Program = { // First we create the appropriate functions from methods: @@ -24,6 +81,42 @@ object MethodLifting extends TransformationPhase { val newUnits = for (u <- program.units) yield { var fdsOf = Map[String, Set[FunDef]]() + // Lift methods to the root class + for { + c <- u.definedClassesOrdered + if c.parent.isDefined + fd <- c.methods + if c.ancestors.forall(!_.methods.map{_.id}.contains(fd.id)) + } { + val root = c.ancestors.last + val tMap = c.tparams.zip(root.tparams.map{_.tp}).toMap + val tSubst: TypeTree => TypeTree = instantiateType(_, tMap) + + val fdParams = fd.params map { vd => + val newId = FreshIdentifier(vd.id.name, tSubst(vd.id.getType)) + ValDef(newId).setPos(vd.getPos) + } + val paramsMap = fd.params.zip(fdParams).map{ case (from, to) => from.id -> to.id }.toMap + val eSubst: Expr => Expr = instantiateType(_, tMap, paramsMap) + + val newFd = new FunDef(fd.id, fd.tparams, tSubst(fd.returnType), fdParams).copiedFrom(fd) + newFd.copyContentFrom(fd) + val prec = fd.precondition.getOrElse(BooleanLiteral(true)) + newFd.fullBody = eSubst(withPrecondition( + newFd.fullBody, + Some(and( + prec, + IsInstanceOf( + classDefToClassType(c,root.tparams.map{ _.tp }), + This(classDefToClassType(root)) + ) + )) + )) + + c.unregisterMethod(fd.id) + root.registerMethod(newFd) + } + // 1) Create one function for each method for { cd <- u.classHierarchyRoots if cd.methods.nonEmpty; fd <- cd.methods } { // We import class type params and freshen them @@ -37,7 +130,6 @@ object MethodLifting extends TransformationPhase { val newId = FreshIdentifier(vd.id.name, instantiateType(vd.id.getType, tparamsMap)) ValDef(newId).setPos(vd.getPos) } - val paramsMap = fd.params.zip(fdParams).map{case (x,y) => (x.id, y.id)}.toMap val receiver = FreshIdentifier("thiss", recType).setPos(cd.id) @@ -45,11 +137,83 @@ object MethodLifting extends TransformationPhase { nfd.copyContentFrom(fd) nfd.setPos(fd) nfd.addFlag(IsMethod(cd)) - nfd.fullBody = postMap{ - case This(ct) if ct.classDef == cd => Some(receiver.toVariable) - case _ => None - }(instantiateType(nfd.fullBody, tparamsMap, paramsMap)) + if (cd.knownDescendents.forall( _.methods.forall(_.id != fd.id))) { + val paramsMap = fd.params.zip(fdParams).map{case (x,y) => (x.id, y.id)}.toMap + // Don't need to compose methods + nfd.fullBody = postMap { + case th@This(ct) if ct.classDef == cd => + Some(receiver.toVariable.setPos(th)) + case _ => + None + }(instantiateType(nfd.fullBody, tparamsMap, paramsMap)) + } else { + // We need to compose methods of subclasses + + /* (Type) parameter substitutions that look at all subclasses */ + val paramsMap = (for { + c <- cd.knownDescendents :+ cd + m <- c.methods if m.id == fd.id + (from,to) <- m.params zip fdParams + } yield (from.id, to.id)).toMap + val classParamsMap = (for { + c <- cd.knownDescendents :+ cd + (from, to) <- c.tparams zip ctParams + } yield (from, to.tp)).toMap + val methodParamsMap = (for { + c <- cd.knownDescendents :+ cd + m <- c.methods if m.id == fd.id + (from,to) <- m.tparams zip fd.tparams + } yield (from, to.tp)).toMap + def inst(cs: Seq[MatchCase]) = instantiateType( + MatchExpr(Variable(receiver), cs).setPos(fd), + classParamsMap ++ methodParamsMap, + paramsMap + ) + + /* Separately handle pre, post, body */ + val (pre, _) = makeCases(cd, fd.id, preconditionOf(_).getOrElse(BooleanLiteral(true))) + val (post, _) = makeCases(cd, fd.id, postconditionOf(_).getOrElse( + Lambda(Seq(ValDef(FreshIdentifier("res", retType, true))), BooleanLiteral(true)) + )) + val (body, _) = makeCases(cd, fd.id, withoutSpec(_).getOrElse(NoTree(retType))) + + /* Some obvious simplifications */ + val preSimple = { + val trivial = pre.forall { _.rhs == BooleanLiteral(true) } + if (trivial) None else Some(inst(pre).setPos(fd.getPos)) + } + val postSimple = { + val trivial = post.forall { + case SimpleCase(_, Lambda(_, BooleanLiteral(true))) => true + case _ => false + } + if (trivial) None + else { + val resVal = FreshIdentifier("res", retType, true) + Some(Lambda( + Seq(ValDef(resVal)), + inst(post map { cs => cs.copy( rhs = + application(cs.rhs, Seq(Variable(resVal))) + )}) + ).setPos(fd)) + } + } + val bodySimple = { + val trivial = body forall { + case SimpleCase(_, NoTree(_)) => true + case _ => false + } + if (trivial) NoTree(retType) else inst(body) + } + + /* Construct full body */ + nfd.fullBody = withPostcondition( + withPrecondition(bodySimple, preSimple), + postSimple + ) + + } mdToFds += fd -> nfd fdsOf += cd.id.name -> (fdsOf.getOrElse(cd.id.name, Set()) + nfd) } @@ -68,16 +232,16 @@ object MethodLifting extends TransformationPhase { ModuleDef(FreshIdentifier(name), fds.toSeq, false) } - // 4) Remove methods in classes - for (cd <- u.definedClasses) { - cd.clearMethods() - } - u.copy(defs = defs ++ newCompanions) } val pgm = Program(newUnits) + // 4) Remove methods in classes + for (cd <- pgm.definedClasses) { + cd.clearMethods() + } + // 5) Replace method calls with function calls for (fd <- pgm.definedFunctions) { fd.fullBody = postMap{ diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala index 32a56fcfcc9c483bf46f88534b34a3a6c968a950..32de020857ab8edccd9c3d2e4d1a00dc4fb08910 100644 --- a/src/main/scala/leon/purescala/PrettyPrinter.scala +++ b/src/main/scala/leon/purescala/PrettyPrinter.scala @@ -185,7 +185,7 @@ class PrettyPrinter(opts: PrinterOptions, case NoTree(tpe) => p"???($tpe)" case Choose(pred) => p"choose($pred)" case e @ Error(tpe, err) => p"""error[$tpe]("$err")""" - case CaseClassInstanceOf(cct, e) => + case IsInstanceOf(cct, e) => if (cct.classDef.isCaseObject) { p"($e == $cct)" } else { diff --git a/src/main/scala/leon/purescala/TypeOps.scala b/src/main/scala/leon/purescala/TypeOps.scala index 8247a1ff5b5ebbd2c7b1863eed5f949b231578c6..d1b99517312cfa305482e0a0cc47156523f4e1b5 100644 --- a/src/main/scala/leon/purescala/TypeOps.scala +++ b/src/main/scala/leon/purescala/TypeOps.scala @@ -113,7 +113,7 @@ object TypeOps { } def bestRealType(t: TypeTree) : TypeTree = t match { - case (c: CaseClassType) => c.root + case (c: ClassType) => c.root case NAryType(tps, builder) => builder(tps.map(bestRealType)) } @@ -272,8 +272,8 @@ object TypeOps { case cc @ CaseClassSelector(ct, e, sel) => CaseClassSelector(tpeSub(ct).asInstanceOf[CaseClassType], srec(e), sel).copiedFrom(cc) - case cc @ CaseClassInstanceOf(ct, e) => - CaseClassInstanceOf(tpeSub(ct).asInstanceOf[CaseClassType], srec(e)).copiedFrom(cc) + case cc @ IsInstanceOf(ct, e) => + IsInstanceOf(tpeSub(ct).asInstanceOf[ClassType], srec(e)).copiedFrom(cc) case l @ Let(id, value, body) => val newId = freshId(id, tpeSub(id.getType)) diff --git a/src/main/scala/leon/purescala/Types.scala b/src/main/scala/leon/purescala/Types.scala index 8bf48519c329e5404438c05b8365795b77f06d41..ed9f730220c32b1a2e4ceb9cba8432aa5551e7f9 100644 --- a/src/main/scala/leon/purescala/Types.scala +++ b/src/main/scala/leon/purescala/Types.scala @@ -102,12 +102,12 @@ object Types { lazy val fieldsTypes = fields.map(_.getType) - lazy val root = parent.getOrElse(this) + lazy val root: ClassType = parent.map{ _.root }.getOrElse(this) - lazy val parent = classDef.parent.map { - pct => instantiateType(pct, (classDef.tparams zip tps).toMap) match { + lazy val parent = classDef.parent.map { pct => + instantiateType(pct, (classDef.tparams zip tps).toMap) match { case act: AbstractClassType => act - case t => throw LeonFatalError("Unexpected translated parent type: "+t) + case t => throw LeonFatalError("Unexpected translated parent type: "+t) } } diff --git a/src/main/scala/leon/solvers/ADTManager.scala b/src/main/scala/leon/solvers/ADTManager.scala index 0c29414816944d66602b829128648b992bbf1638..2ea69dd3ff585c2ee2e7b41c12d9e212df0fca0e 100644 --- a/src/main/scala/leon/solvers/ADTManager.scala +++ b/src/main/scala/leon/solvers/ADTManager.scala @@ -4,7 +4,6 @@ package leon package solvers import purescala.Types._ -import purescala.TypeOps._ import purescala.Common._ case class DataType(sym: Identifier, cases: Seq[Constructor]) { @@ -24,16 +23,15 @@ class ADTManager(ctx: LeonContext) { protected def freshId(id: Identifier): Identifier = freshId(id.name) protected def freshId(name: String): Identifier = FreshIdentifier(name) - protected def getHierarchy(ct: ClassType): (ClassType, Seq[CaseClassType]) = ct match { - case act: AbstractClassType => - (act, act.knownCCDescendents) - case cct: CaseClassType => - cct.parent match { - case Some(p) => - getHierarchy(p) - case None => - (cct, List(cct)) - } + protected def getHierarchy(ct: ClassType): (ClassType, Seq[CaseClassType]) = ct.parent match { + case Some(p) => + getHierarchy(p) + case None => (ct, ct match { + case act: AbstractClassType => + act.knownCCDescendents + case cct: CaseClassType => + List(cct) + }) } protected var defined = Set[TypeTree]() @@ -85,7 +83,7 @@ class ADTManager(ctx: LeonContext) { val (root, sub) = getHierarchy(ct) if (!(discovered contains root) && !(defined contains root)) { - val sym = freshId(ct.id) + val sym = freshId(root.id) val conss = sub.map { case cct => Constructor(freshId(cct.id), cct, cct.fields.map(vd => (freshId(vd.id), vd.getType))) diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala index 9981dae3284a7a4be699436e78c5e905e05eddf9..5b14f2f4cc30436a466f1bc38c5d924cd72baf76 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala @@ -244,9 +244,11 @@ abstract class SMTLIBSolver(val context: LeonContext, (id2sym(sym), cases.map(toDecl)) } + if (adts.nonEmpty) { + val cmd = DeclareDatatypes(adts) + sendCommand(cmd) + } - val cmd = DeclareDatatypes(adts) - sendCommand(cmd) } protected def declareStructuralSort(t: TypeTree): Sort = { @@ -254,7 +256,7 @@ abstract class SMTLIBSolver(val context: LeonContext, adtManager.defineADT(t) match { case Left(adts) => declareDatatypes(adts) - sorts.toB(t) + sorts.toB(normalizeType(t)) case Right(conflicts) => conflicts.foreach { declareStructuralSort } @@ -325,10 +327,24 @@ abstract class SMTLIBSolver(val context: LeonContext, val selector = selectors.toB((cct, s.selectorIndex)) FunctionApplication(selector, Seq(toSMT(e))) - case CaseClassInstanceOf(cct, e) => + case IsInstanceOf(cct, e) => declareSort(cct) - val tester = testers.toB(cct) - FunctionApplication(tester, Seq(toSMT(e))) + val cases = cct match { + case act: AbstractClassType => + act.knownCCDescendents + case cct: CaseClassType => + Seq(cct) + } + val oneOf = cases map testers.toB + oneOf match { + case Seq(tester) => + FunctionApplication(tester, Seq(toSMT(e))) + case more => + val es = freshSym("e") + SMTLet(VarBinding(es, toSMT(e)), Seq(), + Core.Or((oneOf map (FunctionApplication(_, Seq(es:Term)))): _*) + ) + } case CaseClass(cct, es) => declareSort(cct) @@ -406,7 +422,7 @@ abstract class SMTLIBSolver(val context: LeonContext, */ case m @ FiniteMap(elems, _, _) => val mt @ MapType(from, to) = m.getType - val ms = declareSort(mt) + declareSort(mt) toSMT(RawArrayValue(from, elems.map { case (k, v) => k -> CaseClass(library.someType(to), Seq(v)) @@ -414,7 +430,7 @@ abstract class SMTLIBSolver(val context: LeonContext, case MapGet(m, k) => - val mt @ MapType(from, to) = m.getType + val mt @ MapType(_, to) = m.getType declareSort(mt) // m(k) becomes // (Some-value (select m k)) @@ -424,7 +440,7 @@ abstract class SMTLIBSolver(val context: LeonContext, ) case MapIsDefinedAt(m, k) => - val mt @ MapType(from, to) = m.getType + val mt @ MapType(_, to) = m.getType declareSort(mt) // m.isDefinedAt(k) becomes // (is-Some (select m k)) @@ -434,7 +450,7 @@ abstract class SMTLIBSolver(val context: LeonContext, ) case MapUnion(m1, FiniteMap(elems, _, _)) => - val mt @ MapType(f, t) = m1.getType + val MapType(_, t) = m1.getType elems.foldLeft(toSMT(m1)) { case (m, (k,v)) => ArraysEx.Store(m, toSMT(k), toSMT(CaseClass(library.someType(t), Seq(v)))) diff --git a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala index 321f1234a5db3e671c207838993f8a6a06ea6126..c6f0c342e97d0caf381b5d4882a38c7aea5e3ea5 100644 --- a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala @@ -184,12 +184,12 @@ trait AbstractZ3Solver def declareStructuralSort(t: TypeTree): Z3Sort = { //println("///"*40) - //println("Declaring for: "+ct) + //println("Declaring for: "+t) adtManager.defineADT(t) match { case Left(adts) => declareDatatypes(adts.toSeq) - sorts.toB(t) + sorts.toB(normalizeType(t)) case Right(conflicts) => conflicts.foreach { declareStructuralSort } @@ -244,14 +244,12 @@ trait AbstractZ3Solver // Prepares some of the Z3 sorts, but *not* the tuple sorts; these are created on-demand. private def prepareSorts(): Unit = { - val Seq((us, Seq(unitCons), Seq(unitTester), _)) = z3.mkADTSorts( - Seq( - ( - "Unit", - Seq("Unit"), - Seq(Seq()) - ) - ) + z3.mkADTSorts( + Seq(( + "Unit", + Seq("Unit"), + Seq(Seq()) + )) ) //TODO: mkBitVectorType @@ -473,7 +471,16 @@ trait AbstractZ3Solver val selector = selectors.toB(cct, c.selectorIndex) selector(rec(cc)) - case c @ CaseClassInstanceOf(cct, e) => + case IsInstanceOf(act: AbstractClassType, e) => + act.knownCCDescendents match { + case Seq(cct) => + rec(IsInstanceOf(cct, e)) + case more => + val i = FreshIdentifier("e", act, alwaysShowUniqueID = true) + rec(Let(i, e, orJoin(more map(IsInstanceOf(_, Variable(i)))))) + } + + case IsInstanceOf(cct: CaseClassType, e) => typeToSort(cct) // Making sure the sort is defined val tester = testers.toB(cct) tester(rec(e)) diff --git a/src/main/scala/leon/synthesis/rules/ADTDual.scala b/src/main/scala/leon/synthesis/rules/ADTDual.scala index 4e4c8b07e1e9ae55bf6f110aedf7d28a8102e8ea..3f68419ed5f9c4e01e83f76e06e306590147c2ef 100644 --- a/src/main/scala/leon/synthesis/rules/ADTDual.scala +++ b/src/main/scala/leon/synthesis/rules/ADTDual.scala @@ -19,10 +19,10 @@ case object ADTDual extends NormalizingRule("ADTDual") { val (toRemove, toAdd) = exprs.collect { case eq @ Equals(cc @ CaseClass(ct, args), e) if (variablesOf(e) -- as).isEmpty && (variablesOf(cc) & xs).nonEmpty => - (eq, CaseClassInstanceOf(ct, e) +: (ct.fields zip args).map{ case (vd, ex) => Equals(ex, CaseClassSelector(ct, e, vd.id)) } ) + (eq, IsInstanceOf(ct, e) +: (ct.fields zip args).map{ case (vd, ex) => Equals(ex, CaseClassSelector(ct, e, vd.id)) } ) case eq @ Equals(e, cc @ CaseClass(ct, args)) if (variablesOf(e) -- as).isEmpty && (variablesOf(cc) & xs).nonEmpty => - (eq, CaseClassInstanceOf(ct, e) +: (ct.fields zip args).map{ case (vd, ex) => Equals(ex, CaseClassSelector(ct, e, vd.id)) } ) + (eq, IsInstanceOf(ct, e) +: (ct.fields zip args).map{ case (vd, ex) => Equals(ex, CaseClassSelector(ct, e, vd.id)) } ) }.unzip if (toRemove.nonEmpty) { diff --git a/src/main/scala/leon/synthesis/rules/ADTInduction.scala b/src/main/scala/leon/synthesis/rules/ADTInduction.scala index 4578eab14eb9dcd7e70804f359a85c725bf1ec39..c493a540e3d88ab28d6677119bf88955c203ffe8 100644 --- a/src/main/scala/leon/synthesis/rules/ADTInduction.scala +++ b/src/main/scala/leon/synthesis/rules/ADTInduction.scala @@ -68,7 +68,7 @@ case object ADTInduction extends Rule("ADT Induction") { val subPC = substAll(Map(inductOn -> CaseClass(cct, newIds.map(Variable))), innerPC) val subWS = substAll(Map(inductOn -> CaseClass(cct, newIds.map(Variable))), innerWS) - val subPre = CaseClassInstanceOf(cct, Variable(origId)) + val subPre = IsInstanceOf(cct, Variable(origId)) val subProblem = Problem(inputs ::: residualArgs, subWS, andJoin(subPC :: postFs), subPhi, p.xs) diff --git a/src/main/scala/leon/synthesis/rules/ADTLongInduction.scala b/src/main/scala/leon/synthesis/rules/ADTLongInduction.scala index bcafde3c54dd7ebadfab9eb8417a4c4c4ceb2406..bdf0f4cf7610019244e05c46b418e4df1cc1a014 100644 --- a/src/main/scala/leon/synthesis/rules/ADTLongInduction.scala +++ b/src/main/scala/leon/synthesis/rules/ADTLongInduction.scala @@ -82,7 +82,7 @@ case object ADTLongInduction extends Rule("ADT Long Induction") { val newMap = trMap.mapValues(v => substAll(Map(id -> CaseClass(cct, subIds.map(Variable))), v)) - InductCase(newIds, newCalls, newPattern, and(pc, CaseClassInstanceOf(cct, Variable(id))), newMap) + InductCase(newIds, newCalls, newPattern, and(pc, IsInstanceOf(cct, Variable(id))), newMap) } }).flatten } else { diff --git a/src/main/scala/leon/synthesis/rules/ADTSplit.scala b/src/main/scala/leon/synthesis/rules/ADTSplit.scala index 6215814ed33e49e2e573c8cc1d26f49e48c6ebed..38bc0455af310ad7c90713d26fc59feb44c5bbcb 100644 --- a/src/main/scala/leon/synthesis/rules/ADTSplit.scala +++ b/src/main/scala/leon/synthesis/rules/ADTSplit.scala @@ -23,7 +23,7 @@ case object ADTSplit extends Rule("ADT Split.") { val optCases = for (dcd <- cd.knownDescendents.sortBy(_.id.name)) yield dcd match { case ccd : CaseClassDef => val cct = CaseClassType(ccd, tpes) - val toSat = and(p.pc, CaseClassInstanceOf(cct, Variable(id))) + val toSat = and(p.pc, IsInstanceOf(cct, Variable(id))) val isImplied = solver.solveSAT(toSat) match { case (Some(false), _) => true @@ -76,7 +76,7 @@ case object ADTSplit extends Rule("ADT Split.") { val substs = (for ((field,arg) <- cct.fields zip problem.as ) yield { (arg, CaseClassSelector(cct, id.toVariable, field.id)) }).toMap - globalPre ::= and(CaseClassInstanceOf(cct, Variable(id)), replaceFromIDs(substs, sol.pre)) + globalPre ::= and(IsInstanceOf(cct, Variable(id)), replaceFromIDs(substs, sol.pre)) } else { globalPre ::= BooleanLiteral(true) } diff --git a/src/main/scala/leon/synthesis/rules/EquivalentInputs.scala b/src/main/scala/leon/synthesis/rules/EquivalentInputs.scala index ef5f39a7d7311b384cd88492a4d19969bd81403f..10ba905b87b53de902853d1daa770f061503ba2a 100644 --- a/src/main/scala/leon/synthesis/rules/EquivalentInputs.scala +++ b/src/main/scala/leon/synthesis/rules/EquivalentInputs.scala @@ -9,6 +9,7 @@ import purescala.Expressions._ import purescala.ExprOps._ import purescala.Extractors._ import purescala.Constructors._ +import purescala.Types.CaseClassType case object EquivalentInputs extends NormalizingRule("EquivalentInputs") { def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { @@ -16,12 +17,12 @@ case object EquivalentInputs extends NormalizingRule("EquivalentInputs") { def discoverEquivalences(allClauses: Seq[Expr]): Seq[(Expr, Expr)] = { val instanceOfs = allClauses.collect { - case ccio @ CaseClassInstanceOf(cct, s) => ccio + case ccio @ IsInstanceOf(cct, s) => ccio } val clauses = allClauses.filterNot(instanceOfs.toSet) - val ccSubsts = for (CaseClassInstanceOf(cct, s) <- instanceOfs) yield { + val ccSubsts = for (IsInstanceOf(cct: CaseClassType, s) <- instanceOfs) yield { val fieldsVals = (for (f <- cct.fields) yield { val id = f.id diff --git a/src/main/scala/leon/termination/SelfCallsProcessor.scala b/src/main/scala/leon/termination/SelfCallsProcessor.scala index 4f08355dc32264c05ce6f6318b6d7b5b90ccd7b7..60ebe81babe2a95c2924478ec225acc2d0f699a2 100644 --- a/src/main/scala/leon/termination/SelfCallsProcessor.scala +++ b/src/main/scala/leon/termination/SelfCallsProcessor.scala @@ -53,7 +53,7 @@ class SelfCallsProcessor(val checker: TerminationChecker) extends Processor { case Not(expr: Expr) => rec(expr) case Equals(lhs: Expr, rhs: Expr) => rec(lhs) || rec(rhs) case CaseClass(ct, args: Seq[Expr]) => args.exists(arg => rec(arg)) - case CaseClassInstanceOf(ct, expr: Expr) => rec(expr) + case IsInstanceOf(ct, expr: Expr) => rec(expr) case CaseClassSelector(ct, caseClassExpr, selector) => rec(caseClassExpr) case Plus(lhs: Expr, rhs: Expr) => rec(lhs) || rec(rhs) case Minus(lhs: Expr, rhs: Expr) => rec(lhs) || rec(rhs) diff --git a/src/main/scala/leon/utils/TypingPhase.scala b/src/main/scala/leon/utils/TypingPhase.scala index 2de4e72c07edb3a7e42d08fa36185433399baaef..ee7fd600dcc75d956ac2feb5839ef09112e3de85 100644 --- a/src/main/scala/leon/utils/TypingPhase.scala +++ b/src/main/scala/leon/utils/TypingPhase.scala @@ -26,7 +26,7 @@ object TypingPhase extends LeonPhase[Program, Program] { * 2) Report warnings in case parts of the tree are not correctly typed * (Untyped). * - * 3) Make sure that abstract classes have at least one descendent + * 3) Make sure that abstract classes have at least one descendant */ def run(ctx: LeonContext)(pgm: Program): Program = { pgm.definedFunctions.foreach(fd => { @@ -34,7 +34,7 @@ object TypingPhase extends LeonPhase[Program, Program] { // Part (1) fd.precondition = { val argTypesPreconditions = fd.params.flatMap(arg => arg.getType match { - case cct : CaseClassType if cct.parent.isDefined => Seq(CaseClassInstanceOf(cct, arg.id.toVariable)) + case cct : CaseClassType if cct.parent.isDefined => Seq(IsInstanceOf(cct, arg.id.toVariable)) case (at : ArrayType) => Seq(GreaterEquals(ArrayLength(arg.id.toVariable), IntLiteral(0))) case _ => Seq() }) @@ -54,11 +54,11 @@ object TypingPhase extends LeonPhase[Program, Program] { case Some(p) => Some(Lambda(Seq(ValDef(resId)), and( application(p, Seq(Variable(resId))), - CaseClassInstanceOf(cct, Variable(resId)) + IsInstanceOf(cct, Variable(resId)) ).setPos(p)).setPos(p)) case None => - Some(Lambda(Seq(ValDef(resId)), CaseClassInstanceOf(cct, Variable(resId)))) + Some(Lambda(Seq(ValDef(resId)), IsInstanceOf(cct, Variable(resId)))) } } case _ => fd.postcondition @@ -80,7 +80,7 @@ object TypingPhase extends LeonPhase[Program, Program] { pgm.definedClasses.foreach { case acd: AbstractClassDef => if (acd.knownCCDescendents.isEmpty) { - ctx.reporter.error(acd.getPos, "Class "+acd.id.asString(ctx)+" has no concrete descendent!") + ctx.reporter.error(acd.getPos, "Class "+acd.id.asString(ctx)+" has no concrete descendant!") } case _ => } diff --git a/src/main/scala/leon/verification/InductionTactic.scala b/src/main/scala/leon/verification/InductionTactic.scala index 529b9e5fc859d8cbb8f325bcc39d2b7b1e5cd618..c780c647f3fae01cdc2a17c3e904100c9ceff743 100644 --- a/src/main/scala/leon/verification/InductionTactic.scala +++ b/src/main/scala/leon/verification/InductionTactic.scala @@ -40,7 +40,7 @@ class InductionTactic(vctx: VerificationContext) extends DefaultTactic(vctx) { } val vc = implies( - and(CaseClassInstanceOf(cct, arg.toVariable), precOrTrue(fd)), + and(IsInstanceOf(cct, arg.toVariable), precOrTrue(fd)), implies(andJoin(subCases), application(post, Seq(body))) ) @@ -77,7 +77,7 @@ class InductionTactic(vctx: VerificationContext) extends DefaultTactic(vctx) { } val vc = implies( - andJoin(Seq(CaseClassInstanceOf(cct, arg.toVariable), precOrTrue(fd), path) ++ subCases), + andJoin(Seq(IsInstanceOf(cct, arg.toVariable), precOrTrue(fd), path) ++ subCases), tfd.withParamSubst(args, pre) ) diff --git a/src/test/resources/regression/frontends/passing/Overrides.scala b/src/test/resources/regression/frontends/passing/Overrides.scala new file mode 100644 index 0000000000000000000000000000000000000000..c3be61be2482367e28a2d54c49cd7c7c1a54e6ab --- /dev/null +++ b/src/test/resources/regression/frontends/passing/Overrides.scala @@ -0,0 +1,16 @@ +object Overrides { + + abstract class A[T] { + def x[A](a: A): (A,T) + } + + abstract class B[R] extends A[R] { + def x[B](b: B) = x(b) + } + + case class C[W](c: W) extends B[W] { + override def x[C](f: C) = (f,c) + } + + case class D[Z]() extends B[Z] +} diff --git a/src/test/resources/regression/verification/purescala/invalid/Overrides.scala b/src/test/resources/regression/verification/purescala/invalid/Overrides.scala new file mode 100644 index 0000000000000000000000000000000000000000..ae6b4ebd4bea6dffdba88491ced95824719613c1 --- /dev/null +++ b/src/test/resources/regression/verification/purescala/invalid/Overrides.scala @@ -0,0 +1,22 @@ +object Overrides { + abstract class A { + def x(a: Int): Int + } + + abstract class B extends A { + def x(a: Int) = { + require(a > 0) + 42 + } ensuring { _ >= 0 } + } + + case class C(c: Int) extends B { + override def x(i: Int) = { + require(i >= 0) + if (i == 0) 0 + else c + x(i-1) + } ensuring ( _ != c * i ) + } + + case class D() extends B +} diff --git a/src/test/resources/regression/verification/purescala/valid/Overrides.scala b/src/test/resources/regression/verification/purescala/valid/Overrides.scala new file mode 100644 index 0000000000000000000000000000000000000000..4b6767e6b80e7159c223a022762cd5080f65d84e --- /dev/null +++ b/src/test/resources/regression/verification/purescala/valid/Overrides.scala @@ -0,0 +1,22 @@ +object Overrides { + abstract class A { + def x(a: Int): Int + } + + abstract class B extends A { + def x(a: Int) = { + require(a > 0) + 42 + } ensuring { _ >= 0 } + } + + case class C(c: Int) extends B { + override def x(i: Int) = { + require(i >= 0) + if (i == 0) 0 + else c + x(i-1) + } ensuring ( _ == c * i ) + } + + case class D() extends B +} diff --git a/src/test/scala/leon/test/evaluators/EvaluatorSuite.scala b/src/test/scala/leon/test/evaluators/EvaluatorSuite.scala index d1957bcf38fa885b2d66b8411f30d9ffb420852b..b5c623ab62252b2515de76f44c45054c8f4a4838 100644 --- a/src/test/scala/leon/test/evaluators/EvaluatorSuite.scala +++ b/src/test/scala/leon/test/evaluators/EvaluatorSuite.scala @@ -625,4 +625,46 @@ class EvaluatorSuite extends leon.test.LeonTestSuite { checkLambda(e, mkCall("foo4", TWO), { case Lambda(Seq(vd), Plus(Variable(id), TWO)) if vd.id == id => true }) } } + + test("Methods") { + val p = + """object Program { + | abstract class A + | + | abstract class B extends A { + | def foo(i: BigInt) = { + | require(i > 0) + | i + 1 + | } ensuring ( _ >= 0 ) + | } + | + | case class C(x: BigInt) extends B { + | val y = BigInt(42) + | override def foo(i: BigInt) = { + | x + y + (if (i>0) i else -i) + | } ensuring ( _ >= x ) + | } + | + | case class D() extends A + | + | def f1 = { + | val c = C(42) + | (if (c.foo(0) + c.x > 0) c else D()).isInstanceOf[B] + | } + | def f2 = D().isInstanceOf[B] + | def f3 = C(42).isInstanceOf[A] + |} + | + | + """.stripMargin + + implicit val prog = parseString(p) + val evaluators = prepareEvaluators + for(e <- evaluators) { + // Some simple math. + checkComp(e, mkCall("f1"), BooleanLiteral(true)) + checkComp(e, mkCall("f2"), BooleanLiteral(false)) + checkComp(e, mkCall("f3"), BooleanLiteral(true)) + } + } }