From 4ae5ae4f22ba875f4d30d3518fda419bb7c27eb9 Mon Sep 17 00:00:00 2001
From: Manos Koukoutos <emmanouil.koukoutos@epfl.ch>
Date: Thu, 7 May 2015 14:55:22 +0200
Subject: [PATCH] Solver should not exit when encountering unsupported Tree

---
 .../frontends/scalac/ExtractionPhase.scala    |   2 -
 .../smtlib/SMTLIBCVC4QuantifiedTarget.scala   |  41 ++++--
 .../leon/solvers/smtlib/SMTLIBTarget.scala    | 134 +++++++++---------
 3 files changed, 93 insertions(+), 84 deletions(-)

diff --git a/src/main/scala/leon/frontends/scalac/ExtractionPhase.scala b/src/main/scala/leon/frontends/scalac/ExtractionPhase.scala
index 80a6691a1..c4131034a 100644
--- a/src/main/scala/leon/frontends/scalac/ExtractionPhase.scala
+++ b/src/main/scala/leon/frontends/scalac/ExtractionPhase.scala
@@ -6,8 +6,6 @@ package frontends.scalac
 import purescala.Definitions.Program
 import purescala.Common.FreshIdentifier
 
-import purescala.ScalaPrinter
-
 import utils._
 
 import scala.tools.nsc.{Settings,CompilerCommand}
diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4QuantifiedTarget.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4QuantifiedTarget.scala
index 808a3b4db..305457527 100644
--- a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4QuantifiedTarget.scala
+++ b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4QuantifiedTarget.scala
@@ -4,10 +4,11 @@ package leon
 package solvers.smtlib
 
 import purescala.Common.FreshIdentifier
-import purescala.Expressions.{FunctionInvocation, BooleanLiteral, Expr, Implies}
+import leon.purescala.Expressions._
 import purescala.Definitions.TypedFunDef
 import purescala.Constructors.{application, implies}
 import purescala.DefOps.typedTransitiveCallees
+import smtlib.parser.Commands.Assert
 import smtlib.parser.Commands._
 import smtlib.parser.Terms._
 import smtlib.theories.Core.Equals
@@ -44,17 +45,23 @@ trait SMTLIBCVC4QuantifiedTarget extends SMTLIBCVC4Target {
 
       functions +=(tfd, id2sym(id))
 
-      val bodyAssert = Assert(Equals(id2sym(id): Term, toSMT(tfd.body.get)(Map())))
-
-      val specAssert = tfd.postcondition map { post =>
-        val term = implies(
-          tfd.precondition getOrElse BooleanLiteral(true),
-          application(post, Seq(FunctionInvocation(tfd, Seq())))
-        )
-        Assert(toSMT(term)(Map()))
+      try {
+        val bodyAssert = Assert(Equals(id2sym(id): Term, toSMT(tfd.body.get)(Map())))
+
+        val specAssert = tfd.postcondition map { post =>
+          val term = implies(
+            tfd.precondition getOrElse BooleanLiteral(true),
+            application(post, Seq(FunctionInvocation(tfd, Seq())))
+          )
+          Assert(toSMT(term)(Map()))
+        }
+
+        Seq(bodyAssert) ++ specAssert
+      } catch {
+        case _ : IllegalArgumentException =>
+          addError()
+          Seq()
       }
-
-      Seq(bodyAssert) ++ specAssert
     }
 
     val seen = withParams filterNot functions.containsA
@@ -76,9 +83,15 @@ trait SMTLIBCVC4QuantifiedTarget extends SMTLIBCVC4Target {
 
     val smtBodies = smtFunDecls map { case FunDec(sym, _, _) =>
       val tfd = functions.toA(sym)
-      toSMT(tfd.body.get)(tfd.params.map { p =>
-        (p.id, id2sym(p.id): Term)
-      }.toMap)
+      try {
+        toSMT(tfd.body.get)(tfd.params.map { p =>
+          (p.id, id2sym(p.id): Term)
+        }.toMap)
+      } catch {
+        case i: IllegalArgumentException =>
+          addError()
+          toSMT(Error(tfd.body.get.getType, ""))(Map())
+      }
     }
 
     if (smtFunDecls.nonEmpty) {
diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala
index 7d694d10d..54ed63031 100644
--- a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala
+++ b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala
@@ -64,6 +64,9 @@ trait SMTLIBTarget {
   val genericValues = new IncrementalBijection[GenericValue, SSymbol]()
   val sorts         = new IncrementalBijection[TypeTree, Sort]()
   val functions     = new IncrementalBijection[TypedFunDef, SSymbol]()
+  val errors        = new IncrementalBijection[Unit, Boolean]()
+  protected def hasError = errors.getB(()) contains true
+  protected def addError() = errors += () -> true
 
   protected object OptionManager {
     lazy val leonOption = program.library.Option.get
@@ -526,72 +529,58 @@ trait SMTLIBTarget {
       case ap @ Application(caller, args) =>
         ArraysEx.Select(toSMT(caller), toSMT(tupleWrap(args)))
 
-      case e @ UnaryOperator(u, _) =>
-        e match {
-          case (_: Not) => Core.Not(toSMT(u))
-          case (_: UMinus) => Ints.Neg(toSMT(u))
-          case (_: BVUMinus) => FixedSizeBitVectors.Neg(toSMT(u))
-          case (_: BVNot) => FixedSizeBitVectors.Not(toSMT(u))
-          case _ => reporter.fatalError("Unhandled unary "+e)
-        }
-
-      case e @ BinaryOperator(a, b, _) =>
-        e match {
-          case (_: Assert) => toSMT(IfExpr(a, b, Error(b.getType, "assertion failed")))
-          case (_: Equals) => Core.Equals(toSMT(a), toSMT(b))
-          case (_: Implies) => Core.Implies(toSMT(a), toSMT(b))
-          case (_: Plus) => Ints.Add(toSMT(a), toSMT(b))
-          case (_: Minus) => Ints.Sub(toSMT(a), toSMT(b))
-          case (_: Times) => Ints.Mul(toSMT(a), toSMT(b))
-          case (_: Division) => Ints.Div(toSMT(a), toSMT(b))
-          case (_: Modulo) => Ints.Mod(toSMT(a), toSMT(b))
-          case (_: LessThan) => a.getType match {
-            case Int32Type => FixedSizeBitVectors.SLessThan(toSMT(a), toSMT(b))
-            case IntegerType => Ints.LessThan(toSMT(a), toSMT(b))
-          }
-          case (_: LessEquals) => a.getType match {
-            case Int32Type => FixedSizeBitVectors.SLessEquals(toSMT(a), toSMT(b))
-            case IntegerType => Ints.LessEquals(toSMT(a), toSMT(b))
-          }
-          case (_: GreaterThan) => a.getType match {
-            case Int32Type => FixedSizeBitVectors.SGreaterThan(toSMT(a), toSMT(b))
-            case IntegerType => Ints.GreaterThan(toSMT(a), toSMT(b))
-          }
-          case (_: GreaterEquals) => a.getType match {
-            case Int32Type => FixedSizeBitVectors.SGreaterEquals(toSMT(a), toSMT(b))
-            case IntegerType => Ints.GreaterEquals(toSMT(a), toSMT(b))
-          }
-          case (_: BVPlus) => FixedSizeBitVectors.Add(toSMT(a), toSMT(b))
-          case (_: BVMinus) => FixedSizeBitVectors.Sub(toSMT(a), toSMT(b))
-          case (_: BVTimes) => FixedSizeBitVectors.Mul(toSMT(a), toSMT(b))
-          case (_: BVDivision) => FixedSizeBitVectors.SDiv(toSMT(a), toSMT(b))
-          case (_: BVModulo) => FixedSizeBitVectors.SRem(toSMT(a), toSMT(b))
-          case (_: BVAnd) => FixedSizeBitVectors.And(toSMT(a), toSMT(b))
-          case (_: BVOr) => FixedSizeBitVectors.Or(toSMT(a), toSMT(b))
-          case (_: BVXOr) => FixedSizeBitVectors.XOr(toSMT(a), toSMT(b))
-          case (_: BVShiftLeft) => FixedSizeBitVectors.ShiftLeft(toSMT(a), toSMT(b))
-          case (_: BVAShiftRight) => FixedSizeBitVectors.AShiftRight(toSMT(a), toSMT(b))
-          case (_: BVLShiftRight) => FixedSizeBitVectors.LShiftRight(toSMT(a), toSMT(b))
-          case _ => reporter.fatalError("Unhandled binary "+e)
-        }
-
-      case e @ NAryOperator(sub, _) =>
-        e match {
-          case (_: And) => Core.And(sub.map(toSMT): _*)
-          case (_: Or) => Core.Or(sub.map(toSMT): _*)
-          case (_: IfExpr) => Core.ITE(toSMT(sub(0)), toSMT(sub(1)), toSMT(sub(2))) 
-          case (f: FunctionInvocation) => 
-            if (sub.isEmpty) declareFunction(f.tfd) else {
-              FunctionApplication(
-                declareFunction(f.tfd),
-                sub.map(toSMT)
-              )
-            }
-          case _ => reporter.fatalError("Unhandled nary "+e)
+      case Not(u) => Core.Not(toSMT(u))
+      case UMinus(u) => Ints.Neg(toSMT(u))
+      case BVUMinus(u) => FixedSizeBitVectors.Neg(toSMT(u))
+      case BVNot(u) => FixedSizeBitVectors.Not(toSMT(u))
+      case Assert(a,_, b) => toSMT(IfExpr(a, b, Error(b.getType, "assertion failed")))
+      case Equals(a,b) => Core.Equals(toSMT(a), toSMT(b))
+      case Implies(a,b) => Core.Implies(toSMT(a), toSMT(b))
+      case Plus(a,b) => Ints.Add(toSMT(a), toSMT(b))
+      case Minus(a,b) => Ints.Sub(toSMT(a), toSMT(b))
+      case Times(a,b) => Ints.Mul(toSMT(a), toSMT(b))
+      case Division(a,b) => Ints.Div(toSMT(a), toSMT(b))
+      case Modulo(a,b) => Ints.Mod(toSMT(a), toSMT(b))
+      case LessThan(a,b) => a.getType match {
+        case Int32Type => FixedSizeBitVectors.SLessThan(toSMT(a), toSMT(b))
+        case IntegerType => Ints.LessThan(toSMT(a), toSMT(b))
+      }
+      case LessEquals(a,b) => a.getType match {
+        case Int32Type => FixedSizeBitVectors.SLessEquals(toSMT(a), toSMT(b))
+        case IntegerType => Ints.LessEquals(toSMT(a), toSMT(b))
+      }
+      case GreaterThan(a,b) => a.getType match {
+        case Int32Type => FixedSizeBitVectors.SGreaterThan(toSMT(a), toSMT(b))
+        case IntegerType => Ints.GreaterThan(toSMT(a), toSMT(b))
+      }
+      case GreaterEquals(a,b) => a.getType match {
+        case Int32Type => FixedSizeBitVectors.SGreaterEquals(toSMT(a), toSMT(b))
+        case IntegerType => Ints.GreaterEquals(toSMT(a), toSMT(b))
+      }
+      case BVPlus(a,b) => FixedSizeBitVectors.Add(toSMT(a), toSMT(b))
+      case BVMinus(a,b) => FixedSizeBitVectors.Sub(toSMT(a), toSMT(b))
+      case BVTimes(a,b) => FixedSizeBitVectors.Mul(toSMT(a), toSMT(b))
+      case BVDivision(a,b) => FixedSizeBitVectors.SDiv(toSMT(a), toSMT(b))
+      case BVModulo(a,b) => FixedSizeBitVectors.SRem(toSMT(a), toSMT(b))
+      case BVAnd(a,b) => FixedSizeBitVectors.And(toSMT(a), toSMT(b))
+      case BVOr(a,b) => FixedSizeBitVectors.Or(toSMT(a), toSMT(b))
+      case BVXOr(a,b) => FixedSizeBitVectors.XOr(toSMT(a), toSMT(b))
+      case BVShiftLeft(a,b) => FixedSizeBitVectors.ShiftLeft(toSMT(a), toSMT(b))
+      case BVAShiftRight(a,b) => FixedSizeBitVectors.AShiftRight(toSMT(a), toSMT(b))
+      case BVLShiftRight(a,b) => FixedSizeBitVectors.LShiftRight(toSMT(a), toSMT(b))
+      case And(sub) => Core.And(sub.map(toSMT): _*)
+      case Or(sub) => Core.Or(sub.map(toSMT): _*)
+      case IfExpr(cond, thenn, elze) => Core.ITE(toSMT(cond), toSMT(thenn), toSMT(elze))
+      case f@FunctionInvocation(_, sub) =>
+        if (sub.isEmpty) declareFunction(f.tfd) else {
+          FunctionApplication(
+            declareFunction(f.tfd),
+            sub.map(toSMT)
+          )
         }
-
       case o =>
-        unsupported("Tree: " + o)
+        reporter.warning(s"Unsupported Tree in smt-$targetName: $o")
+        throw new IllegalArgumentException
     }
   }
 
@@ -693,7 +682,7 @@ trait SMTLIBTarget {
       out.write("\n")
       out.flush()
     }
-    interpreter.eval(cmd) match {
+    if (hasError) Unsupported else interpreter.eval(cmd) match {
       case err@ErrorResponse(msg) if !interrupted =>
         reporter.fatalError("Unexpected error from smt-"+targetName+" solver: "+msg)
       case res => res
@@ -702,8 +691,16 @@ trait SMTLIBTarget {
 
   override def assertCnstr(expr: Expr): Unit = {
     variablesOf(expr).foreach(declareVariable)
-    val term = toSMT(expr)(Map())
-    sendCommand(SMTAssert(term))
+    try {
+      val term = toSMT(expr)(Map())
+      sendCommand(SMTAssert(term))
+    } catch {
+      case i : IllegalArgumentException =>
+        // Store that there was an error. Now all following check()
+        // invocations will return None
+        addError()
+    }
+
   }
 
   override def check: Option[Boolean] = sendCommand(CheckSat()) match {
@@ -743,7 +740,7 @@ trait SMTLIBTarget {
     genericValues.push()
     sorts.push()
     functions.push()
-
+    errors.push()
     sendCommand(Push(1))
   }
 
@@ -757,6 +754,7 @@ trait SMTLIBTarget {
     genericValues.pop()
     sorts.pop()
     functions.pop()
+    errors.pop()
 
     sendCommand(Pop(1))
   }
-- 
GitLab