From f22dad2d31df8227f60928eb5137ee9892acb063 Mon Sep 17 00:00:00 2001
From: Manos Koukoutos <emmanouil.koukoutos@epfl.ch>
Date: Tue, 22 Sep 2015 14:48:34 +0200
Subject: [PATCH] Add support for forall to SMTQuantifiedSolver

---
 .../leon/solvers/smtlib/SMTLIBCVC4Solver.scala |  6 +++---
 .../smtlib/SMTLIBQuantifiedSolver.scala        | 10 +++++++++-
 .../leon/solvers/smtlib/SMTLIBZ3Solver.scala   | 18 +++++++++---------
 3 files changed, 21 insertions(+), 13 deletions(-)

diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Solver.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Solver.scala
index 7df712491..2ffd6fa75 100644
--- a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Solver.scala
+++ b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Solver.scala
@@ -46,7 +46,7 @@ class SMTLIBCVC4Solver(context: LeonContext, program: Program) extends SMTLIBSol
     new CVC4Interpreter("cvc4", opts.toArray)
   }
 
-  override def declareSort(t: TypeTree): Sort = {
+  override protected def declareSort(t: TypeTree): Sort = {
     val tpe = normalizeType(t)
     sorts.cachedB(tpe) {
       tpe match {
@@ -65,7 +65,7 @@ class SMTLIBCVC4Solver(context: LeonContext, program: Program) extends SMTLIBSol
     }
   }
 
-  override def fromSMT(s: Term, tpe: TypeTree)(implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): Expr = (s, tpe) match {
+  override protected def fromSMT(s: Term, tpe: TypeTree)(implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): Expr = (s, tpe) match {
     case (SimpleSymbol(s), tp: TypeParameter) =>
       val n = s.name.split("_").toList.last
       GenericValue(tp, n.toInt)
@@ -113,7 +113,7 @@ class SMTLIBCVC4Solver(context: LeonContext, program: Program) extends SMTLIBSol
       super.fromSMT(s, tpe)
   }
 
-  override def toSMT(e: Expr)(implicit bindings: Map[Identifier, Term]) = e match {
+  override protected def toSMT(e: Expr)(implicit bindings: Map[Identifier, Term]) = e match {
     /**
      * ===== Set operations =====
      */
diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBQuantifiedSolver.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBQuantifiedSolver.scala
index f27838071..f5d37db92 100644
--- a/src/main/scala/leon/solvers/smtlib/SMTLIBQuantifiedSolver.scala
+++ b/src/main/scala/leon/solvers/smtlib/SMTLIBQuantifiedSolver.scala
@@ -6,8 +6,9 @@ import leon.purescala.Common.Identifier
 import leon.purescala.Constructors._
 import leon.purescala.Definitions.FunDef
 import leon.purescala.ExprOps._
-import leon.purescala.Expressions.{BooleanLiteral, FunctionInvocation, Expr}
+import leon.purescala.Expressions._
 import leon.verification.VC
+import smtlib.parser.Terms.{ Term, Forall => SMTForall, _ }
 
 trait SMTLIBQuantifiedSolver extends SMTLIBSolver {
 
@@ -44,6 +45,13 @@ trait SMTLIBQuantifiedSolver extends SMTLIBSolver {
 
   }
 
+  override protected def toSMT(e: Expr)(implicit bindings: Map[Identifier, Term]): Term = e match {
+    case Forall(vs, bd) =>
+      quantifiedTerm(SMTForall, vs map { _.id }, bd)
+    case _ =>
+      super.toSMT(e)(bindings)
+  }
+
   // We need to know the function context.
   // The reason is we do not want to assume postconditions of functions referring to
   // the current function, as this may make the proof unsound
diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Solver.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Solver.scala
index ec91138e3..eff01e304 100644
--- a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Solver.scala
+++ b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Solver.scala
@@ -31,11 +31,11 @@ class SMTLIBZ3Solver(context: LeonContext, program: Program) extends SMTLIBSolve
 
   def getNewInterpreter(ctx: LeonContext) = new Z3Interpreter("z3", interpreterOps(ctx).toArray)
 
-  val extSym = SSymbol("_")
+  protected val extSym = SSymbol("_")
 
-  var setSort: Option[SSymbol] = None
+  protected var setSort: Option[SSymbol] = None
 
-  override def declareSort(t: TypeTree): Sort = {
+  override protected def declareSort(t: TypeTree): Sort = {
     val tpe = normalizeType(t)
     sorts.cachedB(tpe) {
       tpe match {
@@ -48,7 +48,7 @@ class SMTLIBZ3Solver(context: LeonContext, program: Program) extends SMTLIBSolve
     }
   }
 
-  def declareSetSort(of: TypeTree): Sort = {
+  protected def declareSetSort(of: TypeTree): Sort = {
     setSort match {
       case None =>
         val s = SSymbol("Set")
@@ -66,7 +66,7 @@ class SMTLIBZ3Solver(context: LeonContext, program: Program) extends SMTLIBSolve
     Sort(SMTIdentifier(setSort.get), Seq(declareSort(of)))
   }
 
-  override def fromSMT(s: Term, tpe: TypeTree)(implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): Expr = (s, tpe) match {
+  override protected def fromSMT(s: Term, tpe: TypeTree)(implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): Expr = (s, tpe) match {
     case (SimpleSymbol(s), tp: TypeParameter) =>
       val n = s.name.split("!").toList.last
       GenericValue(tp, n.toInt)
@@ -93,7 +93,7 @@ class SMTLIBZ3Solver(context: LeonContext, program: Program) extends SMTLIBSolve
       super.fromSMT(s, tpe)
   }
 
-  override def toSMT(e: Expr)(implicit bindings: Map[Identifier, Term]): Term = e match {
+  override protected def toSMT(e: Expr)(implicit bindings: Map[Identifier, Term]): Term = e match {
 
     /**
      * ===== Set operations =====
@@ -131,7 +131,7 @@ class SMTLIBZ3Solver(context: LeonContext, program: Program) extends SMTLIBSolve
       super.toSMT(e)
   }
 
-  def extractRawArray(s: DefineFun, tpe: TypeTree)(implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): RawArrayValue = s match {
+  protected def extractRawArray(s: DefineFun, tpe: TypeTree)(implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): RawArrayValue = s match {
     case DefineFun(SMTFunDef(a, Seq(SortedVar(arg, akind)), rkind, body)) =>
       val (argTpe, retTpe) = tpe match {
         case SetType(base) => (base, BooleanType)
@@ -202,7 +202,7 @@ class SMTLIBZ3Solver(context: LeonContext, program: Program) extends SMTLIBSolve
     new Model(model)
   }
 
-  object ArrayMap {
+  protected object ArrayMap {
     def apply(op: SSymbol, arrs: Term*) = {
       FunctionApplication(
         QualifiedIdentifier(SMTIdentifier(SSymbol("map"), List(op))),
@@ -211,7 +211,7 @@ class SMTLIBZ3Solver(context: LeonContext, program: Program) extends SMTLIBSolve
     }
   }
 
-  object ArrayConst {
+  protected object ArrayConst {
     def apply(sort: Sort, default: Term) = {
       FunctionApplication(
         QualifiedIdentifier(SMTIdentifier(SSymbol("const")), Some(sort)),
-- 
GitLab