diff --git a/src/main/scala/inox/ast/Expressions.scala b/src/main/scala/inox/ast/Expressions.scala index 2700994fb7fe01b207a65301584b87d438786e1d..b454e70b22defd5de120b6ebcac47f34f10b5d9d 100644 --- a/src/main/scala/inox/ast/Expressions.scala +++ b/src/main/scala/inox/ast/Expressions.scala @@ -124,6 +124,7 @@ trait Expressions { self: Trees => case class Variable(id: Identifier, tpe: Type) extends Expr with Terminal with VariableSymbol { /** Transforms this [[Variable]] into a [[Definitions.ValDef ValDef]] */ def toVal = to[ValDef] + def freshen = Variable(id.freshen, tpe) } diff --git a/src/main/scala/inox/ast/SymbolOps.scala b/src/main/scala/inox/ast/SymbolOps.scala index 1e01ffb56a2f07f6e44e02233e6788a53b865899..480a21b6b48b297de1c819e9fe6d2ee08d3e4a16 100644 --- a/src/main/scala/inox/ast/SymbolOps.scala +++ b/src/main/scala/inox/ast/SymbolOps.scala @@ -571,7 +571,7 @@ trait SymbolOps extends TreeOps { self: TypeOps => postMap(transform, applyRec = true)(expr) } - def collectWithPC[T](f: PartialFunction[Expr, T])(expr: Expr): Seq[(T, Path)] = { + def collectWithPaths[T](f: PartialFunction[Expr, T])(expr: Expr): Seq[(T, Path)] = { def rec(expr: Expr, path: Path): Seq[(T, Path)] = { val seq = if (f.isDefinedAt(expr)) { diff --git a/src/main/scala/inox/solvers/Solver.scala b/src/main/scala/inox/solvers/Solver.scala index 20d6efe7a960285f7ca3c857bbecef3b5a801816..d4bf0d11b1febdeaf6a7c66cfb1423b3bf0f283d 100644 --- a/src/main/scala/inox/solvers/Solver.scala +++ b/src/main/scala/inox/solvers/Solver.scala @@ -29,57 +29,61 @@ trait Solver extends Interruptible { import program._ import program.trees._ - sealed trait SolverResponse - case object Unknown extends SolverResponse - - sealed trait SolverUnsatResponse extends SolverResponse - case object UnsatResponse extends SolverUnsatResponse - case class UnsatResponseWithCores(cores: Set[Expr]) extends SolverUnsatResponse - - sealed trait SolverSatResponse extends SolverResponse - case object SatResponse extends SolverSatResponse - case class SatResponseWithModel(model: Map[ValDef, Expr]) extends SolverSatResponse - - object Check { - def unapply(resp: SolverResponse): Option[Boolean] = resp match { - case _: SolverUnsatResponse => Some(false) - case _: SolverSatResponse => Some(true) - case Unknown => None + object SolverResponses { + sealed trait SolverResponse + case object Unknown extends SolverResponse + + sealed trait SolverUnsatResponse extends SolverResponse + case object UnsatResponse extends SolverUnsatResponse + case class UnsatResponseWithCores(cores: Set[Expr]) extends SolverUnsatResponse + + sealed trait SolverSatResponse extends SolverResponse + case object SatResponse extends SolverSatResponse + case class SatResponseWithModel(model: Map[ValDef, Expr]) extends SolverSatResponse + + object Check { + def unapply(resp: SolverResponse): Option[Boolean] = resp match { + case _: SolverUnsatResponse => Some(false) + case _: SolverSatResponse => Some(true) + case Unknown => None + } } - } - object Sat { - def unapply(resp: SolverSatResponse): Boolean = resp match { - case SatResponse => true - case SatResponseWithModel(_) => throw FatalError("Unexpected sat response with model") - case _ => false + object Sat { + def unapply(resp: SolverSatResponse): Boolean = resp match { + case SatResponse => true + case SatResponseWithModel(_) => throw FatalError("Unexpected sat response with model") + case _ => false + } } - } - object Model { - def unapply(resp: SolverSatResponse): Option[Map[ValDef, Expr]] = resp match { - case SatResponseWithModel(model) => Some(model) - case SatResponse => throw FatalError("Unexpected sat response without model") - case _ => None + object Model { + def unapply(resp: SolverSatResponse): Option[Map[ValDef, Expr]] = resp match { + case SatResponseWithModel(model) => Some(model) + case SatResponse => throw FatalError("Unexpected sat response without model") + case _ => None + } } - } - object Unsat { - def unapply(resp: SolverUnsatResponse): Boolean = resp match { - case UnsatResponse => true - case UnsatResponseWithCores(_) => throw FatalError("Unexpected unsat response with cores") - case _ => false + object Unsat { + def unapply(resp: SolverUnsatResponse): Boolean = resp match { + case UnsatResponse => true + case UnsatResponseWithCores(_) => throw FatalError("Unexpected unsat response with cores") + case _ => false + } } - } - object Core { - def unapply(resp: SolverUnsatResponse): Option[Set[Expr]] = resp match { - case UnsatResponseWithCores(cores) => Some(cores) - case UnsatResponse => throw FatalError("Unexpected unsat response with cores") - case _ => None + object Core { + def unapply(resp: SolverUnsatResponse): Option[Set[Expr]] = resp match { + case UnsatResponseWithCores(cores) => Some(cores) + case UnsatResponse => throw FatalError("Unexpected unsat response with cores") + case _ => None + } } } + import SolverResponses._ + object SolverUnsupportedError { def msg(t: Tree, reason: Option[String]) = { s"(of ${t.getClass}) is unsupported by solver ${name}" + reason.map(":\n " + _ ).getOrElse("") diff --git a/src/main/scala/inox/solvers/unrolling/DatatypeManager.scala b/src/main/scala/inox/solvers/unrolling/DatatypeManager.scala deleted file mode 100644 index c435b2c04d9827b5a31c41419407b6e24766919c..0000000000000000000000000000000000000000 --- a/src/main/scala/inox/solvers/unrolling/DatatypeManager.scala +++ /dev/null @@ -1,226 +0,0 @@ -/* Copyright 2009-2015 EPFL, Lausanne */ - -package leon -package solvers -package unrolling - -import purescala.Common._ -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.Constructors._ -import purescala.Extractors._ -import purescala.ExprOps._ -import purescala.Types._ -import purescala.TypeOps.bestRealType - -import utils._ -import utils.SeqUtils._ -import Instantiation._ -import Template._ - -import scala.collection.mutable.{Map => MutableMap, Set => MutableSet} - -case class FreshFunction(expr: Expr) extends Expr with Extractable { - val getType = BooleanType - val extract = Some(Seq(expr), (exprs: Seq[Expr]) => FreshFunction(exprs.head)) -} - -object DatatypeTemplate { - - def apply[T]( - encoder: TemplateEncoder[T], - manager: DatatypeManager[T], - tpe: TypeTree - ) : DatatypeTemplate[T] = { - val id = FreshIdentifier("x", tpe, true) - val expr = matchToIfThenElse(manager.typeUnroller(Variable(id))) - - val pathVar = FreshIdentifier("b", BooleanType, true) - - var condVars = Map[Identifier, T]() - var condTree = Map[Identifier, Set[Identifier]](pathVar -> Set.empty).withDefaultValue(Set.empty) - def storeCond(pathVar: Identifier, id: Identifier): Unit = { - condVars += id -> encoder.encodeId(id) - condTree += pathVar -> (condTree(pathVar) + id) - } - - var guardedExprs = Map[Identifier, Seq[Expr]]() - def storeGuarded(pathVar: Identifier, expr: Expr): Unit = { - val prev = guardedExprs.getOrElse(pathVar, Nil) - guardedExprs += pathVar -> (expr +: prev) - } - - def requireDecomposition(e: Expr): Boolean = exists { - case _: FunctionInvocation => true - case _ => false - } (e) - - def rec(pathVar: Identifier, expr: Expr): Unit = expr match { - case i @ IfExpr(cond, thenn, elze) if requireDecomposition(i) => - val newBool1: Identifier = FreshIdentifier("b", BooleanType, true) - val newBool2: Identifier = FreshIdentifier("b", BooleanType, true) - - storeCond(pathVar, newBool1) - storeCond(pathVar, newBool2) - - storeGuarded(pathVar, or(Variable(newBool1), Variable(newBool2))) - storeGuarded(pathVar, or(not(Variable(newBool1)), not(Variable(newBool2)))) - storeGuarded(pathVar, Equals(Variable(newBool1), cond)) - - rec(newBool1, thenn) - rec(newBool2, elze) - - case a @ And(es) if requireDecomposition(a) => - val partitions = groupWhile(es)(!requireDecomposition(_)) - for (e <- partitions.map(andJoin)) rec(pathVar, e) - - case _ => - storeGuarded(pathVar, expr) - } - - rec(pathVar, expr) - - val (idT, pathVarT) = (encoder.encodeId(id), encoder.encodeId(pathVar)) - val (clauses, blockers, _, functions, _, _) = Template.encode(encoder, - pathVar -> pathVarT, Seq(id -> idT), condVars, Map.empty, guardedExprs, Seq.empty, Seq.empty) - - new DatatypeTemplate[T](encoder, manager, - pathVar -> pathVarT, idT, condVars, condTree, clauses, blockers, functions) - } -} - -class DatatypeTemplate[T] private ( - val encoder: TemplateEncoder[T], - val manager: DatatypeManager[T], - val pathVar: (Identifier, T), - val argument: T, - val condVars: Map[Identifier, T], - val condTree: Map[Identifier, Set[Identifier]], - val clauses: Seq[T], - val blockers: Map[T, Set[TemplateCallInfo[T]]], - val functions: Set[(T, FunctionType, T)]) extends Template[T] { - - val args = Seq(argument) - val exprVars = Map.empty[Identifier, T] - val applications = Map.empty[T, Set[App[T]]] - val lambdas = Seq.empty[LambdaTemplate[T]] - val matchers = Map.empty[T, Set[Matcher[T]]] - val quantifications = Seq.empty[QuantificationTemplate[T]] - - def instantiate(blocker: T, arg: T): Instantiation[T] = instantiate(blocker, Seq(Left(arg))) -} - -class DatatypeManager[T](encoder: TemplateEncoder[T]) extends TemplateManager(encoder) { - - private val classCache: MutableMap[ClassType, FunDef] = MutableMap.empty - - private def classTypeUnroller(ct: ClassType): FunDef = classCache.get(ct) match { - case Some(fd) => fd - case None => - val param = ValDef(FreshIdentifier("x", ct)) - val fd = new FunDef(FreshIdentifier("unroll"+ct.classDef.id), Seq.empty, Seq(param), BooleanType) - classCache += ct -> fd - - val matchers = ct match { - case (act: AbstractClassType) => act.knownCCDescendants - case (cct: CaseClassType) => Seq(cct) - } - - fd.body = Some(MatchExpr(param.toVariable, matchers.map { cct => - val pattern = CaseClassPattern(None, cct, cct.fields.map(vd => WildcardPattern(Some(vd.id)))) - val expr = andJoin(cct.fields.map(vd => typeUnroller(Variable(vd.id)))) - MatchCase(pattern, None, expr) - })) - - fd - } - - private val requireChecking: MutableSet[ClassType] = MutableSet.empty - private val requireCache: MutableMap[TypeTree, Boolean] = MutableMap.empty - - private def requireTypeUnrolling(tpe: TypeTree): Boolean = requireCache.get(tpe) match { - case Some(res) => res - case None => - val res = tpe match { - case ft: FunctionType => true - case ct: CaseClassType if ct.classDef.hasParent => true - case ct: ClassType if requireChecking(ct.root) => false - case ct: ClassType => - requireChecking += ct.root - val classTypes = ct.root +: (ct.root match { - case (act: AbstractClassType) => act.knownDescendants - case (cct: CaseClassType) => Seq.empty - }) - - classTypes.exists(ct => ct.classDef.hasInvariant || ct.fieldsTypes.exists(requireTypeUnrolling)) - - case BooleanType | UnitType | CharType | IntegerType | - RealType | Int32Type | StringType | (_: TypeParameter) => false - - case at: ArrayType => true - - case NAryType(tpes, _) => tpes.exists(requireTypeUnrolling) - } - - requireCache += tpe -> res - res - } - - def typeUnroller(expr: Expr): Expr = expr.getType match { - case tpe if !requireTypeUnrolling(tpe) => - BooleanLiteral(true) - - case ct: ClassType => - val inv = if (ct.classDef.root.hasInvariant) { - Seq(FunctionInvocation(ct.root.invariant.get, Seq(expr))) - } else { - Seq.empty - } - - val subtype = if (ct != ct.root) { - Seq(IsInstanceOf(expr, ct)) - } else { - Seq.empty - } - - val induct = if (!ct.classDef.isInductive) { - val matchers = ct.root match { - case (act: AbstractClassType) => act.knownCCDescendants - case (cct: CaseClassType) => Seq(cct) - } - - val cases = matchers.map { cct => - val pattern = CaseClassPattern(None, cct, cct.fields.map(vd => WildcardPattern(Some(vd.id)))) - val expr = andJoin(cct.fields.map(vd => typeUnroller(Variable(vd.id)))) - MatchCase(pattern, None, expr) - } - - if (cases.forall(_.rhs == BooleanLiteral(true))) None else Some(MatchExpr(expr, cases)) - } else { - val fd = classTypeUnroller(ct.root) - Some(FunctionInvocation(fd.typed, Seq(expr))) - } - - andJoin(inv ++ subtype ++ induct) - - case TupleType(tpes) => - andJoin(tpes.zipWithIndex.map(p => typeUnroller(TupleSelect(expr, p._2 + 1)))) - - case FunctionType(_, _) => - FreshFunction(expr) - - case at: ArrayType => - GreaterEquals(ArrayLength(expr), IntLiteral(0)) - - case _ => scala.sys.error("TODO") - } - - private val typeCache: MutableMap[TypeTree, DatatypeTemplate[T]] = MutableMap.empty - - protected def typeTemplate(tpe: TypeTree): DatatypeTemplate[T] = typeCache.getOrElse(tpe, { - val template = DatatypeTemplate(encoder, this, tpe) - typeCache += tpe -> template - template - }) -} - diff --git a/src/main/scala/inox/solvers/unrolling/DatatypeTemplates.scala b/src/main/scala/inox/solvers/unrolling/DatatypeTemplates.scala new file mode 100644 index 0000000000000000000000000000000000000000..1db7f9eee28d70e43843867a3535ebd5145c5ec7 --- /dev/null +++ b/src/main/scala/inox/solvers/unrolling/DatatypeTemplates.scala @@ -0,0 +1,322 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package inox +package solvers +package unrolling + +import utils._ + +import scala.collection.mutable.{Map => MutableMap, Set => MutableSet} + +trait DatatypeTemplates { self: Templates => + import program._ + import program.trees._ + import program.symbols._ + + import datatypesManager._ + + type Functions = Set[(Encoded, FunctionType, Encoded)] + + /** Represents a type unfolding of a free variable (or input) in the unfolding procedure */ + case class TemplateTypeInfo(tcd: TypedAbstractClassDef, arg: Encoded) { + override def toString = tcd.toType.asString + "(" + arg.asString + ")" + } + + private val cache: MutableMap[Type, DatatypeTemplate] = MutableMap.empty + private def mkTemplate(tpe: Type): DatatypeTemplate = cache.getOrElseUpdate(tpe, DatatypeTemplate(tpe)) + + def registerSymbol(start: Encoded, sym: Encoded, tpe: Type): Clauses = { + mkTemplate(tpe).instantiate(start, sym) + } + + object DatatypeTemplate { + + private val requireChecking: MutableSet[TypedClassDef] = MutableSet.empty + private val requireCache: MutableMap[Type, Boolean] = MutableMap.empty + + private def requireTypeUnrolling(tpe: Type): Boolean = requireCache.get(tpe) match { + case Some(res) => res + case None => + val res = tpe match { + case ft: FunctionType => true + case ct: ClassType => ct.tcd match { + case tccd: TypedCaseClassDef => tccd.parent.isDefined + case tcd if requireChecking(tcd.root) => false + case tcd => + requireChecking += tcd.root + val classTypes = tcd.root +: (tcd.root match { + case (tacd: TypedAbstractClassDef) => tacd.descendants + case _ => Seq.empty + }) + + classTypes.exists(ct => ct.hasInvariant || (ct match { + case tccd: TypedCaseClassDef => tccd.fieldsTypes.exists(requireTypeUnrolling) + case _ => false + })) + } + + case BooleanType | UnitType | CharType | IntegerType | + RealType | StringType | (_: BVType) | (_: TypeParameter) => false + + case NAryType(tpes, _) => tpes.exists(requireTypeUnrolling) + } + + requireCache += tpe -> res + res + } + + private case class FreshFunction(expr: Expr) extends Expr with Extractable { + def getType(implicit s: Symbols) = BooleanType + val extract = Some(Seq(expr), (exprs: Seq[Expr]) => FreshFunction(exprs.head)) + } + + private case class InductiveType(tcd: TypedAbstractClassDef, expr: Expr) extends Expr with Extractable { + def getType(implicit s: Symbols) = BooleanType + val extract = Some(Seq(expr), (exprs: Seq[Expr]) => InductiveType(tcd, exprs.head)) + } + + private def typeUnroller(expr: Expr): Expr = expr.getType match { + case tpe if !requireTypeUnrolling(tpe) => + BooleanLiteral(true) + + case ct: ClassType => + val tcd = ct.tcd + + val inv: Seq[Expr] = if (tcd.hasInvariant) { + Seq(tcd.invariant.get.applied(Seq(expr))) + } else { + Seq.empty + } + + def unrollFields(tcd: TypedCaseClassDef): Seq[Expr] = tcd.fields.map { vd => + val tpe = tcd.toType + typeUnroller(CaseClassSelector(tpe, AsInstanceOf(expr, tpe), vd.id)) + } + + val fields: Seq[Expr] = if (tcd != tcd.root) { + IsInstanceOf(expr, tcd.toType) +: unrollFields(tcd.toCase) + } else { + val isInductive = tcd.cd match { + case acd: AbstractClassDef => acd.isInductive + case _ => false + } + + if (!isInductive) { + val matchers = tcd.root match { + case (act: TypedAbstractClassDef) => act.descendants + case (cct: TypedCaseClassDef) => Seq(cct) + } + + val thens = matchers.map(tcd => tcd -> andJoin(unrollFields(tcd))) + + if (thens.forall(_._2 == BooleanLiteral(true))) { + Seq.empty + } else { + val (ifs :+ last) = thens + Seq(ifs.foldRight(last._2) { case ((tcd, thenn), res) => + IfExpr(IsInstanceOf(expr, tcd.toType), thenn, res) + }) + } + } else { + Seq(InductiveType(tcd.toAbstract, expr)) + } + } + + andJoin(inv ++ fields) + + case TupleType(tpes) => + andJoin(tpes.zipWithIndex.map(p => typeUnroller(TupleSelect(expr, p._2 + 1)))) + + case FunctionType(_, _) => + FreshFunction(expr) + + case _ => scala.sys.error("TODO") + } + + def apply(tpe: Type): DatatypeTemplate = { + val v = Variable(FreshIdentifier("x", true), tpe) + val expr = typeUnroller(v) + + val pathVar = Variable(FreshIdentifier("b", true), BooleanType) + + var condVars = Map[Variable, Encoded]() + var condTree = Map[Variable, Set[Variable]](pathVar -> Set.empty).withDefaultValue(Set.empty) + def storeCond(pathVar: Variable, v: Variable): Unit = { + condVars += v -> encodeSymbol(v) + condTree += pathVar -> (condTree(pathVar) + v) + } + + var guardedExprs = Map[Variable, Seq[Expr]]() + def storeGuarded(pathVar: Variable, expr: Expr): Unit = { + val prev = guardedExprs.getOrElse(pathVar, Nil) + guardedExprs += pathVar -> (expr +: prev) + } + + def iff(e1: Expr, e2: Expr): Unit = storeGuarded(pathVar, Equals(e1, e2)) + + def requireDecomposition(e: Expr): Boolean = exprOps.exists { + case _: FunctionInvocation => true + case _: InductiveType => true + case _ => false + } (e) + + def rec(pathVar: Variable, expr: Expr): Unit = expr match { + case i @ IfExpr(cond, thenn, elze) if requireDecomposition(i) => + val newBool1: Variable = Variable(FreshIdentifier("b", true), BooleanType) + val newBool2: Variable = Variable(FreshIdentifier("b", true), BooleanType) + + storeCond(pathVar, newBool1) + storeCond(pathVar, newBool2) + + iff(and(pathVar, cond), newBool1) + iff(and(pathVar, not(cond)), newBool2) + + rec(newBool1, thenn) + rec(newBool2, elze) + + case a @ And(es) if requireDecomposition(a) => + val partitions = SeqUtils.groupWhile(es)(!requireDecomposition(_)) + for (e <- partitions.map(andJoin)) rec(pathVar, e) + + case _ => + storeGuarded(pathVar, expr) + } + + rec(pathVar, expr) + + val (idT, pathVarT) = (encodeSymbol(v), encodeSymbol(pathVar)) + val encoder: Expr => Encoded = encodeExpr(condVars + (v -> idT) + (pathVar -> pathVarT)) + + var clauses: Clauses = Seq.empty + var calls: CallBlockers = Map.empty + var types: Map[Encoded, Set[TemplateTypeInfo]] = Map.empty + var functions: Functions = Set.empty + + for ((b, es) <- guardedExprs) { + var callInfos : Set[Call] = Set.empty + var typeInfos : Set[TemplateTypeInfo] = Set.empty + + for (e <- es) { + val collected = collectWithPaths { + case expr @ (_: InductiveType | _: FreshFunction) => expr + } (e) + + def clean(e: Expr) = exprOps.postMap { + case _: InductiveType => Some(BooleanLiteral(true)) + case _: FreshFunction => Some(BooleanLiteral(true)) + case _ => None + } (e) + + functions ++= collected.collect { case (FreshFunction(f), path) => + val tpe = bestRealType(f.getType).asInstanceOf[FunctionType] + val cleanPath = path.map(clean) + (encoder(and(b, cleanPath.toPath)), tpe, encoder(f)) + } + + typeInfos ++= collected.collect { case (InductiveType(tcd, e), path) => + assert(path.isEmpty, "Inductive datatype unfolder should be implied directly by the blocker") + TemplateTypeInfo(tcd, encoder(e)) + } + + val cleanExpr = clean(e) + callInfos ++= firstOrderCallsOf(cleanExpr).map { case (id, tps, args) => + Call(getFunction(id, tps), args.map(arg => Left(encoder(arg)))) + } + + clauses :+= encoder(Implies(b, cleanExpr)) + } + + if (typeInfos.nonEmpty) types += encoder(b) -> typeInfos + if (callInfos.nonEmpty) calls += encoder(b) -> callInfos + } + + new DatatypeTemplate(pathVar -> pathVarT, idT, condVars, condTree, clauses, calls, types, functions) + } + } + + class DatatypeTemplate private ( + val pathVar: (Variable, Encoded), + val argument: Encoded, + val condVars: Map[Variable, Encoded], + val condTree: Map[Variable, Set[Variable]], + val clauses: Clauses, + val blockers: CallBlockers, + val types: Map[Encoded, Set[TemplateTypeInfo]], + val functions: Functions) extends Template { + + val args = Seq(argument) + val exprVars = Map.empty[Variable, Encoded] + val applications = Map.empty[Encoded, Set[App]] + val lambdas = Seq.empty[LambdaTemplate] + val matchers = Map.empty[Encoded, Set[Matcher]] + val quantifications = Seq.empty[QuantificationTemplate] + + def instantiate(blocker: Encoded, arg: Encoded): Clauses = { + instantiate(blocker, Seq(Left(arg))) + } + + override def instantiate(substMap: Map[Encoded, Arg]): Clauses = { + val substituter = mkSubstituter(substMap.mapValues(_.encoded)) + var clauses: Clauses = Seq.empty + + for ((b,tpe,f) <- functions) { + clauses ++= registerFunction(substituter(b), tpe, substituter(f)) + } + + for ((b, tps) <- types; bp = substituter(b); tp <- tps) { + val stp = tp.copy(arg = substituter(tp.arg)) + val gen = nextGeneration(currentGeneration) + val notBp = mkNot(bp) + + typeInfos.get(bp) match { + case Some((exGen, origGen, _, exTps)) => + val minGen = gen min exGen + typeInfos += bp -> (minGen, origGen, notBp, exTps + stp) + case None => + typeInfos += bp -> (gen, gen, notBp, Set(stp)) + } + } + + clauses ++ super.instantiate(substMap) + } + } + + private[unrolling] object datatypesManager extends Manager { + val typeInfos = new IncrementalMap[Encoded, (Int, Int, Encoded, Set[TemplateTypeInfo])] + + val incrementals: Seq[IncrementalState] = Seq(typeInfos) + + def unrollGeneration: Option[Int] = + if (typeInfos.isEmpty) None + else Some(typeInfos.values.map(_._1).min) + + def satisfactionAssumptions: Seq[Encoded] = typeInfos.map(_._2._3).toSeq + def refutationAssumptions: Seq[Encoded] = Seq.empty + + def promoteBlocker(b: Encoded): Boolean = { + if (typeInfos contains b) { + val (_, origGen, notB, tps) = typeInfos(b) + typeInfos += b -> (currentGeneration, origGen, notB, tps) + true + } else { + false + } + } + + def unroll: Clauses = if (typeInfos.isEmpty) Seq.empty else { + val blockers = typeInfos.filter(_._2._1 <= currentGeneration).toSeq.map(_._1) + + val newClauses = new scala.collection.mutable.ListBuffer[Encoded] + + val newTypeInfos = blockers.flatMap(id => typeInfos.get(id).map(id -> _)) + typeInfos --= blockers + + for ((blocker, (gen, _, _, tps)) <- newTypeInfos; info @ TemplateTypeInfo(tcd, arg) <- tps) { + val template = mkTemplate(tcd.toType) + newClauses ++= template.instantiate(blocker, arg) + } + + newClauses.toSeq + } + } +} diff --git a/src/main/scala/inox/solvers/unrolling/FunctionTemplates.scala b/src/main/scala/inox/solvers/unrolling/FunctionTemplates.scala new file mode 100644 index 0000000000000000000000000000000000000000..1af37ec6bf9c0edca7ea359b87258b9c65415641 --- /dev/null +++ b/src/main/scala/inox/solvers/unrolling/FunctionTemplates.scala @@ -0,0 +1,175 @@ +/* Copyright 2009-2016 EPFL, Lausanne */ + +package inox +package solvers +package unrolling + +import utils._ + +import scala.collection.generic.CanBuildFrom + +trait FunctionTemplates { self: Templates => + import program._ + import program.trees._ + import program.symbols._ + + import functionsManager._ + + object FunctionTemplate { + + def apply( + tfd: TypedFunDef, + pathVar: (Variable, Encoded), + arguments: Seq[(Variable, Encoded)], + condVars: Map[Variable, Encoded], + exprVars: Map[Variable, Encoded], + condTree: Map[Variable, Set[Variable]], + guardedExprs: Map[Variable, Seq[Expr]], + lambdas: Seq[LambdaTemplate], + quantifications: Seq[QuantificationTemplate] + ) : FunctionTemplate = { + + val (clauses, blockers, applications, matchers, templateString) = + Template.encode(pathVar, arguments, condVars, exprVars, guardedExprs, lambdas, quantifications, + optCall = Some(tfd)) + + val funString : () => String = () => { + "Template for def " + tfd.signature + + "(" + tfd.params.map(a => a.id + " : " + a.getType).mkString(", ") + ") : " + + tfd.returnType + " is :\n" + templateString() + } + + new FunctionTemplate( + tfd, + pathVar, + arguments.map(_._2), + condVars, + exprVars, + condTree, + clauses, + blockers, + applications, + matchers, + lambdas, + quantifications, + funString + ) + } + } + + class FunctionTemplate private( + val tfd: TypedFunDef, + val pathVar: (Variable, Encoded), + val args: Seq[Encoded], + val condVars: Map[Variable, Encoded], + val exprVars: Map[Variable, Encoded], + val condTree: Map[Variable, Set[Variable]], + val clauses: Clauses, + val blockers: Calls, + val applications: Apps, + val matchers: Matchers, + val lambdas: Seq[LambdaTemplate], + val quantifications: Seq[QuantificationTemplate], + stringRepr: () => String) extends Template { + + private lazy val str : String = stringRepr() + override def toString : String = str + } + + def instantiateCall(blocker: Encoded, fi: Call): Clauses = { + val gen = nextGeneration(currentGeneration) + val notBlocker = mkNot(blocker) + + callInfos.get(blocker) match { + case Some((exGen, origGen, _, exFis)) => + // PS: when recycling `b`s, this assertion becomes dangerous. + // It's better to simply take the max of the generations. + // assert(exGen == gen, "Mixing the same id "+id+" with various generations "+ exGen+" and "+gen) + + val minGen = gen min exGen + + callInfos += blocker -> (minGen, origGen, notBlocker, exFis + fi) + case None => + callInfos += blocker -> (gen, gen, notBlocker, Set(fi)) + } + + Seq.empty + } + + private[unrolling] object functionsManager extends Manager { + val incrementals: Seq[IncrementalState] = Seq(callInfos, defBlockers) + + // Function instantiations have their own defblocker + private[FunctionTemplates] val defBlockers = new IncrementalMap[Call, Encoded]() + + // Keep which function invocation is guarded by which guard, + // also specify the generation of the blocker. + private[FunctionTemplates] val callInfos = new IncrementalMap[Encoded, (Int, Int, Encoded, Set[Call])]() + + def unrollGeneration: Option[Int] = + if (callInfos.isEmpty) None + else Some(callInfos.values.map(_._1).min) + + def satisfactionAssumptions: Seq[Encoded] = callInfos.map(_._2._3).toSeq + def refutationAssumptions: Seq[Encoded] = Seq.empty + + def promoteBlocker(b: Encoded): Boolean = { + if (callInfos contains b) { + val (_, origGen, notB, fis) = callInfos(b) + callInfos += b -> (currentGeneration, origGen, notB, fis) + true + } else { + false + } + } + + def unroll: Clauses = if (callInfos.isEmpty) Seq.empty else { + val blockers = callInfos.filter(_._2._1 <= currentGeneration).toSeq.map(_._1) + + val newClauses = new scala.collection.mutable.ListBuffer[Encoded] + + val newCallInfos = blockers.flatMap(id => callInfos.get(id).map(id -> _)) + callInfos --= blockers + + for ((blocker, (gen, _, _, calls)) <- newCallInfos; call @ Call(tfd, args) <- calls) { + val newCls = new scala.collection.mutable.ListBuffer[Encoded] + + val defBlocker = defBlockers.get(call) match { + case Some(defBlocker) => + // we already have defBlocker => f(args) = body + defBlocker + + case None => + // we need to define this defBlocker and link it to definition + val defBlocker = encodeSymbol(Variable(FreshIdentifier("d", true), BooleanType)) + defBlockers += call -> defBlocker + + val template = mkTemplate(tfd) + //reporter.debug(template) + + val newExprs = template.instantiate(defBlocker, args) + + newCls ++= newExprs + defBlocker + } + + // We connect it to the defBlocker: blocker => defBlocker + if (defBlocker != blocker) { + newCls += mkImplies(blocker, defBlocker) + impliesBlocker(blocker, defBlocker) + } + + ctx.reporter.debug("Unrolling behind "+call+" ("+newCls.size+")") + for (cl <- newCls) { + ctx.reporter.debug(" . "+cl) + } + + newClauses ++= newCls + } + + ctx.reporter.debug(s" - ${newClauses.size} new clauses") + + newClauses.toSeq + } + } +} diff --git a/src/main/scala/inox/solvers/unrolling/LambdaManager.scala b/src/main/scala/inox/solvers/unrolling/LambdaManager.scala deleted file mode 100644 index 485edfde6c3a938b831e9a8e3f34e1da21f648f1..0000000000000000000000000000000000000000 --- a/src/main/scala/inox/solvers/unrolling/LambdaManager.scala +++ /dev/null @@ -1,452 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package solvers -package unrolling - -import purescala.Common._ -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.Constructors._ -import purescala.Extractors._ -import purescala.ExprOps._ -import purescala.Types._ -import purescala.TypeOps.bestRealType - -import utils._ -import utils.SeqUtils._ -import Instantiation._ -import Template._ - -import scala.collection.mutable.{Map => MutableMap, Set => MutableSet} - -/** Represents an application of a first-class function in the unfolding procedure */ -case class App[T](caller: T, tpe: FunctionType, args: Seq[Arg[T]], encoded: T) { - override def toString = "(" + caller + " : " + tpe + ")" + args.map(_.encoded).mkString("(", ",", ")") - def substitute(substituter: T => T, msubst: Map[T, Matcher[T]]): App[T] = copy( - caller = substituter(caller), - args = args.map(_.substitute(substituter, msubst)), - encoded = substituter(encoded) - ) -} - -/** Constructor object for [[LambdaTemplate]] - * - * The [[apply]] methods performs some pre-processing before creating - * an instance of [[LambdaTemplate]]. - */ -object LambdaTemplate { - - def apply[T]( - ids: (Identifier, T), - encoder: TemplateEncoder[T], - manager: QuantificationManager[T], - pathVar: (Identifier, T), - arguments: Seq[(Identifier, T)], - condVars: Map[Identifier, T], - exprVars: Map[Identifier, T], - condTree: Map[Identifier, Set[Identifier]], - guardedExprs: Map[Identifier, Seq[Expr]], - quantifications: Seq[QuantificationTemplate[T]], - lambdas: Seq[LambdaTemplate[T]], - structure: LambdaStructure[T], - baseSubstMap: Map[Identifier, T], - lambda: Lambda - ) : LambdaTemplate[T] = { - - val id = ids._2 - val tpe = ids._1.getType.asInstanceOf[FunctionType] - val (clauses, blockers, applications, functions, matchers, templateString) = - Template.encode(encoder, pathVar, arguments, condVars, exprVars, guardedExprs, lambdas, quantifications, - substMap = baseSubstMap + ids, optApp = Some(id -> tpe)) - - assert(functions.isEmpty, "Only synthetic type explorers should introduce functions!") - - val lambdaString : () => String = () => { - "Template for lambda " + ids._1 + ": " + lambda + " is :\n" + templateString() - } - - new LambdaTemplate[T]( - ids, - encoder, - manager, - pathVar, - arguments, - condVars, - exprVars, - condTree, - clauses, - blockers, - applications, - lambdas, - matchers, - quantifications, - structure, - lambda, - lambdaString - ) - } -} - -/** Semi-template used for hardcore function equality - * - * Function equality, while unhandled in general, can be very useful for certain - * proofs that refer specifically to first-class functions. In order to support - * such proofs, flexible notions of equality on first-class functions are - * necessary. These are provided by [[LambdaStructure]] which, much like a - * [[Template]], will generate clauses that represent equality between two - * functions. - * - * To support complex cases of equality where closed portions of the first-class - * function rely on complex program features (function calls, introducing lambdas, - * foralls, etc.), we use a structure that resembles a [[Template]] that is - * instantiated when function equality is of interest. - * - * Note that lambda creation now introduces clauses to determine equality between - * closed portions (that are independent of the lambda arguments). - */ -class LambdaStructure[T] private[unrolling] ( - /** @see [[Template.encoder]] */ - val encoder: TemplateEncoder[T], - /** @see [[Template.manager]] */ - val manager: QuantificationManager[T], - - /** The normalized lambda that is shared between all "equal" first-class functions. - * First-class function equality is conditionned on `lambda` equality. - * - * @see [[dependencies]] for the other component of equality between first-class functions - */ - val lambda: Lambda, - - /** The closed expressions (independent of the arguments to [[lambda]] contained in - * the first-class function. Equality is conditioned on equality of `dependencies` - * (inside the solver). - * - * @see [[lambda]] for the other component of equality between first-class functions - */ - val dependencies: Seq[T], - val pathVar: (Identifier, T), - - /** The set of closed variables that exist in the associated lambda. - * - * This set is necessary to determine whether other closures have been - * captured by this particular closure when deciding the order of - * lambda instantiations in [[Template.substitution]]. - */ - val closures: Seq[T], - val condVars: Map[Identifier, T], - val exprVars: Map[Identifier, T], - val condTree: Map[Identifier, Set[Identifier]], - val clauses: Seq[T], - val blockers: Calls[T], - val applications: Apps[T], - val lambdas: Seq[LambdaTemplate[T]], - val matchers: Map[T, Set[Matcher[T]]], - val quantifications: Seq[QuantificationTemplate[T]]) { - - def substitute(substituter: T => T, matcherSubst: Map[T, Matcher[T]]) = new LambdaStructure[T]( - encoder, manager, lambda, - dependencies.map(substituter), - pathVar._1 -> substituter(pathVar._2), - closures.map(substituter), condVars, exprVars, condTree, - clauses.map(substituter), - blockers.map { case (b, fis) => substituter(b) -> fis.map(fi => fi.copy( - args = fi.args.map(_.substitute(substituter, matcherSubst)))) }, - applications.map { case (b, fas) => substituter(b) -> fas.map(_.substitute(substituter, matcherSubst)) }, - lambdas.map(_.substitute(substituter, matcherSubst)), - matchers.map { case (b, ms) => substituter(b) -> ms.map(_.substitute(substituter, matcherSubst)) }, - quantifications.map(_.substitute(substituter, matcherSubst))) - - /** The [[key]] value (tuple of [[lambda]] and [[dependencies]]) is used - * to determine syntactic equality between lambdas. If the keys of two - * closures are equal, then they must necessarily be equal in every model. - * - * The [[instantiation]] consists of the clause set instantiation (in the - * sense of [[Template.instantiate]] that is required for [[dependencies]] - * to make sense in the solver (introduces blockers, quantifications, other - * lambdas, etc.) Since [[dependencies]] CHANGE during instantiation and - * [[key]] makes no sense without the associated instantiation, the implicit - * contract here is that whenever a new key appears during unfolding, its - * associated instantiation MUST be added to the set of instantiations - * managed by the solver. However, if an identical pre-existing key has - * already been found, then the associated instantiations must already appear - * in those handled by the solver. - */ - lazy val (key, instantiation) = { - val (substMap, substInst) = Template.substitution[T](encoder, manager, - condVars, exprVars, condTree, quantifications, lambdas, Set.empty, Map.empty, pathVar._1, pathVar._2) - val tmplInst = Template.instantiate(encoder, manager, clauses, blockers, applications, matchers, substMap) - - val key = (lambda, dependencies.map(encoder.substitute(substMap.mapValues(_.encoded)))) - val instantiation = substInst ++ tmplInst - (key, instantiation) - } - - override def equals(that: Any): Boolean = that match { - case (struct: LambdaStructure[T]) => key == struct.key - case _ => false - } - - override def hashCode: Int = key.hashCode -} - -class LambdaTemplate[T] private ( - val ids: (Identifier, T), - val encoder: TemplateEncoder[T], - val manager: QuantificationManager[T], - val pathVar: (Identifier, T), - val arguments: Seq[(Identifier, T)], - val condVars: Map[Identifier, T], - val exprVars: Map[Identifier, T], - val condTree: Map[Identifier, Set[Identifier]], - val clauses: Clauses[T], - val blockers: Calls[T], - val applications: Apps[T], - val lambdas: Seq[LambdaTemplate[T]], - val matchers: Map[T, Set[Matcher[T]]], - val quantifications: Seq[QuantificationTemplate[T]], - val structure: LambdaStructure[T], - val lambda: Lambda, - stringRepr: () => String) extends Template[T] { - - val args = arguments.map(_._2) - val tpe = bestRealType(ids._1.getType).asInstanceOf[FunctionType] - val functions: Set[(T, FunctionType, T)] = Set.empty - - def substitute(substituter: T => T, matcherSubst: Map[T, Matcher[T]]): LambdaTemplate[T] = { - val newStart = substituter(start) - val newClauses = clauses.map(substituter) - val newBlockers = blockers.map { case (b, fis) => - val bp = if (b == start) newStart else b - bp -> fis.map(fi => fi.copy( - args = fi.args.map(_.substitute(substituter, matcherSubst)) - )) - } - - val newApplications = applications.map { case (b, fas) => - val bp = if (b == start) newStart else b - bp -> fas.map(_.substitute(substituter, matcherSubst)) - } - - val newLambdas = lambdas.map(_.substitute(substituter, matcherSubst)) - - val newMatchers = matchers.map { case (b, ms) => - val bp = if (b == start) newStart else b - bp -> ms.map(_.substitute(substituter, matcherSubst)) - } - - val newQuantifications = quantifications.map(_.substitute(substituter, matcherSubst)) - - val newStructure = structure.substitute(substituter, matcherSubst) - - new LambdaTemplate[T]( - ids._1 -> substituter(ids._2), - encoder, - manager, - pathVar._1 -> newStart, - arguments, - condVars, - exprVars, - condTree, - newClauses, - newBlockers, - newApplications, - newLambdas, - newMatchers, - newQuantifications, - newStructure, - lambda, - stringRepr - ) - } - - def withId(idT: T): LambdaTemplate[T] = { - val substituter = encoder.substitute(Map(ids._2 -> idT)) - new LambdaTemplate[T]( - ids._1 -> idT, encoder, manager, pathVar, arguments, condVars, exprVars, condTree, - clauses map substituter, // make sure the body-defining clause is inlined! - blockers, applications, lambdas, matchers, quantifications, - structure, lambda, stringRepr - ) - } - - private lazy val str : String = stringRepr() - override def toString : String = str - - override def instantiate(substMap: Map[T, Arg[T]]): Instantiation[T] = { - super.instantiate(substMap) ++ manager.instantiateAxiom(this, substMap) - } -} - -class LambdaManager[T](encoder: TemplateEncoder[T]) extends DatatypeManager(encoder) { - private[unrolling] lazy val trueT = encoder.encodeExpr(Map.empty)(BooleanLiteral(true)) - - protected[unrolling] val byID = new IncrementalMap[T, LambdaTemplate[T]] - protected val byType = new IncrementalMap[FunctionType, Map[LambdaStructure[T], LambdaTemplate[T]]].withDefaultValue(Map.empty) - protected val applications = new IncrementalMap[FunctionType, Set[(T, App[T])]].withDefaultValue(Set.empty) - protected val knownFree = new IncrementalMap[FunctionType, Set[T]].withDefaultValue(Set.empty) - protected val maybeFree = new IncrementalMap[FunctionType, Set[(T, T)]].withDefaultValue(Set.empty) - protected val freeBlockers = new IncrementalMap[FunctionType, Set[(T, T)]].withDefaultValue(Set.empty) - - private val instantiated = new IncrementalSet[(T, App[T])] - - override protected def incrementals: List[IncrementalState] = - super.incrementals ++ List(byID, byType, applications, knownFree, maybeFree, freeBlockers, instantiated) - - def registerFunction(b: T, tpe: FunctionType, f: T): Instantiation[T] = { - val ft = bestRealType(tpe).asInstanceOf[FunctionType] - val bs = fixpoint((bs: Set[T]) => bs ++ bs.flatMap(blockerParents))(Set(b)) - - val (known, neqClauses) = if ((bs intersect typeEnablers).nonEmpty) { - maybeFree += ft -> (maybeFree(ft) + (b -> f)) - (false, byType(ft).values.toSeq.map { t => - encoder.mkImplies(b, encoder.mkNot(encoder.mkEquals(t.ids._2, f))) - }) - } else { - knownFree += ft -> (knownFree(ft) + f) - (true, byType(ft).values.toSeq.map(t => encoder.mkNot(encoder.mkEquals(t.ids._2, f)))) - } - - val extClauses = freeBlockers(tpe).map { case (oldB, freeF) => - val equals = encoder.mkEquals(f, freeF) - val nextB = encoder.encodeId(FreshIdentifier("b_or", BooleanType, true)) - val extension = encoder.mkOr(if (known) equals else encoder.mkAnd(b, equals), nextB) - encoder.mkEquals(oldB, extension) - } - - val instantiation = Instantiation.empty[T] withClauses (neqClauses ++ extClauses) - - applications(tpe).foldLeft(instantiation) { - case (instantiation, (app @ (_, App(caller, _, args, _)))) => - val equals = encoder.mkAnd(b, encoder.mkEquals(f, caller)) - instantiation withApp (app -> TemplateAppInfo(f, equals, args)) - } - } - - def assumptions: Seq[T] = freeBlockers.toSeq.flatMap(_._2.map(p => encoder.mkNot(p._1))) - - private val typeBlockers = new IncrementalMap[T, T]() - private val typeEnablers: MutableSet[T] = MutableSet.empty - - private def typeUnroller(blocker: T, app: App[T]): Instantiation[T] = typeBlockers.get(app.encoded) match { - case Some(typeBlocker) => - implies(blocker, typeBlocker) - (Seq(encoder.mkImplies(blocker, typeBlocker)), Map.empty, Map.empty) - - case None => - val App(caller, tpe @ FirstOrderFunctionType(_, to), args, value) = app - val typeBlocker = encoder.encodeId(FreshIdentifier("t", BooleanType)) - typeBlockers += value -> typeBlocker - implies(blocker, typeBlocker) - - val template = typeTemplate(to) - val instantiation = template.instantiate(typeBlocker, value) - - val (b, extClauses) = if (knownFree(tpe) contains caller) { - (blocker, Seq.empty) - } else { - val firstB = encoder.encodeId(FreshIdentifier("b_free", BooleanType, true)) - implies(firstB, typeBlocker) - typeEnablers += firstB - - val nextB = encoder.encodeId(FreshIdentifier("b_or", BooleanType, true)) - freeBlockers += tpe -> (freeBlockers(tpe) + (nextB -> caller)) - - val clause = encoder.mkEquals(firstB, encoder.mkAnd(blocker, encoder.mkOr( - knownFree(tpe).map(idT => encoder.mkEquals(caller, idT)).toSeq ++ - maybeFree(tpe).map { case (b, idT) => encoder.mkAnd(b, encoder.mkEquals(caller, idT)) } :+ - nextB : _*))) - (firstB, Seq(clause)) - } - - instantiation withClauses extClauses withClause encoder.mkImplies(b, typeBlocker) - } - - def instantiateLambda(template: LambdaTemplate[T]): (T, Instantiation[T]) = { - byType(template.tpe).get(template.structure) match { - case Some(template) => - (template.ids._2, Instantiation.empty) - - case None => - val idT = encoder.encodeId(template.ids._1) - val newTemplate = template.withId(idT) - - // make sure the new lambda isn't equal to any free lambda var - val instantiation = newTemplate.structure.instantiation withClauses ( - equalityClauses(newTemplate) ++ - knownFree(newTemplate.tpe).map(f => encoder.mkNot(encoder.mkEquals(idT, f))).toSeq ++ - maybeFree(newTemplate.tpe).map { p => - encoder.mkImplies(p._1, encoder.mkNot(encoder.mkEquals(idT, p._2))) - }) - - byID += idT -> newTemplate - byType += newTemplate.tpe -> (byType(newTemplate.tpe) + (newTemplate.structure -> newTemplate)) - - val inst = applications(newTemplate.tpe).foldLeft(instantiation) { - case (instantiation, app @ (_, App(caller, _, args, _))) => - val equals = encoder.mkEquals(idT, caller) - instantiation withApp (app -> TemplateAppInfo(newTemplate, equals, args)) - } - - (idT, inst) - } - } - - def instantiateApp(blocker: T, app: App[T]): Instantiation[T] = { - val App(caller, tpe @ FunctionType(_, to), args, encoded) = app - - val instantiation: Instantiation[T] = if (byID contains caller) { - Instantiation.empty - } else { - typeUnroller(blocker, app) - } - - val key = blocker -> app - if (instantiated(key)) { - instantiation - } else { - instantiated += key - - if (knownFree(tpe) contains caller) { - instantiation - } else if (byID contains caller) { - instantiation withApp (key -> TemplateAppInfo(byID(caller), trueT, args)) - } else { - - // make sure that even if byType(tpe) is empty, app is recorded in blockers - // so that UnrollingBank will generate the initial block! - val init = instantiation withApps Map(key -> Set.empty) - val inst = byType(tpe).values.foldLeft(init) { - case (instantiation, template) => - val equals = encoder.mkEquals(template.ids._2, caller) - instantiation withApp (key -> TemplateAppInfo(template, equals, args)) - } - - applications += tpe -> (applications(tpe) + key) - - inst - } - } - } - - private def equalityClauses(template: LambdaTemplate[T]): Seq[T] = { - byType(template.tpe).values.map { that => - val equals = encoder.mkEquals(template.ids._2, that.ids._2) - encoder.mkImplies( - encoder.mkAnd(template.pathVar._2, that.pathVar._2), - if (template.structure.lambda == that.structure.lambda) { - val pairs = template.structure.dependencies zip that.structure.dependencies - val filtered = pairs.filter(p => p._1 != p._2) - if (filtered.isEmpty) { - equals - } else { - val eqs = filtered.map(p => encoder.mkEquals(p._1, p._2)) - encoder.mkEquals(encoder.mkAnd(eqs : _*), equals) - } - } else { - encoder.mkNot(equals) - }) - }.toSeq - } -} - diff --git a/src/main/scala/inox/solvers/unrolling/LambdaTemplates.scala b/src/main/scala/inox/solvers/unrolling/LambdaTemplates.scala new file mode 100644 index 0000000000000000000000000000000000000000..83eb4548aca9d9f51770fdc4edf9a1194a0223e1 --- /dev/null +++ b/src/main/scala/inox/solvers/unrolling/LambdaTemplates.scala @@ -0,0 +1,520 @@ +/* Copyright 2009-2016 EPFL, Lausanne */ + +package inox +package solvers +package unrolling + +import utils._ + +import scala.collection.mutable.{Map => MutableMap, Set => MutableSet} + +trait LambdaTemplates { self: Templates => + import program._ + import program.trees._ + import program.symbols._ + + import lambdasManager._ + + /** Represents a POTENTIAL application of a first-class function in the unfolding procedure + * + * The [[equals]] condition guards the application for equality of the concrete caller with + * the provided template in order to perform dynamic dispatch. + */ + case class TemplateAppInfo(template: Either[LambdaTemplate, Encoded], equals: Encoded, args: Seq[Arg]) { + override def toString = { + val caller = template match { + case Left(tmpl) => tmpl.ids._2.asString + case Right(c) => c.asString + } + + caller + "|" + equals.asString + args.map { + case Right(m) => m.toString + case Left(v) => v.asString + }.mkString("(", ",", ")") + } + } + + object TemplateAppInfo { + def apply(template: LambdaTemplate, equals: Encoded, args: Seq[Arg]): TemplateAppInfo = + TemplateAppInfo(Left(template), equals, args) + + def apply(caller: Encoded, equals: Encoded, args: Seq[Arg]): TemplateAppInfo = + TemplateAppInfo(Right(caller), equals, args) + } + + + /** Constructor object for [[LambdaTemplate]] + * + * The [[apply]] methods performs some pre-processing before creating + * an instance of [[LambdaTemplate]]. + */ + object LambdaTemplate { + + def apply( + ids: (Variable, Encoded), + pathVar: (Variable, Encoded), + arguments: Seq[(Variable, Encoded)], + condVars: Map[Variable, Encoded], + exprVars: Map[Variable, Encoded], + condTree: Map[Variable, Set[Variable]], + guardedExprs: Map[Variable, Seq[Expr]], + lambdas: Seq[LambdaTemplate], + quantifications: Seq[QuantificationTemplate], + structure: LambdaStructure, + baseSubstMap: Map[Variable, Encoded], + lambda: Lambda + ) : LambdaTemplate = { + + val id = ids._2 + val tpe = ids._1.getType.asInstanceOf[FunctionType] + val (clauses, blockers, applications, matchers, templateString) = + Template.encode(pathVar, arguments, condVars, exprVars, guardedExprs, lambdas, quantifications, + substMap = baseSubstMap + ids, optApp = Some(id -> tpe)) + + val lambdaString : () => String = () => { + "Template for lambda " + ids._1 + ": " + lambda + " is :\n" + templateString() + } + + new LambdaTemplate( + ids, + pathVar, + arguments, + condVars, + exprVars, + condTree, + clauses, + blockers, + applications, + matchers, + lambdas, + quantifications, + structure, + lambda, + lambdaString + ) + } + } + + /** Semi-template used for hardcore function equality + * + * Function equality, while unhandled in general, can be very useful for certain + * proofs that refer specifically to first-class functions. In order to support + * such proofs, flexible notions of equality on first-class functions are + * necessary. These are provided by [[LambdaStructure]] which, much like a + * [[Template]], will generate clauses that represent equality between two + * functions. + * + * To support complex cases of equality where closed portions of the first-class + * function rely on complex program features (function calls, introducing lambdas, + * foralls, etc.), we use a structure that resembles a [[Template]] that is + * instantiated when function equality is of interest. + * + * Note that lambda creation now introduces clauses to determine equality between + * closed portions (that are independent of the lambda arguments). + */ + class LambdaStructure private[unrolling] ( + /** The normalized lambda that is shared between all "equal" first-class functions. + * First-class function equality is conditionned on `lambda` equality. + * + * @see [[dependencies]] for the other component of equality between first-class functions + */ + val lambda: Lambda, + + /** The closed expressions (independent of the arguments to [[lambda]] contained in + * the first-class function. Equality is conditioned on equality of `dependencies` + * (inside the solver). + * + * @see [[lambda]] for the other component of equality between first-class functions + */ + val dependencies: Seq[Encoded], + val pathVar: (Variable, Encoded), + + /** The set of closed variables that exist in the associated lambda. + * + * This set is necessary to determine whether other closures have been + * captured by this particular closure when deciding the order of + * lambda instantiations in [[Template.substitution]]. + */ + val closures: Seq[Encoded], + val condVars: Map[Variable, Encoded], + val exprVars: Map[Variable, Encoded], + val condTree: Map[Variable, Set[Variable]], + val clauses: Clauses, + val blockers: Calls, + val applications: Apps, + val matchers: Matchers, + val lambdas: Seq[LambdaTemplate], + val quantifications: Seq[QuantificationTemplate]) { + + def substitute(substituter: Encoded => Encoded, msubst: Map[Encoded, Matcher]) = new LambdaStructure( + lambda, + dependencies.map(substituter), + pathVar._1 -> substituter(pathVar._2), + closures.map(substituter), condVars, exprVars, condTree, + clauses.map(substituter), + blockers.map { case (b, fis) => substituter(b) -> fis.map(_.substitute(substituter, msubst)) }, + applications.map { case (b, fas) => substituter(b) -> fas.map(_.substitute(substituter, msubst)) }, + matchers.map { case (b, ms) => substituter(b) -> ms.map(_.substitute(substituter, msubst)) }, + lambdas.map(_.substitute(substituter, msubst)), + quantifications.map(_.substitute(substituter, msubst))) + + /** The [[key]] value (tuple of [[lambda]] and [[dependencies]]) is used + * to determine syntactic equality between lambdas. If the keys of two + * closures are equal, then they must necessarily be equal in every model. + * + * The [[instantiation]] consists of the clause set instantiation (in the + * sense of [[Template.instantiate]] that is required for [[dependencies]] + * to make sense in the solver (introduces blockers, quantifications, other + * lambdas, etc.) Since [[dependencies]] CHANGE during instantiation and + * [[key]] makes no sense without the associated instantiation, the implicit + * contract here is that whenever a new key appears during unfolding, its + * associated instantiation MUST be added to the set of instantiations + * managed by the solver. However, if an identical pre-existing key has + * already been found, then the associated instantiations must already appear + * in those handled by the solver. + */ + lazy val (key, instantiation) = { + val (substMap, substInst) = Template.substitution( + condVars, exprVars, condTree, lambdas, quantifications, Map.empty, pathVar._1, pathVar._2) + val tmplInst = Template.instantiate(clauses, blockers, applications, matchers, substMap) + + val substituter = mkSubstituter(substMap.mapValues(_.encoded)) + val key = (lambda, dependencies.map(substituter)) + val instantiation = substInst ++ tmplInst + (key, instantiation) + } + + override def equals(that: Any): Boolean = that match { + case (struct: LambdaStructure) => key == struct.key + case _ => false + } + + override def hashCode: Int = key.hashCode + } + + class LambdaTemplate private ( + val ids: (Variable, Encoded), + val pathVar: (Variable, Encoded), + val arguments: Seq[(Variable, Encoded)], + val condVars: Map[Variable, Encoded], + val exprVars: Map[Variable, Encoded], + val condTree: Map[Variable, Set[Variable]], + val clauses: Clauses, + val blockers: Calls, + val applications: Apps, + val matchers: Matchers, + val lambdas: Seq[LambdaTemplate], + val quantifications: Seq[QuantificationTemplate], + val structure: LambdaStructure, + val lambda: Lambda, + stringRepr: () => String) extends Template { + + val args = arguments.map(_._2) + val tpe = bestRealType(ids._1.getType).asInstanceOf[FunctionType] + + def substitute(substituter: Encoded => Encoded, msubst: Map[Encoded, Matcher]): LambdaTemplate = new LambdaTemplate( + ids._1 -> substituter(ids._2), + pathVar._1 -> substituter(pathVar._2), + arguments, condVars, exprVars, condTree, + clauses.map(substituter), + blockers.map { case (b, fis) => substituter(b) -> fis.map(_.substitute(substituter, msubst)) }, + applications.map { case (b, apps) => substituter(b) -> apps.map(_.substitute(substituter, msubst)) }, + matchers.map { case (b, ms) => substituter(b) -> ms.map(_.substitute(substituter, msubst)) }, + lambdas.map(_.substitute(substituter, msubst)), + quantifications.map(_.substitute(substituter, msubst)), + structure.substitute(substituter, msubst), + lambda, stringRepr) + + def withId(idT: Encoded): LambdaTemplate = { + val substituter = mkSubstituter(Map(ids._2 -> idT)) + new LambdaTemplate( + ids._1 -> idT, pathVar, arguments, condVars, exprVars, condTree, + clauses map substituter, // make sure the body-defining clause is inlined! + blockers, applications, matchers, lambdas, quantifications, + structure, lambda, stringRepr + ) + } + + private lazy val str : String = stringRepr() + override def toString : String = str + + override def instantiate(substMap: Map[Encoded, Arg]): Clauses = { + super.instantiate(substMap) ++ instantiateAxiom(this, substMap) + } + } + + def registerFunction(b: Encoded, tpe: FunctionType, f: Encoded): Clauses = { + val ft = bestRealType(tpe).asInstanceOf[FunctionType] + val bs = fixpoint((bs: Set[Encoded]) => bs ++ bs.flatMap(blockerParents))(Set(b)) + + val (known, neqClauses) = if ((bs intersect typeEnablers).nonEmpty) { + maybeFree += ft -> (maybeFree(ft) + (b -> f)) + (false, byType(ft).values.toSeq.map { t => + mkImplies(b, mkNot(mkEquals(t.ids._2, f))) + }) + } else { + knownFree += ft -> (knownFree(ft) + f) + (true, byType(ft).values.toSeq.map(t => mkNot(mkEquals(t.ids._2, f)))) + } + + val extClauses = freeBlockers(tpe).map { case (oldB, freeF) => + val equals = mkEquals(f, freeF) + val nextB = encodeSymbol(Variable(FreshIdentifier("b_or", true), BooleanType)) + val extension = mkOr(if (known) equals else mkAnd(b, equals), nextB) + mkEquals(oldB, extension) + } + + lazy val gen = nextGeneration(currentGeneration) + for (app @ (_, App(caller, _, args, _)) <- applications(tpe)) { + val equals = mkAnd(b, mkEquals(f, caller)) + registerAppBlocker(gen, app, Right(f), equals, args) + } + + neqClauses ++ extClauses + } + + private def typeUnroller(blocker: Encoded, app: App): Clauses = typeBlockers.get(app.encoded) match { + case Some(typeBlocker) => + impliesBlocker(blocker, typeBlocker) + Seq(mkImplies(blocker, typeBlocker)) + + case None => + val App(caller, tpe @ FirstOrderFunctionType(_, to), args, value) = app + val typeBlocker = encodeSymbol(Variable(FreshIdentifier("t"), BooleanType)) + typeBlockers += value -> typeBlocker + impliesBlocker(blocker, typeBlocker) + + val clauses = registerSymbol(typeBlocker, value, to) + + val (b, extClauses) = if (knownFree(tpe) contains caller) { + (blocker, Seq.empty) + } else { + val firstB = encodeSymbol(Variable(FreshIdentifier("b_free", true), BooleanType)) + impliesBlocker(firstB, typeBlocker) + typeEnablers += firstB + + val nextB = encodeSymbol(Variable(FreshIdentifier("b_or", true), BooleanType)) + freeBlockers += tpe -> (freeBlockers(tpe) + (nextB -> caller)) + + val clause = mkEquals(firstB, mkAnd(blocker, mkOr( + knownFree(tpe).map(idT => mkEquals(caller, idT)).toSeq ++ + maybeFree(tpe).map { case (b, idT) => mkAnd(b, mkEquals(caller, idT)) } :+ + nextB : _*))) + (firstB, Seq(clause)) + } + + clauses ++ extClauses :+ mkImplies(b, typeBlocker) + } + + private def registerAppBlocker(gen: Int, key: (Encoded, App), template: Either[LambdaTemplate, Encoded], equals: Encoded, args: Seq[Arg]): Unit = { + val info = TemplateAppInfo(template, equals, args) + appInfos.get(key) match { + case Some((exGen, origGen, b, notB, exInfo)) => + val minGen = gen min exGen + appInfos += key -> (minGen, origGen, b, notB, exInfo + info) + + case None => + val b = appBlockers(key) + val notB = mkNot(b) + appInfos += key -> (gen, gen, b, notB, Set(info)) + blockerToApps += b -> key + } + } + + def instantiateLambda(template: LambdaTemplate): (Encoded, Clauses) = { + byType(template.tpe).get(template.structure) match { + case Some(template) => + (template.ids._2, Seq.empty) + + case None => + val idT = encodeSymbol(template.ids._1) + val newTemplate = template.withId(idT) + + // make sure the new lambda isn't equal to any free lambda var + val clauses = newTemplate.structure.instantiation ++ + equalityClauses(newTemplate) ++ + knownFree(newTemplate.tpe).map(f => mkNot(mkEquals(idT, f))).toSeq ++ + maybeFree(newTemplate.tpe).map { p => + mkImplies(p._1, mkNot(mkEquals(idT, p._2))) + } + + byID += idT -> newTemplate + byType += newTemplate.tpe -> (byType(newTemplate.tpe) + (newTemplate.structure -> newTemplate)) + + val gen = nextGeneration(currentGeneration) + for (app @ (_, App(caller, _, args, _)) <- applications(newTemplate.tpe)) { + val equals = mkEquals(idT, caller) + registerAppBlocker(gen, app, Left(newTemplate), equals, args) + } + + (idT, clauses) + } + } + + def instantiateApp(blocker: Encoded, app: App): Clauses = { + val App(caller, tpe @ FunctionType(_, to), args, encoded) = app + + val clauses: Clauses = + if (byID contains caller) Seq.empty + else typeUnroller(blocker, app) + + val key = blocker -> app + if (instantiated(key)) clauses else { + instantiated += key + + if (knownFree(tpe) contains caller) { + clauses + } else if (byID contains caller) { + // we register this app at the CURRENT generation to increase the performance + // of fold-style higher-order functions (the first-class function will be + // dispatched immediately after the fold-style function unrolling) + registerAppBlocker(currentGeneration, key, Left(byID(caller)), trueT, args) + clauses + } else { + val freshAppClause = if (appBlockers.isDefinedAt(key)) None else { + val firstB = encodeSymbol(Variable(FreshIdentifier("b_lambda", true), BooleanType)) + val clause = mkImplies(mkNot(firstB), mkNot(blocker)) + + appBlockers += key -> firstB + Some(clause) + } + + lazy val gen = nextGeneration(currentGeneration) + for (template <- byType(tpe).values) { + val equals = mkEquals(template.ids._2, caller) + registerAppBlocker(gen, key, Left(template), equals, args) + } + + applications += tpe -> (applications(tpe) + key) + + clauses ++ freshAppClause + } + } + } + + private def equalityClauses(template: LambdaTemplate): Clauses = { + byType(template.tpe).values.map { that => + val equals = mkEquals(template.ids._2, that.ids._2) + mkImplies( + mkAnd(template.pathVar._2, that.pathVar._2), + if (template.structure.lambda == that.structure.lambda) { + val pairs = template.structure.dependencies zip that.structure.dependencies + val filtered = pairs.filter(p => p._1 != p._2) + if (filtered.isEmpty) { + equals + } else { + val eqs = filtered.map(p => mkEquals(p._1, p._2)) + mkEquals(mkAnd(eqs : _*), equals) + } + } else { + mkNot(equals) + }) + }.toSeq + } + + private[unrolling] object lambdasManager extends Manager { + // Function instantiations have their own defblocker + val lambdaBlockers = new IncrementalMap[TemplateAppInfo, Encoded]() + + // Keep which function invocation is guarded by which guard, + // also specify the generation of the blocker. + val appInfos = new IncrementalMap[(Encoded, App), (Int, Int, Encoded, Encoded, Set[TemplateAppInfo])]() + val appBlockers = new IncrementalMap[(Encoded, App), Encoded]() + val blockerToApps = new IncrementalMap[Encoded, (Encoded, App)]() + + val byID = new IncrementalMap[Encoded, LambdaTemplate] + val byType = new IncrementalMap[FunctionType, Map[LambdaStructure, LambdaTemplate]].withDefaultValue(Map.empty) + val applications = new IncrementalMap[FunctionType, Set[(Encoded, App)]].withDefaultValue(Set.empty) + val knownFree = new IncrementalMap[FunctionType, Set[Encoded]].withDefaultValue(Set.empty) + val maybeFree = new IncrementalMap[FunctionType, Set[(Encoded, Encoded)]].withDefaultValue(Set.empty) + val freeBlockers = new IncrementalMap[FunctionType, Set[(Encoded, Encoded)]].withDefaultValue(Set.empty) + + val instantiated = new IncrementalSet[(Encoded, App)] + + val typeBlockers = new IncrementalMap[Encoded, Encoded]() + val typeEnablers: MutableSet[Encoded] = MutableSet.empty + + override val incrementals: Seq[IncrementalState] = Seq( + lambdaBlockers, appInfos, appBlockers, blockerToApps, + byID, byType, applications, knownFree, maybeFree, freeBlockers, + instantiated, typeBlockers) + + def unrollGeneration: Option[Int] = + if (appInfos.isEmpty) None + else Some(appInfos.values.map(_._1).min) + + private def assumptions: Seq[Encoded] = freeBlockers.toSeq.flatMap(_._2.map(p => mkNot(p._1))) + def satisfactionAssumptions = appInfos.map(_._2._4).toSeq ++ assumptions + def refutationAssumptions = assumptions + + def promoteBlocker(b: Encoded): Boolean = { + if (blockerToApps contains b) { + val app = blockerToApps(b) + val (_, origGen, _, notB, infos) = appInfos(app) + appInfos += app -> (currentGeneration, origGen, b, notB, infos) + true + } else { + false + } + } + + def unroll: Clauses = if (appInfos.isEmpty) Seq.empty else { + val newClauses = new scala.collection.mutable.ListBuffer[Encoded] + + val blockers = appInfos.values.filter(_._1 <= currentGeneration).toSeq.map(_._3) + val apps = blockers.flatMap(blocker => blockerToApps.get(blocker)) + val thisAppInfos = apps.map(app => app -> { + val (gen, _, _, _, infos) = appInfos(app) + (gen, infos) + }) + + blockerToApps --= blockers + appInfos --= apps + + for ((app, (_, infos)) <- thisAppInfos if infos.nonEmpty) { + val nextB = encodeSymbol(Variable(FreshIdentifier("b_lambda", true), BooleanType)) + val extension = mkOr((infos.map(_.equals).toSeq :+ nextB) : _*) + val clause = mkEquals(appBlockers(app), extension) + + appBlockers += app -> nextB + + ctx.reporter.debug(" -> extending lambda blocker: " + clause) + newClauses += clause + } + + for ((app @ (b, _), (gen, infos)) <- thisAppInfos; + info @ TemplateAppInfo(tmpl, equals, args) <- infos; + template <- tmpl.left) { + val newCls = new scala.collection.mutable.ListBuffer[Encoded] + + val lambdaBlocker = lambdaBlockers.get(info) match { + case Some(lambdaBlocker) => lambdaBlocker + + case None => + val lambdaBlocker = encodeSymbol(Variable(FreshIdentifier("d", true), BooleanType)) + lambdaBlockers += info -> lambdaBlocker + + val newExprs = template.instantiate(lambdaBlocker, args) + + newCls ++= newExprs + lambdaBlocker + } + + val enabler = if (equals == trueT) b else mkAnd(equals, b) + newCls += mkImplies(enabler, lambdaBlocker) + impliesBlocker(b, lambdaBlocker) + + ctx.reporter.debug("Unrolling behind "+info+" ("+newCls.size+")") + for (cl <- newCls) { + ctx.reporter.debug(" . "+cl) + } + + newClauses ++= newCls + } + + ctx.reporter.debug(s" - ${newClauses.size} new clauses") + + newClauses + } + } +} diff --git a/src/main/scala/inox/solvers/unrolling/Quantification.scala b/src/main/scala/inox/solvers/unrolling/Quantification.scala deleted file mode 100644 index 82558fe87e2ccdf0966f124fbf0cc3a4e76833f0..0000000000000000000000000000000000000000 --- a/src/main/scala/inox/solvers/unrolling/Quantification.scala +++ /dev/null @@ -1,218 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package purescala - -import Common._ -import Definitions._ -import Expressions._ -import Constructors._ -import Extractors._ -import ExprOps._ -import Types._ - -import evaluators._ - -object Quantification { - - def extractQuorums[A,B]( - matchers: Set[A], - quantified: Set[B], - margs: A => Set[A], - qargs: A => Set[B] - ): Seq[Set[A]] = { - def expand(m: A): Set[A] = Set(m) ++ margs(m).flatMap(expand) - def allQArgs(m: A): Set[B] = qargs(m) ++ margs(m).flatMap(allQArgs) - val expandedMap: Map[A, Set[A]] = matchers.map(m => m -> expand(m)).toMap - val reverseMap : Map[A, Set[A]] = expandedMap.toSeq - .flatMap(p => p._2.map(m => m -> p._1)) // flatten to reversed pairs - .groupBy(_._1).mapValues(_.map(_._2).toSet) // rebuild map from pair set - .map { case (m, ms) => // filter redundant matchers - val allM = allQArgs(m) - m -> ms.filter(rm => allQArgs(rm) != allM) - } - - def rec(oms: Seq[A], mSet: Set[A], qss: Seq[Set[B]]): Seq[Set[A]] = { - if (qss.contains(quantified)) { - Seq(mSet) - } else { - var res = Seq.empty[Set[A]] - val rest = oms.scanLeft(List.empty[A])((acc, o) => o :: acc).drop(1) - for ((m :: ms) <- rest if margs(m).forall(mSet)) { - val qas = qargs(m) - if (!qss.exists(qs => qs.subsetOf(qas) || qas.subsetOf(qs))) { - res ++= rec(ms, mSet + m, qss ++ qss.map(_ ++ qas) :+ qas) - } - } - res - } - } - - val oms = expandedMap.toSeq.sortBy(p => -p._2.size).map(_._1) - val res = rec(oms, Set.empty, Seq.empty) - - res.filter(ms => ms.forall(m => reverseMap(m) subsetOf ms)) - } - - def extractQuorums(expr: Expr, quantified: Set[Identifier]): Seq[Set[(Path, Expr, Seq[Expr])]] = { - object QMatcher { - def unapply(e: Expr): Option[(Expr, Seq[Expr])] = e match { - case QuantificationMatcher(expr, args) => - if (args.exists { case Variable(id) => quantified(id) case _ => false }) { - Some(expr -> args) - } else { - None - } - case _ => None - } - } - - val allMatchers = CollectorWithPaths { case QMatcher(expr, args) => expr -> args }.traverse(expr) - val matchers = allMatchers.map { case ((caller, args), path) => (path, caller, args) }.toSet - - extractQuorums(matchers, quantified, - (p: (Path, Expr, Seq[Expr])) => p._3.collect { case QMatcher(e, a) => (p._1, e, a) }.toSet, - (p: (Path, Expr, Seq[Expr])) => p._3.collect { case Variable(id) if quantified(id) => id }.toSet) - } - - object Domains { - def empty = new Domains(Map.empty, Map.empty) - } - - class Domains (_lambdas: Map[Lambda, Set[Seq[Expr]]], val tpes: Map[TypeTree, Set[Seq[Expr]]]) { - val lambdas = _lambdas.map { case (lambda, domain) => - val (nl, _) = normalizeStructure(lambda) - nl -> domain - } - - def get(e: Expr): Set[Seq[Expr]] = { - val specialized: Set[Seq[Expr]] = e match { - case FiniteLambda(mapping, _, _) => mapping.map(_._1).toSet - case l: Lambda => - val (nl, _) = normalizeStructure(l) - lambdas.getOrElse(nl, Set.empty) - case _ => Set.empty - } - specialized ++ tpes.getOrElse(e.getType, Set.empty) - } - } - - object QuantificationMatcher { - private def flatApplication(expr: Expr): Option[(Expr, Seq[Expr])] = expr match { - case Application(fi: FunctionInvocation, args) => None - case Application(caller: Application, args) => flatApplication(caller) match { - case Some((c, prevArgs)) => Some((c, prevArgs ++ args)) - case None => None - } - case Application(caller, args) => Some((caller, args)) - case _ => None - } - - def unapply(expr: Expr): Option[(Expr, Seq[Expr])] = expr match { - case IsTyped(a: Application, ft: FunctionType) => None - case Application(e, args) => flatApplication(expr) - case ArraySelect(arr, index) => Some(arr -> Seq(index)) - case MapApply(map, key) => Some(map -> Seq(key)) - case ElementOfSet(elem, set) => Some(set -> Seq(elem)) - case _ => None - } - } - - object QuantificationTypeMatcher { - private def flatType(tpe: TypeTree): (Seq[TypeTree], TypeTree) = tpe match { - case FunctionType(from, to) => - val (nextArgs, finalTo) = flatType(to) - (from ++ nextArgs, finalTo) - case _ => (Seq.empty, tpe) - } - - def unapply(tpe: TypeTree): Option[(Seq[TypeTree], TypeTree)] = tpe match { - case FunctionType(from, to) => Some(flatType(tpe)) - case ArrayType(base) => Some(Seq(Int32Type) -> base) - case MapType(from, to) => Some(Seq(from) -> to) - case SetType(base) => Some(Seq(base) -> BooleanType) - case _ => None - } - } - - sealed abstract class ForallStatus { - def isValid: Boolean - } - - case object ForallValid extends ForallStatus { - def isValid = true - } - - sealed abstract class ForallInvalid(msg: String) extends ForallStatus { - def isValid = false - def getMessage: String = msg - } - - case class NoMatchers(expr: String) extends ForallInvalid("No matchers available for E-Matching in " + expr) - case class ComplexArgument(expr: String) extends ForallInvalid("Unhandled E-Matching pattern in " + expr) - case class NonBijectiveMapping(expr: String) extends ForallInvalid("Non-bijective mapping for quantifiers in " + expr) - case class InvalidOperation(expr: String) extends ForallInvalid("Invalid operation on quantifiers in " + expr) - - def checkForall(quantified: Set[Identifier], body: Expr)(implicit ctx: LeonContext): ForallStatus = { - val TopLevelAnds(conjuncts) = body - for (conjunct <- conjuncts) { - val matchers = collect[(Expr, Seq[Expr])] { - case QuantificationMatcher(e, args) => Set(e -> args) - case _ => Set.empty - } (conjunct) - - if (matchers.isEmpty) return NoMatchers(conjunct.asString) - - val complexArgs = matchers.flatMap { case (_, args) => - args.flatMap(arg => arg match { - case QuantificationMatcher(_, _) => None - case Variable(id) => None - case _ if (variablesOf(arg) & quantified).nonEmpty => Some(arg) - case _ => None - }) - } - - if (complexArgs.nonEmpty) return ComplexArgument(complexArgs.head.asString) - - val matcherToQuants = matchers.foldLeft(Map.empty[Expr, Set[Identifier]]) { - case (acc, (m, args)) => acc + (m -> (acc.getOrElse(m, Set.empty) ++ args.flatMap { - case Variable(id) if quantified(id) => Set(id) - case _ => Set.empty[Identifier] - })) - } - - val bijectiveMappings = matcherToQuants.filter(_._2.nonEmpty).groupBy(_._2) - if (bijectiveMappings.size > 1) return NonBijectiveMapping(bijectiveMappings.head._2.head._1.asString) - - val matcherSet = matcherToQuants.filter(_._2.nonEmpty).keys.toSet - - val qs = fold[Set[Identifier]] { case (m, children) => - val q = children.toSet.flatten - - m match { - case QuantificationMatcher(_, args) => - q -- args.flatMap { - case Variable(id) if quantified(id) => Set(id) - case _ => Set.empty[Identifier] - } - case LessThan(_: Variable, _: Variable) => q - case LessEquals(_: Variable, _: Variable) => q - case GreaterThan(_: Variable, _: Variable) => q - case GreaterEquals(_: Variable, _: Variable) => q - case And(_) => q - case Or(_) => q - case Implies(_, _) => q - case Operator(es, _) => - val matcherArgs = matcherSet & es.toSet - if (q.nonEmpty && !(q.size == 1 && matcherArgs.isEmpty && m.getType == BooleanType)) - return InvalidOperation(m.asString) - else Set.empty - case Variable(id) if quantified(id) => Set(id) - case _ => q - } - } (conjunct) - } - - ForallValid - } -} diff --git a/src/main/scala/inox/solvers/unrolling/QuantificationManager.scala b/src/main/scala/inox/solvers/unrolling/QuantificationTemplates.scala similarity index 51% rename from src/main/scala/inox/solvers/unrolling/QuantificationManager.scala rename to src/main/scala/inox/solvers/unrolling/QuantificationTemplates.scala index 797f11aa1475d063cca706bc7b14da81a3a55dc8..a35fbb3195be9f14d2faba94628c72203259388c 100644 --- a/src/main/scala/inox/solvers/unrolling/QuantificationManager.scala +++ b/src/main/scala/inox/solvers/unrolling/QuantificationTemplates.scala @@ -1,198 +1,196 @@ /* Copyright 2009-2016 EPFL, Lausanne */ -package leon +package inox package solvers package unrolling -import leon.utils._ -import purescala.Common._ -import purescala.Definitions._ -import purescala.Extractors._ -import purescala.Constructors._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Types._ -import purescala.TypeOps._ -import purescala.Quantification.{QuantificationTypeMatcher => QTM, QuantificationMatcher => QM, Domains} +import utils._ -import evaluators._ +import scala.collection.mutable.{Map => MutableMap, Set => MutableSet, Stack => MutableStack, Queue} -import Instantiation._ -import Template._ +trait QuantificationTemplates { self: Templates => + import program._ + import program.trees._ + import program.symbols._ -import scala.collection.mutable.{Map => MutableMap, Set => MutableSet, Stack => MutableStack, Queue} + import lambdasManager._ + import quantificationsManager._ + + def hasQuantifiers = quantifications.nonEmpty -case class Matcher[T](caller: T, tpe: TypeTree, args: Seq[Arg[T]], encoded: T) { - override def toString = caller + args.map { - case Right(m) => m.toString - case Left(v) => v.toString - }.mkString("(",",",")") - - def substitute(substituter: T => T, matcherSubst: Map[T, Matcher[T]]): Matcher[T] = copy( - caller = substituter(caller), - args = args.map { - case Left(v) => matcherSubst.get(v) match { - case Some(m) => Right(m) - case None => Left(substituter(v)) + /* -- Extraction helpers -- */ + + object QuantificationMatcher { + private def flatApplication(expr: Expr): Option[(Expr, Seq[Expr])] = expr match { + case Application(fi: FunctionInvocation, args) => None + case Application(caller: Application, args) => flatApplication(caller) match { + case Some((c, prevArgs)) => Some((c, prevArgs ++ args)) + case None => None } - case Right(m) => Right(m.substitute(substituter, matcherSubst)) - }, - encoded = substituter(encoded) - ) -} + case Application(caller, args) => Some((caller, args)) + case _ => None + } + + def unapply(expr: Expr): Option[(Expr, Seq[Expr])] = expr match { + case IsTyped(a: Application, ft: FunctionType) => None + case Application(e, args) => flatApplication(expr) + case MapApply(map, key) => Some(map -> Seq(key)) + case MultiplicityInBag(elem, bag) => Some(bag -> Seq(elem)) + case ElementOfSet(elem, set) => Some(set -> Seq(elem)) + case _ => None + } + } + + object QuantificationTypeMatcher { + private def flatType(tpe: Type): (Seq[Type], Type) = tpe match { + case FunctionType(from, to) => + val (nextArgs, finalTo) = flatType(to) + (from ++ nextArgs, finalTo) + case _ => (Seq.empty, tpe) + } -class QuantificationTemplate[T]( - val quantificationManager: QuantificationManager[T], - val pathVar: (Identifier, T), - val qs: (Identifier, T), - val q2s: (Identifier, T), - val insts: (Identifier, T), - val guardVar: T, - val quantifiers: Seq[(Identifier, T)], - val condVars: Map[Identifier, T], - val exprVars: Map[Identifier, T], - val condTree: Map[Identifier, Set[Identifier]], - val clauses: Seq[T], - val blockers: Map[T, Set[TemplateCallInfo[T]]], - val applications: Map[T, Set[App[T]]], - val matchers: Map[T, Set[Matcher[T]]], - val lambdas: Seq[LambdaTemplate[T]], - val structure: Forall, - val dependencies: Map[Identifier, T], - val forall: Forall, - stringRepr: () => String) { - - lazy val start = pathVar._2 - lazy val key: (Forall, Seq[T]) = (structure, { - var cls: Seq[T] = Seq.empty - purescala.ExprOps.preTraversal { - case Variable(id) => cls ++= dependencies.get(id) - case _ => - } (structure) - cls - }) - - def substitute(substituter: T => T, matcherSubst: Map[T, Matcher[T]]): QuantificationTemplate[T] = { - new QuantificationTemplate[T]( - quantificationManager, + def unapply(tpe: Type): Option[(Seq[Type], Type)] = tpe match { + case FunctionType(from, to) => Some(flatType(tpe)) + case MapType(from, to) => Some(Seq(from) -> to) + case BagType(base) => Some(Seq(base) -> IntegerType) + case SetType(base) => Some(Seq(base) -> BooleanType) + case _ => None + } + } + + /* -- Quantifier template definitions -- */ + + class QuantificationTemplate( + val pathVar: (Variable, Encoded), + val qs: (Variable, Encoded), + val q2s: (Variable, Encoded), + val insts: (Variable, Encoded), + val guardVar: Encoded, + val quantifiers: Seq[(Variable, Encoded)], + val condVars: Map[Variable, Encoded], + val exprVars: Map[Variable, Encoded], + val condTree: Map[Variable, Set[Variable]], + val clauses: Clauses, + val blockers: Calls, + val applications: Apps, + val matchers: Matchers, + val lambdas: Seq[LambdaTemplate], + val structure: Forall, + val dependencies: Map[Variable, Encoded], + val forall: Forall, + stringRepr: () => String) { + + lazy val start = pathVar._2 + lazy val key: (Forall, Seq[Encoded]) = (structure, { + var cls: Seq[Encoded] = Seq.empty + exprOps.preTraversal { + case v: Variable => cls ++= dependencies.get(v) + case _ => + } (structure) + cls + }) + + def substitute(substituter: Encoded => Encoded, msubst: Map[Encoded, Matcher]) = new QuantificationTemplate( pathVar._1 -> substituter(start), - qs, - q2s, - insts, - guardVar, - quantifiers, - condVars, - exprVars, - condTree, + qs, q2s, insts, guardVar, quantifiers, condVars, exprVars, condTree, clauses.map(substituter), - blockers.map { case (b, fis) => - substituter(b) -> fis.map(fi => fi.copy( - args = fi.args.map(_.substitute(substituter, matcherSubst)) - )) - }, - applications.map { case (b, apps) => - substituter(b) -> apps.map(app => app.copy( - caller = substituter(app.caller), - args = app.args.map(_.substitute(substituter, matcherSubst)) - )) - }, - matchers.map { case (b, ms) => - substituter(b) -> ms.map(_.substitute(substituter, matcherSubst)) - }, - lambdas.map(_.substitute(substituter, matcherSubst)), - structure, - dependencies.map { case (id, value) => id -> substituter(value) }, - forall, - stringRepr - ) + blockers.map { case (b, fis) => substituter(b) -> fis.map(_.substitute(substituter, msubst)) }, + applications.map { case (b, apps) => substituter(b) -> apps.map(_.substitute(substituter, msubst)) }, + matchers.map { case (b, ms) => substituter(b) -> ms.map(_.substitute(substituter, msubst)) }, + lambdas.map(_.substitute(substituter, msubst)), + structure, dependencies.map { case (id, value) => id -> substituter(value) }, + forall, stringRepr) + + private lazy val str : String = stringRepr() + override def toString : String = str } - private lazy val str : String = stringRepr() - override def toString : String = str -} - -object QuantificationTemplate { - def apply[T]( - encoder: TemplateEncoder[T], - quantificationManager: QuantificationManager[T], - pathVar: (Identifier, T), - qs: (Identifier, T), - q2: Identifier, - inst: Identifier, - guard: Identifier, - quantifiers: Seq[(Identifier, T)], - condVars: Map[Identifier, T], - exprVars: Map[Identifier, T], - condTree: Map[Identifier, Set[Identifier]], - guardedExprs: Map[Identifier, Seq[Expr]], - lambdas: Seq[LambdaTemplate[T]], - baseSubstMap: Map[Identifier, T], - dependencies: Map[Identifier, T], - proposition: Forall - ): QuantificationTemplate[T] = { - - val q2s: (Identifier, T) = q2 -> encoder.encodeId(q2) - val insts: (Identifier, T) = inst -> encoder.encodeId(inst) - val guards: (Identifier, T) = guard -> encoder.encodeId(guard) - - val (clauses, blockers, applications, functions, matchers, templateString) = - Template.encode(encoder, pathVar, quantifiers, condVars, exprVars, guardedExprs, lambdas, Seq.empty, - substMap = baseSubstMap + q2s + insts + guards + qs) - - val (structuralQuant, deps) = normalizeStructure(proposition) - val keyDeps = deps.map { case (id, dep) => id -> encoder.encodeExpr(dependencies)(dep) } - - new QuantificationTemplate[T](quantificationManager, - pathVar, qs, q2s, insts, guards._2, quantifiers, condVars, exprVars, condTree, - clauses, blockers, applications, matchers, lambdas, structuralQuant, keyDeps, proposition, - () => "Template for " + proposition + " is :\n" + templateString()) + object QuantificationTemplate { + def apply( + pathVar: (Variable, Encoded), + qs: (Variable, Encoded), + q2: Variable, + inst: Variable, + guard: Variable, + quantifiers: Seq[(Variable, Encoded)], + condVars: Map[Variable, Encoded], + exprVars: Map[Variable, Encoded], + condTree: Map[Variable, Set[Variable]], + guardedExprs: Map[Variable, Seq[Expr]], + lambdas: Seq[LambdaTemplate], + baseSubstMap: Map[Variable, Encoded], + dependencies: Map[Variable, Encoded], + proposition: Forall + ): QuantificationTemplate = { + + val q2s: (Variable, Encoded) = q2 -> encodeSymbol(q2) + val insts: (Variable, Encoded) = inst -> encodeSymbol(inst) + val guards: (Variable, Encoded) = guard -> encodeSymbol(guard) + + val (clauses, blockers, applications, matchers, templateString) = + Template.encode(pathVar, quantifiers, condVars, exprVars, guardedExprs, lambdas, Seq.empty, + substMap = baseSubstMap + q2s + insts + guards + qs) + + val (structuralQuant, deps) = normalizeStructure(proposition) + val keyDeps = deps.map { case (id, dep) => id -> encodeExpr(dependencies)(dep) } + + new QuantificationTemplate( + pathVar, qs, q2s, insts, guards._2, quantifiers, condVars, exprVars, condTree, + clauses, blockers, applications, matchers, lambdas, structuralQuant, keyDeps, proposition, + () => "Template for " + proposition + " is :\n" + templateString()) + } } -} -class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManager[T](encoder) { - private[solvers] val quantifications = new IncrementalSeq[MatcherQuantification] + private[unrolling] object quantificationsManager extends Manager { + val quantifications = new IncrementalSeq[MatcherQuantification] - private val instCtx = new InstantiationContext + private[QuantificationTemplates] val instCtx = new InstantiationContext - private val ignoredMatchers = new IncrementalSeq[(Int, Set[T], Matcher[T])] - private val ignoredSubsts = new IncrementalMap[MatcherQuantification, MutableSet[(Int, Set[T], Map[T,Arg[T]])]] - private val handledSubsts = new IncrementalMap[MatcherQuantification, MutableSet[(Set[T], Map[T,Arg[T]])]] + val ignoredMatchers = new IncrementalSeq[(Int, Set[Encoded], Matcher)] + val ignoredSubsts = new IncrementalMap[MatcherQuantification, MutableSet[(Int, Set[Encoded], Map[Encoded,Arg])]] + val handledSubsts = new IncrementalMap[MatcherQuantification, MutableSet[(Set[Encoded], Map[Encoded,Arg])]] - private val lambdaAxioms = new IncrementalSet[(LambdaStructure[T], Seq[(Identifier, T)])] - private val templates = new IncrementalMap[(Expr, Seq[T]), T] + val lambdaAxioms = new IncrementalSet[(LambdaStructure, Seq[(Variable, Encoded)])] + val templates = new IncrementalMap[(Expr, Seq[Encoded]), Encoded] - override protected def incrementals: List[IncrementalState] = - List(quantifications, instCtx, ignoredMatchers, ignoredSubsts, - handledSubsts, lambdaAxioms, templates) ++ super.incrementals + val incrementals: Seq[IncrementalState] = Seq( + quantifications, instCtx, ignoredMatchers, ignoredSubsts, + handledSubsts, lambdaAxioms, templates) + + private def assumptions: Seq[Encoded] = + quantifications.collect { case q: Quantification => q.currentQ2Var }.toSeq + def satisfactionAssumptions = assumptions + def refutationAssumptions = assumptions + } private var currentGen = 0 - private sealed abstract class MatcherKey(val tpe: TypeTree) - private case class CallerKey(caller: T, tt: TypeTree) extends MatcherKey(tt) - private case class LambdaKey(lambda: Lambda, tt: TypeTree) extends MatcherKey(tt) - private case class TypeKey(tt: TypeTree) extends MatcherKey(tt) + private sealed abstract class MatcherKey(val tpe: Type) + private case class CallerKey(caller: Encoded, tt: Type) extends MatcherKey(tt) + private case class LambdaKey(lambda: Lambda, tt: Type) extends MatcherKey(tt) + private case class TypeKey(tt: Type) extends MatcherKey(tt) - private def matcherKey(caller: T, tpe: TypeTree): MatcherKey = tpe match { + private def matcherKey(caller: Encoded, tpe: Type): MatcherKey = tpe match { case ft: FunctionType if knownFree(ft)(caller) => CallerKey(caller, tpe) case _: FunctionType if byID.isDefinedAt(caller) => LambdaKey(byID(caller).structure.lambda, tpe) case _ => TypeKey(tpe) } @inline - private def correspond(qm: Matcher[T], m: Matcher[T]): Boolean = + private def correspond(qm: Matcher, m: Matcher): Boolean = correspond(qm, m.caller, m.tpe) - private def correspond(qm: Matcher[T], caller: T, tpe: TypeTree): Boolean = { + private def correspond(qm: Matcher, caller: Encoded, tpe: Type): Boolean = { val qkey = matcherKey(qm.caller, qm.tpe) val key = matcherKey(caller, tpe) qkey == key || (qkey.tpe == key.tpe && (qkey.isInstanceOf[TypeKey] || key.isInstanceOf[TypeKey])) } class VariableNormalizer { - private val varMap: MutableMap[TypeTree, Seq[T]] = MutableMap.empty - private val varSet: MutableSet[T] = MutableSet.empty + private val varMap: MutableMap[Type, Seq[Encoded]] = MutableMap.empty + private val varSet: MutableSet[Encoded] = MutableSet.empty - def normalize(ids: Seq[Identifier]): Seq[T] = { + def normalize(ids: Seq[Variable]): Seq[Encoded] = { val mapping = ids.groupBy(id => bestRealType(id.getType)).flatMap { case (tpe, idst) => val prev = varMap.get(tpe) match { case Some(seq) => seq @@ -203,7 +201,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage idst zip prev.take(idst.size) } else { val (handled, newIds) = idst.splitAt(prev.size) - val uIds = newIds.map(id => id -> encoder.encodeId(id)) + val uIds = newIds.map(id => id -> encodeSymbol(id)) varMap(tpe) = prev ++ uIds.map(_._2) varSet ++= uIds.map(_._2) @@ -215,58 +213,55 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage ids.map(mapping) } - def normalSubst(qs: Seq[(Identifier, T)]): Map[T, T] = { + def normalSubst(qs: Seq[(Variable, Encoded)]): Map[Encoded, Encoded] = { (qs.map(_._2) zip normalize(qs.map(_._1))).toMap } - def contains(idT: T): Boolean = varSet(idT) - def get(tpe: TypeTree): Option[Seq[T]] = varMap.get(tpe) + def contains(idT: Encoded): Boolean = varSet(idT) + def get(tpe: Type): Option[Seq[Encoded]] = varMap.get(tpe) } private val abstractNormalizer = new VariableNormalizer private val concreteNormalizer = new VariableNormalizer - def isQuantifier(idT: T): Boolean = abstractNormalizer.contains(idT) + def isQuantifier(idT: Encoded): Boolean = abstractNormalizer.contains(idT) - override def assumptions: Seq[T] = super.assumptions ++ - quantifications.collect { case q: Quantification => q.currentQ2Var }.toSeq - - def typeInstantiations: Map[TypeTree, Matchers] = instCtx.map.instantiations.collect { + def typeInstantiations: Map[Type, MatcherSet] = instCtx.map.instantiations.collect { case (TypeKey(tpe), matchers) => tpe -> matchers } - def lambdaInstantiations: Map[Lambda, Matchers] = instCtx.map.instantiations.collect { + def lambdaInstantiations: Map[Lambda, MatcherSet] = instCtx.map.instantiations.collect { case (LambdaKey(lambda, tpe), matchers) => lambda -> (matchers ++ instCtx.map.get(TypeKey(tpe)).toMatchers) } - def partialInstantiations: Map[T, Matchers] = instCtx.map.instantiations.collect { + def partialInstantiations: Map[Encoded, MatcherSet] = instCtx.map.instantiations.collect { case (CallerKey(caller, tpe), matchers) => caller -> (matchers ++ instCtx.map.get(TypeKey(tpe)).toMatchers) } - private def maxDepth(m: Matcher[T]): Int = 1 + (0 +: m.args.map { + private def maxDepth(m: Matcher): Int = 1 + (0 +: m.args.map { case Right(ma) => maxDepth(ma) case _ => 0 }).max - private def totalDepth(m: Matcher[T]): Int = 1 + m.args.map { + private def totalDepth(m: Matcher): Int = 1 + m.args.map { case Right(ma) => totalDepth(ma) case _ => 0 }.sum - private def encodeEnablers(es: Set[T]): T = - if (es.isEmpty) trueT else encoder.mkAnd(es.toSeq.sortBy(_.toString) : _*) + private def encodeEnablers(es: Set[Encoded]): Encoded = + if (es.isEmpty) trueT else mkAnd(es.toSeq.sortBy(_.toString) : _*) - private type Matchers = Set[(T, Matcher[T])] + private type MatcherSet = Set[(Encoded, Matcher)] - private class Context private(ctx: Map[Matcher[T], Set[Set[T]]]) extends Iterable[(Set[T], Matcher[T])] { + private class Context private(ctx: Map[Matcher, Set[Set[Encoded]]]) extends Iterable[(Set[Encoded], Matcher)] { def this() = this(Map.empty) - def apply(p: (Set[T], Matcher[T])): Boolean = ctx.get(p._2) match { + def apply(p: (Set[Encoded], Matcher)): Boolean = ctx.get(p._2) match { case None => false case Some(blockerSets) => blockerSets(p._1) || blockerSets.exists(set => set.subsetOf(p._1)) } - def +(p: (Set[T], Matcher[T])): Context = if (apply(p)) this else { + def +(p: (Set[Encoded], Matcher)): Context = if (apply(p)) this else { val prev = ctx.getOrElse(p._2, Set.empty) val newSet = prev.filterNot(set => p._1.subsetOf(set)).toSet + p._1 new Context(ctx + (p._2 -> newSet)) @@ -275,14 +270,14 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage def ++(that: Context): Context = that.foldLeft(this)((ctx, p) => ctx + p) def iterator = ctx.toSeq.flatMap { case (m, bss) => bss.map(bs => bs -> m) }.iterator - def toMatchers: Matchers = this.map(p => encodeEnablers(p._1) -> p._2).toSet + def toMatchers: MatcherSet = this.map(p => encodeEnablers(p._1) -> p._2).toSet } private class ContextMap( - private var tpeMap: MutableMap[TypeTree, Context] = MutableMap.empty, + private var tpeMap: MutableMap[Type, Context] = MutableMap.empty, private var funMap: MutableMap[MatcherKey, Context] = MutableMap.empty ) extends IncrementalState { - private val stack = new MutableStack[(MutableMap[TypeTree, Context], MutableMap[MatcherKey, Context])] + private val stack = new MutableStack[(MutableMap[Type, Context], MutableMap[MatcherKey, Context])] def clear(): Unit = { stack.clear() @@ -304,7 +299,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage funMap = pfunMap } - def +=(p: (Set[T], Matcher[T])): Unit = matcherKey(p._2.caller, p._2.tpe) match { + def +=(p: (Set[Encoded], Matcher)): Unit = matcherKey(p._2.caller, p._2.tpe) match { case TypeKey(tpe) => tpeMap(tpe) = tpeMap.getOrElse(tpe, new Context) + p case key => funMap(key) = funMap.getOrElse(key, new Context) + p } @@ -315,7 +310,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage this } - def get(caller: T, tpe: TypeTree): Context = + def get(caller: Encoded, tpe: Type): Context = funMap.getOrElse(matcherKey(caller, tpe), new Context) ++ tpeMap.getOrElse(tpe, new Context) def get(key: MatcherKey): Context = key match { @@ -323,7 +318,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage case key => funMap.getOrElse(key, new Context) ++ tpeMap.getOrElse(key.tpe, new Context) } - def instantiations: Map[MatcherKey, Matchers] = + def instantiations: Map[MatcherKey, MatcherSet] = (funMap.toMap ++ tpeMap.map { case (tpe,ms) => TypeKey(tpe) -> ms }).mapValues(_.toMatchers) } @@ -354,19 +349,17 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage } def instantiated: Context = _instantiated - def apply(p: (Set[T], Matcher[T])): Boolean = _instantiated(p) + def apply(p: (Set[Encoded], Matcher)): Boolean = _instantiated(p) - def corresponding(m: Matcher[T]): Context = map.get(m.caller, m.tpe) + def corresponding(m: Matcher): Context = map.get(m.caller, m.tpe) - def instantiate(blockers: Set[T], matcher: Matcher[T])(qs: MatcherQuantification*): Instantiation[T] = { + def instantiate(blockers: Set[Encoded], matcher: Matcher)(qs: MatcherQuantification*): Clauses = { if (this(blockers -> matcher)) { - Instantiation.empty[T] + Seq.empty } else { map += (blockers -> matcher) _instantiated += (blockers -> matcher) - var instantiation = Instantiation.empty[T] - for (q <- qs) instantiation ++= q.instantiate(blockers, matcher) - instantiation + qs.flatMap(_.instantiate(blockers, matcher)) } } @@ -378,26 +371,26 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage } private[solvers] trait MatcherQuantification { - val pathVar: (Identifier, T) - val quantifiers: Seq[(Identifier, T)] - val matchers: Set[Matcher[T]] - val allMatchers: Map[T, Set[Matcher[T]]] - val condVars: Map[Identifier, T] - val exprVars: Map[Identifier, T] - val condTree: Map[Identifier, Set[Identifier]] - val clauses: Seq[T] - val blockers: Map[T, Set[TemplateCallInfo[T]]] - val applications: Map[T, Set[App[T]]] - val lambdas: Seq[LambdaTemplate[T]] - - val holds: T + val pathVar: (Variable, Encoded) + val quantifiers: Seq[(Variable, Encoded)] + val matchers: Set[Matcher] + val allMatchers: Matchers + val condVars: Map[Variable, Encoded] + val exprVars: Map[Variable, Encoded] + val condTree: Map[Variable, Set[Variable]] + val clauses: Clauses + val blockers: Calls + val applications: Apps + val lambdas: Seq[LambdaTemplate] + + val holds: Encoded val body: Expr - lazy val quantified: Set[T] = quantifiers.map(_._2).toSet + lazy val quantified: Set[Encoded] = quantifiers.map(_._2).toSet lazy val start = pathVar._2 private lazy val depth = matchers.map(maxDepth).max - private lazy val transMatchers: Set[Matcher[T]] = (for { + private lazy val transMatchers: Set[Matcher] = (for { (b, ms) <- allMatchers.toSeq m <- ms if !matchers(m) && maxDepth(m) <= depth } yield m).toSet @@ -407,7 +400,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage * as other instantiations have been performed previously when the associated applications * were first encountered. */ - private def mappings(bs: Set[T], matcher: Matcher[T]): Set[Set[(Set[T], Matcher[T], Matcher[T])]] = { + private def mappings(bs: Set[Encoded], matcher: Matcher): Set[Set[(Set[Encoded], Matcher, Matcher)]] = { /* 1. select an application in the quantified proposition for which the current app can * be bound when generating the new constraints */ @@ -427,7 +420,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage /* 2.2. based on the possible bindings for each quantified application, build a set of * instantiation mappings that can be used to instantiate all necessary constraints */ - val allMappings = matcherToInstances.foldLeft[Set[Set[(Set[T], Matcher[T], Matcher[T])]]](Set(Set.empty)) { + val allMappings = matcherToInstances.foldLeft[Set[Set[(Set[Encoded], Matcher, Matcher)]]](Set(Set.empty)) { case (mappings, (qm, instances)) => Set(instances.toSeq.flatMap { case (bs, m) => mappings.map(mapping => mapping + ((bs, qm, m))) } : _*) @@ -437,13 +430,13 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage } } - private def extractSubst(mapping: Set[(Set[T], Matcher[T], Matcher[T])]): (Set[T], Map[T,Arg[T]], Boolean) = { - var constraints: Set[T] = Set.empty - var eqConstraints: Set[(T, T)] = Set.empty - var subst: Map[T, Arg[T]] = Map.empty + private def extractSubst(mapping: Set[(Set[Encoded], Matcher, Matcher)]): (Set[Encoded], Map[Encoded,Arg], Boolean) = { + var constraints: Set[Encoded] = Set.empty + var eqConstraints: Set[(Encoded, Encoded)] = Set.empty + var subst: Map[Encoded, Arg] = Map.empty - var matcherEqs: Set[(T, T)] = Set.empty - def strictnessCnstr(qarg: Arg[T], arg: Arg[T]): Unit = (qarg, arg) match { + var matcherEqs: Set[(Encoded, Encoded)] = Set.empty + def strictnessCnstr(qarg: Arg, arg: Arg): Unit = (qarg, arg) match { case (Right(qam), Right(am)) => (qam.args zip am.args).foreach(p => strictnessCnstr(p._1, p._2)) case _ => matcherEqs += qarg.encoded -> arg.encoded } @@ -462,18 +455,18 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage eqConstraints += (qam.encoded -> arg.encoded) } - val substituter = encoder.substitute(subst.mapValues(_.encoded)) + val substituter = mkSubstituter(subst.mapValues(_.encoded)) val substConstraints = constraints.filter(_ != trueT).map(substituter) val substEqs = eqConstraints.map(p => substituter(p._1) -> p._2) - .filter(p => p._1 != p._2).map(p => encoder.mkEquals(p._1, p._2)) + .filter(p => p._1 != p._2).map(p => mkEquals(p._1, p._2)) val enablers = substConstraints ++ substEqs val isStrict = matcherEqs.forall(p => substituter(p._1) == p._2) (enablers, subst, isStrict) } - def instantiate(bs: Set[T], matcher: Matcher[T]): Instantiation[T] = { - var instantiation = Instantiation.empty[T] + def instantiate(bs: Set[Encoded], matcher: Matcher): Clauses = { + var clauses: Clauses = Seq.empty for (mapping <- mappings(bs, matcher)) { val (enablers, subst, isStrict) = extractSubst(mapping) @@ -481,110 +474,162 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage if (!skip(subst)) { if (!isStrict) { val msubst = subst.collect { case (c, Right(m)) => c -> m } - val substituter = encoder.substitute(subst.mapValues(_.encoded)) + val substituter = mkSubstituter(subst.mapValues(_.encoded)) ignoredSubsts(this) += ((currentGen + 3, enablers, subst)) } else { - instantiation ++= instantiateSubst(enablers, subst, strict = true) + clauses ++= instantiateSubst(enablers, subst, strict = true) } } } - instantiation + clauses } - def instantiateSubst(enablers: Set[T], subst: Map[T, Arg[T]], strict: Boolean = false): Instantiation[T] = { + def instantiateSubst(enablers: Set[Encoded], subst: Map[Encoded, Arg], strict: Boolean = false): Clauses = { if (handledSubsts(this)(enablers -> subst)) { - Instantiation.empty[T] + Seq.empty } else { handledSubsts(this) += enablers -> subst - var instantiation = Instantiation.empty[T] + var clauses: Clauses = Seq.empty val (enabler, optEnabler) = freshBlocker(enablers) if (optEnabler.isDefined) { - instantiation = instantiation withClause encoder.mkEquals(enabler, optEnabler.get) + clauses :+= mkEquals(enabler, optEnabler.get) } val baseSubst = subst ++ instanceSubst(enabler).mapValues(Left(_)) - val (substMap, inst) = Template.substitution[T](encoder, QuantificationManager.this, - condVars, exprVars, condTree, Seq.empty, lambdas, Set.empty, baseSubst, pathVar._1, enabler) - instantiation ++= inst + val (substMap, cls) = Template.substitution( + condVars, exprVars, condTree, lambdas, Seq.empty, baseSubst, pathVar._1, enabler) + clauses ++= cls val msubst = substMap.collect { case (c, Right(m)) => c -> m } - val substituter = encoder.substitute(substMap.mapValues(_.encoded)) + val substituter = mkSubstituter(substMap.mapValues(_.encoded)) registerBlockers(substituter) - instantiation ++= Template.instantiate(encoder, QuantificationManager.this, - clauses, blockers, applications, Map.empty, substMap) + clauses ++= Template.instantiate(clauses, blockers, applications, Map.empty, substMap) for ((b,ms) <- allMatchers; m <- ms) { val sb = enablers ++ (if (b == start) Set.empty else Set(substituter(b))) val sm = m.substitute(substituter, msubst) if (strict && (matchers(m) || transMatchers(m))) { - instantiation ++= instCtx.instantiate(sb, sm)(quantifications.toSeq : _*) + clauses ++= instCtx.instantiate(sb, sm)(quantifications.toSeq : _*) } else if (!matchers(m)) { ignoredMatchers += ((currentGen + 2 + totalDepth(m), sb, sm)) } } - instantiation + clauses } } - protected def instanceSubst(enabler: T): Map[T, T] + protected def instanceSubst(enabler: Encoded): Map[Encoded, Encoded] + + protected def skip(subst: Map[Encoded, Arg]): Boolean = false + + protected def registerBlockers(substituter: Encoded => Encoded): Unit = () + + def checkForall: Option[String] = { + val quantified = quantifiers.map(_._1).toSet + + val matchers = exprOps.collect[(Expr, Seq[Expr])] { + case QuantificationMatcher(e, args) => Set(e -> args) + case _ => Set.empty + } (body) - protected def skip(subst: Map[T, Arg[T]]): Boolean = false + if (matchers.isEmpty) + return Some("No matchers found.") - protected def registerBlockers(substituter: T => T): Unit = () + val matcherToQuants = matchers.foldLeft(Map.empty[Expr, Set[Variable]]) { + case (acc, (m, args)) => acc + (m -> (acc.getOrElse(m, Set.empty) ++ args.flatMap { + case v: Variable if quantified(v) => Set(v) + case _ => Set.empty[Variable] + })) + } + + val bijectiveMappings = matcherToQuants.filter(_._2.nonEmpty).groupBy(_._2) + if (bijectiveMappings.size > 1) + return Some("Non-bijective mapping for symbol " + bijectiveMappings.head._2.head._1.asString) + + def quantifiedArg(e: Expr): Boolean = e match { + case v: Variable => quantified(v) + case QuantificationMatcher(_, args) => args.forall(quantifiedArg) + case _ => false + } + + exprOps.postTraversal(m => m match { + case QuantificationMatcher(_, args) => + val qArgs = args.filter(quantifiedArg) + + if (qArgs.nonEmpty && qArgs.size < args.size) + return Some("Mixed ground and quantified arguments in " + m.asString) + + case Operator(es, _) if es.collect { case v: Variable if quantified(v) => v }.nonEmpty => + return Some("Invalid operation on quantifiers " + m.asString) + + case (_: Equals) | (_: And) | (_: Or) | (_: Implies) | (_: Not) => // OK + + case Operator(es, _) if (es.flatMap(exprOps.variablesOf).toSet & quantified).nonEmpty => + return Some("Unandled implications from operation " + m.asString) + + case _ => + }) (body) + + body match { + case v: Variable if quantified(v) => + Some("Unexpected free quantifier " + v.asString) + case _ => None + } + } } private class Quantification ( - val pathVar: (Identifier, T), - val qs: (Identifier, T), - val q2s: (Identifier, T), - val insts: (Identifier, T), - val guardVar: T, - val quantifiers: Seq[(Identifier, T)], - val matchers: Set[Matcher[T]], - val allMatchers: Map[T, Set[Matcher[T]]], - val condVars: Map[Identifier, T], - val exprVars: Map[Identifier, T], - val condTree: Map[Identifier, Set[Identifier]], - val clauses: Seq[T], - val blockers: Map[T, Set[TemplateCallInfo[T]]], - val applications: Map[T, Set[App[T]]], - val lambdas: Seq[LambdaTemplate[T]], - val template: QuantificationTemplate[T]) extends MatcherQuantification { - - private var _currentQ2Var: T = qs._2 + val pathVar: (Variable, Encoded), + val qs: (Variable, Encoded), + val q2s: (Variable, Encoded), + val insts: (Variable, Encoded), + val guardVar: Encoded, + val quantifiers: Seq[(Variable, Encoded)], + val matchers: Set[Matcher], + val allMatchers: Matchers, + val condVars: Map[Variable, Encoded], + val exprVars: Map[Variable, Encoded], + val condTree: Map[Variable, Set[Variable]], + val clauses: Clauses, + val blockers: Calls, + val applications: Apps, + val lambdas: Seq[LambdaTemplate], + val template: QuantificationTemplate) extends MatcherQuantification { + + private var _currentQ2Var: Encoded = qs._2 def currentQ2Var = _currentQ2Var val holds = qs._2 val body = template.forall.body - private var _currentInsts: Map[T, Set[T]] = Map.empty - def currentInsts = _currentInsts + private var _insts: Map[Encoded, Set[Encoded]] = Map.empty + def instantiations = _insts - protected def instanceSubst(enabler: T): Map[T, T] = { - val nextQ2Var = encoder.encodeId(q2s._1) + protected def instanceSubst(enabler: Encoded): Map[Encoded, Encoded] = { + val nextQ2Var = encodeSymbol(q2s._1) val subst = Map(qs._2 -> currentQ2Var, guardVar -> enabler, - q2s._2 -> nextQ2Var, insts._2 -> encoder.encodeId(insts._1)) + q2s._2 -> nextQ2Var, insts._2 -> encodeSymbol(insts._1)) _currentQ2Var = nextQ2Var subst } - override def registerBlockers(substituter: T => T): Unit = { + override def registerBlockers(substituter: Encoded => Encoded): Unit = { val freshInst = substituter(insts._2) val bs = (blockers.keys ++ applications.keys).map(substituter).toSet - _currentInsts += freshInst -> bs + _insts += freshInst -> bs } } - private lazy val blockerId = FreshIdentifier("blocker", BooleanType, true) - private lazy val enablersToBlocker: MutableMap[Set[T], T] = MutableMap.empty - private lazy val blockerToEnablers: MutableMap[T, Set[T]] = MutableMap.empty - private def freshBlocker(enablers: Set[T]): (T, Option[T]) = enablers.toSeq match { + private lazy val blockerSymbol = Variable(FreshIdentifier("blocker", true), BooleanType) + private lazy val enablersToBlocker: MutableMap[Set[Encoded], Encoded] = MutableMap.empty + private lazy val blockerToEnablers: MutableMap[Encoded, Set[Encoded]] = MutableMap.empty + private def freshBlocker(enablers: Set[Encoded]): (Encoded, Option[Encoded]) = enablers.toSeq match { case Seq(b) if isBlocker(b) => (b, None) case _ => val last = enablersToBlocker.get(enablers).orElse { @@ -595,10 +640,10 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage last match { case Some(b) => (b, None) case None => - val nb = encoder.encodeId(blockerId) + val nb = encodeSymbol(blockerSymbol) enablersToBlocker += enablers -> nb blockerToEnablers += nb -> enablers - for (b <- enablers if isBlocker(b)) implies(b, nb) + for (b <- enablers if isBlocker(b)) impliesBlocker(b, nb) blocker(nb) (nb, Some(encodeEnablers(enablers))) @@ -606,30 +651,30 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage } private class LambdaAxiom ( - val pathVar: (Identifier, T), - val blocker: T, - val guardVar: T, - val quantifiers: Seq[(Identifier, T)], - val matchers: Set[Matcher[T]], - val allMatchers: Map[T, Set[Matcher[T]]], - val condVars: Map[Identifier, T], - val exprVars: Map[Identifier, T], - val condTree: Map[Identifier, Set[Identifier]], - val clauses: Seq[T], - val blockers: Map[T, Set[TemplateCallInfo[T]]], - val applications: Map[T, Set[App[T]]], - val lambdas: Seq[LambdaTemplate[T]], - val template: LambdaTemplate[T]) extends MatcherQuantification { + val pathVar: (Variable, Encoded), + val blocker: Encoded, + val guardVar: Encoded, + val quantifiers: Seq[(Variable, Encoded)], + val matchers: Set[Matcher], + val allMatchers: Map[Encoded, Set[Matcher]], + val condVars: Map[Variable, Encoded], + val exprVars: Map[Variable, Encoded], + val condTree: Map[Variable, Set[Variable]], + val clauses: Clauses, + val blockers: Calls, + val applications: Apps, + val lambdas: Seq[LambdaTemplate], + val template: LambdaTemplate) extends MatcherQuantification { val holds = start val body = template.lambda.body - protected def instanceSubst(enabler: T): Map[T, T] = { + protected def instanceSubst(enabler: Encoded): Map[Encoded, Encoded] = { Map(guardVar -> start, blocker -> enabler) } - override protected def skip(subst: Map[T, Arg[T]]): Boolean = { - val substituter = encoder.substitute(subst.mapValues(_.encoded)) + override protected def skip(subst: Map[Encoded, Arg]): Boolean = { + val substituter = mkSubstituter(subst.mapValues(_.encoded)) val msubst = subst.collect { case (c, Right(m)) => c -> m } allMatchers.forall { case (b, ms) => ms.forall(m => matchers(m) || instCtx(Set(substituter(b)) -> m.substitute(substituter, msubst))) @@ -638,13 +683,13 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage } private def extractQuorums( - quantified: Set[T], - matchers: Set[Matcher[T]], - lambdas: Seq[LambdaTemplate[T]] - ): Seq[Set[Matcher[T]]] = { - val extMatchers: Set[Matcher[T]] = { - def rec(templates: Seq[LambdaTemplate[T]]): Set[Matcher[T]] = - templates.foldLeft(Set.empty[Matcher[T]]) { + quantified: Set[Encoded], + matchers: Set[Matcher], + lambdas: Seq[LambdaTemplate] + ): Seq[Set[Matcher]] = { + val extMatchers: Set[Matcher] = { + def rec(templates: Seq[LambdaTemplate]): Set[Matcher] = + templates.foldLeft(Set.empty[Matcher]) { case (matchers, template) => matchers ++ template.matchers.flatMap(_._2) ++ rec(template.lambdas) } @@ -656,13 +701,13 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage if args exists (_.left.exists(quantified)) } yield m - purescala.Quantification.extractQuorums(quantifiedMatchers, quantified, - (m: Matcher[T]) => m.args.collect { case Right(m) if quantifiedMatchers(m) => m }.toSet, - (m: Matcher[T]) => m.args.collect { case Left(a) if quantified(a) => a }.toSet) + extractQuorums(quantifiedMatchers, quantified, + (m: Matcher) => m.args.collect { case Right(m) if quantifiedMatchers(m) => m }.toSet, + (m: Matcher) => m.args.collect { case Left(a) if quantified(a) => a }.toSet) } - def instantiateAxiom(template: LambdaTemplate[T], substMap: Map[T, Arg[T]]): Instantiation[T] = { - def quantifiedMatcher(m: Matcher[T]): Boolean = m.args.exists(a => a match { + def instantiateAxiom(template: LambdaTemplate, substMap: Map[Encoded, Arg]): Clauses = { + def quantifiedMatcher(m: Matcher): Boolean = m.args.exists(a => a match { case Left(v) => isQuantifier(v) case Right(m) => quantifiedMatcher(m) }) @@ -679,15 +724,15 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage val key = template.structure -> quantifiers if (quantifiers.isEmpty || lambdaAxioms(key)) { - Instantiation.empty[T] + Seq.empty } else { lambdaAxioms += key - val blockerT = encoder.encodeId(blockerId) + val blockerT = encodeSymbol(blockerSymbol) - val guard = FreshIdentifier("guard", BooleanType, true) - val guardT = encoder.encodeId(guard) + val guard = Variable(FreshIdentifier("guard", true), BooleanType) + val guardT = encodeSymbol(guard) - val substituter = encoder.substitute(substMap.mapValues(_.encoded) + (template.start -> blockerT)) + val substituter = mkSubstituter(substMap.mapValues(_.encoded) + (template.start -> blockerT)) val msubst = substMap.collect { case (c, Right(m)) => c -> m } val allMatchers = template.matchers map { case (b, ms) => @@ -697,13 +742,13 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage val qMatchers = allMatchers.flatMap(_._2).toSet val encArgs = template.args map (arg => Left(arg).substitute(substituter, msubst)) - val app = Application(Variable(template.ids._1), template.arguments.map(_._1.toVariable)) - val appT = encoder.encodeExpr((template.arguments.map(_._1) zip encArgs.map(_.encoded)).toMap + template.ids)(app) + val app = Application(template.ids._1, template.arguments.map(_._1)) + val appT = encodeExpr((template.arguments.map(_._1) zip encArgs.map(_.encoded)).toMap + template.ids)(app) val selfMatcher = Matcher(template.ids._2, template.tpe, encArgs, appT) val instMatchers = allMatchers + (template.start -> (allMatchers.getOrElse(template.start, Set.empty) + selfMatcher)) - val enablingClause = encoder.mkImplies(guardT, blockerT) + val enablingClause = mkImplies(guardT, blockerT) val condVars = template.condVars map { case (id, idT) => id -> substituter(idT) } val exprVars = template.exprVars map { case (id, idT) => id -> substituter(idT) } @@ -724,7 +769,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage val quantified = quantifiers.map(_._2).toSet val matchQuorums = extractQuorums(quantified, qMatchers, lambdas) - var instantiation = Instantiation.empty[T] + var instantiation: Clauses = Seq.empty for (matchers <- matchQuorums) { val axiom = new LambdaAxiom(template.pathVar._1 -> substituter(template.start), @@ -748,22 +793,22 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage } } - def instantiateQuantification(template: QuantificationTemplate[T]): (T, Instantiation[T]) = { + def instantiateQuantification(template: QuantificationTemplate): (Encoded, Clauses) = { templates.get(template.key) match { case Some(idT) => - (idT, Instantiation.empty) + (idT, Seq.empty) case None => - val qT = encoder.encodeId(template.qs._1) + val qT = encodeSymbol(template.qs._1) val quantified = template.quantifiers.map(_._2).toSet val matcherSet = template.matchers.flatMap(_._2).toSet val matchQuorums = extractQuorums(quantified, matcherSet, template.lambdas) - var instantiation = Instantiation.empty[T] + var clauses: Clauses = Seq.empty val qs = for (matchers <- matchQuorums) yield { - val newQ = encoder.encodeId(template.qs._1) - val substituter = encoder.substitute(Map(template.qs._2 -> newQ)) + val newQ = encodeSymbol(template.qs._1) + val substituter = mkSubstituter(Map(template.qs._2 -> newQ)) val quantification = new Quantification( template.pathVar, @@ -780,41 +825,41 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage val newCtx = new InstantiationContext() for ((b,m) <- instCtx.instantiated) { - instantiation ++= newCtx.instantiate(b, m)(quantification) + clauses ++= newCtx.instantiate(b, m)(quantification) } instCtx.merge(newCtx) quantification.qs._2 } - instantiation = instantiation withClause { + clauses :+= { val newQs = if (qs.isEmpty) trueT else if (qs.size == 1) qs.head - else encoder.mkAnd(qs : _*) - encoder.mkImplies(template.start, encoder.mkEquals(qT, newQs)) + else mkAnd(qs : _*) + mkImplies(template.start, mkEquals(qT, newQs)) } - instantiation ++= instantiateConstants(template.quantifiers, matcherSet) + clauses ++= instantiateConstants(template.quantifiers, matcherSet) templates += template.key -> qT - (qT, instantiation) + (qT, clauses) } } - def instantiateMatcher(blocker: T, matcher: Matcher[T]): Instantiation[T] = { + def instantiateMatcher(blocker: Encoded, matcher: Matcher): Clauses = { instCtx.instantiate(Set(blocker), matcher)(quantifications.toSeq : _*) } - def hasIgnored: Boolean = ignoredSubsts.nonEmpty || ignoredMatchers.nonEmpty + def canUnfoldQuantifiers: Boolean = ignoredSubsts.nonEmpty || ignoredMatchers.nonEmpty - def instantiateIgnored(force: Boolean = false): Instantiation[T] = { + def instantiateIgnored(force: Boolean = false): Clauses = { currentGen = if (!force) currentGen + 1 else { val gens = ignoredSubsts.toSeq.flatMap(_._2).map(_._1) ++ ignoredMatchers.toSeq.map(_._1) if (gens.isEmpty) currentGen else gens.min } - var instantiation = Instantiation.empty[T] + var clauses: Clauses = Seq.empty val matchersToRelease = ignoredMatchers.toList.flatMap { case e @ (gen, b, m) => if (gen == currentGen) { @@ -826,7 +871,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage } for ((bs,m) <- matchersToRelease) { - instantiation ++= instCtx.instantiate(bs, m)(quantifications.toSeq : _*) + clauses ++= instCtx.instantiate(bs, m)(quantifications.toSeq : _*) } val substsToRelease = quantifications.toList.flatMap { q => @@ -842,28 +887,28 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage } for ((q, enablers, subst) <- substsToRelease) { - instantiation ++= q.instantiateSubst(enablers, subst, strict = false) + clauses ++= q.instantiateSubst(enablers, subst, strict = false) } - instantiation + clauses } - private def instantiateConstants(quantifiers: Seq[(Identifier, T)], matchers: Set[Matcher[T]]): Instantiation[T] = { - var instantiation: Instantiation[T] = Instantiation.empty + private def instantiateConstants(quantifiers: Seq[(Variable, Encoded)], matchers: Set[Matcher]): Clauses = { + var clauses: Clauses = Seq.empty for (normalizer <- List(abstractNormalizer, concreteNormalizer)) { val quantifierSubst = normalizer.normalSubst(quantifiers) - val substituter = encoder.substitute(quantifierSubst) + val substituter = mkSubstituter(quantifierSubst) for { m <- matchers sm = m.substitute(substituter, Map.empty) if !instCtx.corresponding(sm).exists(_._2.args == sm.args) - } instantiation ++= instCtx.instantiate(Set.empty, sm)(quantifications.toSeq : _*) + } clauses ++= instCtx.instantiate(Set.empty, sm)(quantifications.toSeq : _*) - def unifyMatchers(matchers: Seq[Matcher[T]]): Instantiation[T] = matchers match { + def unifyMatchers(matchers: Seq[Matcher]): Clauses = matchers match { case sm +: others => - var instantiation = Instantiation.empty[T] + var clauses: Clauses = Seq.empty for (pm <- others if correspond(pm, sm)) { val encodedArgs = (sm.args zip pm.args).map(p => p._1.encoded -> p._2.encoded) val mismatches = encodedArgs.zipWithIndex.collect { @@ -893,114 +938,114 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage mapping ++ seq.map(i => i -> res) } - def extractArgs(args: Seq[Arg[T]]): Seq[Arg[T]] = + def extractArgs(args: Seq[Arg]): Seq[Arg] = (0 until args.size).map(i => args(positions.getOrElse(i, i))) - instantiation ++= instCtx.instantiate(Set.empty, sm.copy(args = extractArgs(sm.args)))(quantifications.toSeq : _*) - instantiation ++= instCtx.instantiate(Set.empty, pm.copy(args = extractArgs(pm.args)))(quantifications.toSeq : _*) + clauses ++= instCtx.instantiate(Set.empty, sm.copy(args = extractArgs(sm.args)))(quantifications.toSeq : _*) + clauses ++= instCtx.instantiate(Set.empty, pm.copy(args = extractArgs(pm.args)))(quantifications.toSeq : _*) } - instantiation ++ unifyMatchers(others) + clauses ++ unifyMatchers(others) - case _ => Instantiation.empty[T] + case _ => Seq.empty } if (normalizer == abstractNormalizer) { val substMatchers = matchers.map(_.substitute(substituter, Map.empty)) - instantiation ++= unifyMatchers(substMatchers.toSeq) + clauses ++= unifyMatchers(substMatchers.toSeq) } } - instantiation + clauses } - def checkClauses: Seq[T] = { - val clauses = new scala.collection.mutable.ListBuffer[T] - val keyClause = MutableMap.empty[MatcherKey, (Seq[T], T)] + def getFiniteRangeClauses: Clauses = { + val clauses = new scala.collection.mutable.ListBuffer[Encoded] + val keyClause = MutableMap.empty[MatcherKey, (Clauses, Encoded)] for ((_, bs, m) <- ignoredMatchers) { val key = matcherKey(m.caller, m.tpe) - val QTM(argTypes, _) = key.tpe + val QuantificationTypeMatcher(argTypes, _) = key.tpe val (values, clause) = keyClause.getOrElse(key, { val insts = instCtx.map.get(key).toMatchers - val guard = FreshIdentifier("guard", BooleanType) - val elems = argTypes.map(tpe => FreshIdentifier("elem", tpe)) - val values = argTypes.map(tpe => FreshIdentifier("value", tpe)) - val expr = andJoin(Variable(guard) +: (elems zip values).map(p => Equals(Variable(p._1), Variable(p._2)))) + val guard = Variable(FreshIdentifier("guard", true), BooleanType) + val elems = argTypes.map(tpe => Variable(FreshIdentifier("elem", true), tpe)) + val values = argTypes.map(tpe => Variable(FreshIdentifier("value", true), tpe)) + val expr = andJoin(guard +: (elems zip values).map(p => Equals(p._1, p._2))) - val guardP = guard -> encoder.encodeId(guard) - val elemsP = elems.map(e => e -> encoder.encodeId(e)) - val valuesP = values.map(v => v -> encoder.encodeId(v)) - val exprT = encoder.encodeExpr(elemsP.toMap ++ valuesP + guardP)(expr) + val guardP = guard -> encodeSymbol(guard) + val elemsP = elems.map(e => e -> encodeSymbol(e)) + val valuesP = values.map(v => v -> encodeSymbol(v)) + val exprT = encodeExpr(elemsP.toMap ++ valuesP + guardP)(expr) val disjuncts = insts.toSeq.map { case (b, im) => - val bp = if (m.caller != im.caller) encoder.mkAnd(encoder.mkEquals(m.caller, im.caller), b) else b + val bp = if (m.caller != im.caller) mkAnd(mkEquals(m.caller, im.caller), b) else b val subst = (elemsP.map(_._2) zip im.args.map(_.encoded)).toMap + (guardP._2 -> bp) - encoder.substitute(subst)(exprT) + mkSubstituter(subst)(exprT) } - val res = (valuesP.map(_._2), encoder.mkOr(disjuncts : _*)) + val res = (valuesP.map(_._2), mkOr(disjuncts : _*)) keyClause += key -> res res }) val b = encodeEnablers(bs) val substMap = (values zip m.args.map(_.encoded)).toMap - clauses += encoder.substitute(substMap)(encoder.mkImplies(b, clause)) + clauses += mkSubstituter(substMap)(mkImplies(b, clause)) } for (q <- quantifications) { - val guard = FreshIdentifier("guard", BooleanType) + val guard = Variable(FreshIdentifier("guard", true), BooleanType) val elems = q.quantifiers.map(_._1) - val values = elems.map(id => id.freshen) - val expr = andJoin(Variable(guard) +: (elems zip values).map(p => Equals(Variable(p._1), Variable(p._2)))) + val values = elems.map(v => v.freshen) + val expr = andJoin(guard +: (elems zip values).map(p => Equals(p._1, p._2))) - val guardP = guard -> encoder.encodeId(guard) - val elemsP = elems.map(e => e -> encoder.encodeId(e)) - val valuesP = values.map(v => v -> encoder.encodeId(v)) - val exprT = encoder.encodeExpr(elemsP.toMap ++ valuesP + guardP)(expr) + val guardP = guard -> encodeSymbol(guard) + val elemsP = elems.map(e => e -> encodeSymbol(e)) + val valuesP = values.map(v => v -> encodeSymbol(v)) + val exprT = encodeExpr(elemsP.toMap ++ valuesP + guardP)(expr) val disjunction = handledSubsts(q) match { - case set if set.isEmpty => encoder.encodeExpr(Map.empty)(BooleanLiteral(false)) - case set => encoder.mkOr(set.toSeq.map { case (enablers, subst) => - val b = if (enablers.isEmpty) trueT else encoder.mkAnd(enablers.toSeq : _*) + case set if set.isEmpty => encodeExpr(Map.empty)(BooleanLiteral(false)) + case set => mkOr(set.toSeq.map { case (enablers, subst) => + val b = if (enablers.isEmpty) trueT else mkAnd(enablers.toSeq : _*) val substMap = (elemsP.map(_._2) zip q.quantifiers.map(p => subst(p._2).encoded)).toMap + (guardP._2 -> b) - encoder.substitute(substMap)(exprT) + mkSubstituter(substMap)(exprT) } : _*) } for ((_, enablers, subst) <- ignoredSubsts(q)) { - val b = if (enablers.isEmpty) trueT else encoder.mkAnd(enablers.toSeq : _*) + val b = if (enablers.isEmpty) trueT else mkAnd(enablers.toSeq : _*) val substMap = (valuesP.map(_._2) zip q.quantifiers.map(p => subst(p._2).encoded)).toMap - clauses += encoder.substitute(substMap)(encoder.mkImplies(b, disjunction)) + clauses += mkSubstituter(substMap)(mkImplies(b, disjunction)) } } - def isQuantified(e: Arg[T]): Boolean = e match { + def isQuantified(e: Arg): Boolean = e match { case Left(t) => isQuantifier(t) case Right(m) => m.args.exists(isQuantified) } for ((key, ctx) <- instCtx.map.instantiations) { - val QTM(argTypes, _) = key.tpe + val QuantificationTypeMatcher(argTypes, _) = key.tpe for { (tpe, idx) <- argTypes.zipWithIndex quants <- abstractNormalizer.get(tpe) if quants.nonEmpty (b, m) <- ctx arg = m.args(idx) if !isQuantified(arg) - } clauses += encoder.mkAnd(quants.map(q => encoder.mkNot(encoder.mkEquals(q, arg.encoded))) : _*) + } clauses += mkAnd(quants.map(q => mkNot(mkEquals(q, arg.encoded))) : _*) - val byPosition: Iterable[Seq[T]] = ctx.flatMap { case (b, m) => + val byPosition: Iterable[Seq[Encoded]] = ctx.flatMap { case (b, m) => if (b != trueT) Seq.empty else m.args.zipWithIndex }.groupBy(_._2).map(p => p._2.toSeq.flatMap { case (a, _) => if (isQuantified(a)) Some(a.encoded) else None }).filter(_.nonEmpty) for ((a +: as) <- byPosition; a2 <- as) { - clauses += encoder.mkEquals(a, a2) + clauses += mkEquals(a, a2) } } @@ -1008,17 +1053,17 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage } trait ModelView { - protected val vars: Map[Identifier, T] + protected val vars: Map[Variable, Encoded] protected val evaluator: evaluators.DeterministicEvaluator - protected def get(id: Identifier): Option[Expr] - protected def eval(elem: T, tpe: TypeTree): Option[Expr] + protected def get(id: Variable): Option[Expr] + protected def eval(elem: Encoded, tpe: Type): Option[Expr] implicit lazy val context = evaluator.context lazy val reporter = context.reporter - private def extract(b: T, m: Matcher[T]): Option[Seq[Expr]] = { - val QTM(fromTypes, _) = m.tpe + private def extract(b: Encoded, m: Matcher): Option[Seq[Expr]] = { + val QuantificationTypeMatcher(fromTypes, _) = m.tpe val optEnabler = eval(b, BooleanType) optEnabler.filter(_ == BooleanLiteral(true)).flatMap { _ => val optArgs = (m.args zip fromTypes).map { case (arg, tpe) => eval(arg.encoded, tpe) } @@ -1041,7 +1086,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage }) def rec(expr: Expr, path: Expr): (Seq[(Expr, Expr)], Seq[Expr] => Expr) = expr match { - case (_: Lambda) | (_: FiniteLambda) => + case (_: Lambda) => (Seq(expr -> path), (es: Seq[Expr]) => es.head) case Tuple(es) => reconstruct(es.zipWithIndex.map { @@ -1058,64 +1103,21 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage rec(expr, path) } - def getPartialModel: PartialModel = { - val typeDomains: Map[TypeTree, Set[Seq[Expr]]] = typeInstantiations.map { - case (tpe, domain) => tpe -> domain.flatMap { case (b, m) => extract(b, m) }.toSet - } - - val lambdaDomains: Map[Lambda, Set[Seq[Expr]]] = lambdaInstantiations.map { - case (l, domain) => l -> domain.flatMap { case (b, m) => extract(b, m) }.toSet - } - - val domains = new Domains(lambdaDomains, typeDomains) - - val partialDomains: Map[T, Set[Seq[Expr]]] = partialInstantiations.map { - case (t, domain) => t -> domain.flatMap { case (b, m) => extract(b, m) }.toSet - } - - def extractElse(body: Expr): Expr = body match { - case IfExpr(cond, thenn, elze) => extractElse(elze) - case _ => body - } - - val mapping = vars.map { case (id, idT) => - val value = get(id).getOrElse(simplestValue(id.getType)) - val (functions, recons) = functionsOf(value, Variable(id)) - - id -> recons(functions.map { case (f, path) => - val encoded = encoder.encodeExpr(Map(id -> idT))(path) - val tpe = bestRealType(f.getType).asInstanceOf[FunctionType] - partialDomains.get(encoded).orElse(typeDomains.get(tpe)).map { domain => - FiniteLambda(domain.toSeq.map { es => - val optEv = evaluator.eval(application(f, es)).result - es -> optEv.getOrElse(scala.sys.error("Unexpectedly failed to evaluate " + application(f, es))) - }, f match { - case FiniteLambda(_, dflt, _) => dflt - case Lambda(_, body) => extractElse(body) - case _ => scala.sys.error("What kind of function is this : " + f.asString + " !?") - }, tpe) - }.getOrElse(f) - }) - } - - new PartialModel(mapping, domains) - } - def getTotalModel: Model = { - def checkForalls(quantified: Set[Identifier], body: Expr): Option[String] = { - val matchers = purescala.ExprOps.collect[(Expr, Seq[Expr])] { - case QM(e, args) => Set(e -> args) + def checkForalls(quantified: Set[Variable], body: Expr): Option[String] = { + val matchers = exprOps.collect[(Expr, Seq[Expr])] { + case QuantificationMatcher(e, args) => Set(e -> args) case _ => Set.empty } (body) if (matchers.isEmpty) return Some("No matchers found.") - val matcherToQuants = matchers.foldLeft(Map.empty[Expr, Set[Identifier]]) { + val matcherToQuants = matchers.foldLeft(Map.empty[Expr, Set[Variable]]) { case (acc, (m, args)) => acc + (m -> (acc.getOrElse(m, Set.empty) ++ args.flatMap { - case Variable(id) if quantified(id) => Set(id) - case _ => Set.empty[Identifier] + case v: Variable if quantified(v) => Set(v) + case _ => Set.empty[Variable] })) } @@ -1124,19 +1126,19 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage return Some("Non-bijective mapping for symbol " + bijectiveMappings.head._2.head._1.asString) def quantifiedArg(e: Expr): Boolean = e match { - case Variable(id) => quantified(id) - case QM(_, args) => args.forall(quantifiedArg) + case v: Variable => quantified(v) + case QuantificationMatcher(_, args) => args.forall(quantifiedArg) case _ => false } - purescala.ExprOps.postTraversal(m => m match { - case QM(_, args) => + exprOps.postTraversal(m => m match { + case QuantificationMatcher(_, args) => val qArgs = args.filter(quantifiedArg) if (qArgs.nonEmpty && qArgs.size < args.size) return Some("Mixed ground and quantified arguments in " + m.asString) - case Operator(es, _) if es.collect { case Variable(id) if quantified(id) => id }.nonEmpty => + case Operator(es, _) if es.collect { case v: Variable if quantified(v) => v }.nonEmpty => return Some("Invalid operation on quantifiers " + m.asString) case (_: Equals) | (_: And) | (_: Or) | (_: Implies) | (_: Not) => // OK @@ -1148,13 +1150,13 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage }) (body) body match { - case Variable(id) if quantified(id) => + case v: Variable if quantified(v) => Some("Unexpected free quantifier " + id.asString) case _ => None } } - val issues: Iterable[(Seq[Identifier], Expr, String)] = for { + val issues: Iterable[(Seq[Variable], Expr, String)] = for { q <- quantifications.view if eval(q.holds, BooleanType) == Some(BooleanLiteral(true)) msg <- checkForalls(q.quantifiers.map(_._1).toSet, q.body) @@ -1169,27 +1171,27 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage val types = typeInstantiations val partials = partialInstantiations - def extractCond(params: Seq[Identifier], args: Seq[(T, Expr)], structure: Map[T, Identifier]): Seq[Expr] = (params, args) match { + def extractCond(params: Seq[Variable], args: Seq[(Encoded, Expr)], structure: Map[Encoded, Variable]): Seq[Expr] = (params, args) match { case (id +: rparams, (v, arg) +: rargs) => if (isQuantifier(v)) { structure.get(v) match { - case Some(pid) => Equals(Variable(id), Variable(pid)) +: extractCond(rparams, rargs, structure) + case Some(pid) => Equals(id, pid) +: extractCond(rparams, rargs, structure) case None => extractCond(rparams, rargs, structure + (v -> id)) } } else { - Equals(Variable(id), arg) +: extractCond(rparams, rargs, structure) + Equals(id, arg) +: extractCond(rparams, rargs, structure) } case _ => Seq.empty } new Model(vars.map { case (id, idT) => val value = get(id).getOrElse(simplestValue(id.getType)) - val (functions, recons) = functionsOf(value, Variable(id)) + val (functions, recons) = functionsOf(value, id) id -> recons(functions.map { case (f, path) => - val encoded = encoder.encodeExpr(Map(id -> idT))(path) + val encoded = encodeExpr(Map(id -> idT))(path) val tpe = bestRealType(f.getType).asInstanceOf[FunctionType] - val params = tpe.from.map(tpe => FreshIdentifier("x", tpe, true)) + val params = tpe.from.map(tpe => Variable(FreshIdentifier("x", true), tpe)) partials.get(encoded).orElse(types.get(tpe)).map { domain => val conditionals = domain.flatMap { case (b, m) => extract(b, m).map { args => @@ -1200,7 +1202,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage val cond = if (m.args.exists(arg => isQuantifier(arg.encoded))) { extractCond(params, m.args.map(_.encoded) zip args, Map.empty) } else { - (params zip args).map(p => Equals(Variable(p._1), p._2)) + (params zip args).map(p => Equals(p._1, p._2)) } cond -> result @@ -1225,7 +1227,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage }) } - Lambda(params.map(ValDef(_)), body) + Lambda(params.map(_.toVal), body) } }.getOrElse(f) }) @@ -1233,19 +1235,16 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage } } - def getModel(vs: Map[Identifier, T], ev: DeterministicEvaluator, _get: Identifier => Option[Expr], _eval: (T, TypeTree) => Option[Expr]) = new ModelView { - val vars: Map[Identifier, T] = vs + def getModel(vs: Map[Variable, Encoded], ev: DeterministicEvaluator, _get: Variable => Option[Expr], _eval: (Encoded, Type) => Option[Expr]) = new ModelView { + val vars: Map[Variable, Encoded] = vs val evaluator: DeterministicEvaluator = ev - def get(id: Identifier): Option[Expr] = _get(id) - def eval(elem: T, tpe: TypeTree): Option[Expr] = _eval(elem, tpe) + def get(id: Variable): Option[Expr] = _get(id) + def eval(elem: Encoded, tpe: Type): Option[Expr] = _eval(elem, tpe) } - def getBlockersToPromote(eval: (T, TypeTree) => Option[Expr]): Seq[T] = quantifications.toSeq.flatMap { - case q: Quantification if eval(q.qs._2, BooleanType) == Some(BooleanLiteral(false)) => - val falseInsts = q.currentInsts.filter { case (inst, bs) => eval(inst, BooleanType) == Some(BooleanLiteral(false)) } - falseInsts.flatMap(_._2) + def getInstantiationsWithBlockers = quantifications.toSeq.flatMap { + case q: Quantification => q.instantiations.toSeq case _ => Seq.empty } } - diff --git a/src/main/scala/inox/solvers/unrolling/TemplateEncoder.scala b/src/main/scala/inox/solvers/unrolling/TemplateEncoder.scala deleted file mode 100644 index 74488aa8e7b5ce71aafd115038a0dc74c4ea0b2a..0000000000000000000000000000000000000000 --- a/src/main/scala/inox/solvers/unrolling/TemplateEncoder.scala +++ /dev/null @@ -1,32 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package solvers -package unrolling - -import purescala.Common._ -import purescala.Expressions._ -import purescala.Definitions._ -import purescala.Extractors._ -import purescala.ExprOps._ -import purescala.Types._ - -import utils._ - -import scala.collection.mutable.{Map => MutableMap, Set => MutableSet} - -trait TemplateEncoder[T] { - def encodeId(id: Identifier): T - def encodeExpr(bindings: Map[Identifier, T])(e: Expr): T - def substitute(map: Map[T, T]): T => T - - // Encodings needed for unrollingbank - def mkNot(v: T): T - def mkOr(ts: T*): T - def mkAnd(ts: T*): T - def mkEquals(l: T, r: T): T - def mkImplies(l: T, r: T): T - - def extractNot(v: T): Option[T] -} - diff --git a/src/main/scala/inox/solvers/unrolling/TemplateGenerator.scala b/src/main/scala/inox/solvers/unrolling/TemplateGenerator.scala index 9912becb82b65d8087fc3a8bfbee7bf4b38611c8..a240d8efec54e22bd4cb7b33432c69fba2e65d93 100644 --- a/src/main/scala/inox/solvers/unrolling/TemplateGenerator.scala +++ b/src/main/scala/inox/solvers/unrolling/TemplateGenerator.scala @@ -1,42 +1,32 @@ /* Copyright 2009-2016 EPFL, Lausanne */ -package leon +package inox package solvers package unrolling -import purescala.Common._ -import purescala.Expressions._ -import purescala.Extractors._ -import purescala.ExprOps._ -import purescala.Types._ -import purescala.TypeOps.bestRealType -import purescala.Definitions._ -import purescala.Constructors._ -import purescala.Quantification._ - -import theories._ -import utils.SeqUtils._ -import Instantiation._ - -class TemplateGenerator[T](val theories: TheoryEncoder, - val encoder: TemplateEncoder[T], - val assumePreHolds: Boolean) { - private var cache = Map[TypedFunDef, FunctionTemplate[T]]() - private var cacheExpr = Map[Expr, (FunctionTemplate[T], Map[Identifier, Identifier])]() - - private type Clauses = ( - Map[Identifier,T], - Map[Identifier,T], - Map[Identifier, Set[Identifier]], - Map[Identifier, Seq[Expr]], - Seq[LambdaTemplate[T]], - Seq[QuantificationTemplate[T]] +import utils._ +import scala.collection.mutable.{Map => MutableMap} + +trait TemplateGenerator { self: Templates => + import program._ + import program.trees._ + import program.symbols._ + + val assumePreHolds: Boolean + + private type TemplateClauses = ( + Map[Variable, Encoded], + Map[Variable, Encoded], + Map[Variable, Set[Variable]], + Map[Variable, Seq[Expr]], + Seq[LambdaTemplate], + Seq[QuantificationTemplate] ) - private def emptyClauses: Clauses = (Map.empty, Map.empty, Map.empty, Map.empty, Seq.empty, Seq.empty) + private def emptyClauses: TemplateClauses = (Map.empty, Map.empty, Map.empty, Map.empty, Seq.empty, Seq.empty) - private implicit class ClausesWrapper(clauses: Clauses) { - def ++(that: Clauses): Clauses = { + private implicit class ClausesWrapper(clauses: TemplateClauses) { + def ++(that: TemplateClauses): TemplateClauses = { val (thisConds, thisExprs, thisTree, thisGuarded, thisLambdas, thisQuants) = clauses val (thatConds, thatExprs, thatTree, thatGuarded, thatLambdas, thatQuants) = that @@ -45,29 +35,9 @@ class TemplateGenerator[T](val theories: TheoryEncoder, } } - val manager = new QuantificationManager[T](encoder) - - def mkTemplate(raw: Expr): (FunctionTemplate[T], Map[Identifier, Identifier]) = { - if (cacheExpr contains raw) { - return cacheExpr(raw) - } - - val mapping = variablesOf(raw).map(id => id -> theories.encode(id)).toMap - val body = theories.encode(raw)(mapping) - - val arguments = mapping.values.toSeq.map(ValDef(_)) - val fakeFunDef = new FunDef(FreshIdentifier("fake", alwaysShowUniqueID = true), Nil, arguments, body.getType) - - fakeFunDef.precondition = Some(andJoin(arguments.map(vd => manager.typeUnroller(vd.toVariable)))) - fakeFunDef.body = Some(body) + private val cache: MutableMap[TypedFunDef, FunctionTemplate] = MutableMap.empty - val res = mkTemplate(fakeFunDef.typed, false) - val p = (res, mapping) - cacheExpr += raw -> p - p - } - - def mkTemplate(tfd: TypedFunDef, isRealFunDef: Boolean = true): FunctionTemplate[T] = { + def mkTemplate(tfd: TypedFunDef): FunctionTemplate = { if (cache contains tfd) { return cache(tfd) } @@ -78,12 +48,12 @@ class TemplateGenerator[T](val theories: TheoryEncoder, val newBody : Option[Expr] = tfd.body.map(b => matchToIfThenElse(b)) val lambdaBody : Option[Expr] = newBody.map(b => simplifyHOFunctions(b)) - val funDefArgs: Seq[Identifier] = tfd.paramIds - val lambdaArguments: Seq[Identifier] = lambdaBody.map(lambdaArgs).toSeq.flatten - val invocation : Expr = FunctionInvocation(tfd, funDefArgs.map(_.toVariable)) + val funDefArgs: Seq[Variable] = tfd.params.map(_.toVariable) + val lambdaArguments: Seq[Variable] = lambdaBody.map(lambdaArgs).toSeq.flatten + val invocation : Expr = tfd.applied(funDefArgs) val invocationEqualsBody : Seq[Expr] = lambdaBody match { - case Some(body) if isRealFunDef => + case Some(body) => val bs = liftedEquals(invocation, body, lambdaArguments) :+ Equals(invocation, body) if(prec.isDefined) { @@ -96,19 +66,16 @@ class TemplateGenerator[T](val theories: TheoryEncoder, Seq.empty } - val start : Identifier = FreshIdentifier("start", BooleanType, true) - val pathVar : (Identifier, T) = start -> encoder.encodeId(start) + val start : Variable = Variable(FreshIdentifier("start", true), BooleanType) + val pathVar : (Variable, Encoded) = start -> encodeSymbol(start) - val allArguments : Seq[Identifier] = funDefArgs ++ lambdaArguments - val arguments : Seq[(Identifier, T)] = allArguments.map(id => id -> encoder.encodeId(id)) + val allArguments : Seq[Variable] = funDefArgs ++ lambdaArguments + val arguments : Seq[(Variable, Encoded)] = allArguments.map(id => id -> encodeSymbol(id)) - val substMap : Map[Identifier, T] = arguments.toMap + pathVar + val substMap : Map[Variable, Encoded] = arguments.toMap + pathVar - val (bodyConds, bodyExprs, bodyTree, bodyGuarded, bodyLambdas, bodyQuantifications) = if (isRealFunDef) { + val (bodyConds, bodyExprs, bodyTree, bodyGuarded, bodyLambdas, bodyQuantifications) = invocationEqualsBody.foldLeft(emptyClauses)((clsSet, cls) => clsSet ++ mkClauses(start, cls, substMap)) - } else { - (prec.toSeq :+ lambdaBody.get).foldLeft(emptyClauses)((clsSet, cls) => clsSet ++ mkClauses(start, cls, substMap)) - } // Now the postcondition. val (condVars, exprVars, condTree, guardedExprs, lambdas, quantifications) = tfd.postcondition match { @@ -133,25 +100,24 @@ class TemplateGenerator[T](val theories: TheoryEncoder, (bodyConds, bodyExprs, bodyTree, bodyGuarded, bodyLambdas, bodyQuantifications) } - val template = FunctionTemplate(tfd, encoder, manager, - pathVar, arguments, condVars, exprVars, condTree, guardedExprs, quantifications, lambdas, isRealFunDef) + val template = FunctionTemplate(tfd, pathVar, arguments, + condVars, exprVars, condTree, guardedExprs, lambdas, quantifications) cache += tfd -> template template } - private def lambdaArgs(expr: Expr): Seq[Identifier] = expr match { + private def lambdaArgs(expr: Expr): Seq[Variable] = expr match { case Lambda(args, body) => args.map(_.id.freshen) ++ lambdaArgs(body) case IsTyped(_, _: FunctionType) => sys.error("Only applicable on lambda chains") case _ => Seq.empty } - private def liftedEquals(invocation: Expr, body: Expr, args: Seq[Identifier], inlineFirst: Boolean = false): Seq[Expr] = { - def rec(i: Expr, b: Expr, args: Seq[Identifier], inline: Boolean): Seq[Expr] = i.getType match { + private def liftedEquals(invocation: Expr, body: Expr, args: Seq[Variable], inlineFirst: Boolean = false): Seq[Expr] = { + def rec(i: Expr, b: Expr, args: Seq[Variable], inline: Boolean): Seq[Expr] = i.getType match { case FunctionType(from, to) => val (currArgs, nextArgs) = args.splitAt(from.size) - val arguments = currArgs.map(_.toVariable) val apply = if (inline) application _ else Application - val (appliedInv, appliedBody) = (apply(i, arguments), apply(b, arguments)) + val (appliedInv, appliedBody) = (apply(i, currArgs), apply(b, currArgs)) rec(appliedInv, appliedBody, nextArgs, false) :+ Equals(appliedInv, appliedBody) case _ => assert(args.isEmpty, "liftedEquals should consume all provided arguments") @@ -161,29 +127,29 @@ class TemplateGenerator[T](val theories: TheoryEncoder, rec(invocation, body, args, inlineFirst) } - private def minimalFlattening(inits: Set[Identifier], conj: Expr): (Set[Identifier], Expr) = { + private def minimalFlattening(inits: Set[Variable], conj: Expr): (Set[Variable], Expr) = { var mapping: Map[Expr, Expr] = Map.empty - var quantified: Set[Identifier] = inits - var quantifierEqualities: Seq[(Expr, Identifier)] = Seq.empty + var quantified: Set[Variable] = inits + var quantifierEqualities: Seq[(Expr, Variable)] = Seq.empty - val newConj = postMap { + val newConj = exprOps.postMap { case expr if mapping.isDefinedAt(expr) => Some(mapping(expr)) case expr @ QuantificationMatcher(c, args) => - val isMatcher = args.exists { case Variable(id) => quantified(id) case _ => false } - val isRelevant = (variablesOf(expr) & quantified).nonEmpty + val isMatcher = args.exists { case v: Variable => quantified(v) case _ => false } + val isRelevant = (exprOps.variablesOf(expr) & quantified).nonEmpty if (!isMatcher && isRelevant) { val newArgs = args.map { - case arg @ QuantificationMatcher(_, _) if (variablesOf(arg) & quantified).nonEmpty => - val id = FreshIdentifier("flat", arg.getType) - quantifierEqualities :+= (arg -> id) - quantified += id - Variable(id) + case arg @ QuantificationMatcher(_, _) if (exprOps.variablesOf(arg) & quantified).nonEmpty => + val v = Variable(FreshIdentifier("flat", true), arg.getType) + quantifierEqualities :+= (arg -> v) + quantified += v + v case arg => arg } - val newExpr = replace((args zip newArgs).toMap, expr) + val newExpr = exprOps.replace((args zip newArgs).toMap, expr) mapping += expr -> newExpr Some(newExpr) } else { @@ -194,57 +160,59 @@ class TemplateGenerator[T](val theories: TheoryEncoder, } (conj) val flatConj = implies(andJoin(quantifierEqualities.map { - case (arg, id) => Equals(arg, Variable(id)) + case (arg, id) => Equals(arg, id) }), newConj) (quantified, flatConj) } - def mkClauses(pathVar: Identifier, expr: Expr, substMap: Map[Identifier, T]): Clauses = { + def mkClauses(pathVar: Variable, expr: Expr, substMap: Map[Variable, Encoded]): TemplateClauses = { val (p, (condVars, exprVars, condTree, guardedExprs, lambdas, quantifications)) = mkExprClauses(pathVar, expr, substMap) val allGuarded = guardedExprs + (pathVar -> (p +: guardedExprs.getOrElse(pathVar, Seq.empty))) (condVars, exprVars, condTree, allGuarded, lambdas, quantifications) } - private def mkExprClauses(pathVar: Identifier, expr: Expr, substMap: Map[Identifier, T]): (Expr, Clauses) = { + private def mkExprClauses(pathVar: Variable, expr: Expr, substMap: Map[Variable, Encoded]): (Expr, TemplateClauses) = { - var condVars = Map[Identifier, T]() - var condTree = Map[Identifier, Set[Identifier]](pathVar -> Set.empty).withDefaultValue(Set.empty) - def storeCond(pathVar: Identifier, id: Identifier) : Unit = { - condVars += id -> encoder.encodeId(id) + var condVars = Map[Variable, Encoded]() + var condTree = Map[Variable, Set[Variable]](pathVar -> Set.empty).withDefaultValue(Set.empty) + def storeCond(pathVar: Variable, id: Variable) : Unit = { + condVars += id -> encodeSymbol(id) condTree += pathVar -> (condTree(pathVar) + id) } - @inline def encodedCond(id: Identifier) : T = substMap.getOrElse(id, condVars(id)) + @inline def encodedCond(id: Variable): Encoded = substMap.getOrElse(id, condVars(id)) - var exprVars = Map[Identifier, T]() - @inline def storeExpr(id: Identifier) : Unit = exprVars += id -> encoder.encodeId(id) + var exprVars = Map[Variable, Encoded]() + @inline def storeExpr(id: Variable) : Unit = exprVars += id -> encodeSymbol(id) // Represents clauses of the form: // id => expr && ... && expr - var guardedExprs = Map[Identifier, Seq[Expr]]() - def storeGuarded(guardVar: Identifier, expr: Expr) : Unit = { - assert(expr.getType == BooleanType, expr.asString(Program.empty)(LeonContext.empty) + " is not of type Boolean. " + purescala.ExprOps.explainTyping(expr)) + var guardedExprs = Map[Variable, Seq[Expr]]() + def storeGuarded(guardVar: Variable, expr: Expr) : Unit = { + assert(expr.getType == BooleanType, expr.asString + " is not of type Boolean. " + explainTyping(expr)) val prev = guardedExprs.getOrElse(guardVar, Nil) guardedExprs += guardVar -> (expr +: prev) } - var lambdaVars = Map[Identifier, T]() - @inline def storeLambda(id: Identifier) : T = { - val idT = encoder.encodeId(id) + def iff(e1: Expr, e2: Expr): Unit = storeGuarded(pathVar, Equals(e1, e2)) + + var lambdaVars = Map[Variable, Encoded]() + @inline def storeLambda(id: Variable): Encoded = { + val idT = encodeSymbol(id) lambdaVars += id -> idT idT } - var quantifications = Seq[QuantificationTemplate[T]]() - @inline def registerQuantification(quantification: QuantificationTemplate[T]): Unit = + var quantifications = Seq[QuantificationTemplate]() + @inline def registerQuantification(quantification: QuantificationTemplate): Unit = quantifications :+= quantification - var lambdas = Seq[LambdaTemplate[T]]() - @inline def registerLambda(lambda: LambdaTemplate[T]) : Unit = lambdas :+= lambda + var lambdas = Seq[LambdaTemplate]() + @inline def registerLambda(lambda: LambdaTemplate) : Unit = lambdas :+= lambda - def rec(pathVar: Identifier, expr: Expr): Expr = { + def rec(pathVar: Variable, expr: Expr): Expr = { expr match { case a @ Assert(cond, err, body) => rec(pathVar, IfExpr(cond, body, Error(body.getType, err getOrElse "assertion failed"))) @@ -252,66 +220,47 @@ class TemplateGenerator[T](val theories: TheoryEncoder, case e @ Ensuring(_, _) => rec(pathVar, e.toAssert) - case l @ Let(i, e : Lambda, b) => + case l @ Let(i, e: Lambda, b) => val re = rec(pathVar, e) // guaranteed variable! - val rb = rec(pathVar, replace(Map(Variable(i) -> re), b)) + val rb = rec(pathVar, exprOps.replace(Map(i.toVariable -> re), b)) rb case l @ Let(i, e, b) => - val newExpr : Identifier = FreshIdentifier("lt", i.getType, true) + val newExpr : Variable = Variable(FreshIdentifier("lt", true), i.getType) storeExpr(newExpr) val re = rec(pathVar, e) - storeGuarded(pathVar, Equals(Variable(newExpr), re)) - val rb = rec(pathVar, replace(Map(Variable(i) -> Variable(newExpr)), b)) + storeGuarded(pathVar, Equals(newExpr, re)) + val rb = rec(pathVar, exprOps.replace(Map(i.toVariable -> newExpr), b)) rb - /* TODO: maybe we want this specialization? - case l @ LetTuple(is, e, b) => - val tuple : Identifier = FreshIdentifier("t", TupleType(is.map(_.getType)), true) - storeExpr(tuple) - val re = rec(pathVar, e) - storeGuarded(pathVar, Equals(Variable(tuple), re)) - - val mapping = for ((id, i) <- is.zipWithIndex) yield { - val newId = FreshIdentifier("ti", id.getType, true) - storeExpr(newId) - storeGuarded(pathVar, Equals(Variable(newId), TupleSelect(Variable(tuple), i+1))) - - (Variable(id) -> Variable(newId)) - } - - val rb = rec(pathVar, replace(mapping.toMap, b)) - rb - */ case m : MatchExpr => sys.error("'MatchExpr's should have been eliminated before generating templates.") - case p : Passes => sys.error("'Passes's should have been eliminated before generating templates.") case i @ Implies(lhs, rhs) => - if (!isSimple(i)) { + if (!exprOps.isSimple(i)) { rec(pathVar, Or(Not(lhs), rhs)) } else { implies(rec(pathVar, lhs), rec(pathVar, rhs)) } case a @ And(parts) => - val partitions = groupWhile(parts)(isSimple) + val partitions = SeqUtils.groupWhile(parts)(exprOps.isSimple) partitions.map(andJoin) match { case Seq(e) => e case seq => - val newExpr : Identifier = FreshIdentifier("e", BooleanType, true) + val newExpr: Variable = Variable(FreshIdentifier("e", true), BooleanType) storeExpr(newExpr) - def recAnd(pathVar: Identifier, partitions: Seq[Expr]): Unit = partitions match { + def recAnd(pathVar: Variable, partitions: Seq[Expr]): Unit = partitions match { case x :: Nil => - storeGuarded(pathVar, Equals(Variable(newExpr), rec(pathVar, x))) + storeGuarded(pathVar, Equals(newExpr, rec(pathVar, x))) case x :: xs => - val newBool : Identifier = FreshIdentifier("b", BooleanType, true) + val newBool: Variable = Variable(FreshIdentifier("b", true), BooleanType) storeCond(pathVar, newBool) val xrec = rec(pathVar, x) - storeGuarded(pathVar, Equals(Variable(newBool), xrec)) - storeGuarded(pathVar, Implies(Not(Variable(newBool)), Not(Variable(newExpr)))) + iff(and(pathVar, xrec), newBool) + iff(and(pathVar, not(xrec)), not(newExpr)) recAnd(newBool, xs) @@ -319,28 +268,28 @@ class TemplateGenerator[T](val theories: TheoryEncoder, } recAnd(pathVar, seq) - Variable(newExpr) + newExpr } case o @ Or(parts) => - val partitions = groupWhile(parts)(isSimple) + val partitions = SeqUtils.groupWhile(parts)(exprOps.isSimple) partitions.map(orJoin) match { case Seq(e) => e case seq => - val newExpr : Identifier = FreshIdentifier("e", BooleanType, true) + val newExpr: Variable = Variable(FreshIdentifier("e", true), BooleanType) storeExpr(newExpr) - def recOr(pathVar: Identifier, partitions: Seq[Expr]): Unit = partitions match { + def recOr(pathVar: Variable, partitions: Seq[Expr]): Unit = partitions match { case x :: Nil => - storeGuarded(pathVar, Equals(Variable(newExpr), rec(pathVar, x))) + storeGuarded(pathVar, Equals(newExpr, rec(pathVar, x))) case x :: xs => - val newBool : Identifier = FreshIdentifier("b", BooleanType, true) + val newBool: Variable = Variable(FreshIdentifier("b", true), BooleanType) storeCond(pathVar, newBool) val xrec = rec(pathVar, x) - storeGuarded(pathVar, Equals(Not(Variable(newBool)), xrec)) - storeGuarded(pathVar, Implies(Not(Variable(newBool)), Variable(newExpr))) + iff(and(pathVar, xrec), newExpr) + iff(and(pathVar, not(xrec)), newBool) recOr(newBool, xs) @@ -348,16 +297,16 @@ class TemplateGenerator[T](val theories: TheoryEncoder, } recOr(pathVar, seq) - Variable(newExpr) + newExpr } case i @ IfExpr(cond, thenn, elze) => { - if(isSimple(i)) { + if(exprOps.isSimple(i)) { i } else { - val newBool1 : Identifier = FreshIdentifier("b", BooleanType, true) - val newBool2 : Identifier = FreshIdentifier("b", BooleanType, true) - val newExpr : Identifier = FreshIdentifier("e", i.getType, true) + val newBool1 : Variable = Variable(FreshIdentifier("b", true), BooleanType) + val newBool2 : Variable = Variable(FreshIdentifier("b", true), BooleanType) + val newExpr : Variable = Variable(FreshIdentifier("e", true), i.getType) storeCond(pathVar, newBool1) storeCond(pathVar, newBool2) @@ -368,127 +317,100 @@ class TemplateGenerator[T](val theories: TheoryEncoder, val trec = rec(newBool1, thenn) val erec = rec(newBool2, elze) - storeGuarded(pathVar, or(Variable(newBool1), Variable(newBool2))) - storeGuarded(pathVar, or(not(Variable(newBool1)), not(Variable(newBool2)))) - // TODO can we improve this? i.e. make it more symmetrical? - // Probably it's symmetrical enough to Z3. - storeGuarded(pathVar, Equals(Variable(newBool1), crec)) - storeGuarded(newBool1, Equals(Variable(newExpr), trec)) - storeGuarded(newBool2, Equals(Variable(newExpr), erec)) - Variable(newExpr) - } - } - - case c @ Choose(Lambda(params, cond)) => - val cs = params.map(_.id.freshen.toVariable) + iff(and(pathVar, cond), newBool1) + iff(and(pathVar, not(cond)), newBool2) - for (c <- cs) { - storeExpr(c.id) + storeGuarded(newBool1, Equals(newExpr, trec)) + storeGuarded(newBool2, Equals(newExpr, erec)) + newExpr } - - val freshMap = (params.map(_.id) zip cs).toMap - - storeGuarded(pathVar, replaceFromIDs(freshMap, cond)) - - tupleWrap(cs) - - case FiniteLambda(mapping, dflt, FunctionType(from, to)) => - val args = from.map(tpe => FreshIdentifier("x", tpe)) - val body = mapping.toSeq.foldLeft(dflt) { case (elze, (exprs, res)) => - IfExpr(andJoin((args zip exprs).map(p => Equals(Variable(p._1), p._2))), res, elze) - } - - rec(pathVar, Lambda(args.map(ValDef(_)), body)) + } case l @ Lambda(args, body) => - val idArgs : Seq[Identifier] = lambdaArgs(l) - val trArgs : Seq[T] = idArgs.map(id => substMap.getOrElse(id, encoder.encodeId(id))) + val idArgs : Seq[Variable] = lambdaArgs(l) + val trArgs : Seq[Encoded] = idArgs.map(id => substMap.getOrElse(id, encodeSymbol(id))) - val lid = FreshIdentifier("lambda", bestRealType(l.getType), true) - val clauses = liftedEquals(Variable(lid), l, idArgs, inlineFirst = true) + val lid = Variable(FreshIdentifier("lambda", true), bestRealType(l.getType)) + val clauses = liftedEquals(lid, l, idArgs, inlineFirst = true) - val localSubst: Map[Identifier, T] = substMap ++ condVars ++ exprVars ++ lambdaVars - val clauseSubst: Map[Identifier, T] = localSubst ++ (idArgs zip trArgs) + val localSubst: Map[Variable, Encoded] = substMap ++ condVars ++ exprVars ++ lambdaVars + val clauseSubst: Map[Variable, Encoded] = localSubst ++ (idArgs zip trArgs) val (lambdaConds, lambdaExprs, lambdaTree, lambdaGuarded, lambdaTemplates, lambdaQuants) = clauses.foldLeft(emptyClauses)((clsSet, cls) => clsSet ++ mkClauses(pathVar, cls, clauseSubst)) - val ids: (Identifier, T) = lid -> storeLambda(lid) + val ids: (Variable, Encoded) = lid -> storeLambda(lid) val (struct, deps) = normalizeStructure(l) - import Template._ - import Instantiation.MapSetWrapper - val (dependencies, (depConds, depExprs, depTree, depGuarded, depLambdas, depQuants)) = - deps.foldLeft[(Seq[T], Clauses)](Seq.empty -> emptyClauses) { + deps.foldLeft[(Seq[Encoded], TemplateClauses)](Seq.empty -> emptyClauses) { case ((dependencies, clsSet), (id, expr)) => - if (!isSimple(expr)) { - val encoded = encoder.encodeId(id) + if (!exprOps.isSimple(expr)) { + val encoded = encodeSymbol(id) val (e, cls @ (_, _, _, _, lmbds, quants)) = mkExprClauses(pathVar, expr, localSubst) val clauseSubst = localSubst ++ lmbds.map(_.ids) ++ quants.map(_.qs) - (dependencies :+ encoder.encodeExpr(clauseSubst)(e), clsSet ++ cls) + (dependencies :+ encodeExpr(clauseSubst)(e), clsSet ++ cls) } else { - (dependencies :+ encoder.encodeExpr(localSubst)(expr), clsSet) + (dependencies :+ encodeExpr(localSubst)(expr), clsSet) } } - val (depClauses, depCalls, depApps, _, depMatchers, _) = Template.encode( - encoder, pathVar -> encodedCond(pathVar), Seq.empty, + val (depClauses, depCalls, depApps, depMatchers, _) = Template.encode( + pathVar -> encodedCond(pathVar), Seq.empty, depConds, depExprs, depGuarded, depLambdas, depQuants, localSubst) - val depClosures: Seq[T] = { - val vars = variablesOf(l) - var cls: Seq[Identifier] = Seq.empty - preTraversal { case Variable(id) if vars(id) => cls :+= id case _ => } (l) + val depClosures: Seq[Encoded] = { + val vars = exprOps.variablesOf(l) + var cls: Seq[Variable] = Seq.empty + exprOps.preTraversal { case v: Variable if vars(v) => cls :+= v case _ => } (l) cls.distinct.map(localSubst) } - val structure = new LambdaStructure[T]( encoder, manager, + val structure = new LambdaStructure( struct, dependencies, pathVar -> encodedCond(pathVar), depClosures, depConds, depExprs, depTree, depClauses, depCalls, depApps, depLambdas, depMatchers, depQuants) - val template = LambdaTemplate(ids, encoder, manager, pathVar -> encodedCond(pathVar), + val template = LambdaTemplate(ids, pathVar -> encodedCond(pathVar), idArgs zip trArgs, lambdaConds, lambdaExprs, lambdaTree, - lambdaGuarded, lambdaQuants, lambdaTemplates, structure, localSubst, l) + lambdaGuarded, lambdaTemplates, lambdaQuants, structure, localSubst, l) registerLambda(template) - - Variable(lid) + lid case f @ Forall(args, body) => val TopLevelAnds(conjuncts) = body val conjunctQs = conjuncts.map { conjunct => - val vars = variablesOf(conjunct) - val inits = args.map(_.id).filter(vars).toSet + val vars = exprOps.variablesOf(conjunct) + val inits = args.map(_.toVariable).filter(vars).toSet val (quantifiers, flatConj) = minimalFlattening(inits, conjunct) - val idQuantifiers : Seq[Identifier] = quantifiers.toSeq - val trQuantifiers : Seq[T] = idQuantifiers.map(encoder.encodeId) + val idQuantifiers : Seq[Variable] = quantifiers.toSeq + val trQuantifiers : Seq[Encoded] = idQuantifiers.map(encodeSymbol) - val q: Identifier = FreshIdentifier("q", BooleanType, true) - val q2: Identifier = FreshIdentifier("qo", BooleanType, true) - val inst: Identifier = FreshIdentifier("inst", BooleanType, true) - val guard: Identifier = FreshIdentifier("guard", BooleanType, true) + val q: Variable = Variable(FreshIdentifier("q", true), BooleanType) + val q2: Variable = Variable(FreshIdentifier("qo", true), BooleanType) + val inst: Variable = Variable(FreshIdentifier("inst", true), BooleanType) + val guard: Variable = Variable(FreshIdentifier("guard", true), BooleanType) - val clause = Equals(Variable(inst), Implies(Variable(guard), flatConj)) + val clause = Equals(inst, Implies(guard, flatConj)) - val qs: (Identifier, T) = q -> encoder.encodeId(q) - val localSubst: Map[Identifier, T] = substMap ++ condVars ++ exprVars ++ lambdaVars - val clauseSubst: Map[Identifier, T] = localSubst ++ (idQuantifiers zip trQuantifiers) + val qs: (Variable, Encoded) = q -> encodeSymbol(q) + val localSubst: Map[Variable, Encoded] = substMap ++ condVars ++ exprVars ++ lambdaVars + val clauseSubst: Map[Variable, Encoded] = localSubst ++ (idQuantifiers zip trQuantifiers) val (p, (qConds, qExprs, qTree, qGuarded, qTemplates, qQuants)) = mkExprClauses(pathVar, flatConj, clauseSubst) assert(qQuants.isEmpty, "Unhandled nested quantification in "+clause) val allGuarded = qGuarded + (pathVar -> (Seq( - Equals(Variable(inst), Implies(Variable(guard), p)), - Equals(Variable(q), And(Variable(q2), Variable(inst))) + Equals(inst, Implies(guard, p)), + Equals(q, And(q2, inst)) ) ++ qGuarded.getOrElse(pathVar, Seq.empty))) - val dependencies: Map[Identifier, T] = vars.filterNot(quantifiers).map(id => id -> localSubst(id)).toMap - val template = QuantificationTemplate[T](encoder, manager, pathVar -> encodedCond(pathVar), + val dependencies: Map[Variable, Encoded] = vars.filterNot(quantifiers).map(id => id -> localSubst(id)).toMap + val template = QuantificationTemplate(pathVar -> encodedCond(pathVar), qs, q2, inst, guard, idQuantifiers zip trQuantifiers, qConds, qExprs, qTree, allGuarded, qTemplates, localSubst, - dependencies, Forall(quantifiers.toSeq.sortBy(_.uniqueName).map(ValDef(_)), flatConj)) + dependencies, Forall(quantifiers.toSeq.sortBy(_.id.uniqueName).map(_.toVal), flatConj)) registerQuantification(template) - Variable(q) + q } andJoin(conjunctQs) diff --git a/src/main/scala/inox/solvers/unrolling/TemplateInfo.scala b/src/main/scala/inox/solvers/unrolling/TemplateInfo.scala deleted file mode 100644 index 3dc60c89f29d5efb55abff75b4d1e6d604719b9c..0000000000000000000000000000000000000000 --- a/src/main/scala/inox/solvers/unrolling/TemplateInfo.scala +++ /dev/null @@ -1,39 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package solvers -package unrolling - -import purescala.Definitions.TypedFunDef -import Template.Arg - -case class TemplateCallInfo[T](tfd: TypedFunDef, args: Seq[Arg[T]]) { - override def toString = { - tfd.signature + args.map { - case Right(m) => m.toString - case Left(v) => v.toString - }.mkString("(", ", ", ")") - } -} - -case class TemplateAppInfo[T](template: Either[LambdaTemplate[T], T], equals: T, args: Seq[Arg[T]]) { - override def toString = { - val caller = template match { - case Left(tmpl) => tmpl.ids._2 - case Right(c) => c - } - - caller + "|" + equals + args.map { - case Right(m) => m.toString - case Left(v) => v.toString - }.mkString("(", ",", ")") - } -} - -object TemplateAppInfo { - def apply[T](template: LambdaTemplate[T], equals: T, args: Seq[Arg[T]]): TemplateAppInfo[T] = - TemplateAppInfo(Left(template), equals, args) - - def apply[T](caller: T, equals: T, args: Seq[Arg[T]]): TemplateAppInfo[T] = - TemplateAppInfo(Right(caller), equals, args) -} diff --git a/src/main/scala/inox/solvers/unrolling/TemplateManager.scala b/src/main/scala/inox/solvers/unrolling/TemplateManager.scala deleted file mode 100644 index 63cf3ea024e48bf7d3cc509b6533826538776a2b..0000000000000000000000000000000000000000 --- a/src/main/scala/inox/solvers/unrolling/TemplateManager.scala +++ /dev/null @@ -1,538 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package solvers -package unrolling - -import purescala.Common._ -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.Quantification._ -import purescala.Constructors._ -import purescala.Extractors._ -import purescala.ExprOps._ -import purescala.Types._ -import purescala.TypeOps.bestRealType - -import utils._ - -import scala.collection.generic.CanBuildFrom - -/** A template instatiation - * - * [[Template]] instances, when provided with concrete arguments and a - * blocker, will generate three outputs used for program unfolding: - * - clauses: clauses that will be added to the underlying solver - * - call blockers: bookkeeping information necessary for named - * function unfolding - * - app blockers: bookkeeping information necessary for first-class - * function unfolding - * - * This object provides helper methods to deal with the triplets - * generated during unfolding. - */ -object Instantiation { - type Clauses[T] = Seq[T] - type CallBlockers[T] = Map[T, Set[TemplateCallInfo[T]]] - type AppBlockers[T] = Map[(T, App[T]), Set[TemplateAppInfo[T]]] - type Instantiation[T] = (Clauses[T], CallBlockers[T], AppBlockers[T]) - - def empty[T] = (Seq.empty[T], - Map.empty[T, Set[TemplateCallInfo[T]]], - Map.empty[(T, App[T]), Set[TemplateAppInfo[T]]]) - - implicit class MapSetWrapper[A,B](map: Map[A,Set[B]]) { - def merge(that: Map[A,Set[B]]): Map[A,Set[B]] = (map.keys ++ that.keys).map { k => - k -> (map.getOrElse(k, Set.empty) ++ that.getOrElse(k, Set.empty)) - }.toMap - } - - implicit class MapSeqWrapper[A,B](map: Map[A,Seq[B]]) { - def merge(that: Map[A,Seq[B]]): Map[A,Seq[B]] = (map.keys ++ that.keys).map { k => - k -> (map.getOrElse(k, Seq.empty) ++ that.getOrElse(k, Seq.empty)).distinct - }.toMap - } - - implicit class InstantiationWrapper[T](i: Instantiation[T]) { - def ++(that: Instantiation[T]): Instantiation[T] = { - val (thisClauses, thisBlockers, thisApps) = i - val (thatClauses, thatBlockers, thatApps) = that - - (thisClauses ++ thatClauses, thisBlockers merge thatBlockers, thisApps merge thatApps) - } - - def withClause(cl: T): Instantiation[T] = (i._1 :+ cl, i._2, i._3) - def withClauses(cls: Seq[T]): Instantiation[T] = (i._1 ++ cls, i._2, i._3) - - def withCalls(calls: CallBlockers[T]): Instantiation[T] = (i._1, i._2 merge calls, i._3) - def withApps(apps: AppBlockers[T]): Instantiation[T] = (i._1, i._2, i._3 merge apps) - def withApp(app: ((T, App[T]), TemplateAppInfo[T])): Instantiation[T] = - (i._1, i._2, i._3 merge Map(app._1 -> Set(app._2))) - } -} - -import Instantiation.{empty => _, _} -import Template.{Apps, Calls, Functions, Arg} - -/** Pre-compiled sets of clauses with extra bookkeeping information that enables - * efficient unfolding of function calls and applications. - * [[Template]] is a super-type for all such clause sets that can be instantiated - * given a concrete argument list and a blocker in the decision-tree. - */ -trait Template[T] { self => - val encoder : TemplateEncoder[T] - val manager : TemplateManager[T] - - val pathVar : (Identifier, T) - val args : Seq[T] - - val condVars : Map[Identifier, T] - val exprVars : Map[Identifier, T] - val condTree : Map[Identifier, Set[Identifier]] - - val clauses : Clauses[T] - val blockers : Calls[T] - val applications : Apps[T] - val functions : Functions[T] - val lambdas : Seq[LambdaTemplate[T]] - - val quantifications : Seq[QuantificationTemplate[T]] - val matchers : Map[T, Set[Matcher[T]]] - - lazy val start = pathVar._2 - - def instantiate(aVar: T, args: Seq[Arg[T]]): Instantiation[T] = { - val (substMap, instantiation) = Template.substitution(encoder, manager, - condVars, exprVars, condTree, quantifications, lambdas, functions, - (this.args zip args).toMap + (start -> Left(aVar)), pathVar._1, aVar) - instantiation ++ instantiate(substMap) - } - - protected def instantiate(substMap: Map[T, Arg[T]]): Instantiation[T] = { - Template.instantiate(encoder, manager, clauses, - blockers, applications, matchers, substMap) - } - - override def toString : String = "Instantiated template" -} - -object Template { - - type Arg[T] = Either[T, Matcher[T]] - - implicit class ArgWrapper[T](arg: Arg[T]) { - def encoded: T = arg match { - case Left(value) => value - case Right(matcher) => matcher.encoded - } - - def substitute(substituter: T => T, matcherSubst: Map[T, Matcher[T]]): Arg[T] = arg match { - case Left(v) => matcherSubst.get(v) match { - case Some(m) => Right(m) - case None => Left(substituter(v)) - } - case Right(m) => Right(m.substitute(substituter, matcherSubst)) - } - } - - private def mkApplication(caller: Expr, args: Seq[Expr]): Expr = caller.getType match { - case FunctionType(from, to) => - val (curr, next) = args.splitAt(from.size) - mkApplication(Application(caller, curr), next) - case _ => - assert(args.isEmpty, s"Non-function typed $caller applied to ${args.mkString(",")}") - caller - } - - private def invocationMatcher[T](encodeExpr: Expr => T)(tfd: TypedFunDef, args: Seq[Expr]): Matcher[T] = { - assert(tfd.returnType.isInstanceOf[FunctionType], "invocationMatcher() is only defined on function-typed defs") - - def rec(e: Expr, args: Seq[Expr]): Expr = e.getType match { - case FunctionType(from, to) => - val (appArgs, outerArgs) = args.splitAt(from.size) - rec(Application(e, appArgs), outerArgs) - case _ if args.isEmpty => e - case _ => scala.sys.error("Should never happen") - } - - val (fiArgs, appArgs) = args.splitAt(tfd.params.size) - val app @ Application(caller, arguments) = rec(FunctionInvocation(tfd, fiArgs), appArgs) - Matcher(encodeExpr(caller), bestRealType(caller.getType), arguments.map(arg => Left(encodeExpr(arg))), encodeExpr(app)) - } - - type Calls[T] = Map[T, Set[TemplateCallInfo[T]]] - type Apps[T] = Map[T, Set[App[T]]] - type Functions[T] = Set[(T, FunctionType, T)] - - def encode[T]( - encoder: TemplateEncoder[T], - pathVar: (Identifier, T), - arguments: Seq[(Identifier, T)], - condVars: Map[Identifier, T], - exprVars: Map[Identifier, T], - guardedExprs: Map[Identifier, Seq[Expr]], - lambdas: Seq[LambdaTemplate[T]], - quantifications: Seq[QuantificationTemplate[T]], - substMap: Map[Identifier, T] = Map.empty[Identifier, T], - optCall: Option[TypedFunDef] = None, - optApp: Option[(T, FunctionType)] = None - ) : (Clauses[T], Calls[T], Apps[T], Functions[T], Map[T, Set[Matcher[T]]], () => String) = { - - val idToTrId : Map[Identifier, T] = - condVars ++ exprVars + pathVar ++ arguments ++ substMap ++ lambdas.map(_.ids) ++ quantifications.map(_.qs) - - val encodeExpr : Expr => T = encoder.encodeExpr(idToTrId) - - val (clauses, cleanGuarded, functions) = { - var functions: Set[(T, FunctionType, T)] = Set.empty - var clauses: Seq[T] = Seq.empty - - val cleanGuarded = guardedExprs.map { - case (b, es) => b -> es.map { e => - def clean(expr: Expr): Expr = postMap { - case FreshFunction(f) => Some(BooleanLiteral(true)) - case _ => None - } (expr) - - val withPaths = CollectorWithPaths { case FreshFunction(f) => f }.traverse(e) - functions ++= withPaths.map { case (f, path) => - val tpe = bestRealType(f.getType).asInstanceOf[FunctionType] - val cleanPath = path.map(clean) - (encodeExpr(and(Variable(b), cleanPath.toPath)), tpe, encodeExpr(f)) - } - - val cleanExpr = clean(e) - clauses :+= encodeExpr(Implies(Variable(b), cleanExpr)) - cleanExpr - } - } - - (clauses, cleanGuarded, functions) - } - - val optIdCall = optCall.map(tfd => TemplateCallInfo[T](tfd, arguments.map(p => Left(p._2)))) - val optIdApp = optApp.map { case (idT, tpe) => - val id = FreshIdentifier("x", tpe, true) - val encoded = encoder.encodeExpr(Map(id -> idT) ++ arguments)(mkApplication(Variable(id), arguments.map(_._1.toVariable))) - App(idT, bestRealType(tpe).asInstanceOf[FunctionType], arguments.map(p => Left(p._2)), encoded) - } - - lazy val invocMatcher = optCall.filter(_.returnType.isInstanceOf[FunctionType]) - .map(tfd => invocationMatcher(encodeExpr)(tfd, arguments.map(_._1.toVariable))) - - val (blockers, applications, matchers) = { - var blockers : Map[Identifier, Set[TemplateCallInfo[T]]] = Map.empty - var applications : Map[Identifier, Set[App[T]]] = Map.empty - var matchers : Map[Identifier, Set[Matcher[T]]] = Map.empty - - for ((b,es) <- cleanGuarded) { - var funInfos : Set[TemplateCallInfo[T]] = Set.empty - var appInfos : Set[App[T]] = Set.empty - var matchInfos : Set[Matcher[T]] = Set.empty - - for (e <- es) { - val exprToMatcher = fold[Map[Expr, Matcher[T]]] { (expr, res) => - val result = res.flatten.toMap - - result ++ (expr match { - case QuantificationMatcher(c, args) => - // Note that we rely here on the fact that foldRight visits the matcher's arguments first, - // so any Matcher in arguments will belong to the `result` map - val encodedArgs = args.map(arg => result.get(arg) match { - case Some(matcher) => Right(matcher) - case None => Left(encodeExpr(arg)) - }) - - Some(expr -> Matcher(encodeExpr(c), bestRealType(c.getType), encodedArgs, encodeExpr(expr))) - case _ => None - }) - }(e) - - def encodeArg(arg: Expr): Arg[T] = exprToMatcher.get(arg) match { - case Some(matcher) => Right(matcher) - case None => Left(encodeExpr(arg)) - } - - funInfos ++= firstOrderCallsOf(e).map(p => TemplateCallInfo(p._1, p._2.map(encodeArg))) - appInfos ++= firstOrderAppsOf(e).map { case (c, args) => - val tpe = bestRealType(c.getType).asInstanceOf[FunctionType] - App(encodeExpr(c), tpe, args.map(encodeArg), encodeExpr(mkApplication(c, args))) - } - - matchInfos ++= exprToMatcher.values - } - - val calls = funInfos.filter(i => Some(i) != optIdCall) - if (calls.nonEmpty) blockers += b -> calls - - val apps = appInfos.filter(i => Some(i) != optIdApp) - if (apps.nonEmpty) applications += b -> apps - - val matchs = (matchInfos.filter { case m @ Matcher(_, _, _, menc) => - !optIdApp.exists { case App(_, _, _, aenc) => menc == aenc } - } ++ (if (funInfos.exists(info => Some(info) == optIdCall)) invocMatcher else None)) - - if (matchs.nonEmpty) matchers += b -> matchs - } - - (blockers, applications, matchers) - } - - val encodedBlockers : Calls[T] = blockers.map(p => idToTrId(p._1) -> p._2) - val encodedApps : Apps[T] = applications.map(p => idToTrId(p._1) -> p._2) - val encodedMatchers : Map[T, Set[Matcher[T]]] = matchers.map(p => idToTrId(p._1) -> p._2) - - val stringRepr : () => String = () => { - " * Activating boolean : " + pathVar._1 + "\n" + - " * Control booleans : " + condVars.keys.mkString(", ") + "\n" + - " * Expression vars : " + exprVars.keys.mkString(", ") + "\n" + - " * Clauses : " + (if (cleanGuarded.isEmpty) "\n" else { - "\n " + (for ((b,es) <- cleanGuarded; e <- es) yield (b + " ==> " + e)).mkString("\n ") + "\n" - }) + - " * Invocation-blocks :" + (if (blockers.isEmpty) "\n" else { - "\n " + blockers.map(p => p._1 + " ==> " + p._2).mkString("\n ") + "\n" - }) + - " * Application-blocks :" + (if (applications.isEmpty) "\n" else { - "\n " + applications.map(p => p._1 + " ==> " + p._2).mkString("\n ") + "\n" - }) + - " * Matchers :" + (if (matchers.isEmpty) "\n" else { - "\n " + matchers.map(p => p._1 + " ==> " + p._2).mkString("\n ") + "\n" - }) + - " * Lambdas :\n" + lambdas.map { case template => - " +> " + template.toString.split("\n").mkString("\n ") + "\n" - }.mkString("\n") + - " * Foralls :\n" + quantifications.map { case template => - " +> " + template.toString.split("\n").mkString("\n ") + "\n" - }.mkString("\n") - } - - (clauses, encodedBlockers, encodedApps, functions, encodedMatchers, stringRepr) - } - - def substitution[T]( - encoder: TemplateEncoder[T], - manager: TemplateManager[T], - condVars: Map[Identifier, T], - exprVars: Map[Identifier, T], - condTree: Map[Identifier, Set[Identifier]], - quantifications: Seq[QuantificationTemplate[T]], - lambdas: Seq[LambdaTemplate[T]], - functions: Functions[T], - baseSubst: Map[T, Arg[T]], - pathVar: Identifier, - aVar: T - ): (Map[T, Arg[T]], Instantiation[T]) = { - - val freshSubst = exprVars.map { case (id, idT) => idT -> encoder.encodeId(id) } ++ - manager.freshConds(pathVar -> aVar, condVars, condTree) - val matcherSubst = baseSubst.collect { case (c, Right(m)) => c -> m } - - var subst = freshSubst.mapValues(Left(_)) ++ baseSubst - var instantiation : Instantiation[T] = Instantiation.empty - - manager match { - case lmanager: LambdaManager[T] => - val funSubstituter = encoder.substitute(subst.mapValues(_.encoded)) - for ((b,tpe,f) <- functions) { - instantiation ++= lmanager.registerFunction(funSubstituter(b), tpe, funSubstituter(f)) - } - - // /!\ CAREFUL /!\ - // We have to be wary while computing the lambda subst map since lambdas can - // depend on each other. However, these dependencies cannot be cyclic so it - // suffices to make sure the traversal order is correct. - var seen : Set[LambdaTemplate[T]] = Set.empty - - val lambdaKeys = lambdas.map(lambda => lambda.ids._2 -> lambda).toMap - def extractSubst(lambda: LambdaTemplate[T]): Unit = { - for { - dep <- lambda.structure.closures flatMap lambdaKeys.get - if !seen(dep) - } extractSubst(dep) - - if (!seen(lambda)) { - val substMap = subst.mapValues(_.encoded) - val substLambda = lambda.substitute(encoder.substitute(substMap), matcherSubst) - val (idT, inst) = lmanager.instantiateLambda(substLambda) - instantiation ++= inst - subst += lambda.ids._2 -> Left(idT) - seen += lambda - } - } - - for (l <- lambdas) extractSubst(l) - - case _ => - } - - manager match { - case qmanager: QuantificationManager[T] => - for (q <- quantifications) { - val substMap = subst.mapValues(_.encoded) - val substQuant = q.substitute(encoder.substitute(substMap), matcherSubst) - val (qT, inst) = qmanager.instantiateQuantification(substQuant) - instantiation ++= inst - subst += q.qs._2 -> Left(qT) - } - - case _ => - } - - (subst, instantiation) - } - - def instantiate[T]( - encoder: TemplateEncoder[T], - manager: TemplateManager[T], - clauses: Clauses[T], - blockers: Calls[T], - applications: Apps[T], - matchers: Map[T, Set[Matcher[T]]], - substMap: Map[T, Arg[T]] - ): Instantiation[T] = { - - val substituter : T => T = encoder.substitute(substMap.mapValues(_.encoded)) - val msubst = substMap.collect { case (c, Right(m)) => c -> m } - - val newClauses: Clauses[T] = clauses.map(substituter) - - val newBlockers: CallBlockers[T] = blockers.map { case (b,fis) => - substituter(b) -> fis.map(fi => fi.copy(args = fi.args.map(_.substitute(substituter, msubst)))) - } - - var instantiation: Instantiation[T] = (newClauses, newBlockers, Map.empty) - - manager match { - case lmanager: LambdaManager[T] => - for ((b,apps) <- applications; bp = substituter(b); app <- apps) { - instantiation ++= lmanager.instantiateApp(bp, app.substitute(substituter, msubst)) - } - - case _ => - } - - manager match { - case qmanager: QuantificationManager[T] => - for ((b, matchs) <- matchers; bp = substituter(b); m <- matchs) { - val newMatcher = m.substitute(substituter, msubst) - instantiation ++= qmanager.instantiateMatcher(bp, newMatcher) - } - - case _ => - } - - instantiation - } -} - -object FunctionTemplate { - - def apply[T]( - tfd: TypedFunDef, - encoder: TemplateEncoder[T], - manager: TemplateManager[T], - pathVar: (Identifier, T), - arguments: Seq[(Identifier, T)], - condVars: Map[Identifier, T], - exprVars: Map[Identifier, T], - condTree: Map[Identifier, Set[Identifier]], - guardedExprs: Map[Identifier, Seq[Expr]], - quantifications: Seq[QuantificationTemplate[T]], - lambdas: Seq[LambdaTemplate[T]], - isRealFunDef: Boolean - ) : FunctionTemplate[T] = { - - val (clauses, blockers, applications, functions, matchers, templateString) = - Template.encode(encoder, pathVar, arguments, condVars, exprVars, guardedExprs, lambdas, quantifications, - optCall = Some(tfd)) - - val funString : () => String = () => { - "Template for def " + tfd.signature + - "(" + tfd.params.map(a => a.id + " : " + a.getType).mkString(", ") + ") : " + - tfd.returnType + " is :\n" + templateString() - } - - new FunctionTemplate[T]( - tfd, - encoder, - manager, - pathVar, - arguments.map(_._2), - condVars, - exprVars, - condTree, - clauses, - blockers, - applications, - functions, - lambdas, - matchers, - quantifications, - isRealFunDef, - funString - ) - } -} - -class FunctionTemplate[T] private( - val tfd: TypedFunDef, - val encoder: TemplateEncoder[T], - val manager: TemplateManager[T], - val pathVar: (Identifier, T), - val args: Seq[T], - val condVars: Map[Identifier, T], - val exprVars: Map[Identifier, T], - val condTree: Map[Identifier, Set[Identifier]], - val clauses: Clauses[T], - val blockers: Calls[T], - val applications: Apps[T], - val functions: Functions[T], - val lambdas: Seq[LambdaTemplate[T]], - val matchers: Map[T, Set[Matcher[T]]], - val quantifications: Seq[QuantificationTemplate[T]], - isRealFunDef: Boolean, - stringRepr: () => String) extends Template[T] { - - private lazy val str : String = stringRepr() - override def toString : String = str -} - -class TemplateManager[T](protected[unrolling] val encoder: TemplateEncoder[T]) extends IncrementalState { - private val condImplies = new IncrementalMap[T, Set[T]].withDefaultValue(Set.empty) - private val condImplied = new IncrementalMap[T, Set[T]].withDefaultValue(Set.empty) - - protected def incrementals: List[IncrementalState] = List(condImplies, condImplied) - - def clear(): Unit = incrementals.foreach(_.clear()) - def reset(): Unit = incrementals.foreach(_.reset()) - def push(): Unit = incrementals.foreach(_.push()) - def pop(): Unit = incrementals.foreach(_.pop()) - - def freshConds(path: (Identifier, T), condVars: Map[Identifier, T], tree: Map[Identifier, Set[Identifier]]): Map[T, T] = { - val subst = condVars.map { case (id, idT) => idT -> encoder.encodeId(id) } - val mapping = condVars.mapValues(subst) + path - - for ((parent, children) <- tree; ep = mapping(parent); child <- children) { - val ec = mapping(child) - condImplies += ep -> (condImplies(ep) + ec) - condImplied += ec -> (condImplied(ec) + ep) - } - - subst - } - - def blocker(b: T): Unit = condImplies += (b -> Set.empty) - def isBlocker(b: T): Boolean = condImplies.isDefinedAt(b) || condImplied.isDefinedAt(b) - def blockerParents(b: T): Set[T] = condImplied(b) - def blockerChildren(b: T): Set[T] = condImplies(b) - - def implies(b1: T, b2: T): Unit = implies(b1, Set(b2)) - def implies(b1: T, b2s: Set[T]): Unit = { - val fb2s = b2s.filter(_ != b1) - condImplies += b1 -> (condImplies(b1) ++ fb2s) - for (b2 <- fb2s) { - condImplied += b2 -> (condImplies(b2) + b1) - } - } - -} diff --git a/src/main/scala/inox/solvers/unrolling/Templates.scala b/src/main/scala/inox/solvers/unrolling/Templates.scala new file mode 100644 index 0000000000000000000000000000000000000000..3679b9cbdbe4b429774a8fd3dd621a615135a7cf --- /dev/null +++ b/src/main/scala/inox/solvers/unrolling/Templates.scala @@ -0,0 +1,521 @@ +/* Copyright 2009-2016 EPFL, Lausanne */ + +package inox +package solvers +package unrolling + +import utils._ + +import scala.collection.generic.CanBuildFrom + +trait Templates extends TemplateGenerator + with FunctionTemplates + with DatatypeTemplates + with LambdaTemplates + with QuantificationTemplates + with IncrementalStateWrapper { + + val program: Program + import program._ + import program.trees._ + import program.symbols._ + + type Encoded <: Printable + + def encodeSymbol(v: Variable): Encoded + def encodeExpr(bindings: Map[Variable, Encoded])(e: Expr): Encoded + def mkSubstituter(map: Map[Encoded, Encoded]): Encoded => Encoded + + def mkNot(e: Encoded): Encoded + def mkOr(es: Encoded*): Encoded + def mkAnd(es: Encoded*): Encoded + def mkEquals(l: Encoded, r: Encoded): Encoded + def mkImplies(l: Encoded, r: Encoded): Encoded + + def extractNot(e: Encoded): Option[Encoded] + + private[unrolling] lazy val trueT = encodeExpr(Map.empty)(BooleanLiteral(true)) + + private var currentGen: Int = 0 + protected def currentGeneration: Int = currentGen + protected def nextGeneration(gen: Int): Int = gen + 5 + + trait Manager extends IncrementalStateWrapper { + def unrollGeneration: Option[Int] + + def unroll: Clauses + def satisfactionAssumptions: Seq[Encoded] + def refutationAssumptions: Seq[Encoded] + def promoteBlocker(b: Encoded): Boolean + } + + private val managers: Seq[Manager] = Seq( + functionsManager, + datatypesManager, + lambdasManager, + quantificationsManager + ) + + def canUnroll: Boolean = managers.exists(_.unrollGeneration.isDefined) + def unroll: Clauses = { + assert(canUnroll, "Impossible to unroll further") + val generation = managers.flatMap(_.unrollGeneration).min + if (generation > currentGen) { + currentGen = generation + } + + managers.flatMap(_.unroll) + } + + def satisfactionAssumptions = managers.flatMap(_.satisfactionAssumptions) + def refutationAssumptions = managers.flatMap(_.refutationAssumptions) + + private val condImplies = new IncrementalMap[Encoded, Set[Encoded]].withDefaultValue(Set.empty) + private val condImplied = new IncrementalMap[Encoded, Set[Encoded]].withDefaultValue(Set.empty) + + val incrementals: Seq[IncrementalState] = managers ++ Seq(condImplies, condImplied) + + protected def freshConds( + path: (Variable, Encoded), + condVars: Map[Variable, Encoded], + tree: Map[Variable, Set[Variable]]): Map[Encoded, Encoded] = { + + val subst = condVars.map { case (v, idT) => idT -> encodeSymbol(v) } + val mapping = condVars.mapValues(subst) + path + + for ((parent, children) <- tree; ep = mapping(parent); child <- children) { + val ec = mapping(child) + condImplies += ep -> (condImplies(ep) + ec) + condImplied += ec -> (condImplied(ec) + ep) + } + + subst + } + + protected def blocker(b: Encoded): Unit = condImplies += (b -> Set.empty) + protected def isBlocker(b: Encoded): Boolean = condImplies.isDefinedAt(b) || condImplied.isDefinedAt(b) + protected def blockerParents(b: Encoded): Set[Encoded] = condImplied(b) + protected def blockerChildren(b: Encoded): Set[Encoded] = condImplies(b) + + protected def impliesBlocker(b1: Encoded, b2: Encoded): Unit = impliesBlocker(b1, Set(b2)) + protected def impliesBlocker(b1: Encoded, b2s: Set[Encoded]): Unit = { + val fb2s = b2s.filter(_ != b1) + condImplies += b1 -> (condImplies(b1) ++ fb2s) + for (b2 <- fb2s) { + condImplied += b2 -> (condImplies(b2) + b1) + } + } + + def promoteBlocker(b: Encoded, force: Boolean = false): Boolean = { + var seen: Set[Encoded] = Set.empty + var promoted: Boolean = false + var blockers: Seq[Set[Encoded]] = Seq(Set(b)) + + do { + val (bs +: rest) = blockers + blockers = rest + + val next = (for (b <- bs if !seen(b)) yield { + seen += b + + for (manager <- managers) { + val p = manager.promoteBlocker(b) + promoted = promoted || p + } + + if (force) { + blockerChildren(b) + } else { + Seq.empty[Encoded] + } + }).flatten + + if (next.nonEmpty) blockers :+= next + } while (!promoted && blockers.nonEmpty) + + promoted + } + + implicit val debugSection = DebugSectionSolver + + type Arg = Either[Encoded, Matcher] + + implicit class ArgWrapper(arg: Arg) { + def encoded: Encoded = arg match { + case Left(value) => value + case Right(matcher) => matcher.encoded + } + + def substitute(substituter: Encoded => Encoded, msubst: Map[Encoded, Matcher]): Arg = arg match { + case Left(v) => msubst.get(v) match { + case Some(m) => Right(m) + case None => Left(substituter(v)) + } + case Right(m) => Right(m.substitute(substituter, msubst)) + } + } + + /** Represents a named function call in the unfolding procedure */ + case class Call(tfd: TypedFunDef, args: Seq[Arg]) { + override def toString = { + tfd.signature + args.map { + case Right(m) => m.toString + case Left(v) => v.asString + }.mkString("(", ", ", ")") + } + + def substitute(substituter: Encoded => Encoded, msubst: Map[Encoded, Matcher]): Call = copy( + args = args.map(_.substitute(substituter, msubst)) + ) + } + + /** Represents an application of a first-class function in the unfolding procedure */ + case class App(caller: Encoded, tpe: FunctionType, args: Seq[Arg], encoded: Encoded) { + override def toString: String = + "(" + caller.asString + " : " + tpe.asString + ")" + args.map(_.encoded.asString).mkString("(", ",", ")") + + def substitute(substituter: Encoded => Encoded, msubst: Map[Encoded, Matcher]): App = copy( + caller = substituter(caller), + args = args.map(_.substitute(substituter, msubst)), + encoded = substituter(encoded) + ) + } + + /** Represents an E-matching matcher that will be used to instantiate relevant quantified propositions */ + case class Matcher(caller: Encoded, tpe: Type, args: Seq[Arg], encoded: Encoded) { + override def toString: String = caller.asString + args.map { + case Right(m) => m.toString + case Left(v) => v.asString + }.mkString("(", ",", ")") + + def substitute(substituter: Encoded => Encoded, msubst: Map[Encoded, Matcher]): Matcher = copy( + caller = substituter(caller), + args = args.map(_.substitute(substituter, msubst)), + encoded = substituter(encoded) + ) + } + + + /** Template instantiations + * + * [[Template]] instances, when provided with concrete arguments and a + * blocker, will generate three outputs used for program unfolding: + * - clauses: clauses that will be added to the underlying solver + * - call blockers: bookkeeping information necessary for named + * function unfolding + * - app blockers: bookkeeping information necessary for first-class + * function unfolding + * + * This object provides helper methods to deal with the triplets + * generated during unfolding. + */ + type Clauses = Seq[Encoded] + type CallBlockers = Map[Encoded, Set[Call]] + type AppBlockers = Map[(Encoded, App), Set[TemplateAppInfo]] + + implicit class MapSetWrapper[A,B](map: Map[A,Set[B]]) { + def merge(that: Map[A,Set[B]]): Map[A,Set[B]] = (map.keys ++ that.keys).map { k => + k -> (map.getOrElse(k, Set.empty) ++ that.getOrElse(k, Set.empty)) + }.toMap + } + + implicit class MapSeqWrapper[A,B](map: Map[A,Seq[B]]) { + def merge(that: Map[A,Seq[B]]): Map[A,Seq[B]] = (map.keys ++ that.keys).map { k => + k -> (map.getOrElse(k, Seq.empty) ++ that.getOrElse(k, Seq.empty)).distinct + }.toMap + } + + /** Abstract templates + * + * Pre-compiled sets of clauses with extra bookkeeping information that enables + * efficient unfolding of function calls and applications. + * [[Template]] is a super-type for all such clause sets that can be instantiated + * given a concrete argument list and a blocker in the decision-tree. + */ + type Apps = Map[Encoded, Set[App]] + type Calls = Map[Encoded, Set[Call]] + type Matchers = Map[Encoded, Set[Matcher]] + + trait Template { self => + val pathVar : (Variable, Encoded) + val args : Seq[Encoded] + + val condVars : Map[Variable, Encoded] + val exprVars : Map[Variable, Encoded] + val condTree : Map[Variable, Set[Variable]] + + val clauses : Clauses + val blockers : Calls + val applications : Apps + val matchers : Matchers + + val lambdas : Seq[LambdaTemplate] + val quantifications : Seq[QuantificationTemplate] + + lazy val start = pathVar._2 + + def instantiate(aVar: Encoded, args: Seq[Arg]): Clauses = { + val (substMap, clauses) = Template.substitution( + condVars, exprVars, condTree, lambdas, quantifications, + (this.args zip args).toMap + (start -> Left(aVar)), pathVar._1, aVar) + clauses ++ instantiate(substMap) + } + + protected def instantiate(substMap: Map[Encoded, Arg]): Clauses = + Template.instantiate(clauses, blockers, applications, matchers, substMap) + + override def toString : String = "Instantiated template" + } + + object Template { + private def mkApplication(caller: Expr, args: Seq[Expr]): Expr = caller.getType match { + case FunctionType(from, to) => + val (curr, next) = args.splitAt(from.size) + mkApplication(Application(caller, curr), next) + case _ => + assert(args.isEmpty, s"Non-function typed $caller applied to ${args.mkString(",")}") + caller + } + + private def invocationMatcher(encoder: Expr => Encoded)(tfd: TypedFunDef, args: Seq[Expr]): Matcher = { + assert(tfd.returnType.isInstanceOf[FunctionType], "invocationMatcher() is only defined on function-typed defs") + + def rec(e: Expr, args: Seq[Expr]): Expr = e.getType match { + case FunctionType(from, to) => + val (appArgs, outerArgs) = args.splitAt(from.size) + rec(Application(e, appArgs), outerArgs) + case _ if args.isEmpty => e + case _ => scala.sys.error("Should never happen") + } + + val (fiArgs, appArgs) = args.splitAt(tfd.params.size) + val app @ Application(caller, arguments) = rec(tfd.applied(fiArgs), appArgs) + Matcher(encoder(caller), bestRealType(caller.getType), arguments.map(arg => Left(encoder(arg))), encoder(app)) + } + + def encode( + pathVar: (Variable, Encoded), + arguments: Seq[(Variable, Encoded)], + condVars: Map[Variable, Encoded], + exprVars: Map[Variable, Encoded], + guardedExprs: Map[Variable, Seq[Expr]], + lambdas: Seq[LambdaTemplate], + quantifications: Seq[QuantificationTemplate], + substMap: Map[Variable, Encoded] = Map.empty[Variable, Encoded], + optCall: Option[TypedFunDef] = None, + optApp: Option[(Encoded, FunctionType)] = None + ) : (Clauses, Calls, Apps, Matchers, () => String) = { + + val idToTrId : Map[Variable, Encoded] = + condVars ++ exprVars + pathVar ++ arguments ++ substMap ++ lambdas.map(_.ids) ++ quantifications.map(_.qs) + + val encoder : Expr => Encoded = encodeExpr(idToTrId) + + val optIdCall = optCall.map(tfd => Call(tfd, arguments.map(p => Left(p._2)))) + val optIdApp = optApp.map { case (idT, tpe) => + val v = Variable(FreshIdentifier("x", true), tpe) + val encoded = encodeExpr(Map(v -> idT) ++ arguments)(mkApplication(v, arguments.map(_._1))) + App(idT, bestRealType(tpe).asInstanceOf[FunctionType], arguments.map(p => Left(p._2)), encoded) + } + + lazy val invocMatcher = optCall.filter(_.returnType.isInstanceOf[FunctionType]) + .map(tfd => invocationMatcher(encoder)(tfd, arguments.map(_._1))) + + val (clauses, blockers, applications, matchers) = { + var clauses : Clauses = Seq.empty + var blockers : Map[Variable, Set[Call]] = Map.empty + var applications : Map[Variable, Set[App]] = Map.empty + var matchers : Map[Variable, Set[Matcher]] = Map.empty + + for ((b,es) <- guardedExprs) { + var funInfos : Set[Call] = Set.empty + var appInfos : Set[App] = Set.empty + var matchInfos : Set[Matcher] = Set.empty + + for (e <- es) { + val exprToMatcher = exprOps.fold[Map[Expr, Matcher]] { (expr, res) => + val result = res.flatten.toMap + + result ++ (expr match { + case QuantificationMatcher(c, args) => + // Note that we rely here on the fact that foldRight visits the matcher's arguments first, + // so any Matcher in arguments will belong to the `result` map + val encodedArgs = args.map(arg => result.get(arg) match { + case Some(matcher) => Right(matcher) + case None => Left(encoder(arg)) + }) + + Some(expr -> Matcher(encoder(c), bestRealType(c.getType), encodedArgs, encoder(expr))) + case _ => None + }) + }(e) + + def encodeArg(arg: Expr): Arg = exprToMatcher.get(arg) match { + case Some(matcher) => Right(matcher) + case None => Left(encoder(arg)) + } + + funInfos ++= firstOrderCallsOf(e).map { case (id, tps, args) => + Call(getFunction(id, tps), args.map(encodeArg)) + } + + appInfos ++= firstOrderAppsOf(e).map { case (c, args) => + val tpe = bestRealType(c.getType).asInstanceOf[FunctionType] + App(encoder(c), tpe, args.map(encodeArg), encoder(mkApplication(c, args))) + } + + matchInfos ++= exprToMatcher.values + clauses :+= encoder(Implies(b, e)) + } + + val calls = funInfos.filter(i => Some(i) != optIdCall) + if (calls.nonEmpty) blockers += b -> calls + + val apps = appInfos.filter(i => Some(i) != optIdApp) + if (apps.nonEmpty) applications += b -> apps + + val matchs = (matchInfos.filter { case m @ Matcher(_, _, _, menc) => + !optIdApp.exists { case App(_, _, _, aenc) => menc == aenc } + } ++ (if (funInfos.exists(info => Some(info) == optIdCall)) invocMatcher else None)) + + if (matchs.nonEmpty) matchers += b -> matchs + } + + (clauses, blockers, applications, matchers) + } + + val encodedBlockers : Calls = blockers.map(p => idToTrId(p._1) -> p._2) + val encodedApps : Apps = applications.map(p => idToTrId(p._1) -> p._2) + val encodedMatchers : Matchers = matchers.map(p => idToTrId(p._1) -> p._2) + + val stringRepr : () => String = () => { + " * Activating boolean : " + pathVar._1 + "\n" + + " * Control booleans : " + condVars.keys.mkString(", ") + "\n" + + " * Expression vars : " + exprVars.keys.mkString(", ") + "\n" + + " * Clauses : " + (if (guardedExprs.isEmpty) "\n" else { + "\n " + (for ((b,es) <- guardedExprs; e <- es) yield (b + " ==> " + e)).mkString("\n ") + "\n" + }) + + " * Invocation-blocks :" + (if (blockers.isEmpty) "\n" else { + "\n " + blockers.map(p => p._1 + " ==> " + p._2).mkString("\n ") + "\n" + }) + + " * Application-blocks :" + (if (applications.isEmpty) "\n" else { + "\n " + applications.map(p => p._1 + " ==> " + p._2).mkString("\n ") + "\n" + }) + + " * Matchers :" + (if (matchers.isEmpty) "\n" else { + "\n " + matchers.map(p => p._1 + " ==> " + p._2).mkString("\n ") + "\n" + }) + + " * Lambdas :\n" + lambdas.map { case template => + " +> " + template.toString.split("\n").mkString("\n ") + "\n" + }.mkString("\n") + + " * Foralls :\n" + quantifications.map { case template => + " +> " + template.toString.split("\n").mkString("\n ") + "\n" + }.mkString("\n") + } + + (clauses, encodedBlockers, encodedApps, encodedMatchers, stringRepr) + } + + def substitution( + condVars: Map[Variable, Encoded], + exprVars: Map[Variable, Encoded], + condTree: Map[Variable, Set[Variable]], + lambdas: Seq[LambdaTemplate], + quantifications: Seq[QuantificationTemplate], + baseSubst: Map[Encoded, Arg], + pathVar: Variable, + aVar: Encoded + ): (Map[Encoded, Arg], Clauses) = { + + val freshSubst = exprVars.map { case (v, vT) => vT -> encodeSymbol(v) } ++ + freshConds(pathVar -> aVar, condVars, condTree) + val matcherSubst = baseSubst.collect { case (c, Right(m)) => c -> m } + + var subst = freshSubst.mapValues(Left(_)) ++ baseSubst + var clauses : Clauses = Seq.empty + + // /!\ CAREFUL /!\ + // We have to be wary while computing the lambda subst map since lambdas can + // depend on each other. However, these dependencies cannot be cyclic so it + // suffices to make sure the traversal order is correct. + var seen : Set[LambdaTemplate] = Set.empty + + val lambdaKeys = lambdas.map(lambda => lambda.ids._2 -> lambda).toMap + def extractSubst(lambda: LambdaTemplate): Unit = { + for { + dep <- lambda.structure.closures flatMap lambdaKeys.get + if !seen(dep) + } extractSubst(dep) + + if (!seen(lambda)) { + val substMap = subst.mapValues(_.encoded) + val substLambda = lambda.substitute(mkSubstituter(substMap), matcherSubst) + val (idT, cls) = instantiateLambda(substLambda) + clauses ++= cls + subst += lambda.ids._2 -> Left(idT) + seen += lambda + } + } + + for (l <- lambdas) extractSubst(l) + + for (q <- quantifications) { + val substMap = subst.mapValues(_.encoded) + val substQuant = q.substitute(mkSubstituter(substMap), matcherSubst) + val (qT, cls) = instantiateQuantification(substQuant) + clauses ++= cls + subst += q.qs._2 -> Left(qT) + } + + (subst, clauses) + } + + def instantiate( + clauses: Clauses, + calls: Calls, + apps: Apps, + matchers: Matchers, + substMap: Map[Encoded, Arg] + ): Clauses = { + + val substituter : Encoded => Encoded = mkSubstituter(substMap.mapValues(_.encoded)) + val msubst = substMap.collect { case (c, Right(m)) => c -> m } + + val allClauses = new scala.collection.mutable.ListBuffer[Encoded] + allClauses ++= clauses.map(substituter) + + for ((b, fis) <- calls; bp = substituter(b); fi <- fis) { + allClauses ++= instantiateCall(bp, fi.substitute(substituter, msubst)) + } + + for ((b,fas) <- apps; bp = substituter(b); fa <- fas) { + allClauses ++= instantiateApp(bp, fa.substitute(substituter, msubst)) + } + + for ((b, matchs) <- matchers; bp = substituter(b); m <- matchs) { + allClauses ++= instantiateMatcher(bp, m.substitute(substituter, msubst)) + } + + allClauses.toSeq + } + } + + def instantiateExpr(expr: Expr): Clauses = { + val subst = exprOps.variablesOf(expr).map(v => v -> encodeSymbol(v)).toMap + val start = Variable(FreshIdentifier("start", true), BooleanType) + val encodedStart = encodeSymbol(start) + + val tpeClauses = subst.flatMap { case (v, s) => registerSymbol(encodedStart, s, v.getType) }.toSeq + + val (condVars, exprVars, condTree, guardedExprs, lambdas, quants) = + mkClauses(start, expr, subst + (start -> encodedStart)) + + val (clauses, calls, apps, matchers, _) = Template.encode( + start -> encodedStart, subst.toSeq, condVars, exprVars, guardedExprs, lambdas, quants) + + val (substMap, substClauses) = Template.substitution( + condVars, exprVars, condTree, lambdas, quants, Map.empty, start, encodedStart) + + val templateClauses = Template.instantiate(clauses, calls, apps, matchers, Map.empty) + tpeClauses ++ substClauses ++ templateClauses + } +} diff --git a/src/main/scala/inox/solvers/unrolling/UnrollingBank.scala b/src/main/scala/inox/solvers/unrolling/UnrollingBank.scala deleted file mode 100644 index 2c38aca15800f15a70a80bd50ed5c9a2ddae9490..0000000000000000000000000000000000000000 --- a/src/main/scala/inox/solvers/unrolling/UnrollingBank.scala +++ /dev/null @@ -1,414 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package solvers -package unrolling - -import purescala.Common._ -import purescala.Expressions._ -import purescala.Types._ -import utils._ - -class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: TemplateGenerator[T]) extends IncrementalState { - implicit val debugSection = utils.DebugSectionSolver - implicit val ctx0 = ctx - val reporter = ctx.reporter - - private val encoder = templateGenerator.encoder - private val manager = templateGenerator.manager - - // Function instantiations have their own defblocker - private val defBlockers = new IncrementalMap[TemplateCallInfo[T], T]() - private val lambdaBlockers = new IncrementalMap[TemplateAppInfo[T], T]() - - // Keep which function invocation is guarded by which guard, - // also specify the generation of the blocker. - private val callInfos = new IncrementalMap[T, (Int, Int, T, Set[TemplateCallInfo[T]])]() - private val appInfos = new IncrementalMap[(T, App[T]), (Int, Int, T, T, Set[TemplateAppInfo[T]])]() - private val appBlockers = new IncrementalMap[(T, App[T]), T]() - private val blockerToApps = new IncrementalMap[T, (T, App[T])]() - - def push() { - callInfos.push() - defBlockers.push() - lambdaBlockers.push() - appInfos.push() - appBlockers.push() - blockerToApps.push() - } - - def pop() { - callInfos.pop() - defBlockers.pop() - lambdaBlockers.pop() - appInfos.pop() - appBlockers.pop() - blockerToApps.pop() - } - - def clear() { - callInfos.clear() - defBlockers.clear() - lambdaBlockers.clear() - appInfos.clear() - appBlockers.clear() - blockerToApps.clear() - } - - def reset() { - callInfos.reset() - defBlockers.reset() - lambdaBlockers.reset() - appInfos.reset() - appBlockers.reset() - blockerToApps.clear() - } - - def dumpBlockers() = { - val generations = (callInfos.map(_._2._1).toSet ++ appInfos.map(_._2._1).toSet).toSeq.sorted - - generations.foreach { generation => - reporter.debug("--- " + generation) - - for ((b, (gen, origGen, ast, fis)) <- callInfos if gen == generation) { - reporter.debug(f". $b%15s ~> "+fis.mkString(", ")) - } - - for ((app, (gen, origGen, b, notB, infos)) <- appInfos if gen == generation) { - reporter.debug(f". $b%15s ~> "+infos.mkString(", ")) - } - } - } - - def satisfactionAssumptions = currentBlockers ++ manager.assumptions - - def refutationAssumptions = manager.assumptions - - def canUnroll = callInfos.nonEmpty || appInfos.nonEmpty - def canInstantiate = manager.hasIgnored - - def currentBlockers = callInfos.map(_._2._3).toSeq ++ appInfos.map(_._2._4).toSeq - - def getBlockersToUnlock: Seq[T] = { - if (callInfos.isEmpty && appInfos.isEmpty) { - Seq.empty - } else { - val minGeneration = (callInfos.values.map(_._1) ++ appInfos.values.map(_._1)).min - val callBlocks = callInfos.filter(_._2._1 == minGeneration).toSeq.map(_._1) - val appBlocks = appInfos.values.filter(_._1 == minGeneration).toSeq.map(_._3) - callBlocks ++ appBlocks - } - } - - def getFiniteRangeClauses: Seq[T] = manager.checkClauses - - private def registerCallBlocker(gen: Int, id: T, fis: Set[TemplateCallInfo[T]]) { - val notId = encoder.mkNot(id) - - callInfos.get(id) match { - case Some((exGen, origGen, _, exFis)) => - // PS: when recycling `b`s, this assertion becomes dangerous. - // It's better to simply take the max of the generations. - // assert(exGen == gen, "Mixing the same id "+id+" with various generations "+ exGen+" and "+gen) - - val minGen = gen min exGen - - callInfos += id -> (minGen, origGen, notId, fis++exFis) - case None => - callInfos += id -> (gen, gen, notId, fis) - } - } - - private def registerAppBlocker(gen: Int, app: (T, App[T]), info: Set[TemplateAppInfo[T]]) : Unit = { - appInfos.get(app) match { - case Some((exGen, origGen, b, notB, exInfo)) => - val minGen = gen min exGen - appInfos += app -> (minGen, origGen, b, notB, exInfo ++ info) - - case None => - val b = appBlockers.get(app) match { - case Some(b) => b - case None => encoder.encodeId(FreshIdentifier("b_lambda", BooleanType, true)) - } - - val notB = encoder.mkNot(b) - appInfos += app -> (gen, gen, b, notB, info) - blockerToApps += b -> app - } - } - - private def freshAppBlocks(apps: Traversable[(T, App[T])]) : Seq[T] = { - apps.filter(!appBlockers.isDefinedAt(_)).toSeq.map { - case app @ (blocker, App(caller, tpe, _, _)) => - val firstB = encoder.encodeId(FreshIdentifier("b_lambda", BooleanType, true)) - val clause = encoder.mkImplies(encoder.mkNot(firstB), encoder.mkNot(blocker)) - - appBlockers += app -> firstB - clause - } - } - - private def extendAppBlock(app: (T, App[T]), infos: Set[TemplateAppInfo[T]]) : T = { - assert(!appInfos.isDefinedAt(app), "appInfo -= app must have been called to ensure blocker freshness") - assert(appBlockers.isDefinedAt(app), "freshAppBlocks must have been called on app before it can be unlocked") - assert(infos.nonEmpty, "No point in extending blockers if no templates have been unrolled!") - - val nextB = encoder.encodeId(FreshIdentifier("b_lambda", BooleanType, true)) - val extension = encoder.mkOr((infos.map(_.equals).toSeq :+ nextB) : _*) - val clause = encoder.mkEquals(appBlockers(app), extension) - - appBlockers += app -> nextB - clause - } - - def getClauses(expr: Expr, bindings: Map[Identifier, T]): Seq[T] = { - // OK, now this is subtle. This `getTemplate` will return - // a template for a "fake" function. Now, this template will - // define an activating boolean... - val (template, mapping) = templateGenerator.mkTemplate(expr) - val reverse = mapping.map(p => p._2 -> p._1) - - val trArgs = template.tfd.params.map(vd => Left(bindings(reverse(vd.id)))) - - // ...now this template defines clauses that are all guarded - // by that activating boolean. If that activating boolean is - // undefined (or false) these clauses have no effect... - val (newClauses, callBlocks, appBlocks) = template.instantiate(template.start, trArgs) - - val blockClauses = freshAppBlocks(appBlocks.keys) - - for ((b, infos) <- callBlocks) { - registerCallBlocker(nextGeneration(0), b, infos) - } - - for ((app, infos) <- appBlocks) { - registerAppBlocker(nextGeneration(0), app, infos) - } - - // ...so we must force it to true! - val clauses = template.start +: (newClauses ++ blockClauses) - - reporter.debug("Generating clauses for: " + expr.asString) - for (cls <- clauses) { - reporter.debug(" . " + cls.asString(ctx)) - } - - clauses - } - - def nextGeneration(gen: Int) = gen + 5 - - def decreaseAllGenerations() = { - for ((block, (gen, origGen, ast, infos)) <- callInfos) { - // We also decrease the original generation here - callInfos += block -> (math.max(1,gen-1), math.max(1,origGen-1), ast, infos) - } - - for ((app, (gen, origGen, b, notB, infos)) <- appInfos) { - appInfos += app -> (math.max(1,gen-1), math.max(1,origGen-1), b, notB, infos) - } - } - - def promoteBlocker(b: T, force: Boolean = false): Boolean = { - var seen: Set[T] = Set.empty - var promoted: Boolean = false - var blockers: Seq[Set[T]] = Seq(Set(b)) - - do { - val (bs +: rest) = blockers - blockers = rest - - val next = (for (b <- bs if !seen(b)) yield { - seen += b - - if (callInfos contains b) { - val (_, origGen, notB, fis) = callInfos(b) - - callInfos += b -> (1, origGen, notB, fis) - promoted = true - } - - if (blockerToApps contains b) { - val app = blockerToApps(b) - val (_, origGen, _, notB, infos) = appInfos(app) - - appInfos += app -> (1, origGen, b, notB, infos) - promoted = true - } - - if (force) { - templateGenerator.manager.blockerChildren(b) - } else { - Set.empty[T] - } - }).flatten - - if (next.nonEmpty) blockers :+= next - } while (!promoted && blockers.nonEmpty) - - promoted - } - - def instantiateQuantifiers(force: Boolean = false): Seq[T] = { - val (newExprs, callBlocks, appBlocks) = manager.instantiateIgnored(force) - val blockExprs = freshAppBlocks(appBlocks.keys) - - val gens = (callInfos.values.map(_._1) ++ appInfos.values.map(_._1)) - val gen = if (gens.nonEmpty) gens.min else 0 - - for ((b, newInfos) <- callBlocks) { - registerCallBlocker(nextGeneration(gen), b, newInfos) - } - - for ((newApp, newInfos) <- appBlocks) { - registerAppBlocker(nextGeneration(gen), newApp, newInfos) - } - - val clauses = newExprs ++ blockExprs - if (clauses.nonEmpty) { - reporter.debug("Instantiating ignored quantifiers ("+clauses.size+")") - for (cl <- clauses) { - reporter.debug(" . "+cl) - } - } - - clauses - } - - def unrollBehind(ids: Seq[T]): Seq[T] = { - assert(ids.forall(id => (callInfos contains id) || (blockerToApps contains id))) - - var newClauses : Seq[T] = Seq.empty - - val newCallInfos = ids.flatMap(id => callInfos.get(id).map(id -> _)) - callInfos --= ids - - val apps = ids.flatMap(id => blockerToApps.get(id)) - val thisAppInfos = apps.map(app => app -> { - val (gen, _, _, _, infos) = appInfos(app) - (gen, infos) - }) - - blockerToApps --= ids - appInfos --= apps - - for ((app, (_, infos)) <- thisAppInfos if infos.nonEmpty) { - val extension = extendAppBlock(app, infos) - reporter.debug(" -> extending lambda blocker: " + extension) - newClauses :+= extension - } - - var fastAppInfos : Seq[((T, App[T]), (Int, Set[TemplateAppInfo[T]]))] = Seq.empty - - for ((id, (gen, _, _, infos)) <- newCallInfos; info @ TemplateCallInfo(tfd, args) <- infos) { - var newCls = Seq[T]() - - val defBlocker = defBlockers.get(info) match { - case Some(defBlocker) => - // we already have defBlocker => f(args) = body - defBlocker - - case None => - // we need to define this defBlocker and link it to definition - val defBlocker = encoder.encodeId(FreshIdentifier("d", BooleanType)) - defBlockers += info -> defBlocker - - val template = templateGenerator.mkTemplate(tfd) - //reporter.debug(template) - - val (newExprs, callBlocks, appBlocks) = template.instantiate(defBlocker, args) - - // we handle obvious appBlocks in an immediate manner in order to increase - // performance for folds that just pass a lambda around to recursive calls - val (fastApps, nextApps) = appBlocks.partition(p => p._2.toSeq match { - case Seq(TemplateAppInfo(_, equals, _)) if equals == manager.trueT => true - case _ => false - }) - - fastAppInfos ++= fastApps.mapValues(gen -> _) - - val blockExprs = freshAppBlocks(nextApps.keys) - - for((b, newInfos) <- callBlocks) { - registerCallBlocker(nextGeneration(gen), b, newInfos) - } - - for ((app, newInfos) <- nextApps) { - registerAppBlocker(nextGeneration(gen), app, newInfos) - } - - newCls ++= newExprs - newCls ++= blockExprs - defBlocker - } - - // We connect it to the defBlocker: blocker => defBlocker - if (defBlocker != id) { - newCls :+= encoder.mkImplies(id, defBlocker) - manager.implies(id, defBlocker) - } - - reporter.debug("Unrolling behind "+info+" ("+newCls.size+")") - for (cl <- newCls) { - reporter.debug(" . "+cl) - } - - newClauses ++= newCls - } - - for ((app @ (b, _), (gen, infos)) <- thisAppInfos ++ fastAppInfos; - info @ TemplateAppInfo(tmpl, equals, args) <- infos; - template <- tmpl.left) { - var newCls = Seq.empty[T] - - val lambdaBlocker = lambdaBlockers.get(info) match { - case Some(lambdaBlocker) => lambdaBlocker - - case None => - val lambdaBlocker = encoder.encodeId(FreshIdentifier("d", BooleanType)) - lambdaBlockers += info -> lambdaBlocker - - val (newExprs, callBlocks, appBlocks) = template.instantiate(lambdaBlocker, args) - val blockExprs = freshAppBlocks(appBlocks.keys) - - for ((b, newInfos) <- callBlocks) { - registerCallBlocker(nextGeneration(gen), b, newInfos) - } - - for ((newApp, newInfos) <- appBlocks) { - registerAppBlocker(nextGeneration(gen), newApp, newInfos) - } - - newCls ++= newExprs - newCls ++= blockExprs - lambdaBlocker - } - - val enabler = if (equals == manager.trueT) b else encoder.mkAnd(equals, b) - newCls :+= encoder.mkImplies(enabler, lambdaBlocker) - manager.implies(b, lambdaBlocker) - - reporter.debug("Unrolling behind "+info+" ("+newCls.size+")") - for (cl <- newCls) { - reporter.debug(" . "+cl) - } - - newClauses ++= newCls - } - - reporter.debug(s" - ${newClauses.size} new clauses") - //context.reporter.ifDebug { debug => - // debug(s" - new clauses:") - // debug("@@@@") - // for (cl <- newClauses) { - // debug(""+cl) - // } - // debug("////") - //} - - //dumpBlockers - //readLine() - - newClauses - } -} diff --git a/src/main/scala/inox/solvers/unrolling/UnrollingSolver.scala b/src/main/scala/inox/solvers/unrolling/UnrollingSolver.scala index 2620bd5dcbec43cb4775a8e3dafe70776ccafd21..0e34ca2226c67caa2fcbfd9c4730f1c93d0e084c 100644 --- a/src/main/scala/inox/solvers/unrolling/UnrollingSolver.scala +++ b/src/main/scala/inox/solvers/unrolling/UnrollingSolver.scala @@ -1,101 +1,72 @@ /* Copyright 2009-2016 EPFL, Lausanne */ -package leon +package inox package solvers package unrolling -import purescala.Common._ -import purescala.Definitions._ -import purescala.Quantification._ -import purescala.Constructors._ -import purescala.Extractors._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Types._ -import purescala.TypeOps.bestRealType import utils._ import theories._ import evaluators._ -import Template._ - -trait UnrollingProcedure extends LeonComponent { - val name = "Unroll-P" - val description = "Leon Unrolling Procedure" - - val optUnrollFactor = LeonLongOptionDef("unrollfactor", "Number of unfoldings to perform in each unfold step", default = 1, "<PosInt>") - val optFeelingLucky = LeonFlagOptionDef("feelinglucky", "Use evaluator to find counter-examples early", false) - val optCheckModels = LeonFlagOptionDef("checkmodels", "Double-check counter-examples with evaluator", false) - val optUnrollCores = LeonFlagOptionDef("unrollcores", "Use unsat-cores to drive unfolding while remaining fair", false) - val optUseCodeGen = LeonFlagOptionDef("codegen", "Use compiled evaluator instead of interpreter", false) - val optAssumePre = LeonFlagOptionDef("assumepre", "Assume precondition holds (pre && f(x) = body) when unfolding", false) - val optPartialModels = LeonFlagOptionDef("partialmodels", "Extract domains for quantifiers and bounded first-class functions", false) - val optSilentErrors = LeonFlagOptionDef("silenterrors", "Fail silently into UNKNOWN when encountering an error", false) - - override val definedOptions: Set[LeonOptionDef[Any]] = - Set(optCheckModels, optFeelingLucky, optUseCodeGen, optUnrollCores, optAssumePre, optUnrollFactor, optPartialModels) -} -object UnrollingProcedure extends UnrollingProcedure - -trait AbstractUnrollingSolver[T] - extends UnrollingProcedure - with Solver - with EvaluatingSolver { - - val unfoldFactor = context.findOptionOrDefault(optUnrollFactor) - val feelingLucky = context.findOptionOrDefault(optFeelingLucky) - val checkModels = context.findOptionOrDefault(optCheckModels) - val useCodeGen = context.findOptionOrDefault(optUseCodeGen) - val unrollUnsatCores = context.findOptionOrDefault(optUnrollCores) - val assumePreHolds = context.findOptionOrDefault(optAssumePre) - val partialModels = context.findOptionOrDefault(optPartialModels) - val silentErrors = context.findOptionOrDefault(optSilentErrors) - - protected var foundDefinitiveAnswer = false - protected var definitiveAnswer : Option[Boolean] = None - protected var definitiveModel : Model = Model.empty - protected var definitiveCore : Set[Expr] = Set.empty - - def check: Option[Boolean] = genericCheck(Set.empty) - - def getModel: Model = if (foundDefinitiveAnswer && definitiveAnswer.getOrElse(false)) { - definitiveModel - } else { - Model.empty +object optUnrollFactor extends InoxLongOptionDef("unrollfactor", "Number of unfoldings to perform in each unfold step", default = 1, "<PosInt>") +object optFeelingLucky extends InoxFlagOptionDef("feelinglucky", "Use evaluator to find counter-examples early", false) +object optUnrollCores extends InoxFlagOptionDef("unrollcores", "Use unsat-cores to drive unfolding while remaining fair", false) +object optAssumePre extends InoxFlagOptionDef("assumepre", "Assume precondition holds (pre && f(x) = body) when unfolding", false) + +trait AbstractUnrollingSolver + extends Solver { + + import program._ + import program.trees._ + import program.symbols._ + + val theories: TheoryEncoder + lazy val encodedProgram: Program { val trees: program.trees.type } = theories.encode(program) + + type Encoded + implicit val printable: Encoded => Printable + + val templates: Templates { + val program: encodedProgram.type + type Encoded = AbstractUnrollingSolver.this.Encoded } - def getUnsatCore: Set[Expr] = if (foundDefinitiveAnswer && !definitiveAnswer.getOrElse(true)) { - definitiveCore - } else { - Set.empty + val evaluator: DeterministicEvaluator with SolvingEvaluator { + val program: AbstractUnrollingSolver.this.program.type } + val unfoldFactor = options.findOptionOrDefault(optUnrollFactor) + val feelingLucky = options.findOptionOrDefault(optFeelingLucky) + val checkModels = options.findOptionOrDefault(optCheckModels) + val unrollUnsatCores = options.findOptionOrDefault(optUnrollCores) + val assumePreHolds = options.findOptionOrDefault(optAssumePre) + val silentErrors = options.findOptionOrDefault(optSilentErrors) + + def check(model: Boolean = false, cores: Boolean = false): SolverResponse = + checkAssumptions(model = model, cores = cores)(Set.empty) + private val constraints = new IncrementalSeq[Expr]() - private val freeVars = new IncrementalMap[Identifier, T]() + private val freeVars = new IncrementalMap[Variable, Encoded]() protected var interrupted : Boolean = false - lazy val templateGenerator = new TemplateGenerator(theoryEncoder, templateEncoder, assumePreHolds) - lazy val unrollingBank = new UnrollingBank(context, templateGenerator) - def push(): Unit = { - unrollingBank.push() + templates.push() constraints.push() freeVars.push() } def pop(): Unit = { - unrollingBank.pop() + templates.pop() constraints.pop() freeVars.pop() } override def reset() = { - foundDefinitiveAnswer = false interrupted = false - unrollingBank.reset() + templates.reset() constraints.reset() freeVars.reset() } @@ -108,12 +79,12 @@ trait AbstractUnrollingSolver[T] interrupted = false } - protected def declareVariable(id: Identifier): T + protected def declareVariable(v: Variable): Encoded def assertCnstr(expression: Expr): Unit = { constraints += expression - val bindings = variablesOf(expression).map(id => id -> freeVars.cached(id) { - declareVariable(theoryEncoder.encode(id)) + val bindings = exprOps.variablesOf(expression).map(v => v -> freeVars.cached(v) { + declareVariable(theories.encode(v)) }).toMap val newClauses = unrollingBank.getClauses(expression, bindings) @@ -122,29 +93,18 @@ trait AbstractUnrollingSolver[T] } } - def foundAnswer(res: Option[Boolean], model: Model = Model.empty, core: Set[Expr] = Set.empty) = { - foundDefinitiveAnswer = true - definitiveAnswer = res - definitiveModel = model - definitiveCore = core - } - - implicit val printable: T => Printable + def solverAssert(cnstr: Encoded): Unit - val theoryEncoder: TheoryEncoder - val templateEncoder: TemplateEncoder[T] - - def solverAssert(cnstr: T): Unit - - /** We define solverCheckAssumptions in CPS in order for solvers that don't - * support this feature to be able to use the provided [[solverCheck]] CPS - * construction. - */ - def solverCheckAssumptions[R](assumptions: Seq[T])(block: Option[Boolean] => R): R = - solverCheck(assumptions)(block) - - def checkAssumptions(assumptions: Set[Expr]): Option[Boolean] = - genericCheck(assumptions) + /** Simpler version of [[Solver.SolverResponses]] used internally in + * [[AbstractUnrollingSolver]] and children. + * + * These enable optimizations for the native Z3 solver (such as avoiding + * full Z3 model extraction in certain cases). + */ + protected sealed trait Response + protected case object Unknown extends Response + protected case class Unsat(cores: Option[Set[Encoded]]) extends Response + protected case class Sat(model: Option[ModelWrapper]) extends Response /** Provides CPS solver.check call. CPS is necessary in order for calls that * depend on solver.getModel to be able to access the model BEFORE the call @@ -165,27 +125,33 @@ trait AbstractUnrollingSolver[T] * This sequence of calls can also be used to mimic solver.checkAssumptions() * for solvers that don't support the construct natively. */ - def solverCheck[R](clauses: Seq[T])(block: Option[Boolean] => R): R + def solverCheck[R](clauses: Seq[Encoded], model: Boolean = false, cores: Boolean = false) + (block: Response => R): R - def solverUnsatCore: Option[Seq[T]] + /** We define solverCheckAssumptions in CPS in order for solvers that don't + * support this feature to be able to use the provided [[solverCheck]] CPS + * construction. + */ + def solverCheckAssumptions[R](assumptions: Seq[Encoded], model: Boolean = false, cores: Boolean = false) + (block: Response => R): R = solverCheck(assumptions)(block) trait ModelWrapper { - def modelEval(elem: T, tpe: TypeTree): Option[Expr] + def modelEval(elem: Encoded, tpe: Type): Option[Expr] - def eval(elem: T, tpe: TypeTree): Option[Expr] = modelEval(elem, theoryEncoder.encode(tpe)).flatMap { + def eval(elem: Encoded, tpe: Type): Option[Expr] = modelEval(elem, theories.encode(tpe)).flatMap { expr => try { - Some(theoryEncoder.decode(expr)(Map.empty)) + Some(theories.decode(expr)(Map.empty)) } catch { case u: Unsupported => None } } - def get(id: Identifier): Option[Expr] = eval(freeVars(id), theoryEncoder.encode(id.getType)).filter { - case Variable(_) => false + def get(v: Variable): Option[Expr] = eval(freeVars(v), theories.encode(id.getType)).filter { + case v: Variable => false case _ => true } - private[AbstractUnrollingSolver] def extract(b: T, m: Matcher[T]): Option[Seq[Expr]] = { + private[AbstractUnrollingSolver] def extract(b: Encoded, m: templates.Matcher): Option[Seq[Expr]] = { val QuantificationTypeMatcher(fromTypes, _) = m.tpe val optEnabler = eval(b, BooleanType) optEnabler.filter(_ == BooleanLiteral(true)).flatMap { _ => @@ -199,19 +165,16 @@ trait AbstractUnrollingSolver[T] } } - def solverGetModel: ModelWrapper - private def emit(silenceErrors: Boolean)(msg: String) = if (silenceErrors) reporter.debug(msg) else reporter.warning(msg) - private def extractModel(wrapper: ModelWrapper): Model = - new Model(freeVars.toMap.map(p => p._1 -> wrapper.get(p._1).getOrElse(simplestValue(p._1.getType)))) - - private def validateModel(model: Model, assumptions: Seq[Expr], silenceErrors: Boolean): Boolean = { + private def validateModel(model: ModelWrapper, assumptions: Seq[Expr], silenceErrors: Boolean): Boolean = { val expr = andJoin(assumptions ++ constraints) - val newExpr = model.toSeq.foldLeft(expr){ - case (e, (k, v)) => let(k, v, e) + // we have to check case class constructors in model for ADT invariants + val newExpr = freeVars.toSeq.foldLeft(expr) { case (e, (v, _)) => + val value = model.get(v).getOrElse(simplestValue(v.getType)) + let(v.toVal, value, e) } evaluator.eval(newExpr) match { @@ -233,256 +196,243 @@ trait AbstractUnrollingSolver[T] } } - private def getPartialModel: PartialModel = { - val wrapped = solverGetModel - val view = templateGenerator.manager.getModel(freeVars.toMap, evaluator, wrapped.get, wrapped.eval) - view.getPartialModel - } - private def getTotalModel: Model = { val wrapped = solverGetModel val view = templateGenerator.manager.getModel(freeVars.toMap, evaluator, wrapped.get, wrapped.eval) view.getTotalModel } - def genericCheck(assumptions: Set[Expr]): Option[Boolean] = { - foundDefinitiveAnswer = false + def checkAssumptions(model: Boolean = false, cores: Boolean = false)(assumptions: Set[Expr]) = { - // TODO: theory encoder for assumptions!? - val encoder = templateGenerator.encoder.encodeExpr(freeVars.toMap) _ - val assumptionsSeq : Seq[Expr] = assumptions.toSeq - val encodedAssumptions : Seq[T] = assumptionsSeq.map(encoder) - val encodedToAssumptions : Map[T, Expr] = (encodedAssumptions zip assumptionsSeq).toMap + val assumptionsSeq : Seq[Expr] = assumptions.toSeq + val encodedAssumptions : Seq[Encoded] = assumptionsSeq.map { expr => + val vars = exprOps.variablesOf(expr) + templates.encodeExpr(vars.map(v => theories.encode(v) -> freeVars(v)).toMap)(expr) + } + val encodedToAssumptions : Map[Encoded, Expr] = (encodedAssumptions zip assumptionsSeq).toMap - def encodedCoreToCore(core: Seq[T]): Set[Expr] = { + def encodedCoreToCore(core: Set[Encoded]): Set[Expr] = { core.flatMap(ast => encodedToAssumptions.get(ast) match { - case Some(n @ Not(Variable(_))) => Some(n) - case Some(v @ Variable(_)) => Some(v) + case Some(n @ Not(_: Variable)) => Some(n) + case Some(v: Variable) => Some(v) case _ => None - }).toSet + }) + } + + sealed abstract class CheckState + class CheckResult(val response: SolverResponses.SolverResponse) extends CheckState + case class Validate(resp: Sat) extends CheckState + case object ModelCheck extends CheckState + case object FiniteRangeCheck extends CheckState + case object InstantiateQuantifiers extends CheckState + case object ProofCheck extends CheckState + case object Unroll extends CheckState + + object CheckResult { + def apply(resp: SolverResponses.SolverResponse) = new CheckResult(resp) + def apply(resp: Response): CheckResult = new CheckResult(resp match { + case Unknown => SolverResponses.Unknown + case Sat(None) => SolverResponses.SatResponse + case Sat(Some(model)) => SolverResponses.SatResponseWithModel(model) + case Unsat(None) => SolverResponses.UnsatResponse + case Unsat(Some(core)) => SolverResponses.UnsatResponseWithCores(encodedCoreToCore(core)) + }) + def unapply(res: CheckResult): Option[SolverResponses.SolverResponse] = Some(res.response) } - while (!foundDefinitiveAnswer && !interrupted) { - reporter.debug(" - Running search...") - var quantify = false + object Abort { + def unapply(resp: Response): Boolean = resp == Unknown || interrupted + } - def check[R](clauses: Seq[T])(block: Option[Boolean] => R) = - if (partialModels || templateGenerator.manager.quantifications.isEmpty) - solverCheckAssumptions(clauses)(block) - else solverCheck(clauses)(block) + var currentState: CheckState = ModelCheck + while (!currentState.isInstanceOf[CheckResult]) { + currentState = currentState match { + case _ if interrupted => + CheckResult(Unknown) - val timer = context.timers.solvers.check.start() - check(encodedAssumptions.toSeq ++ unrollingBank.satisfactionAssumptions) { res => - timer.stop() + case ModelCheck => + reporter.debug(" - Running search...") - reporter.debug(" - Finished search with blocked literals") + val withModel = model && !templates.hasIgnored + val withCores = cores || unrollUnsatCores - res match { - case None => - foundAnswer(None) + val timer = ctx.timers.solvers.check.start() + solverCheckAssumptions( + encodedAssumptions.toSeq ++ templates.satisfactionAssumptions, + model = withModel, + cores = withCores + ) { res => + timer.stop() - case Some(true) => // SAT - val (stop, model) = if (interrupted) { - (true, Model.empty) - } else if (partialModels) { - (true, getPartialModel) - } else { - val clauses = unrollingBank.getFiniteRangeClauses - if (clauses.isEmpty) { - (true, extractModel(solverGetModel)) - } else { - reporter.debug(" - Verifying model transitivity") - - val timer = context.timers.solvers.check.start() - solverCheck(clauses) { res => - timer.stop() - - reporter.debug(" - Finished transitivity check") - - res match { - case Some(true) => - (true, getTotalModel) - - case Some(false) => - reporter.debug(" - Transitivity not guaranteed for model") - (false, Model.empty) - - case None => - reporter.warning(" - Unknown for transitivity!?") - (false, Model.empty) - } - } - } - } + reporter.debug(" - Finished search with blocked literals") - if (!interrupted) { - if (!stop) { - if (!unrollingBank.canInstantiate) { - reporter.error("Something went wrong. The model is not transitive yet we can't instantiate!?") - reporter.error(model.asString(context)) - foundAnswer(None, model) - } else { - quantify = true - } - } else { - val valid = !checkModels || validateModel(model, assumptionsSeq, silenceErrors = silentErrors) + res match { + case Abort() => + CheckResult(Unknown) - if (valid) { - foundAnswer(Some(true), model) - } else if (silentErrors) { - foundAnswer(None, model) + case sat: Sat => + if (templates.hasIgnored) { + FiniteRangeCheck } else { - reporter.error( - "Something went wrong. The model should have been valid, yet we got this: " + - model.asString(context) + - " for formula " + andJoin(assumptionsSeq ++ constraints).asString) - foundAnswer(None, model) + Validate(sat) } - } - } - if (interrupted) { - foundAnswer(None) - } + case _: Unsat if !templates.canUnroll => + CheckResult(res) - case Some(false) if !unrollingBank.canUnroll => - solverUnsatCore match { - case Some(core) => - val exprCore = encodedCoreToCore(core) - foundAnswer(Some(false), core = exprCore) - case None => - foundAnswer(Some(false)) - } + case Unsat(Some(cores)) if unrollUnsatCores => + for (c <- cores) templates.extractNot(c) match { + case Some(b) => templates.promoteBlocker(b) + case None => reporter.fatalError("Unexpected blocker polarity for unsat core unrolling: " + c) + } + ProofCheck - case Some(false) => - if (unrollUnsatCores) { - solverUnsatCore match { - case Some(core) => - unrollingBank.decreaseAllGenerations() - - for (c <- core) templateGenerator.encoder.extractNot(c) match { - case Some(b) => unrollingBank.promoteBlocker(b) - case None => reporter.fatalError("Unexpected blocker polarity for unsat core unrolling: " + c) - } - case None => - reporter.fatalError("Can't unroll unsat core for incompatible solver " + name) - } + case _ => + ProofCheck } - } - } + } - if (!quantify && !foundDefinitiveAnswer && !interrupted) { - if (feelingLucky) { - reporter.debug(" - Running search without blocked literals (w/ lucky test)") - } else { - reporter.debug(" - Running search without blocked literals (w/o lucky test)") - } + case FiniteRangeCheck => + reporter.debug(" - Verifying finite ranges") - val timer = context.timers.solvers.check.start() - solverCheckAssumptions(encodedAssumptions.toSeq ++ unrollingBank.refutationAssumptions) { res2 => - timer.stop() + val clauses = templates.getFiniteRangeClauses + val timer = ctx.timers.solvers.check.start() + solverCheck( + encodedAssumptions.toSeq ++ templates.satisfactionAssumptions ++ clauses, + model = model + ) { res => + timer.stop() - reporter.debug(" - Finished search without blocked literals") - - res2 match { - case Some(false) => - solverUnsatCore match { - case Some(core) => - val exprCore = encodedCoreToCore(core) - foundAnswer(Some(false), core = exprCore) - case None => - foundAnswer(Some(false)) - } - - case Some(true) => - if (!interrupted) { - val model = solverGetModel - - if (this.feelingLucky) { - // we might have been lucky :D - val extracted = extractModel(model) - val valid = validateModel(extracted, assumptionsSeq, silenceErrors = true) - if (valid) foundAnswer(Some(true), extracted) - } + reporter.debug(" - Finished checking finite ranges") - if (!foundDefinitiveAnswer) { - val promote = templateGenerator.manager.getBlockersToPromote(model.eval) - if (promote.nonEmpty) { - unrollingBank.decreaseAllGenerations() + res match { + case Abort() => + CheckResult(Unknown) - for (b <- promote) unrollingBank.promoteBlocker(b, force = true) - } - } - } + case sat: Sat => + Validate(sat) - case None => - foundAnswer(None) + case _ => + InstantiateQuantifiers + } } + + case Validate(sat) => sat match { + case Sat(None) => CheckResult(SolverResponses.SatResponse) + case Sat(Some(model)) => + val valid = !checkModels || validateModel(model, assumptionsSeq, silenceErrors = silentErrors) + if (valid) { + CheckResult(model) + } else { + reporter.error( + "Something went wrong. The model should have been valid, yet we got this: " + + model.asString + + " for formula " + andJoin(assumptionsSeq ++ constraints).asString) + CheckResult(Unknown) + } } - } - if (!foundDefinitiveAnswer && !interrupted) { - reporter.debug("- We need to keep going") + case InstantiateQuantifiers => + if (!templates.canUnfoldQuantifiers) { + reporter.error("Something went wrong. The model is not transitive yet we can't instantiate!?") + CheckResult(Unknown) + } else { + // TODO: promote ignored quantifiers! + Unroll + } + + case ProofCheck => + if (feelingLucky) { + reporter.debug(" - Running search without blocked literals (w/ lucky test)") + } else { + reporter.debug(" - Running search without blocked literals (w/o lucky test)") + } - reporter.debug(" - more instantiations") - val newClauses = unrollingBank.instantiateQuantifiers(force = quantify) + val timer = ctx.timers.solvers.check.start() + solverCheckAssumptions( + encodedAssumptions.toSeq ++ templates.refutationAssumptions, + model = feelingLucky, + cores = cores + ) { res => + timer.stop() - for (cls <- newClauses) { - solverAssert(cls) - } + reporter.debug(" - Finished search without blocked literals") - reporter.debug(" - finished instantiating") + res match { + case Abort() => + CheckResult(Unknown) - // unfolling `unfoldFactor` times - for (i <- 1 to unfoldFactor.toInt) { - val toRelease = unrollingBank.getBlockersToUnlock + case _: Unsat => + CheckResult(res) - reporter.debug(" - more unrollings") + case Sat(Some(model)) if feelingLucky => + if (validateModel(model, assumptionsSeq, silenceErrors = true)) { + CheckResult(res) + } else { + for { + (inst, bs) <- templates.getInstantiationsWithBlockers + if !model.isTrue(inst) + b <- bs + } templates.promoteBlocker(b, force = true) - val timer = context.timers.solvers.unroll.start() - val newClauses = unrollingBank.unrollBehind(toRelease) - timer.stop() + Unroll + } - for (ncl <- newClauses) { - solverAssert(ncl) + case _ => + Unroll + } } - } - reporter.debug(" - finished unrolling") + case Unroll => + reporter.debug("- We need to keep going") + + val timer = ctx.timers.solvers.unroll.start() + // unfolling `unfoldFactor` times + for (i <- 1 to unfoldFactor.toInt) { + val newClauses = templates.unroll + for (ncl <- newClauses) { + solverAssert(ncl) + } + } + timer.stop() + + reporter.debug(" - finished unrolling") + ModelCheck } } - if (interrupted) { - None - } else { - definitiveAnswer - } + val CheckResult(res) = currentState + res } } -class UnrollingSolver( - val sctx: SolverContext, - val program: Program, - underlying: Solver, - theories: TheoryEncoder = new NoEncoder -) extends AbstractUnrollingSolver[Expr] { +trait UnrollingSolver extends AbstractUnrollingSolver { + import program._ + import program.trees._ + import program.symbols._ - override val name = "U:"+underlying.name + type Encoded = Expr + val solver: Solver { val program: encodedProgram.type } + + override val name = "U:"+solver.name def free() { - underlying.free() + solver.free() } val printable = (e: Expr) => e - val templateEncoder = new TemplateEncoder[Expr] { - def encodeId(id: Identifier): Expr= Variable(id.freshen) - def encodeExpr(bindings: Map[Identifier, Expr])(e: Expr): Expr = { - replaceFromIDs(bindings, e) + val templates = new Templates { + val program = encodedProgram + type Encoded = Expr + + def encodeSymbol(v: Variable): Expr = v.freshen + def encodeExpr(bindings: Map[Variable, Expr])(e: Expr): Expr = { + exprOps.replaceFromSymbols(bindings, e) } def substitute(substMap: Map[Expr, Expr]): Expr => Expr = { - (e: Expr) => replace(substMap, e) + (e: Expr) => exprOps.replace(substMap, e) } def mkNot(e: Expr) = not(e) @@ -497,20 +447,27 @@ class UnrollingSolver( } } - val theoryEncoder = theories - - val solver = underlying - - def declareVariable(id: Identifier): Variable = id.toVariable + def declareVariable(v: Variable): Variable = v def solverAssert(cnstr: Expr): Unit = { solver.assertCnstr(cnstr) } - def solverCheck[R](clauses: Seq[Expr])(block: Option[Boolean] => R): R = { + case class Model(model: Map[ValDef, Expr]) extends ModelWrapper { + def modelEval(elem: Expr, tpe: Type): Option[Expr] = evaluator.eval(elem, model).result + override def toString = model.mkString("\n") + } + + def solverCheck[R](clauses: Seq[Expr], model: Boolean = false, cores: Boolean = false)(block: Response => R): R = { solver.push() for (cls <- clauses) solver.assertCnstr(cls) - val res = solver.check + val res = solver.check(model = model, cores = cores) match { + case solver.SolverResponses.Unknown => Unknown + case solver.SolverResponses.UnsatResponse => Unsat(None) + case solver.SolverResponses.UnsatResponseWithCores(cores) => Unsat(Some(cores)) + case solver.SolverResponses.SatResponse => Sat(None) + case solver.SolverResponses.SatResponseWithModel(model) => Sat(Some(Model(model))) + } val r = block(res) solver.pop() r @@ -520,11 +477,11 @@ class UnrollingSolver( def solverGetModel: ModelWrapper = new ModelWrapper { val model = solver.getModel - def modelEval(elem: Expr, tpe: TypeTree): Option[Expr] = evaluator.eval(elem, model).result + def modelEval(elem: Expr, tpe: Type): Option[Expr] = evaluator.eval(elem, model).result override def toString = model.toMap.mkString("\n") } - override def dbg(msg: => Any) = underlying.dbg(msg) + override def dbg(msg: => Any) = solver.dbg(msg) override def push(): Unit = { super.push() @@ -536,18 +493,8 @@ class UnrollingSolver( solver.pop() } - override def foundAnswer(res: Option[Boolean], model: Model = Model.empty, core: Set[Expr] = Set.empty) = { - super.foundAnswer(res, model, core) - - if (!interrupted && res.isEmpty && model.isEmpty) { - reporter.ifDebug { debug => - debug("Unknown result!?") - } - } - } - override def reset(): Unit = { - underlying.reset() + solver.reset() super.reset() } diff --git a/src/main/scala/inox/utils/IncrementalStateWrapper.scala b/src/main/scala/inox/utils/IncrementalStateWrapper.scala new file mode 100644 index 0000000000000000000000000000000000000000..ea7628ae3c47da87254fe80c6bbb6e53273a54a4 --- /dev/null +++ b/src/main/scala/inox/utils/IncrementalStateWrapper.scala @@ -0,0 +1,13 @@ +/* Copyright 2009-2016 EPFL, Lausanne */ + +package inox.utils + +trait IncrementalStateWrapper extends IncrementalState { + val incrementals: Seq[IncrementalState] + + def push(): Unit = incrementals.foreach(_.push()) + def pop(): Unit = incrementals.foreach(_.pop()) + + def clear(): Unit = incrementals.foreach(_.clear()) + def reset(): Unit = incrementals.foreach(_.reset()) +}