diff --git a/src/main/scala/leon/repair/rules/Focus.scala b/src/main/scala/leon/repair/rules/Focus.scala index 562e1845f9099ab38da406d6704304fe9723bc3f..a4bf840de307fc8d43a5e8522deafa584204a85b 100644 --- a/src/main/scala/leon/repair/rules/Focus.scala +++ b/src/main/scala/leon/repair/rules/Focus.scala @@ -4,9 +4,6 @@ package leon package repair package rules -import synthesis._ -import leon.evaluators._ - import purescala.Path import purescala.Expressions._ import purescala.Common._ @@ -15,8 +12,11 @@ import purescala.ExprOps._ import purescala.Constructors._ import purescala.Extractors._ -import Witnesses._ +import utils.fixpoint +import evaluators._ +import synthesis._ +import Witnesses._ import graph.AndNode case object Focus extends PreprocessingRule("Focus") { @@ -146,33 +146,64 @@ case object Focus extends PreprocessingRule("Focus") { val map = mapForPattern(scrut, c.pattern) val thisCond = matchCaseCondition(scrut, c) + val prevPCSoFar = pcSoFar val cond = pcSoFar merge thisCond pcSoFar = pcSoFar merge thisCond.negate val subP = if (existsFailing(cond.toClause, map, evaluator)) { - val vars = map.toSeq.map(_._1) - // Filter tests by the path-condition - val eb2 = qeb.filterIns(cond.toClause) + val vars = map.toSeq.map(_._1) - // Augment test with the additional variables and their valuations - val ebF: (Seq[Expr] => List[Seq[Expr]]) = { (e: Seq[Expr]) => - val emap = (p.as zip e).toMap + val (p2e, _) = patternToExpression(c.pattern, scrut.getType) - evaluator.eval(tupleWrap(vars.map(map)), emap).result.map { r => - e ++ unwrapTuple(r, vars.size) - }.toList - } + val substAs = ((scrut, p2e) match { + case (Variable(i), _) if p.as.contains(i) => Seq(i -> p2e) + case (Tuple(as), Tuple(tos)) => + val res = as.zip(tos) collect { + case (Variable(i), to) if p.as.contains(i) => i -> to + } + if (res.size == as.size) res else Nil + }).toMap - val eb3 = if (vars.nonEmpty) { - eb2.flatMapIns(ebF) + if (substAs.nonEmpty) { + val subst: Expr => Expr = { e => + replaceFromIDs(substAs, e) + } + // FIXME intermediate binders?? + val newAs = (p.as diff substAs.keys.toSeq) ++ vars + val newPc = (p.pc merge prevPCSoFar) map subst + val newWs = subst(ws(c.rhs)) + val newPhi = subst(p.phi) + val eb2 = qeb.filterIns(cond.toClause).removeIns(substAs.keySet) + val ebF: Seq[Expr] => List[Seq[Expr]] = { (ins: Seq[Expr]) => + val eval = evaluator.eval(tupleWrap(vars map Variable), p.as.zip(ins).toMap ++ map) + eval.result.map( r => ins ++ unwrapTuple(r, vars.size)).toList + } + val newEb = eb2 flatMapIns ebF + Some(Problem(newAs, newWs, newPc, newPhi, p.xs, newEb)) } else { - eb2.eb - } + // Filter tests by the path-condition + val eb2 = qeb.filterIns(cond.toClause) - val newPc = cond withBindings vars.map(id => id -> map(id)) + // Augment test with the additional variables and their valuations + val ebF: (Seq[Expr] => List[Seq[Expr]]) = { (e: Seq[Expr]) => + val emap = (p.as zip e).toMap + + evaluator.eval(tupleWrap(vars.map(map)), emap).result.map { r => + e ++ unwrapTuple(r, vars.size) + }.toList + } - Some(Problem(p.as, ws(c.rhs), p.pc merge newPc, p.phi, p.xs, eb3)) + val eb3 = if (vars.nonEmpty) { + eb2.flatMapIns(ebF) + } else { + eb2.eb + } + + val newPc = cond withBindings vars.map(id => id -> map(id)) + + Some(Problem(p.as, ws(c.rhs), p.pc merge newPc, p.phi, p.xs, eb3)) + } } else { None }