diff --git a/src/main/scala/leon/purescala/TreeOps.scala b/src/main/scala/leon/purescala/TreeOps.scala index b4ed90a7bca8baa26d15030ef26bc7a99a4d6d96..0ddc9605372cc64dc2aff0521e5a231f4cca5a76 100644 --- a/src/main/scala/leon/purescala/TreeOps.scala +++ b/src/main/scala/leon/purescala/TreeOps.scala @@ -2002,118 +2002,34 @@ object TreeOps { }(expr) } - private val lambdaArgumentsCache = new TrieMap[TypeTree,Seq[Identifier]] - def lambdaArguments(tpe: TypeTree): Seq[Identifier] = lambdaArgumentsCache.get(tpe) match { - case Some(ids) => ids - case None => - val seq = tpe match { - case FunctionType(argTypes, returnType) => - argTypes.map(FreshIdentifier("x", true).setType(_)) ++ lambdaArguments(returnType) - case _ => Seq() - } - lambdaArgumentsCache(tpe) = seq - seq - } - - def functionApplication(expr: Expr, args: Seq[Expr]): Expr = expr.getType match { - case FunctionType(argTypes, returnType) => - val (currentArgs, nextArgs) = args.splitAt(argTypes.size) - val application = Application(expr, currentArgs) - functionApplication(application, nextArgs) - case tpe => - assert(args.isEmpty && !tpe.isInstanceOf[FunctionType]) - expr - } - - def createLambda(expr: Expr, args: Seq[Identifier]): Expr = expr.getType match { - case FunctionType(argTypes, returnType) => - val (currentArgs, nextArgs) = args.splitAt(argTypes.size) - val application = Application(expr, currentArgs.map(_.toVariable)) - Lambda(currentArgs.map(id => ValDef(id, id.getType)), createLambda(application, nextArgs)) - case tpe => - assert(args.isEmpty && !tpe.isInstanceOf[FunctionType]) - expr - } - - def lambdaTransform(expr: Expr) : Expr = { - - def hoistHOIte(expr: Expr) = { - def transform(expr: Expr): Option[Expr] = expr match { - case uop @ UnaryOperator(ife @ IfExpr(c, t, e), op) if ife.getType.isInstanceOf[FunctionType] => - Some(IfExpr(c, op(t).setType(uop.getType), op(e).setType(uop.getType)).setType(uop.getType)) - case bop @ BinaryOperator(ife @ IfExpr(c, t, e), t2, op) if ife.getType.isInstanceOf[FunctionType] => - Some(IfExpr(c, op(t, t2).setType(bop.getType), op(e, t2).setType(bop.getType)).setType(bop.getType)) - case bop @ BinaryOperator(t1, ife @ IfExpr(c, t, e), op) if ife.getType.isInstanceOf[FunctionType] => - Some(IfExpr(c, op(t1, t).setType(bop.getType), op(t1, e).setType(bop.getType)).setType(bop.getType)) - case nop @ NAryOperator(ts, op) => { - val iteIndex = ts.indexWhere { - case ife @ IfExpr(_, _, _) if ife.getType.isInstanceOf[FunctionType] => true - case _ => false - } - if(iteIndex == -1) None else { - val (beforeIte, startIte) = ts.splitAt(iteIndex) - val afterIte = startIte.tail - val IfExpr(c, t, e) = startIte.head - Some(IfExpr(c, - op(beforeIte ++ Seq(t) ++ afterIte).setType(nop.getType), - op(beforeIte ++ Seq(e) ++ afterIte).setType(nop.getType) - ).setType(nop.getType)) - } - } - case _ => None + def simplifyHOFunctions(expr: Expr) : Expr = { + + def liftToLambdas(expr: Expr) = { + def apply(expr: Expr, args: Seq[Expr]): Expr = expr match { + case IfExpr(cond, thenn, elze) => + IfExpr(cond, apply(thenn, args), apply(elze, args)) + case Let(i, e, b) => + Let(i, e, apply(b, args)) + case LetTuple(is, es, b) => + LetTuple(is, es, apply(b, args)) + case Lambda(params, body) => + replaceFromIDs((params.map(_.id) zip args).toMap, body) + case _ => Application(expr, args) } - fixpoint(postMap(transform))(expr) - } - - def expandHOLets(expr: Expr) : Expr = { - def rec(ex: Expr, s: Map[Identifier,Expr]) : Expr = ex match { - case v @ Variable(id) if s.isDefinedAt(id) => rec(s(id), s) - case l @ Let(i,e,b) => - if (i.getType.isInstanceOf[FunctionType]) rec(b, s + (i -> rec(e, s))) - else Let(i, rec(e,s), rec(b,s)) - case i @ IfExpr(t1,t2,t3) => IfExpr(rec(t1,s), rec(t2,s), rec(t3,s)).setType(i.getType) - case m @ MatchExpr(scrut,cses) => MatchExpr(rec(scrut,s), cses.map(inCase(_, s))).setType(m.getType).setPos(m) - case n @ NAryOperator(args, recons) => { - var change = false - val rargs = args.map(a => { - val ra = rec(a, s) - if (ra != a) { - change = true - ra - } else { - a - } - }) - if (change) recons(rargs).setType(n.getType) - else n + def lift(expr: Expr): Expr = expr.getType match { + case FunctionType(from, to) => expr match { + case _ : Lambda => expr + case _ : Variable => expr + case e => + val args = from.map(tpe => ValDef(FreshIdentifier("x", true).setType(tpe), tpe)) + val application = apply(expr, args.map(_.toVariable)) + Lambda(args, lift(application)) } - case b @ BinaryOperator(t1,t2,recons) => { - val r1 = rec(t1, s) - val r2 = rec(t2, s) - if (r1 != t1 || r2 != t2) recons(r1, r2).setType(b.getType) - else b - } - case u @ UnaryOperator(t,recons) => { - val r = rec(t, s) - if (r != t) recons(r).setType(u.getType) - else u - } - case t: Terminal => t - case unhandled => scala.sys.error("Unhandled case in expandHOLets: " + unhandled) - } - - def inCase(cse: MatchCase, s: Map[Identifier,Expr]) : MatchCase = cse match { - case SimpleCase(pat, rhs) => SimpleCase(pat, rec(rhs, s)) - case GuardedCase(pat, guard, rhs) => GuardedCase(pat, rec(guard, s), rec(rhs, s)) + case _ => expr } - rec(expr, Map.empty) - } - - def extractToLambda(expr: Expr) = { - def extract(expr: Expr, build: Boolean) = - if (build) createLambda(expr, lambdaArguments(expr.getType)) else expr + def extract(expr: Expr, build: Boolean) = if (build) lift(expr) else expr def rec(expr: Expr, build: Boolean): Expr = expr match { case Application(caller, args) => @@ -2123,25 +2039,21 @@ object TreeOps { case FunctionInvocation(fd, args) => val newArgs = args.map(rec(_, true)) extract(FunctionInvocation(fd, newArgs), build) - case l @ Lambda(args, body) => l - case NAryOperator(es, recons) => recons(es.map(rec(_, build))) - case BinaryOperator(e1, e2, recons) => recons(rec(e1, build), rec(e2, build)) - case UnaryOperator(e, recons) => recons(rec(e, build)) + case l @ Lambda(args, body) => + val newBody = rec(body, true) + extract(Lambda(args, newBody), build) + case NAryOperator(es, recons) => recons(es.map(rec(_, build))).setType(expr.getType) + case BinaryOperator(e1, e2, recons) => recons(rec(e1, build), rec(e2, build)).setType(expr.getType) + case UnaryOperator(e, recons) => recons(rec(e, build)).setType(expr.getType) case t: Terminal => t } - rec(expr, true) + rec(lift(expr), true) } - extractToLambda( - hoistHOIte( - expandHOLets( - simplifyLets( - matchToIfThenElse( - expr - ) - ) - ) + liftToLambdas( + matchToIfThenElse( + expr ) ) } diff --git a/src/main/scala/leon/purescala/Trees.scala b/src/main/scala/leon/purescala/Trees.scala index af6d1ad9e84bf8253eda0c39b03ba2e4e477ad24..a7088b12790aae3d617c10638e3f7d4289e8dfe2 100644 --- a/src/main/scala/leon/purescala/Trees.scala +++ b/src/main/scala/leon/purescala/Trees.scala @@ -126,7 +126,7 @@ object Trees { lazy val argsTuple = if (lambda.args.size > 1) Tuple(args) else args.head def rec(body: Expr): Option[(Expr, Seq[(Expr, Expr)])] = body match { - case _ : IntLiteral | _ : BooleanLiteral | _ : GenericValue | _ : Tuple | + case _ : IntLiteral | _ : UMinus | _ : BooleanLiteral | _ : GenericValue | _ : Tuple | _ : CaseClass | _ : FiniteArray | _ : FiniteSet | _ : FiniteMap | _ : Lambda => Some(body -> Seq.empty) case IfExpr(Equals(tpArgs, key), expr, elze) if tpArgs == argsTuple => diff --git a/src/main/scala/leon/solvers/templates/TemplateGenerator.scala b/src/main/scala/leon/solvers/templates/TemplateGenerator.scala index a69ec41e9d976481f7e5d7a0774e22a944b0a0ad..a276e41cb5d223ae9d51a46fbf78bc304d858572 100644 --- a/src/main/scala/leon/solvers/templates/TemplateGenerator.scala +++ b/src/main/scala/leon/solvers/templates/TemplateGenerator.scala @@ -47,12 +47,13 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T]) { val prec : Option[Expr] = tfd.precondition.map(p => matchToIfThenElse(p)) val newBody : Option[Expr] = tfd.body.map(b => matchToIfThenElse(b)) + val lambdaBody : Option[Expr] = newBody.map(b => simplifyHOFunctions(b)) val invocation : Expr = FunctionInvocation(tfd, tfd.params.map(_.toVariable)) - val invocationEqualsBody : Option[Expr] = newBody match { + val invocationEqualsBody : Option[Expr] = lambdaBody match { case Some(body) if isRealFunDef => - val b : Expr = Equals(invocation, body) + val b : Expr = appliedEquals(invocation, body) Some(if(prec.isDefined) { Implies(prec.get, b) @@ -66,7 +67,11 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T]) { val start : Identifier = FreshIdentifier("start", true).setType(BooleanType) val pathVar : (Identifier, T) = start -> encoder.encodeId(start) - val arguments : Seq[(Identifier, T)] = tfd.params.map(vd => vd.id -> encoder.encodeId(vd.id)) + + val funDefArgs : Seq[Identifier] = tfd.params.map(_.id) + val allArguments = funDefArgs ++ lambdaBody.map(lambdaArgs).toSeq.flatten + val arguments : Seq[(Identifier, T)] = allArguments.map(id => id -> encoder.encodeId(id)) + val substMap : Map[Identifier, T] = arguments.toMap + pathVar val (bodyConds, bodyExprs, bodyGuarded, bodyLambdas) = if (isRealFunDef) { @@ -74,7 +79,7 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T]) { (Map[Identifier,T](), Map[Identifier,T](), Map[Identifier,Seq[Expr]](), Map[T,LambdaTemplate[T]]()) } } else { - mkClauses(start, newBody.get, substMap) + mkClauses(start, lambdaBody.get, substMap) } // Now the postcondition. @@ -106,6 +111,17 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T]) { template } + private def lambdaArgs(expr: Expr): Seq[Identifier] = expr match { + case Lambda(args, body) => args.map(_.id) ++ lambdaArgs(body) + case _ => Seq.empty + } + + private def appliedEquals(invocation: Expr, body: Expr): Expr = body match { + case Lambda(args, lambdaBody) => + appliedEquals(Application(invocation, args.map(_.toVariable)), lambdaBody) + case _ => Equals(invocation, body) + } + def mkClauses(pathVar: Identifier, expr: Expr, substMap: Map[Identifier, T]): (Map[Identifier,T], Map[Identifier,T], Map[Identifier, Seq[Expr]], Map[T, LambdaTemplate[T]]) = { @@ -127,8 +143,15 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T]) { guardedExprs += guardVar -> (expr +: prev) } + var lambdaVars = Map[Identifier, T]() + @inline def storeLambda(id: Identifier) : T = { + val idT = encoder.encodeId(id) + lambdaVars += id -> idT + idT + } + var lambdas = Map[T, LambdaTemplate[T]]() - @inline def storeLambda(idT: T, lambda: LambdaTemplate[T]) : Unit = lambdas += idT -> lambda + @inline def registerLambda(idT: T, lambda: LambdaTemplate[T]) : Unit = lambdas += idT -> lambda // Group elements that satisfy p toghether // List(a, a, a, b, c, a, a), with p = _ == a will produce: @@ -169,6 +192,11 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T]) { case e @ Ensuring(body, id, post) => rec(pathVar, Let(id, body, Assert(post, None, Variable(id)))) + case l @ Let(i, e : Lambda, b) => + val re = rec(pathVar, e) // guaranteed variable! + val rb = rec(pathVar, replace(Map(Variable(i) -> re), b)) + rb + case l @ Let(i, e, b) => val newExpr : Identifier = FreshIdentifier("lt", true).setType(i.getType) storeExpr(newExpr) @@ -247,20 +275,20 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T]) { Variable(cid) case l @ Lambda(args, body) => - val idArgs : Seq[Identifier] = args.map(_.id) + val idArgs : Seq[Identifier] = lambdaArgs(l) val trArgs : Seq[T] = idArgs.map(encoder.encodeId(_)) val lid = FreshIdentifier("lambda", true).setType(l.getType) - val clause = Equals(Application(Variable(lid), idArgs.map(Variable(_))), body) + val clause = appliedEquals(Variable(lid), l) - val localSubst : Map[Identifier, T] = substMap ++ condVars ++ exprVars + val localSubst : Map[Identifier, T] = substMap ++ condVars ++ exprVars ++ lambdaVars val clauseSubst : Map[Identifier, T] = localSubst ++ (idArgs zip trArgs) val (lambdaConds, lambdaExprs, lambdaGuarded, lambdaTemplates) = mkClauses(pathVar, clause, clauseSubst) - val ids: (Identifier, T) = lid -> encoder.encodeId(lid) + val ids: (Identifier, T) = lid -> storeLambda(lid) val dependencies: Map[Identifier, T] = variablesOf(l).map(id => id -> localSubst(id)).toMap val template = LambdaTemplate(ids, encoder, lambdaManager, pathVar -> encodedCond(pathVar), idArgs zip trArgs, lambdaConds, lambdaExprs, lambdaGuarded, lambdaTemplates, localSubst, dependencies, l) - storeLambda(ids._2, template) + registerLambda(ids._2, template) Variable(lid) diff --git a/src/main/scala/leon/solvers/templates/Templates.scala b/src/main/scala/leon/solvers/templates/Templates.scala index 6c26fa8d88f856c24b710a38b936492e828fe7ad..30ea6d40e8becf853bbba3a642b847f189c106fb 100644 --- a/src/main/scala/leon/solvers/templates/Templates.scala +++ b/src/main/scala/leon/solvers/templates/Templates.scala @@ -82,6 +82,48 @@ trait Template[T] { self => object Template { + private def functionCallInfos[T](encodeExpr: Expr => T)(expr: Expr): (Set[TemplateCallInfo[T]], Set[App[T]]) = { + def invocationCaller(expr: Expr): Boolean = expr match { + case fi: FunctionInvocation => true + case Application(caller, _) => invocationCaller(caller) + case _ => false + } + + val calls = collect[Expr] { + case IsTyped(f: FunctionInvocation, ft: FunctionType) => Set.empty + case IsTyped(f: Application, ft: FunctionType) => Set.empty + case f: FunctionInvocation => Set(f) + case f: Application => Set(f) + case _ => Set.empty + }(expr) + + val (functionCalls, appCalls) = calls partition invocationCaller + + def functionTemplate(expr: Expr): TemplateCallInfo[T] = expr match { + case FunctionInvocation(tfd, args) => + TemplateCallInfo(tfd, args.map(encodeExpr)) + case Application(caller, args) => + val TemplateCallInfo(tfd, prevArgs) = functionTemplate(caller) + TemplateCallInfo(tfd, prevArgs ++ args.map(encodeExpr)) + case _ => scala.sys.error("Should never happen!") + } + + val templates : Set[TemplateCallInfo[T]] = functionCalls map functionTemplate + + def applicationTemplate(expr: Expr): App[T] = expr match { + case Application(caller : Application, args) => + val App(c, tpe, prevArgs) = applicationTemplate(caller) + App(c, tpe, prevArgs ++ args.map(encodeExpr)) + case Application(c, args) => + App(encodeExpr(c), c.getType, args.map(encodeExpr)) + case _ => scala.sys.error("Should never happen!") + } + + val apps : Set[App[T]] = appCalls map applicationTemplate + + (templates, apps) + } + def encode[T]( encoder: TemplateEncoder[T], pathVar: (Identifier, T), @@ -106,32 +148,35 @@ object Template { encodeExpr(Implies(Variable(b), e)) }).toSeq - val blockers : Map[Identifier, Set[TemplateCallInfo[T]]] = { - val optIdCall = optCall.map(tfd => TemplateCallInfo[T](tfd, arguments.map(_._2))) + val extractInfos : Expr => (Set[TemplateCallInfo[T]], Set[App[T]]) = functionCallInfos(encodeExpr) _ + val optIdCall = optCall.map(tfd => TemplateCallInfo[T](tfd, arguments.map(_._2))) + val optIdApp = optApp.map { case (idT, tpe) => App(idT, tpe, arguments.map(_._2)) } - Map((for ((b,es) <- guardedExprs) yield { - val calls = es.flatMap(e => functionCallsOf(e).map { fi => - TemplateCallInfo[T](fi.tfd, fi.args.map(encodeExpr)) - }).toSet -- optIdCall + val (blockers, applications) : (Map[Identifier, Set[TemplateCallInfo[T]]], Map[Identifier, Set[App[T]]]) = { + var blockers : Map[Identifier, Set[TemplateCallInfo[T]]] = Map.empty + var applications : Map[Identifier, Set[App[T]]] = Map.empty - if (calls.isEmpty) None else Some(b -> calls) - }).flatten.toSeq : _*) - } + for ((b,es) <- guardedExprs) { + var funInfos : Set[TemplateCallInfo[T]] = Set.empty + var appInfos : Set[App[T]] = Set.empty - val encodedBlockers : Map[T, Set[TemplateCallInfo[T]]] = blockers.map(p => idToTrId(p._1) -> p._2) + for (e <- es) { + val (newFunInfos, newAppInfos) = extractInfos(e) + funInfos ++= newFunInfos + appInfos ++= newAppInfos + } - val applications : Map[Identifier, Set[App[T]]] = { - val optIdApp = optApp.map { case (idT, tpe) => App(idT, tpe, arguments.map(_._2)) } + val calls = funInfos -- optIdCall + if (calls.nonEmpty) blockers += b -> calls - Map((for ((b,es) <- guardedExprs) yield { - val apps = es.flatMap(e => functionAppsOf(e).map { fa => - App[T](encodeExpr(fa.caller), fa.caller.getType, fa.args.map(encodeExpr)) - }).toSet -- optIdApp + val apps = appInfos -- optIdApp + if (apps.nonEmpty) applications += b -> apps + } - if (apps.isEmpty) None else Some(b -> apps) - }).flatten.toSeq : _*) + (blockers, applications) } + val encodedBlockers : Map[T, Set[TemplateCallInfo[T]]] = blockers.map(p => idToTrId(p._1) -> p._2) val encodedApps : Map[T, Set[App[T]]] = applications.map(p => idToTrId(p._1) -> p._2) val stringRepr : () => String = () => {