diff --git a/library/lang/synthesis/package.scala b/library/lang/synthesis/package.scala index df19bcf8e280aed3e81c117c91623235aa7d769b..9fd244605123a6fcf1dfd4d4bd4b51e8a9fdb383 100644 --- a/library/lang/synthesis/package.scala +++ b/library/lang/synthesis/package.scala @@ -31,6 +31,10 @@ package object synthesis { @ignore def ?[T](e1: T, es: T*): T = noImpl + // Repair with Holes + @ignore + def ?: T = noImpl + @ignore def withOracle[A, R](body: Oracle[A] => R): R = noImpl diff --git a/src/main/scala/leon/codegen/CodeGeneration.scala b/src/main/scala/leon/codegen/CodeGeneration.scala index f8f6b67c597465204c43512e0c241c0cbed5f491..33c5f9db03eb7e50772e377a0d87977c00ed1987 100644 --- a/src/main/scala/leon/codegen/CodeGeneration.scala +++ b/src/main/scala/leon/codegen/CodeGeneration.scala @@ -6,6 +6,7 @@ package codegen import purescala.Common._ import purescala.Definitions._ import purescala.Trees._ +import purescala.TreeOps.simplestValue import purescala.TypeTrees._ import purescala.TypeTreeOps.instantiateType import utils._ @@ -671,6 +672,9 @@ trait CodeGeneration { ch << InvokeSpecial(ErrorClass, constructorName, "(Ljava/lang/String;)V") ch << ATHROW + case rh: RepairHole => + mkExpr(simplestValue(rh.getType), ch) // It is expected to be invalid, we want to repair it + case choose @ Choose(_, _) => val prob = synthesis.Problem.fromChoose(choose) diff --git a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala index 79123e9c383063065cf4fd50a61f25efeb92e9e3..dbe7d651b8a13de82626f27b52de7751bd016646 100644 --- a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala +++ b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala @@ -390,6 +390,9 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int case gv: GenericValue => gv + case rh: RepairHole => + simplestValue(rh.getType) // It will be wrong, we don't care + case choose: Choose => import purescala.TreeOps.simplestValue diff --git a/src/main/scala/leon/frontends/scalac/ASTExtractors.scala b/src/main/scala/leon/frontends/scalac/ASTExtractors.scala index c0eb1cc9a8b55896dac73e1ed16894dfb17ab0b8..1d89142d3de669bfeeea5bd0a2a3499b4b138464 100644 --- a/src/main/scala/leon/frontends/scalac/ASTExtractors.scala +++ b/src/main/scala/leon/frontends/scalac/ASTExtractors.scala @@ -376,6 +376,16 @@ trait ASTExtractors { } } + object ExRepairHoleExpression { + def unapply(tree: Tree) : Option[(Tree, List[Tree])] = tree match { + case a @ Apply(TypeApply(s @ ExSymbol("leon", "lang", "synthesis", "$qmark$bang"), List(tpt)), args1) => + Some((tpt, args1)) + case TypeApply(s @ ExSymbol("leon", "lang", "synthesis", "$qmark$bang"), List(tpt)) => + Some((tpt, Nil)) + case _ => None + } + } + object ExHoleExpression { def unapply(tree: Tree) : Option[(Tree, List[Tree])] = tree match { case a @ Apply(TypeApply(s @ ExSymbol("leon", "lang", "synthesis", "$qmark"), List(tpt)), args1) => diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala index e31ffba62a8b57c473f5155bf130f649ad873d66..74ce0498cd52674dd9e0840a06b19ffe5598daee 100644 --- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala +++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala @@ -1206,6 +1206,11 @@ trait CodeExtraction extends ASTExtractors { Hole(extractType(tpt), exprs.map(extractTree)) + case hole @ ExRepairHoleExpression(tpt, exprs) => + val leonExprs = exprs.map(extractTree) + + RepairHole(extractType(tpt), exprs.map(extractTree)) + case ops @ ExWithOracleExpression(oracles, body) => val newOracles = oracles map { case (tpt, sym) => val aTpe = extractType(tpt) diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala index 8af4b527b4529ef4d7800aee460f027a7d169144..4fbf90ba6160273ffa15faddd4b8575c41f9933a 100644 --- a/src/main/scala/leon/purescala/PrettyPrinter.scala +++ b/src/main/scala/leon/purescala/PrettyPrinter.scala @@ -224,6 +224,9 @@ class PrettyPrinter(opts: PrinterOptions, val sb: StringBuffer = new StringBuffe | $pred |}""" + case h @ RepairHole(tpe, es) => + p"""|?""" + case h @ Hole(tpe, es) => if (es.isEmpty) { p"""|???[$tpe]""" diff --git a/src/main/scala/leon/purescala/TreeOps.scala b/src/main/scala/leon/purescala/TreeOps.scala index a47f702c9ad0e6c6c1ac953eea24917e0245a4e8..2c264755762c4e4f5d8a3f3f7935f4c24bbd5f1d 100644 --- a/src/main/scala/leon/purescala/TreeOps.scala +++ b/src/main/scala/leon/purescala/TreeOps.scala @@ -450,7 +450,7 @@ object TreeOps { case TupleSelect(LetTuple(ids, v, b), ts) => Some(LetTuple(ids, v, TupleSelect(b, ts))) - case IfExpr(c, thenn, elze) if (thenn == elze) && !containsChoose(e) => + case IfExpr(c, thenn, elze) if (thenn == elze) && isDeterministic(e) => Some(thenn) case IfExpr(c, BooleanLiteral(true), BooleanLiteral(false)) => @@ -473,7 +473,7 @@ object TreeOps { } def isGround(e: Expr): Boolean = { - variablesOf(e).isEmpty && !containsChoose(e) + variablesOf(e).isEmpty && isDeterministic(e) } def evalGround(ctx: LeonContext, program: Program): Expr => Expr = { @@ -503,10 +503,10 @@ object TreeOps { def simplifyLets(expr: Expr) : Expr = { def simplerLet(t: Expr) : Option[Expr] = t match { - case letExpr @ Let(i, t: Terminal, b) if !containsChoose(b) => + case letExpr @ Let(i, t: Terminal, b) if isDeterministic(b) => Some(replace(Map((Variable(i) -> t)), b)) - case letExpr @ Let(i,e,b) if !containsChoose(b) => { + case letExpr @ Let(i,e,b) if isDeterministic(b) => { val occurences = count{ (e: Expr) => e match { case Variable(x) if x == i => 1 case _ => 0 @@ -521,7 +521,7 @@ object TreeOps { } } - case letTuple @ LetTuple(ids, Tuple(exprs), body) if !containsChoose(body) => + case letTuple @ LetTuple(ids, Tuple(exprs), body) if isDeterministic(body) => var newBody = body val (remIds, remExprs) = (ids zip exprs).filter { @@ -554,14 +554,14 @@ object TreeOps { Some(LetTuple(remIds, Tuple(remExprs), newBody)) } - case l @ LetTuple(ids, tExpr: Terminal, body) if !containsChoose(body) => + case l @ LetTuple(ids, tExpr: Terminal, body) if isDeterministic(body) => val substMap : Map[Expr,Expr] = ids.map(Variable(_) : Expr).zipWithIndex.toMap.map { case (v,i) => (v -> TupleSelect(tExpr, i + 1).copiedFrom(v)) } Some(replace(substMap, body)) - case l @ LetTuple(ids, tExpr, body) if !containsChoose(body) => + case l @ LetTuple(ids, tExpr, body) if isDeterministic(body) => val arity = ids.size val zeroVec = Seq.fill(arity)(0) val idMap = ids.zipWithIndex.toMap.mapValues(i => zeroVec.updated(i, 1)) @@ -1504,6 +1504,16 @@ object TreeOps { false } + def isDeterministic(e: Expr): Boolean = { + preTraversal{ + case Choose(_, _) => return false + case Hole(_, _) => return false + case RepairHole(_, _) => return false + case _ => + }(e) + true + } + /** * Returns the value for an identifier given a model. */ diff --git a/src/main/scala/leon/purescala/Trees.scala b/src/main/scala/leon/purescala/Trees.scala index a7088b12790aae3d617c10638e3f7d4289e8dfe2..e43986632c96fb2364053af66feed28539c18d99 100644 --- a/src/main/scala/leon/purescala/Trees.scala +++ b/src/main/scala/leon/purescala/Trees.scala @@ -71,6 +71,13 @@ object Trees { } } + case class RepairHole(fixedType: TypeTree, components: Seq[Expr]) extends Expr with FixedType with NAryExtractable { + + def extract = { + Some((components, (es: Seq[Expr]) => RepairHole(fixedType, es).setPos(this))) + } + } + /* Like vals */ case class Let(binder: Identifier, value: Expr, body: Expr) extends Expr with FixedType { val fixedType = body.getType diff --git a/src/main/scala/leon/refactor/RepairCostModel.scala b/src/main/scala/leon/refactor/RepairCostModel.scala new file mode 100644 index 0000000000000000000000000000000000000000..40dc2ad4375d1dc0056a829625b8cb2892059845 --- /dev/null +++ b/src/main/scala/leon/refactor/RepairCostModel.scala @@ -0,0 +1,22 @@ +/* Copyright 2009-2014 EPFL, Lausanne */ + +package leon +package refactor +import synthesis._ + +import purescala.Trees._ +import purescala.TreeOps._ + +case class RepairCostModel(cm: CostModel) extends CostModel(cm.name) { + override def ruleAppCost(app: RuleInstantiation): Cost = { + app.rule match { + case rules.GuidedDecomp => 0 + case rules.CEGLESS => 0 + case _ => cm.ruleAppCost(app) + } + } + def solutionCost(s: Solution) = cm.solutionCost(s) + def problemCost(p: Problem) = cm.problemCost(p) +} + + diff --git a/src/main/scala/leon/refactor/Repairman.scala b/src/main/scala/leon/refactor/Repairman.scala index 1a63edeb2fb2e749a558ec82f113ba491f851bf5..1c1e178440cafe3558eb54774bd5e005b596a6de 100644 --- a/src/main/scala/leon/refactor/Repairman.scala +++ b/src/main/scala/leon/refactor/Repairman.scala @@ -19,6 +19,8 @@ import solvers.z3._ import codegen._ import verification._ import synthesis._ +import synthesis.rules._ +import synthesis.heuristics._ class Repairman(ctx: LeonContext, program: Program, fd: FunDef) { val reporter = ctx.reporter @@ -61,7 +63,7 @@ class Repairman(ctx: LeonContext, program: Program, fd: FunDef) { // Compute initial call val termfd = program.library.terminating.get val withinCall = FunctionInvocation(fd.typedWithDef, fd.params.map(_.id.toVariable)) - val term = FunctionInvocation(termfd.typed(Seq(fd.returnType)), Seq(withinCall)) + val terminating = FunctionInvocation(termfd.typed(Seq(fd.returnType)), Seq(withinCall)) val spec = And(Seq( fd.postcondition.map(_._2).getOrElse(BooleanLiteral(true)), @@ -71,13 +73,24 @@ class Repairman(ctx: LeonContext, program: Program, fd: FunDef) { val pc = And(Seq( pre, guide, - term + terminating )) // Synthesis from the ground up val p = Problem(fd.params.map(_.id).toList, pc, spec, List(out)) - - val soptions = SynthesisPhase.processOptions(ctx).copy(costModel = RepairCostModel(CostModel.default)); + val ch = Choose(List(out), spec) + //fd.body = Some(ch) + + val soptions0 = SynthesisPhase.processOptions(ctx); + + val soptions = soptions0.copy( + costModel = RepairCostModel(soptions0.costModel), + rules = (soptions0.rules ++ Seq( + GuidedDecomp, + GuidedCloser, + CEGLESS + )) diff Seq(ADTInduction) + ); val synthesizer = new Synthesizer(ctx, fd, program, p, soptions) diff --git a/src/main/scala/leon/solvers/templates/TemplateGenerator.scala b/src/main/scala/leon/solvers/templates/TemplateGenerator.scala index a276e41cb5d223ae9d51a46fbf78bc304d858572..efd68f68069a927035919d0e56b670ea7f9b6bb1 100644 --- a/src/main/scala/leon/solvers/templates/TemplateGenerator.scala +++ b/src/main/scala/leon/solvers/templates/TemplateGenerator.scala @@ -261,6 +261,11 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T]) { } } + case h @ RepairHole(_, _) => + val hid = FreshIdentifier("hole", true).setType(h.getType) + exprVars += hid + Variable(hid) + case c @ Choose(ids, cond) => val cid = FreshIdentifier("choose", true).setType(c.getType) storeExpr(cid) diff --git a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala index f5b8b4f12150f9f5f0eddf14641b07a390ea2404..d4df8948d96e98d44866e90ef8b2af2a6554fb42 100644 --- a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala @@ -678,6 +678,11 @@ trait AbstractZ3Solver case gv @ GenericValue(tp, id) => z3.mkApp(genericValueToDecl(gv)) + case h @ RepairHole(_, _) => + val newAST = z3.mkFreshConst("hole", typeToSort(h.getType)) + variables += (h -> newAST) + newAST + case _ => { reporter.warning(ex.getPos, "Can't handle this in translation to Z3: " + ex) throw new CantTranslateException diff --git a/src/main/scala/leon/synthesis/CostModel.scala b/src/main/scala/leon/synthesis/CostModel.scala index 33529ccecbd7f9bf1c8d651ac5df5981357663ba..b4cac6662ae50934b9231a8867910f2aaa6ec757 100644 --- a/src/main/scala/leon/synthesis/CostModel.scala +++ b/src/main/scala/leon/synthesis/CostModel.scala @@ -40,18 +40,6 @@ object CostModel { ) } -case class RepairCostModel(cm: CostModel) extends CostModel(cm.name) { - override def ruleAppCost(app: RuleInstantiation): Cost = { - app.rule match { - case rules.GuidedDecomp => 0 - case _ => cm.ruleAppCost(app) - } - } - def solutionCost(s: Solution) = cm.solutionCost(s) - def problemCost(p: Problem) = cm.problemCost(p) -} - - case object NaiveCostModel extends CostModel("Naive") { def solutionCost(s: Solution): Cost = { val chooses = collectChooses(s.toExpr) diff --git a/src/main/scala/leon/synthesis/Heuristics.scala b/src/main/scala/leon/synthesis/Heuristics.scala index 83c5982bc163af6ceedadf432cfb67ca168735e2..378b0018e9bf25e02f04ecc4dec36e2668adee16 100644 --- a/src/main/scala/leon/synthesis/Heuristics.scala +++ b/src/main/scala/leon/synthesis/Heuristics.scala @@ -11,11 +11,11 @@ import heuristics._ object Heuristics { def all = List[Rule]( IntInduction, - InnerCaseSplit + InnerCaseSplit, //new OptimisticInjection(_), //new SelectiveInlining(_), //ADTLongInduction, - //ADTInduction + ADTInduction ) } diff --git a/src/main/scala/leon/synthesis/Rules.scala b/src/main/scala/leon/synthesis/Rules.scala index efaca2093929c7122b8bc6b2aebd432cbe4fec2e..95877335937746e7c365aa48f11ff1ccc75618c4 100644 --- a/src/main/scala/leon/synthesis/Rules.scala +++ b/src/main/scala/leon/synthesis/Rules.scala @@ -26,9 +26,6 @@ object Rules { InequalitySplit, CEGIS, TEGIS, - GuidedDecomp, - GuidedCloser, - CEGLESS, rules.Assert, DetupleOutput, DetupleInput, @@ -44,7 +41,13 @@ object Rules { val rulesPrio = sctx.rules.groupBy(_.priority).toSeq.sortBy(_._1) for ((_, rs) <- rulesPrio) { - val results = rs.flatMap(_.instantiateOn(sctx, problem)).toList + val results = rs.flatMap{ r => + val ts = System.currentTimeMillis + val res = r.instantiateOn(sctx, problem) + println("Instantiating "+r+" ("+(System.currentTimeMillis-ts)+")") + res + }.toList + if (results.nonEmpty) { return results; } diff --git a/src/main/scala/leon/synthesis/rules/Cegis.scala b/src/main/scala/leon/synthesis/rules/Cegis.scala index 33b583d394816ccc3a871c4e0f2ca41d48240075..b5997165ba6a44ef68e2dc3fe02dbaaf131d463a 100644 --- a/src/main/scala/leon/synthesis/rules/Cegis.scala +++ b/src/main/scala/leon/synthesis/rules/Cegis.scala @@ -34,7 +34,7 @@ case object CEGIS extends CEGISLike("CEGIS") { } case object CEGLESS extends CEGISLike("CEGLESS") { - override val maxUnrolings = 2; + override val maxUnfoldings = 3; def getGrammar(sctx: SynthesisContext, p: Problem) = { import ExpressionGrammars._ @@ -47,9 +47,11 @@ case object CEGLESS extends CEGISLike("CEGLESS") { case FunctionInvocation(TypedFunDef(`guide`, _), Seq(expr)) => expr } - val guidedGrammar = guides.map(SimilarTo(_)).foldLeft[ExpressionGrammar](Empty)(_ || _) + val inputs = p.as.map(_.toVariable) - guidedGrammar || OneOf(p.as.map(_.toVariable)) + val guidedGrammar = guides.map(SimilarTo(_, inputs.toSet)).foldLeft[ExpressionGrammar](Empty)(_ || _) + + guidedGrammar || OneOf(inputs) } } @@ -58,7 +60,7 @@ abstract class CEGISLike(name: String) extends Rule(name) { def getGrammar(sctx: SynthesisContext, p: Problem): ExpressionGrammar - val maxUnrolings = 3 + val maxUnfoldings = 3 def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { @@ -185,7 +187,7 @@ abstract class CEGISLike(name: String) extends Rule(name) { val cToExprs = mappings.groupBy(_._2._1).map { case (c, maps) => - // We only keep cases within the current unrolling closedBs + // We only keep cases within the current unfoldings closedBs val cases = maps.flatMap{ case (b, (_, ex)) => if (isBClosed(b)) None else Some(b -> ex) } // We compute the IF expression corresponding to each c @@ -306,11 +308,22 @@ abstract class CEGISLike(name: String) extends Rule(name) { triedCompilation = false progEvaluator = None + + var cGroups = Map[Identifier, (Set[Identifier], Set[Identifier])]() + + for ((parentGuard, cToBs) <- bTree; (c, bss) <- cToBs) { + val (ps, bs) = cGroups.getOrElse(c, (Set[Identifier](), Set[Identifier]())) + + cGroups += c -> (ps + parentGuard, bs ++ bss) + } + // We need to regenerate clauses for each b - val pathConstraints = for ((parentGuard, cToBs) <- bTree; (c, bs) <- cToBs) yield { + val pathConstraints = for ((_, (parentGuards, bs)) <- cGroups) yield { val bvs = bs.toList.map(Variable(_)) - val failedPath = Not(Variable(parentGuard)) + // Represents the case where all parents guards are false, indicating + // that this C should not be considered at all + val failedPath = And(parentGuards.toSeq.map(p => Not(p.toVariable))) val distinct = bvs.combinations(2).collect { case List(a, b) => @@ -325,23 +338,25 @@ abstract class CEGISLike(name: String) extends Rule(name) { Implies(Variable(bid), Equals(Variable(recId), ex)) } - //for (i <- impliess) { - // println(": "+i) - //} - (pathConstraints ++ impliess).toSeq } - def unroll(finalUnrolling: Boolean): (List[Expr], Set[Identifier]) = { + def unfold(finalUnfolding: Boolean): (List[Expr], Set[Identifier]) = { var newClauses = List[Expr]() var newGuardedTerms = Map[Identifier, Set[Identifier]]() var newMappings = Map[Identifier, (Identifier, Expr)]() + var cGroups = Map[Identifier, Set[Identifier]]() + for ((parentGuard, recIds) <- guardedTerms; recId <- recIds) { + cGroups += recId -> (cGroups.getOrElse(recId, Set()) + parentGuard) + } + + for ((recId, parentGuards) <- cGroups) { var alts = grammar.getProductions(recId.getType) - if (finalUnrolling) { + if (finalUnfolding) { alts = alts.filter(_.subTrees.isEmpty) } @@ -349,7 +364,9 @@ abstract class CEGISLike(name: String) extends Rule(name) { val bvs = altsWithBranches.map(alt => Variable(alt._1)) - val failedPath = Not(Variable(parentGuard)) + // Represents the case where all parents guards are false, indicating + // that this C should not be considered at all + val failedPath = And(parentGuards.toSeq.map(p => Not(p.toVariable))) val distinct = bvs.combinations(2).collect { case List(a, b) => @@ -380,7 +397,10 @@ abstract class CEGISLike(name: String) extends Rule(name) { } val newBIds = altsWithBranches.map(_._1).toSet - bTree += parentGuard -> (bTree.getOrElse(parentGuard, Map()) + (recId -> newBIds)) + + for (parentGuard <- parentGuards) { + bTree += parentGuard -> (bTree.getOrElse(parentGuard, Map()) + (recId -> newBIds)) + } newClauses = newClauses ::: pre :: cases } @@ -417,8 +437,8 @@ abstract class CEGISLike(name: String) extends Rule(name) { val initGuard = FreshIdentifier("START", true).setType(BooleanType) val ndProgram = new NonDeterministicProgram(p, initGuard) - var unrolings = 1 - val maxUnrolings = CEGISLike.this.maxUnrolings + var unfolding = 1 + val maxUnfoldings = CEGISLike.this.maxUnfoldings val exSolverTo = 2000L val cexSolverTo = 2000L @@ -546,12 +566,12 @@ abstract class CEGISLike(name: String) extends Rule(name) { var bssAssumptions = Set[Identifier]() if (!didFilterAlready) { - val (clauses, closedBs) = ndProgram.unroll(unrolings == maxUnrolings) + val (clauses, closedBs) = ndProgram.unfold(unfolding == maxUnfoldings) bssAssumptions = closedBs sctx.reporter.ifDebug { debug => - debug("UNROLLING: ") + debug("UNFOLDING: ") for (c <- clauses) { debug(" - " + c.asString(sctx.context)) } @@ -618,8 +638,8 @@ abstract class CEGISLike(name: String) extends Rule(name) { // We filter the Bss so that the formula we give to z3 is much smalled val bssToKeep = prunedPrograms.foldLeft(Set[Identifier]())(_ ++ _) - // Cannot unroll normally after having filtered, so we need to - // repeat the filtering procedure at next unrolling. + // Cannot unfold normally after having filtered, so we need to + // repeat the filtering procedure at next unfolding. didFilterAlready = true // Freshening solvers @@ -699,7 +719,7 @@ abstract class CEGISLike(name: String) extends Rule(name) { val core = solver1.checkAssumptions(bssAssumptions) match { case Some(false) => - // Core might be empty if unrolling level is + // Core might be empty if unfolding level is // insufficient, it becomes unsat no matter what // the assumptions are. solver1.getUnsatCore @@ -765,8 +785,8 @@ abstract class CEGISLike(name: String) extends Rule(name) { } } - unrolings += 1 - } while(unrolings <= maxUnrolings && result.isEmpty && !interruptManager.isInterrupted()) + unfolding += 1 + } while(unfolding <= maxUnfoldings && result.isEmpty && !interruptManager.isInterrupted()) result.getOrElse(RuleFailed()) diff --git a/src/main/scala/leon/synthesis/rules/GuidedCloser.scala b/src/main/scala/leon/synthesis/rules/GuidedCloser.scala index eb4843903ba4f9f06d8b8d26f679fddeacc1651e..91b1494c7ecfceb8e85e484c92f51d41acda3284 100644 --- a/src/main/scala/leon/synthesis/rules/GuidedCloser.scala +++ b/src/main/scala/leon/synthesis/rules/GuidedCloser.scala @@ -25,7 +25,7 @@ case object GuidedCloser extends NormalizingRule("Guided Closer") { case FunctionInvocation(TypedFunDef(`guide`, _), Seq(expr)) => expr } - val alts = guides.flatMap { e => + val alts = guides.filter(isDeterministic).flatMap { e => // Tentative solution using e val wrappedE = if (p.xs.size == 1) Tuple(Seq(e)) else e @@ -43,10 +43,18 @@ case object GuidedCloser extends NormalizingRule("Guided Closer") { Some(Solution(BooleanLiteral(true), Set(), wrappedE, true)) case None => + sctx.reporter.ifDebug { printer => + printer(vc) + printer("== Unknown ==") + } None //Some(Solution(BooleanLiteral(true), Set(), wrappedE, false)) case _ => + sctx.reporter.ifDebug { printer => + printer(vc) + printer("== Invalid! ==") + } None } diff --git a/src/main/scala/leon/synthesis/rules/Tegis.scala b/src/main/scala/leon/synthesis/rules/Tegis.scala index ada8bce10dba8afb7a1c28c8191306c56eda7f70..e04dc0c718909ead413d1a024e61dfefed9e8938 100644 --- a/src/main/scala/leon/synthesis/rules/Tegis.scala +++ b/src/main/scala/leon/synthesis/rules/Tegis.scala @@ -32,12 +32,22 @@ import bonsai.enumerators._ case object TEGIS extends Rule("TEGIS") { def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { - val grammar = ExpressionGrammars.default(sctx, p) - var tests = p.getTests(sctx).map(_.ins).distinct - if (tests.nonEmpty) { - List(new RuleInstantiation(p, this, SolutionBuilder.none, this.name, this.priority) { - def apply(sctx: SynthesisContext): RuleApplication = { + // check if the formula contains passes: + val passes = sctx.program.library.passes.get + + val mayHaveTests = exists({ + case FunctionInvocation(TypedFunDef(`passes`, _), _) => true + case _ => false + })(p.phi) + + List(new RuleInstantiation(p, this, SolutionBuilder.none, this.name, this.priority) { + def apply(sctx: SynthesisContext): RuleApplication = { + + val grammar = ExpressionGrammars.default(sctx, p) + + var tests = p.getTests(sctx).map(_.ins).distinct + if (tests.nonEmpty) { val evalParams = CodeGenParams(maxFunctionInvocations = 2000, checkContracts = true) //val evaluator = new CodeGenEvaluator(sctx.context, sctx.program, evalParams) @@ -117,10 +127,10 @@ case object TEGIS extends Rule("TEGIS") { } RuleClosed(toStream()) + } else { + RuleFailed() } - }) - } else { - Nil - } + } + }) } } diff --git a/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala b/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala index 3072a6999171bd26c2e536d396d613fbf6569cd8..f20fa982b131ee7b8cefe3a4f8ef7fefba4f2cf4 100644 --- a/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala +++ b/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala @@ -112,7 +112,7 @@ object ExpressionGrammars { } } - case class SimilarTo(e: Expr) extends ExpressionGrammar { + case class SimilarTo(e: Expr, exclude: Set[Expr] = Set()) extends ExpressionGrammar { lazy val allSimilar = computeSimilar(e).groupBy(_._1).mapValues(_.map(_._2)) def computeProductions(t: TypeTree): Seq[Gen] = { @@ -121,29 +121,39 @@ object ExpressionGrammars { def computeSimilar(e : Expr) : Seq[(TypeTree, Gen)] = { - def gen(tp : TypeTree, retType : TypeTree, f : Seq[Expr] => Expr) : (TypeTree, Gen) = - (retType, Generator[TypeTree, Expr](Seq(tp),f)) + var seenSoFar = exclude; + + def gen(retType : TypeTree, tps : Seq[TypeTree], f : Seq[Expr] => Expr) : (TypeTree, Gen) = + (bestRealType(retType), Generator[TypeTree, Expr](tps.map(bestRealType), f)) // A generator that always regenerates its input - def const(e: Expr) = ( e.getType, Generator[TypeTree, Expr](Seq(), _ => e) ) + def const(e: Expr) = ( bestRealType(e.getType), Generator[TypeTree, Expr](Seq(), _ => e) ) def rec(e : Expr) : Seq[(TypeTree, Gen)] = { - val tp = e.getType - const(e) +: (e match { - case _: Terminal | _: Let | _: LetTuple | _: LetDef | _: MatchExpr => - Seq() - case UnaryOperator(sub, builder) => Seq( - gen( sub.getType, tp, { case Seq(ex) => builder(ex) } ) - ) ++ rec(sub) - case BinaryOperator(sub1, sub2, builder) => Seq( - gen( sub1.getType, tp, { case Seq(ex) => builder(ex, sub2) } ), - gen( sub2.getType, tp, { case Seq(ex) => builder(sub1, ex) } ) - ) ++ rec(sub1) ++ rec(sub2) - case NAryOperator(subs, builder) => - (for ((sub,index) <- subs.zipWithIndex) yield { - gen( sub.getType, tp, { case Seq(ex) => builder(subs updated (index, ex) )} ) - }) ++ subs.flatMap(rec) - }) + if (seenSoFar contains e) { + Seq() + } else { + seenSoFar += e + val tp = e.getType + val self: Seq[(TypeTree, Gen)] = e match { + case RepairHole(_, _) => Seq() + case _ => Seq(const(e)) + } + val subs: Seq[(TypeTree, Gen)] = e match { + case _: Terminal | _: Let | _: LetTuple | _: LetDef | _: MatchExpr => + Seq() + case UnaryOperator(sub, builder) => Seq( + gen(tp, List(sub.getType), { case Seq(ex) => builder(ex) } ) + ) ++ rec(sub) + case BinaryOperator(sub1, sub2, builder) => Seq( + gen(tp, List(sub1.getType, sub2.getType), { case Seq(e1, e2) => builder(e1, e2) } ) + ) ++ rec(sub1) ++ rec(sub2) + case NAryOperator(subs, builder) => + Seq(gen(tp, subs.map(_.getType), builder)) ++ subs.flatMap(rec) + } + + self ++ subs + } } rec(e).tail // Don't want the expression itself @@ -159,16 +169,9 @@ object ExpressionGrammars { val isRecursiveCall = (prog.callGraph.transitiveCallers(cfd) + cfd) contains fd - val isNotSynthesizable = fd.body match { - case Some(b) => - !containsChoose(b) - - case None => - false - } - + val isDet = fd.body.map(isDeterministic).getOrElse(false) - if (!isRecursiveCall && isNotSynthesizable) { + if (!isRecursiveCall && isDet) { val free = fd.tparams.map(_.tp) canBeSubtypeOf(fd.returnType, free, t) match { case Some(tpsMap) => diff --git a/src/main/scala/leon/utils/Simplifiers.scala b/src/main/scala/leon/utils/Simplifiers.scala index da10dd8f2cfb0743ddc6392f269bfc0708d6135f..6849a97dcfc8b933cf50854c8b18d9fd6da8999f 100644 --- a/src/main/scala/leon/utils/Simplifiers.scala +++ b/src/main/scala/leon/utils/Simplifiers.scala @@ -50,8 +50,8 @@ object Simplifiers { matchToIfThenElse _, simplifyPaths(uninterpretedZ3)(_), patternMatchReconstruction _, - rewriteTuples _, - normalizeExpression _ + rewriteTuples _//, + //normalizeExpression _ ) val simple = { expr: Expr => diff --git a/testcases/synthesis/repair/SortedList/SortedList1.scala b/testcases/synthesis/repair/SortedList/SortedList1.scala index 39ca45d4957d52fc4df7719e1ef809ab20c0f05d..601912906fe71e3e5b91d7c9f2a82d197aa9fce0 100644 --- a/testcases/synthesis/repair/SortedList/SortedList1.scala +++ b/testcases/synthesis/repair/SortedList/SortedList1.scala @@ -41,8 +41,7 @@ object SortedList { case Nil() => l case Cons(_, Nil()) => l case _ => - val (l1, l2) = split(l) - merge(l1,l2) // FIXME: Forgot to mergeSort l1 and l2 + merge(split(l)._1, split(l)._2) // FIXME: Forgot to mergeSort l1 and l2 }} ensuring { res => res.content == l.content && isSorted(res)