diff --git a/src/main/scala/leon/solvers/SolverFactory.scala b/src/main/scala/leon/solvers/SolverFactory.scala index e5c289b8cf6f3ebdab81e5d6eb3690a77d64c816..474532e19c77ea60f82aa8e502627353cdd4fe93 100644 --- a/src/main/scala/leon/solvers/SolverFactory.scala +++ b/src/main/scala/leon/solvers/SolverFactory.scala @@ -19,7 +19,7 @@ object SolverFactory { } } - val definedSolvers = Set("fairz3", "unrollz3", "enum", "smt", "smt-z3", "smt-cvc4", "smt-2.5-cvc4") + val definedSolvers = Set("fairz3", "unrollz3", "enum", "smt", "smt-z3", "smt-z3-quantified", "smt-cvc4", "smt-2.5-cvc4") def getFromSettings[S](ctx: LeonContext, program: Program): SolverFactory[TimeoutSolver] = { import combinators._ @@ -39,6 +39,9 @@ object SolverFactory { case "smt" | "smt-z3" => SolverFactory(() => new UnrollingSolver(ctx, program, new SMTLIBSolver(ctx, program) with SMTLIBZ3Target) with TimeoutSolver) + case "smt-z3-quantified" => + SolverFactory(() => new SMTLIBSolver(ctx, program) with SMTLIBZ3QuantifiedTarget with TimeoutSolver) + case "smt-cvc4" => SolverFactory(() => new UnrollingSolver(ctx, program, new SMTLIBSolver(ctx, program) with SMTLIBCVC4Target) with TimeoutSolver) diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3QuantifiedTarget.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3QuantifiedTarget.scala new file mode 100644 index 0000000000000000000000000000000000000000..ab38c5203358bbd0d4a6dadf65c2a783cca418b3 --- /dev/null +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3QuantifiedTarget.scala @@ -0,0 +1,76 @@ +package leon.solvers.smtlib + +import leon.purescala.DefOps._ +import leon.purescala.Definitions.TypedFunDef +import leon.purescala.ExprOps._ +import leon.purescala.Expressions.{Equals, FunctionInvocation} +import smtlib.parser.Commands.{DeclareFun, Assert} +import smtlib.parser.Terms.{Term, SortedVar, SSymbol, ForAll} + +/** + * This solver models function definitions as universally quantified formulas. + * It is not meant as an underlying solver to UnrollingSolver. + */ +trait SMTLIBZ3QuantifiedTarget extends SMTLIBZ3Target { + + this: SMTLIBSolver => + + private val typedFunDefExplorationLimit = 10000 + + override def targetName = "z3-quantified" + + override def declareFunction(tfd: TypedFunDef): SSymbol = { + if (tfd.params.isEmpty) { + super[SMTLIBZ3Target].declareFunction(tfd) + } else { + val (funs, exploredAll) = typedTransitiveCallees(Set(tfd), Some(typedFunDefExplorationLimit)) + if (!exploredAll) { + reporter.warning( + s"Did not manage to explore the space of typed functions called from ${tfd.id}. The solver may fail" + ) + } + + val smtFunDecls = funs.toSeq.collect { + case tfd if !functions.containsA(tfd) && tfd.params.nonEmpty => + val id = if (tfd.tps.isEmpty) { + tfd.id + } else { + tfd.id.freshen + } + val sym = id2sym(id) + functions +=(tfd, sym) + sendCommand(DeclareFun( + sym, + tfd.params map { p => declareSort(p.getType) }, + declareSort(tfd.returnType) + )) + sym + } + smtFunDecls foreach { sym => + val tfd = functions.toA(sym) + val sortedVars = tfd.params.map { p => + SortedVar(id2sym(p.id), declareSort(p.getType)) + } + val term = + if (sortedVars.isEmpty) { + toSMT(Equals(FunctionInvocation(tfd, Seq()), tfd.body.get))(Map()) + } else { + ForAll( + sortedVars.head, + sortedVars.tail, + toSMT(Equals( + FunctionInvocation(tfd, tfd.params.map {_.toVariable}), + matchToIfThenElse(tfd.body.get) + ))( + tfd.params.map { p => (p.id, id2sym(p.id): Term) }.toMap + ) + ) + } + sendCommand(Assert(term)) + } + + functions.toB(tfd) + } + } + +}