diff --git a/src/main/scala/leon/purescala/TreeOps.scala b/src/main/scala/leon/purescala/TreeOps.scala index 4a5ca356c696ae6c4da4aabb4fcc8852c5b4c56a..26148a779e8683244b71b3317a141cc34706ec47 100644 --- a/src/main/scala/leon/purescala/TreeOps.scala +++ b/src/main/scala/leon/purescala/TreeOps.scala @@ -1118,7 +1118,7 @@ object TreeOps { p } - MatchExpr(scrutinee, Seq(SimpleCase(simplifyPattern(pattern), newThen), SimpleCase(WildcardPattern(None), elze))) + MatchExpr(scrutinee, Seq(SimpleCase(simplifyPattern(pattern), newThen), SimpleCase(WildcardPattern(None), elze))).setType(e.getType) } else { e } @@ -1197,6 +1197,36 @@ object TreeOps { val se = rec(e, path) Let(i, se, rec(b, Equals(Variable(i), se) +: path)) + case MatchExpr(scrut, cases) => + val rs = rec(scrut, path) + + var stillPossible = true + + if (cases.exists(_.hasGuard)) { + // unsupported for now + e + } else { + MatchExpr(rs, cases.flatMap { c => + val patternExpr = conditionForPattern(rs, c.pattern) + + if (stillPossible && !contradictedBy(patternExpr, path)) { + + if (impliedBy(patternExpr, path)) { + stillPossible = false + } + + c match { + case SimpleCase(p, rhs) => + Some(SimpleCase(p, rec(rhs, patternExpr +: path))) + case GuardedCase(_, _, _) => + sys.error("woot.") + } + } else { + None + } + }) + } + case LetTuple(is, e, b) => // Similar to the Let case val se = rec(e, path) diff --git a/src/main/scala/leon/purescala/Trees.scala b/src/main/scala/leon/purescala/Trees.scala index 4e5e0fea66072d14775847fc1e6df0c01a0477da..98888aa5e19859bd062d98ed341089c2f3978642 100644 --- a/src/main/scala/leon/purescala/Trees.scala +++ b/src/main/scala/leon/purescala/Trees.scala @@ -49,7 +49,9 @@ object Trees { funDef.args.zip(args).foreach { case (a, c) => typeCheck(c, a.tpe) } } - case class IfExpr(cond: Expr, then: Expr, elze: Expr) extends Expr + case class IfExpr(cond: Expr, then: Expr, elze: Expr) extends Expr with FixedType { + val fixedType = leastUpperBound(then.getType, elze.getType).getOrElse(AnyType) + } case class Tuple(exprs: Seq[Expr]) extends Expr { val subTpes = exprs.map(_.getType) @@ -87,7 +89,10 @@ object Trees { def unapply(me: MatchExpr) : Option[(Expr,Seq[MatchCase])] = if (me == null) None else Some((me.scrutinee, me.cases)) } - class MatchExpr(val scrutinee: Expr, val cases: Seq[MatchCase]) extends Expr with ScalacPositional { + class MatchExpr(val scrutinee: Expr, val cases: Seq[MatchCase]) extends Expr with ScalacPositional with FixedType { + + val fixedType = leastUpperBound(cases.map(_.rhs.getType)).getOrElse(AnyType) + def scrutineeClassType: ClassType = scrutinee.getType.asInstanceOf[ClassType] override def equals(that: Any): Boolean = (that != null) && (that match { diff --git a/src/main/scala/leon/purescala/TypeTrees.scala b/src/main/scala/leon/purescala/TypeTrees.scala index 7d2897ad5726251f5b759a0e01cee9d6be7b7b1c..d8c673fcc627b6c57d158161c5f9d706724f0788 100644 --- a/src/main/scala/leon/purescala/TypeTrees.scala +++ b/src/main/scala/leon/purescala/TypeTrees.scala @@ -107,6 +107,15 @@ object TypeTrees { case _ => None } + def leastUpperBound(ts: Seq[TypeTree]): Option[TypeTree] = { + def olub(ot1: Option[TypeTree], t2: Option[TypeTree]): Option[TypeTree] = ot1 match { + case Some(t1) => leastUpperBound(t1, t2.get) + case None => None + } + + ts.map(Some(_)).reduceLeft(olub) + } + def isSubtypeOf(t1: TypeTree, t2: TypeTree): Boolean = { leastUpperBound(t1, t2) == Some(t2) } diff --git a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala index aefe9fd2c89a8d4728e2f85282a68e93b6bd6adc..2a14a92f5636cab373bfb9f5ca1d496c948050d3 100644 --- a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala @@ -393,7 +393,9 @@ trait AbstractZ3Solver extends solvers.IncrementalSolverBuilder { // scala.sys.error("Error in formula being translated to Z3: identifier " + id + " seems to have escaped its let-definition") // } - assert(!this.isInstanceOf[FairZ3Solver], "Trying to convert unknown variable '"+id+"' while using FairZ3") + // Remove this safety check, since choose() expresions are now + // translated to non-unrollable variables, that end up here. + // assert(!this.isInstanceOf[FairZ3Solver], "Trying to convert unknown variable '"+id+"' while using FairZ3") val newAST = z3.mkFreshConst(id.uniqueName/*name*/, typeToSort(v.getType)) z3Vars = z3Vars + (id -> newAST) diff --git a/src/main/scala/leon/solvers/z3/FunctionTemplate.scala b/src/main/scala/leon/solvers/z3/FunctionTemplate.scala index 6a4a96e1519d01077bfcecca0664702ef9da524f..2d63996602d92af44b748e81b0bc2de21cedba79 100644 --- a/src/main/scala/leon/solvers/z3/FunctionTemplate.scala +++ b/src/main/scala/leon/solvers/z3/FunctionTemplate.scala @@ -153,6 +153,8 @@ object FunctionTemplate { } } + case c @ Choose(_, _) => Variable(FreshIdentifier("choose", true).setType(c.getType)) + case n @ NAryOperator(as, r) => r(as.map(a => rec(pathVar, pathPol, a))).setType(n.getType) case b @ BinaryOperator(a1, a2, r) => r(rec(pathVar, pathPol, a1), rec(pathVar, pathPol, a2)).setType(b.getType) case u @ UnaryOperator(a, r) => r(rec(pathVar, pathPol, a)).setType(u.getType) diff --git a/src/main/scala/leon/synthesis/ParallelSearch.scala b/src/main/scala/leon/synthesis/ParallelSearch.scala index 5ec24b47e4857a6eb953065ca331236cba90744a..fb3f824c820c5b4f493d75ff85917feecf5d45e9 100644 --- a/src/main/scala/leon/synthesis/ParallelSearch.scala +++ b/src/main/scala/leon/synthesis/ParallelSearch.scala @@ -24,7 +24,7 @@ class ParallelSearch(synth: Synthesizer, solver.initZ3 - val ctx = SynthesisContext(solver = solver, reporter = synth.reporter, shouldStop = synth.shouldStop) + val ctx = SynthesisContext.fromSynthesizer(synth).copy(solver = solver) synchronized { contexts = ctx :: contexts diff --git a/src/main/scala/leon/synthesis/SynthesisContext.scala b/src/main/scala/leon/synthesis/SynthesisContext.scala index d23f4df29b7b07473518104988e35a57c280e43c..aa46f6381811a10de1a1436093cf338a492b070a 100644 --- a/src/main/scala/leon/synthesis/SynthesisContext.scala +++ b/src/main/scala/leon/synthesis/SynthesisContext.scala @@ -2,16 +2,30 @@ package leon package synthesis import solvers.Solver +import purescala.Trees._ +import purescala.Definitions.{Program, FunDef} +import purescala.Common.Identifier import java.util.concurrent.atomic.AtomicBoolean case class SynthesisContext( + options: SynthesizerOptions, + functionContext: Option[FunDef], + program: Program, solver: Solver, reporter: Reporter, shouldStop: AtomicBoolean ) object SynthesisContext { - def fromSynthesizer(synth: Synthesizer) = SynthesisContext(synth.solver, synth.reporter, new AtomicBoolean(false)) + def fromSynthesizer(synth: Synthesizer) = { + SynthesisContext( + synth.options, + synth.functionContext, + synth.program, + synth.solver, + synth.reporter, + new AtomicBoolean(false)) + } } diff --git a/src/main/scala/leon/synthesis/SynthesisPhase.scala b/src/main/scala/leon/synthesis/SynthesisPhase.scala index 7adb833661334ac6ee8ba9a6b3444f757fd22028..ed7cf742958023fa68b7f99cc6dfce984a763d2f 100644 --- a/src/main/scala/leon/synthesis/SynthesisPhase.scala +++ b/src/main/scala/leon/synthesis/SynthesisPhase.scala @@ -14,13 +14,14 @@ object SynthesisPhase extends LeonPhase[Program, Program] { val description = "Synthesis" override val definedOptions : Set[LeonOptionDef] = Set( - LeonFlagOptionDef( "inplace", "--inplace", "Debug level"), - LeonOptValueOptionDef("parallel", "--parallel[=N]", "Parallel synthesis search using N workers"), - LeonFlagOptionDef( "derivtrees", "--derivtrees", "Generate derivation trees"), - LeonFlagOptionDef( "firstonly", "--firstonly", "Stop as soon as one synthesis solution is found"), - LeonValueOptionDef( "timeout", "--timeout=T", "Timeout after T seconds when searching for synthesis solutions .."), - LeonValueOptionDef( "costmodel", "--costmodel=cm", "Use a specific cost model for this search"), - LeonValueOptionDef( "functions", "--functions=f1:f2", "Limit synthesis of choose found within f1,f2,..") + LeonFlagOptionDef( "inplace", "--inplace", "Debug level"), + LeonOptValueOptionDef("parallel", "--parallel[=N]", "Parallel synthesis search using N workers"), + LeonFlagOptionDef( "derivtrees", "--derivtrees", "Generate derivation trees"), + LeonFlagOptionDef( "firstonly", "--firstonly", "Stop as soon as one synthesis solution is found"), + LeonValueOptionDef( "timeout", "--timeout=T", "Timeout after T seconds when searching for synthesis solutions .."), + LeonValueOptionDef( "costmodel", "--costmodel=cm", "Use a specific cost model for this search"), + LeonValueOptionDef( "functions", "--functions=f1:f2", "Limit synthesis of choose found within f1,f2,.."), + LeonFlagOptionDef( "cegis:gencalls", "--cegis:gencalls", "Include function calls in CEGIS generators") ) def run(ctx: LeonContext)(p: Program): Program = { @@ -74,6 +75,9 @@ object SynthesisPhase extends LeonPhase[Program, Program] { options = options.copy(searchWorkers = nWorkers) } + case LeonFlagOption("cegis:gencalls") => + options = options.copy(cegisGenerateFunCalls = true) + case LeonFlagOption("derivtrees") => options = options.copy(generateDerivationTrees = true) @@ -89,6 +93,7 @@ object SynthesisPhase extends LeonPhase[Program, Program] { case ch @ Choose(vars, pred) => val problem = Problem.fromChoose(ch) val synth = new Synthesizer(ctx, + Some(f), mainSolver, p, problem, @@ -117,9 +122,11 @@ object SynthesisPhase extends LeonPhase[Program, Program] { // Simplify expressions val simplifiers = List[Expr => Expr]( - simplifyTautologies(uninterpretedZ3)(_), + simplifyTautologies(uninterpretedZ3)(_), simplifyLets _, decomposeIfs _, + matchToIfThenElse _, + simplifyPaths(uninterpretedZ3)(_), patternMatchReconstruction _, simplifyTautologies(uninterpretedZ3)(_), simplifyLets _, @@ -129,7 +136,8 @@ object SynthesisPhase extends LeonPhase[Program, Program] { def simplify(e: Expr): Expr = simplifiers.foldLeft(e){ (x, sim) => sim(x) } val chooseToExprs = solutions.map { - case (ch, (fd, sol)) => (ch, (fd, simplify(sol.toExpr))) + case (ch, (fd, sol)) => + (ch, (fd, simplify(sol.toExpr))) } if (inPlace) { diff --git a/src/main/scala/leon/synthesis/Synthesizer.scala b/src/main/scala/leon/synthesis/Synthesizer.scala index b15e1ec69801fec4b57085660bc2b02018570a3a..0a402e5db9cbfce350bdd69302479f6aa9b636df 100644 --- a/src/main/scala/leon/synthesis/Synthesizer.scala +++ b/src/main/scala/leon/synthesis/Synthesizer.scala @@ -18,6 +18,7 @@ import synthesis.search._ import java.util.concurrent.atomic.AtomicBoolean class Synthesizer(val context : LeonContext, + val functionContext: Option[FunDef], val solver: Solver, val program: Program, val problem: Problem, diff --git a/src/main/scala/leon/synthesis/SynthesizerOptions.scala b/src/main/scala/leon/synthesis/SynthesizerOptions.scala index 177d21c49bbfcb6fd4d8e804aaf1647e1405d17a..e9d2f4c94b89e78c473850bc8e0eda343e487f3f 100644 --- a/src/main/scala/leon/synthesis/SynthesizerOptions.scala +++ b/src/main/scala/leon/synthesis/SynthesizerOptions.scala @@ -7,5 +7,6 @@ case class SynthesizerOptions( searchWorkers: Int = 1, firstOnly: Boolean = false, timeoutMs: Option[Long] = None, - costModel: CostModel = CostModel.default + costModel: CostModel = CostModel.default, + cegisGenerateFunCalls: Boolean = false ) diff --git a/src/main/scala/leon/synthesis/rules/ADTSplit.scala b/src/main/scala/leon/synthesis/rules/ADTSplit.scala index 61d3a9768a87cb7cc3f7fe63805b1c626f15845e..7235f9bb4166ba5c61b539967fdd6bbb1f466e5b 100644 --- a/src/main/scala/leon/synthesis/rules/ADTSplit.scala +++ b/src/main/scala/leon/synthesis/rules/ADTSplit.scala @@ -41,7 +41,6 @@ case object ADTSplit extends Rule("ADT Split.") { } } - candidates.collect{ _ match { case Some((id, cases)) => val oas = p.as.filter(_ != id) diff --git a/src/main/scala/leon/synthesis/rules/Cegis.scala b/src/main/scala/leon/synthesis/rules/Cegis.scala index c7f84e483fd78b92415e319d3def1c20faee633a..0be34792948c462135e7fe1a982c8bc08442d4b9 100644 --- a/src/main/scala/leon/synthesis/rules/Cegis.scala +++ b/src/main/scala/leon/synthesis/rules/Cegis.scala @@ -13,6 +13,13 @@ import solvers.z3.FairZ3Solver case object CEGIS extends Rule("CEGIS") { def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { + + // CEGIS Flags to actiave or de-activate features + val useCounterExamples = false + val useUninterpretedProbe = false + val useUnsatCores = true + val useFunGenerators = sctx.options.cegisGenerateFunCalls + case class Generator(tpe: TypeTree, altBuilder: () => List[(Expr, Set[Identifier])]); var generators = Map[TypeTree, Generator]() @@ -64,6 +71,31 @@ case object CEGIS extends Rule("CEGIS") { p.as.filter(a => isSubtypeOf(a.getType, t)).map(id => (Variable(id) : Expr, Set[Identifier]())) } + def funcAlternatives(t: TypeTree): List[(Expr, Set[Identifier])] = { + if (useFunGenerators) { + def isCandidate(fd: FunDef): Boolean = { + // Prevents recursive calls + val isRecursiveCall = sctx.functionContext match { + case Some(cfd) => + (sctx.program.transitiveCallers(cfd) + cfd) contains fd + + case None => + false + } + + isSubtypeOf(fd.returnType, t) && !isRecursiveCall + } + + sctx.program.definedFunctions.filter(isCandidate).map{ fd => + val ids = fd.args.map(vd => FreshIdentifier("c", true).setType(vd.getType)) + + (FunctionInvocation(fd, ids.map(Variable(_))), ids.toSet) + }.toList + } else { + Nil + } + } + class TentativeFormula(val pathcond: Expr, val phi: Expr, var program: Expr, @@ -77,7 +109,7 @@ case object CEGIS extends Rule("CEGIS") { for ((_, recIds) <- recTerms; recId <- recIds) { val gen = getGenerator(recId.getType) - val alts = gen.altBuilder() ::: inputAlternatives(recId.getType) + val alts = gen.altBuilder() ::: inputAlternatives(recId.getType) ::: funcAlternatives(recId.getType) val altsWithBranches = alts.map(alt => FreshIdentifier("b", true).setType(BooleanType) -> alt) @@ -122,6 +154,7 @@ case object CEGIS extends Rule("CEGIS") { val xsSet = p.xs.toSet + val (exprsA, others) = ands.partition(e => (variablesOf(e) & xsSet).isEmpty) if (exprsA.isEmpty) { val res = new RuleInstantiation(p, this, SolutionBuilder.none) { @@ -150,42 +183,35 @@ case object CEGIS extends Rule("CEGIS") { try { do { val (clauses, bounds) = unrolling.unroll - //println("UNROLLING: "+clauses+" WITH BOUNDS "+bounds) - solver1.assertCnstr(And(clauses)) - solver2.assertCnstr(And(clauses)) + //println("UNROLLING: ") + //for (c <- clauses) { + // println(" - " + c) + //} + //println("BOUNDS "+bounds) - //println("="*80) - //println("Was: "+lastF.entireFormula) - //println("Now Trying : "+currentF.entireFormula) + val clause = And(clauses) + solver1.assertCnstr(clause) + solver2.assertCnstr(clause) val tpe = TupleType(p.xs.map(_.getType)) val bss = unrolling.bss var continue = !clauses.isEmpty - //println("Unrolling #"+unrolings+" bss size: "+bss.size) - while (result.isEmpty && continue && !sctx.shouldStop.get) { //println("Looking for CE...") //println("-"*80) - //println(basePhi) - //println("To satisfy: "+constrainedPhi) solver1.checkAssumptions(bounds.map(id => Not(Variable(id)))) match { case Some(true) => val satModel = solver1.getModel - //println("Found solution: "+satModel) - //println("Corresponding program: "+simplifyTautologies(synth.solver)(valuateWithModelIn(currentF.program, bss, satModel))) - //val fixedBss = And(bss.map(b => Equals(Variable(b), satModel(b))).toSeq) - //println("Phi with fixed sat bss: "+fixedBss) - val bssAssumptions: Set[Expr] = bss.map(b => satModel(b) match { case BooleanLiteral(true) => Variable(b) case BooleanLiteral(false) => Not(Variable(b)) }) - //println("FORMULA: "+And(currentF.pathcond :: currentF.program :: Not(currentF.phi) :: fixedBss :: Nil)) + //println("Found solution: "+bssAssumptions) //println("#"*80) solver2.checkAssumptions(bssAssumptions) match { @@ -201,22 +227,18 @@ case object CEGIS extends Rule("CEGIS") { solver1.assertCnstr(fixedAss) //println("Found counter example: "+fixedAss) - val unsatCore = solver1.checkAssumptions(bssAssumptions) match { - case Some(false) => - val core = solver1.getUnsatCore - //println("Formula: "+mustBeUnsat) - //println("Core: "+core) - //println(synth.solver.solveSAT(And(mustBeUnsat +: bssAssumptions.toSeq))) - //println("maxcore: "+bssAssumptions) - if (core.isEmpty) { - // This happens if unrolling level is insufficient, it becomes unsat no matter what the assumptions are. - //sctx.reporter.warning("Got empty core, must be unsat without assumptions!") - Set() - } else { - core - } - case _ => - bssAssumptions + val unsatCore = if (useUnsatCores) { + solver1.checkAssumptions(bssAssumptions) match { + case Some(false) => + // Core might be empty if unrolling level is + // insufficient, it becomes unsat no matter what + // the assumptions are. + solver1.getUnsatCore + case _ => + bssAssumptions + } + } else { + bssAssumptions } solver1.pop() @@ -224,29 +246,31 @@ case object CEGIS extends Rule("CEGIS") { if (unsatCore.isEmpty) { continue = false } else { + if (useCounterExamples) { + val freshCss = unrolling.css.map(c => c -> Variable(FreshIdentifier(c.name, true).setType(c.getType))).toMap + val ceIn = ass.collect { + case id if invalidModel contains id => id -> invalidModel(id) + } - val freshCss = unrolling.css.map(c => c -> Variable(FreshIdentifier(c.name, true).setType(c.getType))).toMap - val ceIn = ass.collect { - case id if invalidModel contains id => id -> invalidModel(id) - } + val ceMap = (freshCss ++ ceIn) + + val counterexample = substAll(ceMap, And(Seq(unrolling.program, unrolling.phi))) - val counterexample = substAll(freshCss ++ ceIn, And(Seq(unrolling.program, unrolling.phi))) + //val And(ands) = counterexample + //println("CE:") + //for (a <- ands) { + // println(" - "+a) + //} - solver1.assertCnstr(counterexample) - solver2.assertCnstr(counterexample) + solver1.assertCnstr(counterexample) + } - //predicates = Not(And(unsatCore.toSeq)) +: counterexample +: predicates solver1.assertCnstr(Not(And(unsatCore.toSeq))) - solver2.assertCnstr(Not(And(unsatCore.toSeq))) } case Some(false) => - //println("#"*80) - //println("UNSAT!") - //println("Sat model: "+satModel.toSeq.sortBy(_._1.toString).map{ case (id, v) => id+" -> "+v }.mkString(", ")) var mapping = unrolling.mappings.filterKeys(satModel.mapValues(_ == BooleanLiteral(true))).values.toMap - // Resolve mapping for ((c, e) <- mapping) { mapping += c -> substAll(mapping, e) @@ -263,7 +287,19 @@ case object CEGIS extends Rule("CEGIS") { case Some(false) => //println("%%%% UNSAT") + + if (useUninterpretedProbe) { + solver1.check match { + case Some(false) => + // Unsat even without blockers (under which fcalls are then uninterpreted) + result = Some(RuleApplicationImpossible) + + case _ => + } + } + continue = false + case _ => //println("%%%% WOOPS") continue = false diff --git a/src/main/scala/leon/synthesis/utils/Benchmarks.scala b/src/main/scala/leon/synthesis/utils/Benchmarks.scala index f5cd758c4f2576a671b232410ecf160f112eb666..99bd387fd3d034937848b6be88301c076ffd210e 100644 --- a/src/main/scala/leon/synthesis/utils/Benchmarks.scala +++ b/src/main/scala/leon/synthesis/utils/Benchmarks.scala @@ -80,13 +80,21 @@ object Benchmarks extends App { val pipeline = leon.plugin.ExtractionPhase andThen SynthesisProblemExtractionPhase - val (results, solver) = pipeline.run(innerCtx)(file.getPath :: Nil) + val (program, results) = pipeline.run(innerCtx)(file.getPath :: Nil) - - val sctx = SynthesisContext(solver, new DefaultReporter, new java.util.concurrent.atomic.AtomicBoolean) + val solver = new FairZ3Solver(ctx.copy(reporter = new SilentReporter)) for ((f, ps) <- results.toSeq.sortBy(_._1.id.toString); p <- ps) { + val sctx = SynthesisContext( + options = opts, + functionContext = Some(f), + program = program, + solver = solver, + reporter = new DefaultReporter, + shouldStop = new java.util.concurrent.atomic.AtomicBoolean + ) + val ts = System.currentTimeMillis val rr = rule.instantiateOn(sctx, p) diff --git a/src/main/scala/leon/synthesis/utils/SynthesisProblemExtractionPhase.scala b/src/main/scala/leon/synthesis/utils/SynthesisProblemExtractionPhase.scala index d875118df2cd4476c4e52f27d6133de708975d66..c31a04a6a99003d96b3c83514a1fb88ac1f4a5d3 100644 --- a/src/main/scala/leon/synthesis/utils/SynthesisProblemExtractionPhase.scala +++ b/src/main/scala/leon/synthesis/utils/SynthesisProblemExtractionPhase.scala @@ -8,16 +8,11 @@ import purescala.Definitions._ import solvers.z3._ import solvers.Solver -object SynthesisProblemExtractionPhase extends LeonPhase[Program, (Map[FunDef, Seq[Problem]], Solver)] { +object SynthesisProblemExtractionPhase extends LeonPhase[Program, (Program, Map[FunDef, Seq[Problem]])] { val name = "Synthesis Problem Extraction" val description = "Synthesis Problem Extraction" - def run(ctx: LeonContext)(p: Program): (Map[FunDef, Seq[Problem]], Solver) = { - - val silentContext : LeonContext = ctx.copy(reporter = new SilentReporter) - val mainSolver = new FairZ3Solver(silentContext) - mainSolver.setProgram(p) - + def run(ctx: LeonContext)(p: Program): (Program, Map[FunDef, Seq[Problem]]) = { var results = Map[FunDef, Seq[Problem]]() def noop(u:Expr, u2: Expr) = u @@ -38,7 +33,7 @@ object SynthesisProblemExtractionPhase extends LeonPhase[Program, (Map[FunDef, S treeCatamorphism(x => x, noop, actOnChoose(f), f.body.get) } - (results, mainSolver) + (p, results) } } diff --git a/src/test/scala/leon/test/synthesis/SynthesisSuite.scala b/src/test/scala/leon/test/synthesis/SynthesisSuite.scala index 55ec2990603fbdf2e1b150ccea563d7e248b7459..96d7104d7521bd9b95ba27bc8d91f515e2479b33 100644 --- a/src/test/scala/leon/test/synthesis/SynthesisSuite.scala +++ b/src/test/scala/leon/test/synthesis/SynthesisSuite.scala @@ -21,7 +21,7 @@ class SynthesisSuite extends FunSuite { counter } - def forProgram(title: String)(content: String)(block: (Solver, FunDef, Problem) => Unit) { + def forProgram(title: String)(content: String)(block: (SynthesisContext, FunDef, Problem) => Unit) { val ctx = LeonContext( settings = Settings( @@ -37,11 +37,16 @@ class SynthesisSuite extends FunSuite { val pipeline = leon.plugin.TemporaryInputPhase andThen leon.plugin.ExtractionPhase andThen SynthesisProblemExtractionPhase - val (results, solver) = pipeline.run(ctx)((content, Nil)) + val (program, results) = pipeline.run(ctx)((content, Nil)) + + val solver = new FairZ3Solver(ctx) + solver.setProgram(program) for ((f, ps) <- results; p <- ps) { test("Synthesizing %3d: %-20s [%s]".format(nextInt(), f.id.toString, title)) { - block(solver, f, p) + val sctx = SynthesisContext(opts, Some(f), program, solver, new DefaultReporter, new java.util.concurrent.atomic.AtomicBoolean) + + block(sctx, f, p) } } } @@ -99,9 +104,7 @@ object Injection { } """ ) { - case (solver, fd, p) => - val sctx = SynthesisContext(solver, new SilentReporter, new java.util.concurrent.atomic.AtomicBoolean) - + case (sctx, fd, p) => assertAllAlternativesSucceed(sctx, rules.CEGIS.instantiateOn(sctx, p)) assertFastEnough(sctx, rules.CEGIS.instantiateOn(sctx, p), 100) } @@ -127,9 +130,7 @@ object Injection { } """ ) { - case (solver, fd, p) => - val sctx = SynthesisContext(solver, new DefaultReporter, new java.util.concurrent.atomic.AtomicBoolean) - + case (sctx, fd, p) => rules.CEGIS.instantiateOn(sctx, p).head.apply(sctx) match { case RuleSuccess(sol) => assert(false, "CEGIS should have failed, but found : %s".format(sol)) diff --git a/testcases/synthesis/CegisFunctions.scala b/testcases/synthesis/CegisFunctions.scala new file mode 100644 index 0000000000000000000000000000000000000000..a3451d01dd35fbeee1f55c5874c141db1bac0739 --- /dev/null +++ b/testcases/synthesis/CegisFunctions.scala @@ -0,0 +1,30 @@ +import leon.Utils._ + +object CegisTests { + sealed abstract class List + case class Cons(head: Int, tail: List) extends List + case class Nil() extends List + + // proved with unrolling=0 + def size(l: List) : Int = (l match { + case Nil() => 0 + case Cons(_, t) => 1 + size(t) + }) ensuring(res => res >= 0) + + def content(l: List): Set[Int] = l match { + case Nil() => Set() + case Cons(i, t) => Set(i) ++ content(t) + } + + def insert(l: List, i: Int) = { + Cons(i, l) + }.ensuring(res => size(res) == size(l)+1 && content(res) == content(l) ++ Set(i)) + + def testInsert(l: List, i: Int) = { + choose { (o: List) => size(o) == size(l) + 1 } + } + + def testDelete(l: List, i: Int) = { + choose { (o: List) => size(o) == size(l) - 1 } + } +}