From 900e027d2a4ee55f97a2252590d1660c5f27d8a3 Mon Sep 17 00:00:00 2001
From: Manos Koukoutos <emmanouil.koukoutos@epfl.ch>
Date: Thu, 8 Oct 2015 15:38:01 +0200
Subject: [PATCH] Add paramIds to TypedFunDef, use it through the code

---
 src/main/scala/leon/codegen/CodeGeneration.scala          | 2 +-
 src/main/scala/leon/purescala/Definitions.scala           | 4 +++-
 src/main/scala/leon/purescala/ExprOps.scala               | 8 ++++----
 src/main/scala/leon/purescala/Quantification.scala        | 2 +-
 src/main/scala/leon/repair/Repairman.scala                | 4 ++--
 .../leon/solvers/smtlib/SMTLIBQuantifiedSolver.scala      | 2 +-
 .../leon/solvers/smtlib/SMTLIBQuantifiedTarget.scala      | 2 +-
 .../leon/solvers/smtlib/SMTLIBZ3QuantifiedTarget.scala    | 2 +-
 .../scala/leon/solvers/templates/TemplateGenerator.scala  | 2 +-
 9 files changed, 15 insertions(+), 13 deletions(-)

diff --git a/src/main/scala/leon/codegen/CodeGeneration.scala b/src/main/scala/leon/codegen/CodeGeneration.scala
index a616a1a0b..33975a011 100644
--- a/src/main/scala/leon/codegen/CodeGeneration.scala
+++ b/src/main/scala/leon/codegen/CodeGeneration.scala
@@ -183,7 +183,7 @@ trait CodeGeneration {
     // An offset we introduce to the parameters:
     // 1 if this is a method, so we need "this" in position 0 of the stack
     // 1 if we are monitoring
-    val idParams = (if (requireMonitor) Seq(monitorID) else Seq.empty) ++ funDef.params.map(_.id)
+    val idParams = (if (requireMonitor) Seq(monitorID) else Seq.empty) ++ funDef.paramIds
     val newMapping = idParams.zipWithIndex.toMap.mapValues(_ + (if (!isStatic) 1 else 0))
 
     val body = funDef.body.getOrElse(throw CompilationException("Can't compile a FunDef without body: "+funDef.id.name))
diff --git a/src/main/scala/leon/purescala/Definitions.scala b/src/main/scala/leon/purescala/Definitions.scala
index 6467226a0..ba6b11bcf 100644
--- a/src/main/scala/leon/purescala/Definitions.scala
+++ b/src/main/scala/leon/purescala/Definitions.scala
@@ -483,7 +483,7 @@ object Definitions {
 
     def paramSubst(realArgs: Seq[Expr]) = {
       require(realArgs.size == params.size)
-      (params map { _.id } zip realArgs).toMap
+      (paramIds zip realArgs).toMap
     }
 
     def withParamSubst(realArgs: Seq[Expr], e: Expr) = {
@@ -522,6 +522,8 @@ object Definitions {
 
     lazy val returnType: TypeTree = translated(fd.returnType)
 
+    lazy val paramIds = params map { _.id }
+
     private var trCache = Map[Expr, Expr]()
 
     private def cached(e: Expr): Expr = {
diff --git a/src/main/scala/leon/purescala/ExprOps.scala b/src/main/scala/leon/purescala/ExprOps.scala
index 30999a647..2f703312a 100644
--- a/src/main/scala/leon/purescala/ExprOps.scala
+++ b/src/main/scala/leon/purescala/ExprOps.scala
@@ -513,7 +513,7 @@ object ExprOps {
       (expr, idSeqs) => idSeqs.foldLeft(expr match {
         case Lambda(args, _) => args.map(_.id)
         case Forall(args, _) => args.map(_.id)
-        case LetDef(fd, _) => fd.params.map(_.id)
+        case LetDef(fd, _) => fd.paramIds
         case Let(i, _, _) => Seq(i)
         case MatchExpr(_, cses) => cses.flatMap(_.pattern.binders)
         case Passes(_, _, cses) => cses.flatMap(_.pattern.binders)
@@ -1511,7 +1511,7 @@ object ExprOps {
       (fd1.params.size == fd2.params.size) && {
          val newMap = map +
            (fd1.id -> fd2.id) ++
-           (fd1.params zip fd2.params).map{ case (vd1, vd2) => (vd1.id, vd2.id) }
+           (fd1.paramIds zip fd2.paramIds)
          isHomo(fd1.fullBody, fd2.fullBody)(newMap)
       }
     }
@@ -1777,14 +1777,14 @@ object ExprOps {
   def flattenFunctions(fdOuter: FunDef, ctx: LeonContext, p: Program): FunDef = {
     fdOuter.body match {
       case Some(LetDef(fdInner, FunctionInvocation(tfdInner2, args))) if fdInner == tfdInner2.fd =>
-        val argsDef  = fdOuter.params.map(_.id)
+        val argsDef  = fdOuter.paramIds
         val argsCall = args.collect { case Variable(id) => id }
 
         if (argsDef.toSet == argsCall.toSet) {
           val defMap = argsDef.zipWithIndex.toMap
           val rewriteMap = argsCall.map(defMap)
 
-          val innerIdsToOuterIds = (fdInner.params.map(_.id) zip argsCall).toMap
+          val innerIdsToOuterIds = (fdInner.paramIds zip argsCall).toMap
 
           def pre(e: Expr) = e match {
             case FunctionInvocation(tfd, args) if tfd.fd == fdInner =>
diff --git a/src/main/scala/leon/purescala/Quantification.scala b/src/main/scala/leon/purescala/Quantification.scala
index 1214af2aa..38fa0ae9f 100644
--- a/src/main/scala/leon/purescala/Quantification.scala
+++ b/src/main/scala/leon/purescala/Quantification.scala
@@ -108,7 +108,7 @@ object Quantification {
           case _ => Set.empty
         } (fd.fullBody)
 
-        val free = fd.params.map(_.id).toSet ++ (fd.postcondition match {
+        val free = fd.paramIds.toSet ++ (fd.postcondition match {
           case Some(Lambda(args, _)) => args.map(_.id)
           case _ => Seq.empty
         })
diff --git a/src/main/scala/leon/repair/Repairman.scala b/src/main/scala/leon/repair/Repairman.scala
index 47dc7abac..e5a1eaafe 100644
--- a/src/main/scala/leon/repair/Repairman.scala
+++ b/src/main/scala/leon/repair/Repairman.scala
@@ -180,7 +180,7 @@ class Repairman(ctx0: LeonContext, initProgram: Program, fd: FunDef, verifTimeou
       val vrs = report.vrs
 
       vrs.collect { case (_, VCResult(VCStatus.Invalid(ex), _, _)) =>
-        InExample(fd.params.map{vd => ex(vd.id)})
+        InExample(fd.paramIds map ex)
       }
     } finally {
       solverf.shutdown()
@@ -203,7 +203,7 @@ class Repairman(ctx0: LeonContext, initProgram: Program, fd: FunDef, verifTimeou
       case None =>
         _ => true
       case Some(pre) =>
-        val argIds = fd.params.map(_.id)
+        val argIds = fd.paramIds
         evaluator.compile(pre, argIds) match {
           case Some(evalFun) =>
             val sat = EvaluationResults.Successful(BooleanLiteral(true));
diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBQuantifiedSolver.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBQuantifiedSolver.scala
index ba0187239..bfcb70668 100644
--- a/src/main/scala/leon/solvers/smtlib/SMTLIBQuantifiedSolver.scala
+++ b/src/main/scala/leon/solvers/smtlib/SMTLIBQuantifiedSolver.scala
@@ -21,7 +21,7 @@ trait SMTLIBQuantifiedSolver {
   // Normally, UnrollingSolver tracks the input variable, but this one
   // is invoked alone so we have to filter them here
   override def getModel: Model = {
-    val filter = currentFunDef.map{ _.params.map{_.id}.toSet }.getOrElse( (_:Identifier) => true )
+    val filter = currentFunDef.map{ _.paramIds.toSet }.getOrElse( (_:Identifier) => true )
     getModel(filter)
   }
 
diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBQuantifiedTarget.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBQuantifiedTarget.scala
index b6d6fbfb3..d49e33167 100644
--- a/src/main/scala/leon/solvers/smtlib/SMTLIBQuantifiedTarget.scala
+++ b/src/main/scala/leon/solvers/smtlib/SMTLIBQuantifiedTarget.scala
@@ -30,7 +30,7 @@ trait SMTLIBQuantifiedTarget extends SMTLIBTarget {
     val inductiveHyps = for {
       fi@FunctionInvocation(tfd, args) <- functionCallsOf(cond).toSeq
     } yield {
-      val formalToRealArgs = tfd.params.map{ _.id}.zip(args).toMap
+      val formalToRealArgs = tfd.paramIds.zip(args).toMap
       val post = tfd.postcondition map { post =>
         application(
           replaceFromIDs(formalToRealArgs, post),
diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3QuantifiedTarget.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3QuantifiedTarget.scala
index 4c8eedce5..f3c7cb69e 100644
--- a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3QuantifiedTarget.scala
+++ b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3QuantifiedTarget.scala
@@ -37,7 +37,7 @@ trait SMTLIBZ3QuantifiedTarget extends SMTLIBZ3Target with SMTLIBQuantifiedTarge
       val tfd = functions.toA(sym)
       val term = quantifiedTerm(
         SMTForall,
-        tfd.params map { _.id },
+        tfd.paramIds,
         Equals(
           FunctionInvocation(tfd, tfd.params.map {_.toVariable}),
           tfd.body.get
diff --git a/src/main/scala/leon/solvers/templates/TemplateGenerator.scala b/src/main/scala/leon/solvers/templates/TemplateGenerator.scala
index 5c746db2f..f6484ff7b 100644
--- a/src/main/scala/leon/solvers/templates/TemplateGenerator.scala
+++ b/src/main/scala/leon/solvers/templates/TemplateGenerator.scala
@@ -44,7 +44,7 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T],
     val newBody : Option[Expr] = tfd.body.map(b => matchToIfThenElse(b))
     val lambdaBody : Option[Expr] = newBody.map(b => simplifyHOFunctions(b))
 
-    val funDefArgs: Seq[Identifier] = tfd.params.map(_.id)
+    val funDefArgs: Seq[Identifier] = tfd.paramIds
     val lambdaArguments: Seq[Identifier] = lambdaBody.map(lambdaArgs).toSeq.flatten
     val invocation : Expr = FunctionInvocation(tfd, funDefArgs.map(_.toVariable))
 
-- 
GitLab