Skip to content
Snippets Groups Projects
Commit bcf8954a authored by Mikaël Mayer's avatar Mikaël Mayer
Browse files

Added draft for converting class definitions.

Found out that replaceFunDef is incomplete (i.e. no unapply pattern replacement)
parent 47ec4515
Branches
Tags
No related merge requests found
......@@ -5,6 +5,8 @@ package leon.purescala
import Definitions._
import Expressions._
import ExprOps.{preMap, functionCallsOf}
import leon.purescala.Types.AbstractClassType
import leon.purescala.Types._
object DefOps {
......@@ -274,13 +276,11 @@ object DefOps {
case _ =>
None
}
/** Clones the given program by replacing some functions by other functions.
*
* @param p The original program
* @param fdMapF Given f, returns Some(g) if f should be replaced by g, and None if f should be kept.
* May be called once each time a function appears (definition and invocation),
* so make sure to output the same if the argument is the same.
* @param fiMapF Given a previous function invocation and its new function definition, returns the expression to use.
* By default it is the function invocation using the new function definition.
* @return the new program with a map from the old functions to the new functions */
......@@ -288,13 +288,13 @@ object DefOps {
fiMapF: (FunctionInvocation, FunDef) => Option[Expr] = defaultFiMap)
: (Program, Map[FunDef, FunDef])= {
var fdMapCache = Map[FunDef, Option[FunDef]]()
var fdMapCache = Map[FunDef, FunDef]()
def fdMap(fd: FunDef): FunDef = {
if (!(fdMapCache contains fd)) {
fdMapCache += fd -> fdMapF(fd)
fdMapCache += fd -> fdMapF(fd).getOrElse(fd.duplicate())
}
fdMapCache(fd).getOrElse(fd)
fdMapCache(fd)
}
......@@ -304,23 +304,21 @@ object DefOps {
case m : ModuleDef =>
m.copy(defs = for (df <- m.defs) yield {
df match {
case f : FunDef =>
val newF = fdMap(f)
newF
case d =>
d
case f : FunDef => fdMap(f)
case d => d
}
})
case d => d
}
)
})
// TODO: Check for function invocations in unapply patterns.
for(fd <- newP.definedFunctions) {
if(ExprOps.exists{ case FunctionInvocation(TypedFunDef(fd, targs), fargs) => fdMapCache.getOrElse(fd, None) != None case _ => false }(fd.fullBody)) {
fd.fullBody = replaceFunCalls(fd.fullBody, fdMap, fiMapF)
if(ExprOps.exists{ case FunctionInvocation(TypedFunDef(fd, targs), fargs) => fdMapCache contains fd case _ => false }(fd.fullBody)) {
fd.fullBody = replaceFunCalls(fd.fullBody, fdMapCache, fiMapF)
}
}
(newP, fdMapCache.collect{ case (ofd, Some(nfd)) => ofd -> nfd })
(newP, fdMapCache)
}
def replaceFunCalls(e: Expr, fdMapF: FunDef => FunDef, fiMapF: (FunctionInvocation, FunDef) => Option[Expr] = defaultFiMap) = {
......@@ -331,6 +329,78 @@ object DefOps {
None
}(e)
}
private def defaultCdMap(cc: CaseClass, ccd: CaseClassDef): Option[Expr] = (cc, ccd) match {
case (CaseClass(old, args), newCcd) if old.classDef != newCcd =>
Some(CaseClass(newCcd.typed(old.tps), args))
case _ =>
None
}
/** Clones the given program by replacing some classes by other classes.
*
* @param p The original program
* @param cdMapF Given c and its cloned parent, returns Some(d) if c should be replaced by d, and None if c should be kept.
* Will always start to call this method for the topmost parents, and then descending.
* @param fiMapF Given a previous case class invocation and its new case class definition, returns the expression to use.
* By default it is the case class construction using the new case class definition.
* @return the new program with a map from the old case classes to the new case classes */
def replaceClassDefs(p: Program)(cdMapF: (ClassDef, Option[AbstractClassType]) => Option[ClassDef],
ciMapF: (CaseClass, CaseClassDef) => Option[Expr] = defaultCdMap): (Program, Map[ClassDef, ClassDef]) = {
var cdMapCache = Map[ClassDef, ClassDef]()
def tpMap(tt: TypeTree): TypeTree = tt match {
case AbstractClassType(asd, targs) => AbstractClassType(cdMap(asd).asInstanceOf[AbstractClassDef], targs map tpMap)
case CaseClassType(ccd, targs) => CaseClassType(cdMap(ccd).asInstanceOf[CaseClassDef], targs map tpMap)
case e => e
}
def cdMap(cd: ClassDef): ClassDef = {
if (!(cdMapCache contains cd)) {
lazy val parent = cd.parent.map( tpMap(_).asInstanceOf[AbstractClassType] )
cdMapCache += cd -> cdMapF(cd, parent).getOrElse{
cd match {
case acd:AbstractClassDef => acd.duplicate(parent = parent)
case ccd:CaseClassDef => ccd.duplicate(parent = parent)
}
}
}
cdMapCache(cd)
}
val newP = p.copy(units = for (u <- p.units) yield {
u.copy(
defs = u.defs.map {
case m : ModuleDef =>
m.copy(defs = for (df <- m.defs) yield {
df match {
case f : ClassDef => cdMap(f)
case d => d
}
})
case d => d
}
)
})
for(fd <- newP.definedFunctions) {
// TODO: Check for patterns
// TODO: Check for isInstanceOf
// TODO: Check for asInstanceOf
if(ExprOps.exists{ case CaseClass(CaseClassType(ccd, targs), fargs) => cdMapCache.getOrElse(ccd, None) != None case _ => false }(fd.fullBody)) {
fd.fullBody = replaceClassDefsUse(fd.fullBody, cdMap, ciMapF)
}
}
(newP, cdMapCache)
}
def replaceClassDefsUse(e: Expr, fdMapF: ClassDef => ClassDef, fiMapF: (CaseClass, CaseClassDef) => Option[Expr] = defaultCdMap) = {
preMap {
case fi @ CaseClass(CaseClassType(cd, tps), args) =>
fiMapF(fi, fdMapF(cd).asInstanceOf[CaseClassDef]).map(_.setPos(fi))
case _ =>
None
}(e)
}
def addDefs(p: Program, cds: Traversable[Definition], after: Definition): Program = {
var found = false
......
......@@ -315,6 +315,20 @@ object Definitions {
AbstractClassType(this, tps)
}
def typed: AbstractClassType = typed(tparams.map(_.tp))
/** Duplication of this [[CaseClassDef]].
* @note This will not add known case class children
*/
def duplicate(
id: Identifier = this.id.freshen,
tparams: Seq[TypeParameterDef] = this.tparams,
parent: Option[AbstractClassType] = this.parent
): AbstractClassDef = {
val acd = new AbstractClassDef(id, tparams, parent)
acd.addFlags(this.flags)
parent.map(_.classDef.ancestors.map(_.registerChild(acd)))
acd.copiedFrom(this)
}
}
/** Case classes/objects. */
......@@ -351,6 +365,24 @@ object Definitions {
CaseClassType(this, tps)
}
def typed: CaseClassType = typed(tparams.map(_.tp))
/** Duplication of this [[CaseClassDef]].
* @note This will not replace recursive case class def calls in [[arguments]] nor the parent abstract class types
*/
def duplicate(
id: Identifier = this.id.freshen,
tparams: Seq[TypeParameterDef] = this.tparams,
fields: Seq[ValDef] = this.fields,
parent: Option[AbstractClassType] = this.parent,
isCaseObject: Boolean = this.isCaseObject
): CaseClassDef = {
val cd = new CaseClassDef(id, tparams, parent, isCaseObject)
cd.setFields(fields)
cd.addFlags(this.flags)
cd.copiedFrom(this)
parent.map(_.classDef.ancestors.map(_.registerChild(cd)))
cd
}
}
/** Function/method definition.
......
......@@ -23,6 +23,22 @@ import leon.utils.Bijection
import leon.solvers.z3.StringEcoSystem
object Z3StringCapableSolver {
def thatShouldBeConverted(t: TypeTree): Boolean = TypeOps.exists{ _== StringType }(t)
def thatShouldBeConverted(e: Expr): Boolean = exists(e => thatShouldBeConverted(e.getType))(e)
def thatShouldBeConverted(id: Identifier): Boolean = thatShouldBeConverted(id.getType)
def thatShouldBeConverted(vd: ValDef): Boolean = thatShouldBeConverted(vd.id)
def thatShouldBeConverted(fd: FunDef): Boolean = {
(fd.body exists thatShouldBeConverted)|| (fd.paramIds exists thatShouldBeConverted)
}
def thatShouldBeConverted(cd: ClassDef): Boolean = cd match {
case ccd:CaseClassDef => ccd.fields.exists(thatShouldBeConverted)
case _ => false
}
def thatShouldBeConverted(p: Program): Boolean = {
(p.definedFunctions exists thatShouldBeConverted) ||
(p.definedClasses exists thatShouldBeConverted)
}
def convert(p: Program): (Program, Option[Z3StringConversion]) = {
val converter = new Z3StringConversion(p)
import converter.Forward._
......@@ -31,8 +47,7 @@ object Z3StringCapableSolver {
val program_with_strings = converter.getProgram
val (new_program, fdMap) = DefOps.replaceFunDefs(program_with_strings)((fd: FunDef) => {
globalFdMap.get(fd).map(_._2).orElse(
if( fd.body.map(exists(e => TypeOps.exists{ _== StringType }(e.getType))).getOrElse(false) ||
fd.paramIds.exists(id => TypeOps.exists(_ == StringType)(id.getType))) {
if(thatShouldBeConverted(fd)) {
val idMap = fd.params.map(vd => vd.id -> convertId(vd.id)).toMap
val newFdId = convertId(fd.id)
val newFd = fd.duplicate(newFdId,
......@@ -205,7 +220,7 @@ class Z3StringFairZ3Solver(context: LeonContext, program: Program)
protected[leon] val z3cfg: _root_.z3.scala.Z3Config = underlying.z3cfg
override def checkAssumptions(assumptions: Set[Expr]): Option[Boolean] = {
someConverter match {
case None => underlying.checkAssumptions(assumptions.map(e => this.convertExprOnTheFly(e, _.Forward.convertExpr(e)(Map()))))
case None => underlying.checkAssumptions(assumptions.map(e => convertExprOnTheFly(e, _.Forward.convertExpr(e)(Map()))))
case Some(converter) =>
underlying.checkAssumptions(assumptions map (e => converter.Forward.convertExpr(e)(Map())))
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment