diff --git a/src/main/scala/leon/purescala/DefOps.scala b/src/main/scala/leon/purescala/DefOps.scala index 9dd7bfbd8ad12abeaf0a1cd10a23543a92e3b8b6..6b7bf67e0ce10066f4719a34101fcb0d97cf17bc 100644 --- a/src/main/scala/leon/purescala/DefOps.scala +++ b/src/main/scala/leon/purescala/DefOps.scala @@ -316,14 +316,16 @@ object DefOps { val fdMap = new utils.Bijection[FunDef , FunDef ] val transformer = new DefinitionTransformer(idMap, fdMap, cdMap) { - override def transform(expr: Expr)(implicit bindings: Map[Identifier, Identifier]): Expr = expr match { + override def transformExpr(expr: Expr)(implicit bindings: Map[Identifier, Identifier]): Option[Expr] = expr match { case fi @ FunctionInvocation(TypedFunDef(fd, tps), args) => - val nfi = fiMapF(fi, transform(fd)) getOrElse expr - super.transform(nfi) + fiMapF(fi, transform(fd)) + //val nfi = fiMapF(fi, transform(fd)) getOrElse expr + //Some(super.transform(nfi)) case cc @ CaseClass(cct, args) => - val ncc = ciMapF(cc, transform(cct).asInstanceOf[CaseClassType]) getOrElse expr - super.transform(ncc) - case _ => super.transform(expr) + ciMapF(cc, transform(cct).asInstanceOf[CaseClassType]) + //val ncc = ciMapF(cc, transform(cct).asInstanceOf[CaseClassType]) getOrElse expr + //Some(super.transform(ncc)) + case _ => None } override def transformFunDef(fd: FunDef): Option[FunDef] = fdMapF(fd) diff --git a/src/main/scala/leon/purescala/DefinitionTransformer.scala b/src/main/scala/leon/purescala/DefinitionTransformer.scala index abe08fa34fd66465bbcf2c2ba7c114d5075c592c..2588ed8fe33419c1becb753506362c0ca663388f 100644 --- a/src/main/scala/leon/purescala/DefinitionTransformer.scala +++ b/src/main/scala/leon/purescala/DefinitionTransformer.scala @@ -22,29 +22,41 @@ class DefinitionTransformer( if (ntpe == id.getType && !freshen) id else id.duplicate(tpe = ntpe) } - override def transform(id: Identifier): Identifier = transformId(id, false) + def transformType(tpe: TypeTree): Option[TypeTree] = None + final override def transform(tpe: TypeTree): TypeTree = { + super.transform(transformType(tpe).getOrElse(tpe)) + } - override def transform(e: Expr)(implicit bindings: Map[Identifier, Identifier]): Expr = e match { - case Variable(id) if !(bindings contains id) => - val ntpe = transform(id.getType) - Variable(idMap.getB(id) match { - case Some(nid) if ntpe == nid.getType => nid - case _ => - val nid = transformId(id, false) - idMap += id -> nid - nid - }) - case LetDef(fds, body) => - val rFds = fds map transform - val rBody = transform(body) - LetDef(rFds, rBody).copiedFrom(e) - - case _ => super.transform(e) + final override def transform(id: Identifier): Identifier = transformId(id, false) + + + def transformExpr(e: Expr)(implicit bindings: Map[Identifier, Identifier]): Option[Expr] = None + final override def transform(e: Expr)(implicit bindings: Map[Identifier, Identifier]): Expr = { + transformExpr(e) match { + case Some(r) => super.transform(r) + case None => e match { + case Variable(id) if !(bindings contains id) => + val ntpe = transform(id.getType) + Variable(idMap.getB(id) match { + case Some(nid) if ntpe == nid.getType => nid + case _ => + val nid = transformId(id, false) + idMap += id -> nid + nid + }) + case LetDef(fds, body) => + val rFds = fds map transform + val rBody = transform(body) + LetDef(rFds, rBody).copiedFrom(e) + + case _ => super.transform(e) + } + } } protected def transformFunDef(fd: FunDef): Option[FunDef] = None - override def transform(fd: FunDef): FunDef = { + final override def transform(fd: FunDef): FunDef = { if ((fdMap containsB fd) || (tmpFdMap containsB fd)) fd else if (tmpFdMap containsA fd) tmpFdMap.toB(fd) else fdMap.getBorElse(fd, { @@ -54,10 +66,11 @@ class DefinitionTransformer( } protected def transformClassDef(cd: ClassDef): Option[ClassDef] = None - override def transform(cd: ClassDef): ClassDef = { + final override def transform(cd: ClassDef): ClassDef = { if ((cdMap containsB cd) || (tmpCdMap containsB cd)) cd else if (tmpCdMap containsA cd) tmpCdMap.toB(cd) - else cdMap.getBorElse(cd, { + else + cdMap.getBorElse(cd, { transformDefs(cd) cdMap.toB(cd) }) @@ -105,7 +118,8 @@ class DefinitionTransformer( newBody != fd.fullBody }) - case cd: ClassDef => !(transformedCds contains cd) && + case cd: ClassDef => + !(transformedCds contains cd) && (cd.fieldsIds.exists(id => transform(id.getType) != id.getType) || cd.invariant.exists(required)) diff --git a/src/main/scala/leon/solvers/theories/BagEncoder.scala b/src/main/scala/leon/solvers/theories/BagEncoder.scala index 2ae7f2ac8a59b0cc2e96a2cec06b22d1983b79da..d6ba2e33df0c3877628f9d1819e7f008da245720 100644 --- a/src/main/scala/leon/solvers/theories/BagEncoder.scala +++ b/src/main/scala/leon/solvers/theories/BagEncoder.scala @@ -20,52 +20,52 @@ class BagEncoder(ctx: LeonContext, p: Program) extends TheoryEncoder { val BagEquals = p.library.lookupUnique[FunDef]("leon.theories.Bag.equals") val encoder = new Encoder { - override def transform(e: Expr)(implicit binders: Map[Identifier, Identifier]): Expr = e match { + override def transformExpr(e: Expr)(implicit binders: Map[Identifier, Identifier]): Option[Expr] = e match { case FiniteBag(elems, tpe) => val newTpe = transform(tpe) val id = FreshIdentifier("x", newTpe, true) - CaseClass(Bag.typed(Seq(newTpe)), Seq(Lambda(Seq(ValDef(id)), + Some(CaseClass(Bag.typed(Seq(newTpe)), Seq(Lambda(Seq(ValDef(id)), elems.foldRight[Expr](InfiniteIntegerLiteral(0).copiedFrom(e)) { case ((k, v), ite) => IfExpr(Equals(Variable(id), transform(k)), transform(v), ite).copiedFrom(e) - }))).copiedFrom(e) + }))).copiedFrom(e)) case BagAdd(bag, elem) => val BagType(base) = bag.getType - FunctionInvocation(Add.typed(Seq(transform(base))), Seq(transform(bag), transform(elem))).copiedFrom(e) + Some(FunctionInvocation(Add.typed(Seq(transform(base))), Seq(transform(bag), transform(elem))).copiedFrom(e)) case MultiplicityInBag(elem, bag) => val BagType(base) = bag.getType - FunctionInvocation(Get.typed(Seq(transform(base))), Seq(transform(bag), transform(elem))).copiedFrom(e) + Some(FunctionInvocation(Get.typed(Seq(transform(base))), Seq(transform(bag), transform(elem))).copiedFrom(e)) case BagIntersection(b1, b2) => val BagType(base) = b1.getType - FunctionInvocation(Intersect.typed(Seq(transform(base))), Seq(transform(b1), transform(b2))).copiedFrom(e) + Some(FunctionInvocation(Intersect.typed(Seq(transform(base))), Seq(transform(b1), transform(b2))).copiedFrom(e)) case BagUnion(b1, b2) => val BagType(base) = b1.getType - FunctionInvocation(Union.typed(Seq(transform(base))), Seq(transform(b1), transform(b2))).copiedFrom(e) + Some(FunctionInvocation(Union.typed(Seq(transform(base))), Seq(transform(b1), transform(b2))).copiedFrom(e)) case BagDifference(b1, b2) => val BagType(base) = b1.getType - FunctionInvocation(Difference.typed(Seq(transform(base))), Seq(transform(b1), transform(b2))).copiedFrom(e) + Some(FunctionInvocation(Difference.typed(Seq(transform(base))), Seq(transform(b1), transform(b2))).copiedFrom(e)) case Equals(b1, b2) if b1.getType.isInstanceOf[BagType] => val BagType(base) = b1.getType - FunctionInvocation(BagEquals.typed(Seq(transform(base))), Seq(transform(b1), transform(b2))).copiedFrom(e) + Some(FunctionInvocation(BagEquals.typed(Seq(transform(base))), Seq(transform(b1), transform(b2))).copiedFrom(e)) - case _ => super.transform(e) + case _ => None } - override def transform(tpe: TypeTree): TypeTree = tpe match { - case BagType(base) => Bag.typed(Seq(transform(base))).copiedFrom(tpe) - case _ => super.transform(tpe) + override def transformType(tpe: TypeTree): Option[TypeTree] = tpe match { + case BagType(base) => Some(Bag.typed(Seq(transform(base))).copiedFrom(tpe)) + case _ => None } } val decoder = new Decoder { - override def transform(e: Expr)(implicit binders: Map[Identifier, Identifier]): Expr = e match { + override def transformExpr(e: Expr)(implicit binders: Map[Identifier, Identifier]): Option[Expr] = e match { case cc @ CaseClass(CaseClassType(Bag, Seq(tpe)), args) => - FiniteBag(args(0) match { + Some(FiniteBag(args(0) match { case FiniteLambda(mapping, dflt, tpe) => if (dflt != InfiniteIntegerLiteral(0)) throw new Unsupported(cc, "Bags can't have default value " + dflt.asString(ctx))(ctx) @@ -81,32 +81,32 @@ class BagEncoder(ctx: LeonContext, p: Program) extends TheoryEncoder { rec(body) case f => scala.sys.error("Unexpected function " + f.asString(ctx)) - }, transform(tpe)).copiedFrom(e) + }, transform(tpe)).copiedFrom(e)) case FunctionInvocation(TypedFunDef(Add, Seq(_)), Seq(bag, elem)) => - BagAdd(transform(bag), transform(elem)).copiedFrom(e) + Some(BagAdd(transform(bag), transform(elem)).copiedFrom(e)) case FunctionInvocation(TypedFunDef(Get, Seq(_)), Seq(bag, elem)) => - MultiplicityInBag(transform(elem), transform(bag)).copiedFrom(e) + Some(MultiplicityInBag(transform(elem), transform(bag)).copiedFrom(e)) case FunctionInvocation(TypedFunDef(Intersect, Seq(_)), Seq(b1, b2)) => - BagIntersection(transform(b1), transform(b2)).copiedFrom(e) + Some(BagIntersection(transform(b1), transform(b2)).copiedFrom(e)) case FunctionInvocation(TypedFunDef(Union, Seq(_)), Seq(b1, b2)) => - BagUnion(transform(b1), transform(b2)).copiedFrom(e) + Some(BagUnion(transform(b1), transform(b2)).copiedFrom(e)) case FunctionInvocation(TypedFunDef(Difference, Seq(_)), Seq(b1, b2)) => - BagDifference(transform(b1), transform(b2)).copiedFrom(e) + Some(BagDifference(transform(b1), transform(b2)).copiedFrom(e)) case FunctionInvocation(TypedFunDef(BagEquals, Seq(_)), Seq(b1, b2)) => - Equals(transform(b1), transform(b2)).copiedFrom(e) + Some(Equals(transform(b1), transform(b2)).copiedFrom(e)) - case _ => super.transform(e) + case _ => None } - override def transform(tpe: TypeTree): TypeTree = tpe match { - case CaseClassType(Bag, Seq(base)) => BagType(transform(base)).copiedFrom(tpe) - case _ => super.transform(tpe) + override def transformType(tpe: TypeTree): Option[TypeTree] = tpe match { + case CaseClassType(Bag, Seq(base)) => Some(BagType(transform(base)).copiedFrom(tpe)) + case _ => None } override def transform(pat: Pattern): (Pattern, Map[Identifier, Identifier]) = pat match { diff --git a/src/main/scala/leon/solvers/theories/StringEncoder.scala b/src/main/scala/leon/solvers/theories/StringEncoder.scala index 5af13c1c0e52e29d3f79654a6cbe30f009630d5d..0519814409d09e717babfd17f37a3a541b88c607 100644 --- a/src/main/scala/leon/solvers/theories/StringEncoder.scala +++ b/src/main/scala/leon/solvers/theories/StringEncoder.scala @@ -37,23 +37,23 @@ class StringEncoder(ctx: LeonContext, p: Program) extends TheoryEncoder { } val encoder = new Encoder { - override def transform(e: Expr)(implicit binders: Map[Identifier, Identifier]): Expr = e match { + override def transformExpr(e: Expr)(implicit binders: Map[Identifier, Identifier]): Option[Expr] = e match { case StringLiteral(v) => - convertFromString(v) + Some(convertFromString(v)) case StringLength(a) => - FunctionInvocation(Size, Seq(transform(a))).copiedFrom(e) + Some(FunctionInvocation(Size, Seq(transform(a))).copiedFrom(e)) case StringConcat(a, b) => - FunctionInvocation(Concat, Seq(transform(a), transform(b))).copiedFrom(e) + Some(FunctionInvocation(Concat, Seq(transform(a), transform(b))).copiedFrom(e)) case SubString(a, start, Plus(start2, length)) if start == start2 => - FunctionInvocation(Take, Seq(FunctionInvocation(Drop, Seq(transform(a), transform(start))), transform(length))).copiedFrom(e) + Some(FunctionInvocation(Take, Seq(FunctionInvocation(Drop, Seq(transform(a), transform(start))), transform(length))).copiedFrom(e)) case SubString(a, start, end) => - FunctionInvocation(Slice, Seq(transform(a), transform(start), transform(end))).copiedFrom(e) - case _ => super.transform(e) + Some(FunctionInvocation(Slice, Seq(transform(a), transform(start), transform(end))).copiedFrom(e)) + case _ => None } - override def transform(tpe: TypeTree): TypeTree = tpe match { - case StringType => String - case _ => super.transform(tpe) + override def transformType(tpe: TypeTree): Option[TypeTree] = tpe match { + case StringType => Some(String) + case _ => None } override def transform(pat: Pattern): (Pattern, Map[Identifier, Identifier]) = pat match { @@ -68,30 +68,30 @@ class StringEncoder(ctx: LeonContext, p: Program) extends TheoryEncoder { } val decoder = new Decoder { - override def transform(e: Expr)(implicit binders: Map[Identifier, Identifier]): Expr = e match { + override def transformExpr(e: Expr)(implicit binders: Map[Identifier, Identifier]): Option[Expr] = e match { case cc @ CaseClass(cct, args) if TypeOps.isSubtypeOf(cct, String)=> - StringLiteral(convertToString(cc)).copiedFrom(cc) + Some(StringLiteral(convertToString(cc)).copiedFrom(cc)) case FunctionInvocation(Size, Seq(a)) => - StringLength(transform(a)).copiedFrom(e) + Some(StringLength(transform(a)).copiedFrom(e)) case FunctionInvocation(Concat, Seq(a, b)) => - StringConcat(transform(a), transform(b)).copiedFrom(e) + Some(StringConcat(transform(a), transform(b)).copiedFrom(e)) case FunctionInvocation(Slice, Seq(a, from, to)) => - SubString(transform(a), transform(from), transform(to)).copiedFrom(e) + Some(SubString(transform(a), transform(from), transform(to)).copiedFrom(e)) case FunctionInvocation(Take, Seq(FunctionInvocation(Drop, Seq(a, start)), length)) => val rstart = transform(start) - SubString(transform(a), rstart, plus(rstart, transform(length))).copiedFrom(e) + Some(SubString(transform(a), rstart, plus(rstart, transform(length))).copiedFrom(e)) case FunctionInvocation(Take, Seq(a, length)) => - SubString(transform(a), InfiniteIntegerLiteral(0), transform(length)).copiedFrom(e) + Some(SubString(transform(a), InfiniteIntegerLiteral(0), transform(length)).copiedFrom(e)) case FunctionInvocation(Drop, Seq(a, count)) => val ra = transform(a) - SubString(ra, transform(count), StringLength(ra)).copiedFrom(e) - case _ => super.transform(e) + Some(SubString(ra, transform(count), StringLength(ra)).copiedFrom(e)) + case _ => None } - override def transform(tpe: TypeTree): TypeTree = tpe match { - case String | StringCons | StringNil => StringType - case _ => super.transform(tpe) + override def transformType(tpe: TypeTree): Option[TypeTree] = tpe match { + case String | StringCons | StringNil => Some(StringType) + case _ => None } override def transform(pat: Pattern): (Pattern, Map[Identifier, Identifier]) = pat match { diff --git a/src/main/scala/leon/solvers/theories/TheoryEncoder.scala b/src/main/scala/leon/solvers/theories/TheoryEncoder.scala index 39b5a632c5321220416e8142f30eb90c65b887a5..f3b4d6ea0b026d3296bfe41b7a85c96665a2ba27 100644 --- a/src/main/scala/leon/solvers/theories/TheoryEncoder.scala +++ b/src/main/scala/leon/solvers/theories/TheoryEncoder.scala @@ -40,14 +40,14 @@ trait TheoryEncoder { self => def >>(that: TheoryEncoder): TheoryEncoder = new TheoryEncoder { val encoder = new Encoder { - override def transform(expr: Expr)(implicit bindings: Map[Identifier, Identifier]): Expr = { + override def transformExpr(expr: Expr)(implicit bindings: Map[Identifier, Identifier]): Option[Expr] = { val mapSeq = bindings.toSeq val intermediate = mapSeq.map { case (id, _) => id.duplicate(tpe = self.encoder.transform(id.getType)) } val e2 = self.encoder.transform(expr)((mapSeq zip intermediate).map { case ((id, _), id2) => id -> id2 }.toMap) - that.encoder.transform(e2)((intermediate zip mapSeq).map { case (id, (_, id2)) => id -> id2 }.toMap) + Some(that.encoder.transform(e2)((intermediate zip mapSeq).map { case (id, (_, id2)) => id -> id2 }.toMap)) } - override def transform(tpe: TypeTree): TypeTree = that.encoder.transform(self.encoder.transform(tpe)) + override def transformType(tpe: TypeTree): Option[TypeTree] = Some(that.encoder.transform(self.encoder.transform(tpe))) override def transform(pat: Pattern): (Pattern, Map[Identifier, Identifier]) = { val (pat2, bindings) = self.encoder.transform(pat) @@ -57,14 +57,14 @@ trait TheoryEncoder { self => } val decoder = new Decoder { - override def transform(expr: Expr)(implicit bindings: Map[Identifier, Identifier]): Expr = { + override def transformExpr(expr: Expr)(implicit bindings: Map[Identifier, Identifier]): Option[Expr] = { val mapSeq = bindings.toSeq val intermediate = mapSeq.map { case (id, _) => id.duplicate(tpe = self.decoder.transform(id.getType)) } val e2 = that.decoder.transform(expr)((mapSeq zip intermediate).map { case ((id, _), id2) => id -> id2 }.toMap) - self.decoder.transform(e2)((intermediate zip mapSeq).map { case (id, (_, id2)) => id -> id2 }.toMap) + Some(self.decoder.transform(e2)((intermediate zip mapSeq).map { case (id, (_, id2)) => id -> id2 }.toMap)) } - override def transform(tpe: TypeTree): TypeTree = self.decoder.transform(that.decoder.transform(tpe)) + override def transformType(tpe: TypeTree): Option[TypeTree] = Some(self.decoder.transform(that.decoder.transform(tpe))) override def transform(pat: Pattern): (Pattern, Map[Identifier, Identifier]) = { val (pat2, bindings) = that.decoder.transform(pat) diff --git a/src/main/scala/leon/xlang/AntiAliasingPhase.scala b/src/main/scala/leon/xlang/AntiAliasingPhase.scala index 063694addccde2f2e4e4df3c399109937f30312d..2b7176fb3dffeeaa4f2c42b8d319176b962265d2 100644 --- a/src/main/scala/leon/xlang/AntiAliasingPhase.scala +++ b/src/main/scala/leon/xlang/AntiAliasingPhase.scala @@ -43,9 +43,9 @@ object AntiAliasingPhase extends TransformationPhase { // p._1.fields.zip(p._2.fields).filter(pvd => pvd._1.id != pvd._2).map(p => (p._1.id, p._2.id)) //}).toMap val transformer = new DefinitionTransformer { - override def transform(tpe: TypeTree): TypeTree = tpe match { - case (ft: FunctionType) => makeFunctionTypeExplicit(ft) - case _ => super.transform(tpe) + override def transformType(tpe: TypeTree): Option[TypeTree] = tpe match { + case (ft: FunctionType) => Some(makeFunctionTypeExplicit(ft)) + case _ => None } //override def transformClassDef(cd: ClassDef): Option[ClassDef] = ccdBijection.getB(cd) }