diff --git a/src/main/scala/leon/grammars/SafeRecursiveCalls.scala b/src/main/scala/leon/grammars/SafeRecursiveCalls.scala index 858cc6efc064731c1b816cce474fa83ced38d0e7..cbfd3947dcce3b9458ed6914b374fe4bf5a57433 100644 --- a/src/main/scala/leon/grammars/SafeRecursiveCalls.scala +++ b/src/main/scala/leon/grammars/SafeRecursiveCalls.scala @@ -19,7 +19,7 @@ case class SafeRecursiveCalls(prog: Program, ws: Expr, pc: Expr) extends Express val calls = terminatingCalls(prog,ws, pc, Some(t), true) calls.map { c => (c: @unchecked) match { - case (_, fi, Some(free)) => + case (fi, Some(free)) => val freeSeq = free.toSeq nonTerminal( diff --git a/src/main/scala/leon/synthesis/rules/IntroduceRecCalls.scala b/src/main/scala/leon/synthesis/rules/IntroduceRecCalls.scala index 18fe97478bade0ffa6e0c9573e07d6cbfa76377b..5103ad0c63a6ef01bb5762c945c93d9d9cecec41 100644 --- a/src/main/scala/leon/synthesis/rules/IntroduceRecCalls.scala +++ b/src/main/scala/leon/synthesis/rules/IntroduceRecCalls.scala @@ -25,8 +25,11 @@ case object IntroduceRecCalls extends NormalizingRule("Introduce rec. calls") { } def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { + val TopLevelAnds(pcs) = p.pc + val existingCalls = pcs.collect { case Equals(_, fi: FunctionInvocation) => fi }.toSet - val (orig, calls, _) = terminatingCalls(hctx.program, p.ws, p.pc, None, false).unzip3 + val calls = terminatingCalls(hctx.program, p.ws, p.pc, None, false) + .map(_._1).filterNot(existingCalls) if (calls.isEmpty) return Nil @@ -76,14 +79,12 @@ case object IntroduceRecCalls extends NormalizingRule("Introduce rec. calls") { } val newWs = calls map Terminating - val TopLevelAnds(ws) = p.ws - try { val newProblem = p.copy( as = p.as ++ recs, pc = andJoin(p.pc +: posts), - ws = andJoin((ws diff orig) ++ newWs), + ws = andJoin(ws ++ newWs), eb = p.eb.map(mapExample) ) diff --git a/src/main/scala/leon/synthesis/utils/Helpers.scala b/src/main/scala/leon/synthesis/utils/Helpers.scala index a0b940958ffecee169005512693d00539ffef196..64a8f4d9d046464c8ebfe8e7ad157a1411740ae4 100644 --- a/src/main/scala/leon/synthesis/utils/Helpers.scala +++ b/src/main/scala/leon/synthesis/utils/Helpers.scala @@ -43,10 +43,10 @@ object Helpers { * @param ws Helper predicates that contain [[Terminating]]s with the initial calls * @param pc The path condition * @param tpe The expected type for the returned function calls. If absent, all types are permitted. - * @return A list of trips of (original terminating call, safe function call, holes), + * @return A list of pairs (safe function call, holes), * where holes stand for the rest of the arguments of the function. */ - def terminatingCalls(prog: Program, ws: Expr, pc: Expr, tpe: Option[TypeTree], introduceHoles: Boolean): List[(Terminating, FunctionInvocation, Option[Set[Identifier]])] = { + def terminatingCalls(prog: Program, ws: Expr, pc: Expr, tpe: Option[TypeTree], introduceHoles: Boolean): List[(FunctionInvocation, Option[Set[Identifier]])] = { val TopLevelAnds(wss) = ws val TopLevelAnds(clauses) = pc @@ -88,13 +88,13 @@ object Helpers { } val res = gs.flatMap { - case term@Terminating(FunctionInvocation(tfd, args)) if tpe forall (isSubtypeOf(tfd.returnType, _)) => + case Terminating(FunctionInvocation(tfd, args)) if tpe forall (isSubtypeOf(tfd.returnType, _)) => val ids = tfd.params.map(vd => FreshIdentifier("<hole>", vd.getType, true)).toList for (((a, i), tpe) <- args.zipWithIndex zip tfd.params.map(_.getType); smaller <- argsSmaller(a, tpe)) yield { val newArgs = (if (introduceHoles) ids.map(_.toVariable) else args).updated(i, smaller) - (term, FunctionInvocation(tfd, newArgs), if(introduceHoles) Some(ids.toSet - ids(i)) else None) + (FunctionInvocation(tfd, newArgs), if(introduceHoles) Some(ids.toSet - ids(i)) else None) } case _ => Nil