From e7e93f4bf1e260363e985767a1f3d9dd32e2b815 Mon Sep 17 00:00:00 2001 From: Manos Koukoutos <emmanouil.koukoutos@epfl.ch> Date: Wed, 8 Jul 2015 11:11:00 +0200 Subject: [PATCH] Make CaseClassSelector consistent with other Expr's --- .../frontends/scalac/CodeExtraction.scala | 2 +- .../scala/leon/purescala/Constructors.scala | 12 +++++++-- .../scala/leon/purescala/Definitions.scala | 15 ++++++----- src/main/scala/leon/purescala/ExprOps.scala | 10 +++---- .../scala/leon/purescala/Expressions.scala | 27 +------------------ .../scala/leon/purescala/Extractors.scala | 2 +- src/main/scala/leon/purescala/TypeOps.scala | 2 +- .../scala/leon/synthesis/rules/ADTDual.scala | 4 +-- .../scala/leon/synthesis/rules/ADTSplit.scala | 2 +- .../leon/synthesis/rules/DetupleInput.scala | 2 +- .../leon/termination/ChainComparator.scala | 4 +-- src/main/scala/leon/utils/InliningPhase.scala | 6 ++--- .../leon/verification/InductionTactic.scala | 2 +- 13 files changed, 37 insertions(+), 53 deletions(-) diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala index 3a94aa0fc..8970b698a 100644 --- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala +++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala @@ -1453,7 +1453,7 @@ trait CodeExtraction extends ASTExtractors { val fieldID = cct.fields.find(_.id.name == name).get.id - CaseClassSelector(cct, rec, fieldID) + caseClassSelector(cct, rec, fieldID) //BigInt methods case (IsTyped(a1, IntegerType), "+", List(IsTyped(a2, IntegerType))) => diff --git a/src/main/scala/leon/purescala/Constructors.scala b/src/main/scala/leon/purescala/Constructors.scala index 6ffb833e3..05ed1e034 100644 --- a/src/main/scala/leon/purescala/Constructors.scala +++ b/src/main/scala/leon/purescala/Constructors.scala @@ -79,9 +79,17 @@ object Constructors { case None => sys.error(s"$actualType cannot be a subtype of $formalType!") } - } - + + def caseClassSelector(classType: CaseClassType, caseClass: Expr, selector: Identifier): Expr = { + caseClass match { + case CaseClass(ct, fields) if ct.classDef == classType.classDef => + fields(ct.classDef.selectorID2Index(selector)) + case _ => + CaseClassSelector(classType, caseClass, selector) + } + } + private def filterCases(scrutType : TypeTree, resType: Option[TypeTree], cases: Seq[MatchCase]): Seq[MatchCase] = { val casesFiltered = scrutType match { case c: CaseClassType => diff --git a/src/main/scala/leon/purescala/Definitions.scala b/src/main/scala/leon/purescala/Definitions.scala index 2d37cce58..8b3e6ca9a 100644 --- a/src/main/scala/leon/purescala/Definitions.scala +++ b/src/main/scala/leon/purescala/Definitions.scala @@ -274,16 +274,17 @@ object Definitions { _fields = fields } - val isAbstract = false - def selectorID2Index(id: Identifier) : Int = { - val index = fields.zipWithIndex.find(_._1.id == id).map(_._2) - - index.getOrElse { - scala.sys.error("Could not find '"+id+"' ("+id.uniqueName+") within "+fields.map(_.id.uniqueName).mkString(", ")) - } + val index = fields.indexWhere(_.id == id) + + if (index < 0) { + scala.sys.error( + "Could not find '"+id+"' ("+id.uniqueName+") within "+ + fields.map(_.id.uniqueName).mkString(", ") + ) + } else index } lazy val singleCaseClasses : Seq[CaseClassDef] = if (hasParent) Nil else Seq(this) diff --git a/src/main/scala/leon/purescala/ExprOps.scala b/src/main/scala/leon/purescala/ExprOps.scala index 73e4aaed9..727e0d94b 100644 --- a/src/main/scala/leon/purescala/ExprOps.scala +++ b/src/main/scala/leon/purescala/ExprOps.scala @@ -392,7 +392,7 @@ object ExprOps { Some(letTuple(ids, v, tupleSelect(b, ts, true))) case CaseClassSelector(cct, cc: CaseClass, id) => - Some(CaseClassSelector(cct, cc, id)) + Some(caseClassSelector(cct, cc, id)) case IfExpr(c, thenn, elze) if (thenn == elze) && isDeterministic(e) => Some(thenn) @@ -644,7 +644,7 @@ object ExprOps { case CaseClassPattern(_, cct, subps) => val subExprs = (subps zip cct.fields) map { - case (p, f) => p.binder.map(_.toVariable).getOrElse(CaseClassSelector(cct, in, f.id)) + case (p, f) => p.binder.map(_.toVariable).getOrElse(caseClassSelector(cct, in, f.id)) } // Special case to get rid of Cons(a,b) match { case Cons(c,d) => .. } @@ -705,7 +705,7 @@ object ExprOps { case CaseClassPattern(ob, cct, subps) => assert(cct.fields.size == subps.size) val pairs = cct.fields.map(_.id).toList zip subps.toList - val subTests = pairs.map(p => rec(CaseClassSelector(cct, in, p._1), p._2)) + val subTests = pairs.map(p => rec(caseClassSelector(cct, in, p._1), p._2)) val together = and(bind(ob, in) +: subTests :_*) and(IsInstanceOf(cct, in), together) @@ -727,7 +727,7 @@ object ExprOps { case CaseClassPattern(b, ccd, subps) => assert(ccd.fields.size == subps.size) val pairs = ccd.fields.map(_.id).toList zip subps.toList - val subMaps = pairs.map(p => mapForPattern(CaseClassSelector(ccd, in, p._1), p._2)) + val subMaps = pairs.map(p => mapForPattern(caseClassSelector(ccd, in, p._1), p._2)) val together = subMaps.flatten.toMap b match { case Some(id) => Map(id -> in) ++ together @@ -1283,7 +1283,7 @@ object ExprOps { val v = Variable(on) recSelectors.map{ s => - and(isType, expr, not(replace(Map(v -> CaseClassSelector(cct, v, s)), expr))) + and(isType, expr, not(replace(Map(v -> caseClassSelector(cct, v, s)), expr))) } } }.flatten diff --git a/src/main/scala/leon/purescala/Expressions.scala b/src/main/scala/leon/purescala/Expressions.scala index b49c61216..77fcdb40c 100644 --- a/src/main/scala/leon/purescala/Expressions.scala +++ b/src/main/scala/leon/purescala/Expressions.scala @@ -270,34 +270,9 @@ object Expressions { val getType = BooleanType } - object CaseClassSelector { - def apply(classType: CaseClassType, caseClass: Expr, selector: Identifier): Expr = { - caseClass match { - case CaseClass(ct, fields) => - if (ct.classDef == classType.classDef) { - fields(ct.classDef.selectorID2Index(selector)) - } else { - new CaseClassSelector(classType, caseClass, selector) - } - case _ => new CaseClassSelector(classType, caseClass, selector) - } - } - - def unapply(ccs: CaseClassSelector): Option[(CaseClassType, Expr, Identifier)] = { - Some((ccs.classType, ccs.caseClass, ccs.selector)) - } - } - - class CaseClassSelector(val classType: CaseClassType, val caseClass: Expr, val selector: Identifier) extends Expr { + case class CaseClassSelector(classType: CaseClassType, caseClass: Expr, selector: Identifier) extends Expr { val selectorIndex = classType.classDef.selectorID2Index(selector) val getType = classType.fieldsTypes(selectorIndex) - - override def equals(that: Any): Boolean = (that != null) && (that match { - case t: CaseClassSelector => (t.classType, t.caseClass, t.selector) == (classType, caseClass, selector) - case _ => false - }) - - override def hashCode: Int = (classType, caseClass, selector).hashCode + 9 } /* Arithmetic */ diff --git a/src/main/scala/leon/purescala/Extractors.scala b/src/main/scala/leon/purescala/Extractors.scala index d74812fd6..d17b58a96 100644 --- a/src/main/scala/leon/purescala/Extractors.scala +++ b/src/main/scala/leon/purescala/Extractors.scala @@ -28,7 +28,7 @@ object Extractors { case SetCardinality(t) => Some((Seq(t), (es: Seq[Expr]) => SetCardinality(es.head))) case CaseClassSelector(cd, e, sel) => - Some((Seq(e), (es: Seq[Expr]) => CaseClassSelector(cd, es.head, sel))) + Some((Seq(e), (es: Seq[Expr]) => caseClassSelector(cd, es.head, sel))) case IsInstanceOf(cd, e) => Some((Seq(e), (es: Seq[Expr]) => IsInstanceOf(cd, es.head))) case TupleSelect(t, i) => diff --git a/src/main/scala/leon/purescala/TypeOps.scala b/src/main/scala/leon/purescala/TypeOps.scala index 49529c035..4b77380bf 100644 --- a/src/main/scala/leon/purescala/TypeOps.scala +++ b/src/main/scala/leon/purescala/TypeOps.scala @@ -270,7 +270,7 @@ object TypeOps { CaseClass(tpeSub(ct).asInstanceOf[CaseClassType], args.map(srec)).copiedFrom(cc) case cc @ CaseClassSelector(ct, e, sel) => - CaseClassSelector(tpeSub(ct).asInstanceOf[CaseClassType], srec(e), sel).copiedFrom(cc) + caseClassSelector(tpeSub(ct).asInstanceOf[CaseClassType], srec(e), sel).copiedFrom(cc) case cc @ IsInstanceOf(ct, e) => IsInstanceOf(tpeSub(ct).asInstanceOf[ClassType], srec(e)).copiedFrom(cc) diff --git a/src/main/scala/leon/synthesis/rules/ADTDual.scala b/src/main/scala/leon/synthesis/rules/ADTDual.scala index 3f68419ed..ddf37760e 100644 --- a/src/main/scala/leon/synthesis/rules/ADTDual.scala +++ b/src/main/scala/leon/synthesis/rules/ADTDual.scala @@ -19,10 +19,10 @@ case object ADTDual extends NormalizingRule("ADTDual") { val (toRemove, toAdd) = exprs.collect { case eq @ Equals(cc @ CaseClass(ct, args), e) if (variablesOf(e) -- as).isEmpty && (variablesOf(cc) & xs).nonEmpty => - (eq, IsInstanceOf(ct, e) +: (ct.fields zip args).map{ case (vd, ex) => Equals(ex, CaseClassSelector(ct, e, vd.id)) } ) + (eq, IsInstanceOf(ct, e) +: (ct.fields zip args).map{ case (vd, ex) => Equals(ex, caseClassSelector(ct, e, vd.id)) } ) case eq @ Equals(e, cc @ CaseClass(ct, args)) if (variablesOf(e) -- as).isEmpty && (variablesOf(cc) & xs).nonEmpty => - (eq, IsInstanceOf(ct, e) +: (ct.fields zip args).map{ case (vd, ex) => Equals(ex, CaseClassSelector(ct, e, vd.id)) } ) + (eq, IsInstanceOf(ct, e) +: (ct.fields zip args).map{ case (vd, ex) => Equals(ex, caseClassSelector(ct, e, vd.id)) } ) }.unzip if (toRemove.nonEmpty) { diff --git a/src/main/scala/leon/synthesis/rules/ADTSplit.scala b/src/main/scala/leon/synthesis/rules/ADTSplit.scala index 38bc0455a..111757777 100644 --- a/src/main/scala/leon/synthesis/rules/ADTSplit.scala +++ b/src/main/scala/leon/synthesis/rules/ADTSplit.scala @@ -74,7 +74,7 @@ case object ADTSplit extends Rule("ADT Split.") { val cases = for ((sol, (cct, problem, pattern)) <- sols zip subInfo) yield { if (sol.pre != BooleanLiteral(true)) { val substs = (for ((field,arg) <- cct.fields zip problem.as ) yield { - (arg, CaseClassSelector(cct, id.toVariable, field.id)) + (arg, caseClassSelector(cct, id.toVariable, field.id)) }).toMap globalPre ::= and(IsInstanceOf(cct, Variable(id)), replaceFromIDs(substs, sol.pre)) } else { diff --git a/src/main/scala/leon/synthesis/rules/DetupleInput.scala b/src/main/scala/leon/synthesis/rules/DetupleInput.scala index ae5f731b9..ed25fc4b3 100644 --- a/src/main/scala/leon/synthesis/rules/DetupleInput.scala +++ b/src/main/scala/leon/synthesis/rules/DetupleInput.scala @@ -23,7 +23,7 @@ case object DetupleInput extends NormalizingRule("Detuple In") { case cct @ CaseClassType(ccd, _) if !ccd.isAbstract => val newIds = cct.fields.map{ vd => FreshIdentifier(vd.id.name, vd.getType, true) } - val map = (ccd.fields zip newIds).map{ case (vd, nid) => nid -> CaseClassSelector(cct, Variable(id), vd.id) }.toMap + val map = (ccd.fields zip newIds).map{ case (vd, nid) => nid -> caseClassSelector(cct, Variable(id), vd.id) }.toMap (newIds.toList, CaseClass(cct, newIds.map(Variable)), map) diff --git a/src/main/scala/leon/termination/ChainComparator.scala b/src/main/scala/leon/termination/ChainComparator.scala index bbee27665..2ed2a08de 100644 --- a/src/main/scala/leon/termination/ChainComparator.scala +++ b/src/main/scala/leon/termination/ChainComparator.scala @@ -34,7 +34,7 @@ trait ChainComparator { self : StructuralSize => def rec(tpe: TypeTree): Set[Expr => Expr] = tpe match { case ContainerType(cct, fields) => powerSetToFunSet(fields.zipWithIndex.flatMap { case ((fieldId, fieldTpe), index) => - rec(fieldTpe).map(recons => (e: Expr) => recons(CaseClassSelector(cct, e, fieldId))) + rec(fieldTpe).map(recons => (e: Expr) => recons(caseClassSelector(cct, e, fieldId))) }) case TupleType(tpes) => powerSetToFunSet((0 until tpes.length).flatMap { case index => @@ -50,7 +50,7 @@ trait ChainComparator { self : StructuralSize => def rec(tpe: TypeTree): Set[Expr => Expr] = tpe match { case ContainerType(cct, fields) => fields.zipWithIndex.flatMap { case ((fieldId, fieldTpe), index) => - rec(fieldTpe).map(recons => (e: Expr) => recons(CaseClassSelector(cct, e, fieldId))) + rec(fieldTpe).map(recons => (e: Expr) => recons(caseClassSelector(cct, e, fieldId))) }.toSet case TupleType(tpes) => (0 until tpes.length).flatMap { case index => diff --git a/src/main/scala/leon/utils/InliningPhase.scala b/src/main/scala/leon/utils/InliningPhase.scala index 62e62a9e2..dea267601 100644 --- a/src/main/scala/leon/utils/InliningPhase.scala +++ b/src/main/scala/leon/utils/InliningPhase.scala @@ -5,10 +5,10 @@ package leon.utils import leon._ import purescala.Definitions._ import purescala.Expressions._ -import purescala.Types._ import purescala.TypeOps._ import purescala.ExprOps._ import purescala.DefOps._ +import purescala.Constructors.caseClassSelector object InliningPhase extends TransformationPhase { @@ -27,14 +27,14 @@ object InliningPhase extends TransformationPhase { def simplifyImplicitClass(e: Expr) = e match { case CaseClassSelector(cct, cc: CaseClass, id) => - Some(CaseClassSelector(cct, cc, id)) + Some(caseClassSelector(cct, cc, id)) case _ => None } def simplify(e: Expr) = { - fixpoint(postMap(simplifyImplicitClass _))(e) + fixpoint(postMap(simplifyImplicitClass))(e) } for (fd <- p.definedFunctions) { diff --git a/src/main/scala/leon/verification/InductionTactic.scala b/src/main/scala/leon/verification/InductionTactic.scala index c780c647f..7bbd332c6 100644 --- a/src/main/scala/leon/verification/InductionTactic.scala +++ b/src/main/scala/leon/verification/InductionTactic.scala @@ -23,7 +23,7 @@ class InductionTactic(vctx: VerificationContext) extends DefaultTactic(vctx) { private def selectorsOfParentType(parentType: ClassType, cct: CaseClassType, expr: Expr): Seq[Expr] = { val childrenOfSameType = cct.fields.filter(_.getType == parentType) for (field <- childrenOfSameType) yield { - CaseClassSelector(cct, expr, field.id) + caseClassSelector(cct, expr, field.id) } } -- GitLab