diff --git a/src/main/scala/leon/purescala/Common.scala b/src/main/scala/leon/purescala/Common.scala index 63ec4a7d1d38245ba7b835524b6c5cd2aec4efed..8b9ecba6ce5d18b7e431500cf7ce86f4e2dd6d10 100644 --- a/src/main/scala/leon/purescala/Common.scala +++ b/src/main/scala/leon/purescala/Common.scala @@ -70,6 +70,10 @@ object Common { def toVariable: Variable = Variable(this) def freshen: Identifier = FreshIdentifier(name, tpe, alwaysShowUniqueID).copiedFrom(this) + + def duplicate(name: String = name, tpe: TypeTree = tpe, alwaysShowUniqueID: Boolean = alwaysShowUniqueID) = { + FreshIdentifier(name, tpe, alwaysShowUniqueID) + } override def compare(that: Identifier): Int = { val ord = implicitly[Ordering[(String, Int, Int)]] diff --git a/src/main/scala/leon/purescala/DefOps.scala b/src/main/scala/leon/purescala/DefOps.scala index 4efc97986400f312afeb19674615e9822c56c8e3..78c2b93edecf113c717a625a3c974c8dcf43e474 100644 --- a/src/main/scala/leon/purescala/DefOps.scala +++ b/src/main/scala/leon/purescala/DefOps.scala @@ -4,7 +4,10 @@ package leon.purescala import Definitions._ import Expressions._ +import Common.Identifier import ExprOps.{preMap, functionCallsOf} +import leon.purescala.Types.AbstractClassType +import leon.purescala.Types._ object DefOps { @@ -274,13 +277,11 @@ object DefOps { case _ => None } - + /** Clones the given program by replacing some functions by other functions. * * @param p The original program * @param fdMapF Given f, returns Some(g) if f should be replaced by g, and None if f should be kept. - * May be called once each time a function appears (definition and invocation), - * so make sure to output the same if the argument is the same. * @param fiMapF Given a previous function invocation and its new function definition, returns the expression to use. * By default it is the function invocation using the new function definition. * @return the new program with a map from the old functions to the new functions */ @@ -288,49 +289,329 @@ object DefOps { fiMapF: (FunctionInvocation, FunDef) => Option[Expr] = defaultFiMap) : (Program, Map[FunDef, FunDef])= { - var fdMapCache = Map[FunDef, Option[FunDef]]() + var fdMapFCache = Map[FunDef, Option[FunDef]]() // Original fdMapF cache + var fdMapCache = Map[FunDef, Option[FunDef]]() // Final replacement. + def fdMapFCached(fd: FunDef): Option[FunDef] = { + fdMapFCache.get(fd) match { + case Some(e) => e + case None => + val new_fd = fdMapF(fd) + fdMapFCache += fd -> new_fd + new_fd + } + } + + def duplicateParents(fd: FunDef): Unit = { + fdMapCache.get(fd) match { + case None => + fdMapCache += fd -> fdMapFCached(fd).orElse(Some(fd.duplicate())) + for(fp <- p.callGraph.callers(fd)) { + duplicateParents(fp) + } + case _ => + } + } + def fdMap(fd: FunDef): FunDef = { - if (!(fdMapCache contains fd)) { - fdMapCache += fd -> fdMapF(fd) + fdMapCache.get(fd) match { + case Some(Some(e)) => e + case Some(None) => fd + case None => + if(fdMapFCached(fd).isDefined || p.callGraph.transitiveCallees(fd).exists(fd => fdMapFCached(fd).isDefined)) { + duplicateParents(fd) + } else { // Verify that for all + fdMapCache += fd -> None + } + fdMapCache(fd).getOrElse(fd) } - - fdMapCache(fd).getOrElse(fd) } - - + val newP = p.copy(units = for (u <- p.units) yield { u.copy( defs = u.defs.map { case m : ModuleDef => m.copy(defs = for (df <- m.defs) yield { df match { - case f : FunDef => - val newF = fdMap(f) - newF - case d => - d + case f : FunDef => fdMap(f) + case d => d } }) case d => d } ) }) + for(fd <- newP.definedFunctions) { - if(ExprOps.exists{ case FunctionInvocation(TypedFunDef(fd, targs), fargs) => fdMapCache.getOrElse(fd, None) != None case _ => false }(fd.fullBody)) { + if(ExprOps.exists{ + case FunctionInvocation(TypedFunDef(fd, targs), fargs) => fdMapCache contains fd + case MatchExpr(_, cases) => cases.exists(c => PatternOps.exists{ + case UnapplyPattern(optId, TypedFunDef(fd, tps), subp) => fdMapCache contains fd + case _ => false + }(c.pattern)) + case _ => false + }(fd.fullBody)) { fd.fullBody = replaceFunCalls(fd.fullBody, fdMap, fiMapF) } } (newP, fdMapCache.collect{ case (ofd, Some(nfd)) => ofd -> nfd }) } - def replaceFunCalls(e: Expr, fdMapF: FunDef => FunDef, fiMapF: (FunctionInvocation, FunDef) => Option[Expr] = defaultFiMap) = { + def replaceFunCalls(e: Expr, fdMapF: FunDef => FunDef, fiMapF: (FunctionInvocation, FunDef) => Option[Expr] = defaultFiMap): Expr = { preMap { + case MatchExpr(scrut, cases) => + Some(MatchExpr(scrut, cases.map(matchcase => matchcase match { + case MatchCase(pattern, guard, rhs) => MatchCase(replaceFunCalls(pattern, fdMapF), guard, rhs) + }))) case fi @ FunctionInvocation(TypedFunDef(fd, tps), args) => fiMapF(fi, fdMapF(fd)).map(_.setPos(fi)) case _ => None }(e) } + + def replaceFunCalls(p: Pattern, fdMapF: FunDef => FunDef): Pattern = PatternOps.preMap{ + case UnapplyPattern(optId, TypedFunDef(fd, tps), subp) => Some(UnapplyPattern(optId, TypedFunDef(fdMapF(fd), tps), subp)) + case _ => None + }(p) + + private def defaultCdMap(cc: CaseClass, ccd: CaseClassType): Option[Expr] = (cc, ccd) match { + case (CaseClass(old, args), newCcd) if old.classDef != newCcd => + Some(CaseClass(newCcd, args)) + case _ => + None + } + + /** Clones the given program by replacing some classes by other classes. + * + * @param p The original program + * @param cdMapF Given c returns Some(d) where d can take an abstract parent and return a class e if c should be replaced by e, and None if c should be kept. + * @param ciMapF Given a previous case class invocation and its new case class definition, returns the expression to use. + * By default it is the case class construction using the new case class definition. + * @return the new program with a map from the old case classes to the new case classes, with maps concerning identifiers and function definitions. */ + def replaceCaseClassDefs(p: Program)(_cdMapF: CaseClassDef => Option[Option[AbstractClassType] => CaseClassDef], + ciMapF: (CaseClass, CaseClassType) => Option[Expr] = defaultCdMap) + : (Program, Map[ClassDef, ClassDef], Map[Identifier, Identifier], Map[FunDef, FunDef]) = { + var cdMapFCache = Map[CaseClassDef, Option[Option[AbstractClassType] => CaseClassDef]]() + var cdMapCache = Map[ClassDef, Option[ClassDef]]() + var idMapCache = Map[Identifier, Identifier]() + var fdMapFCache = Map[FunDef, Option[FunDef]]() + var fdMapCache = Map[FunDef, Option[FunDef]]() + def cdMapF(cd: ClassDef): Option[Option[AbstractClassType] => CaseClassDef] = { + cd match { + case ccd: CaseClassDef => + cdMapFCache.getOrElse(ccd, { + val new_cd_potential = cdMapF(ccd) + cdMapFCache += ccd -> new_cd_potential + new_cd_potential + }) + case acd: AbstractClassDef => None + } + } + def tpMap[T <: TypeTree](tt: T): T = TypeOps.postMap{ + case AbstractClassType(asd, targs) => Some(AbstractClassType(cdMap(asd).asInstanceOf[AbstractClassDef], targs)) + case CaseClassType(ccd, targs) => Some(CaseClassType(cdMap(ccd).asInstanceOf[CaseClassDef], targs)) + case e => None + }(tt).asInstanceOf[T] + + def duplicateClassDef(cd: ClassDef): Unit = { + cdMapCache.get(cd) match { + case Some(new_cd) => + case None => + val parent = cd.parent.map(duplicateAbstractClassType) + val new_cd = cdMapF(cd).map(f => f(parent)).getOrElse{ + cd match { + case acd:AbstractClassDef => acd.duplicate(parent = parent) + case ccd:CaseClassDef => ccd.duplicate(parent = parent, fields = ccd.fieldsIds.map(id => ValDef(idMap(id)))) // Should not cycle since fields have to be abstract. + } + } + cdMapCache += cd -> Some(new_cd) + } + } + + def duplicateAbstractClassType(act: AbstractClassType): AbstractClassType = { + TypeOps.postMap{ + case AbstractClassType(acd, tps) => Some(AbstractClassType(duplicateClassDef(acd).asInstanceOf[AbstractClassDef], tps)) + case CaseClassType(ccd, tps) => Some(CaseClassType(duplicateClassDef(ccd).asInstanceOf[CaseClassDef], tps)) + case _ => None + }(act).asInstanceOf[AbstractClassType] + } + + // If at least one descendants or known case class needs conversion, then all the hierarchy will be converted. + // If something extends List[A] and A is modified, then the first something should be modified. + def dependencies(s: ClassDef): Set[ClassDef] = { + Set(s) ++ s.parent.toList.flatMap(p => TypeOps.collect[ClassDef]{ + case AbstractClassType(acd, _) => Set(acd:ClassDef) ++ acd.knownCCDescendants + case CaseClassType(ccd, _) => Set(ccd:ClassDef) + }(p)) + } + + def cdMap(cd: ClassDef): ClassDef = { + cdMapCache.get(cd) match { + case Some(Some(new_cd)) => new_cd + case Some(None) => cd + case None => + if(cdMapF(cd).isDefined || dependencies(cd).exists(cd => cdMapF(cd).isDefined)) { // Needs replacement in any case. + duplicateClassDef(cd) + } else { + cdMapCache += cd -> None + } + cdMapCache(cd).getOrElse(cd) + } + } + def idMap(id: Identifier): Identifier = { + if (!(idMapCache contains id)) { + idMapCache += id -> id.duplicate(tpe = tpMap(id.getType)) + } + idMapCache(id) + } + + def idHasToChange(id: Identifier): Boolean = { + typeHasToChange(id.getType) + } + + def typeHasToChange(tp: TypeTree): Boolean = { + TypeOps.exists{ + case AbstractClassType(acd, _) => cdMap(acd) != acd + case CaseClassType(ccd, _) => cdMap(ccd) != ccd + }(tp) + } + + def patternHasToChange(p: Pattern): Boolean = { + PatternOps.exists { + case CaseClassPattern(optId, cct, sub) => optId.exists(idHasToChange) || typeHasToChange(cct) + case InstanceOfPattern(optId, cct) => optId.exists(idHasToChange) || typeHasToChange(cct) + case Extractors.Pattern(optId, subp, builder) => optId.exists(idHasToChange) + case e => false + } (p) + } + + def exprHasToChange(e: Expr): Boolean = { + ExprOps.exists{ + case Let(id, expr, body) => idHasToChange(id) + case Variable(id) => idHasToChange(id) + case ci @ CaseClass(cct, args) => typeHasToChange(cct) + case CaseClassSelector(cct, expr, identifier) => typeHasToChange(cct) || idHasToChange(identifier) + case IsInstanceOf(e, cct) => typeHasToChange(cct) + case AsInstanceOf(e, cct) => typeHasToChange(cct) + case MatchExpr(scrut, cases) => + cases.exists{ + case MatchCase(pattern, optGuard, rhs) => + patternHasToChange(pattern) + } + case fi @ FunctionInvocation(TypedFunDef(fd, tps), args) => + tps.exists(typeHasToChange) + case _ => + false + }(e) + } + + def funDefHasToChange(fd: FunDef): Boolean = { + exprHasToChange(fd.fullBody) || fd.params.exists(vid => typeHasToChange(vid.id.getType)) || typeHasToChange(fd.returnType) + } + + def funHasToChange(fd: FunDef): Boolean = { + funDefHasToChange(fd) || p.callGraph.transitiveCallees(fd).exists(fd => + fdMapFCache.get(fd) match { + case Some(Some(_)) => true + case Some(None) => false + case None => funDefHasToChange(fd) + }) + } + + def fdMapFCached(fd: FunDef): Option[FunDef] = { + fdMapFCache.get(fd) match { + case Some(e) => e + case None => + val new_fd = if(funHasToChange(fd)) { + Some(fd.duplicate(params = fd.params.map(vd => ValDef(idMap(vd.id))), returnType = tpMap(fd.returnType))) + } else { + None + } + fdMapFCache += fd -> new_fd + new_fd + } + } + + def duplicateParents(fd: FunDef): Unit = { + fdMapCache.get(fd) match { + case None => + fdMapCache += fd -> fdMapFCached(fd).orElse(Some(fd.duplicate())) + for(fp <- p.callGraph.callers(fd)) { + duplicateParents(fp) + } + case _ => + } + } + + def fdMap(fd: FunDef): FunDef = { + fdMapCache.get(fd) match { + case Some(Some(e)) => e + case Some(None) => fd + case None => + if(fdMapFCached(fd).isDefined || p.callGraph.transitiveCallees(fd).exists(fd => fdMapFCached(fd).isDefined)) { + duplicateParents(fd) + } else { // Verify that for all + fdMapCache += fd -> None + } + fdMapCache(fd).getOrElse(fd) + } + } + + val newP = p.copy(units = for (u <- p.units) yield { + u.copy( + defs = u.defs.map { + case m : ModuleDef => + m.copy(defs = for (df <- m.defs) yield { + df match { + case cd : ClassDef => cdMap(cd) + case fd : FunDef => fdMap(fd) + case d => d + } + }) + case d => d + } + ) + }) + def replaceClassDefUse(e: Pattern): Pattern = PatternOps.postMap{ + case CaseClassPattern(optId, cct, sub) => Some(CaseClassPattern(optId.map(idMap), tpMap[CaseClassType](cct), sub)) + case InstanceOfPattern(optId, cct) => Some(InstanceOfPattern(optId.map(idMap), tpMap[ClassType](cct))) + case UnapplyPattern(optId, TypedFunDef(fd, tps), subp) => Some(UnapplyPattern(optId.map(idMap), TypedFunDef(fdMap(fd), tps.map(tpMap)), subp)) + case Extractors.Pattern(Some(id), subp, builder) => Some(builder(Some(idMap(id)), subp)) + case e => None + }(e) + + def replaceClassDefsUse(e: Expr): Expr = { + ExprOps.postMap { + case Let(id, expr, body) => Some(Let(idMap(id), expr, body)) + case Variable(id) => Some(Variable(idMap(id))) + case ci @ CaseClass(ct, args) => + ciMapF(ci, tpMap(ct)).map(_.setPos(ci)) + case CaseClassSelector(cct, expr, identifier) => + Some(CaseClassSelector(tpMap(cct), expr, idMap(identifier))) + case IsInstanceOf(e, ct) => Some(IsInstanceOf(e, tpMap(ct))) + case AsInstanceOf(e, ct) => Some(AsInstanceOf(e, tpMap(ct))) + case MatchExpr(scrut, cases) => + Some(MatchExpr(scrut, cases.map{ + case MatchCase(pattern, optGuard, rhs) => + MatchCase(replaceClassDefUse(pattern), optGuard, rhs) + })) + case fi @ FunctionInvocation(TypedFunDef(fd, tps), args) => + defaultFiMap(fi, fdMap(fd)).map(_.setPos(fi)) + case _ => + None + }(e) + } + + for(fd <- newP.definedFunctions) { + fd.fullBody = replaceClassDefsUse(fd.fullBody) + } + (newP, + cdMapCache.collect{case (cd, Some(new_cd)) => cd -> new_cd}, + idMapCache, + fdMapCache.collect{case (cd, Some(new_cd)) => cd -> new_cd }) + } + + def addDefs(p: Program, cds: Traversable[Definition], after: Definition): Program = { var found = false @@ -356,7 +637,12 @@ object DefOps { ) }) if (!found) { - println("addDefs could not find anchor definition!") + println(s"addDefs could not find anchor definition! Not found: $after") + p.definedFunctions.filter(f => f.id.name == after.id.name).map(fd => fd.id.name + " : " + fd) match { + case Nil => + case e => println("Did you mean " + e) + } + println(Thread.currentThread().getStackTrace.map(_.toString).take(10).mkString("\n")) } res } diff --git a/src/main/scala/leon/purescala/Definitions.scala b/src/main/scala/leon/purescala/Definitions.scala index 8fb753d23d35364059115f7032bef3310c492d82..dfc78d4c547ddf40fb6c2a1e39bdab40611dcbd1 100644 --- a/src/main/scala/leon/purescala/Definitions.scala +++ b/src/main/scala/leon/purescala/Definitions.scala @@ -315,6 +315,20 @@ object Definitions { AbstractClassType(this, tps) } def typed: AbstractClassType = typed(tparams.map(_.tp)) + + /** Duplication of this [[CaseClassDef]]. + * @note This will not add known case class children + */ + def duplicate( + id: Identifier = this.id.freshen, + tparams: Seq[TypeParameterDef] = this.tparams, + parent: Option[AbstractClassType] = this.parent + ): AbstractClassDef = { + val acd = new AbstractClassDef(id, tparams, parent) + acd.addFlags(this.flags) + parent.map(_.classDef.ancestors.map(_.registerChild(acd))) + acd.copiedFrom(this) + } } /** Case classes/objects. */ @@ -351,6 +365,24 @@ object Definitions { CaseClassType(this, tps) } def typed: CaseClassType = typed(tparams.map(_.tp)) + + /** Duplication of this [[CaseClassDef]]. + * @note This will not replace recursive case class def calls in [[arguments]] nor the parent abstract class types + */ + def duplicate( + id: Identifier = this.id.freshen, + tparams: Seq[TypeParameterDef] = this.tparams, + fields: Seq[ValDef] = this.fields, + parent: Option[AbstractClassType] = this.parent, + isCaseObject: Boolean = this.isCaseObject + ): CaseClassDef = { + val cd = new CaseClassDef(id, tparams, parent, isCaseObject) + cd.setFields(fields) + cd.addFlags(this.flags) + cd.copiedFrom(this) + parent.map(_.classDef.ancestors.map(_.registerChild(cd))) + cd + } } /** Function/method definition. diff --git a/src/main/scala/leon/purescala/Expressions.scala b/src/main/scala/leon/purescala/Expressions.scala index 609f5cda04dcd1a1e5e0e55d1c5f01475236fd94..3c597d26e76976d44a939754fdea8c1caf45d29e 100644 --- a/src/main/scala/leon/purescala/Expressions.scala +++ b/src/main/scala/leon/purescala/Expressions.scala @@ -361,6 +361,7 @@ object Expressions { ) } + // Extracts without taking care of the binder. (contrary to Extractos.Pattern) object PatternExtractor extends SubTreeOps.Extractor[Pattern] { def unapply(e: Pattern): Option[(Seq[Pattern], (Seq[Pattern]) => Pattern)] = e match { case (_: InstanceOfPattern) | (_: WildcardPattern) | (_: LiteralPattern[_]) => diff --git a/src/main/scala/leon/solvers/combinators/Z3StringCapableSolver.scala b/src/main/scala/leon/solvers/combinators/Z3StringCapableSolver.scala index 2b6984d00d989e6958da3c43930b9614436cffbf..b94233f285e1fe63c486dd2711c63e115ef1dd7b 100644 --- a/src/main/scala/leon/solvers/combinators/Z3StringCapableSolver.scala +++ b/src/main/scala/leon/solvers/combinators/Z3StringCapableSolver.scala @@ -23,17 +23,49 @@ import leon.utils.Bijection import leon.solvers.z3.StringEcoSystem object Z3StringCapableSolver { - def convert(p: Program, force: Boolean = false): (Program, Option[Z3StringConversion]) = { + def thatShouldBeConverted(t: TypeTree): Boolean = TypeOps.exists{ _== StringType }(t) + def thatShouldBeConverted(e: Expr): Boolean = exists(e => thatShouldBeConverted(e.getType))(e) + def thatShouldBeConverted(id: Identifier): Boolean = thatShouldBeConverted(id.getType) + def thatShouldBeConverted(vd: ValDef): Boolean = thatShouldBeConverted(vd.id) + def thatShouldBeConverted(fd: FunDef): Boolean = { + (fd.body exists thatShouldBeConverted)|| (fd.paramIds exists thatShouldBeConverted) + } + def thatShouldBeConverted(cd: ClassDef): Boolean = cd match { + case ccd:CaseClassDef => ccd.fields.exists(thatShouldBeConverted) + case _ => false + } + def thatShouldBeConverted(p: Program): Boolean = { + (p.definedFunctions exists thatShouldBeConverted) || + (p.definedClasses exists thatShouldBeConverted) + } + + def convert(p: Program): (Program, Option[Z3StringConversion]) = { val converter = new Z3StringConversion(p) import converter.Forward._ - var globalFdMap = Map[FunDef, (Map[Identifier, Identifier], FunDef)]() var hasStrings = false val program_with_strings = converter.getProgram - val (new_program, fdMap) = DefOps.replaceFunDefs(program_with_strings)((fd: FunDef) => { + val (program_with_correct_classes, cdMap, idMap, fdMap) = if(program_with_strings.definedClasses.exists{ case c: CaseClassDef => c.fieldsIds.exists(id => TypeOps.exists{ _ == StringType}(id.getType)) case _ => false}) { + val res:(Program, Map[ClassDef, ClassDef], Map[Identifier, Identifier], Map[FunDef, FunDef]) = DefOps.replaceCaseClassDefs(program_with_strings)((cd: ClassDef) => { + cd match { + case acd:AbstractClassDef => None + case ccd:CaseClassDef => + if(ccd.fieldsIds.exists(id => TypeOps.exists(StringType == _)(id.getType))) { + Some((parent: Option[AbstractClassType]) => ccd.duplicate(convertId(ccd.id), ccd.tparams, ccd.fieldsIds.map(id => ValDef(convertId(id))), parent, ccd.isCaseObject)) + } else None + } + }) + converter.mappedVariables.clear() // We will compose them later, they have been stored in idMap + res + } else { + (program_with_strings, Map[ClassDef, ClassDef](), Map[Identifier, Identifier](), Map[FunDef, FunDef]()) + } + val fdMapInverse = fdMap.map(kv => kv._2 -> kv._1).toMap + val idMapInverse = idMap.map(kv => kv._2 -> kv._1).toMap + var globalFdMap = Map[FunDef, (Map[Identifier, Identifier], FunDef)]() + val (new_program, _) = DefOps.replaceFunDefs(program_with_correct_classes)((fd: FunDef) => { globalFdMap.get(fd).map(_._2).orElse( - if( fd.body.map(exists(e => TypeOps.exists{ _== StringType }(e.getType))).getOrElse(false) || - fd.paramIds.exists(id => TypeOps.exists(_ == StringType)(id.getType))) { - val idMap = fd.params.map(vd => vd.id -> convertId(vd.id)).toMap + if(thatShouldBeConverted(fd)) { + val idMap = fd.params.zip(fd.params).map(origvd_vd => origvd_vd._1.id -> convertId(origvd_vd._2.id)).toMap val newFdId = convertId(fd.id) val newFd = fd.duplicate(newFdId, fd.tparams, @@ -45,7 +77,7 @@ object Z3StringCapableSolver { } else None ) }) - if(!hasStrings && !force) { + if(!hasStrings) { (p, None) } else { converter.globalFdMap ++= globalFdMap.view.map(kv => (kv._1, kv._2._2)) @@ -53,26 +85,24 @@ object Z3StringCapableSolver { implicit val idVarMap = idMap.mapValues(id => Variable(id)) newFd.fullBody = convertExpr(newFd.fullBody) } + converter.mappedVariables.composeA(id => idMapInverse.getOrElse(id, id)) + converter.globalFdMap.composeA(fd => fdMapInverse.getOrElse(fd, fd)) + converter.globalClassMap ++= cdMap (new_program, Some(converter)) } } } -trait ForcedProgramConversion { self: Z3StringCapableSolver[_] => - override def convertProgram(p: Program): (Program, Option[Z3StringConversion]) = { - Z3StringCapableSolver.convert(p, true) - } -} - abstract class Z3StringCapableSolver[+TUnderlying <: Solver]( val context: LeonContext, val program: Program, val underlyingConstructor: (Program, Option[Z3StringConversion]) => TUnderlying) extends Solver { - def convertProgram(p: Program): (Program, Option[Z3StringConversion]) = Z3StringCapableSolver.convert(p) - protected val (new_program, someConverter) = convertProgram(program) + protected val (new_program, optConverter) = Z3StringCapableSolver.convert(program) + var someConverter = optConverter val underlying = underlyingConstructor(new_program, someConverter) + var solverInvokedWithStrings = false def getModel: leon.solvers.Model = { val model = underlying.getModel @@ -98,24 +128,40 @@ abstract class Z3StringCapableSolver[+TUnderlying <: Solver]( new PartialModel(original_ids.zip(original_exprs).toMap, new_domain) case _ => - new Model(original_ids.zip(original_exprs).toMap) - } + new Model(original_ids.zip(original_exprs).toMap) } } + } // Members declared in leon.utils.Interruptible def interrupt(): Unit = underlying.interrupt() def recoverInterrupt(): Unit = underlying.recoverInterrupt() + // Converts expression on the fly if needed, creating a string converter if needed. + def convertExprOnTheFly(expression: Expr, withConverter: Z3StringConversion => Expr): Expr = { + someConverter match { + case None => + if(solverInvokedWithStrings || exists(e => TypeOps.exists(StringType == _)(e.getType))(expression)) { // On the fly conversion + solverInvokedWithStrings = true + val c = new Z3StringConversion(program) + someConverter = Some(c) + withConverter(c) + } else expression + case Some(converter) => + withConverter(converter) + } + } + // Members declared in leon.solvers.Solver def assertCnstr(expression: Expr): Unit = { someConverter.map{converter => import converter.Forward._ val newExpression = convertExpr(expression)(Map()) underlying.assertCnstr(newExpression) - }.getOrElse(underlying.assertCnstr(expression)) + }.getOrElse{ + underlying.assertCnstr(convertExprOnTheFly(expression, _.Forward.convertExpr(expression)(Map()))) + } } - def getUnsatCore: Set[Expr] = { someConverter.map{converter => import converter.Backward._ @@ -150,7 +196,7 @@ class ConvertibleCodeGenEvaluator(context: LeonContext, originalProgram: Program override def compile(expression: Expr, args: Seq[Identifier]) : Option[solvers.Model=>EvaluationResult] = { import converter._ super.compile(Backward.convertExpr(expression)(Map()), args.map(Backward.convertId)) - .map(evaluator => (m: Model) => Forward.convertResult(evaluator(Backward.convertModel(m))) + .map(evaluator => (m: Model) => Forward.convertResult(evaluator(Backward.convertModel(m))) ) } } @@ -190,36 +236,36 @@ class Z3StringFairZ3Solver(context: LeonContext, program: Program) new FairZ3SolverWithBackwardEvaluator(context, prgm, program, someConverter)) with Z3StringEvaluatingSolver[FairZ3Solver] { - // Members declared in leon.solvers.z3.AbstractZ3Solver - protected[leon] val z3cfg: _root_.z3.scala.Z3Config = underlying.z3cfg - override def checkAssumptions(assumptions: Set[Expr]): Option[Boolean] = { - someConverter match { - case None => underlying.checkAssumptions(assumptions) - case Some(converter) => - underlying.checkAssumptions(assumptions map (e => converter.Forward.convertExpr(e)(Map()))) + // Members declared in leon.solvers.z3.AbstractZ3Solver + protected[leon] val z3cfg: _root_.z3.scala.Z3Config = underlying.z3cfg + override def checkAssumptions(assumptions: Set[Expr]): Option[Boolean] = { + someConverter match { + case None => underlying.checkAssumptions(assumptions.map(e => convertExprOnTheFly(e, _.Forward.convertExpr(e)(Map())))) + case Some(converter) => + underlying.checkAssumptions(assumptions map (e => converter.Forward.convertExpr(e)(Map()))) + } } - } } class Z3StringUnrollingSolver(context: LeonContext, program: Program, underlyingSolverConstructor: Program => Solver) extends Z3StringCapableSolver(context, program, (program: Program, converter: Option[Z3StringConversion]) => new UnrollingSolver(context, program, underlyingSolverConstructor(program))) - with Z3StringNaiveAssumptionSolver[UnrollingSolver] + with Z3StringNaiveAssumptionSolver[UnrollingSolver] with Z3StringEvaluatingSolver[UnrollingSolver] { - override def getUnsatCore = super[Z3StringNaiveAssumptionSolver].getUnsatCore + override def getUnsatCore = super[Z3StringNaiveAssumptionSolver].getUnsatCore } class Z3StringSMTLIBZ3QuantifiedSolver(context: LeonContext, program: Program) extends Z3StringCapableSolver(context, program, (program: Program, converter: Option[Z3StringConversion]) => new smtlib.SMTLIBZ3QuantifiedSolver(context, program)) { - override def checkAssumptions(assumptions: Set[Expr]): Option[Boolean] = { - someConverter match { - case None => underlying.checkAssumptions(assumptions) - case Some(converter) => - underlying.checkAssumptions(assumptions map (e => converter.Forward.convertExpr(e)(Map()))) + override def checkAssumptions(assumptions: Set[Expr]): Option[Boolean] = { + someConverter match { + case None => underlying.checkAssumptions(assumptions) + case Some(converter) => + underlying.checkAssumptions(assumptions map (e => converter.Forward.convertExpr(e)(Map()))) + } } - } } diff --git a/src/main/scala/leon/solvers/z3/Z3StringConversion.scala b/src/main/scala/leon/solvers/z3/Z3StringConversion.scala index 6d41b5fd8f45def478849c6c2d0623ed85985ffb..b644f687af3950d03bb062f0ef0f030c3691f6b4 100644 --- a/src/main/scala/leon/solvers/z3/Z3StringConversion.scala +++ b/src/main/scala/leon/solvers/z3/Z3StringConversion.scala @@ -112,28 +112,26 @@ object StringEcoSystem { } class Z3StringConversion(val p: Program) extends Z3StringConverters { - val stringBijection = new Bijection[String, Expr]() - import StringEcoSystem._ - - lazy val listchar = StringList.typed - lazy val conschar = StringCons.typed - lazy val nilchar = StringNil.typed - - lazy val list_size = StringSize.typed - lazy val list_++ = StringListConcat.typed - lazy val list_take = StringTake.typed - lazy val list_drop = StringDrop.typed - lazy val list_slice = StringSlice.typed - def getProgram = program_with_string_methods lazy val program_with_string_methods = { val p2 = DefOps.addClassDefs(p, StringEcoSystem.classDefs, p.library.Nil.get) DefOps.addFunDefs(p2, StringEcoSystem.funDefs, p2.library.escape.get) } +} - def convertToString(e: Expr)(implicit p: Program): String = +trait Z3StringConverters { + import StringEcoSystem._ + val mappedVariables = new Bijection[Identifier, Identifier]() + + val globalClassMap = new Bijection[ClassDef, ClassDef]() // To be added manually + + val globalFdMap = new Bijection[FunDef, FunDef]() + + val stringBijection = new Bijection[String, Expr]() + + def convertToString(e: Expr): String = stringBijection.cachedA(e) { e match { case CaseClass(_, Seq(CharLiteral(c), l)) => c + convertToString(l) @@ -142,27 +140,28 @@ class Z3StringConversion(val p: Program) extends Z3StringConverters { } def convertFromString(v: String): Expr = stringBijection.cachedB(v) { - v.toList.foldRight(CaseClass(nilchar, Seq())){ - case (char, l) => CaseClass(conschar, Seq(CharLiteral(char), l)) + v.toList.foldRight(CaseClass(StringNilTyped, Seq())){ + case (char, l) => CaseClass(StringConsTyped, Seq(CharLiteral(char), l)) } } -} - -trait Z3StringConverters { self: Z3StringConversion => - import StringEcoSystem._ - val mappedVariables = new Bijection[Identifier, Identifier]() - - val globalFdMap = new Bijection[FunDef, FunDef]() trait BidirectionalConverters { def convertFunDef(fd: FunDef): FunDef def hasIdConversion(id: Identifier): Boolean def convertId(id: Identifier): Identifier + def convertClassDef(d: ClassDef): ClassDef def isTypeToConvert(tpe: TypeTree): Boolean def convertType(tpe: TypeTree): TypeTree def convertPattern(pattern: Pattern): Pattern def convertExpr(expr: Expr)(implicit bindings: Map[Identifier, Expr]): Expr - + object TypeConverted { + def unapply(t: TypeTree): Option[TypeTree] = Some(t match { + case cct@CaseClassType(ccd, args) => CaseClassType(convertClassDef(ccd).asInstanceOf[CaseClassDef], args) + case act@AbstractClassType(acd, args) => AbstractClassType(convertClassDef(acd).asInstanceOf[AbstractClassDef], args) + case NAryType(es, builder) => + builder(es map convertType) + }) + } object PatternConverted { def unapply(e: Pattern): Option[Pattern] = Some(e match { case InstanceOfPattern(binder, ct) => @@ -260,6 +259,10 @@ trait Z3StringConverters { self: Z3StringConversion => def convertFunDef(fd: FunDef): FunDef = { globalFdMap.getBorElse(fd, fd) } + /* The conversion between classdefs should already have taken place */ + def convertClassDef(cd: ClassDef): ClassDef = { + globalClassMap.getBorElse(cd, cd) + } def hasIdConversion(id: Identifier): Boolean = { mappedVariables.containsA(id) } @@ -276,8 +279,10 @@ trait Z3StringConverters { self: Z3StringConversion => } def isTypeToConvert(tpe: TypeTree): Boolean = TypeOps.exists(StringType == _)(tpe) - def convertType(tpe: TypeTree): TypeTree = - TypeOps.preMap{ case StringType => Some(StringList.typed) case e => None}(tpe) + def convertType(tpe: TypeTree): TypeTree = tpe match { + case StringType => StringList.typed + case TypeConverted(t) => t + } def convertPattern(e: Pattern): Pattern = e match { case LiteralPattern(binder, StringLiteral(s)) => s.foldRight(CaseClassPattern(None, StringNilTyped, Seq())) { @@ -291,18 +296,17 @@ trait Z3StringConverters { self: Z3StringConversion => def convertExpr(e: Expr)(implicit bindings: Map[Identifier, Expr]): Expr = e match { case Variable(id) if isTypeToConvert(id.getType) => Variable(convertId(id)).copiedFrom(e) case StringLiteral(v) => - // No string support for z3 at this moment. val stringEncoding = convertFromString(v) convertExpr(stringEncoding).copiedFrom(e) case StringLength(a) => - FunctionInvocation(list_size, Seq(convertExpr(a))).copiedFrom(e) + FunctionInvocation(StringSize.typed, Seq(convertExpr(a))).copiedFrom(e) case StringConcat(a, b) => - FunctionInvocation(list_++, Seq(convertExpr(a), convertExpr(b))).copiedFrom(e) + FunctionInvocation(StringListConcat.typed, Seq(convertExpr(a), convertExpr(b))).copiedFrom(e) case SubString(a, start, Plus(start2, length)) if start == start2 => - FunctionInvocation(list_take, - Seq(FunctionInvocation(list_drop, Seq(convertExpr(a), convertExpr(start))), convertExpr(length))).copiedFrom(e) + FunctionInvocation(StringTake.typed, + Seq(FunctionInvocation(StringDrop.typed, Seq(convertExpr(a), convertExpr(start))), convertExpr(length))).copiedFrom(e) case SubString(a, start, end) => - FunctionInvocation(list_slice, Seq(convertExpr(a), convertExpr(start), convertExpr(end))).copiedFrom(e) + FunctionInvocation(StringSlice.typed, Seq(convertExpr(a), convertExpr(start), convertExpr(end))).copiedFrom(e) case MatchExpr(scrutinee, cases) => MatchExpr(convertExpr(scrutinee), for(MatchCase(pattern, guard, rhs) <- cases) yield { MatchCase(convertPattern(pattern), guard.map(convertExpr), convertExpr(rhs)) @@ -315,6 +319,10 @@ trait Z3StringConverters { self: Z3StringConversion => def convertFunDef(fd: FunDef): FunDef = { globalFdMap.getAorElse(fd, fd) } + /* The conversion between classdefs should already have taken place */ + def convertClassDef(cd: ClassDef): ClassDef = { + globalClassMap.getAorElse(cd, cd) + } def hasIdConversion(id: Identifier): Boolean = { mappedVariables.containsB(id) } @@ -335,38 +343,35 @@ trait Z3StringConverters { self: Z3StringConversion => } def isTypeToConvert(tpe: TypeTree): Boolean = TypeOps.exists(t => TypeOps.isSubtypeOf(t, StringListTyped))(tpe) - def convertType(tpe: TypeTree): TypeTree = { - TypeOps.preMap{ - case StringList | StringCons | StringNil => Some(StringType) - case e => None}(tpe) + def convertType(tpe: TypeTree): TypeTree = tpe match { + case StringList | StringCons | StringNil => StringType + case TypeConverted(t) => t } def convertPattern(e: Pattern): Pattern = e match { - case CaseClassPattern(b, StringNilTyped, Seq()) => - LiteralPattern(b.map(convertId), StringLiteral("")) - case CaseClassPattern(b, StringConsTyped, Seq(LiteralPattern(_, CharLiteral(elem)), subpattern)) => - convertPattern(subpattern) match { - case LiteralPattern(_, StringLiteral(s)) - => LiteralPattern(b.map(convertId), StringLiteral(elem + s)) - case e => LiteralPattern(None, StringLiteral("Failed to parse pattern back as string:" + e)) - } - case PatternConverted(e) => e - } - - + case CaseClassPattern(b, StringNilTyped, Seq()) => + LiteralPattern(b.map(convertId), StringLiteral("")) + case CaseClassPattern(b, StringConsTyped, Seq(LiteralPattern(_, CharLiteral(elem)), subpattern)) => + convertPattern(subpattern) match { + case LiteralPattern(_, StringLiteral(s)) + => LiteralPattern(b.map(convertId), StringLiteral(elem + s)) + case e => LiteralPattern(None, StringLiteral("Failed to parse pattern back as string:" + e)) + } + case PatternConverted(e) => e + } - def convertExpr(e: Expr)(implicit bindings: Map[Identifier, Expr]): Expr = - e match { - case cc@CaseClass(cct, args) if TypeOps.isSubtypeOf(cct, StringListTyped)=> - StringLiteral(convertToString(cc)(self.p)) - case FunctionInvocation(StringSize, Seq(a)) => - StringLength(convertExpr(a)).copiedFrom(e) - case FunctionInvocation(StringListConcat, Seq(a, b)) => - StringConcat(convertExpr(a), convertExpr(b)).copiedFrom(e) - case FunctionInvocation(StringTake, - Seq(FunctionInvocation(StringDrop, Seq(a, start)), length)) => - val rstart = convertExpr(start) - SubString(convertExpr(a), rstart, plus(rstart, convertExpr(length))).copiedFrom(e) - case ExprConverted(e) => e + def convertExpr(e: Expr)(implicit bindings: Map[Identifier, Expr]): Expr = + e match { + case cc@CaseClass(cct, args) if TypeOps.isSubtypeOf(cct, StringListTyped)=> + StringLiteral(convertToString(cc)) + case FunctionInvocation(StringSize, Seq(a)) => + StringLength(convertExpr(a)).copiedFrom(e) + case FunctionInvocation(StringListConcat, Seq(a, b)) => + StringConcat(convertExpr(a), convertExpr(b)).copiedFrom(e) + case FunctionInvocation(StringTake, + Seq(FunctionInvocation(StringDrop, Seq(a, start)), length)) => + val rstart = convertExpr(start) + SubString(convertExpr(a), rstart, plus(rstart, convertExpr(length))).copiedFrom(e) + case ExprConverted(e) => e + } } - } } diff --git a/src/main/scala/leon/utils/Bijection.scala b/src/main/scala/leon/utils/Bijection.scala index 3680930639a2cfba46490d4a21bab7772d7fd0c8..380799d25e1f73ddbbb57d7989706fd03e5f1821 100644 --- a/src/main/scala/leon/utils/Bijection.scala +++ b/src/main/scala/leon/utils/Bijection.scala @@ -2,9 +2,16 @@ package leon.utils -class Bijection[A, B] { +object Bijection { + def apply[A, B](a: Iterable[(A, B)]): Bijection[A, B] = new Bijection[A, B] ++= a + def apply[A, B](a: (A, B)*): Bijection[A, B] = apply(a.toSeq) +} + +class Bijection[A, B] extends Iterable[(A, B)] { protected var a2b = Map[A, B]() protected var b2a = Map[B, A]() + + def iterator = a2b.iterator def +=(a: A, b: B): Unit = { a2b += a -> b @@ -16,7 +23,7 @@ class Bijection[A, B] { this } - def ++=(t: Iterable[(A,B)]) = { + def ++=(t: Iterable[(A, B)]) = { (this /: t){ case (b, elem) => b += elem } } @@ -58,4 +65,11 @@ class Bijection[A, B] { def aSet = a2b.keySet def bSet = b2a.keySet + + def composeA[C](c: A => C): Bijection[C, B] = { + new Bijection[C, B] ++= this.a2b.map(kv => c(kv._1) -> kv._2) + } + def composeB[C](c: B => C): Bijection[A, C] = { + new Bijection[A, C] ++= this.a2b.map(kv => kv._1 -> c(kv._2)) + } } diff --git a/src/test/scala/leon/integration/solvers/SolversSuite.scala b/src/test/scala/leon/integration/solvers/SolversSuite.scala index c95fb39d33c86dac08d5d2465c99772fc88dc58f..d2af2030bded5ae60375180661107346164ca0c6 100644 --- a/src/test/scala/leon/integration/solvers/SolversSuite.scala +++ b/src/test/scala/leon/integration/solvers/SolversSuite.scala @@ -22,13 +22,13 @@ class SolversSuite extends LeonTestSuiteWithProgram { val getFactories: Seq[(String, (LeonContext, Program) => Solver)] = { (if (SolverFactory.hasNativeZ3) Seq( - ("fairz3", (ctx: LeonContext, pgm: Program) => new Z3StringFairZ3Solver(ctx, pgm) with ForcedProgramConversion ) + ("fairz3", (ctx: LeonContext, pgm: Program) => new Z3StringFairZ3Solver(ctx, pgm)) ) else Nil) ++ (if (SolverFactory.hasZ3) Seq( - ("smt-z3", (ctx: LeonContext, pgm: Program) => new Z3StringUnrollingSolver(ctx, pgm, pgm => new SMTLIBZ3Solver(ctx, pgm)) with ForcedProgramConversion ) + ("smt-z3", (ctx: LeonContext, pgm: Program) => new Z3StringUnrollingSolver(ctx, pgm, pgm => new SMTLIBZ3Solver(ctx, pgm))) ) else Nil) ++ (if (SolverFactory.hasCVC4) Seq( - ("smt-cvc4", (ctx: LeonContext, pgm: Program) => new Z3StringUnrollingSolver(ctx, pgm, pgm => new SMTLIBCVC4Solver(ctx, pgm)) with ForcedProgramConversion ) + ("smt-cvc4", (ctx: LeonContext, pgm: Program) => new Z3StringUnrollingSolver(ctx, pgm, pgm => new SMTLIBCVC4Solver(ctx, pgm))) ) else Nil) }