Skip to content
Snippets Groups Projects
Commit ce9875ff authored by Etienne Kneuss's avatar Etienne Kneuss
Browse files

Use --macro-quant, support >1 outputs

parent 5dad01db
No related branches found
No related tags found
No related merge requests found
......@@ -20,6 +20,7 @@ class CVC4SygusSolver(ctx: LeonContext, pgm: Program, p: Problem) extends SygusS
Seq(
"-q",
"--cegqi-si",
"--macros-quant",
"--lang", "sygus",
"--print-success"
)
......
......@@ -37,23 +37,12 @@ abstract class SygusSolver(val context: LeonContext, val program: Program, val p
}
def checkSynth(): Option[Expr] = {
val out = p.xs.head
val c = FreshIdentifier("c")
val fd = new FunDef(c, Seq(), out.getType, p.as.map(a => ValDef(a)))
val bindings = p.as.map(a => a -> (symbolToQualifiedId(id2sym(a)): Term)).toMap
val constraintId = QualifiedIdentifier(SMTIdentifier(SSymbol("constraint")))
emit(SList(SSymbol("set-logic"), SSymbol("ALL_SUPPORTED")))
val fsym = id2sym(fd.id)
functions += fd.typed -> fsym
// declare function to synthesize
emit(SList(SSymbol("synth-fun"), id2sym(fd.id), SList(fd.params.map(vd => SList(id2sym(vd.id), toSMT(vd.getType))) :_*), toSMT(out.getType)))
val constraintId = QualifiedIdentifier(SMTIdentifier(SSymbol("constraint")))
val bindings = p.as.map(a => a -> (symbolToQualifiedId(id2sym(a)): Term)).toMap
// declare inputs
for (a <- p.as) {
......@@ -61,7 +50,23 @@ abstract class SygusSolver(val context: LeonContext, val program: Program, val p
variables += a -> id2sym(a)
}
val synthPhi = replaceFromIDs(Map(out -> FunctionInvocation(fd.typed, p.as.map(_.toVariable))), p.phi)
// declare outputs
val xToFd = for (x <- p.xs) yield {
val fd = new FunDef(x.freshen, Seq(), x.getType, p.as.map(a => ValDef(a)))
val fsym = id2sym(fd.id)
functions += fd.typed -> fsym
// declare function to synthesize
emit(SList(SSymbol("synth-fun"), id2sym(fd.id), SList(fd.params.map(vd => SList(id2sym(vd.id), toSMT(vd.getType))) :_*), toSMT(fd.returnType)))
x -> fd
}
val xToFdCall = xToFd.toMap.mapValues(fd => FunctionInvocation(fd.typed, p.as.map(_.toVariable)))
val synthPhi = replaceFromIDs(xToFdCall, p.phi)
val TopLevelAnds(clauses) = synthPhi
......@@ -69,7 +74,7 @@ abstract class SygusSolver(val context: LeonContext, val program: Program, val p
emit(FunctionApplication(constraintId, Seq(toSMT(c)(bindings))))
}
emit(SList(SSymbol("check-synth"))) // check-synth emits: success; unsat; fdef
emit(SList(SSymbol("check-synth"))) // check-synth emits: success; unsat; fdef*
// We currently cannot predict the amount of success we will get, so we read as many as possible
var lastRes = interpreter.parser.parseSExpr
......@@ -79,14 +84,24 @@ abstract class SygusSolver(val context: LeonContext, val program: Program, val p
lastRes match {
case SSymbol("unsat") =>
interpreter.parser.parseCommand match {
case DefineFun(SMTFunDef(name, params, retSort, body)) =>
val res = fromSMT(body, sorts.toA(retSort))(Map(), Map())
Some(res)
case r =>
reporter.warning("Unnexpected result from cvc4-sygus: "+r)
None
val solutions = (for (x <- p.xs) yield {
interpreter.parser.parseCommand match {
case DefineFun(SMTFunDef(name, params, retSort, body)) =>
val res = fromSMT(body, sorts.toA(retSort))(Map(), Map())
Some(res)
case r =>
reporter.warning("Unnexpected result from cvc4-sygus: "+r)
None
}
}).flatten
if (solutions.size == p.xs.size) {
Some(tupleWrap(solutions))
} else {
None
}
case SSymbol("unknown") =>
None
......
......@@ -53,7 +53,7 @@ object SynthesisPhase extends LeonPhase[Program, Program] {
timeoutMs = timeout map { _ * 1000 },
generateDerivationTrees = ctx.findOptionOrDefault(optDerivTrees),
costModel = costModel,
rules = Rules.all ++ (if(ms.isDefined) Seq(rules.AsChoose, rules.Sygus) else Seq()),
rules = Rules.all ++ (if(ms.isDefined) Seq(rules.AsChoose, rules.SygusCVC4) else Seq()),
manualSearch = ms,
functions = ctx.findOption(SharedOptions.optFunctions) map { _.toSet },
cegisUseOptTimeout = ctx.findOption(optCEGISOptTimeout),
......
......@@ -10,27 +10,23 @@ import solvers.sygus._
import grammars._
import utils._
case object Sygus extends Rule("Sygus") {
case object SygusCVC4 extends Rule("SygusCVC4") {
def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = {
if (p.xs.size != 1) {
Nil
} else {
List(new RuleInstantiation(this.name) {
def apply(hctx: SearchContext): RuleApplication = {
List(new RuleInstantiation(this.name) {
def apply(hctx: SearchContext): RuleApplication = {
val sctx = hctx.sctx
val grammar = Grammars.default(sctx, p)
val sctx = hctx.sctx
val grammar = Grammars.default(sctx, p)
val s = new CVC4SygusSolver(sctx.context, sctx.program, p)
val s = new CVC4SygusSolver(sctx.context, sctx.program, p)
s.checkSynth() match {
case Some(expr) =>
RuleClosed(Solution.term(expr))
case None =>
RuleFailed()
}
s.checkSynth() match {
case Some(expr) =>
RuleClosed(Solution.term(expr))
case None =>
RuleFailed()
}
})
}
}
})
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment