From 916be314328b2c20ce83164cd4044b54157e279f Mon Sep 17 00:00:00 2001
From: Nicolas Voirol <voirol.nicolas@gmail.com>
Date: Sat, 16 Apr 2016 16:46:41 +0200
Subject: [PATCH] Fix for first-class functions returning more functions

---
 src/main/scala/leon/purescala/Types.scala           |  8 ++++++++
 .../leon/solvers/unrolling/LambdaManager.scala      |  2 +-
 .../leon/solvers/unrolling/TemplateManager.scala    | 13 +++++++++++--
 3 files changed, 20 insertions(+), 3 deletions(-)

diff --git a/src/main/scala/leon/purescala/Types.scala b/src/main/scala/leon/purescala/Types.scala
index 1edf989ae..6518d65fb 100644
--- a/src/main/scala/leon/purescala/Types.scala
+++ b/src/main/scala/leon/purescala/Types.scala
@@ -148,6 +148,14 @@ object Types {
       case t => Some(Nil, _ => t)
     }
   }
+
+  object FirstOrderFunctionType {
+    def unapply(tpe: TypeTree): Option[(Seq[TypeTree], TypeTree)] = tpe match {
+      case FunctionType(from, to) =>
+        unapply(to).map(p => (from ++ p._1) -> p._2) orElse Some(from -> to)
+      case _ => None
+    }
+  }
   
   def optionToType(tp: Option[TypeTree]) = tp getOrElse Untyped
 
diff --git a/src/main/scala/leon/solvers/unrolling/LambdaManager.scala b/src/main/scala/leon/solvers/unrolling/LambdaManager.scala
index 29a483aa9..ecc7843f7 100644
--- a/src/main/scala/leon/solvers/unrolling/LambdaManager.scala
+++ b/src/main/scala/leon/solvers/unrolling/LambdaManager.scala
@@ -256,7 +256,7 @@ class LambdaManager[T](encoder: TemplateEncoder[T]) extends DatatypeManager(enco
       (Seq(encoder.mkImplies(blocker, typeBlocker)), Map.empty, Map.empty)
 
     case None =>
-      val App(caller, tpe @ FunctionType(_, to), args, value) = app
+      val App(caller, tpe @ FirstOrderFunctionType(_, to), args, value) = app
       val typeBlocker = encoder.encodeId(FreshIdentifier("t", BooleanType))
       typeBlockers += value -> typeBlocker
       implies(blocker, typeBlocker)
diff --git a/src/main/scala/leon/solvers/unrolling/TemplateManager.scala b/src/main/scala/leon/solvers/unrolling/TemplateManager.scala
index 37f1a031f..d1da0d8f1 100644
--- a/src/main/scala/leon/solvers/unrolling/TemplateManager.scala
+++ b/src/main/scala/leon/solvers/unrolling/TemplateManager.scala
@@ -117,6 +117,15 @@ object Template {
     }
   }
 
+  private def mkApplication(caller: Expr, args: Seq[Expr]): Expr = caller.getType match {
+    case FunctionType(from, to) =>
+      val (curr, next) = args.splitAt(from.size)
+      mkApplication(Application(caller, curr), next)
+    case _ =>
+      assert(args.isEmpty, s"Non-function typed $caller applied to ${args.mkString(",")}")
+      caller
+  }
+
   private def invocationMatcher[T](encodeExpr: Expr => T)(tfd: TypedFunDef, args: Seq[Expr]): Matcher[T] = {
     assert(tfd.returnType.isInstanceOf[FunctionType], "invocationMatcher() is only defined on function-typed defs")
 
@@ -186,7 +195,7 @@ object Template {
     val optIdCall = optCall.map(tfd => TemplateCallInfo[T](tfd, arguments.map(p => Left(p._2))))
     val optIdApp = optApp.map { case (idT, tpe) =>
       val id = FreshIdentifier("x", tpe, true)
-      val encoded = encoder.encodeExpr(Map(id -> idT) ++ arguments)(Application(Variable(id), arguments.map(_._1.toVariable)))
+      val encoded = encoder.encodeExpr(Map(id -> idT) ++ arguments)(mkApplication(Variable(id), arguments.map(_._1.toVariable)))
       App(idT, bestRealType(tpe).asInstanceOf[FunctionType], arguments.map(p => Left(p._2)), encoded)
     }
 
@@ -229,7 +238,7 @@ object Template {
           funInfos ++= firstOrderCallsOf(e).map(p => TemplateCallInfo(p._1, p._2.map(encodeArg)))
           appInfos ++= firstOrderAppsOf(e).map { case (c, args) =>
             val tpe = bestRealType(c.getType).asInstanceOf[FunctionType]
-            App(encodeExpr(c), tpe, args.map(encodeArg), encodeExpr(Application(c, args)))
+            App(encodeExpr(c), tpe, args.map(encodeArg), encodeExpr(mkApplication(c, args)))
           }
 
           matchInfos ++= exprToMatcher.values
-- 
GitLab