diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala index 041caba78a7d6eebe9cd6f81d3a5f1f56f3cd3ba..55ab79460e296e2b37c026261e29621488f6fa46 100644 --- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala +++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala @@ -313,19 +313,45 @@ trait CodeExtraction extends ASTExtractors { } private def fillLeonUnit(u: ScalaUnit): Unit = { + def extractClassMembers(sym: Symbol, tpl: Template): Unit = { + for (t <- tpl.body if !t.isEmpty) { + extractFunOrMethodBody(Some(sym), t) + } + + classToInvariants.get(sym).foreach { bodies => + val fd = new FunDef(invId, Seq.empty, Seq.empty, BooleanType) + fd.addFlag(IsADTInvariant) + + val cd = classesToClasses(sym) + cd.registerMethod(fd) + cd.addFlag(IsADTInvariant) + val ctparams = sym.tpe match { + case TypeRef(_, _, tps) => + extractTypeParams(tps).map(_._1) + case _ => + Nil + } + + val tparamsMap = (ctparams zip cd.tparams.map(_.tp)).toMap + val dctx = DefContext(tparamsMap) + + val body = andJoin(bodies.toSeq.filter(_ != EmptyTree).map { + body => flattenBlocks(extractTreeOrNoTree(body)(dctx)) + }) + + fd.fullBody = body + } + } + for (t <- u.defs) t match { case t if isIgnored(t.symbol) => // ignore case ExAbstractClass(_, sym, tpl) => - for (t <- tpl.body if !t.isEmpty) { - extractFunOrMethodBody(Some(sym), t) - } + extractClassMembers(sym, tpl) case ExCaseClass(_, sym, _, tpl) => - for (t <- tpl.body if !t.isEmpty) { - extractFunOrMethodBody(Some(sym), t) - } + extractClassMembers(sym, tpl) case ExObjectDef(n, templ) => for (t <- templ.body if !t.isEmpty) t match { @@ -334,14 +360,10 @@ trait CodeExtraction extends ASTExtractors { None case ExAbstractClass(_, sym, tpl) => - for (t <- tpl.body if !t.isEmpty) { - extractFunOrMethodBody(Some(sym), t) - } + extractClassMembers(sym, tpl) case ExCaseClass(_, sym, _, tpl) => - for (t <- tpl.body if !t.isEmpty) { - extractFunOrMethodBody(Some(sym), t) - } + extractClassMembers(sym, tpl) case t => extractFunOrMethodBody(None, t) @@ -442,6 +464,7 @@ trait CodeExtraction extends ASTExtractors { private var isMethod = Set[Symbol]() private var methodToClass = Map[FunDef, LeonClassDef]() + private var classToInvariants = Map[Symbol, Set[Tree]]() /** * For the function in $defs with name $owner, find its parameter with index $index, @@ -543,8 +566,53 @@ trait CodeExtraction extends ASTExtractors { if (tpe != id.getType) println(tpe, id.getType) LeonValDef(id.setPos(t.pos)).setPos(t.pos) } + //println(s"Fields of $sym") ccd.setFields(fields) + + // checks whether this type definition could lead to an infinite type + def computeChains(tpe: LeonType): Map[TypeParameterDef, Set[LeonClassDef]] = { + var seen: Set[LeonClassDef] = Set.empty + var chains: Map[TypeParameterDef, Set[LeonClassDef]] = Map.empty + + def rec(tpe: LeonType): Set[LeonClassDef] = tpe match { + case ct: ClassType => + val root = ct.classDef.root + if (!seen(ct.classDef.root)) { + seen += ct.classDef.root + for (cct <- ct.root.knownCCDescendants; + (tp, tpe) <- cct.classDef.tparams zip cct.tps) { + val relevant = rec(tpe) + chains += tp -> (chains.getOrElse(tp, Set.empty) ++ relevant) + for (cd <- relevant; vd <- cd.fields) { + rec(vd.getType) + } + } + } + Set(root) + + case Types.NAryType(tpes, _) => + tpes.flatMap(rec).toSet + } + + rec(tpe) + chains + } + + val chains = computeChains(ccd.typed) + + def check(tp: TypeParameterDef, seen: Set[LeonClassDef]): Unit = chains.get(tp) match { + case Some(classDefs) => + if ((seen intersect classDefs).nonEmpty) { + outOfSubsetError(sym.pos, "Infinite types are not allowed") + } else { + for (cd <- classDefs; tp <- cd.tparams) check(tp, seen + cd) + } + case None => + } + + for (tp <- ccd.tparams) check(tp, Set.empty) + case _ => } @@ -568,6 +636,9 @@ trait CodeExtraction extends ASTExtractors { cd.registerMethod(fd) + case ExRequiredExpression(body) => + classToInvariants += sym -> (classToInvariants.getOrElse(sym, Set.empty) + body) + // Default values for parameters case t@ ExDefaultValueFunction(fsym, _, _, _, owner, index, _) => isMethod += fsym @@ -621,6 +692,8 @@ trait CodeExtraction extends ASTExtractors { } } + private val invId = FreshIdentifier("inv", BooleanType) + private var isLazy = Set[LeonValDef]() private var defsToDefs = Map[Symbol, FunDef]() diff --git a/src/main/scala/leon/purescala/CheckADTFieldsTypes.scala b/src/main/scala/leon/purescala/CheckADTFieldsTypes.scala deleted file mode 100644 index d4583e55311cfaf843983a3f8af70ac46f7b3675..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/purescala/CheckADTFieldsTypes.scala +++ /dev/null @@ -1,29 +0,0 @@ -/* Copyright 2009-2015 EPFL, Lausanne */ - -package leon -package purescala - -import Definitions._ -import TypeOps._ - -object CheckADTFieldsTypes extends UnitPhase[Program] { - - val name = "ADT Fields" - val description = "Check that fields of ADTs are hierarchy roots" - - def apply(ctx: LeonContext, program: Program) = { - program.definedClasses.foreach { - case ccd: CaseClassDef => - for(vd <- ccd.fields) { - val tpe = vd.getType - if (bestRealType(tpe) != tpe) { - ctx.reporter.warning(ccd.getPos, "Definition of "+ccd.id.asString(ctx)+" has a field of a sub-type ("+vd.asString(ctx)+"): " + - "this type is not supported as-is by solvers and will be up-cast. " + - "This may cause issues such as crashes.") - } - } - case _ => - } - } - -} diff --git a/src/main/scala/leon/purescala/Constructors.scala b/src/main/scala/leon/purescala/Constructors.scala index 3437d34efeef0d47a37eded5d85cb167d8ee2b6b..95218cb2857a97c6834ac039f20c426a0736be8f 100644 --- a/src/main/scala/leon/purescala/Constructors.scala +++ b/src/main/scala/leon/purescala/Constructors.scala @@ -131,7 +131,7 @@ object Constructors { */ def caseClassSelector(classType: CaseClassType, caseClass: Expr, selector: Identifier): Expr = { caseClass match { - case CaseClass(ct, fields) if ct.classDef == classType.classDef => + case CaseClass(ct, fields) if ct.classDef == classType.classDef && !ct.classDef.hasInvariant => fields(ct.classDef.selectorID2Index(selector)) case _ => CaseClassSelector(classType, caseClass, selector) diff --git a/src/main/scala/leon/purescala/DefOps.scala b/src/main/scala/leon/purescala/DefOps.scala index 78c2b93edecf113c717a625a3c974c8dcf43e474..856cdb37a21f3b73d07b03aa21ca62affd51971d 100644 --- a/src/main/scala/leon/purescala/DefOps.scala +++ b/src/main/scala/leon/purescala/DefOps.scala @@ -325,7 +325,7 @@ object DefOps { fdMapCache(fd).getOrElse(fd) } } - + val newP = p.copy(units = for (u <- p.units) yield { u.copy( defs = u.defs.map { @@ -636,6 +636,7 @@ object DefOps { } ) }) + if (!found) { println(s"addDefs could not find anchor definition! Not found: $after") p.definedFunctions.filter(f => f.id.name == after.id.name).map(fd => fd.id.name + " : " + fd) match { @@ -644,9 +645,10 @@ object DefOps { } println(Thread.currentThread().getStackTrace.map(_.toString).take(10).mkString("\n")) } + res } - + def addFunDefs(p: Program, fds: Traversable[FunDef], after: FunDef): Program = addDefs(p, fds, after) def addClassDefs(p: Program, fds: Traversable[ClassDef], after: ClassDef): Program = addDefs(p, fds, after) diff --git a/src/main/scala/leon/purescala/Definitions.scala b/src/main/scala/leon/purescala/Definitions.scala index dfc78d4c547ddf40fb6c2a1e39bdab40611dcbd1..733eaf124dbb6dd8499c81347f7fc864c23ca453 100644 --- a/src/main/scala/leon/purescala/Definitions.scala +++ b/src/main/scala/leon/purescala/Definitions.scala @@ -182,7 +182,6 @@ object Definitions { lazy val singleCaseClasses : Seq[CaseClassDef] = defs.collect { case c @ CaseClassDef(_, _, None, _) => c } - } // A class that represents flags that annotate a FunDef with different attributes @@ -217,7 +216,8 @@ object Definitions { case object IsSynthetic extends FunctionFlag // Is inlined case object IsInlined extends FunctionFlag - + // Is an ADT invariant method + case object IsADTInvariant extends FunctionFlag with ClassFlag /** Useful because case classes and classes are somewhat unified in some * patterns (of pattern-matching, that is) */ @@ -268,6 +268,15 @@ object Definitions { def flags = _flags + private var _invariant: Option[FunDef] = None + + def invariant = _invariant + def hasInvariant = flags contains IsADTInvariant + def setInvariant(fd: FunDef): Unit = { + addFlag(IsADTInvariant) + _invariant = Some(fd) + } + def annotations: Set[String] = extAnnotations.keySet def extAnnotations: Map[String, Seq[Option[Any]]] = flags.collect { case Annotation(s, args) => s -> args }.toMap @@ -289,6 +298,23 @@ object Definitions { ccd } + def isInductive: Boolean = { + def induct(tpe: TypeTree, seen: Set[ClassDef]): Boolean = tpe match { + case ct: ClassType => + val root = ct.classDef.root + seen(root) || ct.fields.forall(vd => induct(vd.getType, seen + root)) + case TupleType(tpes) => + tpes.forall(tpe => induct(tpe, seen)) + case _ => true + } + + if (this == root && !this.isAbstract) false + else if (this != root) root.isInductive + else knownCCDescendants.forall { ccd => + ccd.fields.forall(vd => induct(vd.getType, Set(root))) + } + } + val isAbstract: Boolean val isCaseObject: Boolean @@ -474,6 +500,7 @@ object Definitions { def canBeField = canBeLazyField || canBeStrictField def isRealFunction = !canBeField def isSynthetic = flags contains IsSynthetic + def isInvariant = flags contains IsADTInvariant def methodOwner = flags collectFirst { case IsMethod(cd) => cd } /* Wrapping in TypedFunDef */ diff --git a/src/main/scala/leon/purescala/Expressions.scala b/src/main/scala/leon/purescala/Expressions.scala index 3c597d26e76976d44a939754fdea8c1caf45d29e..e8e87c7976a907cfca16ddc79f0cd758fe7d8402 100644 --- a/src/main/scala/leon/purescala/Expressions.scala +++ b/src/main/scala/leon/purescala/Expressions.scala @@ -271,7 +271,7 @@ object Expressions { * @param rhs The expression to the right of `=>` * @see [[Expressions.MatchExpr]] */ - case class MatchCase(pattern : Pattern, optGuard : Option[Expr], rhs: Expr) extends Tree { + case class MatchCase(pattern: Pattern, optGuard: Option[Expr], rhs: Expr) extends Tree { def expressions: Seq[Expr] = optGuard.toList :+ rhs } diff --git a/src/main/scala/leon/purescala/MethodLifting.scala b/src/main/scala/leon/purescala/MethodLifting.scala index 2014739eaa3fba0c84ce10685087262e7e227872..0e9e11d171a15e2b0b5fb77d0b3f0e65d44be38b 100644 --- a/src/main/scala/leon/purescala/MethodLifting.scala +++ b/src/main/scala/leon/purescala/MethodLifting.scala @@ -251,6 +251,10 @@ object MethodLifting extends TransformationPhase { ) } + if (cd.methods.exists(md => md.id == fd.id && md.isInvariant)) { + cd.setInvariant(nfd) + } + mdToFds += fd -> nfd fdsOf += cd.id.name -> (fdsOf.getOrElse(cd.id.name, Set()) + nfd) } diff --git a/src/main/scala/leon/purescala/Types.scala b/src/main/scala/leon/purescala/Types.scala index 9ec0e4b41f33aa0895ff160a5379d16a7cb88d87..6f2518d549be6b1d439ff9cb08594558d02449d5 100644 --- a/src/main/scala/leon/purescala/Types.scala +++ b/src/main/scala/leon/purescala/Types.scala @@ -114,6 +114,8 @@ object Types { } } + def invariant = classDef.invariant.map(_.typed(tps)) + def knownDescendants = classDef.knownDescendants.map( _.typed(tps) ) def knownCCDescendants: Seq[CaseClassType] = classDef.knownCCDescendants.map( _.typed(tps) ) @@ -128,8 +130,8 @@ object Types { case t => throw LeonFatalError("Unexpected translated parent type: "+t) } } - } + case class AbstractClassType(classDef: AbstractClassDef, tps: Seq[TypeTree]) extends ClassType case class CaseClassType(classDef: CaseClassDef, tps: Seq[TypeTree]) extends ClassType diff --git a/src/main/scala/leon/solvers/templates/DatatypeManager.scala b/src/main/scala/leon/solvers/templates/DatatypeManager.scala new file mode 100644 index 0000000000000000000000000000000000000000..8f9c2202adac023c52de03a693a075a660e21a99 --- /dev/null +++ b/src/main/scala/leon/solvers/templates/DatatypeManager.scala @@ -0,0 +1,214 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package solvers +package templates + +import purescala.Common._ +import purescala.Definitions._ +import purescala.Expressions._ +import purescala.Constructors._ +import purescala.Extractors._ +import purescala.ExprOps._ +import purescala.Types._ +import purescala.TypeOps.bestRealType + +import utils._ +import utils.SeqUtils._ +import Instantiation._ +import Template._ + +import scala.collection.mutable.{Map => MutableMap, Set => MutableSet} + +object DatatypeTemplate { + + def apply[T]( + encoder: TemplateEncoder[T], + manager: DatatypeManager[T], + tpe: TypeTree + ) : DatatypeTemplate[T] = { + val id = FreshIdentifier("x", tpe, true) + val expr = manager.typeUnroller(Variable(id)) + + val pathVar = FreshIdentifier("b", BooleanType, true) + + var condVars = Map[Identifier, T]() + var condTree = Map[Identifier, Set[Identifier]](pathVar -> Set.empty).withDefaultValue(Set.empty) + def storeCond(pathVar: Identifier, id: Identifier): Unit = { + condVars += id -> encoder.encodeId(id) + condTree += pathVar -> (condTree(pathVar) + id) + } + + var guardedExprs = Map[Identifier, Seq[Expr]]() + def storeGuarded(pathVar: Identifier, expr: Expr): Unit = { + val prev = guardedExprs.getOrElse(pathVar, Nil) + guardedExprs += pathVar -> (expr +: prev) + } + + def requireDecomposition(e: Expr): Boolean = exists { + case _: FunctionInvocation => true + case _ => false + } (e) + + def rec(pathVar: Identifier, expr: Expr): Unit = expr match { + case i @ IfExpr(cond, thenn, elze) if requireDecomposition(i) => + val newBool1: Identifier = FreshIdentifier("b", BooleanType, true) + val newBool2: Identifier = FreshIdentifier("b", BooleanType, true) + + storeCond(pathVar, newBool1) + storeCond(pathVar, newBool2) + + storeGuarded(pathVar, or(Variable(newBool1), Variable(newBool2))) + storeGuarded(pathVar, or(not(Variable(newBool1)), not(Variable(newBool2)))) + storeGuarded(pathVar, Equals(Variable(newBool1), cond)) + + rec(newBool1, thenn) + rec(newBool2, elze) + + case a @ And(es) if requireDecomposition(a) => + val partitions = groupWhile(es)(!requireDecomposition(_)) + for (e <- partitions.map(andJoin)) rec(pathVar, e) + + case _ => + storeGuarded(pathVar, expr) + } + + rec(pathVar, expr) + + val (idT, pathVarT) = (encoder.encodeId(id), encoder.encodeId(pathVar)) + val (clauses, blockers, _, functions, _, _) = Template.encode(encoder, + pathVar -> pathVarT, Seq(id -> idT), condVars, Map.empty, guardedExprs, Seq.empty, Seq.empty) + + new DatatypeTemplate[T](encoder, manager, + pathVar -> pathVarT, idT, condVars, condTree, clauses, blockers, functions) + } +} + +class DatatypeTemplate[T] private ( + val encoder: TemplateEncoder[T], + val manager: DatatypeManager[T], + val pathVar: (Identifier, T), + val argument: T, + val condVars: Map[Identifier, T], + val condTree: Map[Identifier, Set[Identifier]], + val clauses: Seq[T], + val blockers: Map[T, Set[TemplateCallInfo[T]]], + val functions: Set[(T, FunctionType, T)]) extends Template[T] { + + val args = Seq(argument) + val exprVars = Map.empty[Identifier, T] + val applications = Map.empty[T, Set[App[T]]] + val lambdas = Seq.empty[LambdaTemplate[T]] + val matchers = Map.empty[T, Set[Matcher[T]]] + val quantifications = Seq.empty[QuantificationTemplate[T]] + + def instantiate(blocker: T, arg: T): Instantiation[T] = instantiate(blocker, Seq(Left(arg))) +} + +class DatatypeManager[T](encoder: TemplateEncoder[T]) extends TemplateManager(encoder) { + + private val classCache: MutableMap[ClassType, FunDef] = MutableMap.empty + + private def classTypeUnroller(ct: ClassType): FunDef = classCache.get(ct) match { + case Some(fd) => fd + case None => + val param = ValDef(FreshIdentifier("x", ct)) + val fd = new FunDef(FreshIdentifier("unroll"+ct.classDef.id), Seq.empty, Seq(param), BooleanType) + classCache += ct -> fd + + val matchers = ct match { + case (act: AbstractClassType) => act.knownCCDescendants + case (cct: CaseClassType) => Seq(cct) + } + + fd.body = Some(MatchExpr(param.toVariable, matchers.map { cct => + val pattern = CaseClassPattern(None, cct, cct.fields.map(vd => WildcardPattern(Some(vd.id)))) + val expr = andJoin(cct.fields.map(vd => typeUnroller(Variable(vd.id)))) + MatchCase(pattern, None, expr) + })) + + fd + } + + private val requireChecking: MutableSet[ClassType] = MutableSet.empty + private val requireCache: MutableMap[TypeTree, Boolean] = MutableMap.empty + + private def requireTypeUnrolling(tpe: TypeTree): Boolean = requireCache.get(tpe) match { + case Some(res) => res + case None => + val res = tpe match { + case ft: FunctionType => true + case ct: CaseClassType if ct.classDef.hasParent => true + case ct: ClassType if requireChecking(ct.root) => false + case ct: ClassType => + requireChecking += ct.root + val classTypes = ct.root +: (ct.root match { + case (act: AbstractClassType) => act.knownDescendants + case (cct: CaseClassType) => Seq.empty + }) + + classTypes.exists(ct => ct.classDef.hasInvariant || ct.fieldsTypes.exists(requireTypeUnrolling)) + + case BooleanType | UnitType | CharType | IntegerType | + RealType | Int32Type | StringType | (_: TypeParameter) => false + + case NAryType(tpes, _) => tpes.exists(requireTypeUnrolling) + } + + requireCache += tpe -> res + res + } + + def typeUnroller(expr: Expr): Expr = expr.getType match { + case tpe if !requireTypeUnrolling(tpe) => + BooleanLiteral(true) + + case ct: ClassType => + val inv = if (ct.classDef.root.hasInvariant) { + Seq(FunctionInvocation(ct.root.invariant.get, Seq(expr))) + } else { + Seq.empty + } + + val subtype = if (ct != ct.root) { + Seq(IsInstanceOf(expr, ct)) + } else { + Seq.empty + } + + val induct = if (!ct.classDef.isInductive) { + val matchers = ct.root match { + case (act: AbstractClassType) => act.knownCCDescendants + case (cct: CaseClassType) => Seq(cct) + } + + MatchExpr(expr, matchers.map { cct => + val pattern = CaseClassPattern(None, cct, cct.fields.map(vd => WildcardPattern(Some(vd.id)))) + val expr = andJoin(cct.fields.map(vd => typeUnroller(Variable(vd.id)))) + MatchCase(pattern, None, expr) + }) + } else { + val fd = classTypeUnroller(ct.root) + FunctionInvocation(fd.typed, Seq(expr)) + } + + andJoin(inv ++ subtype :+ induct) + + case TupleType(tpes) => + andJoin(tpes.zipWithIndex.map(p => typeUnroller(TupleSelect(expr, p._2)))) + + case FunctionType(_, _) => + FreshFunction(expr) + + case _ => scala.sys.error("TODO") + } + + private val typeCache: MutableMap[TypeTree, DatatypeTemplate[T]] = MutableMap.empty + + protected def typeTemplate(tpe: TypeTree): DatatypeTemplate[T] = typeCache.getOrElse(tpe, { + val template = DatatypeTemplate(encoder, this, tpe) + typeCache += tpe -> template + template + }) +} + diff --git a/src/main/scala/leon/solvers/templates/LambdaManager.scala b/src/main/scala/leon/solvers/templates/LambdaManager.scala index 036654c25ab64839efa5af4cb0d49a6566ffcf20..edb096ef222243506d0af47b69f8ef340c5fe857 100644 --- a/src/main/scala/leon/solvers/templates/LambdaManager.scala +++ b/src/main/scala/leon/solvers/templates/LambdaManager.scala @@ -5,19 +5,30 @@ package solvers package templates import purescala.Common._ +import purescala.Definitions._ import purescala.Expressions._ +import purescala.Constructors._ import purescala.Extractors._ import purescala.ExprOps._ import purescala.Types._ +import purescala.TypeOps.bestRealType import utils._ +import utils.SeqUtils._ import Instantiation._ import Template._ -case class App[T](caller: T, tpe: FunctionType, args: Seq[Arg[T]]) { +import scala.collection.mutable.{Map => MutableMap, Set => MutableSet} + +case class App[T](caller: T, tpe: FunctionType, args: Seq[Arg[T]], encoded: T) { override def toString = "(" + caller + " : " + tpe + ")" + args.map(_.encoded).mkString("(", ",", ")") } +case class FreshFunction(expr: Expr) extends Expr with Extractable { + val getType = BooleanType + val extract = Some(Seq(expr), (exprs: Seq[Expr]) => FreshFunction(exprs.head)) +} + object LambdaTemplate { def apply[T]( @@ -39,10 +50,12 @@ object LambdaTemplate { val id = ids._2 val tpe = ids._1.getType.asInstanceOf[FunctionType] - val (clauses, blockers, applications, matchers, templateString) = - Template.encode(encoder, pathVar, arguments, condVars, exprVars, guardedExprs, lambdas, + val (clauses, blockers, applications, functions, matchers, templateString) = + Template.encode(encoder, pathVar, arguments, condVars, exprVars, guardedExprs, lambdas, quantifications, substMap = baseSubstMap + ids, optApp = Some(id -> tpe)) + assert(functions.isEmpty, "Only synthetic type explorers should introduce functions!") + val lambdaString : () => String = () => { "Template for lambda " + ids._1 + ": " + lambda + " is :\n" + templateString() } @@ -63,9 +76,9 @@ object LambdaTemplate { clauses, blockers, applications, - quantifications, - matchers, lambdas, + matchers, + quantifications, keyDeps, key, lambdaString @@ -107,15 +120,16 @@ class LambdaTemplate[T] private ( val clauses: Seq[T], val blockers: Map[T, Set[TemplateCallInfo[T]]], val applications: Map[T, Set[App[T]]], - val quantifications: Seq[QuantificationTemplate[T]], - val matchers: Map[T, Set[Matcher[T]]], val lambdas: Seq[LambdaTemplate[T]], + val matchers: Map[T, Set[Matcher[T]]], + val quantifications: Seq[QuantificationTemplate[T]], val dependencies: Map[Identifier, T], val structuralKey: Lambda, stringRepr: () => String) extends Template[T] with KeyedTemplate[T, Lambda] { val args = arguments.map(_._2) - val tpe = ids._1.getType.asInstanceOf[FunctionType] + val tpe = bestRealType(ids._1.getType).asInstanceOf[FunctionType] + val functions: Set[(T, FunctionType, T)] = Set.empty def substitute(substituter: T => T, matcherSubst: Map[T, Matcher[T]]): LambdaTemplate[T] = { val newStart = substituter(start) @@ -135,14 +149,14 @@ class LambdaTemplate[T] private ( )) } - val newQuantifications = quantifications.map(_.substitute(substituter, matcherSubst)) + val newLambdas = lambdas.map(_.substitute(substituter, matcherSubst)) val newMatchers = matchers.map { case (b, ms) => val bp = if (b == start) newStart else b bp -> ms.map(_.substitute(substituter, matcherSubst)) } - val newLambdas = lambdas.map(_.substitute(substituter, matcherSubst)) + val newQuantifications = quantifications.map(_.substitute(substituter, matcherSubst)) val newDependencies = dependencies.map(p => p._1 -> substituter(p._2)) @@ -158,9 +172,9 @@ class LambdaTemplate[T] private ( newClauses, newBlockers, newApplications, - newQuantifications, - newMatchers, newLambdas, + newMatchers, + newQuantifications, newDependencies, structuralKey, stringRepr @@ -172,7 +186,7 @@ class LambdaTemplate[T] private ( new LambdaTemplate[T]( ids._1 -> idT, encoder, manager, pathVar, arguments, condVars, exprVars, condTree, clauses map substituter, // make sure the body-defining clause is inlined! - blockers, applications, quantifications, matchers, lambdas, + blockers, applications, lambdas, matchers, quantifications, dependencies, structuralKey, stringRepr ) } @@ -185,25 +199,82 @@ class LambdaTemplate[T] private ( } } -class LambdaManager[T](encoder: TemplateEncoder[T]) extends TemplateManager(encoder) { +class LambdaManager[T](encoder: TemplateEncoder[T]) extends DatatypeManager(encoder) { private[templates] lazy val trueT = encoder.encodeExpr(Map.empty)(BooleanLiteral(true)) protected[templates] val byID = new IncrementalMap[T, LambdaTemplate[T]] protected val byType = new IncrementalMap[FunctionType, Map[(Expr, Seq[T]), LambdaTemplate[T]]].withDefaultValue(Map.empty) protected val applications = new IncrementalMap[FunctionType, Set[(T, App[T])]].withDefaultValue(Set.empty) - protected val freeLambdas = new IncrementalMap[FunctionType, Set[T]].withDefaultValue(Set.empty) + protected val knownFree = new IncrementalMap[FunctionType, Set[T]].withDefaultValue(Set.empty) + protected val maybeFree = new IncrementalMap[FunctionType, Set[(T, T)]].withDefaultValue(Set.empty) + protected val freeBlockers = new IncrementalMap[FunctionType, Set[(T, T)]].withDefaultValue(Set.empty) private val instantiated = new IncrementalSet[(T, App[T])] override protected def incrementals: List[IncrementalState] = - super.incrementals ++ List(byID, byType, applications, freeLambdas, instantiated) + super.incrementals ++ List(byID, byType, applications, knownFree, maybeFree, freeBlockers, instantiated) + + def registerFunction(b: T, tpe: FunctionType, f: T): Seq[T] = { + val ft = bestRealType(tpe).asInstanceOf[FunctionType] + val bs = fixpoint((bs: Set[T]) => bs.flatMap(blockerParents))(Set(b)) + + val (known, neqClauses) = if ((bs intersect typeEnablers).nonEmpty) { + maybeFree += ft -> (maybeFree(ft) + (b -> f)) + (false, byType(ft).values.toSeq.map { t => + encoder.mkImplies(b, encoder.mkNot(encoder.mkEquals(t.ids._2, f))) + }) + } else { + knownFree += ft -> (knownFree(ft) + f) + (true, byType(ft).values.toSeq.map(t => encoder.mkNot(encoder.mkEquals(t.ids._2, f)))) + } - def registerFree(lambdas: Seq[(Identifier, T)]): Unit = { - for ((id, idT) <- lambdas) id.getType match { - case ft: FunctionType => - freeLambdas += ft -> (freeLambdas(ft) + idT) - case _ => + val extClauses = freeBlockers(tpe).map { case (oldB, freeF) => + val equals = encoder.mkEquals(f, freeF) + val nextB = encoder.encodeId(FreshIdentifier("b_or", BooleanType, true)) + val extension = encoder.mkOr(if (known) equals else encoder.mkAnd(b, equals), nextB) + encoder.mkEquals(oldB, extension) } + + neqClauses ++ extClauses + } + + def assumptions: Seq[T] = freeBlockers.flatMap(_._2.map(p => encoder.mkNot(p._2))).toSeq + + private val typeBlockers = new IncrementalMap[T, T]() + private val typeEnablers: MutableSet[T] = MutableSet.empty + + private def typeUnroller(blocker: T, app: App[T]): Instantiation[T] = typeBlockers.get(app.encoded) match { + case Some(typeBlocker) => + implies(blocker, typeBlocker) + (Seq(encoder.mkImplies(blocker, typeBlocker)), Map.empty, Map.empty) + + case None => + val App(caller, tpe @ FunctionType(_, to), args, value) = app + val typeBlocker = encoder.encodeId(FreshIdentifier("t", BooleanType)) + typeBlockers += value -> typeBlocker + implies(blocker, typeBlocker) + + val template = typeTemplate(to) + val instantiation = template.instantiate(typeBlocker, value) + + val (b, extClauses) = if (knownFree(tpe) contains caller) { + (blocker, Seq.empty) + } else { + val firstB = encoder.encodeId(FreshIdentifier("b_free", BooleanType, true)) + implies(firstB, typeBlocker) + typeEnablers += firstB + + val nextB = encoder.encodeId(FreshIdentifier("b_or", BooleanType, true)) + freeBlockers += tpe -> (freeBlockers(tpe) + (nextB -> caller)) + + val clause = encoder.mkEquals(firstB, encoder.mkOr( + knownFree(tpe).map(idT => encoder.mkEquals(caller, idT)).toSeq ++ + maybeFree(tpe).map { case (b, idT) => encoder.mkAnd(b, encoder.mkEquals(caller, idT)) } :+ + nextB : _*)) + (firstB, Seq(clause)) + } + + instantiation withClauses extClauses withClause encoder.mkImplies(b, typeBlocker) } def instantiateLambda(template: LambdaTemplate[T]): (T, Instantiation[T]) = { @@ -219,12 +290,15 @@ class LambdaManager[T](encoder: TemplateEncoder[T]) extends TemplateManager(enco var appBlockers : AppBlockers[T] = Map.empty.withDefaultValue(Set.empty) // make sure the new lambda isn't equal to any free lambda var - clauses ++= freeLambdas(newTemplate.tpe).map(pIdT => encoder.mkNot(encoder.mkEquals(idT, pIdT))) + clauses ++= knownFree(newTemplate.tpe).map(f => encoder.mkNot(encoder.mkEquals(idT, f))) + clauses ++= maybeFree(newTemplate.tpe).map { p => + encoder.mkImplies(p._1, encoder.mkNot(encoder.mkEquals(idT, p._2))) + } byID += idT -> newTemplate byType += newTemplate.tpe -> (byType(newTemplate.tpe) + (newTemplate.key -> newTemplate)) - for (blockedApp @ (_, App(caller, tpe, args)) <- applications(newTemplate.tpe)) { + for (blockedApp @ (_, App(caller, tpe, args, _)) <- applications(newTemplate.tpe)) { val equals = encoder.mkEquals(idT, caller) appBlockers += (blockedApp -> (appBlockers(blockedApp) + TemplateAppInfo(newTemplate, equals, args))) } @@ -234,17 +308,22 @@ class LambdaManager[T](encoder: TemplateEncoder[T]) extends TemplateManager(enco } def instantiateApp(blocker: T, app: App[T]): Instantiation[T] = { - val App(caller, tpe, args) = app - val instantiation = Instantiation.empty[T] + val App(caller, tpe @ FunctionType(_, to), args, encoded) = app + + val instantiation: Instantiation[T] = if (byID contains caller) { + Instantiation.empty + } else { + typeUnroller(blocker, app) + } - if (freeLambdas(tpe).contains(caller)) instantiation else { + if (knownFree(tpe) contains caller) instantiation else { val key = blocker -> app if (instantiated(key)) instantiation else { instantiated += key if (byID contains caller) { - instantiation withApp (key -> TemplateAppInfo(byID(caller), trueT, args)) + empty withApp (key -> TemplateAppInfo(byID(caller), trueT, args)) } else { // make sure that even if byType(tpe) is empty, app is recorded in blockers diff --git a/src/main/scala/leon/solvers/templates/QuantificationManager.scala b/src/main/scala/leon/solvers/templates/QuantificationManager.scala index 84a3017b8078faebc19fc606f5bd52f5be91fcb0..b53382f6b46a8e5a06004b415d3d64a759f06f97 100644 --- a/src/main/scala/leon/solvers/templates/QuantificationManager.scala +++ b/src/main/scala/leon/solvers/templates/QuantificationManager.scala @@ -117,8 +117,8 @@ object QuantificationTemplate { val insts: (Identifier, T) = inst -> encoder.encodeId(inst) val guards: (Identifier, T) = guard -> encoder.encodeId(guard) - val (clauses, blockers, applications, matchers, _) = - Template.encode(encoder, pathVar, quantifiers, condVars, exprVars, guardedExprs, lambdas, + val (clauses, blockers, applications, functions, matchers, _) = + Template.encode(encoder, pathVar, quantifiers, condVars, exprVars, guardedExprs, lambdas, Seq.empty, substMap = baseSubstMap + q2s + insts + guards + qs) val (structuralQuant, structSubst) = normalizeStructure(proposition) @@ -139,13 +139,12 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage private val ignoredSubsts = new IncrementalMap[MatcherQuantification, MutableSet[(Int, Set[T], Map[T,Arg[T]])]] private val handledSubsts = new IncrementalMap[MatcherQuantification, MutableSet[(Set[T], Map[T,Arg[T]])]] - private val known = new IncrementalSet[T] private val lambdaAxioms = new IncrementalSet[((Expr, Seq[T]), Seq[(Identifier, T)])] private val templates = new IncrementalMap[(Expr, Seq[T]), T] override protected def incrementals: List[IncrementalState] = List(quantifications, instCtx, ignoredMatchers, ignoredSubsts, - handledSubsts, known, lambdaAxioms, templates) ++ super.incrementals + handledSubsts, lambdaAxioms, templates) ++ super.incrementals private var currentGen = 0 @@ -155,7 +154,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage private case class TypeKey(tt: TypeTree) extends MatcherKey(tt) private def matcherKey(caller: T, tpe: TypeTree): MatcherKey = tpe match { - case _: FunctionType if known(caller) => CallerKey(caller, tpe) + case ft: FunctionType if knownFree(ft)(caller) => CallerKey(caller, tpe) case _: FunctionType if byID.isDefinedAt(caller) => LambdaKey(byID(caller).structuralKey, tpe) case _ => TypeKey(tpe) } @@ -201,7 +200,8 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage (qs.map(_._2) zip uniformQuants(qs.map(_._1))).toMap } - def assumptions: Seq[T] = quantifications.collect { case q: Quantification => q.currentQ2Var }.toSeq + override def assumptions: Seq[T] = super.assumptions ++ + quantifications.collect { case q: Quantification => q.currentQ2Var }.toSeq def typeInstantiations: Map[TypeTree, Matchers] = instCtx.map.instantiations.collect { case (TypeKey(tpe), matchers) => tpe -> matchers @@ -215,11 +215,6 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage case (CallerKey(caller, tpe), matchers) => caller -> (matchers ++ instCtx.map.get(TypeKey(tpe)).toMatchers) } - override def registerFree(ids: Seq[(Identifier, T)]): Unit = { - super.registerFree(ids) - known ++= ids.map(_._2) - } - private def matcherDepth(m: Matcher[T]): Int = 1 + (0 +: m.args.map { case Right(ma) => matcherDepth(ma) case _ => 0 @@ -292,7 +287,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage def get(key: MatcherKey): Context = key match { case TypeKey(tpe) => tpeMap.getOrElse(tpe, new Context) - case key => funMap.getOrElse(key, new Context) + case key => funMap.getOrElse(key, new Context) ++ tpeMap.getOrElse(key.tpe, new Context) } def instantiations: Map[MatcherKey, Matchers] = @@ -486,14 +481,14 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage } val baseSubst = subst ++ instanceSubst(enabler).mapValues(Left(_)) - val (substMap, inst) = Template.substitution(encoder, QuantificationManager.this, - condVars, exprVars, condTree, Seq.empty, lambdas, baseSubst, pathVar._1, enabler) + val (substMap, inst) = Template.substitution[T](encoder, QuantificationManager.this, + condVars, exprVars, condTree, Seq.empty, lambdas, Set.empty, baseSubst, pathVar._1, enabler) instantiation ++= inst val msubst = substMap.collect { case (c, Right(m)) => c -> m } val substituter = encoder.substitute(substMap.mapValues(_.encoded)) instantiation ++= Template.instantiate(encoder, QuantificationManager.this, - clauses, blockers, applications, Seq.empty, Map.empty, lambdas, substMap) + clauses, blockers, applications, Map.empty, substMap) for ((b,ms) <- allMatchers; m <- ms) { val sb = enablers ++ (if (b == start) Set.empty else Set(substituter(b))) @@ -928,8 +923,9 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage val valuesP = values.map(v => v -> encoder.encodeId(v)) val exprT = encoder.encodeExpr(elemsP.toMap ++ valuesP + guardP)(expr) - val disjuncts = insts.toSeq.map { case (b, m) => - val subst = (elemsP.map(_._2) zip m.args.map(_.encoded)).toMap + (guardP._2 -> b) + val disjuncts = insts.toSeq.map { case (b, im) => + val bp = if (m.caller != im.caller) encoder.mkAnd(encoder.mkEquals(m.caller, im.caller), b) else b + val subst = (elemsP.map(_._2) zip im.args.map(_.encoded)).toMap + (guardP._2 -> bp) encoder.substitute(subst)(exprT) } diff --git a/src/main/scala/leon/solvers/templates/TemplateGenerator.scala b/src/main/scala/leon/solvers/templates/TemplateGenerator.scala index 2b0e95a4d77de19253730a7fde9b2e04890ca5f6..16dc1c19895da49d2d88518a2f56ae0a47cd9b83 100644 --- a/src/main/scala/leon/solvers/templates/TemplateGenerator.scala +++ b/src/main/scala/leon/solvers/templates/TemplateGenerator.scala @@ -14,6 +14,7 @@ import purescala.Definitions._ import purescala.Constructors._ import purescala.Quantification._ +import utils.SeqUtils._ import Instantiation._ class TemplateGenerator[T](val encoder: TemplateEncoder[T], @@ -49,8 +50,10 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], return cacheExpr(body) } - val fakeFunDef = new FunDef(FreshIdentifier("fake", alwaysShowUniqueID = true), Nil, variablesOf(body).toSeq.map(ValDef(_)), body.getType) + val arguments = variablesOf(body).toSeq.map(ValDef(_)) + val fakeFunDef = new FunDef(FreshIdentifier("fake", alwaysShowUniqueID = true), Nil, arguments, body.getType) + fakeFunDef.precondition = Some(andJoin(arguments.map(vd => manager.typeUnroller(vd.toVariable)))) fakeFunDef.body = Some(body) val res = mkTemplate(fakeFunDef.typed, false) @@ -98,7 +101,7 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], val (bodyConds, bodyExprs, bodyTree, bodyGuarded, bodyLambdas, bodyQuantifications) = if (isRealFunDef) { invocationEqualsBody.foldLeft(emptyClauses)((clsSet, cls) => clsSet ++ mkClauses(start, cls, substMap)) } else { - mkClauses(start, lambdaBody.get, substMap) + (prec.toSeq :+ lambdaBody.get).foldLeft(emptyClauses)((clsSet, cls) => clsSet ++ mkClauses(start, cls, substMap)) } // Now the postcondition. @@ -214,7 +217,7 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], // Represents clauses of the form: // id => expr && ... && expr var guardedExprs = Map[Identifier, Seq[Expr]]() - def storeGuarded(guardVar : Identifier, expr : Expr) : Unit = { + def storeGuarded(guardVar: Identifier, expr: Expr) : Unit = { assert(expr.getType == BooleanType, expr.asString(Program.empty)(LeonContext.empty) + " is not of type Boolean." + ( purescala.ExprOps.fold[String]{ (e, se) => s"$e is of type ${e.getType}" + se.map(child => "\n " + "\n".r.replaceAllIn(child, "\n ")).mkString @@ -222,7 +225,6 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], )) val prev = guardedExprs.getOrElse(guardVar, Nil) - guardedExprs += guardVar -> (expr +: prev) } @@ -249,25 +251,6 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], }(e) } - def groupWhile[T](es: Seq[T])(p: T => Boolean): Seq[Seq[T]] = { - var res: Seq[Seq[T]] = Nil - - var c = es - while (!c.isEmpty) { - val (span, rest) = c.span(p) - - if (span.isEmpty) { - res :+= Seq(rest.head) - c = rest.tail - } else { - res :+= span - c = rest - } - } - - res - } - def rec(pathVar: Identifier, expr: Expr): Expr = { expr match { case a @ Assert(cond, err, body) => diff --git a/src/main/scala/leon/solvers/templates/TemplateManager.scala b/src/main/scala/leon/solvers/templates/TemplateManager.scala index 332d3fdf1442bfab25f72839d9ef00e04de73dd5..62a3f0c7617fbe819f7671e1b7a024fdc5b4832c 100644 --- a/src/main/scala/leon/solvers/templates/TemplateManager.scala +++ b/src/main/scala/leon/solvers/templates/TemplateManager.scala @@ -8,6 +8,7 @@ import purescala.Common._ import purescala.Definitions._ import purescala.Expressions._ import purescala.Quantification._ +import purescala.Constructors._ import purescala.Extractors._ import purescala.ExprOps._ import purescala.Types._ @@ -18,9 +19,9 @@ import utils._ import scala.collection.generic.CanBuildFrom object Instantiation { - type Clauses[T] = Seq[T] - type CallBlockers[T] = Map[T, Set[TemplateCallInfo[T]]] - type AppBlockers[T] = Map[(T, App[T]), Set[TemplateAppInfo[T]]] + type Clauses[T] = Seq[T] + type CallBlockers[T] = Map[T, Set[TemplateCallInfo[T]]] + type AppBlockers[T] = Map[(T, App[T]), Set[TemplateAppInfo[T]]] type Instantiation[T] = (Clauses[T], CallBlockers[T], AppBlockers[T]) def empty[T] = (Seq.empty[T], Map.empty[T, Set[TemplateCallInfo[T]]], Map.empty[(T, App[T]), Set[TemplateAppInfo[T]]]) @@ -61,32 +62,36 @@ import Template.Arg trait Template[T] { self => val encoder : TemplateEncoder[T] - val manager : QuantificationManager[T] + val manager : TemplateManager[T] + + val pathVar : (Identifier, T) + val args : Seq[T] - val pathVar: (Identifier, T) - val args : Seq[T] val condVars : Map[Identifier, T] val exprVars : Map[Identifier, T] val condTree : Map[Identifier, Set[Identifier]] - val clauses : Seq[T] - val blockers : Map[T, Set[TemplateCallInfo[T]]] + + val clauses : Seq[T] + val blockers : Map[T, Set[TemplateCallInfo[T]]] val applications : Map[T, Set[App[T]]] + val functions : Set[(T, FunctionType, T)] + val lambdas : Seq[LambdaTemplate[T]] + val quantifications : Seq[QuantificationTemplate[T]] - val matchers : Map[T, Set[Matcher[T]]] - val lambdas : Seq[LambdaTemplate[T]] + val matchers : Map[T, Set[Matcher[T]]] lazy val start = pathVar._2 def instantiate(aVar: T, args: Seq[Arg[T]]): Instantiation[T] = { val (substMap, instantiation) = Template.substitution(encoder, manager, - condVars, exprVars, condTree, quantifications, lambdas, + condVars, exprVars, condTree, quantifications, lambdas, functions, (this.args zip args).toMap + (start -> Left(aVar)), pathVar._1, aVar) instantiation ++ instantiate(substMap) } protected def instantiate(substMap: Map[T, Arg[T]]): Instantiation[T] = { - Template.instantiate(encoder, manager, - clauses, blockers, applications, quantifications, matchers, lambdas, substMap) + Template.instantiate(encoder, manager, clauses, + blockers, applications, matchers, substMap) } override def toString : String = "Instantiated template" @@ -127,6 +132,9 @@ object Template { Matcher(encodeExpr(caller), bestRealType(caller.getType), arguments.map(arg => Left(encodeExpr(arg))), encodeExpr(app)) } + type Apps[T] = Map[T, Set[App[T]]] + type Functions[T] = Set[(T, FunctionType, T)] + def encode[T]( encoder: TemplateEncoder[T], pathVar: (Identifier, T), @@ -135,23 +143,46 @@ object Template { exprVars: Map[Identifier, T], guardedExprs: Map[Identifier, Seq[Expr]], lambdas: Seq[LambdaTemplate[T]], + quantifications: Seq[QuantificationTemplate[T]], substMap: Map[Identifier, T] = Map.empty[Identifier, T], optCall: Option[TypedFunDef] = None, optApp: Option[(T, FunctionType)] = None - ) : (Seq[T], Map[T, Set[TemplateCallInfo[T]]], Map[T, Set[App[T]]], Map[T, Set[Matcher[T]]], () => String) = { + ) : (Clauses[T], CallBlockers[T], Apps[T], Functions[T], Map[T, Set[Matcher[T]]], () => String) = { val idToTrId : Map[Identifier, T] = - condVars ++ exprVars + pathVar ++ arguments ++ substMap ++ lambdas.map(_.ids) + condVars ++ exprVars + pathVar ++ arguments ++ substMap ++ lambdas.map(_.ids) ++ quantifications.map(_.qs) val encodeExpr : Expr => T = encoder.encodeExpr(idToTrId) - val clauses : Seq[T] = (for ((b,es) <- guardedExprs; e <- es) yield { - encodeExpr(Implies(Variable(b), e)) - }).toSeq + val (clauses, cleanGuarded, functions) = { + var functions: Set[(T, FunctionType, T)] = Set.empty + var clauses: Seq[T] = Seq.empty + + val cleanGuarded = guardedExprs.map { case (b, es) => b -> es.map { e => + val withPaths = CollectorWithPaths { case FreshFunction(f) => f }.traverse(e) + functions ++= withPaths.map { case (f, TopLevelAnds(paths)) => + val tpe = bestRealType(f.getType).asInstanceOf[FunctionType] + val path = andJoin(paths.filterNot(_.isInstanceOf[FreshFunction])) + (encodeExpr(and(Variable(b), path)), tpe, encodeExpr(f)) + } + + val cleanExpr = postMap { + case FreshFunction(f) => Some(BooleanLiteral(true)) + case _ => None + } (e) + + clauses :+= encodeExpr(Implies(Variable(b), cleanExpr)) + cleanExpr + }} + + (clauses, cleanGuarded, functions) + } val optIdCall = optCall.map(tfd => TemplateCallInfo[T](tfd, arguments.map(p => Left(p._2)))) val optIdApp = optApp.map { case (idT, tpe) => - App(idT, bestRealType(tpe).asInstanceOf[FunctionType], arguments.map(p => Left(p._2))) + val id = FreshIdentifier("x", tpe, true) + val encoded = encoder.encodeExpr(Map(id -> idT) ++ arguments)(Application(Variable(id), arguments.map(_._1.toVariable))) + App(idT, bestRealType(tpe).asInstanceOf[FunctionType], arguments.map(p => Left(p._2)), encoded) } lazy val invocMatcher = optCall.filter(_.returnType.isInstanceOf[FunctionType]) @@ -162,7 +193,7 @@ object Template { var applications : Map[Identifier, Set[App[T]]] = Map.empty var matchers : Map[Identifier, Set[Matcher[T]]] = Map.empty - for ((b,es) <- guardedExprs) { + for ((b,es) <- cleanGuarded) { var funInfos : Set[TemplateCallInfo[T]] = Set.empty var appInfos : Set[App[T]] = Set.empty var matchInfos : Set[Matcher[T]] = Set.empty @@ -192,7 +223,8 @@ object Template { funInfos ++= firstOrderCallsOf(e).map(p => TemplateCallInfo(p._1, p._2.map(encodeArg))) appInfos ++= firstOrderAppsOf(e).map { case (c, args) => - App(encodeExpr(c), bestRealType(c.getType).asInstanceOf[FunctionType], args.map(encodeArg)) + val tpe = bestRealType(c.getType).asInstanceOf[FunctionType] + App(encodeExpr(c), tpe, args.map(encodeArg), encodeExpr(Application(c, args))) } matchInfos ++= exprToMatcher.values @@ -204,10 +236,8 @@ object Template { val apps = appInfos -- optIdApp if (apps.nonEmpty) applications += b -> apps - val matchs = matchInfos.filter { case m @ Matcher(mc, mtpe, margs, _) => - !optIdApp.exists { case App(ac, atpe, aargs) => - mc == ac && mtpe == atpe && margs == aargs - } + val matchs = matchInfos.filter { case m @ Matcher(_, _, _, menc) => + !optIdApp.exists { case App(_, _, _, aenc) => menc == aenc } } ++ (if (funInfos.exists(info => Some(info) == optIdCall)) invocMatcher else None) if (matchs.nonEmpty) matchers += b -> matchs @@ -224,8 +254,8 @@ object Template { " * Activating boolean : " + pathVar._1 + "\n" + " * Control booleans : " + condVars.keys.mkString(", ") + "\n" + " * Expression vars : " + exprVars.keys.mkString(", ") + "\n" + - " * Clauses : " + (if (guardedExprs.isEmpty) "\n" else { - "\n " + (for ((b,es) <- guardedExprs; e <- es) yield (b + " ==> " + e)).mkString("\n ") + "\n" + " * Clauses : " + (if (cleanGuarded.isEmpty) "\n" else { + "\n " + (for ((b,es) <- cleanGuarded; e <- es) yield (b + " ==> " + e)).mkString("\n ") + "\n" }) + " * Invocation-blocks :" + (if (blockers.isEmpty) "\n" else { "\n " + blockers.map(p => p._1 + " ==> " + p._2).mkString("\n ") + "\n" @@ -241,17 +271,18 @@ object Template { }.mkString("\n") } - (clauses, encodedBlockers, encodedApps, encodedMatchers, stringRepr) + (clauses, encodedBlockers, encodedApps, functions, encodedMatchers, stringRepr) } def substitution[T]( encoder: TemplateEncoder[T], - manager: QuantificationManager[T], + manager: TemplateManager[T], condVars: Map[Identifier, T], exprVars: Map[Identifier, T], condTree: Map[Identifier, Set[Identifier]], quantifications: Seq[QuantificationTemplate[T]], lambdas: Seq[LambdaTemplate[T]], + functions: Set[(T, FunctionType, T)], baseSubst: Map[T, Arg[T]], pathVar: Identifier, aVar: T @@ -259,40 +290,56 @@ object Template { val freshSubst = exprVars.map { case (id, idT) => idT -> encoder.encodeId(id) } ++ manager.freshConds(pathVar -> aVar, condVars, condTree) val matcherSubst = baseSubst.collect { case (c, Right(m)) => c -> m } - var subst = freshSubst.mapValues(Left(_)) ++ baseSubst - // /!\ CAREFUL /!\ - // We have to be wary while computing the lambda subst map since lambdas can - // depend on each other. However, these dependencies cannot be cyclic so it - // suffices to make sure the traversal order is correct. + var subst = freshSubst.mapValues(Left(_)) ++ baseSubst var instantiation : Instantiation[T] = Instantiation.empty - var seen : Set[LambdaTemplate[T]] = Set.empty - - val lambdaKeys = lambdas.map(lambda => lambda.ids._1 -> lambda).toMap - def extractSubst(lambda: LambdaTemplate[T]): Unit = { - for { - dep <- lambda.dependencies.flatMap(p => lambdaKeys.get(p._1)) - if !seen(dep) - } extractSubst(dep) - - if (!seen(lambda)) { - val substMap = subst.mapValues(_.encoded) - val substLambda = lambda.substitute(encoder.substitute(substMap), matcherSubst) - val (idT, inst) = manager.instantiateLambda(substLambda) - instantiation ++= inst - subst += lambda.ids._2 -> Left(idT) - seen += lambda - } + + manager match { + case lmanager: LambdaManager[T] => + val funSubstituter = encoder.substitute(subst.mapValues(_.encoded)) + for ((b,tpe,f) <- functions) instantiation = instantiation withClauses { + lmanager.registerFunction(funSubstituter(b), tpe, funSubstituter(f)) + } + + // /!\ CAREFUL /!\ + // We have to be wary while computing the lambda subst map since lambdas can + // depend on each other. However, these dependencies cannot be cyclic so it + // suffices to make sure the traversal order is correct. + var seen : Set[LambdaTemplate[T]] = Set.empty + + val lambdaKeys = lambdas.map(lambda => lambda.ids._1 -> lambda).toMap + def extractSubst(lambda: LambdaTemplate[T]): Unit = { + for { + dep <- lambda.dependencies.flatMap(p => lambdaKeys.get(p._1)) + if !seen(dep) + } extractSubst(dep) + + if (!seen(lambda)) { + val substMap = subst.mapValues(_.encoded) + val substLambda = lambda.substitute(encoder.substitute(substMap), matcherSubst) + val (idT, inst) = lmanager.instantiateLambda(substLambda) + instantiation ++= inst + subst += lambda.ids._2 -> Left(idT) + seen += lambda + } + } + + for (l <- lambdas) extractSubst(l) + + case _ => } - for (l <- lambdas) extractSubst(l) + manager match { + case qmanager: QuantificationManager[T] => + for (q <- quantifications) { + val substMap = subst.mapValues(_.encoded) + val substQuant = q.substitute(encoder.substitute(substMap), matcherSubst) + val (qT, inst) = qmanager.instantiateQuantification(substQuant) + instantiation ++= inst + subst += q.qs._2 -> Left(qT) + } - for (q <- quantifications) { - val substMap = subst.mapValues(_.encoded) - val substQuant = q.substitute(encoder.substitute(substMap), matcherSubst) - val (qT, inst) = manager.instantiateQuantification(substQuant) - instantiation ++= inst - subst += q.qs._2 -> Left(qT) + case _ => } (subst, instantiation) @@ -300,13 +347,11 @@ object Template { def instantiate[T]( encoder: TemplateEncoder[T], - manager: QuantificationManager[T], + manager: TemplateManager[T], clauses: Seq[T], blockers: Map[T, Set[TemplateCallInfo[T]]], applications: Map[T, Set[App[T]]], - quantifications: Seq[QuantificationTemplate[T]], matchers: Map[T, Set[Matcher[T]]], - lambdas: Seq[LambdaTemplate[T]], substMap: Map[T, Arg[T]] ): Instantiation[T] = { @@ -314,20 +359,31 @@ object Template { val msubst = substMap.collect { case (c, Right(m)) => c -> m } val newClauses = clauses.map(substituter) + val newBlockers = blockers.map { case (b,fis) => substituter(b) -> fis.map(fi => fi.copy(args = fi.args.map(_.substitute(substituter, msubst)))) } var instantiation: Instantiation[T] = (newClauses, newBlockers, Map.empty) - for ((b,apps) <- applications; bp = substituter(b); app <- apps) { - val newApp = app.copy(caller = substituter(app.caller), args = app.args.map(_.substitute(substituter, msubst))) - instantiation ++= manager.instantiateApp(bp, newApp) + manager match { + case lmanager: LambdaManager[T] => + for ((b,apps) <- applications; bp = substituter(b); app <- apps) { + val newApp = app.copy(caller = substituter(app.caller), args = app.args.map(_.substitute(substituter, msubst))) + instantiation ++= lmanager.instantiateApp(bp, newApp) + } + + case _ => } - for ((b, matchs) <- matchers; bp = substituter(b); m <- matchs) { - val newMatcher = m.substitute(substituter, msubst) - instantiation ++= manager.instantiateMatcher(bp, newMatcher) + manager match { + case qmanager: QuantificationManager[T] => + for ((b, matchs) <- matchers; bp = substituter(b); m <- matchs) { + val newMatcher = m.substitute(substituter, msubst) + instantiation ++= qmanager.instantiateMatcher(bp, newMatcher) + } + + case _ => } instantiation @@ -339,7 +395,7 @@ object FunctionTemplate { def apply[T]( tfd: TypedFunDef, encoder: TemplateEncoder[T], - manager: QuantificationManager[T], + manager: TemplateManager[T], pathVar: (Identifier, T), arguments: Seq[(Identifier, T)], condVars: Map[Identifier, T], @@ -351,9 +407,8 @@ object FunctionTemplate { isRealFunDef: Boolean ) : FunctionTemplate[T] = { - val (clauses, blockers, applications, matchers, templateString) = - Template.encode(encoder, pathVar, arguments, condVars, exprVars, guardedExprs, lambdas, - substMap = quantifications.map(q => q.qs).toMap, + val (clauses, blockers, applications, functions, matchers, templateString) = + Template.encode(encoder, pathVar, arguments, condVars, exprVars, guardedExprs, lambdas, quantifications, optCall = Some(tfd)) val funString : () => String = () => { @@ -374,9 +429,10 @@ object FunctionTemplate { clauses, blockers, applications, - quantifications, - matchers, + functions, lambdas, + matchers, + quantifications, isRealFunDef, funString ) @@ -386,7 +442,7 @@ object FunctionTemplate { class FunctionTemplate[T] private( val tfd: TypedFunDef, val encoder: TemplateEncoder[T], - val manager: QuantificationManager[T], + val manager: TemplateManager[T], val pathVar: (Identifier, T), val args: Seq[T], val condVars: Map[Identifier, T], @@ -395,19 +451,15 @@ class FunctionTemplate[T] private( val clauses: Seq[T], val blockers: Map[T, Set[TemplateCallInfo[T]]], val applications: Map[T, Set[App[T]]], - val quantifications: Seq[QuantificationTemplate[T]], - val matchers: Map[T, Set[Matcher[T]]], + val functions: Set[(T, FunctionType, T)], val lambdas: Seq[LambdaTemplate[T]], + val matchers: Map[T, Set[Matcher[T]]], + val quantifications: Seq[QuantificationTemplate[T]], isRealFunDef: Boolean, stringRepr: () => String) extends Template[T] { private lazy val str : String = stringRepr() override def toString : String = str - - override def instantiate(aVar: T, args: Seq[Arg[T]]): Instantiation[T] = { - if (!isRealFunDef) manager.registerFree(tfd.params.map(_.id) zip args.map(_.left.get)) - super.instantiate(aVar, args) - } } class TemplateManager[T](protected[templates] val encoder: TemplateEncoder[T]) extends IncrementalState { diff --git a/src/main/scala/leon/solvers/templates/UnrollingBank.scala b/src/main/scala/leon/solvers/templates/UnrollingBank.scala index a84dc350e38275bf30b592d1e82badd0576b87c8..e66de514a3dcba2fb1aa2aa4787163268e6867d8 100644 --- a/src/main/scala/leon/solvers/templates/UnrollingBank.scala +++ b/src/main/scala/leon/solvers/templates/UnrollingBank.scala @@ -142,7 +142,7 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat } private def freshAppBlocks(apps: Traversable[(T, App[T])]) : Seq[T] = { - apps.filter(!appBlockers.isDefinedAt(_)).toSeq.map { case app @ (blocker, App(caller, tpe, _)) => + apps.filter(!appBlockers.isDefinedAt(_)).toSeq.map { case app @ (blocker, App(caller, tpe, _, _)) => val firstB = encoder.encodeId(FreshIdentifier("b_lambda", BooleanType, true)) val freeEq = functionVars.getOrElse(tpe, Set()).toSeq.map(t => encoder.mkEquals(t, caller)) @@ -328,6 +328,7 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat // We connect it to the defBlocker: blocker => defBlocker if (defBlocker != id) { newCls :+= encoder.mkImplies(id, defBlocker) + manager.implies(id, defBlocker) } reporter.debug("Unrolling behind "+info+" ("+newCls.size+")") @@ -367,6 +368,7 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat val enabler = if (equals == manager.trueT) b else encoder.mkAnd(equals, b) newCls :+= encoder.mkImplies(enabler, lambdaBlocker) + manager.implies(b, lambdaBlocker) reporter.debug("Unrolling behind "+info+" ("+newCls.size+")") for (cl <- newCls) { diff --git a/src/main/scala/leon/termination/StructuralSize.scala b/src/main/scala/leon/termination/StructuralSize.scala index f59618e17ac4eec3881dbdfc30c2ad8133f5ae54..50de827c17168bacd3db23049ed76237f73f132b 100644 --- a/src/main/scala/leon/termination/StructuralSize.scala +++ b/src/main/scala/leon/termination/StructuralSize.scala @@ -32,7 +32,7 @@ trait StructuralSize { )) absFun.typed } - + def size(expr: Expr) : Expr = { def funDef(ct: ClassType, cases: ClassType => Seq[MatchCase]): FunDef = { // we want to reuse generic size functions for sub-types diff --git a/src/main/scala/leon/utils/PreprocessingPhase.scala b/src/main/scala/leon/utils/PreprocessingPhase.scala index 06f35dc3a47df0a48836e6137eaae96218745d39..72103a4f70098527f4f80dba71c47a4b46c5d2d1 100644 --- a/src/main/scala/leon/utils/PreprocessingPhase.scala +++ b/src/main/scala/leon/utils/PreprocessingPhase.scala @@ -36,7 +36,6 @@ class PreprocessingPhase(desugarXLang: Boolean = false, genc: Boolean = false) e MethodLifting andThen TypingPhase andThen synthesis.ConversionPhase andThen - CheckADTFieldsTypes andThen InliningPhase val pipeX = if (!genc && desugarXLang) { diff --git a/src/main/scala/leon/utils/SeqUtils.scala b/src/main/scala/leon/utils/SeqUtils.scala index f2290a68d11bc668c348af3954af16b95f0f7d88..ada7499120353737d34bc79e2c9cc312d6702580 100644 --- a/src/main/scala/leon/utils/SeqUtils.scala +++ b/src/main/scala/leon/utils/SeqUtils.scala @@ -8,7 +8,6 @@ import scala.collection.mutable.ArrayBuffer object SeqUtils { type Tuple[T] = Seq[T] - def cartesianProduct[T](seqs: Tuple[Seq[T]]): Seq[Tuple[T]] = { val sizes = seqs.map(_.size) val max = sizes.product @@ -59,6 +58,25 @@ object SeqUtils { rec(sum, arity) filterNot (_.head == 0) } + + def groupWhile[T](es: Seq[T])(p: T => Boolean): Seq[Seq[T]] = { + var res: Seq[Seq[T]] = Nil + + var c = es + while (!c.isEmpty) { + val (span, rest) = c.span(p) + + if (span.isEmpty) { + res :+= Seq(rest.head) + c = rest.tail + } else { + res :+= span + c = rest + } + } + + res + } } class CartesianView[+A](views: Seq[SeqView[A, Seq[A]]]) extends SeqView[Seq[A], Seq[Seq[A]]] { @@ -103,4 +121,4 @@ class CartesianView[+A](views: Seq[SeqView[A, Seq[A]]]) extends SeqView[Seq[A], ret } } -} \ No newline at end of file +} diff --git a/src/main/scala/leon/verification/InjectAsserts.scala b/src/main/scala/leon/verification/InjectAsserts.scala index 1bd9a695788877bd2a0034ec151294daddd5ab59..aa88b39dec1e45377aa77f87ded6a0535abac782 100644 --- a/src/main/scala/leon/verification/InjectAsserts.scala +++ b/src/main/scala/leon/verification/InjectAsserts.scala @@ -71,6 +71,12 @@ object InjectAsserts extends SimpleLeonPhase[Program, Program] { e ).setPos(e)) + case e @ CaseClass(cct, args) if cct.root.classDef.hasInvariant => + Some(assertion(FunctionInvocation(cct.root.invariant.get, Seq(e)), + Some("ADT invariant"), + e + ).setPos(e)) + case _ => None }) diff --git a/src/test/resources/regression/termination/valid/Ackermann.scala b/src/test/resources/regression/termination/valid/Ackermann.scala new file mode 100644 index 0000000000000000000000000000000000000000..11ea76bee4ae9ae9c8225642d413d0f9054691f1 --- /dev/null +++ b/src/test/resources/regression/termination/valid/Ackermann.scala @@ -0,0 +1,10 @@ +import leon.lang._ + +object Ackermann { + def ackermann(m: BigInt, n: BigInt): BigInt = { + require(m >= 0 && n >= 0) + if (m == 0) n + 1 + else if (n == 0) ackermann(m - 1, 1) + else ackermann(m - 1, ackermann(m, n - 1)) + } +}