From 4ae5ae4f22ba875f4d30d3518fda419bb7c27eb9 Mon Sep 17 00:00:00 2001 From: Manos Koukoutos <emmanouil.koukoutos@epfl.ch> Date: Thu, 7 May 2015 14:55:22 +0200 Subject: [PATCH] Solver should not exit when encountering unsupported Tree --- .../frontends/scalac/ExtractionPhase.scala | 2 - .../smtlib/SMTLIBCVC4QuantifiedTarget.scala | 41 ++++-- .../leon/solvers/smtlib/SMTLIBTarget.scala | 134 +++++++++--------- 3 files changed, 93 insertions(+), 84 deletions(-) diff --git a/src/main/scala/leon/frontends/scalac/ExtractionPhase.scala b/src/main/scala/leon/frontends/scalac/ExtractionPhase.scala index 80a6691a1..c4131034a 100644 --- a/src/main/scala/leon/frontends/scalac/ExtractionPhase.scala +++ b/src/main/scala/leon/frontends/scalac/ExtractionPhase.scala @@ -6,8 +6,6 @@ package frontends.scalac import purescala.Definitions.Program import purescala.Common.FreshIdentifier -import purescala.ScalaPrinter - import utils._ import scala.tools.nsc.{Settings,CompilerCommand} diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4QuantifiedTarget.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4QuantifiedTarget.scala index 808a3b4db..305457527 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4QuantifiedTarget.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4QuantifiedTarget.scala @@ -4,10 +4,11 @@ package leon package solvers.smtlib import purescala.Common.FreshIdentifier -import purescala.Expressions.{FunctionInvocation, BooleanLiteral, Expr, Implies} +import leon.purescala.Expressions._ import purescala.Definitions.TypedFunDef import purescala.Constructors.{application, implies} import purescala.DefOps.typedTransitiveCallees +import smtlib.parser.Commands.Assert import smtlib.parser.Commands._ import smtlib.parser.Terms._ import smtlib.theories.Core.Equals @@ -44,17 +45,23 @@ trait SMTLIBCVC4QuantifiedTarget extends SMTLIBCVC4Target { functions +=(tfd, id2sym(id)) - val bodyAssert = Assert(Equals(id2sym(id): Term, toSMT(tfd.body.get)(Map()))) - - val specAssert = tfd.postcondition map { post => - val term = implies( - tfd.precondition getOrElse BooleanLiteral(true), - application(post, Seq(FunctionInvocation(tfd, Seq()))) - ) - Assert(toSMT(term)(Map())) + try { + val bodyAssert = Assert(Equals(id2sym(id): Term, toSMT(tfd.body.get)(Map()))) + + val specAssert = tfd.postcondition map { post => + val term = implies( + tfd.precondition getOrElse BooleanLiteral(true), + application(post, Seq(FunctionInvocation(tfd, Seq()))) + ) + Assert(toSMT(term)(Map())) + } + + Seq(bodyAssert) ++ specAssert + } catch { + case _ : IllegalArgumentException => + addError() + Seq() } - - Seq(bodyAssert) ++ specAssert } val seen = withParams filterNot functions.containsA @@ -76,9 +83,15 @@ trait SMTLIBCVC4QuantifiedTarget extends SMTLIBCVC4Target { val smtBodies = smtFunDecls map { case FunDec(sym, _, _) => val tfd = functions.toA(sym) - toSMT(tfd.body.get)(tfd.params.map { p => - (p.id, id2sym(p.id): Term) - }.toMap) + try { + toSMT(tfd.body.get)(tfd.params.map { p => + (p.id, id2sym(p.id): Term) + }.toMap) + } catch { + case i: IllegalArgumentException => + addError() + toSMT(Error(tfd.body.get.getType, ""))(Map()) + } } if (smtFunDecls.nonEmpty) { diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala index 7d694d10d..54ed63031 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala @@ -64,6 +64,9 @@ trait SMTLIBTarget { val genericValues = new IncrementalBijection[GenericValue, SSymbol]() val sorts = new IncrementalBijection[TypeTree, Sort]() val functions = new IncrementalBijection[TypedFunDef, SSymbol]() + val errors = new IncrementalBijection[Unit, Boolean]() + protected def hasError = errors.getB(()) contains true + protected def addError() = errors += () -> true protected object OptionManager { lazy val leonOption = program.library.Option.get @@ -526,72 +529,58 @@ trait SMTLIBTarget { case ap @ Application(caller, args) => ArraysEx.Select(toSMT(caller), toSMT(tupleWrap(args))) - case e @ UnaryOperator(u, _) => - e match { - case (_: Not) => Core.Not(toSMT(u)) - case (_: UMinus) => Ints.Neg(toSMT(u)) - case (_: BVUMinus) => FixedSizeBitVectors.Neg(toSMT(u)) - case (_: BVNot) => FixedSizeBitVectors.Not(toSMT(u)) - case _ => reporter.fatalError("Unhandled unary "+e) - } - - case e @ BinaryOperator(a, b, _) => - e match { - case (_: Assert) => toSMT(IfExpr(a, b, Error(b.getType, "assertion failed"))) - case (_: Equals) => Core.Equals(toSMT(a), toSMT(b)) - case (_: Implies) => Core.Implies(toSMT(a), toSMT(b)) - case (_: Plus) => Ints.Add(toSMT(a), toSMT(b)) - case (_: Minus) => Ints.Sub(toSMT(a), toSMT(b)) - case (_: Times) => Ints.Mul(toSMT(a), toSMT(b)) - case (_: Division) => Ints.Div(toSMT(a), toSMT(b)) - case (_: Modulo) => Ints.Mod(toSMT(a), toSMT(b)) - case (_: LessThan) => a.getType match { - case Int32Type => FixedSizeBitVectors.SLessThan(toSMT(a), toSMT(b)) - case IntegerType => Ints.LessThan(toSMT(a), toSMT(b)) - } - case (_: LessEquals) => a.getType match { - case Int32Type => FixedSizeBitVectors.SLessEquals(toSMT(a), toSMT(b)) - case IntegerType => Ints.LessEquals(toSMT(a), toSMT(b)) - } - case (_: GreaterThan) => a.getType match { - case Int32Type => FixedSizeBitVectors.SGreaterThan(toSMT(a), toSMT(b)) - case IntegerType => Ints.GreaterThan(toSMT(a), toSMT(b)) - } - case (_: GreaterEquals) => a.getType match { - case Int32Type => FixedSizeBitVectors.SGreaterEquals(toSMT(a), toSMT(b)) - case IntegerType => Ints.GreaterEquals(toSMT(a), toSMT(b)) - } - case (_: BVPlus) => FixedSizeBitVectors.Add(toSMT(a), toSMT(b)) - case (_: BVMinus) => FixedSizeBitVectors.Sub(toSMT(a), toSMT(b)) - case (_: BVTimes) => FixedSizeBitVectors.Mul(toSMT(a), toSMT(b)) - case (_: BVDivision) => FixedSizeBitVectors.SDiv(toSMT(a), toSMT(b)) - case (_: BVModulo) => FixedSizeBitVectors.SRem(toSMT(a), toSMT(b)) - case (_: BVAnd) => FixedSizeBitVectors.And(toSMT(a), toSMT(b)) - case (_: BVOr) => FixedSizeBitVectors.Or(toSMT(a), toSMT(b)) - case (_: BVXOr) => FixedSizeBitVectors.XOr(toSMT(a), toSMT(b)) - case (_: BVShiftLeft) => FixedSizeBitVectors.ShiftLeft(toSMT(a), toSMT(b)) - case (_: BVAShiftRight) => FixedSizeBitVectors.AShiftRight(toSMT(a), toSMT(b)) - case (_: BVLShiftRight) => FixedSizeBitVectors.LShiftRight(toSMT(a), toSMT(b)) - case _ => reporter.fatalError("Unhandled binary "+e) - } - - case e @ NAryOperator(sub, _) => - e match { - case (_: And) => Core.And(sub.map(toSMT): _*) - case (_: Or) => Core.Or(sub.map(toSMT): _*) - case (_: IfExpr) => Core.ITE(toSMT(sub(0)), toSMT(sub(1)), toSMT(sub(2))) - case (f: FunctionInvocation) => - if (sub.isEmpty) declareFunction(f.tfd) else { - FunctionApplication( - declareFunction(f.tfd), - sub.map(toSMT) - ) - } - case _ => reporter.fatalError("Unhandled nary "+e) + case Not(u) => Core.Not(toSMT(u)) + case UMinus(u) => Ints.Neg(toSMT(u)) + case BVUMinus(u) => FixedSizeBitVectors.Neg(toSMT(u)) + case BVNot(u) => FixedSizeBitVectors.Not(toSMT(u)) + case Assert(a,_, b) => toSMT(IfExpr(a, b, Error(b.getType, "assertion failed"))) + case Equals(a,b) => Core.Equals(toSMT(a), toSMT(b)) + case Implies(a,b) => Core.Implies(toSMT(a), toSMT(b)) + case Plus(a,b) => Ints.Add(toSMT(a), toSMT(b)) + case Minus(a,b) => Ints.Sub(toSMT(a), toSMT(b)) + case Times(a,b) => Ints.Mul(toSMT(a), toSMT(b)) + case Division(a,b) => Ints.Div(toSMT(a), toSMT(b)) + case Modulo(a,b) => Ints.Mod(toSMT(a), toSMT(b)) + case LessThan(a,b) => a.getType match { + case Int32Type => FixedSizeBitVectors.SLessThan(toSMT(a), toSMT(b)) + case IntegerType => Ints.LessThan(toSMT(a), toSMT(b)) + } + case LessEquals(a,b) => a.getType match { + case Int32Type => FixedSizeBitVectors.SLessEquals(toSMT(a), toSMT(b)) + case IntegerType => Ints.LessEquals(toSMT(a), toSMT(b)) + } + case GreaterThan(a,b) => a.getType match { + case Int32Type => FixedSizeBitVectors.SGreaterThan(toSMT(a), toSMT(b)) + case IntegerType => Ints.GreaterThan(toSMT(a), toSMT(b)) + } + case GreaterEquals(a,b) => a.getType match { + case Int32Type => FixedSizeBitVectors.SGreaterEquals(toSMT(a), toSMT(b)) + case IntegerType => Ints.GreaterEquals(toSMT(a), toSMT(b)) + } + case BVPlus(a,b) => FixedSizeBitVectors.Add(toSMT(a), toSMT(b)) + case BVMinus(a,b) => FixedSizeBitVectors.Sub(toSMT(a), toSMT(b)) + case BVTimes(a,b) => FixedSizeBitVectors.Mul(toSMT(a), toSMT(b)) + case BVDivision(a,b) => FixedSizeBitVectors.SDiv(toSMT(a), toSMT(b)) + case BVModulo(a,b) => FixedSizeBitVectors.SRem(toSMT(a), toSMT(b)) + case BVAnd(a,b) => FixedSizeBitVectors.And(toSMT(a), toSMT(b)) + case BVOr(a,b) => FixedSizeBitVectors.Or(toSMT(a), toSMT(b)) + case BVXOr(a,b) => FixedSizeBitVectors.XOr(toSMT(a), toSMT(b)) + case BVShiftLeft(a,b) => FixedSizeBitVectors.ShiftLeft(toSMT(a), toSMT(b)) + case BVAShiftRight(a,b) => FixedSizeBitVectors.AShiftRight(toSMT(a), toSMT(b)) + case BVLShiftRight(a,b) => FixedSizeBitVectors.LShiftRight(toSMT(a), toSMT(b)) + case And(sub) => Core.And(sub.map(toSMT): _*) + case Or(sub) => Core.Or(sub.map(toSMT): _*) + case IfExpr(cond, thenn, elze) => Core.ITE(toSMT(cond), toSMT(thenn), toSMT(elze)) + case f@FunctionInvocation(_, sub) => + if (sub.isEmpty) declareFunction(f.tfd) else { + FunctionApplication( + declareFunction(f.tfd), + sub.map(toSMT) + ) } - case o => - unsupported("Tree: " + o) + reporter.warning(s"Unsupported Tree in smt-$targetName: $o") + throw new IllegalArgumentException } } @@ -693,7 +682,7 @@ trait SMTLIBTarget { out.write("\n") out.flush() } - interpreter.eval(cmd) match { + if (hasError) Unsupported else interpreter.eval(cmd) match { case err@ErrorResponse(msg) if !interrupted => reporter.fatalError("Unexpected error from smt-"+targetName+" solver: "+msg) case res => res @@ -702,8 +691,16 @@ trait SMTLIBTarget { override def assertCnstr(expr: Expr): Unit = { variablesOf(expr).foreach(declareVariable) - val term = toSMT(expr)(Map()) - sendCommand(SMTAssert(term)) + try { + val term = toSMT(expr)(Map()) + sendCommand(SMTAssert(term)) + } catch { + case i : IllegalArgumentException => + // Store that there was an error. Now all following check() + // invocations will return None + addError() + } + } override def check: Option[Boolean] = sendCommand(CheckSat()) match { @@ -743,7 +740,7 @@ trait SMTLIBTarget { genericValues.push() sorts.push() functions.push() - + errors.push() sendCommand(Push(1)) } @@ -757,6 +754,7 @@ trait SMTLIBTarget { genericValues.pop() sorts.pop() functions.pop() + errors.pop() sendCommand(Pop(1)) } -- GitLab