diff --git a/src/main/scala/leon/purescala/CheckForalls.scala b/src/main/scala/leon/purescala/CheckForalls.scala index 600b0ea26e8831779af4418a9c325692870e5fde..6afcfad1f91d5aff1605c7f1198a001f2a032395 100644 --- a/src/main/scala/leon/purescala/CheckForalls.scala +++ b/src/main/scala/leon/purescala/CheckForalls.scala @@ -60,7 +60,10 @@ object CheckForalls extends UnitPhase[Program] { }) }) ctx.reporter.warning("Matcher arguments must have simple form in " + conjunct) - if (matchers.map(_._1).toSet.size != 1) + if (matchers.filter(_._2.exists { + case Variable(id) => quantified(id) + case _ => false + }).map(_._1).toSet.size != 1) ctx.reporter.warning("Quantification conjuncts must contain exactly one matcher in " + conjunct) preTraversal { diff --git a/src/main/scala/leon/solvers/templates/QuantificationManager.scala b/src/main/scala/leon/solvers/templates/QuantificationManager.scala index 9d19ea839601d631fe7210786c1c9d49894503af..d5bec5c31a1cef1129fd5ab768754351ffc364fe 100644 --- a/src/main/scala/leon/solvers/templates/QuantificationManager.scala +++ b/src/main/scala/leon/solvers/templates/QuantificationManager.scala @@ -1,3 +1,5 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + package leon package solvers package templates @@ -201,7 +203,6 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage for { (b, Matcher(qcaller, _, qargs), Matcher(caller, _, args)) <- mapping _ = constraints :+= b - _ = if (qcaller != caller) constraints :+= encoder.mkEquals(qcaller, caller) (qarg, arg) <- (qargs zip args) } if (subst.isDefinedAt(qarg)) { constraints :+= encoder.mkEquals(subst(qarg), arg) @@ -229,7 +230,6 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage } def instantiateQuantification(template: QuantificationTemplate[T], substMap: Map[T, T]): Instantiation[T] = { - val trueT = encoder.encodeExpr(Map.empty)(BooleanLiteral(true)) val instantiationSubst: Map[T, T] = substMap + (template.guardVar -> trueT) diff --git a/src/main/scala/leon/solvers/templates/Templates.scala b/src/main/scala/leon/solvers/templates/Templates.scala index d727a2b9f0000eb17fe6b1173e7cdab07dc225ef..d14a294c074fedf82c748456e1ee36ef87151485 100644 --- a/src/main/scala/leon/solvers/templates/Templates.scala +++ b/src/main/scala/leon/solvers/templates/Templates.scala @@ -124,12 +124,31 @@ object Template { (templates, apps) } - private def matchersOf[T](encodeExpr: Expr => T)(expr: Expr): Set[Matcher[T]] = collect[Matcher[T]] { - case Application(caller, args) => Set(Matcher(encodeExpr(caller), caller.getType, args.map(encodeExpr))) - case ArraySelect(arr, index) => Set(Matcher(encodeExpr(arr), arr.getType, Seq(encodeExpr(index)))) - case MapGet(map, key) => Set(Matcher(encodeExpr(map), map.getType, Seq(encodeExpr(key)))) - case _ => Set.empty - }(expr) + private def selectMatchInfos[T](encodeExpr: Expr => T)(expr: Expr): Set[Matcher[T]] = { + collect[Matcher[T]] { + case ArraySelect(arr, index) => + Set(Matcher(encodeExpr(arr), arr.getType, Seq(encodeExpr(index)))) + case MapGet(map, key) => + Set(Matcher(encodeExpr(map), map.getType, Seq(encodeExpr(key)))) + case _ => Set.empty + }(expr) + } + + private def invocationMatcher[T](encodeExpr: Expr => T)(tfd: TypedFunDef, args: Seq[Expr]): Matcher[T] = { + assert(tfd.returnType.isInstanceOf[FunctionType], "invocationMatcher() is only defined on function-typed defs") + + def rec(e: Expr, args: Seq[Expr]): Expr = e.getType match { + case FunctionType(from, to) => + val (appArgs, outerArgs) = args.splitAt(from.size) + rec(Application(e, appArgs), outerArgs) + case _ if args.isEmpty => e + case _ => scala.sys.error("Should never happen") + } + + val (fiArgs, appArgs) = args.splitAt(tfd.params.size) + val Application(caller, arguments) = rec(FunctionInvocation(tfd, fiArgs), appArgs) + Matcher(encodeExpr(caller), caller.getType, arguments.map(encodeExpr)) + } def encode[T]( encoder: TemplateEncoder[T], @@ -155,12 +174,14 @@ object Template { encodeExpr(Implies(Variable(b), e)) }).toSeq - val extractInfos : Expr => (Set[TemplateCallInfo[T]], Set[App[T]]) = functionCallInfos(encodeExpr) - val extractMatchers: Expr => Set[Matcher[T]] = matchersOf(encodeExpr) + val extractInfos : Expr => (Set[TemplateCallInfo[T]], Set[App[T]]) = functionCallInfos(encodeExpr) + val extractMatchers : Expr => Set[Matcher[T]] = selectMatchInfos(encodeExpr) val optIdCall = optCall.map(tfd => TemplateCallInfo[T](tfd, arguments.map(_._2))) val optIdApp = optApp.map { case (idT, tpe) => App(idT, tpe, arguments.map(_._2)) } - val optIdMatch = optIdApp.map { case App(caller, tpe, args) => Matcher(caller, tpe, args) } + + val invocMatcher = optCall.filter(_.returnType.isInstanceOf[FunctionType]) + .map(tfd => invocationMatcher(encodeExpr)(tfd, arguments.map(_._1.toVariable))) val (blockers, applications, matchers) = { var blockers : Map[Identifier, Set[TemplateCallInfo[T]]] = Map.empty @@ -185,7 +206,14 @@ object Template { val apps = appInfos -- optIdApp if (apps.nonEmpty) applications += b -> apps - val matchs = matchInfos -- optIdMatch + val matchs = matchInfos ++ + apps.map(app => Matcher(app.caller, app.tpe, app.args)) ++ + funInfos.flatMap { + case info @ TemplateCallInfo(tfd, args) if Some(info) == optIdCall => + invocMatcher + case _ => None + } + if (matchs.nonEmpty) matchers += b -> matchs } @@ -208,8 +236,11 @@ object Template { " * Application-blocks :" + (if (applications.isEmpty) "\n" else { "\n " + applications.map(p => p._1 + " ==> " + p._2).mkString("\n ") + "\n" }) + + " * Matchers :" + (if (matchers.isEmpty) "\n" else { + "\n " + matchers.map(p => p._1 + " ==> " + p._2).mkString("\n ") + "\n" + }) + " * Lambdas :\n" + lambdas.map { case (_, template) => - " +> " + template.toString.split("\n").mkString("\n ") + " +> " + template.toString.split("\n").mkString("\n ") + "\n" }.mkString("\n") }