diff --git a/src/main/scala/leon/repair/rules/Focus.scala b/src/main/scala/leon/repair/rules/Focus.scala index 562e1845f9099ab38da406d6704304fe9723bc3f..a77ed826a25bb8142d097c368525b11907b7b0de 100644 --- a/src/main/scala/leon/repair/rules/Focus.scala +++ b/src/main/scala/leon/repair/rules/Focus.scala @@ -4,9 +4,6 @@ package leon package repair package rules -import synthesis._ -import leon.evaluators._ - import purescala.Path import purescala.Expressions._ import purescala.Common._ @@ -15,8 +12,11 @@ import purescala.ExprOps._ import purescala.Constructors._ import purescala.Extractors._ -import Witnesses._ +import utils.fixpoint +import evaluators._ +import synthesis._ +import Witnesses._ import graph.AndNode case object Focus extends PreprocessingRule("Focus") { @@ -75,10 +75,7 @@ case object Focus extends PreprocessingRule("Focus") { } } - val fdSpec = { - val id = FreshIdentifier("res", fd.returnType) - Let(id, fd.body.get, application(fd.postOrTrue, Seq(id.toVariable))) - } + val TopLevelAnds(clauses) = p.ws @@ -93,17 +90,15 @@ case object Focus extends PreprocessingRule("Focus") { def ws(g: Expr) = andJoin(Guide(g) +: wss) - def testCondition(cond: Expr) = { - val ndSpec = postMap { - case c if c eq cond => Some(not(cond)) - case _ => None - }(fdSpec) - forAllTests(ndSpec, Map(), new AngelicEvaluator(new RepairNDEvaluator(hctx, program, cond))) + def testCondition(guide: IfExpr) = { + val IfExpr(cond, thenn, elze) = guide + val spec = letTuple(p.xs, IfExpr(Not(cond), thenn, elze), p.phi) + forAllTests(spec, Map(), new AngelicEvaluator(new RepairNDEvaluator(hctx, program, cond))) } guides.flatMap { case g @ IfExpr(c, thn, els) => - testCondition(c) match { + testCondition(g) match { case Some(true) => val cx = FreshIdentifier("cond", BooleanType) // Focus on condition @@ -146,33 +141,49 @@ case object Focus extends PreprocessingRule("Focus") { val map = mapForPattern(scrut, c.pattern) val thisCond = matchCaseCondition(scrut, c) + val prevPCSoFar = pcSoFar val cond = pcSoFar merge thisCond pcSoFar = pcSoFar merge thisCond.negate val subP = if (existsFailing(cond.toClause, map, evaluator)) { - val vars = map.toSeq.map(_._1) - // Filter tests by the path-condition - val eb2 = qeb.filterIns(cond.toClause) + val vars = map.keys.toSeq - // Augment test with the additional variables and their valuations - val ebF: (Seq[Expr] => List[Seq[Expr]]) = { (e: Seq[Expr]) => - val emap = (p.as zip e).toMap + val (p2e, _) = patternToExpression(c.pattern, scrut.getType) - evaluator.eval(tupleWrap(vars.map(map)), emap).result.map { r => - e ++ unwrapTuple(r, vars.size) - }.toList - } - - val eb3 = if (vars.nonEmpty) { - eb2.flatMapIns(ebF) + val substAs = ((scrut, p2e) match { + case (Variable(i), _) if p.as.contains(i) => Seq(i -> p2e) + case (Tuple(as), Tuple(tos)) => + val res = as.zip(tos) collect { + case (Variable(i), to) if p.as.contains(i) => i -> to + } + if (res.size == as.size) res else Nil + case _ => Nil + }).toMap + + if (substAs.nonEmpty) { + val subst = replaceFromIDs(substAs, (_:Expr)) + // FIXME intermediate binders?? + val newAs = (p.as diff substAs.keys.toSeq) ++ vars + val newPc = (p.pc merge prevPCSoFar) map subst + val newWs = subst(ws(c.rhs)) + val newPhi = subst(p.phi) + val eb2 = qeb.filterIns(cond.toClause) + val ebF: Seq[(Identifier, Expr)] => List[Seq[Expr]] = { (ins: Seq[(Identifier, Expr)]) => + val eval = evaluator.eval(tupleWrap(vars map Variable), map ++ ins) + val insWithout = ins.collect{ case (id, v) if !substAs.contains(id) => v } + eval.result.map(r => insWithout ++ unwrapTuple(r, vars.size)).toList + } + val newEb = eb2.flatMapIns(ebF) + Some(Problem(newAs, newWs, newPc, newPhi, p.xs, newEb)) } else { - eb2.eb - } + // Filter tests by the path-condition + val eb2 = qeb.filterIns(cond.toClause) - val newPc = cond withBindings vars.map(id => id -> map(id)) + val newPc = cond withBindings vars.map(id => id -> map(id)) - Some(Problem(p.as, ws(c.rhs), p.pc merge newPc, p.phi, p.xs, eb3)) + Some(Problem(p.as, ws(c.rhs), p.pc merge newPc, p.phi, p.xs, eb2)) + } } else { None }