diff --git a/src/main/scala/inox/Program.scala b/src/main/scala/inox/Program.scala index c478fc490e33ccab3cb08ce24b13333ef40575c7..15fc7e6d7e5b072cc989ee418fa79e3920ec0386 100644 --- a/src/main/scala/inox/Program.scala +++ b/src/main/scala/inox/Program.scala @@ -11,4 +11,10 @@ trait Program { implicit def implicitProgram: this.type = this implicit def printerOpts: trees.PrinterOptions = trees.PrinterOptions.fromSymbols(symbols, ctx) + + def transform(t: trees.TreeTransformer): Program { val trees: Program.this.trees.type } = new Program { + val trees: Program.this.trees.type = Program.this.trees + val symbols = Program.this.symbols.transform(t) + val ctx = Program.this.ctx + } } diff --git a/src/main/scala/inox/ast/CallGraph.scala b/src/main/scala/inox/ast/CallGraph.scala index 7cd906ba60b4f348c51f97a4744916ca665d110a..e4498f7af9d420915b0a228c12321fdde5b3ba63 100644 --- a/src/main/scala/inox/ast/CallGraph.scala +++ b/src/main/scala/inox/ast/CallGraph.scala @@ -11,22 +11,15 @@ trait CallGraph { import trees.exprOps._ protected val symbols: Symbols - private def collectCallsInPats(fd: FunDef)(p: Pattern): Set[(FunDef, FunDef)] = - (p match { - case u: UnapplyPattern => Set((fd, symbols.getFunction(u.fd))) - case _ => Set() - }) ++ p.subPatterns.flatMap(collectCallsInPats(fd)) - private def collectCalls(fd: FunDef)(e: Expr): Set[(FunDef, FunDef)] = e match { case f @ FunctionInvocation(id, tps, _) => Set((fd, symbols.getFunction(id))) - case MatchExpr(_, cases) => cases.toSet.flatMap((mc: MatchCase) => collectCallsInPats(fd)(mc.pattern)) case _ => Set() } lazy val graph: DiGraph[FunDef, SimpleEdge[FunDef]] = { var g = DiGraph[FunDef, SimpleEdge[FunDef]]() - for ((_, fd) <- symbols.functions; c <- collect(collectCalls(fd))(fd.fullBody)) { + for ((_, fd) <- symbols.functions; body <- fd.body; c <- collect(collectCalls(fd))(body)) { g += SimpleEdge(c._1, c._2) } diff --git a/src/main/scala/inox/ast/Constructors.scala b/src/main/scala/inox/ast/Constructors.scala index 9b11aad35d0673864951ffb1b8fc47c8740631e3..e9bf6c88568941cde347ea2f98810b6392ef121a 100644 --- a/src/main/scala/inox/ast/Constructors.scala +++ b/src/main/scala/inox/ast/Constructors.scala @@ -47,26 +47,6 @@ trait Constructors { else bd } - /** $encodingof ``val (...binders...) = value; body`` which is translated to ``value match { case (...binders...) => body }``, and returns `body` if the identifiers are not bound in `body`. - * @see [[purescala.Expressions.Let]] - */ - def letTuple(binders: Seq[ValDef], value: Expr, body: Expr) = binders match { - case Nil => - body - case x :: Nil => - Let(x, value, body) - case xs => - require( - value.getType match { - case TupleType(args) => args.size == xs.size - case _ => false - }, - s"In letTuple: '$value' is being assigned as a tuple of arity ${xs.size}; yet its type is '${value.getType}' (body is '$body')" - ) - - LetPattern(TuplePattern(None,binders map { b => WildcardPattern(Some(b)) }), value, body) - } - /** Wraps the sequence of expressions as a tuple. If the sequence contains a single expression, it is returned instead. * @see [[purescala.Expressions.Tuple]] */ @@ -76,17 +56,6 @@ trait Constructors { case more => Tuple(more) } - /** Wraps the sequence of patterns as a tuple. If the sequence contains a single pattern, it is returned instead. - * If the sequence is empty, [[purescala.Expressions.LiteralPattern `LiteralPattern`]]`(None, `[[purescala.Expressions.UnitLiteral `UnitLiteral`]]`())` is returned. - * @see [[purescala.Expressions.TuplePattern]] - * @see [[purescala.Expressions.LiteralPattern]] - */ - def tuplePatternWrap(ps: Seq[Pattern]) = ps match { - case Seq() => LiteralPattern(None, UnitLiteral()) - case Seq(elem) => elem - case more => TuplePattern(None, more) - } - /** Wraps the sequence of types as a tuple. If the sequence contains a single type, it is returned instead. * If the sequence is empty, the [[purescala.Types.UnitType UnitType]] is returned. * @see [[purescala.Types.TupleType]] @@ -117,53 +86,15 @@ trait Constructors { /** Simplifies the provided case class selector. * @see [[purescala.Expressions.CaseClassSelector]] */ - def caseClassSelector(classType: ClassType, caseClass: Expr, selector: Identifier): Expr = { + def caseClassSelector(caseClass: Expr, selector: Identifier): Expr = { caseClass match { - case CaseClass(ct, fields) if ct == classType && !ct.tcd.hasInvariant => + case CaseClass(ct, fields) if !ct.tcd.hasInvariant => fields(ct.tcd.cd.asInstanceOf[CaseClassDef].selectorID2Index(selector)) case _ => - CaseClassSelector(classType, caseClass, selector) + CaseClassSelector(caseClass, selector) } } - /** $encoding of `case ... if ... => ... ` but simplified if possible, based on types of the encompassing [[purescala.Expressions.CaseClassPattern MatchExpr]]. - * @see [[purescala.Expressions.CaseClassPattern MatchExpr]] - * @see [[purescala.Expressions.CaseClassPattern CaseClassPattern]] - */ - private def filterCases(scrutType: Type, resType: Option[Type], cases: Seq[MatchCase]): Seq[MatchCase] = { - val casesFiltered = scrutType match { - case c: ClassType if !c.tcd.cd.isAbstract => - cases.filter(_.pattern match { - case CaseClassPattern(_, cct, _) if cct.id != c.id => false - case _ => true - }) - - case _ => - cases - } - - resType match { - case Some(tpe) => - casesFiltered.filter(c => symbols.typesCompatible(c.rhs.getType, tpe)) - case None => - casesFiltered - } - } - - /** $encodingof `... match { ... }` but simplified if possible. Simplifies to [[Error]] if no case can match the scrutined expression. - * @see [[purescala.Expressions.MatchExpr MatchExpr]] - */ - def matchExpr(scrutinee: Expr, cases: Seq[MatchCase]): Expr = { - val filtered = filterCases(scrutinee.getType, None, cases) - if (filtered.nonEmpty) - MatchExpr(scrutinee, filtered) - else - Error( - cases.headOption.map{ _.rhs.getType }.getOrElse(Untyped), - "No case matches the scrutinee" - ) - } - /** $encodingof `&&`-expressions with arbitrary number of operands, and simplified. * @see [[purescala.Expressions.And And]] */ @@ -240,23 +171,13 @@ trait Constructors { */ // @mk I simplified that because it seemed dangerous and unnessecary def equality(a: Expr, b: Expr): Expr = { - if (a.isInstanceOf[Terminal] && isPurelyFunctional(a) && a == b ) { + if (a.isInstanceOf[Terminal] && a == b ) { BooleanLiteral(true) } else { Equals(a, b) } } - def assertion(c: Expr, err: Option[String], res: Expr) = { - if (c == BooleanLiteral(true)) { - res - } else if (c == BooleanLiteral(false)) { - Error(res.getType, err.getOrElse("Assertion failed")) - } else { - Assert(c, err, res) - } - } - /** $encodingof simplified `fn(realArgs)` (function application). * Transforms * {{{ ((x: A, y: B) => g(x, y))(c, d) }}} @@ -283,9 +204,9 @@ trait Constructors { vd -> fresh.toVariable }.toMap - val (vds, bds) = defs.unzip - - letTuple(vds, tupleWrap(bds), exprOps.replaceFromSymbols(subst, body)) + defs.foldRight(exprOps.replaceFromSymbols(subst, body)) { + case ((vd, bd), body) => let(vd, bd, body) + } case _ => Application(fn, realArgs) @@ -359,28 +280,4 @@ trait Constructors { IsInstanceOf(expr, tpe) } } - - def req(pred: Expr, body: Expr) = pred match { - case BooleanLiteral(true) => body - case BooleanLiteral(false) => Error(body.getType, "Precondition failed") - case _ => Require(pred, body) - } - - def tupleWrapArg(fun: Expr) = fun.getType match { - case FunctionType(args, res) if args.size > 1 => - val newArgs = fun match { - case Lambda(args, _) => args - case _ => args map (tpe => ValDef(FreshIdentifier("x", alwaysShowUniqueID = true), tpe)) - } - val res = ValDef(FreshIdentifier("res", alwaysShowUniqueID = true), TupleType(args)) - val patt = TuplePattern(None, newArgs map (arg => WildcardPattern(Some(arg)))) - Lambda(Seq(res), MatchExpr(res.toVariable, Seq(SimpleCase(patt, application(fun, newArgs map (_.toVariable)))))) - case _ => - fun - } - - def ensur(e: Expr, pred: Expr) = { - Ensuring(e, tupleWrapArg(pred)) - } - } diff --git a/src/main/scala/inox/ast/DSL.scala b/src/main/scala/inox/ast/DSL.scala index fc33d002746a144a447bceb0ed526cf0558e86fe..03380ea15ee5d3bb82707de065f81220b2e8386a 100644 --- a/src/main/scala/inox/ast/DSL.scala +++ b/src/main/scala/inox/ast/DSL.scala @@ -84,15 +84,16 @@ trait DSL { // Misc. - def getField(cc: ClassType, selector: String) = { + def getField(selector: String) = { val id = for { - tcd <- cc.lookupClass + ct <- scala.util.Try(e.getType.asInstanceOf[ClassType]).toOption + tcd <- ct.lookupClass tccd <- scala.util.Try(tcd.toCase).toOption field <- tccd.cd.fields.find(_.id.name == selector) } yield { field.id } - CaseClassSelector(cc, e, id.get) + CaseClassSelector(e, id.get) } def ensures(other: Expr) = Ensuring(e, other) @@ -191,111 +192,6 @@ trait DSL { def in(e: Expr) = susp(e) } - def prec(e: Expr) = new BlockSuspension(Require(e, _)) - def assertion(e: Expr) = new BlockSuspension(Assert(e, None, _)) - def assertion(e: Expr, msg: String) = new BlockSuspension(Assert(e, Some(msg), _)) - - // Pattern-matching - implicit class PatternMatch(scrut: Expr) { - def matchOn(cases: MatchCase* ) = { - MatchExpr(scrut, cases.toList) - } - } - - /* Patterns */ - - // This introduces the rhs of a case given a pattern - implicit class PatternOps(pat: Pattern) { - - val guard: Option[Expr] = None - - def ==>(rhs: => Expr) = { - val Seq() = pat.binders - MatchCase(pat, guard, rhs) - } - def ==>(rhs: Variable => Expr) = { - val Seq(b1) = pat.binders - MatchCase(pat, guard, rhs(b1.toVariable)) - } - def ==>(rhs: (Variable, Variable) => Expr) = { - val Seq(b1, b2) = pat.binders - MatchCase(pat, guard, rhs(b1.toVariable, b2.toVariable)) - } - def ==>(rhs: (Variable, Variable, Variable) => Expr) = { - val Seq(b1, b2, b3) = pat.binders - MatchCase(pat, guard, rhs(b1.toVariable, b2.toVariable, b3.toVariable)) - } - def ==>(rhs: (Variable, Variable, Variable, Variable) => Expr) = { - val Seq(b1, b2, b3, b4) = pat.binders - MatchCase(pat, guard, - rhs(b1.toVariable, b2.toVariable, b3.toVariable, b4.toVariable)) - } - - def ~|~(g: Expr) = new PatternOpsWithGuard(pat, g) - } - - class PatternOpsWithGuard(pat: Pattern, g: Expr) extends PatternOps(pat) { - override val guard = Some(g) - override def ~|~(g: Expr) = sys.error("Redefining guard!") - } - - private def l2p[T](l: Literal[T]) = LiteralPattern(None, l) - // Literal patterns - def P(i: Int) = l2p(IntLiteral(i)) - def P(b: BigInt) = l2p(IntegerLiteral(b)) - def P(b: Boolean) = l2p(BooleanLiteral(b)) - def P(c: Char) = l2p(CharLiteral(c)) - def P() = l2p(UnitLiteral()) - // Binder-only patterns - def P(vd: ValDef) = WildcardPattern(Some(vd)) - - class CaseClassToPattern(ct: ClassType) { - def apply(ps: Pattern*) = CaseClassPattern(None, ct, ps.toSeq) - } - // case class patterns - def P(ct: ClassType) = new CaseClassToPattern(ct) - // Tuple patterns - def P(p1:Pattern, p2: Pattern, ps: Pattern*) = TuplePattern(None, p1 :: p2 :: ps.toList) - // Wildcard pattern - def __ = WildcardPattern(None) - // Attach binder to pattern - implicit class BinderToPattern(b: ValDef) { - def @@ (p: Pattern) = p.withBinder(b) - } - - // Instance-of patterns - implicit class TypeToInstanceOfPattern(ct: ClassType) { - def @:(vd: ValDef) = InstanceOfPattern(Some(vd), ct) - def @:(wp: WildcardPattern) = { - if (wp.binder.nonEmpty) sys.error("Instance of pattern with named wildcardpattern?") - else InstanceOfPattern(None, ct) - } // TODO Kinda dodgy... - } - - // TODO: Remove this at some point - private def testExpr(e1: Expr, e2: Expr, ct: ClassType)(implicit simpl: SimplificationLevel) = { - prec(e1) in - let("i" :: Untyped, e1) { i => - if_ (\("j" :: Untyped)(j => e1(j))) { - e1 + e2 + i + E(42) - } else_ { - assertion(E(true), "Weird things") in - ct(e1, e2) matchOn ( - P(ct)( - ("i" :: Untyped) @: ct, - P(42), - __ @: ct, - P("k" :: Untyped), - P(__, ( "j" :: Untyped) @@ P(42)) - ) ==> { - (i, j, k) => !e1 - }, - __ ~|~ e1 ==> e2 - ) - } - } - } ensures e2 - /* Types */ def T(tp1: Type, tp2: Type, tps: Type*) = TupleType(tp1 :: tp2 :: tps.toList) def T(id: Identifier) = new IdToClassType(id) diff --git a/src/main/scala/inox/ast/Definitions.scala b/src/main/scala/inox/ast/Definitions.scala index 1b497884ac1ecf3efdcb7a9df033e06157a832f3..143033d387893de407d0ce7484066309e9826420 100644 --- a/src/main/scala/inox/ast/Definitions.scala +++ b/src/main/scala/inox/ast/Definitions.scala @@ -77,8 +77,10 @@ trait Definitions { self: Trees => override def hashCode: Int = super[VariableSymbol].hashCode } + type Symbols <: AbstractSymbols + /** A wrapper for a program. For now a program is simply a single object. */ - case class Symbols(classes: Map[Identifier, ClassDef], functions: Map[Identifier, FunDef]) + trait AbstractSymbols extends Tree with TypeOps with SymbolOps @@ -86,6 +88,9 @@ trait Definitions { self: Trees => with Constructors with Paths { + protected val classes: Map[Identifier, ClassDef] + protected val functions: Map[Identifier, FunDef] + private[ast] val trees: self.type = self protected val symbols: this.type = this @@ -116,6 +121,8 @@ trait Definitions { self: Trees => def getFunction(id: Identifier, tps: Seq[Type]): TypedFunDef = lookupFunction(id, tps).getOrElse(throw FunctionLookupException(id)) def asString: String = asString(PrinterOptions.fromSymbols(this, InoxContext.printNames)) + + def transform(t: TreeTransformer): Symbols } case class TypeParameterDef(tp: TypeParameter) extends Definition { @@ -307,25 +314,11 @@ trait Definitions { self: Trees => val tparams: Seq[TypeParameterDef], val params: Seq[ValDef], val returnType: Type, - val fullBody: Expr, + val body: Option[Expr], val flags: Set[FunctionFlag] ) extends Definition { - /* Body manipulation */ - - lazy val body: Option[Expr] = exprOps.withoutSpec(fullBody) - lazy val precondition = exprOps.preconditionOf(fullBody) - lazy val precOrTrue = precondition getOrElse BooleanLiteral(true) - - lazy val postcondition = exprOps.postconditionOf(fullBody) - lazy val postOrTrue = postcondition getOrElse { - val arg = ValDef(FreshIdentifier("res", alwaysShowUniqueID = true), returnType) - Lambda(Seq(arg), BooleanLiteral(true)) - } - def hasBody = body.isDefined - def hasPrecondition = precondition.isDefined - def hasPostcondition = postcondition.isDefined def annotations: Set[String] = extAnnotations.keySet def extAnnotations: Map[String, Seq[Option[Any]]] = flags.collect { @@ -409,16 +402,7 @@ trait Definitions { self: Trees => lazy val returnType: Type = translated(fd.returnType) - lazy val fullBody = translated(fd.fullBody) lazy val body = fd.body map translated - lazy val precondition = fd.precondition map translated - lazy val precOrTrue = translated(fd.precOrTrue) - lazy val postcondition = fd.postcondition map (p => translated(p).asInstanceOf[Lambda]) - lazy val postOrTrue = translated(fd.postOrTrue).asInstanceOf[Lambda] - - def hasImplementation = body.isDefined - def hasBody = hasImplementation - def hasPrecondition = precondition.isDefined - def hasPostcondition = postcondition.isDefined + def hasBody = body.isDefined } } diff --git a/src/main/scala/inox/ast/ExprOps.scala b/src/main/scala/inox/ast/ExprOps.scala index 21151ac9acd05d84054f116920c291190c90a853..1f85f14f75f8a8da9d787f1c2852e8d51801c005 100644 --- a/src/main/scala/inox/ast/ExprOps.scala +++ b/src/main/scala/inox/ast/ExprOps.scala @@ -51,7 +51,6 @@ trait ExprOps extends GenTreeOps { e match { case v: Variable => subvs + v case Let(vd, _, _) => subvs - vd.toVariable - case MatchExpr(_, cses) => subvs -- cses.flatMap(_.pattern.binders).map(_.toVariable) case Lambda(args, _) => subvs -- args.map(_.toVariable) case Forall(args, _) => subvs -- args.map(_.toVariable) case _ => subvs @@ -86,8 +85,7 @@ trait ExprOps extends GenTreeOps { * unrolling solver. See implementation for what this means exactly. */ def isSimple(e: Expr): Boolean = !exists { - case (_: Assert) | (_: Ensuring) | - (_: Forall) | (_: Lambda) | + case (_: Assume) | (_: Forall) | (_: Lambda) | (_: FunctionInvocation) | (_: Application) => true case _ => false } (e) @@ -101,56 +99,6 @@ trait ExprOps extends GenTreeOps { }(v) } - override def formulaSize(e: Expr): Int = e match { - case ml: MatchExpr => - super.formulaSize(e) + ml.cases.map(cs => patternOps.formulaSize(cs.pattern)).sum - case _ => - super.formulaSize(e) - } - - /** Returns if this expression behaves as a purely functional construct, - * i.e. always returns the same value (for the same environment) and has no side-effects - */ - def isPurelyFunctional(e: Expr): Boolean = !exists { - case _ : Error => true - case _ => false - }(e) - - /** Extracts the body without its specification - * - * [[Expressions.Expr]] trees contain its specifications as part of certain nodes. - * This function helps extracting only the body part of an expression - * - * @return An option type with the resulting expression if not [[Expressions.NoTree]] - * @see [[Expressions.Ensuring]] - * @see [[Expressions.Require]] - */ - def withoutSpec(expr: Expr): Option[Expr] = expr match { - case Let(i, e, b) => withoutSpec(b).map(Let(i, e, _)) - case Require(pre, b) => Option(b).filterNot(_.isInstanceOf[NoTree]) - case Ensuring(Require(pre, b), post) => Option(b).filterNot(_.isInstanceOf[NoTree]) - case Ensuring(b, post) => Option(b).filterNot(_.isInstanceOf[NoTree]) - case b => Option(b).filterNot(_.isInstanceOf[NoTree]) - } - - /** Returns the precondition of an expression wrapped in Option */ - def preconditionOf(expr: Expr): Option[Expr] = expr match { - case Let(i, e, b) => preconditionOf(b).map(Let(i, e, _).copiedFrom(expr)) - case Require(pre, _) => Some(pre) - case Ensuring(Require(pre, _), _) => Some(pre) - case b => None - } - - /** Returns the postcondition of an expression wrapped in Option */ - def postconditionOf(expr: Expr): Option[Lambda] = expr match { - case Let(i, e, b) => postconditionOf(b).map(l => l.copy(body = Let(i, e, l.body)).copiedFrom(expr)) - case Ensuring(_, post: Lambda) => Some(post) - case _ => None - } - - /** Returns a tuple of precondition, the raw body and the postcondition of an expression */ - def breakDownSpecs(e: Expr) = (preconditionOf(e), withoutSpec(e), postconditionOf(e)) - def preTraversalWithParent(f: (Expr, Option[Tree]) => Unit, initParent: Option[Tree] = None)(e: Expr): Unit = { val rec = preTraversalWithParent(f, Some(e)) _ diff --git a/src/main/scala/inox/ast/Expressions.scala b/src/main/scala/inox/ast/Expressions.scala index b454e70b22defd5de120b6ebcac47f34f10b5d9d..23cc7170e32db81187148a6010bd806b5fce8384 100644 --- a/src/main/scala/inox/ast/Expressions.scala +++ b/src/main/scala/inox/ast/Expressions.scala @@ -47,70 +47,12 @@ trait Expressions { self: Trees => } - /** Stands for an undefined Expr, similar to `???` or `null` */ - case class NoTree(tpe: Type) extends Expr with Terminal { - def getType(implicit s: Symbols): Type = tpe - } - - - /* Specifications */ - - /** Computational errors (unmatched case, taking min of an empty set, - * division by zero, etc.). Corresponds to the ``error[T](description)`` - * Leon library function. - * It should always be typed according to the expected type. - * - * @param tpe The type of this expression - * @param description The description of the error - */ - case class Error(tpe: Type, description: String) extends Expr with Terminal { - def getType(implicit s: Symbols): Type = tpe - } - - /** Precondition of an [[Expressions.Expr]]. Corresponds to the Leon keyword *require* - * - * @param pred The precondition formula inside ``require(...)`` - * @param body The body following the ``require(...)`` - */ - case class Require(pred: Expr, body: Expr) extends Expr with CachingTyped { - protected def computeType(implicit s: Symbols): Type = { - if (pred.getType == BooleanType) body.getType - else Untyped - } - } - - /** Postcondition of an [[Expressions.Expr]]. Corresponds to the Leon keyword *ensuring* - * - * @param body The body of the expression. It can contain at most one [[Expressions.Require]] sub-expression. - * @param pred The predicate to satisfy. It should be a function whose argument's type can handle the type of the body - */ - case class Ensuring(body: Expr, pred: Expr) extends Expr with CachingTyped { - require(pred.isInstanceOf[Lambda]) - - protected def computeType(implicit s: Symbols) = pred.getType match { - case FunctionType(Seq(bodyType), BooleanType) if s.isSubtypeOf(body.getType, bodyType) => - body.getType - case _ => - Untyped - } - - /** Converts this ensuring clause to the body followed by an assert statement */ - def toAssert(implicit s: Symbols): Expr = { - val res = ValDef(FreshIdentifier("res", true), getType) - Let(res, body, Assert( - s.application(pred, Seq(res.toVariable)), - Some("Postcondition failed @" + this.getPos), res.toVariable - )) - } - } - - /** Local assertions with customizable error message + /** Local assumption * - * @param pred The predicate, first argument of `assert(..., ...)` - * @param error An optional error string to display if the assert fails. Second argument of `assert(..., ...)` - * @param body The expression following `assert(..., ...)` + * @param pred The predicate to be assumed + * @param body The expression following `assume(pred)` */ - case class Assert(pred: Expr, error: Option[String], body: Expr) extends Expr with CachingTyped { + case class Assume(pred: Expr, body: Expr) extends Expr with CachingTyped { protected def computeType(implicit s: Symbols): Type = { if (pred.getType == BooleanType) body.getType else Untyped @@ -197,159 +139,6 @@ trait Expressions { self: Trees => s.leastUpperBound(thenn.getType, elze.getType).getOrElse(Untyped).unveilUntyped } - /** $encodingof `... match { ... }` - * - * '''cases''' should be nonempty. If you are not sure about this, you should use - * [[purescala.Constructors#matchExpr purescala's constructor matchExpr]] - * - * @param scrutinee Expression to the left of the '''match''' keyword - * @param cases A sequence of cases to match `scrutinee` against - */ - case class MatchExpr(scrutinee: Expr, cases: Seq[MatchCase]) extends Expr with CachingTyped { - require(cases.nonEmpty) - protected def computeType(implicit s: Symbols): Type = - s.leastUpperBound(cases.map(_.rhs.getType)).getOrElse(Untyped).unveilUntyped - } - - /** $encodingof `case pattern [if optGuard] => rhs` - * - * @param pattern The pattern just to the right of the '''case''' keyword - * @param optGuard An optional if-condition just to the left of the `=>` - * @param rhs The expression to the right of `=>` - * @see [[Expressions.MatchExpr]] - */ - case class MatchCase(pattern: Pattern, optGuard: Option[Expr], rhs: Expr) extends Tree { - def expressions: Seq[Expr] = optGuard.toList :+ rhs - } - - /** $encodingof a pattern after a '''case''' keyword. - * - * @see [[Expressions.MatchCase]] - */ - sealed abstract class Pattern extends Tree { - val subPatterns: Seq[Pattern] - val binder: Option[ValDef] - - private def subBinders = subPatterns.flatMap(_.binders) - def binders: Seq[ValDef] = binder.toSeq ++ subBinders - - def withBinder(b: ValDef) = { this match { - case Pattern(None, subs, builder) => builder(Some(b), subs) - case other => other - }}.copiedFrom(this) - } - - /** Pattern encoding `case binder: ct` - * - * If [[binder]] is empty, consider a wildcard `_` in its place. - */ - case class InstanceOfPattern(binder: Option[ValDef], ct: ClassType) extends Pattern { - val subPatterns = Seq() - } - - /** Pattern encoding `case _ => `, or `case binder => ` if identifier [[binder]] is present */ - case class WildcardPattern(binder: Option[ValDef]) extends Pattern { // c @ _ - val subPatterns = Seq() - } - - /** Pattern encoding `case binder @ ct(subPatterns...) =>` - * - * If [[binder]] is empty, consider a wildcard `_` in its place. - */ - case class CaseClassPattern(binder: Option[ValDef], ct: ClassType, subPatterns: Seq[Pattern]) extends Pattern - - /** Pattern encoding tuple pattern `case binder @ (subPatterns...) =>` - * - * If [[binder]] is empty, consider a wildcard `_` in its place. - */ - case class TuplePattern(binder: Option[ValDef], subPatterns: Seq[Pattern]) extends Pattern - - /** Pattern encoding like `case binder @ 0 => ...` or `case binder @ "Foo" => ...` - * - * If [[binder]] is empty, consider a wildcard `_` in its place. - */ - case class LiteralPattern[+T](binder: Option[ValDef], lit: Literal[T]) extends Pattern { - val subPatterns = Seq() - } - - /** A custom pattern defined through an object's `unapply` function */ - case class UnapplyPattern(binder: Option[ValDef], fd: Identifier, tps: Seq[Type], subPatterns: Seq[Pattern]) extends Pattern { - // Hacky, but ok - def optionType(implicit s: Symbols) = s.getFunction(fd, tps).returnType.asInstanceOf[ClassType] - def optionChildren(implicit s: Symbols): (ClassType, ClassType) = { - val children = optionType.tcd.asInstanceOf[TypedAbstractClassDef].descendants.sortBy(_.fields.size) - val Seq(noneType, someType) = children.map(_.toType) - (noneType, someType) - } - - def noneType(implicit s: Symbols): ClassType = optionChildren(s)._1 - def someType(implicit s: Symbols): ClassType = optionChildren(s)._2 - def someValue(implicit s: Symbols): ValDef = someType.tcd.asInstanceOf[TypedCaseClassDef].fields.head - - /** Construct a pattern matching against unapply(scrut) (as an if-expression) - * - * @param scrut The scrutinee of the pattern matching - * @param noneCase The expression that will happen if unapply(scrut) is None - * @param someCase How unapply(scrut).get will be handled in case it exists - */ - def patternMatch(scrut: Expr, noneCase: Expr, someCase: Expr => Expr)(implicit s: Symbols): Expr = { - // We use this hand-coded if-then-else because we don't want to generate - // match exhaustiveness checks in the program - val vd = ValDef(FreshIdentifier("unap", true), optionType) - Let( - vd, - FunctionInvocation(fd, tps, Seq(scrut)), - IfExpr( - IsInstanceOf(vd.toVariable, someType), - someCase(CaseClassSelector(someType, vd.toVariable, someValue.id)), - noneCase - ) - ) - } - - /** Inlined .get method */ - def get(scrut: Expr)(implicit s: Symbols) = patternMatch( - scrut, - Error(optionType.tps.head, "None.get"), - e => e - ) - - /** Selects Some.v field without type-checking. - * Use in a context where scrut.isDefined returns true. - */ - def getUnsafe(scrut: Expr)(implicit s: Symbols) = CaseClassSelector( - someType, - FunctionInvocation(fd, tps, Seq(scrut)), - someValue.id - ) - - def isSome(scrut: Expr)(implicit s: Symbols) = - IsInstanceOf(FunctionInvocation(fd, tps, Seq(scrut)), someType) - } - - // Extracts without taking care of the binder. (contrary to Extractos.Pattern) - object PatternExtractor extends TreeExtractor { - val trees: self.type = self - type SubTree = Pattern - - def unapply(e: Pattern): Option[(Seq[Pattern], (Seq[Pattern]) => Pattern)] = e match { - case (_: InstanceOfPattern) | (_: WildcardPattern) | (_: LiteralPattern[_]) => - Some(Seq(), es => e) - case CaseClassPattern(vd, ct, subpatterns) => - Some(subpatterns, es => CaseClassPattern(vd, ct, es)) - case TuplePattern(vd, subpatterns) => - Some(subpatterns, es => TuplePattern(vd, es)) - case UnapplyPattern(vd, id, tps, subpatterns) => - Some(subpatterns, es => UnapplyPattern(vd, id, tps, es)) - case _ => None - } - } - - object patternOps extends GenTreeOps { - val trees: self.type = self - type SubTree = Pattern - val Deconstructor = PatternExtractor - } /** Literals */ @@ -455,15 +244,14 @@ trait Expressions { self: Trees => * If you are not sure about the requirement you should use * [[purescala.Constructors#caseClassSelector purescala's constructor caseClassSelector]] */ - case class CaseClassSelector(classType: ClassType, caseClass: Expr, selector: Identifier) extends Expr with CachingTyped { - protected def computeType(implicit s: Symbols): Type = classType.lookupClass match { - case Some(tcd: TypedCaseClassDef) => - val index = tcd.cd.selectorID2Index(selector) - if (classType == caseClass.getType) { + case class CaseClassSelector(caseClass: Expr, selector: Identifier) extends Expr with CachingTyped { + protected def computeType(implicit s: Symbols): Type = caseClass.getType match { + case ct: ClassType => ct.lookupClass match { + case Some(tcd: TypedCaseClassDef) => + val index = tcd.cd.selectorID2Index(selector) tcd.fieldsTypes(index) - } else { - Untyped - } + case _ => Untyped + } case _ => Untyped } } diff --git a/src/main/scala/inox/ast/Extractors.scala b/src/main/scala/inox/ast/Extractors.scala index 699874093161dd0526e0930b609163709698325d..4795bd3b1b7e1f6aade06d7db41a2dbe5eec2fde 100644 --- a/src/main/scala/inox/ast/Extractors.scala +++ b/src/main/scala/inox/ast/Extractors.scala @@ -44,8 +44,8 @@ trait Extractors { self: Trees => Some((Seq(t), (es: Seq[Expr]) => RealToString(es.head))) case SetCardinality(t) => Some((Seq(t), (es: Seq[Expr]) => SetCardinality(es.head))) - case CaseClassSelector(cd, e, sel) => - Some((Seq(e), (es: Seq[Expr]) => CaseClassSelector(cd, es.head, sel))) + case CaseClassSelector(e, sel) => + Some((Seq(e), (es: Seq[Expr]) => CaseClassSelector(es.head, sel))) case IsInstanceOf(e, ct) => Some((Seq(e), (es: Seq[Expr]) => IsInstanceOf(es.head, ct))) case AsInstanceOf(e, ct) => @@ -118,12 +118,6 @@ trait Extractors { self: Trees => Some(Seq(t1, t2), (es: Seq[Expr]) => MapApply(es(0), es(1))) case Let(binder, e, body) => Some(Seq(e, body), (es: Seq[Expr]) => Let(binder, es(0), es(1))) - case Require(pre, body) => - Some(Seq(pre, body), (es: Seq[Expr]) => Require(es(0), es(1))) - case Ensuring(body, post) => - Some(Seq(body, post), (es: Seq[Expr]) => Ensuring(es(0), es(1))) - case Assert(const, oerr, body) => - Some(Seq(const, body), (es: Seq[Expr]) => Assert(es(0), oerr, es(1))) /* Other operators */ case fi @ FunctionInvocation(fd, tps, args) => Some((args, FunctionInvocation(fd, tps, _))) @@ -165,18 +159,6 @@ trait Extractors { self: Trees => Seq(cond, thenn, elze), { case Seq(c, t, e) => IfExpr(c, t, e) } )) - case m @ MatchExpr(scrut, cases) => Some(( - scrut +: cases.flatMap { _.expressions }, - (es: Seq[Expr]) => { - var i = 1 - val newcases = for (caze <- cases) yield caze match { - case SimpleCase(b, _) => i += 1; SimpleCase(b, es(i - 1)) - case GuardedCase(b, _, _) => i += 2; GuardedCase(b, es(i - 2), es(i - 1)) - } - - MatchExpr(es.head, newcases) - } - )) /* Terminals */ case t: Terminal => Some(Seq[Expr](), (_:Seq[Expr]) => t) @@ -231,37 +213,6 @@ trait Extractors { self: Trees => } } - object SimpleCase { - def apply(p: Pattern, rhs: Expr) = MatchCase(p, None, rhs) - def unapply(c : MatchCase) = c match { - case MatchCase(p, None, rhs) => Some((p, rhs)) - case _ => None - } - } - - object GuardedCase { - def apply(p: Pattern, g: Expr, rhs: Expr) = MatchCase(p, Some(g), rhs) - def unapply(c : MatchCase) = c match { - case MatchCase(p, Some(g), rhs) => Some((p, g, rhs)) - case _ => None - } - } - - object Pattern { - def unapply(p: Pattern) : Option[( - Option[ValDef], - Seq[Pattern], - (Option[ValDef], Seq[Pattern]) => Pattern - )] = Option(p) map { - case InstanceOfPattern(b, ct) => (b, Seq(), (b, _) => InstanceOfPattern(b,ct)) - case WildcardPattern(b) => (b, Seq(), (b, _) => WildcardPattern(b)) - case CaseClassPattern(b, ct, subs) => (b, subs, (b, sp) => CaseClassPattern(b, ct, sp)) - case TuplePattern(b,subs) => (b, subs, (b, sp) => TuplePattern(b, sp)) - case LiteralPattern(b, l) => (b, Seq(), (b, _) => LiteralPattern(b, l)) - case UnapplyPattern(b, id, tps, subs) => (b, subs, (b, sp) => UnapplyPattern(b, id, tps, sp)) - } - } - def unwrapTuple(e: Expr, isTuple: Boolean)(implicit s: Symbols): Seq[Expr] = e.getType match { case TupleType(subs) if isTuple => for (ind <- 1 to subs.size) yield { s.tupleSelect(e, ind, isTuple) } @@ -279,39 +230,4 @@ trait Extractors { self: Trees => def unwrapTupleType(tp: Type, expectedSize: Int): Seq[Type] = unwrapTupleType(tp, expectedSize > 1) - - def unwrapTuplePattern(p: Pattern, isTuple: Boolean): Seq[Pattern] = p match { - case TuplePattern(_, subs) if isTuple => subs - case tp if !isTuple => Seq(tp) - case tp => sys.error(s"Calling unwrapTuplePattern on $p") - } - - def unwrapTuplePattern(p: Pattern, expectedSize: Int): Seq[Pattern] = - unwrapTuplePattern(p, expectedSize > 1) - - object LetPattern { - def apply(patt: Pattern, value: Expr, body: Expr) : Expr = { - patt match { - case WildcardPattern(Some(binder)) => Let(binder, value, body) - case _ => MatchExpr(value, List(SimpleCase(patt, body))) - } - } - - def unapply(me: MatchExpr) : Option[(Pattern, Expr, Expr)] = { - Option(me) collect { - case MatchExpr(scrut, List(SimpleCase(pattern, body))) if !aliasedSymbols(pattern.binders.toSet, exprOps.variablesOf(scrut)) => - ( pattern, scrut, body ) - } - } - } - - object LetTuple { - def unapply(me: MatchExpr) : Option[(Seq[ValDef], Expr, Expr)] = { - Option(me) collect { - case LetPattern(TuplePattern(None, subPatts), value, body) if - subPatts forall { case WildcardPattern(Some(_)) => true; case _ => false } => - (subPatts map { _.binder.get }, value, body ) - } - } - } } diff --git a/src/main/scala/inox/ast/Paths.scala b/src/main/scala/inox/ast/Paths.scala index e39e79757dca34b4529abbfbad6dafa2d8dc26d0..88e00a617abb0fa2490319b2b161c532257a9ca6 100644 --- a/src/main/scala/inox/ast/Paths.scala +++ b/src/main/scala/inox/ast/Paths.scala @@ -173,30 +173,30 @@ trait Paths { self: TypeOps with Constructors => /** Fold the path into an implication of `base`, namely `path ==> base` */ def implies(base: Expr) = distributiveClause(base, self.implies) - /** Folds the path into a `require` wrapping the expression `body` + /** Folds the path into an expression that shares the path's outer lets * - * The function takes additional optional parameters - * - [[pre]] if one wishes to mix a pre-existing precondition into the final - * [[leon.purescala.Expressions.Require]], and - * - [[post]] for mixing a postcondition ([[leon.purescala.Expressions.Ensuring]]) in. + * The folding shares all outer bindings in an wrapping sequence of + * let-expressions. The inner condition is then passed as the first + * argument of the [[recons]] function and must be shared out between + * the reconstructions of [[es]] which will only feature the bindings + * from the current path. + * + * This method is useful to reconstruct if-expressions or assumptions + * where the condition can be added to the expression in a position + * that implies further positions. */ - def specs(body: Expr, pre: Expr = BooleanLiteral(true), post: Expr = NoTree(BooleanType)) = { + def withShared(es: Seq[Expr], recons: (Expr, Seq[Expr]) => Expr): Expr = { val (outers, rest) = elements.span(_.isLeft) val cond = fold[Expr](BooleanLiteral(true), let, self.and(_, _))(rest) def wrap(e: Expr): Expr = { val bindings = rest.collect { case Left((vd, e)) => vd -> e } - val vdSubst = bindings.map(p => p._1 -> p._1.freshen).toMap - val replace = exprOps.replaceFromSymbols(vdSubst.mapValues(_.toVariable), _: Expr) - bindings.foldRight(replace(e)) { case ((vd, e), b) => let(vdSubst(vd), replace(e), b) } - } - - val req = Require(self.and(cond, wrap(pre)), wrap(body)) - val full = post match { - case l @ Lambda(args, body) => Ensuring(req, Lambda(args, wrap(body)).copiedFrom(l)) - case _ => req + val subst = bindings.map(p => p._1 -> p._1.toVariable.freshen).toMap + val replace = exprOps.replaceFromSymbols(subst, _: Expr) + bindings.foldRight(replace(e)) { case ((vd, e), b) => let(subst(vd).toVal, replace(e), b) } } + val full = recons(cond, es.map(wrap)) fold[Expr](full, let, (_, _) => scala.sys.error("Should never happen!"))(outers) } diff --git a/src/main/scala/inox/ast/Printers.scala b/src/main/scala/inox/ast/Printers.scala index 979cffc572ec800ba60e138a945bc8ae2c593851..3d7f09c6b40135adebbe257ffa05668b58677e65 100644 --- a/src/main/scala/inox/ast/Printers.scala +++ b/src/main/scala/inox/ast/Printers.scala @@ -96,25 +96,6 @@ trait Printers { self: Trees => p"""|val $b = $d |$e""" - case Require(pre, body) => - p"""|require($pre) - |$body""" - - case Assert(const, Some(err), body) => - p"""|assert($const, "$err") - |$body""" - - case Assert(const, None, body) => - p"""|assert($const) - |$body""" - - case Ensuring(body, post) => - p"""| { - | $body - |} ensuring { - | $post - |}""" - case Forall(args, e) => p"\u2200${nary(args)}. $e" @@ -160,11 +141,9 @@ trait Printers { self: Trees => case GenericValue(tp, id) => p"$tp#$id" case Tuple(exprs) => p"($exprs)" case TupleSelect(t, i) => p"$t._$i" - case NoTree(tpe) => p"<empty tree>[$tpe]" - case e @ Error(tpe, err) => p"""error[$tpe]("$err")""" case AsInstanceOf(e, ct) => p"""$e.asInstanceOf[$ct]""" case IsInstanceOf(e, cct) => p"$e.isInstanceOf[$cct]" - case CaseClassSelector(_, e, id) => p"$e.$id" + case CaseClassSelector(e, id) => p"$e.$id" case FunctionInvocation(id, tps, args) => p"${id}${nary(tps, ", ", "[", "]")}" @@ -240,49 +219,6 @@ trait Printers { self: Trees => |}""" } - case LetPattern(p,s,rhs) => - p"""|val $p = $s - |$rhs""" - - case MatchExpr(s, csc) => - optP { - p"""|$s match { - | ${nary(csc, "\n")} - |}""" - } - - // Cases - case MatchCase(pat, optG, rhs) => - p"|case $pat "; optG foreach { g => p"if $g "}; p"""=> - | $rhs""" - - // Patterns - case WildcardPattern(None) => p"_" - case WildcardPattern(Some(id)) => p"$id" - - case CaseClassPattern(ob, ct, subps) => - ob.foreach { b => p"$b @ " } - // Print only the classDef because we don't want type parameters in patterns - p"${ct.id}" - p"($subps)" - - case InstanceOfPattern(ob, cct) => - ob.foreach { b => p"$b : " } - // It's ok to print the whole type although scalac will complain about erasure - p"$cct" - - case TuplePattern(ob, subps) => - ob.foreach { b => p"$b @ " } - p"($subps)" - - case UnapplyPattern(ob, id, tps, subps) => - ob.foreach { b => p"$b @ " } - p"$id(${nary(subps)})" - - case LiteralPattern(ob, lit) => - ob foreach { b => p"$b @ " } - p"$lit" - // Types case Untyped => p"<untyped>" case UnitType => p"Unit" @@ -330,7 +266,12 @@ trait Printers { self: Trees => p"(${fd.params}): " } - p"${fd.returnType} = ${fd.fullBody}" + p"${fd.returnType} = " + + fd.body match { + case Some(body) => p"$body" + case None => p"???" + } case (tree: PrettyPrintable) => tree.printWith(ctx) @@ -365,25 +306,20 @@ trait Printers { self: Trees => } protected def isSimpleExpr(e: Expr): Boolean = e match { - case _: Let | LetPattern(_, _, _) | _: Assert | _: Require => false + case _: Let => false case p: PrettyPrintable => p.isSimpleExpr case _ => true } protected def noBracesSub(e: Expr): Seq[Expr] = e match { - case Assert(_, _, bd) => Seq(bd) case Let(_, _, bd) => Seq(bd) - case LetPattern(_, _, bd) => Seq(bd) - case Require(_, bd) => Seq(bd) case IfExpr(_, t, e) => Seq(t, e) // if-else always has braces anyway - case Ensuring(bd, pred) => Seq(bd, pred) case _ => Seq() } protected def requiresBraces(ex: Tree, within: Option[Tree]) = (ex, within) match { case (e: Expr, _) if isSimpleExpr(e) => false case (e: Expr, Some(within: Expr)) if noBracesSub(within) contains e => false - case (_: Expr, Some(_: MatchCase)) => false case (e: Expr, Some(_)) => true case _ => false } @@ -405,14 +341,11 @@ trait Printers { self: Trees => case (pa: PrettyPrintable, _) => pa.printRequiresParentheses(within) case (_, None) => false case (_, Some( - _: Ensuring | _: Assert | _: Require | _: Definition | _: MatchExpr | _: MatchCase | - _: Let | _: IfExpr | _ : CaseClass | _ : Lambda | _ : Tuple + _: Definition | _: Let | _: IfExpr | _ : CaseClass | _ : Lambda | _ : Tuple )) => false - case (_:Pattern, _) => false case (ex: StringConcat, Some(_: StringConcat)) => false case (_, Some(_: FunctionInvocation)) => false case (ie: IfExpr, _) => true - case (me: MatchExpr, _ ) => true case (e1: Expr, Some(e2: Expr)) if precedence(e1) > precedence(e2) => false case (_, _) => true } diff --git a/src/main/scala/inox/ast/SymbolOps.scala b/src/main/scala/inox/ast/SymbolOps.scala index 480a21b6b48b297de1c819e9fe6d2ee08d3e4a16..40e2618c7f0b8f8fb176792ac7edc8724dc464fc 100644 --- a/src/main/scala/inox/ast/SymbolOps.scala +++ b/src/main/scala/inox/ast/SymbolOps.scala @@ -24,7 +24,7 @@ import utils._ * operations on Leon expressions. * */ -trait SymbolOps extends TreeOps { self: TypeOps => +trait SymbolOps { self: TypeOps => import trees._ import trees.exprOps._ import symbols._ @@ -57,7 +57,7 @@ trait SymbolOps extends TreeOps { self: TypeOps => def step(e: Expr): Option[Expr] = e match { case Not(t) => Some(not(t)) case UMinus(t) => Some(uminus(t)) - case CaseClassSelector(cd, e, sel) => Some(caseClassSelector(cd, e, sel)) + case CaseClassSelector(e, sel) => Some(caseClassSelector(e, sel)) case AsInstanceOf(e, ct) => Some(asInstOf(e, ct)) case Equals(t1, t2) => Some(equality(t1, t2)) case Implies(t1, t2) => Some(implies(t1, t2)) @@ -67,7 +67,6 @@ trait SymbolOps extends TreeOps { self: TypeOps => case And(args) => Some(andJoin(args)) case Or(args) => Some(orJoin(args)) case Tuple(args) => Some(tupleWrap(args)) - case MatchExpr(scrut, cases) => Some(matchExpr(scrut, cases)) case _ => None } postMap(step)(expr) @@ -79,13 +78,10 @@ trait SymbolOps extends TreeOps { self: TypeOps => case TupleSelect(Let(id, v, b), ts) => Some(Let(id, v, tupleSelect(b, ts, true))) - case TupleSelect(LetTuple(ids, v, b), ts) => - Some(letTuple(ids, v, tupleSelect(b, ts, true))) + case CaseClassSelector(cc: CaseClass, id) => + Some(caseClassSelector(cc, id).copiedFrom(e)) - case CaseClassSelector(cct, cc: CaseClass, id) => - Some(caseClassSelector(cct, cc, id).copiedFrom(e)) - - case IfExpr(c, thenn, elze) if (thenn == elze) && isPurelyFunctional(c) => + case IfExpr(c, thenn, elze) if thenn == elze => Some(thenn) case IfExpr(c, BooleanLiteral(true), BooleanLiteral(false)) => @@ -170,7 +166,7 @@ trait SymbolOps extends TreeOps { self: TypeOps => val n = new Normalizer // this registers the argument images into n.subst val bindings = args map n.transform - val normalized = n.transform(matchToIfThenElse(expr)) + val normalized = n.transform(expr) val freeVars = variablesOf(normalized) -- bindings.map(_.toVariable) val bodySubst = n.subst.filter(p => freeVars(p._1)) @@ -194,7 +190,6 @@ trait SymbolOps extends TreeOps { self: TypeOps => case v: Variable if s.isDefinedAt(v) => rec(s(v), s) case l @ Let(i,e,b) => rec(b, s + (i.toVariable -> rec(e, s))) case i @ IfExpr(t1,t2,t3) => IfExpr(rec(t1, s),rec(t2, s),rec(t3, s)).copiedFrom(i) - case m @ MatchExpr(scrut, cses) => matchExpr(rec(scrut, s), cses.map(inCase(_, s))).copiedFrom(m) case n @ Deconstructor(args, recons) => var change = false val rargs = args.map(a => { @@ -213,18 +208,12 @@ trait SymbolOps extends TreeOps { self: TypeOps => case unhandled => scala.sys.error("Unhandled case in expandLets: " + unhandled) } - def inCase(cse: MatchCase, s: Map[Variable,Expr]) : MatchCase = { - import cse._ - MatchCase(pattern, optGuard map { rec(_, s) }, rec(rhs,s)) - } - rec(expr, Map.empty) } /** Lifts lets to top level. * * Does not push any used variable out of scope. - * Assumes no match expressions (i.e. matchToIfThenElse has been called on e) */ def liftLets(e: Expr): Expr = { @@ -247,187 +236,6 @@ trait SymbolOps extends TreeOps { self: TypeOps => defs.foldRight(bd){ case ((vd, e), body) => Let(vd, e, body) } } - /** Recursively transforms a pattern on a boolean formula expressing the conditions for the input expression, possibly including name binders - * - * For example, the following pattern on the input `i` - * {{{ - * case m @ MyCaseClass(t: B, (_, 7)) => - * }}} - * will yield the following condition before simplification (to give some flavour) - * - * {{{and(IsInstanceOf(MyCaseClass, i), and(Equals(m, i), InstanceOfClass(B, i.t), equals(i.k.arity, 2), equals(i.k._2, 7))) }}} - * - * Pretty-printed, this would be: - * {{{ - * i.instanceOf[MyCaseClass] && m == i && i.t.instanceOf[B] && i.k.instanceOf[Tuple2] && i.k._2 == 7 - * }}} - * - * @see [[purescala.Expressions.Pattern]] - */ - def conditionForPattern(in: Expr, pattern: Pattern, includeBinders: Boolean = false): Path = { - def bind(ob: Option[ValDef], to: Expr): Path = { - if (!includeBinders) { - Path.empty - } else { - ob.map(v => Path.empty withBinding (v -> to)).getOrElse(Path.empty) - } - } - - def rec(in: Expr, pattern: Pattern): Path = { - pattern match { - case WildcardPattern(ob) => - bind(ob, in) - - case InstanceOfPattern(ob, ct) => - val tcd = ct.tcd - if (tcd.root == tcd) { - bind(ob, in) - } else { - Path(IsInstanceOf(in, ct)) merge bind(ob, in) - } - - case CaseClassPattern(ob, cct, subps) => - val tccd = cct.tcd.toCase - assert(tccd.fields.size == subps.size) - val pairs = tccd.fields.map(_.id).toList zip subps.toList - val subTests = pairs.map(p => rec(caseClassSelector(cct, in, p._1), p._2)) - Path(IsInstanceOf(in, cct)) merge bind(ob, in) merge subTests - - case TuplePattern(ob, subps) => - val TupleType(tpes) = in.getType - assert(tpes.size == subps.size) - val subTests = subps.zipWithIndex.map { - case (p, i) => rec(tupleSelect(in, i+1, subps.size), p) - } - bind(ob, in) merge subTests - - case up @ UnapplyPattern(ob, id, tps, subps) => - val subs = unwrapTuple(up.get(in), subps.size).zip(subps) map (rec _).tupled - bind(ob, in) withCond up.isSome(in) merge subs - - case LiteralPattern(ob, lit) => - Path(Equals(in, lit)) merge bind(ob, in) - } - } - - rec(in, pattern) - } - - /** Converts the pattern applied to an input to a map between identifiers and expressions */ - def mapForPattern(in: Expr, pattern: Pattern): Map[Variable,Expr] = { - def bindIn(ov: Option[ValDef], cast: Option[ClassType] = None): Map[Variable, Expr] = ov match { - case None => Map() - case Some(v) => Map(v.toVariable -> cast.map(asInstOf(in, _)).getOrElse(in)) - } - - pattern match { - case CaseClassPattern(b, ct, subps) => - val tcd = ct.tcd.toCase - assert(tcd.fields.size == subps.size) - val pairs = tcd.fields.map(_.id).toList zip subps.toList - val subMaps = pairs.map(p => mapForPattern(caseClassSelector(ct, asInstOf(in, ct), p._1), p._2)) - val together = subMaps.flatten.toMap - bindIn(b, Some(ct)) ++ together - - case TuplePattern(b, subps) => - val TupleType(tpes) = in.getType - assert(tpes.size == subps.size) - - val maps = subps.zipWithIndex.map{case (p, i) => mapForPattern(tupleSelect(in, i+1, subps.size), p)} - val map = maps.flatten.toMap - bindIn(b) ++ map - - case up @ UnapplyPattern(b, _, _, subps) => - bindIn(b) ++ unwrapTuple(up.getUnsafe(in), subps.size).zip(subps).flatMap { - case (e, p) => mapForPattern(e, p) - }.toMap - - case InstanceOfPattern(b, ct) => - bindIn(b, Some(ct)) - - case other => - bindIn(other.binder) - } - } - - /** Rewrites all pattern-matching expressions into if-then-else expressions - * Introduces additional error conditions. Does not introduce additional variables. - */ - def matchToIfThenElse(expr: Expr): Expr = { - - def rewritePM(e: Expr): Option[Expr] = e match { - case m @ MatchExpr(scrut, cases) => - // println("Rewriting the following PM: " + e) - - val condsAndRhs = for (cse <- cases) yield { - val map = mapForPattern(scrut, cse.pattern) - val patCond = conditionForPattern(scrut, cse.pattern, includeBinders = false) - val realCond = cse.optGuard match { - case Some(g) => patCond withCond replaceFromSymbols(map, g) - case None => patCond - } - val newRhs = replaceFromSymbols(map, cse.rhs) - (realCond.toClause, newRhs, cse) - } - - val bigIte = condsAndRhs.foldRight[Expr](Error(m.getType, "Match is non-exhaustive").copiedFrom(m))((p1, ex) => { - if(p1._1 == BooleanLiteral(true)) { - p1._2 - } else { - IfExpr(p1._1, p1._2, ex).copiedFrom(p1._3) - } - }) - - Some(bigIte) - - case _ => None - } - - preMap(rewritePM)(expr) - } - - /** For each case in the [[purescala.Expressions.MatchExpr MatchExpr]], - * concatenates the path condition with the newly induced conditions. - * Each case holds the conditions on other previous cases as negative. - * @note The guard of the final case is NOT included in the Paths. - * - * @see [[purescala.ExprOps#conditionForPattern conditionForPattern]] - * @see [[purescala.ExprOps#mapForPattern mapForPattern]] - */ - def matchExprCaseConditions(m: MatchExpr, path: Path): Seq[Path] = { - val MatchExpr(scrut, cases) = m - var pcSoFar = path - - for (c <- cases) yield { - val cond = conditionForPattern(scrut, c.pattern, includeBinders = true) - val localCond = pcSoFar merge cond - - // These contain no binders defined in this MatchCase - val condSafe = conditionForPattern(scrut, c.pattern) - val g = c.optGuard getOrElse BooleanLiteral(true) - val gSafe = replaceFromSymbols(mapForPattern(scrut, c.pattern), g) - pcSoFar = pcSoFar merge (condSafe withCond gSafe).negate - - localCond - } - } - - /** Condition to pass this match case, expressed w.r.t scrut only */ - def matchCaseCondition(scrut: Expr, c: MatchCase): Path = { - - val patternC = conditionForPattern(scrut, c.pattern, includeBinders = false) - - c.optGuard match { - case Some(g) => - // guard might refer to binders - val map = mapForPattern(scrut, c.pattern) - patternC withCond replaceFromSymbols(map, g) - - case None => - patternC - } - } - private def hasInstance(tcd: TypedClassDef): Boolean = { val recursive = Set(tcd, tcd.root) @@ -585,42 +393,10 @@ trait SymbolOps extends TreeOps { self: TypeOps => rec(v, path) ++ rec(b, path withBinding (i -> v)) - case Ensuring(Require(pre, body), Lambda(Seq(arg), post)) => - rec(pre, path) ++ - rec(body, path withCond pre) ++ - rec(post, path withCond pre withBinding (arg -> body)) - - case Ensuring(body, Lambda(Seq(arg), post)) => - rec(body, path) ++ - rec(post, path withBinding (arg -> body)) - - case Require(pre, body) => - rec(pre, path) ++ - rec(body, path withCond pre) - - case Assert(pred, err, body) => + case Assume(pred, body) => rec(pred, path) ++ rec(body, path withCond pred) - case MatchExpr(scrut, cases) => - val rs = rec(scrut, path) - var soFar = path - - rs ++ cases.flatMap { c => - val patternPathPos = conditionForPattern(scrut, c.pattern, includeBinders = true) - val patternPathNeg = conditionForPattern(scrut, c.pattern, includeBinders = false) - val map = mapForPattern(scrut, c.pattern) - val guardOrTrue = c.optGuard.getOrElse(BooleanLiteral(true)) - val guardMapped = replaceFromSymbols(map, guardOrTrue) - - val rc = rec((patternPathPos withCond guardOrTrue).fullClause, soFar) - val subPath = soFar merge (patternPathPos withCond guardOrTrue) - val rrhs = rec(c.rhs, subPath) - - soFar = soFar merge (patternPathNeg withCond guardMapped).negate - rc ++ rrhs - } - case IfExpr(cond, thenn, elze) => rec(cond, path) ++ rec(thenn, path withCond cond) ++ @@ -786,96 +562,6 @@ trait SymbolOps extends TreeOps { self: TypeOps => } } - /* ================= - * Body manipulation - * ================= - */ - - /** Returns whether a particular [[Expressions.Expr]] contains specification - * constructs, namely [[Expressions.Require]] and [[Expressions.Ensuring]]. - */ - def hasSpec(e: Expr): Boolean = exists { - case Require(_, _) => true - case Ensuring(_, _) => true - case Let(i, e, b) => hasSpec(b) - case _ => false - } (e) - - /** Merges the given [[Path]] into the provided [[Expressions.Expr]]. - * - * This method expects to run on a [[Definitions.FunDef.fullBody]] and merges into - * existing pre- and postconditions. - * - * @param expr The current body - * @param path The path that should be wrapped around the given body - * @see [[Expressions.Ensuring]] - * @see [[Expressions.Require]] - */ - def withPath(expr: Expr, path: Path): Expr = expr match { - case Let(i, e, b) => withPath(b, path withBinding (i -> e)) - case Require(pre, b) => path specs (b, pre) - case Ensuring(Require(pre, b), post) => path specs (b, pre, post) - case Ensuring(b, post) => path specs (b, post = post) - case b => path specs b - } - - /** Replaces the precondition of an existing [[Expressions.Expr]] with a new one. - * - * If no precondition is provided, removes any existing precondition. - * Else, wraps the expression with a [[Expressions.Require]] clause referring to the new precondition. - * - * @param expr The current expression - * @param pred An optional precondition. Setting it to None removes any precondition. - * @see [[Expressions.Ensuring]] - * @see [[Expressions.Require]] - */ - def withPrecondition(expr: Expr, pred: Option[Expr]): Expr = (pred, expr) match { - case (Some(newPre), Require(pre, b)) => req(newPre, b) - case (Some(newPre), Ensuring(Require(pre, b), p)) => Ensuring(req(newPre, b), p) - case (Some(newPre), Ensuring(b, p)) => Ensuring(req(newPre, b), p) - case (Some(newPre), Let(i, e, b)) if hasSpec(b) => Let(i, e, withPrecondition(b, pred)) - case (Some(newPre), b) => req(newPre, b) - case (None, Require(pre, b)) => b - case (None, Ensuring(Require(pre, b), p)) => Ensuring(b, p) - case (None, Let(i, e, b)) if hasSpec(b) => Let(i, e, withPrecondition(b, pred)) - case (None, b) => b - } - - /** Replaces the postcondition of an existing [[Expressions.Expr]] with a new one. - * - * If no postcondition is provided, removes any existing postcondition. - * Else, wraps the expression with a [[Expressions.Ensuring]] clause referring to the new postcondition. - * - * @param expr The current expression - * @param oie An optional postcondition. Setting it to None removes any postcondition. - * @see [[Expressions.Ensuring]] - * @see [[Expressions.Require]] - */ - def withPostcondition(expr: Expr, oie: Option[Expr]): Expr = (oie, expr) match { - case (Some(npost), Ensuring(b, post)) => ensur(b, npost) - case (Some(npost), Let(i, e, b)) if hasSpec(b) => Let(i, e, withPostcondition(b, oie)) - case (Some(npost), b) => ensur(b, npost) - case (None, Ensuring(b, p)) => b - case (None, Let(i, e, b)) if hasSpec(b) => Let(i, e, withPostcondition(b, oie)) - case (None, b) => b - } - - /** Adds a body to a specification - * - * @param expr The specification expression [[Expressions.Ensuring]] or [[Expressions.Require]]. If none of these, the argument is discarded. - * @param body An option of [[Expressions.Expr]] possibly containing an expression body. - * @return The post/pre condition with the body. If no body is provided, returns [[Expressions.NoTree]] - * @see [[Expressions.Ensuring]] - * @see [[Expressions.Require]] - */ - def withBody(expr: Expr, body: Option[Expr]): Expr = expr match { - case Let(i, e, b) if hasSpec(b) => Let(i, e, withBody(b, body)) - case Require(pre, _) => Require(pre, body.getOrElse(NoTree(expr.getType))) - case Ensuring(Require(pre, _), post) => Ensuring(Require(pre, body.getOrElse(NoTree(expr.getType))), post) - case Ensuring(_, post) => Ensuring(body.getOrElse(NoTree(expr.getType)), post) - case _ => body.getOrElse(NoTree(expr.getType)) - } - object InvocationExtractor { private def flatInvocation(expr: Expr): Option[(Identifier, Seq[Type], Seq[Expr])] = expr match { case fi @ FunctionInvocation(id, tps, args) => Some((id, tps, args)) @@ -930,8 +616,6 @@ trait SymbolOps extends TreeOps { self: TypeOps => IfExpr(cond, apply(thenn, args), apply(elze, args)) case Let(i, e, b) => Let(i, e, apply(b, args)) - case LetTuple(is, es, b) => - letTuple(is, es, apply(b, args)) //case l @ Lambda(params, body) => // l.withParamSubst(args, body) case _ => Application(expr, args) @@ -968,11 +652,7 @@ trait SymbolOps extends TreeOps { self: TypeOps => rec(lift(expr), true) } - liftToLambdas( - matchToIfThenElse( - expr - ) - ) + liftToLambdas(expr) } // Use this only to debug isValueOfType diff --git a/src/main/scala/inox/ast/TreeOps.scala b/src/main/scala/inox/ast/TreeOps.scala index 1fbfbf6a056dfbc27e38c49103a11752e895cbf0..81729beb71c36bbf91c44c2b40f84d5f6b2995b4 100644 --- a/src/main/scala/inox/ast/TreeOps.scala +++ b/src/main/scala/inox/ast/TreeOps.scala @@ -2,69 +2,178 @@ package inox package ast -trait TreeOps { - private[ast] val trees: Trees - import trees._ +trait TreeOps { self: Trees => trait TreeTransformer { def transform(id: Identifier, tpe: Type): (Identifier, Type) = (id, transform(tpe)) def transform(v: Variable): Variable = { val (id, tpe) = transform(v.id, v.tpe) - Variable(id, tpe).copiedFrom(v) + if ((id ne v.id) || (tpe ne v.tpe)) { + Variable(id, tpe).copiedFrom(v) + } else { + v + } } def transform(vd: ValDef): ValDef = { val (id, tpe) = transform(vd.id, vd.tpe) - ValDef(id, tpe).copiedFrom(vd) + if ((id ne vd.id) || (tpe ne vd.tpe)) { + ValDef(id, tpe).copiedFrom(vd) + } else { + vd + } + } + + @inline + private def transformChanged(vds: Seq[ValDef]): (Seq[ValDef], Boolean) = { + var changed = false + val newVds = vds.map { vd => + val newVd = transform(vd) + if (vd ne newVd) changed = true + newVd + } + + (newVds, changed) + } + + @inline + private def transformChanged(args: Seq[Expr]): (Seq[Expr], Boolean) = { + var changed = false + val newArgs = args.map { arg => + val newArg = transform(arg) + if (arg ne newArg) changed = true + newArg + } + + (newArgs, changed) + } + + @inline + private def transformChanged(tps: Seq[Type]): (Seq[Type], Boolean) = { + var changed = false + val newTps = tps.map { tp => + val newTp = transform(tp) + if (tp ne newTp) changed = true + newTp + } + + (newTps, changed) } def transform(e: Expr): Expr = e match { case v: Variable => transform(v) case Lambda(args, body) => - Lambda(args map transform, transform(body)).copiedFrom(e) + val (newArgs, changedArgs) = transformChanged(args) + val newBody = transform(body) + if (changedArgs || (body ne newBody)) { + Lambda(newArgs, newBody).copiedFrom(e) + } else { + e + } case Forall(args, body) => - Forall(args map transform, transform(body)).copiedFrom(e) + val (newArgs, changedArgs) = transformChanged(args) + val newBody = transform(body) + if (changedArgs || (body ne newBody)) { + Forall(newArgs, newBody).copiedFrom(e) + } else { + e + } case Let(vd, expr, body) => - Let(transform(vd), transform(expr), transform(body)).copiedFrom(e) + val newVd = transform(vd) + val newExpr = transform(expr) + val newBody = transform(body) + if ((vd ne newVd) || (expr ne newExpr) || (body ne newBody)) { + Let(newVd, newExpr, newBody).copiedFrom(e) + } else { + e + } - case CaseClass(cct, args) => - CaseClass(transform(cct).asInstanceOf[ClassType], args map transform).copiedFrom(e) + case CaseClass(ct, args) => + val newCt = transform(ct) + val (newArgs, changedArgs) = transformChanged(args) + if ((ct ne newCt) || changedArgs) { + CaseClass(newCt.asInstanceOf[ClassType], newArgs).copiedFrom(e) + } else { + e + } - case CaseClassSelector(cct, caseClass, selector) => - CaseClassSelector(transform(cct).asInstanceOf[ClassType], transform(caseClass), selector).copiedFrom(e) + case CaseClassSelector(cc, selector) => + val newCc = transform(cc) + if (cc ne newCc) { + CaseClassSelector(cc, selector).copiedFrom(e) + } else { + e + } case FunctionInvocation(id, tps, args) => - FunctionInvocation(id, tps map transform, args map transform).copiedFrom(e) + val (newTps, changedTps) = transformChanged(tps) + val (newArgs, changedArgs) = transformChanged(args) + if (changedTps || changedArgs) { + FunctionInvocation(id, newTps, newArgs).copiedFrom(e) + } else { + e + } case IsInstanceOf(expr, ct) => - IsInstanceOf(transform(expr), transform(ct).asInstanceOf[ClassType]).copiedFrom(e) + val newExpr = transform(expr) + val newCt = transform(ct) + if ((expr ne newExpr) || (ct ne newCt)) { + IsInstanceOf(newExpr, newCt.asInstanceOf[ClassType]).copiedFrom(e) + } else { + e + } case AsInstanceOf(expr, ct) => - AsInstanceOf(transform(expr), transform(ct).asInstanceOf[ClassType]).copiedFrom(e) - - case MatchExpr(scrutinee, cases) => - MatchExpr(transform(scrutinee), for (cse @ MatchCase(pattern, guard, rhs) <- cases) yield { - MatchCase(transform(pattern), guard.map(transform), transform(rhs)).copiedFrom(cse) - }).copiedFrom(e) + val newExpr = transform(expr) + val newCt = transform(ct) + if ((expr ne newExpr) || (ct ne newCt)) { + AsInstanceOf(newExpr, newCt.asInstanceOf[ClassType]).copiedFrom(e) + } else { + e + } case FiniteSet(es, tpe) => - FiniteSet(es map transform, transform(tpe)).copiedFrom(e) + val (newArgs, changed) = transformChanged(es) + val newTpe = transform(tpe) + if (changed || (tpe ne newTpe)) { + FiniteSet(newArgs, newTpe).copiedFrom(e) + } else { + e + } case FiniteBag(es, tpe) => - FiniteBag(es map { case (k, v) => transform(k) -> v }, transform(tpe)).copiedFrom(e) + var changed = false + val newEs = es.map { case (k, v) => + val newK = transform(k) + if (k ne newK) changed = true + newK -> v + } + val newTpe = transform(tpe) + if (changed || (tpe ne newTpe)) { + FiniteBag(newEs, newTpe).copiedFrom(e) + } else { + e + } case FiniteMap(pairs, from, to) => - FiniteMap(pairs map { case (k, v) => transform(k) -> transform(v) }, transform(from), transform(to)).copiedFrom(e) - - case NoTree(tpe) => - NoTree(transform(tpe)).copiedFrom(e) - - case Error(tpe, desc) => - Error(transform(tpe), desc).copiedFrom(e) + var changed = false + val newPairs = pairs.map { case (k, v) => + val newK = transform(k) + val newV = transform(v) + if ((k ne newK) || (v ne newV)) changed = true + newK -> newV + } + val newFrom = transform(from) + val newTo = transform(to) + if (changed || (from ne newFrom) || (to ne newTo)) { + FiniteMap(newPairs, newFrom, newTo).copiedFrom(e) + } else { + e + } case Operator(es, builder) => val newEs = es map transform @@ -77,27 +186,14 @@ trait TreeOps { case e => e } - def transform(pat: Pattern): Pattern = pat match { - case InstanceOfPattern(binder, ct) => - InstanceOfPattern(binder map transform, transform(ct).asInstanceOf[ClassType]).copiedFrom(pat) - - case CaseClassPattern(binder, ct, subs) => - CaseClassPattern(binder map transform, transform(ct).asInstanceOf[ClassType], subs map transform).copiedFrom(pat) - - case TuplePattern(binder, subs) => - TuplePattern(binder map transform, subs map transform).copiedFrom(pat) - - case UnapplyPattern(binder, id, tps, subs) => - UnapplyPattern(binder map transform, id, tps map transform, subs map transform).copiedFrom(pat) - - case PatternExtractor(subs, builder) => - builder(subs map transform).copiedFrom(pat) - - case p => p - } - def transform(tpe: Type): Type = tpe match { - case NAryType(ts, builder) => builder(ts map transform).copiedFrom(tpe) + case NAryType(ts, builder) => + val newTs = ts map transform + if ((newTs zip ts).exists { case (bef, aft) => aft ne bef }) { + builder(ts map transform).copiedFrom(tpe) + } else { + tpe + } } } @@ -122,8 +218,7 @@ trait TreeOps { traverse(cct) args foreach traverse - case CaseClassSelector(cct, caseClass, selector) => - traverse(cct) + case CaseClassSelector(caseClass, selector) => traverse(caseClass) case FunctionInvocation(id, tps, args) => @@ -138,14 +233,6 @@ trait TreeOps { traverse(expr) traverse(ct) - case MatchExpr(scrutinee, cases) => - traverse(scrutinee) - for (cse @ MatchCase(pattern, guard, rhs) <- cases) { - traverse(pattern) - guard foreach traverse - traverse(rhs) - } - case FiniteSet(es, tpe) => es foreach traverse traverse(tpe) @@ -159,39 +246,12 @@ trait TreeOps { traverse(from) traverse(to) - case NoTree(tpe) => - traverse(tpe) - - case Error(tpe, desc) => - traverse(tpe) - case Operator(es, builder) => es foreach traverse case e => } - def traverse(pat: Pattern): Unit = pat match { - case InstanceOfPattern(binder, ct) => - traverse(ct) - - case CaseClassPattern(binder, ct, subs) => - traverse(ct) - subs foreach traverse - - case TuplePattern(binder, subs) => - subs foreach traverse - - case UnapplyPattern(binder, id, tps, subs) => - tps foreach traverse - subs foreach traverse - - case PatternExtractor(subs, builder) => - subs foreach traverse - - case pat => - } - def traverse(tpe: Type): Unit = tpe match { case NAryType(ts, builder) => ts foreach traverse diff --git a/src/main/scala/inox/ast/Trees.scala b/src/main/scala/inox/ast/Trees.scala index 321ff6921c7fead1bfba19685f78fdeecce3c3a5..69cc86967634c24f6f851bc659d17210df68158c 100644 --- a/src/main/scala/inox/ast/Trees.scala +++ b/src/main/scala/inox/ast/Trees.scala @@ -8,7 +8,13 @@ import scala.language.implicitConversions case object DebugSectionTrees extends DebugSection("trees") -trait Trees extends Expressions with Extractors with Types with Definitions with Printers { +trait Trees + extends Expressions + with Extractors + with Types + with Definitions + with Printers + with TreeOps { class Unsupported(t: Tree, msg: String)(implicit ctx: InoxContext) extends Exception(s"${t.asString(PrinterOptions.fromContext(ctx))}@${t.getPos} $msg") diff --git a/src/main/scala/inox/evaluators/DefaultEvaluator.scala b/src/main/scala/inox/evaluators/DefaultEvaluator.scala deleted file mode 100644 index 90fb4e0b91ebe303874f9a173adc8a137af8414d..0000000000000000000000000000000000000000 --- a/src/main/scala/inox/evaluators/DefaultEvaluator.scala +++ /dev/null @@ -1,11 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package evaluators - -import purescala.Definitions.Program - -class DefaultEvaluator(ctx: LeonContext, prog: Program, bank: EvaluationBank = new EvaluationBank) - extends RecursiveEvaluator(ctx, prog, bank, 50000) - with HasDefaultGlobalContext - with HasDefaultRecContext diff --git a/src/main/scala/inox/evaluators/RecursiveEvaluator.scala b/src/main/scala/inox/evaluators/RecursiveEvaluator.scala index d1eba7f1ed700c4d8c38fee825f566c158f5ee37..794fd17fadd2b062275b126358bf41697f0c04f9 100644 --- a/src/main/scala/inox/evaluators/RecursiveEvaluator.scala +++ b/src/main/scala/inox/evaluators/RecursiveEvaluator.scala @@ -51,14 +51,10 @@ trait RecursiveEvaluator //println(s"Eval $i to $first") e(b)(rctx.withNewVar(i, first), gctx) - case Assert(cond, oerr, body) => - e(IfExpr(Not(cond), Error(expr.getType, oerr.getOrElse("Assertion failed @"+expr.getPos)), body)) - - case en @ Ensuring(body, post) => - e(en.toAssert) - - case Error(tpe, desc) => - throw RuntimeError("Error reached in evaluation: " + desc) + case Assume(cond, body) => + if (e(cond) != BooleanLiteral(true)) + throw RuntimeError("Assumption did not hold @" + expr.getPos) + e(body) case IfExpr(cond, thenn, elze) => val first = e(cond) @@ -80,16 +76,6 @@ trait RecursiveEvaluator // build a mapping for the function... val frame = rctx.withNewVars(tfd.paramSubst(evArgs)) - if (tfd.hasPrecondition) { - e(tfd.precondition.get)(frame, gctx) match { - case BooleanLiteral(true) => - case BooleanLiteral(false) => - throw RuntimeError("Precondition violation for " + tfd.id.asString + " reached in evaluation.: " + tfd.precondition.get.asString) - case other => - throw RuntimeError(typeErrorMsg(other, BooleanType)) - } - } - // @nv TODO: choose evaluation /* @nv TODO: should we do this differently?? lambdas?? @@ -100,17 +86,9 @@ trait RecursiveEvaluator val callResult: Expr = tfd.body match { case Some(body) => e(body)(frame, gctx) - case None => onSpecInvocation(tfd.postOrTrue) - } - - tfd.postcondition match { - case Some(post) => - e(application(post, Seq(callResult)))(frame, gctx) match { - case BooleanLiteral(true) => - case BooleanLiteral(false) => throw RuntimeError("Postcondition violation for " + tfd.id.asString + " reached in evaluation.") - case other => throw EvalError(typeErrorMsg(other, BooleanType)) - } case None => + // @nv TODO: this isn't right... + throw RuntimeError("Cannot evaluate bodyless function") } callResult @@ -199,13 +177,13 @@ trait RecursiveEvaluator val le = e(expr) BooleanLiteral(isSubtypeOf(le.getType, ct)) - case CaseClassSelector(ct1, expr, sel) => + case CaseClassSelector(expr, sel) => e(expr) match { - case CaseClass(ct2, args) if ct1 == ct2 => args(ct1.tcd.cd match { + case CaseClass(ct, args) => args(ct.tcd.cd match { case ccd: CaseClassDef => ccd.selectorID2Index(sel) case _ => throw RuntimeError("Unexpected case class type") }) - case le => throw EvalError(typeErrorMsg(le, ct1)) + case le => throw EvalError(typeErrorMsg(le, expr.getType)) } case Plus(l,r) => @@ -557,15 +535,6 @@ trait RecursiveEvaluator throw EvalError(typeErrorMsg(l, MapType(r.getType, g.getType))) } - case MatchExpr(scrut, cases) => - val rscrut = e(scrut) - cases.toStream.map(c => matchesCase(rscrut, c)).find(_.nonEmpty) match { - case Some(Some((c, mappings))) => - e(c.rhs)(rctx.withNewVars(mappings), gctx) - case _ => - throw RuntimeError("MatchError: "+rscrut.asString+" did not match any of the cases:\n"+cases) - } - case gl: GenericValue => gl case fl : FractionLiteral => normalizeFraction(fl) case l : Literal[_] => l @@ -573,81 +542,5 @@ trait RecursiveEvaluator case other => throw EvalError("Unhandled case in Evaluator : [" + other.getClass + "] " + other.asString) } - - def matchesCase(scrut: Expr, caze: MatchCase)(implicit rctx: RC, gctx: GC): Option[(MatchCase, Map[ValDef, Expr])] = { - - def matchesPattern(pat: Pattern, expr: Expr): Option[Map[ValDef, Expr]] = (pat, expr) match { - case (InstanceOfPattern(ob, pct), e) => - if (isSubtypeOf(e.getType, pct)) { - Some(obind(ob, e)) - } else { - None - } - case (WildcardPattern(ob), e) => - Some(obind(ob, e)) - - case (CaseClassPattern(ob, pct, subs), CaseClass(ct, args)) => - if (pct == ct) { - val res = (subs zip args).map{ case (s, a) => matchesPattern(s, a) } - if (res.forall(_.isDefined)) { - Some(obind(ob, expr) ++ res.flatten.flatten) - } else { - None - } - } else { - None - } - case (up @ UnapplyPattern(ob, id, tps, subs), scrut) => - val tfd = getFunction(id, tps) - val (noneType, someType) = up.optionChildren - - e(FunctionInvocation(id, tps, Seq(scrut))) match { - case CaseClass(`noneType`, Seq()) => - None - case CaseClass(`someType`, Seq(arg)) => - val res = subs zip unwrapTuple(arg, subs.size) map { - case (s,a) => matchesPattern(s,a) - } - if (res.forall(_.isDefined)) { - Some(obind(ob, expr) ++ res.flatten.flatten) - } else { - None - } - case other => - throw EvalError(typeErrorMsg(other, tfd.returnType)) - } - case (TuplePattern(ob, subs), Tuple(args)) => - if (subs.size == args.size) { - val res = (subs zip args).map{ case (s, a) => matchesPattern(s, a) } - if (res.forall(_.isDefined)) { - Some(obind(ob, expr) ++ res.flatten.flatten) - } else { - None - } - } else { - None - } - case (LiteralPattern(ob, l1) , l2 : Literal[_]) if l1 == l2 => - Some(obind(ob,l1)) - case _ => None - } - - def obind(ovd: Option[ValDef], e: Expr): Map[ValDef, Expr] = { - ovd.map(vd => vd -> e).toMap - } - - caze match { - case SimpleCase(p, rhs) => - matchesPattern(p, scrut).map(r => - (caze, r) - ) - - case GuardedCase(p, g, rhs) => - for { - r <- matchesPattern(p, scrut) - if e(g)(rctx.withNewVars(r), gctx) == BooleanLiteral(true) - } yield (caze, r) - } - } } diff --git a/src/main/scala/inox/evaluators/SolvingEvaluator.scala b/src/main/scala/inox/evaluators/SolvingEvaluator.scala index 267328f68f2a14e886d11f8071889840f7821d30..3d8e64fe51feb5ce98f5eac98747fbc798c4bc9b 100644 --- a/src/main/scala/inox/evaluators/SolvingEvaluator.scala +++ b/src/main/scala/inox/evaluators/SolvingEvaluator.scala @@ -33,12 +33,14 @@ trait SolvingEvaluator extends Evaluator { case o @ InoxOption(opt, _) if opt == optForallCache => o }.toSeq : _*) + import solver.SolverResponses._ + solver.assertCnstr(body) val res = solver.check(model = true) timer.stop() res match { - case solver.Model(model) => + case Model(model) => valuateWithModel(model)(vd) case _ => @@ -60,16 +62,18 @@ trait SolvingEvaluator extends Evaluator { InoxOption(optForallCache)(cache) ) + import solver.SolverResponses._ + solver.assertCnstr(Not(forall.body)) val res = solver.check(model = true) timer.stop() res match { - case solver.Unsat() => + case Unsat() => cache(forall) = true true - case solver.Model(model) => + case Model(model) => cache(forall) = false eval(Not(forall.body), model) match { case EvaluationResults.Successful(BooleanLiteral(true)) => false diff --git a/src/main/scala/inox/grammars/utils/Helpers.scala b/src/main/scala/inox/grammars/utils/Helpers.scala index 40611043515530c4c496a3a3ca429d9a0865c4dd..2cc09813127ce1539fd7b0089b2d2f4994e07c07 100644 --- a/src/main/scala/inox/grammars/utils/Helpers.scala +++ b/src/main/scala/inox/grammars/utils/Helpers.scala @@ -44,7 +44,7 @@ trait Helpers { self: GrammarsUniverse => def terminatingCalls(prog: Program, wss: Seq[FunctionInvocation], pc: Path, tpe: Option[Type], introduceHoles: Boolean): List[(FunctionInvocation, Option[Set[Identifier]])] = { def subExprsOf(expr: Expr, v: EVariable): Option[(EVariable, Expr)] = expr match { - case CaseClassSelector(cct, r, _) => subExprsOf(r, v) + case CaseClassSelector(r, _) => subExprsOf(r, v) case (r: EVariable) if leastUpperBound(r.getType, v.getType).isDefined => Some(r -> v) case _ => None } @@ -53,7 +53,7 @@ trait Helpers { self: GrammarsUniverse => val one = IntegerLiteral(1) val knownSmallers = (pc.bindings.flatMap { // @nv: used to check both Equals(id, selector) and Equals(selector, id) - case (id, s @ CaseClassSelector(cct, r, _)) => subExprsOf(s, id.toVariable) + case (id, s @ CaseClassSelector(r, _)) => subExprsOf(s, id.toVariable) case _ => None } ++ pc.conditions.flatMap { case GreaterThan(v: EVariable, `z`) => diff --git a/src/main/scala/inox/solvers/theories/ArrayEncoder.scala b/src/main/scala/inox/solvers/theories/ArrayEncoder.scala deleted file mode 100644 index a91b88da2cca4e9ef639164f26d145057c7dd04c..0000000000000000000000000000000000000000 --- a/src/main/scala/inox/solvers/theories/ArrayEncoder.scala +++ /dev/null @@ -1,157 +0,0 @@ -/* Copyright 2009-2015 EPFL, Lausanne */ - -package leon -package solvers -package theories - -import purescala.Common._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Extractors._ -import purescala.Constructors._ -import purescala.Types._ -import purescala.Types._ -import purescala.Definitions._ -import leon.utils.Bijection -import leon.purescala.TypeOps - -class ArrayEncoder(ctx: LeonContext, p: Program) extends TheoryEncoder { - - private val arrayTypeParam = TypeParameter.fresh("A") - val ArrayCaseClass = new CaseClassDef(FreshIdentifier("InternalArray"), Seq(TypeParameterDef(arrayTypeParam)), None, false) - val rawArrayField = FreshIdentifier("raw", RawArrayType(Int32Type, arrayTypeParam)) - val lengthField = FreshIdentifier("length", Int32Type) - ArrayCaseClass.setFields(Seq(ValDef(rawArrayField), ValDef(lengthField))) - - val encoder = new Encoder { - override def transformExpr(e: Expr)(implicit binders: Map[Identifier, Identifier]): Option[Expr] = e match { - case al @ ArrayLength(a) => - val ArrayType(base) = a.getType - val ra = transform(a) - Some(CaseClassSelector(ArrayCaseClass.typed(Seq(base)), ra, lengthField)) - - case al @ ArraySelect(a, i) => - val ArrayType(base) = a.getType - val ra = transform(a) - val ri = transform(i) - val raw = CaseClassSelector(transform(a.getType).asInstanceOf[CaseClassType], ra, rawArrayField) - Some(RawArraySelect(raw, ri)) - - case al @ ArrayUpdated(a, i, e) => - val ra = transform(a) - val ri = transform(i) - val re = transform(e) - - val length = CaseClassSelector(transform(a.getType).asInstanceOf[CaseClassType], ra, lengthField) - val raw = CaseClassSelector(transform(a.getType).asInstanceOf[CaseClassType], ra, rawArrayField) - - Some(CaseClass(transform(a.getType).asInstanceOf[CaseClassType], Seq(RawArrayUpdated(raw, ri, re), length))) - - case a @ FiniteArray(elems, oDef, size) => - - val tpe @ ArrayType(to) = a.getType - val default: Expr = transform(oDef.getOrElse(simplestValue(to))) - - val raw = RawArrayValue(Int32Type, elems.map { - case (k, v) => IntLiteral(k) -> transform(v) - }, default) - Some(CaseClass(ArrayCaseClass.typed(Seq(to)), Seq(raw, transform(size)))) - - case _ => None - } - - override def transformType(tpe: TypeTree): Option[TypeTree] = tpe match { - case ArrayType(base) => Some(ArrayCaseClass.typed(Seq(base))) - case _ => None - } - - } - - - val decoder = new Decoder { - - private def fromRaw(raw: Expr, length: Expr): Expr = { - val RawArrayValue(baseType, elems, default) = raw - val IntLiteral(size) = length - if (size < 0) { - throw new Exception("Cannot build an array of negative length: " + size) - } else if (size > 10) { - val definedElements = elems.collect { - case (IntLiteral(i), value) => (i, value) - } - finiteArray(definedElements, Some(default, IntLiteral(size)), baseType) - - } else { - val entries = for (i <- 0 to size - 1) yield elems.getOrElse(IntLiteral(i), default) - - finiteArray(entries, None, baseType) - } - } - - override def transformExpr(e: Expr)(implicit binders: Map[Identifier, Identifier]): Option[Expr] = e match { - case cc @ CaseClass(cct, args) if cct.classDef == ArrayCaseClass => - val Seq(rawArray, length) = args - val leonArray = fromRaw(rawArray, length) - Some(leonArray) - // Some(StringLiteral(convertToString(cc)).copiedFrom(cc)) - //case FunctionInvocation(SizeI, Seq(a)) => - // Some(StringLength(transform(a)).copiedFrom(e)) - //case FunctionInvocation(Size, Seq(a)) => - // Some(StringBigLength(transform(a)).copiedFrom(e)) - //case FunctionInvocation(Concat, Seq(a, b)) => - // Some(StringConcat(transform(a), transform(b)).copiedFrom(e)) - //case FunctionInvocation(SliceI, Seq(a, from, to)) => - // Some(SubString(transform(a), transform(from), transform(to)).copiedFrom(e)) - //case FunctionInvocation(Slice, Seq(a, from, to)) => - // Some(BigSubString(transform(a), transform(from), transform(to)).copiedFrom(e)) - //case FunctionInvocation(TakeI, Seq(FunctionInvocation(DropI, Seq(a, start)), length)) => - // val rstart = transform(start) - // Some(SubString(transform(a), rstart, plus(rstart, transform(length))).copiedFrom(e)) - //case FunctionInvocation(Take, Seq(FunctionInvocation(Drop, Seq(a, start)), length)) => - // val rstart = transform(start) - // Some(BigSubString(transform(a), rstart, plus(rstart, transform(length))).copiedFrom(e)) - //case FunctionInvocation(TakeI, Seq(a, length)) => - // Some(SubString(transform(a), IntLiteral(0), transform(length)).copiedFrom(e)) - //case FunctionInvocation(Take, Seq(a, length)) => - // Some(BigSubString(transform(a), InfiniteIntegerLiteral(0), transform(length)).copiedFrom(e)) - //case FunctionInvocation(DropI, Seq(a, count)) => - // val ra = transform(a) - // Some(SubString(ra, transform(count), StringLength(ra)).copiedFrom(e)) - //case FunctionInvocation(Drop, Seq(a, count)) => - // val ra = transform(a) - // Some(BigSubString(ra, transform(count), StringBigLength(ra)).copiedFrom(e)) - //case FunctionInvocation(FromInt, Seq(a)) => - // Some(Int32ToString(transform(a)).copiedFrom(e)) - //case FunctionInvocation(FromBigInt, Seq(a)) => - // Some(IntegerToString(transform(a)).copiedFrom(e)) - //case FunctionInvocation(FromChar, Seq(a)) => - // Some(CharToString(transform(a)).copiedFrom(e)) - //case FunctionInvocation(FromBoolean, Seq(a)) => - // Some(BooleanToString(transform(a)).copiedFrom(e)) - case _ => None - } - - - override def transformType(tpe: TypeTree): Option[TypeTree] = tpe match { - //case String | StringCons | StringNil => Some(StringType) - case _ => None - } - - //override def transform(pat: Pattern): (Pattern, Map[Identifier, Identifier]) = pat match { - // case CaseClassPattern(b, StringNil, Seq()) => - // val newBinder = b map transform - // (LiteralPattern(newBinder , StringLiteral("")), (b zip newBinder).filter(p => p._1 != p._2).toMap) - - // case CaseClassPattern(b, StringCons, Seq(LiteralPattern(_, CharLiteral(elem)), sub)) => transform(sub) match { - // case (LiteralPattern(_, StringLiteral(s)), binders) => - // val newBinder = b map transform - // (LiteralPattern(newBinder, StringLiteral(elem + s)), (b zip newBinder).filter(p => p._1 != p._2).toMap ++ binders) - // case (e, binders) => - // throw new Unsupported(pat, "Failed to parse pattern back as string: " + e)(ctx) - // } - - // case _ => super.transform(pat) - //} - } -} - diff --git a/src/main/scala/inox/solvers/theories/StringEncoder.scala b/src/main/scala/inox/solvers/theories/StringEncoder.scala index de77b4aa96520fd674e69145e16126acacd2ed99..4c6f3a51d312fdc6a6aff0b845b29561bc0771f7 100644 --- a/src/main/scala/inox/solvers/theories/StringEncoder.scala +++ b/src/main/scala/inox/solvers/theories/StringEncoder.scala @@ -13,9 +13,23 @@ import leon.utils.Bijection import leon.purescala.TypeOps class StringEncoder(ctx: LeonContext, p: Program) extends TheoryEncoder { - val String = p.library.lookupUnique[ClassDef]("leon.theories.String").typed - val StringCons = p.library.lookupUnique[CaseClassDef]("leon.theories.StringCons").typed - val StringNil = p.library.lookupUnique[CaseClassDef]("leon.theories.StringNil").typed + + val StringID = FreshIdentifier("String") + val StringNilID = FreshIdentifier("StringNil") + val StringConsID = FreshIdentifier("StringCons") + + val StringConsHeadID = FreshIdentifier("head") + val StringConsTailID = FreshIdentifier("tail") + + val String = new AbstractClassDef(StringID, Seq.empty Seq(StringConsID, StringNilID), Set.empty).typed + val StringNil = new CaseClassDef(StringNilID, Seq.empty, Some(StringID), Seq.empty, Set.empty).typed + val StringCons = new CaseClassDef(StringConsID, Seq.empty, Some(StringID), Seq( + ValDef(StringConsHeadID, CharType), + ValDef(StringConsTailID, ClassType(StringID, Seq.empty)) + ), Set.empty).typed + + val SizeID = FreshIdentifier("size") + val Size = new FunDef() val Size = p.library.lookupUnique[FunDef]("leon.theories.String.size").typed val Take = p.library.lookupUnique[FunDef]("leon.theories.String.take").typed diff --git a/src/main/scala/inox/solvers/theories/TheoryEncoder.scala b/src/main/scala/inox/solvers/theories/TheoryEncoder.scala index f3b4d6ea0b026d3296bfe41b7a85c96665a2ba27..def0c986993eb04a2c0c36759e51a662d5a1091c 100644 --- a/src/main/scala/inox/solvers/theories/TheoryEncoder.scala +++ b/src/main/scala/inox/solvers/theories/TheoryEncoder.scala @@ -1,82 +1,67 @@ /* Copyright 2009-2015 EPFL, Lausanne */ -package leon +package inox package solvers package theories -import purescala.Common._ -import purescala.Expressions._ -import purescala.Definitions._ -import purescala.Extractors._ -import purescala.ExprOps._ -import purescala.Types._ - import utils._ -import scala.collection.mutable.{Map => MutableMap, Set => MutableSet} - -trait TheoryEncoder { self => - protected val encoder: Encoder - protected val decoder: Decoder +trait TheoryEncoder { + val trees: Trees + import trees._ - private val idMap = new Bijection[Identifier, Identifier] - private val fdMap = new Bijection[FunDef , FunDef ] - private val cdMap = new Bijection[ClassDef , ClassDef ] + private type SameTrees = TheoryEncoder { + val trees: TheoryEncoder.this.trees.type + } - def encode(id: Identifier): Identifier = encoder.transform(id) - def decode(id: Identifier): Identifier = decoder.transform(id) + protected val encoder: TreeTransformer + protected val decoder: TreeTransformer - def encode(expr: Expr)(implicit bindings: Map[Identifier, Identifier]): Expr = encoder.transform(expr) - def decode(expr: Expr)(implicit bindings: Map[Identifier, Identifier]): Expr = decoder.transform(expr) + def encode(v: Variable): Variable = encoder.transform(v) + def decode(v: Variable): Variable = decoder.transform(v) - def encode(tpe: TypeTree): TypeTree = encoder.transform(tpe) - def decode(tpe: TypeTree): TypeTree = decoder.transform(tpe) + def encode(expr: Expr): Expr = encoder.transform(expr) + def decode(expr: Expr): Expr = decoder.transform(expr) - def encode(fd: FunDef): FunDef = encoder.transform(fd) - def decode(fd: FunDef): FunDef = decoder.transform(fd) + def encode(tpe: Type): Type = encoder.transform(tpe) + def decode(tpe: Type): Type = decoder.transform(tpe) - protected class Encoder extends purescala.DefinitionTransformer(idMap, fdMap, cdMap) - protected class Decoder extends purescala.DefinitionTransformer(idMap.swap, fdMap.swap, cdMap.swap) + def >>(that: SameTrees): SameTrees = new TheoryEncoder { + val trees: TheoryEncoder.this.trees.type = TheoryEncoder.this.trees - def >>(that: TheoryEncoder): TheoryEncoder = new TheoryEncoder { - val encoder = new Encoder { - override def transformExpr(expr: Expr)(implicit bindings: Map[Identifier, Identifier]): Option[Expr] = { - val mapSeq = bindings.toSeq - val intermediate = mapSeq.map { case (id, _) => id.duplicate(tpe = self.encoder.transform(id.getType)) } - val e2 = self.encoder.transform(expr)((mapSeq zip intermediate).map { case ((id, _), id2) => id -> id2 }.toMap) - Some(that.encoder.transform(e2)((intermediate zip mapSeq).map { case (id, (_, id2)) => id -> id2 }.toMap)) + val encoder = new TreeTransformer { + override def transform(id: Identifier, tpe: Type): (Identifier, Type) = { + val (id1, tpe1) = TheoryEncoder.this.transform(id, tpe) + that.transform(id1, tpe1) } - override def transformType(tpe: TypeTree): Option[TypeTree] = Some(that.encoder.transform(self.encoder.transform(tpe))) - - override def transform(pat: Pattern): (Pattern, Map[Identifier, Identifier]) = { - val (pat2, bindings) = self.encoder.transform(pat) - val (pat3, bindings2) = that.encoder.transform(pat2) - (pat3, bindings2.map { case (id, id2) => id -> bindings2(id2) }) - } + override def transform(expr: Expr): Expr = that.transform(TheoryEncoder.this.transform(expr)) + override def transform(tpe: Type): Type = that.transform(TheoryEncoder.this.transform(expr)) } - val decoder = new Decoder { - override def transformExpr(expr: Expr)(implicit bindings: Map[Identifier, Identifier]): Option[Expr] = { - val mapSeq = bindings.toSeq - val intermediate = mapSeq.map { case (id, _) => id.duplicate(tpe = self.decoder.transform(id.getType)) } - val e2 = that.decoder.transform(expr)((mapSeq zip intermediate).map { case ((id, _), id2) => id -> id2 }.toMap) - Some(self.decoder.transform(e2)((intermediate zip mapSeq).map { case (id, (_, id2)) => id -> id2 }.toMap)) + val decoder = new TreeTransformer { + override def transform(id: Identifier, tpe: Type): (Identifier, Type) = { + val (id1, tpe1) = that.transform(id, tpe) + TheoryEncoder.this.transform(id1, tpe1) } - override def transformType(tpe: TypeTree): Option[TypeTree] = Some(self.decoder.transform(that.decoder.transform(tpe))) - - override def transform(pat: Pattern): (Pattern, Map[Identifier, Identifier]) = { - val (pat2, bindings) = that.decoder.transform(pat) - val (pat3, bindings2) = self.decoder.transform(pat2) - (pat3, bindings.map { case (id, id2) => id -> bindings2(id2) }) - } + override def transform(expr: Expr): Expr = TheoryEncoder.this.transform(that.transform(expr)) + override def transform(tpe: Type): Type = TheoryEncoder.this.transform(that.transform(tpe)) } } } -class NoEncoder extends TheoryEncoder { - val encoder = new Encoder - val decoder = new Decoder +trait NoEncoder extends TheoryEncoder { + + private object NoTransformer extends trees.TreeTransformer { + override def transform(id: Identifier, tpe: Type): (Identifier, Type) = (id, tpe) + override def transform(v: Variable): Variable = v + override def transform(vd: ValDef): ValDef = vd + override def transform(expr: Expr): Expr = expr + override def transform(tpe: Type): Type = tpe + } + + val encoder = NoTransformer + val decoder = NoTransformer } diff --git a/src/main/scala/inox/solvers/unrolling/DatatypeTemplates.scala b/src/main/scala/inox/solvers/unrolling/DatatypeTemplates.scala index 1db7f9eee28d70e43843867a3535ebd5145c5ec7..de89d2274b15f731ab815a3603b547d143da7668 100644 --- a/src/main/scala/inox/solvers/unrolling/DatatypeTemplates.scala +++ b/src/main/scala/inox/solvers/unrolling/DatatypeTemplates.scala @@ -90,7 +90,7 @@ trait DatatypeTemplates { self: Templates => def unrollFields(tcd: TypedCaseClassDef): Seq[Expr] = tcd.fields.map { vd => val tpe = tcd.toType - typeUnroller(CaseClassSelector(tpe, AsInstanceOf(expr, tpe), vd.id)) + typeUnroller(CaseClassSelector(AsInstanceOf(expr, tpe), vd.id)) } val fields: Seq[Expr] = if (tcd != tcd.root) { @@ -185,7 +185,7 @@ trait DatatypeTemplates { self: Templates => rec(pathVar, expr) val (idT, pathVarT) = (encodeSymbol(v), encodeSymbol(pathVar)) - val encoder: Expr => Encoded = encodeExpr(condVars + (v -> idT) + (pathVar -> pathVarT)) + val encoder: Expr => Encoded = mkEncoder(condVars + (v -> idT) + (pathVar -> pathVarT)) var clauses: Clauses = Seq.empty var calls: CallBlockers = Map.empty diff --git a/src/main/scala/inox/solvers/unrolling/FunctionTemplates.scala b/src/main/scala/inox/solvers/unrolling/FunctionTemplates.scala index 1af37ec6bf9c0edca7ea359b87258b9c65415641..5466a41d8fe15cedac35a64c6468c1693d94434c 100644 --- a/src/main/scala/inox/solvers/unrolling/FunctionTemplates.scala +++ b/src/main/scala/inox/solvers/unrolling/FunctionTemplates.scala @@ -156,7 +156,6 @@ trait FunctionTemplates { self: Templates => // We connect it to the defBlocker: blocker => defBlocker if (defBlocker != blocker) { newCls += mkImplies(blocker, defBlocker) - impliesBlocker(blocker, defBlocker) } ctx.reporter.debug("Unrolling behind "+call+" ("+newCls.size+")") diff --git a/src/main/scala/inox/solvers/unrolling/LambdaTemplates.scala b/src/main/scala/inox/solvers/unrolling/LambdaTemplates.scala index 83eb4548aca9d9f51770fdc4edf9a1194a0223e1..2fd60dba7d678966445f47d420b90c6263fb76c7 100644 --- a/src/main/scala/inox/solvers/unrolling/LambdaTemplates.scala +++ b/src/main/scala/inox/solvers/unrolling/LambdaTemplates.scala @@ -76,21 +76,12 @@ trait LambdaTemplates { self: Templates => } new LambdaTemplate( - ids, - pathVar, - arguments, - condVars, - exprVars, - condTree, - clauses, - blockers, - applications, - matchers, - lambdas, - quantifications, + ids, pathVar, arguments, + condVars, exprVars, condTree, + clauses, blockers, applications, matchers, + lambdas, quantifications, structure, - lambda, - lambdaString + lambda, lambdaString ) } } @@ -207,7 +198,7 @@ trait LambdaTemplates { self: Templates => val quantifications: Seq[QuantificationTemplate], val structure: LambdaStructure, val lambda: Lambda, - stringRepr: () => String) extends Template { + private[unrolling] val stringRepr: () => String) extends Template { val args = arguments.map(_._2) val tpe = bestRealType(ids._1.getType).asInstanceOf[FunctionType] @@ -231,15 +222,23 @@ trait LambdaTemplates { self: Templates => ids._1 -> idT, pathVar, arguments, condVars, exprVars, condTree, clauses map substituter, // make sure the body-defining clause is inlined! blockers, applications, matchers, lambdas, quantifications, - structure, lambda, stringRepr - ) + structure, lambda, stringRepr) } private lazy val str : String = stringRepr() override def toString : String = str - override def instantiate(substMap: Map[Encoded, Arg]): Clauses = { - super.instantiate(substMap) ++ instantiateAxiom(this, substMap) + /** When instantiating closure templates, we want to preserve the condition + * under which the associated closure can be evaluated in the program + * (namely `pathVar._2`), as well as the condition under which the current + * application can take place. We therefore have + * {{{ + * aVar && pathVar._2 ==> instantiation + * }}} + */ + override def instantiate(aVar: Encoded, args: Seq[Arg]): Clauses = { + val (freshBlocker, eqClauses) = encodeBlockers(Set(aVar, pathVar._2)) + eqClauses ++ super.instantiate(freshBlocker, args) } } @@ -275,14 +274,12 @@ trait LambdaTemplates { self: Templates => private def typeUnroller(blocker: Encoded, app: App): Clauses = typeBlockers.get(app.encoded) match { case Some(typeBlocker) => - impliesBlocker(blocker, typeBlocker) Seq(mkImplies(blocker, typeBlocker)) case None => val App(caller, tpe @ FirstOrderFunctionType(_, to), args, value) = app val typeBlocker = encodeSymbol(Variable(FreshIdentifier("t"), BooleanType)) typeBlockers += value -> typeBlocker - impliesBlocker(blocker, typeBlocker) val clauses = registerSymbol(typeBlocker, value, to) @@ -290,7 +287,6 @@ trait LambdaTemplates { self: Templates => (blocker, Seq.empty) } else { val firstB = encodeSymbol(Variable(FreshIdentifier("b_free", true), BooleanType)) - impliesBlocker(firstB, typeBlocker) typeEnablers += firstB val nextB = encodeSymbol(Variable(FreshIdentifier("b_or", true), BooleanType)) @@ -502,7 +498,6 @@ trait LambdaTemplates { self: Templates => val enabler = if (equals == trueT) b else mkAnd(equals, b) newCls += mkImplies(enabler, lambdaBlocker) - impliesBlocker(b, lambdaBlocker) ctx.reporter.debug("Unrolling behind "+info+" ("+newCls.size+")") for (cl <- newCls) { diff --git a/src/main/scala/inox/solvers/unrolling/QuantificationTemplates.scala b/src/main/scala/inox/solvers/unrolling/QuantificationTemplates.scala index a35fbb3195be9f14d2faba94628c72203259388c..250075f49bfc89dbcf8c0228546b055a0bca3040 100644 --- a/src/main/scala/inox/solvers/unrolling/QuantificationTemplates.scala +++ b/src/main/scala/inox/solvers/unrolling/QuantificationTemplates.scala @@ -5,6 +5,7 @@ package solvers package unrolling import utils._ +import evaluators._ import scala.collection.mutable.{Map => MutableMap, Set => MutableSet, Stack => MutableStack, Queue} @@ -16,8 +17,6 @@ trait QuantificationTemplates { self: Templates => import lambdasManager._ import quantificationsManager._ - def hasQuantifiers = quantifications.nonEmpty - /* -- Extraction helpers -- */ object QuantificationMatcher { @@ -27,8 +26,8 @@ trait QuantificationTemplates { self: Templates => case Some((c, prevArgs)) => Some((c, prevArgs ++ args)) case None => None } - case Application(caller, args) => Some((caller, args)) - case _ => None + case Application(caller, args) => Some((caller, args)) + case _ => None } def unapply(expr: Expr): Option[(Expr, Seq[Expr])] = expr match { @@ -60,12 +59,32 @@ trait QuantificationTemplates { self: Templates => /* -- Quantifier template definitions -- */ - class QuantificationTemplate( + /** Represents the polarity of the quantification within the considered + * formula. Positive and negative polarity enable optimizations during + * quantifier instantiation. + * + * Unknown polarity is treated conservatively (subsumes both positive and + * negative cases). + */ + sealed abstract class Polarity { + def substitute(substituter: Encoded => Encoded): Polarity = this match { + case Positive(guardVar) => Positive(guardVar) + case Negative(insts) => Negative(insts._1 -> substituter(insts._2)) + case Unknown(qs, q2s, insts, guardVar) => Unknown(qs._1 -> substituter(qs._2), q2s, insts, guardVar) + } + } + + case class Positive(guardVar: Encoded) extends Polarity + case class Negative(insts: (Variable, Encoded)) extends Polarity + case class Unknown( + qs: (Variable, Encoded), + q2s: (Variable, Encoded), + insts: (Variable, Encoded), + guardVar: Encoded) extends Polarity + + class QuantificationTemplate private[QuantificationTemplates] ( val pathVar: (Variable, Encoded), - val qs: (Variable, Encoded), - val q2s: (Variable, Encoded), - val insts: (Variable, Encoded), - val guardVar: Encoded, + val polarity: Polarity, val quantifiers: Seq[(Variable, Encoded)], val condVars: Map[Variable, Encoded], val exprVars: Map[Variable, Encoded], @@ -75,313 +94,255 @@ trait QuantificationTemplates { self: Templates => val applications: Apps, val matchers: Matchers, val lambdas: Seq[LambdaTemplate], - val structure: Forall, - val dependencies: Map[Variable, Encoded], - val forall: Forall, + val quantifications: Seq[QuantificationTemplate], + val key: (Seq[ValDef], Expr, Seq[Encoded]), + val body: Expr, stringRepr: () => String) { lazy val start = pathVar._2 - lazy val key: (Forall, Seq[Encoded]) = (structure, { - var cls: Seq[Encoded] = Seq.empty - exprOps.preTraversal { - case v: Variable => cls ++= dependencies.get(v) - case _ => - } (structure) - cls - }) - - def substitute(substituter: Encoded => Encoded, msubst: Map[Encoded, Matcher]) = new QuantificationTemplate( - pathVar._1 -> substituter(start), - qs, q2s, insts, guardVar, quantifiers, condVars, exprVars, condTree, - clauses.map(substituter), - blockers.map { case (b, fis) => substituter(b) -> fis.map(_.substitute(substituter, msubst)) }, - applications.map { case (b, apps) => substituter(b) -> apps.map(_.substitute(substituter, msubst)) }, - matchers.map { case (b, ms) => substituter(b) -> ms.map(_.substitute(substituter, msubst)) }, - lambdas.map(_.substitute(substituter, msubst)), - structure, dependencies.map { case (id, value) => id -> substituter(value) }, - forall, stringRepr) + lazy val mapping: Map[Variable, Encoded] = polarity match { + case Positive(_) => Map.empty + case Negative(insts) => Map(insts) + case Unknown(qs, _, _, _) => Map(qs) + } + + def substitute(substituter: Encoded => Encoded, msubst: Map[Encoded, Matcher]): QuantificationTemplate = + new QuantificationTemplate(pathVar._1 -> substituter(start), polarity.substitute(substituter), + quantifiers, condVars, exprVars, condTree, clauses.map(substituter), + blockers.map { case (b, fis) => substituter(b) -> fis.map(_.substitute(substituter, msubst)) }, + applications.map { case (b, apps) => substituter(b) -> apps.map(_.substitute(substituter, msubst)) }, + matchers.map { case (b, ms) => substituter(b) -> ms.map(_.substitute(substituter, msubst)) }, + lambdas.map(_.substitute(substituter, msubst)), + quantifications.map(_.substitute(substituter, msubst)), + (key._1, key._2, key._3.map(substituter)), + body, stringRepr) private lazy val str : String = stringRepr() override def toString : String = str } object QuantificationTemplate { + def templateKey(quantifiers: Seq[ValDef], expr: Expr, substMap: Map[Variable, Encoded]): (Seq[ValDef], Expr, Seq[Encoded]) = { + val (vals, struct, deps) = normalizeStructure(quantifiers, expr) + val encoder = mkEncoder(substMap) _ + val depClosures = deps.toSeq.sortBy(_._1.id.uniqueName).map(p => encoder(p._2)) + (vals, struct, depClosures) + } + def apply( pathVar: (Variable, Encoded), - qs: (Variable, Encoded), - q2: Variable, - inst: Variable, - guard: Variable, + optPol: Option[Boolean], + p: Expr, quantifiers: Seq[(Variable, Encoded)], condVars: Map[Variable, Encoded], exprVars: Map[Variable, Encoded], condTree: Map[Variable, Set[Variable]], guardedExprs: Map[Variable, Seq[Expr]], lambdas: Seq[LambdaTemplate], + quantifications: Seq[QuantificationTemplate], baseSubstMap: Map[Variable, Encoded], - dependencies: Map[Variable, Encoded], proposition: Forall - ): QuantificationTemplate = { + ): (Option[Variable], QuantificationTemplate) = { + + val (optVar, polarity, extraGuarded, extraSubst) = optPol match { + case Some(true) => + val guard: Variable = Variable(FreshIdentifier("guard", true), BooleanType) + val guards = guard -> encodeSymbol(guard) + (None, Positive(guards._2), Map(pathVar._1 -> Seq(Implies(guard, p))), Map(guards)) - val q2s: (Variable, Encoded) = q2 -> encodeSymbol(q2) - val insts: (Variable, Encoded) = inst -> encodeSymbol(inst) - val guards: (Variable, Encoded) = guard -> encodeSymbol(guard) + case Some(false) => + val inst: Variable = Variable(FreshIdentifier("inst", true), BooleanType) + val insts = inst -> encodeSymbol(inst) + (Some(inst), Negative(insts), Map(pathVar._1 -> Seq(Equals(inst, p))), Map(insts)) + + case None => + val q: Variable = Variable(FreshIdentifier("q", true), BooleanType) + val q2: Variable = Variable(FreshIdentifier("qo", true), BooleanType) + val inst: Variable = Variable(FreshIdentifier("inst", true), BooleanType) + val guard: Variable = Variable(FreshIdentifier("guard", true), BooleanType) + + val qs = q -> encodeSymbol(q) + val q2s = q2 -> encodeSymbol(q2) + val insts = inst -> encodeSymbol(inst) + val guards = guard -> encodeSymbol(guard) + + val polarity = Unknown(qs, q2s, insts, guards._2) + val extraGuarded = Map(pathVar._1 -> Seq(Equals(inst, Implies(guard, p)), Equals(q, And(q2, inst)))) + val extraSubst = Map(qs, q2s, insts, guards) + (Some(q), polarity, extraGuarded, extraSubst) + } + + val substMap = baseSubstMap ++ extraSubst + val allGuarded = guardedExprs merge extraGuarded val (clauses, blockers, applications, matchers, templateString) = - Template.encode(pathVar, quantifiers, condVars, exprVars, guardedExprs, lambdas, Seq.empty, - substMap = baseSubstMap + q2s + insts + guards + qs) + Template.encode(pathVar, quantifiers, condVars, exprVars, allGuarded, + lambdas, quantifications, substMap = substMap) - val (structuralQuant, deps) = normalizeStructure(proposition) - val keyDeps = deps.map { case (id, dep) => id -> encodeExpr(dependencies)(dep) } + val key = templateKey(proposition.args, proposition.body, substMap) - new QuantificationTemplate( - pathVar, qs, q2s, insts, guards._2, quantifiers, condVars, exprVars, condTree, - clauses, blockers, applications, matchers, lambdas, structuralQuant, keyDeps, proposition, - () => "Template for " + proposition + " is :\n" + templateString()) + (optVar, new QuantificationTemplate( + pathVar, polarity, quantifiers, condVars, exprVars, condTree, + clauses, blockers, applications, matchers, lambdas, quantifications, key, + proposition.body, () => "Template for " + proposition + " is :\n" + templateString())) } } private[unrolling] object quantificationsManager extends Manager { - val quantifications = new IncrementalSeq[MatcherQuantification] - - private[QuantificationTemplates] val instCtx = new InstantiationContext + val quantifications = new IncrementalSeq[Quantification] val ignoredMatchers = new IncrementalSeq[(Int, Set[Encoded], Matcher)] - val ignoredSubsts = new IncrementalMap[MatcherQuantification, MutableSet[(Int, Set[Encoded], Map[Encoded,Arg])]] - val handledSubsts = new IncrementalMap[MatcherQuantification, MutableSet[(Set[Encoded], Map[Encoded,Arg])]] + val handledMatchers = new IncrementalSeq[(Set[Encoded], Matcher)] + + val ignoredSubsts = new IncrementalMap[Quantification, Set[(Int, Set[Encoded], Map[Encoded,Arg])]] + val handledSubsts = new IncrementalMap[Quantification, Set[(Set[Encoded], Map[Encoded,Arg])]] val lambdaAxioms = new IncrementalSet[(LambdaStructure, Seq[(Variable, Encoded)])] - val templates = new IncrementalMap[(Expr, Seq[Encoded]), Encoded] + val templates = new IncrementalMap[(Seq[ValDef], Expr, Seq[Encoded]), Map[Encoded, Encoded]] val incrementals: Seq[IncrementalState] = Seq( - quantifications, instCtx, ignoredMatchers, ignoredSubsts, - handledSubsts, lambdaAxioms, templates) + quantifications, ignoredMatchers, handledMatchers, ignoredSubsts, handledSubsts, lambdaAxioms, templates) private def assumptions: Seq[Encoded] = - quantifications.collect { case q: Quantification => q.currentQ2Var }.toSeq + quantifications.collect { case q: GeneralQuantification => q.currentQ2Var }.toSeq + def satisfactionAssumptions = assumptions def refutationAssumptions = assumptions - } - - private var currentGen = 0 - - private sealed abstract class MatcherKey(val tpe: Type) - private case class CallerKey(caller: Encoded, tt: Type) extends MatcherKey(tt) - private case class LambdaKey(lambda: Lambda, tt: Type) extends MatcherKey(tt) - private case class TypeKey(tt: Type) extends MatcherKey(tt) - - private def matcherKey(caller: Encoded, tpe: Type): MatcherKey = tpe match { - case ft: FunctionType if knownFree(ft)(caller) => CallerKey(caller, tpe) - case _: FunctionType if byID.isDefinedAt(caller) => LambdaKey(byID(caller).structure.lambda, tpe) - case _ => TypeKey(tpe) - } - - @inline - private def correspond(qm: Matcher, m: Matcher): Boolean = - correspond(qm, m.caller, m.tpe) - - private def correspond(qm: Matcher, caller: Encoded, tpe: Type): Boolean = { - val qkey = matcherKey(qm.caller, qm.tpe) - val key = matcherKey(caller, tpe) - qkey == key || (qkey.tpe == key.tpe && (qkey.isInstanceOf[TypeKey] || key.isInstanceOf[TypeKey])) - } - class VariableNormalizer { - private val varMap: MutableMap[Type, Seq[Encoded]] = MutableMap.empty - private val varSet: MutableSet[Encoded] = MutableSet.empty - - def normalize(ids: Seq[Variable]): Seq[Encoded] = { - val mapping = ids.groupBy(id => bestRealType(id.getType)).flatMap { case (tpe, idst) => - val prev = varMap.get(tpe) match { - case Some(seq) => seq - case None => Seq.empty - } - - if (prev.size >= idst.size) { - idst zip prev.take(idst.size) - } else { - val (handled, newIds) = idst.splitAt(prev.size) - val uIds = newIds.map(id => id -> encodeSymbol(id)) - - varMap(tpe) = prev ++ uIds.map(_._2) - varSet ++= uIds.map(_._2) - - (handled zip prev) ++ uIds - } - }.toMap - - ids.map(mapping) + def unrollGeneration: Option[Int] = { + val gens: Seq[Int] = ignoredMatchers.toSeq.map(_._1) ++ ignoredSubsts.flatMap(p => p._2.map(_._1)) + if (gens.isEmpty) None else Some(gens.min) } - def normalSubst(qs: Seq[(Variable, Encoded)]): Map[Encoded, Encoded] = { - (qs.map(_._2) zip normalize(qs.map(_._1))).toMap - } + // promoting blockers makes no sense in this context + def promoteBlocker(b: Encoded): Boolean = false - def contains(idT: Encoded): Boolean = varSet(idT) - def get(tpe: Type): Option[Seq[Encoded]] = varMap.get(tpe) - } + def unroll: Clauses = { + val clauses = new scala.collection.mutable.ListBuffer[Encoded] - private val abstractNormalizer = new VariableNormalizer - private val concreteNormalizer = new VariableNormalizer + for (e @ (gen, bs, m) <- ignoredMatchers.toSeq if gen == currentGeneration) { + clauses ++= instantiateMatcher(bs, m) + ignoredMatchers -= e + } - def isQuantifier(idT: Encoded): Boolean = abstractNormalizer.contains(idT) + for (q <- quantifications.toSeq) { + val (release, keep) = ignoredSubsts(q).partition(_._1 == currentGeneration) + for ((_, bs, subst) <- release) clauses ++= q.instantiateSubst(bs, subst) + ignoredSubsts += q -> keep + } - def typeInstantiations: Map[Type, MatcherSet] = instCtx.map.instantiations.collect { - case (TypeKey(tpe), matchers) => tpe -> matchers + clauses.toSeq + } } - def lambdaInstantiations: Map[Lambda, MatcherSet] = instCtx.map.instantiations.collect { - case (LambdaKey(lambda, tpe), matchers) => lambda -> (matchers ++ instCtx.map.get(TypeKey(tpe)).toMatchers) - } + def instantiateMatcher(blocker: Encoded, matcher: Matcher): Clauses = + instantiateMatcher(Set(blocker), matcher) - def partialInstantiations: Map[Encoded, MatcherSet] = instCtx.map.instantiations.collect { - case (CallerKey(caller, tpe), matchers) => caller -> (matchers ++ instCtx.map.get(TypeKey(tpe)).toMatchers) + @inline + private def instantiateMatcher(blockers: Set[Encoded], matcher: Matcher): Clauses = { + handledMatchers += blockers -> matcher + quantifications.flatMap(_.instantiate(blockers, matcher)) } - private def maxDepth(m: Matcher): Int = 1 + (0 +: m.args.map { - case Right(ma) => maxDepth(ma) - case _ => 0 - }).max - - private def totalDepth(m: Matcher): Int = 1 + m.args.map { - case Right(ma) => totalDepth(ma) - case _ => 0 - }.sum - - private def encodeEnablers(es: Set[Encoded]): Encoded = - if (es.isEmpty) trueT else mkAnd(es.toSeq.sortBy(_.toString) : _*) - - private type MatcherSet = Set[(Encoded, Matcher)] - - private class Context private(ctx: Map[Matcher, Set[Set[Encoded]]]) extends Iterable[(Set[Encoded], Matcher)] { - def this() = this(Map.empty) - - def apply(p: (Set[Encoded], Matcher)): Boolean = ctx.get(p._2) match { - case None => false - case Some(blockerSets) => blockerSets(p._1) || blockerSets.exists(set => set.subsetOf(p._1)) - } - - def +(p: (Set[Encoded], Matcher)): Context = if (apply(p)) this else { - val prev = ctx.getOrElse(p._2, Set.empty) - val newSet = prev.filterNot(set => p._1.subsetOf(set)).toSet + p._1 - new Context(ctx + (p._2 -> newSet)) - } - - def ++(that: Context): Context = that.foldLeft(this)((ctx, p) => ctx + p) + def hasQuantifiers: Boolean = quantifications.nonEmpty - def iterator = ctx.toSeq.flatMap { case (m, bss) => bss.map(bs => bs -> m) }.iterator - def toMatchers: MatcherSet = this.map(p => encodeEnablers(p._1) -> p._2).toSet + def getInstantiationsWithBlockers = quantifications.toSeq.flatMap { + case q: GeneralQuantification => q.instantiations.toSeq + case _ => Seq.empty } - private class ContextMap( - private var tpeMap: MutableMap[Type, Context] = MutableMap.empty, - private var funMap: MutableMap[MatcherKey, Context] = MutableMap.empty - ) extends IncrementalState { - private val stack = new MutableStack[(MutableMap[Type, Context], MutableMap[MatcherKey, Context])] - - def clear(): Unit = { - stack.clear() - tpeMap.clear() - funMap.clear() - } - - def reset(): Unit = clear() - - def push(): Unit = { - stack.push((tpeMap, funMap)) - tpeMap = tpeMap.clone - funMap = funMap.clone - } - - def pop(): Unit = { - val (ptpeMap, pfunMap) = stack.pop() - tpeMap = ptpeMap - funMap = pfunMap - } - - def +=(p: (Set[Encoded], Matcher)): Unit = matcherKey(p._2.caller, p._2.tpe) match { - case TypeKey(tpe) => tpeMap(tpe) = tpeMap.getOrElse(tpe, new Context) + p - case key => funMap(key) = funMap.getOrElse(key, new Context) + p - } - - def merge(that: ContextMap): this.type = { - for ((tpe, values) <- that.tpeMap) tpeMap(tpe) = tpeMap.getOrElse(tpe, new Context) ++ values - for ((caller, values) <- that.funMap) funMap(caller) = funMap.getOrElse(caller, new Context) ++ values - this - } - - def get(caller: Encoded, tpe: Type): Context = - funMap.getOrElse(matcherKey(caller, tpe), new Context) ++ tpeMap.getOrElse(tpe, new Context) - - def get(key: MatcherKey): Context = key match { - case TypeKey(tpe) => tpeMap.getOrElse(tpe, new Context) - case key => funMap.getOrElse(key, new Context) ++ tpeMap.getOrElse(key.tpe, new Context) - } - - def instantiations: Map[MatcherKey, MatcherSet] = - (funMap.toMap ++ tpeMap.map { case (tpe,ms) => TypeKey(tpe) -> ms }).mapValues(_.toMatchers) + private sealed trait MatcherKey + private case class FunctionKey(tfd: TypedFunDef) extends MatcherKey + private sealed abstract class TypedKey(val tpe: Type) extends MatcherKey + private case class CallerKey(caller: Encoded, tt: Type) extends TypedKey(tt) + private case class LambdaKey(lambda: Lambda, tt: Type) extends TypedKey(tt) + private case class TypeKey(tt: Type) extends TypedKey(tt) + + private def matcherKey(key: Either[(Encoded, Type), TypedFunDef]): MatcherKey = key match { + case Right(tfd) => FunctionKey(tfd) + case Left((caller, ft: FunctionType)) if knownFree(ft)(caller) => CallerKey(caller, ft) + case Left((caller, ft: FunctionType)) if byID.isDefinedAt(caller) => LambdaKey(byID(caller).structure.lambda, ft) + case Left((_, tpe)) => TypeKey(tpe) } - private class InstantiationContext private ( - private var _instantiated : Context, val map : ContextMap - ) extends IncrementalState { + @inline + private def matcherKey(m: Matcher): MatcherKey = matcherKey(m.key) - private val stack = new MutableStack[Context] + private def correspond(k1: MatcherKey, k2: MatcherKey): Boolean = + k1 == k2 || ((k1, k2) match { + case (TypeKey(tp1), TypeKey(tp2)) => tp1 == tp2 + case _ => false + }) - def this() = this(new Context, new ContextMap) + @inline + private def correspond(m1: Matcher, m2: Matcher): Boolean = + correspond(matcherKey(m1), matcherKey(m2)) - def clear(): Unit = { - stack.clear() - map.clear() - _instantiated = new Context - } + private class GroundSet private( + private val map: MutableMap[Arg, MutableSet[Set[Encoded]]] + ) extends Iterable[(Set[Encoded], Arg)] { - def reset(): Unit = clear() + def this() = this(MutableMap.empty) - def push(): Unit = { - stack.push(_instantiated) - map.push() + def apply(p: (Set[Encoded], Arg)): Boolean = map.get(p._2) match { + case Some(blockerSets) => blockerSets(p._1) || + // we assume here that iterating through the powerset of `p._1` + // will be significantly faster then iterating through `blockerSets` + p._1.subsets.exists(set => blockerSets(set)) + case None => false } - def pop(): Unit = { - _instantiated = stack.pop() - map.pop() + def +=(p: (Set[Encoded], Arg)): Unit = if (!this(p)) map.get(p._2) match { + case Some(blockerSets) => blockerSets += p._1 + case None => map(p._2) = MutableSet.empty + p._1 } - def instantiated: Context = _instantiated - def apply(p: (Set[Encoded], Matcher)): Boolean = _instantiated(p) + def iterator: Iterator[(Set[Encoded], Arg)] = new collection.AbstractIterator[(Set[Encoded], Arg)] { + private val mapIt: Iterator[(Arg, MutableSet[Set[Encoded]])] = GroundSet.this.map.iterator + private var setIt: Iterator[Set[Encoded]] = Iterator.empty + private var current: Arg = _ - def corresponding(m: Matcher): Context = map.get(m.caller, m.tpe) - - def instantiate(blockers: Set[Encoded], matcher: Matcher)(qs: MatcherQuantification*): Clauses = { - if (this(blockers -> matcher)) { - Seq.empty + def hasNext = mapIt.hasNext || setIt.hasNext + def next: (Set[Encoded], Arg) = if (setIt.hasNext) { + val bs = setIt.next + bs -> current } else { - map += (blockers -> matcher) - _instantiated += (blockers -> matcher) - qs.flatMap(_.instantiate(blockers, matcher)) + val (e, bss) = mapIt.next + current = e + setIt = bss.iterator + next } } - def merge(that: InstantiationContext): this.type = { - _instantiated ++= that._instantiated - map.merge(that.map) - this + override def clone: GroundSet = { + val newMap: MutableMap[Arg, MutableSet[Set[Encoded]]] = MutableMap.empty + for ((e, bss) <- map) { + newMap += e -> bss.clone + } + new GroundSet(newMap) } } - private[solvers] trait MatcherQuantification { + private def totalDepth(m: Matcher): Int = 1 + m.args.map { + case Right(ma) => totalDepth(ma) + case _ => 0 + }.sum + + private def encodeEnablers(es: Set[Encoded]): Encoded = + if (es.isEmpty) trueT else mkAnd(es.toSeq.sortBy(_.toString) : _*) + + private[solvers] trait Quantification { val pathVar: (Variable, Encoded) val quantifiers: Seq[(Variable, Encoded)] - val matchers: Set[Matcher] - val allMatchers: Matchers val condVars: Map[Variable, Encoded] val exprVars: Map[Variable, Encoded] val condTree: Map[Variable, Set[Variable]] val clauses: Clauses val blockers: Calls val applications: Apps + val matchers: Matchers val lambdas: Seq[LambdaTemplate] + val quantifications: Seq[QuantificationTemplate] val holds: Encoded val body: Expr @@ -389,144 +350,120 @@ trait QuantificationTemplates { self: Templates => lazy val quantified: Set[Encoded] = quantifiers.map(_._2).toSet lazy val start = pathVar._2 - private lazy val depth = matchers.map(maxDepth).max - private lazy val transMatchers: Set[Matcher] = (for { - (b, ms) <- allMatchers.toSeq - m <- ms if !matchers(m) && maxDepth(m) <= depth - } yield m).toSet - - /* Build a mapping from applications in the quantified statement to all potential concrete - * applications previously encountered. Also make sure the current `app` is in the mapping - * as other instantiations have been performed previously when the associated applications - * were first encountered. - */ - private def mappings(bs: Set[Encoded], matcher: Matcher): Set[Set[(Set[Encoded], Matcher, Matcher)]] = { - /* 1. select an application in the quantified proposition for which the current app can - * be bound when generating the new constraints - */ - matchers.filter(qm => correspond(qm, matcher)) - - /* 2. build the instantiation mapping associated to the chosen current application binding */ - .flatMap { bindingMatcher => - - /* 2.1. select all potential matches for each quantified application */ - val matcherToInstances = matchers - .map(qm => if (qm == bindingMatcher) { - bindingMatcher -> Set(bs -> matcher) - } else { - qm -> instCtx.corresponding(qm) - }).toMap - - /* 2.2. based on the possible bindings for each quantified application, build a set of - * instantiation mappings that can be used to instantiate all necessary constraints - */ - val allMappings = matcherToInstances.foldLeft[Set[Set[(Set[Encoded], Matcher, Matcher)]]](Set(Set.empty)) { - case (mappings, (qm, instances)) => Set(instances.toSeq.flatMap { - case (bs, m) => mappings.map(mapping => mapping + ((bs, qm, m))) - } : _*) - } + private val constraints: Seq[(Encoded, MatcherKey, Int)] = (for { + (_, ms) <- matchers + m <- ms + (arg,i) <- m.args.zipWithIndex + q <- arg.left.toOption if quantified(q) + } yield (q, matcherKey(m), i)).toSeq - allMappings - } - } + private val groupedConstraints: Map[Encoded, Seq[(MatcherKey, Int)]] = + constraints.groupBy(_._1).map(p => p._1 -> p._2.map(p2 => (p2._2, p2._3))) - private def extractSubst(mapping: Set[(Set[Encoded], Matcher, Matcher)]): (Set[Encoded], Map[Encoded,Arg], Boolean) = { - var constraints: Set[Encoded] = Set.empty - var eqConstraints: Set[(Encoded, Encoded)] = Set.empty - var subst: Map[Encoded, Arg] = Map.empty + private val grounds: Map[Encoded, GroundSet] = quantified.map(q => q -> new GroundSet).toMap - var matcherEqs: Set[(Encoded, Encoded)] = Set.empty - def strictnessCnstr(qarg: Arg, arg: Arg): Unit = (qarg, arg) match { - case (Right(qam), Right(am)) => (qam.args zip am.args).foreach(p => strictnessCnstr(p._1, p._2)) - case _ => matcherEqs += qarg.encoded -> arg.encoded - } + def instantiate(bs: Set[Encoded], m: Matcher): Clauses = { - for { - (bs, qm @ Matcher(qcaller, _, qargs, _), m @ Matcher(caller, _, args, _)) <- mapping - _ = constraints ++= bs - (qarg, arg) <- (qargs zip args) - _ = strictnessCnstr(qarg, arg) - } qarg match { - case Left(quant) if !quantified(quant) || subst.isDefinedAt(quant) => - eqConstraints += (quant -> arg.encoded) - case Left(quant) if quantified(quant) => - subst += quant -> arg - case Right(qam) => - eqConstraints += (qam.encoded -> arg.encoded) - } + /* Build mappings from quantifiers to all potential ground values previously encountered. */ + val quantToGround = (for ((q, constraints) <- groupedConstraints) yield { + q -> (grounds(q).toSet ++ constraints.flatMap { case (key, i) => + if (correspond(matcherKey(m), key)) Some(bs -> m.args(i)) else None + }) + }).toMap + + /* Transform the map to sequences into a sequence of maps making sure that the current + * matcher is part of the mapping (otherwise, instantiation has already taken place). */ + var mappings: Seq[(Set[Encoded], Map[Encoded, Arg])] = Seq.empty + for ((q, constraints) <- groupedConstraints; + (key, i) <- constraints if correspond(matcherKey(m), key) && !grounds(q)(bs -> m.args(i))) { + mappings ++= (quantified - q).foldLeft(Seq(bs -> Map(q -> m.args(i)))) { + case (maps, oq) => for { + (bs, map) <- maps + groundSet <- quantToGround.get(oq).toSeq + (ibs, inst) <- groundSet + } yield (bs ++ ibs, map + (oq -> inst)) + } - val substituter = mkSubstituter(subst.mapValues(_.encoded)) - val substConstraints = constraints.filter(_ != trueT).map(substituter) - val substEqs = eqConstraints.map(p => substituter(p._1) -> p._2) - .filter(p => p._1 != p._2).map(p => mkEquals(p._1, p._2)) - val enablers = substConstraints ++ substEqs - val isStrict = matcherEqs.forall(p => substituter(p._1) == p._2) + // register ground instantiation for future instantiations + grounds(q) += bs -> m.args(i) + } - (enablers, subst, isStrict) + instantiateSubsts(mappings) } - def instantiate(bs: Set[Encoded], matcher: Matcher): Clauses = { - var clauses: Clauses = Seq.empty + def ensureGrounds: Clauses = { + /* Build mappings from quantifiers to all potential ground values previously encountered + * AND the constants we're introducing to make sure grounds are non-empty. */ + val quantToGround = (for (q <- quantified) yield { + val groundsSet = grounds(q).toSet + q -> (groundsSet ++ (if (groundsSet.isEmpty) Some(Set.empty[Encoded] -> Left(q)) else None)) + }).toMap - for (mapping <- mappings(bs, matcher)) { - val (enablers, subst, isStrict) = extractSubst(mapping) + /* Generate the sequence of all relevant instantiation mappings */ + var mappings: Seq[(Set[Encoded], Map[Encoded, Arg])] = Seq.empty + for (q <- quantified if grounds(q).isEmpty) { + mappings ++= (quantified - q).foldLeft(Seq(Set.empty[Encoded] -> Map[Encoded, Arg](q -> Left(q)))) { + case (maps, oq) => for ((bs, map) <- maps; (ibs, inst) <- quantToGround(oq)) yield (bs ++ ibs, map + (oq -> inst)) + } - if (!skip(subst)) { - if (!isStrict) { - val msubst = subst.collect { case (c, Right(m)) => c -> m } - val substituter = mkSubstituter(subst.mapValues(_.encoded)) - ignoredSubsts(this) += ((currentGen + 3, enablers, subst)) - } else { - clauses ++= instantiateSubst(enablers, subst, strict = true) - } + grounds(q) += Set.empty[Encoded] -> Left(q) + } + + instantiateSubsts(mappings) + } + + private def instantiateSubsts(substs: Seq[(Set[Encoded], Map[Encoded, Arg])]): Clauses = { + val instantiation = new scala.collection.mutable.ListBuffer[Encoded] + for (p @ (bs, subst) <- substs if !handledSubsts(this)(p)) { + if (subst.values.exists(_.isRight)) { + ignoredSubsts += this -> (ignoredSubsts.getOrElse(this, Set.empty) + ((currentGeneration + 3, bs, subst))) + } else { + instantiation ++= instantiateSubst(bs, subst) } } - clauses + instantiation.toSeq } - def instantiateSubst(enablers: Set[Encoded], subst: Map[Encoded, Arg], strict: Boolean = false): Clauses = { - if (handledSubsts(this)(enablers -> subst)) { - Seq.empty - } else { - handledSubsts(this) += enablers -> subst + def instantiateSubst(bs: Set[Encoded], subst: Map[Encoded, Arg]): Clauses = { + handledSubsts += this -> (handledSubsts.getOrElse(this, Set.empty) + (bs -> subst)) + val instantiation = new scala.collection.mutable.ListBuffer[Encoded] - var clauses: Clauses = Seq.empty - val (enabler, optEnabler) = freshBlocker(enablers) - if (optEnabler.isDefined) { - clauses :+= mkEquals(enabler, optEnabler.get) - } + val (enabler, enablerClauses) = encodeBlockers(bs) + instantiation ++= enablerClauses - val baseSubst = subst ++ instanceSubst(enabler).mapValues(Left(_)) - val (substMap, cls) = Template.substitution( - condVars, exprVars, condTree, lambdas, Seq.empty, baseSubst, pathVar._1, enabler) - clauses ++= cls + val baseSubst = subst ++ instanceSubst(enabler).mapValues(Left(_)) + val (substMap, substClauses) = Template.substitution( + condVars, exprVars, condTree, lambdas, quantifications, baseSubst, pathVar._1, enabler) + instantiation ++= substClauses - val msubst = substMap.collect { case (c, Right(m)) => c -> m } - val substituter = mkSubstituter(substMap.mapValues(_.encoded)) - registerBlockers(substituter) + val msubst = substMap.collect { case (c, Right(m)) => c -> m } + val substituter = mkSubstituter(substMap.mapValues(_.encoded)) + registerBlockers(substituter) - clauses ++= Template.instantiate(clauses, blockers, applications, Map.empty, substMap) + // matcher instantiation must be manually controlled here to avoid never-ending loops + instantiation ++= Template.instantiate(clauses, blockers, applications, Map.empty, substMap) - for ((b,ms) <- allMatchers; m <- ms) { - val sb = enablers ++ (if (b == start) Set.empty else Set(substituter(b))) - val sm = m.substitute(substituter, msubst) + for ((b,ms) <- matchers; m <- ms) { + val sb = bs ++ (if (b == start) Set.empty else Set(substituter(b))) + val sm = m.substitute(substituter, msubst) - if (strict && (matchers(m) || transMatchers(m))) { - clauses ++= instCtx.instantiate(sb, sm)(quantifications.toSeq : _*) - } else if (!matchers(m)) { - ignoredMatchers += ((currentGen + 2 + totalDepth(m), sb, sm)) - } - } + def abs(i: Int): Int = if (i < 0) -i else i + val nextGeneration: Int = currentGeneration + + 2 * (abs(totalDepth(sm) - totalDepth(m)) + (if (b == start) 0 else 1)) - clauses + if (nextGeneration == currentGeneration) { + instantiation ++= instantiateMatcher(sb, sm) + } else { + ignoredMatchers += ((nextGeneration, sb, sm)) + } } + + instantiation.toSeq } protected def instanceSubst(enabler: Encoded): Map[Encoded, Encoded] - protected def skip(subst: Map[Encoded, Arg]): Boolean = false - protected def registerBlockers(substituter: Encoded => Encoded): Unit = () def checkForall: Option[String] = { @@ -583,32 +520,32 @@ trait QuantificationTemplates { self: Templates => } } - private class Quantification ( + private class GeneralQuantification ( val pathVar: (Variable, Encoded), val qs: (Variable, Encoded), val q2s: (Variable, Encoded), val insts: (Variable, Encoded), val guardVar: Encoded, val quantifiers: Seq[(Variable, Encoded)], - val matchers: Set[Matcher], - val allMatchers: Matchers, val condVars: Map[Variable, Encoded], val exprVars: Map[Variable, Encoded], val condTree: Map[Variable, Set[Variable]], val clauses: Clauses, val blockers: Calls, val applications: Apps, + val matchers: Matchers, val lambdas: Seq[LambdaTemplate], - val template: QuantificationTemplate) extends MatcherQuantification { + val quantifications: Seq[QuantificationTemplate], + val body: Expr) extends Quantification { private var _currentQ2Var: Encoded = qs._2 def currentQ2Var = _currentQ2Var val holds = qs._2 - val body = template.forall.body private var _insts: Map[Encoded, Set[Encoded]] = Map.empty def instantiations = _insts + private val blocker = Variable(FreshIdentifier("b_fresh", true), BooleanType) protected def instanceSubst(enabler: Encoded): Map[Encoded, Encoded] = { val nextQ2Var = encodeSymbol(q2s._1) @@ -626,349 +563,192 @@ trait QuantificationTemplates { self: Templates => } } - private lazy val blockerSymbol = Variable(FreshIdentifier("blocker", true), BooleanType) - private lazy val enablersToBlocker: MutableMap[Set[Encoded], Encoded] = MutableMap.empty - private lazy val blockerToEnablers: MutableMap[Encoded, Set[Encoded]] = MutableMap.empty - private def freshBlocker(enablers: Set[Encoded]): (Encoded, Option[Encoded]) = enablers.toSeq match { - case Seq(b) if isBlocker(b) => (b, None) - case _ => - val last = enablersToBlocker.get(enablers).orElse { - val initialEnablers = enablers.flatMap(e => blockerToEnablers.getOrElse(e, Set(e))) - enablersToBlocker.get(initialEnablers) - } - - last match { - case Some(b) => (b, None) - case None => - val nb = encodeSymbol(blockerSymbol) - enablersToBlocker += enablers -> nb - blockerToEnablers += nb -> enablers - for (b <- enablers if isBlocker(b)) impliesBlocker(b, nb) - blocker(nb) - - (nb, Some(encodeEnablers(enablers))) - } - } - - private class LambdaAxiom ( + private class Axiom ( val pathVar: (Variable, Encoded), - val blocker: Encoded, val guardVar: Encoded, val quantifiers: Seq[(Variable, Encoded)], - val matchers: Set[Matcher], - val allMatchers: Map[Encoded, Set[Matcher]], val condVars: Map[Variable, Encoded], val exprVars: Map[Variable, Encoded], val condTree: Map[Variable, Set[Variable]], val clauses: Clauses, val blockers: Calls, val applications: Apps, + val matchers: Matchers, val lambdas: Seq[LambdaTemplate], - val template: LambdaTemplate) extends MatcherQuantification { + val quantifications: Seq[QuantificationTemplate], + val body: Expr) extends Quantification { - val holds = start - val body = template.lambda.body + val holds = trueT protected def instanceSubst(enabler: Encoded): Map[Encoded, Encoded] = { - Map(guardVar -> start, blocker -> enabler) - } - - override protected def skip(subst: Map[Encoded, Arg]): Boolean = { - val substituter = mkSubstituter(subst.mapValues(_.encoded)) - val msubst = subst.collect { case (c, Right(m)) => c -> m } - allMatchers.forall { case (b, ms) => - ms.forall(m => matchers(m) || instCtx(Set(substituter(b)) -> m.substitute(substituter, msubst))) - } + Map(guardVar -> enabler) } } - private def extractQuorums( - quantified: Set[Encoded], - matchers: Set[Matcher], - lambdas: Seq[LambdaTemplate] - ): Seq[Set[Matcher]] = { - val extMatchers: Set[Matcher] = { - def rec(templates: Seq[LambdaTemplate]): Set[Matcher] = - templates.foldLeft(Set.empty[Matcher]) { - case (matchers, template) => matchers ++ template.matchers.flatMap(_._2) ++ rec(template.lambdas) - } + def instantiateAxiom(template: LambdaTemplate): Clauses = { + val quantifiers = template.arguments.map { p => p._1.freshen -> encodeSymbol(p._1) } - matchers ++ rec(lambdas) - } + val app = mkApplication(template.ids._1, quantifiers.map(_._1)) + val appT = mkEncoder(quantifiers.toMap + template.ids)(app) + val selfMatcher = Matcher(Left(template.ids._2 -> template.tpe), quantifiers.map(p => Left(p._2)), appT) - val quantifiedMatchers = for { - m @ Matcher(_, _, args, _) <- extMatchers - if args exists (_.left.exists(quantified)) - } yield m + val blocker = Variable(FreshIdentifier("blocker", true), BooleanType) + val blockerT = encodeSymbol(blocker) - extractQuorums(quantifiedMatchers, quantified, - (m: Matcher) => m.args.collect { case Right(m) if quantifiedMatchers(m) => m }.toSet, - (m: Matcher) => m.args.collect { case Left(a) if quantified(a) => a }.toSet) - } + val guard = Variable(FreshIdentifier("guard", true), BooleanType) + val guardT = encodeSymbol(guard) - def instantiateAxiom(template: LambdaTemplate, substMap: Map[Encoded, Arg]): Clauses = { - def quantifiedMatcher(m: Matcher): Boolean = m.args.exists(a => a match { - case Left(v) => isQuantifier(v) - case Right(m) => quantifiedMatcher(m) - }) + val enablingClause = mkEquals(mkAnd(guardT, template.start), blockerT) - val quantified = template.arguments flatMap { - case (id, idT) => substMap(idT) match { - case Left(v) if isQuantifier(v) => Some(id) - case Right(m) if quantifiedMatcher(m) => Some(id) - case _ => None - } - } - - val quantifiers = quantified zip abstractNormalizer.normalize(quantified) - val key = template.structure -> quantifiers - - if (quantifiers.isEmpty || lambdaAxioms(key)) { - Seq.empty - } else { - lambdaAxioms += key - val blockerT = encodeSymbol(blockerSymbol) - - val guard = Variable(FreshIdentifier("guard", true), BooleanType) - val guardT = encodeSymbol(guard) - - val substituter = mkSubstituter(substMap.mapValues(_.encoded) + (template.start -> blockerT)) - val msubst = substMap.collect { case (c, Right(m)) => c -> m } + /* Compute Axiom's unique key to avoid redudant instantiations */ - val allMatchers = template.matchers map { case (b, ms) => - substituter(b) -> ms.map(_.substitute(substituter, msubst)) - } - - val qMatchers = allMatchers.flatMap(_._2).toSet - - val encArgs = template.args map (arg => Left(arg).substitute(substituter, msubst)) - val app = Application(template.ids._1, template.arguments.map(_._1)) - val appT = encodeExpr((template.arguments.map(_._1) zip encArgs.map(_.encoded)).toMap + template.ids)(app) - val selfMatcher = Matcher(template.ids._2, template.tpe, encArgs, appT) - - val instMatchers = allMatchers + (template.start -> (allMatchers.getOrElse(template.start, Set.empty) + selfMatcher)) - - val enablingClause = mkImplies(guardT, blockerT) - - val condVars = template.condVars map { case (id, idT) => id -> substituter(idT) } - val exprVars = template.exprVars map { case (id, idT) => id -> substituter(idT) } - val clauses = (template.clauses map substituter) :+ enablingClause - val blockers = template.blockers map { case (b, fis) => - substituter(b) -> fis.map(fi => fi.copy(args = fi.args.map(_.substitute(substituter, msubst)))) - } - - val applications = template.applications map { case (b, apps) => - substituter(b) -> apps.map(app => app.copy( - caller = substituter(app.caller), - args = app.args.map(_.substitute(substituter, msubst)) - )) - } - - val lambdas = template.lambdas map (_.substitute(substituter, msubst)) - - val quantified = quantifiers.map(_._2).toSet - val matchQuorums = extractQuorums(quantified, qMatchers, lambdas) + def flattenLambda(e: Expr): (Seq[ValDef], Expr) = e match { + case Lambda(args, body) => + val (recArgs, recBody) = flattenLambda(body) + (args ++ recArgs, recBody) + case _ => (Seq.empty, e) + } - var instantiation: Clauses = Seq.empty + val (structArgs, structBody) = flattenLambda(template.structure.lambda) + assert(quantifiers.size == structArgs.size, "Expecting lambda templates to contain flattened lamdbas") - for (matchers <- matchQuorums) { - val axiom = new LambdaAxiom(template.pathVar._1 -> substituter(template.start), - blockerT, guardT, quantifiers, matchers, instMatchers, condVars, exprVars, template.condTree, - clauses, blockers, applications, lambdas, template) + val lambdaBody = exprOps.replaceFromSymbols((structArgs zip quantifiers.map(_._1)).toMap, structBody) + val quantBody = Equals(app, lambdaBody) - quantifications += axiom - handledSubsts += axiom -> MutableSet.empty - ignoredSubsts += axiom -> MutableSet.empty + val sortedDeps = exprOps.variablesOf(quantBody).toSeq.sortBy(_.id.uniqueName) + val substMap = (sortedDeps zip template.structure.dependencies).toMap + template.ids - val newCtx = new InstantiationContext() - for ((b,m) <- instCtx.instantiated) { - instantiation ++= newCtx.instantiate(b, m)(axiom) - } - instCtx.merge(newCtx) - } + val key = QuantificationTemplate.templateKey(quantifiers.map(_._1.toVal), quantBody, substMap) - instantiation ++= instantiateConstants(quantifiers, qMatchers) + val substituter = mkSubstituter((template.args zip quantifiers.map(_._2)).toMap + (template.start -> blockerT)) + val msubst = Map.empty[Encoded, Matcher] - instantiation - } + instantiateQuantification(new QuantificationTemplate( + template.pathVar, + Positive(guardT), + quantifiers, + template.condVars + (blocker -> blockerT), + template.exprVars, + template.condTree, + (template.clauses map substituter) :+ enablingClause, + template.blockers.map { case (b, fis) => substituter(b) -> fis.map(_.substitute(substituter, msubst)) }, + template.applications.map { case (b, fas) => substituter(b) -> fas.map(_.substitute(substituter, msubst)) }, + template.matchers.map { case (b, ms) => + substituter(b) -> ms.map(_.substitute(substituter, msubst)) + } merge Map(blockerT -> Set(selfMatcher)), + template.lambdas.map(_.substitute(substituter, msubst)), + template.quantifications.map(_.substitute(substituter, msubst)), + key, quantBody, template.stringRepr))._2 // mapping is guaranteed empty!! } - def instantiateQuantification(template: QuantificationTemplate): (Encoded, Clauses) = { + def instantiateQuantification(template: QuantificationTemplate): (Map[Encoded, Encoded], Clauses) = { templates.get(template.key) match { - case Some(idT) => - (idT, Seq.empty) + case Some(map) => + (map, Seq.empty) case None => - val qT = encodeSymbol(template.qs._1) - val quantified = template.quantifiers.map(_._2).toSet - val matcherSet = template.matchers.flatMap(_._2).toSet - val matchQuorums = extractQuorums(quantified, matcherSet, template.lambdas) - - var clauses: Clauses = Seq.empty - - val qs = for (matchers <- matchQuorums) yield { - val newQ = encodeSymbol(template.qs._1) - val substituter = mkSubstituter(Map(template.qs._2 -> newQ)) - - val quantification = new Quantification( - template.pathVar, - template.qs._1 -> newQ, - template.q2s, template.insts, template.guardVar, - template.quantifiers, matchers, template.matchers, - template.condVars, template.exprVars, template.condTree, - template.clauses map substituter, // one clause depends on 'q' (and therefore 'newQ') - template.blockers, template.applications, template.lambdas, template) - - quantifications += quantification - handledSubsts += quantification -> MutableSet.empty - ignoredSubsts += quantification -> MutableSet.empty - - val newCtx = new InstantiationContext() - for ((b,m) <- instCtx.instantiated) { - clauses ++= newCtx.instantiate(b, m)(quantification) - } - instCtx.merge(newCtx) - - quantification.qs._2 - } + val clauses = new scala.collection.mutable.ListBuffer[Encoded] + val mapping: Map[Encoded, Encoded] = template.polarity match { + case Positive(guardVar) => + val axiom = new Axiom(template.pathVar, guardVar, + template.quantifiers, template.condVars, template.exprVars, template.condTree, + template.clauses, template.blockers, template.applications, template.matchers, + template.lambdas, template.quantifications, template.body) + + quantifications += axiom + + for ((bs,m) <- handledMatchers) { + clauses ++= axiom.instantiate(bs, m) + } - clauses :+= { - val newQs = - if (qs.isEmpty) trueT - else if (qs.size == 1) qs.head - else mkAnd(qs : _*) - mkImplies(template.start, mkEquals(qT, newQs)) - } + clauses ++= axiom.ensureGrounds + Map.empty - clauses ++= instantiateConstants(template.quantifiers, matcherSet) + case Negative(insts) => + val instT = encodeSymbol(insts._1) + val (substMap, substClauses) = Template.substitution( + template.condVars, template.exprVars, template.condTree, + template.lambdas, template.quantifications, + Map(insts._2 -> Left(instT)), template.pathVar._1, template.pathVar._2) + clauses ++= substClauses - templates += template.key -> qT - (qT, clauses) - } - } + // this will call `instantiateMatcher` on all matchers in `template.matchers` + val instClauses = Template.instantiate(template.clauses, + template.blockers, template.applications, template.matchers, substMap) + clauses ++= instClauses - def instantiateMatcher(blocker: Encoded, matcher: Matcher): Clauses = { - instCtx.instantiate(Set(blocker), matcher)(quantifications.toSeq : _*) - } + Map(insts._2 -> instT) - def canUnfoldQuantifiers: Boolean = ignoredSubsts.nonEmpty || ignoredMatchers.nonEmpty + case Unknown(qs, q2s, insts, guardVar) => + val qT = encodeSymbol(qs._1) + val substituter = mkSubstituter(Map(qs._2 -> qT)) - def instantiateIgnored(force: Boolean = false): Clauses = { - currentGen = if (!force) currentGen + 1 else { - val gens = ignoredSubsts.toSeq.flatMap(_._2).map(_._1) ++ ignoredMatchers.toSeq.map(_._1) - if (gens.isEmpty) currentGen else gens.min - } + val quantification = new GeneralQuantification(template.pathVar, + qs._1 -> qT, q2s, insts, guardVar, + template.quantifiers, template.condVars, template.exprVars, template.condTree, + template.clauses map substituter, // one clause depends on 'qs._2' (and therefore 'qT') + template.blockers, template.applications, template.matchers, + template.lambdas, template.quantifications, template.body) - var clauses: Clauses = Seq.empty + quantifications += quantification - val matchersToRelease = ignoredMatchers.toList.flatMap { case e @ (gen, b, m) => - if (gen == currentGen) { - ignoredMatchers -= e - Some(b -> m) - } else { - None - } - } + for ((bs,m) <- handledMatchers) { + clauses ++= quantification.instantiate(bs, m) + } - for ((bs,m) <- matchersToRelease) { - clauses ++= instCtx.instantiate(bs, m)(quantifications.toSeq : _*) - } + for ((b,ms) <- template.matchers; m <- ms) { + clauses ++= instantiateMatcher(b, m) + } - val substsToRelease = quantifications.toList.flatMap { q => - val qsubsts = ignoredSubsts(q) - qsubsts.toList.flatMap { case e @ (gen, enablers, subst) => - if (gen == currentGen) { - qsubsts -= e - Some((q, enablers, subst)) - } else { - None + clauses ++= quantification.ensureGrounds + Map(qs._2 -> qT) } - } - } - for ((q, enablers, subst) <- substsToRelease) { - clauses ++= q.instantiateSubst(enablers, subst, strict = false) + templates += template.key -> mapping + (mapping, clauses.toSeq) } - - clauses } - private def instantiateConstants(quantifiers: Seq[(Variable, Encoded)], matchers: Set[Matcher]): Clauses = { - var clauses: Clauses = Seq.empty - - for (normalizer <- List(abstractNormalizer, concreteNormalizer)) { - val quantifierSubst = normalizer.normalSubst(quantifiers) - val substituter = mkSubstituter(quantifierSubst) - - for { - m <- matchers - sm = m.substitute(substituter, Map.empty) - if !instCtx.corresponding(sm).exists(_._2.args == sm.args) - } clauses ++= instCtx.instantiate(Set.empty, sm)(quantifications.toSeq : _*) - - def unifyMatchers(matchers: Seq[Matcher]): Clauses = matchers match { - case sm +: others => - var clauses: Clauses = Seq.empty - for (pm <- others if correspond(pm, sm)) { - val encodedArgs = (sm.args zip pm.args).map(p => p._1.encoded -> p._2.encoded) - val mismatches = encodedArgs.zipWithIndex.collect { - case ((sa, pa), idx) if isQuantifier(sa) && isQuantifier(pa) && sa != pa => (idx, (pa, sa)) - }.toMap - - def extractChains(indexes: Seq[Int], partials: Seq[Seq[Int]]): Seq[Seq[Int]] = indexes match { - case idx +: xs => - val (p1, p2) = mismatches(idx) - val newPartials = Seq(idx) +: partials.map { seq => - if (mismatches(seq.head)._1 == p2) idx +: seq - else if (mismatches(seq.last)._2 == p1) seq :+ idx - else seq - } - - val (closed, remaining) = newPartials.partition { seq => - mismatches(seq.head)._1 == mismatches(seq.last)._2 - } - closed ++ extractChains(xs, partials ++ remaining) - - case _ => Seq.empty - } - - val chains = extractChains(mismatches.keys.toSeq, Seq.empty) - val positions = chains.foldLeft(Map.empty[Int, Int]) { (mapping, seq) => - val res = seq.min - mapping ++ seq.map(i => i -> res) - } - - def extractArgs(args: Seq[Arg]): Seq[Arg] = - (0 until args.size).map(i => args(positions.getOrElse(i, i))) - - clauses ++= instCtx.instantiate(Set.empty, sm.copy(args = extractArgs(sm.args)))(quantifications.toSeq : _*) - clauses ++= instCtx.instantiate(Set.empty, pm.copy(args = extractArgs(pm.args)))(quantifications.toSeq : _*) - } - - clauses ++ unifyMatchers(others) + def promoteQuantifications: Unit = { + val optGen = quantificationsManager.unrollGeneration + if (!optGen.isDefined) + throw FatalError("Attempting to promote inexistent quantifiers") - case _ => Seq.empty - } - - if (normalizer == abstractNormalizer) { - val substMatchers = matchers.map(_.substitute(substituter, Map.empty)) - clauses ++= unifyMatchers(substMatchers.toSeq) - } + val diff = (currentGeneration - optGen.get) max 0 + val currentMatchers = ignoredMatchers.toSeq + ignoredMatchers.clear + for ((gen, bs, m) <- currentMatchers) { + ignoredMatchers += ((gen - diff, bs, m)) } - clauses + for (q <- quantifications) { + ignoredSubsts += q -> ignoredSubsts(q).map { case (gen, bs, subst) => (gen - diff, bs, subst) } + } } + def requiresFiniteRangeCheck: Boolean = + ignoredMatchers.nonEmpty || ignoredSubsts.exists(_._2.nonEmpty) + def getFiniteRangeClauses: Clauses = { val clauses = new scala.collection.mutable.ListBuffer[Encoded] val keyClause = MutableMap.empty[MatcherKey, (Clauses, Encoded)] for ((_, bs, m) <- ignoredMatchers) { - val key = matcherKey(m.caller, m.tpe) - val QuantificationTypeMatcher(argTypes, _) = key.tpe + val key = matcherKey(m) + val argTypes = key match { + case tk: TypedKey => + val QuantificationTypeMatcher(argTypes, _) = tk.tpe + argTypes + case FunctionKey(tfd) => + tfd.params.map(_.getType) ++ (tfd.returnType match { + case tpe @ QuantificationTypeMatcher(argTypes, _) if tpe.isInstanceOf[FunctionType] => + argTypes + case _ => Seq.empty + }) + } val (values, clause) = keyClause.getOrElse(key, { - val insts = instCtx.map.get(key).toMatchers + val insts = handledMatchers.filter(hm => correspond(matcherKey(hm._2), key)) val guard = Variable(FreshIdentifier("guard", true), BooleanType) val elems = argTypes.map(tpe => Variable(FreshIdentifier("elem", true), tpe)) @@ -978,10 +758,16 @@ trait QuantificationTemplates { self: Templates => val guardP = guard -> encodeSymbol(guard) val elemsP = elems.map(e => e -> encodeSymbol(e)) val valuesP = values.map(v => v -> encodeSymbol(v)) - val exprT = encodeExpr(elemsP.toMap ++ valuesP + guardP)(expr) + val exprT = mkEncoder(elemsP.toMap ++ valuesP + guardP)(expr) - val disjuncts = insts.toSeq.map { case (b, im) => - val bp = if (m.caller != im.caller) mkAnd(mkEquals(m.caller, im.caller), b) else b + val disjuncts = insts.toSeq.map { case (bs, im) => + val cond = (m.key, im.key) match { + case (Left((mcaller, _)), Left((imcaller, _))) if mcaller != imcaller => + Some(mkEquals(mcaller, imcaller)) + case _ => None + } + + val bp = encodeEnablers(bs ++ cond) val subst = (elemsP.map(_._2) zip im.args.map(_.encoded)).toMap + (guardP._2 -> bp) mkSubstituter(subst)(exprT) } @@ -1005,10 +791,10 @@ trait QuantificationTemplates { self: Templates => val guardP = guard -> encodeSymbol(guard) val elemsP = elems.map(e => e -> encodeSymbol(e)) val valuesP = values.map(v => v -> encodeSymbol(v)) - val exprT = encodeExpr(elemsP.toMap ++ valuesP + guardP)(expr) + val exprT = mkEncoder(elemsP.toMap ++ valuesP + guardP)(expr) val disjunction = handledSubsts(q) match { - case set if set.isEmpty => encodeExpr(Map.empty)(BooleanLiteral(false)) + case set if set.isEmpty => mkEncoder(Map.empty)(BooleanLiteral(false)) case set => mkOr(set.toSeq.map { case (enablers, subst) => val b = if (enablers.isEmpty) trueT else mkAnd(enablers.toSeq : _*) val substMap = (elemsP.map(_._2) zip q.quantifiers.map(p => subst(p._2).encoded)).toMap + (guardP._2 -> b) @@ -1023,228 +809,20 @@ trait QuantificationTemplates { self: Templates => } } - def isQuantified(e: Arg): Boolean = e match { - case Left(t) => isQuantifier(t) - case Right(m) => m.args.exists(isQuantified) - } - - for ((key, ctx) <- instCtx.map.instantiations) { - val QuantificationTypeMatcher(argTypes, _) = key.tpe - - for { - (tpe, idx) <- argTypes.zipWithIndex - quants <- abstractNormalizer.get(tpe) if quants.nonEmpty - (b, m) <- ctx - arg = m.args(idx) if !isQuantified(arg) - } clauses += mkAnd(quants.map(q => mkNot(mkEquals(q, arg.encoded))) : _*) - - val byPosition: Iterable[Seq[Encoded]] = ctx.flatMap { case (b, m) => - if (b != trueT) Seq.empty else m.args.zipWithIndex - }.groupBy(_._2).map(p => p._2.toSeq.flatMap { - case (a, _) => if (isQuantified(a)) Some(a.encoded) else None - }).filter(_.nonEmpty) - - for ((a +: as) <- byPosition; a2 <- as) { - clauses += mkEquals(a, a2) - } - } - clauses.toSeq } - trait ModelView { - protected val vars: Map[Variable, Encoded] - protected val evaluator: evaluators.DeterministicEvaluator - - protected def get(id: Variable): Option[Expr] - protected def eval(elem: Encoded, tpe: Type): Option[Expr] - - implicit lazy val context = evaluator.context - lazy val reporter = context.reporter - - private def extract(b: Encoded, m: Matcher): Option[Seq[Expr]] = { - val QuantificationTypeMatcher(fromTypes, _) = m.tpe - val optEnabler = eval(b, BooleanType) - optEnabler.filter(_ == BooleanLiteral(true)).flatMap { _ => - val optArgs = (m.args zip fromTypes).map { case (arg, tpe) => eval(arg.encoded, tpe) } - if (optArgs.forall(_.isDefined)) Some(optArgs.map(_.get)) - else None - } - } - - private def functionsOf(expr: Expr, path: Expr): (Seq[(Expr, Expr)], Seq[Expr] => Expr) = { - - def reconstruct(subs: Seq[(Seq[(Expr, Expr)], Seq[Expr] => Expr)], - recons: Seq[Expr] => Expr): (Seq[(Expr, Expr)], Seq[Expr] => Expr) = - (subs.flatMap(_._1), (exprs: Seq[Expr]) => { - var curr = exprs - recons(subs.map { case (es, recons) => - val (used, remaining) = curr.splitAt(es.size) - curr = remaining - recons(used) - }) - }) - - def rec(expr: Expr, path: Expr): (Seq[(Expr, Expr)], Seq[Expr] => Expr) = expr match { - case (_: Lambda) => - (Seq(expr -> path), (es: Seq[Expr]) => es.head) - - case Tuple(es) => reconstruct(es.zipWithIndex.map { - case (e, i) => rec(e, TupleSelect(path, i + 1)) - }, Tuple) - - case CaseClass(cct, es) => reconstruct((cct.classDef.fieldsIds zip es).map { - case (id, e) => rec(e, CaseClassSelector(cct, path, id)) - }, CaseClass(cct, _)) - - case _ => (Seq.empty, (es: Seq[Expr]) => expr) - } - - rec(expr, path) - } - - def getTotalModel: Model = { - - def checkForalls(quantified: Set[Variable], body: Expr): Option[String] = { - val matchers = exprOps.collect[(Expr, Seq[Expr])] { - case QuantificationMatcher(e, args) => Set(e -> args) - case _ => Set.empty - } (body) - - if (matchers.isEmpty) - return Some("No matchers found.") - - val matcherToQuants = matchers.foldLeft(Map.empty[Expr, Set[Variable]]) { - case (acc, (m, args)) => acc + (m -> (acc.getOrElse(m, Set.empty) ++ args.flatMap { - case v: Variable if quantified(v) => Set(v) - case _ => Set.empty[Variable] - })) - } - - val bijectiveMappings = matcherToQuants.filter(_._2.nonEmpty).groupBy(_._2) - if (bijectiveMappings.size > 1) - return Some("Non-bijective mapping for symbol " + bijectiveMappings.head._2.head._1.asString) - - def quantifiedArg(e: Expr): Boolean = e match { - case v: Variable => quantified(v) - case QuantificationMatcher(_, args) => args.forall(quantifiedArg) - case _ => false - } - - exprOps.postTraversal(m => m match { - case QuantificationMatcher(_, args) => - val qArgs = args.filter(quantifiedArg) - - if (qArgs.nonEmpty && qArgs.size < args.size) - return Some("Mixed ground and quantified arguments in " + m.asString) - - case Operator(es, _) if es.collect { case v: Variable if quantified(v) => v }.nonEmpty => - return Some("Invalid operation on quantifiers " + m.asString) - - case (_: Equals) | (_: And) | (_: Or) | (_: Implies) | (_: Not) => // OK - - case Operator(es, _) if (es.flatMap(variablesOf).toSet & quantified).nonEmpty => - return Some("Unandled implications from operation " + m.asString) - - case _ => - }) (body) - - body match { - case v: Variable if quantified(v) => - Some("Unexpected free quantifier " + id.asString) - case _ => None - } - } - - val issues: Iterable[(Seq[Variable], Expr, String)] = for { - q <- quantifications.view - if eval(q.holds, BooleanType) == Some(BooleanLiteral(true)) - msg <- checkForalls(q.quantifiers.map(_._1).toSet, q.body) - } yield (q.quantifiers.map(_._1), q.body, msg) - - if (issues.nonEmpty) { - val (quantifiers, body, msg) = issues.head - reporter.warning("Model soundness not guaranteed for \u2200" + - quantifiers.map(_.asString).mkString(",") + ". " + body.asString+" :\n => " + msg) - } - - val types = typeInstantiations - val partials = partialInstantiations - - def extractCond(params: Seq[Variable], args: Seq[(Encoded, Expr)], structure: Map[Encoded, Variable]): Seq[Expr] = (params, args) match { - case (id +: rparams, (v, arg) +: rargs) => - if (isQuantifier(v)) { - structure.get(v) match { - case Some(pid) => Equals(id, pid) +: extractCond(rparams, rargs, structure) - case None => extractCond(rparams, rargs, structure + (v -> id)) - } - } else { - Equals(id, arg) +: extractCond(rparams, rargs, structure) - } - case _ => Seq.empty + def getGroundInstantiations(e: Encoded, tpe: Type): Seq[(Encoded, Seq[Encoded])] = { + val bestTpe = bestRealType(tpe) + handledMatchers.flatMap { case (bs, m) => + val enabler = encodeEnablers(bs) + val optArgs = matcherKey(m) match { + case TypeKey(tpe) if bestTpe == tpe => Some(m.args.map(_.encoded)) + case CallerKey(caller, tpe) if e == caller => Some(m.args.map(_.encoded)) + case _ => None } - new Model(vars.map { case (id, idT) => - val value = get(id).getOrElse(simplestValue(id.getType)) - val (functions, recons) = functionsOf(value, id) - - id -> recons(functions.map { case (f, path) => - val encoded = encodeExpr(Map(id -> idT))(path) - val tpe = bestRealType(f.getType).asInstanceOf[FunctionType] - val params = tpe.from.map(tpe => Variable(FreshIdentifier("x", true), tpe)) - partials.get(encoded).orElse(types.get(tpe)).map { domain => - val conditionals = domain.flatMap { case (b, m) => - extract(b, m).map { args => - val result = evaluator.eval(application(f, args)).result.getOrElse { - scala.sys.error("Unexpectedly failed to evaluate " + application(f, args)) - } - - val cond = if (m.args.exists(arg => isQuantifier(arg.encoded))) { - extractCond(params, m.args.map(_.encoded) zip args, Map.empty) - } else { - (params zip args).map(p => Equals(p._1, p._2)) - } - - cond -> result - } - }.toMap - - if (conditionals.isEmpty) f match { - case FiniteLambda(mapping, dflt, tpe) => - Lambda(params.map(ValDef(_)), mapping.foldRight(dflt) { case ((es, v), elze) => - IfExpr(andJoin((params zip es).map(p => Equals(p._1.toVariable, p._2))), v, elze) - }) - case _ => f - } else { - val ((_, dflt)) +: rest = conditionals.toSeq.sortBy { case (conds, _) => - (conds.flatMap(variablesOf).toSet.size, conds.size) - } - - val body = rest.foldLeft(dflt) { case (elze, (conds, res)) => - if (conds.isEmpty) elze else (elze match { - case pres if res == pres => res - case _ => IfExpr(andJoin(conds), res, elze) - }) - } - - Lambda(params.map(_.toVal), body) - } - }.getOrElse(f) - }) - }) + optArgs.map(args => enabler -> args) } } - - def getModel(vs: Map[Variable, Encoded], ev: DeterministicEvaluator, _get: Variable => Option[Expr], _eval: (Encoded, Type) => Option[Expr]) = new ModelView { - val vars: Map[Variable, Encoded] = vs - val evaluator: DeterministicEvaluator = ev - - def get(id: Variable): Option[Expr] = _get(id) - def eval(elem: Encoded, tpe: Type): Option[Expr] = _eval(elem, tpe) - } - - def getInstantiationsWithBlockers = quantifications.toSeq.flatMap { - case q: Quantification => q.instantiations.toSeq - case _ => Seq.empty - } } diff --git a/src/main/scala/inox/solvers/unrolling/TemplateGenerator.scala b/src/main/scala/inox/solvers/unrolling/TemplateGenerator.scala index a240d8efec54e22bd4cb7b33432c69fba2e65d93..68041401851d74cb684ffd4603361bdf670c2690 100644 --- a/src/main/scala/inox/solvers/unrolling/TemplateGenerator.scala +++ b/src/main/scala/inox/solvers/unrolling/TemplateGenerator.scala @@ -12,8 +12,6 @@ trait TemplateGenerator { self: Templates => import program.trees._ import program.symbols._ - val assumePreHolds: Boolean - private type TemplateClauses = ( Map[Variable, Encoded], Map[Variable, Encoded], @@ -42,11 +40,7 @@ trait TemplateGenerator { self: Templates => return cache(tfd) } - // The precondition if it exists. - val prec : Option[Expr] = tfd.precondition.map(p => simplifyHOFunctions(matchToIfThenElse(p))) - - val newBody : Option[Expr] = tfd.body.map(b => matchToIfThenElse(b)) - val lambdaBody : Option[Expr] = newBody.map(b => simplifyHOFunctions(b)) + val lambdaBody : Option[Expr] = tfd.body.map(simplifyHOFunctions) val funDefArgs: Seq[Variable] = tfd.params.map(_.toVariable) val lambdaArguments: Seq[Variable] = lambdaBody.map(lambdaArgs).toSeq.flatten @@ -54,14 +48,7 @@ trait TemplateGenerator { self: Templates => val invocationEqualsBody : Seq[Expr] = lambdaBody match { case Some(body) => - val bs = liftedEquals(invocation, body, lambdaArguments) :+ Equals(invocation, body) - - if(prec.isDefined) { - bs.map(Implies(prec.get, _)) - } else { - bs - } - + liftedEquals(invocation, body, lambdaArguments) :+ Equals(invocation, body) case _ => Seq.empty } @@ -74,32 +61,9 @@ trait TemplateGenerator { self: Templates => val substMap : Map[Variable, Encoded] = arguments.toMap + pathVar - val (bodyConds, bodyExprs, bodyTree, bodyGuarded, bodyLambdas, bodyQuantifications) = + val (condVars, exprVars, condTree, guardedExprs, lambdas, quantifications) = invocationEqualsBody.foldLeft(emptyClauses)((clsSet, cls) => clsSet ++ mkClauses(start, cls, substMap)) - // Now the postcondition. - val (condVars, exprVars, condTree, guardedExprs, lambdas, quantifications) = tfd.postcondition match { - case Some(post) => - val newPost : Expr = simplifyHOFunctions(application(matchToIfThenElse(post), Seq(invocation))) - - val postHolds : Expr = - if(tfd.hasPrecondition) { - if (assumePreHolds) { - And(prec.get, newPost) - } else { - Implies(prec.get, newPost) - } - } else { - newPost - } - - val (postConds, postExprs, postTree, postGuarded, postLambdas, postQuantifications) = mkClauses(start, postHolds, substMap) - (bodyConds ++ postConds, bodyExprs ++ postExprs, bodyTree merge postTree, bodyGuarded merge postGuarded, bodyLambdas ++ postLambdas, bodyQuantifications ++ postQuantifications) - - case None => - (bodyConds, bodyExprs, bodyTree, bodyGuarded, bodyLambdas, bodyQuantifications) - } - val template = FunctionTemplate(tfd, pathVar, arguments, condVars, exprVars, condTree, guardedExprs, lambdas, quantifications) cache += tfd -> template @@ -107,7 +71,7 @@ trait TemplateGenerator { self: Templates => } private def lambdaArgs(expr: Expr): Seq[Variable] = expr match { - case Lambda(args, body) => args.map(_.id.freshen) ++ lambdaArgs(body) + case Lambda(args, body) => args.map(_.toVariable.freshen) ++ lambdaArgs(body) case IsTyped(_, _: FunctionType) => sys.error("Only applicable on lambda chains") case _ => Seq.empty } @@ -127,52 +91,13 @@ trait TemplateGenerator { self: Templates => rec(invocation, body, args, inlineFirst) } - private def minimalFlattening(inits: Set[Variable], conj: Expr): (Set[Variable], Expr) = { - var mapping: Map[Expr, Expr] = Map.empty - var quantified: Set[Variable] = inits - var quantifierEqualities: Seq[(Expr, Variable)] = Seq.empty - - val newConj = exprOps.postMap { - case expr if mapping.isDefinedAt(expr) => - Some(mapping(expr)) - - case expr @ QuantificationMatcher(c, args) => - val isMatcher = args.exists { case v: Variable => quantified(v) case _ => false } - val isRelevant = (exprOps.variablesOf(expr) & quantified).nonEmpty - if (!isMatcher && isRelevant) { - val newArgs = args.map { - case arg @ QuantificationMatcher(_, _) if (exprOps.variablesOf(arg) & quantified).nonEmpty => - val v = Variable(FreshIdentifier("flat", true), arg.getType) - quantifierEqualities :+= (arg -> v) - quantified += v - v - case arg => arg - } - - val newExpr = exprOps.replace((args zip newArgs).toMap, expr) - mapping += expr -> newExpr - Some(newExpr) - } else { - None - } - - case _ => None - } (conj) - - val flatConj = implies(andJoin(quantifierEqualities.map { - case (arg, id) => Equals(arg, id) - }), newConj) - - (quantified, flatConj) - } - - def mkClauses(pathVar: Variable, expr: Expr, substMap: Map[Variable, Encoded]): TemplateClauses = { - val (p, (condVars, exprVars, condTree, guardedExprs, lambdas, quantifications)) = mkExprClauses(pathVar, expr, substMap) + def mkClauses(pathVar: Variable, expr: Expr, substMap: Map[Variable, Encoded], polarity: Option[Boolean] = None): TemplateClauses = { + val (p, (condVars, exprVars, condTree, guardedExprs, lambdas, quantifications)) = mkExprClauses(pathVar, expr, substMap, polarity) val allGuarded = guardedExprs + (pathVar -> (p +: guardedExprs.getOrElse(pathVar, Seq.empty))) (condVars, exprVars, condTree, allGuarded, lambdas, quantifications) } - private def mkExprClauses(pathVar: Variable, expr: Expr, substMap: Map[Variable, Encoded]): (Expr, TemplateClauses) = { + private def mkExprClauses(pathVar: Variable, expr: Expr, substMap: Map[Variable, Encoded], polarity: Option[Boolean] = None): (Expr, TemplateClauses) = { var condVars = Map[Variable, Encoded]() var condTree = Map[Variable, Set[Variable]](pathVar -> Set.empty).withDefaultValue(Set.empty) @@ -212,214 +137,197 @@ trait TemplateGenerator { self: Templates => var lambdas = Seq[LambdaTemplate]() @inline def registerLambda(lambda: LambdaTemplate) : Unit = lambdas :+= lambda - def rec(pathVar: Variable, expr: Expr): Expr = { - expr match { - case a @ Assert(cond, err, body) => - rec(pathVar, IfExpr(cond, body, Error(body.getType, err getOrElse "assertion failed"))) + def rec(pathVar: Variable, expr: Expr, pol: Option[Boolean]): Expr = expr match { + case a @ Assume(cond, body) => + val e = rec(pathVar, cond, Some(true)) + storeGuarded(pathVar, e) + rec(pathVar, body, pol) + + case l @ Let(i, e: Lambda, b) => + val re = rec(pathVar, e, None) // guaranteed variable! + val rb = rec(pathVar, exprOps.replace(Map(i.toVariable -> re), b), pol) + rb + + case l @ Let(i, e, b) => + val newExpr : Variable = Variable(FreshIdentifier("lt", true), i.getType) + storeExpr(newExpr) + val re = rec(pathVar, e, None) + storeGuarded(pathVar, Equals(newExpr, re)) + val rb = rec(pathVar, exprOps.replace(Map(i.toVariable -> newExpr), b), pol) + rb + + case n @ Not(e) if n.getType == BooleanType => + Not(rec(pathVar, e, pol.map(!_))) + + case i @ Implies(lhs, rhs) => + if (!exprOps.isSimple(i)) { + rec(pathVar, Or(Not(lhs), rhs), pol) + } else { + implies(rec(pathVar, lhs, None), rec(pathVar, rhs, None)) + } - case e @ Ensuring(_, _) => - rec(pathVar, e.toAssert) + case a @ And(parts) if a.getType == BooleanType => + val partitions = SeqUtils.groupWhile(parts)(exprOps.isSimple) + partitions.map(andJoin) match { + case Seq(e) => e + case seq => + val newExpr: Variable = Variable(FreshIdentifier("e", true), BooleanType) + storeExpr(newExpr) - case l @ Let(i, e: Lambda, b) => - val re = rec(pathVar, e) // guaranteed variable! - val rb = rec(pathVar, exprOps.replace(Map(i.toVariable -> re), b)) - rb + def recAnd(pathVar: Variable, partitions: Seq[Expr]): Unit = partitions match { + case x :: Nil => + storeGuarded(pathVar, Equals(newExpr, rec(pathVar, x, pol))) - case l @ Let(i, e, b) => - val newExpr : Variable = Variable(FreshIdentifier("lt", true), i.getType) - storeExpr(newExpr) - val re = rec(pathVar, e) - storeGuarded(pathVar, Equals(newExpr, re)) - val rb = rec(pathVar, exprOps.replace(Map(i.toVariable -> newExpr), b)) - rb - - case m : MatchExpr => sys.error("'MatchExpr's should have been eliminated before generating templates.") - - case i @ Implies(lhs, rhs) => - if (!exprOps.isSimple(i)) { - rec(pathVar, Or(Not(lhs), rhs)) - } else { - implies(rec(pathVar, lhs), rec(pathVar, rhs)) - } + case x :: xs => + val newBool: Variable = Variable(FreshIdentifier("b", true), BooleanType) + storeCond(pathVar, newBool) - case a @ And(parts) => - val partitions = SeqUtils.groupWhile(parts)(exprOps.isSimple) - partitions.map(andJoin) match { - case Seq(e) => e - case seq => - val newExpr: Variable = Variable(FreshIdentifier("e", true), BooleanType) - storeExpr(newExpr) + val xrec = rec(pathVar, x, pol) + iff(and(pathVar, xrec), newBool) + iff(and(pathVar, not(xrec)), not(newExpr)) - def recAnd(pathVar: Variable, partitions: Seq[Expr]): Unit = partitions match { - case x :: Nil => - storeGuarded(pathVar, Equals(newExpr, rec(pathVar, x))) + recAnd(newBool, xs) - case x :: xs => - val newBool: Variable = Variable(FreshIdentifier("b", true), BooleanType) - storeCond(pathVar, newBool) + case Nil => scala.sys.error("Should never happen!") + } - val xrec = rec(pathVar, x) - iff(and(pathVar, xrec), newBool) - iff(and(pathVar, not(xrec)), not(newExpr)) + recAnd(pathVar, seq) + newExpr + } - recAnd(newBool, xs) + case o @ Or(parts) if o.getType == BooleanType => + val partitions = SeqUtils.groupWhile(parts)(exprOps.isSimple) + partitions.map(orJoin) match { + case Seq(e) => e + case seq => + val newExpr: Variable = Variable(FreshIdentifier("e", true), BooleanType) + storeExpr(newExpr) - case Nil => scala.sys.error("Should never happen!") - } + def recOr(pathVar: Variable, partitions: Seq[Expr]): Unit = partitions match { + case x :: Nil => + storeGuarded(pathVar, Equals(newExpr, rec(pathVar, x, None))) - recAnd(pathVar, seq) - newExpr - } + case x :: xs => + val newBool: Variable = Variable(FreshIdentifier("b", true), BooleanType) + storeCond(pathVar, newBool) - case o @ Or(parts) => - val partitions = SeqUtils.groupWhile(parts)(exprOps.isSimple) - partitions.map(orJoin) match { - case Seq(e) => e - case seq => - val newExpr: Variable = Variable(FreshIdentifier("e", true), BooleanType) - storeExpr(newExpr) + val xrec = rec(pathVar, x, None) + iff(and(pathVar, xrec), newExpr) + iff(and(pathVar, not(xrec)), newBool) - def recOr(pathVar: Variable, partitions: Seq[Expr]): Unit = partitions match { - case x :: Nil => - storeGuarded(pathVar, Equals(newExpr, rec(pathVar, x))) + recOr(newBool, xs) - case x :: xs => - val newBool: Variable = Variable(FreshIdentifier("b", true), BooleanType) - storeCond(pathVar, newBool) + case Nil => scala.sys.error("Should never happen!") + } - val xrec = rec(pathVar, x) - iff(and(pathVar, xrec), newExpr) - iff(and(pathVar, not(xrec)), newBool) + recOr(pathVar, seq) + newExpr + } - recOr(newBool, xs) + case i @ IfExpr(cond, thenn, elze) => { + if(exprOps.isSimple(i)) { + i + } else { + val newBool1 : Variable = Variable(FreshIdentifier("b", true), BooleanType) + val newBool2 : Variable = Variable(FreshIdentifier("b", true), BooleanType) + val newExpr : Variable = Variable(FreshIdentifier("e", true), i.getType) - case Nil => scala.sys.error("Should never happen!") - } + storeCond(pathVar, newBool1) + storeCond(pathVar, newBool2) - recOr(pathVar, seq) - newExpr - } + storeExpr(newExpr) - case i @ IfExpr(cond, thenn, elze) => { - if(exprOps.isSimple(i)) { - i - } else { - val newBool1 : Variable = Variable(FreshIdentifier("b", true), BooleanType) - val newBool2 : Variable = Variable(FreshIdentifier("b", true), BooleanType) - val newExpr : Variable = Variable(FreshIdentifier("e", true), i.getType) + val crec = rec(pathVar, cond, None) + val trec = rec(newBool1, thenn, None) + val erec = rec(newBool2, elze, None) - storeCond(pathVar, newBool1) - storeCond(pathVar, newBool2) + iff(and(pathVar, cond), newBool1) + iff(and(pathVar, not(cond)), newBool2) - storeExpr(newExpr) + storeGuarded(newBool1, Equals(newExpr, trec)) + storeGuarded(newBool2, Equals(newExpr, erec)) + newExpr + } + } - val crec = rec(pathVar, cond) - val trec = rec(newBool1, thenn) - val erec = rec(newBool2, elze) + case l @ Lambda(args, body) => + val idArgs : Seq[Variable] = lambdaArgs(l) + val trArgs : Seq[Encoded] = idArgs.map(id => substMap.getOrElse(id, encodeSymbol(id))) + + val lid = Variable(FreshIdentifier("lambda", true), bestRealType(l.getType)) + val clauses = liftedEquals(lid, l, idArgs, inlineFirst = true) + + val localSubst: Map[Variable, Encoded] = substMap ++ condVars ++ exprVars ++ lambdaVars + val clauseSubst: Map[Variable, Encoded] = localSubst ++ (idArgs zip trArgs) + val (lambdaConds, lambdaExprs, lambdaTree, lambdaGuarded, lambdaTemplates, lambdaQuants) = + clauses.foldLeft(emptyClauses)((clsSet, cls) => clsSet ++ mkClauses(pathVar, cls, clauseSubst)) + + val ids: (Variable, Encoded) = lid -> storeLambda(lid) + + val (struct, deps) = normalizeStructure(l) + val sortedDeps = deps.toSeq.sortBy(_._1.id.uniqueName) + + val (dependencies, (depConds, depExprs, depTree, depGuarded, depLambdas, depQuants)) = + sortedDeps.foldLeft[(Seq[Encoded], TemplateClauses)](Seq.empty -> emptyClauses) { + case ((dependencies, clsSet), (id, expr)) => + if (!exprOps.isSimple(expr)) { + val encoded = encodeSymbol(id) + val (e, cls @ (_, _, _, _, lmbds, quants)) = mkExprClauses(pathVar, expr, localSubst) + val clauseSubst = localSubst ++ lmbds.map(_.ids) ++ quants.flatMap(_.mapping) + (dependencies :+ mkEncoder(clauseSubst)(e), clsSet ++ cls) + } else { + (dependencies :+ mkEncoder(localSubst)(expr), clsSet) + } + } - iff(and(pathVar, cond), newBool1) - iff(and(pathVar, not(cond)), newBool2) + val (depClauses, depCalls, depApps, depMatchers, _) = Template.encode( + pathVar -> encodedCond(pathVar), Seq.empty, + depConds, depExprs, depGuarded, depLambdas, depQuants, localSubst) - storeGuarded(newBool1, Equals(newExpr, trec)) - storeGuarded(newBool2, Equals(newExpr, erec)) - newExpr - } + val depClosures: Seq[Encoded] = { + val vars = exprOps.variablesOf(l) + var cls: Seq[Variable] = Seq.empty + exprOps.preTraversal { case v: Variable if vars(v) => cls :+= v case _ => } (l) + cls.distinct.map(localSubst) } - case l @ Lambda(args, body) => - val idArgs : Seq[Variable] = lambdaArgs(l) - val trArgs : Seq[Encoded] = idArgs.map(id => substMap.getOrElse(id, encodeSymbol(id))) + val structure = new LambdaStructure( + struct, dependencies, pathVar -> encodedCond(pathVar), depClosures, + depConds, depExprs, depTree, depClauses, depCalls, depApps, depMatchers, depLambdas, depQuants) - val lid = Variable(FreshIdentifier("lambda", true), bestRealType(l.getType)) - val clauses = liftedEquals(lid, l, idArgs, inlineFirst = true) + val template = LambdaTemplate(ids, pathVar -> encodedCond(pathVar), + idArgs zip trArgs, lambdaConds, lambdaExprs, lambdaTree, + lambdaGuarded, lambdaTemplates, lambdaQuants, structure, localSubst, l) + registerLambda(template) + lid - val localSubst: Map[Variable, Encoded] = substMap ++ condVars ++ exprVars ++ lambdaVars - val clauseSubst: Map[Variable, Encoded] = localSubst ++ (idArgs zip trArgs) - val (lambdaConds, lambdaExprs, lambdaTree, lambdaGuarded, lambdaTemplates, lambdaQuants) = - clauses.foldLeft(emptyClauses)((clsSet, cls) => clsSet ++ mkClauses(pathVar, cls, clauseSubst)) - - val ids: (Variable, Encoded) = lid -> storeLambda(lid) - - val (struct, deps) = normalizeStructure(l) - - val (dependencies, (depConds, depExprs, depTree, depGuarded, depLambdas, depQuants)) = - deps.foldLeft[(Seq[Encoded], TemplateClauses)](Seq.empty -> emptyClauses) { - case ((dependencies, clsSet), (id, expr)) => - if (!exprOps.isSimple(expr)) { - val encoded = encodeSymbol(id) - val (e, cls @ (_, _, _, _, lmbds, quants)) = mkExprClauses(pathVar, expr, localSubst) - val clauseSubst = localSubst ++ lmbds.map(_.ids) ++ quants.map(_.qs) - (dependencies :+ encodeExpr(clauseSubst)(e), clsSet ++ cls) - } else { - (dependencies :+ encodeExpr(localSubst)(expr), clsSet) - } - } + case f @ Forall(args, body) => + val TopLevelAnds(conjuncts) = body - val (depClauses, depCalls, depApps, depMatchers, _) = Template.encode( - pathVar -> encodedCond(pathVar), Seq.empty, - depConds, depExprs, depGuarded, depLambdas, depQuants, localSubst) + val conjunctQs = conjuncts.map { conjunct => + val vars = exprOps.variablesOf(conjunct) + val quantifiers = args.map(_.toVariable).filter(vars).toSet - val depClosures: Seq[Encoded] = { - val vars = exprOps.variablesOf(l) - var cls: Seq[Variable] = Seq.empty - exprOps.preTraversal { case v: Variable if vars(v) => cls :+= v case _ => } (l) - cls.distinct.map(localSubst) - } + val idQuantifiers : Seq[Variable] = quantifiers.toSeq + val trQuantifiers : Seq[Encoded] = idQuantifiers.map(encodeSymbol) - val structure = new LambdaStructure( - struct, dependencies, pathVar -> encodedCond(pathVar), depClosures, - depConds, depExprs, depTree, depClauses, depCalls, depApps, depLambdas, depMatchers, depQuants) - - val template = LambdaTemplate(ids, pathVar -> encodedCond(pathVar), - idArgs zip trArgs, lambdaConds, lambdaExprs, lambdaTree, - lambdaGuarded, lambdaTemplates, lambdaQuants, structure, localSubst, l) - registerLambda(template) - lid - - case f @ Forall(args, body) => - val TopLevelAnds(conjuncts) = body - - val conjunctQs = conjuncts.map { conjunct => - val vars = exprOps.variablesOf(conjunct) - val inits = args.map(_.toVariable).filter(vars).toSet - val (quantifiers, flatConj) = minimalFlattening(inits, conjunct) - - val idQuantifiers : Seq[Variable] = quantifiers.toSeq - val trQuantifiers : Seq[Encoded] = idQuantifiers.map(encodeSymbol) - - val q: Variable = Variable(FreshIdentifier("q", true), BooleanType) - val q2: Variable = Variable(FreshIdentifier("qo", true), BooleanType) - val inst: Variable = Variable(FreshIdentifier("inst", true), BooleanType) - val guard: Variable = Variable(FreshIdentifier("guard", true), BooleanType) - - val clause = Equals(inst, Implies(guard, flatConj)) - - val qs: (Variable, Encoded) = q -> encodeSymbol(q) - val localSubst: Map[Variable, Encoded] = substMap ++ condVars ++ exprVars ++ lambdaVars - val clauseSubst: Map[Variable, Encoded] = localSubst ++ (idQuantifiers zip trQuantifiers) - val (p, (qConds, qExprs, qTree, qGuarded, qTemplates, qQuants)) = mkExprClauses(pathVar, flatConj, clauseSubst) - assert(qQuants.isEmpty, "Unhandled nested quantification in "+clause) - - val allGuarded = qGuarded + (pathVar -> (Seq( - Equals(inst, Implies(guard, p)), - Equals(q, And(q2, inst)) - ) ++ qGuarded.getOrElse(pathVar, Seq.empty))) - - val dependencies: Map[Variable, Encoded] = vars.filterNot(quantifiers).map(id => id -> localSubst(id)).toMap - val template = QuantificationTemplate(pathVar -> encodedCond(pathVar), - qs, q2, inst, guard, idQuantifiers zip trQuantifiers, qConds, qExprs, qTree, allGuarded, qTemplates, localSubst, - dependencies, Forall(quantifiers.toSeq.sortBy(_.id.uniqueName).map(_.toVal), flatConj)) - registerQuantification(template) - q - } + val localSubst: Map[Variable, Encoded] = substMap ++ condVars ++ exprVars ++ lambdaVars + val clauseSubst: Map[Variable, Encoded] = localSubst ++ (idQuantifiers zip trQuantifiers) + val (p, (qConds, qExprs, qTree, qGuarded, qLambdas, qQuants)) = mkExprClauses(pathVar, conjunct, clauseSubst) + + val (optVar, template) = QuantificationTemplate(pathVar -> encodedCond(pathVar), + pol, p, idQuantifiers zip trQuantifiers, qConds, qExprs, qTree, qGuarded, qLambdas, qQuants, + localSubst, Forall(quantifiers.toSeq.sortBy(_.id.uniqueName).map(_.toVal), conjunct)) + registerQuantification(template) + optVar.getOrElse(BooleanLiteral(true)) + } - andJoin(conjunctQs) + andJoin(conjunctQs) - case Operator(as, r) => r(as.map(a => rec(pathVar, a))) - } + case Operator(as, r) => r(as.map(a => rec(pathVar, a, None))) } - val p = rec(pathVar, expr) + val p = rec(pathVar, expr, polarity) (p, (condVars, exprVars, condTree, guardedExprs, lambdas, quantifications)) } diff --git a/src/main/scala/inox/solvers/unrolling/Templates.scala b/src/main/scala/inox/solvers/unrolling/Templates.scala index 3679b9cbdbe4b429774a8fd3dd621a615135a7cf..fbb3da51eef6f286ae7b0fe29fb9142ff6286fea 100644 --- a/src/main/scala/inox/solvers/unrolling/Templates.scala +++ b/src/main/scala/inox/solvers/unrolling/Templates.scala @@ -23,7 +23,7 @@ trait Templates extends TemplateGenerator type Encoded <: Printable def encodeSymbol(v: Variable): Encoded - def encodeExpr(bindings: Map[Variable, Encoded])(e: Expr): Encoded + def mkEncoder(bindings: Map[Variable, Encoded])(e: Expr): Encoded def mkSubstituter(map: Map[Encoded, Encoded]): Encoded => Encoded def mkNot(e: Encoded): Encoded @@ -34,7 +34,7 @@ trait Templates extends TemplateGenerator def extractNot(e: Encoded): Option[Encoded] - private[unrolling] lazy val trueT = encodeExpr(Map.empty)(BooleanLiteral(true)) + private[unrolling] lazy val trueT = mkEncoder(Map.empty)(BooleanLiteral(true)) private var currentGen: Int = 0 protected def currentGeneration: Int = currentGen @@ -72,40 +72,57 @@ trait Templates extends TemplateGenerator private val condImplies = new IncrementalMap[Encoded, Set[Encoded]].withDefaultValue(Set.empty) private val condImplied = new IncrementalMap[Encoded, Set[Encoded]].withDefaultValue(Set.empty) + private val condEquals = new IncrementalBijection[Encoded, Set[Encoded]] - val incrementals: Seq[IncrementalState] = managers ++ Seq(condImplies, condImplied) + val incrementals: Seq[IncrementalState] = managers ++ Seq(condImplies, condImplied, condEquals) protected def freshConds( - path: (Variable, Encoded), + pathVar: Encoded, condVars: Map[Variable, Encoded], tree: Map[Variable, Set[Variable]]): Map[Encoded, Encoded] = { val subst = condVars.map { case (v, idT) => idT -> encodeSymbol(v) } - val mapping = condVars.mapValues(subst) + path + val mapping = condVars.mapValues(subst) + + for ((parent, children) <- tree) { + mapping.get(parent) match { + case None => // enabling condition, corresponds to pathVar + for (child <- children) { + val ec = mapping(child) + condImplied += ec -> (condImplies(ec) + pathVar) + } - for ((parent, children) <- tree; ep = mapping(parent); child <- children) { - val ec = mapping(child) - condImplies += ep -> (condImplies(ep) + ec) - condImplied += ec -> (condImplied(ec) + ep) + case Some(ep) => + for (child <- children) { + val ec = mapping(child) + condImplies += ep -> (condImplies(ep) + ec) + condImplied += ec -> (condImplied(ec) + ep) + } + } } subst } - protected def blocker(b: Encoded): Unit = condImplies += (b -> Set.empty) - protected def isBlocker(b: Encoded): Boolean = condImplies.isDefinedAt(b) || condImplied.isDefinedAt(b) + private val sym = Variable(FreshIdentifier("bs", true), BooleanType) + protected def encodeBlockers(bs: Set[Encoded]): (Encoded, Clauses) = bs.toSeq match { + case Seq(b) if condImplies.isDefinedAt(b) || condImplied.isDefinedAt(b) || condEquals.containsA(b) => + (b, Seq.empty) + + case _ => + val flatBs = fixpoint((bs: Set[Encoded]) => bs.flatMap(b => condEquals.getBorElse(b, Set(b))))(bs) + condEquals.getA(flatBs) match { + case Some(b) => (b, Seq.empty) + case None => + val b = encodeSymbol(sym) + condEquals += (b -> flatBs) + (b, Seq(mkEquals(b, if (flatBs.isEmpty) trueT else mkAnd(flatBs.toSeq : _*)))) + } + } + protected def blockerParents(b: Encoded): Set[Encoded] = condImplied(b) protected def blockerChildren(b: Encoded): Set[Encoded] = condImplies(b) - protected def impliesBlocker(b1: Encoded, b2: Encoded): Unit = impliesBlocker(b1, Set(b2)) - protected def impliesBlocker(b1: Encoded, b2s: Set[Encoded]): Unit = { - val fb2s = b2s.filter(_ != b1) - condImplies += b1 -> (condImplies(b1) ++ fb2s) - for (b2 <- fb2s) { - condImplied += b2 -> (condImplies(b2) + b1) - } - } - def promoteBlocker(b: Encoded, force: Boolean = false): Boolean = { var seen: Set[Encoded] = Set.empty var promoted: Boolean = false @@ -182,14 +199,17 @@ trait Templates extends TemplateGenerator } /** Represents an E-matching matcher that will be used to instantiate relevant quantified propositions */ - case class Matcher(caller: Encoded, tpe: Type, args: Seq[Arg], encoded: Encoded) { - override def toString: String = caller.asString + args.map { + case class Matcher(key: Either[(Encoded, Type), TypedFunDef], args: Seq[Arg], encoded: Encoded) { + override def toString: String = (key match { + case Left((c, tpe)) => c.asString + ":" + tpe.asString + case Right(tfd) => tfd.signature + }) + args.map { case Right(m) => m.toString case Left(v) => v.asString }.mkString("(", ",", ")") def substitute(substituter: Encoded => Encoded, msubst: Map[Encoded, Matcher]): Matcher = copy( - caller = substituter(caller), + key = key.left.map(p => substituter(p._1) -> p._2), args = args.map(_.substitute(substituter, msubst)), encoded = substituter(encoded) ) @@ -267,31 +287,16 @@ trait Templates extends TemplateGenerator override def toString : String = "Instantiated template" } - object Template { - private def mkApplication(caller: Expr, args: Seq[Expr]): Expr = caller.getType match { - case FunctionType(from, to) => - val (curr, next) = args.splitAt(from.size) - mkApplication(Application(caller, curr), next) - case _ => - assert(args.isEmpty, s"Non-function typed $caller applied to ${args.mkString(",")}") - caller - } - - private def invocationMatcher(encoder: Expr => Encoded)(tfd: TypedFunDef, args: Seq[Expr]): Matcher = { - assert(tfd.returnType.isInstanceOf[FunctionType], "invocationMatcher() is only defined on function-typed defs") - - def rec(e: Expr, args: Seq[Expr]): Expr = e.getType match { - case FunctionType(from, to) => - val (appArgs, outerArgs) = args.splitAt(from.size) - rec(Application(e, appArgs), outerArgs) - case _ if args.isEmpty => e - case _ => scala.sys.error("Should never happen") - } + private[unrolling] def mkApplication(caller: Expr, args: Seq[Expr]): Expr = caller.getType match { + case FunctionType(from, to) => + val (curr, next) = args.splitAt(from.size) + mkApplication(Application(caller, curr), next) + case _ => + assert(args.isEmpty, s"Non-function typed $caller applied to ${args.mkString(",")}") + caller + } - val (fiArgs, appArgs) = args.splitAt(tfd.params.size) - val app @ Application(caller, arguments) = rec(tfd.applied(fiArgs), appArgs) - Matcher(encoder(caller), bestRealType(caller.getType), arguments.map(arg => Left(encoder(arg))), encoder(app)) - } + object Template { def encode( pathVar: (Variable, Encoded), @@ -307,19 +312,23 @@ trait Templates extends TemplateGenerator ) : (Clauses, Calls, Apps, Matchers, () => String) = { val idToTrId : Map[Variable, Encoded] = - condVars ++ exprVars + pathVar ++ arguments ++ substMap ++ lambdas.map(_.ids) ++ quantifications.map(_.qs) + condVars ++ exprVars + pathVar ++ arguments ++ substMap ++ + lambdas.map(_.ids) ++ quantifications.flatMap(_.mapping) - val encoder : Expr => Encoded = encodeExpr(idToTrId) + val encoder : Expr => Encoded = mkEncoder(idToTrId) val optIdCall = optCall.map(tfd => Call(tfd, arguments.map(p => Left(p._2)))) val optIdApp = optApp.map { case (idT, tpe) => val v = Variable(FreshIdentifier("x", true), tpe) - val encoded = encodeExpr(Map(v -> idT) ++ arguments)(mkApplication(v, arguments.map(_._1))) + val encoded = mkEncoder(Map(v -> idT) ++ arguments)(mkApplication(v, arguments.map(_._1))) App(idT, bestRealType(tpe).asInstanceOf[FunctionType], arguments.map(p => Left(p._2)), encoded) } - lazy val invocMatcher = optCall.filter(_.returnType.isInstanceOf[FunctionType]) - .map(tfd => invocationMatcher(encoder)(tfd, arguments.map(_._1))) + lazy val optIdMatcher = optCall.map { tfd => + val (fiArgs, appArgs) = arguments.map(_._1).splitAt(tfd.params.size) + val encoded = mkEncoder(arguments.toMap)(mkApplication(tfd.applied(fiArgs), appArgs)) + Matcher(Right(tfd), arguments.map(p => Left(p._2)), encoded) + } val (clauses, blockers, applications, matchers) = { var clauses : Clauses = Seq.empty @@ -345,7 +354,7 @@ trait Templates extends TemplateGenerator case None => Left(encoder(arg)) }) - Some(expr -> Matcher(encoder(c), bestRealType(c.getType), encodedArgs, encoder(expr))) + Some(expr -> Matcher(Left(encoder(c) -> bestRealType(c.getType)), encodedArgs, encoder(expr))) case _ => None }) }(e) @@ -374,14 +383,14 @@ trait Templates extends TemplateGenerator val apps = appInfos.filter(i => Some(i) != optIdApp) if (apps.nonEmpty) applications += b -> apps - val matchs = (matchInfos.filter { case m @ Matcher(_, _, _, menc) => + val matchs = matchInfos.filter { case m @ Matcher(_, _, menc) => !optIdApp.exists { case App(_, _, _, aenc) => menc == aenc } - } ++ (if (funInfos.exists(info => Some(info) == optIdCall)) invocMatcher else None)) + } if (matchs.nonEmpty) matchers += b -> matchs } - (clauses, blockers, applications, matchers) + (clauses, blockers, applications, matchers merge optIdMatcher.map(m => pathVar._1 -> Set(m)).toMap) } val encodedBlockers : Calls = blockers.map(p => idToTrId(p._1) -> p._2) @@ -427,7 +436,8 @@ trait Templates extends TemplateGenerator ): (Map[Encoded, Arg], Clauses) = { val freshSubst = exprVars.map { case (v, vT) => vT -> encodeSymbol(v) } ++ - freshConds(pathVar -> aVar, condVars, condTree) + freshConds(aVar, condVars, condTree) + val matcherSubst = baseSubst.collect { case (c, Right(m)) => c -> m } var subst = freshSubst.mapValues(Left(_)) ++ baseSubst @@ -450,20 +460,22 @@ trait Templates extends TemplateGenerator val substMap = subst.mapValues(_.encoded) val substLambda = lambda.substitute(mkSubstituter(substMap), matcherSubst) val (idT, cls) = instantiateLambda(substLambda) - clauses ++= cls subst += lambda.ids._2 -> Left(idT) + clauses ++= cls seen += lambda } } for (l <- lambdas) extractSubst(l) - for (q <- quantifications) { + // instantiate positive quantifications last to avoid introducing + // extra quantifier instantiations that arise due to empty domains + for (q <- quantifications.sortBy(_.polarity.isInstanceOf[Positive])) { val substMap = subst.mapValues(_.encoded) val substQuant = q.substitute(mkSubstituter(substMap), matcherSubst) - val (qT, cls) = instantiateQuantification(substQuant) + val (map, cls) = instantiateQuantification(substQuant) + subst ++= map.mapValues(Left(_)) clauses ++= cls - subst += q.qs._2 -> Left(qT) } (subst, clauses) @@ -499,18 +511,17 @@ trait Templates extends TemplateGenerator } } - def instantiateExpr(expr: Expr): Clauses = { - val subst = exprOps.variablesOf(expr).map(v => v -> encodeSymbol(v)).toMap + def instantiateExpr(expr: Expr, bindings: Map[Variable, Encoded]): Clauses = { val start = Variable(FreshIdentifier("start", true), BooleanType) val encodedStart = encodeSymbol(start) - val tpeClauses = subst.flatMap { case (v, s) => registerSymbol(encodedStart, s, v.getType) }.toSeq + val tpeClauses = bindings.flatMap { case (v, s) => registerSymbol(encodedStart, s, v.getType) }.toSeq val (condVars, exprVars, condTree, guardedExprs, lambdas, quants) = - mkClauses(start, expr, subst + (start -> encodedStart)) + mkClauses(start, expr, bindings + (start -> encodedStart)) val (clauses, calls, apps, matchers, _) = Template.encode( - start -> encodedStart, subst.toSeq, condVars, exprVars, guardedExprs, lambdas, quants) + start -> encodedStart, bindings.toSeq, condVars, exprVars, guardedExprs, lambdas, quants) val (substMap, substClauses) = Template.substitution( condVars, exprVars, condTree, lambdas, quants, Map.empty, start, encodedStart) diff --git a/src/main/scala/inox/solvers/unrolling/UnrollingSolver.scala b/src/main/scala/inox/solvers/unrolling/UnrollingSolver.scala index 0e34ca2226c67caa2fcbfd9c4730f1c93d0e084c..f75a9ff68391f9a08e0822735147d7bd1b5fb9ef 100644 --- a/src/main/scala/inox/solvers/unrolling/UnrollingSolver.scala +++ b/src/main/scala/inox/solvers/unrolling/UnrollingSolver.scala @@ -43,7 +43,7 @@ trait AbstractUnrollingSolver val assumePreHolds = options.findOptionOrDefault(optAssumePre) val silentErrors = options.findOptionOrDefault(optSilentErrors) - def check(model: Boolean = false, cores: Boolean = false): SolverResponse = + def check(model: Boolean = false, cores: Boolean = false): SolverResponses.SolverResponse = checkAssumptions(model = model, cores = cores)(Set.empty) private val constraints = new IncrementalSeq[Expr]() @@ -87,7 +87,7 @@ trait AbstractUnrollingSolver declareVariable(theories.encode(v)) }).toMap - val newClauses = unrollingBank.getClauses(expression, bindings) + val newClauses = templates.instantiateExpr(expression, bindings) for (cl <- newClauses) { solverAssert(cl) } @@ -150,32 +150,16 @@ trait AbstractUnrollingSolver case v: Variable => false case _ => true } - - private[AbstractUnrollingSolver] def extract(b: Encoded, m: templates.Matcher): Option[Seq[Expr]] = { - val QuantificationTypeMatcher(fromTypes, _) = m.tpe - val optEnabler = eval(b, BooleanType) - optEnabler.filter(_ == BooleanLiteral(true)).flatMap { _ => - val optArgs = (m.args zip fromTypes).map { case (arg, tpe) => eval(arg.encoded, tpe) } - if (optArgs.forall(_.isDefined)) { - Some(optArgs.map(_.get)) - } else { - None - } - } - } } private def emit(silenceErrors: Boolean)(msg: String) = if (silenceErrors) reporter.debug(msg) else reporter.warning(msg) - private def validateModel(model: ModelWrapper, assumptions: Seq[Expr], silenceErrors: Boolean): Boolean = { + private def validateModel(model: Map[ValDef, Expr], assumptions: Seq[Expr], silenceErrors: Boolean): Boolean = { val expr = andJoin(assumptions ++ constraints) // we have to check case class constructors in model for ADT invariants - val newExpr = freeVars.toSeq.foldLeft(expr) { case (e, (v, _)) => - val value = model.get(v).getOrElse(simplestValue(v.getType)) - let(v.toVal, value, e) - } + val newExpr = model.toSeq.foldLeft(expr) { case (e, (v, value)) => let(v, value, e) } evaluator.eval(newExpr) match { case EvaluationResults.Successful(BooleanLiteral(true)) => @@ -196,10 +180,109 @@ trait AbstractUnrollingSolver } } - private def getTotalModel: Model = { - val wrapped = solverGetModel - val view = templateGenerator.manager.getModel(freeVars.toMap, evaluator, wrapped.get, wrapped.eval) - view.getTotalModel + private def extractSimpleModel(wrapper: ModelWrapper): Map[ValDef, Expr] = { + freeVars.toMap.map { case (v, _) => v.toVal -> wrapper.get(v).getOrElse(simplestValue(v.getType)) } + } + + private def extractTotalModel(wrapper: ModelWrapper): Map[ValDef, Expr] = { + def functionsOf(expr: Expr, selector: Expr): (Seq[(Expr, Expr)], Seq[Expr] => Expr) = { + def reconstruct(subs: Seq[(Seq[(Expr, Expr)], Seq[Expr] => Expr)], + recons: Seq[Expr] => Expr): (Seq[(Expr, Expr)], Seq[Expr] => Expr) = + (subs.flatMap(_._1), (exprs: Seq[Expr]) => { + var curr = exprs + recons(subs.map { case (es, recons) => + val (used, remaining) = curr.splitAt(es.size) + curr = remaining + recons(used) + }) + }) + + def rec(expr: Expr, selector: Expr): (Seq[(Expr, Expr)], Seq[Expr] => Expr) = expr match { + case (_: Lambda) => + (Seq(expr -> selector), (es: Seq[Expr]) => es.head) + + case Tuple(es) => reconstruct(es.zipWithIndex.map { + case (e, i) => rec(e, TupleSelect(selector, i + 1)) + }, Tuple) + + case CaseClass(cct, es) => reconstruct((cct.tcd.toCase.fields zip es).map { + case (vd, e) => rec(e, CaseClassSelector(selector, vd.id)) + }, CaseClass(cct, _)) + + case _ => (Seq.empty, (es: Seq[Expr]) => expr) + } + + rec(expr, selector) + } + + import templates.{QuantificationTypeMatcher => QTM} + freeVars.toMap.map { case (v, idT) => + val value = wrapper.get(v).getOrElse(simplestValue(v.getType)) + val (functions, recons) = functionsOf(value, v) + + v.toVal -> recons(functions.map { case (f, selector) => + val encoded = templates.mkEncoder(Map(v -> idT))(selector) + val tpe = bestRealType(f.getType).asInstanceOf[FunctionType] + val QTM(from, to) = tpe + + if (from.isEmpty) f else { + val params = from.map(tpe => Variable(FreshIdentifier("x", true), tpe)) + val app = templates.mkApplication(selector, params) + + val allImages = templates.getGroundInstantiations(encoded, tpe).flatMap { case (b, eArgs) => + wrapper.eval(b, BooleanType).filter(_ == BooleanLiteral(true)).flatMap { _ => + val optArgs = (eArgs zip from).map { case (arg, tpe) => wrapper.eval(arg, tpe) } + val eApp = templates.mkEncoder(Map(v -> idT) ++ (params zip eArgs))(app) + val optResult = wrapper.eval(eApp, to) + + if (optArgs.forall(_.isDefined) && optResult.isDefined) { + val args = optArgs.map(_.get) + val result = optResult.get + Some(args -> result) + } else { + None + } + } + } + + val default = if (allImages.isEmpty) { + def rec(e: Expr): Expr = e match { + case Lambda(_, body) => rec(body) + case IfExpr(_, _, elze) => rec(elze) + case e => e + } + + rec(f) + } else { + val optDefault = allImages.collectFirst { + case (firstArg +: otherArgs, result) if otherArgs.forall { o => + evaluator.eval(Equals(firstArg, o)).result == Some(BooleanLiteral(true)) + } => result + } + + optDefault.getOrElse { + val app = templates.mkApplication(f, Seq.fill(from.size)(allImages.head._1.head)) + evaluator.eval(app).result.getOrElse { + scala.sys.error("Unexpectedly failed to evaluate " + app.asString) + } + } + } + + val body = allImages.foldRight(default) { case ((args, result), elze) => + IfExpr(andJoin((params zip args).map(p => Equals(p._1, p._2))), result, elze) + } + + def mkLambda(params: Seq[ValDef], body: Expr): Lambda = body.getType match { + case FunctionType(from, to) => + val (rest, curr) = params.splitAt(params.size - from.size) + mkLambda(rest, Lambda(curr, body)) + case _ => Lambda(params, body) + } + + mkLambda(params.map(_.toVal), body) + } + }) + } } def checkAssumptions(model: Boolean = false, cores: Boolean = false)(assumptions: Set[Expr]) = { @@ -207,7 +290,7 @@ trait AbstractUnrollingSolver val assumptionsSeq : Seq[Expr] = assumptions.toSeq val encodedAssumptions : Seq[Encoded] = assumptionsSeq.map { expr => val vars = exprOps.variablesOf(expr) - templates.encodeExpr(vars.map(v => theories.encode(v) -> freeVars(v)).toMap)(expr) + templates.mkEncoder(vars.map(v => theories.encode(v) -> freeVars(v)).toMap)(expr) } val encodedToAssumptions : Map[Encoded, Expr] = (encodedAssumptions zip assumptionsSeq).toMap @@ -221,7 +304,7 @@ trait AbstractUnrollingSolver sealed abstract class CheckState class CheckResult(val response: SolverResponses.SolverResponse) extends CheckState - case class Validate(resp: Sat) extends CheckState + case class Validate(model: Option[Map[ValDef, Expr]]) extends CheckState case object ModelCheck extends CheckState case object FiniteRangeCheck extends CheckState case object InstantiateQuantifiers extends CheckState @@ -233,7 +316,7 @@ trait AbstractUnrollingSolver def apply(resp: Response): CheckResult = new CheckResult(resp match { case Unknown => SolverResponses.Unknown case Sat(None) => SolverResponses.SatResponse - case Sat(Some(model)) => SolverResponses.SatResponseWithModel(model) + case Sat(Some(model)) => SolverResponses.SatResponseWithModel(extractSimpleModel(model)) case Unsat(None) => SolverResponses.UnsatResponse case Unsat(Some(core)) => SolverResponses.UnsatResponseWithCores(encodedCoreToCore(core)) }) @@ -253,7 +336,7 @@ trait AbstractUnrollingSolver case ModelCheck => reporter.debug(" - Running search...") - val withModel = model && !templates.hasIgnored + val withModel = model && !templates.requiresFiniteRangeCheck val withCores = cores || unrollUnsatCores val timer = ctx.timers.solvers.check.start() @@ -271,10 +354,10 @@ trait AbstractUnrollingSolver CheckResult(Unknown) case sat: Sat => - if (templates.hasIgnored) { + if (templates.requiresFiniteRangeCheck) { FiniteRangeCheck } else { - Validate(sat) + Validate(sat.model.map(extractSimpleModel)) } case _: Unsat if !templates.canUnroll => @@ -309,35 +392,35 @@ trait AbstractUnrollingSolver case Abort() => CheckResult(Unknown) - case sat: Sat => - Validate(sat) + case Sat(optModel) => + Validate(optModel.map(extractTotalModel)) case _ => InstantiateQuantifiers } } - case Validate(sat) => sat match { - case Sat(None) => CheckResult(SolverResponses.SatResponse) - case Sat(Some(model)) => + case Validate(optModel) => optModel match { + case None => CheckResult(SolverResponses.SatResponse) + case Some(model) => val valid = !checkModels || validateModel(model, assumptionsSeq, silenceErrors = silentErrors) if (valid) { - CheckResult(model) + CheckResult(SolverResponses.SatResponseWithModel(model)) } else { reporter.error( "Something went wrong. The model should have been valid, yet we got this: " + - model.asString + + model.toString + " for formula " + andJoin(assumptionsSeq ++ constraints).asString) CheckResult(Unknown) } } case InstantiateQuantifiers => - if (!templates.canUnfoldQuantifiers) { + if (!templates.quantificationsManager.unrollGeneration.isDefined) { reporter.error("Something went wrong. The model is not transitive yet we can't instantiate!?") CheckResult(Unknown) } else { - // TODO: promote ignored quantifiers! + templates.promoteQuantifications Unroll } @@ -366,12 +449,12 @@ trait AbstractUnrollingSolver CheckResult(res) case Sat(Some(model)) if feelingLucky => - if (validateModel(model, assumptionsSeq, silenceErrors = true)) { + if (validateModel(extractSimpleModel(model), assumptionsSeq, silenceErrors = true)) { CheckResult(res) } else { for { (inst, bs) <- templates.getInstantiationsWithBlockers - if !model.isTrue(inst) + if model.eval(inst, BooleanType) == Some(BooleanLiteral(false)) b <- bs } templates.promoteBlocker(b, force = true) @@ -427,13 +510,10 @@ trait UnrollingSolver extends AbstractUnrollingSolver { type Encoded = Expr def encodeSymbol(v: Variable): Expr = v.freshen - def encodeExpr(bindings: Map[Variable, Expr])(e: Expr): Expr = { + def mkEncoder(bindings: Map[Variable, Expr])(e: Expr): Expr = exprOps.replaceFromSymbols(bindings, e) - } - - def substitute(substMap: Map[Expr, Expr]): Expr => Expr = { + def mkSubstituter(substMap: Map[Expr, Expr]): Expr => Expr = (e: Expr) => exprOps.replace(substMap, e) - } def mkNot(e: Expr) = not(e) def mkOr(es: Expr*) = orJoin(es) @@ -473,14 +553,6 @@ trait UnrollingSolver extends AbstractUnrollingSolver { r } - def solverUnsatCore = None - - def solverGetModel: ModelWrapper = new ModelWrapper { - val model = solver.getModel - def modelEval(elem: Expr, tpe: Type): Option[Expr] = evaluator.eval(elem, model).result - override def toString = model.toMap.mkString("\n") - } - override def dbg(msg: => Any) = solver.dbg(msg) override def push(): Unit = { diff --git a/src/main/scala/inox/utils/Bijection.scala b/src/main/scala/inox/utils/Bijection.scala index 009f4d9cf938a7ae6c6383e223e5bc3543251946..ed84073a4088a39ce849dbcfcc930425e9c1b357 100644 --- a/src/main/scala/inox/utils/Bijection.scala +++ b/src/main/scala/inox/utils/Bijection.scala @@ -53,7 +53,7 @@ class Bijection[A, B] extends Iterable[(A, B)] { def getA(b: B): Option[A] = b2a.get(b) def getB(a: A): Option[B] = a2b.get(a) - + def getAorElse(b: B, orElse: =>A): A = b2a.getOrElse(b, orElse) def getBorElse(a: A, orElse: =>B): B = a2b.getOrElse(a, orElse)