diff --git a/src/main/scala/leon/synthesis/Problem.scala b/src/main/scala/leon/synthesis/Problem.scala index e3c818354cf70d34aa0d9dd0926831d4600dd0d6..4127a28f5ebda80ecbfec4c3960b9d8fa18680b3 100644 --- a/src/main/scala/leon/synthesis/Problem.scala +++ b/src/main/scala/leon/synthesis/Problem.scala @@ -24,10 +24,15 @@ import Witnesses._ */ case class Problem(as: List[Identifier], ws: Expr, pc: Path, phi: Expr, xs: List[Identifier], eb: ExamplesBank = ExamplesBank.empty) extends Printable { + // Activate this for debugging... + // assert(eb.examples.forall(_.ins.size == as.size)) + + val TopLevelAnds(wsList) = ws + def inType = tupleTypeWrap(as.map(_.getType)) def outType = tupleTypeWrap(xs.map(_.getType)) - def allAs = as ++ pc.bindings.map(_._1) + def allAs = as ++ (pc.bindings.map(_._1) diff wsList.collect{ case Inactive(i) => i }) def asString(implicit ctx: LeonContext): String = { val pcws = pc withCond ws @@ -41,9 +46,8 @@ case class Problem(as: List[Identifier], ws: Expr, pc: Path, phi: Expr, xs: List |⟧ $ebInfo""".stripMargin } - def withWs(es: Seq[Expr]) = { - val TopLevelAnds(prev) = ws - copy(ws = andJoin(prev ++ es)) + def withWs(es: Traversable[Expr]) = { + copy(ws = andJoin(wsList ++ es)) } // Qualified example bank, allows us to perform operations (e.g. filter) with expressions diff --git a/src/main/scala/leon/synthesis/Witnesses.scala b/src/main/scala/leon/synthesis/Witnesses.scala index b2359b1d406c838a0243e12f1659441774d83915..98b14905e9f740baea0bddad25aa96ebefd1cb0b 100644 --- a/src/main/scala/leon/synthesis/Witnesses.scala +++ b/src/main/scala/leon/synthesis/Witnesses.scala @@ -2,6 +2,7 @@ package leon.synthesis +import leon.purescala.Common.Identifier import leon.purescala._ import Types._ import Extractors._ @@ -38,5 +39,12 @@ object Witnesses { p"谶$e" } } - + + case class Inactive(i: Identifier) extends Witness { + def extract: Option[(Seq[Expr], Seq[Expr] => Expr)] = Some((Seq(), _ => this )) + override def printWith(implicit pctx: PrinterContext): Unit = { + p"inactive($i)" + } + + } } diff --git a/src/main/scala/leon/synthesis/rules/ADTSplit.scala b/src/main/scala/leon/synthesis/rules/ADTSplit.scala index 71f564859a8ec9a577a20790559c56e903b512b7..317f63807bf990624160e2f0cca40a332c3be879 100644 --- a/src/main/scala/leon/synthesis/rules/ADTSplit.scala +++ b/src/main/scala/leon/synthesis/rules/ADTSplit.scala @@ -4,7 +4,8 @@ package leon package synthesis package rules -import Witnesses.Hint +import Witnesses._ + import purescala.Expressions._ import purescala.Common._ import purescala.Types._ @@ -12,10 +13,21 @@ import purescala.ExprOps._ import purescala.Extractors._ import purescala.Constructors._ import purescala.Definitions._ +import evaluators.DefaultEvaluator /** Abstract data type split. If a variable is typed as an abstract data type, then * it will create a match case statement on all known subtypes. */ case object ADTSplit extends Rule("ADT Split.") { + + protected class NoChooseEvaluator(ctx: LeonContext, prog: Program) extends DefaultEvaluator(ctx, prog) { + override def e(expr: Expr)(implicit rctx: RC, gctx: GC): Expr = expr match { + case ch: Choose => + throw new EvalError("Choose!") + case _ => + super.e(expr) + } + } + def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { // We approximate knowledge of types based on facts found at the top-level // we don't care if the variables are known to be equal or not, we just @@ -32,7 +44,7 @@ case object ADTSplit extends Rule("ADT Split.") { instChecks.toMap ++ boundCcs } - val candidates = p.as.collect { + val candidates = p.allAs.collect { case IsTyped(id, act @ AbstractClassType(cd, tpes)) => val optCases = cd.knownDescendants.sortBy(_.id.name).collect { @@ -63,27 +75,50 @@ case object ADTSplit extends Rule("ADT Split.") { case Some((id, act, cases)) => val oas = p.as.filter(_ != id) + val evaluator = new NoChooseEvaluator(hctx, hctx.program) + val subInfo0 = for(ccd <- cases) yield { - val cct = CaseClassType(ccd, act.tps) + val isInputVar = p.as.contains(id) + val cct = CaseClassType(ccd, act.tps) - val args = cct.fields.map { vd => FreshIdentifier(vd.id.name, vd.getType, true) }.toList + val args = cct.fields.map { vd => FreshIdentifier(vd.id.name, vd.getType, true) }.toList - val whole = CaseClass(cct, args.map(Variable)) + val whole = CaseClass(cct, args.map(Variable)) val subPhi = subst(id -> whole, p.phi) - val subPC = p.pc map (subst(id -> whole, _)) - val subWS = subst(id -> whole, p.ws) - - val eb2 = p.qeb.mapIns { inInfo => - inInfo.toMap.apply(id) match { - case CaseClass(`cct`, vs) => - List(vs ++ inInfo.filter(_._1 != id).map(_._2)) - case _ => - Nil - } + val subPC = { + val withSubst = p.pc map (subst(id -> whole, _)) + if (isInputVar) withSubst + else { + val mapping = cct.classDef.fields.zip(args).map { + case (f, a) => a -> caseClassSelector(cct, Variable(id), f.id) + } + withSubst.withBindings(mapping).withCond(isInstOf(id.toVariable, cct)) + } } - - val subProblem = Problem(args ::: oas, subWS, subPC, subPhi, p.xs, eb2).withWs(Seq(Hint(whole))) + val subWS = subst(id -> whole, p.ws) + + val eb2 = { + if (isInputVar) { + // Filter out examples where id has the wrong type, and fix input variables + p.qeb.mapIns { inInfo => + inInfo.toMap.apply(id) match { + case CaseClass(`cct`, vs) => + List(vs ++ inInfo.filter(_._1 != id).map(_._2)) + case _ => + Nil + } + } + } else { + // Filter out examples where id has the wrong type + p.qeb.filterIns { inValues => + evaluator.eval(id.toVariable, inValues ++ p.pc.bindings).result.exists(_.getType == cct) + }.eb + } + } + val newAs = if (isInputVar) args ::: oas else p.as + val inactive = (!isInputVar).option(Inactive(id)) + val subProblem = Problem(newAs, subWS, subPC, subPhi, p.xs, eb2).withWs(Seq(Hint(whole)) ++ inactive) val subPattern = CaseClassPattern(None, cct, args.map(id => WildcardPattern(Some(id)))) (cct, subProblem, subPattern) diff --git a/src/main/scala/leon/synthesis/rules/CEGISLike.scala b/src/main/scala/leon/synthesis/rules/CEGISLike.scala index 933ae752b19eff0ff816718ef7e6c4ee79560f22..24a2010783fc27b4c86dc3ceab035226486c6771 100644 --- a/src/main/scala/leon/synthesis/rules/CEGISLike.scala +++ b/src/main/scala/leon/synthesis/rules/CEGISLike.scala @@ -451,13 +451,22 @@ abstract class CEGISLike(name: String) extends Rule(name) { cTreeFd.fullBody = innerSol timers.testForProgram.start() + + def withBindings(e: Expr) = p.pc.bindings.foldRight(e){ + case ((id, v), bd) => let(id, outerExprToInnerExpr(v), bd) + } + + val boundCnstr = withBindings(cnstr) + val res = ex match { case InExample(ins) => - evaluator.eval(cnstr, p.as.zip(ins).toMap ++ p.pc.bindings) + evaluator.eval(boundCnstr, p.as.zip(ex.ins).toMap) case InOutExample(ins, outs) => - val eq = equality(innerSol, tupleWrap(outs)) - evaluator.eval(eq, p.as.zip(ins).toMap ++ p.pc.bindings) + evaluator.eval( + withBindings(equality(innerSol, tupleWrap(outs))), + p.as.zip(ex.ins).toMap + ) } timers.testForProgram.stop() @@ -468,15 +477,6 @@ abstract class CEGISLike(name: String) extends Rule(name) { Some(res == BooleanLiteral(true)) case EvaluationResults.RuntimeError(err) => - /*if (err.contains("Empty production rule")) { - println(programCTree.asString) - println(bValues) - println(ex) - println(this.getExpr(bValues)) - (new Throwable).printStackTrace() - println(err) - println() - }*/ debug("RE testing CE: "+err) Some(false) @@ -783,7 +783,7 @@ abstract class CEGISLike(name: String) extends Rule(name) { /** * We (lazily) generate additional tests for discarding potential programs with a data generator */ - val nTests = if (p.pc == BooleanLiteral(true)) 50 else 20 + val nTests = if (p.pc.isEmpty) 50 else 20 val inputGenerator: Iterator[Example] = { val complicated = exists{ @@ -850,10 +850,11 @@ abstract class CEGISLike(name: String) extends Rule(name) { debug(s"#Tests: >= ${gi.bufferedCount}") ifDebug{ printer => - for (e <- baseExampleInputs.take(10)) { + val es = allInputExamples() + for (e <- es.take(Math.min(gi.bufferedCount, 10))) { printer(" - "+e.asString) } - if(baseExampleInputs.size > 10) { + if(es.hasNext) { printer(" - ...") } } diff --git a/src/main/scala/leon/synthesis/rules/DetupleInput.scala b/src/main/scala/leon/synthesis/rules/DetupleInput.scala index 57e85d7b0ce2a226781dbe8f982562666eaaa1c2..a08f10479aef5e51780829b29886a44522988d86 100644 --- a/src/main/scala/leon/synthesis/rules/DetupleInput.scala +++ b/src/main/scala/leon/synthesis/rules/DetupleInput.scala @@ -4,11 +4,11 @@ package leon package synthesis package rules -import Witnesses.Hint +import Witnesses._ import purescala.Expressions._ import purescala.Common._ import purescala.Types._ -import purescala.ExprOps.simplePostTransform +import purescala.ExprOps._ import purescala.Constructors._ import purescala.Extractors.LetPattern @@ -22,12 +22,14 @@ case object DetupleInput extends NormalizingRule("Detuple In") { def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { /** Returns true if this identifier is a tuple or a case class */ - def isDecomposable(id: Identifier) = id.getType match { + def typeCompatible(id: Identifier) = id.getType match { case CaseClassType(t, _) if !t.isAbstract => true case TupleType(ts) => true case _ => false } + def isDecomposable(id: Identifier) = typeCompatible(id) && !p.wsList.contains(Inactive(id)) + /* Decomposes a decomposable input identifier (eg of type Tuple or case class) * into a list of fresh typed identifiers, the tuple of these new identifiers, * and the mapping of those identifiers to their respective expressions. @@ -54,45 +56,52 @@ case object DetupleInput extends NormalizingRule("Detuple In") { case _ => sys.error("woot") } - if (p.as.exists(isDecomposable)) { + if (p.allAs.exists(isDecomposable)) { var subProblem = p.phi var subPc = p.pc var subWs = p.ws var hints: Seq[Expr] = Nil var patterns = List[(Identifier, Pattern)]() var revMap = Map[Expr, Expr]().withDefault((e: Expr) => e) + var inactive = Set[Identifier]() var ebMapInfo = Map[Identifier, Expr => Seq[Expr]]() - val subAs = p.as.map { a => + val subAs = p.allAs.map { a => if (isDecomposable(a)) { val (newIds, expr, tMap) = decompose(a) + val patts = newIds map (id => WildcardPattern(Some(id))) + val patt = a.getType match { + case TupleType(_) => + TuplePattern(None, patts) + case cct: CaseClassType => + CaseClassPattern(None, cct, patts) + } subProblem = subst(a -> expr, subProblem) - subPc = subPc map (subst(a -> expr, _)) + subPc = { + val withSubst = subPc map (subst(a -> expr, _)) + if (!p.pc.boundIds.contains(a)){ + withSubst + } else { + inactive += a + val mapping = mapForPattern(a.toVariable, patt) + withSubst.withBindings(mapping) + } + } subWs = subst(a -> expr, subWs) revMap += expr -> Variable(a) hints +:= Hint(expr) - val patts = newIds map (id => WildcardPattern(Some(id))) - - patterns +:= (( - a, - a.getType match { - case TupleType(_) => - TuplePattern(None, patts) - case cct: CaseClassType => - CaseClassPattern(None, cct, patts) - } - )) + patterns +:= a -> patt - ebMapInfo += a -> tMap + ebMapInfo += a -> tMap - newIds + a -> newIds } else { - List(a) + a -> List(a) } - } + }.toMap val eb = p.qeb.mapIns { info => List(info.flatMap { case (id, v) => @@ -105,11 +114,13 @@ case object DetupleInput extends NormalizingRule("Detuple In") { }) } - val newAs = subAs.flatten + val newAs = p.as.flatMap(subAs) val (as, patts) = patterns.unzip - val sub = Problem(newAs, subWs, subPc, subProblem, p.xs, eb).withWs(hints) + val sub = Problem(newAs, subWs, subPc, subProblem, p.xs, eb) + .withWs(hints) + .withWs(inactive.toSeq.map(Inactive)) val s = { (e: Expr) => val body = simplePostTransform(revMap)(e) diff --git a/src/main/scala/leon/synthesis/rules/GenericTypeEqualitySplit.scala b/src/main/scala/leon/synthesis/rules/GenericTypeEqualitySplit.scala index b9be6da6d0004817b0cf1ff54ec7f0635626de0d..6df6624bc2c54c9ef9637b9ca65ade8fe9fae977 100644 --- a/src/main/scala/leon/synthesis/rules/GenericTypeEqualitySplit.scala +++ b/src/main/scala/leon/synthesis/rules/GenericTypeEqualitySplit.scala @@ -4,11 +4,12 @@ package leon package synthesis package rules -import leon.purescala.Common.Identifier +import purescala.Common.Identifier import purescala.Constructors._ import purescala.Expressions._ -import leon.purescala.Extractors.{IsTyped, TopLevelAnds} +import purescala.Extractors.{IsTyped, TopLevelAnds} import purescala.Types._ +import Witnesses._ /** For every pair of input variables of the same generic type, * checks equality and output an If-Then-Else statement with the two new branches. @@ -41,17 +42,27 @@ case object GenericTypeEqualitySplit extends Rule("Eq. Split") { case (a1, a2) => val v1 = Variable(a1) val v2 = Variable(a2) - val subProblems = List( - p.copy(as = p.as.diff(Seq(a1)), - pc = p.pc map (subst(a1 -> v2, _)), - ws = subst(a1 -> v2, p.ws), - phi = subst(a1 -> v2, p.phi), - eb = p.qeb.filterIns(Equals(v1, v2)).removeIns(Set(a1))), - p.copy(pc = p.pc withCond not(Equals(v1, v2)), - eb = p.qeb.filterIns(not(Equals(v1, v2)))) + val (f, t, isInput) = if (p.as contains a1) (a1, v2, true) else (a2, v1, p.as contains a2) + val eq = if (isInput) { + p.copy( + as = p.as.diff(Seq(f)), + pc = p.pc map (subst(f -> t, _)), + ws = subst(f -> t, p.ws), + phi = subst(f -> t, p.phi), + eb = p.qeb.filterIns(Equals(v1, v2)).removeIns(Set(f)) + ) + } else { + p.copy(pc = p.pc withCond Equals(v1,v2)).withWs(Seq(Inactive(f))) // FIXME! + } + + val neq = p.copy( + pc = p.pc withCond not(Equals(v1, v2)), + eb = p.qeb.filterIns(not(Equals(v1, v2))) // FIXME! ) + val subProblems = List(eq, neq) + val onSuccess: List[Solution] => Option[Solution] = { case sols @ List(sEQ, sNE) => val pre = or( diff --git a/src/main/scala/leon/synthesis/rules/InequalitySplit.scala b/src/main/scala/leon/synthesis/rules/InequalitySplit.scala index e4e35b226a4e45a139eab44898dc789ee0c6f873..cacb44a4657ac15f1d1228026a521182501bb0e3 100644 --- a/src/main/scala/leon/synthesis/rules/InequalitySplit.scala +++ b/src/main/scala/leon/synthesis/rules/InequalitySplit.scala @@ -4,6 +4,7 @@ package leon package synthesis package rules +import leon.synthesis.Witnesses.Inactive import purescala.Expressions._ import purescala.Types._ import purescala.Constructors._ @@ -43,8 +44,8 @@ case object InequalitySplit extends Rule("Ineq. Split.") { } val facts: Set[Fact] = { - val TopLevelAnds(as) = andJoin(p.pc.conditions :+ p.phi) - as.toSet flatMap getFacts + val TopLevelAnds(fromPhi) = p.phi + (fromPhi.toSet ++ p.pc.conditions ++ p.pc.bindingsAsEqs) flatMap getFacts } val candidates = @@ -66,18 +67,24 @@ case object InequalitySplit extends Rule("Ineq. Split.") { val eq = if (!facts.contains(EQ(v1, v2)) && !facts.contains(EQ(v2,v1))) { val pc = Equals(v1, v2) - // One of v1, v2 will be an input variable - val a1 = (v1, v2) match { - case (Variable(a), _) => a - case (_, Variable(a)) => a + // Let's see if an input variable is involved + val (f, t, isInput) = (v1, v2) match { + case (Variable(a1), _) if p.as.contains(a1) => (a1, v2, true) + case (_, Variable(a2)) if p.as.contains(a2) => (a2, v1, true) + case (Variable(a1), _) => (a1, v2, false) } - val newP = p.copy( - as = p.as.diff(Seq(a1)), - pc = p.pc map (subst(a1 -> v2, _)), - ws = subst(a1 -> v2, p.ws), - phi = subst(a1 -> v2, p.phi), - eb = p.qeb.filterIns(Equals(v1, v2)).removeIns(Set(a1)) - ) + val newP = if (isInput) { + p.copy( + as = p.as.diff(Seq(f)), + pc = p.pc map (subst(f -> t, _)), + ws = subst(f -> t, p.ws), + phi = subst(f -> t, p.phi), + eb = p.qeb.filterIns(Equals(v1, v2)).removeIns(Set(f)) + ) + } else { + p.copy(pc = p.pc withCond pc).withWs(Seq(Inactive(f))) // equality in pc is fine for numeric types + } + Some(pc, newP) } else None diff --git a/src/main/scala/leon/synthesis/rules/IntroduceRecCalls.scala b/src/main/scala/leon/synthesis/rules/IntroduceRecCalls.scala index 4463e6818b2b7b102cf404bd80ca794d9b96c3c7..22af64c189a33afe9890fa3a62d8b7c74bb1da62 100644 --- a/src/main/scala/leon/synthesis/rules/IntroduceRecCalls.scala +++ b/src/main/scala/leon/synthesis/rules/IntroduceRecCalls.scala @@ -39,16 +39,16 @@ case object IntroduceRecCalls extends NormalizingRule("Introduce rec. calls") { val rec = FreshIdentifier("rec", newCall.getType, alwaysShowUniqueID = true) // Assume the postcondition of recursive call - val (bound, path) = if (specifyCalls) { - (true, Path.empty withBinding (rec -> newCall)) + val path = if (specifyCalls) { + Path.empty withBinding (rec -> newCall) } else { - (false, Path(application( + Path(application( newCall.tfd.withParamSubst(newCall.args, newCall.tfd.postOrTrue), Seq(rec.toVariable) - ))) + )) } - (rec, bound, path) + (rec, path) } val onSuccess = forwardMap(letTuple(recs.map(_._1), tupleWrap(calls), _)) @@ -65,28 +65,16 @@ case object IntroduceRecCalls extends NormalizingRule("Introduce rec. calls") { val origImpl = hctx.functionContext.fullBody hctx.functionContext.fullBody = psol - val evaluator = new NoChooseEvaluator(hctx, hctx.program) - def mapExample(ex: Example): List[Example] = { - val results = calls map (evaluator.eval(_, p.as.zip(ex.ins).toMap).result) - if (results forall (_.isDefined)) List({ - val extra = results map (_.get) - ex match { - case InExample(ins) => - InExample(ins ++ extra) - case InOutExample(ins, outs) => - InOutExample(ins ++ extra, outs) - } - }) else Nil - } + //val evaluator = new NoChooseEvaluator(hctx, hctx.program) val newWs = calls map Terminating val TopLevelAnds(ws) = p.ws try { val newProblem = p.copy( - as = p.as ++ recs.collect { case (r, false, _) => r }, - pc = recs.map(_._3).foldLeft(p.pc)(_ merge _), + as = p.as ++ (if (specifyCalls) Nil else recs.map(_._1)), + pc = recs.map(_._2).foldLeft(p.pc)(_ merge _), ws = andJoin(ws ++ newWs), - eb = p.eb.map(mapExample) + eb = p.qeb//.filterIns(filter _) ) RuleExpanded(List(newProblem)) diff --git a/src/main/scala/leon/synthesis/strategies/ManualStrategy.scala b/src/main/scala/leon/synthesis/strategies/ManualStrategy.scala index 809c48786a984c224aa5a0d4a8f5d09fec34ca57..9f821d593201b488dbf6fccdd59d51d812feb8c4 100644 --- a/src/main/scala/leon/synthesis/strategies/ManualStrategy.scala +++ b/src/main/scala/leon/synthesis/strategies/ManualStrategy.scala @@ -186,7 +186,6 @@ class ManualStrategy(ctx: LeonContext, initCmd: Option[String], strat: Strategy) } manualGetNext() - case Best => strat.bestNext(c) match { case Some(n) =>