-
Etienne Kneuss authoredEtienne Kneuss authored
ADTSplit.scala 2.33 KiB
package leon
package synthesis
package rules
import solvers.TimeoutSolver
import purescala.Trees._
import purescala.Common._
import purescala.TypeTrees._
import purescala.TreeOps._
import purescala.Extractors._
import purescala.Definitions._
case object ADTSplit extends Rule("ADT Split.") {
def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation]= {
val solver = new TimeoutSolver(sctx.solver, 100L) // We give that 100ms
val candidates = p.as.collect {
case IsTyped(id, AbstractClassType(cd)) =>
val optCases = for (dcd <- cd.knownDescendents.sortBy(_.id.name)) yield dcd match {
case ccd : CaseClassDef =>
val toSat = And(p.pc, Not(CaseClassInstanceOf(ccd, Variable(id))))
val isImplied = solver.solveSAT(toSat) match {
case (Some(false), _) => true
case _ => false
}
if (!isImplied) {
Some(ccd)
} else {
None
}
case _ =>
None
}
val cases = optCases.flatten
if (!cases.isEmpty) {
Some((id, cases))
} else {
None
}
}
candidates.collect{ _ match {
case Some((id, cases)) =>
val oas = p.as.filter(_ != id)
val subInfo = for(ccd <- cases) yield {
val args = ccd.fieldsIds.map(id => FreshIdentifier(id.name, true).setType(id.getType)).toList
val subPhi = subst(id -> CaseClass(ccd, args.map(Variable(_))), p.phi)
val subProblem = Problem(args ::: oas, p.pc, subPhi, p.xs)
val subPattern = CaseClassPattern(None, ccd, args.map(id => WildcardPattern(Some(id))))
(ccd, subProblem, subPattern)
}
val onSuccess: List[Solution] => Option[Solution] = {
case sols =>
var globalPre = List[Expr]()
val cases = for ((sol, (ccd, problem, pattern)) <- (sols zip subInfo)) yield {
globalPre ::= And(CaseClassInstanceOf(ccd, Variable(id)), sol.pre)
SimpleCase(pattern, sol.term)
}
Some(Solution(Or(globalPre), sols.flatMap(_.defs).toSet, MatchExpr(Variable(id), cases)))
}
Some(RuleInstantiation.immediateDecomp(p, this, subInfo.map(_._2).toList, onSuccess))
case _ =>
None
}}.flatten
}
}