From 8af3310b1871c8e150744c2bf0895e52181108c2 Mon Sep 17 00:00:00 2001
From: "Emmanouil (Manos) Koukoutos" <emmanouil.koukoutos@epfl.ch>
Date: Fri, 11 Apr 2014 17:44:14 +0200
Subject: [PATCH] Default parameters for functions

---
 .../leon/frontends/scalac/ASTExtractors.scala | 40 ++++++++
 .../frontends/scalac/CodeExtraction.scala     | 96 ++++++++++++++++++-
 .../scala/leon/purescala/Definitions.scala    |  4 +
 .../scala/leon/purescala/PrettyPrinter.scala  | 34 ++++++-
 .../scala/leon/purescala/ScalaPrinter.scala   |  6 ++
 .../regression/frontends/OptParams.scala      | 15 +++
 6 files changed, 189 insertions(+), 6 deletions(-)
 create mode 100644 src/test/resources/regression/frontends/OptParams.scala

diff --git a/src/main/scala/leon/frontends/scalac/ASTExtractors.scala b/src/main/scala/leon/frontends/scalac/ASTExtractors.scala
index 840dd6fa4..ea6309c8c 100644
--- a/src/main/scala/leon/frontends/scalac/ASTExtractors.scala
+++ b/src/main/scala/leon/frontends/scalac/ASTExtractors.scala
@@ -310,6 +310,18 @@ trait ASTExtractors {
       }
     }
     
+    object ExCompanionObjectSynthetic {
+      def unapply(cd : ClassDef) : Option[(String, Symbol, Template)] = {
+        val sym = cd.symbol 
+        cd match {
+         case ClassDef(_, name, tparams, impl) if sym.isModule && sym.isSynthetic => //FIXME flags?
+           Some((name.toString, sym, impl))
+         case _ => None
+        }
+        
+      }
+    }
+
     object ExCaseClassSyntheticJunk {
       def unapply(cd: ClassDef): Boolean = cd match {
         case ClassDef(_, _, _, _) if (cd.symbol.isSynthetic) => true
@@ -408,6 +420,34 @@ trait ASTExtractors {
       }
       
     }
+    
+    object ExDefaultValueFunction{
+      /** Matches a function that defines the default value of a parameter */
+      def unapply(dd: DefDef): Option[(Symbol, Seq[Symbol], Seq[ValDef], Type, String, Int, Tree)] = {
+        val sym = dd.symbol
+        dd match {
+          case DefDef(_, name, tparams, vparamss, tpt, rhs) if(
+            vparamss.size <= 1 && name != nme.CONSTRUCTOR && sym.isSynthetic 
+          ) => 
+            
+            // Split the name into pieces, to find owner of the parameter + param.index
+            // Form has to be <owner name>$default$<param index>
+            val symPieces = sym.name.toString.reverse.split("\\$",3).reverse map { _.reverse }
+            
+            try {
+              if (symPieces(1) != "default" || symPieces(0) == "copy") throw new IllegalArgumentException("")
+              val ownerString = symPieces(0)
+              val index = symPieces(2).toInt - 1
+              Some((sym, tparams.map(_.symbol), vparamss.headOption.getOrElse(Nil), tpt.tpe, ownerString, index, rhs))
+            } catch {
+              case _ : NumberFormatException | _ : IllegalArgumentException | _ : ArrayIndexOutOfBoundsException =>
+                None 
+            }
+              
+          case _ => None
+        }
+      }
+    } 
 
   }
 
diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala
index 8166e69f8..630f688a5 100644
--- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala
+++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala
@@ -209,7 +209,6 @@ trait CodeExtraction extends ASTExtractors {
               !isLib(u)
             )
    
-        
           case pd @ PackageDef(refTree, lst) =>
                          
             var standaloneDefs = List[Tree]()
@@ -224,7 +223,22 @@ trait CodeExtraction extends ASTExtractors {
                 
               case ExObjectDef(n, templ) if n != "package" =>
                 Some(TempModule(FreshIdentifier(n), templ.body, false))
-  
+
+              /*
+              case d @ ExCompanionObjectSynthetic(_, sym, templ) => 
+                // Find default param. implementations
+                templ.body foreach { 
+                  case ExDefaultValueFunction(sym, _, _, _, owner, index, _) =>
+                    val namePieces = sym.toString.reverse.split("\\$", 3).reverse map { _.reverse }
+                    assert(namePieces.length == 3 && namePieces(0)== "$lessinit$greater" && namePieces(1) == "default") // FIXME : maybe $lessinit$greater?
+                    val index = namePieces(2).toInt
+                    val theParam = sym.companionClass.paramss.head(index - 1)
+                    paramsToDefaultValues += theParam -> body
+                  case _ => 
+                } 
+                None 
+              */
+
               case d @ ExAbstractClass(_, _, _) =>
                 standaloneDefs ::= d
                 None
@@ -389,6 +403,8 @@ trait CodeExtraction extends ASTExtractors {
           outOfSubsetError(pos, "Class "+className+" is not a case class")
       }
     }
+    
+    private var paramsToDefaultValues = Map[Symbol,FunDef]()
 
     private def collectClassSymbols(defs: List[Tree]) {
       // We collect all defined classes
@@ -450,6 +466,28 @@ trait CodeExtraction extends ASTExtractors {
     private var isMethod = Set[Symbol]()
     private var methodToClass = Map[FunDef, LeonClassDef]()
 
+    /**
+     * For the function in $defs with name $owner, find its parameter with index $index, 
+     * and registers $fd as the default value function for this parameter.  
+     */
+    private def registerDefaultMethod(
+        defs : List[Tree],
+        matcher : PartialFunction[Tree,Symbol],
+        index : Int,
+        fd : FunDef
+    )  {
+      // Search tmpl to find the function that includes this parameter
+      val paramOwner = defs.collectFirst(matcher).get
+      
+      // assumes single argument list
+      if(paramOwner.paramss.length != 1) {
+        outOfSubsetError(paramOwner.pos, "Multiple argument lists for a function are not allowed")
+      }
+      val theParam = paramOwner.paramss.head(index)
+      paramsToDefaultValues += (theParam -> fd)
+    }
+    
+    
     def extractClassDef(sym: Symbol, args: Seq[(Symbol, ValDef)], tmpl: Template): LeonClassDef = {
       val id = FreshIdentifier(sym.name.toString).setPos(sym.pos)
 
@@ -544,6 +582,21 @@ trait CodeExtraction extends ASTExtractors {
 
           cd.registerMethod(fd)
 
+        // Default values for parameters
+        case t@ ExDefaultValueFunction(fsym, _, _, _, owner, index, _) =>          
+          val fd = defineFunDef(fsym)(defCtx)
+          fd.addAnnotation("synthetic")
+                    
+          isMethod += fsym
+          methodToClass += fd -> cd
+
+          cd.registerMethod(fd)       
+          val matcher : PartialFunction[Tree, Symbol] = { 
+            case ExFunctionDef(ownerSym, _ ,_ ,_, _) if ownerSym.name.toString == owner => ownerSym 
+          } 
+          registerDefaultMethod(tmpl.body, matcher, index, fd )
+                   
+          
         // Lazy fields
         case t @ ExLazyAccessorFunction(fsym, _, _)  =>
           if (parent.isDefined) {
@@ -639,6 +692,16 @@ trait CodeExtraction extends ASTExtractors {
         case ExFunctionDef(sym, _, _, _, _) =>
           defineFunDef(sym)(DefContext())
 
+        case t @ ExDefaultValueFunction(sym, _, _, _, owner, index, _) => { 
+          
+          val fd = defineFunDef(sym)(DefContext())
+          fd.addAnnotation("synthetic")
+          val matcher : PartialFunction[Tree, Symbol] = { 
+            case ExFunctionDef(ownerSym, _ ,_ ,_, _) if ownerSym.name.toString == owner => ownerSym 
+          } 
+          registerDefaultMethod(defs, matcher, index, fd)
+           
+        }
         case ExLazyAccessorFunction(sym, _, _)  =>
           defineFieldFunDef(sym,true)(DefContext())
           
@@ -673,6 +736,16 @@ trait CodeExtraction extends ASTExtractors {
               extractFunBody(fd, params, body)(DefContext(tparamsMap))
             }
             
+          // Default value functions
+          case ExDefaultValueFunction(sym, tparams, params, _, _, _, body) =>
+            val fd = defsToDefs(sym)
+
+            val tparamsMap = (tparams zip fd.tparams.map(_.tp)).toMap ++ ctparamsMap
+
+            if(body != EmptyTree) {
+              extractFunBody(fd, params, body)(DefContext(tparamsMap))
+            }
+            
           // Lazy fields
           case t @ ExLazyAccessorFunction(sym, _, body) =>
             val fd = defsToDefs(sym)
@@ -722,6 +795,14 @@ trait CodeExtraction extends ASTExtractors {
 
           extractFunBody(fd, params, body)(DefContext(tparamsMap, isExtern = isExtern(sym)))
 
+        case ExDefaultValueFunction(sym, tparams, params, _ ,_ , _, body) =>
+          // Default value functions
+          val fd = defsToDefs(sym)
+
+          val tparamsMap = (tparams zip fd.tparams.map(_.tp)).toMap
+
+          extractFunBody(fd, params, body)(DefContext(tparamsMap))
+          
         case ExLazyAccessorFunction(sym, _, body)  =>
           // Lazy vals
           val fd = defsToDefs(sym)
@@ -763,6 +844,8 @@ trait CodeExtraction extends ASTExtractors {
         // Taking accessor functions will duplicate work for strict fields, but we need them in case of lazy fields
         case ExFunctionDef(sym, tparams, params, _, body) =>
           Some(defsToDefs(sym))
+        case ExDefaultValueFunction(sym, _, _, _, _, _, _) =>
+          Some(defsToDefs(sym))
         case ExLazyAccessorFunction(sym, _, _) =>
           Some(defsToDefs(sym))
         case ExFieldDef(sym, _, _) =>
@@ -780,6 +863,7 @@ trait CodeExtraction extends ASTExtractors {
         case ExConstructorDef() =>
         case ExFunctionDef(_, _, _, _, _) =>
         case ExLazyAccessorFunction(_, _, _) =>
+        case ExDefaultValueFunction(_, _, _, _, _, _, _ ) =>
         case ExFieldDef(_,_,_) =>
         case ExLazyFieldDef() => 
         case ExFieldAccessorFunction() => 
@@ -795,6 +879,10 @@ trait CodeExtraction extends ASTExtractors {
     private def extractFunBody(funDef: FunDef, params: Seq[ValDef], body0 : Tree)(implicit dctx: DefContext): FunDef = {
       currentFunDef = funDef
       
+      // Find defining function for params with default value
+      for ((s,vd) <- params zip funDef.params) {
+        vd.defaultValue = paramsToDefaultValues.get(s.symbol) 
+      }
       
       val newVars = for ((s, vd) <- params zip funDef.params) yield {
         s.symbol -> (() => Variable(vd.id))
@@ -1115,6 +1203,8 @@ trait CodeExtraction extends ASTExtractors {
           rest = None
           LetDef(funDefWithBody, restTree)
 
+        // FIXME case ExDefaultValueFunction
+        
         /**
          * XLang Extractors
          */
@@ -1785,6 +1875,8 @@ trait CodeExtraction extends ASTExtractors {
             outOfSubsetError(tpt.typeSymbol.pos, "Could not extract refined type as PureScala: "+tpt+" ("+tpt.getClass+")")
         }
 
+      case AnnotatedType(_, tpe) => extractType(tpe)
+
       case _ =>
         outOfSubsetError(tpt.typeSymbol.pos, "Could not extract type as PureScala: "+tpt+" ("+tpt.getClass+")")
     }
diff --git a/src/main/scala/leon/purescala/Definitions.scala b/src/main/scala/leon/purescala/Definitions.scala
index ef3505476..f056671da 100644
--- a/src/main/scala/leon/purescala/Definitions.scala
+++ b/src/main/scala/leon/purescala/Definitions.scala
@@ -55,6 +55,8 @@ object Definitions {
 
     val getType = tpe getOrElse id.getType
 
+    var defaultValue : Option[FunDef] = None
+      
     def subDefinitions = Seq()
 
     // Warning: the variable will not have the same type as the ValDef, but 
@@ -413,6 +415,8 @@ object Definitions {
     def canBeField       = canBeLazyField || canBeStrictField
     def isRealFunction   = !canBeField
 
+    def isSynthetic = annotations contains "synthetic"
+    
     private var annots: Set[String] = Set.empty[String]
     def addAnnotation(as: String*) : FunDef = {
       annots = annots ++ as
diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala
index 80da11205..f43546154 100644
--- a/src/main/scala/leon/purescala/PrettyPrinter.scala
+++ b/src/main/scala/leon/purescala/PrettyPrinter.scala
@@ -303,7 +303,16 @@ class PrettyPrinter(opts: PrinterOptions, val sb: StringBuffer = new StringBuffe
           p"[${tfd.tps}]"
         }
 
-        if (tfd.fd.isRealFunction) p"($args)"
+        // No () for fields
+        if (tfd.fd.isRealFunction) {
+          // The non-present arguments are synthetic function invocations
+          val presentArgs = args filter {
+            case MethodInvocation(_, _, tfd, _) if tfd.fd.isSynthetic => false
+            case FunctionInvocation(tfd, _)     if tfd.fd.isSynthetic => false
+            case other => true
+          }
+          p"($presentArgs)"
+        }
 
       case BinaryMethodCall(a, op, b) =>
         optP { p"${a} $op ${b}" }
@@ -316,7 +325,13 @@ class PrettyPrinter(opts: PrinterOptions, val sb: StringBuffer = new StringBuffe
         }
 
         if (fd.isRealFunction) {
-          p"($args)"
+          // The non-present arguments are synthetic function invocations
+          val presentArgs = args filter {
+            case MethodInvocation(_, _, tfd, _) if tfd.fd.isSynthetic => false
+            case FunctionInvocation(tfd, _)     if tfd.fd.isSynthetic => false
+            case other => true
+          }
+          p"($presentArgs)"
         }
 
       case FunctionInvocation(tfd, args) =>
@@ -326,7 +341,15 @@ class PrettyPrinter(opts: PrinterOptions, val sb: StringBuffer = new StringBuffe
           p"[${tfd.tps}]"
         }
 
-        if (tfd.fd.isRealFunction) p"($args)"
+        if (tfd.fd.isRealFunction) { 
+          // The non-present arguments are synthetic function invocations
+          val presentArgs = args filter {
+            case MethodInvocation(_, _, tfd, _) if tfd.fd.isSynthetic => false
+            case FunctionInvocation(tfd, _)     if tfd.fd.isSynthetic => false
+            case other => true
+          }
+          p"($presentArgs)"
+        }
 
       case Application(caller, args) =>
         p"$caller($args)"
@@ -408,7 +431,10 @@ class PrettyPrinter(opts: PrinterOptions, val sb: StringBuffer = new StringBuffe
       }
 
       case Not(expr)                 => p"\u00AC$expr"
-      case vd@ValDef(id, _)          => p"$id : ${vd.getType}"
+      case vd@ValDef(id, _)          => vd.defaultValue match {
+        case Some(fd) => p"$id : ${vd.getType} = ${fd.body.get}"
+        case None => p"$id : ${vd.getType}"
+      }
       case This(_)                   => p"this"
       case (tfd: TypedFunDef)        => p"typed def ${tfd.id}[${tfd.tps}]"
       case TypeParameterDef(tp)      => p"$tp"
diff --git a/src/main/scala/leon/purescala/ScalaPrinter.scala b/src/main/scala/leon/purescala/ScalaPrinter.scala
index cf58ee34d..666f728d6 100644
--- a/src/main/scala/leon/purescala/ScalaPrinter.scala
+++ b/src/main/scala/leon/purescala/ScalaPrinter.scala
@@ -24,6 +24,12 @@ class ScalaPrinter(opts: PrinterOptions, sb: StringBuffer = new StringBuffer) ex
   override def pp(tree: Tree)(implicit ctx: PrinterContext): Unit = {
    
     tree match {
+      case m: ModuleDef =>
+        // Don't print synthetic functions
+        super.pp(m.copy(defs = m.defs.filter {
+          case f:FunDef if f.isSynthetic => false
+          case _ => true
+        }))
       case Not(Equals(l, r))    => p"$l != $r"
       case Implies(l,r)         => pp(or(not(l), r))
       case Choose(pred, None) => p"choose($pred)"
diff --git a/src/test/resources/regression/frontends/OptParams.scala b/src/test/resources/regression/frontends/OptParams.scala
new file mode 100644
index 000000000..535c51727
--- /dev/null
+++ b/src/test/resources/regression/frontends/OptParams.scala
@@ -0,0 +1,15 @@
+object OptParams {
+
+  def foo( x : Int, y : Int = 12 ) = x + y
+
+  def bar = foo(42)
+  def baz = foo(1,2)
+
+
+
+  abstract class Opt  {
+    def opt( o : Opt = OptChild(), i : Int = 0) : Int = i + 1 
+    def opt2 = opt()
+  }
+  case class OptChild() extends Opt
+}
-- 
GitLab