diff --git a/src/main/scala/leon/synthesis/rules/ADTSplit.scala b/src/main/scala/leon/synthesis/rules/ADTSplit.scala index 482f4056478420680d87a7ca7663212824a7ce18..d5b03204d5448a25dbb14592492ca1ca31c553b3 100644 --- a/src/main/scala/leon/synthesis/rules/ADTSplit.scala +++ b/src/main/scala/leon/synthesis/rules/ADTSplit.scala @@ -48,7 +48,7 @@ case object ADTSplit extends Rule("ADT Split.") { } } - candidates.collect{ _ match { + candidates.collect { case Some((id, act, cases)) => val oas = p.as.filter(_ != id) @@ -72,7 +72,14 @@ case object ADTSplit extends Rule("ADT Split.") { var globalPre = List[Expr]() val cases = for ((sol, (cct, problem, pattern)) <- (sols zip subInfo)) yield { - globalPre ::= and(CaseClassInstanceOf(cct, Variable(id)), sol.pre) + if (sol.pre != BooleanLiteral(true)) { + val substs = (for ((field,arg) <- cct.fields zip problem.as ) yield { + (arg, CaseClassSelector(cct, id.toVariable, field.id)) + }).toMap + globalPre ::= and(CaseClassInstanceOf(cct, Variable(id)), replaceFromIDs(substs, sol.pre)) + } else { + globalPre ::= BooleanLiteral(true) + } SimpleCase(pattern, sol.term) } @@ -80,9 +87,7 @@ case object ADTSplit extends Rule("ADT Split.") { Some(Solution(orJoin(globalPre), sols.flatMap(_.defs).toSet, matchExpr(Variable(id), cases), sols.forall(_.isTrusted))) } - Some(decomp(subInfo.map(_._2).toList, onSuccess, s"ADT Split on '$id'")) - case _ => - None - }}.flatten + decomp(subInfo.map(_._2).toList, onSuccess, s"ADT Split on '$id'") + } } }