From e68a291ae5adb759820dfa2de0c3be4888b7c94f Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?R=C3=A9gis=20Blanc?= <regwblanc@gmail.com>
Date: Wed, 29 Jan 2014 15:23:28 +0100
Subject: [PATCH] Base implementation of an SMTLib printer/solver

---
 project/Build.scala                           |   5 +-
 src/main/scala/leon/Main.scala                |   2 +
 src/main/scala/leon/smtlib/ExprToSExpr.scala  | 107 +++++++++++++
 .../scala/leon/smtlib/PrettyPrinter.scala     |  46 ++++++
 src/main/scala/leon/smtlib/SExprToExpr.scala  |  48 ++++++
 src/main/scala/leon/smtlib/SMTLIBSolver.scala | 147 ++++++++++++++++++
 src/main/scala/leon/smtlib/package.scala      |  30 ++++
 .../solvers/combinators/UnrollingSolver.scala |   1 +
 .../leon/verification/AnalysisPhase.scala     |   4 +-
 9 files changed, 387 insertions(+), 3 deletions(-)
 create mode 100644 src/main/scala/leon/smtlib/ExprToSExpr.scala
 create mode 100644 src/main/scala/leon/smtlib/PrettyPrinter.scala
 create mode 100644 src/main/scala/leon/smtlib/SExprToExpr.scala
 create mode 100644 src/main/scala/leon/smtlib/SMTLIBSolver.scala
 create mode 100644 src/main/scala/leon/smtlib/package.scala

diff --git a/project/Build.scala b/project/Build.scala
index 1b805f4e4..9347b1536 100644
--- a/project/Build.scala
+++ b/project/Build.scala
@@ -74,9 +74,12 @@ object Leon extends Build {
     id = "leon",
     base = file("."),
     settings = Project.defaultSettings ++ LeonProject.settings
-  ).dependsOn(Github.bonsai)
+  ).dependsOn(Github.bonsai, Github.scalaSmtLib)
 
   object Github {
     lazy val bonsai = RootProject(uri("git://github.com/colder/bonsai.git#8f485605785bda98ac61885b0c8036133783290a"))
+
+    private val scalaSmtLibVersion = "160a635e3677a185e2d5bd84669be98fcda8c574"
+    lazy val scalaSmtLib = RootProject(uri("git://github.com/regb/scala-smtlib.git#%s".format(scalaSmtLibVersion)))
   }
 }
diff --git a/src/main/scala/leon/Main.scala b/src/main/scala/leon/Main.scala
index 461db366c..3195b5cfb 100644
--- a/src/main/scala/leon/Main.scala
+++ b/src/main/scala/leon/Main.scala
@@ -276,6 +276,8 @@ object Main {
         ctx.timers.outputTable(debug)
       }
 
+      println("time: " + smtlib.SMTLIBSolver.time)
+
     } catch {
       case LeonFatalError(None) =>
         sys.exit(1)
diff --git a/src/main/scala/leon/smtlib/ExprToSExpr.scala b/src/main/scala/leon/smtlib/ExprToSExpr.scala
new file mode 100644
index 000000000..b8e1be22e
--- /dev/null
+++ b/src/main/scala/leon/smtlib/ExprToSExpr.scala
@@ -0,0 +1,107 @@
+package leon
+package smtlib
+
+import purescala._
+import Common._
+import Trees._
+import Extractors._
+import TreeOps._
+import TypeTrees._
+import Definitions._
+
+import _root_.smtlib.sexpr
+import sexpr.SExprs._
+import _root_.smtlib.Commands.{Identifier => SmtLibIdentifier, _}
+
+/** This pretty-printer prints an SMTLIB2 representation of the Purescala program */
+object ExprToSExpr {
+
+  //returns the set of free identifier
+  def apply(tree: Expr): (SExpr, Set[Identifier]) = {
+
+    var freeVars: Set[Identifier] = Set()
+    
+    def rec(t: Expr): SExpr = t match {
+      case Variable(id) => {
+        val sym = id2sym(id)
+        freeVars += id
+        sym
+      }
+      //case LetTuple(ids,d,e) => SList(List(
+      //  SSymbol("let"),
+      //  SList(List(ids))
+      case Let(b,d,e) => {
+        val id = id2sym(b)
+        val value = rec(d)
+        val newBody = rec(e)
+        freeVars -= b
+        
+        SList(
+          SSymbol("let"),
+          SList(
+            SList(id, value)
+          ),
+          newBody
+        )
+      }
+
+      case er@Error(_) => {
+        val id = FreshIdentifier("error_value").setType(er.getType)
+        val sym = id2sym(id)
+        freeVars += id
+        sym
+      }
+
+      case And(exprs) => SList(SSymbol("and") :: exprs.map(rec).toList)
+      case Or(exprs) => SList(SSymbol("or") :: exprs.map(rec).toList)
+      case Not(expr) => SList(SSymbol("not"), rec(expr))
+      case Equals(l,r) => SList(SSymbol("="), rec(l), rec(r))
+      case IntLiteral(v) => SInt(v)
+      case BooleanLiteral(v) => SSymbol(v.toString) //TODO: maybe need some builtin type here
+      case StringLiteral(s) => SString(s)
+
+      case Implies(l,r) => SList(SSymbol("=>"), rec(l), rec(r))
+      case Iff(l,r) => SList(SSymbol("="), rec(l), rec(r))
+
+      case Plus(l,r) => SList(SSymbol("+"), rec(l), rec(r))
+      case UMinus(expr) => SList(SSymbol("-"), rec(expr))
+      case Minus(l,r) => SList(SSymbol("-"), rec(l), rec(r))
+      case Times(l,r) => SList(SSymbol("*"), rec(l), rec(r))
+      case Division(l,r) => SList(SSymbol("div"), rec(l), rec(r))
+      case Modulo(l,r) => SList(SSymbol("mod"), rec(l), rec(r))
+      case LessThan(l,r) => SList(SSymbol("<"), rec(l), rec(r))
+      case LessEquals(l,r) => SList(SSymbol("<="), rec(l), rec(r))
+      case GreaterThan(l,r) => SList(SSymbol(">"), rec(l), rec(r))
+      case GreaterEquals(l,r) => SList(SSymbol(">="), rec(l), rec(r))
+
+      case IfExpr(c, t, e) => SList(SSymbol("ite"), rec(c), rec(t), rec(e))
+
+      case FunctionInvocation(fd, args) => SList(id2sym(fd.id) :: args.map(rec).toList)
+
+      //case ArrayFill(length, defaultValue) => SList(
+      //  SList(SSymbol("as"), SSymbol("const"), tpe2sort(tree.getType)),
+      //  rec(defaultValue)
+      //)
+      //case ArrayMake(defaultValue) => SList(
+      //  SList(SSymbol("as"), SSymbol("const"), tpe2sort(tree.getType)),
+      //  rec(defaultValue)
+      //)
+      //case ArraySelect(array, index) => SList(SSymbol("select"), rec(array), rec(index))
+      //case ArrayUpdated(array, index, newValue) => SList(SSymbol("store"), rec(array), rec(index), rec(newValue))
+
+      case CaseClass(ccd, args) if args.isEmpty => id2sym(ccd.id)
+      case CaseClass(ccd, args) => SList(id2sym(ccd.id) :: args.map(rec(_)).toList)
+      case CaseClassSelector(_, arg, field) => SList(id2sym(field), rec(arg))
+
+      case CaseClassInstanceOf(ccd, arg) => {
+        val name = id2sym(ccd.id)
+        val testerName = SSymbol("is-" + name.s)
+        SList(testerName, rec(arg))
+      }
+      case o => sys.error("TODO converting to smtlib: " + o)
+    }
+
+    val res = rec(tree)
+    (res, freeVars)
+  }
+}
diff --git a/src/main/scala/leon/smtlib/PrettyPrinter.scala b/src/main/scala/leon/smtlib/PrettyPrinter.scala
new file mode 100644
index 000000000..d13d4cc45
--- /dev/null
+++ b/src/main/scala/leon/smtlib/PrettyPrinter.scala
@@ -0,0 +1,46 @@
+//package leon
+//package smtlib
+//
+//import purescala._
+//import Common._
+//import Trees._
+//import Extractors._
+//import TreeOps._
+//import TypeTrees._
+//import Definitions._
+//
+//import _root_.smtlib.sexpr
+//import sexpr.SExprs._
+//import _root_.smtlib.Commands.{Identifier => SmtLibIdentifier, _}
+//
+//  // prec: there should be no lets and no pattern-matching in this expression
+//  def collectWithPathCondition(matcher: Expr=>Boolean, expression: Expr) : Set[(Seq[Expr],Expr)] = {
+//    var collected : Set[(Seq[Expr],Expr)] = Set.empty
+//
+//      def rec(expr: Expr, path: List[Expr]) : Unit = {
+//        if(matcher(expr)) {
+//          collected = collected + ((path.reverse, expr))
+//        }
+//
+//        expr match {
+//          case Let(i,e,b) => {
+//            rec(e, path)
+//              rec(b, Equals(Variable(i), e) :: path)
+//          }
+//          case IfExpr(cond, thn, els) => {
+//            rec(cond, path)
+//              rec(thn, cond :: path)
+//              rec(els, Not(cond) :: path)
+//          }
+//          case NAryOperator(args, _) => args.foreach(rec(_, path))
+//            case BinaryOperator(t1, t2, _) => rec(t1, path); rec(t2, path)
+//            case UnaryOperator(t, _) => rec(t, path)
+//          case t : Terminal => ;
+//          case _ => scala.sys.error("Unhandled tree in collectWithPathCondition : " + expr)
+//        }
+//      }
+//
+//    rec(expression, Nil)
+//      collected
+//  }
+//}
diff --git a/src/main/scala/leon/smtlib/SExprToExpr.scala b/src/main/scala/leon/smtlib/SExprToExpr.scala
new file mode 100644
index 000000000..9dc8690a0
--- /dev/null
+++ b/src/main/scala/leon/smtlib/SExprToExpr.scala
@@ -0,0 +1,48 @@
+package leon
+package smtlib
+
+import purescala._
+import Common._
+import Trees._
+import Extractors._
+import TreeOps._
+import TypeTrees._
+import Definitions._
+
+import _root_.smtlib.sexpr
+import sexpr.SExprs._
+import _root_.smtlib.Commands.{Identifier => SmtLibIdentifier, _}
+
+object SExprToExpr {
+
+  def apply(sexpr: SExpr, context: Map[String, Expr], constructors: Map[String, CaseClassDef]): Expr = sexpr match {
+    case SInt(n) => IntLiteral(n.toInt)
+    case SSymbol("TRUE") => BooleanLiteral(true)
+    case SSymbol("FALSE") => BooleanLiteral(false)
+    case SSymbol(s) => constructors.get(s) match {
+      case Some(app) => CaseClass(app, Seq())
+      case None => context(s)
+    }
+    case SList(SSymbol(app) :: args) if(constructors.isDefinedAt(app)) => 
+      CaseClass(constructors(app), args.map(apply(_, context, constructors)))
+
+    case SList(List(SSymbol("LET"), SList(defs), body)) => {
+      val leonDefs: Seq[(Identifier, Expr, String)] = defs.map {
+        case SList(List(SSymbol(sym), value)) => (FreshIdentifier(sym), apply(value, context, constructors), sym)
+      }
+      val recBody = apply(body, context ++ leonDefs.map(p => (p._3, p._1.toVariable)), constructors)
+      leonDefs.foldRight(recBody)((binding, b) => Let(binding._1, binding._2, b))
+    }
+    case SList(SSymbol(app) :: args) => {
+      val recArgs = args.map(arg => apply(arg, context, constructors))
+      app match {
+        case "-" => recArgs match {
+          case List(a) => UMinus(a)
+          case List(a, b) => Minus(a, b)
+        }
+      }
+    }
+    case o => sys.error("TODO converting from s-expr: " + o)
+  }
+
+}
diff --git a/src/main/scala/leon/smtlib/SMTLIBSolver.scala b/src/main/scala/leon/smtlib/SMTLIBSolver.scala
new file mode 100644
index 000000000..510b00ff2
--- /dev/null
+++ b/src/main/scala/leon/smtlib/SMTLIBSolver.scala
@@ -0,0 +1,147 @@
+/* Copyright 2009-2013 EPFL, Lausanne */
+
+package leon
+package smtlib
+
+import solvers.IncrementalSolver
+import utils.Interruptible
+import purescala._
+import Common._
+import Trees._
+import Extractors._
+import TreeOps._
+import TypeTrees._
+import Definitions._
+
+import _root_.smtlib.{PrettyPrinter => CommandPrinter, Commands, CommandResponses, sexpr, Interpreter => SMTLIBInterpreter}
+import Commands.{Identifier => SMTLIBIdentifier, _}
+import CommandResponses.{Error => ErrorResponse, _}
+import sexpr.SExprs._
+import _root_.smtlib.Commands.{Identifier => SmtLibIdentifier, _}
+
+
+class SMTLIBSolver(override val context: LeonContext, 
+                   val program: Program, 
+                   smtlibInterpreterFactory: () => SMTLIBInterpreter) 
+  extends IncrementalSolver with Interruptible {
+
+  override def interrupt: Unit = {}
+  override def recoverInterrupt(): Unit = {}
+
+
+  override def name: String = "smtlib-solver"
+  val out = new java.io.FileWriter("vcs/vc_" + SMTLIBSolver.counter)
+
+  //using a factory, so that creation and call to free are local to this class
+  private val smtlibInterpreter = smtlibInterpreterFactory()
+
+  private var errorConstants: Set[(SSymbol, SExpr)] = Set()
+
+  val defs: Seq[ClassTypeDef] = program.definedClasses
+
+  val partition: Seq[(AbstractClassDef, Seq[CaseClassDef])] = {
+    val parents: Seq[AbstractClassDef] = defs.filter(!_.hasParent).asInstanceOf[Seq[AbstractClassDef]]
+    parents.map(p => (p, defs.filter(c => c.parent match {
+      case Some(p2) => p == p2
+      case None => false
+    }).asInstanceOf[Seq[CaseClassDef]]))
+  }
+
+  val sorts: Seq[SExpr] = partition.map{ case (parent, children) => {
+    val name = id2sym(parent.id)
+    val constructors: List[SExpr] = children.map(child => {
+      val fields: List[SExpr] = child.fields.map{case VarDecl(id, tpe) => SList(id2sym(id), tpe2sort(tpe)) }.toList
+      if(fields.isEmpty) id2sym(child.id) else SList(id2sym(child.id) :: fields)
+    }).toList
+
+    SList(name :: constructors)
+  }}
+  val constructors: Map[String, CaseClassDef] = 
+    partition.unzip._2.flatMap(ccds => ccds.map(ccd => (ccd.id.uniqueName.toUpperCase, ccd))).toMap
+  val sortDecls = NonStandardCommand(SList(SSymbol("declare-datatypes"), SList(), SList(sorts.toList)))
+  val funDefDecls: Seq[Command] = program.definedFunctions.flatMap(fd2sexp)
+  val sortErrors: List[Command] = errorConstants.map(p => DeclareFun(p._1.s, Seq(), p._2)).toList
+
+  def declareConst(id: Identifier): Command = DeclareFun(id2sym(id).s, Seq(), tpe2sort(id.getType))
+
+    //SComment("! THEORY=1") +:
+    //SComment("Generated by Leon") +:
+  sendCommand(sortDecls)
+  sortErrors.foreach(sendCommand(_)) 
+  funDefDecls.foreach(sendCommand)
+
+      //convertedFunDefs.flatMap(_._1) ++
+      //convertedFunDefs.map(_._2) ++
+      //convertedFunDefs.flatMap(_._3) ++
+      //Seq(SList(SSymbol("check-sat")))
+
+  private var freeVars: Set[Identifier] = program.definedFunctions.flatMap(fd => fd.args.map(_.id)).toSet
+  override def assertCnstr(expression: Expr): Unit = {
+
+    val (sexpr, exprFreevars) = ExprToSExpr(expression)
+
+    val newFreeVars: Set[Identifier] = exprFreevars.diff(freeVars)
+    freeVars ++= newFreeVars
+    newFreeVars.foreach(id => sendCommand(declareConst(id)))
+
+    sendCommand(Assert(sexpr))
+  }
+
+  override def check: Option[Boolean] = sendCommand(CheckSat) match {
+    case CheckSatResponse(SatStatus) => Some(true)
+    case CheckSatResponse(UnsatStatus) => Some(false)
+    case CheckSatResponse(UnknownStatus) => None
+  }
+
+  override def getModel: Map[Identifier, Expr] = {
+    val ids: List[Identifier] = freeVars.toList
+    val sexprs: List[SSymbol] = ids.map(id => id2sym(id))
+    val cmd: Command = GetValue(sexprs.head, sexprs.tail)
+
+    val symToIds: Map[SExpr, Identifier] = sexprs.map(s => SSymbol(s.s.toUpperCase)).zip(ids).toMap
+    val GetValueResponse(valuationPairs) = sendCommand(cmd)
+    println("got valuation pairs: " + valuationPairs)
+    valuationPairs.map{ case (sym, value) => (symToIds(sym), SExprToExpr(value, Map(), constructors)) }.toMap
+  }
+
+  override def free() = {
+    sendCommand(Exit)
+    smtlibInterpreter.free()
+    out.close
+    SMTLIBSolver.counter += 1
+  }
+
+  override def push(): Unit = {
+    sendCommand(Push(1))
+  }
+  override def pop(lvl: Int = 1): Unit = {
+    sendCommand(Pop(1))
+  }
+
+  def sendCommand(cmd: Command): CommandResponse = {
+    val startTime = System.currentTimeMillis
+    CommandPrinter(cmd, out)
+    out.write("\n")
+    val response = smtlibInterpreter.eval(cmd)
+    assert(!response.isInstanceOf[Error])
+    SMTLIBSolver.time = SMTLIBSolver.time + (System.currentTimeMillis - startTime)
+    response
+  }
+
+  private def fd2sexp(funDef: FunDef): List[Command] = {
+    val name = id2sym(funDef.id)
+    val returnSort = tpe2sort(funDef.returnType)
+
+    val varDecls: List[(SSymbol, SExpr)] = funDef.args.map(vd => (id2sym(vd.id), tpe2sort(vd.tpe))).toList
+
+    val topLevelVarDecl: List[Command] = varDecls.map{ case (name, tpe) => DeclareFun(name.s, Seq(), tpe) }
+    val funDecl: Command = DeclareFun(name.s, varDecls.map(_._2), returnSort)
+
+    funDecl :: topLevelVarDecl
+  }
+}
+
+object SMTLIBSolver {
+  var counter = 0
+  var time: Long = 0
+}
diff --git a/src/main/scala/leon/smtlib/package.scala b/src/main/scala/leon/smtlib/package.scala
new file mode 100644
index 000000000..63f6b6c0f
--- /dev/null
+++ b/src/main/scala/leon/smtlib/package.scala
@@ -0,0 +1,30 @@
+package leon
+
+import purescala._
+import Common._
+import Trees._
+import Extractors._
+import TreeOps._
+import TypeTrees._
+import Definitions._
+
+import _root_.smtlib.sexpr
+import sexpr.SExprs._
+import _root_.smtlib.Commands.{Identifier => SmtLibIdentifier, _}
+
+package object smtlib {
+
+  private[smtlib] def id2sym(id: Identifier): SSymbol = SSymbol(id.uniqueName)
+  //return a series of declarations, an expression that defines the function, 
+  //and the seq of asserts for pre/post-conditions
+
+
+  def tpe2sort[smtlib](tpe: TypeTree): SExpr = tpe match {
+    case Int32Type => SSymbol("Int")
+    case BooleanType => SSymbol("Bool")
+    case ArrayType(baseTpe) => SList(SSymbol("Array"), SSymbol("Int"), tpe2sort(baseTpe))
+    case AbstractClassType(abs) => id2sym(abs.id)
+    case _ => sys.error("TODO tpe2sort: " + tpe)
+  }
+
+}
diff --git a/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala b/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala
index 8841c2b94..728ae07eb 100644
--- a/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala
+++ b/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala
@@ -118,6 +118,7 @@ class UnrollingSolver(val context: LeonContext, underlyings: SolverFactory[Incre
           theModel = Some(model)
       }
     }
+    solver.free
     result
 
   } getOrElse {
diff --git a/src/main/scala/leon/verification/AnalysisPhase.scala b/src/main/scala/leon/verification/AnalysisPhase.scala
index 922ca045f..2d8f544bc 100644
--- a/src/main/scala/leon/verification/AnalysisPhase.scala
+++ b/src/main/scala/leon/verification/AnalysisPhase.scala
@@ -151,7 +151,8 @@ object AnalysisPhase extends LeonPhase[Program,VerificationReport] {
 
     val allSolvers = Map(
       "fairz3" -> SolverFactory(() => new FairZ3Solver(ctx, program) with TimeoutSolver),
-      "enum"   -> SolverFactory(() => new EnumerationSolver(ctx, program) with TimeoutSolver)
+      "enum"   -> SolverFactory(() => new EnumerationSolver(ctx, program) with TimeoutSolver),
+      "smt"    -> SolverFactory(() => new smtlib.SMTLIBSolver(ctx, program, new _root_.smtlib.interpreters.Z3Interpreter))
     )
 
     val reporter = ctx.reporter
@@ -184,7 +185,6 @@ object AnalysisPhase extends LeonPhase[Program,VerificationReport] {
       SolverFactory( () => new PortfolioSolver(ctx, solversToUse.values.toSeq) with TimeoutSolver)
     }
 
-
     val mainSolver = timeout match {
       case Some(sec) =>
         new TimeoutSolverFactory(entrySolver, sec*1000L)
-- 
GitLab