diff --git a/src/main/scala/leon/purescala/Constructors.scala b/src/main/scala/leon/purescala/Constructors.scala index 24f43d7569eb4213fec4f1199a63e0a000d619f2..3d3bd42eee1a79e996153d8029b069739c809996 100644 --- a/src/main/scala/leon/purescala/Constructors.scala +++ b/src/main/scala/leon/purescala/Constructors.scala @@ -222,21 +222,21 @@ object Constructors { NonemptyArray(els.zipWithIndex.map{ _.swap }.toMap, defaultLength) } + /* + * Take a mapping from keys to values and a default expression and return a lambda of the form + * (x1, ..., xn) => + * if ( key1 == (x1, ..., xn) ) value1 + * else if ( key2 == (x1, ..., xn) ) value2 + * ... + * else default + */ def finiteLambda(default: Expr, els: Seq[(Expr, Expr)], inputTypes: Seq[TypeTree]): Lambda = { val args = inputTypes map { tpe => ValDef(FreshIdentifier("x", tpe, true)) } - if (els.isEmpty) { - Lambda(args, default) - } else { - val theMap = NonemptyMap(els) - val theMapVar = FreshIdentifier("pairs", theMap.getType, true) - val argsAsExpr = tupleWrap(args map { _.toVariable }) - val body = Let(theMapVar, theMap, IfExpr( - MapIsDefinedAt(Variable(theMapVar), argsAsExpr), - MapGet(Variable(theMapVar), argsAsExpr), - default - )) - Lambda(args, body) + val argsExpr = tupleWrap(args map { _.toVariable }) + val body = els.foldRight(default) { case ((key, value), default) => + IfExpr(Equals(argsExpr, key), value, default) } + Lambda(args, body) } def application(fn: Expr, realArgs: Seq[Expr]) = fn match { diff --git a/src/main/scala/leon/purescala/Extractors.scala b/src/main/scala/leon/purescala/Extractors.scala index 8cde59fe888ee25caf203c36da9dd60a8acb88e3..f47672cb38d1518a3321a28ffc09f5c7a4365adf 100644 --- a/src/main/scala/leon/purescala/Extractors.scala +++ b/src/main/scala/leon/purescala/Extractors.scala @@ -234,26 +234,36 @@ object Extractors { def unapply[T <: Typed](e: T): Option[(T, TypeTree)] = Some((e, e.getType)) } + /* + * Extract a default expression and key-value pairs from a lambda constructed with + * Constructors.finiteLambda + */ object FiniteLambda { def unapply(lambda: Lambda): Option[(Expr, Seq[(Expr, Expr)])] = { val inSize = lambda.getType.asInstanceOf[FunctionType].from.size - lambda match { - case Lambda(args, Let(theMapVar, FiniteMap(pairs), IfExpr( - MapIsDefinedAt(Variable(theMapVar1), targs2), - MapGet(Variable(theMapVar2), targs3), - default - ))) if { - val args2 = unwrapTuple(targs2, inSize) - val args3 = unwrapTuple(targs3, inSize) - (args map { x: ValDef => x.toVariable }) == args2 && - args2 == args3 && theMapVar == theMapVar1 && - theMapVar == theMapVar2 + val Lambda(args, body) = lambda + def step(e: Expr): (Option[(Expr, Expr)], Expr) = e match { + case IfExpr(Equals(argsExpr, key), value, default) if { + val formal = args.map{ _.id } + val real = unwrapTuple(argsExpr, inSize).collect{ case Variable(id) => id} + formal == real } => - Some(default, pairs) - case Lambda(args, default) if (variablesOf(default) & args.toSet.map{x: ValDef => x.id}).isEmpty => - Some(default, Seq()) - case _ => None + (Some((key, value)), default) + case other => + (None, other) + } + + def rec(e: Expr): (Expr, Seq[(Expr, Expr)]) = { + step(e) match { + case (None, default) => (default, Seq()) + case (Some(pair), default) => + val (defaultRest, pairs) = rec(default) + (defaultRest, pair +: pairs) + } } + + Some(rec(body)) + } }