From e445f6d34fd1113a88902cc42f4ae4a7bf7114ba Mon Sep 17 00:00:00 2001
From: Manos Koukoutos <emmanouil.koukoutos@epfl.ch>
Date: Tue, 24 Mar 2015 18:17:50 +0100
Subject: [PATCH] SMTLIB-CVC4 with define-funs-rec

---
 src/main/scala/leon/purescala/DefOps.scala    | 26 +++++++--
 .../scala/leon/purescala/Definitions.scala    |  2 +
 src/main/scala/leon/purescala/ExprOps.scala   |  1 +
 .../scala/leon/solvers/SolverFactory.scala    |  5 +-
 .../smtlib/SMTLIBUnrollingCVC4Target.scala    | 55 +++++++++++++++++++
 src/main/scala/leon/utils/SearchSpace.scala   | 22 ++++++++
 6 files changed, 104 insertions(+), 7 deletions(-)
 create mode 100644 src/main/scala/leon/solvers/smtlib/SMTLIBUnrollingCVC4Target.scala
 create mode 100644 src/main/scala/leon/utils/SearchSpace.scala

diff --git a/src/main/scala/leon/purescala/DefOps.scala b/src/main/scala/leon/purescala/DefOps.scala
index 5175737f5..8edf77f51 100644
--- a/src/main/scala/leon/purescala/DefOps.scala
+++ b/src/main/scala/leon/purescala/DefOps.scala
@@ -2,6 +2,7 @@ package leon.purescala
 
 import Definitions._
 import Expressions._
+import ExprOps.{preMap, postMap, functionCallsOf}
 
 object DefOps {
 
@@ -262,11 +263,7 @@ object DefOps {
     
   }
   
-  
-  
-  import Expressions.Expr
-  import ExprOps.{preMap, postMap}
-  
+
   /*
    * Apply an expression operation on all expressions contained in a FunDef
    */
@@ -385,4 +382,21 @@ object DefOps {
     }(e)
   }
 
-}
+  /**
+   * Returns a call graph starting from the given sources, taking into account
+   * instantiations of function type parameters,
+   * If given limit of explored nodes reached, it returns a partial set of reached TypedFunDefs
+   * and the boolean set to "false".
+   * Otherwise, it returns the full set of reachable TypedFunDefs and "true"
+   */
+
+  def typedTransitiveCallees(sources: Set[TypedFunDef], limit: Option[Int] = None): (Set[TypedFunDef], Boolean) = {
+    import leon.utils.SearchSpace.reachable
+    reachable(
+      sources,
+      (tfd: TypedFunDef) => functionCallsOf(tfd.fd.fullBody) map { _.tfd },
+      limit
+    )
+  }
+
+}
\ No newline at end of file
diff --git a/src/main/scala/leon/purescala/Definitions.scala b/src/main/scala/leon/purescala/Definitions.scala
index 07769dbad..ee0f6a03c 100644
--- a/src/main/scala/leon/purescala/Definitions.scala
+++ b/src/main/scala/leon/purescala/Definitions.scala
@@ -437,6 +437,8 @@ object Definitions {
       TypedFunDef(this, tparams.map(_.tp))
     }
 
+    def isRecursive(p: Program) = p.callGraph.transitiveCallees(this) contains this
+
     setSubDefOwners()
     // Deprecated, old API
     @deprecated("Use .body instead", "2.3")
diff --git a/src/main/scala/leon/purescala/ExprOps.scala b/src/main/scala/leon/purescala/ExprOps.scala
index 566a50e21..a91daaa7c 100644
--- a/src/main/scala/leon/purescala/ExprOps.scala
+++ b/src/main/scala/leon/purescala/ExprOps.scala
@@ -1955,6 +1955,7 @@ object ExprOps {
     case _ => None
   }
 
+
   /**
    * Deprecated API
    * ========
diff --git a/src/main/scala/leon/solvers/SolverFactory.scala b/src/main/scala/leon/solvers/SolverFactory.scala
index 360c48d6c..e5c289b8c 100644
--- a/src/main/scala/leon/solvers/SolverFactory.scala
+++ b/src/main/scala/leon/solvers/SolverFactory.scala
@@ -19,7 +19,7 @@ object SolverFactory {
     }
   }
 
-  val definedSolvers = Set("fairz3", "unrollz3", "enum", "smt", "smt-z3", "smt-cvc4")
+  val definedSolvers = Set("fairz3", "unrollz3", "enum", "smt", "smt-z3", "smt-cvc4", "smt-2.5-cvc4")
 
   def getFromSettings[S](ctx: LeonContext, program: Program): SolverFactory[TimeoutSolver] = {
     import combinators._
@@ -42,6 +42,9 @@ object SolverFactory {
       case "smt-cvc4" =>
         SolverFactory(() => new UnrollingSolver(ctx, program, new SMTLIBSolver(ctx, program) with SMTLIBCVC4Target) with TimeoutSolver)
 
+      case "smt-2.5-cvc4" =>
+        SolverFactory(() => new SMTLIBSolver(ctx, program) with SMTLIBUnrollingCVC4Target with TimeoutSolver)
+
       case _ =>
         ctx.reporter.fatalError("Unknown solver "+name)
     }
diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBUnrollingCVC4Target.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBUnrollingCVC4Target.scala
new file mode 100644
index 000000000..9fd5eeba2
--- /dev/null
+++ b/src/main/scala/leon/solvers/smtlib/SMTLIBUnrollingCVC4Target.scala
@@ -0,0 +1,55 @@
+/* Copyright 2009-2015 EPFL, Lausanne */
+
+package leon
+package solvers.smtlib
+
+import purescala.Definitions.TypedFunDef
+import purescala.DefOps.typedTransitiveCallees
+import purescala.ExprOps.matchToIfThenElse
+import smtlib.parser.Commands._
+import smtlib.parser.Terms._
+
+trait SMTLIBUnrollingCVC4Target extends SMTLIBCVC4Target {
+  this: SMTLIBSolver =>
+
+  private val typedFunDefExplorationLimit = 10000
+
+  override def targetName = "2.5-cvc4"
+  override def declareFunction(tfd: TypedFunDef): SSymbol = {
+    if (tfd.params.isEmpty) {
+      super[SMTLIBCVC4Target].declareFunction(tfd)
+    } else {
+      val (funs, exploredAll) = typedTransitiveCallees(Set(tfd), Some(typedFunDefExplorationLimit))
+      if (!exploredAll) {
+        reporter.warning(
+          s"Did not manage to explore the space of typed functions called from ${tfd.id}. The solver may fail"
+        )
+      }
+
+      val (smtFunDecls, smtBodies) = funs.toSeq.collect {
+        case tfd if !functions.containsA(tfd) && tfd.params.nonEmpty =>
+          val id = if (tfd.tps.isEmpty) {
+            tfd.id
+          } else {
+            tfd.id.freshen
+          }
+          val sym = id2sym(id)
+          functions +=(tfd, sym)
+          (
+            FunDec(
+              sym,
+              tfd.params map { p => SortedVar(id2sym(p.id), declareSort(p.getType)) },
+              declareSort(tfd.returnType)
+            ),
+            toSMT(matchToIfThenElse(tfd.body.get))(tfd.params.map { p =>
+              (p.id, id2sym(p.id): Term)
+            }.toMap)
+            )
+      }.unzip
+
+      if (smtFunDecls.nonEmpty) sendCommand(DefineFunsRec(smtFunDecls, smtBodies))
+      functions.toB(tfd)
+    }
+  }
+
+}
diff --git a/src/main/scala/leon/utils/SearchSpace.scala b/src/main/scala/leon/utils/SearchSpace.scala
new file mode 100644
index 000000000..ee0424210
--- /dev/null
+++ b/src/main/scala/leon/utils/SearchSpace.scala
@@ -0,0 +1,22 @@
+package leon.utils
+
+object SearchSpace {
+
+  def reachable[A](sources: Set[A], generateNeighbors: A => Set[A], limit: Option[Int] = None): (Set[A], Boolean) = {
+    require(limit forall(_ >= 0))
+    def rec(seen: Set[A], toSee: Set[A], limit: Option[Int]): (Set[A], Boolean) = {
+      (toSee.headOption, limit) match {
+        case (None, _) =>
+          (seen, true)
+        case (_, Some(0)) =>
+          (seen, false)
+        case (Some(hd), _) =>
+          val neighbors = generateNeighbors(hd)
+          rec(seen + hd, toSee ++ neighbors -- (seen + hd), limit map {_ - 1})
+      }
+    }
+
+    rec(Set(), sources, limit)
+  }
+
+}
-- 
GitLab