From bcf8954ae4f9eaf7c3b7001859dab39db1dd5397 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mika=C3=ABl=20Mayer?= <mikael.mayer@epfl.ch> Date: Fri, 12 Feb 2016 18:45:42 +0100 Subject: [PATCH] Added draft for converting class definitions. Found out that replaceFunDef is incomplete (i.e. no unapply pattern replacement) --- src/main/scala/leon/purescala/DefOps.scala | 98 ++++++++++++++++--- .../scala/leon/purescala/Definitions.scala | 32 ++++++ .../combinators/Z3StringCapableSolver.scala | 21 +++- 3 files changed, 134 insertions(+), 17 deletions(-) diff --git a/src/main/scala/leon/purescala/DefOps.scala b/src/main/scala/leon/purescala/DefOps.scala index 4efc97986..f1a35ce97 100644 --- a/src/main/scala/leon/purescala/DefOps.scala +++ b/src/main/scala/leon/purescala/DefOps.scala @@ -5,6 +5,8 @@ package leon.purescala import Definitions._ import Expressions._ import ExprOps.{preMap, functionCallsOf} +import leon.purescala.Types.AbstractClassType +import leon.purescala.Types._ object DefOps { @@ -274,13 +276,11 @@ object DefOps { case _ => None } - + /** Clones the given program by replacing some functions by other functions. * * @param p The original program * @param fdMapF Given f, returns Some(g) if f should be replaced by g, and None if f should be kept. - * May be called once each time a function appears (definition and invocation), - * so make sure to output the same if the argument is the same. * @param fiMapF Given a previous function invocation and its new function definition, returns the expression to use. * By default it is the function invocation using the new function definition. * @return the new program with a map from the old functions to the new functions */ @@ -288,13 +288,13 @@ object DefOps { fiMapF: (FunctionInvocation, FunDef) => Option[Expr] = defaultFiMap) : (Program, Map[FunDef, FunDef])= { - var fdMapCache = Map[FunDef, Option[FunDef]]() + var fdMapCache = Map[FunDef, FunDef]() def fdMap(fd: FunDef): FunDef = { if (!(fdMapCache contains fd)) { - fdMapCache += fd -> fdMapF(fd) + fdMapCache += fd -> fdMapF(fd).getOrElse(fd.duplicate()) } - fdMapCache(fd).getOrElse(fd) + fdMapCache(fd) } @@ -304,23 +304,21 @@ object DefOps { case m : ModuleDef => m.copy(defs = for (df <- m.defs) yield { df match { - case f : FunDef => - val newF = fdMap(f) - newF - case d => - d + case f : FunDef => fdMap(f) + case d => d } }) case d => d } ) }) + // TODO: Check for function invocations in unapply patterns. for(fd <- newP.definedFunctions) { - if(ExprOps.exists{ case FunctionInvocation(TypedFunDef(fd, targs), fargs) => fdMapCache.getOrElse(fd, None) != None case _ => false }(fd.fullBody)) { - fd.fullBody = replaceFunCalls(fd.fullBody, fdMap, fiMapF) + if(ExprOps.exists{ case FunctionInvocation(TypedFunDef(fd, targs), fargs) => fdMapCache contains fd case _ => false }(fd.fullBody)) { + fd.fullBody = replaceFunCalls(fd.fullBody, fdMapCache, fiMapF) } } - (newP, fdMapCache.collect{ case (ofd, Some(nfd)) => ofd -> nfd }) + (newP, fdMapCache) } def replaceFunCalls(e: Expr, fdMapF: FunDef => FunDef, fiMapF: (FunctionInvocation, FunDef) => Option[Expr] = defaultFiMap) = { @@ -331,6 +329,78 @@ object DefOps { None }(e) } + + + private def defaultCdMap(cc: CaseClass, ccd: CaseClassDef): Option[Expr] = (cc, ccd) match { + case (CaseClass(old, args), newCcd) if old.classDef != newCcd => + Some(CaseClass(newCcd.typed(old.tps), args)) + case _ => + None + } + + /** Clones the given program by replacing some classes by other classes. + * + * @param p The original program + * @param cdMapF Given c and its cloned parent, returns Some(d) if c should be replaced by d, and None if c should be kept. + * Will always start to call this method for the topmost parents, and then descending. + * @param fiMapF Given a previous case class invocation and its new case class definition, returns the expression to use. + * By default it is the case class construction using the new case class definition. + * @return the new program with a map from the old case classes to the new case classes */ + def replaceClassDefs(p: Program)(cdMapF: (ClassDef, Option[AbstractClassType]) => Option[ClassDef], + ciMapF: (CaseClass, CaseClassDef) => Option[Expr] = defaultCdMap): (Program, Map[ClassDef, ClassDef]) = { + var cdMapCache = Map[ClassDef, ClassDef]() + def tpMap(tt: TypeTree): TypeTree = tt match { + case AbstractClassType(asd, targs) => AbstractClassType(cdMap(asd).asInstanceOf[AbstractClassDef], targs map tpMap) + case CaseClassType(ccd, targs) => CaseClassType(cdMap(ccd).asInstanceOf[CaseClassDef], targs map tpMap) + case e => e + } + + def cdMap(cd: ClassDef): ClassDef = { + if (!(cdMapCache contains cd)) { + lazy val parent = cd.parent.map( tpMap(_).asInstanceOf[AbstractClassType] ) + cdMapCache += cd -> cdMapF(cd, parent).getOrElse{ + cd match { + case acd:AbstractClassDef => acd.duplicate(parent = parent) + case ccd:CaseClassDef => ccd.duplicate(parent = parent) + } + } + } + cdMapCache(cd) + } + + val newP = p.copy(units = for (u <- p.units) yield { + u.copy( + defs = u.defs.map { + case m : ModuleDef => + m.copy(defs = for (df <- m.defs) yield { + df match { + case f : ClassDef => cdMap(f) + case d => d + } + }) + case d => d + } + ) + }) + for(fd <- newP.definedFunctions) { + // TODO: Check for patterns + // TODO: Check for isInstanceOf + // TODO: Check for asInstanceOf + if(ExprOps.exists{ case CaseClass(CaseClassType(ccd, targs), fargs) => cdMapCache.getOrElse(ccd, None) != None case _ => false }(fd.fullBody)) { + fd.fullBody = replaceClassDefsUse(fd.fullBody, cdMap, ciMapF) + } + } + (newP, cdMapCache) + } + + def replaceClassDefsUse(e: Expr, fdMapF: ClassDef => ClassDef, fiMapF: (CaseClass, CaseClassDef) => Option[Expr] = defaultCdMap) = { + preMap { + case fi @ CaseClass(CaseClassType(cd, tps), args) => + fiMapF(fi, fdMapF(cd).asInstanceOf[CaseClassDef]).map(_.setPos(fi)) + case _ => + None + }(e) + } def addDefs(p: Program, cds: Traversable[Definition], after: Definition): Program = { var found = false diff --git a/src/main/scala/leon/purescala/Definitions.scala b/src/main/scala/leon/purescala/Definitions.scala index 8fb753d23..dfc78d4c5 100644 --- a/src/main/scala/leon/purescala/Definitions.scala +++ b/src/main/scala/leon/purescala/Definitions.scala @@ -315,6 +315,20 @@ object Definitions { AbstractClassType(this, tps) } def typed: AbstractClassType = typed(tparams.map(_.tp)) + + /** Duplication of this [[CaseClassDef]]. + * @note This will not add known case class children + */ + def duplicate( + id: Identifier = this.id.freshen, + tparams: Seq[TypeParameterDef] = this.tparams, + parent: Option[AbstractClassType] = this.parent + ): AbstractClassDef = { + val acd = new AbstractClassDef(id, tparams, parent) + acd.addFlags(this.flags) + parent.map(_.classDef.ancestors.map(_.registerChild(acd))) + acd.copiedFrom(this) + } } /** Case classes/objects. */ @@ -351,6 +365,24 @@ object Definitions { CaseClassType(this, tps) } def typed: CaseClassType = typed(tparams.map(_.tp)) + + /** Duplication of this [[CaseClassDef]]. + * @note This will not replace recursive case class def calls in [[arguments]] nor the parent abstract class types + */ + def duplicate( + id: Identifier = this.id.freshen, + tparams: Seq[TypeParameterDef] = this.tparams, + fields: Seq[ValDef] = this.fields, + parent: Option[AbstractClassType] = this.parent, + isCaseObject: Boolean = this.isCaseObject + ): CaseClassDef = { + val cd = new CaseClassDef(id, tparams, parent, isCaseObject) + cd.setFields(fields) + cd.addFlags(this.flags) + cd.copiedFrom(this) + parent.map(_.classDef.ancestors.map(_.registerChild(cd))) + cd + } } /** Function/method definition. diff --git a/src/main/scala/leon/solvers/combinators/Z3StringCapableSolver.scala b/src/main/scala/leon/solvers/combinators/Z3StringCapableSolver.scala index c39c4fe09..df01b574a 100644 --- a/src/main/scala/leon/solvers/combinators/Z3StringCapableSolver.scala +++ b/src/main/scala/leon/solvers/combinators/Z3StringCapableSolver.scala @@ -23,6 +23,22 @@ import leon.utils.Bijection import leon.solvers.z3.StringEcoSystem object Z3StringCapableSolver { + def thatShouldBeConverted(t: TypeTree): Boolean = TypeOps.exists{ _== StringType }(t) + def thatShouldBeConverted(e: Expr): Boolean = exists(e => thatShouldBeConverted(e.getType))(e) + def thatShouldBeConverted(id: Identifier): Boolean = thatShouldBeConverted(id.getType) + def thatShouldBeConverted(vd: ValDef): Boolean = thatShouldBeConverted(vd.id) + def thatShouldBeConverted(fd: FunDef): Boolean = { + (fd.body exists thatShouldBeConverted)|| (fd.paramIds exists thatShouldBeConverted) + } + def thatShouldBeConverted(cd: ClassDef): Boolean = cd match { + case ccd:CaseClassDef => ccd.fields.exists(thatShouldBeConverted) + case _ => false + } + def thatShouldBeConverted(p: Program): Boolean = { + (p.definedFunctions exists thatShouldBeConverted) || + (p.definedClasses exists thatShouldBeConverted) + } + def convert(p: Program): (Program, Option[Z3StringConversion]) = { val converter = new Z3StringConversion(p) import converter.Forward._ @@ -31,8 +47,7 @@ object Z3StringCapableSolver { val program_with_strings = converter.getProgram val (new_program, fdMap) = DefOps.replaceFunDefs(program_with_strings)((fd: FunDef) => { globalFdMap.get(fd).map(_._2).orElse( - if( fd.body.map(exists(e => TypeOps.exists{ _== StringType }(e.getType))).getOrElse(false) || - fd.paramIds.exists(id => TypeOps.exists(_ == StringType)(id.getType))) { + if(thatShouldBeConverted(fd)) { val idMap = fd.params.map(vd => vd.id -> convertId(vd.id)).toMap val newFdId = convertId(fd.id) val newFd = fd.duplicate(newFdId, @@ -205,7 +220,7 @@ class Z3StringFairZ3Solver(context: LeonContext, program: Program) protected[leon] val z3cfg: _root_.z3.scala.Z3Config = underlying.z3cfg override def checkAssumptions(assumptions: Set[Expr]): Option[Boolean] = { someConverter match { - case None => underlying.checkAssumptions(assumptions.map(e => this.convertExprOnTheFly(e, _.Forward.convertExpr(e)(Map())))) + case None => underlying.checkAssumptions(assumptions.map(e => convertExprOnTheFly(e, _.Forward.convertExpr(e)(Map())))) case Some(converter) => underlying.checkAssumptions(assumptions map (e => converter.Forward.convertExpr(e)(Map()))) } -- GitLab