Skip to content
Snippets Groups Projects
Commit 8af3310b authored by Emmanouil (Manos) Koukoutos's avatar Emmanouil (Manos) Koukoutos Committed by Etienne Kneuss
Browse files

Default parameters for functions

parent 463b47e1
Branches
Tags
No related merge requests found
......@@ -310,6 +310,18 @@ trait ASTExtractors {
}
}
object ExCompanionObjectSynthetic {
def unapply(cd : ClassDef) : Option[(String, Symbol, Template)] = {
val sym = cd.symbol
cd match {
case ClassDef(_, name, tparams, impl) if sym.isModule && sym.isSynthetic => //FIXME flags?
Some((name.toString, sym, impl))
case _ => None
}
}
}
object ExCaseClassSyntheticJunk {
def unapply(cd: ClassDef): Boolean = cd match {
case ClassDef(_, _, _, _) if (cd.symbol.isSynthetic) => true
......@@ -408,6 +420,34 @@ trait ASTExtractors {
}
}
object ExDefaultValueFunction{
/** Matches a function that defines the default value of a parameter */
def unapply(dd: DefDef): Option[(Symbol, Seq[Symbol], Seq[ValDef], Type, String, Int, Tree)] = {
val sym = dd.symbol
dd match {
case DefDef(_, name, tparams, vparamss, tpt, rhs) if(
vparamss.size <= 1 && name != nme.CONSTRUCTOR && sym.isSynthetic
) =>
// Split the name into pieces, to find owner of the parameter + param.index
// Form has to be <owner name>$default$<param index>
val symPieces = sym.name.toString.reverse.split("\\$",3).reverse map { _.reverse }
try {
if (symPieces(1) != "default" || symPieces(0) == "copy") throw new IllegalArgumentException("")
val ownerString = symPieces(0)
val index = symPieces(2).toInt - 1
Some((sym, tparams.map(_.symbol), vparamss.headOption.getOrElse(Nil), tpt.tpe, ownerString, index, rhs))
} catch {
case _ : NumberFormatException | _ : IllegalArgumentException | _ : ArrayIndexOutOfBoundsException =>
None
}
case _ => None
}
}
}
}
......
......@@ -209,7 +209,6 @@ trait CodeExtraction extends ASTExtractors {
!isLib(u)
)
case pd @ PackageDef(refTree, lst) =>
var standaloneDefs = List[Tree]()
......@@ -224,7 +223,22 @@ trait CodeExtraction extends ASTExtractors {
case ExObjectDef(n, templ) if n != "package" =>
Some(TempModule(FreshIdentifier(n), templ.body, false))
/*
case d @ ExCompanionObjectSynthetic(_, sym, templ) =>
// Find default param. implementations
templ.body foreach {
case ExDefaultValueFunction(sym, _, _, _, owner, index, _) =>
val namePieces = sym.toString.reverse.split("\\$", 3).reverse map { _.reverse }
assert(namePieces.length == 3 && namePieces(0)== "$lessinit$greater" && namePieces(1) == "default") // FIXME : maybe $lessinit$greater?
val index = namePieces(2).toInt
val theParam = sym.companionClass.paramss.head(index - 1)
paramsToDefaultValues += theParam -> body
case _ =>
}
None
*/
case d @ ExAbstractClass(_, _, _) =>
standaloneDefs ::= d
None
......@@ -389,6 +403,8 @@ trait CodeExtraction extends ASTExtractors {
outOfSubsetError(pos, "Class "+className+" is not a case class")
}
}
private var paramsToDefaultValues = Map[Symbol,FunDef]()
private def collectClassSymbols(defs: List[Tree]) {
// We collect all defined classes
......@@ -450,6 +466,28 @@ trait CodeExtraction extends ASTExtractors {
private var isMethod = Set[Symbol]()
private var methodToClass = Map[FunDef, LeonClassDef]()
/**
* For the function in $defs with name $owner, find its parameter with index $index,
* and registers $fd as the default value function for this parameter.
*/
private def registerDefaultMethod(
defs : List[Tree],
matcher : PartialFunction[Tree,Symbol],
index : Int,
fd : FunDef
) {
// Search tmpl to find the function that includes this parameter
val paramOwner = defs.collectFirst(matcher).get
// assumes single argument list
if(paramOwner.paramss.length != 1) {
outOfSubsetError(paramOwner.pos, "Multiple argument lists for a function are not allowed")
}
val theParam = paramOwner.paramss.head(index)
paramsToDefaultValues += (theParam -> fd)
}
def extractClassDef(sym: Symbol, args: Seq[(Symbol, ValDef)], tmpl: Template): LeonClassDef = {
val id = FreshIdentifier(sym.name.toString).setPos(sym.pos)
......@@ -544,6 +582,21 @@ trait CodeExtraction extends ASTExtractors {
cd.registerMethod(fd)
// Default values for parameters
case t@ ExDefaultValueFunction(fsym, _, _, _, owner, index, _) =>
val fd = defineFunDef(fsym)(defCtx)
fd.addAnnotation("synthetic")
isMethod += fsym
methodToClass += fd -> cd
cd.registerMethod(fd)
val matcher : PartialFunction[Tree, Symbol] = {
case ExFunctionDef(ownerSym, _ ,_ ,_, _) if ownerSym.name.toString == owner => ownerSym
}
registerDefaultMethod(tmpl.body, matcher, index, fd )
// Lazy fields
case t @ ExLazyAccessorFunction(fsym, _, _) =>
if (parent.isDefined) {
......@@ -639,6 +692,16 @@ trait CodeExtraction extends ASTExtractors {
case ExFunctionDef(sym, _, _, _, _) =>
defineFunDef(sym)(DefContext())
case t @ ExDefaultValueFunction(sym, _, _, _, owner, index, _) => {
val fd = defineFunDef(sym)(DefContext())
fd.addAnnotation("synthetic")
val matcher : PartialFunction[Tree, Symbol] = {
case ExFunctionDef(ownerSym, _ ,_ ,_, _) if ownerSym.name.toString == owner => ownerSym
}
registerDefaultMethod(defs, matcher, index, fd)
}
case ExLazyAccessorFunction(sym, _, _) =>
defineFieldFunDef(sym,true)(DefContext())
......@@ -673,6 +736,16 @@ trait CodeExtraction extends ASTExtractors {
extractFunBody(fd, params, body)(DefContext(tparamsMap))
}
// Default value functions
case ExDefaultValueFunction(sym, tparams, params, _, _, _, body) =>
val fd = defsToDefs(sym)
val tparamsMap = (tparams zip fd.tparams.map(_.tp)).toMap ++ ctparamsMap
if(body != EmptyTree) {
extractFunBody(fd, params, body)(DefContext(tparamsMap))
}
// Lazy fields
case t @ ExLazyAccessorFunction(sym, _, body) =>
val fd = defsToDefs(sym)
......@@ -722,6 +795,14 @@ trait CodeExtraction extends ASTExtractors {
extractFunBody(fd, params, body)(DefContext(tparamsMap, isExtern = isExtern(sym)))
case ExDefaultValueFunction(sym, tparams, params, _ ,_ , _, body) =>
// Default value functions
val fd = defsToDefs(sym)
val tparamsMap = (tparams zip fd.tparams.map(_.tp)).toMap
extractFunBody(fd, params, body)(DefContext(tparamsMap))
case ExLazyAccessorFunction(sym, _, body) =>
// Lazy vals
val fd = defsToDefs(sym)
......@@ -763,6 +844,8 @@ trait CodeExtraction extends ASTExtractors {
// Taking accessor functions will duplicate work for strict fields, but we need them in case of lazy fields
case ExFunctionDef(sym, tparams, params, _, body) =>
Some(defsToDefs(sym))
case ExDefaultValueFunction(sym, _, _, _, _, _, _) =>
Some(defsToDefs(sym))
case ExLazyAccessorFunction(sym, _, _) =>
Some(defsToDefs(sym))
case ExFieldDef(sym, _, _) =>
......@@ -780,6 +863,7 @@ trait CodeExtraction extends ASTExtractors {
case ExConstructorDef() =>
case ExFunctionDef(_, _, _, _, _) =>
case ExLazyAccessorFunction(_, _, _) =>
case ExDefaultValueFunction(_, _, _, _, _, _, _ ) =>
case ExFieldDef(_,_,_) =>
case ExLazyFieldDef() =>
case ExFieldAccessorFunction() =>
......@@ -795,6 +879,10 @@ trait CodeExtraction extends ASTExtractors {
private def extractFunBody(funDef: FunDef, params: Seq[ValDef], body0 : Tree)(implicit dctx: DefContext): FunDef = {
currentFunDef = funDef
// Find defining function for params with default value
for ((s,vd) <- params zip funDef.params) {
vd.defaultValue = paramsToDefaultValues.get(s.symbol)
}
val newVars = for ((s, vd) <- params zip funDef.params) yield {
s.symbol -> (() => Variable(vd.id))
......@@ -1115,6 +1203,8 @@ trait CodeExtraction extends ASTExtractors {
rest = None
LetDef(funDefWithBody, restTree)
// FIXME case ExDefaultValueFunction
/**
* XLang Extractors
*/
......@@ -1785,6 +1875,8 @@ trait CodeExtraction extends ASTExtractors {
outOfSubsetError(tpt.typeSymbol.pos, "Could not extract refined type as PureScala: "+tpt+" ("+tpt.getClass+")")
}
case AnnotatedType(_, tpe) => extractType(tpe)
case _ =>
outOfSubsetError(tpt.typeSymbol.pos, "Could not extract type as PureScala: "+tpt+" ("+tpt.getClass+")")
}
......
......@@ -55,6 +55,8 @@ object Definitions {
val getType = tpe getOrElse id.getType
var defaultValue : Option[FunDef] = None
def subDefinitions = Seq()
// Warning: the variable will not have the same type as the ValDef, but
......@@ -413,6 +415,8 @@ object Definitions {
def canBeField = canBeLazyField || canBeStrictField
def isRealFunction = !canBeField
def isSynthetic = annotations contains "synthetic"
private var annots: Set[String] = Set.empty[String]
def addAnnotation(as: String*) : FunDef = {
annots = annots ++ as
......
......@@ -303,7 +303,16 @@ class PrettyPrinter(opts: PrinterOptions, val sb: StringBuffer = new StringBuffe
p"[${tfd.tps}]"
}
if (tfd.fd.isRealFunction) p"($args)"
// No () for fields
if (tfd.fd.isRealFunction) {
// The non-present arguments are synthetic function invocations
val presentArgs = args filter {
case MethodInvocation(_, _, tfd, _) if tfd.fd.isSynthetic => false
case FunctionInvocation(tfd, _) if tfd.fd.isSynthetic => false
case other => true
}
p"($presentArgs)"
}
case BinaryMethodCall(a, op, b) =>
optP { p"${a} $op ${b}" }
......@@ -316,7 +325,13 @@ class PrettyPrinter(opts: PrinterOptions, val sb: StringBuffer = new StringBuffe
}
if (fd.isRealFunction) {
p"($args)"
// The non-present arguments are synthetic function invocations
val presentArgs = args filter {
case MethodInvocation(_, _, tfd, _) if tfd.fd.isSynthetic => false
case FunctionInvocation(tfd, _) if tfd.fd.isSynthetic => false
case other => true
}
p"($presentArgs)"
}
case FunctionInvocation(tfd, args) =>
......@@ -326,7 +341,15 @@ class PrettyPrinter(opts: PrinterOptions, val sb: StringBuffer = new StringBuffe
p"[${tfd.tps}]"
}
if (tfd.fd.isRealFunction) p"($args)"
if (tfd.fd.isRealFunction) {
// The non-present arguments are synthetic function invocations
val presentArgs = args filter {
case MethodInvocation(_, _, tfd, _) if tfd.fd.isSynthetic => false
case FunctionInvocation(tfd, _) if tfd.fd.isSynthetic => false
case other => true
}
p"($presentArgs)"
}
case Application(caller, args) =>
p"$caller($args)"
......@@ -408,7 +431,10 @@ class PrettyPrinter(opts: PrinterOptions, val sb: StringBuffer = new StringBuffe
}
case Not(expr) => p"\u00AC$expr"
case vd@ValDef(id, _) => p"$id : ${vd.getType}"
case vd@ValDef(id, _) => vd.defaultValue match {
case Some(fd) => p"$id : ${vd.getType} = ${fd.body.get}"
case None => p"$id : ${vd.getType}"
}
case This(_) => p"this"
case (tfd: TypedFunDef) => p"typed def ${tfd.id}[${tfd.tps}]"
case TypeParameterDef(tp) => p"$tp"
......
......@@ -24,6 +24,12 @@ class ScalaPrinter(opts: PrinterOptions, sb: StringBuffer = new StringBuffer) ex
override def pp(tree: Tree)(implicit ctx: PrinterContext): Unit = {
tree match {
case m: ModuleDef =>
// Don't print synthetic functions
super.pp(m.copy(defs = m.defs.filter {
case f:FunDef if f.isSynthetic => false
case _ => true
}))
case Not(Equals(l, r)) => p"$l != $r"
case Implies(l,r) => pp(or(not(l), r))
case Choose(pred, None) => p"choose($pred)"
......
object OptParams {
def foo( x : Int, y : Int = 12 ) = x + y
def bar = foo(42)
def baz = foo(1,2)
abstract class Opt {
def opt( o : Opt = OptChild(), i : Int = 0) : Int = i + 1
def opt2 = opt()
}
case class OptChild() extends Opt
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment