diff --git a/src/main/scala/leon/solvers/sygus/CVC4SygusSolver.scala b/src/main/scala/leon/solvers/sygus/CVC4SygusSolver.scala index a6aa654fbc9759553ba2c390d5a169ca3fe34eb1..3e05c3ee3eba90153338c210ed8a69b31340ae05 100644 --- a/src/main/scala/leon/solvers/sygus/CVC4SygusSolver.scala +++ b/src/main/scala/leon/solvers/sygus/CVC4SygusSolver.scala @@ -20,6 +20,7 @@ class CVC4SygusSolver(ctx: LeonContext, pgm: Program, p: Problem) extends SygusS Seq( "-q", "--cegqi-si", + "--macros-quant", "--lang", "sygus", "--print-success" ) diff --git a/src/main/scala/leon/solvers/sygus/SygusSolver.scala b/src/main/scala/leon/solvers/sygus/SygusSolver.scala index 4ca63c6f7a64fe4238e40a0c2e95e76064650cab..952745b338a77f5cc5db9936da368808e2b5fe03 100644 --- a/src/main/scala/leon/solvers/sygus/SygusSolver.scala +++ b/src/main/scala/leon/solvers/sygus/SygusSolver.scala @@ -37,23 +37,12 @@ abstract class SygusSolver(val context: LeonContext, val program: Program, val p } def checkSynth(): Option[Expr] = { - val out = p.xs.head - val c = FreshIdentifier("c") - val fd = new FunDef(c, Seq(), out.getType, p.as.map(a => ValDef(a))) - - val bindings = p.as.map(a => a -> (symbolToQualifiedId(id2sym(a)): Term)).toMap - - val constraintId = QualifiedIdentifier(SMTIdentifier(SSymbol("constraint"))) emit(SList(SSymbol("set-logic"), SSymbol("ALL_SUPPORTED"))) - val fsym = id2sym(fd.id) - - functions += fd.typed -> fsym - - // declare function to synthesize - emit(SList(SSymbol("synth-fun"), id2sym(fd.id), SList(fd.params.map(vd => SList(id2sym(vd.id), toSMT(vd.getType))) :_*), toSMT(out.getType))) + val constraintId = QualifiedIdentifier(SMTIdentifier(SSymbol("constraint"))) + val bindings = p.as.map(a => a -> (symbolToQualifiedId(id2sym(a)): Term)).toMap // declare inputs for (a <- p.as) { @@ -61,7 +50,23 @@ abstract class SygusSolver(val context: LeonContext, val program: Program, val p variables += a -> id2sym(a) } - val synthPhi = replaceFromIDs(Map(out -> FunctionInvocation(fd.typed, p.as.map(_.toVariable))), p.phi) + // declare outputs + val xToFd = for (x <- p.xs) yield { + val fd = new FunDef(x.freshen, Seq(), x.getType, p.as.map(a => ValDef(a))) + + val fsym = id2sym(fd.id) + + functions += fd.typed -> fsym + + // declare function to synthesize + emit(SList(SSymbol("synth-fun"), id2sym(fd.id), SList(fd.params.map(vd => SList(id2sym(vd.id), toSMT(vd.getType))) :_*), toSMT(fd.returnType))) + + x -> fd + } + + val xToFdCall = xToFd.toMap.mapValues(fd => FunctionInvocation(fd.typed, p.as.map(_.toVariable))) + + val synthPhi = replaceFromIDs(xToFdCall, p.phi) val TopLevelAnds(clauses) = synthPhi @@ -69,7 +74,7 @@ abstract class SygusSolver(val context: LeonContext, val program: Program, val p emit(FunctionApplication(constraintId, Seq(toSMT(c)(bindings)))) } - emit(SList(SSymbol("check-synth"))) // check-synth emits: success; unsat; fdef + emit(SList(SSymbol("check-synth"))) // check-synth emits: success; unsat; fdef* // We currently cannot predict the amount of success we will get, so we read as many as possible var lastRes = interpreter.parser.parseSExpr @@ -79,14 +84,24 @@ abstract class SygusSolver(val context: LeonContext, val program: Program, val p lastRes match { case SSymbol("unsat") => - interpreter.parser.parseCommand match { - case DefineFun(SMTFunDef(name, params, retSort, body)) => - val res = fromSMT(body, sorts.toA(retSort))(Map(), Map()) - Some(res) - case r => - reporter.warning("Unnexpected result from cvc4-sygus: "+r) - None + + val solutions = (for (x <- p.xs) yield { + interpreter.parser.parseCommand match { + case DefineFun(SMTFunDef(name, params, retSort, body)) => + val res = fromSMT(body, sorts.toA(retSort))(Map(), Map()) + Some(res) + case r => + reporter.warning("Unnexpected result from cvc4-sygus: "+r) + None + } + }).flatten + + if (solutions.size == p.xs.size) { + Some(tupleWrap(solutions)) + } else { + None } + case SSymbol("unknown") => None diff --git a/src/main/scala/leon/synthesis/SynthesisPhase.scala b/src/main/scala/leon/synthesis/SynthesisPhase.scala index 9ee0febaa3cc83d5ad610497627fdf8dbc9cb051..082e5dd97e8e9b57500f63d78128c5426e4f9190 100644 --- a/src/main/scala/leon/synthesis/SynthesisPhase.scala +++ b/src/main/scala/leon/synthesis/SynthesisPhase.scala @@ -53,7 +53,7 @@ object SynthesisPhase extends LeonPhase[Program, Program] { timeoutMs = timeout map { _ * 1000 }, generateDerivationTrees = ctx.findOptionOrDefault(optDerivTrees), costModel = costModel, - rules = Rules.all ++ (if(ms.isDefined) Seq(rules.AsChoose, rules.Sygus) else Seq()), + rules = Rules.all ++ (if(ms.isDefined) Seq(rules.AsChoose, rules.SygusCVC4) else Seq()), manualSearch = ms, functions = ctx.findOption(SharedOptions.optFunctions) map { _.toSet }, cegisUseOptTimeout = ctx.findOption(optCEGISOptTimeout), diff --git a/src/main/scala/leon/synthesis/rules/Sygus.scala b/src/main/scala/leon/synthesis/rules/Sygus.scala deleted file mode 100644 index 29f75d1b48b0fe90e9e335785938e2abf65bd9b8..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/rules/Sygus.scala +++ /dev/null @@ -1,36 +0,0 @@ -/* Copyright 2009-2015 EPFL, Lausanne */ - -package leon -package synthesis -package rules - -import purescala.Types._ -import solvers.sygus._ - -import grammars._ -import utils._ - -case object Sygus extends Rule("Sygus") { - def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { - if (p.xs.size != 1) { - Nil - } else { - List(new RuleInstantiation(this.name) { - def apply(hctx: SearchContext): RuleApplication = { - - val sctx = hctx.sctx - val grammar = Grammars.default(sctx, p) - - val s = new CVC4SygusSolver(sctx.context, sctx.program, p) - - s.checkSynth() match { - case Some(expr) => - RuleClosed(Solution.term(expr)) - case None => - RuleFailed() - } - } - }) - } - } -} diff --git a/src/main/scala/leon/synthesis/rules/SygusCVC4.scala b/src/main/scala/leon/synthesis/rules/SygusCVC4.scala new file mode 100644 index 0000000000000000000000000000000000000000..2e8c40dd393cca442dd9758172bbd7aa2bade23a --- /dev/null +++ b/src/main/scala/leon/synthesis/rules/SygusCVC4.scala @@ -0,0 +1,32 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package synthesis +package rules + +import purescala.Types._ +import solvers.sygus._ + +import grammars._ +import utils._ + +case object SygusCVC4 extends Rule("SygusCVC4") { + def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { + List(new RuleInstantiation(this.name) { + def apply(hctx: SearchContext): RuleApplication = { + + val sctx = hctx.sctx + val grammar = Grammars.default(sctx, p) + + val s = new CVC4SygusSolver(sctx.context, sctx.program, p) + + s.checkSynth() match { + case Some(expr) => + RuleClosed(Solution.term(expr)) + case None => + RuleFailed() + } + } + }) + } +}