diff --git a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala index 75c691edf1f7e5a4ca056bf8c8448bff9b5ead77..f7240f1b763951aee607241cc07bf41a5c535e4a 100644 --- a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala +++ b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala @@ -8,7 +8,7 @@ import purescala.Constructors._ import purescala.ExprOps._ import purescala.Expressions.Pattern import purescala.Extractors._ -import purescala.TypeOps._ +import purescala.TypeOps.isSubtypeOf import purescala.Types._ import purescala.Common._ import purescala.Expressions._ diff --git a/src/main/scala/leon/evaluators/StreamEvaluator.scala b/src/main/scala/leon/evaluators/StreamEvaluator.scala index 9cc6dd132036ffdde4a5ef70e2be87e6276f9f3c..385157e4d64c835841279f0970705f6a5aa10b1c 100644 --- a/src/main/scala/leon/evaluators/StreamEvaluator.scala +++ b/src/main/scala/leon/evaluators/StreamEvaluator.scala @@ -6,7 +6,7 @@ package evaluators import purescala.Constructors._ import purescala.ExprOps._ import purescala.Extractors._ -import purescala.TypeOps._ +import purescala.TypeOps.{leastUpperBound, isSubtypeOf} import purescala.Types._ import purescala.Common.Identifier import purescala.Definitions.{TypedFunDef, Program} diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala index 3e93664008654edef6fa057c56451db0a108bc6b..d7dd1a947ead97ccd1c30a6a89256e24e10654f4 100644 --- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala +++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala @@ -22,7 +22,7 @@ import Common._ import Extractors._ import Constructors._ import ExprOps._ -import TypeOps._ +import TypeOps.{leastUpperBound, typesCompatible, typeParamsOf, canBeSubtypeOf} import xlang.Expressions.{Block => LeonBlock, _} import xlang.ExprOps._ diff --git a/src/main/scala/leon/purescala/DefOps.scala b/src/main/scala/leon/purescala/DefOps.scala index 81ebbdbec38e1484baf106eac9652d62297ae493..8cae0aa72c7a6d7ac18f3af251eeccb4056a6441 100644 --- a/src/main/scala/leon/purescala/DefOps.scala +++ b/src/main/scala/leon/purescala/DefOps.scala @@ -275,6 +275,7 @@ object DefOps { None } + /** Returns the new program with a map from the old functions to the new functions */ def replaceFunDefs(p: Program)(fdMapF: FunDef => Option[FunDef], fiMapF: (FunctionInvocation, FunDef) => Option[Expr] = defaultFiMap) = { diff --git a/src/main/scala/leon/purescala/ExprOps.scala b/src/main/scala/leon/purescala/ExprOps.scala index 49d945ec6217a42578da7fef97bf572252cb6c26..99803ec2adea8a7c00dceacb579cdb4f10366b13 100644 --- a/src/main/scala/leon/purescala/ExprOps.scala +++ b/src/main/scala/leon/purescala/ExprOps.scala @@ -19,285 +19,19 @@ import solvers._ * * The generic operations lets you apply operations on a whole tree * expression. You can look at: - * - [[ExprOps.fold foldRight]] - * - [[ExprOps.preTraversal preTraversal]] - * - [[ExprOps.postTraversal postTraversal]] - * - [[ExprOps.preMap preMap]] - * - [[ExprOps.postMap postMap]] - * - [[ExprOps.genericTransform genericTransform]] + * - [[SubTreeOps.fold foldRight]] + * - [[SubTreeOps.preTraversal preTraversal]] + * - [[SubTreeOps.postTraversal postTraversal]] + * - [[SubTreeOps.preMap preMap]] + * - [[SubTreeOps.postMap postMap]] + * - [[SubTreeOps.genericTransform genericTransform]] * * These operations usually take a higher order function that gets applied to the * expression tree in some strategy. They provide an expressive way to build complex * operations on Leon expressions. * */ -object ExprOps { - - /* ======== - * Core API - * ======== - * - * All these functions should be stable, tested, and used everywhere. Modify - * with care. - */ - - - /** Does a right tree fold - * - * A right tree fold applies the input function to the subnodes first (from left - * to right), and combine the results along with the current node value. - * - * @param f a function that takes the current node and the seq - * of results form the subtrees. - * @param e The Expr on which to apply the fold. - * @return The expression after applying `f` on all subtrees. - * @note the computation is lazy, hence you should not rely on side-effects of `f` - */ - def fold[T](f: (Expr, Seq[T]) => T)(e: Expr): T = { - val rec = fold(f) _ - val Operator(es, _) = e - - //Usages of views makes the computation lazy. (which is useful for - //contains-like operations) - f(e, es.view.map(rec)) - } - - /** Pre-traversal of the tree. - * - * Invokes the input function on every node '''before''' visiting - * children. Traverse children from left to right subtrees. - * - * e.g. - * {{{ - * Add(a, Minus(b, c)) - * }}} - * will yield, in order: - * {{{ - * f(Add(a, Minus(b, c))); f(a); f(Minus(b, c)); f(b); f(c) - * }}} - * - * @param f a function to apply on each node of the expression - * @param e the expression to traverse - */ - def preTraversal(f: Expr => Unit)(e: Expr): Unit = { - val rec = preTraversal(f) _ - val Operator(es, _) = e - f(e) - es.foreach(rec) - } - - /** Post-traversal of the tree. - * - * Invokes the input function on every node '''after''' visiting - * children. - * - * e.g. - * {{{ - * Add(a, Minus(b, c)) - * }}} - * will yield, in order: - * {{{ - * f(a), f(b), f(c), f(Minus(b, c)), f(Add(a, Minus(b, c))) - * }}} - * - * @param f a function to apply on each node of the expression - * @param e the expression to traverse - */ - def postTraversal(f: Expr => Unit)(e: Expr): Unit = { - val rec = postTraversal(f) _ - val Operator(es, _) = e - es.foreach(rec) - f(e) - } - - /** Pre-transformation of the tree. - * - * Takes a partial function of replacements and substitute - * '''before''' recursing down the trees. - * - * Supports two modes : - * - * - If applyRec is false (default), will only substitute once on each level. - * - * e.g. - * {{{ - * Add(a, Minus(b, c)) with replacements: Minus(b,c) -> d, b -> e, d -> f - * }}} - * will yield: - * {{{ - * Add(a, d) // And not Add(a, f) because it only substitute once for each level. - * }}} - * - * - If applyRec is true, it will substitute multiple times on each level: - * - * e.g. - * {{{ - * Add(a, Minus(b, c)) with replacements: Minus(b,c) -> d, b -> e, d -> f - * }}} - * will yield: - * {{{ - * Add(a, f) - * }}} - * - * @note The mode with applyRec true can diverge if f is not well formed - */ - def preMap(f: Expr => Option[Expr], applyRec : Boolean = false)(e: Expr): Expr = { - val rec = preMap(f, applyRec) _ - - val newV = if (applyRec) { - // Apply f as long as it returns Some() - fixpoint { e : Expr => f(e) getOrElse e } (e) - } else { - f(e) getOrElse e - } - - val Operator(es, builder) = newV - val newEs = es.map(rec) - - if ((newEs zip es).exists { case (bef, aft) => aft ne bef }) { - builder(newEs).copiedFrom(newV) - } else { - newV - } - } - - /** Post-transformation of the tree. - * - * Takes a partial function of replacements. - * Substitutes '''after''' recursing down the trees. - * - * Supports two modes : - * - * - If applyRec is false (default), will only substitute once on each level. - * e.g. - * {{{ - * Add(a, Minus(b, c)) with replacements: Minus(b,c) -> z, Minus(e,c) -> d, b -> e - * }}} - * will yield: - * {{{ - * Add(a, Minus(e, c)) - * }}} - * - * - If applyRec is true, it will substitute multiple times on each level: - * e.g. - * {{{ - * Add(a, Minus(b, c)) with replacements: Minus(e,c) -> d, b -> e, d -> f - * }}} - * will yield: - * {{{ - * Add(a, f) - * }}} - * - * @note The mode with applyRec true can diverge if f is not well formed (i.e. not convergent) - */ - def postMap(f: Expr => Option[Expr], applyRec : Boolean = false)(e: Expr): Expr = { - val rec = postMap(f, applyRec) _ - - val Operator(es, builder) = e - val newEs = es.map(rec) - val newV = { - if ((newEs zip es).exists { case (bef, aft) => aft ne bef }) { - builder(newEs).copiedFrom(e) - } else { - e - } - } - - if (applyRec) { - // Apply f as long as it returns Some() - fixpoint { e : Expr => f(e) getOrElse e } (newV) - } else { - f(newV) getOrElse newV - } - - } - - - /** Applies functions and combines results in a generic way - * - * Start with an initial value, and apply functions to nodes before - * and after the recursion in the children. Combine the results of - * all children and apply a final function on the resulting node. - * - * @param pre a function applied on a node before doing a recursion in the children - * @param post a function applied to the node built from the recursive application to - all children - * @param combiner a function to combine the resulting values from all children with - the current node - * @param init the initial value - * @param expr the expression on which to apply the transform - * - * @see [[simpleTransform]] - * @see [[simplePreTransform]] - * @see [[simplePostTransform]] - */ - def genericTransform[C](pre: (Expr, C) => (Expr, C), - post: (Expr, C) => (Expr, C), - combiner: (Expr, Seq[C]) => C)(init: C)(expr: Expr) = { - - def rec(eIn: Expr, cIn: C): (Expr, C) = { - - val (expr, ctx) = pre(eIn, cIn) - val Operator(es, builder) = expr - val (newExpr, newC) = { - val (nes, cs) = es.map{ rec(_, ctx)}.unzip - val newE = builder(nes).copiedFrom(expr) - - (newE, combiner(newE, cs)) - } - - post(newExpr, newC) - } - - rec(expr, init) - } - - /* - * ============= - * Auxiliary API - * ============= - * - * Convenient methods using the Core API. - */ - - /** Checks if the predicate holds in some sub-expression */ - def exists(matcher: Expr => Boolean)(e: Expr): Boolean = { - fold[Boolean]({ (e, subs) => matcher(e) || subs.contains(true) } )(e) - } - - /** Collects a set of objects from all sub-expressions */ - def collect[T](matcher: Expr => Set[T])(e: Expr): Set[T] = { - fold[Set[T]]({ (e, subs) => matcher(e) ++ subs.flatten } )(e) - } - - def collectPreorder[T](matcher: Expr => Seq[T])(e: Expr): Seq[T] = { - fold[Seq[T]]({ (e, subs) => matcher(e) ++ subs.flatten } )(e) - } - - /** Returns a set of all sub-expressions matching the predicate */ - def filter(matcher: Expr => Boolean)(e: Expr): Set[Expr] = { - collect[Expr] { e => Set(e) filter matcher }(e) - } - - /** Counts how many times the predicate holds in sub-expressions */ - def count(matcher: Expr => Int)(e: Expr): Int = { - fold[Int]({ (e, subs) => matcher(e) + subs.sum } )(e) - } - - /** Replaces bottom-up sub-expressions by looking up for them in a map */ - def replace(substs: Map[Expr,Expr], expr: Expr) : Expr = { - postMap(substs.lift)(expr) - } - - /** Replaces bottom-up sub-expressions by looking up for them in the provided order */ - def replaceSeq(substs: Seq[(Expr, Expr)], expr: Expr): Expr = { - var res = expr - for (s <- substs) { - res = replace(Map(s), res) - } - res - } - +object ExprOps extends { val Deconstructor = Operator } with SubTreeOps[Expr] { /** Replaces bottom-up sub-identifiers by looking up for them in a map */ def replaceFromIDs(substs: Map[Identifier, Expr], expr: Expr) : Expr = { postMap({ @@ -332,7 +66,7 @@ object ExprOps { Lambda(args, rec(binders ++ args.map(_.id), bd)) case Forall(args, bd) => Forall(args, rec(binders ++ args.map(_.id), bd)) - case Operator(subs, builder) => + case Deconstructor(subs, builder) => builder(subs map (rec(binders, _))) }).copiedFrom(e) @@ -695,7 +429,7 @@ object ExprOps { case i @ IfExpr(t1,t2,t3) => IfExpr(rec(t1, s),rec(t2, s),rec(t3, s)) case m @ MatchExpr(scrut, cses) => matchExpr(rec(scrut, s), cses.map(inCase(_, s))).setPos(m) case p @ Passes(in, out, cses) => Passes(rec(in, s), rec(out,s), cses.map(inCase(_, s))).setPos(p) - case n @ Operator(args, recons) => { + case n @ Deconstructor(args, recons) => { var change = false val rargs = args.map(a => { val ra = rec(a, s) @@ -1204,7 +938,7 @@ object ExprOps { def transform(expr: Expr): Option[Expr] = expr match { case IfExpr(c, t, e) => None - case nop@Operator(ts, op) => { + case nop@Deconstructor(ts, op) => { val iteIndex = ts.indexWhere{ case IfExpr(_, _, _) => true case _ => false } if(iteIndex == -1) None else { val (beforeIte, startIte) = ts.splitAt(iteIndex) @@ -1355,7 +1089,7 @@ object ExprOps { formulaSize(rhs) + og.map(formulaSize).getOrElse(0) + patternSize(p) }.sum - case Operator(es, _) => + case Deconstructor(es, _) => es.map(formulaSize).sum+1 } @@ -1669,7 +1403,7 @@ object ExprOps { // TODO: Seems a lot is missing, like Literals - case Same(Operator(es1, _), Operator(es2, _)) => + case Same(Deconstructor(es1, _), Deconstructor(es2, _)) => (es1.size == es2.size) && (es1 zip es2).forall{ case (e1, e2) => isHomo(e1, e2) } @@ -2008,7 +1742,7 @@ object ExprOps { f(e, initParent) - val Operator(es, _) = e + val Deconstructor(es, _) = e es foreach rec } @@ -2101,7 +1835,7 @@ object ExprOps { case l @ Lambda(args, body) => val newBody = rec(body, true) extract(Lambda(args, newBody), build) - case Operator(es, recons) => recons(es.map(rec(_, build))) + case Deconstructor(es, recons) => recons(es.map(rec(_, build))) } rec(lift(expr), true) diff --git a/src/main/scala/leon/purescala/Extractors.scala b/src/main/scala/leon/purescala/Extractors.scala index e2581dd8cdb33e3d025e8206d72dd32fc4ca59f7..cfa4780efccad338b5df44400b53da7264a2c2d7 100644 --- a/src/main/scala/leon/purescala/Extractors.scala +++ b/src/main/scala/leon/purescala/Extractors.scala @@ -7,12 +7,11 @@ import Expressions._ import Common._ import Types._ import Constructors._ -import ExprOps._ -import Definitions.Program +import Definitions.{Program, AbstractClassDef, CaseClassDef} object Extractors { - object Operator { + object Operator extends SubTreeOps.Extractor[Expr] { def unapply(expr: Expr): Option[(Seq[Expr], (Seq[Expr]) => Expr)] = expr match { /* Unary operators */ case Not(t) => @@ -250,6 +249,8 @@ object Extractors { None } } + + // Extractors for types are available at Types.NAryType trait Extractable { def extract: Option[(Seq[Expr], Seq[Expr] => Expr)] @@ -367,7 +368,7 @@ object Extractors { def unapply(me : MatchExpr) : Option[(Pattern, Expr, Expr)] = { Option(me) collect { - case MatchExpr(scrut, List(SimpleCase(pattern, body))) if !aliased(pattern.binders, variablesOf(scrut)) => + case MatchExpr(scrut, List(SimpleCase(pattern, body))) if !aliased(pattern.binders, ExprOps.variablesOf(scrut)) => ( pattern, scrut, body ) } } diff --git a/src/main/scala/leon/purescala/SubTreeOps.scala b/src/main/scala/leon/purescala/SubTreeOps.scala new file mode 100644 index 0000000000000000000000000000000000000000..795471604c56bdcc211caf897020449617855b74 --- /dev/null +++ b/src/main/scala/leon/purescala/SubTreeOps.scala @@ -0,0 +1,286 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package purescala + +import Expressions.Expr +import Types.TypeTree +import Common._ +import utils._ + +object SubTreeOps { + trait Extractor[SubTree <: Tree] { + def unapply(e: SubTree): Option[(Seq[SubTree], (Seq[SubTree]) => SubTree)] + } +} +trait SubTreeOps[SubTree <: Tree] { + val Deconstructor: SubTreeOps.Extractor[SubTree] + + /* ======== + * Core API + * ======== + * + * All these functions should be stable, tested, and used everywhere. Modify + * with care. + */ + + /** Does a right tree fold + * + * A right tree fold applies the input function to the subnodes first (from left + * to right), and combine the results along with the current node value. + * + * @param f a function that takes the current node and the seq + * of results form the subtrees. + * @param e The value on which to apply the fold. + * @return The expression after applying `f` on all subtrees. + * @note the computation is lazy, hence you should not rely on side-effects of `f` + */ + def fold[T](f: (SubTree, Seq[T]) => T)(e: SubTree): T = { + val rec = fold(f) _ + val Deconstructor(es, _) = e + + //Usages of views makes the computation lazy. (which is useful for + //contains-like operations) + f(e, es.view.map(rec)) + } + + + /** Pre-traversal of the tree. + * + * Invokes the input function on every node '''before''' visiting + * children. Traverse children from left to right subtrees. + * + * e.g. + * {{{ + * Add(a, Minus(b, c)) + * }}} + * will yield, in order: + * {{{ + * f(Add(a, Minus(b, c))); f(a); f(Minus(b, c)); f(b); f(c) + * }}} + * + * @param f a function to apply on each node of the expression + * @param e the expression to traverse + */ + def preTraversal(f: SubTree => Unit)(e: SubTree): Unit = { + val rec = preTraversal(f) _ + val Deconstructor(es, _) = e + f(e) + es.foreach(rec) + } + + /** Post-traversal of the tree. + * + * Invokes the input function on every node '''after''' visiting + * children. + * + * e.g. + * {{{ + * Add(a, Minus(b, c)) + * }}} + * will yield, in order: + * {{{ + * f(a), f(b), f(c), f(Minus(b, c)), f(Add(a, Minus(b, c))) + * }}} + * + * @param f a function to apply on each node of the expression + * @param e the expression to traverse + */ + def postTraversal(f: SubTree => Unit)(e: SubTree): Unit = { + val rec = postTraversal(f) _ + val Deconstructor(es, _) = e + es.foreach(rec) + f(e) + } + + /** Pre-transformation of the tree. + * + * Takes a partial function of replacements and substitute + * '''before''' recursing down the trees. + * + * Supports two modes : + * + * - If applyRec is false (default), will only substitute once on each level. + * + * e.g. + * {{{ + * Add(a, Minus(b, c)) with replacements: Minus(b,c) -> d, b -> e, d -> f + * }}} + * will yield: + * {{{ + * Add(a, d) // And not Add(a, f) because it only substitute once for each level. + * }}} + * + * - If applyRec is true, it will substitute multiple times on each level: + * + * e.g. + * {{{ + * Add(a, Minus(b, c)) with replacements: Minus(b,c) -> d, b -> e, d -> f + * }}} + * will yield: + * {{{ + * Add(a, f) + * }}} + * + * @note The mode with applyRec true can diverge if f is not well formed + */ + def preMap(f: SubTree => Option[SubTree], applyRec : Boolean = false)(e: SubTree): SubTree = { + val rec = preMap(f, applyRec) _ + + val newV = if (applyRec) { + // Apply f as long as it returns Some() + fixpoint { e : SubTree => f(e) getOrElse e } (e) + } else { + f(e) getOrElse e + } + + val Deconstructor(es, builder) = newV + val newEs = es.map(rec) + + if ((newEs zip es).exists { case (bef, aft) => aft ne bef }) { + builder(newEs).copiedFrom(newV) + } else { + newV + } + } + + + /** Post-transformation of the tree. + * + * Takes a partial function of replacements. + * Substitutes '''after''' recursing down the trees. + * + * Supports two modes : + * + * - If applyRec is false (default), will only substitute once on each level. + * e.g. + * {{{ + * Add(a, Minus(b, c)) with replacements: Minus(b,c) -> z, Minus(e,c) -> d, b -> e + * }}} + * will yield: + * {{{ + * Add(a, Minus(e, c)) + * }}} + * + * - If applyRec is true, it will substitute multiple times on each level: + * e.g. + * {{{ + * Add(a, Minus(b, c)) with replacements: Minus(e,c) -> d, b -> e, d -> f + * }}} + * will yield: + * {{{ + * Add(a, f) + * }}} + * + * @note The mode with applyRec true can diverge if f is not well formed (i.e. not convergent) + */ + def postMap(f: SubTree => Option[SubTree], applyRec : Boolean = false)(e: SubTree): SubTree = { + val rec = postMap(f, applyRec) _ + + val Deconstructor(es, builder) = e + val newEs = es.map(rec) + val newV = { + if ((newEs zip es).exists { case (bef, aft) => aft ne bef }) { + builder(newEs).copiedFrom(e) + } else { + e + } + } + + if (applyRec) { + // Apply f as long as it returns Some() + fixpoint { e : SubTree => f(e) getOrElse e } (newV) + } else { + f(newV) getOrElse newV + } + + } + + + /** Applies functions and combines results in a generic way + * + * Start with an initial value, and apply functions to nodes before + * and after the recursion in the children. Combine the results of + * all children and apply a final function on the resulting node. + * + * @param pre a function applied on a node before doing a recursion in the children + * @param post a function applied to the node built from the recursive application to + all children + * @param combiner a function to combine the resulting values from all children with + the current node + * @param init the initial value + * @param expr the expression on which to apply the transform + * + * @see [[simpleTransform]] + * @see [[simplePreTransform]] + * @see [[simplePostTransform]] + */ + def genericTransform[C](pre: (SubTree, C) => (SubTree, C), + post: (SubTree, C) => (SubTree, C), + combiner: (SubTree, Seq[C]) => C)(init: C)(expr: SubTree) = { + + def rec(eIn: SubTree, cIn: C): (SubTree, C) = { + + val (expr, ctx) = pre(eIn, cIn) + val Deconstructor(es, builder) = expr + val (newExpr, newC) = { + val (nes, cs) = es.map{ rec(_, ctx)}.unzip + val newE = builder(nes).copiedFrom(expr) + + (newE, combiner(newE, cs)) + } + + post(newExpr, newC) + } + + rec(expr, init) + } + + + /* + * ============= + * Auxiliary API + * ============= + * + * Convenient methods using the Core API. + */ + + /** Checks if the predicate holds in some sub-expression */ + def exists(matcher: SubTree => Boolean)(e: SubTree): Boolean = { + fold[Boolean]({ (e, subs) => matcher(e) || subs.contains(true) } )(e) + } + + /** Collects a set of objects from all sub-expressions */ + def collect[T](matcher: SubTree => Set[T])(e: SubTree): Set[T] = { + fold[Set[T]]({ (e, subs) => matcher(e) ++ subs.flatten } )(e) + } + + def collectPreorder[T](matcher: SubTree => Seq[T])(e: SubTree): Seq[T] = { + fold[Seq[T]]({ (e, subs) => matcher(e) ++ subs.flatten } )(e) + } + + /** Returns a set of all sub-expressions matching the predicate */ + def filter(matcher: SubTree => Boolean)(e: SubTree): Set[SubTree] = { + collect[SubTree] { e => Set(e) filter matcher }(e) + } + + /** Counts how many times the predicate holds in sub-expressions */ + def count(matcher: SubTree => Int)(e: SubTree): Int = { + fold[Int]({ (e, subs) => matcher(e) + subs.sum } )(e) + } + + /** Replaces bottom-up sub-expressions by looking up for them in a map */ + def replace(substs: Map[SubTree,SubTree], expr: SubTree) : SubTree = { + postMap(substs.lift)(expr) + } + + /** Replaces bottom-up sub-expressions by looking up for them in the provided order */ + def replaceSeq(substs: Seq[(SubTree, SubTree)], expr: SubTree): SubTree = { + var res = expr + for (s <- substs) { + res = replace(Map(s), res) + } + res + } + +} \ No newline at end of file diff --git a/src/main/scala/leon/purescala/TypeOps.scala b/src/main/scala/leon/purescala/TypeOps.scala index db655365c24a7304831e818849238bba0a849de9..20b8e5fdbee22a2b110fdc85839e9495cc19c7d4 100644 --- a/src/main/scala/leon/purescala/TypeOps.scala +++ b/src/main/scala/leon/purescala/TypeOps.scala @@ -11,16 +11,16 @@ import Extractors._ import Constructors._ import ExprOps.preMap -object TypeOps { +object TypeOps extends { val Deconstructor = NAryType } with SubTreeOps[TypeTree] { def typeDepth(t: TypeTree): Int = t match { case NAryType(tps, builder) => 1+ (0 +: (tps map typeDepth)).max } - def typeParamsOf(t: TypeTree): Set[TypeParameter] = t match { - case tp: TypeParameter => Set(tp) - case _ => - val NAryType(subs, _) = t - subs.flatMap(typeParamsOf).toSet + def typeParamsOf(t: TypeTree): Set[TypeParameter] = { + collect[TypeParameter]({ + case tp: TypeParameter => Set(tp) + case _ => Set.empty + })(t) } def canBeSubtypeOf( @@ -313,7 +313,7 @@ object TypeOps { val returnType = tpeSub(fd.returnType) val params = fd.params map (vd => vd.copy(id = freshId(vd.id, tpeSub(vd.getType)))) val newFd = fd.duplicate(id, tparams, params, returnType) - val subCalls = preMap { + val subCalls = ExprOps.preMap { case fi @ FunctionInvocation(tfd, args) if tfd.fd == fd => Some(FunctionInvocation(newFd.typed(tfd.tps), args).copiedFrom(fi)) case _ => diff --git a/src/main/scala/leon/purescala/Types.scala b/src/main/scala/leon/purescala/Types.scala index 3a0a85bb24045df18ab65b0afa488a34c8921315..9ec0e4b41f33aa0895ff160a5379d16a7cb88d87 100644 --- a/src/main/scala/leon/purescala/Types.scala +++ b/src/main/scala/leon/purescala/Types.scala @@ -133,7 +133,7 @@ object Types { case class AbstractClassType(classDef: AbstractClassDef, tps: Seq[TypeTree]) extends ClassType case class CaseClassType(classDef: CaseClassDef, tps: Seq[TypeTree]) extends ClassType - object NAryType { + object NAryType extends SubTreeOps.Extractor[TypeTree] { def unapply(t: TypeTree): Option[(Seq[TypeTree], Seq[TypeTree] => TypeTree)] = t match { case CaseClassType(ccd, ts) => Some((ts, ts => CaseClassType(ccd, ts))) case AbstractClassType(acd, ts) => Some((ts, ts => AbstractClassType(acd, ts))) @@ -142,6 +142,7 @@ object Types { case SetType(t) => Some((Seq(t), ts => SetType(ts.head))) case MapType(from,to) => Some((Seq(from, to), t => MapType(t(0), t(1)))) case FunctionType(fts, tt) => Some((tt +: fts, ts => FunctionType(ts.tail.toList, ts.head))) + /* n-ary operators */ case t => Some(Nil, _ => t) } } diff --git a/src/main/scala/leon/solvers/Model.scala b/src/main/scala/leon/solvers/Model.scala index 07bdee913f21605fbc41f660af608c492e5ee1b5..060cb7fc6fcf83df0785002a1dbfc35f40918113 100644 --- a/src/main/scala/leon/solvers/Model.scala +++ b/src/main/scala/leon/solvers/Model.scala @@ -68,6 +68,7 @@ class Model(protected val mapping: Map[Identifier, Expr]) def isDefinedAt(id: Identifier): Boolean = mapping.isDefinedAt(id) def get(id: Identifier): Option[Expr] = mapping.get(id) def getOrElse[E >: Expr](id: Identifier, e: E): E = get(id).getOrElse(e) + def ids = mapping.keys def apply(id: Identifier): Expr = get(id).getOrElse { throw new IllegalArgumentException } } diff --git a/src/main/scala/leon/solvers/SolverFactory.scala b/src/main/scala/leon/solvers/SolverFactory.scala index 67d28f877019a3e4741df6d268c68fc2d86d17ee..4e87a28883a62455a79d24d0279e6ae84bc595cf 100644 --- a/src/main/scala/leon/solvers/SolverFactory.scala +++ b/src/main/scala/leon/solvers/SolverFactory.scala @@ -91,7 +91,8 @@ object SolverFactory { SolverFactory(() => new GroundSolver(ctx, program) with TimeoutSolver) case "smt-z3" => - SolverFactory(() => new UnrollingSolver(ctx, program, new SMTLIBZ3Solver(ctx, program)) with TimeoutSolver) + SolverFactory(() => new Z3StringCapableSolver(ctx, program, (program: Program) => + new UnrollingSolver(ctx, program, new SMTLIBZ3Solver(ctx, program))) with TimeoutSolver) case "smt-z3-q" => SolverFactory(() => new SMTLIBZ3QuantifiedSolver(ctx, program) with TimeoutSolver) diff --git a/src/main/scala/leon/solvers/combinators/Z3StringCapableSolver.scala b/src/main/scala/leon/solvers/combinators/Z3StringCapableSolver.scala new file mode 100644 index 0000000000000000000000000000000000000000..c57958bc726b9384b49ace471e44c2bc595d5358 --- /dev/null +++ b/src/main/scala/leon/solvers/combinators/Z3StringCapableSolver.scala @@ -0,0 +1,114 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package solvers +package combinators + +import purescala.Common._ +import purescala.Definitions._ +import purescala.Quantification._ +import purescala.Constructors._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.Types._ +import purescala.DefOps +import purescala.TypeOps +import purescala.Extractors._ +import utils._ +import z3.FairZ3Component.{optFeelingLucky, optUseCodeGen, optAssumePre, optNoChecks, optUnfoldFactor} +import templates._ +import evaluators._ +import Template._ +import leon.solvers.z3.Z3StringConversionReverse + +object Z3StringCapableSolver { + def convert(p: Program): ((Program, Map[FunDef, FunDef]), Z3StringConversionReverse, Map[Identifier, Identifier]) = { + val converter = new Z3StringConversionReverse { + def getProgram = p + + def convertId(id: Identifier): (Identifier, Variable) = { + id -> Variable(FreshIdentifier(id.name, convertType(id.getType))) + } + def convertToTarget(e: Expr)(implicit bindings: Map[Identifier, Expr]): Expr = e match { + case Variable(id) if bindings contains id => bindings(id) + case Let(a, expr, body) if TypeOps.exists( _ == StringType)(a.getType) => + val new_a_bid = convertId(a) + val new_bindings = bindings + new_a_bid + val expr2 = convertToTarget(expr)(new_bindings) + val body2 = convertToTarget(expr)(new_bindings) + Let(new_a_bid._1, expr2, body2) + case StringConverted(p) => p + case Operator(es, builder) => + val rec = convertToTarget _ + val newEs = es.map(rec) + if ((newEs zip es).exists { case (bef, aft) => aft ne bef }) { + builder(newEs).copiedFrom(e) + } else { + e + } + } + def targetApplication(fd: TypedFunDef, args: Seq[Expr])(implicit bindings: Map[Identifier, Expr]): Expr = { + FunctionInvocation(fd, args) + } + } + import converter._ + var globalIdMap = Map[Identifier, Identifier]() + (DefOps.replaceFunDefs(p)((fd: FunDef) => { + if( fd.body.map(exists(e => TypeOps.exists{ _== StringType }(e.getType))).getOrElse(false) || + fd.paramIds.exists(id => TypeOps.exists(_ == StringType)(id.getType))) { + + val idMap = fd.params.map(vd => vd.id -> FreshIdentifier(vd.id.name, convertType(vd.id.getType))).toMap + globalIdMap ++= idMap + implicit val idVarMap = idMap.mapValues(id => Variable(id)) + + val newFd = fd.duplicate(FreshIdentifier(fd.id.name, convertType(fd.id.getType)), + fd.tparams, + fd.params.map(vd => ValDef(idMap(vd.id))), + convertType(fd.returnType)) + fd.body foreach { body => + newFd.body = Some(convertToTarget(body)) + } + Some(newFd) + } else None + }), converter, globalIdMap) + } +} + +class Z3StringCapableSolver(val context: LeonContext, val program: Program, f: Program => UnrollingSolver) extends Solver + with NaiveAssumptionSolver + with EvaluatingSolver + with QuantificationSolver { + + val ((new_program, mappings), converter, idMap) = Z3StringCapableSolver.convert(program) + + val idMapReverse = idMap.map(kv => kv._2 -> kv._1).toMap + val underlying = f(new_program) + + // Members declared in leon.solvers.EvaluatingSolver + val useCodeGen: Boolean = underlying.useCodeGen + + // Members declared in leon.utils.Interruptible + def interrupt(): Unit = underlying.interrupt() + def recoverInterrupt(): Unit = underlying.recoverInterrupt() + + // Members declared in leon.solvers.QuantificationSolver + def getModel: leon.solvers.HenkinModel = { + val model = underlying.getModel + val ids = model.ids.toSet + val exprs = ids.map(model.apply) + val original_ids = ids.map(idMapReverse) // Should exist. + val original_exprs = exprs.map(e => converter.StringConversion.reverse(e)) + new HenkinModel(original_ids.zip(original_exprs).toMap, model.doms) // TODO: Convert the domains as well + } + + // Members declared in leon.solvers.Solver + def assertCnstr(expression: leon.purescala.Expressions.Expr): Unit = { + underlying.assertCnstr(converter.convertToTarget(expression)(Map())) + } + def check: Option[Boolean] = underlying.check + def free(): Unit = underlying.free() + def name: String = "String" + underlying.name + def pop(): Unit = underlying.pop() + def push(): Unit = underlying.push() + def reset(): Unit = underlying.reset() +} \ No newline at end of file diff --git a/src/main/scala/leon/solvers/templates/TemplateGenerator.scala b/src/main/scala/leon/solvers/templates/TemplateGenerator.scala index aa0e5dad6506be945919cb22361cce649f40c077..b16083bad85958f5f2f906368a1938cb91dd25cd 100644 --- a/src/main/scala/leon/solvers/templates/TemplateGenerator.scala +++ b/src/main/scala/leon/solvers/templates/TemplateGenerator.scala @@ -9,7 +9,7 @@ import purescala.Expressions._ import purescala.Extractors._ import purescala.ExprOps._ import purescala.Types._ -import purescala.TypeOps._ +import purescala.TypeOps.bestRealType import purescala.Definitions._ import purescala.Constructors._ import purescala.Quantification._ diff --git a/src/main/scala/leon/solvers/templates/TemplateManager.scala b/src/main/scala/leon/solvers/templates/TemplateManager.scala index 2b75f08f0480cf272515bb8d8393e01e29d4dbf1..cdfe0c9ed6462b322b760e635122f5c3b11d2923 100644 --- a/src/main/scala/leon/solvers/templates/TemplateManager.scala +++ b/src/main/scala/leon/solvers/templates/TemplateManager.scala @@ -11,7 +11,7 @@ import purescala.Quantification._ import purescala.Extractors._ import purescala.ExprOps._ import purescala.Types._ -import purescala.TypeOps._ +import purescala.TypeOps.bestRealType import utils._ diff --git a/src/main/scala/leon/solvers/z3/Z3StringConversion.scala b/src/main/scala/leon/solvers/z3/Z3StringConversion.scala index 3daf1ad4964ad73e8c4d9701ae4e65d0f4170897..68e966ad2a2e38582d633ee4fac1a81037d831b4 100644 --- a/src/main/scala/leon/solvers/z3/Z3StringConversion.scala +++ b/src/main/scala/leon/solvers/z3/Z3StringConversion.scala @@ -50,7 +50,13 @@ trait Z3StringTypeConversion { def convertType(t: TypeTree): TypeTree = t match { case StringType => listchar - case _ => t + case NAryType(subtypes, builder) => + builder(subtypes.map(convertType)) + } + def convertTypeBack(expected_type: TypeTree)(t: TypeTree): TypeTree = (expected_type, t) match { + case (StringType, `listchar`) => StringType + case (NAryType(ex, builder), NAryType(cur, builder2)) => + builder2(ex.zip(cur).map(ex_cur => convertTypeBack(ex_cur._1)(ex_cur._2))) } def convertToString(e: Expr)(implicit p: Program): String = stringBijection.cachedA(e) { @@ -59,7 +65,7 @@ trait Z3StringTypeConversion { case CaseClass(_, Seq()) => "" } } - def convertFromString(v: String) = + def convertFromString(v: String): Expr = stringBijection.cachedB(v) { v.toList.foldRight(CaseClass(nilchar, Seq())){ case (char, l) => CaseClass(conschar, Seq(CharLiteral(char), l)) @@ -68,11 +74,13 @@ trait Z3StringTypeConversion { } trait Z3StringConversion[TargetType] extends Z3StringTypeConversion { + /** Method which can use recursively StringConverted in its body in unapply positions */ def convertToTarget(e: Expr)(implicit bindings: Map[Identifier, TargetType]): TargetType + /** How the application (or function invocation) of a given fundef is performed in the target type. */ def targetApplication(fd: TypedFunDef, args: Seq[TargetType])(implicit bindings: Map[Identifier, TargetType]): TargetType object StringConverted { - def unapply(e: Expr)(implicit bindings: Map[Identifier, TargetType]): Option[TargetType] = e match { + def unapply(e: Expr)(implicit replacement: Map[Identifier, TargetType]): Option[TargetType] = e match { case StringLiteral(v) => // No string support for z3 at this moment. val stringEncoding = convertFromString(v) @@ -91,4 +99,33 @@ trait Z3StringConversion[TargetType] extends Z3StringTypeConversion { def apply(t: TypeTree): TypeTree = convertType(t) } +} + +trait Z3StringConversionReverse extends Z3StringConversion[Expr] { + + object StringConversion { + def reverse(e: Expr): Expr = unapply(e).getOrElse(e) + def unapply(e: Expr): Option[Expr] = e match { + case CaseClass(`conschar`, Seq(CharLiteral(c), l)) => + reverse(l) match { + case StringLiteral(s) => Some(StringLiteral(c + s)) + case _ => None + } + case CaseClass(`nilchar`, Seq()) => + Some(StringLiteral("")) + case FunctionInvocation(`list_size`, Seq(a)) => + Some(StringLength(reverse(a))) + case FunctionInvocation(`list_++`, Seq(a, b)) => + Some(StringConcat(reverse(a), reverse(b))) + case FunctionInvocation(`list_take`, + Seq(FunctionInvocation(`list_drop`, Seq(a, start)), length)) => + val rstart = reverse(start) + Some(SubString(reverse(a), rstart, Plus(rstart, reverse(length)))) + case purescala.Extractors.Operator(es, builder) => + Some(builder(es.map(reverse _))) + case _ => None + } + } + + } \ No newline at end of file diff --git a/src/main/scala/leon/utils/InliningPhase.scala b/src/main/scala/leon/utils/InliningPhase.scala index 8053a8dc1e9956c9ef264c247d41f8d96ffbb934..17fdf48bc88d88cca9c49e63672af6f3c76afa48 100644 --- a/src/main/scala/leon/utils/InliningPhase.scala +++ b/src/main/scala/leon/utils/InliningPhase.scala @@ -5,7 +5,7 @@ package leon.utils import leon._ import purescala.Definitions._ import purescala.Expressions._ -import purescala.TypeOps._ +import purescala.TypeOps.instantiateType import purescala.ExprOps._ import purescala.DefOps._ import purescala.Constructors.caseClassSelector diff --git a/src/main/scala/leon/xlang/ImperativeCodeElimination.scala b/src/main/scala/leon/xlang/ImperativeCodeElimination.scala index 45bb36770cddca417ba51582cd7824fc09152199..5feb45dda12662fcd904df25453c5f41e76beb6e 100644 --- a/src/main/scala/leon/xlang/ImperativeCodeElimination.scala +++ b/src/main/scala/leon/xlang/ImperativeCodeElimination.scala @@ -9,7 +9,7 @@ import leon.purescala.Expressions._ import leon.purescala.Extractors._ import leon.purescala.Constructors._ import leon.purescala.ExprOps._ -import leon.purescala.TypeOps._ +import leon.purescala.TypeOps.leastUpperBound import leon.purescala.Types._ import leon.xlang.Expressions._