diff --git a/src/main/scala/leon/purescala/Common.scala b/src/main/scala/leon/purescala/Common.scala index 63ec4a7d1d38245ba7b835524b6c5cd2aec4efed..8b9ecba6ce5d18b7e431500cf7ce86f4e2dd6d10 100644 --- a/src/main/scala/leon/purescala/Common.scala +++ b/src/main/scala/leon/purescala/Common.scala @@ -70,6 +70,10 @@ object Common { def toVariable: Variable = Variable(this) def freshen: Identifier = FreshIdentifier(name, tpe, alwaysShowUniqueID).copiedFrom(this) + + def duplicate(name: String = name, tpe: TypeTree = tpe, alwaysShowUniqueID: Boolean = alwaysShowUniqueID) = { + FreshIdentifier(name, tpe, alwaysShowUniqueID) + } override def compare(that: Identifier): Int = { val ord = implicitly[Ordering[(String, Int, Int)]] diff --git a/src/main/scala/leon/purescala/DefOps.scala b/src/main/scala/leon/purescala/DefOps.scala index f1a35ce979e524592bb0f87869b24342927aa371..7cc576adbd6dac0a277d40ade8c7dc06a4836602 100644 --- a/src/main/scala/leon/purescala/DefOps.scala +++ b/src/main/scala/leon/purescala/DefOps.scala @@ -4,6 +4,7 @@ package leon.purescala import Definitions._ import Expressions._ +import Common.Identifier import ExprOps.{preMap, functionCallsOf} import leon.purescala.Types.AbstractClassType import leon.purescala.Types._ @@ -331,9 +332,9 @@ object DefOps { } - private def defaultCdMap(cc: CaseClass, ccd: CaseClassDef): Option[Expr] = (cc, ccd) match { + private def defaultCdMap(cc: CaseClass, ccd: CaseClassType): Option[Expr] = (cc, ccd) match { case (CaseClass(old, args), newCcd) if old.classDef != newCcd => - Some(CaseClass(newCcd.typed(old.tps), args)) + Some(CaseClass(newCcd, args)) case _ => None } @@ -347,26 +348,48 @@ object DefOps { * By default it is the case class construction using the new case class definition. * @return the new program with a map from the old case classes to the new case classes */ def replaceClassDefs(p: Program)(cdMapF: (ClassDef, Option[AbstractClassType]) => Option[ClassDef], - ciMapF: (CaseClass, CaseClassDef) => Option[Expr] = defaultCdMap): (Program, Map[ClassDef, ClassDef]) = { + ciMapF: (CaseClass, CaseClassType) => Option[Expr] = defaultCdMap) + : (Program, Map[ClassDef, ClassDef], Map[Identifier, Identifier], Map[FunDef, FunDef]) = { var cdMapCache = Map[ClassDef, ClassDef]() - def tpMap(tt: TypeTree): TypeTree = tt match { - case AbstractClassType(asd, targs) => AbstractClassType(cdMap(asd).asInstanceOf[AbstractClassDef], targs map tpMap) - case CaseClassType(ccd, targs) => CaseClassType(cdMap(ccd).asInstanceOf[CaseClassDef], targs map tpMap) - case e => e - } + var idMapCache = Map[Identifier, Identifier]() + var fdMapCache = Map[FunDef, FunDef]() + def tpMap(tt: TypeTree): TypeTree = TypeOps.postMap{ + case AbstractClassType(asd, targs) => Some(AbstractClassType(cdMap(asd).asInstanceOf[AbstractClassDef], targs)) + case CaseClassType(ccd, targs) => Some(CaseClassType(cdMap(ccd).asInstanceOf[CaseClassDef], targs)) + case e => None + }(tt) def cdMap(cd: ClassDef): ClassDef = { if (!(cdMapCache contains cd)) { lazy val parent = cd.parent.map( tpMap(_).asInstanceOf[AbstractClassType] ) - cdMapCache += cd -> cdMapF(cd, parent).getOrElse{ - cd match { - case acd:AbstractClassDef => acd.duplicate(parent = parent) - case ccd:CaseClassDef => ccd.duplicate(parent = parent) - } + val ncd = cdMapF(cd, parent) match { + case Some(new_ccd) => + for((old_id, new_id) <- cd.fieldsIds.zip(new_ccd.fieldsIds)) { + idMapCache += old_id -> new_id + } + new_ccd + case None => + cd match { + case acd:AbstractClassDef => acd.duplicate(parent = parent) + case ccd:CaseClassDef => ccd.duplicate(parent = parent, fields = ccd.fieldsIds.map(id => ValDef(idMap(id)))) // Should not cycle since fields have to be abstract. + } } + cdMapCache += cd -> ncd } cdMapCache(cd) } + def idMap(id: Identifier): Identifier = { + if (!(idMapCache contains id)) { + idMapCache += id -> id.duplicate(tpe = tpMap(id.getType)) + } + idMapCache(id) + } + def fdMap(fd: FunDef): FunDef = { + if (!(fdMapCache contains fd)) { + fdMapCache += fd -> fd.duplicate(params = fd.params.map(vd => ValDef(idMap(vd.id))), returnType = tpMap(fd.returnType)) + } + fdMapCache(fd) + } val newP = p.copy(units = for (u <- p.units) yield { u.copy( @@ -374,7 +397,8 @@ object DefOps { case m : ModuleDef => m.copy(defs = for (df <- m.defs) yield { df match { - case f : ClassDef => cdMap(f) + case cd : ClassDef => cdMap(cd) + case fd : FunDef => fdMap(fd) case d => d } }) @@ -382,25 +406,63 @@ object DefOps { } ) }) + object ToTransform { + def unapply(c: ClassType): Option[ClassDef] = Some(cdMap(c.classDef)) + } + trait Transformed[T <: TypeTree] { + def unapply(c: T): Option[T] = Some(TypeOps.postMap({ + case c: ClassType => + val newClassDef = cdMap(c.classDef) + Some((c match { + case CaseClassType(ccd, tps) => + CaseClassType(newClassDef.asInstanceOf[CaseClassDef], tps.map(e => TypeOps.postMap{ case TypeTransformed(ct) => Some(ct) case _ => None }(e))) + case AbstractClassType(acd, tps) => + AbstractClassType(newClassDef.asInstanceOf[AbstractClassDef], tps.map(e => TypeOps.postMap{ case TypeTransformed(ct) => Some(ct) case _ => None }(e))) + }).asInstanceOf[T]) + case _ => None + })(c).asInstanceOf[T]) + } + object CaseClassTransformed extends Transformed[CaseClassType] + object ClassTransformed extends Transformed[ClassType] + object TypeTransformed extends Transformed[TypeTree] + def replaceClassDefUse(e: Pattern): Pattern = PatternOps.postMap{ + case CaseClassPattern(optId, CaseClassTransformed(ct), sub) => Some(CaseClassPattern(optId.map(idMap), ct, sub)) + case InstanceOfPattern(optId, ClassTransformed(ct)) => Some(InstanceOfPattern(optId.map(idMap), ct)) + case UnapplyPattern(optId, TypedFunDef(fd, tps), subp) => Some(UnapplyPattern(optId.map(idMap), TypedFunDef(fdMap(fd), tps.map(tpMap)), subp)) + case Extractors.Pattern(Some(id), subp, builder) => Some(builder(Some(idMap(id)), subp)) + case e => None + }(e) + + def replaceClassDefsUse(e: Expr): Expr = { + ExprOps.postMap { + case Let(id, expr, body) => Some(Let(idMap(id), expr, body)) + case Variable(id) => Some(Variable(idMap(id))) + case ci @ CaseClass(CaseClassTransformed(ct), args) => + ciMapF(ci, ct).map(_.setPos(ci)) + //case IsInstanceOf(e, ToTransform()) => + case CaseClassSelector(CaseClassTransformed(cct), expr, identifier) => + Some(CaseClassSelector(cct, expr, idMap(identifier))) + case IsInstanceOf(e, ClassTransformed(ct)) => Some(IsInstanceOf(e, ct)) + case AsInstanceOf(e, ClassTransformed(ct)) => Some(AsInstanceOf(e, ct)) + case MatchExpr(scrut, cases) => + Some(MatchExpr(scrut, cases.map{ + case MatchCase(pattern, optGuard, rhs) => + MatchCase(replaceClassDefUse(pattern), optGuard, rhs) + })) + case fi @ FunctionInvocation(TypedFunDef(fd, tps), args) => + defaultFiMap(fi, fdMap(fd)).map(_.setPos(fi)) + case _ => + None + }(e) + } + for(fd <- newP.definedFunctions) { - // TODO: Check for patterns - // TODO: Check for isInstanceOf - // TODO: Check for asInstanceOf - if(ExprOps.exists{ case CaseClass(CaseClassType(ccd, targs), fargs) => cdMapCache.getOrElse(ccd, None) != None case _ => false }(fd.fullBody)) { - fd.fullBody = replaceClassDefsUse(fd.fullBody, cdMap, ciMapF) - } + fd.fullBody = replaceClassDefsUse(fd.fullBody) } - (newP, cdMapCache) + (newP, cdMapCache, idMapCache, fdMapCache) } - def replaceClassDefsUse(e: Expr, fdMapF: ClassDef => ClassDef, fiMapF: (CaseClass, CaseClassDef) => Option[Expr] = defaultCdMap) = { - preMap { - case fi @ CaseClass(CaseClassType(cd, tps), args) => - fiMapF(fi, fdMapF(cd).asInstanceOf[CaseClassDef]).map(_.setPos(fi)) - case _ => - None - }(e) - } + def addDefs(p: Program, cds: Traversable[Definition], after: Definition): Program = { var found = false diff --git a/src/main/scala/leon/purescala/Expressions.scala b/src/main/scala/leon/purescala/Expressions.scala index 609f5cda04dcd1a1e5e0e55d1c5f01475236fd94..3c597d26e76976d44a939754fdea8c1caf45d29e 100644 --- a/src/main/scala/leon/purescala/Expressions.scala +++ b/src/main/scala/leon/purescala/Expressions.scala @@ -361,6 +361,7 @@ object Expressions { ) } + // Extracts without taking care of the binder. (contrary to Extractos.Pattern) object PatternExtractor extends SubTreeOps.Extractor[Pattern] { def unapply(e: Pattern): Option[(Seq[Pattern], (Seq[Pattern]) => Pattern)] = e match { case (_: InstanceOfPattern) | (_: WildcardPattern) | (_: LiteralPattern[_]) => diff --git a/src/main/scala/leon/solvers/combinators/Z3StringCapableSolver.scala b/src/main/scala/leon/solvers/combinators/Z3StringCapableSolver.scala index df01b574ac1ad79f4a09a53b11d64c50e014ed78..2807c2394c0b1540094ddeffd0bccd2071be9b3f 100644 --- a/src/main/scala/leon/solvers/combinators/Z3StringCapableSolver.scala +++ b/src/main/scala/leon/solvers/combinators/Z3StringCapableSolver.scala @@ -42,13 +42,30 @@ object Z3StringCapableSolver { def convert(p: Program): (Program, Option[Z3StringConversion]) = { val converter = new Z3StringConversion(p) import converter.Forward._ - var globalFdMap = Map[FunDef, (Map[Identifier, Identifier], FunDef)]() var hasStrings = false val program_with_strings = converter.getProgram - val (new_program, fdMap) = DefOps.replaceFunDefs(program_with_strings)((fd: FunDef) => { + val (program_with_correct_classes, cdMap, idMap, fdMap) = if(program_with_strings.definedClasses.exists{ case c: CaseClassDef => c.fieldsIds.exists(id => TypeOps.exists{ _ == StringType}(id.getType)) case _ => false}) { + val res:(Program, Map[ClassDef, ClassDef], Map[Identifier, Identifier], Map[FunDef, FunDef]) = DefOps.replaceClassDefs(program_with_strings)((cd: ClassDef, parent: Option[AbstractClassType]) => { + cd match { + case acd:AbstractClassDef => None + case ccd:CaseClassDef => + if(ccd.fieldsIds.exists(id => TypeOps.exists(StringType == _)(id.getType))) { + Some(ccd.duplicate(convertId(ccd.id), ccd.tparams, ccd.fieldsIds.map(id => ValDef(convertId(id))), parent, ccd.isCaseObject)) + } else None + } + }) + converter.mappedVariables.clear() // We will compose them later, they have been stored in idMap + res + } else { + (program_with_strings, Map[ClassDef, ClassDef](), Map[Identifier, Identifier](), Map[FunDef, FunDef]()) + } + val fdMapInverse = fdMap.map(kv => kv._2 -> kv._1).toMap + val idMapInverse = idMap.map(kv => kv._2 -> kv._1).toMap + var globalFdMap = Map[FunDef, (Map[Identifier, Identifier], FunDef)]() + val (new_program, _) = DefOps.replaceFunDefs(program_with_correct_classes)((fd: FunDef) => { globalFdMap.get(fd).map(_._2).orElse( if(thatShouldBeConverted(fd)) { - val idMap = fd.params.map(vd => vd.id -> convertId(vd.id)).toMap + val idMap = fd.params.zip(fd.params).map(origvd_vd => origvd_vd._1.id -> convertId(origvd_vd._2.id)).toMap val newFdId = convertId(fd.id) val newFd = fd.duplicate(newFdId, fd.tparams, @@ -68,6 +85,9 @@ object Z3StringCapableSolver { implicit val idVarMap = idMap.mapValues(id => Variable(id)) newFd.fullBody = convertExpr(newFd.fullBody) } + converter.mappedVariables.composeA(id => idMapInverse.getOrElse(id, id)) + converter.globalFdMap.composeA(fd => fdMapInverse.getOrElse(fd, fd)) + converter.globalClassMap ++= cdMap (new_program, Some(converter)) } } diff --git a/src/main/scala/leon/solvers/z3/Z3StringConversion.scala b/src/main/scala/leon/solvers/z3/Z3StringConversion.scala index 1c713fdf1ceb600cca902b38fd1dbd64f1754f4a..b644f687af3950d03bb062f0ef0f030c3691f6b4 100644 --- a/src/main/scala/leon/solvers/z3/Z3StringConversion.scala +++ b/src/main/scala/leon/solvers/z3/Z3StringConversion.scala @@ -124,7 +124,9 @@ class Z3StringConversion(val p: Program) extends Z3StringConverters { trait Z3StringConverters { import StringEcoSystem._ val mappedVariables = new Bijection[Identifier, Identifier]() - + + val globalClassMap = new Bijection[ClassDef, ClassDef]() // To be added manually + val globalFdMap = new Bijection[FunDef, FunDef]() val stringBijection = new Bijection[String, Expr]() @@ -147,11 +149,19 @@ trait Z3StringConverters { def convertFunDef(fd: FunDef): FunDef def hasIdConversion(id: Identifier): Boolean def convertId(id: Identifier): Identifier + def convertClassDef(d: ClassDef): ClassDef def isTypeToConvert(tpe: TypeTree): Boolean def convertType(tpe: TypeTree): TypeTree def convertPattern(pattern: Pattern): Pattern def convertExpr(expr: Expr)(implicit bindings: Map[Identifier, Expr]): Expr - + object TypeConverted { + def unapply(t: TypeTree): Option[TypeTree] = Some(t match { + case cct@CaseClassType(ccd, args) => CaseClassType(convertClassDef(ccd).asInstanceOf[CaseClassDef], args) + case act@AbstractClassType(acd, args) => AbstractClassType(convertClassDef(acd).asInstanceOf[AbstractClassDef], args) + case NAryType(es, builder) => + builder(es map convertType) + }) + } object PatternConverted { def unapply(e: Pattern): Option[Pattern] = Some(e match { case InstanceOfPattern(binder, ct) => @@ -249,6 +259,10 @@ trait Z3StringConverters { def convertFunDef(fd: FunDef): FunDef = { globalFdMap.getBorElse(fd, fd) } + /* The conversion between classdefs should already have taken place */ + def convertClassDef(cd: ClassDef): ClassDef = { + globalClassMap.getBorElse(cd, cd) + } def hasIdConversion(id: Identifier): Boolean = { mappedVariables.containsA(id) } @@ -265,8 +279,10 @@ trait Z3StringConverters { } def isTypeToConvert(tpe: TypeTree): Boolean = TypeOps.exists(StringType == _)(tpe) - def convertType(tpe: TypeTree): TypeTree = - TypeOps.preMap{ case StringType => Some(StringList.typed) case e => None}(tpe) + def convertType(tpe: TypeTree): TypeTree = tpe match { + case StringType => StringList.typed + case TypeConverted(t) => t + } def convertPattern(e: Pattern): Pattern = e match { case LiteralPattern(binder, StringLiteral(s)) => s.foldRight(CaseClassPattern(None, StringNilTyped, Seq())) { @@ -303,6 +319,10 @@ trait Z3StringConverters { def convertFunDef(fd: FunDef): FunDef = { globalFdMap.getAorElse(fd, fd) } + /* The conversion between classdefs should already have taken place */ + def convertClassDef(cd: ClassDef): ClassDef = { + globalClassMap.getAorElse(cd, cd) + } def hasIdConversion(id: Identifier): Boolean = { mappedVariables.containsB(id) } @@ -323,36 +343,35 @@ trait Z3StringConverters { } def isTypeToConvert(tpe: TypeTree): Boolean = TypeOps.exists(t => TypeOps.isSubtypeOf(t, StringListTyped))(tpe) - def convertType(tpe: TypeTree): TypeTree = { - TypeOps.preMap{ - case StringList | StringCons | StringNil => Some(StringType) - case e => None}(tpe) + def convertType(tpe: TypeTree): TypeTree = tpe match { + case StringList | StringCons | StringNil => StringType + case TypeConverted(t) => t } def convertPattern(e: Pattern): Pattern = e match { - case CaseClassPattern(b, StringNilTyped, Seq()) => - LiteralPattern(b.map(convertId), StringLiteral("")) - case CaseClassPattern(b, StringConsTyped, Seq(LiteralPattern(_, CharLiteral(elem)), subpattern)) => - convertPattern(subpattern) match { - case LiteralPattern(_, StringLiteral(s)) - => LiteralPattern(b.map(convertId), StringLiteral(elem + s)) - case e => LiteralPattern(None, StringLiteral("Failed to parse pattern back as string:" + e)) + case CaseClassPattern(b, StringNilTyped, Seq()) => + LiteralPattern(b.map(convertId), StringLiteral("")) + case CaseClassPattern(b, StringConsTyped, Seq(LiteralPattern(_, CharLiteral(elem)), subpattern)) => + convertPattern(subpattern) match { + case LiteralPattern(_, StringLiteral(s)) + => LiteralPattern(b.map(convertId), StringLiteral(elem + s)) + case e => LiteralPattern(None, StringLiteral("Failed to parse pattern back as string:" + e)) + } + case PatternConverted(e) => e + } + + def convertExpr(e: Expr)(implicit bindings: Map[Identifier, Expr]): Expr = + e match { + case cc@CaseClass(cct, args) if TypeOps.isSubtypeOf(cct, StringListTyped)=> + StringLiteral(convertToString(cc)) + case FunctionInvocation(StringSize, Seq(a)) => + StringLength(convertExpr(a)).copiedFrom(e) + case FunctionInvocation(StringListConcat, Seq(a, b)) => + StringConcat(convertExpr(a), convertExpr(b)).copiedFrom(e) + case FunctionInvocation(StringTake, + Seq(FunctionInvocation(StringDrop, Seq(a, start)), length)) => + val rstart = convertExpr(start) + SubString(convertExpr(a), rstart, plus(rstart, convertExpr(length))).copiedFrom(e) + case ExprConverted(e) => e } - case PatternConverted(e) => e - } - - def convertExpr(e: Expr)(implicit bindings: Map[Identifier, Expr]): Expr = - e match { - case cc@CaseClass(cct, args) if TypeOps.isSubtypeOf(cct, StringListTyped)=> - StringLiteral(convertToString(cc)) - case FunctionInvocation(StringSize, Seq(a)) => - StringLength(convertExpr(a)).copiedFrom(e) - case FunctionInvocation(StringListConcat, Seq(a, b)) => - StringConcat(convertExpr(a), convertExpr(b)).copiedFrom(e) - case FunctionInvocation(StringTake, - Seq(FunctionInvocation(StringDrop, Seq(a, start)), length)) => - val rstart = convertExpr(start) - SubString(convertExpr(a), rstart, plus(rstart, convertExpr(length))).copiedFrom(e) - case ExprConverted(e) => e } - } } diff --git a/src/main/scala/leon/utils/Bijection.scala b/src/main/scala/leon/utils/Bijection.scala index 3680930639a2cfba46490d4a21bab7772d7fd0c8..380799d25e1f73ddbbb57d7989706fd03e5f1821 100644 --- a/src/main/scala/leon/utils/Bijection.scala +++ b/src/main/scala/leon/utils/Bijection.scala @@ -2,9 +2,16 @@ package leon.utils -class Bijection[A, B] { +object Bijection { + def apply[A, B](a: Iterable[(A, B)]): Bijection[A, B] = new Bijection[A, B] ++= a + def apply[A, B](a: (A, B)*): Bijection[A, B] = apply(a.toSeq) +} + +class Bijection[A, B] extends Iterable[(A, B)] { protected var a2b = Map[A, B]() protected var b2a = Map[B, A]() + + def iterator = a2b.iterator def +=(a: A, b: B): Unit = { a2b += a -> b @@ -16,7 +23,7 @@ class Bijection[A, B] { this } - def ++=(t: Iterable[(A,B)]) = { + def ++=(t: Iterable[(A, B)]) = { (this /: t){ case (b, elem) => b += elem } } @@ -58,4 +65,11 @@ class Bijection[A, B] { def aSet = a2b.keySet def bSet = b2a.keySet + + def composeA[C](c: A => C): Bijection[C, B] = { + new Bijection[C, B] ++= this.a2b.map(kv => c(kv._1) -> kv._2) + } + def composeB[C](c: B => C): Bijection[A, C] = { + new Bijection[A, C] ++= this.a2b.map(kv => kv._1 -> c(kv._2)) + } }