diff --git a/src/it/scala/inox/solvers/unrolling/InductiveUnrollingSuite.scala b/src/it/scala/inox/solvers/unrolling/InductiveUnrollingSuite.scala index 5484297fd89083e9f2779377fa04fbdd8e1d76d8..753145736961353b597aa080960bc6760a6eecf0 100644 --- a/src/it/scala/inox/solvers/unrolling/InductiveUnrollingSuite.scala +++ b/src/it/scala/inox/solvers/unrolling/InductiveUnrollingSuite.scala @@ -15,9 +15,9 @@ class InductiveUnrollingSuite extends SolvingTestSuite { val head = FreshIdentifier("head") val tail = FreshIdentifier("tail") - val List = mkAbstractClass(listID)("A")(Seq(consID, nilID)) - val Nil = mkCaseClass(nilID)("A")(Some(listID))(_ => Seq.empty) - val Cons = mkCaseClass(consID)("A")(Some(listID)) { + val List = mkSort(listID)("A")(Seq(consID, nilID)) + val Nil = mkConstructor(nilID)("A")(Some(listID))(_ => Seq.empty) + val Cons = mkConstructor(consID)("A")(Some(listID)) { case Seq(aT) => Seq(ValDef(head, aT), ValDef(tail, T(listID)(aT))) } diff --git a/src/it/scala/inox/solvers/unrolling/SimpleUnrollingSuite.scala b/src/it/scala/inox/solvers/unrolling/SimpleUnrollingSuite.scala index dcc1bcaa286dbf294054d3f967c1d52c014c1a66..9912471668ef692f76026ea1e29d7679d1b4404d 100644 --- a/src/it/scala/inox/solvers/unrolling/SimpleUnrollingSuite.scala +++ b/src/it/scala/inox/solvers/unrolling/SimpleUnrollingSuite.scala @@ -17,9 +17,9 @@ class SimpleUnrollingSuite extends SolvingTestSuite { val head = FreshIdentifier("head") val tail = FreshIdentifier("tail") - val List = mkAbstractClass(listID)("A")(Seq(consID, nilID)) - val Nil = mkCaseClass(nilID)("A")(Some(listID))(_ => Seq.empty) - val Cons = mkCaseClass(consID)("A")(Some(listID)) { + val List = mkSort(listID)("A")(Seq(consID, nilID)) + val Nil = mkConstructor(nilID)("A")(Some(listID))(_ => Seq.empty) + val Cons = mkConstructor(consID)("A")(Some(listID)) { case Seq(aT) => Seq(ValDef(head, aT), ValDef(tail, T(listID)(aT))) } @@ -48,7 +48,7 @@ class SimpleUnrollingSuite extends SolvingTestSuite { SimpleSolverAPI(SolverFactory.default(program)).solveSAT(clause) match { case SatWithModel(model) => symbols.valuateWithModel(model)(vd) match { - case CaseClass(ClassType(`consID`, Seq(IntegerType)), _) => + case ADT(ADTType(`consID`, Seq(IntegerType)), _) => // success!! case r => fail("Unexpected valuation: " + r) @@ -69,7 +69,7 @@ class SimpleUnrollingSuite extends SolvingTestSuite { SimpleSolverAPI(SolverFactory.default(program)).solveSAT(clause) match { case SatWithModel(model) => symbols.valuateWithModel(model)(vd) match { - case CaseClass(ClassType(`nilID`, Seq(`tp`)), Seq()) => + case ADT(ADTType(`nilID`, Seq(`tp`)), Seq()) => // success!! case r => fail("Unexpected valuation: " + r) diff --git a/src/main/scala/inox/Program.scala b/src/main/scala/inox/Program.scala index 457efc2d6cd40bb1240d09fbfd562391c2c9f283..94bd3b43e09544fe30ee6a4e0b8e5ca146daa39f 100644 --- a/src/main/scala/inox/Program.scala +++ b/src/main/scala/inox/Program.scala @@ -30,10 +30,10 @@ trait Program { val ctx = Program.this.ctx } - def extend(functions: Seq[trees.FunDef] = Seq.empty, classes: Seq[trees.ClassDef] = Seq.empty): + def extend(functions: Seq[trees.FunDef] = Seq.empty, adts: Seq[trees.ADTDefinition] = Seq.empty): Program { val trees: Program.this.trees.type } = new Program { val trees: Program.this.trees.type = Program.this.trees - val symbols = Program.this.symbols.extend(functions, classes) + val symbols = Program.this.symbols.extend(functions, adts) val ctx = Program.this.ctx } } diff --git a/src/main/scala/inox/ast/Constructors.scala b/src/main/scala/inox/ast/Constructors.scala index 664408f30f0e4f8dfb6308026c1652fb5256fb93..f854c5ba976da2082235bd642ff40925bc02bc48 100644 --- a/src/main/scala/inox/ast/Constructors.scala +++ b/src/main/scala/inox/ast/Constructors.scala @@ -86,12 +86,12 @@ trait Constructors { /** Simplifies the provided case class selector. * @see [[purescala.Expressions.CaseClassSelector]] */ - def caseClassSelector(caseClass: Expr, selector: Identifier): Expr = { - caseClass match { - case CaseClass(ct, fields) if !ct.tcd.hasInvariant => - fields(ct.tcd.cd.asInstanceOf[CaseClassDef].selectorID2Index(selector)) + def adtSelector(adt: Expr, selector: Identifier): Expr = { + adt match { + case a @ ADT(tp, fields) if !tp.getADT.hasInvariant => + fields(tp.getADT.toConstructor.definition.selectorID2Index(selector)) case _ => - CaseClassSelector(caseClass, selector) + ADTSelector(adt, selector) } } @@ -265,7 +265,7 @@ trait Constructors { } /** $encodingof expr.asInstanceOf[tpe], returns `expr` if it already is of type `tpe`. */ - def asInstOf(expr: Expr, tpe: ClassType) = { + def asInstOf(expr: Expr, tpe: ADTType) = { if (symbols.isSubtypeOf(expr.getType, tpe)) { expr } else { @@ -273,7 +273,7 @@ trait Constructors { } } - def isInstOf(expr: Expr, tpe: ClassType) = { + def isInstOf(expr: Expr, tpe: ADTType) = { if (symbols.isSubtypeOf(expr.getType, tpe)) { BooleanLiteral(true) } else { diff --git a/src/main/scala/inox/ast/DSL.scala b/src/main/scala/inox/ast/DSL.scala index 52a867c9634d8b29625c27ddf217680b2619ae10..691b0a35bd904fccfbbd1ce770a8f09e1fe426fb 100644 --- a/src/main/scala/inox/ast/DSL.scala +++ b/src/main/scala/inox/ast/DSL.scala @@ -14,7 +14,7 @@ import scala.language.implicitConversions * (in the form of a function) to which the newly created identifiers will be passed. * 2) No implicit conversions are provided where there would be ambiguity. * This refers mainly to Identifiers, which can be transformed to - * [[inox.ast.Types.ClassType]] or [[inox.ast.Expressions.FunctionInvocation]] or ... . + * [[inox.ast.Types.ADTType]] or [[inox.ast.Expressions.FunctionInvocation]] or ... . * Instead one-letter constructors are provided. */ trait DSL { @@ -60,12 +60,12 @@ trait DSL { // Misc. - def getField(selector: Identifier) = CaseClassSelector(e, selector) + def getField(selector: Identifier) = ADTSelector(e, selector) def apply(es: Expr*) = Application(e, es.toSeq) - def isInstOf(tp: ClassType) = IsInstanceOf(e, tp) - def asInstOf(tp: ClassType) = AsInstanceOf(e, tp) + def isInstOf(tp: ADTType) = IsInstanceOf(e, tp) + def asInstOf(tp: ADTType) = AsInstanceOf(e, tp) } // Literals @@ -108,8 +108,8 @@ trait DSL { FunctionInvocation(fd.id, Seq.empty, args.toSeq) } - implicit class CaseClassToExpr(ct: ClassType) { - def apply(args: Expr*) = CaseClass(ct, args) + implicit class ADTTypeToExpr(adt: ADTType) { + def apply(args: Expr*) = ADT(adt, args) } implicit class GenValue(tp: TypeParameter) { @@ -187,11 +187,11 @@ trait DSL { /* Types */ def T(tp1: Type, tp2: Type, tps: Type*) = TupleType(tp1 :: tp2 :: tps.toList) - def T(id: Identifier) = new IdToClassType(id) + def T(id: Identifier) = new IdToADTType(id) def T(str: String) = TypeParameter.fresh(str) - class IdToClassType(id: Identifier) { - def apply(tps: Type*) = ClassType(id, tps.toSeq) + class IdToADTType(id: Identifier) { + def apply(tps: Type*) = ADTType(id, tps.toSeq) } implicit class FunctionTypeBuilder(to: Type) { @@ -242,22 +242,22 @@ trait DSL { new FunDef(id, tParamDefs, params, retType, body, Set()) } - def mkAbstractClass(id: Identifier) - (tParamNames: String*) - (children: Seq[Identifier]) = { + def mkSort(id: Identifier) + (tParamNames: String*) + (cons: Seq[Identifier]) = { val tParams = tParamNames map TypeParameter.fresh val tParamDefs = tParams map TypeParameterDef - new AbstractClassDef(id, tParamDefs, children, Set()) + new ADTSort(id, tParamDefs, cons, Set()) } - def mkCaseClass(id: Identifier) - (tParamNames: String*) - (parent: Option[Identifier]) - (fieldBuilder: Seq[TypeParameter] => Seq[ValDef]) = { + def mkConstructor(id: Identifier) + (tParamNames: String*) + (sort: Option[Identifier]) + (fieldBuilder: Seq[TypeParameter] => Seq[ValDef]) = { val tParams = tParamNames map TypeParameter.fresh val tParamDefs = tParams map TypeParameterDef val fields = fieldBuilder(tParams) - new CaseClassDef(id, tParamDefs, parent, fields, Set()) + new ADTConstructor(id, tParamDefs, sort, fields, Set()) } // TODO: Remove this at some point diff --git a/src/main/scala/inox/ast/Definitions.scala b/src/main/scala/inox/ast/Definitions.scala index 1ebe9ba2787af164d625652af09145b09bd0f68b..57bf9d6ddb0c985f31e92b96ed89906fd808e43e 100644 --- a/src/main/scala/inox/ast/Definitions.scala +++ b/src/main/scala/inox/ast/Definitions.scala @@ -23,7 +23,7 @@ trait Definitions { self: Trees => abstract class LookupException(id: Identifier, what: String) extends Exception("Lookup failed for " + what + " with symbol " + id) case class FunctionLookupException(id: Identifier) extends LookupException(id, "function") - case class ClassLookupException(id: Identifier) extends LookupException(id, "class") + case class ADTLookupException(id: Identifier) extends LookupException(id, "class") case class NotWellFormedException(id: Identifier, s: Symbols) extends Exception(s"$id not well formed in $s") @@ -93,7 +93,7 @@ trait Definitions { self: Trees => with Constructors with Paths { self0: Symbols => - val classes: Map[Identifier, ClassDef] + val adts: Map[Identifier, ADTDefinition] val functions: Map[Identifier, FunDef] protected val trees: self.type = self @@ -109,13 +109,13 @@ trait Definitions { self: Trees => // for some mysterious reason. implicit def implicitSymbols: this.type = this - private val typedClassCache: MutableMap[(Identifier, Seq[Type]), Option[TypedClassDef]] = MutableMap.empty - def lookupClass(id: Identifier): Option[ClassDef] = classes.get(id) - def lookupClass(id: Identifier, tps: Seq[Type]): Option[TypedClassDef] = - typedClassCache.getOrElseUpdate(id -> tps, lookupClass(id).map(_.typed(tps))) + private val typedADTCache: MutableMap[(Identifier, Seq[Type]), Option[TypedADTDefinition]] = MutableMap.empty + def lookupADT(id: Identifier): Option[ADTDefinition] = adts.get(id) + def lookupADT(id: Identifier, tps: Seq[Type]): Option[TypedADTDefinition] = + typedADTCache.getOrElseUpdate(id -> tps, lookupADT(id).map(_.typed(tps))) - def getClass(id: Identifier): ClassDef = lookupClass(id).getOrElse(throw ClassLookupException(id)) - def getClass(id: Identifier, tps: Seq[Type]): TypedClassDef = lookupClass(id, tps).getOrElse(throw ClassLookupException(id)) + def getADT(id: Identifier): ADTDefinition = lookupADT(id).getOrElse(throw ADTLookupException(id)) + def getADT(id: Identifier, tps: Seq[Type]): TypedADTDefinition = lookupADT(id, tps).getOrElse(throw ADTLookupException(id)) private val typedFunctionCache: MutableMap[(Identifier, Seq[Type]), Option[TypedFunDef]] = MutableMap.empty def lookupFunction(id: Identifier): Option[FunDef] = functions.get(id) @@ -127,13 +127,13 @@ trait Definitions { self: Trees => override def toString: String = asString(PrinterOptions.fromSymbols(this, InoxContext.printNames)) override def asString(implicit opts: PrinterOptions): String = { - classes.map(p => PrettyPrinter(p._2, opts)).mkString("\n\n") + + adts.map(p => PrettyPrinter(p._2, opts)).mkString("\n\n") + "\n\n-----------\n\n" + functions.map(p => PrettyPrinter(p._2, opts)).mkString("\n\n") } def transform(t: TreeTransformer): Symbols - def extend(functions: Seq[FunDef] = Seq.empty, classes: Seq[ClassDef] = Seq.empty): Symbols + def extend(functions: Seq[FunDef] = Seq.empty, adts: Seq[ADTDefinition] = Seq.empty): Symbols } case class TypeParameterDef(tp: TypeParameter) extends Definition { @@ -141,11 +141,11 @@ trait Definitions { self: Trees => val id = tp.id } - /** A trait that represents flags that annotate a ClassDef with different attributes */ - sealed trait ClassFlag + /** A trait that represents flags that annotate an ADTDefinition with different attributes */ + sealed trait ADTFlag - object ClassFlag { - def fromName(name: String, args: Seq[Option[Any]]): ClassFlag = Annotation(name, args) + object ADTFlag { + def fromName(name: String, args: Seq[Option[Any]]): ADTFlag = Annotation(name, args) } /** A trait that represents flags that annotate a FunDef with different attributes */ @@ -159,27 +159,25 @@ trait Definitions { self: Trees => } // Compiler annotations given in the source code as @annot - case class Annotation(annot: String, args: Seq[Option[Any]]) extends FunctionFlag with ClassFlag - /** Denotes that this class is refined by invariant ''id'' */ - case class HasADTInvariant(id: Identifier) extends ClassFlag + case class Annotation(annot: String, args: Seq[Option[Any]]) extends FunctionFlag with ADTFlag + /** Denotes that this adt is refined by invariant ''id'' */ + case class HasADTInvariant(id: Identifier) extends ADTFlag // Is inlined case object IsInlined extends FunctionFlag - /** Represents a class definition (either an abstract- or a case-class). - * In functional terms, abstract classes are ADTs and case classes are ADT constructors. - */ - sealed trait ClassDef extends Definition { + /** Represents an ADT definition (either the ADT sort or a constructor). */ + sealed trait ADTDefinition extends Definition { val id: Identifier val tparams: Seq[TypeParameterDef] - val flags: Set[ClassFlag] + val flags: Set[ADTFlag] def annotations: Set[String] = extAnnotations.keySet def extAnnotations: Map[String, Seq[Option[Any]]] = flags.collect { case Annotation(s, args) => s -> args }.toMap /** The root of the class hierarchy */ - def root(implicit s: Symbols): ClassDef + def root(implicit s: Symbols): ADTDefinition - /** An invariant that refines this [[ClassDef]] */ + /** An invariant that refines this [[ADTDefinition]] */ def invariant(implicit s: Symbols): Option[FunDef] = { val rt = root if (rt ne this) rt.invariant @@ -188,35 +186,36 @@ trait Definitions { self: Trees => def hasInvariant(implicit s: Symbols): Boolean = invariant.isDefined - val isAbstract: Boolean + val isSort: Boolean def typeArgs = tparams map (_.tp) - def typed(tps: Seq[Type])(implicit s: Symbols): TypedClassDef - def typed(implicit s: Symbols): TypedClassDef + def typed(tps: Seq[Type])(implicit s: Symbols): TypedADTDefinition + def typed(implicit s: Symbols): TypedADTDefinition } - /** Abstract classes / ADTs */ - class AbstractClassDef(val id: Identifier, - val tparams: Seq[TypeParameterDef], - val children: Seq[Identifier], - val flags: Set[ClassFlag]) extends ClassDef { - val isAbstract = true - - def descendants(implicit s: Symbols): Seq[CaseClassDef] = children - .map(id => s.getClass(id) match { - case ccd: CaseClassDef => ccd + /** Algebraic datatype sort definition. + * An ADT sort is linked to a series of constructors (ADTConstructor) for this particular sort. */ + class ADTSort(val id: Identifier, + val tparams: Seq[TypeParameterDef], + val cons: Seq[Identifier], + val flags: Set[ADTFlag]) extends ADTDefinition { + val isSort = true + + def constructors(implicit s: Symbols): Seq[ADTConstructor] = cons + .map(id => s.getADT(id) match { + case cons: ADTConstructor => cons case _ => throw NotWellFormedException(id, s) }) def isInductive(implicit s: Symbols): Boolean = { - def induct(tpe: Type, seen: Set[ClassDef]): Boolean = tpe match { - case ct: ClassType => - val tcd = ct.lookupClass.getOrElse(throw ClassLookupException(ct.id)) - val root = tcd.cd.root - seen(root) || (tcd match { - case tccd: TypedCaseClassDef => - tccd.fields.exists(vd => induct(vd.getType, seen + root)) + def induct(tpe: Type, seen: Set[ADTDefinition]): Boolean = tpe match { + case adt: ADTType => + val tadt = adt.lookupADT.getOrElse(throw ADTLookupException(adt.id)) + val root = tadt.definition.root + seen(root) || (tadt match { + case tcons: TypedADTConstructor => + tcons.fields.exists(vd => induct(vd.getType, seen + root)) case _ => false }) case TupleType(tpes) => @@ -224,18 +223,18 @@ trait Definitions { self: Trees => case _ => false } - if (this == root && !this.isAbstract) false - else descendants.exists { ccd => - ccd.fields.exists(vd => induct(vd.getType, Set(root))) + if (this == root && !this.isSort) false + else constructors.exists { cons => + cons.fields.exists(vd => induct(vd.getType, Set(root))) } } - def root(implicit s: Symbols): ClassDef = this + def root(implicit s: Symbols): ADTDefinition = this - def typed(implicit s: Symbols): TypedAbstractClassDef = typed(tparams.map(_.tp)) - def typed(tps: Seq[Type])(implicit s: Symbols): TypedAbstractClassDef = { + def typed(implicit s: Symbols): TypedADTSort = typed(tparams.map(_.tp)) + def typed(tps: Seq[Type])(implicit s: Symbols): TypedADTSort = { require(tps.length == tparams.length) - TypedAbstractClassDef(this, tps) + TypedADTSort(this, tps) } } @@ -243,17 +242,17 @@ trait Definitions { self: Trees => * * @param id * @param tparams - * @param parent + * @param sort * @param fields * @param flags */ - class CaseClassDef(val id: Identifier, - val tparams: Seq[TypeParameterDef], - val parent: Option[Identifier], - val fields: Seq[ValDef], - val flags: Set[ClassFlag]) extends ClassDef { + class ADTConstructor(val id: Identifier, + val tparams: Seq[TypeParameterDef], + val sort: Option[Identifier], + val fields: Seq[ValDef], + val flags: Set[ADTFlag]) extends ADTDefinition { - val isAbstract = false + val isSort = false /** Returns the index of the field with the specified id */ def selectorID2Index(id: Identifier) : Int = { val index = fields.indexWhere(_.id == id) @@ -266,57 +265,57 @@ trait Definitions { self: Trees => } else index } - def root(implicit s: Symbols): ClassDef = parent.map(id => s.getClass(id).root).getOrElse(this) + def root(implicit s: Symbols): ADTDefinition = sort.map(id => s.getADT(id).root).getOrElse(this) - def typed(implicit s: Symbols): TypedCaseClassDef = typed(tparams.map(_.tp)) - def typed(tps: Seq[Type])(implicit s: Symbols): TypedCaseClassDef = { + def typed(implicit s: Symbols): TypedADTConstructor = typed(tparams.map(_.tp)) + def typed(tps: Seq[Type])(implicit s: Symbols): TypedADTConstructor = { require(tps.length == tparams.length) - TypedCaseClassDef(this, tps) + TypedADTConstructor(this, tps) } } - /** Represents a [[ClassDef]] whose type parameters have been instantiated to ''tps'' */ - sealed abstract class TypedClassDef extends Tree { - val cd: ClassDef + /** Represents an [[ADTDefinition]] whose type parameters have been instantiated to ''tps'' */ + sealed abstract class TypedADTDefinition extends Tree { + val definition: ADTDefinition val tps: Seq[Type] implicit val symbols: Symbols - val id: Identifier = cd.id + val id: Identifier = definition.id /** The root of the class hierarchy */ - lazy val root: TypedClassDef = cd.root.typed(tps) - lazy val invariant: Option[TypedFunDef] = cd.invariant.map(_.typed(tps)) + lazy val root: TypedADTDefinition = definition.root.typed(tps) + lazy val invariant: Option[TypedFunDef] = definition.invariant.map(_.typed(tps)) lazy val hasInvariant: Boolean = invariant.isDefined - def toType = ClassType(cd.id, tps) + def toType = ADTType(definition.id, tps) - def toCase = this match { - case tccd: TypedCaseClassDef => tccd - case _ => throw NotWellFormedException(cd.id, symbols) + def toConstructor = this match { + case tcons: TypedADTConstructor => tcons + case _ => throw NotWellFormedException(definition.id, symbols) } - def toAbstract = this match { - case accd: TypedAbstractClassDef => accd - case _ => throw NotWellFormedException(cd.id, symbols) + def toSort = this match { + case tsort: TypedADTSort => tsort + case _ => throw NotWellFormedException(definition.id, symbols) } } - /** Represents an [[AbstractClassDef]] whose type parameters have been instantiated to ''tps'' */ - case class TypedAbstractClassDef(cd: AbstractClassDef, tps: Seq[Type])(implicit val symbols: Symbols) extends TypedClassDef { - def descendants: Seq[TypedCaseClassDef] = cd.descendants.map(_.typed(tps)) + /** Represents an [[ADTSort]] whose type parameters have been instantiated to ''tps'' */ + case class TypedADTSort(definition: ADTSort, tps: Seq[Type])(implicit val symbols: Symbols) extends TypedADTDefinition { + def constructors: Seq[TypedADTConstructor] = definition.constructors.map(_.typed(tps)) } - /** Represents a [[CaseClassDef]] whose type parameters have been instantiated to ''tps'' */ - case class TypedCaseClassDef(cd: CaseClassDef, tps: Seq[Type])(implicit val symbols: Symbols) extends TypedClassDef { + /** Represents an [[ADTConstructor]] whose type parameters have been instantiated to ''tps'' */ + case class TypedADTConstructor(definition: ADTConstructor, tps: Seq[Type])(implicit val symbols: Symbols) extends TypedADTDefinition { lazy val fields: Seq[ValDef] = { - val tmap = (cd.typeArgs zip tps).toMap - if (tmap.isEmpty) cd.fields - else cd.fields.map(vd => vd.copy(tpe = symbols.instantiateType(vd.getType, tmap))) + val tmap = (definition.typeArgs zip tps).toMap + if (tmap.isEmpty) definition.fields + else definition.fields.map(vd => vd.copy(tpe = symbols.instantiateType(vd.getType, tmap))) } lazy val fieldsTypes = fields.map(_.tpe) - lazy val parent: Option[TypedAbstractClassDef] = cd.parent.map(id => symbols.getClass(id) match { - case acd: AbstractClassDef => TypedAbstractClassDef(acd, tps) + lazy val sort: Option[TypedADTSort] = definition.sort.map(id => symbols.getADT(id) match { + case sort: ADTSort => TypedADTSort(sort, tps) case _ => throw NotWellFormedException(id, symbols) }) } diff --git a/src/main/scala/inox/ast/Expressions.scala b/src/main/scala/inox/ast/Expressions.scala index b14cb7cac96901990bd7b81f91a1c0741fdb67e1..377f5ea22cd66794f337b4daffefd29e6ad9548d 100644 --- a/src/main/scala/inox/ast/Expressions.scala +++ b/src/main/scala/inox/ast/Expressions.scala @@ -230,17 +230,17 @@ trait Expressions { self: Trees => * @param ct The case class name and inherited attributes * @param args The arguments of the case class */ - case class CaseClass(ct: ClassType, args: Seq[Expr]) extends Expr with CachingTyped { - protected def computeType(implicit s: Symbols): Type = ct.lookupClass match { - case Some(tcd: TypedCaseClassDef) => checkParamTypes(args.map(_.getType), tcd.fieldsTypes, ct) + case class ADT(adt: ADTType, args: Seq[Expr]) extends Expr with CachingTyped { + protected def computeType(implicit s: Symbols): Type = adt.lookupADT match { + case Some(tcons: TypedADTConstructor) => checkParamTypes(args.map(_.getType), tcons.fieldsTypes, adt) case _ => Untyped } } /** $encodingof `.isInstanceOf[...]` */ - case class IsInstanceOf(expr: Expr, classType: ClassType) extends Expr with CachingTyped { + case class IsInstanceOf(expr: Expr, tpe: ADTType) extends Expr with CachingTyped { protected def computeType(implicit s: Symbols): Type = - if (s.typesCompatible(expr.getType, classType)) BooleanType else Untyped + if (s.typesCompatible(expr.getType, tpe)) BooleanType else Untyped } /** $encodingof `expr.asInstanceOf[tpe]` @@ -248,7 +248,7 @@ trait Expressions { self: Trees => * Introduced by matchToIfThenElse to transform match-cases to type-correct * if bodies. */ - case class AsInstanceOf(expr: Expr, tpe: ClassType) extends Expr with CachingTyped { + case class AsInstanceOf(expr: Expr, tpe: ADTType) extends Expr with CachingTyped { protected def computeType(implicit s: Symbols): Type = if (s.typesCompatible(tpe, expr.getType)) tpe else Untyped } @@ -258,25 +258,24 @@ trait Expressions { self: Trees => * If you are not sure about the requirement you should use * [[Constructors#caseClassSelector purescala's constructor caseClassSelector]] */ - case class CaseClassSelector(caseClass: Expr, selector: Identifier) extends Expr with CachingTyped { + case class ADTSelector(adt: Expr, selector: Identifier) extends Expr with CachingTyped { - def selectorIndex(implicit s: Symbols) = classIndex.map(_._2).getOrElse { + def selectorIndex(implicit s: Symbols) = constructor.map(_.definition.selectorID2Index(selector)).getOrElse { throw FatalError("Not well formed selector: " + this) } - def classIndex(implicit s: Symbols) = caseClass.getType match { - case ct: ClassType => ct.lookupClass match { - case Some(tcd: TypedCaseClassDef) => - Some(tcd, tcd.cd.selectorID2Index(selector)) + def constructor(implicit s: Symbols) = adt.getType match { + case adt: ADTType => adt.lookupADT.flatMap { + case tcons: TypedADTConstructor => Some(tcons) case _ => None } case _ => None } - protected def computeType(implicit s: Symbols): Type = classIndex match { - case Some((tcd, ind)) => tcd.fieldsTypes(ind) - case _ => Untyped - } + protected def computeType(implicit s: Symbols): Type = constructor.map { tccd => + val index = tccd.definition.selectorID2Index(selector) + tccd.fieldsTypes(index) + }.getOrElse(Untyped) } /** $encodingof `... == ...` */ diff --git a/src/main/scala/inox/ast/Extractors.scala b/src/main/scala/inox/ast/Extractors.scala index 81c60e0dc56b32a15562ee4a3bd18db9c3717ecd..70f6eb66e1b88b494422b75645ead189a5f33210 100644 --- a/src/main/scala/inox/ast/Extractors.scala +++ b/src/main/scala/inox/ast/Extractors.scala @@ -20,12 +20,12 @@ trait TreeDeconstructor { (Seq(e), Seq(), (es, tps) => t.StringLength(es.head)) case s.SetCardinality(e) => (Seq(e), Seq(), (es, tps) => t.SetCardinality(es.head)) - case s.CaseClassSelector(e, sel) => - (Seq(e), Seq(), (es, tps) => t.CaseClassSelector(es.head, sel)) + case s.ADTSelector(e, sel) => + (Seq(e), Seq(), (es, tps) => t.ADTSelector(es.head, sel)) case s.IsInstanceOf(e, ct) => - (Seq(e), Seq(ct), (es, tps) => t.IsInstanceOf(es.head, tps.head.asInstanceOf[t.ClassType])) + (Seq(e), Seq(ct), (es, tps) => t.IsInstanceOf(es.head, tps.head.asInstanceOf[t.ADTType])) case s.AsInstanceOf(e, ct) => - (Seq(e), Seq(ct), (es, tps) => t.AsInstanceOf(es.head, tps.head.asInstanceOf[t.ClassType])) + (Seq(e), Seq(ct), (es, tps) => t.AsInstanceOf(es.head, tps.head.asInstanceOf[t.ADTType])) case s.TupleSelect(e, i) => (Seq(e), Seq(), (es, tps) => t.TupleSelect(es.head, i)) case s.Lambda(args, body) => ( @@ -111,7 +111,7 @@ trait TreeDeconstructor { case s.FunctionInvocation(id, tps, args) => (args, tps, (es, tps) => t.FunctionInvocation(id, tps, es)) case s.Application(caller, args) => (caller +: args, Seq(), (es, tps) => t.Application(es.head, es.tail)) - case s.CaseClass(ct, args) => (args, Seq(ct), (es, tps) => t.CaseClass(tps.head.asInstanceOf[t.ClassType], es)) + case s.ADT(adt, args) => (args, Seq(adt), (es, tps) => t.ADT(tps.head.asInstanceOf[t.ADTType], es)) case s.And(args) => (args, Seq(), (es, _) => t.And(es)) case s.Or(args) => (args, Seq(), (es, _) => t.Or(es)) case s.SubString(t1, a, b) => (t1 :: a :: b :: Nil, Seq(), (es, _) => t.SubString(es(0), es(1), es(2))) @@ -176,7 +176,7 @@ trait TreeDeconstructor { } def deconstruct(tp: s.Type): (Seq[s.Type], Seq[t.Type] => t.Type) = tp match { - case s.ClassType(ccd, ts) => (ts, ts => t.ClassType(ccd, ts)) + case s.ADTType(d, ts) => (ts, ts => t.ADTType(d, ts)) case s.TupleType(ts) => (ts, t.TupleType) case s.SetType(tp) => (Seq(tp), ts => t.SetType(ts.head)) case s.BagType(tp) => (Seq(tp), ts => t.BagType(ts.head)) diff --git a/src/main/scala/inox/ast/Printers.scala b/src/main/scala/inox/ast/Printers.scala index 513443fb02784554a9067a2f7b43807da5e81584..080a5c47751793b5cda6cf7d83b8fe3c370dfc66 100644 --- a/src/main/scala/inox/ast/Printers.scala +++ b/src/main/scala/inox/ast/Printers.scala @@ -108,8 +108,8 @@ trait Printers { case Choose(res, pred) => p"choose(($res) => $pred)" - case e@CaseClass(cct, args) => - p"$cct($args)" + case e @ ADT(adt, args) => + p"$adt($args)" case And(exprs) => optP { p"${nary(exprs, " && ")}" @@ -155,7 +155,7 @@ trait Printers { case TupleSelect(t, i) => p"$t._$i" case AsInstanceOf(e, ct) => p"""$e.asInstanceOf[$ct]""" case IsInstanceOf(e, cct) => p"$e.isInstanceOf[$cct]" - case CaseClassSelector(e, id) => p"$e.$id" + case ADTSelector(e, id) => p"$e.$id" case FunctionInvocation(id, tps, args) => p"$id${nary(tps, ", ", "[", "]")}" @@ -284,21 +284,21 @@ trait Printers { case MapType(ft, tt) => p"Map[$ft, $tt]" case TupleType(tpes) => p"($tpes)" case FunctionType(fts, tt) => p"($fts) => $tt" - case c: ClassType => - p"${c.id}${nary(c.tps, ", ", "[", "]")}" + case adt: ADTType => + p"${adt.id}${nary(adt.tps, ", ", "[", "]")}" // Definitions - case acd: AbstractClassDef => - p"abstract class ${acd.id}${nary(acd.tparams, ", ", "[", "]")}" + case sort: ADTSort => + p"abstract class ${sort.id}${nary(sort.tparams, ", ", "[", "]")}" - case ccd: CaseClassDef => - p"case class ${ccd.id}" - p"${nary(ccd.tparams, ", ", "[", "]")}" - p"(${ccd.fields})" + case cons: ADTConstructor => + p"case class ${cons.id}" + p"${nary(cons.tparams, ", ", "[", "]")}" + p"(${cons.fields})" - ccd.parent.foreach { par => + cons.sort.foreach { s => // Remember child and parents tparams are simple bijection - p" extends $par${nary(ccd.tparams, ", ", "[", "]")}" + p" extends $s${nary(cons.tparams, ", ", "[", "]")}" } case fd: FunDef => @@ -396,7 +396,7 @@ trait Printers { case (pa: PrettyPrintable, _) => pa.printRequiresParentheses(within) case (_, None) => false case (_, Some( - _: Definition | _: Let | _: IfExpr | _: CaseClass | _: Lambda | _: Choose | _: Tuple + _: Definition | _: Let | _: IfExpr | _: ADT | _: Lambda | _: Choose | _: Tuple )) => false case (ex: StringConcat, Some(_: StringConcat)) => false case (_, Some(_: FunctionInvocation)) => false diff --git a/src/main/scala/inox/ast/SymbolOps.scala b/src/main/scala/inox/ast/SymbolOps.scala index f26a9be3a25824e2e4f5d52aeb88173d35043381..27e66692b3f1aae900cdb5463d3efbeb5b910784 100644 --- a/src/main/scala/inox/ast/SymbolOps.scala +++ b/src/main/scala/inox/ast/SymbolOps.scala @@ -57,7 +57,7 @@ trait SymbolOps { self: TypeOps => def step(e: Expr): Option[Expr] = e match { case Not(t) => Some(not(t)) case UMinus(t) => Some(uminus(t)) - case CaseClassSelector(e, sel) => Some(caseClassSelector(e, sel)) + case ADTSelector(e, sel) => Some(adtSelector(e, sel)) case AsInstanceOf(e, ct) => Some(asInstOf(e, ct)) case Equals(t1, t2) => Some(equality(t1, t2)) case Implies(t1, t2) => Some(implies(t1, t2)) @@ -78,8 +78,8 @@ trait SymbolOps { self: TypeOps => case TupleSelect(Let(id, v, b), ts) => Some(Let(id, v, tupleSelect(b, ts, true))) - case CaseClassSelector(cc: CaseClass, id) => - Some(caseClassSelector(cc, id).copiedFrom(e)) + case ADTSelector(cc: ADT, id) => + Some(adtSelector(cc, id).copiedFrom(e)) case IfExpr(c, thenn, elze) if thenn == elze => Some(thenn) @@ -274,31 +274,31 @@ trait SymbolOps { self: TypeOps => defs.foldRight(bd){ case ((vd, e), body) => Let(vd, e, body) } } - private def hasInstance(tcd: TypedClassDef): Boolean = { - val recursive = Set(tcd, tcd.root) + private def hasInstance(tadt: TypedADTDefinition): Boolean = { + val recursive = Set(tadt, tadt.root) - def isRecursive(tpe: Type, seen: Set[TypedClassDef]): Boolean = tpe match { - case ct: ClassType => - val ctcd = ct.tcd - if (seen(ctcd)) { + def isRecursive(tpe: Type, seen: Set[TypedADTDefinition]): Boolean = tpe match { + case adt: ADTType => + val tadt = adt.getADT + if (seen(tadt)) { false - } else if (recursive(ctcd)) { + } else if (recursive(tadt)) { true - } else ctcd match { - case tcc: TypedCaseClassDef => - tcc.fieldsTypes.exists(isRecursive(_, seen + ctcd)) + } else tadt match { + case tcons: TypedADTConstructor => + tcons.fieldsTypes.exists(isRecursive(_, seen + tadt)) case _ => false } case _ => false } - val tcds = tcd match { - case tacd: TypedAbstractClassDef => tacd.descendants - case tccd: TypedCaseClassDef => Seq(tccd) + val tconss = tadt match { + case tsort: TypedADTSort => tsort.constructors + case tcons: TypedADTConstructor => Seq(tcons) } - tcds.exists { tcd => - tcd.fieldsTypes.forall(tpe => !isRecursive(tpe, Set.empty)) + tconss.exists { tcons => + tcons.fieldsTypes.forall(tpe => !isRecursive(tpe, Set.empty)) } } @@ -316,17 +316,17 @@ trait SymbolOps { self: TypeOps => case MapType(fromType, toType) => FiniteMap(Seq(), simplestValue(toType), fromType) case TupleType(tpes) => Tuple(tpes.map(simplestValue)) - case ct @ ClassType(id, tps) => - val tcd = ct.lookupClass.getOrElse(throw ClassLookupException(id)) - if (!hasInstance(tcd)) scala.sys.error(ct +" does not seem to be well-founded") + case adt @ ADTType(id, tps) => + val tadt = adt.getADT + if (!hasInstance(tadt)) scala.sys.error(adt +" does not seem to be well-founded") - val tccd @ TypedCaseClassDef(cd, tps) = tcd match { - case tacd: TypedAbstractClassDef => - tacd.descendants.filter(hasInstance(_)).sortBy(_.fields.size).head - case tccd: TypedCaseClassDef => tccd + val tcons @ TypedADTConstructor(cons, tps) = tadt match { + case tsort: TypedADTSort => + tsort.constructors.filter(hasInstance(_)).sortBy(_.fields.size).head + case tcons: TypedADTConstructor => tcons } - CaseClass(ClassType(cd.id, tps), tccd.fieldsTypes.map(simplestValue)) + ADT(tcons.toType, tcons.fieldsTypes.map(simplestValue)) case tp: TypeParameter => GenericValue(tp, 0) @@ -376,12 +376,11 @@ trait SymbolOps { self: TypeOps => prev flatMap { case seq => Stream(seq, seq :+ curr) } }.flatten cartesianProduct(seqs, valuesOf(to)) map { case (values, default) => FiniteMap(values, default, from) } - case ct: ClassType => ct.lookupClass match { - case Some(tccd: TypedCaseClassDef) => - cartesianProduct(tccd.fieldsTypes map valuesOf) map (CaseClass(ct, _)) - case Some(accd: TypedAbstractClassDef) => - interleave(accd.descendants.map(tccd => valuesOf(tccd.toType))) - case None => throw ClassLookupException(ct.id) + case adt: ADTType => adt.getADT match { + case tcons: TypedADTConstructor => + cartesianProduct(tcons.fieldsTypes map valuesOf) map (ADT(adt, _)) + case tsort: TypedADTSort => + interleave(tsort.constructors.map(tcons => valuesOf(tcons.toType))) } } } @@ -587,7 +586,7 @@ trait SymbolOps { self: TypeOps => /** Returns true if expr is a value of type t */ def isValueOfType(e: Expr, t: Type): Boolean = { def unWrapSome(s: Expr) = s match { - case CaseClass(_, Seq(a)) => a + case ADT(_, Seq(a)) => a case _ => s } (e, t) match { @@ -610,9 +609,9 @@ trait SymbolOps { self: TypeOps => case (FiniteMap(elems, default, kt), MapType(from, to)) => (kt == from) < s"$kt not equal to $from" && (default.getType == to) < s"${default.getType} not equal to $to" && (elems forall (kv => isValueOfType(kv._1, from) < s"${kv._1} not a value of type $from" && isValueOfType(unWrapSome(kv._2), to) < s"${unWrapSome(kv._2)} not a value of type ${to}" )) - case (CaseClass(ct, args), ct2: ClassType) => - isSubtypeOf(ct, ct2) < s"$ct not a subtype of $ct2" && - ((args zip ct.tcd.toCase.fieldsTypes) forall (argstyped => isValueOfType(argstyped._1, argstyped._2) < s"${argstyped._1} not a value of type ${argstyped._2}" )) + case (ADT(adt, args), adt2: ADTType) => + isSubtypeOf(adt, adt2) < s"$adt not a subtype of $adt2" && + ((args zip adt.getADT.toConstructor.fieldsTypes) forall (argstyped => isValueOfType(argstyped._1, argstyped._2) < s"${argstyped._1} not a value of type ${argstyped._2}" )) case (Lambda(valdefs, body), FunctionType(ins, out)) => variablesOf(e).isEmpty && (valdefs zip ins forall (vdin => isSubtypeOf(vdin._2, vdin._1.getType) < s"${vdin._2} is not a subtype of ${vdin._1.getType}")) && diff --git a/src/main/scala/inox/ast/TypeOps.scala b/src/main/scala/inox/ast/TypeOps.scala index b24d8e1ae57bf20655f9253246425cae671d8dc7..467c8953999dbdb6c8d5ff0d40f914ac923ce27e 100644 --- a/src/main/scala/inox/ast/TypeOps.scala +++ b/src/main/scala/inox/ast/TypeOps.scala @@ -74,33 +74,33 @@ trait TypeOps { case (_, _: TypeParameter) => None - case (ct1: ClassType, ct2: ClassType) => - val cd1 = ct1.tcd.cd - val cd2 = ct2.tcd.cd - val bound: Option[ClassDef] = if (allowSub) { - val an1 = Seq(cd1, cd1.root) - val an2 = Seq(cd2, cd2.root) + case (adt1: ADTType, adt2: ADTType) => + val def1 = adt1.getADT.definition + val def2 = adt2.getADT.definition + val bound: Option[ADTDefinition] = if (allowSub) { + val an1 = Seq(def1, def1.root) + val an2 = Seq(def2, def2.root) if (isLub) { (an1.reverse zip an2.reverse) - .takeWhile(((_: ClassDef) == (_: ClassDef)).tupled) + .takeWhile(((_: ADTDefinition) == (_: ADTDefinition)).tupled) .lastOption.map(_._1) } else { // Lower bound - if(an1.contains(cd2)) Some(cd1) - else if (an2.contains(cd1)) Some(cd2) + if(an1.contains(def2)) Some(def1) + else if (an2.contains(def1)) Some(def2) else None } } else { - (cd1 == cd2).option(cd1) + (def1 == def2).option(def1) } for { - cd <- bound - (subs, map) <- flatten((ct1.tps zip ct2.tps).map { case (tp1, tp2) => + adtDef <- bound + (subs, map) <- flatten((adt1.tps zip adt2.tps).map { case (tp1, tp2) => // Class types are invariant! typeBound(tp1, tp2, isLub, allowSub = false) }) - } yield (cd.typed(subs).toType, map) + } yield (adtDef.typed(subs).toType, map) case (FunctionType(from1, to1), FunctionType(from2, to2)) => if (from1.size != from2.size) None @@ -201,7 +201,7 @@ trait TypeOps { } def bestRealType(t: Type): Type = t match { - case (c: ClassType) => c.tcd.root.toType + case (adt: ADTType) => adt.getADT.root.toType case NAryType(tps, builder) => builder(tps.map(bestRealType)) } @@ -237,27 +237,27 @@ trait TypeOps { t <- typeCardinality(to) f <- typeCardinality(from) } yield Math.pow(t + 1, f).toInt - case ct: ClassType => ct.tcd match { - case tccd: TypedCaseClassDef => - cards(tccd.fieldsTypes).map(_.product) + case adt: ADTType => adt.getADT match { + case tcons: TypedADTConstructor => + cards(tcons.fieldsTypes).map(_.product) - case accd: TypedAbstractClassDef => + case tsort: TypedADTSort => val possibleChildTypes = utils.fixpoint((tpes: Set[Type]) => { tpes.flatMap(tpe => Set(tpe) ++ (tpe match { - case ct: ClassType => ct.tcd match { - case tccd: TypedCaseClassDef => tccd.fieldsTypes - case tacd: TypedAbstractClassDef => (Set(tacd) ++ tacd.descendants).map(_.toType) + case adt: ADTType => adt.getADT match { + case tcons: TypedADTConstructor => tcons.fieldsTypes + case tsort: TypedADTSort => (Set(tsort) ++ tsort.constructors).map(_.toType) } case _ => Set.empty }) ) - })(accd.descendants.map(_.toType).toSet) + })(tsort.constructors.map(_.toType).toSet) - if (possibleChildTypes(accd.toType)) { + if (possibleChildTypes(tsort.toType)) { None } else { - cards(accd.descendants.map(_.toType)).map(_.sum) + cards(tsort.constructors.map(_.toType)).map(_.sum) } } case _ => None @@ -269,11 +269,11 @@ trait TypeOps { def rec(tpe: Type): Unit = if (!dependencies.isDefinedAt(tpe)) { val next = tpe match { - case ct: ClassType => ct.tcd match { - case accd: TypedAbstractClassDef => - accd.descendants.map(_.toType) - case tccd: TypedCaseClassDef => - tccd.fieldsTypes ++ tccd.parent.map(_.toType) + case adt: ADTType => adt.getADT match { + case tsort: TypedADTSort => + tsort.constructors.map(_.toType) + case tcons: TypedADTConstructor => + tcons.fieldsTypes ++ tcons.sort.map(_.toType) } case NAryType(tps, _) => tps diff --git a/src/main/scala/inox/ast/Types.scala b/src/main/scala/inox/ast/Types.scala index 92fb6c242372c5658fd7143d050612d7c91ed8a6..6a5210f5a235091637b53cc2ef76d7d9add5ddf5 100644 --- a/src/main/scala/inox/ast/Types.scala +++ b/src/main/scala/inox/ast/Types.scala @@ -80,9 +80,9 @@ trait Types { self: Trees => case class MapType(from: Type, to: Type) extends Type case class FunctionType(from: Seq[Type], to: Type) extends Type - case class ClassType(id: Identifier, tps: Seq[Type]) extends Type { - def lookupClass(implicit s: Symbols): Option[TypedClassDef] = s.lookupClass(id, tps) - def tcd(implicit s: Symbols): TypedClassDef = s.getClass(id, tps) + case class ADTType(id: Identifier, tps: Seq[Type]) extends Type { + def lookupADT(implicit s: Symbols): Option[TypedADTDefinition] = s.lookupADT(id, tps) + def getADT(implicit s: Symbols): TypedADTDefinition = s.getADT(id, tps) } /** NAryType extractor to extract any Type in a consistent way. diff --git a/src/main/scala/inox/datagen/SolverDataGen.scala b/src/main/scala/inox/datagen/SolverDataGen.scala index 18b8a39636525b244b75bcf945acecba9b20bcfd..d13360915f17bd6ec6d19937bb09c2258ca37232 100644 --- a/src/main/scala/inox/datagen/SolverDataGen.scala +++ b/src/main/scala/inox/datagen/SolverDataGen.scala @@ -25,13 +25,13 @@ trait SolverDataGen extends DataGenerator { self => FreeableIterator.empty } else { - var cdToId: Map[ClassDef, Identifier] = Map.empty + var cdToId: Map[ADTDefinition, Identifier] = Map.empty var fds: Seq[FunDef] = Seq.empty def sizeFor(of: Expr): Expr = bestRealType(of.getType) match { - case ct: ClassType => - val tcd = ct.tcd - val root = tcd.cd.root + case adt: ADTType => + val tadt = adt.getADT + val root = tadt.definition.root val id = cdToId.getOrElse(root, { import dsl._ @@ -39,29 +39,29 @@ trait SolverDataGen extends DataGenerator { self => val tparams = root.tparams.map(_.freshen) cdToId += root -> id - def typed(ccd: CaseClassDef) = TypedCaseClassDef(ccd, tparams.map(_.tp)) - def sizeOfCaseClass(ccd: CaseClassDef, expr: Expr): Expr = - typed(ccd).fields.foldLeft(IntegerLiteral(1): Expr) { (i, f) => + def typed(cons: ADTConstructor) = TypedADTConstructor(cons, tparams.map(_.tp)) + def sizeOfConstructor(cons: ADTConstructor, expr: Expr): Expr = + typed(cons).fields.foldLeft(IntegerLiteral(1): Expr) { (i, f) => plus(i, sizeFor(expr.getField(f.id))) } - val x = Variable(FreshIdentifier("x", true), tcd.root.toType) + val x = Variable(FreshIdentifier("x", true), tadt.root.toType) fds +:= new FunDef(id, tparams, Seq(x.toVal), IntegerType, root match { - case acd: AbstractClassDef => - val (child +: rest) = acd.descendants - def sizeOf(ccd: CaseClassDef) = sizeOfCaseClass(ccd, x.asInstOf(typed(ccd).toType)) + case sort: ADTSort => + val (child +: rest) = sort.constructors + def sizeOf(cons: ADTConstructor) = sizeOfConstructor(cons, x.asInstOf(typed(cons).toType)) rest.foldLeft(sizeOf(child)) { (elze, ccd) => if_ (x.isInstOf(typed(ccd).toType)) { sizeOf(ccd) } else_ { elze } } - case ccd: CaseClassDef => - sizeOfCaseClass(ccd, x) + case cons: ADTConstructor => + sizeOfConstructor(cons, x) }, Set.empty) id }) - FunctionInvocation(id, ct.tps, Seq(of)) + FunctionInvocation(id, adt.tps, Seq(of)) case tt @ TupleType(tps) => val exprs = for ((t,i) <- tps.zipWithIndex) yield { diff --git a/src/main/scala/inox/evaluators/RecursiveEvaluator.scala b/src/main/scala/inox/evaluators/RecursiveEvaluator.scala index 59a0eee042710090766fcc669657f6800e04fc04..0ecd015b65b633773157ab868dd2ae7d185bdf19 100644 --- a/src/main/scala/inox/evaluators/RecursiveEvaluator.scala +++ b/src/main/scala/inox/evaluators/RecursiveEvaluator.scala @@ -140,10 +140,10 @@ trait RecursiveEvaluator case _ => BooleanLiteral(lv == rv) } - case CaseClass(cct, args) => - val cc = CaseClass(cct, args.map(e)) - cct.tcd.invariant.foreach { tfd => - val v = Variable(FreshIdentifier("x", true), cct) + case ADT(adt, args) => + val cc = ADT(adt, args.map(e)) + adt.getADT.invariant.foreach { tfd => + val v = Variable(FreshIdentifier("x", true), adt) e(tfd.applied(Seq(v)))(rctx.withNewVar(v.toVal, cc), gctx) match { case BooleanLiteral(true) => case BooleanLiteral(false) => @@ -166,10 +166,10 @@ trait RecursiveEvaluator val le = e(expr) BooleanLiteral(isSubtypeOf(le.getType, ct)) - case CaseClassSelector(expr, sel) => + case ADTSelector(expr, sel) => e(expr) match { - case CaseClass(ct, args) => args(ct.tcd.cd match { - case ccd: CaseClassDef => ccd.selectorID2Index(sel) + case ADT(adt, args) => args(adt.getADT.definition match { + case cons: ADTConstructor => cons.selectorID2Index(sel) case _ => throw RuntimeError("Unexpected case class type") }) case le => throw EvalError(typeErrorMsg(le, expr.getType)) diff --git a/src/main/scala/inox/grammars/BaseGrammars.scala b/src/main/scala/inox/grammars/BaseGrammars.scala index b5800c3334e69bd04af96a776d87b955db4f9973..47d88d0445b27d0cd8bb347209002f233fb9e420 100644 --- a/src/main/scala/inox/grammars/BaseGrammars.scala +++ b/src/main/scala/inox/grammars/BaseGrammars.scala @@ -53,16 +53,16 @@ trait BaseGrammars { self: GrammarsUniverse => nonTerminal(stps, Tuple, Constructor(isTerminal = false)) ) - case ct: ClassType => - ct.tcd match { - case cct: TypedCaseClassDef => + case adt: ADTType => + adt.getADT match { + case tcons: TypedADTConstructor => List( - nonTerminal(cct.fields.map(_.getType), CaseClass(ct, _), tagOf(cct.cd) ) + nonTerminal(tcons.fields.map(_.getType), ADT(adt, _), tagOf(tcons.definition) ) ) - case act: TypedAbstractClassDef => - act.descendants.map { cct => - nonTerminal(cct.fields.map(_.getType), CaseClass(cct.toType, _), tagOf(cct.cd) ) + case tsort: TypedADTSort => + tsort.constructors.map { tcons => + nonTerminal(tcons.fields.map(_.getType), ADT(tcons.toType, _), tagOf(tcons.definition) ) } } diff --git a/src/main/scala/inox/grammars/Tags.scala b/src/main/scala/inox/grammars/Tags.scala index a6de1b1fbb6c4d16f544711c088bcb7f84f25c6a..919b40d58bb62d1d83dd8cec48a3f6dfb1d9aff4 100644 --- a/src/main/scala/inox/grammars/Tags.scala +++ b/src/main/scala/inox/grammars/Tags.scala @@ -5,7 +5,7 @@ package grammars trait Tags { self: GrammarsUniverse => import program._ - import trees.CaseClassDef + import trees.ADTConstructor import trees.FunDef /** A class for tags that tag a [[ProductionRule]] with the kind of expression in generates. */ @@ -61,6 +61,6 @@ trait Tags { self: GrammarsUniverse => case _ => false } - def tagOf(cct: CaseClassDef) = Constructor(cct.fields.isEmpty) + def tagOf(cons: ADTConstructor) = Constructor(cons.fields.isEmpty) //def tagOf(fd: FunDef, isSafe: Boolean) = FunCall(fd.methodOwner.isDefined, isSafe) } diff --git a/src/main/scala/inox/grammars/ValueGrammars.scala b/src/main/scala/inox/grammars/ValueGrammars.scala index 81055c3baa639e0a5410d7c3369bf7e72a9cf0ec..cf3f60cc8aceaf66b607308cfe4f61d3ca3fca1c 100644 --- a/src/main/scala/inox/grammars/ValueGrammars.scala +++ b/src/main/scala/inox/grammars/ValueGrammars.scala @@ -59,16 +59,16 @@ trait ValueGrammars { self: GrammarsUniverse => nonTerminal(stps, Tuple, Constructor(stps.isEmpty)) ) - case ct: ClassType => - ct.tcd match { - case cct: TypedCaseClassDef => + case adt: ADTType => + adt.getADT match { + case tcons: TypedADTConstructor => List( - nonTerminal(cct.fields.map(_.getType), CaseClass(cct.toType, _), tagOf(cct.cd)) + nonTerminal(tcons.fields.map(_.getType), ADT(adt, _), tagOf(tcons.definition)) ) - case act: TypedAbstractClassDef => - act.descendants.map { cct => - nonTerminal(cct.fields.map(_.getType), CaseClass(cct.toType, _), tagOf(cct.cd)) + case tsort: TypedADTSort => + tsort.constructors.map { tcons => + nonTerminal(tcons.fields.map(_.getType), ADT(tcons.toType, _), tagOf(tcons.definition)) } } diff --git a/src/main/scala/inox/grammars/aspects/SimilarToAspects.scala b/src/main/scala/inox/grammars/aspects/SimilarToAspects.scala index b7041cabf47020abbf3779217c575fdc7fa81ac6..9c226caf5c886eb041d1fac0104c44de5aae7dd5 100644 --- a/src/main/scala/inox/grammars/aspects/SimilarToAspects.scala +++ b/src/main/scala/inox/grammars/aspects/SimilarToAspects.scala @@ -118,17 +118,17 @@ trait SimilarToAspects { self: GrammarsUniverse => } val ccVariations: Prods = e match { - case CaseClass(cct, args) => - val resType = cct.tcd.toCase + case ADT(adt, args) => + val resType = adt.getADT.toConstructor val neighbors = resType.root match { - case acd: TypedAbstractClassDef => - acd.descendants diff Seq(resType) - case ccd: TypedCaseClassDef => + case tsort: TypedADTSort => + tsort.constructors diff Seq(resType) + case tcons: TypedADTConstructor => Nil } for (scct <- neighbors if scct.fieldsTypes == resType.fieldsTypes) yield { - term(CaseClass(scct.toType, args)) + term(ADT(scct.toType, args)) } case _ => Nil diff --git a/src/main/scala/inox/grammars/utils/Helpers.scala b/src/main/scala/inox/grammars/utils/Helpers.scala index 8e7d37cf8702b14d60d76e6c03c625569fefe0a3..08e91da1a71434423bfc1fe53c3d9529b79e2ec7 100644 --- a/src/main/scala/inox/grammars/utils/Helpers.scala +++ b/src/main/scala/inox/grammars/utils/Helpers.scala @@ -44,7 +44,7 @@ trait Helpers { self: GrammarsUniverse => def terminatingCalls(prog: Program, wss: Seq[FunctionInvocation], pc: Path, tpe: Option[Type], introduceHoles: Boolean): List[(FunctionInvocation, Option[Set[Identifier]])] = { def subExprsOf(expr: Expr, v: EVariable): Option[(EVariable, Expr)] = expr match { - case CaseClassSelector(r, _) => subExprsOf(r, v) + case ADTSelector(r, _) => subExprsOf(r, v) case (r: EVariable) if leastUpperBound(r.getType, v.getType).isDefined => Some(r -> v) case _ => None } @@ -53,7 +53,7 @@ trait Helpers { self: GrammarsUniverse => val one = IntegerLiteral(1) val knownSmallers = (pc.bindings.flatMap { // @nv: used to check both Equals(id, selector) and Equals(selector, id) - case (id, s @ CaseClassSelector(r, _)) => subExprsOf(s, id.toVariable) + case (id, s @ ADTSelector(r, _)) => subExprsOf(s, id.toVariable) case _ => None } ++ pc.conditions.flatMap { case GreaterThan(v: EVariable, `z`) => @@ -68,8 +68,8 @@ trait Helpers { self: GrammarsUniverse => }).groupBy(_._1).mapValues(v => v.map(_._2)) def argsSmaller(e: Expr, tpe: Type): Seq[Expr] = e match { - case CaseClass(cct, args) => - (cct.tcd.asInstanceOf[TypedCaseClassDef].fields.map(_.getType) zip args).collect { + case ADT(adt, args) => + (adt.getADT.toConstructor.fields.map(_.getType) zip args).collect { case (t, e) if isSubtypeOf(t, tpe) => List(e) ++ argsSmaller(e, tpe) }.flatten diff --git a/src/main/scala/inox/package.scala b/src/main/scala/inox/package.scala index 649f8d2d1461e4e4bfad0875d34f43e3563b815d..9ae8863f52c1f22eeced0c52c11ccf111c477dff 100644 --- a/src/main/scala/inox/package.scala +++ b/src/main/scala/inox/package.scala @@ -26,12 +26,12 @@ package object inox { object InoxProgram { def apply(ictx: InoxContext, functions: Seq[inox.trees.FunDef], - classes: Seq[inox.trees.ClassDef]): InoxProgram = new Program { + adts: Seq[inox.trees.ADTDefinition]): InoxProgram = new Program { val trees = inox.trees val ctx = ictx val symbols = new inox.trees.Symbols( functions.map(fd => fd.id -> fd).toMap, - classes.map(cd => cd.id -> cd).toMap) + adts.map(cd => cd.id -> cd).toMap) } def apply(ictx: InoxContext, sym: inox.trees.Symbols): InoxProgram = new Program { @@ -50,7 +50,7 @@ package object inox { class Symbols( val functions: Map[Identifier, FunDef], - val classes: Map[Identifier, ClassDef] + val adts: Map[Identifier, ADTDefinition] ) extends AbstractSymbols { def transform(t: TreeTransformer) = new Symbols( @@ -61,19 +61,19 @@ package object inox { t.transform(fd.returnType), t.transform(fd.fullBody), fd.flags)), - classes.mapValues { - case acd: AbstractClassDef => acd - case ccd: CaseClassDef => new CaseClassDef( - ccd.id, - ccd.tparams, - ccd.parent, - ccd.fields.map(t.transform), - ccd.flags) + adts.mapValues { + case sort: ADTSort => sort + case cons: ADTConstructor => new ADTConstructor( + cons.id, + cons.tparams, + cons.sort, + cons.fields.map(t.transform), + cons.flags) }) - def extend(functions: Seq[FunDef] = Seq.empty, classes: Seq[ClassDef] = Seq.empty) = new Symbols( + def extend(functions: Seq[FunDef] = Seq.empty, adts: Seq[ADTDefinition] = Seq.empty) = new Symbols( this.functions ++ functions.map(fd => fd.id -> fd), - this.classes ++ classes.map(cd => cd.id -> cd) + this.adts ++ adts.map(cd => cd.id -> cd) ) } diff --git a/src/main/scala/inox/solvers/ADTManagers.scala b/src/main/scala/inox/solvers/ADTManagers.scala index ca0f74363b83e9478defebb53c9fe5636ddac9c9..22d8617d81236126cfdf9e3b04971a7282d17596 100644 --- a/src/main/scala/inox/solvers/ADTManagers.scala +++ b/src/main/scala/inox/solvers/ADTManagers.scala @@ -54,15 +54,15 @@ trait ADTManagers { for (scc <- sccs.map(scc => scc.map(bestRealType))) { val declarations = (for (tpe <- scc if !declared(tpe)) yield (tpe match { - case ct: ClassType => - val (root, deps) = ct.tcd.root match { - case tacd: TypedAbstractClassDef => - (tacd, tacd.descendants) - case tccd: TypedCaseClassDef => - (tccd, Seq(tccd)) + case adt: ADTType => + val (root, deps) = adt.getADT.root match { + case tsort: TypedADTSort => + (tsort, tsort.constructors) + case tcons: TypedADTConstructor => + (tcons, Seq(tcons)) } - Some(ct -> DataType(freshId(root.id), deps.map { tccd => + Some(adt -> DataType(freshId(root.id), deps.map { tccd => Constructor(freshId(tccd.id), tccd.toType, tccd.fields.map(vd => freshId(vd.id) -> vd.tpe)) })) diff --git a/src/main/scala/inox/solvers/smtlib/SMTLIBTarget.scala b/src/main/scala/inox/solvers/smtlib/SMTLIBTarget.scala index c072fda15b720353b417c83ae4194a731b0498c7..6b6dcb0634411b0b21491df4e207642d6233e93d 100644 --- a/src/main/scala/inox/solvers/smtlib/SMTLIBTarget.scala +++ b/src/main/scala/inox/solvers/smtlib/SMTLIBTarget.scala @@ -142,7 +142,7 @@ trait SMTLIBTarget extends Interruptible with ADTManagers { /* Helper functions */ protected def normalizeType(t: Type): Type = t match { - case ct: ClassType => ct.lookupClass.get.root.toType + case adt: ADTType => adt.getADT.root.toType case tt: TupleType => tupleTypeWrap(tt.bases.map(normalizeType)) case _ => t } @@ -185,7 +185,7 @@ trait SMTLIBTarget extends Interruptible with ADTManagers { case FunctionType(from, to) => Ints.IntSort() - case tpe @ (_: ClassType | _: TupleType | _: TypeParameter | UnitType) => + case tpe @ (_: ADTType | _: TupleType | _: TypeParameter | UnitType) => declareStructuralSort(tpe) case other => @@ -307,21 +307,21 @@ trait SMTLIBTarget extends Interruptible with ADTManagers { Seq(), newBody) - case s @ CaseClassSelector(e, id) => + case s @ ADTSelector(e, id) => declareSort(e.getType) - val selector = selectors.toB((e.getType, s.classIndex.get._2)) + val selector = selectors.toB((e.getType, s.selectorIndex)) FunctionApplication(selector, Seq(toSMT(e))) - case AsInstanceOf(expr, cct) => + case AsInstanceOf(expr, adt) => toSMT(expr) - case io @ IsInstanceOf(e, cct) => - declareSort(cct) - val cases = cct.lookupClass match { - case Some(act: TypedAbstractClassDef) => - act.descendants - case Some(cct: TypedCaseClassDef) => - Seq(cct) + case io @ IsInstanceOf(e, adt) => + declareSort(adt) + val cases = adt.lookupADT match { + case Some(tsort: TypedADTSort) => + tsort.constructors + case Some(tcons: TypedADTConstructor) => + Seq(tcons) case None => unsupported(io, "isInstanceOf on non-class") } @@ -335,9 +335,9 @@ trait SMTLIBTarget extends Interruptible with ADTManagers { SmtLibConstructors.or(oneOf.map(FunctionApplication(_, Seq(es: Term))))) } - case CaseClass(cct, es) => - declareSort(cct) - val constructor = constructors.toB(cct) + case ADT(adt, es) => + declareSort(adt) + val constructor = constructors.toB(adt) if (es.isEmpty) { constructor } else { @@ -627,29 +627,29 @@ trait SMTLIBTarget extends Interruptible with ADTManagers { case (SimpleSymbol(s), _) if constructors.containsB(s) => constructors.toA(s) match { - case ct: ClassType => - CaseClass(ct, Nil) + case adt: ADTType => + ADT(adt, Nil) case t => unsupported(t, "woot? for a single constructor for non-case-object") } case (FunctionApplication(SimpleSymbol(s), List(e)), _) if testers.containsB(s) => testers.toA(s) match { - case cct: ClassType => - IsInstanceOf(fromSMT(e, cct), cct) + case adt: ADTType => + IsInstanceOf(fromSMT(e, adt), adt) } case (FunctionApplication(SimpleSymbol(s), List(e)), _) if selectors.containsB(s) => selectors.toA(s) match { - case (ct: ClassType, i) => - CaseClassSelector(fromSMT(e, ct), ct.lookupClass.get.toCase.fields(i).id) + case (adt: ADTType, i) => + ADTSelector(fromSMT(e, adt), adt.getADT.toConstructor.fields(i).id) } case (FunctionApplication(SimpleSymbol(s), args), _) if constructors.containsB(s) => constructors.toA(s) match { - case ct: ClassType => - val rargs = args.zip(ct.lookupClass.get.toCase.fields.map(_.getType)).map(fromSMT) - CaseClass(ct, rargs) + case adt: ADTType => + val rargs = args.zip(adt.getADT.toConstructor.fields.map(_.getType)).map(fromSMT) + ADT(adt, rargs) case tt: TupleType => val rargs = args.zip(tt.bases).map(fromSMT) diff --git a/src/main/scala/inox/solvers/theories/BagEncoder.scala b/src/main/scala/inox/solvers/theories/BagEncoder.scala index 57f38edf416de918e4fc07aff768eabd7fb2866d..b1be97cc57fb640e790e83ae23c4de849801c42c 100644 --- a/src/main/scala/inox/solvers/theories/BagEncoder.scala +++ b/src/main/scala/inox/solvers/theories/BagEncoder.scala @@ -11,8 +11,9 @@ trait BagEncoder extends TheoryEncoder { val BagID = FreshIdentifier("Bag") val f = FreshIdentifier("f") - val bagClass = mkCaseClass(BagID)("T")(None)( - { case Seq(aT) => Seq(ValDef(f, aT =>: IntegerType)) }) + val bagADT = mkConstructor(BagID)("T")(None) { + case Seq(aT) => Seq(ValDef(f, aT =>: IntegerType)) + } val Bag = T(BagID) @@ -117,7 +118,7 @@ trait BagEncoder extends TheoryEncoder { import targetProgram._ override def transform(e: Expr): Expr = e match { - case cc @ CaseClass(ClassType(BagID, Seq(tpe)), Seq(Lambda(Seq(vd), body))) => + case cc @ ADT(ADTType(BagID, Seq(tpe)), Seq(Lambda(Seq(vd), body))) => val Variable = vd.toVariable def rec(expr: Expr): Seq[(Expr, Expr)] = expr match { case IfExpr(Equals(Variable, k), v, elze) => rec(elze) :+ (transform(k) -> transform(v)) @@ -127,7 +128,7 @@ trait BagEncoder extends TheoryEncoder { FiniteBag(rec(body), transform(tpe)).copiedFrom(e) - case cc @ CaseClass(ClassType(BagID, Seq(tpe)), args) => + case cc @ ADT(ADTType(BagID, Seq(tpe)), args) => throw new Unsupported(e, "Unexpected argument to bag constructor") case FunctionInvocation(AddID, Seq(_), Seq(bag, elem)) => @@ -152,7 +153,7 @@ trait BagEncoder extends TheoryEncoder { } override def transform(tpe: Type): Type = tpe match { - case ClassType(BagID, Seq(base)) => BagType(transform(base)).copiedFrom(tpe) + case ADTType(BagID, Seq(base)) => BagType(transform(base)).copiedFrom(tpe) case _ => super.transform(tpe) } } diff --git a/src/main/scala/inox/solvers/theories/StringEncoder.scala b/src/main/scala/inox/solvers/theories/StringEncoder.scala index 7199069aa47a789a00827226e75380f075e24eca..58573b4f92ca98e2f134d9cdee8b3612470630a4 100644 --- a/src/main/scala/inox/solvers/theories/StringEncoder.scala +++ b/src/main/scala/inox/solvers/theories/StringEncoder.scala @@ -17,16 +17,16 @@ trait StringEncoder extends TheoryEncoder { val head = FreshIdentifier("head") val tail = FreshIdentifier("tail") - val stringClass = new AbstractClassDef(stringID, Seq.empty, Seq(stringConsID, stringNilID), Set.empty) - val stringNilClass = new CaseClassDef(stringNilID, Seq.empty, Some(stringID), Seq.empty, Set.empty) - val stringConsClass = new CaseClassDef(stringConsID, Seq.empty, Some(stringID), Seq( + val stringADT = new ADTSort(stringID, Seq.empty, Seq(stringConsID, stringNilID), Set.empty) + val stringNilADT = new ADTConstructor(stringNilID, Seq.empty, Some(stringID), Seq.empty, Set.empty) + val stringConsADT = new ADTConstructor(stringConsID, Seq.empty, Some(stringID), Seq( ValDef(head, CharType), - ValDef(tail, ClassType(stringID, Seq.empty)) + ValDef(tail, ADTType(stringID, Seq.empty)) ), Set.empty) - val String : ClassType = T(stringID)() - val StringNil : ClassType = T(stringNilID)() - val StringCons : ClassType = T(stringConsID)() + val String : ADTType = T(stringID)() + val StringNil : ADTType = T(stringNilID)() + val StringCons : ADTType = T(stringConsID)() val SizeID = FreshIdentifier("size") val Size = mkFunDef(SizeID)()(_ => ( @@ -84,8 +84,8 @@ trait StringEncoder extends TheoryEncoder { private val stringBijection = new Bijection[String, Expr]() private def convertToString(e: Expr): String = stringBijection.cachedA(e)(e match { - case CaseClass(StringCons, Seq(CharLiteral(c), l)) => (if(c < 31) (c + 97).toChar else c) + convertToString(l) - case CaseClass(StringNil, Seq()) => "" + case ADT(StringCons, Seq(CharLiteral(c), l)) => (if(c < 31) (c + 97).toChar else c) + convertToString(l) + case ADT(StringNil, Seq()) => "" }) private def convertFromString(v: String): Expr = stringBijection.cachedB(v) { @@ -112,7 +112,7 @@ trait StringEncoder extends TheoryEncoder { val decoder = new TreeTransformer { override def transform(e: Expr): Expr = e match { - case cc @ CaseClass(ct, args) if ct == StringNil || ct == StringCons => + case cc @ ADT(adt, args) if adt == StringNil || adt == StringCons => StringLiteral(convertToString(cc)).copiedFrom(cc) case FunctionInvocation(SizeID, Seq(), Seq(a)) => StringLength(transform(a)).copiedFrom(e) diff --git a/src/main/scala/inox/solvers/unrolling/DatatypeTemplates.scala b/src/main/scala/inox/solvers/unrolling/DatatypeTemplates.scala index f739a0f4034803f9075a8d594ec90422eebb8428..caa22aeb8fdc1cb6311e408a3236f56d0e4c675e 100644 --- a/src/main/scala/inox/solvers/unrolling/DatatypeTemplates.scala +++ b/src/main/scala/inox/solvers/unrolling/DatatypeTemplates.scala @@ -18,8 +18,8 @@ trait DatatypeTemplates { self: Templates => type Functions = Set[(Encoded, FunctionType, Encoded)] /** Represents a type unfolding of a free variable (or input) in the unfolding procedure */ - case class TemplateTypeInfo(tcd: TypedAbstractClassDef, arg: Encoded) { - override def toString = tcd.toType.asString + "(" + asString(arg) + ")" + case class TemplateTypeInfo(tsort: TypedADTSort, arg: Encoded) { + override def toString = tsort.toType.asString + "(" + asString(arg) + ")" } private val cache: MutableMap[Type, DatatypeTemplate] = MutableMap.empty @@ -29,7 +29,7 @@ trait DatatypeTemplates { self: Templates => mkTemplate(tpe).instantiate(start, sym) } - private val requireChecking: MutableSet[TypedClassDef] = MutableSet.empty + private val requireChecking: MutableSet[TypedADTDefinition] = MutableSet.empty private val requireCache: MutableMap[Type, Boolean] = MutableMap.empty def requiresUnrolling(tpe: Type): Boolean = requireCache.get(tpe) match { @@ -38,18 +38,18 @@ trait DatatypeTemplates { self: Templates => val res = tpe match { case ft: FunctionType => true - case ct: ClassType => ct.tcd match { - case tccd: TypedCaseClassDef => tccd.parent.isDefined - case tcd if requireChecking(tcd.root) => false - case tcd => - requireChecking += tcd.root - val classTypes = tcd.root +: (tcd.root match { - case (tacd: TypedAbstractClassDef) => tacd.descendants + case adt: ADTType => adt.getADT match { + case tcons: TypedADTConstructor => tcons.sort.isDefined + case tadt if requireChecking(tadt.root) => false + case tadt => + requireChecking += tadt.root + val classTypes = tadt.root +: (tadt.root match { + case (tsort: TypedADTSort) => tsort.constructors case _ => Seq.empty }) classTypes.exists(ct => ct.hasInvariant || (ct match { - case tccd: TypedCaseClassDef => tccd.fieldsTypes.exists(requiresUnrolling) + case tcons: TypedADTConstructor => tcons.fieldsTypes.exists(requiresUnrolling) case _ => false })) } @@ -85,9 +85,9 @@ trait DatatypeTemplates { self: Templates => @inline def iff(e1: Expr, e2: Expr): Unit = storeGuarded(pathVar, Equals(e1, e2)) - var types = Map[Variable, Set[(TypedAbstractClassDef, Expr)]]() - @inline def storeType(pathVar: Variable, tacd: TypedAbstractClassDef, arg: Expr): Unit = { - types += pathVar -> (types.getOrElse(pathVar, Set.empty) + (tacd -> arg)) + var types = Map[Variable, Set[(TypedADTSort, Expr)]]() + @inline def storeType(pathVar: Variable, tsort: TypedADTSort, arg: Expr): Unit = { + types += pathVar -> (types.getOrElse(pathVar, Set.empty) + (tsort -> arg)) } var functions = Map[Variable, Set[Expr]]() @@ -99,38 +99,38 @@ trait DatatypeTemplates { self: Templates => case tpe if !requiresUnrolling(tpe) => // nothing to do here! - case ct: ClassType => - val tcd = ct.tcd + case adt: ADTType => + val tadt = adt.getADT - if (tcd.hasInvariant) { - storeGuarded(pathVar, tcd.invariant.get.applied(Seq(expr))) + if (tadt.hasInvariant) { + storeGuarded(pathVar, tadt.invariant.get.applied(Seq(expr))) } - if (tcd.cd.isAbstract && tcd.toAbstract.cd.isInductive) { - storeType(pathVar, tcd.toAbstract, expr) - } else if (tcd != tcd.root) { - storeGuarded(pathVar, IsInstanceOf(expr, tcd.toType)) + if (tadt.definition.isSort && tadt.toSort.definition.isInductive) { + storeType(pathVar, tadt.toSort, expr) + } else if (tadt != tadt.root) { + storeGuarded(pathVar, IsInstanceOf(expr, tadt.toType)) - val tpe = tcd.toType - for (vd <- tcd.toCase.fields) { - rec(pathVar, CaseClassSelector(AsInstanceOf(expr, tpe), vd.id)) + val tpe = tadt.toType + for (vd <- tadt.toConstructor.fields) { + rec(pathVar, ADTSelector(AsInstanceOf(expr, tpe), vd.id)) } } else { - val matchers = tcd.root match { - case (act: TypedAbstractClassDef) => act.descendants - case (cct: TypedCaseClassDef) => Seq(cct) + val matchers = tadt.root match { + case (tsort: TypedADTSort) => tsort.constructors + case (tcons: TypedADTConstructor) => Seq(tcons) } - for (tccd <- matchers) { - val tpe = tccd.toType + for (tcons <- matchers) { + val tpe = tcons.toType if (requiresUnrolling(tpe)) { val newBool: Variable = Variable(FreshIdentifier("b", true), BooleanType) storeCond(pathVar, newBool) iff(and(pathVar, IsInstanceOf(expr, tpe)), newBool) - for (vd <- tccd.fields) { - rec(newBool, CaseClassSelector(AsInstanceOf(expr, tpe), vd.id)) + for (vd <- tcons.fields) { + rec(newBool, ADTSelector(AsInstanceOf(expr, tpe), vd.id)) } } } @@ -260,8 +260,8 @@ trait DatatypeTemplates { self: Templates => val newTypeInfos = blockers.flatMap(id => typeInfos.get(id).map(id -> _)) typeInfos --= blockers - for ((blocker, (gen, _, _, tps)) <- newTypeInfos; info @ TemplateTypeInfo(tcd, arg) <- tps) { - val template = mkTemplate(tcd.toType) + for ((blocker, (gen, _, _, tps)) <- newTypeInfos; info @ TemplateTypeInfo(tadt, arg) <- tps) { + val template = mkTemplate(tadt.toType) newClauses ++= template.instantiate(blocker, arg) } diff --git a/src/main/scala/inox/solvers/unrolling/UnrollingSolver.scala b/src/main/scala/inox/solvers/unrolling/UnrollingSolver.scala index 4918b5c9ad3eb2997ac0cf89b570badb496e0984..09f333eb54e65bb5c4f5b64fa2077b959d8cbdd2 100644 --- a/src/main/scala/inox/solvers/unrolling/UnrollingSolver.scala +++ b/src/main/scala/inox/solvers/unrolling/UnrollingSolver.scala @@ -175,9 +175,9 @@ trait AbstractUnrollingSolver case (e, i) => rec(e, TupleSelect(selector, i + 1)) }, Tuple) - case CaseClass(cct, es) => reconstruct((cct.tcd.toCase.fields zip es).map { - case (vd, e) => rec(e, CaseClassSelector(selector, vd.id)) - }, CaseClass(cct, _)) + case ADT(adt, es) => reconstruct((adt.getADT.toConstructor.fields zip es).map { + case (vd, e) => rec(e, ADTSelector(selector, vd.id)) + }, ADT(adt, _)) case _ => (Seq.empty, (es: Seq[Expr]) => expr) } diff --git a/src/main/scala/inox/solvers/z3/AbstractZ3Solver.scala b/src/main/scala/inox/solvers/z3/AbstractZ3Solver.scala index 08f82c01cf02e3bfcc6af1323b24a433ba6eb6f5..5c18892492296e4a6c070e41955164dea7f0f9b1 100644 --- a/src/main/scala/inox/solvers/z3/AbstractZ3Solver.scala +++ b/src/main/scala/inox/solvers/z3/AbstractZ3Solver.scala @@ -166,7 +166,7 @@ trait AbstractZ3Solver def typeToSortRef(tt: Type): ADTSortReference = { val tpe = tt match { - case ct: ClassType => ct.tcd.root.toType + case adt: ADTType => adt.getADT.root.toType case _ => tt } @@ -225,7 +225,7 @@ trait AbstractZ3Solver case Int32Type | BooleanType | IntegerType | RealType | CharType => sorts(oldtt) - case tpe @ (_: ClassType | _: TupleType | _: TypeParameter | UnitType) => + case tpe @ (_: ADTType | _: TupleType | _: TypeParameter | UnitType) => sorts.getOrElse(tpe, declareStructuralSort(tpe)) case tt @ SetType(base) => @@ -370,32 +370,32 @@ trait AbstractZ3Solver val selector = selectors.toB((tpe, i-1)) selector(rec(t)) - case c @ CaseClass(ct, args) => - typeToSort(ct) // Making sure the sort is defined - val constructor = constructors.toB(ct) + case c @ ADT(adt, args) => + typeToSort(adt) // Making sure the sort is defined + val constructor = constructors.toB(adt) constructor(args.map(rec): _*) - case c @ CaseClassSelector(cc, sel) => - val cct = cc.getType - typeToSort(cct) // Making sure the sort is defined - val selector = selectors.toB(cct -> c.selectorIndex) + case c @ ADTSelector(cc, sel) => + val adt = cc.getType + typeToSort(adt) // Making sure the sort is defined + val selector = selectors.toB(adt -> c.selectorIndex) selector(rec(cc)) - case AsInstanceOf(expr, ct) => + case AsInstanceOf(expr, adt) => rec(expr) - case IsInstanceOf(e, ct) => ct.tcd match { - case tacd: TypedAbstractClassDef => - tacd.descendants match { - case Seq(tccd) => - rec(IsInstanceOf(e, tccd.toType)) + case IsInstanceOf(e, adt) => adt.getADT match { + case tsort: TypedADTSort => + tsort.constructors match { + case Seq(tcons) => + rec(IsInstanceOf(e, tcons.toType)) case more => - val v = Variable(FreshIdentifier("e", true), ct) - rec(Let(v.toVal, e, orJoin(more map (tccd => IsInstanceOf(v, tccd.toType))))) + val v = Variable(FreshIdentifier("e", true), adt) + rec(Let(v.toVal, e, orJoin(more map (tcons => IsInstanceOf(v, tcons.toType))))) } - case tccd: TypedCaseClassDef => - typeToSort(ct) - val tester = testers.toB(ct) + case tcons: TypedADTConstructor => + typeToSort(adt) + val tester = testers.toB(adt) tester(rec(e)) } @@ -554,8 +554,8 @@ trait AbstractZ3Solver FunctionInvocation(tfd.id, tfd.tps, args.zip(tfd.params).map{ case (a, p) => rec(a, p.getType) }) } else if (constructors containsB decl) { constructors.toA(decl) match { - case ct: ClassType => - CaseClass(ct, args.zip(ct.tcd.toCase.fieldsTypes).map { case (a, t) => rec(a, t) }) + case adt: ADTType => + ADT(adt, args.zip(adt.getADT.toConstructor.fieldsTypes).map { case (a, t) => rec(a, t) }) case UnitType => UnitLiteral()