From e07343479c0fe3a31e7358fdf53e553f01480723 Mon Sep 17 00:00:00 2001 From: Nicolas Voirol <voirol.nicolas@gmail.com> Date: Tue, 15 Mar 2016 15:38:03 +0100 Subject: [PATCH] Quantifier improvements + invariant preservation in program transforms --- .../leon/evaluators/RecursiveEvaluator.scala | 4 +- .../frontends/scalac/CodeExtraction.scala | 9 +- src/main/scala/leon/purescala/DefOps.scala | 73 ++- .../scala/leon/purescala/Definitions.scala | 42 +- .../scala/leon/purescala/Expressions.scala | 1 - .../scala/leon/purescala/PrettyPrinter.scala | 22 +- .../scala/leon/purescala/Quantification.scala | 5 + src/main/scala/leon/solvers/ADTManager.scala | 11 + .../solvers/combinators/UnrollingSolver.scala | 204 +------ .../solvers/smtlib/SMTLIBCVC4Target.scala | 4 - .../leon/solvers/smtlib/SMTLIBSolver.scala | 2 - .../leon/solvers/smtlib/SMTLIBTarget.scala | 30 +- .../leon/solvers/smtlib/SMTLIBZ3Target.scala | 4 - .../solvers/templates/LambdaManager.scala | 4 +- .../templates/QuantificationManager.scala | 537 +++++++++++++----- .../solvers/templates/TemplateGenerator.scala | 1 - .../solvers/templates/TemplateManager.scala | 45 +- .../solvers/templates/UnrollingBank.scala | 11 +- .../leon/solvers/z3/AbstractZ3Solver.scala | 67 +-- .../scala/leon/solvers/z3/FairZ3Solver.scala | 14 +- .../solvers/z3/UninterpretedZ3Solver.scala | 12 - .../leon/solvers/z3/Z3StringConversion.scala | 15 +- .../leon/synthesis/rules/StringRender.scala | 25 +- .../transformations/IntToRealProgram.scala | 4 +- .../invalid/AbstractRefinementMap.scala | 22 + .../invalid/AbstractRefinementMap2.scala | 24 + .../valid/AbstractRefinementMap.scala | 24 + .../solvers/QuantifierSolverSuite.scala | 4 +- .../leon/unit/purescala/TypeOpsSuite.scala | 6 +- 29 files changed, 699 insertions(+), 527 deletions(-) create mode 100644 src/test/resources/regression/verification/purescala/invalid/AbstractRefinementMap.scala create mode 100644 src/test/resources/regression/verification/purescala/invalid/AbstractRefinementMap2.scala create mode 100644 src/test/resources/regression/verification/purescala/valid/AbstractRefinementMap.scala diff --git a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala index f752b1895..ee1c689bc 100644 --- a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala +++ b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala @@ -518,7 +518,7 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int implicit val debugSection = utils.DebugSectionVerification - ctx.reporter.debug("Executing forall!") + ctx.reporter.debug("Executing forall: " + f.asString) val mapping = variablesOf(f).map(id => id -> rctx.mappings(id)).toMap val context = mapping.toSeq.sortBy(_._1.uniqueName).map(_._2) @@ -546,7 +546,7 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int val domainCnstr = orJoin(quorums.map { quorum => val quantifierDomains = quorum.flatMap { case (path, caller, args) => - val matcher = e(expr) match { + val matcher = e(caller) match { case l: Lambda => gctx.lambdas.getOrElse(l, l) case ev => ev } diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala index 85e587444..9dc3514a4 100644 --- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala +++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala @@ -333,12 +333,12 @@ trait CodeExtraction extends ASTExtractors { } classToInvariants.get(sym).foreach { bodies => + val cd = classesToClasses(sym) val fd = new FunDef(invId, Seq.empty, Seq.empty, BooleanType) fd.addFlag(IsADTInvariant) + fd.addFlags(cd.flags.collect { case annot : purescala.Definitions.Annotation => annot }) - val cd = classesToClasses(sym) cd.registerMethod(fd) - cd.addFlag(IsADTInvariant) val ctparams = sym.tpe match { case TypeRef(_, _, tps) => extractTypeParams(tps).map(_._1) @@ -381,7 +381,6 @@ trait CodeExtraction extends ASTExtractors { case t => extractFunOrMethodBody(None, t) - } case _ => } @@ -559,9 +558,9 @@ trait CodeExtraction extends ASTExtractors { // Extract class val cd = if (sym.isAbstractClass) { - AbstractClassDef(id, tparams, parent.map(_._1)) + new AbstractClassDef(id, tparams, parent.map(_._1)) } else { - CaseClassDef(id, tparams, parent.map(_._1), sym.isModuleClass) + new CaseClassDef(id, tparams, parent.map(_._1), sym.isModuleClass) } cd.setPos(sym.pos) //println(s"Registering $sym") diff --git a/src/main/scala/leon/purescala/DefOps.scala b/src/main/scala/leon/purescala/DefOps.scala index 8d7f626d8..5363066f6 100644 --- a/src/main/scala/leon/purescala/DefOps.scala +++ b/src/main/scala/leon/purescala/DefOps.scala @@ -290,7 +290,7 @@ object DefOps { case _ => None } - + /** Clones the given program by replacing some functions by other functions. * * @param p The original program @@ -300,10 +300,11 @@ object DefOps { * @return the new program with a map from the old functions to the new functions */ def replaceFunDefs(p: Program)(fdMapF: FunDef => Option[FunDef], fiMapF: (FunctionInvocation, FunDef) => Option[Expr] = defaultFiMap) - : (Program, Map[FunDef, FunDef])= { + : (Program, Map[FunDef, FunDef]) = { var fdMapFCache = Map[FunDef, Option[FunDef]]() // Original fdMapF cache var fdMapCache = Map[FunDef, Option[FunDef]]() // Final replacement. + def fdMapFCached(fd: FunDef): Option[FunDef] = { fdMapFCache.get(fd) match { case Some(e) => e @@ -313,7 +314,7 @@ object DefOps { new_fd } } - + def duplicateParents(fd: FunDef): Unit = { fdMapCache.get(fd) match { case None => @@ -324,7 +325,7 @@ object DefOps { case _ => } } - + def fdMap(fd: FunDef): FunDef = { fdMapCache.get(fd) match { case Some(Some(e)) => e @@ -353,9 +354,9 @@ object DefOps { } ) }) - - for(fd <- newP.definedFunctions) { - if(ExprOps.exists{ + + for (fd <- newP.definedFunctions) { + if (ExprOps.exists{ case FunctionInvocation(TypedFunDef(fd, targs), fargs) => fdMapCache contains fd case MatchExpr(_, cases) => cases.exists(c => PatternOps.exists{ case UnapplyPattern(optId, TypedFunDef(fd, tps), subp) => fdMapCache contains fd @@ -366,6 +367,11 @@ object DefOps { fd.fullBody = replaceFunCalls(fd.fullBody, fdMap, fiMapF) } } + + for (cd <- newP.classHierarchyRoots) { + cd.invariant.foreach(inv => cd.setInvariant(fdMap(inv))) + } + (newP, fdMapCache.collect{ case (ofd, Some(nfd)) => ofd -> nfd }) } @@ -381,7 +387,7 @@ object DefOps { None }(e) } - + def replaceFunCalls(p: Pattern, fdMapF: FunDef => FunDef): Pattern = PatternOps.preMap{ case UnapplyPattern(optId, TypedFunDef(fd, tps), subp) => Some(UnapplyPattern(optId, TypedFunDef(fdMapF(fd), tps), subp)) case _ => None @@ -404,11 +410,13 @@ object DefOps { def replaceCaseClassDefs(p: Program)(cdMapFOriginal: CaseClassDef => Option[Option[AbstractClassType] => CaseClassDef], ciMapF: (CaseClass, CaseClassType) => Option[Expr] = defaultCdMap) : (Program, Map[ClassDef, ClassDef], Map[Identifier, Identifier], Map[FunDef, FunDef]) = { + var cdMapFCache = Map[CaseClassDef, Option[Option[AbstractClassType] => CaseClassDef]]() var cdMapCache = Map[ClassDef, Option[ClassDef]]() var idMapCache = Map[Identifier, Identifier]() var fdMapFCache = Map[FunDef, Option[FunDef]]() var fdMapCache = Map[FunDef, Option[FunDef]]() + def cdMapF(cd: ClassDef): Option[Option[AbstractClassType] => CaseClassDef] = { cd match { case ccd: CaseClassDef => @@ -420,19 +428,20 @@ object DefOps { case acd: AbstractClassDef => None } } + def tpMap[T <: TypeTree](tt: T): T = TypeOps.postMap{ case AbstractClassType(asd, targs) => Some(AbstractClassType(cdMap(asd).asInstanceOf[AbstractClassDef], targs)) case CaseClassType(ccd, targs) => Some(CaseClassType(cdMap(ccd).asInstanceOf[CaseClassDef], targs)) case e => None }(tt).asInstanceOf[T] - + def duplicateClassDef(cd: ClassDef): ClassDef = { cdMapCache.get(cd) match { case Some(new_cd) => new_cd.get // None would have meant that this class would never be duplicated, which is not possible. case None => val parent = cd.parent.map(duplicateAbstractClassType) - val new_cd = cdMapF(cd).map(f => f(parent)).getOrElse{ + val new_cd = cdMapF(cd).map(f => f(parent)).getOrElse { cd match { case acd:AbstractClassDef => acd.duplicate(parent = parent) case ccd:CaseClassDef => @@ -443,7 +452,7 @@ object DefOps { new_cd } } - + def duplicateAbstractClassType(act: AbstractClassType): AbstractClassType = { TypeOps.postMap{ case AbstractClassType(acd, tps) => Some(AbstractClassType(duplicateClassDef(acd).asInstanceOf[AbstractClassDef], tps)) @@ -451,7 +460,7 @@ object DefOps { case _ => None }(act).asInstanceOf[AbstractClassType] } - + // If at least one descendants or known case class needs conversion, then all the hierarchy will be converted. // If something extends List[A] and A is modified, then the first something should be modified. def dependencies(s: ClassDef): Set[ClassDef] = { @@ -461,7 +470,7 @@ object DefOps { case _ => Set() }(p))))(Set(s)) } - + def cdMap(cd: ClassDef): ClassDef = { cdMapCache.get(cd) match { case Some(Some(new_cd)) => new_cd @@ -475,6 +484,7 @@ object DefOps { cdMapCache(cd).getOrElse(cd) } } + def idMap(id: Identifier): Identifier = { if (!(idMapCache contains id)) { val new_id = id.duplicate(tpe = tpMap(id.getType)) @@ -482,11 +492,11 @@ object DefOps { } idMapCache(id) } - + def idHasToChange(id: Identifier): Boolean = { typeHasToChange(id.getType) } - + def typeHasToChange(tp: TypeTree): Boolean = { TypeOps.exists{ case AbstractClassType(acd, _) => cdMap(acd) != acd @@ -494,7 +504,7 @@ object DefOps { case _ => false }(tp) } - + def patternHasToChange(p: Pattern): Boolean = { PatternOps.exists { case CaseClassPattern(optId, cct, sub) => optId.exists(idHasToChange) || typeHasToChange(cct) @@ -503,7 +513,7 @@ object DefOps { case e => false } (p) } - + def exprHasToChange(e: Expr): Boolean = { ExprOps.exists{ case Let(id, expr, body) => idHasToChange(id) @@ -523,11 +533,11 @@ object DefOps { false }(e) } - + def funDefHasToChange(fd: FunDef): Boolean = { exprHasToChange(fd.fullBody) || fd.params.exists(vid => typeHasToChange(vid.id.getType)) || typeHasToChange(fd.returnType) } - + def funHasToChange(fd: FunDef): Boolean = { funDefHasToChange(fd) || p.callGraph.transitiveCallees(fd).exists(fd => fdMapFCache.get(fd) match { @@ -536,7 +546,7 @@ object DefOps { case None => funDefHasToChange(fd) }) } - + def fdMapFCached(fd: FunDef): Option[FunDef] = { fdMapFCache.get(fd) match { case Some(e) => e @@ -550,7 +560,7 @@ object DefOps { new_fd } } - + def duplicateParents(fd: FunDef): Unit = { fdMapCache.get(fd) match { case None => @@ -561,7 +571,7 @@ object DefOps { case _ => } } - + def fdMap(fd: FunDef): FunDef = { fdMapCache.get(fd) match { case Some(Some(e)) => e @@ -575,7 +585,7 @@ object DefOps { fdMapCache(fd).getOrElse(fd) } } - + val newP = p.copy(units = for (u <- p.units) yield { u.copy( defs = u.defs.map { @@ -591,6 +601,7 @@ object DefOps { } ) }) + def replaceClassDefUse(e: Pattern): Pattern = PatternOps.postMap{ case CaseClassPattern(optId, cct, sub) => Some(CaseClassPattern(optId.map(idMap), tpMap[CaseClassType](cct), sub)) case InstanceOfPattern(optId, cct) => Some(InstanceOfPattern(optId.map(idMap), tpMap[ClassType](cct))) @@ -598,7 +609,7 @@ object DefOps { case Extractors.Pattern(Some(id), subp, builder) => Some(builder(Some(idMap(id)), subp)) case e => None }(e) - + def replaceClassDefsUse(e: Expr): Expr = { ExprOps.postMap { case Let(id, expr, body) => Some(Let(idMap(id), expr, body)) @@ -623,19 +634,23 @@ object DefOps { None }(e) } - - for(fd <- newP.definedFunctions) { - if(fdMapCache.getOrElse(fd, None).isDefined) { + + for (fd <- newP.definedFunctions) { + if (fdMapCache.getOrElse(fd, None).isDefined) { fd.fullBody = replaceClassDefsUse(fd.fullBody) } } + + // make sure classDef invariants are correctly assigned to transformed classes + for ((cd, optNew) <- cdMapCache; newCd <- optNew; inv <- newCd.invariant) { + newCd.setInvariant(fdMap(inv)) + } + (newP, cdMapCache.collect{case (cd, Some(new_cd)) => cd -> new_cd}, idMapCache, fdMapCache.collect{case (cd, Some(new_cd)) => cd -> new_cd }) } - - def addDefs(p: Program, cds: Traversable[Definition], after: Definition): Program = { var found = false diff --git a/src/main/scala/leon/purescala/Definitions.scala b/src/main/scala/leon/purescala/Definitions.scala index 98497afd3..3c26299ab 100644 --- a/src/main/scala/leon/purescala/Definitions.scala +++ b/src/main/scala/leon/purescala/Definitions.scala @@ -186,11 +186,11 @@ object Definitions { } lazy val algebraicDataTypes : Map[AbstractClassDef, Seq[CaseClassDef]] = defs.collect { - case c@CaseClassDef(_, _, Some(p), _) => c + case c : CaseClassDef if c.parent.isDefined => c }.groupBy(_.parent.get.classDef) lazy val singleCaseClasses : Seq[CaseClassDef] = defs.collect { - case c @ CaseClassDef(_, _, None, _) => c + case c : CaseClassDef if !c.parent.isDefined => c } } @@ -227,7 +227,7 @@ object Definitions { // Is inlined case object IsInlined extends FunctionFlag // Is an ADT invariant method - case object IsADTInvariant extends FunctionFlag with ClassFlag + case object IsADTInvariant extends FunctionFlag case object IsInner extends FunctionFlag /** Represents a class definition (either an abstract- or a case-class) */ @@ -280,13 +280,14 @@ object Definitions { private var _invariant: Option[FunDef] = None - def invariant = _invariant - def hasInvariant = flags contains IsADTInvariant - def setInvariant(fd: FunDef): Unit = { - addFlag(IsADTInvariant) - _invariant = Some(fd) + def invariant: Option[FunDef] = parent.flatMap(_.classDef.invariant).orElse(_invariant) + def setInvariant(fd: FunDef): Unit = parent match { + case Some(act) => act.classDef.setInvariant(fd) + case None => _invariant = Some(fd) } + def hasInvariant: Boolean = invariant.isDefined || (root.knownChildren.exists(cd => cd.methods.exists(_.isInvariant))) + def annotations: Set[String] = extAnnotations.keySet def extAnnotations: Map[String, Seq[Option[Any]]] = flags.collect { case Annotation(s, args) => s -> args }.toMap @@ -336,9 +337,9 @@ object Definitions { } /** Abstract classes. */ - case class AbstractClassDef(id: Identifier, - tparams: Seq[TypeParameterDef], - parent: Option[AbstractClassType]) extends ClassDef { + class AbstractClassDef(val id: Identifier, + val tparams: Seq[TypeParameterDef], + val parent: Option[AbstractClassType]) extends ClassDef { val fields = Nil val isAbstract = true @@ -362,16 +363,17 @@ object Definitions { ): AbstractClassDef = { val acd = new AbstractClassDef(id, tparams, parent) acd.addFlags(this.flags) - parent.foreach(_.classDef.ancestors.foreach(_.registerChild(acd))) + if (!parent.exists(_.classDef.hasInvariant)) invariant.foreach(inv => acd.setInvariant(inv)) + parent.map(_.classDef.ancestors.map(_.registerChild(acd))) acd.copiedFrom(this) } } /** Case classes/ case objects. */ - case class CaseClassDef(id: Identifier, - tparams: Seq[TypeParameterDef], - parent: Option[AbstractClassType], - isCaseObject: Boolean) extends ClassDef { + class CaseClassDef(val id: Identifier, + val tparams: Seq[TypeParameterDef], + val parent: Option[AbstractClassType], + val isCaseObject: Boolean) extends ClassDef { private var _fields = Seq[ValDef]() @@ -393,14 +395,14 @@ object Definitions { ) } else index } - + lazy val singleCaseClasses : Seq[CaseClassDef] = if (hasParent) Nil else Seq(this) + def typed: CaseClassType = typed(tparams.map(_.tp)) def typed(tps: Seq[TypeTree]): CaseClassType = { require(tps.length == tparams.length) CaseClassType(this, tps) } - def typed: CaseClassType = typed(tparams.map(_.tp)) /** Duplication of this [[CaseClassDef]]. * @note This will not replace recursive [[CaseClassDef]] calls in [[fields]] nor the parent abstract class types @@ -415,9 +417,9 @@ object Definitions { val cd = new CaseClassDef(id, tparams, parent, isCaseObject) cd.setFields(fields) cd.addFlags(this.flags) + if (!parent.exists(_.classDef.hasInvariant)) invariant.foreach(inv => cd.setInvariant(inv)) + parent.map(_.classDef.ancestors.map(_.registerChild(cd))) cd.copiedFrom(this) - parent.foreach(_.classDef.ancestors.foreach(_.registerChild(cd))) - cd } } diff --git a/src/main/scala/leon/purescala/Expressions.scala b/src/main/scala/leon/purescala/Expressions.scala index 6d6382709..8a7294552 100644 --- a/src/main/scala/leon/purescala/Expressions.scala +++ b/src/main/scala/leon/purescala/Expressions.scala @@ -444,7 +444,6 @@ object Expressions { * This is useful e.g. to present counterexamples of generic types. */ case class GenericValue(tp: TypeParameter, id: Int) extends Expr with Terminal { - // TODO: Is it valid that GenericValue(tp, 0) != GenericValue(tp, 1)? val getType = tp } diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala index 3773ebce4..f1b80fd9d 100644 --- a/src/main/scala/leon/purescala/PrettyPrinter.scala +++ b/src/main/scala/leon/purescala/PrettyPrinter.scala @@ -497,10 +497,10 @@ class PrettyPrinter(opts: PrinterOptions, | ${nary(defs, "\n\n")} |}""" - case acd @ AbstractClassDef(id, tparams, parent) => - p"abstract class $id${nary(tparams, ", ", "[", "]")}" + case acd : AbstractClassDef => + p"abstract class ${acd.id}${nary(acd.tparams, ", ", "[", "]")}" - parent.foreach{ par => + acd.parent.foreach{ par => p" extends ${par.id}" } @@ -510,22 +510,22 @@ class PrettyPrinter(opts: PrinterOptions, |}""" } - case ccd @ CaseClassDef(id, tparams, parent, isObj) => - if (isObj) { - p"case object $id" + case ccd : CaseClassDef => + if (ccd.isCaseObject) { + p"case object ${ccd.id}" } else { - p"case class $id" + p"case class ${ccd.id}" } - p"${nary(tparams, ", ", "[", "]")}" + p"${nary(ccd.tparams, ", ", "[", "]")}" - if (!isObj) { + if (!ccd.isCaseObject) { p"(${ccd.fields})" } - parent.foreach { par => + ccd.parent.foreach { par => // Remember child and parents tparams are simple bijection - p" extends ${par.id}${nary(tparams, ", ", "[", "]")}" + p" extends ${par.id}${nary(ccd.tparams, ", ", "[", "]")}" } if (ccd.methods.nonEmpty) { diff --git a/src/main/scala/leon/purescala/Quantification.scala b/src/main/scala/leon/purescala/Quantification.scala index 079be92d1..c815f4ca5 100644 --- a/src/main/scala/leon/purescala/Quantification.scala +++ b/src/main/scala/leon/purescala/Quantification.scala @@ -22,10 +22,15 @@ object Quantification { qargs: A => Set[B] ): Seq[Set[A]] = { def expand(m: A): Set[A] = Set(m) ++ margs(m).flatMap(expand) + def allQArgs(m: A): Set[B] = qargs(m) ++ margs(m).flatMap(allQArgs) val expandedMap: Map[A, Set[A]] = matchers.map(m => m -> expand(m)).toMap val reverseMap : Map[A, Set[A]] = expandedMap.toSeq .flatMap(p => p._2.map(m => m -> p._1)) // flatten to reversed pairs .groupBy(_._1).mapValues(_.map(_._2).toSet) // rebuild map from pair set + .map { case (m, ms) => // filter redundant matchers + val allM = allQArgs(m) + m -> ms.filter(rm => allQArgs(rm) != allM) + } def rec(oms: Seq[A], mSet: Set[A], qss: Seq[Set[B]]): Seq[Set[A]] = { if (qss.contains(quantified)) { diff --git a/src/main/scala/leon/solvers/ADTManager.scala b/src/main/scala/leon/solvers/ADTManager.scala index b52be838a..7aefd9577 100644 --- a/src/main/scala/leon/solvers/ADTManager.scala +++ b/src/main/scala/leon/solvers/ADTManager.scala @@ -136,6 +136,17 @@ class ADTManager(ctx: LeonContext) { findDependencies(base) } + case tp @ TypeParameter(id) => + if (!(discovered contains t) && !(defined contains t)) { + val sym = freshId(id.name) + + val c = Constructor(freshId(sym.name), tp, List( + (freshId("val"), IntegerType) + )) + + discovered += (tp -> DataType(sym, Seq(c))) + } + case _ => } } diff --git a/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala b/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala index 6594eb07e..8a924762a 100644 --- a/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala +++ b/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala @@ -215,180 +215,14 @@ trait AbstractUnrollingSolver[T] private def getPartialModel: PartialModel = { val wrapped = solverGetModel - - val typeInsts = templateGenerator.manager.typeInstantiations - val partialInsts = templateGenerator.manager.partialInstantiations - val lambdaInsts = templateGenerator.manager.lambdaInstantiations - - val typeDomains: Map[TypeTree, Set[Seq[Expr]]] = typeInsts.map { - case (tpe, domain) => tpe -> domain.flatMap { case (b, m) => wrapped.extract(b, m) }.toSet - } - - val funDomains: Map[Identifier, Set[Seq[Expr]]] = freeVars.toMap.map { case (id, idT) => - id -> partialInsts.get(idT).toSeq.flatten.flatMap { case (b, m) => wrapped.extract(b, m) }.toSet - } - - val lambdaDomains: Map[Lambda, Set[Seq[Expr]]] = lambdaInsts.map { - case (l, domain) => l -> domain.flatMap { case (b, m) => wrapped.extract(b, m) }.toSet - } - - val model = new Model(freeVars.toMap.map { case (id, _) => - val value = wrapped.get(id).getOrElse(simplestValue(id.getType)) - id -> (funDomains.get(id) match { - case Some(domain) => - val dflt = value match { - case FiniteLambda(_, dflt, _) => dflt - case Lambda(_, IfExpr(_, _, dflt)) => dflt - case _ => scala.sys.error("Can't extract default from " + value) - } - - FiniteLambda(domain.toSeq.map { es => - val optEv = evaluator.eval(application(value, es)).result - es -> optEv.getOrElse(scala.sys.error("Unexpectedly failed to evaluate " + application(value, es))) - }, dflt, id.getType.asInstanceOf[FunctionType]) - - case None => postMap { - case p @ FiniteLambda(mapping, dflt, tpe) => - Some(FiniteLambda(typeDomains.get(tpe) match { - case Some(domain) => domain.toSeq.map { es => - val optEv = evaluator.eval(application(value, es)).result - es -> optEv.getOrElse(scala.sys.error("Unexpectedly failed to evaluate " + application(value, es))) - } - case _ => Seq.empty - }, dflt, tpe)) - case _ => None - } (value) - }) - }) - - val domains = new Domains(lambdaDomains, typeDomains) - new PartialModel(model.toMap, domains) + val view = templateGenerator.manager.getModel(freeVars.toMap, evaluator, wrapped.get, wrapped.eval) + view.getPartialModel } private def getTotalModel: Model = { val wrapped = solverGetModel - - def checkForalls(quantified: Set[Identifier], body: Expr): Option[String] = { - val matchers = collect[(Expr, Seq[Expr])] { - case QuantificationMatcher(e, args) => Set(e -> args) - case _ => Set.empty - } (body) - - if (matchers.isEmpty) - return Some("No matchers found.") - - val matcherToQuants = matchers.foldLeft(Map.empty[Expr, Set[Identifier]]) { - case (acc, (m, args)) => acc + (m -> (acc.getOrElse(m, Set.empty) ++ args.flatMap { - case Variable(id) if quantified(id) => Set(id) - case _ => Set.empty[Identifier] - })) - } - - val bijectiveMappings = matcherToQuants.filter(_._2.nonEmpty).groupBy(_._2) - if (bijectiveMappings.size > 1) - return Some("Non-bijective mapping for symbol " + bijectiveMappings.head._2.head._1.asString) - - def quantifiedArg(e: Expr): Boolean = e match { - case Variable(id) => quantified(id) - case QuantificationMatcher(_, args) => args.forall(quantifiedArg) - case _ => false - } - - postTraversal(m => m match { - case QuantificationMatcher(_, args) => - val qArgs = args.filter(quantifiedArg) - - if (qArgs.nonEmpty && qArgs.size < args.size) - return Some("Mixed ground and quantified arguments in " + m.asString) - - case Operator(es, _) if es.collect { case Variable(id) if quantified(id) => id }.nonEmpty => - return Some("Invalid operation on quantifiers " + m.asString) - - case (_: Equals) | (_: And) | (_: Or) | (_: Implies) => // OK - - case Operator(es, _) if (es.flatMap(variablesOf).toSet & quantified).nonEmpty => - return Some("Unandled implications from operation " + m.asString) - - case _ => - }) (body) - - body match { - case Variable(id) if quantified(id) => - Some("Unexpected free quantifier " + id.asString) - case _ => None - } - } - - val issues: Iterable[(Seq[Identifier], Expr, String)] = for { - q <- templateGenerator.manager.quantifications.view - if wrapped.eval(q.holds, BooleanType) == Some(BooleanLiteral(true)) - msg <- checkForalls(q.quantifiers.map(_._1).toSet, q.body) - } yield (q.quantifiers.map(_._1), q.body, msg) - - if (issues.nonEmpty) { - val (quantifiers, body, msg) = issues.head - reporter.warning("Model soundness not guaranteed for \u2200" + - quantifiers.map(_.asString).mkString(",") + ". " + body.asString+" :\n => " + msg) - } - - val typeInsts = templateGenerator.manager.typeInstantiations - val partialInsts = templateGenerator.manager.partialInstantiations - - def extractCond(params: Seq[Identifier], args: Seq[(T, Expr)], structure: Map[T, Identifier]): Seq[Expr] = (params, args) match { - case (id +: rparams, (v, arg) +: rargs) => - if (templateGenerator.manager.isQuantifier(v)) { - structure.get(v) match { - case Some(pid) => Equals(Variable(id), Variable(pid)) +: extractCond(rparams, rargs, structure) - case None => extractCond(rparams, rargs, structure + (v -> id)) - } - } else { - Equals(Variable(id), arg) +: extractCond(rparams, rargs, structure) - } - case _ => Seq.empty - } - - new Model(freeVars.toMap.map { case (id, idT) => - val value = wrapped.get(id).getOrElse(simplestValue(id.getType)) - id -> (id.getType match { - case FunctionType(from, to) => - val params = from.map(tpe => FreshIdentifier("x", tpe, true)) - val domain = partialInsts.get(idT).orElse(typeInsts.get(bestRealType(id.getType))).toSeq.flatten - val conditionals = domain.flatMap { case (b, m) => - wrapped.extract(b, m).map { args => - val result = evaluator.eval(application(value, args)).result.getOrElse { - scala.sys.error("Unexpectedly failed to evaluate " + application(value, args)) - } - - val cond = if (m.args.exists(arg => templateGenerator.manager.isQuantifier(arg.encoded))) { - extractCond(params, m.args.map(_.encoded) zip args, Map.empty) - } else { - (params zip args).map(p => Equals(Variable(p._1), p._2)) - } - - cond -> result - } - } - - val filteredConds = conditionals - .foldLeft(Map.empty[Seq[Expr], Expr]) { case (mapping, (conds, result)) => - if (mapping.isDefinedAt(conds)) mapping else mapping + (conds -> result) - } - - if (filteredConds.isEmpty) { - // TODO: warning?? - value - } else { - val rest :+ ((_, dflt)) = filteredConds.toSeq.sortBy(_._1.size) - val body = rest.foldLeft(dflt) { case (elze, (conds, res)) => - if (conds.isEmpty) elze else IfExpr(andJoin(conds), res, elze) - } - - Lambda(params.map(ValDef(_)), body) - } - - case _ => value - }) - }) + val view = templateGenerator.manager.getModel(freeVars.toMap, evaluator, wrapped.get, wrapped.eval) + view.getTotalModel } def genericCheck(assumptions: Set[Expr]): Option[Boolean] = { @@ -430,7 +264,7 @@ trait AbstractUnrollingSolver[T] } else if (partialModels) { (true, getPartialModel) } else { - val clauses = templateGenerator.manager.checkClauses + val clauses = unrollingBank.getFiniteRangeClauses if (clauses.isEmpty) { (true, extractModel(solverGetModel)) } else { @@ -473,8 +307,10 @@ trait AbstractUnrollingSolver[T] if (valid) { foundAnswer(Some(true), model) } else { - reporter.error("Something went wrong. The model should have been valid, yet we got this : ") - reporter.error(model.asString(context)) + reporter.error( + "Something went wrong. The model should have been valid, yet we got this: " + + model.asString(context) + + " for formula " + andJoin(assumptionsSeq ++ constraints).asString) foundAnswer(None, model) } } @@ -534,11 +370,23 @@ trait AbstractUnrollingSolver[T] } case Some(true) => - if (this.feelingLucky && !interrupted) { - // we might have been lucky :D - val model = extractModel(solverGetModel) - val valid = validateModel(model, assumptionsSeq, silenceErrors = true) - if (valid) foundAnswer(Some(true), model) + if (!interrupted) { + val model = solverGetModel + + if (this.feelingLucky) { + // we might have been lucky :D + val extracted = extractModel(model) + val valid = validateModel(extracted, assumptionsSeq, silenceErrors = true) + if (valid) foundAnswer(Some(true), extracted) + } + + if (!foundDefinitiveAnswer) { + unrollingBank.decreaseAllGenerations() + + for (b <- templateGenerator.manager.getBlockersToPromote(model.eval)) { + unrollingBank.promoteBlocker(b) + } + } } case None => diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala index d9deb8790..30ae22f98 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala @@ -50,10 +50,6 @@ trait SMTLIBCVC4Target extends SMTLIBTarget { super.fromSMT(t, otpe) } - case (SimpleSymbol(s), Some(tp: TypeParameter)) => - val n = s.name.split("_").toList.last - GenericValue(tp, n.toInt) - case (QualifiedIdentifier(SMTIdentifier(SSymbol("emptyset"), Seq()), _), Some(SetType(base))) => FiniteSet(Set(), base) diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala index 78b9ccbd9..72b475c5c 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala @@ -108,7 +108,6 @@ abstract class SMTLIBSolver(val context: LeonContext, val program: Program) selectors.push() testers.push() variables.push() - genericValues.push() sorts.push() lambdas.push() functions.push() @@ -122,7 +121,6 @@ abstract class SMTLIBSolver(val context: LeonContext, val program: Program) selectors.pop() testers.pop() variables.pop() - genericValues.pop() sorts.pop() lambdas.pop() functions.pop() diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala index 242f32121..546e62603 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala @@ -153,7 +153,6 @@ trait SMTLIBTarget extends Interruptible { protected val selectors = new IncrementalBijection[(TypeTree, Int), SSymbol]() protected val testers = new IncrementalBijection[TypeTree, SSymbol]() protected val variables = new IncrementalBijection[Identifier, SSymbol]() - protected val genericValues = new IncrementalBijection[GenericValue, SSymbol]() protected val sorts = new IncrementalBijection[TypeTree, Sort]() protected val functions = new IncrementalBijection[TypedFunDef, SSymbol]() protected val lambdas = new IncrementalBijection[FunctionType, SSymbol]() @@ -226,13 +225,6 @@ trait SMTLIBTarget extends Interruptible { unsupported(other, "Unable to extract from raw array for " + tpe) } - protected def declareUninterpretedSort(t: TypeParameter): Sort = { - val s = id2sym(t.id) - val cmd = DeclareSort(s, 0) - emit(cmd) - Sort(SMTIdentifier(s)) - } - protected def declareSort(t: TypeTree): Sort = { val tpe = normalizeType(t) sorts.cachedB(tpe) { @@ -252,10 +244,7 @@ trait SMTLIBTarget extends Interruptible { case FunctionType(from, to) => Ints.IntSort() - case tp: TypeParameter => - declareUninterpretedSort(tp) - - case _: ClassType | _: TupleType | _: ArrayType | UnitType => + case _: ClassType | _: TupleType | _: ArrayType | _: TypeParameter | UnitType => declareStructuralSort(tpe) case other => @@ -532,13 +521,9 @@ trait SMTLIBTarget extends Interruptible { toSMT(matchToIfThenElse(m)) case gv @ GenericValue(tpe, n) => - genericValues.cachedB(gv) { - val v = declareVariable(FreshIdentifier("gv" + n, tpe)) - for ((ogv, ov) <- genericValues.aToB if ogv.getType == tpe) { - emit(SMTAssert(Core.Not(Core.Equals(v, ov)))) - } - v - } + declareSort(tpe) + val constructor = constructors.toB(tpe) + FunctionApplication(constructor, Seq(toSMT(InfiniteIntegerLiteral(n)))) /** * ===== Everything else ===== @@ -803,11 +788,12 @@ trait SMTLIBTarget extends Interruptible { case cct: CaseClassType => val rargs = args.zip(cct.fields.map(_.getType)).map(fromSMT) CaseClass(cct, rargs) + case tt: TupleType => val rargs = args.zip(tt.bases).map(fromSMT) tupleWrap(rargs) - case at@ArrayType(baseType) => + case at @ ArrayType(baseType) => val IntLiteral(size) = fromSMT(args(0), Int32Type) val RawArrayValue(_, elems, default) = fromSMT(args(1), RawArrayType(Int32Type, baseType)) @@ -825,6 +811,10 @@ trait SMTLIBTarget extends Interruptible { finiteArray(entries, None, baseType) } + case tp @ TypeParameter(id) => + val InfiniteIntegerLiteral(n) = fromSMT(args(0), IntegerType) + GenericValue(tp, n.toInt) + case t => unsupported(t, "Woot? structural type that is non-structural") } diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala index 5744e1a10..1be3f7ecf 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala @@ -72,10 +72,6 @@ trait SMTLIBZ3Target extends SMTLIBTarget { override protected def fromSMT(t: Term, otpe: Option[TypeTree] = None) (implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): Expr = { (t, otpe) match { - case (SimpleSymbol(s), Some(tp: TypeParameter)) => - val n = s.name.split("!").toList.last - GenericValue(tp, n.toInt) - case (QualifiedIdentifier(ExtendedIdentifier(SSymbol("as-array"), k: SSymbol), _), Some(tpe)) => if (letDefs contains k) { // Need to recover value form function model diff --git a/src/main/scala/leon/solvers/templates/LambdaManager.scala b/src/main/scala/leon/solvers/templates/LambdaManager.scala index f80a143a5..1e715ce1d 100644 --- a/src/main/scala/leon/solvers/templates/LambdaManager.scala +++ b/src/main/scala/leon/solvers/templates/LambdaManager.scala @@ -99,7 +99,7 @@ trait KeyedTemplate[T, E <: Expr] { case _ => Seq.empty } - structure -> rec(structure).map(dependencies) + structure -> rec(structure).distinct.map(dependencies) } } @@ -325,7 +325,7 @@ class LambdaManager[T](encoder: TemplateEncoder[T]) extends DatatypeManager(enco instantiated += key if (knownFree(tpe) contains caller) { - instantiation withApp (key -> TemplateAppInfo(caller, trueT, args)) + instantiation } else if (byID contains caller) { instantiation withApp (key -> TemplateAppInfo(byID(caller), trueT, args)) } else { diff --git a/src/main/scala/leon/solvers/templates/QuantificationManager.scala b/src/main/scala/leon/solvers/templates/QuantificationManager.scala index d10b786e5..4509e6069 100644 --- a/src/main/scala/leon/solvers/templates/QuantificationManager.scala +++ b/src/main/scala/leon/solvers/templates/QuantificationManager.scala @@ -6,13 +6,16 @@ package templates import leon.utils._ import purescala.Common._ +import purescala.Definitions._ import purescala.Extractors._ import purescala.Constructors._ import purescala.Expressions._ import purescala.ExprOps._ import purescala.Types._ import purescala.TypeOps._ -import purescala.Quantification.{QuantificationTypeMatcher => QTM} +import purescala.Quantification.{QuantificationTypeMatcher => QTM, QuantificationMatcher => QM, Domains} + +import evaluators._ import Instantiation._ import Template._ @@ -55,7 +58,8 @@ class QuantificationTemplate[T]( val matchers: Map[T, Set[Matcher[T]]], val lambdas: Seq[LambdaTemplate[T]], val dependencies: Map[Identifier, T], - val struct: (Forall, Map[Identifier, Identifier])) extends KeyedTemplate[T, Forall] { + val struct: (Forall, Map[Identifier, Identifier]), + stringRepr: () => String) extends KeyedTemplate[T, Forall] { val structure = struct._1 lazy val start = pathVar._2 @@ -89,9 +93,13 @@ class QuantificationTemplate[T]( }, lambdas.map(_.substitute(substituter, matcherSubst)), dependencies.map { case (id, value) => id -> substituter(value) }, - struct + struct, + stringRepr ) } + + private lazy val str : String = stringRepr() + override def toString : String = str } object QuantificationTemplate { @@ -118,7 +126,7 @@ object QuantificationTemplate { val insts: (Identifier, T) = inst -> encoder.encodeId(inst) val guards: (Identifier, T) = guard -> encoder.encodeId(guard) - val (clauses, blockers, applications, functions, matchers, _) = + val (clauses, blockers, applications, functions, matchers, templateString) = Template.encode(encoder, pathVar, quantifiers, condVars, exprVars, guardedExprs, lambdas, Seq.empty, substMap = baseSubstMap + q2s + insts + guards + qs) @@ -128,7 +136,8 @@ object QuantificationTemplate { new QuantificationTemplate[T](quantificationManager, pathVar, qs, q2s, insts, guards._2, quantifiers, condVars, exprVars, condTree, - clauses, blockers, applications, matchers, lambdas, keyDeps, key -> structSubst) + clauses, blockers, applications, matchers, lambdas, keyDeps, key -> structSubst, + () => "Template for " + proposition + " is :\n" + templateString()) } } @@ -171,37 +180,46 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage qkey == key || (qkey.tpe == key.tpe && (qkey.isInstanceOf[TypeKey] || key.isInstanceOf[TypeKey])) } - private val uniformQuantMap: MutableMap[TypeTree, Seq[T]] = MutableMap.empty - private val uniformQuantSet: MutableSet[T] = MutableSet.empty + class VariableNormalizer { + private val varMap: MutableMap[TypeTree, Seq[T]] = MutableMap.empty + private val varSet: MutableSet[T] = MutableSet.empty - def isQuantifier(idT: T): Boolean = uniformQuantSet(idT) - def uniformQuants(ids: Seq[Identifier]): Seq[T] = { - val mapping = ids.groupBy(id => bestRealType(id.getType)).flatMap { case (tpe, idst) => - val prev = uniformQuantMap.get(tpe) match { - case Some(seq) => seq - case None => Seq.empty - } + def normalize(ids: Seq[Identifier]): Seq[T] = { + val mapping = ids.groupBy(id => bestRealType(id.getType)).flatMap { case (tpe, idst) => + val prev = varMap.get(tpe) match { + case Some(seq) => seq + case None => Seq.empty + } - if (prev.size >= idst.size) { - idst zip prev.take(idst.size) - } else { - val (handled, newIds) = idst.splitAt(prev.size) - val uIds = newIds.map(id => id -> encoder.encodeId(id)) + if (prev.size >= idst.size) { + idst zip prev.take(idst.size) + } else { + val (handled, newIds) = idst.splitAt(prev.size) + val uIds = newIds.map(id => id -> encoder.encodeId(id)) - uniformQuantMap(tpe) = prev ++ uIds.map(_._2) - uniformQuantSet ++= uIds.map(_._2) + varMap(tpe) = prev ++ uIds.map(_._2) + varSet ++= uIds.map(_._2) - (handled zip prev) ++ uIds - } - }.toMap + (handled zip prev) ++ uIds + } + }.toMap - ids.map(mapping) - } + ids.map(mapping) + } + + def normalSubst(qs: Seq[(Identifier, T)]): Map[T, T] = { + (qs.map(_._2) zip normalize(qs.map(_._1))).toMap + } - private def uniformSubst(qs: Seq[(Identifier, T)]): Map[T, T] = { - (qs.map(_._2) zip uniformQuants(qs.map(_._1))).toMap + def contains(idT: T): Boolean = varSet(idT) + def get(tpe: TypeTree): Option[Seq[T]] = varMap.get(tpe) } + private val abstractNormalizer = new VariableNormalizer + private val concreteNormalizer = new VariableNormalizer + + def isQuantifier(idT: T): Boolean = abstractNormalizer.contains(idT) + override def assumptions: Seq[T] = super.assumptions ++ quantifications.collect { case q: Quantification => q.currentQ2Var }.toSeq @@ -217,11 +235,16 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage case (CallerKey(caller, tpe), matchers) => caller -> (matchers ++ instCtx.map.get(TypeKey(tpe)).toMatchers) } - private def matcherDepth(m: Matcher[T]): Int = 1 + (0 +: m.args.map { - case Right(ma) => matcherDepth(ma) + private def maxDepth(m: Matcher[T]): Int = 1 + (0 +: m.args.map { + case Right(ma) => maxDepth(ma) case _ => 0 }).max + private def totalDepth(m: Matcher[T]): Int = 1 + m.args.map { + case Right(ma) => totalDepth(ma) + case _ => 0 + }.sum + private def encodeEnablers(es: Set[T]): T = if (es.isEmpty) trueT else encoder.mkAnd(es.toSeq.sortBy(_.toString) : _*) @@ -365,10 +388,10 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage lazy val quantified: Set[T] = quantifiers.map(_._2).toSet lazy val start = pathVar._2 - private lazy val depth = matchers.map(matcherDepth).max + private lazy val depth = matchers.map(maxDepth).max private lazy val transMatchers: Set[Matcher[T]] = (for { (b, ms) <- allMatchers.toSeq - m <- ms if !matchers(m) && matcherDepth(m) <= depth + m <- ms if !matchers(m) && maxDepth(m) <= depth } yield m).toSet /* Build a mapping from applications in the quantified statement to all potential concrete @@ -402,19 +425,6 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage } : _*) } - /* 2.3. filter out bindings that don't make sense where abstract sub-matchers - * (matchers in arguments of other matchers) are bound to different concrete - * matchers in the argument and quorum positions - */ - allMappings.filter { s => - def expand(ms: Traversable[(Arg[T], Arg[T])]): Set[(Matcher[T], Matcher[T])] = ms.flatMap { - case (Right(qm), Right(m)) => Set(qm -> m) ++ expand(qm.args zip m.args) - case _ => Set.empty[(Matcher[T], Matcher[T])] - }.toSet - - expand(s.map(p => Right(p._2) -> Right(p._3))).groupBy(_._1).forall(_._2.size == 1) - } - allMappings } } @@ -462,7 +472,9 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage if (!skip(subst)) { if (!isStrict) { - ignoreSubst(enablers, subst) + val msubst = subst.collect { case (c, Right(m)) => c -> m } + val substituter = encoder.substitute(subst.mapValues(_.encoded)) + ignoredSubsts(this) += ((currentGen + 3, enablers, subst)) } else { instantiation ++= instantiateSubst(enablers, subst, strict = true) } @@ -491,6 +503,8 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage val msubst = substMap.collect { case (c, Right(m)) => c -> m } val substituter = encoder.substitute(substMap.mapValues(_.encoded)) + registerBlockers(substituter) + instantiation ++= Template.instantiate(encoder, QuantificationManager.this, clauses, blockers, applications, Map.empty, substMap) @@ -501,7 +515,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage if (strict && (matchers(m) || transMatchers(m))) { instantiation ++= instCtx.instantiate(sb, sm)(quantifications.toSeq : _*) } else if (!matchers(m)) { - ignoredMatchers += ((currentGen + 3, sb, sm)) + ignoredMatchers += ((currentGen + 2 + totalDepth(m), sb, sm)) } } @@ -509,20 +523,11 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage } } - def ignoreSubst(enablers: Set[T], subst: Map[T, Arg[T]]): Unit = { - val msubst = subst.collect { case (c, Right(m)) => c -> m } - val substituter = encoder.substitute(subst.mapValues(_.encoded)) - val nextGen = (if (matchers.forall { m => - val sm = m.substitute(substituter, msubst) - instCtx(enablers -> sm) - }) currentGen + 3 else currentGen + 3) - - ignoredSubsts(this) += ((nextGen, enablers, subst)) - } - protected def instanceSubst(enabler: T): Map[T, T] protected def skip(subst: Map[T, Arg[T]]): Boolean = false + + protected def registerBlockers(substituter: T => T): Unit = () } private class Quantification ( @@ -543,7 +548,8 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage val lambdas: Seq[LambdaTemplate[T]], val template: QuantificationTemplate[T]) extends MatcherQuantification { - var currentQ2Var: T = qs._2 + private var _currentQ2Var: T = qs._2 + def currentQ2Var = _currentQ2Var val holds = qs._2 val body = { val quantified = quantifiers.map(_._1).toSet @@ -551,15 +557,24 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage replaceFromIDs(mapping, template.structure.body) } + private var _currentInsts: Map[T, Set[T]] = Map.empty + def currentInsts = _currentInsts + protected def instanceSubst(enabler: T): Map[T, T] = { val nextQ2Var = encoder.encodeId(q2s._1) val subst = Map(qs._2 -> currentQ2Var, guardVar -> enabler, q2s._2 -> nextQ2Var, insts._2 -> encoder.encodeId(insts._1)) - currentQ2Var = nextQ2Var + _currentQ2Var = nextQ2Var subst } + + override def registerBlockers(substituter: T => T): Unit = { + val freshInst = substituter(insts._2) + val bs = (blockers.keys ++ applications.keys).map(substituter).toSet + _currentInsts += freshInst -> bs + } } private lazy val blockerId = FreshIdentifier("blocker", BooleanType, true) @@ -647,69 +662,6 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage (m: Matcher[T]) => m.args.collect { case Left(a) if quantified(a) => a }.toSet) } - private def instantiateConstants(quantifiers: Seq[(Identifier, T)], matchers: Set[Matcher[T]]): Instantiation[T] = { - val quantifierSubst = uniformSubst(quantifiers) - val substituter = encoder.substitute(quantifierSubst) - var instantiation: Instantiation[T] = Instantiation.empty - - for { - m <- matchers - sm = m.substitute(substituter, Map.empty) - if !instCtx.corresponding(sm).exists(_._2.args == sm.args) - } { - instantiation ++= instCtx.instantiate(Set.empty, m)(quantifications.toSeq : _*) - instantiation ++= instCtx.instantiate(Set.empty, sm)(quantifications.toSeq : _*) - } - - def unifyMatchers(matchers: Seq[Matcher[T]]): Unit = matchers match { - case sm +: others => - for (pm <- others if correspond(pm, sm)) { - val encodedArgs = (sm.args zip pm.args).map(p => p._1.encoded -> p._2.encoded) - val mismatches = encodedArgs.zipWithIndex.collect { - case ((sa, pa), idx) if isQuantifier(sa) && isQuantifier(pa) && sa != pa => (idx, (pa, sa)) - }.toMap - - def extractChains(indexes: Seq[Int], partials: Seq[Seq[Int]]): Seq[Seq[Int]] = indexes match { - case idx +: xs => - val (p1, p2) = mismatches(idx) - val newPartials = Seq(idx) +: partials.map { seq => - if (mismatches(seq.head)._1 == p2) idx +: seq - else if (mismatches(seq.last)._2 == p1) seq :+ idx - else seq - } - - val (closed, remaining) = newPartials.partition { seq => - mismatches(seq.head)._1 == mismatches(seq.last)._2 - } - closed ++ extractChains(xs, partials ++ remaining) - - case _ => Seq.empty - } - - val chains = extractChains(mismatches.keys.toSeq, Seq.empty) - val positions = chains.foldLeft(Map.empty[Int, Int]) { (mapping, seq) => - val res = seq.min - mapping ++ seq.map(i => i -> res) - } - - def extractArgs(args: Seq[Arg[T]]): Seq[Arg[T]] = - (0 until args.size).map(i => args(positions.getOrElse(i, i))) - - instantiation ++= instCtx.instantiate(Set.empty, sm.copy(args = extractArgs(sm.args)))(quantifications.toSeq : _*) - instantiation ++= instCtx.instantiate(Set.empty, pm.copy(args = extractArgs(pm.args)))(quantifications.toSeq : _*) - } - - unifyMatchers(others) - - case _ => - } - - val substMatchers = matchers.map(_.substitute(substituter, Map.empty)) - unifyMatchers(substMatchers.toSeq) - - instantiation - } - def instantiateAxiom(template: LambdaTemplate[T], substMap: Map[T, Arg[T]]): Instantiation[T] = { def quantifiedMatcher(m: Matcher[T]): Boolean = m.args.exists(a => a match { case Left(v) => isQuantifier(v) @@ -724,7 +676,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage } } - val quantifiers = quantified zip uniformQuants(quantified) + val quantifiers = quantified zip abstractNormalizer.normalize(quantified) val key = template.key -> quantifiers if (quantifiers.isEmpty || lambdaAxioms(key)) { @@ -805,7 +757,8 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage case None => val qT = encoder.encodeId(template.qs._1) val quantified = template.quantifiers.map(_._2).toSet - val matchQuorums = extractQuorums(quantified, template.matchers.flatMap(_._2).toSet, template.lambdas) + val matcherSet = template.matchers.flatMap(_._2).toSet + val matchQuorums = extractQuorums(quantified, matcherSet, template.lambdas) var instantiation = Instantiation.empty[T] @@ -843,7 +796,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage encoder.mkImplies(template.start, encoder.mkEquals(qT, newQs)) } - instantiation ++= instantiateConstants(template.quantifiers, template.matchers.flatMap(_._2).toSet) + instantiation ++= instantiateConstants(template.quantifiers, matcherSet) templates += template.key -> qT (qT, instantiation) @@ -874,7 +827,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage } for ((bs,m) <- matchersToRelease) { - instCtx.instantiate(bs, m)(quantifications.toSeq : _*) + instantiation ++= instCtx.instantiate(bs, m)(quantifications.toSeq : _*) } val substsToRelease = quantifications.toList.flatMap { q => @@ -896,9 +849,74 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage instantiation } + private def instantiateConstants(quantifiers: Seq[(Identifier, T)], matchers: Set[Matcher[T]]): Instantiation[T] = { + var instantiation: Instantiation[T] = Instantiation.empty + + for (normalizer <- List(abstractNormalizer, concreteNormalizer)) { + val quantifierSubst = normalizer.normalSubst(quantifiers) + val substituter = encoder.substitute(quantifierSubst) + + for { + m <- matchers + sm = m.substitute(substituter, Map.empty) + if !instCtx.corresponding(sm).exists(_._2.args == sm.args) + } instantiation ++= instCtx.instantiate(Set.empty, sm)(quantifications.toSeq : _*) + + def unifyMatchers(matchers: Seq[Matcher[T]]): Instantiation[T] = matchers match { + case sm +: others => + var instantiation = Instantiation.empty[T] + for (pm <- others if correspond(pm, sm)) { + val encodedArgs = (sm.args zip pm.args).map(p => p._1.encoded -> p._2.encoded) + val mismatches = encodedArgs.zipWithIndex.collect { + case ((sa, pa), idx) if isQuantifier(sa) && isQuantifier(pa) && sa != pa => (idx, (pa, sa)) + }.toMap + + def extractChains(indexes: Seq[Int], partials: Seq[Seq[Int]]): Seq[Seq[Int]] = indexes match { + case idx +: xs => + val (p1, p2) = mismatches(idx) + val newPartials = Seq(idx) +: partials.map { seq => + if (mismatches(seq.head)._1 == p2) idx +: seq + else if (mismatches(seq.last)._2 == p1) seq :+ idx + else seq + } + + val (closed, remaining) = newPartials.partition { seq => + mismatches(seq.head)._1 == mismatches(seq.last)._2 + } + closed ++ extractChains(xs, partials ++ remaining) + + case _ => Seq.empty + } + + val chains = extractChains(mismatches.keys.toSeq, Seq.empty) + val positions = chains.foldLeft(Map.empty[Int, Int]) { (mapping, seq) => + val res = seq.min + mapping ++ seq.map(i => i -> res) + } + + def extractArgs(args: Seq[Arg[T]]): Seq[Arg[T]] = + (0 until args.size).map(i => args(positions.getOrElse(i, i))) + + instantiation ++= instCtx.instantiate(Set.empty, sm.copy(args = extractArgs(sm.args)))(quantifications.toSeq : _*) + instantiation ++= instCtx.instantiate(Set.empty, pm.copy(args = extractArgs(pm.args)))(quantifications.toSeq : _*) + } + + instantiation ++ unifyMatchers(others) + + case _ => Instantiation.empty[T] + } + + if (normalizer == abstractNormalizer) { + val substMatchers = matchers.map(_.substitute(substituter, Map.empty)) + instantiation ++= unifyMatchers(substMatchers.toSeq) + } + } + + instantiation + } + def checkClauses: Seq[T] = { val clauses = new scala.collection.mutable.ListBuffer[T] - //val keySets = scala.collection.mutable.Map.empty[MatcherKey, T] val keyClause = MutableMap.empty[MatcherKey, (Seq[T], T)] for ((_, bs, m) <- ignoredMatchers) { @@ -961,21 +979,270 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage } } + def isQuantified(e: Arg[T]): Boolean = e match { + case Left(t) => isQuantifier(t) + case Right(m) => m.args.exists(isQuantified) + } + for ((key, ctx) <- instCtx.map.instantiations) { val QTM(argTypes, _) = key.tpe for { - (tpe,idx) <- argTypes.zipWithIndex - quants <- uniformQuantMap.get(tpe) if quants.nonEmpty + (tpe, idx) <- argTypes.zipWithIndex + quants <- abstractNormalizer.get(tpe) if quants.nonEmpty (b, m) <- ctx - arg = m.args(idx).encoded if !isQuantifier(arg) - } clauses += encoder.mkAnd(quants.map(q => encoder.mkNot(encoder.mkEquals(q, arg))) : _*) - } + arg = m.args(idx) if !isQuantified(arg) + } clauses += encoder.mkAnd(quants.map(q => encoder.mkNot(encoder.mkEquals(q, arg.encoded))) : _*) + + val byPosition: Iterable[Seq[T]] = ctx.flatMap { case (b, m) => + if (b != trueT) Seq.empty else m.args.zipWithIndex + }.groupBy(_._2).map(p => p._2.toSeq.flatMap { + case (a, _) => if (isQuantified(a)) Some(a.encoded) else None + }).filter(_.nonEmpty) - for ((tpe, base +: rest) <- uniformQuantMap; q <- rest) { - clauses += encoder.mkEquals(base, q) + for ((a +: as) <- byPosition; a2 <- as) { + clauses += encoder.mkEquals(a, a2) + } } clauses.toSeq } + + trait ModelView { + protected val vars: Map[Identifier, T] + protected val evaluator: evaluators.DeterministicEvaluator + + protected def get(id: Identifier): Option[Expr] + protected def eval(elem: T, tpe: TypeTree): Option[Expr] + + implicit lazy val context = evaluator.context + lazy val reporter = context.reporter + + private def extract(b: T, m: Matcher[T]): Option[Seq[Expr]] = { + val QTM(fromTypes, _) = m.tpe + val optEnabler = eval(b, BooleanType) + optEnabler.filter(_ == BooleanLiteral(true)).flatMap { _ => + val optArgs = (m.args zip fromTypes).map { case (arg, tpe) => eval(arg.encoded, tpe) } + if (optArgs.forall(_.isDefined)) Some(optArgs.map(_.get)) + else None + } + } + + private def functionsOf(expr: Expr, path: Expr): (Seq[(Expr, Expr)], Seq[Expr] => Expr) = { + + def reconstruct(subs: Seq[(Seq[(Expr, Expr)], Seq[Expr] => Expr)], + recons: Seq[Expr] => Expr): (Seq[(Expr, Expr)], Seq[Expr] => Expr) = + (subs.flatMap(_._1), (exprs: Seq[Expr]) => { + var curr = exprs + recons(subs.map { case (es, recons) => + val (used, remaining) = curr.splitAt(es.size) + curr = remaining + recons(used) + }) + }) + + def rec(expr: Expr, path: Expr): (Seq[(Expr, Expr)], Seq[Expr] => Expr) = expr match { + case (_: Lambda) | (_: FiniteLambda) => + (Seq(expr -> path), (es: Seq[Expr]) => es.head) + + case Tuple(es) => reconstruct(es.zipWithIndex.map { + case (e, i) => rec(e, TupleSelect(path, i + 1)) + }, Tuple) + + case CaseClass(cct, es) => reconstruct((cct.classDef.fieldsIds zip es).map { + case (id, e) => rec(e, CaseClassSelector(cct, path, id)) + }, CaseClass(cct, _)) + + case _ => (Seq.empty, (es: Seq[Expr]) => expr) + } + + rec(expr, path) + } + + def getPartialModel: PartialModel = { + val typeDomains: Map[TypeTree, Set[Seq[Expr]]] = typeInstantiations.map { + case (tpe, domain) => tpe -> domain.flatMap { case (b, m) => extract(b, m) }.toSet + } + + val lambdaDomains: Map[Lambda, Set[Seq[Expr]]] = lambdaInstantiations.map { + case (l, domain) => l -> domain.flatMap { case (b, m) => extract(b, m) }.toSet + } + + val domains = new Domains(lambdaDomains, typeDomains) + + val partialDomains: Map[T, Set[Seq[Expr]]] = partialInstantiations.map { + case (t, domain) => t -> domain.flatMap { case (b, m) => extract(b, m) }.toSet + } + + def extractElse(body: Expr): Expr = body match { + case IfExpr(cond, thenn, elze) => extractElse(elze) + case _ => body + } + + val mapping = vars.map { case (id, idT) => + val value = get(id).getOrElse(simplestValue(id.getType)) + val (functions, recons) = functionsOf(value, Variable(id)) + + id -> recons(functions.map { case (f, path) => + val encoded = encoder.encodeExpr(Map(id -> idT))(path) + val tpe = bestRealType(f.getType).asInstanceOf[FunctionType] + partialDomains.get(encoded).orElse(typeDomains.get(tpe)).map { domain => + FiniteLambda(domain.toSeq.map { es => + val optEv = evaluator.eval(application(f, es)).result + es -> optEv.getOrElse(scala.sys.error("Unexpectedly failed to evaluate " + application(f, es))) + }, f match { + case FiniteLambda(_, dflt, _) => dflt + case Lambda(_, body) => extractElse(body) + case _ => scala.sys.error("What kind of function is this : " + f.asString + " !?") + }, tpe) + }.getOrElse(f) + }) + } + + new PartialModel(mapping, domains) + } + + def getTotalModel: Model = { + + def checkForalls(quantified: Set[Identifier], body: Expr): Option[String] = { + val matchers = purescala.ExprOps.collect[(Expr, Seq[Expr])] { + case QM(e, args) => Set(e -> args) + case _ => Set.empty + } (body) + + if (matchers.isEmpty) + return Some("No matchers found.") + + val matcherToQuants = matchers.foldLeft(Map.empty[Expr, Set[Identifier]]) { + case (acc, (m, args)) => acc + (m -> (acc.getOrElse(m, Set.empty) ++ args.flatMap { + case Variable(id) if quantified(id) => Set(id) + case _ => Set.empty[Identifier] + })) + } + + val bijectiveMappings = matcherToQuants.filter(_._2.nonEmpty).groupBy(_._2) + if (bijectiveMappings.size > 1) + return Some("Non-bijective mapping for symbol " + bijectiveMappings.head._2.head._1.asString) + + def quantifiedArg(e: Expr): Boolean = e match { + case Variable(id) => quantified(id) + case QM(_, args) => args.forall(quantifiedArg) + case _ => false + } + + purescala.ExprOps.postTraversal(m => m match { + case QM(_, args) => + val qArgs = args.filter(quantifiedArg) + + if (qArgs.nonEmpty && qArgs.size < args.size) + return Some("Mixed ground and quantified arguments in " + m.asString) + + case Operator(es, _) if es.collect { case Variable(id) if quantified(id) => id }.nonEmpty => + return Some("Invalid operation on quantifiers " + m.asString) + + case (_: Equals) | (_: And) | (_: Or) | (_: Implies) | (_: Not) => // OK + + case Operator(es, _) if (es.flatMap(variablesOf).toSet & quantified).nonEmpty => + return Some("Unandled implications from operation " + m.asString) + + case _ => + }) (body) + + body match { + case Variable(id) if quantified(id) => + Some("Unexpected free quantifier " + id.asString) + case _ => None + } + } + + val issues: Iterable[(Seq[Identifier], Expr, String)] = for { + q <- quantifications.view + if eval(q.holds, BooleanType) == Some(BooleanLiteral(true)) + msg <- checkForalls(q.quantifiers.map(_._1).toSet, q.body) + } yield (q.quantifiers.map(_._1), q.body, msg) + + if (issues.nonEmpty) { + val (quantifiers, body, msg) = issues.head + reporter.warning("Model soundness not guaranteed for \u2200" + + quantifiers.map(_.asString).mkString(",") + ". " + body.asString+" :\n => " + msg) + } + + val types = typeInstantiations + val partials = partialInstantiations + + def extractCond(params: Seq[Identifier], args: Seq[(T, Expr)], structure: Map[T, Identifier]): Seq[Expr] = (params, args) match { + case (id +: rparams, (v, arg) +: rargs) => + if (isQuantifier(v)) { + structure.get(v) match { + case Some(pid) => Equals(Variable(id), Variable(pid)) +: extractCond(rparams, rargs, structure) + case None => extractCond(rparams, rargs, structure + (v -> id)) + } + } else { + Equals(Variable(id), arg) +: extractCond(rparams, rargs, structure) + } + case _ => Seq.empty + } + + new Model(vars.map { case (id, idT) => + val value = get(id).getOrElse(simplestValue(id.getType)) + val (functions, recons) = functionsOf(value, Variable(id)) + + id -> recons(functions.map { case (f, path) => + val encoded = encoder.encodeExpr(Map(id -> idT))(path) + val tpe = bestRealType(f.getType).asInstanceOf[FunctionType] + val params = tpe.from.map(tpe => FreshIdentifier("x", tpe, true)) + partials.get(encoded).orElse(types.get(tpe)).map { domain => + val conditionals = domain.flatMap { case (b, m) => + extract(b, m).map { args => + val result = evaluator.eval(application(value, args)).result.getOrElse { + scala.sys.error("Unexpectedly failed to evaluate " + application(value, args)) + } + + val cond = if (m.args.exists(arg => isQuantifier(arg.encoded))) { + extractCond(params, m.args.map(_.encoded) zip args, Map.empty) + } else { + (params zip args).map(p => Equals(Variable(p._1), p._2)) + } + + cond -> result + } + }.toMap + + if (conditionals.isEmpty) { + value + } else { + val ((_, dflt)) +: rest = conditionals.toSeq.sortBy { case (conds, _) => + (conds.flatMap(variablesOf).toSet.size, conds.size) + } + + val body = rest.foldLeft(dflt) { case (elze, (conds, res)) => + if (conds.isEmpty) elze else (elze match { + case pres if res == pres => res + case _ => IfExpr(andJoin(conds), res, elze) + }) + } + + Lambda(params.map(ValDef(_)), body) + } + }.getOrElse(f) + }) + }) + } + } + + def getModel(vs: Map[Identifier, T], ev: DeterministicEvaluator, _get: Identifier => Option[Expr], _eval: (T, TypeTree) => Option[Expr]) = new ModelView { + val vars: Map[Identifier, T] = vs + val evaluator: DeterministicEvaluator = ev + + def get(id: Identifier): Option[Expr] = _get(id) + def eval(elem: T, tpe: TypeTree): Option[Expr] = _eval(elem, tpe) + } + + def getBlockersToPromote(eval: (T, TypeTree) => Option[Expr]): Seq[T] = quantifications.toSeq.flatMap { + case q: Quantification if eval(q.qs._2, BooleanType) == Some(BooleanLiteral(false)) => + val falseInsts = q.currentInsts.filter { case (inst, bs) => eval(inst, BooleanType) == Some(BooleanLiteral(false)) } + falseInsts.flatMap(_._2) + case _ => Seq.empty + } } + diff --git a/src/main/scala/leon/solvers/templates/TemplateGenerator.scala b/src/main/scala/leon/solvers/templates/TemplateGenerator.scala index 34b7b35f4..9126f77c1 100644 --- a/src/main/scala/leon/solvers/templates/TemplateGenerator.scala +++ b/src/main/scala/leon/solvers/templates/TemplateGenerator.scala @@ -98,7 +98,6 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], val substMap : Map[Identifier, T] = arguments.toMap + pathVar - val (bodyConds, bodyExprs, bodyTree, bodyGuarded, bodyLambdas, bodyQuantifications) = if (isRealFunDef) { invocationEqualsBody.foldLeft(emptyClauses)((clsSet, cls) => clsSet ++ mkClauses(start, cls, substMap)) } else { diff --git a/src/main/scala/leon/solvers/templates/TemplateManager.scala b/src/main/scala/leon/solvers/templates/TemplateManager.scala index 2bb0cbbd0..8c4752dbf 100644 --- a/src/main/scala/leon/solvers/templates/TemplateManager.scala +++ b/src/main/scala/leon/solvers/templates/TemplateManager.scala @@ -24,7 +24,9 @@ object Instantiation { type AppBlockers[T] = Map[(T, App[T]), Set[TemplateAppInfo[T]]] type Instantiation[T] = (Clauses[T], CallBlockers[T], AppBlockers[T]) - def empty[T] = (Seq.empty[T], Map.empty[T, Set[TemplateCallInfo[T]]], Map.empty[(T, App[T]), Set[TemplateAppInfo[T]]]) + def empty[T] = (Seq.empty[T], + Map.empty[T, Set[TemplateCallInfo[T]]], + Map.empty[(T, App[T]), Set[TemplateAppInfo[T]]]) implicit class MapSetWrapper[A,B](map: Map[A,Set[B]]) { def merge(that: Map[A,Set[B]]): Map[A,Set[B]] = (map.keys ++ that.keys).map { k => @@ -51,14 +53,13 @@ object Instantiation { def withCalls(calls: CallBlockers[T]): Instantiation[T] = (i._1, i._2 merge calls, i._3) def withApps(apps: AppBlockers[T]): Instantiation[T] = (i._1, i._2, i._3 merge apps) - def withApp(app: ((T, App[T]), TemplateAppInfo[T])): Instantiation[T] = { + def withApp(app: ((T, App[T]), TemplateAppInfo[T])): Instantiation[T] = (i._1, i._2, i._3 merge Map(app._1 -> Set(app._2))) - } } } import Instantiation.{empty => _, _} -import Template.Arg +import Template.{Apps, Calls, Functions, Arg} trait Template[T] { self => val encoder : TemplateEncoder[T] @@ -71,9 +72,9 @@ trait Template[T] { self => val exprVars : Map[Identifier, T] val condTree : Map[Identifier, Set[Identifier]] - val clauses : Seq[T] - val blockers : Map[T, Set[TemplateCallInfo[T]]] - val applications : Map[T, Set[App[T]]] + val clauses : Clauses[T] + val blockers : Calls[T] + val applications : Apps[T] val functions : Set[(T, FunctionType, T)] val lambdas : Seq[LambdaTemplate[T]] @@ -132,6 +133,7 @@ object Template { Matcher(encodeExpr(caller), bestRealType(caller.getType), arguments.map(arg => Left(encodeExpr(arg))), encodeExpr(app)) } + type Calls[T] = Map[T, Set[TemplateCallInfo[T]]] type Apps[T] = Map[T, Set[App[T]]] type Functions[T] = Set[(T, FunctionType, T)] @@ -147,7 +149,7 @@ object Template { substMap: Map[Identifier, T] = Map.empty[Identifier, T], optCall: Option[TypedFunDef] = None, optApp: Option[(T, FunctionType)] = None - ) : (Clauses[T], CallBlockers[T], Apps[T], Functions[T], Map[T, Set[Matcher[T]]], () => String) = { + ) : (Clauses[T], Calls[T], Apps[T], Functions[T], Map[T, Set[Matcher[T]]], () => String) = { val idToTrId : Map[Identifier, T] = condVars ++ exprVars + pathVar ++ arguments ++ substMap ++ lambdas.map(_.ids) ++ quantifications.map(_.qs) @@ -249,8 +251,8 @@ object Template { (blockers, applications, matchers) } - val encodedBlockers : Map[T, Set[TemplateCallInfo[T]]] = blockers.map(p => idToTrId(p._1) -> p._2) - val encodedApps : Map[T, Set[App[T]]] = applications.map(p => idToTrId(p._1) -> p._2) + val encodedBlockers : Calls[T] = blockers.map(p => idToTrId(p._1) -> p._2) + val encodedApps : Apps[T] = applications.map(p => idToTrId(p._1) -> p._2) val encodedMatchers : Map[T, Set[Matcher[T]]] = matchers.map(p => idToTrId(p._1) -> p._2) val stringRepr : () => String = () => { @@ -271,6 +273,9 @@ object Template { }) + " * Lambdas :\n" + lambdas.map { case template => " +> " + template.toString.split("\n").mkString("\n ") + "\n" + }.mkString("\n") + + " * Foralls :\n" + quantifications.map { case template => + " +> " + template.toString.split("\n").mkString("\n ") + "\n" }.mkString("\n") } @@ -285,7 +290,7 @@ object Template { condTree: Map[Identifier, Set[Identifier]], quantifications: Seq[QuantificationTemplate[T]], lambdas: Seq[LambdaTemplate[T]], - functions: Set[(T, FunctionType, T)], + functions: Functions[T], baseSubst: Map[T, Arg[T]], pathVar: Identifier, aVar: T @@ -351,9 +356,9 @@ object Template { def instantiate[T]( encoder: TemplateEncoder[T], manager: TemplateManager[T], - clauses: Seq[T], - blockers: Map[T, Set[TemplateCallInfo[T]]], - applications: Map[T, Set[App[T]]], + clauses: Clauses[T], + blockers: Calls[T], + applications: Apps[T], matchers: Map[T, Set[Matcher[T]]], substMap: Map[T, Arg[T]] ): Instantiation[T] = { @@ -361,9 +366,9 @@ object Template { val substituter : T => T = encoder.substitute(substMap.mapValues(_.encoded)) val msubst = substMap.collect { case (c, Right(m)) => c -> m } - val newClauses = clauses.map(substituter) + val newClauses: Clauses[T] = clauses.map(substituter) - val newBlockers = blockers.map { case (b,fis) => + val newBlockers: CallBlockers[T] = blockers.map { case (b,fis) => substituter(b) -> fis.map(fi => fi.copy(args = fi.args.map(_.substitute(substituter, msubst)))) } @@ -451,10 +456,10 @@ class FunctionTemplate[T] private( val condVars: Map[Identifier, T], val exprVars: Map[Identifier, T], val condTree: Map[Identifier, Set[Identifier]], - val clauses: Seq[T], - val blockers: Map[T, Set[TemplateCallInfo[T]]], - val applications: Map[T, Set[App[T]]], - val functions: Set[(T, FunctionType, T)], + val clauses: Clauses[T], + val blockers: Calls[T], + val applications: Apps[T], + val functions: Functions[T], val lambdas: Seq[LambdaTemplate[T]], val matchers: Map[T, Set[Matcher[T]]], val quantifications: Seq[QuantificationTemplate[T]], diff --git a/src/main/scala/leon/solvers/templates/UnrollingBank.scala b/src/main/scala/leon/solvers/templates/UnrollingBank.scala index acee03b73..2383f65f6 100644 --- a/src/main/scala/leon/solvers/templates/UnrollingBank.scala +++ b/src/main/scala/leon/solvers/templates/UnrollingBank.scala @@ -100,6 +100,8 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat } } + def getFiniteRangeClauses: Seq[T] = manager.checkClauses + private def registerCallBlocker(gen: Int, id: T, fis: Set[TemplateCallInfo[T]]) { val notId = encoder.mkNot(id) @@ -174,7 +176,7 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat val blockClauses = freshAppBlocks(appBlocks.keys) - for((b, infos) <- callBlocks) { + for ((b, infos) <- callBlocks) { registerCallBlocker(nextGeneration(0), b, infos) } @@ -224,6 +226,7 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat def instantiateQuantifiers(force: Boolean = false): Seq[T] = { val (newExprs, callBlocks, appBlocks) = manager.instantiateIgnored(force) val blockExprs = freshAppBlocks(appBlocks.keys) + val gens = (callInfos.values.map(_._1) ++ appInfos.values.map(_._1)) val gen = if (gens.nonEmpty) gens.min else 0 @@ -367,12 +370,6 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat newClauses ++= newCls } - /* - for ((app @ (b, _), (gen, _, _, _, infos)) <- thisAppInfos if infos.isEmpty) { - registerAppBlocker(nextGeneration(gen), app, infos) - } - */ - reporter.debug(s" - ${newClauses.size} new clauses") //context.reporter.ifDebug { debug => // debug(s" - new clauses:") diff --git a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala index 01a86e2b9..4e24697b6 100644 --- a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala @@ -19,6 +19,8 @@ import purescala.Types._ case class UnsoundExtractionException(ast: Z3AST, msg: String) extends Exception("Can't extract " + ast + " : " + msg) +object AbstractZ3Solver + // This is just to factor out the things that are common in "classes that deal // with a Z3 instance" trait AbstractZ3Solver extends Solver { @@ -45,8 +47,19 @@ trait AbstractZ3Solver extends Solver { } } - protected[leon] val z3cfg : Z3Config - protected[leon] var z3 : Z3Context = null + // FIXME: (dirty?) hack to bypass z3lib bug. + // Uses the unique AbstractZ3Solver to ensure synchronization (no assumption on context). + protected[leon] val z3cfg : Z3Config = + AbstractZ3Solver.synchronized(new Z3Config( + "MODEL" -> true, + "TYPE_CHECK" -> true, + "WELL_SORTED_CHECK" -> true + )) + toggleWarningMessages(true) + + protected[leon] var z3 : Z3Context = null + + lazy protected val solver = z3.mkSolver() override def free(): Unit = { freed = true @@ -73,28 +86,21 @@ trait AbstractZ3Solver extends Solver { } } - def genericValueToDecl(gv: GenericValue): Z3FuncDecl = { - generics.cachedB(gv) { - z3.mkFreshFuncDecl(gv.tp.id.uniqueName+"#"+gv.id+"!val", Seq(), typeToSort(gv.tp)) - } - } - // ADT Manager protected val adtManager = new ADTManager(context) // Bijections between Leon Types/Functions/Ids to Z3 Sorts/Decls/ASTs protected val functions = new IncrementalBijection[TypedFunDef, Z3FuncDecl]() - protected val generics = new IncrementalBijection[GenericValue, Z3FuncDecl]() protected val lambdas = new IncrementalBijection[FunctionType, Z3FuncDecl]() protected val sorts = new IncrementalBijection[TypeTree, Z3Sort]() protected val variables = new IncrementalBijection[Expr, Z3AST]() - protected val constructors = new IncrementalBijection[TypeTree, Z3FuncDecl]() - protected val selectors = new IncrementalBijection[(TypeTree, Int), Z3FuncDecl]() - protected val testers = new IncrementalBijection[TypeTree, Z3FuncDecl]() + protected val constructors = new IncrementalBijection[TypeTree, Z3FuncDecl]() + protected val selectors = new IncrementalBijection[(TypeTree, Int), Z3FuncDecl]() + protected val testers = new IncrementalBijection[TypeTree, Z3FuncDecl]() var isInitialized = false - protected[leon] def initZ3() { + protected[leon] def initZ3(): Unit = { if (!isInitialized) { val timer = context.timers.solvers.z3.init.start() @@ -102,7 +108,6 @@ trait AbstractZ3Solver extends Solver { functions.clear() lambdas.clear() - generics.clear() sorts.clear() variables.clear() constructors.clear() @@ -117,11 +122,7 @@ trait AbstractZ3Solver extends Solver { } } - protected[leon] def restartZ3() { - isInitialized = false - - initZ3() - } + initZ3() def rootType(ct: TypeTree): TypeTree = ct match { case ct: ClassType => ct.root @@ -218,7 +219,7 @@ trait AbstractZ3Solver extends Solver { case Int32Type | BooleanType | IntegerType | RealType | CharType => sorts.toB(oldtt) - case tpe @ (_: ClassType | _: ArrayType | _: TupleType | UnitType) => + case tpe @ (_: ClassType | _: ArrayType | _: TupleType | _: TypeParameter | UnitType) => sorts.cachedB(tpe) { declareStructuralSort(tpe) } @@ -239,14 +240,6 @@ trait AbstractZ3Solver extends Solver { z3.mkArraySort(fromSort, toSort) } - case tt @ TypeParameter(id) => - sorts.cachedB(tt) { - val symbol = z3.mkFreshStringSymbol(id.name) - val newTPSort = z3.mkUninterpretedSort(symbol) - - newTPSort - } - case ft @ FunctionType(from, to) => sorts.cachedB(ft) { val symbol = z3.mkFreshStringSymbol(ft.toString) @@ -531,9 +524,10 @@ trait AbstractZ3Solver extends Solver { z3.mkStore(m, rec(k), rec(CaseClass(library.someType(t), Seq(v)))) } - case gv @ GenericValue(tp, id) => - z3.mkApp(genericValueToDecl(gv)) + typeToSort(tp) + val constructor = constructors.toB(tp) + constructor(rec(InfiniteIntegerLiteral(id))) case other => unsupported(other) @@ -607,8 +601,6 @@ trait AbstractZ3Solver extends Solver { val tfd = functions.toA(decl) assert(tfd.params.size == argsSize) FunctionInvocation(tfd, args.zip(tfd.params).map{ case (a, p) => rec(a, p.getType) }) - } else if (generics containsB decl) { - generics.toA(decl) } else if (constructors containsB decl) { constructors.toA(decl) match { case cct: CaseClassType => @@ -640,6 +632,13 @@ trait AbstractZ3Solver extends Solver { case (s : IntLiteral, arr) => unsound(args(1), "invalid array type") case (size, _) => unsound(args(0), "invalid array size") } + + case tp @ TypeParameter(id) => + val InfiniteIntegerLiteral(n) = rec(args(0), IntegerType) + GenericValue(tp, n.toInt) + + case t => + unsupported(t, "Woot? structural type that is non-structural") } } else { tpe match { @@ -671,10 +670,6 @@ trait AbstractZ3Solver extends Solver { } } - case tp: TypeParameter => - val id = t.toString.split("!").last.toInt - GenericValue(tp, id) - case MapType(from, to) => rec(t, RawArrayType(from, library.optionType(to))) match { case r: RawArrayValue => diff --git a/src/main/scala/leon/solvers/z3/FairZ3Solver.scala b/src/main/scala/leon/solvers/z3/FairZ3Solver.scala index fa3352b1e..5de176ff2 100644 --- a/src/main/scala/leon/solvers/z3/FairZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/FairZ3Solver.scala @@ -32,14 +32,6 @@ class FairZ3Solver(val context: LeonContext, val program: Program) override protected val reporter = context.reporter override def reset(): Unit = super[AbstractZ3Solver].reset() - // FIXME: Dirty hack to bypass z3lib bug. Assumes context is the same over all instances of FairZ3Solver - protected[leon] val z3cfg = context.synchronized { new Z3Config( - "MODEL" -> true, - "TYPE_CHECK" -> true, - "WELL_SORTED_CHECK" -> true - )} - toggleWarningMessages(true) - def solverCheck[R](clauses: Seq[Z3AST])(block: Option[Boolean] => R): R = { solver.push() for (cls <- clauses) solver.assertCnstr(cls) @@ -145,12 +137,8 @@ class FairZ3Solver(val context: LeonContext, val program: Program) } } - initZ3() - - val solver = z3.mkSolver() - private val incrementals: List[IncrementalState] = List( - errors, functions, generics, lambdas, sorts, variables, + errors, functions, lambdas, sorts, variables, constructors, selectors, testers ) diff --git a/src/main/scala/leon/solvers/z3/UninterpretedZ3Solver.scala b/src/main/scala/leon/solvers/z3/UninterpretedZ3Solver.scala index 6f40151cf..829c77343 100644 --- a/src/main/scala/leon/solvers/z3/UninterpretedZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/UninterpretedZ3Solver.scala @@ -29,18 +29,6 @@ class UninterpretedZ3Solver(val context : LeonContext, val program: Program) val name = "Z3-u" val description = "Uninterpreted Z3 Solver" - // this is fixed - protected[leon] val z3cfg = new Z3Config( - "MODEL" -> true, - "TYPE_CHECK" -> true, - "WELL_SORTED_CHECK" -> true - ) - toggleWarningMessages(true) - - initZ3() - - val solver = z3.mkSolver() - def push() { solver.push() freeVariables.push() diff --git a/src/main/scala/leon/solvers/z3/Z3StringConversion.scala b/src/main/scala/leon/solvers/z3/Z3StringConversion.scala index fa262caa6..c5fa2b0ba 100644 --- a/src/main/scala/leon/solvers/z3/Z3StringConversion.scala +++ b/src/main/scala/leon/solvers/z3/Z3StringConversion.scala @@ -20,23 +20,25 @@ object StringEcoSystem { val id = FreshIdentifier(name, tpe) f(id) } + private def withIdentifiers[T](name: String, tpe: TypeTree, name2: String, tpe2: TypeTree = Untyped)(f: (Identifier, Identifier) => T): T = { withIdentifier(name, tpe)(id => withIdentifier(name2, tpe2)(id2 => f(id, id2))) } - - val StringList = AbstractClassDef(FreshIdentifier("StringList"), Seq(), None) + + val StringList = new AbstractClassDef(FreshIdentifier("StringList"), Seq(), None) val StringListTyped = StringList.typed val StringCons = withIdentifiers("head", CharType, "tail", StringListTyped){ (head, tail) => - val d = CaseClassDef(FreshIdentifier("StringCons"), Seq(), Some(StringListTyped), false) + val d = new CaseClassDef(FreshIdentifier("StringCons"), Seq(), Some(StringListTyped), false) d.setFields(Seq(ValDef(head), ValDef(tail))) d } + StringList.registerChild(StringCons) val StringConsTyped = StringCons.typed - val StringNil = CaseClassDef(FreshIdentifier("StringNil"), Seq(), Some(StringListTyped), false) + val StringNil = new CaseClassDef(FreshIdentifier("StringNil"), Seq(), Some(StringListTyped), false) val StringNilTyped = StringNil.typed StringList.registerChild(StringNil) - + val StringSize = withIdentifiers("l", StringListTyped, "StringSize"){ (lengthArg, id) => val fd = new FunDef(id, Seq(), Seq(ValDef(lengthArg)), IntegerType) fd.body = Some(withIdentifiers("h", CharType, "t", StringListTyped){ (h, t) => @@ -48,6 +50,7 @@ object StringEcoSystem { }) fd } + val StringListConcat = withIdentifiers("x", StringListTyped, "y", StringListTyped) { (x, y) => val fd = new FunDef(FreshIdentifier("StringListConcat"), Seq(), Seq(ValDef(x), ValDef(y)), StringListTyped) fd.body = Some( @@ -61,7 +64,7 @@ object StringEcoSystem { ) fd } - + val StringTake = withIdentifiers("tt", StringListTyped, "it", StringListTyped) { (tt, it) => val fd = new FunDef(FreshIdentifier("StringTake"), Seq(), Seq(ValDef(tt), ValDef(it)), StringListTyped) fd.body = Some{ diff --git a/src/main/scala/leon/synthesis/rules/StringRender.scala b/src/main/scala/leon/synthesis/rules/StringRender.scala index ec681763e..b86894867 100644 --- a/src/main/scala/leon/synthesis/rules/StringRender.scala +++ b/src/main/scala/leon/synthesis/rules/StringRender.scala @@ -22,7 +22,6 @@ import solvers.string.StringSolver import programsets.DirectProgramSet import programsets.JoinProgramSet - /** A template generator for a given type tree. * Extend this class using a concrete type tree, * Then use the apply method to get a hole which can be a placeholder for holes in the template. @@ -210,12 +209,12 @@ case object StringRender extends Rule("StringRender") { def askQuestion(input: List[Identifier], r: RuleClosed)(implicit c: LeonContext, p: Program): List[disambiguation.Question[StringLiteral]] = { //if !s.contains(EDIT_ME) val qb = new disambiguation.QuestionBuilder(input, r.solutions, (seq: Seq[Expr], expr: Expr) => expr match { - case s@StringLiteral(slv) if !slv.contains(EDIT_ME) => Some(s) + case s @ StringLiteral(slv) if !slv.contains(EDIT_ME) => Some(s) case _ => None }) qb.result() } - + /** Converts the stream of solutions to a RuleApplication */ def solutionStreamToRuleApplication(p: Problem, solutions: Stream[(Seq[(FunDef, WithIds[Expr])], WithIds[Expr], Assignment)])(implicit program: Program): RuleApplication = { if(solutions.isEmpty) RuleFailed() else { @@ -361,8 +360,8 @@ case object StringRender extends Rule("StringRender") { def extractCaseVariants(cct: CaseClassType, ctx: StringSynthesisContext) : (Stream[WithIds[MatchCase]], StringSynthesisResult) = cct match { - case CaseClassType(ccd@CaseClassDef(id, tparams, parent, isCaseObject), tparams2) => - val typeMap = tparams.zip(tparams2).toMap + case CaseClassType(ccd: CaseClassDef, tparams2) => + val typeMap = ccd.tparams.zip(tparams2).toMap val fields = ccd.fields.map(vd => TypeOps.instantiateType(vd.id, typeMap) ) val pattern = CaseClassPattern(None, ccd.typed(tparams2), fields.map(k => WildcardPattern(Some(k)))) val (rhs, result) = createFunDefsTemplates(ctx.copy(currentCaseClassParent=Some(cct)), fields.map(Variable)) // Invoke functions for each of the fields. @@ -387,11 +386,11 @@ case object StringRender extends Rule("StringRender") { */ def constantPatternMatching(fd: FunDef, act: AbstractClassType): WithIds[MatchExpr] = { val cases = (ListBuffer[WithIds[MatchCase]]() /: act.knownCCDescendants) { - case (acc, cct@CaseClassType(ccd@CaseClassDef(id, tparams, parent, isCaseObject), tparams2)) => - val typeMap = tparams.zip(tparams2).toMap + case (acc, cct @ CaseClassType(ccd, tparams2)) => + val typeMap = ccd.tparams.zip(tparams2).toMap val fields = ccd.fields.map(vd => TypeOps.instantiateType(vd.id, typeMap) ) val pattern = CaseClassPattern(None, ccd.typed(tparams2), fields.map(k => WildcardPattern(Some(k)))) - val rhs = StringLiteral(id.asString) + val rhs = StringLiteral(ccd.id.asString) MatchCase(pattern, None, rhs) acc += ((MatchCase(pattern, None, rhs), Nil)) case (acc, e) => hctx.reporter.fatalError("Could not handle this class definition for string rendering " + e) @@ -458,17 +457,17 @@ case object StringRender extends Rule("StringRender") { val fd = createEmptyFunDef(ctx, dependentType) val ctx2 = preUpdateFunDefBody(dependentType, fd, ctx) // Inserts the FunDef in the assignments so that it can already be used. t.root match { - case act@AbstractClassType(acd@AbstractClassDef(id, tparams, parent), tps) => + case act @ AbstractClassType(acd, tps) => // Create a complete FunDef body with pattern matching val allKnownDescendantsAreCCAndHaveZeroArgs = act.knownCCDescendants.forall { - case CaseClassType(ccd@CaseClassDef(id, tparams, parent, isCaseObject), tparams2) => ccd.fields.isEmpty + case CaseClassType(ccd, tparams2) => ccd.fields.isEmpty case _ => false } //TODO: Test other templates not only with Wilcard patterns, but more cases options for non-recursive classes (e.g. Option, Boolean, Finite parameterless case classes.) val (ctx3, cases) = ((ctx2, ListBuffer[Stream[WithIds[MatchCase]]]()) /: act.knownCCDescendants) { - case ((ctx22, acc), cct@CaseClassType(ccd@CaseClassDef(id, tparams, parent, isCaseObject), tparams2)) => + case ((ctx22, acc), cct @ CaseClassType(ccd, tparams2)) => val (newCases, result) = extractCaseVariants(cct, ctx22) val ctx23 = ctx22.copy(result = result) (ctx23, acc += newCases) @@ -481,7 +480,7 @@ case object StringRender extends Rule("StringRender") { } else allMatchExprsEnd gatherInputs(ctx3.add(dependentType, fd, allMatchExprs), q, result += Stream((functionInvocation(fd, input::ctx.provided_functions.toList.map(Variable)), Nil))) - case cct@CaseClassType(ccd@CaseClassDef(id, tparams, parent, isCaseObject), tparams2) => + case cct @ CaseClassType(ccd, tparams2) => val (newCases, result3) = extractCaseVariants(cct, ctx2) val allMatchExprs = newCases.map(acase => mergeMatchCases(fd)(Seq(acase))) gatherInputs(ctx2.copy(result = result3).add(dependentType, fd, allMatchExprs), q, @@ -580,4 +579,4 @@ case object StringRender extends Rule("StringRender") { case _ => Nil } } -} \ No newline at end of file +} diff --git a/src/main/scala/leon/transformations/IntToRealProgram.scala b/src/main/scala/leon/transformations/IntToRealProgram.scala index fa09f9dca..f5444dec5 100644 --- a/src/main/scala/leon/transformations/IntToRealProgram.scala +++ b/src/main/scala/leon/transformations/IntToRealProgram.scala @@ -40,7 +40,7 @@ abstract class ProgramTypeTransformer { val absType = ccdef.parent.get Some(AbstractClassType(mapClass(absType.classDef), absType.tps)) } else None - val newclassDef = ccdef.copy(id = FreshIdentifier(ccdef.id.name, ccdef.id.getType, true), parent = newparent) + val newclassDef = ccdef.duplicate(id = FreshIdentifier(ccdef.id.name, ccdef.id.getType, true), parent = newparent) //important: register a child if a parent was newly created. if (newparent.isDefined) @@ -55,7 +55,7 @@ abstract class ProgramTypeTransformer { val absType = acdef.parent.get Some(AbstractClassType(mapClass(absType.classDef), absType.tps)) } else None - val newClassDef = acdef.copy(id = FreshIdentifier(acdef.id.name, acdef.id.getType, true), parent = newparent) + val newClassDef = acdef.duplicate(id = FreshIdentifier(acdef.id.name, acdef.id.getType, true), parent = newparent) defmap += (acdef -> newClassDef) newClassDef.asInstanceOf[T] } diff --git a/src/test/resources/regression/verification/purescala/invalid/AbstractRefinementMap.scala b/src/test/resources/regression/verification/purescala/invalid/AbstractRefinementMap.scala new file mode 100644 index 000000000..30acb0f91 --- /dev/null +++ b/src/test/resources/regression/verification/purescala/invalid/AbstractRefinementMap.scala @@ -0,0 +1,22 @@ +import leon.annotation._ +import leon.collection._ +import leon.lang._ + +object AbstractRefinementMap { + + case class ~>[A,B](private val f: A => B, pre: A => Boolean, ens: B => Boolean) { + def apply(x: A): B = { + require(pre(x)) + f(x) + } ensuring(ens) + } + + def map[A, B](l: List[A], f: A ~> B): List[B] = { + require(forall((x:A) => l.contains(x) ==> f.pre(x))) + l match { + case Cons(x, xs) => Cons[B](f(x), map(xs, f)) + case Nil() => Nil[B]() + } + } ensuring { res => forall((x: B) => res.contains(x) ==> f.ens(x)) } +} + diff --git a/src/test/resources/regression/verification/purescala/invalid/AbstractRefinementMap2.scala b/src/test/resources/regression/verification/purescala/invalid/AbstractRefinementMap2.scala new file mode 100644 index 000000000..f0cb18413 --- /dev/null +++ b/src/test/resources/regression/verification/purescala/invalid/AbstractRefinementMap2.scala @@ -0,0 +1,24 @@ +import leon.annotation._ +import leon.collection._ +import leon.lang._ + +object AbstractRefinementMap2 { + + case class ~>[A,B](private val f: A => B, pre: A => Boolean, ens: B => Boolean) { + require(forall((x: A) => pre(x) ==> ens(f(x)))) + + def apply(x: A): B = { + require(pre(x)) + f(x) + } ensuring(ens) + } + + def map[A, B](l: List[A], f: A ~> B): List[B] = { + require(forall((x:A) => l.contains(x) ==> f.pre(x))) + l match { + case Cons(x, xs) => Cons[B](f(x), map(xs, f)) + case Nil() => Nil[B]() + } + } ensuring { res => forall((x: B) => /* res.contains(x) ==> */ f.ens(x)) } +} + diff --git a/src/test/resources/regression/verification/purescala/valid/AbstractRefinementMap.scala b/src/test/resources/regression/verification/purescala/valid/AbstractRefinementMap.scala new file mode 100644 index 000000000..c42579dce --- /dev/null +++ b/src/test/resources/regression/verification/purescala/valid/AbstractRefinementMap.scala @@ -0,0 +1,24 @@ +import leon.annotation._ +import leon.collection._ +import leon.lang._ + +object AbstractRefinementMap { + + case class ~>[A,B](private val f: A => B, pre: A => Boolean, ens: B => Boolean) { + require(forall((x: A) => pre(x) ==> ens(f(x)))) + + def apply(x: A): B = { + require(pre(x)) + f(x) + } ensuring(ens) + } + + def map[A, B](l: List[A], f: A ~> B): List[B] = { + require(forall((x:A) => l.contains(x) ==> f.pre(x))) + l match { + case Cons(x, xs) => Cons[B](f(x), map(xs, f)) + case Nil() => Nil[B]() + } + } ensuring { res => forall((x: B) => res.contains(x) ==> f.ens(x)) } +} + diff --git a/src/test/scala/leon/integration/solvers/QuantifierSolverSuite.scala b/src/test/scala/leon/integration/solvers/QuantifierSolverSuite.scala index 4ff5ccc71..fa2260afc 100644 --- a/src/test/scala/leon/integration/solvers/QuantifierSolverSuite.scala +++ b/src/test/scala/leon/integration/solvers/QuantifierSolverSuite.scala @@ -20,7 +20,7 @@ class QuantifierSolverSuite extends LeonTestSuiteWithProgram { val sources = List() - override val leonOpts = List("checkmodels") + override val leonOpts = List("--checkmodels") val getFactories: Seq[(String, (LeonContext, Program) => Solver)] = { (if (SolverFactory.hasNativeZ3) Seq( @@ -126,6 +126,7 @@ class QuantifierSolverSuite extends LeonTestSuiteWithProgram { checkSolver(solver, expr, true) } + /* test(s"Satisfiable quantified formula $ename in $sname with partial models") { implicit fix => val (ctx, pgm) = fix val newCtx = ctx.copy(options = ctx.options.filter(_ != UnrollingProcedure.optPartialModels) :+ @@ -133,6 +134,7 @@ class QuantifierSolverSuite extends LeonTestSuiteWithProgram { val solver = sf(newCtx, pgm) checkSolver(solver, expr, true) } + */ } for ((sname, sf) <- getFactories; (ename, expr) <- unsatisfiable) { diff --git a/src/test/scala/leon/unit/purescala/TypeOpsSuite.scala b/src/test/scala/leon/unit/purescala/TypeOpsSuite.scala index e24be6f28..c6158d66a 100644 --- a/src/test/scala/leon/unit/purescala/TypeOpsSuite.scala +++ b/src/test/scala/leon/unit/purescala/TypeOpsSuite.scala @@ -19,13 +19,13 @@ class TypeOpsSuite extends LeonTestSuite with helpers.WithLikelyEq with helpers. val tp2 = TypeParameter.fresh("A") val tp3 = TypeParameter.fresh("B") - val listD = AbstractClassDef(FreshIdentifier("List"), Seq(tpD), None) + val listD = new AbstractClassDef(FreshIdentifier("List"), Seq(tpD), None) val listT = listD.typed - val nilD = CaseClassDef(FreshIdentifier("Nil"), Seq(tpD), Some(listT), false) + val nilD = new CaseClassDef(FreshIdentifier("Nil"), Seq(tpD), Some(listT), false) val nilT = nilD.typed - val consD = CaseClassDef(FreshIdentifier("Cons"), Seq(tpD), Some(listT), false) + val consD = new CaseClassDef(FreshIdentifier("Cons"), Seq(tpD), Some(listT), false) val consT = consD.typed // Simple tests for fixed types -- GitLab