From bcf8954ae4f9eaf7c3b7001859dab39db1dd5397 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Mika=C3=ABl=20Mayer?= <mikael.mayer@epfl.ch>
Date: Fri, 12 Feb 2016 18:45:42 +0100
Subject: [PATCH] Added draft for converting class definitions. Found out that
 replaceFunDef is incomplete (i.e. no unapply pattern replacement)

---
 src/main/scala/leon/purescala/DefOps.scala    | 98 ++++++++++++++++---
 .../scala/leon/purescala/Definitions.scala    | 32 ++++++
 .../combinators/Z3StringCapableSolver.scala   | 21 +++-
 3 files changed, 134 insertions(+), 17 deletions(-)

diff --git a/src/main/scala/leon/purescala/DefOps.scala b/src/main/scala/leon/purescala/DefOps.scala
index 4efc97986..f1a35ce97 100644
--- a/src/main/scala/leon/purescala/DefOps.scala
+++ b/src/main/scala/leon/purescala/DefOps.scala
@@ -5,6 +5,8 @@ package leon.purescala
 import Definitions._
 import Expressions._
 import ExprOps.{preMap, functionCallsOf}
+import leon.purescala.Types.AbstractClassType
+import leon.purescala.Types._
 
 object DefOps {
 
@@ -274,13 +276,11 @@ object DefOps {
     case _ =>
       None
   }
-
+  
   /** Clones the given program by replacing some functions by other functions.
     * 
     * @param p The original program
     * @param fdMapF Given f, returns Some(g) if f should be replaced by g, and None if f should be kept.
-    *        May be called once each time a function appears (definition and invocation),
-    *        so make sure to output the same if the argument is the same.
     * @param fiMapF Given a previous function invocation and its new function definition, returns the expression to use.
     *               By default it is the function invocation using the new function definition.
     * @return the new program with a map from the old functions to the new functions */
@@ -288,13 +288,13 @@ object DefOps {
                                  fiMapF: (FunctionInvocation, FunDef) => Option[Expr] = defaultFiMap)
                                  : (Program, Map[FunDef, FunDef])= {
 
-    var fdMapCache = Map[FunDef, Option[FunDef]]()
+    var fdMapCache = Map[FunDef, FunDef]()
     def fdMap(fd: FunDef): FunDef = {
       if (!(fdMapCache contains fd)) {
-        fdMapCache += fd -> fdMapF(fd)
+        fdMapCache += fd -> fdMapF(fd).getOrElse(fd.duplicate())
       }
 
-      fdMapCache(fd).getOrElse(fd)
+      fdMapCache(fd)
     }
 
 
@@ -304,23 +304,21 @@ object DefOps {
           case m : ModuleDef =>
             m.copy(defs = for (df <- m.defs) yield {
               df match {
-                case f : FunDef =>
-                  val newF = fdMap(f)
-                  newF
-                case d =>
-                  d
+                case f : FunDef => fdMap(f)
+                case d => d
               }
           })
           case d => d
         }
       )
     })
+    // TODO: Check for function invocations in unapply patterns.
     for(fd <- newP.definedFunctions) {
-      if(ExprOps.exists{ case FunctionInvocation(TypedFunDef(fd, targs), fargs) => fdMapCache.getOrElse(fd, None) != None case _ => false }(fd.fullBody)) {
-        fd.fullBody = replaceFunCalls(fd.fullBody, fdMap, fiMapF)
+      if(ExprOps.exists{ case FunctionInvocation(TypedFunDef(fd, targs), fargs) => fdMapCache contains fd case _ => false }(fd.fullBody)) {
+        fd.fullBody = replaceFunCalls(fd.fullBody, fdMapCache, fiMapF)
       }
     }
-    (newP, fdMapCache.collect{ case (ofd, Some(nfd)) => ofd -> nfd })
+    (newP, fdMapCache)
   }
 
   def replaceFunCalls(e: Expr, fdMapF: FunDef => FunDef, fiMapF: (FunctionInvocation, FunDef) => Option[Expr] = defaultFiMap) = {
@@ -331,6 +329,78 @@ object DefOps {
         None
     }(e)
   }
+  
+
+  private def defaultCdMap(cc: CaseClass, ccd: CaseClassDef): Option[Expr] = (cc, ccd) match {
+    case (CaseClass(old, args), newCcd) if old.classDef != newCcd =>
+      Some(CaseClass(newCcd.typed(old.tps), args))
+    case _ =>
+      None
+  }
+  
+  /** Clones the given program by replacing some classes by other classes.
+    * 
+    * @param p The original program
+    * @param cdMapF Given c and its cloned parent, returns Some(d) if c should be replaced by d, and None if c should be kept.
+    *        Will always start to call this method for the topmost parents, and then descending.
+    * @param fiMapF Given a previous case class invocation and its new case class definition, returns the expression to use.
+    *               By default it is the case class construction using the new case class definition.
+    * @return the new program with a map from the old case classes to the new case classes */
+  def replaceClassDefs(p: Program)(cdMapF: (ClassDef, Option[AbstractClassType]) => Option[ClassDef],
+                                   ciMapF: (CaseClass, CaseClassDef) => Option[Expr] = defaultCdMap): (Program, Map[ClassDef, ClassDef]) = {
+    var cdMapCache = Map[ClassDef, ClassDef]()
+    def tpMap(tt: TypeTree): TypeTree = tt match {
+      case AbstractClassType(asd, targs) => AbstractClassType(cdMap(asd).asInstanceOf[AbstractClassDef], targs map tpMap)
+      case CaseClassType(ccd, targs) => CaseClassType(cdMap(ccd).asInstanceOf[CaseClassDef], targs map tpMap)
+      case e => e
+    }
+    
+    def cdMap(cd: ClassDef): ClassDef = {
+      if (!(cdMapCache contains cd)) {
+        lazy val parent = cd.parent.map( tpMap(_).asInstanceOf[AbstractClassType] )
+        cdMapCache += cd -> cdMapF(cd, parent).getOrElse{
+          cd match {
+            case acd:AbstractClassDef => acd.duplicate(parent = parent)
+            case ccd:CaseClassDef => ccd.duplicate(parent = parent)
+          }
+        }
+      }
+      cdMapCache(cd)
+    }
+    
+    val newP = p.copy(units = for (u <- p.units) yield {
+      u.copy(
+        defs = u.defs.map {
+          case m : ModuleDef =>
+            m.copy(defs = for (df <- m.defs) yield {
+              df match {
+                case f : ClassDef => cdMap(f)
+                case d => d
+              }
+          })
+          case d => d
+        }
+      )
+    })
+    for(fd <- newP.definedFunctions) {
+      // TODO: Check for patterns
+      // TODO: Check for isInstanceOf
+      // TODO: Check for asInstanceOf
+      if(ExprOps.exists{ case CaseClass(CaseClassType(ccd, targs), fargs) => cdMapCache.getOrElse(ccd, None) != None case _ => false }(fd.fullBody)) {
+        fd.fullBody = replaceClassDefsUse(fd.fullBody, cdMap, ciMapF)
+      }
+    }
+    (newP, cdMapCache)
+  }
+  
+  def replaceClassDefsUse(e: Expr, fdMapF: ClassDef => ClassDef, fiMapF: (CaseClass, CaseClassDef) => Option[Expr] = defaultCdMap) = {
+    preMap {
+      case fi @ CaseClass(CaseClassType(cd, tps), args) =>
+        fiMapF(fi, fdMapF(cd).asInstanceOf[CaseClassDef]).map(_.setPos(fi))
+      case _ =>
+        None
+    }(e)
+  }
 
   def addDefs(p: Program, cds: Traversable[Definition], after: Definition): Program = {
     var found = false
diff --git a/src/main/scala/leon/purescala/Definitions.scala b/src/main/scala/leon/purescala/Definitions.scala
index 8fb753d23..dfc78d4c5 100644
--- a/src/main/scala/leon/purescala/Definitions.scala
+++ b/src/main/scala/leon/purescala/Definitions.scala
@@ -315,6 +315,20 @@ object Definitions {
       AbstractClassType(this, tps)
     }
     def typed: AbstractClassType = typed(tparams.map(_.tp))
+    
+    /** Duplication of this [[CaseClassDef]].
+      * @note This will not add known case class children
+      */
+    def duplicate(
+      id: Identifier                    = this.id.freshen,
+      tparams: Seq[TypeParameterDef]    = this.tparams,
+      parent: Option[AbstractClassType] = this.parent
+    ): AbstractClassDef = {
+      val acd = new AbstractClassDef(id, tparams, parent)
+      acd.addFlags(this.flags)
+      parent.map(_.classDef.ancestors.map(_.registerChild(acd)))
+      acd.copiedFrom(this)
+    }
   }
 
   /** Case classes/objects. */
@@ -351,6 +365,24 @@ object Definitions {
       CaseClassType(this, tps)
     }
     def typed: CaseClassType = typed(tparams.map(_.tp))
+    
+    /** Duplication of this [[CaseClassDef]].
+      * @note This will not replace recursive case class def calls in [[arguments]] nor the parent abstract class types
+      */
+    def duplicate(
+      id: Identifier                    = this.id.freshen,
+      tparams: Seq[TypeParameterDef]    = this.tparams,
+      fields: Seq[ValDef]               = this.fields,
+      parent: Option[AbstractClassType] = this.parent,
+      isCaseObject: Boolean             = this.isCaseObject
+    ): CaseClassDef = {
+      val cd = new CaseClassDef(id, tparams, parent, isCaseObject)
+      cd.setFields(fields)
+      cd.addFlags(this.flags)
+      cd.copiedFrom(this)
+      parent.map(_.classDef.ancestors.map(_.registerChild(cd)))
+      cd
+    }
   }
 
   /** Function/method definition.
diff --git a/src/main/scala/leon/solvers/combinators/Z3StringCapableSolver.scala b/src/main/scala/leon/solvers/combinators/Z3StringCapableSolver.scala
index c39c4fe09..df01b574a 100644
--- a/src/main/scala/leon/solvers/combinators/Z3StringCapableSolver.scala
+++ b/src/main/scala/leon/solvers/combinators/Z3StringCapableSolver.scala
@@ -23,6 +23,22 @@ import leon.utils.Bijection
 import leon.solvers.z3.StringEcoSystem
 
 object Z3StringCapableSolver {
+  def thatShouldBeConverted(t: TypeTree): Boolean = TypeOps.exists{ _== StringType }(t)
+  def thatShouldBeConverted(e: Expr): Boolean = exists(e => thatShouldBeConverted(e.getType))(e)
+  def thatShouldBeConverted(id: Identifier): Boolean = thatShouldBeConverted(id.getType)
+  def thatShouldBeConverted(vd: ValDef): Boolean = thatShouldBeConverted(vd.id)
+  def thatShouldBeConverted(fd: FunDef): Boolean = {
+    (fd.body exists thatShouldBeConverted)|| (fd.paramIds exists thatShouldBeConverted)
+  }
+  def thatShouldBeConverted(cd: ClassDef): Boolean = cd match {
+    case ccd:CaseClassDef =>  ccd.fields.exists(thatShouldBeConverted)
+    case _ => false
+  }
+  def thatShouldBeConverted(p: Program): Boolean = {
+    (p.definedFunctions exists thatShouldBeConverted) ||
+    (p.definedClasses exists thatShouldBeConverted)
+  }
+  
   def convert(p: Program): (Program, Option[Z3StringConversion]) = {
     val converter = new Z3StringConversion(p)
     import converter.Forward._
@@ -31,8 +47,7 @@ object Z3StringCapableSolver {
     val program_with_strings = converter.getProgram
     val (new_program, fdMap) = DefOps.replaceFunDefs(program_with_strings)((fd: FunDef) => {
       globalFdMap.get(fd).map(_._2).orElse(
-          if( fd.body.map(exists(e => TypeOps.exists{ _== StringType }(e.getType))).getOrElse(false) ||
-              fd.paramIds.exists(id => TypeOps.exists(_ == StringType)(id.getType))) {
+          if(thatShouldBeConverted(fd)) {
             val idMap = fd.params.map(vd => vd.id -> convertId(vd.id)).toMap
             val newFdId = convertId(fd.id)
             val newFd = fd.duplicate(newFdId,
@@ -205,7 +220,7 @@ class Z3StringFairZ3Solver(context: LeonContext, program: Program)
     protected[leon] val z3cfg: _root_.z3.scala.Z3Config = underlying.z3cfg
     override def checkAssumptions(assumptions: Set[Expr]): Option[Boolean] = {
       someConverter match {
-        case None => underlying.checkAssumptions(assumptions.map(e => this.convertExprOnTheFly(e, _.Forward.convertExpr(e)(Map()))))
+        case None => underlying.checkAssumptions(assumptions.map(e => convertExprOnTheFly(e, _.Forward.convertExpr(e)(Map()))))
         case Some(converter) =>
           underlying.checkAssumptions(assumptions map (e => converter.Forward.convertExpr(e)(Map())))
       }
-- 
GitLab