diff --git a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala index 75c691edf1f7e5a4ca056bf8c8448bff9b5ead77..f7240f1b763951aee607241cc07bf41a5c535e4a 100644 --- a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala +++ b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala @@ -8,7 +8,7 @@ import purescala.Constructors._ import purescala.ExprOps._ import purescala.Expressions.Pattern import purescala.Extractors._ -import purescala.TypeOps._ +import purescala.TypeOps.isSubtypeOf import purescala.Types._ import purescala.Common._ import purescala.Expressions._ diff --git a/src/main/scala/leon/evaluators/StreamEvaluator.scala b/src/main/scala/leon/evaluators/StreamEvaluator.scala index 9cc6dd132036ffdde4a5ef70e2be87e6276f9f3c..385157e4d64c835841279f0970705f6a5aa10b1c 100644 --- a/src/main/scala/leon/evaluators/StreamEvaluator.scala +++ b/src/main/scala/leon/evaluators/StreamEvaluator.scala @@ -6,7 +6,7 @@ package evaluators import purescala.Constructors._ import purescala.ExprOps._ import purescala.Extractors._ -import purescala.TypeOps._ +import purescala.TypeOps.{leastUpperBound, isSubtypeOf} import purescala.Types._ import purescala.Common.Identifier import purescala.Definitions.{TypedFunDef, Program} diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala index c5877e1e4d3705c67ebe51eb97c57daf8da58b33..e3fb828617ceca48f6b5287cafe3494d9692b570 100644 --- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala +++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala @@ -22,7 +22,7 @@ import Common._ import Extractors._ import Constructors._ import ExprOps._ -import TypeOps._ +import TypeOps.{leastUpperBound, typesCompatible, typeParamsOf, canBeSubtypeOf} import xlang.Expressions.{Block => LeonBlock, _} import xlang.ExprOps._ diff --git a/src/main/scala/leon/purescala/DefOps.scala b/src/main/scala/leon/purescala/DefOps.scala index 578f52bff2bb4ba3c2fc79a5e5a4d2e85ce8c6ee..2eeb5a4d44fbde9af89ae5d5c103bed1ba1b75ed 100644 --- a/src/main/scala/leon/purescala/DefOps.scala +++ b/src/main/scala/leon/purescala/DefOps.scala @@ -275,7 +275,15 @@ object DefOps { None } - + /** Clones the given program by replacing some functions by other functions. + * + * @param p The original program + * @param fdMapF Given f, returns Some(g) if f should be replaced by g, and None if f should be kept. + * May be called once each time a function appears (definition and invocation), + * so make sure to output the same if the argument is the same. + * @param fiMapF Given a previous function invocation and its new function definition, returns the expression to use. + * By default it is the function invocation using the new function definition. + * @return the new program with a map from the old functions to the new functions */ def replaceFunDefs(p: Program)(fdMapF: FunDef => Option[FunDef], fiMapF: (FunctionInvocation, FunDef) => Option[Expr] = defaultFiMap) = { @@ -297,7 +305,6 @@ object DefOps { df match { case f : FunDef => val newF = fdMap(f) - newF.fullBody = replaceFunCalls(newF.fullBody, fdMap, fiMapF) newF case d => d @@ -307,7 +314,11 @@ object DefOps { } ) }) - + for(fd <- newP.definedFunctions) { + if(ExprOps.exists{ case FunctionInvocation(TypedFunDef(fd, targs), fargs) => fdMapCache.getOrElse(fd, None) != None case _ => false }(fd.fullBody)) { + fd.fullBody = replaceFunCalls(fd.fullBody, fdMap, fiMapF) + } + } (newP, fdMapCache.collect{ case (ofd, Some(nfd)) => ofd -> nfd }) } @@ -320,32 +331,38 @@ object DefOps { }(e) } - def addFunDefs(p: Program, fds: Traversable[FunDef], after: FunDef): Program = { + def addDefs(p: Program, cds: Traversable[Definition], after: Definition): Program = { var found = false val res = p.copy(units = for (u <- p.units) yield { u.copy( - defs = u.defs.map { + defs = u.defs.flatMap { case m: ModuleDef => val newdefs = for (df <- m.defs) yield { df match { case `after` => found = true - after +: fds.toSeq - case d => - Seq(d) + after +: cds.toSeq + case d => Seq(d) } } - m.copy(defs = newdefs.flatten) - case d => d + Seq(m.copy(defs = newdefs.flatten)) + case `after` => + found = true + after +: cds.toSeq + case d => Seq(d) } ) }) if (!found) { - println("addFunDefs could not find anchor function!") + println("addDefs could not find anchor definition!") } res } + + def addFunDefs(p: Program, fds: Traversable[FunDef], after: FunDef): Program = addDefs(p, fds, after) + + def addClassDefs(p: Program, fds: Traversable[ClassDef], after: ClassDef): Program = addDefs(p, fds, after) // @Note: This function does not filter functions in classdefs def filterFunDefs(p: Program, fdF: FunDef => Boolean): Program = { diff --git a/src/main/scala/leon/purescala/ExprOps.scala b/src/main/scala/leon/purescala/ExprOps.scala index 724f6a0fca6e4b2c91c1d288013ef8fb9a817b41..ec304c0aef6dd48e9dd01edd7d3ed25f190d1a82 100644 --- a/src/main/scala/leon/purescala/ExprOps.scala +++ b/src/main/scala/leon/purescala/ExprOps.scala @@ -19,328 +19,19 @@ import solvers._ * * The generic operations lets you apply operations on a whole tree * expression. You can look at: - * - [[ExprOps.fold foldRight]] - * - [[ExprOps.preTraversal preTraversal]] - * - [[ExprOps.postTraversal postTraversal]] - * - [[ExprOps.preMap preMap]] - * - [[ExprOps.postMap postMap]] - * - [[ExprOps.genericTransform genericTransform]] + * - [[SubTreeOps.fold foldRight]] + * - [[SubTreeOps.preTraversal preTraversal]] + * - [[SubTreeOps.postTraversal postTraversal]] + * - [[SubTreeOps.preMap preMap]] + * - [[SubTreeOps.postMap postMap]] + * - [[SubTreeOps.genericTransform genericTransform]] * * These operations usually take a higher order function that gets applied to the * expression tree in some strategy. They provide an expressive way to build complex * operations on Leon expressions. * */ -object ExprOps { - - /* ======== - * Core API - * ======== - * - * All these functions should be stable, tested, and used everywhere. Modify - * with care. - */ - - - /** Does a right tree fold - * - * A right tree fold applies the input function to the subnodes first (from left - * to right), and combine the results along with the current node value. - * - * @param f a function that takes the current node and the seq - * of results form the subtrees. - * @param e The Expr on which to apply the fold. - * @return The expression after applying `f` on all subtrees. - * @note the computation is lazy, hence you should not rely on side-effects of `f` - */ - def fold[T](f: (Expr, Seq[T]) => T)(e: Expr): T = { - val rec = fold(f) _ - val Operator(es, _) = e - - //Usages of views makes the computation lazy. (which is useful for - //contains-like operations) - f(e, es.view.map(rec)) - } - - /** Pre-traversal of the tree. - * - * Invokes the input function on every node '''before''' visiting - * children. Traverse children from left to right subtrees. - * - * e.g. - * {{{ - * Add(a, Minus(b, c)) - * }}} - * will yield, in order: - * {{{ - * f(Add(a, Minus(b, c))); f(a); f(Minus(b, c)); f(b); f(c) - * }}} - * - * @param f a function to apply on each node of the expression - * @param e the expression to traverse - */ - def preTraversal(f: Expr => Unit)(e: Expr): Unit = { - val rec = preTraversal(f) _ - val Operator(es, _) = e - f(e) - es.foreach(rec) - } - - /** Post-traversal of the tree. - * - * Invokes the input function on every node '''after''' visiting - * children. - * - * e.g. - * {{{ - * Add(a, Minus(b, c)) - * }}} - * will yield, in order: - * {{{ - * f(a), f(b), f(c), f(Minus(b, c)), f(Add(a, Minus(b, c))) - * }}} - * - * @param f a function to apply on each node of the expression - * @param e the expression to traverse - */ - def postTraversal(f: Expr => Unit)(e: Expr): Unit = { - val rec = postTraversal(f) _ - val Operator(es, _) = e - es.foreach(rec) - f(e) - } - - /** Pre-transformation of the tree. - * - * Takes a partial function of replacements and substitute - * '''before''' recursing down the trees. - * - * Supports two modes : - * - * - If applyRec is false (default), will only substitute once on each level. - * - * e.g. - * {{{ - * Add(a, Minus(b, c)) with replacements: Minus(b,c) -> d, b -> e, d -> f - * }}} - * will yield: - * {{{ - * Add(a, d) // And not Add(a, f) because it only substitute once for each level. - * }}} - * - * - If applyRec is true, it will substitute multiple times on each level: - * - * e.g. - * {{{ - * Add(a, Minus(b, c)) with replacements: Minus(b,c) -> d, b -> e, d -> f - * }}} - * will yield: - * {{{ - * Add(a, f) - * }}} - * - * @note The mode with applyRec true can diverge if f is not well formed - */ - def preMap(f: Expr => Option[Expr], applyRec : Boolean = false)(e: Expr): Expr = { - val rec = preMap(f, applyRec) _ - - val newV = if (applyRec) { - // Apply f as long as it returns Some() - fixpoint { e : Expr => f(e) getOrElse e } (e) - } else { - f(e) getOrElse e - } - - val Operator(es, builder) = newV - val newEs = es.map(rec) - - if ((newEs zip es).exists { case (bef, aft) => aft ne bef }) { - builder(newEs).copiedFrom(newV) - } else { - newV - } - } - - /** Post-transformation of the tree. - * - * Takes a partial function of replacements. - * Substitutes '''after''' recursing down the trees. - * - * Supports two modes : - * - * - If applyRec is false (default), will only substitute once on each level. - * e.g. - * {{{ - * Add(a, Minus(b, c)) with replacements: Minus(b,c) -> z, Minus(e,c) -> d, b -> e - * }}} - * will yield: - * {{{ - * Add(a, Minus(e, c)) - * }}} - * - * - If applyRec is true, it will substitute multiple times on each level: - * e.g. - * {{{ - * Add(a, Minus(b, c)) with replacements: Minus(e,c) -> d, b -> e, d -> f - * }}} - * will yield: - * {{{ - * Add(a, f) - * }}} - * - * @note The mode with applyRec true can diverge if f is not well formed (i.e. not convergent) - */ - def postMap(f: Expr => Option[Expr], applyRec : Boolean = false)(e: Expr): Expr = { - val rec = postMap(f, applyRec) _ - - val Operator(es, builder) = e - val newEs = es.map(rec) - val newV = { - if ((newEs zip es).exists { case (bef, aft) => aft ne bef }) { - builder(newEs).copiedFrom(e) - } else { - e - } - } - - if (applyRec) { - // Apply f as long as it returns Some() - fixpoint { e : Expr => f(e) getOrElse e } (newV) - } else { - f(newV) getOrElse newV - } - - } - - - /** Pre-transformation of the tree, with a context value from "top-down". - * - * Takes a partial function of replacements. - * Substitutes '''before''' recursing down the trees. The function returns - * an option of the new value, as well as the new context to be used for - * the recursion in its children. The context is "lost" when going back up, - * so changes made by one node will not be see by its siblings. - */ - def preMapWithContext[C](f: (Expr, C) => (Option[Expr], C), applyRec: Boolean = false) - (e: Expr, c: C): Expr = { - - def rec(expr: Expr, context: C): Expr = { - - val (newV, newCtx) = { - if(applyRec) { - var ctx = context - val finalV = fixpoint{ e: Expr => { - val res = f(e, ctx) - ctx = res._2 - res._1.getOrElse(e) - }} (expr) - (finalV, ctx) - } else { - val res = f(expr, context) - (res._1.getOrElse(expr), res._2) - } - } - - val Operator(es, builder) = newV - val newEs = es.map(e => rec(e, newCtx)) - - if ((newEs zip es).exists { case (bef, aft) => aft ne bef }) { - builder(newEs).copiedFrom(newV) - } else { - newV - } - - } - - rec(e, c) - } - - - /** Applies functions and combines results in a generic way - * - * Start with an initial value, and apply functions to nodes before - * and after the recursion in the children. Combine the results of - * all children and apply a final function on the resulting node. - * - * @param pre a function applied on a node before doing a recursion in the children - * @param post a function applied to the node built from the recursive application to - all children - * @param combiner a function to combine the resulting values from all children with - the current node - * @param init the initial value - * @param expr the expression on which to apply the transform - * - * @see [[simpleTransform]] - * @see [[simplePreTransform]] - * @see [[simplePostTransform]] - */ - def genericTransform[C](pre: (Expr, C) => (Expr, C), - post: (Expr, C) => (Expr, C), - combiner: (Expr, Seq[C]) => C)(init: C)(expr: Expr) = { - - def rec(eIn: Expr, cIn: C): (Expr, C) = { - - val (expr, ctx) = pre(eIn, cIn) - val Operator(es, builder) = expr - val (newExpr, newC) = { - val (nes, cs) = es.map{ rec(_, ctx)}.unzip - val newE = builder(nes).copiedFrom(expr) - - (newE, combiner(newE, cs)) - } - - post(newExpr, newC) - } - - rec(expr, init) - } - - /* - * ============= - * Auxiliary API - * ============= - * - * Convenient methods using the Core API. - */ - - /** Checks if the predicate holds in some sub-expression */ - def exists(matcher: Expr => Boolean)(e: Expr): Boolean = { - fold[Boolean]({ (e, subs) => matcher(e) || subs.contains(true) } )(e) - } - - /** Collects a set of objects from all sub-expressions */ - def collect[T](matcher: Expr => Set[T])(e: Expr): Set[T] = { - fold[Set[T]]({ (e, subs) => matcher(e) ++ subs.flatten } )(e) - } - - def collectPreorder[T](matcher: Expr => Seq[T])(e: Expr): Seq[T] = { - fold[Seq[T]]({ (e, subs) => matcher(e) ++ subs.flatten } )(e) - } - - /** Returns a set of all sub-expressions matching the predicate */ - def filter(matcher: Expr => Boolean)(e: Expr): Set[Expr] = { - collect[Expr] { e => Set(e) filter matcher }(e) - } - - /** Counts how many times the predicate holds in sub-expressions */ - def count(matcher: Expr => Int)(e: Expr): Int = { - fold[Int]({ (e, subs) => matcher(e) + subs.sum } )(e) - } - - /** Replaces bottom-up sub-expressions by looking up for them in a map */ - def replace(substs: Map[Expr,Expr], expr: Expr) : Expr = { - postMap(substs.lift)(expr) - } - - /** Replaces bottom-up sub-expressions by looking up for them in the provided order */ - def replaceSeq(substs: Seq[(Expr, Expr)], expr: Expr): Expr = { - var res = expr - for (s <- substs) { - res = replace(Map(s), res) - } - res - } - +object ExprOps extends { val Deconstructor = Operator } with SubTreeOps[Expr] { /** Replaces bottom-up sub-identifiers by looking up for them in a map */ def replaceFromIDs(substs: Map[Identifier, Expr], expr: Expr) : Expr = { postMap({ @@ -375,7 +66,7 @@ object ExprOps { Lambda(args, rec(binders ++ args.map(_.id), bd)) case Forall(args, bd) => Forall(args, rec(binders ++ args.map(_.id), bd)) - case Operator(subs, builder) => + case Deconstructor(subs, builder) => builder(subs map (rec(binders, _))) }).copiedFrom(e) @@ -418,8 +109,7 @@ object ExprOps { case _ => Set() }(expr) } - - /** Returns all Function calls found in the expression */ + def nestedFunDefsOf(expr: Expr): Set[FunDef] = { collect[FunDef] { case LetDef(fds, _) => fds.toSet @@ -746,7 +436,7 @@ object ExprOps { case i @ IfExpr(t1,t2,t3) => IfExpr(rec(t1, s),rec(t2, s),rec(t3, s)) case m @ MatchExpr(scrut, cses) => matchExpr(rec(scrut, s), cses.map(inCase(_, s))).setPos(m) case p @ Passes(in, out, cses) => Passes(rec(in, s), rec(out,s), cses.map(inCase(_, s))).setPos(p) - case n @ Operator(args, recons) => { + case n @ Deconstructor(args, recons) => { var change = false val rargs = args.map(a => { val ra = rec(a, s) @@ -1255,7 +945,7 @@ object ExprOps { def transform(expr: Expr): Option[Expr] = expr match { case IfExpr(c, t, e) => None - case nop@Operator(ts, op) => { + case nop@Deconstructor(ts, op) => { val iteIndex = ts.indexWhere{ case IfExpr(_, _, _) => true case _ => false } if(iteIndex == -1) None else { val (beforeIte, startIte) = ts.splitAt(iteIndex) @@ -1406,7 +1096,7 @@ object ExprOps { formulaSize(rhs) + og.map(formulaSize).getOrElse(0) + patternSize(p) }.sum - case Operator(es, _) => + case Deconstructor(es, _) => es.map(formulaSize).sum+1 } @@ -1720,7 +1410,7 @@ object ExprOps { // TODO: Seems a lot is missing, like Literals - case Same(Operator(es1, _), Operator(es2, _)) => + case Same(Deconstructor(es1, _), Deconstructor(es2, _)) => (es1.size == es2.size) && (es1 zip es2).forall{ case (e1, e2) => isHomo(e1, e2) } @@ -2059,7 +1749,7 @@ object ExprOps { f(e, initParent) - val Operator(es, _) = e + val Deconstructor(es, _) = e es foreach rec } @@ -2152,7 +1842,7 @@ object ExprOps { case l @ Lambda(args, body) => val newBody = rec(body, true) extract(Lambda(args, newBody), build) - case Operator(es, recons) => recons(es.map(rec(_, build))) + case Deconstructor(es, recons) => recons(es.map(rec(_, build))) } rec(lift(expr), true) diff --git a/src/main/scala/leon/purescala/Expressions.scala b/src/main/scala/leon/purescala/Expressions.scala index 76a75fdf263248e7045ebeb959fb65750b4767ad..88edef68586474c24e7a13b70a22e9beed7441b8 100644 --- a/src/main/scala/leon/purescala/Expressions.scala +++ b/src/main/scala/leon/purescala/Expressions.scala @@ -360,6 +360,22 @@ object Expressions { someValue.id ) } + + object PatternExtractor extends SubTreeOps.Extractor[Pattern] { + def unapply(e: Pattern): Option[(Seq[Pattern], (Seq[Pattern]) => Pattern)] = e match { + case (_: InstanceOfPattern) | (_: WildcardPattern) | (_: LiteralPattern[_]) => + Some(Seq(), es => e) + case CaseClassPattern(binder, ct, subpatterns) => + Some(subpatterns, es => CaseClassPattern(binder, ct, es)) + case TuplePattern(binder, subpatterns) => + Some(subpatterns, es => TuplePattern(binder, es)) + case UnapplyPattern(binder, unapplyFun, subpatterns) => + Some(subpatterns, es => UnapplyPattern(binder, unapplyFun, es)) + case _ => None + } + } + + object PatternOps extends { val Deconstructor = PatternExtractor } with SubTreeOps[Pattern] /** Symbolic I/O examples as a match/case. * $encodingof `out == (in match { cases; case _ => out })` diff --git a/src/main/scala/leon/purescala/Extractors.scala b/src/main/scala/leon/purescala/Extractors.scala index e2581dd8cdb33e3d025e8206d72dd32fc4ca59f7..cfa4780efccad338b5df44400b53da7264a2c2d7 100644 --- a/src/main/scala/leon/purescala/Extractors.scala +++ b/src/main/scala/leon/purescala/Extractors.scala @@ -7,12 +7,11 @@ import Expressions._ import Common._ import Types._ import Constructors._ -import ExprOps._ -import Definitions.Program +import Definitions.{Program, AbstractClassDef, CaseClassDef} object Extractors { - object Operator { + object Operator extends SubTreeOps.Extractor[Expr] { def unapply(expr: Expr): Option[(Seq[Expr], (Seq[Expr]) => Expr)] = expr match { /* Unary operators */ case Not(t) => @@ -250,6 +249,8 @@ object Extractors { None } } + + // Extractors for types are available at Types.NAryType trait Extractable { def extract: Option[(Seq[Expr], Seq[Expr] => Expr)] @@ -367,7 +368,7 @@ object Extractors { def unapply(me : MatchExpr) : Option[(Pattern, Expr, Expr)] = { Option(me) collect { - case MatchExpr(scrut, List(SimpleCase(pattern, body))) if !aliased(pattern.binders, variablesOf(scrut)) => + case MatchExpr(scrut, List(SimpleCase(pattern, body))) if !aliased(pattern.binders, ExprOps.variablesOf(scrut)) => ( pattern, scrut, body ) } } diff --git a/src/main/scala/leon/purescala/SubTreeOps.scala b/src/main/scala/leon/purescala/SubTreeOps.scala new file mode 100644 index 0000000000000000000000000000000000000000..140bd5edc2ff5f316a7afb0d90442df12b78ced8 --- /dev/null +++ b/src/main/scala/leon/purescala/SubTreeOps.scala @@ -0,0 +1,327 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package purescala + +import Expressions.Expr +import Types.TypeTree +import Common._ +import utils._ + +object SubTreeOps { + trait Extractor[SubTree <: Tree] { + def unapply(e: SubTree): Option[(Seq[SubTree], (Seq[SubTree]) => SubTree)] + } +} +trait SubTreeOps[SubTree <: Tree] { + val Deconstructor: SubTreeOps.Extractor[SubTree] + + /* ======== + * Core API + * ======== + * + * All these functions should be stable, tested, and used everywhere. Modify + * with care. + */ + + /** Does a right tree fold + * + * A right tree fold applies the input function to the subnodes first (from left + * to right), and combine the results along with the current node value. + * + * @param f a function that takes the current node and the seq + * of results form the subtrees. + * @param e The value on which to apply the fold. + * @return The expression after applying `f` on all subtrees. + * @note the computation is lazy, hence you should not rely on side-effects of `f` + */ + def fold[T](f: (SubTree, Seq[T]) => T)(e: SubTree): T = { + val rec = fold(f) _ + val Deconstructor(es, _) = e + + //Usages of views makes the computation lazy. (which is useful for + //contains-like operations) + f(e, es.view.map(rec)) + } + + + /** Pre-traversal of the tree. + * + * Invokes the input function on every node '''before''' visiting + * children. Traverse children from left to right subtrees. + * + * e.g. + * {{{ + * Add(a, Minus(b, c)) + * }}} + * will yield, in order: + * {{{ + * f(Add(a, Minus(b, c))); f(a); f(Minus(b, c)); f(b); f(c) + * }}} + * + * @param f a function to apply on each node of the expression + * @param e the expression to traverse + */ + def preTraversal(f: SubTree => Unit)(e: SubTree): Unit = { + val rec = preTraversal(f) _ + val Deconstructor(es, _) = e + f(e) + es.foreach(rec) + } + + /** Post-traversal of the tree. + * + * Invokes the input function on every node '''after''' visiting + * children. + * + * e.g. + * {{{ + * Add(a, Minus(b, c)) + * }}} + * will yield, in order: + * {{{ + * f(a), f(b), f(c), f(Minus(b, c)), f(Add(a, Minus(b, c))) + * }}} + * + * @param f a function to apply on each node of the expression + * @param e the expression to traverse + */ + def postTraversal(f: SubTree => Unit)(e: SubTree): Unit = { + val rec = postTraversal(f) _ + val Deconstructor(es, _) = e + es.foreach(rec) + f(e) + } + + /** Pre-transformation of the tree. + * + * Takes a partial function of replacements and substitute + * '''before''' recursing down the trees. + * + * Supports two modes : + * + * - If applyRec is false (default), will only substitute once on each level. + * + * e.g. + * {{{ + * Add(a, Minus(b, c)) with replacements: Minus(b,c) -> d, b -> e, d -> f + * }}} + * will yield: + * {{{ + * Add(a, d) // And not Add(a, f) because it only substitute once for each level. + * }}} + * + * - If applyRec is true, it will substitute multiple times on each level: + * + * e.g. + * {{{ + * Add(a, Minus(b, c)) with replacements: Minus(b,c) -> d, b -> e, d -> f + * }}} + * will yield: + * {{{ + * Add(a, f) + * }}} + * + * @note The mode with applyRec true can diverge if f is not well formed + */ + def preMap(f: SubTree => Option[SubTree], applyRec : Boolean = false)(e: SubTree): SubTree = { + val rec = preMap(f, applyRec) _ + + val newV = if (applyRec) { + // Apply f as long as it returns Some() + fixpoint { e : SubTree => f(e) getOrElse e } (e) + } else { + f(e) getOrElse e + } + + val Deconstructor(es, builder) = newV + val newEs = es.map(rec) + + if ((newEs zip es).exists { case (bef, aft) => aft ne bef }) { + builder(newEs).copiedFrom(newV) + } else { + newV + } + } + + + /** Post-transformation of the tree. + * + * Takes a partial function of replacements. + * Substitutes '''after''' recursing down the trees. + * + * Supports two modes : + * + * - If applyRec is false (default), will only substitute once on each level. + * e.g. + * {{{ + * Add(a, Minus(b, c)) with replacements: Minus(b,c) -> z, Minus(e,c) -> d, b -> e + * }}} + * will yield: + * {{{ + * Add(a, Minus(e, c)) + * }}} + * + * - If applyRec is true, it will substitute multiple times on each level: + * e.g. + * {{{ + * Add(a, Minus(b, c)) with replacements: Minus(e,c) -> d, b -> e, d -> f + * }}} + * will yield: + * {{{ + * Add(a, f) + * }}} + * + * @note The mode with applyRec true can diverge if f is not well formed (i.e. not convergent) + */ + def postMap(f: SubTree => Option[SubTree], applyRec : Boolean = false)(e: SubTree): SubTree = { + val rec = postMap(f, applyRec) _ + + val Deconstructor(es, builder) = e + val newEs = es.map(rec) + val newV = { + if ((newEs zip es).exists { case (bef, aft) => aft ne bef }) { + builder(newEs).copiedFrom(e) + } else { + e + } + } + + if (applyRec) { + // Apply f as long as it returns Some() + fixpoint { e : SubTree => f(e) getOrElse e } (newV) + } else { + f(newV) getOrElse newV + } + + } + + + /** Applies functions and combines results in a generic way + * + * Start with an initial value, and apply functions to nodes before + * and after the recursion in the children. Combine the results of + * all children and apply a final function on the resulting node. + * + * @param pre a function applied on a node before doing a recursion in the children + * @param post a function applied to the node built from the recursive application to + all children + * @param combiner a function to combine the resulting values from all children with + the current node + * @param init the initial value + * @param expr the expression on which to apply the transform + * + * @see [[simpleTransform]] + * @see [[simplePreTransform]] + * @see [[simplePostTransform]] + */ + def genericTransform[C](pre: (SubTree, C) => (SubTree, C), + post: (SubTree, C) => (SubTree, C), + combiner: (SubTree, Seq[C]) => C)(init: C)(expr: SubTree) = { + + def rec(eIn: SubTree, cIn: C): (SubTree, C) = { + + val (expr, ctx) = pre(eIn, cIn) + val Deconstructor(es, builder) = expr + val (newExpr, newC) = { + val (nes, cs) = es.map{ rec(_, ctx)}.unzip + val newE = builder(nes).copiedFrom(expr) + + (newE, combiner(newE, cs)) + } + + post(newExpr, newC) + } + + rec(expr, init) + } + + /** Pre-transformation of the tree, with a context value from "top-down". + * + * Takes a partial function of replacements. + * Substitutes '''before''' recursing down the trees. The function returns + * an option of the new value, as well as the new context to be used for + * the recursion in its children. The context is "lost" when going back up, + * so changes made by one node will not be see by its siblings. + */ + def preMapWithContext[C](f: (SubTree, C) => (Option[SubTree], C), applyRec: Boolean = false) + (e: SubTree, c: C): SubTree = { + + def rec(expr: SubTree, context: C): SubTree = { + + val (newV, newCtx) = { + if(applyRec) { + var ctx = context + val finalV = fixpoint{ e: SubTree => { + val res = f(e, ctx) + ctx = res._2 + res._1.getOrElse(e) + }} (expr) + (finalV, ctx) + } else { + val res = f(expr, context) + (res._1.getOrElse(expr), res._2) + } + } + + val Deconstructor(es, builder) = newV + val newEs = es.map(e => rec(e, newCtx)) + + if ((newEs zip es).exists { case (bef, aft) => aft ne bef }) { + builder(newEs).copiedFrom(newV) + } else { + newV + } + + } + + rec(e, c) + } + + /* + * ============= + * Auxiliary API + * ============= + * + * Convenient methods using the Core API. + */ + + /** Checks if the predicate holds in some sub-expression */ + def exists(matcher: SubTree => Boolean)(e: SubTree): Boolean = { + fold[Boolean]({ (e, subs) => matcher(e) || subs.contains(true) } )(e) + } + + /** Collects a set of objects from all sub-expressions */ + def collect[T](matcher: SubTree => Set[T])(e: SubTree): Set[T] = { + fold[Set[T]]({ (e, subs) => matcher(e) ++ subs.flatten } )(e) + } + + def collectPreorder[T](matcher: SubTree => Seq[T])(e: SubTree): Seq[T] = { + fold[Seq[T]]({ (e, subs) => matcher(e) ++ subs.flatten } )(e) + } + + /** Returns a set of all sub-expressions matching the predicate */ + def filter(matcher: SubTree => Boolean)(e: SubTree): Set[SubTree] = { + collect[SubTree] { e => Set(e) filter matcher }(e) + } + + /** Counts how many times the predicate holds in sub-expressions */ + def count(matcher: SubTree => Int)(e: SubTree): Int = { + fold[Int]({ (e, subs) => matcher(e) + subs.sum } )(e) + } + + /** Replaces bottom-up sub-expressions by looking up for them in a map */ + def replace(substs: Map[SubTree,SubTree], expr: SubTree) : SubTree = { + postMap(substs.lift)(expr) + } + + /** Replaces bottom-up sub-expressions by looking up for them in the provided order */ + def replaceSeq(substs: Seq[(SubTree, SubTree)], expr: SubTree): SubTree = { + var res = expr + for (s <- substs) { + res = replace(Map(s), res) + } + res + } + +} \ No newline at end of file diff --git a/src/main/scala/leon/purescala/TypeOps.scala b/src/main/scala/leon/purescala/TypeOps.scala index 20cca91e1b8f4ebd2f2541d32d772cf55fbdfa79..14cf3b8250682a71285fffb3fd34e2e25608cdb1 100644 --- a/src/main/scala/leon/purescala/TypeOps.scala +++ b/src/main/scala/leon/purescala/TypeOps.scala @@ -11,16 +11,16 @@ import Extractors._ import Constructors._ import ExprOps.preMap -object TypeOps { +object TypeOps extends { val Deconstructor = NAryType } with SubTreeOps[TypeTree] { def typeDepth(t: TypeTree): Int = t match { case NAryType(tps, builder) => 1 + (0 +: (tps map typeDepth)).max } - def typeParamsOf(t: TypeTree): Set[TypeParameter] = t match { - case tp: TypeParameter => Set(tp) - case _ => - val NAryType(subs, _) = t - subs.flatMap(typeParamsOf).toSet + def typeParamsOf(t: TypeTree): Set[TypeParameter] = { + collect[TypeParameter]({ + case tp: TypeParameter => Set(tp) + case _ => Set.empty + })(t) } def canBeSubtypeOf( @@ -313,7 +313,7 @@ object TypeOps { val returnType = tpeSub(fd.returnType) val params = fd.params map (vd => vd.copy(id = freshId(vd.id, tpeSub(vd.getType)))) val newFd = fd.duplicate(id, tparams, params, returnType) - val subCalls = preMap { + val subCalls = ExprOps.preMap { case fi @ FunctionInvocation(tfd, args) if tfd.fd == fd => Some(FunctionInvocation(newFd.typed(tfd.tps), args).copiedFrom(fi)) case _ => diff --git a/src/main/scala/leon/purescala/Types.scala b/src/main/scala/leon/purescala/Types.scala index 3a0a85bb24045df18ab65b0afa488a34c8921315..9ec0e4b41f33aa0895ff160a5379d16a7cb88d87 100644 --- a/src/main/scala/leon/purescala/Types.scala +++ b/src/main/scala/leon/purescala/Types.scala @@ -133,7 +133,7 @@ object Types { case class AbstractClassType(classDef: AbstractClassDef, tps: Seq[TypeTree]) extends ClassType case class CaseClassType(classDef: CaseClassDef, tps: Seq[TypeTree]) extends ClassType - object NAryType { + object NAryType extends SubTreeOps.Extractor[TypeTree] { def unapply(t: TypeTree): Option[(Seq[TypeTree], Seq[TypeTree] => TypeTree)] = t match { case CaseClassType(ccd, ts) => Some((ts, ts => CaseClassType(ccd, ts))) case AbstractClassType(acd, ts) => Some((ts, ts => AbstractClassType(acd, ts))) @@ -142,6 +142,7 @@ object Types { case SetType(t) => Some((Seq(t), ts => SetType(ts.head))) case MapType(from,to) => Some((Seq(from, to), t => MapType(t(0), t(1)))) case FunctionType(fts, tt) => Some((tt +: fts, ts => FunctionType(ts.tail.toList, ts.head))) + /* n-ary operators */ case t => Some(Nil, _ => t) } } diff --git a/src/main/scala/leon/solvers/Model.scala b/src/main/scala/leon/solvers/Model.scala index 07bdee913f21605fbc41f660af608c492e5ee1b5..060cb7fc6fcf83df0785002a1dbfc35f40918113 100644 --- a/src/main/scala/leon/solvers/Model.scala +++ b/src/main/scala/leon/solvers/Model.scala @@ -68,6 +68,7 @@ class Model(protected val mapping: Map[Identifier, Expr]) def isDefinedAt(id: Identifier): Boolean = mapping.isDefinedAt(id) def get(id: Identifier): Option[Expr] = mapping.get(id) def getOrElse[E >: Expr](id: Identifier, e: E): E = get(id).getOrElse(e) + def ids = mapping.keys def apply(id: Identifier): Expr = get(id).getOrElse { throw new IllegalArgumentException } } diff --git a/src/main/scala/leon/solvers/QuantificationSolver.scala b/src/main/scala/leon/solvers/QuantificationSolver.scala index fa11ab6613bd65b196cce87ee062c3c56f0b95f9..4f56903c578da6c2d1b71f19268e8885cb93e99e 100644 --- a/src/main/scala/leon/solvers/QuantificationSolver.scala +++ b/src/main/scala/leon/solvers/QuantificationSolver.scala @@ -25,7 +25,7 @@ class HenkinModelBuilder(domains: HenkinDomains) override def result = new HenkinModel(mapBuilder.result, domains) } -trait QuantificationSolver { +trait QuantificationSolver extends Solver { val program: Program def getModel: HenkinModel diff --git a/src/main/scala/leon/solvers/SolverFactory.scala b/src/main/scala/leon/solvers/SolverFactory.scala index 67d28f877019a3e4741df6d268c68fc2d86d17ee..42d7ead0235a13ad3465dcd1fe64fd80db33ed25 100644 --- a/src/main/scala/leon/solvers/SolverFactory.scala +++ b/src/main/scala/leon/solvers/SolverFactory.scala @@ -79,10 +79,12 @@ object SolverFactory { def getFromName(ctx: LeonContext, program: Program)(name: String): SolverFactory[TimeoutSolver] = name match { case "fairz3" => - SolverFactory(() => new FairZ3Solver(ctx, program) with TimeoutSolver) + // Previously: new FairZ3Solver(ctx, program) with TimeoutSolver + SolverFactory(() => new Z3StringFairZ3Solver(ctx, program) with TimeoutSolver) case "unrollz3" => - SolverFactory(() => new UnrollingSolver(ctx, program, new UninterpretedZ3Solver(ctx, program)) with TimeoutSolver) + // Previously: new UnrollingSolver(ctx, program, new UninterpretedZ3Solver(ctx, program)) with TimeoutSolver + SolverFactory(() => new Z3StringUnrollingSolver(ctx, program, (program: Program) => new UninterpretedZ3Solver(ctx, program)) with TimeoutSolver) case "enum" => SolverFactory(() => new EnumerationSolver(ctx, program) with TimeoutSolver) @@ -91,10 +93,12 @@ object SolverFactory { SolverFactory(() => new GroundSolver(ctx, program) with TimeoutSolver) case "smt-z3" => - SolverFactory(() => new UnrollingSolver(ctx, program, new SMTLIBZ3Solver(ctx, program)) with TimeoutSolver) + // Previously: new UnrollingSolver(ctx, program, new SMTLIBZ3Solver(ctx, program)) with TimeoutSolver + SolverFactory(() => new Z3StringUnrollingSolver(ctx, program, (program: Program) => new SMTLIBZ3Solver(ctx, program)) with TimeoutSolver) case "smt-z3-q" => - SolverFactory(() => new SMTLIBZ3QuantifiedSolver(ctx, program) with TimeoutSolver) + // Previously: new SMTLIBZ3QuantifiedSolver(ctx, program) with TimeoutSolver + SolverFactory(() => new Z3StringSMTLIBZ3QuantifiedSolver(ctx, program) with TimeoutSolver) case "smt-cvc4" => SolverFactory(() => new UnrollingSolver(ctx, program, new SMTLIBCVC4Solver(ctx, program)) with TimeoutSolver) diff --git a/src/main/scala/leon/solvers/combinators/Z3StringCapableSolver.scala b/src/main/scala/leon/solvers/combinators/Z3StringCapableSolver.scala new file mode 100644 index 0000000000000000000000000000000000000000..602873aba9df1859650ee94486c59d9487ec7fbd --- /dev/null +++ b/src/main/scala/leon/solvers/combinators/Z3StringCapableSolver.scala @@ -0,0 +1,239 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package solvers +package combinators + +import purescala.Common._ +import purescala.Definitions._ +import purescala.Quantification._ +import purescala.Constructors._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.Types._ +import purescala.DefOps +import purescala.TypeOps +import purescala.Extractors._ +import utils._ +import z3.FairZ3Component.{optFeelingLucky, optUseCodeGen, optAssumePre, optNoChecks, optUnfoldFactor} +import templates._ +import evaluators._ +import Template._ +import leon.solvers.z3.Z3StringConversion +import leon.utils.Bijection +import leon.solvers.z3.StringEcoSystem + +object Z3StringCapableSolver { + def convert(p: Program, force: Boolean = false): (Program, Option[Z3StringConversion]) = { + val converter = new Z3StringConversion(p) + import converter.Forward._ + var globalFdMap = Map[FunDef, (Map[Identifier, Identifier], FunDef)]() + var hasStrings = false + val program_with_strings = converter.getProgram + val (new_program, fdMap) = DefOps.replaceFunDefs(program_with_strings)((fd: FunDef) => { + globalFdMap.get(fd).map(_._2).orElse( + if( fd.body.map(exists(e => TypeOps.exists{ _== StringType }(e.getType))).getOrElse(false) || + fd.paramIds.exists(id => TypeOps.exists(_ == StringType)(id.getType))) { + val idMap = fd.params.map(vd => vd.id -> convertId(vd.id)).toMap + val newFdId = convertId(fd.id) + val newFd = fd.duplicate(newFdId, + fd.tparams, + fd.params.map(vd => ValDef(idMap(vd.id))), + convertType(fd.returnType)) + globalFdMap += fd -> ((idMap, newFd)) + hasStrings = hasStrings || (program_with_strings.library.escape.get != fd) + Some(newFd) + } else None + ) + }) + if(!hasStrings && !force) { + (p, None) + } else { + converter.globalFdMap ++= globalFdMap.view.map(kv => (kv._1, kv._2._2)) + for((fd, (idMap, newFd)) <- globalFdMap) { + implicit val idVarMap = idMap.mapValues(id => Variable(id)) + newFd.fullBody = convertExpr(newFd.fullBody) + } + (new_program, Some(converter)) + } + } +} +trait ForcedProgramConversion { self: Z3StringCapableSolver[_] => + override def convertProgram(p: Program): (Program, Option[Z3StringConversion]) = { + Z3StringCapableSolver.convert(p, true) + } +} + +abstract class Z3StringCapableSolver[+TUnderlying <: Solver](val context: LeonContext, val program: Program, + val underlyingConstructor: (Program, Option[Z3StringConversion]) => TUnderlying) +extends Solver { + def convertProgram(p: Program): (Program, Option[Z3StringConversion]) = Z3StringCapableSolver.convert(p) + protected val (new_program, someConverter) = convertProgram(program) + + val underlying = underlyingConstructor(new_program, someConverter) + + def getModel: leon.solvers.Model = { + val model = underlying.getModel + someConverter match { + case None => model + case Some(converter) => + println("Conversion") + val ids = model.ids.toSeq + val exprs = ids.map(model.apply) + import converter.Backward._ + val original_ids = ids.map(convertId) + val original_exprs = exprs.map{ case e => convertExpr(e)(Map()) } + new Model(original_ids.zip(original_exprs).toMap) + } + } + + // Members declared in leon.utils.Interruptible + def interrupt(): Unit = underlying.interrupt() + def recoverInterrupt(): Unit = underlying.recoverInterrupt() + + // Members declared in leon.solvers.Solver + def assertCnstr(expression: Expr): Unit = { + someConverter.map{converter => + import converter.Forward._ + val newExpression = convertExpr(expression)(Map()) + underlying.assertCnstr(newExpression) + }.getOrElse(underlying.assertCnstr(expression)) + } + def getUnsatCore: Set[Expr] = { + someConverter.map{converter => + import converter.Backward._ + underlying.getUnsatCore map (e => convertExpr(e)(Map())) + }.getOrElse(underlying.getUnsatCore) + } + def check: Option[Boolean] = underlying.check + def free(): Unit = underlying.free() + def pop(): Unit = underlying.pop() + def push(): Unit = underlying.push() + def reset(): Unit = underlying.reset() + def name: String = underlying.name +} + +import z3._ + +trait Z3StringAbstractZ3Solver[TUnderlying <: Solver] extends AbstractZ3Solver { self: Z3StringCapableSolver[TUnderlying] => +} + +trait Z3StringNaiveAssumptionSolver[TUnderlying <: Solver] extends NaiveAssumptionSolver { self: Z3StringCapableSolver[TUnderlying] => +} + +trait Z3StringEvaluatingSolver[TUnderlying <: EvaluatingSolver] extends EvaluatingSolver{ self: Z3StringCapableSolver[TUnderlying] => + // Members declared in leon.solvers.EvaluatingSolver + val useCodeGen: Boolean = underlying.useCodeGen +} + +trait Z3StringQuantificationSolver[TUnderlying <: QuantificationSolver] extends QuantificationSolver { self: Z3StringCapableSolver[TUnderlying] => + // Members declared in leon.solvers.QuantificationSolver + override def getModel: leon.solvers.HenkinModel = { + val model = underlying.getModel + someConverter map { converter => + val ids = model.ids.toSeq + val exprs = ids.map(model.apply) + import converter.Backward._ + val original_ids = ids.map(convertId) + val original_exprs = exprs.map{ case e => convertExpr(e)(Map()) } + + val new_domain = new HenkinDomains( + model.doms.lambdas.map(kv => + (convertExpr(kv._1)(Map()).asInstanceOf[Lambda], + kv._2.map(e => e.map(e => convertExpr(e)(Map()))))).toMap, + model.doms.tpes.map(kv => + (convertType(kv._1), + kv._2.map(e => e.map(e => convertExpr(e)(Map()))))).toMap + ) + + new HenkinModel(original_ids.zip(original_exprs).toMap, new_domain) + } getOrElse model + } +} + +trait EvaluatorCheckConverter extends DeterministicEvaluator { + def converter: Z3StringConversion + abstract override def check(expression: Expr, model: solvers.Model) : CheckResult = { + val c = converter + import c.Backward._ // Because the evaluator is going to be called by the underlying solver, but it will use the original program + super.check(convertExpr(expression)(Map()), convertModel(model)) + } +} + +class ConvertibleCodeGenEvaluator(context: LeonContext, originalProgram: Program, val converter: Z3StringConversion) + extends CodeGenEvaluator(context, originalProgram) with EvaluatorCheckConverter { + override def compile(expression: Expr, args: Seq[Identifier]) : Option[solvers.Model=>EvaluationResult] = { + import converter._ + super.compile(Backward.convertExpr(expression)(Map()), args.map(Backward.convertId)) + .map(evaluator => (m: Model) => Forward.convertResult(evaluator(Backward.convertModel(m))) + ) + } +} + +class ConvertibleDefaultEvaluator(context: LeonContext, originalProgram: Program, val converter: Z3StringConversion) extends DefaultEvaluator(context, originalProgram) with EvaluatorCheckConverter { + override def eval(ex: Expr, model: Model): EvaluationResults.Result[Expr] = { + import converter._ + Forward.convertResult(super.eval(Backward.convertExpr(ex)(Map()), Backward.convertModel(model))) + } +} + + +class FairZ3SolverWithBackwardEvaluator(context: LeonContext, program: Program, + originalProgram: Program, someConverter: Option[Z3StringConversion]) extends FairZ3Solver(context, program) { + override lazy val evaluator: DeterministicEvaluator = { // We evaluate expressions using the original evaluator + someConverter match { + case Some(converter) => + if (useCodeGen) { + new ConvertibleCodeGenEvaluator(context, originalProgram, converter) + } else { + new ConvertibleDefaultEvaluator(context, originalProgram, converter) + } + case None => + if (useCodeGen) { + new CodeGenEvaluator(context, program) + } else { + new DefaultEvaluator(context, program) + } + } + } +} + + +class Z3StringFairZ3Solver(context: LeonContext, program: Program) + extends Z3StringCapableSolver(context, program, + (prgm: Program, someConverter: Option[Z3StringConversion]) => + new FairZ3SolverWithBackwardEvaluator(context, prgm, program, someConverter)) + with Z3StringEvaluatingSolver[FairZ3Solver] + with Z3StringQuantificationSolver[FairZ3Solver] { + // Members declared in leon.solvers.z3.AbstractZ3Solver + protected[leon] val z3cfg: _root_.z3.scala.Z3Config = underlying.z3cfg + override def checkAssumptions(assumptions: Set[Expr]): Option[Boolean] = { + someConverter match { + case None => underlying.checkAssumptions(assumptions) + case Some(converter) => + underlying.checkAssumptions(assumptions map (e => converter.Forward.convertExpr(e)(Map()))) + } + } +} + +class Z3StringUnrollingSolver(context: LeonContext, program: Program, underlyingSolverConstructor: Program => Solver) + extends Z3StringCapableSolver(context, program, (program: Program, converter: Option[Z3StringConversion]) => + new UnrollingSolver(context, program, underlyingSolverConstructor(program))) + with Z3StringNaiveAssumptionSolver[UnrollingSolver] + with Z3StringEvaluatingSolver[UnrollingSolver] + with Z3StringQuantificationSolver[UnrollingSolver] { + override def getUnsatCore = super[Z3StringNaiveAssumptionSolver].getUnsatCore +} + +class Z3StringSMTLIBZ3QuantifiedSolver(context: LeonContext, program: Program) + extends Z3StringCapableSolver(context, program, (program: Program, converter: Option[Z3StringConversion]) => + new smtlib.SMTLIBZ3QuantifiedSolver(context, program)) { + override def checkAssumptions(assumptions: Set[Expr]): Option[Boolean] = { + someConverter match { + case None => underlying.checkAssumptions(assumptions) + case Some(converter) => + underlying.checkAssumptions(assumptions map (e => converter.Forward.convertExpr(e)(Map()))) + } + } +} + diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala index 47017bcf1471770f1b1e9fb574a81bed34f4c515..d05b45c35e2e3847be4e1a6650c779de91da91d9 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala @@ -104,6 +104,7 @@ trait SMTLIBTarget extends Interruptible { interpreter.eval(cmd) match { case err @ ErrorResponse(msg) if !hasError && !interrupted && !rawOut => reporter.warning(s"Unexpected error from $targetName solver: $msg") + //println(Thread.currentThread().getStackTrace.map(_.toString).take(10).mkString("\n")) // Store that there was an error. Now all following check() // invocations will return None addError() diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala index 1731b94ae5f4f91b87248deddc50db5339915552..3d4a06a838a5057a693d85754bb5113e1ce7d0ae 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala @@ -8,15 +8,15 @@ import purescala.Common._ import purescala.Expressions._ import purescala.Constructors._ import purescala.Types._ -import purescala.Definitions._ + import _root_.smtlib.parser.Terms.{Identifier => SMTIdentifier, _} import _root_.smtlib.parser.Commands.{FunDef => SMTFunDef, _} import _root_.smtlib.interpreters.Z3Interpreter import _root_.smtlib.theories.Core.{Equals => SMTEquals, _} import _root_.smtlib.theories.ArraysEx -import leon.solvers.z3.Z3StringConversion -trait SMTLIBZ3Target extends SMTLIBTarget with Z3StringConversion[Term] { +trait SMTLIBZ3Target extends SMTLIBTarget { + def targetName = "z3" def interpreterOps(ctx: LeonContext) = { @@ -40,11 +40,11 @@ trait SMTLIBZ3Target extends SMTLIBTarget with Z3StringConversion[Term] { override protected def declareSort(t: TypeTree): Sort = { val tpe = normalizeType(t) sorts.cachedB(tpe) { - convertType(tpe) match { + tpe match { case SetType(base) => super.declareSort(BooleanType) declareSetSort(base) - case t => + case _ => super.declareSort(t) } } @@ -69,13 +69,9 @@ trait SMTLIBZ3Target extends SMTLIBTarget with Z3StringConversion[Term] { Sort(SMTIdentifier(setSort.get), Seq(declareSort(of))) } - override protected def fromSMT(t: Term, expected_otpe: Option[TypeTree] = None) + override protected def fromSMT(t: Term, otpe: Option[TypeTree] = None) (implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): Expr = { - val otpe = expected_otpe match { - case Some(StringType) => Some(listchar) - case _ => expected_otpe - } - val res = (t, otpe) match { + (t, otpe) match { case (SimpleSymbol(s), Some(tp: TypeParameter)) => val n = s.name.split("!").toList.last GenericValue(tp, n.toInt) @@ -100,16 +96,6 @@ trait SMTLIBZ3Target extends SMTLIBTarget with Z3StringConversion[Term] { case _ => super.fromSMT(t, otpe) } - expected_otpe match { - case Some(StringType) => - StringLiteral(convertToString(res)(program)) - case _ => res - } - } - - def convertToTarget(e: Expr)(implicit bindings: Map[Identifier, Term]): Term = toSMT(e) - def targetApplication(tfd: TypedFunDef, args: Seq[Term])(implicit bindings: Map[Identifier, Term]): Term = { - FunctionApplication(declareFunction(tfd), args) } override protected def toSMT(e: Expr)(implicit bindings: Map[Identifier, Term]): Term = e match { @@ -146,7 +132,6 @@ trait SMTLIBZ3Target extends SMTLIBTarget with Z3StringConversion[Term] { case SetIntersection(l, r) => ArrayMap(SSymbol("and"), toSMT(l), toSMT(r)) - case StringConverted(result) => result case _ => super.toSMT(e) } diff --git a/src/main/scala/leon/solvers/templates/TemplateGenerator.scala b/src/main/scala/leon/solvers/templates/TemplateGenerator.scala index aa0e5dad6506be945919cb22361cce649f40c077..59be52775406a78c22a55773dfd161db003b21bb 100644 --- a/src/main/scala/leon/solvers/templates/TemplateGenerator.scala +++ b/src/main/scala/leon/solvers/templates/TemplateGenerator.scala @@ -9,7 +9,7 @@ import purescala.Expressions._ import purescala.Extractors._ import purescala.ExprOps._ import purescala.Types._ -import purescala.TypeOps._ +import purescala.TypeOps.bestRealType import purescala.Definitions._ import purescala.Constructors._ import purescala.Quantification._ @@ -207,7 +207,11 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], // id => expr && ... && expr var guardedExprs = Map[Identifier, Seq[Expr]]() def storeGuarded(guardVar : Identifier, expr : Expr) : Unit = { - assert(expr.getType == BooleanType, expr.asString(Program.empty)(LeonContext.empty) + " is not of type Boolean") + assert(expr.getType == BooleanType, expr.asString(Program.empty)(LeonContext.empty) + " is not of type Boolean." + ( + purescala.ExprOps.fold[String]{ (e, se) => + s"$e is of type ${e.getType}" + se.map(child => "\n " + "\n".r.replaceAllIn(child, "\n ")).mkString + }(expr) + )) val prev = guardedExprs.getOrElse(guardVar, Nil) diff --git a/src/main/scala/leon/solvers/templates/TemplateManager.scala b/src/main/scala/leon/solvers/templates/TemplateManager.scala index 2b75f08f0480cf272515bb8d8393e01e29d4dbf1..cdfe0c9ed6462b322b760e635122f5c3b11d2923 100644 --- a/src/main/scala/leon/solvers/templates/TemplateManager.scala +++ b/src/main/scala/leon/solvers/templates/TemplateManager.scala @@ -11,7 +11,7 @@ import purescala.Quantification._ import purescala.Extractors._ import purescala.ExprOps._ import purescala.Types._ -import purescala.TypeOps._ +import purescala.TypeOps.bestRealType import utils._ diff --git a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala index 04496ae9ce8d76f33f3498c3e4cde84052a51049..7c3486ff533a630bbd1bce6bde31d86969163e7c 100644 --- a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala @@ -264,322 +264,311 @@ trait AbstractZ3Solver extends Solver { case other => throw SolverUnsupportedError(other, this) } - - protected[leon] def toZ3Formula(expr: Expr, initialMap: Map[Identifier, Z3AST] = Map.empty): Z3AST = { - implicit var z3Vars: Map[Identifier,Z3AST] = if(initialMap.nonEmpty) { + var z3Vars: Map[Identifier,Z3AST] = if(initialMap.nonEmpty) { initialMap } else { // FIXME TODO pleeeeeeeease make this cleaner. Ie. decide what set of // variable has to remain in a map etc. variables.aToB.collect{ case (Variable(id), p2) => id -> p2 } } - new Z3StringConversion[Z3AST] { - def getProgram = AbstractZ3Solver.this.program - def convertToTarget(e: Expr)(implicit bindings: Map[Identifier, Z3AST]): Z3AST = { - rec(e) - } - def targetApplication(tfd: TypedFunDef, args: Seq[Z3AST])(implicit bindings: Map[Identifier, Z3AST]): Z3AST = { - z3.mkApp(functionDefToDecl(tfd), args: _*) + + def rec(ex: Expr): Z3AST = ex match { + + // TODO: Leave that as a specialization? + case LetTuple(ids, e, b) => { + z3Vars = z3Vars ++ ids.zipWithIndex.map { case (id, ix) => + val entry = id -> rec(tupleSelect(e, ix + 1, ids.size)) + entry } - def rec(ex: Expr): Z3AST = ex match { - - // TODO: Leave that as a specialization? - case LetTuple(ids, e, b) => { - z3Vars = z3Vars ++ ids.zipWithIndex.map { case (id, ix) => - val entry = id -> rec(tupleSelect(e, ix + 1, ids.size)) - entry - } - val rb = rec(b) - z3Vars = z3Vars -- ids - rb - } - - case p @ Passes(_, _, _) => - rec(p.asConstraint) - - case me @ MatchExpr(s, cs) => - rec(matchToIfThenElse(me)) - - case Let(i, e, b) => { - val re = rec(e) - z3Vars = z3Vars + (i -> re) - val rb = rec(b) - z3Vars = z3Vars - i - rb - } - - case Waypoint(_, e, _) => rec(e) - case a @ Assert(cond, err, body) => - rec(IfExpr(cond, body, Error(a.getType, err.getOrElse("Assertion failed")).setPos(a.getPos)).setPos(a.getPos)) - - case e @ Error(tpe, _) => { - val newAST = z3.mkFreshConst("errorValue", typeToSort(tpe)) - // Might introduce dupplicates (e), but no worries here - variables += (e -> newAST) - newAST - } - case v @ Variable(id) => z3Vars.get(id) match { - case Some(ast) => + val rb = rec(b) + z3Vars = z3Vars -- ids + rb + } + + case p @ Passes(_, _, _) => + rec(p.asConstraint) + + case me @ MatchExpr(s, cs) => + rec(matchToIfThenElse(me)) + + case Let(i, e, b) => { + val re = rec(e) + z3Vars = z3Vars + (i -> re) + val rb = rec(b) + z3Vars = z3Vars - i + rb + } + + case Waypoint(_, e, _) => rec(e) + case a @ Assert(cond, err, body) => + rec(IfExpr(cond, body, Error(a.getType, err.getOrElse("Assertion failed")).setPos(a.getPos)).setPos(a.getPos)) + + case e @ Error(tpe, _) => { + val newAST = z3.mkFreshConst("errorValue", typeToSort(tpe)) + // Might introduce dupplicates (e), but no worries here + variables += (e -> newAST) + newAST + } + case v @ Variable(id) => z3Vars.get(id) match { + case Some(ast) => + ast + case None => { + variables.getB(v) match { + case Some(ast) => ast - case None => { - variables.getB(v) match { - case Some(ast) => - ast - - case None => - val newAST = z3.mkFreshConst(id.uniqueName, typeToSort(v.getType)) - z3Vars = z3Vars + (id -> newAST) - variables += (v -> newAST) - newAST - } - } - } - - case ite @ IfExpr(c, t, e) => z3.mkITE(rec(c), rec(t), rec(e)) - case And(exs) => z3.mkAnd(exs.map(rec): _*) - case Or(exs) => z3.mkOr(exs.map(rec): _*) - case Implies(l, r) => z3.mkImplies(rec(l), rec(r)) - case Not(Equals(l, r)) => z3.mkDistinct(rec(l), rec(r)) - case Not(e) => z3.mkNot(rec(e)) - case IntLiteral(v) => z3.mkInt(v, typeToSort(Int32Type)) - case InfiniteIntegerLiteral(v) => z3.mkNumeral(v.toString, typeToSort(IntegerType)) - case FractionalLiteral(n, d) => z3.mkNumeral(s"$n / $d", typeToSort(RealType)) - case CharLiteral(c) => z3.mkInt(c, typeToSort(CharType)) - case BooleanLiteral(v) => if (v) z3.mkTrue() else z3.mkFalse() - case Equals(l, r) => z3.mkEq(rec( l ), rec( r ) ) - case Plus(l, r) => z3.mkAdd(rec(l), rec(r)) - case Minus(l, r) => z3.mkSub(rec(l), rec(r)) - case Times(l, r) => z3.mkMul(rec(l), rec(r)) - case Division(l, r) => { - val rl = rec(l) - val rr = rec(r) - z3.mkITE( - z3.mkGE(rl, z3.mkNumeral("0", typeToSort(IntegerType))), - z3.mkDiv(rl, rr), - z3.mkUnaryMinus(z3.mkDiv(z3.mkUnaryMinus(rl), rr)) - ) - } - case Remainder(l, r) => { - val q = rec(Division(l, r)) - z3.mkSub(rec(l), z3.mkMul(rec(r), q)) - } - case Modulo(l, r) => { - z3.mkMod(rec(l), rec(r)) - } - case UMinus(e) => z3.mkUnaryMinus(rec(e)) - - case RealPlus(l, r) => z3.mkAdd(rec(l), rec(r)) - case RealMinus(l, r) => z3.mkSub(rec(l), rec(r)) - case RealTimes(l, r) => z3.mkMul(rec(l), rec(r)) - case RealDivision(l, r) => z3.mkDiv(rec(l), rec(r)) - case RealUMinus(e) => z3.mkUnaryMinus(rec(e)) - - case BVPlus(l, r) => z3.mkBVAdd(rec(l), rec(r)) - case BVMinus(l, r) => z3.mkBVSub(rec(l), rec(r)) - case BVTimes(l, r) => z3.mkBVMul(rec(l), rec(r)) - case BVDivision(l, r) => z3.mkBVSdiv(rec(l), rec(r)) - case BVRemainder(l, r) => z3.mkBVSrem(rec(l), rec(r)) - case BVUMinus(e) => z3.mkBVNeg(rec(e)) - case BVNot(e) => z3.mkBVNot(rec(e)) - case BVAnd(l, r) => z3.mkBVAnd(rec(l), rec(r)) - case BVOr(l, r) => z3.mkBVOr(rec(l), rec(r)) - case BVXOr(l, r) => z3.mkBVXor(rec(l), rec(r)) - case BVShiftLeft(l, r) => z3.mkBVShl(rec(l), rec(r)) - case BVAShiftRight(l, r) => z3.mkBVAshr(rec(l), rec(r)) - case BVLShiftRight(l, r) => z3.mkBVLshr(rec(l), rec(r)) - case LessThan(l, r) => l.getType match { - case IntegerType => z3.mkLT(rec(l), rec(r)) - case RealType => z3.mkLT(rec(l), rec(r)) - case Int32Type => z3.mkBVSlt(rec(l), rec(r)) - case CharType => z3.mkBVSlt(rec(l), rec(r)) - } - case LessEquals(l, r) => l.getType match { - case IntegerType => z3.mkLE(rec(l), rec(r)) - case RealType => z3.mkLE(rec(l), rec(r)) - case Int32Type => z3.mkBVSle(rec(l), rec(r)) - case CharType => z3.mkBVSle(rec(l), rec(r)) - //case _ => throw new IllegalStateException(s"l: $l, Left type: ${l.getType} Expr: $ex") - } - case GreaterThan(l, r) => l.getType match { - case IntegerType => z3.mkGT(rec(l), rec(r)) - case RealType => z3.mkGT(rec(l), rec(r)) - case Int32Type => z3.mkBVSgt(rec(l), rec(r)) - case CharType => z3.mkBVSgt(rec(l), rec(r)) - } - case GreaterEquals(l, r) => l.getType match { - case IntegerType => z3.mkGE(rec(l), rec(r)) - case RealType => z3.mkGE(rec(l), rec(r)) - case Int32Type => z3.mkBVSge(rec(l), rec(r)) - case CharType => z3.mkBVSge(rec(l), rec(r)) - } - - case StringConverted(result) => - result - - case u : UnitLiteral => - val tpe = normalizeType(u.getType) - typeToSort(tpe) - val constructor = constructors.toB(tpe) - constructor() - - case t @ Tuple(es) => - val tpe = normalizeType(t.getType) - typeToSort(tpe) - val constructor = constructors.toB(tpe) - constructor(es.map(rec): _*) - - case ts @ TupleSelect(t, i) => - val tpe = normalizeType(t.getType) - typeToSort(tpe) - val selector = selectors.toB((tpe, i-1)) - selector(rec(t)) - - case c @ CaseClass(ct, args) => - typeToSort(ct) // Making sure the sort is defined - val constructor = constructors.toB(ct) - constructor(args.map(rec): _*) - - case c @ CaseClassSelector(cct, cc, sel) => - typeToSort(cct) // Making sure the sort is defined - val selector = selectors.toB(cct, c.selectorIndex) - selector(rec(cc)) - - case AsInstanceOf(expr, ct) => - rec(expr) - - case IsInstanceOf(e, act: AbstractClassType) => - act.knownCCDescendants match { - case Seq(cct) => - rec(IsInstanceOf(e, cct)) - case more => - val i = FreshIdentifier("e", act, alwaysShowUniqueID = true) - rec(Let(i, e, orJoin(more map(IsInstanceOf(Variable(i), _))))) - } - - case IsInstanceOf(e, cct: CaseClassType) => - typeToSort(cct) // Making sure the sort is defined - val tester = testers.toB(cct) - tester(rec(e)) - - case al @ ArraySelect(a, i) => - val tpe = normalizeType(a.getType) - - val sa = rec(a) - val content = selectors.toB((tpe, 1))(sa) - - z3.mkSelect(content, rec(i)) - - case al @ ArrayUpdated(a, i, e) => - val tpe = normalizeType(a.getType) - - val sa = rec(a) - val ssize = selectors.toB((tpe, 0))(sa) - val scontent = selectors.toB((tpe, 1))(sa) - - val newcontent = z3.mkStore(scontent, rec(i), rec(e)) - - val constructor = constructors.toB(tpe) - - constructor(ssize, newcontent) - - case al @ ArrayLength(a) => - val tpe = normalizeType(a.getType) - val sa = rec(a) - selectors.toB((tpe, 0))(sa) - - case arr @ FiniteArray(elems, oDefault, length) => - val at @ ArrayType(base) = normalizeType(arr.getType) - typeToSort(at) - - val default = oDefault.getOrElse(simplestValue(base)) - - val ar = rec(RawArrayValue(Int32Type, elems.map { - case (i, e) => IntLiteral(i) -> e - }, default)) - - constructors.toB(at)(rec(length), ar) - - case f @ FunctionInvocation(tfd, args) => - z3.mkApp(functionDefToDecl(tfd), args.map(rec): _*) - - case fa @ Application(caller, args) => - val ft @ FunctionType(froms, to) = normalizeType(caller.getType) - val funDecl = lambdas.cachedB(ft) { - val sortSeq = (ft +: froms).map(tpe => typeToSort(tpe)) - val returnSort = typeToSort(to) - - val name = FreshIdentifier("dynLambda").uniqueName - z3.mkFreshFuncDecl(name, sortSeq, returnSort) - } - z3.mkApp(funDecl, (caller +: args).map(rec): _*) - - case ElementOfSet(e, s) => z3.mkSetMember(rec(e), rec(s)) - case SubsetOf(s1, s2) => z3.mkSetSubset(rec(s1), rec(s2)) - case SetIntersection(s1, s2) => z3.mkSetIntersect(rec(s1), rec(s2)) - case SetUnion(s1, s2) => z3.mkSetUnion(rec(s1), rec(s2)) - case SetDifference(s1, s2) => z3.mkSetDifference(rec(s1), rec(s2)) - case f @ FiniteSet(elems, base) => elems.foldLeft(z3.mkEmptySet(typeToSort(base)))((ast, el) => z3.mkSetAdd(ast, rec(el))) - - case RawArrayValue(keyTpe, elems, default) => - val ar = z3.mkConstArray(typeToSort(keyTpe), rec(default)) - - elems.foldLeft(ar) { - case (array, (k, v)) => z3.mkStore(array, rec(k), rec(v)) - } - - /** - * ===== Map operations ===== - */ - case m @ FiniteMap(elems, from, to) => - val MapType(_, t) = normalizeType(m.getType) - - rec(RawArrayValue(from, elems.map{ - case (k, v) => (k, CaseClass(library.someType(t), Seq(v))) - }, CaseClass(library.noneType(t), Seq()))) - - case MapApply(m, k) => - val mt @ MapType(_, t) = normalizeType(m.getType) - typeToSort(mt) - - val el = z3.mkSelect(rec(m), rec(k)) - - // Really ?!? We don't check that it is actually != None? - selectors.toB(library.someType(t), 0)(el) - - case MapIsDefinedAt(m, k) => - val mt @ MapType(_, t) = normalizeType(m.getType) - typeToSort(mt) - - val el = z3.mkSelect(rec(m), rec(k)) - - testers.toB(library.someType(t))(el) - - case MapUnion(m1, FiniteMap(elems, _, _)) => - val mt @ MapType(_, t) = normalizeType(m1.getType) - typeToSort(mt) - - elems.foldLeft(rec(m1)) { case (m, (k,v)) => - z3.mkStore(m, rec(k), rec(CaseClass(library.someType(t), Seq(v)))) - } - - - case gv @ GenericValue(tp, id) => - z3.mkApp(genericValueToDecl(gv)) - - case other => - unsupported(other) + + case None => + val newAST = z3.mkFreshConst(id.uniqueName, typeToSort(v.getType)) + z3Vars = z3Vars + (id -> newAST) + variables += (v -> newAST) + newAST } - }.rec(expr) + } + } + + case ite @ IfExpr(c, t, e) => z3.mkITE(rec(c), rec(t), rec(e)) + case And(exs) => z3.mkAnd(exs.map(rec): _*) + case Or(exs) => z3.mkOr(exs.map(rec): _*) + case Implies(l, r) => z3.mkImplies(rec(l), rec(r)) + case Not(Equals(l, r)) => z3.mkDistinct(rec(l), rec(r)) + case Not(e) => z3.mkNot(rec(e)) + case IntLiteral(v) => z3.mkInt(v, typeToSort(Int32Type)) + case InfiniteIntegerLiteral(v) => z3.mkNumeral(v.toString, typeToSort(IntegerType)) + case FractionalLiteral(n, d) => z3.mkNumeral(s"$n / $d", typeToSort(RealType)) + case CharLiteral(c) => z3.mkInt(c, typeToSort(CharType)) + case BooleanLiteral(v) => if (v) z3.mkTrue() else z3.mkFalse() + case Equals(l, r) => z3.mkEq(rec( l ), rec( r ) ) + case Plus(l, r) => z3.mkAdd(rec(l), rec(r)) + case Minus(l, r) => z3.mkSub(rec(l), rec(r)) + case Times(l, r) => z3.mkMul(rec(l), rec(r)) + case Division(l, r) => { + val rl = rec(l) + val rr = rec(r) + z3.mkITE( + z3.mkGE(rl, z3.mkNumeral("0", typeToSort(IntegerType))), + z3.mkDiv(rl, rr), + z3.mkUnaryMinus(z3.mkDiv(z3.mkUnaryMinus(rl), rr)) + ) + } + case Remainder(l, r) => { + val q = rec(Division(l, r)) + z3.mkSub(rec(l), z3.mkMul(rec(r), q)) + } + case Modulo(l, r) => { + z3.mkMod(rec(l), rec(r)) + } + case UMinus(e) => z3.mkUnaryMinus(rec(e)) + + case RealPlus(l, r) => z3.mkAdd(rec(l), rec(r)) + case RealMinus(l, r) => z3.mkSub(rec(l), rec(r)) + case RealTimes(l, r) => z3.mkMul(rec(l), rec(r)) + case RealDivision(l, r) => z3.mkDiv(rec(l), rec(r)) + case RealUMinus(e) => z3.mkUnaryMinus(rec(e)) + + case BVPlus(l, r) => z3.mkBVAdd(rec(l), rec(r)) + case BVMinus(l, r) => z3.mkBVSub(rec(l), rec(r)) + case BVTimes(l, r) => z3.mkBVMul(rec(l), rec(r)) + case BVDivision(l, r) => z3.mkBVSdiv(rec(l), rec(r)) + case BVRemainder(l, r) => z3.mkBVSrem(rec(l), rec(r)) + case BVUMinus(e) => z3.mkBVNeg(rec(e)) + case BVNot(e) => z3.mkBVNot(rec(e)) + case BVAnd(l, r) => z3.mkBVAnd(rec(l), rec(r)) + case BVOr(l, r) => z3.mkBVOr(rec(l), rec(r)) + case BVXOr(l, r) => z3.mkBVXor(rec(l), rec(r)) + case BVShiftLeft(l, r) => z3.mkBVShl(rec(l), rec(r)) + case BVAShiftRight(l, r) => z3.mkBVAshr(rec(l), rec(r)) + case BVLShiftRight(l, r) => z3.mkBVLshr(rec(l), rec(r)) + case LessThan(l, r) => l.getType match { + case IntegerType => z3.mkLT(rec(l), rec(r)) + case RealType => z3.mkLT(rec(l), rec(r)) + case Int32Type => z3.mkBVSlt(rec(l), rec(r)) + case CharType => z3.mkBVSlt(rec(l), rec(r)) + } + case LessEquals(l, r) => l.getType match { + case IntegerType => z3.mkLE(rec(l), rec(r)) + case RealType => z3.mkLE(rec(l), rec(r)) + case Int32Type => z3.mkBVSle(rec(l), rec(r)) + case CharType => z3.mkBVSle(rec(l), rec(r)) + //case _ => throw new IllegalStateException(s"l: $l, Left type: ${l.getType} Expr: $ex") + } + case GreaterThan(l, r) => l.getType match { + case IntegerType => z3.mkGT(rec(l), rec(r)) + case RealType => z3.mkGT(rec(l), rec(r)) + case Int32Type => z3.mkBVSgt(rec(l), rec(r)) + case CharType => z3.mkBVSgt(rec(l), rec(r)) + } + case GreaterEquals(l, r) => l.getType match { + case IntegerType => z3.mkGE(rec(l), rec(r)) + case RealType => z3.mkGE(rec(l), rec(r)) + case Int32Type => z3.mkBVSge(rec(l), rec(r)) + case CharType => z3.mkBVSge(rec(l), rec(r)) + } + + case u : UnitLiteral => + val tpe = normalizeType(u.getType) + typeToSort(tpe) + val constructor = constructors.toB(tpe) + constructor() + + case t @ Tuple(es) => + val tpe = normalizeType(t.getType) + typeToSort(tpe) + val constructor = constructors.toB(tpe) + constructor(es.map(rec): _*) + + case ts @ TupleSelect(t, i) => + val tpe = normalizeType(t.getType) + typeToSort(tpe) + val selector = selectors.toB((tpe, i-1)) + selector(rec(t)) + + case c @ CaseClass(ct, args) => + typeToSort(ct) // Making sure the sort is defined + val constructor = constructors.toB(ct) + constructor(args.map(rec): _*) + + case c @ CaseClassSelector(cct, cc, sel) => + typeToSort(cct) // Making sure the sort is defined + val selector = selectors.toB(cct, c.selectorIndex) + selector(rec(cc)) + + case AsInstanceOf(expr, ct) => + rec(expr) + + case IsInstanceOf(e, act: AbstractClassType) => + act.knownCCDescendants match { + case Seq(cct) => + rec(IsInstanceOf(e, cct)) + case more => + val i = FreshIdentifier("e", act, alwaysShowUniqueID = true) + rec(Let(i, e, orJoin(more map(IsInstanceOf(Variable(i), _))))) + } + + case IsInstanceOf(e, cct: CaseClassType) => + typeToSort(cct) // Making sure the sort is defined + val tester = testers.toB(cct) + tester(rec(e)) + + case al @ ArraySelect(a, i) => + val tpe = normalizeType(a.getType) + + val sa = rec(a) + val content = selectors.toB((tpe, 1))(sa) + + z3.mkSelect(content, rec(i)) + + case al @ ArrayUpdated(a, i, e) => + val tpe = normalizeType(a.getType) + + val sa = rec(a) + val ssize = selectors.toB((tpe, 0))(sa) + val scontent = selectors.toB((tpe, 1))(sa) + + val newcontent = z3.mkStore(scontent, rec(i), rec(e)) + + val constructor = constructors.toB(tpe) + + constructor(ssize, newcontent) + + case al @ ArrayLength(a) => + val tpe = normalizeType(a.getType) + val sa = rec(a) + selectors.toB((tpe, 0))(sa) + + case arr @ FiniteArray(elems, oDefault, length) => + val at @ ArrayType(base) = normalizeType(arr.getType) + typeToSort(at) + + val default = oDefault.getOrElse(simplestValue(base)) + + val ar = rec(RawArrayValue(Int32Type, elems.map { + case (i, e) => IntLiteral(i) -> e + }, default)) + + constructors.toB(at)(rec(length), ar) + + case f @ FunctionInvocation(tfd, args) => + z3.mkApp(functionDefToDecl(tfd), args.map(rec): _*) + + case fa @ Application(caller, args) => + val ft @ FunctionType(froms, to) = normalizeType(caller.getType) + val funDecl = lambdas.cachedB(ft) { + val sortSeq = (ft +: froms).map(tpe => typeToSort(tpe)) + val returnSort = typeToSort(to) + + val name = FreshIdentifier("dynLambda").uniqueName + z3.mkFreshFuncDecl(name, sortSeq, returnSort) + } + z3.mkApp(funDecl, (caller +: args).map(rec): _*) + + case ElementOfSet(e, s) => z3.mkSetMember(rec(e), rec(s)) + case SubsetOf(s1, s2) => z3.mkSetSubset(rec(s1), rec(s2)) + case SetIntersection(s1, s2) => z3.mkSetIntersect(rec(s1), rec(s2)) + case SetUnion(s1, s2) => z3.mkSetUnion(rec(s1), rec(s2)) + case SetDifference(s1, s2) => z3.mkSetDifference(rec(s1), rec(s2)) + case f @ FiniteSet(elems, base) => elems.foldLeft(z3.mkEmptySet(typeToSort(base)))((ast, el) => z3.mkSetAdd(ast, rec(el))) + + case RawArrayValue(keyTpe, elems, default) => + val ar = z3.mkConstArray(typeToSort(keyTpe), rec(default)) + + elems.foldLeft(ar) { + case (array, (k, v)) => z3.mkStore(array, rec(k), rec(v)) + } + + /** + * ===== Map operations ===== + */ + case m @ FiniteMap(elems, from, to) => + val MapType(_, t) = normalizeType(m.getType) + + rec(RawArrayValue(from, elems.map{ + case (k, v) => (k, CaseClass(library.someType(t), Seq(v))) + }, CaseClass(library.noneType(t), Seq()))) + + case MapApply(m, k) => + val mt @ MapType(_, t) = normalizeType(m.getType) + typeToSort(mt) + + val el = z3.mkSelect(rec(m), rec(k)) + + // Really ?!? We don't check that it is actually != None? + selectors.toB(library.someType(t), 0)(el) + + case MapIsDefinedAt(m, k) => + val mt @ MapType(_, t) = normalizeType(m.getType) + typeToSort(mt) + + val el = z3.mkSelect(rec(m), rec(k)) + + testers.toB(library.someType(t))(el) + + case MapUnion(m1, FiniteMap(elems, _, _)) => + val mt @ MapType(_, t) = normalizeType(m1.getType) + typeToSort(mt) + + elems.foldLeft(rec(m1)) { case (m, (k,v)) => + z3.mkStore(m, rec(k), rec(CaseClass(library.someType(t), Seq(v)))) + } + + + case gv @ GenericValue(tp, id) => + z3.mkApp(genericValueToDecl(gv)) + + case other => + unsupported(other) + } + + rec(expr) } protected[leon] def fromZ3Formula(model: Z3Model, tree: Z3AST, tpe: TypeTree): Expr = { - def rec(t: Z3AST, expected_tpe: TypeTree): Expr = { + + def rec(t: Z3AST, tpe: TypeTree): Expr = { val kind = z3.getASTKind(t) - val tpe = Z3StringTypeConversion.convert(expected_tpe)(program) - val res = kind match { + kind match { case Z3NumeralIntAST(Some(v)) => val leading = t.toString.substring(0, 2 min t.toString.length) if(leading == "#x") { @@ -769,11 +758,6 @@ trait AbstractZ3Solver extends Solver { } case _ => unsound(t, "unexpected AST") } - expected_tpe match { - case StringType => - StringLiteral(Z3StringTypeConversion.convertToString(res)(program)) - case _ => res - } } rec(tree, normalizeType(tpe)) @@ -790,8 +774,7 @@ trait AbstractZ3Solver extends Solver { } def idToFreshZ3Id(id: Identifier): Z3AST = { - val correctType = Z3StringTypeConversion.convert(id.getType)(program) - z3.mkFreshConst(id.uniqueName, typeToSort(correctType)) + z3.mkFreshConst(id.uniqueName, typeToSort(id.getType)) } def reset() = { diff --git a/src/main/scala/leon/solvers/z3/Z3StringConversion.scala b/src/main/scala/leon/solvers/z3/Z3StringConversion.scala index 3daf1ad4964ad73e8c4d9701ae4e65d0f4170897..21df018db70935ad63b6c22e7d6bc77894005b23 100644 --- a/src/main/scala/leon/solvers/z3/Z3StringConversion.scala +++ b/src/main/scala/leon/solvers/z3/Z3StringConversion.scala @@ -7,51 +7,132 @@ import purescala.Expressions._ import purescala.Constructors._ import purescala.Types._ import purescala.Definitions._ -import _root_.smtlib.parser.Terms.{Identifier => SMTIdentifier, _} -import _root_.smtlib.parser.Commands.{FunDef => SMTFunDef, _} -import _root_.smtlib.interpreters.Z3Interpreter -import _root_.smtlib.theories.Core.{Equals => SMTEquals, _} -import _root_.smtlib.theories.ArraysEx import leon.utils.Bijection +import leon.purescala.DefOps +import leon.purescala.TypeOps +import leon.purescala.Extractors.Operator +import leon.evaluators.EvaluationResults -object Z3StringTypeConversion { - def convert(t: TypeTree)(implicit p: Program) = new Z3StringTypeConversion { def getProgram = p }.convertType(t) - def convertToString(e: Expr)(implicit p: Program) = new Z3StringTypeConversion{ def getProgram = p }.convertToString(e) -} - -trait Z3StringTypeConversion { - val stringBijection = new Bijection[String, Expr]() +object StringEcoSystem { + private def withIdentifier[T](name: String, tpe: TypeTree = Untyped)(f: Identifier => T): T = { + val id = FreshIdentifier(name, tpe) + f(id) + } + private def withIdentifiers[T](name: String, tpe: TypeTree, name2: String, tpe2: TypeTree = Untyped)(f: (Identifier, Identifier) => T): T = { + withIdentifier(name, tpe)(id => withIdentifier(name2, tpe2)(id2 => f(id, id2))) + } - lazy val conschar = program.lookupCaseClass("leon.collection.Cons") match { - case Some(cc) => cc.typed(Seq(CharType)) - case _ => throw new Exception("Could not find Cons in Z3 solver") + val StringList = AbstractClassDef(FreshIdentifier("StringList"), Seq(), None) + val StringListTyped = StringList.typed + val StringCons = withIdentifiers("head", CharType, "tail", StringListTyped){ (head, tail) => + val d = CaseClassDef(FreshIdentifier("StringCons"), Seq(), Some(StringListTyped), false) + d.setFields(Seq(ValDef(head), ValDef(tail))) + d } - lazy val nilchar = program.lookupCaseClass("leon.collection.Nil") match { - case Some(cc) => cc.typed(Seq(CharType)) - case _ => throw new Exception("Could not find Nil in Z3 solver") + StringList.registerChild(StringCons) + val StringConsTyped = StringCons.typed + val StringNil = CaseClassDef(FreshIdentifier("StringNil"), Seq(), Some(StringListTyped), false) + val StringNilTyped = StringNil.typed + StringList.registerChild(StringNil) + + val StringSize = withIdentifiers("l", StringListTyped, "StringSize"){ (lengthArg, id) => + val fd = new FunDef(id, Seq(), Seq(ValDef(lengthArg)), IntegerType) + fd.body = Some(withIdentifiers("h", CharType, "t", StringListTyped){ (h, t) => + MatchExpr(Variable(lengthArg), Seq( + MatchCase(CaseClassPattern(None, StringNilTyped, Seq()), None, InfiniteIntegerLiteral(BigInt(0))), + MatchCase(CaseClassPattern(None, StringConsTyped, Seq(WildcardPattern(Some(h)), WildcardPattern(Some(t)))), None, + Plus(InfiniteIntegerLiteral(BigInt(1)), FunctionInvocation(fd.typed, Seq(Variable(t))))) + )) + }) + fd } - lazy val listchar = program.lookupAbstractClass("leon.collection.List") match { - case Some(cc) => cc.typed(Seq(CharType)) - case _ => throw new Exception("Could not find List in Z3 solver") + val StringListConcat = withIdentifiers("x", StringListTyped, "y", StringListTyped) { (x, y) => + val fd = new FunDef(FreshIdentifier("StringListConcat"), Seq(), Seq(ValDef(x), ValDef(y)), StringListTyped) + fd.body = Some( + withIdentifiers("h", CharType, "t", StringListTyped){ (h, t) => + MatchExpr(Variable(x), Seq( + MatchCase(CaseClassPattern(None, StringNilTyped, Seq()), None, Variable(y)), + MatchCase(CaseClassPattern(None, StringConsTyped, Seq(WildcardPattern(Some(h)), WildcardPattern(Some(t)))), None, + CaseClass(StringConsTyped, Seq(Variable(h), FunctionInvocation(fd.typed, Seq(Variable(t), Variable(y))))) + ))) + } + ) + fd } - def lookupFunDef(s: String): FunDef = program.lookupFunDef(s) match { - case Some(fd) => fd - case _ => throw new Exception("Could not find function "+s+" in program") + + val StringTake = withIdentifiers("tt", StringListTyped, "it", StringListTyped) { (tt, it) => + val fd = new FunDef(FreshIdentifier("StringTake"), Seq(), Seq(ValDef(tt), ValDef(it)), StringListTyped) + fd.body = Some{ + withIdentifiers("h", CharType, "t", StringListTyped) { (h, t) => + withIdentifier("i", IntegerType){ i => + MatchExpr(Tuple(Seq(Variable(tt), Variable(it))), Seq( + MatchCase(TuplePattern(None, Seq(CaseClassPattern(None, StringNilTyped, Seq()), WildcardPattern(None))), None, + InfiniteIntegerLiteral(BigInt(0))), + MatchCase(TuplePattern(None, Seq(CaseClassPattern(None, StringConsTyped, Seq(WildcardPattern(Some(h)), WildcardPattern(Some(t)))), WildcardPattern(Some(i)))), None, + IfExpr(LessThan(Variable(i), InfiniteIntegerLiteral(BigInt(0))), + CaseClass(StringNilTyped, Seq()), + CaseClass(StringConsTyped, Seq(Variable(h), + FunctionInvocation(fd.typed, Seq(Variable(t), Minus(Variable(i), InfiniteIntegerLiteral(BigInt(1))))))) + )))) + } + } + } + fd + } + + val StringDrop = withIdentifiers("td", StringListTyped, "id", IntegerType) { (td, id) => + val fd = new FunDef(FreshIdentifier("StringDrop"), Seq(), Seq(ValDef(td), ValDef(id)), StringListTyped) + fd.body = Some( + withIdentifiers("h", CharType, "t", StringListTyped) { (h, t) => + withIdentifier("i", IntegerType){ i => + MatchExpr(Tuple(Seq(Variable(td), Variable(id))), Seq( + MatchCase(TuplePattern(None, Seq(CaseClassPattern(None, StringNilTyped, Seq()), WildcardPattern(None))), None, + InfiniteIntegerLiteral(BigInt(0))), + MatchCase(TuplePattern(None, Seq(CaseClassPattern(None, StringConsTyped, Seq(WildcardPattern(Some(h)), WildcardPattern(Some(t)))), WildcardPattern(Some(i)))), None, + IfExpr(LessThan(Variable(i), InfiniteIntegerLiteral(BigInt(0))), + CaseClass(StringConsTyped, Seq(Variable(h), Variable(t))), + FunctionInvocation(fd.typed, Seq(Variable(t), Minus(Variable(i), InfiniteIntegerLiteral(BigInt(1))))) + )))) + }} + ) + fd } - lazy val list_size = lookupFunDef("leon.collection.List.size").typed(Seq(CharType)) - lazy val list_++ = lookupFunDef("leon.collection.List.++").typed(Seq(CharType)) - lazy val list_take = lookupFunDef("leon.collection.List.take").typed(Seq(CharType)) - lazy val list_drop = lookupFunDef("leon.collection.List.drop").typed(Seq(CharType)) - lazy val list_slice = lookupFunDef("leon.collection.List.slice").typed(Seq(CharType)) - private lazy val program = getProgram + val StringSlice = withIdentifier("s", StringListTyped) { s => withIdentifiers("from", IntegerType, "to", IntegerType) { (from, to) => + val fd = new FunDef(FreshIdentifier("StringSlice"), Seq(), Seq(ValDef(s), ValDef(from), ValDef(to)), StringListTyped) + fd.body = Some( + FunctionInvocation(StringTake.typed, + Seq(FunctionInvocation(StringDrop.typed, Seq(Variable(s), Variable(from))), + Minus(Variable(to), Variable(from))))) + fd + } } - def getProgram: Program + val classDefs = Seq(StringList, StringCons, StringNil) + val funDefs = Seq(StringSize, StringListConcat, StringTake, StringDrop, StringSlice) +} + +class Z3StringConversion(val p: Program) extends Z3StringConverters { + val stringBijection = new Bijection[String, Expr]() + + import StringEcoSystem._ + + lazy val listchar = StringList.typed + lazy val conschar = StringCons.typed + lazy val nilchar = StringNil.typed + + lazy val list_size = StringSize.typed + lazy val list_++ = StringListConcat.typed + lazy val list_take = StringTake.typed + lazy val list_drop = StringDrop.typed + lazy val list_slice = StringSlice.typed - def convertType(t: TypeTree): TypeTree = t match { - case StringType => listchar - case _ => t + def getProgram = program_with_string_methods + + lazy val program_with_string_methods = { + val p2 = DefOps.addClassDefs(p, StringEcoSystem.classDefs, p.library.Nil.get) + DefOps.addFunDefs(p2, StringEcoSystem.funDefs, p2.library.escape.get) } + def convertToString(e: Expr)(implicit p: Program): String = stringBijection.cachedA(e) { e match { @@ -59,7 +140,7 @@ trait Z3StringTypeConversion { case CaseClass(_, Seq()) => "" } } - def convertFromString(v: String) = + def convertFromString(v: String): Expr = stringBijection.cachedB(v) { v.toList.foldRight(CaseClass(nilchar, Seq())){ case (char, l) => CaseClass(conschar, Seq(CharLiteral(char), l)) @@ -67,28 +148,226 @@ trait Z3StringTypeConversion { } } -trait Z3StringConversion[TargetType] extends Z3StringTypeConversion { - def convertToTarget(e: Expr)(implicit bindings: Map[Identifier, TargetType]): TargetType - def targetApplication(fd: TypedFunDef, args: Seq[TargetType])(implicit bindings: Map[Identifier, TargetType]): TargetType +trait Z3StringConverters { self: Z3StringConversion => + import StringEcoSystem._ + val mappedVariables = new Bijection[Identifier, Identifier]() + + val globalFdMap = new Bijection[FunDef, FunDef]() - object StringConverted { - def unapply(e: Expr)(implicit bindings: Map[Identifier, TargetType]): Option[TargetType] = e match { + trait BidirectionalConverters { + def convertFunDef(fd: FunDef): FunDef + def hasIdConversion(id: Identifier): Boolean + def convertId(id: Identifier): Identifier + def isTypeToConvert(tpe: TypeTree): Boolean + def convertType(tpe: TypeTree): TypeTree + def convertPattern(pattern: Pattern): Pattern + def convertExpr(expr: Expr)(implicit bindings: Map[Identifier, Expr]): Expr + + object PatternConverted { + def unapply(e: Pattern): Option[Pattern] = Some(e match { + case InstanceOfPattern(binder, ct) => + InstanceOfPattern(binder.map(convertId), convertType(ct).asInstanceOf[ClassType]) + case WildcardPattern(binder) => + WildcardPattern(binder.map(convertId)) + case CaseClassPattern(binder, ct, subpatterns) => + CaseClassPattern(binder.map(convertId), convertType(ct).asInstanceOf[CaseClassType], subpatterns map convertPattern) + case TuplePattern(binder, subpatterns) => + TuplePattern(binder.map(convertId), subpatterns map convertPattern) + case UnapplyPattern(binder, TypedFunDef(fd, tpes), subpatterns) => + UnapplyPattern(binder.map(convertId), TypedFunDef(convertFunDef(fd), tpes map convertType), subpatterns map convertPattern) + case PatternExtractor(es, builder) => + builder(es map convertPattern) + }) + } + + object ExprConverted { + def unapply(e: Expr)(implicit bindings: Map[Identifier, Expr]): Option[Expr] = Some(e match { + case Variable(id) if bindings contains id => bindings(id).copiedFrom(e) + case Variable(id) if hasIdConversion(id) => Variable(convertId(id)).copiedFrom(e) + case Variable(id) => e + case pl@PartialLambda(mappings, default, tpe) => + PartialLambda( + mappings.map(kv => (kv._1.map(argtpe => convertExpr(argtpe)), + convertExpr(kv._2))), + default.map(d => convertExpr(d)), convertType(tpe).asInstanceOf[FunctionType]) + case Lambda(args, body) => + println("Converting Lambda :" + e) + val new_bindings = scala.collection.mutable.ListBuffer[(Identifier, Identifier)]() + val new_args = for(arg <- args) yield { + val in = arg.getType + val new_id = convertId(arg.id) + if(new_id ne arg.id) { + new_bindings += (arg.id -> new_id) + ValDef(new_id) + } else arg + } + val res = Lambda(new_args, convertExpr(body)(bindings ++ new_bindings.map(t => (t._1, Variable(t._2))))).copiedFrom(e) + res + case Let(a, expr, body) if isTypeToConvert(a.getType) => + val new_a = convertId(a) + val new_bindings = bindings + (a -> Variable(new_a)) + val expr2 = convertExpr(expr)(new_bindings) + val body2 = convertExpr(body)(new_bindings) + Let(new_a, expr2, body2).copiedFrom(e) + case CaseClass(CaseClassType(ccd, tpes), args) => + CaseClass(CaseClassType(ccd, tpes map convertType), args map convertExpr).copiedFrom(e) + case CaseClassSelector(CaseClassType(ccd, tpes), caseClass, selector) => + CaseClassSelector(CaseClassType(ccd, tpes map convertType), convertExpr(caseClass), selector).copiedFrom(e) + case MethodInvocation(rec: Expr, cd: ClassDef, TypedFunDef(fd, tpes), args: Seq[Expr]) => + MethodInvocation(convertExpr(rec), cd, TypedFunDef(convertFunDef(fd), tpes map convertType), args map convertExpr).copiedFrom(e) + case FunctionInvocation(TypedFunDef(fd, tpes), args) => + FunctionInvocation(TypedFunDef(convertFunDef(fd), tpes map convertType), args map convertExpr).copiedFrom(e) + case This(ct: ClassType) => + This(convertType(ct).asInstanceOf[ClassType]).copiedFrom(e) + case IsInstanceOf(expr, ct) => + IsInstanceOf(convertExpr(expr), convertType(ct).asInstanceOf[ClassType]).copiedFrom(e) + case AsInstanceOf(expr, ct) => + AsInstanceOf(convertExpr(expr), convertType(ct).asInstanceOf[ClassType]).copiedFrom(e) + case Tuple(args) => + Tuple(for(arg <- args) yield convertExpr(arg)).copiedFrom(e) + case MatchExpr(scrutinee, cases) => + MatchExpr(convertExpr(scrutinee), for(MatchCase(pattern, guard, rhs) <- cases) yield { + MatchCase(convertPattern(pattern), guard.map(convertExpr), convertExpr(rhs)) + }) + case Operator(es, builder) => + val rec = convertExpr _ + val newEs = es.map(rec) + if ((newEs zip es).exists { case (bef, aft) => aft ne bef }) { + builder(newEs).copiedFrom(e) + } else { + e + } + case e => e + }) + } + + def convertModel(model: Model): Model = { + new Model(model.ids.map{i => + val id = convertId(i) + id -> convertExpr(model(i))(Map()) + }.toMap) + } + + def convertResult(result: EvaluationResults.Result[Expr]) = { + result match { + case EvaluationResults.Successful(e) => EvaluationResults.Successful(convertExpr(e)(Map())) + case result => result + } + } + } + + object Forward extends BidirectionalConverters { + /* The conversion between functions should already have taken place */ + def convertFunDef(fd: FunDef): FunDef = { + globalFdMap.getBorElse(fd, fd) + } + def hasIdConversion(id: Identifier): Boolean = { + mappedVariables.containsA(id) + } + def convertId(id: Identifier): Identifier = { + mappedVariables.getB(id) match { + case Some(idB) => idB + case None => + if(isTypeToConvert(id.getType)) { + val new_id = FreshIdentifier(id.name, convertType(id.getType)) + mappedVariables += (id -> new_id) + new_id + } else id + } + } + def isTypeToConvert(tpe: TypeTree): Boolean = + TypeOps.exists(StringType == _)(tpe) + def convertType(tpe: TypeTree): TypeTree = + TypeOps.preMap{ case StringType => Some(StringList.typed) case e => None}(tpe) + def convertPattern(e: Pattern): Pattern = e match { + case LiteralPattern(binder, StringLiteral(s)) => + s.foldRight(CaseClassPattern(None, StringNilTyped, Seq())) { + case (elem, pattern) => + CaseClassPattern(None, StringConsTyped, Seq(LiteralPattern(None, CharLiteral(elem)), pattern)) + } + case PatternConverted(e) => e + } + + /** Method which can use recursively StringConverted in its body in unapply positions */ + def convertExpr(e: Expr)(implicit bindings: Map[Identifier, Expr]): Expr = e match { + case Variable(id) if isTypeToConvert(id.getType) => Variable(convertId(id)).copiedFrom(e) case StringLiteral(v) => // No string support for z3 at this moment. val stringEncoding = convertFromString(v) - Some(convertToTarget(stringEncoding)) + convertExpr(stringEncoding).copiedFrom(e) case StringLength(a) => - Some(targetApplication(list_size, Seq(convertToTarget(a)))) + FunctionInvocation(list_size, Seq(convertExpr(a))).copiedFrom(e) case StringConcat(a, b) => - Some(targetApplication(list_++, Seq(convertToTarget(a), convertToTarget(b)))) + FunctionInvocation(list_++, Seq(convertExpr(a), convertExpr(b))).copiedFrom(e) case SubString(a, start, Plus(start2, length)) if start == start2 => - Some(targetApplication(list_take, - Seq(targetApplication(list_drop, Seq(convertToTarget(a), convertToTarget(start))), convertToTarget(length)))) + FunctionInvocation(list_take, + Seq(FunctionInvocation(list_drop, Seq(convertExpr(a), convertExpr(start))), convertExpr(length))).copiedFrom(e) case SubString(a, start, end) => - Some(targetApplication(list_slice, Seq(convertToTarget(a), convertToTarget(start), convertToTarget(end)))) - case _ => None + FunctionInvocation(list_slice, Seq(convertExpr(a), convertExpr(start), convertExpr(end))).copiedFrom(e) + case MatchExpr(scrutinee, cases) => + MatchExpr(convertExpr(scrutinee), for(MatchCase(pattern, guard, rhs) <- cases) yield { + MatchCase(convertPattern(pattern), guard.map(convertExpr), convertExpr(rhs)) + }) + case ExprConverted(e) => e + } + } + + object Backward extends BidirectionalConverters { + def convertFunDef(fd: FunDef): FunDef = { + globalFdMap.getAorElse(fd, fd) + } + def hasIdConversion(id: Identifier): Boolean = { + mappedVariables.containsB(id) + } + def convertId(id: Identifier): Identifier = { + mappedVariables.getA(id) match { + case Some(idA) => idA + case None => + if(isTypeToConvert(id.getType)) { + val old_type = convertType(id.getType) + val old_id = FreshIdentifier(id.name, old_type) + mappedVariables += (old_id -> id) + old_id + } else id + } + } + def convertIdToMapping(id: Identifier): (Identifier, Variable) = { + id -> Variable(convertId(id)) } + def isTypeToConvert(tpe: TypeTree): Boolean = + TypeOps.exists(t => TypeOps.isSubtypeOf(t, StringListTyped))(tpe) + def convertType(tpe: TypeTree): TypeTree = { + TypeOps.preMap{ + case StringList | StringCons | StringNil => Some(StringType) + case e => None}(tpe) + } + def convertPattern(e: Pattern): Pattern = e match { + case CaseClassPattern(b, StringNilTyped, Seq()) => + LiteralPattern(b.map(convertId), StringLiteral("")) + case CaseClassPattern(b, StringConsTyped, Seq(LiteralPattern(_, CharLiteral(elem)), subpattern)) => + convertPattern(subpattern) match { + case LiteralPattern(_, StringLiteral(s)) + => LiteralPattern(b.map(convertId), StringLiteral(elem + s)) + case e => LiteralPattern(None, StringLiteral("Failed to parse pattern back as string:" + e)) + } + case PatternConverted(e) => e + } - def apply(t: TypeTree): TypeTree = convertType(t) + + + def convertExpr(e: Expr)(implicit bindings: Map[Identifier, Expr]): Expr = + e match { + case cc@CaseClass(cct, args) if TypeOps.isSubtypeOf(cct, StringListTyped)=> + StringLiteral(convertToString(cc)(self.p)) + case FunctionInvocation(StringSize, Seq(a)) => + StringLength(convertExpr(a)).copiedFrom(e) + case FunctionInvocation(StringListConcat, Seq(a, b)) => + StringConcat(convertExpr(a), convertExpr(b)).copiedFrom(e) + case FunctionInvocation(StringTake, + Seq(FunctionInvocation(StringDrop, Seq(a, start)), length)) => + val rstart = convertExpr(start) + SubString(convertExpr(a), rstart, plus(rstart, convertExpr(length))).copiedFrom(e) + case ExprConverted(e) => e + } } } \ No newline at end of file diff --git a/src/main/scala/leon/utils/Bijection.scala b/src/main/scala/leon/utils/Bijection.scala index 57a62b665c797b10fab2d099fabd3a722f6e7d27..3680930639a2cfba46490d4a21bab7772d7fd0c8 100644 --- a/src/main/scala/leon/utils/Bijection.scala +++ b/src/main/scala/leon/utils/Bijection.scala @@ -11,8 +11,13 @@ class Bijection[A, B] { b2a += b -> a } - def +=(t: (A,B)): Unit = { - this += (t._1, t._2) + def +=(t: (A,B)): this.type = { + +=(t._1, t._2) + this + } + + def ++=(t: Iterable[(A,B)]) = { + (this /: t){ case (b, elem) => b += elem } } def clear(): Unit = { @@ -22,6 +27,9 @@ class Bijection[A, B] { def getA(b: B) = b2a.get(b) def getB(a: A) = a2b.get(a) + + def getAorElse(b: B, orElse: =>A) = b2a.getOrElse(b, orElse) + def getBorElse(a: A, orElse: =>B) = a2b.getOrElse(a, orElse) def toA(b: B) = getA(b).get def toB(a: A) = getB(a).get diff --git a/src/main/scala/leon/utils/InliningPhase.scala b/src/main/scala/leon/utils/InliningPhase.scala index 8053a8dc1e9956c9ef264c247d41f8d96ffbb934..17fdf48bc88d88cca9c49e63672af6f3c76afa48 100644 --- a/src/main/scala/leon/utils/InliningPhase.scala +++ b/src/main/scala/leon/utils/InliningPhase.scala @@ -5,7 +5,7 @@ package leon.utils import leon._ import purescala.Definitions._ import purescala.Expressions._ -import purescala.TypeOps._ +import purescala.TypeOps.instantiateType import purescala.ExprOps._ import purescala.DefOps._ import purescala.Constructors.caseClassSelector diff --git a/src/main/scala/leon/xlang/ImperativeCodeElimination.scala b/src/main/scala/leon/xlang/ImperativeCodeElimination.scala index a298e90b16f405a19856c1a38a276a04153fd500..6b7f7cc6ee3c00827289313ed60e111b6ec3a640 100644 --- a/src/main/scala/leon/xlang/ImperativeCodeElimination.scala +++ b/src/main/scala/leon/xlang/ImperativeCodeElimination.scala @@ -9,7 +9,7 @@ import leon.purescala.Expressions._ import leon.purescala.Extractors._ import leon.purescala.Constructors._ import leon.purescala.ExprOps._ -import leon.purescala.TypeOps._ +import leon.purescala.TypeOps.leastUpperBound import leon.purescala.Types._ import leon.xlang.Expressions._ diff --git a/src/test/scala/leon/integration/solvers/SolversSuite.scala b/src/test/scala/leon/integration/solvers/SolversSuite.scala index d568e471f08eb4cf1a558675419daabd9e9c940a..c5571fa4f9b2def4f33a586247aa7eb213483a10 100644 --- a/src/test/scala/leon/integration/solvers/SolversSuite.scala +++ b/src/test/scala/leon/integration/solvers/SolversSuite.scala @@ -22,13 +22,13 @@ class SolversSuite extends LeonTestSuiteWithProgram { val getFactories: Seq[(String, (LeonContext, Program) => Solver)] = { (if (SolverFactory.hasNativeZ3) Seq( - ("fairz3", (ctx: LeonContext, pgm: Program) => new FairZ3Solver(ctx, pgm)) + ("fairz3", (ctx: LeonContext, pgm: Program) => new Z3StringFairZ3Solver(ctx, pgm) with ForcedProgramConversion ) ) else Nil) ++ (if (SolverFactory.hasZ3) Seq( - ("smt-z3", (ctx: LeonContext, pgm: Program) => new UnrollingSolver(ctx, pgm, new SMTLIBZ3Solver(ctx, pgm))) + ("smt-z3", (ctx: LeonContext, pgm: Program) => new Z3StringUnrollingSolver(ctx, pgm, pgm => new SMTLIBZ3Solver(ctx, pgm)) with ForcedProgramConversion ) ) else Nil) ++ (if (SolverFactory.hasCVC4) Seq( - ("smt-cvc4", (ctx: LeonContext, pgm: Program) => new UnrollingSolver(ctx, pgm, new SMTLIBCVC4Solver(ctx, pgm))) + ("smt-cvc4", (ctx: LeonContext, pgm: Program) => new Z3StringUnrollingSolver(ctx, pgm, pgm => new SMTLIBCVC4Solver(ctx, pgm)) with ForcedProgramConversion ) ) else Nil) } @@ -78,7 +78,7 @@ class SolversSuite extends LeonTestSuiteWithProgram { } } case _ => - fail(s"Solver $solver - Constraint "+cnstr.asString+" is unsat!?") + fail(s"Solver $solver - Constraint "+cnstr.asString+" is unsat!? Solver was "+solver.getClass) } } finally { solver.free() diff --git a/src/test/scala/leon/unit/purescala/ExprOpsSuite.scala b/src/test/scala/leon/unit/purescala/ExprOpsSuite.scala index 827596fc6cd116aefc752a37aeb30647e3ebdf99..4f74b00c5f93a3125caece106809dfd4182fa8a8 100644 --- a/src/test/scala/leon/unit/purescala/ExprOpsSuite.scala +++ b/src/test/scala/leon/unit/purescala/ExprOpsSuite.scala @@ -6,7 +6,7 @@ import leon.test._ import leon.purescala.Common._ import leon.purescala.Expressions._ import leon.purescala.Types._ -import leon.purescala.TypeOps._ +import leon.purescala.TypeOps.isSubtypeOf import leon.purescala.Definitions._ import leon.purescala.ExprOps._ diff --git a/testcases/verification/strings/invalid/CompatibleListChar.scala b/testcases/verification/strings/invalid/CompatibleListChar.scala new file mode 100644 index 0000000000000000000000000000000000000000..86eec34cddcee8055c34d5ffc791b7bbf7a397e7 --- /dev/null +++ b/testcases/verification/strings/invalid/CompatibleListChar.scala @@ -0,0 +1,29 @@ +import leon.lang._ +import leon.annotation._ +import leon.collection._ +import leon.collection.ListOps._ +import leon.lang.synthesis._ + +object CompatibleListChar { + def rec[T](l : List[T], f : T => String): String = l match { + case Cons(head, tail) => f(head) + rec(tail, f) + case Nil() => "" + } + def customToString[T](l : List[T], p: List[Char], d: String, fd: String => String, fp: List[Char] => String, pf: String => List[Char], f : T => String): String = rec(l, f) ensuring { + (res : String) => (p == Nil[Char]() || d == "" || fd(d) == "" || fp(p) == "" || pf(d) == Nil[Char]()) && ((l, res) passes { + case Cons(a, Nil()) => f(a) + }) + } + def customPatternMatching(s: String): Boolean = { + s match { + case "" => true + case b => List(b) match { + case Cons("", Nil()) => true + case Cons(s, Nil()) => false // StrOps.length(s) < BigInt(2) // || (s == "\u0000") //+ "a" + case Cons(_, Cons(_, Nil())) => true + case _ => false + } + case _ => false + } + } holds +} \ No newline at end of file