diff --git a/src/main/scala/inox/ast/SymbolOps.scala b/src/main/scala/inox/ast/SymbolOps.scala index ec0179bd9c5d33fd95ea6d60307296c6fb356add..3a51cd78aa9dcbe46da38771dff17a2418920144 100644 --- a/src/main/scala/inox/ast/SymbolOps.scala +++ b/src/main/scala/inox/ast/SymbolOps.scala @@ -91,6 +91,23 @@ trait SymbolOps { self: TypeOps => fixpoint(postMap(rec))(expr) } + /** Returns '''true''' iff the expression [[expr]] cannot lead to the evaluation of + * an [[Assume]] or [[Choose]] expression. */ + def isPure(expr: Expr): Boolean = { + val callees = collect { + case fi: FunctionInvocation => Set(fi.tfd.fd) + case _ => Set.empty[FunDef] + } (expr) + + val allCallees = callees ++ callees.flatMap(transitiveCallees) + !(expr +: allCallees.toSeq.map(_.fullBody)).exists { expr => + exists { + case (_: Assume) | (_: Choose) => true + case _ => false + }(expr) + } + } + private val typedIds: MutableMap[Type, List[Identifier]] = MutableMap.empty.withDefaultValue(List.empty) @@ -144,19 +161,94 @@ trait SymbolOps { self: TypeOps => } def outer(vars: Set[Variable], body: Expr): Expr = { + // this registers the argument images into subst + val tvars = vars map (v => v.copy(id = transformId(v.id, v.tpe))) + + def isLocal(e: Expr, path: Path): Boolean = { + val vs = variablesOf(e) + val bindings = path.bindings.map(p => p._1.toVariable -> p._2).toMap + (tvars & (vs.flatMap { v => + varSubst.get(v.id).map(Variable(_, v.tpe)).toSet ++ + bindings.get(v).toSet.flatMap(variablesOf) + })).isEmpty + } + + def isSatisfiable(path: Path): Option[Boolean] = path.conditions match { + case Seq() => Some(true) + case conds => + val (posConds, posRest) = conds.partition { + case IsInstanceOf(v: Variable, adt) if tvars(v) => true + case _ => false + } + + val (negConds, negRest) = posRest.partition { + case Not(IsInstanceOf(v: Variable, adt)) if tvars(v) => true + case _ => false + } + + if (negRest.nonEmpty) { + None + } else { + val posAdts = posConds.collect { + case IsInstanceOf(v: Variable, tpe) if tpe != tpe.getADT.root.toType => v -> tpe + }.groupBy(_._1).mapValues(_.map(_._2).toSet) + + val negAdts = negConds.collect { + case Not(IsInstanceOf(v: Variable, tpe)) => v -> tpe + }.groupBy(_._1).mapValues(_.map(_._2).toSet) + + val results = for (v: Variable <- (posAdts.keys ++ negAdts.keys).toSet) yield { + val pos = posAdts.getOrElse(v, Set.empty) + val neg = negAdts.getOrElse(v, Set.empty) + + if (pos.size > 1 || (pos & neg).nonEmpty) { + Some(false) + } else { + val constructors = ((pos ++ neg).head.getADT.root match { + case tsort: TypedADTSort => tsort.constructors + case tcons: TypedADTConstructor => Seq(tcons) + }).map(_.toType).toSet + + if (neg == constructors) { + Some(false) + } else { + None + } + } + } - object normalizer extends SelfTreeTransformer { - override def transform(id: Identifier, tpe: Type): (Identifier, Type) = (transformId(id, tpe), tpe) + if (results.exists(_ contains false)) { + Some(false) + } else if (results.forall(_ contains true)) { + Some(true) + } else { + None + } + } + } - override def transform(e: Expr): Expr = e match { + object normalizer extends transformers.TransformerWithPC { + val trees: self.trees.type = self.trees + val symbols: self.symbols.type = self.symbols + val initEnv = Path.empty + + override protected def rec(e: Expr, path: Path): Expr = e match { case Variable(id, tpe) => Variable(transformId(id, tpe), tpe) - case Let(vd, e, b) if (!onlySimple || isSimple(e)) && (variablesOf(e) & vars).isEmpty => + case Let(vd, e, b) if ( + isLocal(e, path) && + (!onlySimple || isSimple(e)) && + ((isSatisfiable(path) contains true) || isPure(e)) + ) => val newId = getId(e) - transform(replaceFromSymbols(Map(vd.toVariable -> Variable(newId, vd.tpe)), b)) + rec(replaceFromSymbols(Map(vd.toVariable -> Variable(newId, vd.tpe)), b), path) - case expr if (!onlySimple || isSimple(expr)) && (variablesOf(expr) & vars).isEmpty => + case expr if ( + isLocal(expr, path) && + (!onlySimple || isSimple(expr)) && + ((isSatisfiable(path) contains true) || isPure(expr)) + ) => Variable(getId(expr), expr.getType) case f: Forall => @@ -167,12 +259,13 @@ trait SymbolOps { self: TypeOps => val newBody = outer(vars ++ l.args.map(_.toVariable), l.body) Lambda(l.args.map(vd => vd.copy(id = varSubst(vd.id))), newBody) - case _ => super.transform(e) + case _ => + val (vs, es, tps, recons) = deconstructor.deconstruct(e) + val newVs = vs.map(v => v.copy(id = transformId(v.id, v.tpe))) + super.rec(recons(newVs, es, tps), path) } } - // this registers the argument images into subst - vars foreach (v => transformId(v.id, v.tpe)) normalizer.transform(body) }