Skip to content
Snippets Groups Projects
Commit 4ae5ae4f authored by Manos Koukoutos's avatar Manos Koukoutos
Browse files

Solver should not exit when encountering unsupported Tree

parent eac8e7f2
Branches
Tags
No related merge requests found
......@@ -6,8 +6,6 @@ package frontends.scalac
import purescala.Definitions.Program
import purescala.Common.FreshIdentifier
import purescala.ScalaPrinter
import utils._
import scala.tools.nsc.{Settings,CompilerCommand}
......
......@@ -4,10 +4,11 @@ package leon
package solvers.smtlib
import purescala.Common.FreshIdentifier
import purescala.Expressions.{FunctionInvocation, BooleanLiteral, Expr, Implies}
import leon.purescala.Expressions._
import purescala.Definitions.TypedFunDef
import purescala.Constructors.{application, implies}
import purescala.DefOps.typedTransitiveCallees
import smtlib.parser.Commands.Assert
import smtlib.parser.Commands._
import smtlib.parser.Terms._
import smtlib.theories.Core.Equals
......@@ -44,17 +45,23 @@ trait SMTLIBCVC4QuantifiedTarget extends SMTLIBCVC4Target {
functions +=(tfd, id2sym(id))
val bodyAssert = Assert(Equals(id2sym(id): Term, toSMT(tfd.body.get)(Map())))
val specAssert = tfd.postcondition map { post =>
val term = implies(
tfd.precondition getOrElse BooleanLiteral(true),
application(post, Seq(FunctionInvocation(tfd, Seq())))
)
Assert(toSMT(term)(Map()))
try {
val bodyAssert = Assert(Equals(id2sym(id): Term, toSMT(tfd.body.get)(Map())))
val specAssert = tfd.postcondition map { post =>
val term = implies(
tfd.precondition getOrElse BooleanLiteral(true),
application(post, Seq(FunctionInvocation(tfd, Seq())))
)
Assert(toSMT(term)(Map()))
}
Seq(bodyAssert) ++ specAssert
} catch {
case _ : IllegalArgumentException =>
addError()
Seq()
}
Seq(bodyAssert) ++ specAssert
}
val seen = withParams filterNot functions.containsA
......@@ -76,9 +83,15 @@ trait SMTLIBCVC4QuantifiedTarget extends SMTLIBCVC4Target {
val smtBodies = smtFunDecls map { case FunDec(sym, _, _) =>
val tfd = functions.toA(sym)
toSMT(tfd.body.get)(tfd.params.map { p =>
(p.id, id2sym(p.id): Term)
}.toMap)
try {
toSMT(tfd.body.get)(tfd.params.map { p =>
(p.id, id2sym(p.id): Term)
}.toMap)
} catch {
case i: IllegalArgumentException =>
addError()
toSMT(Error(tfd.body.get.getType, ""))(Map())
}
}
if (smtFunDecls.nonEmpty) {
......
......@@ -64,6 +64,9 @@ trait SMTLIBTarget {
val genericValues = new IncrementalBijection[GenericValue, SSymbol]()
val sorts = new IncrementalBijection[TypeTree, Sort]()
val functions = new IncrementalBijection[TypedFunDef, SSymbol]()
val errors = new IncrementalBijection[Unit, Boolean]()
protected def hasError = errors.getB(()) contains true
protected def addError() = errors += () -> true
protected object OptionManager {
lazy val leonOption = program.library.Option.get
......@@ -526,72 +529,58 @@ trait SMTLIBTarget {
case ap @ Application(caller, args) =>
ArraysEx.Select(toSMT(caller), toSMT(tupleWrap(args)))
case e @ UnaryOperator(u, _) =>
e match {
case (_: Not) => Core.Not(toSMT(u))
case (_: UMinus) => Ints.Neg(toSMT(u))
case (_: BVUMinus) => FixedSizeBitVectors.Neg(toSMT(u))
case (_: BVNot) => FixedSizeBitVectors.Not(toSMT(u))
case _ => reporter.fatalError("Unhandled unary "+e)
}
case e @ BinaryOperator(a, b, _) =>
e match {
case (_: Assert) => toSMT(IfExpr(a, b, Error(b.getType, "assertion failed")))
case (_: Equals) => Core.Equals(toSMT(a), toSMT(b))
case (_: Implies) => Core.Implies(toSMT(a), toSMT(b))
case (_: Plus) => Ints.Add(toSMT(a), toSMT(b))
case (_: Minus) => Ints.Sub(toSMT(a), toSMT(b))
case (_: Times) => Ints.Mul(toSMT(a), toSMT(b))
case (_: Division) => Ints.Div(toSMT(a), toSMT(b))
case (_: Modulo) => Ints.Mod(toSMT(a), toSMT(b))
case (_: LessThan) => a.getType match {
case Int32Type => FixedSizeBitVectors.SLessThan(toSMT(a), toSMT(b))
case IntegerType => Ints.LessThan(toSMT(a), toSMT(b))
}
case (_: LessEquals) => a.getType match {
case Int32Type => FixedSizeBitVectors.SLessEquals(toSMT(a), toSMT(b))
case IntegerType => Ints.LessEquals(toSMT(a), toSMT(b))
}
case (_: GreaterThan) => a.getType match {
case Int32Type => FixedSizeBitVectors.SGreaterThan(toSMT(a), toSMT(b))
case IntegerType => Ints.GreaterThan(toSMT(a), toSMT(b))
}
case (_: GreaterEquals) => a.getType match {
case Int32Type => FixedSizeBitVectors.SGreaterEquals(toSMT(a), toSMT(b))
case IntegerType => Ints.GreaterEquals(toSMT(a), toSMT(b))
}
case (_: BVPlus) => FixedSizeBitVectors.Add(toSMT(a), toSMT(b))
case (_: BVMinus) => FixedSizeBitVectors.Sub(toSMT(a), toSMT(b))
case (_: BVTimes) => FixedSizeBitVectors.Mul(toSMT(a), toSMT(b))
case (_: BVDivision) => FixedSizeBitVectors.SDiv(toSMT(a), toSMT(b))
case (_: BVModulo) => FixedSizeBitVectors.SRem(toSMT(a), toSMT(b))
case (_: BVAnd) => FixedSizeBitVectors.And(toSMT(a), toSMT(b))
case (_: BVOr) => FixedSizeBitVectors.Or(toSMT(a), toSMT(b))
case (_: BVXOr) => FixedSizeBitVectors.XOr(toSMT(a), toSMT(b))
case (_: BVShiftLeft) => FixedSizeBitVectors.ShiftLeft(toSMT(a), toSMT(b))
case (_: BVAShiftRight) => FixedSizeBitVectors.AShiftRight(toSMT(a), toSMT(b))
case (_: BVLShiftRight) => FixedSizeBitVectors.LShiftRight(toSMT(a), toSMT(b))
case _ => reporter.fatalError("Unhandled binary "+e)
}
case e @ NAryOperator(sub, _) =>
e match {
case (_: And) => Core.And(sub.map(toSMT): _*)
case (_: Or) => Core.Or(sub.map(toSMT): _*)
case (_: IfExpr) => Core.ITE(toSMT(sub(0)), toSMT(sub(1)), toSMT(sub(2)))
case (f: FunctionInvocation) =>
if (sub.isEmpty) declareFunction(f.tfd) else {
FunctionApplication(
declareFunction(f.tfd),
sub.map(toSMT)
)
}
case _ => reporter.fatalError("Unhandled nary "+e)
case Not(u) => Core.Not(toSMT(u))
case UMinus(u) => Ints.Neg(toSMT(u))
case BVUMinus(u) => FixedSizeBitVectors.Neg(toSMT(u))
case BVNot(u) => FixedSizeBitVectors.Not(toSMT(u))
case Assert(a,_, b) => toSMT(IfExpr(a, b, Error(b.getType, "assertion failed")))
case Equals(a,b) => Core.Equals(toSMT(a), toSMT(b))
case Implies(a,b) => Core.Implies(toSMT(a), toSMT(b))
case Plus(a,b) => Ints.Add(toSMT(a), toSMT(b))
case Minus(a,b) => Ints.Sub(toSMT(a), toSMT(b))
case Times(a,b) => Ints.Mul(toSMT(a), toSMT(b))
case Division(a,b) => Ints.Div(toSMT(a), toSMT(b))
case Modulo(a,b) => Ints.Mod(toSMT(a), toSMT(b))
case LessThan(a,b) => a.getType match {
case Int32Type => FixedSizeBitVectors.SLessThan(toSMT(a), toSMT(b))
case IntegerType => Ints.LessThan(toSMT(a), toSMT(b))
}
case LessEquals(a,b) => a.getType match {
case Int32Type => FixedSizeBitVectors.SLessEquals(toSMT(a), toSMT(b))
case IntegerType => Ints.LessEquals(toSMT(a), toSMT(b))
}
case GreaterThan(a,b) => a.getType match {
case Int32Type => FixedSizeBitVectors.SGreaterThan(toSMT(a), toSMT(b))
case IntegerType => Ints.GreaterThan(toSMT(a), toSMT(b))
}
case GreaterEquals(a,b) => a.getType match {
case Int32Type => FixedSizeBitVectors.SGreaterEquals(toSMT(a), toSMT(b))
case IntegerType => Ints.GreaterEquals(toSMT(a), toSMT(b))
}
case BVPlus(a,b) => FixedSizeBitVectors.Add(toSMT(a), toSMT(b))
case BVMinus(a,b) => FixedSizeBitVectors.Sub(toSMT(a), toSMT(b))
case BVTimes(a,b) => FixedSizeBitVectors.Mul(toSMT(a), toSMT(b))
case BVDivision(a,b) => FixedSizeBitVectors.SDiv(toSMT(a), toSMT(b))
case BVModulo(a,b) => FixedSizeBitVectors.SRem(toSMT(a), toSMT(b))
case BVAnd(a,b) => FixedSizeBitVectors.And(toSMT(a), toSMT(b))
case BVOr(a,b) => FixedSizeBitVectors.Or(toSMT(a), toSMT(b))
case BVXOr(a,b) => FixedSizeBitVectors.XOr(toSMT(a), toSMT(b))
case BVShiftLeft(a,b) => FixedSizeBitVectors.ShiftLeft(toSMT(a), toSMT(b))
case BVAShiftRight(a,b) => FixedSizeBitVectors.AShiftRight(toSMT(a), toSMT(b))
case BVLShiftRight(a,b) => FixedSizeBitVectors.LShiftRight(toSMT(a), toSMT(b))
case And(sub) => Core.And(sub.map(toSMT): _*)
case Or(sub) => Core.Or(sub.map(toSMT): _*)
case IfExpr(cond, thenn, elze) => Core.ITE(toSMT(cond), toSMT(thenn), toSMT(elze))
case f@FunctionInvocation(_, sub) =>
if (sub.isEmpty) declareFunction(f.tfd) else {
FunctionApplication(
declareFunction(f.tfd),
sub.map(toSMT)
)
}
case o =>
unsupported("Tree: " + o)
reporter.warning(s"Unsupported Tree in smt-$targetName: $o")
throw new IllegalArgumentException
}
}
......@@ -693,7 +682,7 @@ trait SMTLIBTarget {
out.write("\n")
out.flush()
}
interpreter.eval(cmd) match {
if (hasError) Unsupported else interpreter.eval(cmd) match {
case err@ErrorResponse(msg) if !interrupted =>
reporter.fatalError("Unexpected error from smt-"+targetName+" solver: "+msg)
case res => res
......@@ -702,8 +691,16 @@ trait SMTLIBTarget {
override def assertCnstr(expr: Expr): Unit = {
variablesOf(expr).foreach(declareVariable)
val term = toSMT(expr)(Map())
sendCommand(SMTAssert(term))
try {
val term = toSMT(expr)(Map())
sendCommand(SMTAssert(term))
} catch {
case i : IllegalArgumentException =>
// Store that there was an error. Now all following check()
// invocations will return None
addError()
}
}
override def check: Option[Boolean] = sendCommand(CheckSat()) match {
......@@ -743,7 +740,7 @@ trait SMTLIBTarget {
genericValues.push()
sorts.push()
functions.push()
errors.push()
sendCommand(Push(1))
}
......@@ -757,6 +754,7 @@ trait SMTLIBTarget {
genericValues.pop()
sorts.pop()
functions.pop()
errors.pop()
sendCommand(Pop(1))
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment