From 5dad01db7d13050f6fefe82282f86ea2e4311c84 Mon Sep 17 00:00:00 2001 From: Etienne Kneuss <ekneuss@gmail.com> Date: Mon, 7 Sep 2015 14:27:08 +0200 Subject: [PATCH] Implement basic sygus, only possible in --manual mode for now Split SMTLIBSolver to Solver+Target to allow things other than leon solvers to talk SMT. --- build.sbt | 2 +- src/main/scala/leon/solvers/RawArray.scala | 18 +- .../smtlib/SMTLIBCVC4ProofSolver.scala | 16 +- .../smtlib/SMTLIBCVC4QuantifiedSolver.scala | 109 +-- .../smtlib/SMTLIBCVC4QuantifiedTarget.scala | 118 +++ .../solvers/smtlib/SMTLIBCVC4Solver.scala | 132 +-- .../solvers/smtlib/SMTLIBCVC4Target.scala | 133 +++ .../smtlib/SMTLIBQuantifiedSolver.scala | 59 +- .../smtlib/SMTLIBQuantifiedTarget.scala | 52 ++ .../leon/solvers/smtlib/SMTLIBSolver.scala | 759 +--------------- .../leon/solvers/smtlib/SMTLIBTarget.scala | 818 ++++++++++++++++++ .../smtlib/SMTLIBUnsupportedError.scala | 10 + .../smtlib/SMTLIBZ3QuantifiedSolver.scala | 68 +- .../smtlib/SMTLIBZ3QuantifiedTarget.scala | 76 ++ .../leon/solvers/smtlib/SMTLIBZ3Solver.scala | 171 +--- .../leon/solvers/smtlib/SMTLIBZ3Target.scala | 183 ++++ .../leon/solvers/sygus/CVC4SygusSolver.scala | 29 + .../leon/solvers/sygus/SygusSolver.scala | 131 +++ src/main/scala/leon/synthesis/Rules.scala | 2 +- src/main/scala/leon/synthesis/Solution.scala | 4 + .../scala/leon/synthesis/SynthesisPhase.scala | 2 +- .../scala/leon/synthesis/rules/Sygus.scala | 36 + src/main/scala/leon/utils/DebugFiles.scala | 3 + testcases/synthesis/sygus/listqueue.scala | 64 ++ testcases/synthesis/sygus/max2.scala | 9 + testcases/synthesis/sygus/numerals1.scala | 16 + testcases/synthesis/sygus/plusone.scala | 9 + 27 files changed, 1772 insertions(+), 1257 deletions(-) create mode 100644 src/main/scala/leon/solvers/smtlib/SMTLIBCVC4QuantifiedTarget.scala create mode 100644 src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala create mode 100644 src/main/scala/leon/solvers/smtlib/SMTLIBQuantifiedTarget.scala create mode 100644 src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala create mode 100644 src/main/scala/leon/solvers/smtlib/SMTLIBUnsupportedError.scala create mode 100644 src/main/scala/leon/solvers/smtlib/SMTLIBZ3QuantifiedTarget.scala create mode 100644 src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala create mode 100644 src/main/scala/leon/solvers/sygus/CVC4SygusSolver.scala create mode 100644 src/main/scala/leon/solvers/sygus/SygusSolver.scala create mode 100644 src/main/scala/leon/synthesis/rules/Sygus.scala create mode 100644 src/main/scala/leon/utils/DebugFiles.scala create mode 100644 testcases/synthesis/sygus/listqueue.scala create mode 100644 testcases/synthesis/sygus/max2.scala create mode 100644 testcases/synthesis/sygus/numerals1.scala create mode 100644 testcases/synthesis/sygus/plusone.scala diff --git a/build.sbt b/build.sbt index c6449492b..c5ac64d0e 100644 --- a/build.sbt +++ b/build.sbt @@ -141,7 +141,7 @@ def ghProject(repo: String, version: String) = RootProject(uri(s"${repo}#${versi lazy val bonsai = ghProject("git://github.com/colder/bonsai.git", "0fec9f97f4220fa94b1f3f305f2e8b76a3cd1539") -lazy val scalaSmtLib = ghProject("git://github.com/regb/scala-smtlib.git", "8aa4a5588653ce4986e3721115a62cc386714cc2") +lazy val scalaSmtLib = ghProject("git://github.com/regb/scala-smtlib.git", "204018c3f413fc8191c99ae9ccc4806912e02a83") lazy val root = (project in file(".")). configs(RegressionTest, IsabelleTest, IntegrTest). diff --git a/src/main/scala/leon/solvers/RawArray.scala b/src/main/scala/leon/solvers/RawArray.scala index 05ec28a86..a527a98fa 100644 --- a/src/main/scala/leon/solvers/RawArray.scala +++ b/src/main/scala/leon/solvers/RawArray.scala @@ -3,23 +3,23 @@ package leon package solvers -import leon.purescala.{PrinterContext, PrettyPrintable} -import leon.purescala.PrinterHelpers._ import purescala.Types._ import purescala.Expressions._ -// Corresponds to a smt map, not a leon/scala array -private[solvers] case class RawArrayType(from: TypeTree, to: TypeTree) extends TypeTree with PrettyPrintable { - override def printWith(implicit pctx: PrinterContext): Unit = { - p"RawArrayType[$from, $to]" +// Corresponds to a complete map (SMT Array), not a Leon/Scala array +// Only used within solvers or SMT for encoding purposes +case class RawArrayType(from: TypeTree, to: TypeTree) extends TypeTree { + override def asString(implicit ctx: LeonContext) = { + s"RawArrayType[${from.asString}, ${to.asString}]" } } // Corresponds to a raw array value, which is coerced to a Leon expr depending on target type (set/array) -private[solvers] case class RawArrayValue(keyTpe: TypeTree, elems: Map[Expr, Expr], default: Expr) extends Expr with PrettyPrintable{ +case class RawArrayValue(keyTpe: TypeTree, elems: Map[Expr, Expr], default: Expr) extends Expr { val getType = RawArrayType(keyTpe, default.getType) - override def printWith(implicit pctx: PrinterContext): Unit = { - p"RawArrayValue[$keyTpe](${nary(elems.toSeq, ", ")}, default = $default)" + override def asString(implicit ctx: LeonContext) = { + val elemsString = elems.map { case (k, v) => k.asString+" -> "+v.asString } mkString(", ") + s"RawArray[${keyTpe.asString}]($elemsString, default = ${default.asString})" } } diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4ProofSolver.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4ProofSolver.scala index 281310a81..830216880 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4ProofSolver.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4ProofSolver.scala @@ -3,12 +3,12 @@ package leon package solvers.smtlib -import leon.purescala.Common.Identifier -import leon.purescala.Definitions.Program -import leon.purescala.Expressions.Expr -import leon.solvers.SolverUnsupportedError -import smtlib.parser.Commands.{Assert => SMTAssert} -import smtlib.parser.Terms.{Exists => SMTExists} +import purescala.Common.Identifier +import purescala.Definitions.Program +import purescala.Expressions.Expr + +import _root_.smtlib.parser.Commands.{Assert => SMTAssert} +import _root_.smtlib.parser.Terms.{Exists => SMTExists} class SMTLIBCVC4ProofSolver(context: LeonContext, program: Program) extends SMTLIBCVC4QuantifiedSolver(context, program) { @@ -30,9 +30,9 @@ class SMTLIBCVC4ProofSolver(context: LeonContext, program: Program) extends SMTL // For this solver, we prefer the variables of assert() commands to be exist. quantified instead of free override def assertCnstr(e: Expr) = try { - sendCommand(SMTAssert(quantifiedTerm(SMTExists, e)(Map()))) + emit(SMTAssert(quantifiedTerm(SMTExists, e)(Map()))) } catch { - case _ : SolverUnsupportedError => + case _ : SMTLIBUnsupportedError => addError() } diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4QuantifiedSolver.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4QuantifiedSolver.scala index 9704fceba..16068794f 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4QuantifiedSolver.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4QuantifiedSolver.scala @@ -3,116 +3,11 @@ package leon package solvers.smtlib -import leon.solvers.SolverUnsupportedError -import purescala._ -import Expressions._ -import Definitions._ -import Constructors._ -import DefOps.typedTransitiveCallees -import smtlib.parser.Commands.{Assert => SMTAssert, FunDef => _, _} -import smtlib.parser.Terms.{Exists => SMTExists, Forall => SMTForall, _ } -import smtlib.theories.Core.Equals +import purescala.Definitions.Program // This solver utilizes the define-funs-rec command of SMTLIB-2.5 to define mutually recursive functions. // It is not meant as an underlying solver to UnrollingSolver, and does not handle HOFs. abstract class SMTLIBCVC4QuantifiedSolver(context: LeonContext, program: Program) extends SMTLIBCVC4Solver(context, program) with SMTLIBQuantifiedSolver -{ - - override def targetName = "cvc4-quantified" - - override def declareFunction(tfd: TypedFunDef): SSymbol = { - val (funs, exploredAll) = typedTransitiveCallees(Set(tfd), Some(typedFunDefExplorationLimit)) - if (!exploredAll) { - reporter.warning( - "Did not manage to explore the space of typed functions " + - s"transitively called from ${tfd.id}. The solver may fail" - ) - } - - // define-funs-rec does not accept parameterless functions, so we have to treat them differently: - // we declare-fun each one and assert it is equal to its body - val (withParams, withoutParams) = funs.toSeq filterNot functions.containsA partition(_.params.nonEmpty) - - // FIXME this may introduce dependency errors - val parameterlessAssertions = withoutParams flatMap { tfd => - val s = super.declareFunction(tfd) - - try { - val bodyAssert = tfd.body map { bd => - SMTAssert(Equals(s: Term, toSMT(bd)(Map()))) - } - - val specAssert = tfd.postcondition map { post => - val term = implies( - tfd.precondition getOrElse BooleanLiteral(true), - application(post, Seq(FunctionInvocation(tfd, Seq()))) - ) - SMTAssert(toSMT(term)(Map())) - } - - bodyAssert ++ specAssert - } catch { - case _ : SolverUnsupportedError => - addError() - Seq() - } - } - - val smtFunDecls = withParams map { tfd => - val id = if (tfd.tps.isEmpty) { - tfd.id - } else { - tfd.id.freshen - } - val sym = id2sym(id) - functions +=(tfd, sym) - FunDec( - sym, - tfd.params map { p => SortedVar(id2sym(p.id), declareSort(p.getType)) }, - declareSort(tfd.returnType) - ) - } - - val smtBodies = smtFunDecls map { case f => - val tfd = functions.toA(f.name) - try { - toSMT(tfd.body.get)(tfd.params.map { p => - (p.id, id2sym(p.id): Term) - }.toMap) - } catch { - case _: SolverUnsupportedError => - addError() - toSMT(Error(tfd.body.get.getType, ""))(Map()) - } - } - - if (smtFunDecls.nonEmpty) { - sendCommand(DefineFunsRec(smtFunDecls, smtBodies)) - // Assert contracts for defined functions - if (allowQuantifiedAssertions) for { - // If we encounter a function that does not refer to the current function, - // it is sound to assume its contracts for all inputs - tfd <- withParams if !refersToCurrent(tfd.fd) - post <- tfd.postcondition - } { - val term = implies( - tfd.precondition getOrElse BooleanLiteral(true), - application(post, Seq(tfd.applied)) - ) - try { - sendCommand(SMTAssert(quantifiedTerm(SMTForall, term)(Map()))) - } catch { - case _ : SolverUnsupportedError => - addError() - } - } - } - - parameterlessAssertions.foreach(a => sendCommand(a)) - - functions.toB(tfd) - } - -} + with SMTLIBCVC4QuantifiedTarget diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4QuantifiedTarget.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4QuantifiedTarget.scala new file mode 100644 index 000000000..a7215bc97 --- /dev/null +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4QuantifiedTarget.scala @@ -0,0 +1,118 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package solvers +package smtlib + +import purescala.Common._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.Extractors._ +import purescala.Constructors._ +import purescala.Types._ +import purescala.Definitions._ +import purescala.DefOps.typedTransitiveCallees + +import _root_.smtlib.parser.Commands.{Assert => SMTAssert, FunDef => _, _} +import _root_.smtlib.parser.Terms.{Exists => SMTExists, Forall => SMTForall, _ } +import _root_.smtlib.theories.Core.Equals +import _root_.smtlib.parser.Commands._ + +trait SMTLIBCVC4QuantifiedTarget extends SMTLIBCVC4Target with SMTLIBQuantifiedTarget { + + override def targetName = "cvc4-quantified" + + override def declareFunction(tfd: TypedFunDef): SSymbol = { + val (funs, exploredAll) = typedTransitiveCallees(Set(tfd), Some(typedFunDefExplorationLimit)) + + if (!exploredAll) { + reporter.warning( + "Did not manage to explore the space of typed functions " + + s"transitively called from ${tfd.id}. The solver may fail" + ) + } + + // define-funs-rec does not accept parameterless functions, so we have to treat them differently: + // we declare-fun each one and assert it is equal to its body + val (withParams, withoutParams) = funs.toSeq filterNot functions.containsA partition(_.params.nonEmpty) + + // FIXME this may introduce dependency errors + val parameterlessAssertions = withoutParams flatMap { tfd => + val s = super.declareFunction(tfd) + + try { + val bodyAssert = tfd.body.map { body => + SMTAssert(Equals(s: Term, toSMT(body)(Map()))) + } + + val specAssert = tfd.postcondition map { post => + val term = implies( + tfd.precondition getOrElse BooleanLiteral(true), + application(post, Seq(FunctionInvocation(tfd, Seq()))) + ) + SMTAssert(toSMT(term)(Map())) + } + + bodyAssert ++ specAssert + } catch { + case _ : SMTLIBUnsupportedError => + addError() + Seq() + } + } + + val smtFunDecls = withParams map { tfd => + val id = if (tfd.tps.isEmpty) { + tfd.id + } else { + tfd.id.freshen + } + val sym = id2sym(id) + functions +=(tfd, sym) + FunDec( + sym, + tfd.params map { p => SortedVar(id2sym(p.id), declareSort(p.getType)) }, + declareSort(tfd.returnType) + ) + } + + val smtBodies = smtFunDecls map { case f => + val tfd = functions.toA(f.name) + try { + toSMT(tfd.body.get)(tfd.params.map { p => + (p.id, id2sym(p.id): Term) + }.toMap) + } catch { + case _: SMTLIBUnsupportedError => + addError() + toSMT(Error(tfd.body.get.getType, ""))(Map()) + } + } + + if (smtFunDecls.nonEmpty) { + emit(DefineFunsRec(smtFunDecls, smtBodies)) + // Assert contracts for defined functions + if (allowQuantifiedAssertions) for { + // If we encounter a function that does not refer to the current function, + // it is sound to assume its contracts for all inputs + tfd <- withParams if !refersToCurrent(tfd.fd) + post <- tfd.postcondition + } { + val term = implies( + tfd.precondition getOrElse BooleanLiteral(true), + application(post, Seq(tfd.applied)) + ) + try { + emit(SMTAssert(quantifiedTerm(SMTForall, term)(Map()))) + } catch { + case _ : SMTLIBUnsupportedError => + addError() + } + } + } + + parameterlessAssertions.foreach(a => emit(a)) + + functions.toB(tfd) + } +} diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Solver.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Solver.scala index 2ffd6fa75..dfcfbcc3f 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Solver.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Solver.scala @@ -1,25 +1,13 @@ /* Copyright 2009-2015 EPFL, Lausanne */ package leon -package solvers -package smtlib +package solvers.smtlib import OptionParsers._ -import purescala._ -import Definitions.Program -import Common._ -import Expressions.{Assert => _, _} -import Extractors._ -import Constructors._ -import Types._ +import purescala.Definitions.Program -import _root_.smtlib.parser.Terms.{Identifier => SMTIdentifier, Forall => SMTForall, _} -import _root_.smtlib.parser.Commands._ -import _root_.smtlib.interpreters.CVC4Interpreter -import _root_.smtlib.theories._ - -class SMTLIBCVC4Solver(context: LeonContext, program: Program) extends SMTLIBSolver(context, program) { +class SMTLIBCVC4Solver(context: LeonContext, program: Program) extends SMTLIBSolver(context, program) with SMTLIBCVC4Target { def targetName = "cvc4" @@ -34,124 +22,10 @@ class SMTLIBCVC4Solver(context: LeonContext, program: Program) extends SMTLIBSol "--no-incremental", "--tear-down-incremental", "--rewrite-divk", -// "--dt-rewrite-error-sel", // Removing since it causes CVC4 to segfault on some inputs "--print-success", "--lang", "smt" ) ++ userDefinedOps(ctx).toSeq } - - def getNewInterpreter(ctx: LeonContext) = { - val opts = interpreterOps(ctx) - reporter.debug("Invoking solver with "+opts.mkString(" ")) - new CVC4Interpreter("cvc4", opts.toArray) - } - - override protected def declareSort(t: TypeTree): Sort = { - val tpe = normalizeType(t) - sorts.cachedB(tpe) { - tpe match { - case TypeParameter(id) => - val s = id2sym(id) - val cmd = DeclareSort(s, 0) - sendCommand(cmd) - Sort(SMTIdentifier(s)) - - case SetType(base) => - Sort(SMTIdentifier(SSymbol("Set")), Seq(declareSort(base))) - - case _ => - super.declareSort(t) - } - } - } - - override protected def fromSMT(s: Term, tpe: TypeTree)(implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): Expr = (s, tpe) match { - case (SimpleSymbol(s), tp: TypeParameter) => - val n = s.name.split("_").toList.last - GenericValue(tp, n.toInt) - - case (QualifiedIdentifier(SMTIdentifier(SSymbol("emptyset"), Seq()), _), SetType(base)) => - FiniteSet(Set(), base) - - case (FunctionApplication(SimpleSymbol(SSymbol("__array_store_all__")), Seq(_, elem)), RawArrayType(k,v)) => - RawArrayValue(k, Map(), fromSMT(elem, v)) - - case (FunctionApplication(SimpleSymbol(SSymbol("__array_store_all__")), Seq(_, elem)), FunctionType(from,to)) => - RawArrayValue(tupleTypeWrap(from), Map(), fromSMT(elem, to)) - - case (FunctionApplication(SimpleSymbol(SSymbol("store")), Seq(arr, key, elem)), RawArrayType(k,v)) => - val RawArrayValue(_, elems, base) = fromSMT(arr, tpe) - RawArrayValue(k, elems + (fromSMT(key, k) -> fromSMT(elem, v)), base) - - case (FunctionApplication(SimpleSymbol(SSymbol("store")), Seq(arr, key, elem)), FunctionType(from,to)) => - val RawArrayValue(k, elems, base) = fromSMT(arr, tpe) - RawArrayValue(k, elems + (fromSMT(key, k) -> fromSMT(elem, to)), base) - - case (FunctionApplication(SimpleSymbol(SSymbol("singleton")), elems), SetType(base)) => - FiniteSet(elems.map(fromSMT(_, base)).toSet, base) - - case (FunctionApplication(SimpleSymbol(SSymbol("insert")), elems), SetType(base)) => - val selems = elems.init.map(fromSMT(_, base)) - val FiniteSet(se, _) = fromSMT(elems.last, tpe) - FiniteSet(se ++ selems, base) - - case (FunctionApplication(SimpleSymbol(SSymbol("union")), elems), SetType(base)) => - FiniteSet(elems.flatMap(fromSMT(_, tpe) match { - case FiniteSet(elems, _) => elems - }).toSet, base) - - // FIXME (nicolas) - // some versions of CVC4 seem to generate array constants with "as const" notation instead of the __array_store_all__ - // one I've witnessed up to now. Don't know why this is happening... - case (FunctionApplication(QualifiedIdentifier(SMTIdentifier(SSymbol("const"), _), _), Seq(elem)), FunctionType(from, to)) => - RawArrayValue(tupleTypeWrap(from), Map(), fromSMT(elem, to)) - - case (FunctionApplication(QualifiedIdentifier(SMTIdentifier(SSymbol("const"), _), _), Seq(elem)), RawArrayType(k, v)) => - RawArrayValue(k, Map(), fromSMT(elem, v)) - - case _ => - super.fromSMT(s, tpe) - } - - override protected def toSMT(e: Expr)(implicit bindings: Map[Identifier, Term]) = e match { - /** - * ===== Set operations ===== - */ - - - case fs @ FiniteSet(elems, _) => - if (elems.isEmpty) { - QualifiedIdentifier(SMTIdentifier(SSymbol("emptyset")), Some(declareSort(fs.getType))) - } else { - val selems = elems.toSeq.map(toSMT) - - val sgt = FunctionApplication(SSymbol("singleton"), Seq(selems.head)) - - if (selems.size > 1) { - FunctionApplication(SSymbol("insert"), selems.tail :+ sgt) - } else { - sgt - } - } - - case SubsetOf(ss, s) => - FunctionApplication(SSymbol("subset"), Seq(toSMT(ss), toSMT(s))) - - case ElementOfSet(e, s) => - FunctionApplication(SSymbol("member"), Seq(toSMT(e), toSMT(s))) - - case SetDifference(a, b) => - FunctionApplication(SSymbol("setminus"), Seq(toSMT(a), toSMT(b))) - - case SetUnion(a, b) => - FunctionApplication(SSymbol("union"), Seq(toSMT(a), toSMT(b))) - - case SetIntersection(a, b) => - FunctionApplication(SSymbol("intersection"), Seq(toSMT(a), toSMT(b))) - - case _ => - super.toSMT(e) - } } object SMTLIBCVC4Component extends LeonComponent { diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala new file mode 100644 index 000000000..679e42ac7 --- /dev/null +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala @@ -0,0 +1,133 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package solvers +package smtlib + +import purescala.Common._ +import purescala.Expressions._ +import purescala.Extractors._ +import purescala.Constructors._ +import purescala.Types._ + +import _root_.smtlib.parser.Terms.{Identifier => SMTIdentifier, Forall => SMTForall, _} +import _root_.smtlib.parser.Commands._ +import _root_.smtlib.interpreters.CVC4Interpreter + +trait SMTLIBCVC4Target extends SMTLIBTarget { + + override def getNewInterpreter(ctx: LeonContext) = { + val opts = interpreterOps(ctx) + reporter.debug("Invoking solver with "+opts.mkString(" ")) + + new CVC4Interpreter("cvc4", opts.toArray) + } + + override protected def declareSort(t: TypeTree): Sort = { + val tpe = normalizeType(t) + sorts.cachedB(tpe) { + tpe match { + case SetType(base) => + Sort(SMTIdentifier(SSymbol("Set")), Seq(declareSort(base))) + + case _ => + super.declareSort(t) + } + } + } + + override protected def fromSMT(s: Term, tpe: TypeTree)(implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): Expr = (s, tpe) match { + // EK: This hack is necessary for sygus which does not strictly follow smt-lib for negative literals + case (SimpleSymbol(SSymbol(v)), IntegerType) if v.startsWith("-") => + try { + InfiniteIntegerLiteral(v.toInt) + } catch { + case t: Throwable => + super.fromSMT(s, tpe) + } + + case (SimpleSymbol(s), tp: TypeParameter) => + val n = s.name.split("_").toList.last + GenericValue(tp, n.toInt) + + case (QualifiedIdentifier(SMTIdentifier(SSymbol("emptyset"), Seq()), _), SetType(base)) => + FiniteSet(Set(), base) + + case (FunctionApplication(SimpleSymbol(SSymbol("__array_store_all__")), Seq(_, elem)), RawArrayType(k,v)) => + RawArrayValue(k, Map(), fromSMT(elem, v)) + + case (FunctionApplication(SimpleSymbol(SSymbol("__array_store_all__")), Seq(_, elem)), FunctionType(from,to)) => + RawArrayValue(tupleTypeWrap(from), Map(), fromSMT(elem, to)) + + case (FunctionApplication(SimpleSymbol(SSymbol("store")), Seq(arr, key, elem)), RawArrayType(k,v)) => + val RawArrayValue(_, elems, base) = fromSMT(arr, tpe) + RawArrayValue(k, elems + (fromSMT(key, k) -> fromSMT(elem, v)), base) + + case (FunctionApplication(SimpleSymbol(SSymbol("store")), Seq(arr, key, elem)), FunctionType(from,to)) => + val RawArrayValue(k, elems, base) = fromSMT(arr, tpe) + RawArrayValue(k, elems + (fromSMT(key, k) -> fromSMT(elem, to)), base) + + case (FunctionApplication(SimpleSymbol(SSymbol("singleton")), elems), SetType(base)) => + FiniteSet(elems.map(fromSMT(_, base)).toSet, base) + + case (FunctionApplication(SimpleSymbol(SSymbol("insert")), elems), SetType(base)) => + val selems = elems.init.map(fromSMT(_, base)) + val FiniteSet(se, _) = fromSMT(elems.last, tpe) + FiniteSet(se ++ selems, base) + + case (FunctionApplication(SimpleSymbol(SSymbol("union")), elems), SetType(base)) => + FiniteSet(elems.flatMap(fromSMT(_, tpe) match { + case FiniteSet(elems, _) => elems + }).toSet, base) + + case (FunctionApplication(QualifiedIdentifier(SMTIdentifier(SSymbol("const"), _), _), Seq(elem)), RawArrayType(k, v)) => + RawArrayValue(k, Map(), fromSMT(elem, v)) + + // FIXME (nicolas) + // some versions of CVC4 seem to generate array constants with "as const" notation instead of the __array_store_all__ + // one I've witnessed up to now. Don't know why this is happening... + case (FunctionApplication(QualifiedIdentifier(SMTIdentifier(SSymbol("const"), _), _), Seq(elem)), FunctionType(from, to)) => + RawArrayValue(tupleTypeWrap(from), Map(), fromSMT(elem, to)) + + case _ => + super.fromSMT(s, tpe) + } + + override protected def toSMT(e: Expr)(implicit bindings: Map[Identifier, Term]) = e match { + /** + * ===== Set operations ===== + */ + case fs @ FiniteSet(elems, _) => + if (elems.isEmpty) { + QualifiedIdentifier(SMTIdentifier(SSymbol("emptyset")), Some(declareSort(fs.getType))) + } else { + val selems = elems.toSeq.map(toSMT) + + val sgt = FunctionApplication(SSymbol("singleton"), Seq(selems.head)) + + if (selems.size > 1) { + FunctionApplication(SSymbol("insert"), selems.tail :+ sgt) + } else { + sgt + } + } + + case SubsetOf(ss, s) => + FunctionApplication(SSymbol("subset"), Seq(toSMT(ss), toSMT(s))) + + case ElementOfSet(e, s) => + FunctionApplication(SSymbol("member"), Seq(toSMT(e), toSMT(s))) + + case SetDifference(a, b) => + FunctionApplication(SSymbol("setminus"), Seq(toSMT(a), toSMT(b))) + + case SetUnion(a, b) => + FunctionApplication(SSymbol("union"), Seq(toSMT(a), toSMT(b))) + + case SetIntersection(a, b) => + FunctionApplication(SSymbol("intersection"), Seq(toSMT(a), toSMT(b))) + + case _ => + super.toSMT(e) + } +} diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBQuantifiedSolver.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBQuantifiedSolver.scala index f5d37db92..fbaa73cc2 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBQuantifiedSolver.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBQuantifiedSolver.scala @@ -1,56 +1,15 @@ /* Copyright 2009-2015 EPFL, Lausanne */ -package leon.solvers.smtlib +package leon +package solvers +package smtlib -import leon.purescala.Common.Identifier -import leon.purescala.Constructors._ -import leon.purescala.Definitions.FunDef -import leon.purescala.ExprOps._ -import leon.purescala.Expressions._ -import leon.verification.VC -import smtlib.parser.Terms.{ Term, Forall => SMTForall, _ } +import purescala.Common.Identifier +import purescala.Expressions.Expr +import verification.VC -trait SMTLIBQuantifiedSolver extends SMTLIBSolver { - - private var currentFunDef: Option[FunDef] = None - protected def refersToCurrent(fd: FunDef) = { - (currentFunDef contains fd) || (currentFunDef exists { - program.callGraph.transitivelyCalls(fd, _) - }) - } - - protected val allowQuantifiedAssertions: Boolean - - protected val typedFunDefExplorationLimit = 10000 - - protected def withInductiveHyp(cond: Expr): Expr = { - - val inductiveHyps = for { - fi@FunctionInvocation(tfd, args) <- functionCallsOf(cond).toSeq - } yield { - val formalToRealArgs = tfd.params.map{ _.id}.zip(args).toMap - val post = tfd.postcondition map { post => - application( - replaceFromIDs(formalToRealArgs, post), - Seq(fi) - ) - } getOrElse BooleanLiteral(true) - val pre = tfd.precondition getOrElse BooleanLiteral(true) - and(pre, post) - } - - // We want to check if the negation of the vc is sat under inductive hyp. - // So we need to see if (indHyp /\ !vc) is satisfiable - liftLets(matchToIfThenElse(andJoin(inductiveHyps :+ not(cond)))) - - } - - override protected def toSMT(e: Expr)(implicit bindings: Map[Identifier, Term]): Term = e match { - case Forall(vs, bd) => - quantifiedTerm(SMTForall, vs map { _.id }, bd) - case _ => - super.toSMT(e)(bindings) - } +trait SMTLIBQuantifiedSolver { + this: SMTLIBSolver with SMTLIBQuantifiedTarget => // We need to know the function context. // The reason is we do not want to assume postconditions of functions referring to @@ -62,7 +21,7 @@ trait SMTLIBQuantifiedSolver extends SMTLIBSolver { // Normally, UnrollingSolver tracks the input variable, but this one // is invoked alone so we have to filter them here - override def getModel: leon.solvers.Model = { + override def getModel: Model = { val filter = currentFunDef.map{ _.params.map{_.id}.toSet }.getOrElse( (_:Identifier) => true ) getModel(filter) } diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBQuantifiedTarget.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBQuantifiedTarget.scala new file mode 100644 index 000000000..0de742f5e --- /dev/null +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBQuantifiedTarget.scala @@ -0,0 +1,52 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package solvers +package smtlib + +import purescala.Expressions._ +import purescala.Definitions._ +import purescala.Constructors._ +import purescala.ExprOps._ +import purescala.DefOps.typedTransitiveCallees + +import _root_.smtlib.parser.Commands.{Assert => SMTAssert, FunDef => _, _} +import _root_.smtlib.parser.Terms.{Exists => SMTExists, Forall => SMTForall, _ } +import _root_.smtlib.theories.Core.Equals + +trait SMTLIBQuantifiedTarget extends SMTLIBTarget { + + protected var currentFunDef: Option[FunDef] = None + + protected def refersToCurrent(fd: FunDef) = { + (currentFunDef contains fd) || (currentFunDef exists { + program.callGraph.transitivelyCalls(fd, _) + }) + } + + protected val allowQuantifiedAssertions: Boolean + + protected val typedFunDefExplorationLimit = 10000 + + protected def withInductiveHyp(cond: Expr): Expr = { + + val inductiveHyps = for { + fi@FunctionInvocation(tfd, args) <- functionCallsOf(cond).toSeq + } yield { + val formalToRealArgs = tfd.params.map{ _.id}.zip(args).toMap + val post = tfd.postcondition map { post => + application( + replaceFromIDs(formalToRealArgs, post), + Seq(fi) + ) + } getOrElse BooleanLiteral(true) + val pre = tfd.precondition getOrElse BooleanLiteral(true) + and(pre, post) + } + + // We want to check if the negation of the vc is sat under inductive hyp. + // So we need to see if (indHyp /\ !vc) is satisfiable + liftLets(matchToIfThenElse(andJoin(inductiveHyps :+ not(cond)))) + + } +} diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala index 275013666..edb558e95 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala @@ -5,32 +5,23 @@ package solvers package smtlib import utils._ -import purescala._ -import Common._ -import Expressions._ -import Extractors._ -import ExprOps._ -import Types._ -import Constructors._ -import Definitions._ +import purescala.Common._ +import purescala.Expressions._ +import purescala.Extractors._ +import purescala.ExprOps._ +import purescala.Types._ +import purescala.Constructors._ +import purescala.Definitions._ import _root_.smtlib.common._ -import _root_.smtlib.printer.{RecursivePrinter => SMTPrinter} -import _root_.smtlib.parser.Commands.{Constructor => SMTConstructor, FunDef => _, Assert => SMTAssert, _} -import _root_.smtlib.parser.Terms.{ - Forall => SMTForall, - Exists => SMTExists, - Identifier => SMTIdentifier, - Let => SMTLet, - _ -} +import _root_.smtlib.parser.Commands.{Assert => SMTAssert, _} +import _root_.smtlib.parser.Terms.{Identifier => SMTIdentifier, _} import _root_.smtlib.parser.CommandsResponses.{Error => ErrorResponse, _} import _root_.smtlib.theories._ -import _root_.smtlib.{Interpreter => SMTInterpreter} +import _root_.smtlib.interpreters.ProcessInterpreter -abstract class SMTLIBSolver(val context: LeonContext, val program: Program) - extends Solver - with NaiveAssumptionSolver { +abstract class SMTLIBSolver(val context: LeonContext, val program: Program) + extends Solver with SMTLIBTarget with NaiveAssumptionSolver { /* Solver name */ def targetName: String @@ -39,690 +30,15 @@ abstract class SMTLIBSolver(val context: LeonContext, val program: Program) /* Reporter */ protected val reporter = context.reporter - /* Interface with Interpreter */ - - protected def interpreterOps(ctx: LeonContext): Seq[String] - - protected def getNewInterpreter(ctx: LeonContext): SMTInterpreter - - protected val interpreter = getNewInterpreter(context) - - /* Printing VCs */ - protected lazy val out: Option[java.io.FileWriter] = if (reporter.isDebugEnabled) Some { - val file = context.files.headOption.map(_.getName).getOrElse("NA") - val n = VCNumbers.next(targetName+file) - - val dir = new java.io.File("smt-sessions") - - if (!dir.isDirectory) { - dir.mkdir - } - - val fileName = s"smt-sessions/$targetName-$file-$n.smt2" - - reporter.debug(s"Outputting smt session into $fileName" ) - - val javaFile = new java.io.File(fileName) - javaFile.getParentFile.mkdirs() - - val fw = new java.io.FileWriter(javaFile, false) - - fw.write("; Solver : "+name+"\n") - fw.write("; Options: "+interpreterOps(context).mkString(" ")+"\n") - - fw - } else None - - - /* Interruptible interface */ - - private var interrupted = false - - context.interruptManager.registerForInterrupts(this) - - override def interrupt(): Unit = { - interrupted = true - interpreter.interrupt() - } - override def recoverInterrupt(): Unit = { - interrupted = false - } - - - - /* - * Translation from Leon Expressions to SMTLIB terms and reverse - */ - - /* Symbol handling */ - protected object SimpleSymbol { - def unapply(term: Term): Option[SSymbol] = term match { - case QualifiedIdentifier(SMTIdentifier(sym, Seq()), None) => Some(sym) - case _ => None - } - } - - import scala.language.implicitConversions - protected implicit def symbolToQualifiedId(s: SSymbol): QualifiedIdentifier = { - QualifiedIdentifier(SMTIdentifier(s)) - } - - protected val adtManager = new ADTManager(context) - - protected val library = program.library - - protected def id2sym(id: Identifier): SSymbol = { - SSymbol(id.uniqueNameDelimited("!").replace("|", "$pipe").replace("\\", "$backslash")) - } - - protected def freshSym(id: Identifier): SSymbol = freshSym(id.name) - protected def freshSym(name: String): SSymbol = id2sym(FreshIdentifier(name)) - - - /* Metadata for CC, and variables */ - protected val constructors = new IncrementalBijection[TypeTree, SSymbol]() - protected val selectors = new IncrementalBijection[(TypeTree, Int), SSymbol]() - protected val testers = new IncrementalBijection[TypeTree, SSymbol]() - protected val variables = new IncrementalBijection[Identifier, SSymbol]() - protected val genericValues = new IncrementalBijection[GenericValue, SSymbol]() - protected val sorts = new IncrementalBijection[TypeTree, Sort]() - protected val functions = new IncrementalBijection[TypedFunDef, SSymbol]() - protected val errors = new IncrementalBijection[Unit, Boolean]() - protected def hasError = errors.getB(()) contains true - protected def addError() = errors += () -> true - - /* Helper functions */ - - protected def normalizeType(t: TypeTree): TypeTree = t match { - case ct: ClassType => ct.root - case tt: TupleType => tupleTypeWrap(tt.bases.map(normalizeType)) - case _ => t - } - - protected def quantifiedTerm( - quantifier: (SortedVar, Seq[SortedVar], Term) => Term, - vars: Seq[Identifier], - body: Expr - )( - implicit bindings: Map[Identifier, Term] - ): Term = { - if (vars.isEmpty) toSMT(body) - else { - val sortedVars = vars map { id => - SortedVar(id2sym(id), declareSort(id.getType)) - } - quantifier( - sortedVars.head, - sortedVars.tail, - toSMT(body)(bindings ++ vars.map{ id => id -> (id2sym(id): Term)}) - ) - } - } - - // Returns a quantified term where all free variables in the body have been quantified - protected def quantifiedTerm(quantifier: (SortedVar, Seq[SortedVar], Term) => Term, body: Expr)( - implicit bindings: Map[Identifier, Term] - ): Term = - quantifiedTerm(quantifier, variablesOf(body).toSeq, body) - - protected def fromRawArray(r: RawArrayValue, tpe: TypeTree): Expr = tpe match { - case SetType(base) => - if (r.default != BooleanLiteral(false)) { - unsupported(r, "Solver returned a co-finite set which is not supported.") - } - require(r.keyTpe == base, s"Type error in solver model, expected $base, found ${r.keyTpe}") - - FiniteSet(r.elems.keySet, base) - - case RawArrayType(from, to) => - r - - case FunctionType(from, to) => - r - - case MapType(from, to) => - // We expect a RawArrayValue with keys in from and values in Option[to], - // with default value == None - if (r.default.getType != library.noneType(to)) { - unsupported(r, "Solver returned a co-finite map which is not supported.") - } - require(r.keyTpe == from, s"Type error in solver model, expected $from, found ${r.keyTpe}") - - val elems = r.elems.flatMap { - case (k, CaseClass(leonSome, Seq(x))) => Some(k -> x) - case (k, _) => None - }.toSeq - FiniteMap(elems, from, to) - - case other => - unsupported(other, "Unable to extract from raw array for "+tpe) - } - - protected def declareSort(t: TypeTree): Sort = { - val tpe = normalizeType(t) - sorts.cachedB(tpe) { - tpe match { - case BooleanType => Core.BoolSort() - case IntegerType => Ints.IntSort() - case RealType => Reals.RealSort() - case Int32Type => FixedSizeBitVectors.BitVectorSort(32) - case CharType => FixedSizeBitVectors.BitVectorSort(32) - - case RawArrayType(from, to) => - Sort(SMTIdentifier(SSymbol("Array")), Seq(declareSort(from), declareSort(to))) - - case MapType(from, to) => - declareSort(RawArrayType(from, library.optionType(to))) - - case FunctionType(from, to) => - Sort(SMTIdentifier(SSymbol("Array")), Seq(declareSort(tupleTypeWrap(from)), declareSort(to))) - - case TypeParameter(id) => - val s = id2sym(id) - val cmd = DeclareSort(s, 0) - sendCommand(cmd) - Sort(SMTIdentifier(s)) - - case _: ClassType | _: TupleType | _: ArrayType | UnitType => - declareStructuralSort(tpe) - - case other => - unsupported(other, s"Could not transform $other into an SMT sort") - } - } - } - - protected def declareDatatypes(datatypes: Map[TypeTree, DataType]): Unit = { - // We pre-declare ADTs - for ((tpe, DataType(sym, _)) <- datatypes) { - sorts += tpe -> Sort(SMTIdentifier(id2sym(sym))) - } - - def toDecl(c: Constructor): SMTConstructor = { - val s = id2sym(c.sym) - - testers += c.tpe -> SSymbol("is-"+s.name) - constructors += c.tpe -> s - - SMTConstructor(s, c.fields.zipWithIndex.map { - case ((cs, t), i) => - selectors += (c.tpe, i) -> id2sym(cs) - (id2sym(cs), declareSort(t)) - }) - } - - val adts = for ((tpe, DataType(sym, cases)) <- datatypes.toList) yield { - (id2sym(sym), cases.map(toDecl)) - } - - if (adts.nonEmpty) { - val cmd = DeclareDatatypes(adts) - sendCommand(cmd) - } - - } - - protected def declareStructuralSort(t: TypeTree): Sort = { - // Populates the dependencies of the structural type to define. - adtManager.defineADT(t) match { - case Left(adts) => - declareDatatypes(adts) - sorts.toB(normalizeType(t)) - - case Right(conflicts) => - conflicts.foreach { declareStructuralSort } - declareStructuralSort(t) - } - - } - - protected def declareVariable(id: Identifier): SSymbol = { - variables.cachedB(id) { - val s = id2sym(id) - val cmd = DeclareFun(s, List(), declareSort(id.getType)) - sendCommand(cmd) - s - } - } - - protected def declareFunction(tfd: TypedFunDef): SSymbol = { - functions.cachedB(tfd) { - val id = if (tfd.tps.isEmpty) { - tfd.id - } else { - FreshIdentifier(tfd.id.name) - } - val s = id2sym(id) - sendCommand(DeclareFun( - s, - tfd.params.map( (p: ValDef) => declareSort(p.getType)), - declareSort(tfd.returnType) - )) - s - } - } - - /* Translate a Leon Expr to an SMTLIB term */ - - protected def toSMT(e: Expr)(implicit bindings: Map[Identifier, Term]): Term = { - e match { - case Variable(id) => - declareSort(e.getType) - bindings.getOrElse(id, variables.toB(id)) - - case UnitLiteral() => - declareSort(UnitType) - - declareVariable(FreshIdentifier("Unit", UnitType)) - - case InfiniteIntegerLiteral(i) => if (i >= 0) Ints.NumeralLit(i) else Ints.Neg(Ints.NumeralLit(-i)) - case IntLiteral(i) => FixedSizeBitVectors.BitVectorLit(Hexadecimal.fromInt(i)) - case RealLiteral(d) => if (d >= 0) Reals.DecimalLit(d) else Reals.Neg(Reals.DecimalLit(-d)) - case CharLiteral(c) => FixedSizeBitVectors.BitVectorLit(Hexadecimal.fromInt(c.toInt)) - case BooleanLiteral(v) => Core.BoolConst(v) - case Let(b,d,e) => - val id = id2sym(b) - val value = toSMT(d) - val newBody = toSMT(e)(bindings + (b -> id)) - - SMTLet( - VarBinding(id, value), - Seq(), - newBody - ) - - case er @ Error(tpe, _) => - declareVariable(FreshIdentifier("error_value", tpe)) - - case s @ CaseClassSelector(cct, e, id) => - declareSort(cct) - val selector = selectors.toB((cct, s.selectorIndex)) - FunctionApplication(selector, Seq(toSMT(e))) - - case AsInstanceOf(expr, cct) => - toSMT(expr) - - case IsInstanceOf(e, cct) => - declareSort(cct) - val cases = cct match { - case act: AbstractClassType => - act.knownCCDescendants - case cct: CaseClassType => - Seq(cct) - } - val oneOf = cases map testers.toB - oneOf match { - case Seq(tester) => - FunctionApplication(tester, Seq(toSMT(e))) - case more => - val es = freshSym("e") - SMTLet(VarBinding(es, toSMT(e)), Seq(), - Core.Or(oneOf.map(FunctionApplication(_, Seq(es:Term))): _*) - ) - } - - case CaseClass(cct, es) => - declareSort(cct) - val constructor = constructors.toB(cct) - if (es.isEmpty) { - constructor - } else { - FunctionApplication(constructor, es.map(toSMT)) - } - - case t @ Tuple(es) => - val tpe = normalizeType(t.getType) - declareSort(tpe) - val constructor = constructors.toB(tpe) - FunctionApplication(constructor, es.map(toSMT)) - - case ts @ TupleSelect(t, i) => - val tpe = normalizeType(t.getType) - declareSort(tpe) - val selector = selectors.toB((tpe, i-1)) - FunctionApplication(selector, Seq(toSMT(t))) - - case al @ ArrayLength(a) => - val tpe = normalizeType(a.getType) - val selector = selectors.toB((tpe, 0)) - - FunctionApplication(selector, Seq(toSMT(a))) - - case al @ ArraySelect(a, i) => - val tpe = normalizeType(a.getType) - - val scontent = FunctionApplication(selectors.toB((tpe, 1)), Seq(toSMT(a))) - - ArraysEx.Select(scontent, toSMT(i)) - - case al @ ArrayUpdated(a, i, e) => - val tpe = normalizeType(a.getType) - - val sa = toSMT(a) - val ssize = FunctionApplication(selectors.toB((tpe, 0)), Seq(sa)) - val scontent = FunctionApplication(selectors.toB((tpe, 1)), Seq(sa)) - - val newcontent = ArraysEx.Store(scontent, toSMT(i), toSMT(e)) - - val constructor = constructors.toB(tpe) - FunctionApplication(constructor, Seq(ssize, newcontent)) - - case ra @ RawArrayValue(keyTpe, elems, default) => - val s = declareSort(ra.getType) - - var res: Term = FunctionApplication( - QualifiedIdentifier(SMTIdentifier(SSymbol("const")), Some(s)), - List(toSMT(default)) - ) - for ((k, v) <- elems) { - res = ArraysEx.Store(res, toSMT(k), toSMT(v)) - } - - res - - case a @ FiniteArray(elems, oDef, size) => - val tpe @ ArrayType(to) = normalizeType(a.getType) - declareSort(tpe) - - val default: Expr = oDef.getOrElse(simplestValue(to)) - - val arr = toSMT(RawArrayValue(Int32Type, elems.map { - case (k, v) => IntLiteral(k) -> v - }, default)) - - FunctionApplication(constructors.toB(tpe), List(toSMT(size), arr)) - - /** - * ===== Map operations ===== - */ - case m @ FiniteMap(elems, _, _) => - val mt @ MapType(from, to) = m.getType - declareSort(mt) - - toSMT(RawArrayValue(from, elems.map { - case (k, v) => k -> CaseClass(library.someType(to), Seq(v)) - }.toMap, CaseClass(library.noneType(to), Seq()))) - - - case MapApply(m, k) => - val mt @ MapType(_, to) = m.getType - declareSort(mt) - // m(k) becomes - // (Some-value (select m k)) - FunctionApplication( - selectors.toB((library.someType(to), 0)), - Seq(ArraysEx.Select(toSMT(m), toSMT(k))) - ) - - case MapIsDefinedAt(m, k) => - val mt @ MapType(_, to) = m.getType - declareSort(mt) - // m.isDefinedAt(k) becomes - // (is-Some (select m k)) - FunctionApplication( - testers.toB(library.someType(to)), - Seq(ArraysEx.Select(toSMT(m), toSMT(k))) - ) - - case MapUnion(m1, FiniteMap(elems, _, _)) => - val MapType(_, t) = m1.getType - - elems.foldLeft(toSMT(m1)) { case (m, (k,v)) => - ArraysEx.Store(m, toSMT(k), toSMT(CaseClass(library.someType(t), Seq(v)))) - } - - case p : Passes => - toSMT(matchToIfThenElse(p.asConstraint)) - - case m : MatchExpr => - toSMT(matchToIfThenElse(m)) - - - case gv @ GenericValue(tpe, n) => - genericValues.cachedB(gv) { - declareVariable(FreshIdentifier("gv"+n, tpe)) - } - - /** - * ===== Everything else ===== - */ - case ap @ Application(caller, args) => - ArraysEx.Select(toSMT(caller), toSMT(tupleWrap(args))) - - case Not(u) => Core.Not(toSMT(u)) - case UMinus(u) => Ints.Neg(toSMT(u)) - case BVUMinus(u) => FixedSizeBitVectors.Neg(toSMT(u)) - case BVNot(u) => FixedSizeBitVectors.Not(toSMT(u)) - case Assert(a,_, b) => toSMT(IfExpr(a, b, Error(b.getType, "assertion failed"))) - - case Equals(a,b) => Core.Equals(toSMT(a), toSMT(b)) - case Implies(a,b) => Core.Implies(toSMT(a), toSMT(b)) - case Plus(a,b) => Ints.Add(toSMT(a), toSMT(b)) - case Minus(a,b) => Ints.Sub(toSMT(a), toSMT(b)) - case Times(a,b) => Ints.Mul(toSMT(a), toSMT(b)) - case Division(a,b) => { - val ar = toSMT(a) - val br = toSMT(b) - - Core.ITE( - Ints.GreaterEquals(ar, Ints.NumeralLit(0)), - Ints.Div(ar, br), - Ints.Neg(Ints.Div(Ints.Neg(ar), br))) - } - case Remainder(a,b) => { - val q = toSMT(Division(a, b)) - Ints.Sub(toSMT(a), Ints.Mul(toSMT(b), q)) - } - case Modulo(a,b) => { - Ints.Mod(toSMT(a), toSMT(b)) - } - case LessThan(a,b) => a.getType match { - case Int32Type => FixedSizeBitVectors.SLessThan(toSMT(a), toSMT(b)) - case IntegerType => Ints.LessThan(toSMT(a), toSMT(b)) - case RealType => Reals.LessThan(toSMT(a), toSMT(b)) - case CharType => FixedSizeBitVectors.SLessThan(toSMT(a), toSMT(b)) - } - case LessEquals(a,b) => a.getType match { - case Int32Type => FixedSizeBitVectors.SLessEquals(toSMT(a), toSMT(b)) - case IntegerType => Ints.LessEquals(toSMT(a), toSMT(b)) - case RealType => Reals.LessEquals(toSMT(a), toSMT(b)) - case CharType => FixedSizeBitVectors.SLessEquals(toSMT(a), toSMT(b)) - } - case GreaterThan(a,b) => a.getType match { - case Int32Type => FixedSizeBitVectors.SGreaterThan(toSMT(a), toSMT(b)) - case IntegerType => Ints.GreaterThan(toSMT(a), toSMT(b)) - case RealType => Reals.GreaterThan(toSMT(a), toSMT(b)) - case CharType => FixedSizeBitVectors.SGreaterThan(toSMT(a), toSMT(b)) - } - case GreaterEquals(a,b) => a.getType match { - case Int32Type => FixedSizeBitVectors.SGreaterEquals(toSMT(a), toSMT(b)) - case IntegerType => Ints.GreaterEquals(toSMT(a), toSMT(b)) - case RealType => Reals.GreaterEquals(toSMT(a), toSMT(b)) - case CharType => FixedSizeBitVectors.SGreaterEquals(toSMT(a), toSMT(b)) - } - case BVPlus(a,b) => FixedSizeBitVectors.Add(toSMT(a), toSMT(b)) - case BVMinus(a,b) => FixedSizeBitVectors.Sub(toSMT(a), toSMT(b)) - case BVTimes(a,b) => FixedSizeBitVectors.Mul(toSMT(a), toSMT(b)) - case BVDivision(a,b) => FixedSizeBitVectors.SDiv(toSMT(a), toSMT(b)) - case BVRemainder(a,b) => FixedSizeBitVectors.SRem(toSMT(a), toSMT(b)) - case BVAnd(a,b) => FixedSizeBitVectors.And(toSMT(a), toSMT(b)) - case BVOr(a,b) => FixedSizeBitVectors.Or(toSMT(a), toSMT(b)) - case BVXOr(a,b) => FixedSizeBitVectors.XOr(toSMT(a), toSMT(b)) - case BVShiftLeft(a,b) => FixedSizeBitVectors.ShiftLeft(toSMT(a), toSMT(b)) - case BVAShiftRight(a,b) => FixedSizeBitVectors.AShiftRight(toSMT(a), toSMT(b)) - case BVLShiftRight(a,b) => FixedSizeBitVectors.LShiftRight(toSMT(a), toSMT(b)) - - case RealPlus(a,b) => Reals.Add(toSMT(a), toSMT(b)) - case RealMinus(a,b) => Reals.Sub(toSMT(a), toSMT(b)) - case RealTimes(a,b) => Reals.Mul(toSMT(a), toSMT(b)) - case RealDivision(a,b) => Reals.Div(toSMT(a), toSMT(b)) - - case And(sub) => Core.And(sub.map(toSMT): _*) - case Or(sub) => Core.Or(sub.map(toSMT): _*) - case IfExpr(cond, thenn, elze) => Core.ITE(toSMT(cond), toSMT(thenn), toSMT(elze)) - case f@FunctionInvocation(_, sub) => - if (sub.isEmpty) declareFunction(f.tfd) else { - FunctionApplication( - declareFunction(f.tfd), - sub.map(toSMT) - ) - } - case o => - unsupported(o, "") - } - } - - /* Translate an SMTLIB term back to a Leon Expr */ - - protected def fromSMT(pair: (Term, TypeTree))(implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): Expr = { - fromSMT(pair._1, pair._2) - } - - protected def fromSMT(s: Term, tpe: TypeTree)(implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): Expr = (s, tpe) match { - case (_, UnitType) => - UnitLiteral() - - case (FixedSizeBitVectors.BitVectorConstant(n, b), CharType) if b == BigInt(32) => - CharLiteral(n.toInt.toChar) - - case (SHexadecimal(h), CharType) => - CharLiteral(h.toInt.toChar) - - case (SNumeral(n), IntegerType) => - InfiniteIntegerLiteral(n) - - case (SDecimal(d), RealType) => - RealLiteral(d) - - case (SNumeral(n), RealType) => - RealLiteral(BigDecimal(n)) - - case (Core.True(), BooleanType) => BooleanLiteral(true) - case (Core.False(), BooleanType) => BooleanLiteral(false) - - case (FixedSizeBitVectors.BitVectorConstant(n, b), Int32Type) if b == BigInt(32) => IntLiteral(n.toInt) - case (SHexadecimal(hexa), Int32Type) => IntLiteral(hexa.toInt) - - case (SimpleSymbol(s), _: ClassType) if constructors.containsB(s) => - constructors.toA(s) match { - case cct: CaseClassType => - CaseClass(cct, Nil) - case t => - unsupported(t, "woot? for a single constructor for non-case-object") - } - - case (SimpleSymbol(s), tpe) if lets contains s => - fromSMT(lets(s), tpe) - - case (SimpleSymbol(s), _) => - variables.getA(s).map(_.toVariable).getOrElse { - reporter.fatalError("Unknown symbol: "+s) - } - - case (FunctionApplication(SimpleSymbol(SSymbol("ite")), Seq(cond, thenn, elze)), t) => - IfExpr( - fromSMT(cond, BooleanType), - fromSMT(thenn, t), - fromSMT(elze, t) - ) - - case (FunctionApplication(SimpleSymbol(s), args), tpe) if constructors.containsB(s) => - constructors.toA(s) match { - case cct: CaseClassType => - val rargs = args.zip(cct.fields.map(_.getType)).map(fromSMT) - CaseClass(cct, rargs) - case tt: TupleType => - val rargs = args.zip(tt.bases).map(fromSMT) - tupleWrap(rargs) - - case ArrayType(baseType) => - val IntLiteral(size) = fromSMT(args(0), Int32Type) - val RawArrayValue(_, elems, default) = fromSMT(args(1), RawArrayType(Int32Type, baseType)) - - if(size > 10) { - val definedElements = elems.collect { - case (IntLiteral(i), value) => (i, value) - } - finiteArray(definedElements, Some(default, IntLiteral(size)), baseType) - - } else { - val entries = for (i <- 0 to size-1) yield elems.getOrElse(IntLiteral(i), default) - - finiteArray(entries, None, baseType) - } - - case t => - unsupported(t, "Woot? structural type that is non-structural") - } - - // EK: Since we have no type information, we cannot do type-directed - // extraction of defs, instead, we expand them in smt-world - case (SMTLet(binding, bindings, body), tpe) => - val defsMap: Map[SSymbol, Term] = (binding +: bindings).map { - case VarBinding(s, value) => (s, value) - }.toMap - - fromSMT(body, tpe)(lets ++ defsMap, letDefs) - - case (FunctionApplication(SimpleSymbol(SSymbol(app)), args), tpe) => { - app match { - case "-" => - (args, tpe) match { - case (List(a), IntegerType) => UMinus(fromSMT(a, IntegerType)) - case (List(a, b), IntegerType) => Minus(fromSMT(a, IntegerType), fromSMT(b, IntegerType)) - case (List(a), RealType) => RealUMinus(fromSMT(a, RealType)) - case (List(a, b), RealType) => RealMinus(fromSMT(a, RealType), fromSMT(b, RealType)) - } - case "/" => - (args, tpe) match { - case (List(a, b), RealType) => RealDivision(fromSMT(a, RealType), fromSMT(b, RealType)) - } - - case _ => - reporter.fatalError("Function "+app+" not handled in fromSMT: "+s) - } - } - case (QualifiedIdentifier(id, sort), tpe) => - reporter.fatalError("Unhandled case in fromSMT: " + id +": "+sort +" ("+tpe+")") - case _ => - reporter.fatalError("Unhandled case in fromSMT: " + (s, tpe)) - } - - - /* Send a command to the solver */ - def sendCommand(cmd: Command, rawOut: Boolean = false): CommandResponse = { - out foreach { o => - SMTPrinter.printCommand(cmd, o) - o.write("\n") - o.flush() - } - interpreter.eval(cmd) match { - case err@ErrorResponse(msg) if !hasError && !interrupted && !rawOut => - reporter.warning(s"Unexpected error from $name solver: $msg") - // Store that there was an error. Now all following check() - // invocations will return None - addError() - err - case res => res - } - } - - /* Public solver interface */ - - def free() = { - interpreter.free() - context.interruptManager.unregisterForInterrupts(this) - out foreach { _.close } - } - override def assertCnstr(expr: Expr): Unit = if(!hasError) { - variablesOf(expr).foreach(declareVariable) try { + variablesOf(expr).foreach(declareVariable) + val term = toSMT(expr)(Map()) - sendCommand(SMTAssert(term)) + emit(SMTAssert(term)) } catch { - case _ : SolverUnsupportedError => + case _ : SMTLIBUnsupportedError => // Store that there was an error. Now all following check() // invocations will return None addError() @@ -730,7 +46,7 @@ abstract class SMTLIBSolver(val context: LeonContext, val program: Program) } override def reset() = { - sendCommand(Reset(), rawOut = true) match { + emit(Reset(), rawOut = true) match { case ErrorResponse(msg) => reporter.warning(s"Failed to reset $name: $msg") throw new CantResetException(this) @@ -740,7 +56,7 @@ abstract class SMTLIBSolver(val context: LeonContext, val program: Program) override def check: Option[Boolean] = { if (hasError) None - else sendCommand(CheckSat()) match { + else emit(CheckSat()) match { case CheckSatStatus(SatStatus) => Some(true) case CheckSatStatus(UnsatStatus) => Some(false) case CheckSatStatus(UnknownStatus) => None @@ -753,22 +69,27 @@ abstract class SMTLIBSolver(val context: LeonContext, val program: Program) if (syms.isEmpty) { Model.empty } else { - val cmd: Command = GetValue( - syms.head, - syms.tail.map(s => QualifiedIdentifier(SMTIdentifier(s))) - ) + try { + val cmd: Command = GetValue( + syms.head, + syms.tail.map(s => QualifiedIdentifier(SMTIdentifier(s))) + ) - sendCommand(cmd) match { - case GetValueResponseSuccess(valuationPairs) => - new Model(valuationPairs.collect { - case (SimpleSymbol(sym), value) if variables.containsB(sym) => - val id = variables.toA(sym) + emit(cmd) match { + case GetValueResponseSuccess(valuationPairs) => - (id, fromSMT(value, id.getType)(Map(), Map())) - }.toMap) + new Model(valuationPairs.collect { + case (SimpleSymbol(sym), value) if variables.containsB(sym) => + val id = variables.toA(sym) - case _ => - Model.empty //FIXME improve this + (id, fromSMT(value, id.getType)(Map(), Map())) + }.toMap) + case _ => + Model.empty //FIXME improve this + } + } catch { + case e : SMTLIBUnsupportedError => + throw new SolverUnsupportedError(e.t, this, e.reason) } } } @@ -784,7 +105,8 @@ abstract class SMTLIBSolver(val context: LeonContext, val program: Program) sorts.push() functions.push() errors.push() - sendCommand(Push(1)) + + emit(Push(1)) } override def pop(): Unit = { @@ -797,10 +119,7 @@ abstract class SMTLIBSolver(val context: LeonContext, val program: Program) functions.pop() errors.pop() - sendCommand(Pop(1)) + emit(Pop(1)) } } - -// Unique numbers -private [smtlib] object VCNumbers extends UniqueCounter[String] diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala new file mode 100644 index 000000000..c1b622242 --- /dev/null +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala @@ -0,0 +1,818 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package solvers +package smtlib + +import utils._ + +import purescala.Common._ +import purescala.Expressions._ +import purescala.Extractors._ +import purescala.ExprOps._ +import purescala.Types._ +import purescala.Constructors._ +import purescala.Definitions._ + +import _root_.smtlib.common._ +import _root_.smtlib.printer.{RecursivePrinter => SMTPrinter} +import _root_.smtlib.parser.Commands.{Constructor => SMTConstructor, FunDef => _, Assert => SMTAssert, _} +import _root_.smtlib.parser.Terms.{ + Forall => SMTForall, + Exists => SMTExists, + Identifier => SMTIdentifier, + Let => SMTLet, + _ +} +import _root_.smtlib.theories.Core.{ + Equals => SMTEquals +} +import _root_.smtlib.parser.CommandsResponses.{Error => ErrorResponse, _} +import _root_.smtlib.theories._ +import _root_.smtlib.interpreters.ProcessInterpreter + +trait SMTLIBTarget extends Interruptible { + val context: LeonContext; + val program: Program; + protected val reporter: Reporter; + + def targetName: String + + implicit val debugSection: DebugSection; + + protected def interpreterOps(ctx: LeonContext): Seq[String] + + protected def getNewInterpreter(ctx: LeonContext): ProcessInterpreter + + protected def unsupported(t: Tree, str: String): Nothing; + + + protected lazy val interpreter = getNewInterpreter(context) + + /* Interruptible interface */ + private var interrupted = false + + context.interruptManager.registerForInterrupts(this) + + override def interrupt(): Unit = { + interrupted = true + interpreter.interrupt() + } + override def recoverInterrupt(): Unit = { + interrupted = false + } + + def free() = { + interpreter.free() + context.interruptManager.unregisterForInterrupts(this) + debugOut foreach { _.close } + } + + /* Printing VCs */ + protected lazy val debugOut: Option[java.io.FileWriter] = { + if (reporter.isDebugEnabled) { + val file = context.files.headOption.map(_.getName).getOrElse("NA") + val n = DebugFileNumbers.next(targetName+file) + + val fileName = s"smt-sessions/$targetName-$file-$n.smt2" + + val javaFile = new java.io.File(fileName) + javaFile.getParentFile.mkdirs() + + reporter.debug(s"Outputting smt session into $fileName" ) + + val fw = new java.io.FileWriter(javaFile, false) + + fw.write("; Options: "+interpreterOps(context).mkString(" ")+"\n") + + Some(fw) + } else { + None + } + } + + /* Send a command to the solver */ + def emit(cmd: SExpr, rawOut: Boolean = false): SExpr = { + debugOut foreach { o => + SMTPrinter.printSExpr(cmd, o) + o.write("\n") + o.flush() + } + interpreter.eval(cmd) match { + case err@ErrorResponse(msg) if !hasError && !interrupted && !rawOut => + reporter.warning(s"Unexpected error from $targetName solver: $msg") + // Store that there was an error. Now all following check() + // invocations will return None + addError() + err + case res => + res + } + } + + def parseSuccess() = { + val res = interpreter.parser.parseGenResponse + if (res != Success) { + reporter.warning("Unnexpected result from "+targetName+": "+res+" expected success") + } + } + + /* + * Translation from Leon Expressions to SMTLIB terms and reverse + */ + + /* Symbol handling */ + protected object SimpleSymbol { + def unapply(term: Term): Option[SSymbol] = term match { + case QualifiedIdentifier(SMTIdentifier(sym, Seq()), None) => Some(sym) + case _ => None + } + } + + import scala.language.implicitConversions + protected implicit def symbolToQualifiedId(s: SSymbol): QualifiedIdentifier = { + QualifiedIdentifier(SMTIdentifier(s)) + } + + protected val adtManager = new ADTManager(context) + + protected val library = program.library + + protected def id2sym(id: Identifier): SSymbol = { + SSymbol(id.uniqueNameDelimited("!").replace("|", "$pipe").replace("\\", "$backslash")) + } + + + protected def freshSym(id: Identifier): SSymbol = freshSym(id.name) + protected def freshSym(name: String): SSymbol = id2sym(FreshIdentifier(name)) + + + /* Metadata for CC, and variables */ + protected val constructors = new IncrementalBijection[TypeTree, SSymbol]() + protected val selectors = new IncrementalBijection[(TypeTree, Int), SSymbol]() + protected val testers = new IncrementalBijection[TypeTree, SSymbol]() + protected val variables = new IncrementalBijection[Identifier, SSymbol]() + protected val genericValues = new IncrementalBijection[GenericValue, SSymbol]() + protected val sorts = new IncrementalBijection[TypeTree, Sort]() + protected val functions = new IncrementalBijection[TypedFunDef, SSymbol]() + protected val errors = new IncrementalBijection[Unit, Boolean]() + protected def hasError = errors.getB(()) contains true + protected def addError() = errors += () -> true + + /* Helper functions */ + + protected def normalizeType(t: TypeTree): TypeTree = t match { + case ct: ClassType => ct.root + case tt: TupleType => tupleTypeWrap(tt.bases.map(normalizeType)) + case _ => t + } + + protected def quantifiedTerm( + quantifier: (SortedVar, Seq[SortedVar], Term) => Term, + vars: Seq[Identifier], + body: Expr + )( + implicit bindings: Map[Identifier, Term] + ): Term = { + if (vars.isEmpty) toSMT(body) + if (vars.isEmpty) toSMT(body)(Map()) + else { + val sortedVars = vars map { id => + SortedVar(id2sym(id), declareSort(id.getType)) + } + quantifier( + sortedVars.head, + sortedVars.tail, + toSMT(body)(bindings ++ vars.map{ id => id -> (id2sym(id): Term)}.toMap) + ) + } + } + + // Returns a quantified term where all free variables in the body have been quantified + protected def quantifiedTerm(quantifier: (SortedVar, Seq[SortedVar], Term) => Term, body: Expr)( + implicit bindings: Map[Identifier, Term] + ): Term = + quantifiedTerm(quantifier, variablesOf(body).toSeq, body) + + protected def fromRawArray(r: RawArrayValue, tpe: TypeTree): Expr = tpe match { + case SetType(base) => + if (r.default != BooleanLiteral(false)) { + unsupported(r, "Solver returned a co-finite set which is not supported.") + } + require(r.keyTpe == base, s"Type error in solver model, expected $base, found ${r.keyTpe}") + + FiniteSet(r.elems.keySet, base) + + case RawArrayType(from, to) => + r + + case ft @ FunctionType(from, to) => + r + + case MapType(from, to) => + // We expect a RawArrayValue with keys in from and values in Option[to], + // with default value == None + if (r.default.getType != library.noneType(to)) { + unsupported(r, "Solver returned a co-finite map which is not supported.") + } + require(r.keyTpe == from, s"Type error in solver model, expected $from, found ${r.keyTpe}") + + val elems = r.elems.flatMap { + case (k, CaseClass(leonSome, Seq(x))) => Some(k -> x) + case (k, _) => None + }.toSeq + FiniteMap(elems, from, to) + + case other => + unsupported(other, "Unable to extract from raw array for "+tpe) + } + + protected def declareUninterpretedSort(t: TypeParameter): Sort = { + val s = id2sym(t.id) + val cmd = DeclareSort(s, 0) + emit(cmd) + Sort(SMTIdentifier(s)) + } + + protected def declareSort(t: TypeTree): Sort = { + val tpe = normalizeType(t) + sorts.cachedB(tpe) { + tpe match { + case BooleanType => Core.BoolSort() + case IntegerType => Ints.IntSort() + case RealType => Reals.RealSort() + case Int32Type => FixedSizeBitVectors.BitVectorSort(32) + case CharType => FixedSizeBitVectors.BitVectorSort(32) + + case RawArrayType(from, to) => + Sort(SMTIdentifier(SSymbol("Array")), Seq(declareSort(from), declareSort(to))) + + case MapType(from, to) => + declareSort(RawArrayType(from, library.optionType(to))) + + case FunctionType(from, to) => + Sort(SMTIdentifier(SSymbol("Array")), Seq(declareSort(tupleTypeWrap(from)), declareSort(to))) + + case tp: TypeParameter => + declareUninterpretedSort(tp) + + case _: ClassType | _: TupleType | _: ArrayType | UnitType => + declareStructuralSort(tpe) + + case other => + unsupported(other, s"Could not transform $other into an SMT sort") + } + } + } + + protected def declareDatatypes(datatypes: Map[TypeTree, DataType]): Unit = { + // We pre-declare ADTs + for ((tpe, DataType(sym, _)) <- datatypes) { + sorts += tpe -> Sort(SMTIdentifier(id2sym(sym))) + } + + def toDecl(c: Constructor): SMTConstructor = { + val s = id2sym(c.sym) + + testers += c.tpe -> SSymbol("is-"+s.name) + constructors += c.tpe -> s + + SMTConstructor(s, c.fields.zipWithIndex.map { + case ((cs, t), i) => + selectors += (c.tpe, i) -> id2sym(cs) + (id2sym(cs), declareSort(t)) + }) + } + + val adts = for ((tpe, DataType(sym, cases)) <- datatypes.toList) yield { + (id2sym(sym), cases.map(toDecl)) + } + + if (adts.nonEmpty) { + val cmd = DeclareDatatypes(adts) + emit(cmd) + } + + } + + protected def declareStructuralSort(t: TypeTree): Sort = { + // Populates the dependencies of the structural type to define. + adtManager.defineADT(t) match { + case Left(adts) => + declareDatatypes(adts) + sorts.toB(normalizeType(t)) + + case Right(conflicts) => + conflicts.foreach { declareStructuralSort } + declareStructuralSort(t) + } + + } + + protected def declareVariable(id: Identifier): SSymbol = { + variables.cachedB(id) { + val s = id2sym(id) + val cmd = DeclareFun(s, List(), declareSort(id.getType)) + emit(cmd) + s + } + } + + protected def declareFunction(tfd: TypedFunDef): SSymbol = { + functions.cachedB(tfd) { + val id = if (tfd.tps.isEmpty) { + tfd.id + } else { + FreshIdentifier(tfd.id.name) + } + val s = id2sym(id) + emit(DeclareFun( + s, + tfd.params.map( (p: ValDef) => declareSort(p.getType)), + declareSort(tfd.returnType) + )) + s + } + } + + /* Translate a Leon Expr to an SMTLIB term */ + + def sortToSMT(s: Sort): SExpr = { + s match { + case Sort(id, Nil) => + id.symbol + + case Sort(id, subs) => + SList((id.symbol +: subs.map(sortToSMT)).toList) + } + } + + protected def toSMT(t: TypeTree): SExpr = { + sortToSMT(declareSort(t)) + } + + protected def toSMT(e: Expr)(implicit bindings: Map[Identifier, Term]): Term = { + e match { + case Variable(id) => + declareSort(e.getType) + bindings.getOrElse(id, variables.toB(id)) + + case UnitLiteral() => + declareSort(UnitType) + + declareVariable(FreshIdentifier("Unit", UnitType)) + + case InfiniteIntegerLiteral(i) => if (i >= 0) Ints.NumeralLit(i) else Ints.Neg(Ints.NumeralLit(-i)) + case IntLiteral(i) => FixedSizeBitVectors.BitVectorLit(Hexadecimal.fromInt(i)) + case RealLiteral(d) => if (d >= 0) Reals.DecimalLit(d) else Reals.Neg(Reals.DecimalLit(-d)) + case CharLiteral(c) => FixedSizeBitVectors.BitVectorLit(Hexadecimal.fromInt(c.toInt)) + case BooleanLiteral(v) => Core.BoolConst(v) + case Let(b,d,e) => + val id = id2sym(b) + val value = toSMT(d) + val newBody = toSMT(e)(bindings + (b -> id)) + + SMTLet( + VarBinding(id, value), + Seq(), + newBody + ) + + case er @ Error(tpe, _) => + declareVariable(FreshIdentifier("error_value", tpe)) + + case s @ CaseClassSelector(cct, e, id) => + declareSort(cct) + val selector = selectors.toB((cct, s.selectorIndex)) + FunctionApplication(selector, Seq(toSMT(e))) + + case AsInstanceOf(expr, cct) => + toSMT(expr) + + case IsInstanceOf(e, cct) => + declareSort(cct) + val cases = cct match { + case act: AbstractClassType => + act.knownCCDescendants + case cct: CaseClassType => + Seq(cct) + } + val oneOf = cases map testers.toB + oneOf match { + case Seq(tester) => + FunctionApplication(tester, Seq(toSMT(e))) + case more => + val es = freshSym("e") + SMTLet(VarBinding(es, toSMT(e)), Seq(), + Core.Or(oneOf.map(FunctionApplication(_, Seq(es:Term))): _*) + ) + } + + case CaseClass(cct, es) => + declareSort(cct) + val constructor = constructors.toB(cct) + if (es.isEmpty) { + constructor + } else { + FunctionApplication(constructor, es.map(toSMT)) + } + + case t @ Tuple(es) => + val tpe = normalizeType(t.getType) + declareSort(tpe) + val constructor = constructors.toB(tpe) + FunctionApplication(constructor, es.map(toSMT)) + + case ts @ TupleSelect(t, i) => + val tpe = normalizeType(t.getType) + declareSort(tpe) + val selector = selectors.toB((tpe, i-1)) + FunctionApplication(selector, Seq(toSMT(t))) + + case al @ ArrayLength(a) => + val tpe = normalizeType(a.getType) + val selector = selectors.toB((tpe, 0)) + + FunctionApplication(selector, Seq(toSMT(a))) + + case al @ ArraySelect(a, i) => + val tpe = normalizeType(a.getType) + + val scontent = FunctionApplication(selectors.toB((tpe, 1)), Seq(toSMT(a))) + + ArraysEx.Select(scontent, toSMT(i)) + + case al @ ArrayUpdated(a, i, e) => + val tpe = normalizeType(a.getType) + + val sa = toSMT(a) + val ssize = FunctionApplication(selectors.toB((tpe, 0)), Seq(sa)) + val scontent = FunctionApplication(selectors.toB((tpe, 1)), Seq(sa)) + + val newcontent = ArraysEx.Store(scontent, toSMT(i), toSMT(e)) + + val constructor = constructors.toB(tpe) + FunctionApplication(constructor, Seq(ssize, newcontent)) + + case ra @ RawArrayValue(keyTpe, elems, default) => + val s = declareSort(ra.getType) + + var res: Term = FunctionApplication( + QualifiedIdentifier(SMTIdentifier(SSymbol("const")), Some(s)), + List(toSMT(default)) + ) + for ((k, v) <- elems) { + res = ArraysEx.Store(res, toSMT(k), toSMT(v)) + } + + res + + case a @ FiniteArray(elems, oDef, size) => + val tpe @ ArrayType(to) = normalizeType(a.getType) + declareSort(tpe) + + val default: Expr = oDef.getOrElse(simplestValue(to)) + + val arr = toSMT(RawArrayValue(Int32Type, elems.map { + case (k, v) => IntLiteral(k) -> v + }, default)) + + FunctionApplication(constructors.toB(tpe), List(toSMT(size), arr)) + + /** + * ===== Map operations ===== + */ + case m @ FiniteMap(elems, _, _) => + val mt @ MapType(from, to) = m.getType + declareSort(mt) + + toSMT(RawArrayValue(from, elems.map { + case (k, v) => k -> CaseClass(library.someType(to), Seq(v)) + }.toMap, CaseClass(library.noneType(to), Seq()))) + + + case MapApply(m, k) => + val mt @ MapType(_, to) = m.getType + declareSort(mt) + // m(k) becomes + // (Some-value (select m k)) + FunctionApplication( + selectors.toB((library.someType(to), 0)), + Seq(ArraysEx.Select(toSMT(m), toSMT(k))) + ) + + case MapIsDefinedAt(m, k) => + val mt @ MapType(_, to) = m.getType + declareSort(mt) + // m.isDefinedAt(k) becomes + // (is-Some (select m k)) + FunctionApplication( + testers.toB(library.someType(to)), + Seq(ArraysEx.Select(toSMT(m), toSMT(k))) + ) + + case MapUnion(m1, FiniteMap(elems, _, _)) => + val MapType(_, t) = m1.getType + + elems.foldLeft(toSMT(m1)) { case (m, (k,v)) => + ArraysEx.Store(m, toSMT(k), toSMT(CaseClass(library.someType(t), Seq(v)))) + } + + case p : Passes => + toSMT(matchToIfThenElse(p.asConstraint)) + + case m : MatchExpr => + toSMT(matchToIfThenElse(m)) + + + case gv @ GenericValue(tpe, n) => + genericValues.cachedB(gv) { + declareVariable(FreshIdentifier("gv"+n, tpe)) + } + + /** + * ===== Everything else ===== + */ + case ap @ Application(caller, args) => + ArraysEx.Select(toSMT(caller), toSMT(tupleWrap(args))) + + case Not(u) => Core.Not(toSMT(u)) + case UMinus(u) => Ints.Neg(toSMT(u)) + case BVUMinus(u) => FixedSizeBitVectors.Neg(toSMT(u)) + case BVNot(u) => FixedSizeBitVectors.Not(toSMT(u)) + case Assert(a,_, b) => toSMT(IfExpr(a, b, Error(b.getType, "assertion failed"))) + + case Equals(a,b) => Core.Equals(toSMT(a), toSMT(b)) + case Implies(a,b) => Core.Implies(toSMT(a), toSMT(b)) + case Plus(a,b) => Ints.Add(toSMT(a), toSMT(b)) + case Minus(a,b) => Ints.Sub(toSMT(a), toSMT(b)) + case Times(a,b) => Ints.Mul(toSMT(a), toSMT(b)) + case Division(a,b) => { + val ar = toSMT(a) + val br = toSMT(b) + + Core.ITE( + Ints.GreaterEquals(ar, Ints.NumeralLit(0)), + Ints.Div(ar, br), + Ints.Neg(Ints.Div(Ints.Neg(ar), br))) + } + case Remainder(a,b) => { + val q = toSMT(Division(a, b)) + Ints.Sub(toSMT(a), Ints.Mul(toSMT(b), q)) + } + case Modulo(a,b) => { + Ints.Mod(toSMT(a), toSMT(b)) + } + case LessThan(a,b) => a.getType match { + case Int32Type => FixedSizeBitVectors.SLessThan(toSMT(a), toSMT(b)) + case IntegerType => Ints.LessThan(toSMT(a), toSMT(b)) + case RealType => Reals.LessThan(toSMT(a), toSMT(b)) + case CharType => FixedSizeBitVectors.SLessThan(toSMT(a), toSMT(b)) + } + case LessEquals(a,b) => a.getType match { + case Int32Type => FixedSizeBitVectors.SLessEquals(toSMT(a), toSMT(b)) + case IntegerType => Ints.LessEquals(toSMT(a), toSMT(b)) + case RealType => Reals.LessEquals(toSMT(a), toSMT(b)) + case CharType => FixedSizeBitVectors.SLessEquals(toSMT(a), toSMT(b)) + } + case GreaterThan(a,b) => a.getType match { + case Int32Type => FixedSizeBitVectors.SGreaterThan(toSMT(a), toSMT(b)) + case IntegerType => Ints.GreaterThan(toSMT(a), toSMT(b)) + case RealType => Reals.GreaterThan(toSMT(a), toSMT(b)) + case CharType => FixedSizeBitVectors.SGreaterThan(toSMT(a), toSMT(b)) + } + case GreaterEquals(a,b) => a.getType match { + case Int32Type => FixedSizeBitVectors.SGreaterEquals(toSMT(a), toSMT(b)) + case IntegerType => Ints.GreaterEquals(toSMT(a), toSMT(b)) + case RealType => Reals.GreaterEquals(toSMT(a), toSMT(b)) + case CharType => FixedSizeBitVectors.SGreaterEquals(toSMT(a), toSMT(b)) + } + case BVPlus(a,b) => FixedSizeBitVectors.Add(toSMT(a), toSMT(b)) + case BVMinus(a,b) => FixedSizeBitVectors.Sub(toSMT(a), toSMT(b)) + case BVTimes(a,b) => FixedSizeBitVectors.Mul(toSMT(a), toSMT(b)) + case BVDivision(a,b) => FixedSizeBitVectors.SDiv(toSMT(a), toSMT(b)) + case BVRemainder(a,b) => FixedSizeBitVectors.SRem(toSMT(a), toSMT(b)) + case BVAnd(a,b) => FixedSizeBitVectors.And(toSMT(a), toSMT(b)) + case BVOr(a,b) => FixedSizeBitVectors.Or(toSMT(a), toSMT(b)) + case BVXOr(a,b) => FixedSizeBitVectors.XOr(toSMT(a), toSMT(b)) + case BVShiftLeft(a,b) => FixedSizeBitVectors.ShiftLeft(toSMT(a), toSMT(b)) + case BVAShiftRight(a,b) => FixedSizeBitVectors.AShiftRight(toSMT(a), toSMT(b)) + case BVLShiftRight(a,b) => FixedSizeBitVectors.LShiftRight(toSMT(a), toSMT(b)) + + case RealPlus(a,b) => Reals.Add(toSMT(a), toSMT(b)) + case RealMinus(a,b) => Reals.Sub(toSMT(a), toSMT(b)) + case RealTimes(a,b) => Reals.Mul(toSMT(a), toSMT(b)) + case RealDivision(a,b) => Reals.Div(toSMT(a), toSMT(b)) + + case And(sub) => Core.And(sub.map(toSMT): _*) + case Or(sub) => Core.Or(sub.map(toSMT): _*) + case IfExpr(cond, thenn, elze) => Core.ITE(toSMT(cond), toSMT(thenn), toSMT(elze)) + case f@FunctionInvocation(_, sub) => + if (sub.isEmpty) declareFunction(f.tfd) else { + FunctionApplication( + declareFunction(f.tfd), + sub.map(toSMT) + ) + } + case Forall(vs, bd) => + quantifiedTerm(SMTForall, vs map { _.id }, bd)(Map()) + case o => + unsupported(o, "") + } + } + + /* Translate an SMTLIB term back to a Leon Expr */ + + protected def fromSMT(pair: (Term, TypeTree))(implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): Expr = { + fromSMT(pair._1, pair._2) + } + + protected def fromUntypedSMT(t: Term)(implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): Expr = t match { + case SimpleSymbol(s) if constructors.containsB(s) => + constructors.toA(s) match { + case cct: CaseClassType => + CaseClass(cct, Nil) + case t => + unsupported(t, "woot? for a single constructor for non-case-object") + } + + case SimpleSymbol(s) if lets contains s => + fromUntypedSMT(lets(s)) + + case SimpleSymbol(s) => + variables.getA(s).map(_.toVariable).getOrElse { + reporter.fatalError("Unknown symbol: "+s) + } + case _ => + reporter.fatalError("Unhandled case in fromUntypedSMT: " + t) + + } + + protected def fromSMT(s: Term, tpe: TypeTree)(implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): Expr = (s, tpe) match { + case (_, UnitType) => + UnitLiteral() + + case (FixedSizeBitVectors.BitVectorConstant(n, b), CharType) if b == BigInt(32) => + CharLiteral(n.toInt.toChar) + + case (SHexadecimal(h), CharType) => + CharLiteral(h.toInt.toChar) + + case (SNumeral(n), IntegerType) => + InfiniteIntegerLiteral(n) + + case (SDecimal(d), RealType) => + RealLiteral(d) + + case (SNumeral(n), RealType) => + RealLiteral(BigDecimal(n)) + + case (Core.True(), BooleanType) => BooleanLiteral(true) + case (Core.False(), BooleanType) => BooleanLiteral(false) + + case (FixedSizeBitVectors.BitVectorConstant(n, b), Int32Type) if b == BigInt(32) => IntLiteral(n.toInt) + case (SHexadecimal(hexa), Int32Type) => IntLiteral(hexa.toInt) + + case (SimpleSymbol(s), _: ClassType) if constructors.containsB(s) => + constructors.toA(s) match { + case cct: CaseClassType => + CaseClass(cct, Nil) + case t => + unsupported(t, "woot? for a single constructor for non-case-object") + } + + case (SimpleSymbol(s), tpe) if lets contains s => + fromSMT(lets(s), tpe) + + case (SimpleSymbol(s), _) => + variables.getA(s).map(_.toVariable).getOrElse { + reporter.fatalError("Unknown symbol: "+s) + } + + case (FunctionApplication(SimpleSymbol(SSymbol("ite")), Seq(cond, thenn, elze)), t) => + IfExpr( + fromSMT(cond, BooleanType), + fromSMT(thenn, t), + fromSMT(elze, t) + ) + + case (FunctionApplication(SimpleSymbol(s), args), tpe) if constructors.containsB(s) => + constructors.toA(s) match { + case cct: CaseClassType => + val rargs = args.zip(cct.fields.map(_.getType)).map(fromSMT) + CaseClass(cct, rargs) + case tt: TupleType => + val rargs = args.zip(tt.bases).map(fromSMT) + tupleWrap(rargs) + + case ArrayType(baseType) => + val IntLiteral(size) = fromSMT(args(0), Int32Type) + val RawArrayValue(_, elems, default) = fromSMT(args(1), RawArrayType(Int32Type, baseType)) + + if(size > 10) { + val definedElements = elems.collect { + case (IntLiteral(i), value) => (i, value) + } + finiteArray(definedElements, Some(default, IntLiteral(size)), baseType) + + } else { + val entries = for (i <- 0 to size-1) yield elems.getOrElse(IntLiteral(i), default) + + finiteArray(entries, None, baseType) + } + + case t => + unsupported(t, "Woot? structural type that is non-structural") + } + + // EK: Since we have no type information, we cannot do type-directed + // extraction of defs, instead, we expand them in smt-world + case (SMTLet(binding, bindings, body), tpe) => + val defsMap: Map[SSymbol, Term] = (binding +: bindings).map { + case VarBinding(s, value) => (s, value) + }.toMap + + fromSMT(body, tpe)(lets ++ defsMap, letDefs) + + case (FunctionApplication(SimpleSymbol(SSymbol(app)), args), tpe) => { + app match { + case ">=" => + (args, tpe) match { + case (List(a, b), BooleanType) => GreaterEquals(fromSMT(a, IntegerType), fromSMT(b, IntegerType)) + } + + case "<=" => + (args, tpe) match { + case (List(a, b), BooleanType) => LessEquals(fromSMT(a, IntegerType), fromSMT(b, IntegerType)) + } + + case ">" => + (args, tpe) match { + case (List(a, b), BooleanType) => GreaterThan(fromSMT(a, IntegerType), fromSMT(b, IntegerType)) + } + + case "<" => + (args, tpe) match { + case (List(a, b), BooleanType) => LessThan(fromSMT(a, IntegerType), fromSMT(b, IntegerType)) + } + + case "+" => + (args, tpe) match { + case (List(a, b), IntegerType) => Plus(fromSMT(a, IntegerType), fromSMT(b, IntegerType)) + case (List(a, b), RealType) => RealPlus(fromSMT(a, RealType), fromSMT(b, RealType)) + } + + case "not" => + (args, tpe) match { + case (List(a), BooleanType) => Not(fromSMT(a, BooleanType)) + } + + case "or" => + (args, tpe) match { + case (List(a, b), BooleanType) => Or(fromSMT(a, BooleanType), fromSMT(b, BooleanType)) + } + + case "and" => + (args, tpe) match { + case (List(a, b), BooleanType) => And(fromSMT(a, BooleanType), fromSMT(b, BooleanType)) + } + + case "=" => + (args, tpe) match { + case (List(a, b), BooleanType) => + val ra = fromUntypedSMT(a) + Equals(ra, fromSMT(b, ra.getType)) + } + + case "*" => + (args, tpe) match { + case (List(a, b), IntegerType) => Times(fromSMT(a, IntegerType), fromSMT(b, IntegerType)) + case (List(a, b), RealType) => RealTimes(fromSMT(a, RealType), fromSMT(b, RealType)) + } + + case "-" => + (args, tpe) match { + case (List(a), IntegerType) => UMinus(fromSMT(a, IntegerType)) + case (List(a, b), IntegerType) => Minus(fromSMT(a, IntegerType), fromSMT(b, IntegerType)) + case (List(a), RealType) => RealUMinus(fromSMT(a, RealType)) + case (List(a, b), RealType) => RealMinus(fromSMT(a, RealType), fromSMT(b, RealType)) + } + case "/" => + (args, tpe) match { + case (List(a, b), RealType) => RealDivision(fromSMT(a, RealType), fromSMT(b, RealType)) + } + + case _ => + reporter.fatalError("Function "+app+" not handled in fromSMT: "+s) + } + } + case (QualifiedIdentifier(id, sort), tpe) => + reporter.fatalError("Unhandled case in fromSMT: " + id +": "+sort +" ("+tpe+")") + case _ => + reporter.fatalError("Unhandled case in fromSMT: " + (s, tpe)) + } + +} + +// Unique numbers +private [smtlib] object DebugFileNumbers extends UniqueCounter[String] diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBUnsupportedError.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBUnsupportedError.scala new file mode 100644 index 000000000..0a48c2988 --- /dev/null +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBUnsupportedError.scala @@ -0,0 +1,10 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package solvers +package smtlib + +import purescala.Common.Tree + +case class SMTLIBUnsupportedError(t: Tree, s: SMTLIBTarget, reason: Option[String] = None) + extends Unsupported(t, s" is unsupported by ${s.targetName}" + reason.map(":\n " + _ ).getOrElse(""))(s.context) diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3QuantifiedSolver.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3QuantifiedSolver.scala index 9a8e60e6b..aa78b97f0 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3QuantifiedSolver.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3QuantifiedSolver.scala @@ -3,74 +3,12 @@ package leon package solvers.smtlib -import leon.solvers.SolverUnsupportedError -import purescala._ -import DefOps._ -import Definitions._ -import Expressions._ -import Constructors._ -import smtlib.parser.Commands.{Assert => SMTAssert} -import smtlib.parser.Terms.{Forall => SMTForall, SSymbol} +import purescala.Definitions.Program /** * This solver models function definitions as universally quantified formulas. * It is not meant as an underlying solver to UnrollingSolver, and does not handle HOFs. */ -class SMTLIBZ3QuantifiedSolver(context: LeonContext, program: Program) - extends SMTLIBZ3Solver(context, program) +class SMTLIBZ3QuantifiedSolver(context: LeonContext, program: Program) extends SMTLIBZ3Solver(context, program) with SMTLIBQuantifiedSolver -{ - - protected val allowQuantifiedAssertions: Boolean = true - - override def targetName = "z3-q" - - override def declareFunction(tfd: TypedFunDef): SSymbol = { - - val (funs, exploredAll) = typedTransitiveCallees(Set(tfd), Some(typedFunDefExplorationLimit)) - if (!exploredAll) { - reporter.warning( - s"Did not manage to explore the space of typed functions called from ${tfd.id}. The solver may fail" - ) - } - - val notSeen = funs.toSeq filterNot functions.containsA - - val smtFunDecls = notSeen map super.declareFunction - - smtFunDecls foreach { sym => - val tfd = functions.toA(sym) - val term = quantifiedTerm( - SMTForall, - tfd.params map { _.id }, - Equals( - FunctionInvocation(tfd, tfd.params.map {_.toVariable}), - tfd.body.get - ) - )(Map()) - sendCommand(SMTAssert(term)) - } - - // If we encounter a function that does not refer to the current function, - // it is sound to assume its contracts for all inputs - if (allowQuantifiedAssertions) for { - tfd <- notSeen if !refersToCurrent(tfd.fd) - post <- tfd.postcondition - } { - val term = implies( - tfd.precondition getOrElse BooleanLiteral(true), - application(post, Seq(tfd.applied)) - ) - try { - sendCommand(SMTAssert(quantifiedTerm(SMTForall, term)(Map()))) - } catch { - case _ : SolverUnsupportedError => - addError() - } - } - - functions.toB(tfd) - - } - -} + with SMTLIBZ3QuantifiedTarget diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3QuantifiedTarget.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3QuantifiedTarget.scala new file mode 100644 index 000000000..6c20b337d --- /dev/null +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3QuantifiedTarget.scala @@ -0,0 +1,76 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package solvers +package smtlib + +import purescala.Common._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.Extractors._ +import purescala.Constructors._ +import purescala.Types._ +import purescala.Definitions._ +import purescala.DefOps.typedTransitiveCallees + +import _root_.smtlib.parser.Commands.{Assert => SMTAssert, FunDef => _, _} +import _root_.smtlib.parser.Terms.{Exists => SMTExists, Forall => SMTForall, _ } +import _root_.smtlib.theories.Core.{Equals => SMTEquals} +import _root_.smtlib.parser.Commands._ + +trait SMTLIBZ3QuantifiedTarget extends SMTLIBZ3Target with SMTLIBQuantifiedTarget { + + protected val allowQuantifiedAssertions: Boolean = true + + override def targetName = "z3-q" + + override def declareFunction(tfd: TypedFunDef): SSymbol = { + + val (funs, exploredAll) = typedTransitiveCallees(Set(tfd), Some(typedFunDefExplorationLimit)) + + if (!exploredAll) { + reporter.warning( + "Did not manage to explore the space of typed functions " + + s"transitively called from ${tfd.id}. The solver may fail" + ) + } + + val notSeen = funs.toSeq filterNot functions.containsA + + val smtFunDecls = notSeen map super.declareFunction + + smtFunDecls foreach { sym => + val tfd = functions.toA(sym) + val term = quantifiedTerm( + SMTForall, + tfd.params map { _.id }, + Equals( + FunctionInvocation(tfd, tfd.params.map {_.toVariable}), + tfd.body.get + ) + )(Map()) + emit(SMTAssert(term)) + } + + // If we encounter a function that does not refer to the current function, + // it is sound to assume its contracts for all inputs + if (allowQuantifiedAssertions) for { + tfd <- notSeen if !refersToCurrent(tfd.fd) + post <- tfd.postcondition + } { + val term = implies( + tfd.precondition getOrElse BooleanLiteral(true), + application(post, Seq(tfd.applied)) + ) + try { + emit(SMTAssert(quantifiedTerm(SMTForall, term)(Map()))) + } catch { + case _ : SMTLIBUnsupportedError => + addError() + } + } + + functions.toB(tfd) + + } +} diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Solver.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Solver.scala index eff01e304..0e289ff80 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Solver.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Solver.scala @@ -4,166 +4,21 @@ package leon package solvers package smtlib -import purescala._ -import Common._ -import Definitions.Program -import Expressions._ -import Types._ -import Constructors._ +import purescala.Definitions.Program +import purescala.Common.Identifier +import purescala.Expressions.Expr import _root_.smtlib.parser.Terms.{Identifier => SMTIdentifier, _} import _root_.smtlib.parser.Commands.{FunDef => SMTFunDef, _} -import _root_.smtlib.interpreters.Z3Interpreter import _root_.smtlib.parser.CommandsResponses.GetModelResponseSuccess import _root_.smtlib.theories.Core.{Equals => SMTEquals, _} -import _root_.smtlib.theories.ArraysEx -class SMTLIBZ3Solver(context: LeonContext, program: Program) extends SMTLIBSolver(context, program) { - - def targetName = "z3" - - def interpreterOps(ctx: LeonContext) = { - Seq( - "-in", - "-smt2" - ) - } - - def getNewInterpreter(ctx: LeonContext) = new Z3Interpreter("z3", interpreterOps(ctx).toArray) - - protected val extSym = SSymbol("_") - - protected var setSort: Option[SSymbol] = None - - override protected def declareSort(t: TypeTree): Sort = { - val tpe = normalizeType(t) - sorts.cachedB(tpe) { - tpe match { - case SetType(base) => - super.declareSort(BooleanType) - declareSetSort(base) - case _ => - super.declareSort(t) - } - } - } - - protected def declareSetSort(of: TypeTree): Sort = { - setSort match { - case None => - val s = SSymbol("Set") - val t = SSymbol("T") - setSort = Some(s) - - val arraySort = Sort(SMTIdentifier(SSymbol("Array")), - Seq(Sort(SMTIdentifier(t)), BoolSort())) - - val cmd = DefineSort(s, Seq(t), arraySort) - sendCommand(cmd) - case _ => - } - - Sort(SMTIdentifier(setSort.get), Seq(declareSort(of))) - } - - override protected def fromSMT(s: Term, tpe: TypeTree)(implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): Expr = (s, tpe) match { - case (SimpleSymbol(s), tp: TypeParameter) => - val n = s.name.split("!").toList.last - GenericValue(tp, n.toInt) - - - case (QualifiedIdentifier(ExtendedIdentifier(SSymbol("as-array"), k: SSymbol), _), tpe) => - if (letDefs contains k) { - // Need to recover value form function model - fromRawArray(extractRawArray(letDefs(k), tpe), tpe) - } else { - throw LeonFatalError("Array on non-function or unknown symbol "+k) - } - - case (FunctionApplication( - QualifiedIdentifier(SMTIdentifier(SSymbol("const"), _), Some(ArraysEx.ArraySort(k, v))), - Seq(defV) - ), tpe) => - val ktpe = sorts.fromB(k) - val vtpe = sorts.fromB(v) - - fromRawArray(RawArrayValue(ktpe, Map(), fromSMT(defV, vtpe)), tpe) - - case _ => - super.fromSMT(s, tpe) - } - - override protected def toSMT(e: Expr)(implicit bindings: Map[Identifier, Term]): Term = e match { - - /** - * ===== Set operations ===== - */ - case fs @ FiniteSet(elems, base) => - declareSort(fs.getType) - - toSMT(RawArrayValue(base, elems.map { - case k => k -> BooleanLiteral(true) - }.toMap, BooleanLiteral(false))) - - case SubsetOf(ss, s) => - // a isSubset b ==> (a zip b).map(implies) == (* => true) - val allTrue = ArrayConst(declareSort(s.getType), True()) - - SMTEquals(ArrayMap(SSymbol("implies"), toSMT(ss), toSMT(s)), allTrue) - - case ElementOfSet(e, s) => - ArraysEx.Select(toSMT(s), toSMT(e)) - - case SetDifference(a, b) => - // a -- b - // becomes: - // a && not(b) - - ArrayMap(SSymbol("and"), toSMT(a), ArrayMap(SSymbol("not"), toSMT(b))) - - case SetUnion(l, r) => - ArrayMap(SSymbol("or"), toSMT(l), toSMT(r)) - - case SetIntersection(l, r) => - ArrayMap(SSymbol("and"), toSMT(l), toSMT(r)) - - case _ => - super.toSMT(e) - } - - protected def extractRawArray(s: DefineFun, tpe: TypeTree)(implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): RawArrayValue = s match { - case DefineFun(SMTFunDef(a, Seq(SortedVar(arg, akind)), rkind, body)) => - val (argTpe, retTpe) = tpe match { - case SetType(base) => (base, BooleanType) - case MapType(from, to) => (from, library.optionType(to)) - case ArrayType(base) => (Int32Type, base) - case FunctionType(args, ret) => (tupleTypeWrap(args), ret) - case RawArrayType(from, to) => (from, to) - case _ => unsupported(tpe, "Unsupported type for (un)packing into raw arrays (got kinds "+akind+" -> "+rkind+")") - } - - def extractCases(e: Term): (Map[Expr, Expr], Expr) = e match { - case ITE(SMTEquals(SimpleSymbol(`arg`), k), v, e) => - val (cs, d) = extractCases(e) - (Map(fromSMT(k, argTpe) -> fromSMT(v, retTpe)) ++ cs, d) - case e => - (Map(),fromSMT(e, retTpe)) - } - - val (cases, default) = extractCases(body) - - RawArrayValue(argTpe, cases, default) - - case _ => - throw LeonFatalError("Unable to extract "+s) - } +class SMTLIBZ3Solver(context: LeonContext, program: Program) extends SMTLIBSolver(context, program) with SMTLIBZ3Target { // EK: We use get-model instead in order to extract models for arrays override def getModel: Model = { - val cmd = GetModel() - - val res = sendCommand(cmd) + val res = emit(GetModel()) val smodel: Seq[SExpr] = res match { case GetModelResponseSuccess(model) => model @@ -202,20 +57,4 @@ class SMTLIBZ3Solver(context: LeonContext, program: Program) extends SMTLIBSolve new Model(model) } - protected object ArrayMap { - def apply(op: SSymbol, arrs: Term*) = { - FunctionApplication( - QualifiedIdentifier(SMTIdentifier(SSymbol("map"), List(op))), - arrs - ) - } - } - - protected object ArrayConst { - def apply(sort: Sort, default: Term) = { - FunctionApplication( - QualifiedIdentifier(SMTIdentifier(SSymbol("const")), Some(sort)), - List(default)) - } - } } diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala new file mode 100644 index 000000000..c0c696617 --- /dev/null +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala @@ -0,0 +1,183 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package solvers +package smtlib + +import purescala.Common._ +import purescala.Expressions._ +import purescala.Extractors._ +import purescala.Constructors._ +import purescala.Types._ + +import _root_.smtlib.parser.Terms.{Identifier => SMTIdentifier, _} +import _root_.smtlib.parser.Commands.{FunDef => SMTFunDef, _} +import _root_.smtlib.interpreters.Z3Interpreter +import _root_.smtlib.parser.CommandsResponses.GetModelResponseSuccess +import _root_.smtlib.theories.Core.{Equals => SMTEquals, _} +import _root_.smtlib.theories.ArraysEx + + +trait SMTLIBZ3Target extends SMTLIBTarget { + + def targetName = "z3" + + def interpreterOps(ctx: LeonContext) = { + Seq( + "-in", + "-smt2" + ) + } + + def getNewInterpreter(ctx: LeonContext) = { + val opts = interpreterOps(ctx) + reporter.debug("Invoking solver "+targetName+" with "+opts.mkString(" ")) + + new Z3Interpreter("z3", opts.toArray) + } + + protected val extSym = SSymbol("_") + + protected var setSort: Option[SSymbol] = None + + override protected def declareSort(t: TypeTree): Sort = { + val tpe = normalizeType(t) + sorts.cachedB(tpe) { + tpe match { + case SetType(base) => + super.declareSort(BooleanType) + declareSetSort(base) + case _ => + super.declareSort(t) + } + } + } + + protected def declareSetSort(of: TypeTree): Sort = { + setSort match { + case None => + val s = SSymbol("Set") + val t = SSymbol("T") + setSort = Some(s) + + val arraySort = Sort(SMTIdentifier(SSymbol("Array")), + Seq(Sort(SMTIdentifier(t)), BoolSort())) + + val cmd = DefineSort(s, Seq(t), arraySort) + emit(cmd) + + case _ => + } + + Sort(SMTIdentifier(setSort.get), Seq(declareSort(of))) + } + + override protected def fromSMT(s: Term, tpe: TypeTree)(implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): Expr = (s, tpe) match { + case (SimpleSymbol(s), tp: TypeParameter) => + val n = s.name.split("!").toList.last + GenericValue(tp, n.toInt) + + + case (QualifiedIdentifier(ExtendedIdentifier(SSymbol("as-array"), k: SSymbol), _), tpe) => + if (letDefs contains k) { + // Need to recover value form function model + fromRawArray(extractRawArray(letDefs(k), tpe), tpe) + } else { + throw LeonFatalError("Array on non-function or unknown symbol "+k) + } + + case (FunctionApplication( + QualifiedIdentifier(SMTIdentifier(SSymbol("const"), _), Some(ArraysEx.ArraySort(k, v))), + Seq(defV) + ), tpe) => + val ktpe = sorts.fromB(k) + val vtpe = sorts.fromB(v) + + fromRawArray(RawArrayValue(ktpe, Map(), fromSMT(defV, vtpe)), tpe) + + case _ => + super.fromSMT(s, tpe) + } + + override protected def toSMT(e: Expr)(implicit bindings: Map[Identifier, Term]): Term = e match { + + /** + * ===== Set operations ===== + */ + case fs @ FiniteSet(elems, base) => + declareSort(fs.getType) + + toSMT(RawArrayValue(base, elems.map { + case k => k -> BooleanLiteral(true) + }.toMap, BooleanLiteral(false))) + + case SubsetOf(ss, s) => + // a isSubset b ==> (a zip b).map(implies) == (* => true) + val allTrue = ArrayConst(declareSort(s.getType), True()) + + SMTEquals(ArrayMap(SSymbol("implies"), toSMT(ss), toSMT(s)), allTrue) + + case ElementOfSet(e, s) => + ArraysEx.Select(toSMT(s), toSMT(e)) + + case SetDifference(a, b) => + // a -- b + // becomes: + // a && not(b) + + ArrayMap(SSymbol("and"), toSMT(a), ArrayMap(SSymbol("not"), toSMT(b))) + + case SetUnion(l, r) => + ArrayMap(SSymbol("or"), toSMT(l), toSMT(r)) + + case SetIntersection(l, r) => + ArrayMap(SSymbol("and"), toSMT(l), toSMT(r)) + + case _ => + super.toSMT(e) + } + + protected def extractRawArray(s: DefineFun, tpe: TypeTree)(implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): RawArrayValue = s match { + case DefineFun(SMTFunDef(a, Seq(SortedVar(arg, akind)), rkind, body)) => + val (argTpe, retTpe) = tpe match { + case SetType(base) => (base, BooleanType) + case MapType(from, to) => (from, library.optionType(to)) + case ArrayType(base) => (Int32Type, base) + case FunctionType(args, ret) => (tupleTypeWrap(args), ret) + case RawArrayType(from, to) => (from, to) + case _ => unsupported(tpe, "Unsupported type for (un)packing into raw arrays (got kinds "+akind+" -> "+rkind+")") + } + + def extractCases(e: Term): (Map[Expr, Expr], Expr) = e match { + case ITE(SMTEquals(SimpleSymbol(`arg`), k), v, e) => + val (cs, d) = extractCases(e) + (Map(fromSMT(k, argTpe) -> fromSMT(v, retTpe)) ++ cs, d) + case e => + (Map(),fromSMT(e, retTpe)) + } + + val (cases, default) = extractCases(body) + + RawArrayValue(argTpe, cases, default) + + case _ => + throw LeonFatalError("Unable to extract "+s) + } + + protected object ArrayMap { + def apply(op: SSymbol, arrs: Term*) = { + FunctionApplication( + QualifiedIdentifier(SMTIdentifier(SSymbol("map"), List(op))), + arrs + ) + } + } + + protected object ArrayConst { + def apply(sort: Sort, default: Term) = { + FunctionApplication( + QualifiedIdentifier(SMTIdentifier(SSymbol("const")), Some(sort)), + List(default)) + } + } +} diff --git a/src/main/scala/leon/solvers/sygus/CVC4SygusSolver.scala b/src/main/scala/leon/solvers/sygus/CVC4SygusSolver.scala new file mode 100644 index 000000000..a6aa654fb --- /dev/null +++ b/src/main/scala/leon/solvers/sygus/CVC4SygusSolver.scala @@ -0,0 +1,29 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package solvers +package sygus + +import purescala._ +import Definitions.Program + +import synthesis.Problem + +import leon.solvers.smtlib._ + +import _root_.smtlib.interpreters.CVC4Interpreter + +class CVC4SygusSolver(ctx: LeonContext, pgm: Program, p: Problem) extends SygusSolver(ctx, pgm, p) with SMTLIBCVC4QuantifiedTarget { + override def targetName = "cvc4-sygus"; + + def interpreterOps(ctx: LeonContext) = { + Seq( + "-q", + "--cegqi-si", + "--lang", "sygus", + "--print-success" + ) + } + + protected val allowQuantifiedAssertions: Boolean = true +} diff --git a/src/main/scala/leon/solvers/sygus/SygusSolver.scala b/src/main/scala/leon/solvers/sygus/SygusSolver.scala new file mode 100644 index 000000000..4ca63c6f7 --- /dev/null +++ b/src/main/scala/leon/solvers/sygus/SygusSolver.scala @@ -0,0 +1,131 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package solvers +package sygus + +import utils._ +import grammars._ + +import purescala.Common._ +import purescala.Definitions._ +import purescala.Types._ +import purescala.ExprOps._ +import purescala.Extractors._ +import purescala.Constructors._ +import purescala.Expressions._ + +import scala.collection.mutable.ArrayBuffer + +import synthesis.Problem + +import leon.solvers.smtlib._ + +import _root_.smtlib.common._ +import _root_.smtlib.parser.Terms.{Identifier => SMTIdentifier, _} +import _root_.smtlib.parser.Commands.{FunDef => SMTFunDef, _} +import _root_.smtlib.parser.CommandsResponses.{Error => _, _} + +abstract class SygusSolver(val context: LeonContext, val program: Program, val p: Problem) extends SMTLIBTarget { + implicit val ctx = context + implicit val debugSection = leon.utils.DebugSectionSynthesis + + val reporter = context.reporter + + protected def unsupported(t: Tree, str: String): Nothing = { + throw new Unsupported(t, str) + } + + def checkSynth(): Option[Expr] = { + val out = p.xs.head + val c = FreshIdentifier("c") + val fd = new FunDef(c, Seq(), out.getType, p.as.map(a => ValDef(a))) + + val bindings = p.as.map(a => a -> (symbolToQualifiedId(id2sym(a)): Term)).toMap + + val constraintId = QualifiedIdentifier(SMTIdentifier(SSymbol("constraint"))) + + emit(SList(SSymbol("set-logic"), SSymbol("ALL_SUPPORTED"))) + + val fsym = id2sym(fd.id) + + functions += fd.typed -> fsym + + // declare function to synthesize + emit(SList(SSymbol("synth-fun"), id2sym(fd.id), SList(fd.params.map(vd => SList(id2sym(vd.id), toSMT(vd.getType))) :_*), toSMT(out.getType))) + + + // declare inputs + for (a <- p.as) { + emit(SList(SSymbol("declare-var"), id2sym(a), toSMT(a.getType))) + variables += a -> id2sym(a) + } + + val synthPhi = replaceFromIDs(Map(out -> FunctionInvocation(fd.typed, p.as.map(_.toVariable))), p.phi) + + val TopLevelAnds(clauses) = synthPhi + + for(c <- clauses) { + emit(FunctionApplication(constraintId, Seq(toSMT(c)(bindings)))) + } + + emit(SList(SSymbol("check-synth"))) // check-synth emits: success; unsat; fdef + + // We currently cannot predict the amount of success we will get, so we read as many as possible + var lastRes = interpreter.parser.parseSExpr + while(lastRes == SSymbol("success")) { + lastRes = interpreter.parser.parseSExpr + } + + lastRes match { + case SSymbol("unsat") => + interpreter.parser.parseCommand match { + case DefineFun(SMTFunDef(name, params, retSort, body)) => + val res = fromSMT(body, sorts.toA(retSort))(Map(), Map()) + Some(res) + case r => + reporter.warning("Unnexpected result from cvc4-sygus: "+r) + None + } + case SSymbol("unknown") => + None + + case r => + reporter.warning("Unnexpected result from cvc4-sygus: "+r+" expected unsat") + None + } + } +} + + //val g = BaseGrammar || OneOf(p.as.map(_.toVariable)) + + //type Label = TypeTree + + //var defined = Map[Label, Identifier]() + //var definitions = new ArrayBuffer[(Identifier, Seq[Expr])]() + + //def getLabel(l: Label): Identifier = defined.getOrElse(l, { + // val id = FreshIdentifier(l.getType.toString, l.getType) + // defined += l -> id + + // val defs = g.getProductions(l).map { g => + // g.builder(g.subTrees.map(getLabel(_).toVariable)) + // } + + // definitions += (id -> defs) + + // id + //}) + + //// discover grammar + //getLabel(out.getType) + + //val grammarBindings: Map[Identifier, Term] = defined.map{ case (_, id) => + // id -> QualifiedIdentifier(SMTIdentifier(idToSymbol(id))) + //} + + //// define grammar + //val grammar = SList((for ((id, ds) <- definitions) yield { + // SList(idToSymbol(id), typeToSort(id.getType), SList(ds.map(d => smt.toSMT(d)(bindings++grammarBindings)): _*)) + //}).toList) + diff --git a/src/main/scala/leon/synthesis/Rules.scala b/src/main/scala/leon/synthesis/Rules.scala index 1d2696772..51a5a15ec 100644 --- a/src/main/scala/leon/synthesis/Rules.scala +++ b/src/main/scala/leon/synthesis/Rules.scala @@ -41,7 +41,7 @@ object Rules { Unification.OccursCheck, // probably useless Disunification.Decomp, ADTDual, - OnePoint, + //OnePoint, Ground, CaseSplit, IfSplit, diff --git a/src/main/scala/leon/synthesis/Solution.scala b/src/main/scala/leon/synthesis/Solution.scala index 3a44131eb..e87a7883d 100644 --- a/src/main/scala/leon/synthesis/Solution.scala +++ b/src/main/scala/leon/synthesis/Solution.scala @@ -65,6 +65,10 @@ object Solution { new Solution(simplify(pre), defs, simplify(term), isTrusted) } + def term(term: Expr, isTrusted: Boolean = true) = { + new Solution(BooleanLiteral(true), Set(), simplify(term), isTrusted) + } + def unapply(s: Solution): Option[(Expr, Set[FunDef], Expr)] = if (s eq null) None else Some((s.pre, s.defs, s.term)) def choose(p: Problem): Solution = { diff --git a/src/main/scala/leon/synthesis/SynthesisPhase.scala b/src/main/scala/leon/synthesis/SynthesisPhase.scala index 73820eb3c..9ee0febaa 100644 --- a/src/main/scala/leon/synthesis/SynthesisPhase.scala +++ b/src/main/scala/leon/synthesis/SynthesisPhase.scala @@ -53,7 +53,7 @@ object SynthesisPhase extends LeonPhase[Program, Program] { timeoutMs = timeout map { _ * 1000 }, generateDerivationTrees = ctx.findOptionOrDefault(optDerivTrees), costModel = costModel, - rules = Rules.all ++ (ms map { _ => rules.AsChoose}), + rules = Rules.all ++ (if(ms.isDefined) Seq(rules.AsChoose, rules.Sygus) else Seq()), manualSearch = ms, functions = ctx.findOption(SharedOptions.optFunctions) map { _.toSet }, cegisUseOptTimeout = ctx.findOption(optCEGISOptTimeout), diff --git a/src/main/scala/leon/synthesis/rules/Sygus.scala b/src/main/scala/leon/synthesis/rules/Sygus.scala new file mode 100644 index 000000000..29f75d1b4 --- /dev/null +++ b/src/main/scala/leon/synthesis/rules/Sygus.scala @@ -0,0 +1,36 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package synthesis +package rules + +import purescala.Types._ +import solvers.sygus._ + +import grammars._ +import utils._ + +case object Sygus extends Rule("Sygus") { + def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { + if (p.xs.size != 1) { + Nil + } else { + List(new RuleInstantiation(this.name) { + def apply(hctx: SearchContext): RuleApplication = { + + val sctx = hctx.sctx + val grammar = Grammars.default(sctx, p) + + val s = new CVC4SygusSolver(sctx.context, sctx.program, p) + + s.checkSynth() match { + case Some(expr) => + RuleClosed(Solution.term(expr)) + case None => + RuleFailed() + } + } + }) + } + } +} diff --git a/src/main/scala/leon/utils/DebugFiles.scala b/src/main/scala/leon/utils/DebugFiles.scala new file mode 100644 index 000000000..1f006c423 --- /dev/null +++ b/src/main/scala/leon/utils/DebugFiles.scala @@ -0,0 +1,3 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon.utils diff --git a/testcases/synthesis/sygus/listqueue.scala b/testcases/synthesis/sygus/listqueue.scala new file mode 100644 index 000000000..d3783aceb --- /dev/null +++ b/testcases/synthesis/sygus/listqueue.scala @@ -0,0 +1,64 @@ +import leon._ +import leon.lang._ +import leon.collection._ +import leon.lang.synthesis._ + +object ListQueue { + + case class Queue[T](in: List[T], out: List[T]) { + def toListOut: List[T] = { + out ++ in.reverse + } + + def toListIn: List[T] = { + in ++ out.reverse + } + + def size: BigInt = { + in.size + out.size + } ensuring { + _ >= 0 + } + + def content: Set[T] = { + in.content ++ out.content + } + + + def isEmpty: Boolean = { + + ???[Boolean] + + } ensuring { + res => res == (in == Nil[T]() && out == Nil[T]()) + } + + def enqueue(t: T): Queue[T] = { + + ???[Queue[T]] // Queue(Cons(t, in), out) + + } ensuring { res => + (res.size == size + 1) && + (res.content == content ++ Set(t)) && + (res.toListIn == Cons(t, toListIn)) + } + + def dequeue(): (Queue[T], T) = { + require(in.nonEmpty || out.nonEmpty) + + out match { + case Cons(h, t) => + (Queue(in, t), h) + case Nil() => + Queue(Nil(), in.reverse).dequeue() + } + } ensuring { resAndT => + val res = resAndT._1 + val t = resAndT._2 + + (res.size == size - 1) && + (content contains t) && + (Cons(t, res.toListOut) == toListOut) + } + } +} diff --git a/testcases/synthesis/sygus/max2.scala b/testcases/synthesis/sygus/max2.scala new file mode 100644 index 000000000..aebef892a --- /dev/null +++ b/testcases/synthesis/sygus/max2.scala @@ -0,0 +1,9 @@ +import leon.lang._ +import leon.lang.synthesis._ + + +object Sort { + def max2(a: BigInt, b: BigInt): BigInt = { + choose((x: BigInt) => x >= a && x >= b && (x == a || x == b)) + } +} diff --git a/testcases/synthesis/sygus/numerals1.scala b/testcases/synthesis/sygus/numerals1.scala new file mode 100644 index 000000000..a61c7166f --- /dev/null +++ b/testcases/synthesis/sygus/numerals1.scala @@ -0,0 +1,16 @@ +import leon.lang._ +import leon.lang.synthesis._ + +object Numerals { + abstract class N + case class S(succ: N) extends N + case object Z extends N + + + def plusone(a: N): N = { + choose((x: N) => x match { + case S(succ) => succ == a + case Z => false + }) + } +} diff --git a/testcases/synthesis/sygus/plusone.scala b/testcases/synthesis/sygus/plusone.scala new file mode 100644 index 000000000..6da981378 --- /dev/null +++ b/testcases/synthesis/sygus/plusone.scala @@ -0,0 +1,9 @@ +import leon.lang._ +import leon.lang.synthesis._ + + +object Sort { + def plusone(a: BigInt): BigInt = { + choose((x: BigInt) => x > a) + } +} -- GitLab