-
Manos Koukoutos authoredManos Koukoutos authored
Synthesizer.scala 3.84 KiB
/* Copyright 2009-2014 EPFL, Lausanne */
package leon
package synthesis
import purescala.Common._
import purescala.Definitions.{Program, FunDef, ModuleDef, DefType, ValDef}
import purescala.TreeOps._
import purescala.Trees._
import purescala.Constructors._
import purescala.ScalaPrinter
import purescala.TypeTrees._
import solvers._
import solvers.combinators._
import solvers.z3._
import java.io.File
import synthesis.graph._
class Synthesizer(val context : LeonContext,
val program: Program,
val ci: ChooseInfo,
val settings: SynthesisSettings) {
val problem = ci.problem
val reporter = context.reporter
lazy val sctx = SynthesisContext.fromSynthesizer(this)
def getSearch(): Search = {
if (settings.manualSearch.isDefined) {
new ManualSearch(context, ci, problem, settings.costModel, settings.manualSearch)
} else if (settings.searchWorkers > 1) {
???
//new ParallelSearch(this, problem, options.searchWorkers)
} else {
new SimpleSearch(context, ci, problem, settings.costModel, settings.searchBound)
}
}
def synthesize(): (Search, Stream[Solution]) = {
val s = getSearch();
val t = context.timers.synthesis.search.start()
val sols = s.search(sctx)
val diff = t.stop()
reporter.info("Finished in "+diff+"ms")
(s, sols)
}
def validate(results: (Search, Stream[Solution])): (Search, Stream[(Solution, Boolean)]) = {
val (s, sols) = results
val result = sols.map {
case sol if sol.isTrusted =>
(sol, true)
case sol =>
validateSolution(s, sol, 5000L)
}
(s, if (result.isEmpty) {
List((new PartialSolution(s.g, true).getSolution, false)).toStream
} else {
result
})
}
def validateSolution(search: Search, sol: Solution, timeoutMs: Long): (Solution, Boolean) = {
import verification.AnalysisPhase._
import verification.VerificationContext
val ssol = sol.toSimplifiedExpr(context, program)
reporter.info("Solution requires validation")
val (npr, fds) = solutionToProgram(sol)
val solverf = SolverFactory(() => (new FairZ3Solver(context, npr) with TimeoutSolver).setTimeout(timeoutMs))
val vctx = VerificationContext(context, npr, solverf, context.reporter)
val vcs = generateVerificationConditions(vctx, Some(fds.map(_.id.name).toSeq))
val vcreport = checkVerificationConditions(vctx, vcs)
if (vcreport.totalValid == vcreport.totalConditions) {
(sol, true)
} else if (vcreport.totalValid + vcreport.totalUnknown == vcreport.totalConditions) {
reporter.warning("Solution may be invalid:")
(sol, false)
} else {
reporter.warning("Solution was invalid:")
reporter.warning(fds.map(ScalaPrinter(_)).mkString("\n\n"))
reporter.warning(vcreport.summaryString)
(new PartialSolution(search.g, false).getSolution, false)
}
}
// Returns the new program and the new functions generated for this
def solutionToProgram(sol: Solution): (Program, List[FunDef]) = {
// Create new fundef for the body
val ret = tupleTypeWrap(problem.xs.map(_.getType))
val res = Variable(FreshIdentifier("res", ret))
val mapPost: Map[Expr, Expr] = problem.xs.zipWithIndex.map{ case (id, i) =>
Variable(id) -> tupleSelect(res, i+1, problem.xs.size)
}.toMap
val fd = new FunDef(FreshIdentifier(ci.fd.id.name+"_final", alwaysShowUniqueID = true), Nil, ret, problem.as.map(ValDef(_)), DefType.MethodDef)
fd.precondition = Some(and(problem.pc, sol.pre))
fd.postcondition = Some(Lambda(Seq(ValDef(res.id)), replace(mapPost, problem.phi)))
fd.body = Some(sol.term)
val newDefs = fd +: sol.defs.toList
val npr = program.copy(units = program.units map { u =>
u.copy(modules = ModuleDef(FreshIdentifier("synthesis"), newDefs.toSeq, false) +: u.modules )
})
(npr, newDefs)
}
}