Skip to content
Snippets Groups Projects
Commit 46fabd14 authored by Etienne Kneuss's avatar Etienne Kneuss
Browse files

Starting ADTInduction

parent 4340b6ed
No related branches found
No related tags found
No related merge requests found
...@@ -301,4 +301,8 @@ object Extractors { ...@@ -301,4 +301,8 @@ object Extractors {
} }
} }
object IsTyped {
def unapply[T <: Typed](e: T): Option[(T, TypeTree)] = Some((e, e.getType))
}
} }
...@@ -11,7 +11,8 @@ import purescala.Definitions._ ...@@ -11,7 +11,8 @@ import purescala.Definitions._
object Heuristics { object Heuristics {
def all = Set[Synthesizer => Rule]( def all = Set[Synthesizer => Rule](
new IntInduction(_), new IntInduction(_),
new OptimisticInjection(_) new OptimisticInjection(_),
new ADTInduction(_)
) )
} }
...@@ -26,7 +27,7 @@ object HeuristicStep { ...@@ -26,7 +27,7 @@ object HeuristicStep {
def verifyPre(synth: Synthesizer, problem: Problem)(s: Solution): (Solution, Boolean) = { def verifyPre(synth: Synthesizer, problem: Problem)(s: Solution): (Solution, Boolean) = {
synth.solver.solveSAT(And(Not(s.pre), problem.phi)) match { synth.solver.solveSAT(And(Not(s.pre), problem.phi)) match {
case (Some(true), model) => case (Some(true), model) =>
synth.reporter.warning("Heuristic failed to produce strongest precondition:") synth.reporter.warning("Heuristic failed to produce weakest precondition:")
synth.reporter.warning(" - problem: "+problem) synth.reporter.warning(" - problem: "+problem)
synth.reporter.warning(" - precondition: "+s.pre) synth.reporter.warning(" - precondition: "+s.pre)
(s, false) (s, false)
...@@ -46,7 +47,7 @@ class IntInduction(synth: Synthesizer) extends Rule("Int Induction", synth, 80) ...@@ -46,7 +47,7 @@ class IntInduction(synth: Synthesizer) extends Rule("Int Induction", synth, 80)
val p = task.problem val p = task.problem
p.as match { p.as match {
case List(origId) if origId.getType == Int32Type => case List(IsTyped(origId, Int32Type)) =>
val tpe = TupleType(p.xs.map(_.getType)) val tpe = TupleType(p.xs.map(_.getType))
val inductOn = FreshIdentifier(origId.name, true).setType(origId.getType) val inductOn = FreshIdentifier(origId.name, true).setType(origId.getType)
...@@ -167,3 +168,116 @@ class SelectiveInlining(synth: Synthesizer) extends Rule("Sel. Inlining", synth, ...@@ -167,3 +168,116 @@ class SelectiveInlining(synth: Synthesizer) extends Rule("Sel. Inlining", synth,
} }
} }
class ADTInduction(synth: Synthesizer) extends Rule("ADT Induction", synth, 80) with Heuristic {
def applyOn(task: Task): RuleResult = {
val p = task.problem
val candidates = p.as.collect {
case IsTyped(origId, AbstractClassType(cd)) => (origId, cd)
}
if (!candidates.isEmpty) {
val (origId, cd) = candidates.head
val oas = p.as.filterNot(_ == origId)
val resType = TupleType(p.xs.map(_.getType))
val inductOn = FreshIdentifier(origId.name, true).setType(origId.getType)
val residualArgs = oas.map(id => FreshIdentifier(id.name, true).setType(id.getType))
val residualMap = (oas zip residualArgs).map{ case (id, id2) => id -> Variable(id2) }.toMap
val residualArgDefs = residualArgs.map(a => VarDecl(a, a.getType))
def isAlternativeRecursive(cd: CaseClassDef): Boolean = {
cd.fieldsIds.exists(_.getType == origId.getType)
}
val isRecursive = cd.knownDescendents.exists {
case ccd: CaseClassDef => isAlternativeRecursive(ccd)
case _ => false
}
if (isRecursive) {
val newFun = new FunDef(FreshIdentifier("rec", true), resType, VarDecl(inductOn, inductOn.getType) +: residualArgDefs)
val innerPhi = substAll(residualMap + (origId -> Variable(inductOn)), p.phi)
val subProblemsInfo = for (dcd <- cd.knownDescendents) yield dcd match {
case ccd : CaseClassDef =>
var recCalls = Map[List[Identifier], Expr]()
var postFs = List[Expr]()
val newIds = ccd.fieldsIds.map(id => FreshIdentifier(id.name, true).setType(id.getType)).toList
val inputs = (for (id <- newIds) yield {
if (id.getType == origId.getType) {
val postXs = p.xs map (id => FreshIdentifier("r", true).setType(id.getType))
val postXsMap = (p.xs zip postXs).toMap.mapValues(Variable(_))
recCalls += postXs -> FunctionInvocation(newFun, Variable(id) +: residualArgs.map(id => Variable(id)))
postFs ::= substAll(postXsMap + (origId -> Variable(id)), innerPhi)
id :: postXs
} else {
List(id)
}
}).flatten
val subPhi = substAll(Map(inductOn -> CaseClass(ccd, newIds.map(Variable(_)))), innerPhi)
val subPre = CaseClassInstanceOf(ccd, Variable(origId))
val subProblem = Problem(inputs ::: residualArgs, And(p.c :: postFs), subPhi, p.xs)
(subProblem, subPre, recCalls)
case _ =>
sys.error("Woops, non case-class as descendent")
}
val onSuccess: List[Solution] => Solution = {
case sols =>
var globalPre = List[Expr]()
for ((sol, (problem, pre, calls)) <- (sols zip subProblemsInfo)) {
globalPre ::= And(pre, sol.pre)
}
val funPre = subst(origId -> Variable(inductOn), Or(globalPre))
val outerPre = funPre
/*
solPre ::= And(pre, sol.pre)
val preIn = Or(Seq(And(Equals(Variable(inductOn), IntLiteral(0)), base.pre),
And(GreaterThan(Variable(inductOn), IntLiteral(0)), gt.pre),
And(LessThan(Variable(inductOn), IntLiteral(0)), lt.pre)))
val preOut = subst(inductOn -> Variable(origId), preIn)
val newFun = new FunDef(FreshIdentifier("rec", true), tpe, Seq(VarDecl(inductOn, inductOn.getType)))
newFun.precondition = Some(preIn)
newFun.postcondition = Some(And(Equals(ResultVariable(), Tuple(p.xs.map(Variable(_)))), p.phi))
newFun.body = Some(
IfExpr(Equals(Variable(inductOn), IntLiteral(0)),
base.toExpr,
IfExpr(GreaterThan(Variable(inductOn), IntLiteral(0)),
LetTuple(postXs, FunctionInvocation(newFun, Seq(Minus(Variable(inductOn), IntLiteral(1)))), gt.toExpr)
, LetTuple(postXs, FunctionInvocation(newFun, Seq(Plus(Variable(inductOn), IntLiteral(1)))), lt.toExpr)))
)
Solution(preOut, base.defs++gt.defs++lt.defs+newFun, FunctionInvocation(newFun, Seq(Variable(origId))))
*/
Solution.none
}
println(subProblemsInfo)
RuleInapplicable
} else {
RuleInapplicable
}
} else {
RuleInapplicable
}
}
}
...@@ -24,8 +24,8 @@ object Rules { ...@@ -24,8 +24,8 @@ object Rules {
new OptimisticGround(_), new OptimisticGround(_),
new EqualitySplit(_), new EqualitySplit(_),
new CEGIS(_), new CEGIS(_),
new Assert(_), new Assert(_)
new IntegerEquation(_) //new IntegerEquation(_)
) )
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment