diff --git a/src/main/scala/leon/purescala/ScalaPrinter.scala b/src/main/scala/leon/purescala/ScalaPrinter.scala index 269f612b8f53ec60820c05c0b417fb4f4fb7de35..31f3e7305c11104972c33e42655a1db687b70ae5 100644 --- a/src/main/scala/leon/purescala/ScalaPrinter.scala +++ b/src/main/scala/leon/purescala/ScalaPrinter.scala @@ -376,6 +376,11 @@ object ScalaPrinter { } case WildcardPattern(None) => sb.append("_") case WildcardPattern(Some(id)) => sb.append(id) + case InstanceOfPattern(bndr, ccd) => { + var nsb = sb + bndr.foreach(b => nsb.append(b + " : ")) + nsb.append(ccd.id) + } case TuplePattern(bndr, subPatterns) => { bndr.foreach(b => sb.append(b + " @ ")) sb.append("(") diff --git a/src/main/scala/leon/purescala/TreeOps.scala b/src/main/scala/leon/purescala/TreeOps.scala index e5b8d1f5827e2a7d8aa35c5e527fa1145785ebc6..2b5cda8ed885e2bc73b6129e602e7990d77931ba 100644 --- a/src/main/scala/leon/purescala/TreeOps.scala +++ b/src/main/scala/leon/purescala/TreeOps.scala @@ -1040,6 +1040,39 @@ object TreeOps { simplePreTransform(pre)(e) } + /* + * Transforms complicated Ifs into multiple nested if blocks + * It will decompose every OR clauses, and it will group AND clauses checking + * isInstanceOf toghether. + * + * if (a.isInstanceof[T1] && a.tail.isInstanceof[T2] && a.head == a2 || C) { + * T + * } else { + * E + * } + * + * Becomes: + * + * if (a.isInstanceof[T1] && a.tail.isInstanceof[T2]) { + * if (a.head == a2) { + * T + * } else { + * if(C) { + * T + * } else { + * E + * } + * } + * } else { + * if(C) { + * T + * } else { + * E + * } + * } + * + * This transformation runs immediately before patternMatchReconstruction. + */ def decomposeIfs(e: Expr): Expr = { def pre(e: Expr): Expr = e match { case IfExpr(cond, then, elze) => @@ -1065,25 +1098,92 @@ object TreeOps { simplePreTransform(pre)(e) } + // This transformation assumes IfExpr of the form generated by decomposeIfs def patternMatchReconstruction(e: Expr): Expr = { - case class PMContext() - - def pre(e: Expr, c: PMContext): (Expr, PMContext) = e match { + def pre(e: Expr): Expr = e match { case IfExpr(cond, then, elze) => - val TopLevelOrs(cases) = toDNF(cond) - // find one variable on which we will match: - val casematches = for (caze <- cases) yield { - val TopLevelAnds(conds) = caze + val TopLevelAnds(cases) = cond + + if (cases.forall(_.isInstanceOf[CaseClassInstanceOf])) { + // matchingOn might initially be: a : T1, a.tail : T2, b: T2 + def selectorDepth(e: Expr): Int = e match { + case v: Variable => + 0 + case cd: CaseClassSelector => + 1+selectorDepth(cd.caseClass) + } - conds.filter(_.isInstanceOf[CaseClassInstanceOf]) - } + var scrutSet = Set[Expr]() + var conditions = Map[Expr, CaseClassDef]() + + var matchingOn = cases.collect { case cc : CaseClassInstanceOf => cc } sortBy(cc => selectorDepth(cc.expr)) + for (CaseClassInstanceOf(cd, expr) <- matchingOn) { + conditions += expr -> cd + + expr match { + case v: Variable => + scrutSet += v + case cd: CaseClassSelector => + if (!scrutSet.contains(cd.caseClass)) { + // we found a test looking like "a.foo.isInstanceof[..]" + // without a check on "a". + scrutSet += cd + } + } + } + + var substMap = Map[Expr, Expr]() - (e, c) + + def computePatternFor(cd: CaseClassDef, prefix: Expr): Pattern = { + + val id = prefix match { + case CaseClassSelector(_, _, id) => id + case Variable(id) => id + } + + val binder = FreshIdentifier(id.name, true).setType(id.getType) // Is it full of women though? + + // prefix becomes binder + substMap += prefix -> Variable(binder) + substMap += CaseClassInstanceOf(cd, prefix) -> BooleanLiteral(true) + + val subconds = for (id <- cd.fieldsIds) yield { + val fieldSel = CaseClassSelector(cd, prefix, id) + if (conditions contains fieldSel) { + computePatternFor(conditions(fieldSel), fieldSel) + } else { + WildcardPattern(None) + } + } + + if (subconds.forall(_.isInstanceOf[WildcardPattern])) { + // nothing to check underneath + InstanceOfPattern(Some(binder), cd) + } else { + CaseClassPattern(Some(binder), cd, subconds) + } + } + + val (scrutinees, patterns) = scrutSet.toSeq.map(s => (s, computePatternFor(conditions(s), s))) unzip + + val (scrutinee, pattern) = if (scrutinees.size > 1) { + (Tuple(scrutinees), TuplePattern(None, patterns)) + } else { + (scrutinees.head, patterns.head) + } + + val newThen = searchAndReplace(substMap.get)(then) + + MatchExpr(scrutinee, Seq(SimpleCase(pattern, newThen), SimpleCase(WildcardPattern(None), elze))) + } else { + e + } case _ => - (e, c) + e } - genericTransform[PMContext](pre, (_, _), noCombiner)(PMContext())(e)._1 + simplePreTransform(pre)(e) } def simplifyTautologies(solver : Solver)(expr : Expr) : Expr = { diff --git a/src/main/scala/leon/synthesis/SynthesisPhase.scala b/src/main/scala/leon/synthesis/SynthesisPhase.scala index 3f1c80a8f6c7f7a55f620177f449147eab1d56e7..6e809dc472e39206022ae413aea8964dfea58afd 100644 --- a/src/main/scala/leon/synthesis/SynthesisPhase.scala +++ b/src/main/scala/leon/synthesis/SynthesisPhase.scala @@ -45,7 +45,9 @@ object SynthesisPhase extends LeonPhase[Program, Program] { val simplifiers = List[Expr => Expr]( simplifyTautologies(uninterpretedZ3)(_), simplifyLets _, - decomposeIfs _ + decomposeIfs _, + patternMatchReconstruction _, + simplifyTautologies(uninterpretedZ3)(_) ) val chooseToExprs = solutions.mapValues(sol => simplifiers.foldLeft(sol.toExpr){ (x, sim) => sim(x) })