-
Manos Koukoutos authoredManos Koukoutos authored
SynthesisPhase.scala 3.91 KiB
/* Copyright 2009-2015 EPFL, Lausanne */
package leon
package synthesis
import purescala.ExprOps._
import purescala.ScalaPrinter
import purescala.Definitions.{Program, FunDef}
import leon.utils.ASCIIHelpers
import graph._
object SynthesisPhase extends LeonPhase[Program, Program] {
val name = "Synthesis"
val description = "Partial synthesis of \"choose\" constructs. Also used by repair during the synthesis stage."
val optManual = LeonStringOptionDef("manual", "Manual search", default = "", "[cmd]")
val optCostModel = LeonStringOptionDef("costmodel", "Use a specific cost model for this search", "FIXME", "cm")
val optDerivTrees = LeonFlagOptionDef( "derivtrees", "Generate derivation trees", false)
// CEGIS options
val optCEGISShrink = LeonFlagOptionDef( "cegis:shrink", "Shrink non-det programs when tests pruning works well", true)
val optCEGISOptTimeout = LeonFlagOptionDef( "cegis:opttimeout", "Consider a time-out of CE-search as untrusted solution", true)
val optCEGISVanuatoo = LeonFlagOptionDef( "cegis:vanuatoo", "Generate inputs using new korat-style generator", false)
override val definedOptions : Set[LeonOptionDef[Any]] =
Set(optManual, optCostModel, optDerivTrees, optCEGISShrink, optCEGISOptTimeout, optCEGISVanuatoo)
def processOptions(ctx: LeonContext): SynthesisSettings = {
val ms = ctx.findOption(optManual)
val timeout = ctx.findOption(SharedOptions.optTimeout)
if (ms.isDefined && timeout.isDefined) {
ctx.reporter.warning("Defining timeout with manual search")
}
SynthesisSettings(
manualSearch = ms,
functions = ctx.findOption(SharedOptions.optFunctions) map { _.toSet },
timeoutMs = timeout map { _ * 1000 },
generateDerivationTrees = ctx.findOptionOrDefault(optDerivTrees),
cegisUseOptTimeout = ctx.findOption(optCEGISOptTimeout),
cegisUseShrink = ctx.findOption(optCEGISShrink),
cegisUseVanuatoo = ctx.findOption(optCEGISVanuatoo),
rules = Rules.all ++ (ms map { _ => rules.AsChoose}),
costModel = {
ctx.findOption(optCostModel) match {
case None => CostModels.default
case Some(name) => CostModels.all.find(_.name.toLowerCase == name.toLowerCase) match {
case Some(model) => model
case None =>
var errorMsg = "Unknown cost model: " + name + "\n" +
"Defined cost models: \n"
for (cm <- CostModels.all.toSeq.sortBy(_.name)) {
errorMsg += " - " + cm.name + (if (cm == CostModels.default) " (default)" else "") + "\n"
}
ctx.reporter.fatalError(errorMsg)
}
}
}
)
}
def run(ctx: LeonContext)(p: Program): Program = {
val options = processOptions(ctx)
def excludeByDefault(fd: FunDef): Boolean = fd.annotations contains "library"
val fdFilter = {
import OptionsHelpers._
val ciTofd = { (ci: ChooseInfo) => ci.fd }
filterInclusive(options.functions.map(fdMatcher), Some(excludeByDefault _)) compose ciTofd
}
val chooses = ChooseInfo.extractFromProgram(p).filter(fdFilter)
var functions = Set[FunDef]()
chooses.foreach { ci =>
val synthesizer = new Synthesizer(ctx, p, ci, options)
val (search, solutions) = synthesizer.validate(synthesizer.synthesize(), true)
try {
val fd = ci.fd
if (options.generateDerivationTrees) {
val dot = new DotGenerator(search.g)
dot.writeFile("derivation"+DotGenerator.nextId()+".dot")
}
val (sol, _) = solutions.head
val expr = sol.toSimplifiedExpr(ctx, p)
fd.body = fd.body.map(b => replace(Map(ci.source -> expr), b))
functions += fd
} finally {
synthesizer.shutdown()
}
}
for (fd <- functions) {
ctx.reporter.info(ASCIIHelpers.title(fd.id.name))
ctx.reporter.info(ScalaPrinter(fd))
ctx.reporter.info("")
}
p
}
}