Skip to content
Snippets Groups Projects
Commit 439af53e authored by Manos Koukoutos's avatar Manos Koukoutos
Browse files

Fix SafeRecursiveCalls to not forget previous calls

simplify terminatingCalls
parent 66114616
Branches
Tags
Loading
...@@ -19,7 +19,7 @@ case class SafeRecursiveCalls(prog: Program, ws: Expr, pc: Expr) extends Express ...@@ -19,7 +19,7 @@ case class SafeRecursiveCalls(prog: Program, ws: Expr, pc: Expr) extends Express
val calls = terminatingCalls(prog,ws, pc, Some(t), true) val calls = terminatingCalls(prog,ws, pc, Some(t), true)
calls.map { c => (c: @unchecked) match { calls.map { c => (c: @unchecked) match {
case (_, fi, Some(free)) => case (fi, Some(free)) =>
val freeSeq = free.toSeq val freeSeq = free.toSeq
nonTerminal( nonTerminal(
......
...@@ -25,8 +25,11 @@ case object IntroduceRecCalls extends NormalizingRule("Introduce rec. calls") { ...@@ -25,8 +25,11 @@ case object IntroduceRecCalls extends NormalizingRule("Introduce rec. calls") {
} }
def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = {
val TopLevelAnds(pcs) = p.pc
val existingCalls = pcs.collect { case Equals(_, fi: FunctionInvocation) => fi }.toSet
val (orig, calls, _) = terminatingCalls(hctx.program, p.ws, p.pc, None, false).unzip3 val calls = terminatingCalls(hctx.program, p.ws, p.pc, None, false)
.map(_._1).filterNot(existingCalls)
if (calls.isEmpty) return Nil if (calls.isEmpty) return Nil
...@@ -76,14 +79,12 @@ case object IntroduceRecCalls extends NormalizingRule("Introduce rec. calls") { ...@@ -76,14 +79,12 @@ case object IntroduceRecCalls extends NormalizingRule("Introduce rec. calls") {
} }
val newWs = calls map Terminating val newWs = calls map Terminating
val TopLevelAnds(ws) = p.ws val TopLevelAnds(ws) = p.ws
try { try {
val newProblem = p.copy( val newProblem = p.copy(
as = p.as ++ recs, as = p.as ++ recs,
pc = andJoin(p.pc +: posts), pc = andJoin(p.pc +: posts),
ws = andJoin((ws diff orig) ++ newWs), ws = andJoin(ws ++ newWs),
eb = p.eb.map(mapExample) eb = p.eb.map(mapExample)
) )
......
...@@ -43,10 +43,10 @@ object Helpers { ...@@ -43,10 +43,10 @@ object Helpers {
* @param ws Helper predicates that contain [[Terminating]]s with the initial calls * @param ws Helper predicates that contain [[Terminating]]s with the initial calls
* @param pc The path condition * @param pc The path condition
* @param tpe The expected type for the returned function calls. If absent, all types are permitted. * @param tpe The expected type for the returned function calls. If absent, all types are permitted.
* @return A list of trips of (original terminating call, safe function call, holes), * @return A list of pairs (safe function call, holes),
* where holes stand for the rest of the arguments of the function. * where holes stand for the rest of the arguments of the function.
*/ */
def terminatingCalls(prog: Program, ws: Expr, pc: Expr, tpe: Option[TypeTree], introduceHoles: Boolean): List[(Terminating, FunctionInvocation, Option[Set[Identifier]])] = { def terminatingCalls(prog: Program, ws: Expr, pc: Expr, tpe: Option[TypeTree], introduceHoles: Boolean): List[(FunctionInvocation, Option[Set[Identifier]])] = {
val TopLevelAnds(wss) = ws val TopLevelAnds(wss) = ws
val TopLevelAnds(clauses) = pc val TopLevelAnds(clauses) = pc
...@@ -88,13 +88,13 @@ object Helpers { ...@@ -88,13 +88,13 @@ object Helpers {
} }
val res = gs.flatMap { val res = gs.flatMap {
case term@Terminating(FunctionInvocation(tfd, args)) if tpe forall (isSubtypeOf(tfd.returnType, _)) => case Terminating(FunctionInvocation(tfd, args)) if tpe forall (isSubtypeOf(tfd.returnType, _)) =>
val ids = tfd.params.map(vd => FreshIdentifier("<hole>", vd.getType, true)).toList val ids = tfd.params.map(vd => FreshIdentifier("<hole>", vd.getType, true)).toList
for (((a, i), tpe) <- args.zipWithIndex zip tfd.params.map(_.getType); for (((a, i), tpe) <- args.zipWithIndex zip tfd.params.map(_.getType);
smaller <- argsSmaller(a, tpe)) yield { smaller <- argsSmaller(a, tpe)) yield {
val newArgs = (if (introduceHoles) ids.map(_.toVariable) else args).updated(i, smaller) val newArgs = (if (introduceHoles) ids.map(_.toVariable) else args).updated(i, smaller)
(term, FunctionInvocation(tfd, newArgs), if(introduceHoles) Some(ids.toSet - ids(i)) else None) (FunctionInvocation(tfd, newArgs), if(introduceHoles) Some(ids.toSet - ids(i)) else None)
} }
case _ => case _ =>
Nil Nil
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment