diff --git a/src/main/scala/leon/purescala/Extractors.scala b/src/main/scala/leon/purescala/Extractors.scala index ddc46ee306d945cf418a3722d5040dcd5c0cfb19..4d202f3b53eeb320141f0179af6c900d659f446d 100644 --- a/src/main/scala/leon/purescala/Extractors.scala +++ b/src/main/scala/leon/purescala/Extractors.scala @@ -301,4 +301,8 @@ object Extractors { } } + object IsTyped { + def unapply[T <: Typed](e: T): Option[(T, TypeTree)] = Some((e, e.getType)) + } + } diff --git a/src/main/scala/leon/synthesis/Heuristics.scala b/src/main/scala/leon/synthesis/Heuristics.scala index 13bda85f391975c419537fcdec894ec2945bd2f1..51e68ca13ecddbfcd9910071bce5bbe4ec860a2c 100644 --- a/src/main/scala/leon/synthesis/Heuristics.scala +++ b/src/main/scala/leon/synthesis/Heuristics.scala @@ -11,7 +11,8 @@ import purescala.Definitions._ object Heuristics { def all = Set[Synthesizer => Rule]( new IntInduction(_), - new OptimisticInjection(_) + new OptimisticInjection(_), + new ADTInduction(_) ) } @@ -26,7 +27,7 @@ object HeuristicStep { def verifyPre(synth: Synthesizer, problem: Problem)(s: Solution): (Solution, Boolean) = { synth.solver.solveSAT(And(Not(s.pre), problem.phi)) match { case (Some(true), model) => - synth.reporter.warning("Heuristic failed to produce strongest precondition:") + synth.reporter.warning("Heuristic failed to produce weakest precondition:") synth.reporter.warning(" - problem: "+problem) synth.reporter.warning(" - precondition: "+s.pre) (s, false) @@ -46,7 +47,7 @@ class IntInduction(synth: Synthesizer) extends Rule("Int Induction", synth, 80) val p = task.problem p.as match { - case List(origId) if origId.getType == Int32Type => + case List(IsTyped(origId, Int32Type)) => val tpe = TupleType(p.xs.map(_.getType)) val inductOn = FreshIdentifier(origId.name, true).setType(origId.getType) @@ -167,3 +168,116 @@ class SelectiveInlining(synth: Synthesizer) extends Rule("Sel. Inlining", synth, } } +class ADTInduction(synth: Synthesizer) extends Rule("ADT Induction", synth, 80) with Heuristic { + def applyOn(task: Task): RuleResult = { + val p = task.problem + + val candidates = p.as.collect { + case IsTyped(origId, AbstractClassType(cd)) => (origId, cd) + } + + if (!candidates.isEmpty) { + val (origId, cd) = candidates.head + val oas = p.as.filterNot(_ == origId) + + val resType = TupleType(p.xs.map(_.getType)) + + val inductOn = FreshIdentifier(origId.name, true).setType(origId.getType) + val residualArgs = oas.map(id => FreshIdentifier(id.name, true).setType(id.getType)) + val residualMap = (oas zip residualArgs).map{ case (id, id2) => id -> Variable(id2) }.toMap + val residualArgDefs = residualArgs.map(a => VarDecl(a, a.getType)) + + def isAlternativeRecursive(cd: CaseClassDef): Boolean = { + cd.fieldsIds.exists(_.getType == origId.getType) + } + + val isRecursive = cd.knownDescendents.exists { + case ccd: CaseClassDef => isAlternativeRecursive(ccd) + case _ => false + } + + if (isRecursive) { + val newFun = new FunDef(FreshIdentifier("rec", true), resType, VarDecl(inductOn, inductOn.getType) +: residualArgDefs) + + val innerPhi = substAll(residualMap + (origId -> Variable(inductOn)), p.phi) + + val subProblemsInfo = for (dcd <- cd.knownDescendents) yield dcd match { + case ccd : CaseClassDef => + var recCalls = Map[List[Identifier], Expr]() + var postFs = List[Expr]() + + val newIds = ccd.fieldsIds.map(id => FreshIdentifier(id.name, true).setType(id.getType)).toList + + val inputs = (for (id <- newIds) yield { + if (id.getType == origId.getType) { + val postXs = p.xs map (id => FreshIdentifier("r", true).setType(id.getType)) + val postXsMap = (p.xs zip postXs).toMap.mapValues(Variable(_)) + + recCalls += postXs -> FunctionInvocation(newFun, Variable(id) +: residualArgs.map(id => Variable(id))) + + postFs ::= substAll(postXsMap + (origId -> Variable(id)), innerPhi) + id :: postXs + } else { + List(id) + } + }).flatten + + val subPhi = substAll(Map(inductOn -> CaseClass(ccd, newIds.map(Variable(_)))), innerPhi) + + val subPre = CaseClassInstanceOf(ccd, Variable(origId)) + + val subProblem = Problem(inputs ::: residualArgs, And(p.c :: postFs), subPhi, p.xs) + + (subProblem, subPre, recCalls) + case _ => + sys.error("Woops, non case-class as descendent") + } + + val onSuccess: List[Solution] => Solution = { + case sols => + var globalPre = List[Expr]() + + for ((sol, (problem, pre, calls)) <- (sols zip subProblemsInfo)) { + globalPre ::= And(pre, sol.pre) + + } + + val funPre = subst(origId -> Variable(inductOn), Or(globalPre)) + val outerPre = funPre + /* + solPre ::= And(pre, sol.pre) + + val preIn = Or(Seq(And(Equals(Variable(inductOn), IntLiteral(0)), base.pre), + And(GreaterThan(Variable(inductOn), IntLiteral(0)), gt.pre), + And(LessThan(Variable(inductOn), IntLiteral(0)), lt.pre))) + val preOut = subst(inductOn -> Variable(origId), preIn) + + val newFun = new FunDef(FreshIdentifier("rec", true), tpe, Seq(VarDecl(inductOn, inductOn.getType))) + newFun.precondition = Some(preIn) + newFun.postcondition = Some(And(Equals(ResultVariable(), Tuple(p.xs.map(Variable(_)))), p.phi)) + + newFun.body = Some( + IfExpr(Equals(Variable(inductOn), IntLiteral(0)), + base.toExpr, + IfExpr(GreaterThan(Variable(inductOn), IntLiteral(0)), + LetTuple(postXs, FunctionInvocation(newFun, Seq(Minus(Variable(inductOn), IntLiteral(1)))), gt.toExpr) + , LetTuple(postXs, FunctionInvocation(newFun, Seq(Plus(Variable(inductOn), IntLiteral(1)))), lt.toExpr))) + ) + + + Solution(preOut, base.defs++gt.defs++lt.defs+newFun, FunctionInvocation(newFun, Seq(Variable(origId)))) + */ + Solution.none + } + + println(subProblemsInfo) + + RuleInapplicable + } else { + RuleInapplicable + } + } else { + RuleInapplicable + } + } +} diff --git a/src/main/scala/leon/synthesis/Rules.scala b/src/main/scala/leon/synthesis/Rules.scala index 4bc36af3dfa0c638f2086500af95a205047635c3..36576a1f4ac9e59a298598270568424f0a0a0f56 100644 --- a/src/main/scala/leon/synthesis/Rules.scala +++ b/src/main/scala/leon/synthesis/Rules.scala @@ -24,8 +24,8 @@ object Rules { new OptimisticGround(_), new EqualitySplit(_), new CEGIS(_), - new Assert(_), - new IntegerEquation(_) + new Assert(_) + //new IntegerEquation(_) ) }