From 8af3310b1871c8e150744c2bf0895e52181108c2 Mon Sep 17 00:00:00 2001 From: "Emmanouil (Manos) Koukoutos" <emmanouil.koukoutos@epfl.ch> Date: Fri, 11 Apr 2014 17:44:14 +0200 Subject: [PATCH] Default parameters for functions --- .../leon/frontends/scalac/ASTExtractors.scala | 40 ++++++++ .../frontends/scalac/CodeExtraction.scala | 96 ++++++++++++++++++- .../scala/leon/purescala/Definitions.scala | 4 + .../scala/leon/purescala/PrettyPrinter.scala | 34 ++++++- .../scala/leon/purescala/ScalaPrinter.scala | 6 ++ .../regression/frontends/OptParams.scala | 15 +++ 6 files changed, 189 insertions(+), 6 deletions(-) create mode 100644 src/test/resources/regression/frontends/OptParams.scala diff --git a/src/main/scala/leon/frontends/scalac/ASTExtractors.scala b/src/main/scala/leon/frontends/scalac/ASTExtractors.scala index 840dd6fa4..ea6309c8c 100644 --- a/src/main/scala/leon/frontends/scalac/ASTExtractors.scala +++ b/src/main/scala/leon/frontends/scalac/ASTExtractors.scala @@ -310,6 +310,18 @@ trait ASTExtractors { } } + object ExCompanionObjectSynthetic { + def unapply(cd : ClassDef) : Option[(String, Symbol, Template)] = { + val sym = cd.symbol + cd match { + case ClassDef(_, name, tparams, impl) if sym.isModule && sym.isSynthetic => //FIXME flags? + Some((name.toString, sym, impl)) + case _ => None + } + + } + } + object ExCaseClassSyntheticJunk { def unapply(cd: ClassDef): Boolean = cd match { case ClassDef(_, _, _, _) if (cd.symbol.isSynthetic) => true @@ -408,6 +420,34 @@ trait ASTExtractors { } } + + object ExDefaultValueFunction{ + /** Matches a function that defines the default value of a parameter */ + def unapply(dd: DefDef): Option[(Symbol, Seq[Symbol], Seq[ValDef], Type, String, Int, Tree)] = { + val sym = dd.symbol + dd match { + case DefDef(_, name, tparams, vparamss, tpt, rhs) if( + vparamss.size <= 1 && name != nme.CONSTRUCTOR && sym.isSynthetic + ) => + + // Split the name into pieces, to find owner of the parameter + param.index + // Form has to be <owner name>$default$<param index> + val symPieces = sym.name.toString.reverse.split("\\$",3).reverse map { _.reverse } + + try { + if (symPieces(1) != "default" || symPieces(0) == "copy") throw new IllegalArgumentException("") + val ownerString = symPieces(0) + val index = symPieces(2).toInt - 1 + Some((sym, tparams.map(_.symbol), vparamss.headOption.getOrElse(Nil), tpt.tpe, ownerString, index, rhs)) + } catch { + case _ : NumberFormatException | _ : IllegalArgumentException | _ : ArrayIndexOutOfBoundsException => + None + } + + case _ => None + } + } + } } diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala index 8166e69f8..630f688a5 100644 --- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala +++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala @@ -209,7 +209,6 @@ trait CodeExtraction extends ASTExtractors { !isLib(u) ) - case pd @ PackageDef(refTree, lst) => var standaloneDefs = List[Tree]() @@ -224,7 +223,22 @@ trait CodeExtraction extends ASTExtractors { case ExObjectDef(n, templ) if n != "package" => Some(TempModule(FreshIdentifier(n), templ.body, false)) - + + /* + case d @ ExCompanionObjectSynthetic(_, sym, templ) => + // Find default param. implementations + templ.body foreach { + case ExDefaultValueFunction(sym, _, _, _, owner, index, _) => + val namePieces = sym.toString.reverse.split("\\$", 3).reverse map { _.reverse } + assert(namePieces.length == 3 && namePieces(0)== "$lessinit$greater" && namePieces(1) == "default") // FIXME : maybe $lessinit$greater? + val index = namePieces(2).toInt + val theParam = sym.companionClass.paramss.head(index - 1) + paramsToDefaultValues += theParam -> body + case _ => + } + None + */ + case d @ ExAbstractClass(_, _, _) => standaloneDefs ::= d None @@ -389,6 +403,8 @@ trait CodeExtraction extends ASTExtractors { outOfSubsetError(pos, "Class "+className+" is not a case class") } } + + private var paramsToDefaultValues = Map[Symbol,FunDef]() private def collectClassSymbols(defs: List[Tree]) { // We collect all defined classes @@ -450,6 +466,28 @@ trait CodeExtraction extends ASTExtractors { private var isMethod = Set[Symbol]() private var methodToClass = Map[FunDef, LeonClassDef]() + /** + * For the function in $defs with name $owner, find its parameter with index $index, + * and registers $fd as the default value function for this parameter. + */ + private def registerDefaultMethod( + defs : List[Tree], + matcher : PartialFunction[Tree,Symbol], + index : Int, + fd : FunDef + ) { + // Search tmpl to find the function that includes this parameter + val paramOwner = defs.collectFirst(matcher).get + + // assumes single argument list + if(paramOwner.paramss.length != 1) { + outOfSubsetError(paramOwner.pos, "Multiple argument lists for a function are not allowed") + } + val theParam = paramOwner.paramss.head(index) + paramsToDefaultValues += (theParam -> fd) + } + + def extractClassDef(sym: Symbol, args: Seq[(Symbol, ValDef)], tmpl: Template): LeonClassDef = { val id = FreshIdentifier(sym.name.toString).setPos(sym.pos) @@ -544,6 +582,21 @@ trait CodeExtraction extends ASTExtractors { cd.registerMethod(fd) + // Default values for parameters + case t@ ExDefaultValueFunction(fsym, _, _, _, owner, index, _) => + val fd = defineFunDef(fsym)(defCtx) + fd.addAnnotation("synthetic") + + isMethod += fsym + methodToClass += fd -> cd + + cd.registerMethod(fd) + val matcher : PartialFunction[Tree, Symbol] = { + case ExFunctionDef(ownerSym, _ ,_ ,_, _) if ownerSym.name.toString == owner => ownerSym + } + registerDefaultMethod(tmpl.body, matcher, index, fd ) + + // Lazy fields case t @ ExLazyAccessorFunction(fsym, _, _) => if (parent.isDefined) { @@ -639,6 +692,16 @@ trait CodeExtraction extends ASTExtractors { case ExFunctionDef(sym, _, _, _, _) => defineFunDef(sym)(DefContext()) + case t @ ExDefaultValueFunction(sym, _, _, _, owner, index, _) => { + + val fd = defineFunDef(sym)(DefContext()) + fd.addAnnotation("synthetic") + val matcher : PartialFunction[Tree, Symbol] = { + case ExFunctionDef(ownerSym, _ ,_ ,_, _) if ownerSym.name.toString == owner => ownerSym + } + registerDefaultMethod(defs, matcher, index, fd) + + } case ExLazyAccessorFunction(sym, _, _) => defineFieldFunDef(sym,true)(DefContext()) @@ -673,6 +736,16 @@ trait CodeExtraction extends ASTExtractors { extractFunBody(fd, params, body)(DefContext(tparamsMap)) } + // Default value functions + case ExDefaultValueFunction(sym, tparams, params, _, _, _, body) => + val fd = defsToDefs(sym) + + val tparamsMap = (tparams zip fd.tparams.map(_.tp)).toMap ++ ctparamsMap + + if(body != EmptyTree) { + extractFunBody(fd, params, body)(DefContext(tparamsMap)) + } + // Lazy fields case t @ ExLazyAccessorFunction(sym, _, body) => val fd = defsToDefs(sym) @@ -722,6 +795,14 @@ trait CodeExtraction extends ASTExtractors { extractFunBody(fd, params, body)(DefContext(tparamsMap, isExtern = isExtern(sym))) + case ExDefaultValueFunction(sym, tparams, params, _ ,_ , _, body) => + // Default value functions + val fd = defsToDefs(sym) + + val tparamsMap = (tparams zip fd.tparams.map(_.tp)).toMap + + extractFunBody(fd, params, body)(DefContext(tparamsMap)) + case ExLazyAccessorFunction(sym, _, body) => // Lazy vals val fd = defsToDefs(sym) @@ -763,6 +844,8 @@ trait CodeExtraction extends ASTExtractors { // Taking accessor functions will duplicate work for strict fields, but we need them in case of lazy fields case ExFunctionDef(sym, tparams, params, _, body) => Some(defsToDefs(sym)) + case ExDefaultValueFunction(sym, _, _, _, _, _, _) => + Some(defsToDefs(sym)) case ExLazyAccessorFunction(sym, _, _) => Some(defsToDefs(sym)) case ExFieldDef(sym, _, _) => @@ -780,6 +863,7 @@ trait CodeExtraction extends ASTExtractors { case ExConstructorDef() => case ExFunctionDef(_, _, _, _, _) => case ExLazyAccessorFunction(_, _, _) => + case ExDefaultValueFunction(_, _, _, _, _, _, _ ) => case ExFieldDef(_,_,_) => case ExLazyFieldDef() => case ExFieldAccessorFunction() => @@ -795,6 +879,10 @@ trait CodeExtraction extends ASTExtractors { private def extractFunBody(funDef: FunDef, params: Seq[ValDef], body0 : Tree)(implicit dctx: DefContext): FunDef = { currentFunDef = funDef + // Find defining function for params with default value + for ((s,vd) <- params zip funDef.params) { + vd.defaultValue = paramsToDefaultValues.get(s.symbol) + } val newVars = for ((s, vd) <- params zip funDef.params) yield { s.symbol -> (() => Variable(vd.id)) @@ -1115,6 +1203,8 @@ trait CodeExtraction extends ASTExtractors { rest = None LetDef(funDefWithBody, restTree) + // FIXME case ExDefaultValueFunction + /** * XLang Extractors */ @@ -1785,6 +1875,8 @@ trait CodeExtraction extends ASTExtractors { outOfSubsetError(tpt.typeSymbol.pos, "Could not extract refined type as PureScala: "+tpt+" ("+tpt.getClass+")") } + case AnnotatedType(_, tpe) => extractType(tpe) + case _ => outOfSubsetError(tpt.typeSymbol.pos, "Could not extract type as PureScala: "+tpt+" ("+tpt.getClass+")") } diff --git a/src/main/scala/leon/purescala/Definitions.scala b/src/main/scala/leon/purescala/Definitions.scala index ef3505476..f056671da 100644 --- a/src/main/scala/leon/purescala/Definitions.scala +++ b/src/main/scala/leon/purescala/Definitions.scala @@ -55,6 +55,8 @@ object Definitions { val getType = tpe getOrElse id.getType + var defaultValue : Option[FunDef] = None + def subDefinitions = Seq() // Warning: the variable will not have the same type as the ValDef, but @@ -413,6 +415,8 @@ object Definitions { def canBeField = canBeLazyField || canBeStrictField def isRealFunction = !canBeField + def isSynthetic = annotations contains "synthetic" + private var annots: Set[String] = Set.empty[String] def addAnnotation(as: String*) : FunDef = { annots = annots ++ as diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala index 80da11205..f43546154 100644 --- a/src/main/scala/leon/purescala/PrettyPrinter.scala +++ b/src/main/scala/leon/purescala/PrettyPrinter.scala @@ -303,7 +303,16 @@ class PrettyPrinter(opts: PrinterOptions, val sb: StringBuffer = new StringBuffe p"[${tfd.tps}]" } - if (tfd.fd.isRealFunction) p"($args)" + // No () for fields + if (tfd.fd.isRealFunction) { + // The non-present arguments are synthetic function invocations + val presentArgs = args filter { + case MethodInvocation(_, _, tfd, _) if tfd.fd.isSynthetic => false + case FunctionInvocation(tfd, _) if tfd.fd.isSynthetic => false + case other => true + } + p"($presentArgs)" + } case BinaryMethodCall(a, op, b) => optP { p"${a} $op ${b}" } @@ -316,7 +325,13 @@ class PrettyPrinter(opts: PrinterOptions, val sb: StringBuffer = new StringBuffe } if (fd.isRealFunction) { - p"($args)" + // The non-present arguments are synthetic function invocations + val presentArgs = args filter { + case MethodInvocation(_, _, tfd, _) if tfd.fd.isSynthetic => false + case FunctionInvocation(tfd, _) if tfd.fd.isSynthetic => false + case other => true + } + p"($presentArgs)" } case FunctionInvocation(tfd, args) => @@ -326,7 +341,15 @@ class PrettyPrinter(opts: PrinterOptions, val sb: StringBuffer = new StringBuffe p"[${tfd.tps}]" } - if (tfd.fd.isRealFunction) p"($args)" + if (tfd.fd.isRealFunction) { + // The non-present arguments are synthetic function invocations + val presentArgs = args filter { + case MethodInvocation(_, _, tfd, _) if tfd.fd.isSynthetic => false + case FunctionInvocation(tfd, _) if tfd.fd.isSynthetic => false + case other => true + } + p"($presentArgs)" + } case Application(caller, args) => p"$caller($args)" @@ -408,7 +431,10 @@ class PrettyPrinter(opts: PrinterOptions, val sb: StringBuffer = new StringBuffe } case Not(expr) => p"\u00AC$expr" - case vd@ValDef(id, _) => p"$id : ${vd.getType}" + case vd@ValDef(id, _) => vd.defaultValue match { + case Some(fd) => p"$id : ${vd.getType} = ${fd.body.get}" + case None => p"$id : ${vd.getType}" + } case This(_) => p"this" case (tfd: TypedFunDef) => p"typed def ${tfd.id}[${tfd.tps}]" case TypeParameterDef(tp) => p"$tp" diff --git a/src/main/scala/leon/purescala/ScalaPrinter.scala b/src/main/scala/leon/purescala/ScalaPrinter.scala index cf58ee34d..666f728d6 100644 --- a/src/main/scala/leon/purescala/ScalaPrinter.scala +++ b/src/main/scala/leon/purescala/ScalaPrinter.scala @@ -24,6 +24,12 @@ class ScalaPrinter(opts: PrinterOptions, sb: StringBuffer = new StringBuffer) ex override def pp(tree: Tree)(implicit ctx: PrinterContext): Unit = { tree match { + case m: ModuleDef => + // Don't print synthetic functions + super.pp(m.copy(defs = m.defs.filter { + case f:FunDef if f.isSynthetic => false + case _ => true + })) case Not(Equals(l, r)) => p"$l != $r" case Implies(l,r) => pp(or(not(l), r)) case Choose(pred, None) => p"choose($pred)" diff --git a/src/test/resources/regression/frontends/OptParams.scala b/src/test/resources/regression/frontends/OptParams.scala new file mode 100644 index 000000000..535c51727 --- /dev/null +++ b/src/test/resources/regression/frontends/OptParams.scala @@ -0,0 +1,15 @@ +object OptParams { + + def foo( x : Int, y : Int = 12 ) = x + y + + def bar = foo(42) + def baz = foo(1,2) + + + + abstract class Opt { + def opt( o : Opt = OptChild(), i : Int = 0) : Int = i + 1 + def opt2 = opt() + } + case class OptChild() extends Opt +} -- GitLab