diff --git a/src/main/scala/leon/purescala/Types.scala b/src/main/scala/leon/purescala/Types.scala index 1edf989ae8d85f2fd70d874ee09c5bb74a87f4e9..6518d65fbcec5d365164f2b6213e40f563fb3233 100644 --- a/src/main/scala/leon/purescala/Types.scala +++ b/src/main/scala/leon/purescala/Types.scala @@ -148,6 +148,14 @@ object Types { case t => Some(Nil, _ => t) } } + + object FirstOrderFunctionType { + def unapply(tpe: TypeTree): Option[(Seq[TypeTree], TypeTree)] = tpe match { + case FunctionType(from, to) => + unapply(to).map(p => (from ++ p._1) -> p._2) orElse Some(from -> to) + case _ => None + } + } def optionToType(tp: Option[TypeTree]) = tp getOrElse Untyped diff --git a/src/main/scala/leon/solvers/unrolling/LambdaManager.scala b/src/main/scala/leon/solvers/unrolling/LambdaManager.scala index 29a483aa9113eeb542d80c89c0a6ba8b0e6b607b..ecc7843f7cb5dd2e755c3a12527269ca4abe8443 100644 --- a/src/main/scala/leon/solvers/unrolling/LambdaManager.scala +++ b/src/main/scala/leon/solvers/unrolling/LambdaManager.scala @@ -256,7 +256,7 @@ class LambdaManager[T](encoder: TemplateEncoder[T]) extends DatatypeManager(enco (Seq(encoder.mkImplies(blocker, typeBlocker)), Map.empty, Map.empty) case None => - val App(caller, tpe @ FunctionType(_, to), args, value) = app + val App(caller, tpe @ FirstOrderFunctionType(_, to), args, value) = app val typeBlocker = encoder.encodeId(FreshIdentifier("t", BooleanType)) typeBlockers += value -> typeBlocker implies(blocker, typeBlocker) diff --git a/src/main/scala/leon/solvers/unrolling/TemplateManager.scala b/src/main/scala/leon/solvers/unrolling/TemplateManager.scala index 37f1a031f1ef401d3ac83fbe132321b6483f0ef0..d1da0d8f17f0d32f7fa343d0d5e8155529fe6dd0 100644 --- a/src/main/scala/leon/solvers/unrolling/TemplateManager.scala +++ b/src/main/scala/leon/solvers/unrolling/TemplateManager.scala @@ -117,6 +117,15 @@ object Template { } } + private def mkApplication(caller: Expr, args: Seq[Expr]): Expr = caller.getType match { + case FunctionType(from, to) => + val (curr, next) = args.splitAt(from.size) + mkApplication(Application(caller, curr), next) + case _ => + assert(args.isEmpty, s"Non-function typed $caller applied to ${args.mkString(",")}") + caller + } + private def invocationMatcher[T](encodeExpr: Expr => T)(tfd: TypedFunDef, args: Seq[Expr]): Matcher[T] = { assert(tfd.returnType.isInstanceOf[FunctionType], "invocationMatcher() is only defined on function-typed defs") @@ -186,7 +195,7 @@ object Template { val optIdCall = optCall.map(tfd => TemplateCallInfo[T](tfd, arguments.map(p => Left(p._2)))) val optIdApp = optApp.map { case (idT, tpe) => val id = FreshIdentifier("x", tpe, true) - val encoded = encoder.encodeExpr(Map(id -> idT) ++ arguments)(Application(Variable(id), arguments.map(_._1.toVariable))) + val encoded = encoder.encodeExpr(Map(id -> idT) ++ arguments)(mkApplication(Variable(id), arguments.map(_._1.toVariable))) App(idT, bestRealType(tpe).asInstanceOf[FunctionType], arguments.map(p => Left(p._2)), encoded) } @@ -229,7 +238,7 @@ object Template { funInfos ++= firstOrderCallsOf(e).map(p => TemplateCallInfo(p._1, p._2.map(encodeArg))) appInfos ++= firstOrderAppsOf(e).map { case (c, args) => val tpe = bestRealType(c.getType).asInstanceOf[FunctionType] - App(encodeExpr(c), tpe, args.map(encodeArg), encodeExpr(Application(c, args))) + App(encodeExpr(c), tpe, args.map(encodeArg), encodeExpr(mkApplication(c, args))) } matchInfos ++= exprToMatcher.values