diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala index 42c7af929ec57d654640b502820fdb6ac659ccf1..a89fa0da394415a5ae51494cf4bc077ce81e7710 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala @@ -15,7 +15,13 @@ import utils.IncrementalBijection import _root_.smtlib.common._ import _root_.smtlib.printer.{RecursivePrinter => SMTPrinter} import _root_.smtlib.parser.Commands.{Constructor => SMTConstructor, FunDef => _, Assert => SMTAssert, _} -import _root_.smtlib.parser.Terms.{Identifier => SMTIdentifier, Let => SMTLet, _} +import _root_.smtlib.parser.Terms.{ + Identifier => SMTIdentifier, + Let => SMTLet, + ForAll => SMTForall, + Exists => SMTExists, + _ +} import _root_.smtlib.parser.CommandsResponses.{Error => ErrorResponse, _} import _root_.smtlib.theories._ import _root_.smtlib.{Interpreter => SMTInterpreter} @@ -108,6 +114,28 @@ trait SMTLIBTarget { case _ => t } + protected def quantifiedTerm( + quantifier: (SortedVar, Seq[SortedVar], Term) => Term, + vars: Seq[Identifier], + body: Expr + ) : Term = { + if (vars.isEmpty) toSMT(body)(Map()) + else { + val sortedVars = vars map { id => + SortedVar(id2sym(id), declareSort(id.getType)) + } + quantifier( + sortedVars.head, + sortedVars.tail, + toSMT(body)(vars.map{ id => id -> (id2sym(id): Term)}.toMap) + ) + } + } + + // Returns a quantified term where all free variables in the body have been quantified + protected def quantifiedTerm(quantifier: (SortedVar, Seq[SortedVar], Term) => Term, body: Expr): Term = + quantifiedTerm(quantifier, variablesOf(body).toSeq, body) + // Corresponds to a smt map, not a leon/scala array // Should NEVER escape past SMT-world private[smtlib] case class RawArrayType(from: TypeTree, to: TypeTree) extends TypeTree diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBUnrollingCVC4Target.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBUnrollingCVC4Target.scala index ea6f20d2b303f20fa23843bd7e3f5a8454b4773b..4a72c468d49e30b992549e483e4ceb3d5aa11ef5 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBUnrollingCVC4Target.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBUnrollingCVC4Target.scala @@ -3,11 +3,12 @@ package leon package solvers.smtlib -import leon.purescala.Common.FreshIdentifier -import leon.purescala.Expressions.Expr -import leon.purescala.Definitions.{ValDef, TypedFunDef} +import purescala.Common.FreshIdentifier +import purescala.Expressions.{FunctionInvocation, BooleanLiteral, Expr, Implies} +import purescala.Definitions.TypedFunDef +import purescala.Constructors.application import purescala.DefOps.typedTransitiveCallees -import leon.purescala.ExprOps.{variablesOf, matchToIfThenElse} +import leon.purescala.ExprOps.matchToIfThenElse import smtlib.parser.Commands._ import smtlib.parser.Terms._ @@ -28,34 +29,38 @@ trait SMTLIBUnrollingCVC4Target extends SMTLIBCVC4Target { ) } - val smtFunDecls = funs.toSeq.flatMap { - case tfd if tfd.params.isEmpty => - // FIXME: Here we actually want to call super[SMTLIBCVC4Target].declareFunction(tfd), - // but we inline it to work around a freakish compiler bug - if (!functions.containsA(tfd)) { - val id = if (tfd.tps.isEmpty) { - tfd.id - } else { - FreshIdentifier(tfd.id.name) - } - sendCommand( DeclareFun(id2sym(id),Seq(),declareSort(tfd.returnType)) ) - } - None - case tfd if !functions.containsA(tfd) && tfd.params.nonEmpty => + val (withParams, withoutParams) = funs.toSeq partition( _.params.nonEmpty) + + withoutParams foreach { tfd => + // FIXME: Here we actually want to call super[SMTLIBCVC4Target].declareFunction(tfd), + // but we inline it to work around a freakish compiler bug + if (!functions.containsA(tfd)) { val id = if (tfd.tps.isEmpty) { tfd.id } else { - tfd.id.freshen + FreshIdentifier(tfd.id.name) } - val sym = id2sym(id) - functions +=(tfd, sym) - Some(FunDec( - sym, - tfd.params map { p => SortedVar(id2sym(p.id), declareSort(p.getType)) }, - declareSort(tfd.returnType) - )) - case _ => None + sendCommand(DeclareFun(id2sym(id), Seq(), declareSort(tfd.returnType))) + } + } + + val seen = withParams filterNot functions.containsA + + val smtFunDecls = seen map { tfd => + val id = if (tfd.tps.isEmpty) { + tfd.id + } else { + tfd.id.freshen + } + val sym = id2sym(id) + functions +=(tfd, sym) + FunDec( + sym, + tfd.params map { p => SortedVar(id2sym(p.id), declareSort(p.getType)) }, + declareSort(tfd.returnType) + ) } + val smtBodies = smtFunDecls map { case FunDec(sym, _, _) => val tfd = functions.toA(sym) toSMT(matchToIfThenElse(tfd.body.get))(tfd.params.map { p => @@ -63,26 +68,26 @@ trait SMTLIBUnrollingCVC4Target extends SMTLIBCVC4Target { }.toMap) } - if (smtFunDecls.nonEmpty) sendCommand(DefineFunsRec(smtFunDecls, smtBodies)) + if (smtFunDecls.nonEmpty) { + sendCommand(DefineFunsRec(smtFunDecls, smtBodies)) + // Assert contracts for defined functions + for { + tfd <- seen + post <- tfd.postcondition + } { + val term = Implies( + tfd.precondition getOrElse BooleanLiteral(true), + application(post, Seq(FunctionInvocation(tfd, tfd.params map { _.toVariable}))) + ) + sendCommand(Assert(quantifiedTerm(ForAll, term))) + } + } functions.toB(tfd) } } // For this solver, we prefer the variables of assert() commands to be exist. quantified instead of free - override def assertCnstr(expr: Expr): Unit = { - val existentials = variablesOf(expr).toSeq - - val term = if (existentials.isEmpty) toSMT(expr)(Map()) else { - val es = existentials map { id => - SortedVar(id2sym(id), declareSort(id.getType)) - } - Exists( - es.head, - es.tail, - toSMT(expr)(existentials.map { id => id -> (id2sym(id): Term)}.toMap) - ) - } - sendCommand(Assert(term)) - } + override def assertCnstr(expr: Expr) = + sendCommand(Assert(quantifiedTerm(Exists, expr))) }