diff --git a/src/main/scala/leon/Main.scala b/src/main/scala/leon/Main.scala index 3939332776cd21af143f355bb6d0e8fc94459b51..efce1a2ff6937c860e684b1df9c03166089572bc 100644 --- a/src/main/scala/leon/Main.scala +++ b/src/main/scala/leon/Main.scala @@ -36,7 +36,7 @@ object Main { // Add whatever you need here. lazy val allComponents : Set[LeonComponent] = allPhases.toSet ++ Set( - solvers.combinators.UnrollingProcedure, MainComponent, GlobalOptions, solvers.smtlib.SMTLIBCVC4Component, solvers.isabelle.Component + solvers.unrolling.UnrollingProcedure, MainComponent, GlobalOptions, solvers.smtlib.SMTLIBCVC4Component, solvers.isabelle.Component ) /* diff --git a/src/main/scala/leon/codegen/CodeGeneration.scala b/src/main/scala/leon/codegen/CodeGeneration.scala index 0e2a3fe235621adf8db2746f56e96db338479929..866c3ce096d3f53fd0a81d533fd64e50ff3b0a0e 100644 --- a/src/main/scala/leon/codegen/CodeGeneration.scala +++ b/src/main/scala/leon/codegen/CodeGeneration.scala @@ -92,7 +92,8 @@ trait CodeGeneration { def idToSafeJVMName(id: Identifier) = { scala.reflect.NameTransformer.encode(id.uniqueName).replaceAll("\\.", "\\$") } - def defToJVMName(d : Definition) : String = "Leon$CodeGen$" + idToSafeJVMName(d.id) + + def defToJVMName(d: Definition): String = "Leon$CodeGen$" + idToSafeJVMName(d.id) /** Retrieve the name of the underlying lazy field from a lazy field accessor method */ private[codegen] def underlyingField(lazyAccessor : String) = lazyAccessor + "$underlying" @@ -201,7 +202,7 @@ trait CodeGeneration { funDef.fullBody } else { funDef.body.getOrElse( - if(funDef.annotations contains "extern") { + if (funDef.annotations contains "extern") { Error(funDef.id.getType, "Body of " + funDef.id.name + " not implemented at compile-time and still executed.") } else { throw CompilationException("Can't compile a FunDef without body: "+funDef.id.name) @@ -489,8 +490,9 @@ trait CodeGeneration { } ch << New(ccName) << DUP load(monitorID, ch) + loadTypes(cct.tps, ch) - for((a, vd) <- as zip cct.classDef.fields) { + for ((a, vd) <- as zip cct.classDef.fields) { vd.getType match { case TypeParameter(_) => mkBoxedExpr(a, ch) @@ -668,6 +670,8 @@ trait CodeGeneration { load(monitorID, ch) ch << DUP_X2 << POP + loadTypes(Seq(tp), ch) + ch << DUP_X2 << POP ch << InvokeSpecial(consName, constructorName, ccApplySig) // list, it, newList @@ -1529,7 +1533,7 @@ trait CodeGeneration { } } - def compileAbstractClassDef(acd : AbstractClassDef) { + def compileAbstractClassDef(acd: AbstractClassDef) { val cName = defToJVMName(acd) @@ -1642,7 +1646,6 @@ trait CodeGeneration { } def compileCaseClassDef(ccd: CaseClassDef) { - val cName = defToJVMName(ccd) val pName = ccd.parent.map(parent => defToJVMName(parent.classDef)) // An instantiation of ccd with its own type parameters @@ -1656,13 +1659,14 @@ trait CodeGeneration { CLASS_ACC_FINAL ).asInstanceOf[U2]) - if(ccd.parent.isEmpty) { + if (ccd.parent.isEmpty) { cf.addInterface(CaseClassClass) } // Case class parameters val fieldsTypes = ccd.fields.map { vd => (vd.id, typeToJVM(vd.getType)) } - val constructorArgs = (monitorID -> s"L$MonitorClass;") +: fieldsTypes + val tpeParam = if (ccd.tparams.isEmpty) Seq() else Seq(tpsID -> "[I") + val constructorArgs = (monitorID -> s"L$MonitorClass;") +: (tpeParam ++ fieldsTypes) val newLocs = NoLocals.withFields(constructorArgs.map { case (id, jvmt) => (id, (cName, id.name, jvmt)) @@ -1674,7 +1678,7 @@ trait CodeGeneration { // Compile methods for (method <- methods) { - compileFunDef(method,ccd) + compileFunDef(method, ccd) } // Compile lazy fields @@ -1688,7 +1692,7 @@ trait CodeGeneration { } // definition of the constructor - for((id, jvmt) <- constructorArgs) { + for ((id, jvmt) <- constructorArgs) { val fh = cf.addField(jvmt, id.name) fh.setFlags(( FIELD_ACC_PUBLIC | @@ -1710,7 +1714,7 @@ trait CodeGeneration { } var c = 1 - for((id, jvmt) <- constructorArgs) { + for ((id, jvmt) <- constructorArgs) { cch << ALoad(0) cch << (jvmt match { case "I" | "Z" => ILoad(c) @@ -1734,6 +1738,17 @@ trait CodeGeneration { // Now initialize fields for (lzy <- lazyFields) { initLazyField(cch, cName, lzy, isStatic = false)(newLocs) } for (field <- strictFields) { initStrictField(cch, cName , field, isStatic = false)(newLocs) } + + // Finally check invariant (if it exists) + if (params.checkContracts && ccd.hasInvariant) { + val thisId = FreshIdentifier("this", cct, true) + val invLocals = newLocs.withVar(thisId -> 0) + mkExpr(IfExpr(FunctionInvocation(cct.invariant.get, Seq(Variable(thisId))), + BooleanLiteral(true), + Error(BooleanType, "ADT Invariant failed @" + ccd.invariant.get.getPos)), cch)(invLocals) + cch << POP + } + cch << RETURN cch.freeze } diff --git a/src/main/scala/leon/codegen/CompilationUnit.scala b/src/main/scala/leon/codegen/CompilationUnit.scala index 1bc3600fb74e3af430167431cc68abdb050b95c2..cd6bad57e5b374c6a7088931989c87a0d782068f 100644 --- a/src/main/scala/leon/codegen/CompilationUnit.scala +++ b/src/main/scala/leon/codegen/CompilationUnit.scala @@ -80,7 +80,7 @@ class CompilationUnit(val ctx: LeonContext, id } - def defineClass(df: Definition) { + def defineClass(df: Definition): Unit = { val cName = defToJVMName(df) val cf = df match { @@ -105,7 +105,8 @@ class CompilationUnit(val ctx: LeonContext, def leonClassToJVMInfo(cd: ClassDef): Option[(String, String)] = { classes.get(cd) match { case Some(cf) => - val sig = "(L"+MonitorClass+";" + cd.fields.map(f => typeToJVM(f.getType)).mkString("") + ")V" + val tpeParam = if (cd.tparams.isEmpty) "" else "[I" + val sig = "(L"+MonitorClass+";" + tpeParam + cd.fields.map(f => typeToJVM(f.getType)).mkString("") + ")V" Some((cf.className, sig)) case _ => None } @@ -114,7 +115,6 @@ class CompilationUnit(val ctx: LeonContext, // Returns className, methodName, methodSignature private[this] var funDefInfo = Map[FunDef, (String, String, String)]() - /** * Returns (cn, mn, sig) where * cn is the module name @@ -213,7 +213,8 @@ class CompilationUnit(val ctx: LeonContext, case CaseClass(cct, args) => caseClassConstructor(cct.classDef) match { case Some(cons) => - val jvmArgs = monitor +: args.map(valueToJVM) + val tpeParam = if (cct.tps.isEmpty) Seq() else Seq(cct.tps.map(registerType).toArray) + val jvmArgs = monitor +: (tpeParam ++ args.map(valueToJVM)) cons.newInstance(jvmArgs.toArray : _*).asInstanceOf[AnyRef] case None => ctx.reporter.fatalError("Case class constructor not found?!?") @@ -259,10 +260,8 @@ class CompilationUnit(val ctx: LeonContext, val lc = loader.loadClass(afName) val conss = lc.getConstructors.sortBy(_.getParameterTypes.length) - println(conss) assert(conss.nonEmpty) val lambdaConstructor = conss.last - println(args.toArray) lambdaConstructor.newInstance(args.toArray : _*).asInstanceOf[AnyRef] case f @ IsTyped(FiniteArray(elems, default, IntLiteral(length)), ArrayType(underlying)) => @@ -541,7 +540,7 @@ class CompilationUnit(val ctx: LeonContext, for (m <- u.modules) { defineClass(m) - for(funDef <- m.definedFunctions) { + for (funDef <- m.definedFunctions) { defToModuleOrClass += funDef -> m } } diff --git a/src/main/scala/leon/codegen/runtime/Monitor.scala b/src/main/scala/leon/codegen/runtime/Monitor.scala index 0861ce3bc406711e5d4d460c5e5cacdb7eb09cd0..6ed6b2d6777294f7206d895a2b6ec3ac881e8d61 100644 --- a/src/main/scala/leon/codegen/runtime/Monitor.scala +++ b/src/main/scala/leon/codegen/runtime/Monitor.scala @@ -20,7 +20,7 @@ import scala.collection.mutable.{HashMap => MutableMap, Set => MutableSet} import scala.concurrent.duration._ import solvers.SolverFactory -import solvers.combinators.UnrollingProcedure +import solvers.unrolling.UnrollingProcedure import synthesis._ diff --git a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala index f752b1895d9ac04a629d1b9b68c2a36d447cf07e..418ac42e62409d18d7ebb39276b9fe40e85b063b 100644 --- a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala +++ b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala @@ -15,7 +15,7 @@ import purescala.Expressions._ import purescala.Definitions._ import purescala.DefOps import solvers.{PartialModel, Model, SolverFactory} -import solvers.combinators.UnrollingProcedure +import solvers.unrolling.UnrollingProcedure import scala.collection.mutable.{Map => MutableMap} import scala.concurrent.duration._ import org.apache.commons.lang3.StringEscapeUtils @@ -191,8 +191,18 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int case _ => BooleanLiteral(lv == rv) } - case CaseClass(cd, args) => - CaseClass(cd, args.map(e)) + case CaseClass(cct, args) => + val cc = CaseClass(cct, args.map(e)) + if (cct.classDef.hasInvariant) { + e(FunctionInvocation(cct.invariant.get, Seq(cc))) match { + case BooleanLiteral(true) => + case BooleanLiteral(false) => + throw RuntimeError("ADT invariant violation for " + cct.classDef.id.asString + " reached in evaluation.: " + cct.invariant.get.asString) + case other => + throw RuntimeError(typeErrorMsg(other, BooleanType)) + } + } + cc case AsInstanceOf(expr, ct) => val le = e(expr) @@ -518,7 +528,7 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int implicit val debugSection = utils.DebugSectionVerification - ctx.reporter.debug("Executing forall!") + ctx.reporter.debug("Executing forall: " + f.asString) val mapping = variablesOf(f).map(id => id -> rctx.mappings(id)).toMap val context = mapping.toSeq.sortBy(_._1.uniqueName).map(_._2) @@ -546,7 +556,7 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int val domainCnstr = orJoin(quorums.map { quorum => val quantifierDomains = quorum.flatMap { case (path, caller, args) => - val matcher = e(expr) match { + val matcher = e(caller) match { case l: Lambda => gctx.lambdas.getOrElse(l, l) case ev => ev } diff --git a/src/main/scala/leon/evaluators/StreamEvaluator.scala b/src/main/scala/leon/evaluators/StreamEvaluator.scala index 16d42e892881044e7892de1de7bd23de8daf9911..d53869f898288ba466b8d79729ba8964dbf06095 100644 --- a/src/main/scala/leon/evaluators/StreamEvaluator.scala +++ b/src/main/scala/leon/evaluators/StreamEvaluator.scala @@ -14,7 +14,7 @@ import purescala.Expressions._ import purescala.Quantification._ import leon.solvers.{SolverFactory, PartialModel} -import leon.solvers.combinators.UnrollingProcedure +import leon.solvers.unrolling.UnrollingProcedure import leon.utils.StreamUtils._ import scala.concurrent.duration._ @@ -166,7 +166,7 @@ class StreamEvaluator(ctx: LeonContext, prog: Program) val domainCnstr = orJoin(quorums.map { quorum => val quantifierDomains = quorum.flatMap { case (path, caller, args) => - val optMatcher = e(expr) match { + val optMatcher = e(caller) match { case Stream(l: Lambda) => Some(gctx.lambdas.getOrElse(l, l)) case Stream(ev) => Some(ev) case _ => None diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala index 85e587444a41b59e38a057ac6fc8e1c6e616b8e5..9dc3514a4996ffd67d77482ffa786fef469d4b82 100644 --- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala +++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala @@ -333,12 +333,12 @@ trait CodeExtraction extends ASTExtractors { } classToInvariants.get(sym).foreach { bodies => + val cd = classesToClasses(sym) val fd = new FunDef(invId, Seq.empty, Seq.empty, BooleanType) fd.addFlag(IsADTInvariant) + fd.addFlags(cd.flags.collect { case annot : purescala.Definitions.Annotation => annot }) - val cd = classesToClasses(sym) cd.registerMethod(fd) - cd.addFlag(IsADTInvariant) val ctparams = sym.tpe match { case TypeRef(_, _, tps) => extractTypeParams(tps).map(_._1) @@ -381,7 +381,6 @@ trait CodeExtraction extends ASTExtractors { case t => extractFunOrMethodBody(None, t) - } case _ => } @@ -559,9 +558,9 @@ trait CodeExtraction extends ASTExtractors { // Extract class val cd = if (sym.isAbstractClass) { - AbstractClassDef(id, tparams, parent.map(_._1)) + new AbstractClassDef(id, tparams, parent.map(_._1)) } else { - CaseClassDef(id, tparams, parent.map(_._1), sym.isModuleClass) + new CaseClassDef(id, tparams, parent.map(_._1), sym.isModuleClass) } cd.setPos(sym.pos) //println(s"Registering $sym") diff --git a/src/main/scala/leon/invariant/engine/UnfoldingTemplateSolver.scala b/src/main/scala/leon/invariant/engine/UnfoldingTemplateSolver.scala index 7634da75ad61fca3613566861457c3d77ed6784b..4cc4750d39e5d9e070929b40dd845a4c2546b58e 100644 --- a/src/main/scala/leon/invariant/engine/UnfoldingTemplateSolver.scala +++ b/src/main/scala/leon/invariant/engine/UnfoldingTemplateSolver.scala @@ -191,7 +191,7 @@ class UnfoldingTemplateSolver(ctx: InferenceContext, program: Program, rootFd: F } import leon.solvers._ - import leon.solvers.combinators.UnrollingSolver + import leon.solvers.unrolling.UnrollingSolver def solveUsingLeon(leonctx: LeonContext, p: Program, vc: VC) = { val solFactory = SolverFactory.uninterpreted(leonctx, program) val smtUnrollZ3 = new UnrollingSolver(ctx.leonContext, program, solFactory.getNewSolver()) with TimeoutSolver diff --git a/src/main/scala/leon/laziness/FreeVariableFactory.scala b/src/main/scala/leon/laziness/FreeVariableFactory.scala index 441da84ba6844b3922456d2d63f45d600bcc05d4..7388227095e049d344b5e1b4a354ff69f46c7a08 100644 --- a/src/main/scala/leon/laziness/FreeVariableFactory.scala +++ b/src/main/scala/leon/laziness/FreeVariableFactory.scala @@ -15,22 +15,22 @@ import purescala.Types._ */ object FreeVariableFactory { - val fvClass = AbstractClassDef(FreshIdentifier("FreeVar@"), Seq(), None) + val fvClass = new AbstractClassDef(FreshIdentifier("FreeVar@"), Seq(), None) val fvType = AbstractClassType(fvClass, Seq()) val varCase = { - val cdef = CaseClassDef(FreshIdentifier("Var@"), Seq(), Some(fvType), false) + val cdef = new CaseClassDef(FreshIdentifier("Var@"), Seq(), Some(fvType), false) cdef.setFields(Seq(ValDef(FreshIdentifier("fl", fvType)))) fvClass.registerChild(cdef) cdef } val nextCase = { - val cdef = CaseClassDef(FreshIdentifier("NextVar@"), Seq(), Some(fvType), false) + val cdef = new CaseClassDef(FreshIdentifier("NextVar@"), Seq(), Some(fvType), false) cdef.setFields(Seq(ValDef(FreshIdentifier("fl", fvType)))) fvClass.registerChild(cdef) cdef } val nilCase = { - val cdef = CaseClassDef(FreshIdentifier("NilVar@"), Seq(), Some(fvType), false) + val cdef = new CaseClassDef(FreshIdentifier("NilVar@"), Seq(), Some(fvType), false) fvClass.registerChild(cdef) cdef } diff --git a/src/main/scala/leon/laziness/LazinessUtil.scala b/src/main/scala/leon/laziness/LazinessUtil.scala index eb9e35c31119cb216f356e88777a9e762417359b..90064561d1da3382384e9835482faa80c0786d17 100644 --- a/src/main/scala/leon/laziness/LazinessUtil.scala +++ b/src/main/scala/leon/laziness/LazinessUtil.scala @@ -134,14 +134,14 @@ object LazinessUtil { } def isLazyType(tpe: TypeTree): Boolean = tpe match { - case CaseClassType(CaseClassDef(cid, _, None, false), Seq(_)) => - cid.name == "Lazy" + case CaseClassType(ccd, Seq(_)) if !ccd.hasParent && !ccd.isCaseObject => + ccd.id.name == "Lazy" case _ => false } def isMemType(tpe: TypeTree): Boolean = tpe match { - case CaseClassType(CaseClassDef(cid, _, None, false), Seq(_)) => - cid.name == "Mem" + case CaseClassType(ccd, Seq(_)) if !ccd.hasParent && !ccd.isCaseObject => + ccd.id.name == "Mem" case _ => false } diff --git a/src/main/scala/leon/laziness/LazyClosureConverter.scala b/src/main/scala/leon/laziness/LazyClosureConverter.scala index 58fa1674135b4b2f07aaf2b94f4a44b0715d1e7a..8e86581d86d38216485bb96a0a7a25866007d0c1 100644 --- a/src/main/scala/leon/laziness/LazyClosureConverter.scala +++ b/src/main/scala/leon/laziness/LazyClosureConverter.scala @@ -742,7 +742,7 @@ class LazyClosureConverter(p: Program, ctx: LeonContext, } def transformCaseClasses = p.definedClasses.foreach { - case ccd @ CaseClassDef(id, tparamDefs, superClass, isCaseObj) if !ccd.flags.contains(Annotation("library", Seq())) && + case ccd: CaseClassDef if !ccd.flags.contains(Annotation("library", Seq())) && ccd.fields.exists(vd => isLazyType(vd.getType)) => val nfields = ccd.fields.map { fld => unwrapLazyType(fld.getType) match { diff --git a/src/main/scala/leon/laziness/LazyClosureFactory.scala b/src/main/scala/leon/laziness/LazyClosureFactory.scala index 1a69fa89dd438210969981f1cd452e4ac6ba90cf..ae7da4aa2bcafd6b3234569f6244038cb3a6b950 100644 --- a/src/main/scala/leon/laziness/LazyClosureFactory.scala +++ b/src/main/scala/leon/laziness/LazyClosureFactory.scala @@ -59,7 +59,7 @@ class LazyClosureFactory(p: Program) { ops.tail.forall(op => isMemoized(op) == isMemoized(ops.head)) } val absTParams = (1 to tpcount).map(i => TypeParameterDef(TypeParameter.fresh("T" + i))) - tpename -> AbstractClassDef(FreshIdentifier(typeNameToADTName(tpename), Untyped), + tpename -> new AbstractClassDef(FreshIdentifier(typeNameToADTName(tpename), Untyped), absTParams, None) }.toMap var opToAdt = Map[FunDef, CaseClassDef]() @@ -76,7 +76,7 @@ class LazyClosureFactory(p: Program) { assert(opfd.tparams.size == absTParamsDef.size) val absType = AbstractClassType(absClass, opfd.tparams.map(_.tp)) val classid = FreshIdentifier(opNameToCCName(opfd.id.name), Untyped) - val cdef = CaseClassDef(classid, opfd.tparams, Some(absType), isCaseObject = false) + val cdef = new CaseClassDef(classid, opfd.tparams, Some(absType), isCaseObject = false) val nfields = opfd.params.map { vd => val fldType = vd.getType unwrapLazyType(fldType) match { @@ -105,7 +105,7 @@ class LazyClosureFactory(p: Program) { case NAryType(tparams, tcons) => tcons(absTParams) } val eagerid = FreshIdentifier("Eager" + TypeUtil.typeNameWOParams(clresType)) - val eagerClosure = CaseClassDef(eagerid, absTParamsDef, + val eagerClosure = new CaseClassDef(eagerid, absTParamsDef, Some(AbstractClassType(absClass, absTParams)), isCaseObject = false) eagerClosure.setFields(Seq(ValDef(FreshIdentifier("a", clresType)))) absClass.registerChild(eagerClosure) @@ -166,7 +166,7 @@ class LazyClosureFactory(p: Program) { val fldType = SetType(AbstractClassType(absClass, tparams)) ValDef(FreshIdentifier(typeToFieldName(tn), fldType)) } - val ccd = CaseClassDef(FreshIdentifier("State@"), tparams map TypeParameterDef, None, false) + val ccd = new CaseClassDef(FreshIdentifier("State@"), tparams map TypeParameterDef, None, false) ccd.setFields(fields) ccd } diff --git a/src/main/scala/leon/purescala/Common.scala b/src/main/scala/leon/purescala/Common.scala index 848ded806dc049df282c85b022c3c35a09be10da..c3af4f2004f449d7087afe892ebacc225b9df3d6 100644 --- a/src/main/scala/leon/purescala/Common.scala +++ b/src/main/scala/leon/purescala/Common.scala @@ -95,8 +95,10 @@ object Common { * @param tpe The type of the identifier * @param alwaysShowUniqueID If the unique ID should always be shown */ - def apply(name: String, tpe: TypeTree = Untyped, alwaysShowUniqueID: Boolean = false) : Identifier = - new Identifier(decode(name), uniqueCounter.nextGlobal, uniqueCounter.next(name), tpe, alwaysShowUniqueID) + def apply(name: String, tpe: TypeTree = Untyped, alwaysShowUniqueID: Boolean = false) : Identifier = { + val decoded = decode(name) + new Identifier(decoded, uniqueCounter.nextGlobal, uniqueCounter.next(decoded), tpe, alwaysShowUniqueID) + } /** Builds a fresh identifier, whose ID is always shown * @@ -104,7 +106,7 @@ object Common { * @param forceId The forced ID of the identifier * @param tpe The type of the identifier */ - def apply(name: String, forceId: Int, tpe: TypeTree): Identifier = + def apply(name: String, forceId: Int, tpe: TypeTree): Identifier = new Identifier(decode(name), uniqueCounter.nextGlobal, forceId, tpe, alwaysShowUniqueID = true) } diff --git a/src/main/scala/leon/purescala/DefOps.scala b/src/main/scala/leon/purescala/DefOps.scala index 8d7f626d8a3a0771f1044b8d6b4470fd97232d68..753b0944c62023257398b09305ed85020eae64b2 100644 --- a/src/main/scala/leon/purescala/DefOps.scala +++ b/src/main/scala/leon/purescala/DefOps.scala @@ -9,6 +9,8 @@ import ExprOps.{preMap, functionCallsOf} import leon.purescala.Types.AbstractClassType import leon.purescala.Types._ +import scala.collection.mutable.{Map => MutableMap} + object DefOps { private def packageOf(df: Definition)(implicit pgm: Program): PackageRef = { @@ -284,13 +286,111 @@ object DefOps { } } + def replaceDefsInProgram(p: Program)(fdMap: Map[FunDef, FunDef] = Map.empty, + cdMap: Map[ClassDef, ClassDef] = Map.empty): Program = { + p.copy(units = for (u <- p.units) yield { + u.copy(defs = u.defs.map { + case m : ModuleDef => + m.copy(defs = for (df <- m.defs) yield { + df match { + case cd : ClassDef => cdMap.getOrElse(cd, cd) + case fd : FunDef => fdMap.getOrElse(fd, fd) + case d => d + } + }) + case cd: ClassDef => cdMap.getOrElse(cd, cd) + case d => d + }) + }) + } + + def replaceDefs(p: Program)(fdMapF: FunDef => Option[FunDef], + cdMapF: ClassDef => Option[ClassDef], + fiMapF: (FunctionInvocation, FunDef) => Option[Expr] = defaultFiMap, + ciMapF: (CaseClass, CaseClassType) => Option[Expr] = defaultCdMap) + : (Program, Map[Identifier, Identifier], Map[FunDef, FunDef], Map[ClassDef, ClassDef]) = { + + val idMap: MutableMap[Identifier, Identifier] = MutableMap.empty + val cdMap: MutableMap[ClassDef , ClassDef ] = MutableMap.empty + val fdMap: MutableMap[FunDef , FunDef ] = MutableMap.empty + + val dependencies = new DependencyFinder + val transformer = new TreeTransformer { + override def transform(id: Identifier): Identifier = idMap.getOrElse(id, { + val ntpe = transform(id.getType) + val nid = if (ntpe == id.getType) id else id.duplicate(tpe = ntpe) + idMap += id -> nid + nid + }) + + override def transform(cd: ClassDef): ClassDef = cdMap.getOrElse(cd, cd) + override def transform(fd: FunDef): FunDef = fdMap.getOrElse(fd, fd) + + override def transform(expr: Expr)(implicit bindings: Map[Identifier, Identifier]): Expr = expr match { + case fi @ FunctionInvocation(TypedFunDef(fd, tps), args) => + val nfi = fiMapF(fi, transform(fd)) getOrElse expr + super.transform(nfi) + case cc @ CaseClass(cct, args) => + val ncc = ciMapF(cc, transform(cct).asInstanceOf[CaseClassType]) getOrElse expr + super.transform(ncc) + case _ => super.transform(expr) + } + } + + for (fd <- p.definedFunctions; nfd <- fdMapF(fd)) fdMap += fd -> nfd + for (cd <- p.definedClasses; ncd <- cdMapF(cd)) cdMap += cd -> ncd + + def requiresReplacement(d: Definition): Boolean = dependencies(d).exists { + case cd: ClassDef => cdMap contains cd + case fd: FunDef => fdMap contains fd + case _ => false + } + + def trCd(cd: ClassDef): ClassDef = cdMap.getOrElse(cd, { + val parent = cd.parent.map(act => act.copy(classDef = trCd(act.classDef).asInstanceOf[AbstractClassDef])) + val newCd = cd match { + case acd: AbstractClassDef => acd.duplicate(parent = parent) + case ccd: CaseClassDef => ccd.duplicate(parent = parent) + } + cdMap += cd -> newCd + newCd + }) + + for (cd <- p.definedClasses if requiresReplacement(cd)) trCd(cd) + for (fd <- p.definedFunctions if requiresReplacement(fd) && !(fdMap contains fd)) { + val newId = transformer.transform(fd.id) + val newReturn = transformer.transform(fd.returnType) + val newParams = fd.params map (vd => ValDef(transformer.transform(vd.id))) + fdMap += fd -> fd.duplicate(id = newId, params = newParams, returnType = newReturn) + } + + for ((cd,ncd) <- cdMap) (cd, ncd) match { + case (ccd: CaseClassDef, nccd: CaseClassDef) => + nccd.setFields(ccd.fields map (vd => ValDef(transformer.transform(vd.id)))) + ccd.invariant.foreach(fd => nccd.setInvariant(transformer.transform(fd))) + case _ => + } + + for ((fd,nfd) <- fdMap) { + val bindings = (fd.params zip nfd.params).map(p => p._1.id -> p._2.id).toMap ++ + nfd.params.map(vd => vd.id -> vd.id) + + nfd.fullBody = transformer.transform(nfd.fullBody)(bindings) + } + + val fdsMap = fdMap.toMap + val cdsMap = cdMap.toMap + val newP = replaceDefsInProgram(p)(fdsMap, cdsMap) + (newP, idMap.toMap, fdsMap, cdsMap) + } + private def defaultFiMap(fi: FunctionInvocation, nfd: FunDef): Option[Expr] = (fi, nfd) match { case (FunctionInvocation(old, args), newfd) if old.fd != newfd => Some(FunctionInvocation(newfd.typed(old.tps), args)) case _ => None } - + /** Clones the given program by replacing some functions by other functions. * * @param p The original program @@ -300,80 +400,15 @@ object DefOps { * @return the new program with a map from the old functions to the new functions */ def replaceFunDefs(p: Program)(fdMapF: FunDef => Option[FunDef], fiMapF: (FunctionInvocation, FunDef) => Option[Expr] = defaultFiMap) - : (Program, Map[FunDef, FunDef])= { - - var fdMapFCache = Map[FunDef, Option[FunDef]]() // Original fdMapF cache - var fdMapCache = Map[FunDef, Option[FunDef]]() // Final replacement. - def fdMapFCached(fd: FunDef): Option[FunDef] = { - fdMapFCache.get(fd) match { - case Some(e) => e - case None => - val new_fd = fdMapF(fd) - fdMapFCache += fd -> new_fd - new_fd - } - } - - def duplicateParents(fd: FunDef): Unit = { - fdMapCache.get(fd) match { - case None => - fdMapCache += fd -> fdMapFCached(fd).orElse(Some(fd.duplicate())) - for(fp <- p.callGraph.callers(fd)) { - duplicateParents(fp) - } - case _ => - } - } - - def fdMap(fd: FunDef): FunDef = { - fdMapCache.get(fd) match { - case Some(Some(e)) => e - case Some(None) => fd - case None => - if(fdMapFCached(fd).isDefined || p.callGraph.transitiveCallees(fd).exists(fd => fdMapFCached(fd).isDefined)) { - duplicateParents(fd) - } else { // Verify that for all - fdMapCache += fd -> None - } - fdMapCache(fd).getOrElse(fd) - } - } - - val newP = p.copy(units = for (u <- p.units) yield { - u.copy( - defs = u.defs.map { - case m : ModuleDef => - m.copy(defs = for (df <- m.defs) yield { - df match { - case f : FunDef => fdMap(f) - case d => d - } - }) - case d => d - } - ) - }) - - for(fd <- newP.definedFunctions) { - if(ExprOps.exists{ - case FunctionInvocation(TypedFunDef(fd, targs), fargs) => fdMapCache contains fd - case MatchExpr(_, cases) => cases.exists(c => PatternOps.exists{ - case UnapplyPattern(optId, TypedFunDef(fd, tps), subp) => fdMapCache contains fd - case _ => false - }(c.pattern)) - case _ => false - }(fd.fullBody)) { - fd.fullBody = replaceFunCalls(fd.fullBody, fdMap, fiMapF) - } - } - (newP, fdMapCache.collect{ case (ofd, Some(nfd)) => ofd -> nfd }) + : (Program, Map[Identifier, Identifier], Map[FunDef, FunDef], Map[ClassDef, ClassDef]) = { + replaceDefs(p)(fdMapF, cd => None, fiMapF) } def replaceFunCalls(e: Expr, fdMapF: FunDef => FunDef, fiMapF: (FunctionInvocation, FunDef) => Option[Expr] = defaultFiMap): Expr = { preMap { - case me@MatchExpr(scrut, cases) => + case me @ MatchExpr(scrut, cases) => Some(MatchExpr(scrut, cases.map(matchcase => matchcase match { - case mc@MatchCase(pattern, guard, rhs) => MatchCase(replaceFunCalls(pattern, fdMapF), guard, rhs).copiedFrom(mc) + case mc @ MatchCase(pattern, guard, rhs) => MatchCase(replaceFunCalls(pattern, fdMapF), guard, rhs).copiedFrom(mc) })).copiedFrom(me)) case fi @ FunctionInvocation(TypedFunDef(fd, tps), args) => fiMapF(fi, fdMapF(fd)).map(_.copiedFrom(fi)) @@ -381,7 +416,7 @@ object DefOps { None }(e) } - + def replaceFunCalls(p: Pattern, fdMapF: FunDef => FunDef): Pattern = PatternOps.preMap{ case UnapplyPattern(optId, TypedFunDef(fd, tps), subp) => Some(UnapplyPattern(optId, TypedFunDef(fdMapF(fd), tps), subp)) case _ => None @@ -409,6 +444,7 @@ object DefOps { var idMapCache = Map[Identifier, Identifier]() var fdMapFCache = Map[FunDef, Option[FunDef]]() var fdMapCache = Map[FunDef, Option[FunDef]]() + def cdMapF(cd: ClassDef): Option[Option[AbstractClassType] => CaseClassDef] = { cd match { case ccd: CaseClassDef => @@ -420,19 +456,20 @@ object DefOps { case acd: AbstractClassDef => None } } + def tpMap[T <: TypeTree](tt: T): T = TypeOps.postMap{ case AbstractClassType(asd, targs) => Some(AbstractClassType(cdMap(asd).asInstanceOf[AbstractClassDef], targs)) case CaseClassType(ccd, targs) => Some(CaseClassType(cdMap(ccd).asInstanceOf[CaseClassDef], targs)) case e => None }(tt).asInstanceOf[T] - + def duplicateClassDef(cd: ClassDef): ClassDef = { cdMapCache.get(cd) match { case Some(new_cd) => new_cd.get // None would have meant that this class would never be duplicated, which is not possible. case None => val parent = cd.parent.map(duplicateAbstractClassType) - val new_cd = cdMapF(cd).map(f => f(parent)).getOrElse{ + val new_cd = cdMapF(cd).map(f => f(parent)).getOrElse { cd match { case acd:AbstractClassDef => acd.duplicate(parent = parent) case ccd:CaseClassDef => @@ -443,7 +480,7 @@ object DefOps { new_cd } } - + def duplicateAbstractClassType(act: AbstractClassType): AbstractClassType = { TypeOps.postMap{ case AbstractClassType(acd, tps) => Some(AbstractClassType(duplicateClassDef(acd).asInstanceOf[AbstractClassDef], tps)) @@ -451,7 +488,7 @@ object DefOps { case _ => None }(act).asInstanceOf[AbstractClassType] } - + // If at least one descendants or known case class needs conversion, then all the hierarchy will be converted. // If something extends List[A] and A is modified, then the first something should be modified. def dependencies(s: ClassDef): Set[ClassDef] = { @@ -461,7 +498,7 @@ object DefOps { case _ => Set() }(p))))(Set(s)) } - + def cdMap(cd: ClassDef): ClassDef = { cdMapCache.get(cd) match { case Some(Some(new_cd)) => new_cd @@ -475,6 +512,7 @@ object DefOps { cdMapCache(cd).getOrElse(cd) } } + def idMap(id: Identifier): Identifier = { if (!(idMapCache contains id)) { val new_id = id.duplicate(tpe = tpMap(id.getType)) @@ -482,11 +520,11 @@ object DefOps { } idMapCache(id) } - + def idHasToChange(id: Identifier): Boolean = { typeHasToChange(id.getType) } - + def typeHasToChange(tp: TypeTree): Boolean = { TypeOps.exists{ case AbstractClassType(acd, _) => cdMap(acd) != acd @@ -494,7 +532,7 @@ object DefOps { case _ => false }(tp) } - + def patternHasToChange(p: Pattern): Boolean = { PatternOps.exists { case CaseClassPattern(optId, cct, sub) => optId.exists(idHasToChange) || typeHasToChange(cct) @@ -503,7 +541,7 @@ object DefOps { case e => false } (p) } - + def exprHasToChange(e: Expr): Boolean = { ExprOps.exists{ case Let(id, expr, body) => idHasToChange(id) @@ -523,11 +561,11 @@ object DefOps { false }(e) } - + def funDefHasToChange(fd: FunDef): Boolean = { exprHasToChange(fd.fullBody) || fd.params.exists(vid => typeHasToChange(vid.id.getType)) || typeHasToChange(fd.returnType) } - + def funHasToChange(fd: FunDef): Boolean = { funDefHasToChange(fd) || p.callGraph.transitiveCallees(fd).exists(fd => fdMapFCache.get(fd) match { @@ -536,7 +574,7 @@ object DefOps { case None => funDefHasToChange(fd) }) } - + def fdMapFCached(fd: FunDef): Option[FunDef] = { fdMapFCache.get(fd) match { case Some(e) => e @@ -550,7 +588,7 @@ object DefOps { new_fd } } - + def duplicateParents(fd: FunDef): Unit = { fdMapCache.get(fd) match { case None => @@ -561,7 +599,7 @@ object DefOps { case _ => } } - + def fdMap(fd: FunDef): FunDef = { fdMapCache.get(fd) match { case Some(Some(e)) => e @@ -575,7 +613,7 @@ object DefOps { fdMapCache(fd).getOrElse(fd) } } - + val newP = p.copy(units = for (u <- p.units) yield { u.copy( defs = u.defs.map { @@ -591,6 +629,7 @@ object DefOps { } ) }) + def replaceClassDefUse(e: Pattern): Pattern = PatternOps.postMap{ case CaseClassPattern(optId, cct, sub) => Some(CaseClassPattern(optId.map(idMap), tpMap[CaseClassType](cct), sub)) case InstanceOfPattern(optId, cct) => Some(InstanceOfPattern(optId.map(idMap), tpMap[ClassType](cct))) @@ -598,7 +637,7 @@ object DefOps { case Extractors.Pattern(Some(id), subp, builder) => Some(builder(Some(idMap(id)), subp)) case e => None }(e) - + def replaceClassDefsUse(e: Expr): Expr = { ExprOps.postMap { case Let(id, expr, body) => Some(Let(idMap(id), expr, body)) @@ -623,19 +662,23 @@ object DefOps { None }(e) } - - for(fd <- newP.definedFunctions) { - if(fdMapCache.getOrElse(fd, None).isDefined) { + + for (fd <- newP.definedFunctions) { + if (fdMapCache.getOrElse(fd, None).isDefined) { fd.fullBody = replaceClassDefsUse(fd.fullBody) } } + + // make sure classDef invariants are correctly assigned to transformed classes + for ((cd, optNew) <- cdMapCache; newCd <- optNew; inv <- newCd.invariant) { + newCd.setInvariant(fdMap(inv)) + } + (newP, cdMapCache.collect{case (cd, Some(new_cd)) => cd -> new_cd}, idMapCache, fdMapCache.collect{case (cd, Some(new_cd)) => cd -> new_cd }) } - - def addDefs(p: Program, cds: Traversable[Definition], after: Definition): Program = { var found = false diff --git a/src/main/scala/leon/purescala/DefinitionTransformer.scala b/src/main/scala/leon/purescala/DefinitionTransformer.scala new file mode 100644 index 0000000000000000000000000000000000000000..89c5b613b426f8d6c9b6b2eeaa1273da2496ed40 --- /dev/null +++ b/src/main/scala/leon/purescala/DefinitionTransformer.scala @@ -0,0 +1,104 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package purescala + +import Common._ +import Definitions._ +import Expressions._ +import Extractors._ +import Types._ + +import utils._ +import scala.collection.mutable.{Set => MutableSet} + +class DefinitionTransformer( + idMap: Bijection[Identifier, Identifier] = new Bijection[Identifier, Identifier], + fdMap: Bijection[FunDef , FunDef ] = new Bijection[FunDef , FunDef ], + cdMap: Bijection[ClassDef , ClassDef ] = new Bijection[ClassDef , ClassDef ]) extends TreeTransformer { + + override def transform(id: Identifier): Identifier = idMap.cachedB(id) { + val ntpe = transform(id.getType) + if (ntpe == id.getType) id else id.duplicate(tpe = ntpe) + } + + override def transform(fd: FunDef): FunDef = fdMap.getBorElse(fd, if (tmpDefs(fd)) fd else { + transformDefs(fd) + fdMap.toB(fd) + }) + + override def transform(cd: ClassDef): ClassDef = cdMap.getBorElse(cd, if (tmpDefs(cd)) cd else { + transformDefs(cd) + cdMap.toB(cd) + }) + + private val dependencies = new DependencyFinder + private val tmpDefs: MutableSet[Definition] = MutableSet.empty + + private def transformDefs(base: Definition): Unit = { + val deps = dependencies(base) + val (cds, fds) = { + val (c, f) = deps.partition(_.isInstanceOf[ClassDef]) + (c.map(_.asInstanceOf[ClassDef]), f.map(_.asInstanceOf[FunDef])) + } + + tmpDefs ++= cds.filterNot(cdMap containsA _) ++ fds.filterNot(fdMap containsA _) + + var requireCache: Map[Definition, Boolean] = Map.empty + def required(d: Definition): Boolean = requireCache.getOrElse(d, { + val res = d match { + case fd: FunDef => + val newReturn = transform(fd.returnType) + lazy val newParams = fd.params.map(vd => ValDef(transform(vd.id))) + lazy val newBody = transform(fd.fullBody)((fd.params.map(_.id) zip newParams.map(_.id)).toMap) + newReturn != fd.returnType || newParams != fd.params || newBody != fd.fullBody + + case cd: ClassDef => + cd.fieldsIds.exists(id => transform(id.getType) != id.getType) || + cd.invariant.exists(required) + + case _ => scala.sys.error("Should never happen!?") + } + + requireCache += d -> res + res + }) + + val req = deps filter required + val allReq = req ++ (deps filter (d => (dependencies(d) & req).nonEmpty)) + val requiredCds = allReq collect { case cd: ClassDef => cd } + val requiredFds = allReq collect { case fd: FunDef => fd } + tmpDefs --= deps + + val nonReq = deps filterNot allReq + cdMap ++= nonReq collect { case cd: ClassDef => cd -> cd } + fdMap ++= nonReq collect { case fd: FunDef => fd -> fd } + + def trCd(cd: ClassDef): ClassDef = cdMap.cachedB(cd) { + val parent = cd.parent.map(act => act.copy(classDef = trCd(act.classDef).asInstanceOf[AbstractClassDef])) + cd match { + case acd: AbstractClassDef => acd.duplicate(id = transform(acd.id), parent = parent) + case ccd: CaseClassDef => ccd.duplicate(id = transform(ccd.id), parent = parent) + } + } + + for (cd <- requiredCds) trCd(cd) + for (fd <- requiredFds) { + val newReturn = transform(fd.returnType) + val newParams = fd.params map (vd => ValDef(transform(vd.id))) + fdMap += fd -> fd.duplicate(id = transform(fd.id), params = newParams, returnType = newReturn) + } + + for (ccd <- requiredCds collect { case ccd: CaseClassDef => ccd }) { + val newCcd = cdMap.toB(ccd).asInstanceOf[CaseClassDef] + newCcd.setFields(ccd.fields.map(vd => ValDef(transform(vd.id)))) + ccd.invariant.foreach(fd => newCcd.setInvariant(transform(fd))) + } + + for (fd <- requiredFds) { + val nfd = fdMap.toB(fd) + fd.fullBody = transform(fd.fullBody)((fd.params zip nfd.params).map(p => p._1.id -> p._2.id).toMap) + } + } +} + diff --git a/src/main/scala/leon/purescala/Definitions.scala b/src/main/scala/leon/purescala/Definitions.scala index 98497afd31586ef8cff40eec4a68ef36da63162b..0103bd950b80b86b1acbd493bda5aad73a4cfef0 100644 --- a/src/main/scala/leon/purescala/Definitions.scala +++ b/src/main/scala/leon/purescala/Definitions.scala @@ -186,11 +186,11 @@ object Definitions { } lazy val algebraicDataTypes : Map[AbstractClassDef, Seq[CaseClassDef]] = defs.collect { - case c@CaseClassDef(_, _, Some(p), _) => c + case c : CaseClassDef if c.parent.isDefined => c }.groupBy(_.parent.get.classDef) lazy val singleCaseClasses : Seq[CaseClassDef] = defs.collect { - case c @ CaseClassDef(_, _, None, _) => c + case c : CaseClassDef if !c.parent.isDefined => c } } @@ -227,7 +227,7 @@ object Definitions { // Is inlined case object IsInlined extends FunctionFlag // Is an ADT invariant method - case object IsADTInvariant extends FunctionFlag with ClassFlag + case object IsADTInvariant extends FunctionFlag case object IsInner extends FunctionFlag /** Represents a class definition (either an abstract- or a case-class) */ @@ -235,7 +235,7 @@ object Definitions { self => def subDefinitions = fields ++ methods ++ tparams - + val id: Identifier val tparams: Seq[TypeParameterDef] def fields: Seq[ValDef] @@ -280,13 +280,14 @@ object Definitions { private var _invariant: Option[FunDef] = None - def invariant = _invariant - def hasInvariant = flags contains IsADTInvariant - def setInvariant(fd: FunDef): Unit = { - addFlag(IsADTInvariant) - _invariant = Some(fd) + def invariant: Option[FunDef] = parent.flatMap(_.classDef.invariant).orElse(_invariant) + def setInvariant(fd: FunDef): Unit = parent match { + case Some(act) => act.classDef.setInvariant(fd) + case None => _invariant = Some(fd) } + def hasInvariant: Boolean = invariant.isDefined || (root.knownChildren.exists(cd => cd.methods.exists(_.isInvariant))) + def annotations: Set[String] = extAnnotations.keySet def extAnnotations: Map[String, Seq[Option[Any]]] = flags.collect { case Annotation(s, args) => s -> args }.toMap @@ -336,9 +337,9 @@ object Definitions { } /** Abstract classes. */ - case class AbstractClassDef(id: Identifier, - tparams: Seq[TypeParameterDef], - parent: Option[AbstractClassType]) extends ClassDef { + class AbstractClassDef(val id: Identifier, + val tparams: Seq[TypeParameterDef], + val parent: Option[AbstractClassType]) extends ClassDef { val fields = Nil val isAbstract = true @@ -362,16 +363,17 @@ object Definitions { ): AbstractClassDef = { val acd = new AbstractClassDef(id, tparams, parent) acd.addFlags(this.flags) - parent.foreach(_.classDef.ancestors.foreach(_.registerChild(acd))) + if (!parent.exists(_.classDef.hasInvariant)) invariant.foreach(inv => acd.setInvariant(inv)) + parent.foreach(_.classDef.registerChild(acd)) acd.copiedFrom(this) } } /** Case classes/ case objects. */ - case class CaseClassDef(id: Identifier, - tparams: Seq[TypeParameterDef], - parent: Option[AbstractClassType], - isCaseObject: Boolean) extends ClassDef { + class CaseClassDef(val id: Identifier, + val tparams: Seq[TypeParameterDef], + val parent: Option[AbstractClassType], + val isCaseObject: Boolean) extends ClassDef { private var _fields = Seq[ValDef]() @@ -393,14 +395,14 @@ object Definitions { ) } else index } - + lazy val singleCaseClasses : Seq[CaseClassDef] = if (hasParent) Nil else Seq(this) + def typed: CaseClassType = typed(tparams.map(_.tp)) def typed(tps: Seq[TypeTree]): CaseClassType = { require(tps.length == tparams.length) CaseClassType(this, tps) } - def typed: CaseClassType = typed(tparams.map(_.tp)) /** Duplication of this [[CaseClassDef]]. * @note This will not replace recursive [[CaseClassDef]] calls in [[fields]] nor the parent abstract class types @@ -415,9 +417,9 @@ object Definitions { val cd = new CaseClassDef(id, tparams, parent, isCaseObject) cd.setFields(fields) cd.addFlags(this.flags) + if (!parent.exists(_.classDef.hasInvariant)) invariant.foreach(inv => cd.setInvariant(inv)) + parent.foreach(_.classDef.registerChild(cd)) cd.copiedFrom(this) - parent.foreach(_.classDef.ancestors.foreach(_.registerChild(cd))) - cd } } diff --git a/src/main/scala/leon/purescala/DependencyFinder.scala b/src/main/scala/leon/purescala/DependencyFinder.scala new file mode 100644 index 0000000000000000000000000000000000000000..5382046e1bba0fd5b6040f3a6b01c994a29595b4 --- /dev/null +++ b/src/main/scala/leon/purescala/DependencyFinder.scala @@ -0,0 +1,84 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package purescala + +import Common._ +import Definitions._ +import Expressions._ +import Extractors._ +import Types._ + +import scala.collection.mutable.{Map => MutableMap, Set => MutableSet} + +class DependencyFinder { + private val deps: MutableMap[Definition, Set[Definition]] = MutableMap.empty + + def apply(d: Definition): Set[Definition] = deps.getOrElse(d, { + new Finder(d).dependencies + }) + + private class Finder(var current: Definition) extends TreeTraverser { + val foundDeps: MutableMap[Definition, MutableSet[Definition]] = MutableMap.empty + foundDeps(current) = MutableSet.empty + + private def withCurrent[T](d: Definition)(b: => T): T = { + if (!(foundDeps contains d)) foundDeps(d) = MutableSet.empty + val c = current + current = d + val res = b + current = c + res + } + + override def traverse(id: Identifier): Unit = traverse(id.getType) + + override def traverse(cd: ClassDef): Unit = if (!foundDeps(current)(cd)) { + foundDeps(current) += cd + if (!(deps contains cd) && !(foundDeps contains cd)) { + for (cd <- cd.root.knownDescendants :+ cd) { + cd.invariant foreach (fd => withCurrent(cd)(traverse(fd))) + withCurrent(cd)(cd.fieldsIds foreach traverse) + cd.parent foreach { p => + foundDeps(p.classDef) = foundDeps.getOrElse(p.classDef, MutableSet.empty) + cd + foundDeps(cd) = foundDeps.getOrElse(cd, MutableSet.empty) + p.classDef + } + } + } + } + + override def traverse(fd: FunDef): Unit = if (!foundDeps(current)(fd)) { + foundDeps(current) += fd + if (!(deps contains fd) && !(foundDeps contains fd)) withCurrent(fd) { + fd.params foreach (vd => traverse(vd.id)) + traverse(fd.returnType) + traverse(fd.fullBody) + } + } + + def dependencies: Set[Definition] = { + current match { + case fd: FunDef => traverse(fd) + case cd: ClassDef => traverse(cd) + case _ => + } + + for ((d, ds) <- foundDeps) { + deps(d) = deps.getOrElse(d, Set.empty) ++ ds + } + + var changed = false + do { + for ((d, ds) <- deps.toSeq) { + val next = ds.flatMap(d => deps.getOrElse(d, Set.empty)) + if (!(next subsetOf ds)) { + deps(d) = next + changed = true + } + } + } while (changed) + + deps(current) + } + } +} diff --git a/src/main/scala/leon/purescala/Expressions.scala b/src/main/scala/leon/purescala/Expressions.scala index 6d6382709e704d4562a5a0027005d7b0ef4c4264..79172eb7a538c298e5b5eb026eb3e8dfca397f3a 100644 --- a/src/main/scala/leon/purescala/Expressions.scala +++ b/src/main/scala/leon/purescala/Expressions.scala @@ -388,7 +388,7 @@ object Expressions { * @param out The output expression * @param cases The cases to compare against */ - case class Passes(in: Expr, out : Expr, cases : Seq[MatchCase]) extends Expr { + case class Passes(in: Expr, out: Expr, cases: Seq[MatchCase]) extends Expr { require(cases.nonEmpty) val getType = leastUpperBound(cases.map(_.rhs.getType)) match { @@ -444,7 +444,6 @@ object Expressions { * This is useful e.g. to present counterexamples of generic types. */ case class GenericValue(tp: TypeParameter, id: Int) extends Expr with Terminal { - // TODO: Is it valid that GenericValue(tp, 0) != GenericValue(tp, 1)? val getType = tp } @@ -847,6 +846,43 @@ object Expressions { val getType = leastUpperBound(Seq(set1, set2).map(_.getType)).getOrElse(Untyped).unveilUntyped } + /* Bag operations */ + /** $encodingof `Bag[base](elements)` */ + case class FiniteBag(elements: Map[Expr, Int], base: TypeTree) extends Expr { + val getType = BagType(base).unveilUntyped + } + /** $encodingof `bag.get(element)` or `bag(element)` */ + case class MultiplicityInBag(element: Expr, bag: Expr) extends Expr { + val getType = IntegerType + } + /** $encodingof `bag.length` */ + /* + case class BagCardinality(bag: Expr) extends Expr { + val getType = IntegerType + } + */ + /** $encodingof `bag1.subsetOf(bag2)` */ + /* + case class SubbagOf(bag1: Expr, bag2: Expr) extends Expr { + val getType = BooleanType + } + */ + /** $encodingof `bag1.intersect(bag2)` */ + case class BagIntersection(bag1: Expr, bag2: Expr) extends Expr { + val getType = leastUpperBound(Seq(bag1, bag2).map(_.getType)).getOrElse(Untyped).unveilUntyped + } + /** $encodingof `bag1 ++ bag2` */ + case class BagUnion(bag1: Expr, bag2: Expr) extends Expr { + val getType = leastUpperBound(Seq(bag1, bag2).map(_.getType)).getOrElse(Untyped).unveilUntyped + } + /** $encodingof `bag1 -- bag2` */ + /* + case class SetDifference(bag1: Expr, bag2: Expr) extends Expr { + val getType = leastUpperBound(Seq(set1, set2).map(_.getType)).getOrElse(Untyped).unveilUntyped + } + */ + + // TODO: Add checks for these expressions too /* Map operations */ diff --git a/src/main/scala/leon/purescala/Extractors.scala b/src/main/scala/leon/purescala/Extractors.scala index c7d8a1a40ac9d2bf7f1386311941fe1006881ad3..5c09c2d296bcbf7ddeb3518f2c42acdc6a695cb8 100644 --- a/src/main/scala/leon/purescala/Extractors.scala +++ b/src/main/scala/leon/purescala/Extractors.scala @@ -144,6 +144,12 @@ object Extractors { Some(Seq(t1, t2), (es: Seq[Expr]) => SetUnion(es(0), es(1))) case SetDifference(t1, t2) => Some(Seq(t1, t2), (es: Seq[Expr]) => SetDifference(es(0), es(1))) + case MultiplicityInBag(e1, e2) => + Some(Seq(e1, e2), (es: Seq[Expr]) => MultiplicityInBag(es(0), es(1))) + case BagIntersection(e1, e2) => + Some(Seq(e1, e2), (es: Seq[Expr]) => BagIntersection(es(0), es(1))) + case BagUnion(e1, e2) => + Some(Seq(e1, e2), (es: Seq[Expr]) => BagUnion(es(0), es(1))) case mg @ MapApply(t1, t2) => Some(Seq(t1, t2), (es: Seq[Expr]) => MapApply(es(0), es(1))) case MapUnion(t1, t2) => @@ -173,6 +179,9 @@ object Extractors { case SubString(t1, a, b) => Some((t1::a::b::Nil, es => SubString(es(0), es(1), es(2)))) case FiniteSet(els, base) => Some((els.toSeq, els => FiniteSet(els.toSet, base))) + case FiniteBag(els, base) => + val seq = els.toSeq + Some((seq.map(_._1), els => FiniteBag((els zip seq.map(_._2)).toMap, base))) case FiniteMap(args, f, t) => { val subArgs = args.flatMap { case (k, v) => Seq(k, v) }.toSeq val builder = (as: Seq[Expr]) => { @@ -381,5 +390,4 @@ object Extractors { } } } - } diff --git a/src/main/scala/leon/purescala/MethodLifting.scala b/src/main/scala/leon/purescala/MethodLifting.scala index 957505f297b551ff0bab84955015de35b0a0d7e2..13810f83db03029ec0a8f66dfb00112a841199f9 100644 --- a/src/main/scala/leon/purescala/MethodLifting.scala +++ b/src/main/scala/leon/purescala/MethodLifting.scala @@ -156,19 +156,18 @@ object MethodLifting extends TransformationPhase { isInstOf(Variable(receiver), cl.typed(ctParams map { _.tp })) } - if (cd.knownDescendants.forall( cd => (cd.methods ++ cd.fields).forall(_.id != fd.id))) { // Don't need to compose methods - val paramsMap = fd.params.zip(fdParams).map{case (x,y) => (x.id, y.id)}.toMap + val paramsMap = fd.params.zip(fdParams).map { case (x,y) => (x.id, y.id) }.toMap def thisToReceiver(e: Expr): Option[Expr] = e match { - case th@This(ct) => + case th @ This(ct) => Some(asInstOf(receiver.toVariable, ct).setPos(th)) case _ => None } val insTp: Expr => Expr = instantiateType(_, tparamsMap, paramsMap) - nfd.fullBody = insTp( postMap(thisToReceiver)(insTp(nfd.fullBody)) ) + nfd.fullBody = postMap(thisToReceiver)(insTp(nfd.fullBody)) // Add precondition if the method was defined in a subclass val pre = and( @@ -186,19 +185,22 @@ object MethodLifting extends TransformationPhase { m <- c.methods if m.id == fd.id (from,to) <- m.params zip fdParams } yield (from.id, to.id)).toMap + val classParamsMap = (for { c <- cd.knownDescendants :+ cd (from, to) <- c.tparams zip ctParams } yield (from, to.tp)).toMap + val methodParamsMap = (for { c <- cd.knownDescendants :+ cd m <- c.methods if m.id == fd.id (from,to) <- m.tparams zip fd.tparams } yield (from, to.tp)).toMap + def inst(cs: Seq[MatchCase]) = instantiateType( matchExpr(Variable(receiver), cs).setPos(fd), classParamsMap ++ methodParamsMap, - paramsMap + paramsMap + (receiver -> receiver) ) /* Separately handle pre, post, body */ diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala index 3773ebce48b3c51e056cee01f1ed059288d25bc5..f1b80fd9de6df1996d33fdc7936669be6a997b73 100644 --- a/src/main/scala/leon/purescala/PrettyPrinter.scala +++ b/src/main/scala/leon/purescala/PrettyPrinter.scala @@ -497,10 +497,10 @@ class PrettyPrinter(opts: PrinterOptions, | ${nary(defs, "\n\n")} |}""" - case acd @ AbstractClassDef(id, tparams, parent) => - p"abstract class $id${nary(tparams, ", ", "[", "]")}" + case acd : AbstractClassDef => + p"abstract class ${acd.id}${nary(acd.tparams, ", ", "[", "]")}" - parent.foreach{ par => + acd.parent.foreach{ par => p" extends ${par.id}" } @@ -510,22 +510,22 @@ class PrettyPrinter(opts: PrinterOptions, |}""" } - case ccd @ CaseClassDef(id, tparams, parent, isObj) => - if (isObj) { - p"case object $id" + case ccd : CaseClassDef => + if (ccd.isCaseObject) { + p"case object ${ccd.id}" } else { - p"case class $id" + p"case class ${ccd.id}" } - p"${nary(tparams, ", ", "[", "]")}" + p"${nary(ccd.tparams, ", ", "[", "]")}" - if (!isObj) { + if (!ccd.isCaseObject) { p"(${ccd.fields})" } - parent.foreach { par => + ccd.parent.foreach { par => // Remember child and parents tparams are simple bijection - p" extends ${par.id}${nary(tparams, ", ", "[", "]")}" + p" extends ${par.id}${nary(ccd.tparams, ", ", "[", "]")}" } if (ccd.methods.nonEmpty) { diff --git a/src/main/scala/leon/purescala/Quantification.scala b/src/main/scala/leon/purescala/Quantification.scala index 079be92d12c0ebaf8913842207a3ba1a65a387fd..c815f4ca5a27eaeb04660a06863fc9ac61ca8fc8 100644 --- a/src/main/scala/leon/purescala/Quantification.scala +++ b/src/main/scala/leon/purescala/Quantification.scala @@ -22,10 +22,15 @@ object Quantification { qargs: A => Set[B] ): Seq[Set[A]] = { def expand(m: A): Set[A] = Set(m) ++ margs(m).flatMap(expand) + def allQArgs(m: A): Set[B] = qargs(m) ++ margs(m).flatMap(allQArgs) val expandedMap: Map[A, Set[A]] = matchers.map(m => m -> expand(m)).toMap val reverseMap : Map[A, Set[A]] = expandedMap.toSeq .flatMap(p => p._2.map(m => m -> p._1)) // flatten to reversed pairs .groupBy(_._1).mapValues(_.map(_._2).toSet) // rebuild map from pair set + .map { case (m, ms) => // filter redundant matchers + val allM = allQArgs(m) + m -> ms.filter(rm => allQArgs(rm) != allM) + } def rec(oms: Seq[A], mSet: Set[A], qss: Seq[Set[B]]): Seq[Set[A]] = { if (qss.contains(quantified)) { diff --git a/src/main/scala/leon/purescala/RestoreMethods.scala b/src/main/scala/leon/purescala/RestoreMethods.scala index 5c0abc5e5472a52ed750444366e8c5eff38e10f5..664b18978167d3f0f752650644aae71aecf5feb1 100644 --- a/src/main/scala/leon/purescala/RestoreMethods.scala +++ b/src/main/scala/leon/purescala/RestoreMethods.scala @@ -57,7 +57,7 @@ object RestoreMethods extends TransformationPhase { }) }) - val (np2, _) = replaceFunDefs(np)(fd => None, { (fi, fd) => + val (np2, _, _, _) = replaceFunDefs(np)(fd => None, { (fi, fd) => fdToMd.get(fi.tfd.fd) match { case Some(md) => Some(MethodInvocation( diff --git a/src/main/scala/leon/purescala/TreeTransformer.scala b/src/main/scala/leon/purescala/TreeTransformer.scala new file mode 100644 index 0000000000000000000000000000000000000000..77a5659f1dbb8d49d7b81983656e33eaaa128a9c --- /dev/null +++ b/src/main/scala/leon/purescala/TreeTransformer.scala @@ -0,0 +1,234 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package purescala + +import Common._ +import Definitions._ +import Expressions._ +import Extractors._ +import Types._ + +import utils._ +import scala.collection.mutable.{Map => MutableMap, Set => MutableSet} + +trait TreeTransformer { + def transform(id: Identifier): Identifier = id + def transform(cd: ClassDef): ClassDef = cd + def transform(fd: FunDef): FunDef = fd + + def transform(e: Expr)(implicit bindings: Map[Identifier, Identifier]): Expr = e match { + case Variable(id) if bindings contains id => Variable(bindings(id)).copiedFrom(e) + case Variable(id) => Variable(transform(id)).copiedFrom(e) + case FiniteLambda(mappings, default, tpe) => + FiniteLambda(mappings.map { case (ks, v) => (ks map transform, transform(v)) }, + transform(default), transform(tpe).asInstanceOf[FunctionType]).copiedFrom(e) + case Lambda(args, body) => + val newArgs = args.map(vd => ValDef(transform(vd.id))) + val newBindings = (args zip newArgs).map(p => p._1.id -> p._2.id) + Lambda(newArgs, transform(body)(bindings ++ newBindings)).copiedFrom(e) + case Forall(args, body) => + val newArgs = args.map(vd => ValDef(transform(vd.id))) + val newBindings = (args zip newArgs).map(p => p._1.id -> p._2.id) + Forall(newArgs, transform(body)(bindings ++ newBindings)).copiedFrom(e) + case Let(a, expr, body) => + val newA = transform(a) + Let(newA, transform(expr), transform(body)(bindings + (a -> newA))).copiedFrom(e) + case CaseClass(cct, args) => + CaseClass(transform(cct).asInstanceOf[CaseClassType], args map transform).copiedFrom(e) + case CaseClassSelector(cct, caseClass, selector) => + val newCct @ CaseClassType(ccd, _) = transform(cct) + val newSelector = ccd.fieldsIds(cct.classDef.fieldsIds.indexOf(selector)) + CaseClassSelector(newCct, transform(caseClass), newSelector).copiedFrom(e) + case FunctionInvocation(TypedFunDef(fd, tpes), args) => + FunctionInvocation(TypedFunDef(transform(fd), tpes map transform), args map transform).copiedFrom(e) + case MethodInvocation(rec, cd, TypedFunDef(fd, tpes), args) => + MethodInvocation(transform(rec), transform(cd), TypedFunDef(transform(fd), tpes map transform), args map transform).copiedFrom(e) + case This(ct) => + This(transform(ct).asInstanceOf[ClassType]).copiedFrom(e) + case IsInstanceOf(expr, ct) => + IsInstanceOf(transform(expr), transform(ct).asInstanceOf[ClassType]).copiedFrom(e) + case AsInstanceOf(expr, ct) => + AsInstanceOf(transform(expr), transform(ct).asInstanceOf[ClassType]).copiedFrom(e) + case MatchExpr(scrutinee, cases) => + MatchExpr(transform(scrutinee), for (cse @ MatchCase(pattern, guard, rhs) <- cases) yield { + val (newPattern, newBindings) = transform(pattern) + val allBindings = bindings ++ newBindings + MatchCase(newPattern, guard.map(g => transform(g)(allBindings)), transform(rhs)(allBindings)).copiedFrom(cse) + }).copiedFrom(e) + case Passes(in, out, cases) => + Passes(transform(in), transform(out), for (cse @ MatchCase(pattern, guard, rhs) <- cases) yield { + val (newPattern, newBindings) = transform(pattern) + val allBindings = bindings ++ newBindings + MatchCase(newPattern, guard.map(g => transform(g)(allBindings)), transform(rhs)(allBindings)).copiedFrom(cse) + }).copiedFrom(e) + case FiniteSet(es, tpe) => + FiniteSet(es map transform, transform(tpe)).copiedFrom(e) + case FiniteBag(es, tpe) => + FiniteBag(es map { case (k, v) => transform(k) -> v }, transform(tpe)).copiedFrom(e) + case FiniteMap(pairs, from, to) => + FiniteMap(pairs map { case (k, v) => transform(k) -> transform(v) }, transform(from), transform(to)).copiedFrom(e) + case EmptyArray(tpe) => + EmptyArray(transform(tpe)).copiedFrom(e) + case Hole(tpe, alts) => + Hole(transform(tpe), alts map transform).copiedFrom(e) + case NoTree(tpe) => + NoTree(transform(tpe)).copiedFrom(e) + case Error(tpe, desc) => + Error(transform(tpe), desc).copiedFrom(e) + case Operator(es, builder) => + val newEs = es map transform + if ((newEs zip es).exists { case (bef, aft) => aft ne bef }) { + builder(newEs).copiedFrom(e) + } else { + e + } + case e => e + } + + def transform(pat: Pattern): (Pattern, Map[Identifier, Identifier]) = pat match { + case InstanceOfPattern(binder, ct) => + val newBinder = binder map transform + val newPat = InstanceOfPattern(newBinder, transform(ct).asInstanceOf[ClassType]).copiedFrom(pat) + (newPat, (binder zip newBinder).toMap) + case WildcardPattern(binder) => + val newBinder = binder map transform + val newPat = WildcardPattern(newBinder).copiedFrom(pat) + (newPat, (binder zip newBinder).toMap) + case CaseClassPattern(binder, ct, subs) => + val newBinder = binder map transform + val (newSubs, subBinders) = (subs map transform).unzip + val newPat = CaseClassPattern(newBinder, transform(ct).asInstanceOf[CaseClassType], newSubs).copiedFrom(pat) + (newPat, (binder zip newBinder).toMap ++ subBinders.flatten) + case TuplePattern(binder, subs) => + val newBinder = binder map transform + val (newSubs, subBinders) = (subs map transform).unzip + val newPat = TuplePattern(newBinder, newSubs).copiedFrom(pat) + (newPat, (binder zip newBinder).toMap ++ subBinders.flatten) + case UnapplyPattern(binder, TypedFunDef(fd, tpes), subs) => + val newBinder = binder map transform + val (newSubs, subBinders) = (subs map transform).unzip + val newPat = UnapplyPattern(newBinder, TypedFunDef(transform(fd), tpes map transform), newSubs).copiedFrom(pat) + (newPat, (binder zip newBinder).toMap ++ subBinders.flatten) + case PatternExtractor(subs, builder) => + val (newSubs, subBinders) = (subs map transform).unzip + (builder(newSubs).copiedFrom(pat), subBinders.flatten.toMap) + } + + def transform(tpe: TypeTree): TypeTree = tpe match { + case cct @ CaseClassType(ccd, args) => + CaseClassType(transform(ccd).asInstanceOf[CaseClassDef], args map transform).copiedFrom(tpe) + case act @ AbstractClassType(acd, args) => + AbstractClassType(transform(acd).asInstanceOf[AbstractClassDef], args map transform).copiedFrom(tpe) + case NAryType(ts, builder) => builder(ts map transform).copiedFrom(tpe) + } +} + +trait TreeTraverser { + def traverse(id: Identifier): Unit = () + def traverse(cd: ClassDef): Unit = () + def traverse(fd: FunDef): Unit = () + + def traverse(e: Expr): Unit = e match { + case Variable(id) => traverse(id) + case FiniteLambda(mappings, default, tpe) => + (default +: mappings.toSeq.flatMap(p => p._2 +: p._1)) foreach traverse + traverse(tpe) + case Lambda(args, body) => + args foreach (vd => traverse(vd.id)) + traverse(body) + case Forall(args, body) => + args foreach (vd => traverse(vd.id)) + traverse(body) + case Let(a, expr, body) => + traverse(a) + traverse(expr) + traverse(body) + case CaseClass(cct, args) => + traverse(cct) + args foreach traverse + case CaseClassSelector(cct, caseClass, selector) => + traverse(cct) + traverse(caseClass) + case FunctionInvocation(TypedFunDef(fd, tpes), args) => + traverse(fd) + tpes foreach traverse + args foreach traverse + case MethodInvocation(rec, cd, TypedFunDef(fd, tpes), args) => + traverse(rec) + traverse(cd) + traverse(fd) + tpes foreach traverse + args foreach traverse + case This(ct) => + traverse(ct) + case IsInstanceOf(expr, ct) => + traverse(expr) + traverse(ct) + case AsInstanceOf(expr, ct) => + traverse(expr) + traverse(ct) + case MatchExpr(scrutinee, cases) => + traverse(scrutinee) + for (cse @ MatchCase(pattern, guard, rhs) <- cases) { + traverse(pattern) + guard foreach traverse + traverse(rhs) + } + case FiniteSet(es, tpe) => + es foreach traverse + traverse(tpe) + case FiniteBag(es, tpe) => + es foreach { case (k, _) => traverse(k) } + traverse(tpe) + case FiniteMap(pairs, from, to) => + pairs foreach { case (k, v) => traverse(k); traverse(v) } + traverse(from) + traverse(to) + case EmptyArray(tpe) => + traverse(tpe) + case Hole(tpe, alts) => + traverse(tpe) + alts foreach traverse + case NoTree(tpe) => + traverse(tpe) + case Error(tpe, desc) => + traverse(tpe) + case Operator(es, builder) => + es foreach traverse + case e => + } + + def traverse(pat: Pattern): Unit = pat match { + case InstanceOfPattern(binder, ct) => + binder foreach traverse + traverse(ct) + case WildcardPattern(binder) => + binder foreach traverse + case CaseClassPattern(binder, ct, subs) => + binder foreach traverse + traverse(ct) + subs foreach traverse + case TuplePattern(binder, subs) => + binder foreach traverse + subs foreach traverse + case UnapplyPattern(binder, TypedFunDef(fd, tpes), subs) => + binder foreach traverse + traverse(fd) + tpes foreach traverse + subs foreach traverse + case PatternExtractor(subs, builder) => + subs foreach traverse + } + + def traverse(tpe: TypeTree): Unit = tpe match { + case cct @ CaseClassType(ccd, args) => + traverse(ccd) + args foreach traverse + case act @ AbstractClassType(acd, args) => + traverse(acd) + args foreach traverse + case NAryType(ts, builder) => + ts foreach traverse + } +} diff --git a/src/main/scala/leon/purescala/TypeOps.scala b/src/main/scala/leon/purescala/TypeOps.scala index 99da89dbc20e7def4e5dcb7a2cc1cfc97fb21577..1a33a0883a1ec311126b6c165c542534c74dff9a 100644 --- a/src/main/scala/leon/purescala/TypeOps.scala +++ b/src/main/scala/leon/purescala/TypeOps.scala @@ -244,192 +244,12 @@ object TypeOps extends { val Deconstructor = NAryType } with SubTreeOps[TypeTree typeParamSubst(tps.map { case (tpd, tp) => tpd.tp -> tp }) _ } - def rec(idsMap: Map[Identifier, Identifier])(e: Expr): Expr = { - - // Simple rec without affecting map - val srec = rec(idsMap) _ - - def onMatchLike(e: Expr, cases : Seq[MatchCase]) = { - - val newTpe = tpeSub(e.getType) - - def mapsUnion(maps: Seq[Map[Identifier, Identifier]]): Map[Identifier, Identifier] = { - maps.flatten.toMap - } - - def trCase(c: MatchCase) = c match { - case SimpleCase(p, b) => - val (newP, newIds) = trPattern(p, newTpe) - SimpleCase(newP, rec(idsMap ++ newIds)(b)) - - case GuardedCase(p, g, b) => - val (newP, newIds) = trPattern(p, newTpe) - GuardedCase(newP, rec(idsMap ++ newIds)(g), rec(idsMap ++ newIds)(b)) - } - - def trPattern(p: Pattern, expType: TypeTree): (Pattern, Map[Identifier, Identifier]) = (p, expType) match { - case (InstanceOfPattern(ob, ct), _) => - val newCt = tpeSub(ct).asInstanceOf[ClassType] - val newOb = ob.map(id => freshId(id, newCt)) - - (InstanceOfPattern(newOb, newCt), (ob zip newOb).toMap) - - case (TuplePattern(ob, sps), tpt @ TupleType(stps)) => - val newOb = ob.map(id => freshId(id, tpt)) - - val (newSps, newMaps) = (sps zip stps).map { case (sp, stpe) => trPattern(sp, stpe) }.unzip - - (TuplePattern(newOb, newSps), (ob zip newOb).toMap ++ mapsUnion(newMaps)) - - case (CaseClassPattern(ob, cct, sps), _) => - val newCt = tpeSub(cct).asInstanceOf[CaseClassType] - - val newOb = ob.map(id => freshId(id, newCt)) - - val (newSps, newMaps) = (sps zip newCt.fieldsTypes).map { case (sp, stpe) => trPattern(sp, stpe) }.unzip - - (CaseClassPattern(newOb, newCt, newSps), (ob zip newOb).toMap ++ mapsUnion(newMaps)) - - case (up@UnapplyPattern(ob, fd, sps), tp) => - val newFd = if ((fd.tps map tpeSub) == fd.tps) fd else fd.fd.typed(fd.tps map tpeSub) - val newOb = ob.map(id => freshId(id,tp)) - val exType = tpeSub(up.someType.tps.head) - val exTypes = unwrapTupleType(exType, exType.isInstanceOf[TupleType]) - val (newSps, newMaps) = (sps zip exTypes).map { case (sp, stpe) => trPattern(sp, stpe) }.unzip - (UnapplyPattern(newOb, newFd, newSps), (ob zip newOb).toMap ++ mapsUnion(newMaps)) - - case (WildcardPattern(ob), expTpe) => - val newOb = ob.map(id => freshId(id, expTpe)) - - (WildcardPattern(newOb), (ob zip newOb).toMap) - - case (LiteralPattern(ob, lit), expType) => - val newOb = ob.map(id => freshId(id, expType)) - (LiteralPattern(newOb,lit), (ob zip newOb).toMap) - - case _ => - sys.error(s"woot!? $p:$expType") - } - - (srec(e), cases.map(trCase))//.copiedFrom(m) - } - - e match { - case fi @ FunctionInvocation(TypedFunDef(fd, tps), args) => - FunctionInvocation(TypedFunDef(fd, tps.map(tpeSub)), args.map(srec)).copiedFrom(fi) - - case mi @ MethodInvocation(r, cd, TypedFunDef(fd, tps), args) => - MethodInvocation(srec(r), cd, TypedFunDef(fd, tps.map(tpeSub)), args.map(srec)).copiedFrom(mi) - - case th @ This(ct) => - This(tpeSub(ct).asInstanceOf[ClassType]).copiedFrom(th) - - case cc @ CaseClass(ct, args) => - CaseClass(tpeSub(ct).asInstanceOf[CaseClassType], args.map(srec)).copiedFrom(cc) - - case cc @ CaseClassSelector(ct, e, sel) => - caseClassSelector(tpeSub(ct).asInstanceOf[CaseClassType], srec(e), sel).copiedFrom(cc) - - case cc @ IsInstanceOf(e, ct) => - IsInstanceOf(srec(e), tpeSub(ct).asInstanceOf[ClassType]).copiedFrom(cc) - - case cc @ AsInstanceOf(e, ct) => - AsInstanceOf(srec(e), tpeSub(ct).asInstanceOf[ClassType]).copiedFrom(cc) - - case l @ Let(id, value, body) => - val newId = freshId(id, tpeSub(id.getType)) - Let(newId, srec(value), rec(idsMap + (id -> newId))(body)).copiedFrom(l) - - case l @ LetDef(fds, bd) => - val fdsMapping = for(fd <- fds) yield { - val id = fd.id.freshen - val tparams = fd.tparams map { p => - TypeParameterDef(tpeSub(p.tp).asInstanceOf[TypeParameter]) - } - val returnType = tpeSub(fd.returnType) - val params = fd.params map (vd => vd.copy(id = freshId(vd.id, tpeSub(vd.getType)))) - val newFd = fd.duplicate(id, tparams, params, returnType) - val subCalls = ExprOps.preMap { - case fi @ FunctionInvocation(tfd, args) if tfd.fd == fd => - Some(FunctionInvocation(newFd.typed(tfd.tps), args).copiedFrom(fi)) - case _ => - None - } _ - (fd, newFd, subCalls) - } - // We group the subcalls functions all in once - val subCalls = fdsMapping.map(_._3).reduceLeft { _ andThen _ } - - // We apply all the functions mappings at once - val newFds = for((fd, newFd, _) <- fdsMapping) yield { - val fullBody = rec(idsMap ++ fd.paramIds.zip(newFd.paramIds))(subCalls(fd.fullBody)) - newFd.fullBody = fullBody - newFd - } - val newBd = srec(subCalls(bd)).copiedFrom(bd) - - letDef(newFds, newBd).copiedFrom(l) - - case l @ Lambda(args, body) => - val newArgs = args.map { arg => - val tpe = tpeSub(arg.getType) - arg.copy(id = freshId(arg.id, tpe)) - } - val mapping = args.map(_.id) zip newArgs.map(_.id) - Lambda(newArgs, rec(idsMap ++ mapping)(body)).copiedFrom(l) - - case f @ Forall(args, body) => - val newArgs = args.map { arg => - val tpe = tpeSub(arg.getType) - arg.copy(id = freshId(arg.id, tpe)) - } - val mapping = args.map(_.id) zip newArgs.map(_.id) - Forall(newArgs, rec(idsMap ++ mapping)(body)).copiedFrom(f) - - case p @ Passes(in, out, cases) => - val (newIn, newCases) = onMatchLike(in, cases) - passes(newIn, srec(out), newCases).copiedFrom(p) - - case m @ MatchExpr(e, cases) => - val (newE, newCases) = onMatchLike(e, cases) - matchExpr(newE, newCases).copiedFrom(m) - - case Error(tpe, desc) => - Error(tpeSub(tpe), desc).copiedFrom(e) - - case Hole(tpe, alts) => - Hole(tpeSub(tpe), alts.map(srec)).copiedFrom(e) - - case g @ GenericValue(tpar, id) => - tpeSub(tpar) match { - case newTpar : TypeParameter => - GenericValue(newTpar, id).copiedFrom(g) - case other => // FIXME any better ideas? - throw LeonFatalError(Some(s"Tried to substitute $tpar with $other within GenericValue $g")) - } - - case s @ FiniteSet(elems, tpe) => - FiniteSet(elems.map(srec), tpeSub(tpe)).copiedFrom(s) - - case m @ FiniteMap(elems, from, to) => - FiniteMap(elems.map{ case (k, v) => (srec(k), srec(v)) }, tpeSub(from), tpeSub(to)).copiedFrom(m) - - case f @ FiniteLambda(mapping, dflt, FunctionType(from, to)) => - FiniteLambda(mapping.map { case (ks, v) => ks.map(srec) -> srec(v) }, srec(dflt), - FunctionType(from.map(tpeSub), tpeSub(to))).copiedFrom(f) - - case v @ Variable(id) if idsMap contains id => - Variable(idsMap(id)).copiedFrom(v) - - case n @ Operator(es, builder) => - builder(es.map(srec)).copiedFrom(n) - - case _ => - e - } + val transformer = new TreeTransformer { + override def transform(id: Identifier): Identifier = freshId(id, transform(id.getType)) + override def transform(tpe: TypeTree): TypeTree = tpeSub(tpe) } - rec(ids)(e) + transformer.transform(e)(ids) } } } diff --git a/src/main/scala/leon/purescala/Types.scala b/src/main/scala/leon/purescala/Types.scala index 0ee936c813b942d3e55564506a200b31630f0283..e2838104fd704110fbb06e6234ea1ae2599a20e7 100644 --- a/src/main/scala/leon/purescala/Types.scala +++ b/src/main/scala/leon/purescala/Types.scala @@ -80,6 +80,7 @@ object Types { } case class SetType(base: TypeTree) extends TypeTree + case class BagType(base: TypeTree) extends TypeTree case class MapType(from: TypeTree, to: TypeTree) extends TypeTree case class FunctionType(from: Seq[TypeTree], to: TypeTree) extends TypeTree case class ArrayType(base: TypeTree) extends TypeTree diff --git a/src/main/scala/leon/solvers/ADTManager.scala b/src/main/scala/leon/solvers/ADTManager.scala index b52be838a0c00417d76be9274a20e8588a8221be..7aefd95773e05f0a22f9bb435f2b2771a2c93da0 100644 --- a/src/main/scala/leon/solvers/ADTManager.scala +++ b/src/main/scala/leon/solvers/ADTManager.scala @@ -136,6 +136,17 @@ class ADTManager(ctx: LeonContext) { findDependencies(base) } + case tp @ TypeParameter(id) => + if (!(discovered contains t) && !(defined contains t)) { + val sym = freshId(id.name) + + val c = Constructor(freshId(sym.name), tp, List( + (freshId("val"), IntegerType) + )) + + discovered += (tp -> DataType(sym, Seq(c))) + } + case _ => } } diff --git a/src/main/scala/leon/solvers/SolverFactory.scala b/src/main/scala/leon/solvers/SolverFactory.scala index 8656515790b90c7d828a267b025ed40f87f8e094..8b84cd05eaff99293b6db2144ef64ece16f552aa 100644 --- a/src/main/scala/leon/solvers/SolverFactory.scala +++ b/src/main/scala/leon/solvers/SolverFactory.scala @@ -4,7 +4,9 @@ package leon package solvers import combinators._ +import unrolling._ import z3._ +import cvc4._ import smtlib._ import purescala.Definitions._ @@ -79,12 +81,10 @@ object SolverFactory { def getFromName(ctx: LeonContext, program: Program)(name: String): SolverFactory[TimeoutSolver] = name match { case "fairz3" => - // Previously: new FairZ3Solver(ctx, program) with TimeoutSolver - SolverFactory(() => new Z3StringFairZ3Solver(ctx, program) with TimeoutSolver) + SolverFactory(() => new FairZ3Solver(ctx, program) with TimeoutSolver) case "unrollz3" => - // Previously: new UnrollingSolver(ctx, program, new UninterpretedZ3Solver(ctx, program)) with TimeoutSolver - SolverFactory(() => new Z3StringUnrollingSolver(ctx, program, (program: Program) => new UninterpretedZ3Solver(ctx, program)) with TimeoutSolver) + SolverFactory(() => new Z3UnrollingSolver(ctx, program, new UninterpretedZ3Solver(ctx, program)) with TimeoutSolver) case "enum" => SolverFactory(() => new EnumerationSolver(ctx, program) with TimeoutSolver) @@ -93,15 +93,13 @@ object SolverFactory { SolverFactory(() => new GroundSolver(ctx, program) with TimeoutSolver) case "smt-z3" => - // Previously: new UnrollingSolver(ctx, program, new SMTLIBZ3Solver(ctx, program)) with TimeoutSolver - SolverFactory(() => new Z3StringUnrollingSolver(ctx, program, (program: Program) => new SMTLIBZ3Solver(ctx, program)) with TimeoutSolver) + SolverFactory(() => new Z3UnrollingSolver(ctx, program, new SMTLIBZ3Solver(ctx, program)) with TimeoutSolver) case "smt-z3-q" => - // Previously: new SMTLIBZ3QuantifiedSolver(ctx, program) with TimeoutSolver - SolverFactory(() => new Z3StringSMTLIBZ3QuantifiedSolver(ctx, program) with TimeoutSolver) + SolverFactory(() => new SMTLIBZ3QuantifiedSolver(ctx, program) with TimeoutSolver) case "smt-cvc4" => - SolverFactory(() => new UnrollingSolver(ctx, program, new SMTLIBCVC4Solver(ctx, program)) with TimeoutSolver) + SolverFactory(() => new CVC4UnrollingSolver(ctx, program, new SMTLIBCVC4Solver(ctx, program)) with TimeoutSolver) case "smt-cvc4-proof" => SolverFactory(() => new SMTLIBCVC4ProofSolver(ctx, program) with TimeoutSolver) diff --git a/src/main/scala/leon/solvers/combinators/Z3StringCapableSolver.scala b/src/main/scala/leon/solvers/combinators/Z3StringCapableSolver.scala deleted file mode 100644 index b94233f285e1fe63c486dd2711c63e115ef1dd7b..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/solvers/combinators/Z3StringCapableSolver.scala +++ /dev/null @@ -1,271 +0,0 @@ -/* Copyright 2009-2015 EPFL, Lausanne */ - -package leon -package solvers -package combinators - -import purescala.Common._ -import purescala.Definitions._ -import purescala.Quantification._ -import purescala.Constructors._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Types._ -import purescala.DefOps -import purescala.TypeOps -import purescala.Extractors._ -import utils._ -import templates._ -import evaluators._ -import Template._ -import leon.solvers.z3.Z3StringConversion -import leon.utils.Bijection -import leon.solvers.z3.StringEcoSystem - -object Z3StringCapableSolver { - def thatShouldBeConverted(t: TypeTree): Boolean = TypeOps.exists{ _== StringType }(t) - def thatShouldBeConverted(e: Expr): Boolean = exists(e => thatShouldBeConverted(e.getType))(e) - def thatShouldBeConverted(id: Identifier): Boolean = thatShouldBeConverted(id.getType) - def thatShouldBeConverted(vd: ValDef): Boolean = thatShouldBeConverted(vd.id) - def thatShouldBeConverted(fd: FunDef): Boolean = { - (fd.body exists thatShouldBeConverted)|| (fd.paramIds exists thatShouldBeConverted) - } - def thatShouldBeConverted(cd: ClassDef): Boolean = cd match { - case ccd:CaseClassDef => ccd.fields.exists(thatShouldBeConverted) - case _ => false - } - def thatShouldBeConverted(p: Program): Boolean = { - (p.definedFunctions exists thatShouldBeConverted) || - (p.definedClasses exists thatShouldBeConverted) - } - - def convert(p: Program): (Program, Option[Z3StringConversion]) = { - val converter = new Z3StringConversion(p) - import converter.Forward._ - var hasStrings = false - val program_with_strings = converter.getProgram - val (program_with_correct_classes, cdMap, idMap, fdMap) = if(program_with_strings.definedClasses.exists{ case c: CaseClassDef => c.fieldsIds.exists(id => TypeOps.exists{ _ == StringType}(id.getType)) case _ => false}) { - val res:(Program, Map[ClassDef, ClassDef], Map[Identifier, Identifier], Map[FunDef, FunDef]) = DefOps.replaceCaseClassDefs(program_with_strings)((cd: ClassDef) => { - cd match { - case acd:AbstractClassDef => None - case ccd:CaseClassDef => - if(ccd.fieldsIds.exists(id => TypeOps.exists(StringType == _)(id.getType))) { - Some((parent: Option[AbstractClassType]) => ccd.duplicate(convertId(ccd.id), ccd.tparams, ccd.fieldsIds.map(id => ValDef(convertId(id))), parent, ccd.isCaseObject)) - } else None - } - }) - converter.mappedVariables.clear() // We will compose them later, they have been stored in idMap - res - } else { - (program_with_strings, Map[ClassDef, ClassDef](), Map[Identifier, Identifier](), Map[FunDef, FunDef]()) - } - val fdMapInverse = fdMap.map(kv => kv._2 -> kv._1).toMap - val idMapInverse = idMap.map(kv => kv._2 -> kv._1).toMap - var globalFdMap = Map[FunDef, (Map[Identifier, Identifier], FunDef)]() - val (new_program, _) = DefOps.replaceFunDefs(program_with_correct_classes)((fd: FunDef) => { - globalFdMap.get(fd).map(_._2).orElse( - if(thatShouldBeConverted(fd)) { - val idMap = fd.params.zip(fd.params).map(origvd_vd => origvd_vd._1.id -> convertId(origvd_vd._2.id)).toMap - val newFdId = convertId(fd.id) - val newFd = fd.duplicate(newFdId, - fd.tparams, - fd.params.map(vd => ValDef(idMap(vd.id))), - convertType(fd.returnType)) - globalFdMap += fd -> ((idMap, newFd)) - hasStrings = hasStrings || (program_with_strings.library.escape.get != fd) - Some(newFd) - } else None - ) - }) - if(!hasStrings) { - (p, None) - } else { - converter.globalFdMap ++= globalFdMap.view.map(kv => (kv._1, kv._2._2)) - for((fd, (idMap, newFd)) <- globalFdMap) { - implicit val idVarMap = idMap.mapValues(id => Variable(id)) - newFd.fullBody = convertExpr(newFd.fullBody) - } - converter.mappedVariables.composeA(id => idMapInverse.getOrElse(id, id)) - converter.globalFdMap.composeA(fd => fdMapInverse.getOrElse(fd, fd)) - converter.globalClassMap ++= cdMap - (new_program, Some(converter)) - } - } -} - -abstract class Z3StringCapableSolver[+TUnderlying <: Solver]( - val context: LeonContext, - val program: Program, - val underlyingConstructor: (Program, Option[Z3StringConversion]) => TUnderlying) extends Solver { - - protected val (new_program, optConverter) = Z3StringCapableSolver.convert(program) - var someConverter = optConverter - - val underlying = underlyingConstructor(new_program, someConverter) - var solverInvokedWithStrings = false - - def getModel: leon.solvers.Model = { - val model = underlying.getModel - someConverter match { - case None => model - case Some(converter) => - val ids = model.ids.toSeq - val exprs = ids.map(model.apply) - import converter.Backward._ - val original_ids = ids.map(convertId) - val original_exprs = exprs.map{ case e => convertExpr(e)(Map()) } - - model match { - case hm: PartialModel => - val new_domain = new Domains( - hm.domains.lambdas.map(kv => - (convertExpr(kv._1)(Map()).asInstanceOf[Lambda], - kv._2.map(e => e.map(e => convertExpr(e)(Map()))))).toMap, - hm.domains.tpes.map(kv => - (convertType(kv._1), - kv._2.map(e => e.map(e => convertExpr(e)(Map()))))).toMap - ) - - new PartialModel(original_ids.zip(original_exprs).toMap, new_domain) - case _ => - new Model(original_ids.zip(original_exprs).toMap) - } - } - } - - // Members declared in leon.utils.Interruptible - def interrupt(): Unit = underlying.interrupt() - def recoverInterrupt(): Unit = underlying.recoverInterrupt() - - // Converts expression on the fly if needed, creating a string converter if needed. - def convertExprOnTheFly(expression: Expr, withConverter: Z3StringConversion => Expr): Expr = { - someConverter match { - case None => - if(solverInvokedWithStrings || exists(e => TypeOps.exists(StringType == _)(e.getType))(expression)) { // On the fly conversion - solverInvokedWithStrings = true - val c = new Z3StringConversion(program) - someConverter = Some(c) - withConverter(c) - } else expression - case Some(converter) => - withConverter(converter) - } - } - - // Members declared in leon.solvers.Solver - def assertCnstr(expression: Expr): Unit = { - someConverter.map{converter => - import converter.Forward._ - val newExpression = convertExpr(expression)(Map()) - underlying.assertCnstr(newExpression) - }.getOrElse{ - underlying.assertCnstr(convertExprOnTheFly(expression, _.Forward.convertExpr(expression)(Map()))) - } - } - def getUnsatCore: Set[Expr] = { - someConverter.map{converter => - import converter.Backward._ - underlying.getUnsatCore map (e => convertExpr(e)(Map())) - }.getOrElse(underlying.getUnsatCore) - } - - def check: Option[Boolean] = underlying.check - def free(): Unit = underlying.free() - def pop(): Unit = underlying.pop() - def push(): Unit = underlying.push() - def reset(): Unit = underlying.reset() - def name: String = underlying.name -} - -import z3._ - -trait Z3StringAbstractZ3Solver[TUnderlying <: Solver] extends AbstractZ3Solver { self: Z3StringCapableSolver[TUnderlying] => -} - -trait Z3StringNaiveAssumptionSolver[TUnderlying <: Solver] extends NaiveAssumptionSolver { self: Z3StringCapableSolver[TUnderlying] => -} - -trait Z3StringEvaluatingSolver[TUnderlying <: EvaluatingSolver] extends EvaluatingSolver{ self: Z3StringCapableSolver[TUnderlying] => - // Members declared in leon.solvers.EvaluatingSolver - val useCodeGen: Boolean = underlying.useCodeGen -} - -class ConvertibleCodeGenEvaluator(context: LeonContext, originalProgram: Program, val converter: Z3StringConversion) - extends CodeGenEvaluator(context, originalProgram) { - - override def compile(expression: Expr, args: Seq[Identifier]) : Option[solvers.Model=>EvaluationResult] = { - import converter._ - super.compile(Backward.convertExpr(expression)(Map()), args.map(Backward.convertId)) - .map(evaluator => (m: Model) => Forward.convertResult(evaluator(Backward.convertModel(m))) - ) - } -} - -class ConvertibleDefaultEvaluator(context: LeonContext, originalProgram: Program, val converter: Z3StringConversion) - extends DefaultEvaluator(context, originalProgram) { - - override def eval(ex: Expr, model: Model): EvaluationResults.Result[Expr] = { - import converter._ - Forward.convertResult(super.eval(Backward.convertExpr(ex)(Map()), Backward.convertModel(model))) - } -} - -class FairZ3SolverWithBackwardEvaluator(context: LeonContext, program: Program, - originalProgram: Program, someConverter: Option[Z3StringConversion]) extends FairZ3Solver(context, program) { - override lazy val evaluator: DeterministicEvaluator = { // We evaluate expressions using the original evaluator - someConverter match { - case Some(converter) => - if (useCodeGen) { - new ConvertibleCodeGenEvaluator(context, originalProgram, converter) - } else { - new ConvertibleDefaultEvaluator(context, originalProgram, converter) - } - case None => - if (useCodeGen) { - new CodeGenEvaluator(context, program) - } else { - new DefaultEvaluator(context, program) - } - } - } -} - -class Z3StringFairZ3Solver(context: LeonContext, program: Program) - extends Z3StringCapableSolver(context, program, - (prgm: Program, someConverter: Option[Z3StringConversion]) => - new FairZ3SolverWithBackwardEvaluator(context, prgm, program, someConverter)) - with Z3StringEvaluatingSolver[FairZ3Solver] { - - // Members declared in leon.solvers.z3.AbstractZ3Solver - protected[leon] val z3cfg: _root_.z3.scala.Z3Config = underlying.z3cfg - override def checkAssumptions(assumptions: Set[Expr]): Option[Boolean] = { - someConverter match { - case None => underlying.checkAssumptions(assumptions.map(e => convertExprOnTheFly(e, _.Forward.convertExpr(e)(Map())))) - case Some(converter) => - underlying.checkAssumptions(assumptions map (e => converter.Forward.convertExpr(e)(Map()))) - } - } -} - -class Z3StringUnrollingSolver(context: LeonContext, program: Program, underlyingSolverConstructor: Program => Solver) - extends Z3StringCapableSolver(context, program, (program: Program, converter: Option[Z3StringConversion]) => - new UnrollingSolver(context, program, underlyingSolverConstructor(program))) - with Z3StringNaiveAssumptionSolver[UnrollingSolver] - with Z3StringEvaluatingSolver[UnrollingSolver] { - - override def getUnsatCore = super[Z3StringNaiveAssumptionSolver].getUnsatCore -} - -class Z3StringSMTLIBZ3QuantifiedSolver(context: LeonContext, program: Program) - extends Z3StringCapableSolver(context, program, (program: Program, converter: Option[Z3StringConversion]) => - new smtlib.SMTLIBZ3QuantifiedSolver(context, program)) { - - override def checkAssumptions(assumptions: Set[Expr]): Option[Boolean] = { - someConverter match { - case None => underlying.checkAssumptions(assumptions) - case Some(converter) => - underlying.checkAssumptions(assumptions map (e => converter.Forward.convertExpr(e)(Map()))) - } - } -} - diff --git a/src/main/scala/leon/solvers/cvc4/CVC4UnrollingSolver.scala b/src/main/scala/leon/solvers/cvc4/CVC4UnrollingSolver.scala new file mode 100644 index 0000000000000000000000000000000000000000..3ddfca8d9524f9052dc5e6e49266006e9ffa2e3b --- /dev/null +++ b/src/main/scala/leon/solvers/cvc4/CVC4UnrollingSolver.scala @@ -0,0 +1,13 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package solvers +package cvc4 + +import purescala.Definitions._ + +import unrolling._ +import theories._ + +class CVC4UnrollingSolver(context: LeonContext, program: Program, underlying: Solver) + extends UnrollingSolver(context, program, underlying, theories = new NoEncoder) diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Solver.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Solver.scala index 4c17f80892dfad288c0522e3c918a31eb4cef9ae..5de74510b465d76cfc152c476cff96321c4758e9 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Solver.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Solver.scala @@ -5,9 +5,12 @@ package solvers.smtlib import OptionParsers._ +import solvers.theories._ import purescala.Definitions.Program -class SMTLIBCVC4Solver(context: LeonContext, program: Program) extends SMTLIBSolver(context, program) with SMTLIBCVC4Target { +class SMTLIBCVC4Solver(context: LeonContext, program: Program) + extends SMTLIBSolver(context, program, new NoEncoder) + with SMTLIBCVC4Target { def targetName = "cvc4" diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala index d9deb8790f129b7d81a103ef9e25853ae332694e..30ae22f989a2b021e56165fdaf3021246720a013 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala @@ -50,10 +50,6 @@ trait SMTLIBCVC4Target extends SMTLIBTarget { super.fromSMT(t, otpe) } - case (SimpleSymbol(s), Some(tp: TypeParameter)) => - val n = s.name.split("_").toList.last - GenericValue(tp, n.toInt) - case (QualifiedIdentifier(SMTIdentifier(SSymbol("emptyset"), Seq()), _), Some(SetType(base))) => FiniteSet(Set(), base) diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBQuantifiedTarget.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBQuantifiedTarget.scala index 980d87fa8ae2b443ae1c954594297e753449114b..76037a9451a27011a7ff47609073fcc07e99a92b 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBQuantifiedTarget.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBQuantifiedTarget.scala @@ -26,7 +26,7 @@ trait SMTLIBQuantifiedTarget extends SMTLIBTarget { protected def withInductiveHyp(cond: Expr): Expr = { val inductiveHyps = for { - fi@FunctionInvocation(tfd, args) <- functionCallsOf(cond).toSeq + fi @ FunctionInvocation(tfd, args) <- functionCallsOf(cond).toSeq } yield { val post = application( tfd.withParamSubst(args, tfd.postOrTrue), @@ -38,6 +38,5 @@ trait SMTLIBQuantifiedTarget extends SMTLIBTarget { // We want to check if the negation of the vc is sat under inductive hyp. // So we need to see if (indHyp /\ !vc) is satisfiable liftLets(matchToIfThenElse(andJoin(inductiveHyps :+ not(cond)))) - } } diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala index 78b9ccbd90f00260ab04c8d0a55da2ed00c24f5f..9f6f92b93c0cafc7cf73210879d32b3c9defff2b 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala @@ -13,13 +13,18 @@ import _root_.smtlib.parser.Commands.{Assert => SMTAssert, FunDef => SMTFunDef, import _root_.smtlib.parser.Terms.{Identifier => SMTIdentifier, _} import _root_.smtlib.parser.CommandsResponses.{Error => ErrorResponse, _} -abstract class SMTLIBSolver(val context: LeonContext, val program: Program) +import theories._ +import utils._ + +abstract class SMTLIBSolver(val context: LeonContext, val program: Program, theories: TheoryEncoder) extends Solver with SMTLIBTarget with NaiveAssumptionSolver { /* Solver name */ def targetName: String override def name: String = "smt-"+targetName + private val ids = new IncrementalBijection[Identifier, Identifier]() + override def dbg(msg: => Any) = { debugOut foreach { o => o.write(msg.toString) @@ -28,8 +33,10 @@ abstract class SMTLIBSolver(val context: LeonContext, val program: Program) } /* Public solver interface */ - def assertCnstr(expr: Expr): Unit = if(!hasError) { + def assertCnstr(raw: Expr): Unit = if (!hasError) { try { + val bindings = variablesOf(raw).map(id => id -> ids.cachedB(id)(theories.encode(id))).toMap + val expr = theories.encode(raw)(bindings) variablesOf(expr).foreach(declareVariable) val term = toSMT(expr)(Map()) @@ -84,8 +91,13 @@ abstract class SMTLIBSolver(val context: LeonContext, val program: Program) for (me <- smodel) me match { case DefineFun(SMTFunDef(s, args, kind, e)) if syms(s) => - val id = variables.toA(s) - model += id -> fromSMT(e, id.getType)(Map(), modelFunDefs) + try { + val id = variables.toA(s) + val value = fromSMT(e, id.getType)(Map(), modelFunDefs) + model += ids.getAorElse(id, id) -> theories.decode(value)(variablesOf(value).map(id => id -> ids.toA(id)).toMap) + } catch { + case _: Unsupported => + } case _ => } @@ -101,14 +113,14 @@ abstract class SMTLIBSolver(val context: LeonContext, val program: Program) } } - override def getModel: Model = getModel( _ => true) + override def getModel: Model = getModel(_ => true) override def push(): Unit = { + ids.push() constructors.push() selectors.push() testers.push() variables.push() - genericValues.push() sorts.push() lambdas.push() functions.push() @@ -118,11 +130,11 @@ abstract class SMTLIBSolver(val context: LeonContext, val program: Program) } override def pop(): Unit = { + ids.pop() constructors.pop() selectors.pop() testers.pop() variables.pop() - genericValues.pop() sorts.pop() lambdas.pop() functions.pop() diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala index 242f321219eebb34c93c4756dd10a4ba92d114df..7418526c9338fce9c558d118fb947af503ee53ec 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala @@ -153,7 +153,6 @@ trait SMTLIBTarget extends Interruptible { protected val selectors = new IncrementalBijection[(TypeTree, Int), SSymbol]() protected val testers = new IncrementalBijection[TypeTree, SSymbol]() protected val variables = new IncrementalBijection[Identifier, SSymbol]() - protected val genericValues = new IncrementalBijection[GenericValue, SSymbol]() protected val sorts = new IncrementalBijection[TypeTree, Sort]() protected val functions = new IncrementalBijection[TypedFunDef, SSymbol]() protected val lambdas = new IncrementalBijection[FunctionType, SSymbol]() @@ -226,13 +225,6 @@ trait SMTLIBTarget extends Interruptible { unsupported(other, "Unable to extract from raw array for " + tpe) } - protected def declareUninterpretedSort(t: TypeParameter): Sort = { - val s = id2sym(t.id) - val cmd = DeclareSort(s, 0) - emit(cmd) - Sort(SMTIdentifier(s)) - } - protected def declareSort(t: TypeTree): Sort = { val tpe = normalizeType(t) sorts.cachedB(tpe) { @@ -252,10 +244,7 @@ trait SMTLIBTarget extends Interruptible { case FunctionType(from, to) => Ints.IntSort() - case tp: TypeParameter => - declareUninterpretedSort(tp) - - case _: ClassType | _: TupleType | _: ArrayType | UnitType => + case _: ClassType | _: TupleType | _: ArrayType | _: TypeParameter | UnitType => declareStructuralSort(tpe) case other => @@ -305,7 +294,6 @@ trait SMTLIBTarget extends Interruptible { conflicts.foreach { declareStructuralSort } declareStructuralSort(t) } - } protected def declareVariable(id: Identifier): SSymbol = { @@ -532,13 +520,9 @@ trait SMTLIBTarget extends Interruptible { toSMT(matchToIfThenElse(m)) case gv @ GenericValue(tpe, n) => - genericValues.cachedB(gv) { - val v = declareVariable(FreshIdentifier("gv" + n, tpe)) - for ((ogv, ov) <- genericValues.aToB if ogv.getType == tpe) { - emit(SMTAssert(Core.Not(Core.Equals(v, ov)))) - } - v - } + declareSort(tpe) + val constructor = constructors.toB(tpe) + FunctionApplication(constructor, Seq(toSMT(InfiniteIntegerLiteral(n)))) /** * ===== Everything else ===== @@ -803,11 +787,12 @@ trait SMTLIBTarget extends Interruptible { case cct: CaseClassType => val rargs = args.zip(cct.fields.map(_.getType)).map(fromSMT) CaseClass(cct, rargs) + case tt: TupleType => val rargs = args.zip(tt.bases).map(fromSMT) tupleWrap(rargs) - case at@ArrayType(baseType) => + case at @ ArrayType(baseType) => val IntLiteral(size) = fromSMT(args(0), Int32Type) val RawArrayValue(_, elems, default) = fromSMT(args(1), RawArrayType(Int32Type, baseType)) @@ -825,6 +810,10 @@ trait SMTLIBTarget extends Interruptible { finiteArray(entries, None, baseType) } + case tp @ TypeParameter(id) => + val InfiniteIntegerLiteral(n) = fromSMT(args(0), IntegerType) + GenericValue(tp, n.toInt) + case t => unsupported(t, "Woot? structural type that is non-structural") } diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3QuantifiedSolver.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3QuantifiedSolver.scala index 083e07f34b76cbd118727c326803e521d4967ba6..0be498f1352e0ced484318ca5b48d3f8d7b1da00 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3QuantifiedSolver.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3QuantifiedSolver.scala @@ -1,14 +1,18 @@ /* Copyright 2009-2016 EPFL, Lausanne */ package leon -package solvers.smtlib +package solvers +package smtlib import purescala.Definitions.Program +import theories._ + /** * This solver models function definitions as universally quantified formulas. * It is not meant as an underlying solver to UnrollingSolver, and does not handle HOFs. */ -class SMTLIBZ3QuantifiedSolver(context: LeonContext, program: Program) extends SMTLIBZ3Solver(context, program) - with SMTLIBQuantifiedSolver - with SMTLIBZ3QuantifiedTarget +class SMTLIBZ3QuantifiedSolver(context: LeonContext, program: Program) + extends SMTLIBZ3Solver(context, program) + with SMTLIBQuantifiedSolver + with SMTLIBZ3QuantifiedTarget diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3QuantifiedTarget.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3QuantifiedTarget.scala index fbd4b581c7c34835c6deb1845c04717e0beafd2c..f355392d069edfd55ab97a530b814bb0211a8e8a 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3QuantifiedTarget.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3QuantifiedTarget.scala @@ -65,6 +65,5 @@ trait SMTLIBZ3QuantifiedTarget extends SMTLIBZ3Target with SMTLIBQuantifiedTarge } functions.toB(tfd) - } } diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Solver.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Solver.scala index 58a88d1a6153bed23c38e059d54c8e8a46a5f1fd..6a58f9f1a294180182595bcb34d35d1a2ffe08a7 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Solver.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Solver.scala @@ -13,49 +13,11 @@ import _root_.smtlib.parser.Commands.{FunDef => SMTFunDef, _} import _root_.smtlib.parser.CommandsResponses.GetModelResponseSuccess import _root_.smtlib.theories.Core.{Equals => _, _} -class SMTLIBZ3Solver(context: LeonContext, program: Program) extends SMTLIBSolver(context, program) with SMTLIBZ3Target { +import theories._ - def getProgram: Program = program - - // EK: We use get-model instead in order to extract models for arrays - override def getModel: Model = { - - val res = emit(GetModel()) - - val smodel: Seq[SExpr] = res match { - case GetModelResponseSuccess(model) => model - case _ => Nil - } - - var modelFunDefs = Map[SSymbol, DefineFun]() - - // First pass to gather functions (arrays defs) - for (me <- smodel) me match { - case me @ DefineFun(SMTFunDef(a, args, _, _)) if args.nonEmpty => - modelFunDefs += a -> me - case _ => - } - - var model = Map[Identifier, Expr]() - - for (me <- smodel) me match { - case DefineFun(SMTFunDef(s, args, kind, e)) => - if(args.isEmpty) { - variables.getA(s) match { - case Some(id) => - // EK: this is a little hack, we pass models for array functions as let-defs - try { - model += id -> fromSMT(e, id.getType)(Map(), modelFunDefs) - } catch { - case _ : Unsupported => - - } - case _ => // function, should be handled elsewhere - } - } - case _ => - } - new Model(model) - } +class SMTLIBZ3Solver(context: LeonContext, program: Program) + extends SMTLIBSolver(context, program, new StringEncoder) + with SMTLIBZ3Target { + def getProgram: Program = program } diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala index 5744e1a107f8c1c25b8e82e02c0619b50ca9c5c9..1be3f7ecf930990749565a1c7ca2181cd2a100dd 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala @@ -72,10 +72,6 @@ trait SMTLIBZ3Target extends SMTLIBTarget { override protected def fromSMT(t: Term, otpe: Option[TypeTree] = None) (implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): Expr = { (t, otpe) match { - case (SimpleSymbol(s), Some(tp: TypeParameter)) => - val n = s.name.split("!").toList.last - GenericValue(tp, n.toInt) - case (QualifiedIdentifier(ExtendedIdentifier(SSymbol("as-array"), k: SSymbol), _), Some(tpe)) => if (letDefs contains k) { // Need to recover value form function model diff --git a/src/main/scala/leon/solvers/theories/BagEncoder.scala b/src/main/scala/leon/solvers/theories/BagEncoder.scala new file mode 100644 index 0000000000000000000000000000000000000000..4ba7fcaa42307ae9c1b65bd3b81c4da03ab62b22 --- /dev/null +++ b/src/main/scala/leon/solvers/theories/BagEncoder.scala @@ -0,0 +1,14 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package solvers +package theories + +import purescala.Common._ +import purescala.Expressions._ +import purescala.Types._ + +class BagEncoder(val context: LeonContext) extends TheoryEncoder { + val encoder = new Encoder + val decoder = new Decoder +} diff --git a/src/main/scala/leon/solvers/theories/StringEncoder.scala b/src/main/scala/leon/solvers/theories/StringEncoder.scala new file mode 100644 index 0000000000000000000000000000000000000000..8f33513a64898c17a7a4b2dfda13a037b2d93f5d --- /dev/null +++ b/src/main/scala/leon/solvers/theories/StringEncoder.scala @@ -0,0 +1,203 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package solvers +package theories + +import purescala.Common._ +import purescala.Expressions._ +import purescala.Constructors._ +import purescala.Types._ +import purescala.Definitions._ +import leon.utils.Bijection +import leon.purescala.DefOps +import leon.purescala.TypeOps +import leon.purescala.Extractors.Operator +import leon.evaluators.EvaluationResults + +object StringEcoSystem { + private def withIdentifier[T](name: String, tpe: TypeTree = Untyped)(f: Identifier => T): T = { + val id = FreshIdentifier(name, tpe) + f(id) + } + + private def withIdentifiers[T](name: String, tpe: TypeTree, name2: String, tpe2: TypeTree = Untyped)(f: (Identifier, Identifier) => T): T = { + withIdentifier(name, tpe)(id => withIdentifier(name2, tpe2)(id2 => f(id, id2))) + } + + val StringList = new AbstractClassDef(FreshIdentifier("StringList"), Seq(), None) + val StringListTyped = StringList.typed + val StringCons = withIdentifiers("head", CharType, "tail", StringListTyped){ (head, tail) => + val d = new CaseClassDef(FreshIdentifier("StringCons"), Seq(), Some(StringListTyped), false) + d.setFields(Seq(ValDef(head), ValDef(tail))) + d + } + + StringList.registerChild(StringCons) + val StringConsTyped = StringCons.typed + val StringNil = new CaseClassDef(FreshIdentifier("StringNil"), Seq(), Some(StringListTyped), false) + val StringNilTyped = StringNil.typed + StringList.registerChild(StringNil) + + val StringSize = withIdentifiers("l", StringListTyped, "StringSize"){ (lengthArg, id) => + val fd = new FunDef(id, Seq(), Seq(ValDef(lengthArg)), IntegerType) + fd.body = Some(withIdentifiers("h", CharType, "t", StringListTyped){ (h, t) => + MatchExpr(Variable(lengthArg), Seq( + MatchCase(CaseClassPattern(None, StringNilTyped, Seq()), None, InfiniteIntegerLiteral(BigInt(0))), + MatchCase(CaseClassPattern(None, StringConsTyped, Seq(WildcardPattern(Some(h)), WildcardPattern(Some(t)))), None, + Plus(InfiniteIntegerLiteral(BigInt(1)), FunctionInvocation(fd.typed, Seq(Variable(t))))) + )) + }) + fd + } + + val StringListConcat = withIdentifiers("x", StringListTyped, "y", StringListTyped) { (x, y) => + val fd = new FunDef(FreshIdentifier("StringListConcat"), Seq(), Seq(ValDef(x), ValDef(y)), StringListTyped) + fd.body = Some( + withIdentifiers("h", CharType, "t", StringListTyped){ (h, t) => + MatchExpr(Variable(x), Seq( + MatchCase(CaseClassPattern(None, StringNilTyped, Seq()), None, Variable(y)), + MatchCase(CaseClassPattern(None, StringConsTyped, Seq(WildcardPattern(Some(h)), WildcardPattern(Some(t)))), None, + CaseClass(StringConsTyped, Seq(Variable(h), FunctionInvocation(fd.typed, Seq(Variable(t), Variable(y))))) + ))) + } + ) + fd + } + + val StringTake = withIdentifiers("tt", StringListTyped, "it", StringListTyped) { (tt, it) => + val fd = new FunDef(FreshIdentifier("StringTake"), Seq(), Seq(ValDef(tt), ValDef(it)), StringListTyped) + fd.body = Some{ + withIdentifiers("h", CharType, "t", StringListTyped) { (h, t) => + withIdentifier("i", IntegerType){ i => + MatchExpr(Tuple(Seq(Variable(tt), Variable(it))), Seq( + MatchCase(TuplePattern(None, Seq(CaseClassPattern(None, StringNilTyped, Seq()), WildcardPattern(None))), None, + InfiniteIntegerLiteral(BigInt(0))), + MatchCase(TuplePattern(None, Seq(CaseClassPattern(None, StringConsTyped, Seq(WildcardPattern(Some(h)), WildcardPattern(Some(t)))), WildcardPattern(Some(i)))), None, + IfExpr(LessThan(Variable(i), InfiniteIntegerLiteral(BigInt(0))), + CaseClass(StringNilTyped, Seq()), + CaseClass(StringConsTyped, Seq(Variable(h), + FunctionInvocation(fd.typed, Seq(Variable(t), Minus(Variable(i), InfiniteIntegerLiteral(BigInt(1))))))) + )))) + } + } + } + fd + } + + val StringDrop = withIdentifiers("td", StringListTyped, "id", IntegerType) { (td, id) => + val fd = new FunDef(FreshIdentifier("StringDrop"), Seq(), Seq(ValDef(td), ValDef(id)), StringListTyped) + fd.body = Some( + withIdentifiers("h", CharType, "t", StringListTyped) { (h, t) => + withIdentifier("i", IntegerType){ i => + MatchExpr(Tuple(Seq(Variable(td), Variable(id))), Seq( + MatchCase(TuplePattern(None, Seq(CaseClassPattern(None, StringNilTyped, Seq()), WildcardPattern(None))), None, + InfiniteIntegerLiteral(BigInt(0))), + MatchCase(TuplePattern(None, Seq(CaseClassPattern(None, StringConsTyped, Seq(WildcardPattern(Some(h)), WildcardPattern(Some(t)))), WildcardPattern(Some(i)))), None, + IfExpr(LessThan(Variable(i), InfiniteIntegerLiteral(BigInt(0))), + CaseClass(StringConsTyped, Seq(Variable(h), Variable(t))), + FunctionInvocation(fd.typed, Seq(Variable(t), Minus(Variable(i), InfiniteIntegerLiteral(BigInt(1))))) + )))) + }} + ) + fd + } + + val StringSlice = withIdentifier("s", StringListTyped) { s => withIdentifiers("from", IntegerType, "to", IntegerType) { (from, to) => + val fd = new FunDef(FreshIdentifier("StringSlice"), Seq(), Seq(ValDef(s), ValDef(from), ValDef(to)), StringListTyped) + fd.body = Some( + FunctionInvocation(StringTake.typed, + Seq(FunctionInvocation(StringDrop.typed, Seq(Variable(s), Variable(from))), + Minus(Variable(to), Variable(from))))) + fd + } } + + val classDefs = Seq(StringList, StringCons, StringNil) + val funDefs = Seq(StringSize, StringListConcat, StringTake, StringDrop, StringSlice) +} + +class StringEncoder extends TheoryEncoder { + import StringEcoSystem._ + + private val stringBijection = new Bijection[String, Expr]() + + private def convertToString(e: Expr): String = stringBijection.cachedA(e)(e match { + case CaseClass(_, Seq(CharLiteral(c), l)) => c + convertToString(l) + case CaseClass(_, Seq()) => "" + }) + + private def convertFromString(v: String): Expr = stringBijection.cachedB(v) { + v.toList.foldRight(CaseClass(StringNilTyped, Seq())){ + case (char, l) => CaseClass(StringConsTyped, Seq(CharLiteral(char), l)) + } + } + + val encoder = new Encoder { + override def transform(e: Expr)(implicit binders: Map[Identifier, Identifier]): Expr = e match { + case StringLiteral(v) => + convertFromString(v) + case StringLength(a) => + FunctionInvocation(StringSize.typed, Seq(transform(a))).copiedFrom(e) + case StringConcat(a, b) => + FunctionInvocation(StringListConcat.typed, Seq(transform(a), transform(b))).copiedFrom(e) + case SubString(a, start, Plus(start2, length)) if start == start2 => + FunctionInvocation(StringTake.typed, Seq(FunctionInvocation(StringDrop.typed, Seq(transform(a), transform(start))), transform(length))).copiedFrom(e) + case SubString(a, start, end) => + FunctionInvocation(StringSlice.typed, Seq(transform(a), transform(start), transform(end))).copiedFrom(e) + case _ => super.transform(e) + } + + override def transform(tpe: TypeTree): TypeTree = tpe match { + case StringType => StringListTyped + case _ => super.transform(tpe) + } + + override def transform(pat: Pattern): (Pattern, Map[Identifier, Identifier]) = pat match { + case LiteralPattern(binder, StringLiteral(s)) => + val newBinder = binder map transform + val newPattern = s.foldRight(CaseClassPattern(None, StringNilTyped, Seq())) { + case (elem, pattern) => CaseClassPattern(None, StringConsTyped, Seq(LiteralPattern(None, CharLiteral(elem)), pattern)) + } + (newPattern.copy(binder = newBinder), (binder zip newBinder).filter(p => p._1 != p._2).toMap) + case _ => super.transform(pat) + } + } + + val decoder = new Decoder { + override def transform(e: Expr)(implicit binders: Map[Identifier, Identifier]): Expr = e match { + case cc @ CaseClass(cct, args) if TypeOps.isSubtypeOf(cct, StringListTyped)=> + StringLiteral(convertToString(cc)).copiedFrom(cc) + case FunctionInvocation(StringSize, Seq(a)) => + StringLength(transform(a)).copiedFrom(e) + case FunctionInvocation(StringListConcat, Seq(a, b)) => + StringConcat(transform(a), transform(b)).copiedFrom(e) + case FunctionInvocation(StringTake, Seq(FunctionInvocation(StringDrop, Seq(a, start)), length)) => + val rstart = transform(start) + SubString(transform(a), rstart, plus(rstart, transform(length))).copiedFrom(e) + case _ => super.transform(e) + } + + + override def transform(tpe: TypeTree): TypeTree = tpe match { + case StringListTyped | StringConsTyped | StringNilTyped => StringType + case _ => super.transform(tpe) + } + + override def transform(pat: Pattern): (Pattern, Map[Identifier, Identifier]) = pat match { + case CaseClassPattern(b, StringNilTyped, Seq()) => + val newBinder = b map transform + (LiteralPattern(newBinder , StringLiteral("")), (b zip newBinder).filter(p => p._1 != p._2).toMap) + + case CaseClassPattern(b, StringConsTyped, Seq(LiteralPattern(_, CharLiteral(elem)), sub)) => transform(sub) match { + case (LiteralPattern(_, StringLiteral(s)), binders) => + val newBinder = b map transform + (LiteralPattern(newBinder, StringLiteral(elem + s)), (b zip newBinder).filter(p => p._1 != p._2).toMap ++ binders) + case (e, binders) => + (LiteralPattern(None, StringLiteral("Failed to parse pattern back as string:" + e)), binders) + } + + case _ => super.transform(pat) + } + } +} + diff --git a/src/main/scala/leon/solvers/theories/TheoryEncoder.scala b/src/main/scala/leon/solvers/theories/TheoryEncoder.scala new file mode 100644 index 0000000000000000000000000000000000000000..39b5a632c5321220416e8142f30eb90c65b887a5 --- /dev/null +++ b/src/main/scala/leon/solvers/theories/TheoryEncoder.scala @@ -0,0 +1,82 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package solvers +package theories + +import purescala.Common._ +import purescala.Expressions._ +import purescala.Definitions._ +import purescala.Extractors._ +import purescala.ExprOps._ +import purescala.Types._ + +import utils._ + +import scala.collection.mutable.{Map => MutableMap, Set => MutableSet} + +trait TheoryEncoder { self => + protected val encoder: Encoder + protected val decoder: Decoder + + private val idMap = new Bijection[Identifier, Identifier] + private val fdMap = new Bijection[FunDef , FunDef ] + private val cdMap = new Bijection[ClassDef , ClassDef ] + + def encode(id: Identifier): Identifier = encoder.transform(id) + def decode(id: Identifier): Identifier = decoder.transform(id) + + def encode(expr: Expr)(implicit bindings: Map[Identifier, Identifier]): Expr = encoder.transform(expr) + def decode(expr: Expr)(implicit bindings: Map[Identifier, Identifier]): Expr = decoder.transform(expr) + + def encode(tpe: TypeTree): TypeTree = encoder.transform(tpe) + def decode(tpe: TypeTree): TypeTree = decoder.transform(tpe) + + def encode(fd: FunDef): FunDef = encoder.transform(fd) + def decode(fd: FunDef): FunDef = decoder.transform(fd) + + protected class Encoder extends purescala.DefinitionTransformer(idMap, fdMap, cdMap) + protected class Decoder extends purescala.DefinitionTransformer(idMap.swap, fdMap.swap, cdMap.swap) + + def >>(that: TheoryEncoder): TheoryEncoder = new TheoryEncoder { + val encoder = new Encoder { + override def transform(expr: Expr)(implicit bindings: Map[Identifier, Identifier]): Expr = { + val mapSeq = bindings.toSeq + val intermediate = mapSeq.map { case (id, _) => id.duplicate(tpe = self.encoder.transform(id.getType)) } + val e2 = self.encoder.transform(expr)((mapSeq zip intermediate).map { case ((id, _), id2) => id -> id2 }.toMap) + that.encoder.transform(e2)((intermediate zip mapSeq).map { case (id, (_, id2)) => id -> id2 }.toMap) + } + + override def transform(tpe: TypeTree): TypeTree = that.encoder.transform(self.encoder.transform(tpe)) + + override def transform(pat: Pattern): (Pattern, Map[Identifier, Identifier]) = { + val (pat2, bindings) = self.encoder.transform(pat) + val (pat3, bindings2) = that.encoder.transform(pat2) + (pat3, bindings2.map { case (id, id2) => id -> bindings2(id2) }) + } + } + + val decoder = new Decoder { + override def transform(expr: Expr)(implicit bindings: Map[Identifier, Identifier]): Expr = { + val mapSeq = bindings.toSeq + val intermediate = mapSeq.map { case (id, _) => id.duplicate(tpe = self.decoder.transform(id.getType)) } + val e2 = that.decoder.transform(expr)((mapSeq zip intermediate).map { case ((id, _), id2) => id -> id2 }.toMap) + self.decoder.transform(e2)((intermediate zip mapSeq).map { case (id, (_, id2)) => id -> id2 }.toMap) + } + + override def transform(tpe: TypeTree): TypeTree = self.decoder.transform(that.decoder.transform(tpe)) + + override def transform(pat: Pattern): (Pattern, Map[Identifier, Identifier]) = { + val (pat2, bindings) = that.decoder.transform(pat) + val (pat3, bindings2) = self.decoder.transform(pat2) + (pat3, bindings.map { case (id, id2) => id -> bindings2(id2) }) + } + } + } +} + +class NoEncoder extends TheoryEncoder { + val encoder = new Encoder + val decoder = new Decoder +} + diff --git a/src/main/scala/leon/solvers/templates/DatatypeManager.scala b/src/main/scala/leon/solvers/unrolling/DatatypeManager.scala similarity index 98% rename from src/main/scala/leon/solvers/templates/DatatypeManager.scala rename to src/main/scala/leon/solvers/unrolling/DatatypeManager.scala index dcfa67e8343f88fb53dfa0a18212e0a2aa383016..554689830fe4b61442e6272cd658c9212da93da8 100644 --- a/src/main/scala/leon/solvers/templates/DatatypeManager.scala +++ b/src/main/scala/leon/solvers/unrolling/DatatypeManager.scala @@ -2,7 +2,7 @@ package leon package solvers -package templates +package unrolling import purescala.Common._ import purescala.Definitions._ @@ -157,6 +157,8 @@ class DatatypeManager[T](encoder: TemplateEncoder[T]) extends TemplateManager(en case BooleanType | UnitType | CharType | IntegerType | RealType | Int32Type | StringType | (_: TypeParameter) => false + case at: ArrayType => true + case NAryType(tpes, _) => tpes.exists(requireTypeUnrolling) } @@ -205,6 +207,9 @@ class DatatypeManager[T](encoder: TemplateEncoder[T]) extends TemplateManager(en case FunctionType(_, _) => FreshFunction(expr) + case at: ArrayType => + GreaterEquals(ArrayLength(expr), IntLiteral(0)) + case _ => scala.sys.error("TODO") } diff --git a/src/main/scala/leon/solvers/templates/LambdaManager.scala b/src/main/scala/leon/solvers/unrolling/LambdaManager.scala similarity index 97% rename from src/main/scala/leon/solvers/templates/LambdaManager.scala rename to src/main/scala/leon/solvers/unrolling/LambdaManager.scala index f80a143a5095f0f0ccc25cd836a4c27f96be72dc..f1ad010d81c25a08b6e5683dffe581552315e05e 100644 --- a/src/main/scala/leon/solvers/templates/LambdaManager.scala +++ b/src/main/scala/leon/solvers/unrolling/LambdaManager.scala @@ -2,7 +2,7 @@ package leon package solvers -package templates +package unrolling import purescala.Common._ import purescala.Definitions._ @@ -99,7 +99,7 @@ trait KeyedTemplate[T, E <: Expr] { case _ => Seq.empty } - structure -> rec(structure).map(dependencies) + structure -> rec(structure).distinct.map(dependencies) } } @@ -196,9 +196,9 @@ class LambdaTemplate[T] private ( } class LambdaManager[T](encoder: TemplateEncoder[T]) extends DatatypeManager(encoder) { - private[templates] lazy val trueT = encoder.encodeExpr(Map.empty)(BooleanLiteral(true)) + private[unrolling] lazy val trueT = encoder.encodeExpr(Map.empty)(BooleanLiteral(true)) - protected[templates] val byID = new IncrementalMap[T, LambdaTemplate[T]] + protected[unrolling] val byID = new IncrementalMap[T, LambdaTemplate[T]] protected val byType = new IncrementalMap[FunctionType, Map[(Expr, Seq[T]), LambdaTemplate[T]]].withDefaultValue(Map.empty) protected val applications = new IncrementalMap[FunctionType, Set[(T, App[T])]].withDefaultValue(Set.empty) protected val knownFree = new IncrementalMap[FunctionType, Set[T]].withDefaultValue(Set.empty) @@ -325,7 +325,7 @@ class LambdaManager[T](encoder: TemplateEncoder[T]) extends DatatypeManager(enco instantiated += key if (knownFree(tpe) contains caller) { - instantiation withApp (key -> TemplateAppInfo(caller, trueT, args)) + instantiation } else if (byID contains caller) { instantiation withApp (key -> TemplateAppInfo(byID(caller), trueT, args)) } else { diff --git a/src/main/scala/leon/solvers/templates/QuantificationManager.scala b/src/main/scala/leon/solvers/unrolling/QuantificationManager.scala similarity index 66% rename from src/main/scala/leon/solvers/templates/QuantificationManager.scala rename to src/main/scala/leon/solvers/unrolling/QuantificationManager.scala index d10b786e5ac3c01a188e1d05f54cebf8add5d362..af31e6543edc4ca06bf58648f37412bb9cf6b22d 100644 --- a/src/main/scala/leon/solvers/templates/QuantificationManager.scala +++ b/src/main/scala/leon/solvers/unrolling/QuantificationManager.scala @@ -2,17 +2,20 @@ package leon package solvers -package templates +package unrolling import leon.utils._ import purescala.Common._ +import purescala.Definitions._ import purescala.Extractors._ import purescala.Constructors._ import purescala.Expressions._ import purescala.ExprOps._ import purescala.Types._ import purescala.TypeOps._ -import purescala.Quantification.{QuantificationTypeMatcher => QTM} +import purescala.Quantification.{QuantificationTypeMatcher => QTM, QuantificationMatcher => QM, Domains} + +import evaluators._ import Instantiation._ import Template._ @@ -55,7 +58,8 @@ class QuantificationTemplate[T]( val matchers: Map[T, Set[Matcher[T]]], val lambdas: Seq[LambdaTemplate[T]], val dependencies: Map[Identifier, T], - val struct: (Forall, Map[Identifier, Identifier])) extends KeyedTemplate[T, Forall] { + val struct: (Forall, Map[Identifier, Identifier]), + stringRepr: () => String) extends KeyedTemplate[T, Forall] { val structure = struct._1 lazy val start = pathVar._2 @@ -89,9 +93,13 @@ class QuantificationTemplate[T]( }, lambdas.map(_.substitute(substituter, matcherSubst)), dependencies.map { case (id, value) => id -> substituter(value) }, - struct + struct, + stringRepr ) } + + private lazy val str : String = stringRepr() + override def toString : String = str } object QuantificationTemplate { @@ -118,7 +126,7 @@ object QuantificationTemplate { val insts: (Identifier, T) = inst -> encoder.encodeId(inst) val guards: (Identifier, T) = guard -> encoder.encodeId(guard) - val (clauses, blockers, applications, functions, matchers, _) = + val (clauses, blockers, applications, functions, matchers, templateString) = Template.encode(encoder, pathVar, quantifiers, condVars, exprVars, guardedExprs, lambdas, Seq.empty, substMap = baseSubstMap + q2s + insts + guards + qs) @@ -128,7 +136,8 @@ object QuantificationTemplate { new QuantificationTemplate[T](quantificationManager, pathVar, qs, q2s, insts, guards._2, quantifiers, condVars, exprVars, condTree, - clauses, blockers, applications, matchers, lambdas, keyDeps, key -> structSubst) + clauses, blockers, applications, matchers, lambdas, keyDeps, key -> structSubst, + () => "Template for " + proposition + " is :\n" + templateString()) } } @@ -171,37 +180,46 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage qkey == key || (qkey.tpe == key.tpe && (qkey.isInstanceOf[TypeKey] || key.isInstanceOf[TypeKey])) } - private val uniformQuantMap: MutableMap[TypeTree, Seq[T]] = MutableMap.empty - private val uniformQuantSet: MutableSet[T] = MutableSet.empty + class VariableNormalizer { + private val varMap: MutableMap[TypeTree, Seq[T]] = MutableMap.empty + private val varSet: MutableSet[T] = MutableSet.empty - def isQuantifier(idT: T): Boolean = uniformQuantSet(idT) - def uniformQuants(ids: Seq[Identifier]): Seq[T] = { - val mapping = ids.groupBy(id => bestRealType(id.getType)).flatMap { case (tpe, idst) => - val prev = uniformQuantMap.get(tpe) match { - case Some(seq) => seq - case None => Seq.empty - } + def normalize(ids: Seq[Identifier]): Seq[T] = { + val mapping = ids.groupBy(id => bestRealType(id.getType)).flatMap { case (tpe, idst) => + val prev = varMap.get(tpe) match { + case Some(seq) => seq + case None => Seq.empty + } - if (prev.size >= idst.size) { - idst zip prev.take(idst.size) - } else { - val (handled, newIds) = idst.splitAt(prev.size) - val uIds = newIds.map(id => id -> encoder.encodeId(id)) + if (prev.size >= idst.size) { + idst zip prev.take(idst.size) + } else { + val (handled, newIds) = idst.splitAt(prev.size) + val uIds = newIds.map(id => id -> encoder.encodeId(id)) - uniformQuantMap(tpe) = prev ++ uIds.map(_._2) - uniformQuantSet ++= uIds.map(_._2) + varMap(tpe) = prev ++ uIds.map(_._2) + varSet ++= uIds.map(_._2) - (handled zip prev) ++ uIds - } - }.toMap + (handled zip prev) ++ uIds + } + }.toMap - ids.map(mapping) - } + ids.map(mapping) + } + + def normalSubst(qs: Seq[(Identifier, T)]): Map[T, T] = { + (qs.map(_._2) zip normalize(qs.map(_._1))).toMap + } - private def uniformSubst(qs: Seq[(Identifier, T)]): Map[T, T] = { - (qs.map(_._2) zip uniformQuants(qs.map(_._1))).toMap + def contains(idT: T): Boolean = varSet(idT) + def get(tpe: TypeTree): Option[Seq[T]] = varMap.get(tpe) } + private val abstractNormalizer = new VariableNormalizer + private val concreteNormalizer = new VariableNormalizer + + def isQuantifier(idT: T): Boolean = abstractNormalizer.contains(idT) + override def assumptions: Seq[T] = super.assumptions ++ quantifications.collect { case q: Quantification => q.currentQ2Var }.toSeq @@ -217,11 +235,16 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage case (CallerKey(caller, tpe), matchers) => caller -> (matchers ++ instCtx.map.get(TypeKey(tpe)).toMatchers) } - private def matcherDepth(m: Matcher[T]): Int = 1 + (0 +: m.args.map { - case Right(ma) => matcherDepth(ma) + private def maxDepth(m: Matcher[T]): Int = 1 + (0 +: m.args.map { + case Right(ma) => maxDepth(ma) case _ => 0 }).max + private def totalDepth(m: Matcher[T]): Int = 1 + m.args.map { + case Right(ma) => totalDepth(ma) + case _ => 0 + }.sum + private def encodeEnablers(es: Set[T]): T = if (es.isEmpty) trueT else encoder.mkAnd(es.toSeq.sortBy(_.toString) : _*) @@ -365,10 +388,10 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage lazy val quantified: Set[T] = quantifiers.map(_._2).toSet lazy val start = pathVar._2 - private lazy val depth = matchers.map(matcherDepth).max + private lazy val depth = matchers.map(maxDepth).max private lazy val transMatchers: Set[Matcher[T]] = (for { (b, ms) <- allMatchers.toSeq - m <- ms if !matchers(m) && matcherDepth(m) <= depth + m <- ms if !matchers(m) && maxDepth(m) <= depth } yield m).toSet /* Build a mapping from applications in the quantified statement to all potential concrete @@ -402,19 +425,6 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage } : _*) } - /* 2.3. filter out bindings that don't make sense where abstract sub-matchers - * (matchers in arguments of other matchers) are bound to different concrete - * matchers in the argument and quorum positions - */ - allMappings.filter { s => - def expand(ms: Traversable[(Arg[T], Arg[T])]): Set[(Matcher[T], Matcher[T])] = ms.flatMap { - case (Right(qm), Right(m)) => Set(qm -> m) ++ expand(qm.args zip m.args) - case _ => Set.empty[(Matcher[T], Matcher[T])] - }.toSet - - expand(s.map(p => Right(p._2) -> Right(p._3))).groupBy(_._1).forall(_._2.size == 1) - } - allMappings } } @@ -462,7 +472,9 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage if (!skip(subst)) { if (!isStrict) { - ignoreSubst(enablers, subst) + val msubst = subst.collect { case (c, Right(m)) => c -> m } + val substituter = encoder.substitute(subst.mapValues(_.encoded)) + ignoredSubsts(this) += ((currentGen + 3, enablers, subst)) } else { instantiation ++= instantiateSubst(enablers, subst, strict = true) } @@ -491,6 +503,8 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage val msubst = substMap.collect { case (c, Right(m)) => c -> m } val substituter = encoder.substitute(substMap.mapValues(_.encoded)) + registerBlockers(substituter) + instantiation ++= Template.instantiate(encoder, QuantificationManager.this, clauses, blockers, applications, Map.empty, substMap) @@ -501,7 +515,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage if (strict && (matchers(m) || transMatchers(m))) { instantiation ++= instCtx.instantiate(sb, sm)(quantifications.toSeq : _*) } else if (!matchers(m)) { - ignoredMatchers += ((currentGen + 3, sb, sm)) + ignoredMatchers += ((currentGen + 2 + totalDepth(m), sb, sm)) } } @@ -509,20 +523,11 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage } } - def ignoreSubst(enablers: Set[T], subst: Map[T, Arg[T]]): Unit = { - val msubst = subst.collect { case (c, Right(m)) => c -> m } - val substituter = encoder.substitute(subst.mapValues(_.encoded)) - val nextGen = (if (matchers.forall { m => - val sm = m.substitute(substituter, msubst) - instCtx(enablers -> sm) - }) currentGen + 3 else currentGen + 3) - - ignoredSubsts(this) += ((nextGen, enablers, subst)) - } - protected def instanceSubst(enabler: T): Map[T, T] protected def skip(subst: Map[T, Arg[T]]): Boolean = false + + protected def registerBlockers(substituter: T => T): Unit = () } private class Quantification ( @@ -543,7 +548,8 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage val lambdas: Seq[LambdaTemplate[T]], val template: QuantificationTemplate[T]) extends MatcherQuantification { - var currentQ2Var: T = qs._2 + private var _currentQ2Var: T = qs._2 + def currentQ2Var = _currentQ2Var val holds = qs._2 val body = { val quantified = quantifiers.map(_._1).toSet @@ -551,15 +557,24 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage replaceFromIDs(mapping, template.structure.body) } + private var _currentInsts: Map[T, Set[T]] = Map.empty + def currentInsts = _currentInsts + protected def instanceSubst(enabler: T): Map[T, T] = { val nextQ2Var = encoder.encodeId(q2s._1) val subst = Map(qs._2 -> currentQ2Var, guardVar -> enabler, q2s._2 -> nextQ2Var, insts._2 -> encoder.encodeId(insts._1)) - currentQ2Var = nextQ2Var + _currentQ2Var = nextQ2Var subst } + + override def registerBlockers(substituter: T => T): Unit = { + val freshInst = substituter(insts._2) + val bs = (blockers.keys ++ applications.keys).map(substituter).toSet + _currentInsts += freshInst -> bs + } } private lazy val blockerId = FreshIdentifier("blocker", BooleanType, true) @@ -647,69 +662,6 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage (m: Matcher[T]) => m.args.collect { case Left(a) if quantified(a) => a }.toSet) } - private def instantiateConstants(quantifiers: Seq[(Identifier, T)], matchers: Set[Matcher[T]]): Instantiation[T] = { - val quantifierSubst = uniformSubst(quantifiers) - val substituter = encoder.substitute(quantifierSubst) - var instantiation: Instantiation[T] = Instantiation.empty - - for { - m <- matchers - sm = m.substitute(substituter, Map.empty) - if !instCtx.corresponding(sm).exists(_._2.args == sm.args) - } { - instantiation ++= instCtx.instantiate(Set.empty, m)(quantifications.toSeq : _*) - instantiation ++= instCtx.instantiate(Set.empty, sm)(quantifications.toSeq : _*) - } - - def unifyMatchers(matchers: Seq[Matcher[T]]): Unit = matchers match { - case sm +: others => - for (pm <- others if correspond(pm, sm)) { - val encodedArgs = (sm.args zip pm.args).map(p => p._1.encoded -> p._2.encoded) - val mismatches = encodedArgs.zipWithIndex.collect { - case ((sa, pa), idx) if isQuantifier(sa) && isQuantifier(pa) && sa != pa => (idx, (pa, sa)) - }.toMap - - def extractChains(indexes: Seq[Int], partials: Seq[Seq[Int]]): Seq[Seq[Int]] = indexes match { - case idx +: xs => - val (p1, p2) = mismatches(idx) - val newPartials = Seq(idx) +: partials.map { seq => - if (mismatches(seq.head)._1 == p2) idx +: seq - else if (mismatches(seq.last)._2 == p1) seq :+ idx - else seq - } - - val (closed, remaining) = newPartials.partition { seq => - mismatches(seq.head)._1 == mismatches(seq.last)._2 - } - closed ++ extractChains(xs, partials ++ remaining) - - case _ => Seq.empty - } - - val chains = extractChains(mismatches.keys.toSeq, Seq.empty) - val positions = chains.foldLeft(Map.empty[Int, Int]) { (mapping, seq) => - val res = seq.min - mapping ++ seq.map(i => i -> res) - } - - def extractArgs(args: Seq[Arg[T]]): Seq[Arg[T]] = - (0 until args.size).map(i => args(positions.getOrElse(i, i))) - - instantiation ++= instCtx.instantiate(Set.empty, sm.copy(args = extractArgs(sm.args)))(quantifications.toSeq : _*) - instantiation ++= instCtx.instantiate(Set.empty, pm.copy(args = extractArgs(pm.args)))(quantifications.toSeq : _*) - } - - unifyMatchers(others) - - case _ => - } - - val substMatchers = matchers.map(_.substitute(substituter, Map.empty)) - unifyMatchers(substMatchers.toSeq) - - instantiation - } - def instantiateAxiom(template: LambdaTemplate[T], substMap: Map[T, Arg[T]]): Instantiation[T] = { def quantifiedMatcher(m: Matcher[T]): Boolean = m.args.exists(a => a match { case Left(v) => isQuantifier(v) @@ -724,7 +676,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage } } - val quantifiers = quantified zip uniformQuants(quantified) + val quantifiers = quantified zip abstractNormalizer.normalize(quantified) val key = template.key -> quantifiers if (quantifiers.isEmpty || lambdaAxioms(key)) { @@ -805,7 +757,8 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage case None => val qT = encoder.encodeId(template.qs._1) val quantified = template.quantifiers.map(_._2).toSet - val matchQuorums = extractQuorums(quantified, template.matchers.flatMap(_._2).toSet, template.lambdas) + val matcherSet = template.matchers.flatMap(_._2).toSet + val matchQuorums = extractQuorums(quantified, matcherSet, template.lambdas) var instantiation = Instantiation.empty[T] @@ -843,7 +796,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage encoder.mkImplies(template.start, encoder.mkEquals(qT, newQs)) } - instantiation ++= instantiateConstants(template.quantifiers, template.matchers.flatMap(_._2).toSet) + instantiation ++= instantiateConstants(template.quantifiers, matcherSet) templates += template.key -> qT (qT, instantiation) @@ -874,7 +827,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage } for ((bs,m) <- matchersToRelease) { - instCtx.instantiate(bs, m)(quantifications.toSeq : _*) + instantiation ++= instCtx.instantiate(bs, m)(quantifications.toSeq : _*) } val substsToRelease = quantifications.toList.flatMap { q => @@ -896,9 +849,74 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage instantiation } + private def instantiateConstants(quantifiers: Seq[(Identifier, T)], matchers: Set[Matcher[T]]): Instantiation[T] = { + var instantiation: Instantiation[T] = Instantiation.empty + + for (normalizer <- List(abstractNormalizer, concreteNormalizer)) { + val quantifierSubst = normalizer.normalSubst(quantifiers) + val substituter = encoder.substitute(quantifierSubst) + + for { + m <- matchers + sm = m.substitute(substituter, Map.empty) + if !instCtx.corresponding(sm).exists(_._2.args == sm.args) + } instantiation ++= instCtx.instantiate(Set.empty, sm)(quantifications.toSeq : _*) + + def unifyMatchers(matchers: Seq[Matcher[T]]): Instantiation[T] = matchers match { + case sm +: others => + var instantiation = Instantiation.empty[T] + for (pm <- others if correspond(pm, sm)) { + val encodedArgs = (sm.args zip pm.args).map(p => p._1.encoded -> p._2.encoded) + val mismatches = encodedArgs.zipWithIndex.collect { + case ((sa, pa), idx) if isQuantifier(sa) && isQuantifier(pa) && sa != pa => (idx, (pa, sa)) + }.toMap + + def extractChains(indexes: Seq[Int], partials: Seq[Seq[Int]]): Seq[Seq[Int]] = indexes match { + case idx +: xs => + val (p1, p2) = mismatches(idx) + val newPartials = Seq(idx) +: partials.map { seq => + if (mismatches(seq.head)._1 == p2) idx +: seq + else if (mismatches(seq.last)._2 == p1) seq :+ idx + else seq + } + + val (closed, remaining) = newPartials.partition { seq => + mismatches(seq.head)._1 == mismatches(seq.last)._2 + } + closed ++ extractChains(xs, partials ++ remaining) + + case _ => Seq.empty + } + + val chains = extractChains(mismatches.keys.toSeq, Seq.empty) + val positions = chains.foldLeft(Map.empty[Int, Int]) { (mapping, seq) => + val res = seq.min + mapping ++ seq.map(i => i -> res) + } + + def extractArgs(args: Seq[Arg[T]]): Seq[Arg[T]] = + (0 until args.size).map(i => args(positions.getOrElse(i, i))) + + instantiation ++= instCtx.instantiate(Set.empty, sm.copy(args = extractArgs(sm.args)))(quantifications.toSeq : _*) + instantiation ++= instCtx.instantiate(Set.empty, pm.copy(args = extractArgs(pm.args)))(quantifications.toSeq : _*) + } + + instantiation ++ unifyMatchers(others) + + case _ => Instantiation.empty[T] + } + + if (normalizer == abstractNormalizer) { + val substMatchers = matchers.map(_.substitute(substituter, Map.empty)) + instantiation ++= unifyMatchers(substMatchers.toSeq) + } + } + + instantiation + } + def checkClauses: Seq[T] = { val clauses = new scala.collection.mutable.ListBuffer[T] - //val keySets = scala.collection.mutable.Map.empty[MatcherKey, T] val keyClause = MutableMap.empty[MatcherKey, (Seq[T], T)] for ((_, bs, m) <- ignoredMatchers) { @@ -961,21 +979,270 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage } } + def isQuantified(e: Arg[T]): Boolean = e match { + case Left(t) => isQuantifier(t) + case Right(m) => m.args.exists(isQuantified) + } + for ((key, ctx) <- instCtx.map.instantiations) { val QTM(argTypes, _) = key.tpe for { - (tpe,idx) <- argTypes.zipWithIndex - quants <- uniformQuantMap.get(tpe) if quants.nonEmpty + (tpe, idx) <- argTypes.zipWithIndex + quants <- abstractNormalizer.get(tpe) if quants.nonEmpty (b, m) <- ctx - arg = m.args(idx).encoded if !isQuantifier(arg) - } clauses += encoder.mkAnd(quants.map(q => encoder.mkNot(encoder.mkEquals(q, arg))) : _*) - } + arg = m.args(idx) if !isQuantified(arg) + } clauses += encoder.mkAnd(quants.map(q => encoder.mkNot(encoder.mkEquals(q, arg.encoded))) : _*) + + val byPosition: Iterable[Seq[T]] = ctx.flatMap { case (b, m) => + if (b != trueT) Seq.empty else m.args.zipWithIndex + }.groupBy(_._2).map(p => p._2.toSeq.flatMap { + case (a, _) => if (isQuantified(a)) Some(a.encoded) else None + }).filter(_.nonEmpty) - for ((tpe, base +: rest) <- uniformQuantMap; q <- rest) { - clauses += encoder.mkEquals(base, q) + for ((a +: as) <- byPosition; a2 <- as) { + clauses += encoder.mkEquals(a, a2) + } } clauses.toSeq } + + trait ModelView { + protected val vars: Map[Identifier, T] + protected val evaluator: evaluators.DeterministicEvaluator + + protected def get(id: Identifier): Option[Expr] + protected def eval(elem: T, tpe: TypeTree): Option[Expr] + + implicit lazy val context = evaluator.context + lazy val reporter = context.reporter + + private def extract(b: T, m: Matcher[T]): Option[Seq[Expr]] = { + val QTM(fromTypes, _) = m.tpe + val optEnabler = eval(b, BooleanType) + optEnabler.filter(_ == BooleanLiteral(true)).flatMap { _ => + val optArgs = (m.args zip fromTypes).map { case (arg, tpe) => eval(arg.encoded, tpe) } + if (optArgs.forall(_.isDefined)) Some(optArgs.map(_.get)) + else None + } + } + + private def functionsOf(expr: Expr, path: Expr): (Seq[(Expr, Expr)], Seq[Expr] => Expr) = { + + def reconstruct(subs: Seq[(Seq[(Expr, Expr)], Seq[Expr] => Expr)], + recons: Seq[Expr] => Expr): (Seq[(Expr, Expr)], Seq[Expr] => Expr) = + (subs.flatMap(_._1), (exprs: Seq[Expr]) => { + var curr = exprs + recons(subs.map { case (es, recons) => + val (used, remaining) = curr.splitAt(es.size) + curr = remaining + recons(used) + }) + }) + + def rec(expr: Expr, path: Expr): (Seq[(Expr, Expr)], Seq[Expr] => Expr) = expr match { + case (_: Lambda) | (_: FiniteLambda) => + (Seq(expr -> path), (es: Seq[Expr]) => es.head) + + case Tuple(es) => reconstruct(es.zipWithIndex.map { + case (e, i) => rec(e, TupleSelect(path, i + 1)) + }, Tuple) + + case CaseClass(cct, es) => reconstruct((cct.classDef.fieldsIds zip es).map { + case (id, e) => rec(e, CaseClassSelector(cct, path, id)) + }, CaseClass(cct, _)) + + case _ => (Seq.empty, (es: Seq[Expr]) => expr) + } + + rec(expr, path) + } + + def getPartialModel: PartialModel = { + val typeDomains: Map[TypeTree, Set[Seq[Expr]]] = typeInstantiations.map { + case (tpe, domain) => tpe -> domain.flatMap { case (b, m) => extract(b, m) }.toSet + } + + val lambdaDomains: Map[Lambda, Set[Seq[Expr]]] = lambdaInstantiations.map { + case (l, domain) => l -> domain.flatMap { case (b, m) => extract(b, m) }.toSet + } + + val domains = new Domains(lambdaDomains, typeDomains) + + val partialDomains: Map[T, Set[Seq[Expr]]] = partialInstantiations.map { + case (t, domain) => t -> domain.flatMap { case (b, m) => extract(b, m) }.toSet + } + + def extractElse(body: Expr): Expr = body match { + case IfExpr(cond, thenn, elze) => extractElse(elze) + case _ => body + } + + val mapping = vars.map { case (id, idT) => + val value = get(id).getOrElse(simplestValue(id.getType)) + val (functions, recons) = functionsOf(value, Variable(id)) + + id -> recons(functions.map { case (f, path) => + val encoded = encoder.encodeExpr(Map(id -> idT))(path) + val tpe = bestRealType(f.getType).asInstanceOf[FunctionType] + partialDomains.get(encoded).orElse(typeDomains.get(tpe)).map { domain => + FiniteLambda(domain.toSeq.map { es => + val optEv = evaluator.eval(application(f, es)).result + es -> optEv.getOrElse(scala.sys.error("Unexpectedly failed to evaluate " + application(f, es))) + }, f match { + case FiniteLambda(_, dflt, _) => dflt + case Lambda(_, body) => extractElse(body) + case _ => scala.sys.error("What kind of function is this : " + f.asString + " !?") + }, tpe) + }.getOrElse(f) + }) + } + + new PartialModel(mapping, domains) + } + + def getTotalModel: Model = { + + def checkForalls(quantified: Set[Identifier], body: Expr): Option[String] = { + val matchers = purescala.ExprOps.collect[(Expr, Seq[Expr])] { + case QM(e, args) => Set(e -> args) + case _ => Set.empty + } (body) + + if (matchers.isEmpty) + return Some("No matchers found.") + + val matcherToQuants = matchers.foldLeft(Map.empty[Expr, Set[Identifier]]) { + case (acc, (m, args)) => acc + (m -> (acc.getOrElse(m, Set.empty) ++ args.flatMap { + case Variable(id) if quantified(id) => Set(id) + case _ => Set.empty[Identifier] + })) + } + + val bijectiveMappings = matcherToQuants.filter(_._2.nonEmpty).groupBy(_._2) + if (bijectiveMappings.size > 1) + return Some("Non-bijective mapping for symbol " + bijectiveMappings.head._2.head._1.asString) + + def quantifiedArg(e: Expr): Boolean = e match { + case Variable(id) => quantified(id) + case QM(_, args) => args.forall(quantifiedArg) + case _ => false + } + + purescala.ExprOps.postTraversal(m => m match { + case QM(_, args) => + val qArgs = args.filter(quantifiedArg) + + if (qArgs.nonEmpty && qArgs.size < args.size) + return Some("Mixed ground and quantified arguments in " + m.asString) + + case Operator(es, _) if es.collect { case Variable(id) if quantified(id) => id }.nonEmpty => + return Some("Invalid operation on quantifiers " + m.asString) + + case (_: Equals) | (_: And) | (_: Or) | (_: Implies) | (_: Not) => // OK + + case Operator(es, _) if (es.flatMap(variablesOf).toSet & quantified).nonEmpty => + return Some("Unandled implications from operation " + m.asString) + + case _ => + }) (body) + + body match { + case Variable(id) if quantified(id) => + Some("Unexpected free quantifier " + id.asString) + case _ => None + } + } + + val issues: Iterable[(Seq[Identifier], Expr, String)] = for { + q <- quantifications.view + if eval(q.holds, BooleanType) == Some(BooleanLiteral(true)) + msg <- checkForalls(q.quantifiers.map(_._1).toSet, q.body) + } yield (q.quantifiers.map(_._1), q.body, msg) + + if (issues.nonEmpty) { + val (quantifiers, body, msg) = issues.head + reporter.warning("Model soundness not guaranteed for \u2200" + + quantifiers.map(_.asString).mkString(",") + ". " + body.asString+" :\n => " + msg) + } + + val types = typeInstantiations + val partials = partialInstantiations + + def extractCond(params: Seq[Identifier], args: Seq[(T, Expr)], structure: Map[T, Identifier]): Seq[Expr] = (params, args) match { + case (id +: rparams, (v, arg) +: rargs) => + if (isQuantifier(v)) { + structure.get(v) match { + case Some(pid) => Equals(Variable(id), Variable(pid)) +: extractCond(rparams, rargs, structure) + case None => extractCond(rparams, rargs, structure + (v -> id)) + } + } else { + Equals(Variable(id), arg) +: extractCond(rparams, rargs, structure) + } + case _ => Seq.empty + } + + new Model(vars.map { case (id, idT) => + val value = get(id).getOrElse(simplestValue(id.getType)) + val (functions, recons) = functionsOf(value, Variable(id)) + + id -> recons(functions.map { case (f, path) => + val encoded = encoder.encodeExpr(Map(id -> idT))(path) + val tpe = bestRealType(f.getType).asInstanceOf[FunctionType] + val params = tpe.from.map(tpe => FreshIdentifier("x", tpe, true)) + partials.get(encoded).orElse(types.get(tpe)).map { domain => + val conditionals = domain.flatMap { case (b, m) => + extract(b, m).map { args => + val result = evaluator.eval(application(f, args)).result.getOrElse { + scala.sys.error("Unexpectedly failed to evaluate " + application(f, args)) + } + + val cond = if (m.args.exists(arg => isQuantifier(arg.encoded))) { + extractCond(params, m.args.map(_.encoded) zip args, Map.empty) + } else { + (params zip args).map(p => Equals(Variable(p._1), p._2)) + } + + cond -> result + } + }.toMap + + if (conditionals.isEmpty) { + value + } else { + val ((_, dflt)) +: rest = conditionals.toSeq.sortBy { case (conds, _) => + (conds.flatMap(variablesOf).toSet.size, conds.size) + } + + val body = rest.foldLeft(dflt) { case (elze, (conds, res)) => + if (conds.isEmpty) elze else (elze match { + case pres if res == pres => res + case _ => IfExpr(andJoin(conds), res, elze) + }) + } + + Lambda(params.map(ValDef(_)), body) + } + }.getOrElse(f) + }) + }) + } + } + + def getModel(vs: Map[Identifier, T], ev: DeterministicEvaluator, _get: Identifier => Option[Expr], _eval: (T, TypeTree) => Option[Expr]) = new ModelView { + val vars: Map[Identifier, T] = vs + val evaluator: DeterministicEvaluator = ev + + def get(id: Identifier): Option[Expr] = _get(id) + def eval(elem: T, tpe: TypeTree): Option[Expr] = _eval(elem, tpe) + } + + def getBlockersToPromote(eval: (T, TypeTree) => Option[Expr]): Seq[T] = quantifications.toSeq.flatMap { + case q: Quantification if eval(q.qs._2, BooleanType) == Some(BooleanLiteral(false)) => + val falseInsts = q.currentInsts.filter { case (inst, bs) => eval(inst, BooleanType) == Some(BooleanLiteral(false)) } + falseInsts.flatMap(_._2) + case _ => Seq.empty + } } + diff --git a/src/main/scala/leon/solvers/templates/TemplateEncoder.scala b/src/main/scala/leon/solvers/unrolling/TemplateEncoder.scala similarity index 61% rename from src/main/scala/leon/solvers/templates/TemplateEncoder.scala rename to src/main/scala/leon/solvers/unrolling/TemplateEncoder.scala index c2a2051b15f3e3a857deb52d73d9d973362ffd56..74488aa8e7b5ce71aafd115038a0dc74c4ea0b2a 100644 --- a/src/main/scala/leon/solvers/templates/TemplateEncoder.scala +++ b/src/main/scala/leon/solvers/unrolling/TemplateEncoder.scala @@ -2,10 +2,18 @@ package leon package solvers -package templates +package unrolling -import purescala.Common.Identifier -import purescala.Expressions.Expr +import purescala.Common._ +import purescala.Expressions._ +import purescala.Definitions._ +import purescala.Extractors._ +import purescala.ExprOps._ +import purescala.Types._ + +import utils._ + +import scala.collection.mutable.{Map => MutableMap, Set => MutableSet} trait TemplateEncoder[T] { def encodeId(id: Identifier): T @@ -21,3 +29,4 @@ trait TemplateEncoder[T] { def extractNot(v: T): Option[T] } + diff --git a/src/main/scala/leon/solvers/templates/TemplateGenerator.scala b/src/main/scala/leon/solvers/unrolling/TemplateGenerator.scala similarity index 96% rename from src/main/scala/leon/solvers/templates/TemplateGenerator.scala rename to src/main/scala/leon/solvers/unrolling/TemplateGenerator.scala index 34b7b35f428bff701af9a25c06258da5a43443ed..127c1004eb600c70b240726088c74f72da1d9357 100644 --- a/src/main/scala/leon/solvers/templates/TemplateGenerator.scala +++ b/src/main/scala/leon/solvers/unrolling/TemplateGenerator.scala @@ -2,7 +2,7 @@ package leon package solvers -package templates +package unrolling import purescala.Common._ import purescala.Expressions._ @@ -14,13 +14,15 @@ import purescala.Definitions._ import purescala.Constructors._ import purescala.Quantification._ +import theories._ import utils.SeqUtils._ import Instantiation._ -class TemplateGenerator[T](val encoder: TemplateEncoder[T], +class TemplateGenerator[T](val theories: TheoryEncoder, + val encoder: TemplateEncoder[T], val assumePreHolds: Boolean) { private var cache = Map[TypedFunDef, FunctionTemplate[T]]() - private var cacheExpr = Map[Expr, FunctionTemplate[T]]() + private var cacheExpr = Map[Expr, (FunctionTemplate[T], Map[Identifier, Identifier])]() private type Clauses = ( Map[Identifier,T], @@ -45,20 +47,24 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], val manager = new QuantificationManager[T](encoder) - def mkTemplate(body: Expr): FunctionTemplate[T] = { - if (cacheExpr contains body) { - return cacheExpr(body) + def mkTemplate(raw: Expr): (FunctionTemplate[T], Map[Identifier, Identifier]) = { + if (cacheExpr contains raw) { + return cacheExpr(raw) } - val arguments = variablesOf(body).toSeq.map(ValDef(_)) + val mapping = variablesOf(raw).map(id => id -> theories.encode(id)).toMap + val body = theories.encode(raw)(mapping) + + val arguments = mapping.values.toSeq.map(ValDef(_)) val fakeFunDef = new FunDef(FreshIdentifier("fake", alwaysShowUniqueID = true), Nil, arguments, body.getType) fakeFunDef.precondition = Some(andJoin(arguments.map(vd => manager.typeUnroller(vd.toVariable)))) fakeFunDef.body = Some(body) val res = mkTemplate(fakeFunDef.typed, false) - cacheExpr += body -> res - res + val p = (res, mapping) + cacheExpr += raw -> p + p } def mkTemplate(tfd: TypedFunDef, isRealFunDef: Boolean = true): FunctionTemplate[T] = { @@ -98,7 +104,6 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], val substMap : Map[Identifier, T] = arguments.toMap + pathVar - val (bodyConds, bodyExprs, bodyTree, bodyGuarded, bodyLambdas, bodyQuantifications) = if (isRealFunDef) { invocationEqualsBody.foldLeft(emptyClauses)((clsSet, cls) => clsSet ++ mkClauses(start, cls, substMap)) } else { diff --git a/src/main/scala/leon/solvers/templates/TemplateInfo.scala b/src/main/scala/leon/solvers/unrolling/TemplateInfo.scala similarity index 98% rename from src/main/scala/leon/solvers/templates/TemplateInfo.scala rename to src/main/scala/leon/solvers/unrolling/TemplateInfo.scala index 455704dc43928c473108d29fbb8789d13e6a78a3..3dc60c89f29d5efb55abff75b4d1e6d604719b9c 100644 --- a/src/main/scala/leon/solvers/templates/TemplateInfo.scala +++ b/src/main/scala/leon/solvers/unrolling/TemplateInfo.scala @@ -2,7 +2,7 @@ package leon package solvers -package templates +package unrolling import purescala.Definitions.TypedFunDef import Template.Arg diff --git a/src/main/scala/leon/solvers/templates/TemplateManager.scala b/src/main/scala/leon/solvers/unrolling/TemplateManager.scala similarity index 93% rename from src/main/scala/leon/solvers/templates/TemplateManager.scala rename to src/main/scala/leon/solvers/unrolling/TemplateManager.scala index 2bb0cbbd0b00abdfdc21b04b2067a68bd5a6134d..e55c46f2fc97f442e6a49a8a71307cb5390fc5de 100644 --- a/src/main/scala/leon/solvers/templates/TemplateManager.scala +++ b/src/main/scala/leon/solvers/unrolling/TemplateManager.scala @@ -2,7 +2,7 @@ package leon package solvers -package templates +package unrolling import purescala.Common._ import purescala.Definitions._ @@ -24,7 +24,9 @@ object Instantiation { type AppBlockers[T] = Map[(T, App[T]), Set[TemplateAppInfo[T]]] type Instantiation[T] = (Clauses[T], CallBlockers[T], AppBlockers[T]) - def empty[T] = (Seq.empty[T], Map.empty[T, Set[TemplateCallInfo[T]]], Map.empty[(T, App[T]), Set[TemplateAppInfo[T]]]) + def empty[T] = (Seq.empty[T], + Map.empty[T, Set[TemplateCallInfo[T]]], + Map.empty[(T, App[T]), Set[TemplateAppInfo[T]]]) implicit class MapSetWrapper[A,B](map: Map[A,Set[B]]) { def merge(that: Map[A,Set[B]]): Map[A,Set[B]] = (map.keys ++ that.keys).map { k => @@ -51,14 +53,13 @@ object Instantiation { def withCalls(calls: CallBlockers[T]): Instantiation[T] = (i._1, i._2 merge calls, i._3) def withApps(apps: AppBlockers[T]): Instantiation[T] = (i._1, i._2, i._3 merge apps) - def withApp(app: ((T, App[T]), TemplateAppInfo[T])): Instantiation[T] = { + def withApp(app: ((T, App[T]), TemplateAppInfo[T])): Instantiation[T] = (i._1, i._2, i._3 merge Map(app._1 -> Set(app._2))) - } } } import Instantiation.{empty => _, _} -import Template.Arg +import Template.{Apps, Calls, Functions, Arg} trait Template[T] { self => val encoder : TemplateEncoder[T] @@ -71,9 +72,9 @@ trait Template[T] { self => val exprVars : Map[Identifier, T] val condTree : Map[Identifier, Set[Identifier]] - val clauses : Seq[T] - val blockers : Map[T, Set[TemplateCallInfo[T]]] - val applications : Map[T, Set[App[T]]] + val clauses : Clauses[T] + val blockers : Calls[T] + val applications : Apps[T] val functions : Set[(T, FunctionType, T)] val lambdas : Seq[LambdaTemplate[T]] @@ -132,6 +133,7 @@ object Template { Matcher(encodeExpr(caller), bestRealType(caller.getType), arguments.map(arg => Left(encodeExpr(arg))), encodeExpr(app)) } + type Calls[T] = Map[T, Set[TemplateCallInfo[T]]] type Apps[T] = Map[T, Set[App[T]]] type Functions[T] = Set[(T, FunctionType, T)] @@ -147,7 +149,7 @@ object Template { substMap: Map[Identifier, T] = Map.empty[Identifier, T], optCall: Option[TypedFunDef] = None, optApp: Option[(T, FunctionType)] = None - ) : (Clauses[T], CallBlockers[T], Apps[T], Functions[T], Map[T, Set[Matcher[T]]], () => String) = { + ) : (Clauses[T], Calls[T], Apps[T], Functions[T], Map[T, Set[Matcher[T]]], () => String) = { val idToTrId : Map[Identifier, T] = condVars ++ exprVars + pathVar ++ arguments ++ substMap ++ lambdas.map(_.ids) ++ quantifications.map(_.qs) @@ -249,8 +251,8 @@ object Template { (blockers, applications, matchers) } - val encodedBlockers : Map[T, Set[TemplateCallInfo[T]]] = blockers.map(p => idToTrId(p._1) -> p._2) - val encodedApps : Map[T, Set[App[T]]] = applications.map(p => idToTrId(p._1) -> p._2) + val encodedBlockers : Calls[T] = blockers.map(p => idToTrId(p._1) -> p._2) + val encodedApps : Apps[T] = applications.map(p => idToTrId(p._1) -> p._2) val encodedMatchers : Map[T, Set[Matcher[T]]] = matchers.map(p => idToTrId(p._1) -> p._2) val stringRepr : () => String = () => { @@ -271,6 +273,9 @@ object Template { }) + " * Lambdas :\n" + lambdas.map { case template => " +> " + template.toString.split("\n").mkString("\n ") + "\n" + }.mkString("\n") + + " * Foralls :\n" + quantifications.map { case template => + " +> " + template.toString.split("\n").mkString("\n ") + "\n" }.mkString("\n") } @@ -285,7 +290,7 @@ object Template { condTree: Map[Identifier, Set[Identifier]], quantifications: Seq[QuantificationTemplate[T]], lambdas: Seq[LambdaTemplate[T]], - functions: Set[(T, FunctionType, T)], + functions: Functions[T], baseSubst: Map[T, Arg[T]], pathVar: Identifier, aVar: T @@ -351,9 +356,9 @@ object Template { def instantiate[T]( encoder: TemplateEncoder[T], manager: TemplateManager[T], - clauses: Seq[T], - blockers: Map[T, Set[TemplateCallInfo[T]]], - applications: Map[T, Set[App[T]]], + clauses: Clauses[T], + blockers: Calls[T], + applications: Apps[T], matchers: Map[T, Set[Matcher[T]]], substMap: Map[T, Arg[T]] ): Instantiation[T] = { @@ -361,9 +366,9 @@ object Template { val substituter : T => T = encoder.substitute(substMap.mapValues(_.encoded)) val msubst = substMap.collect { case (c, Right(m)) => c -> m } - val newClauses = clauses.map(substituter) + val newClauses: Clauses[T] = clauses.map(substituter) - val newBlockers = blockers.map { case (b,fis) => + val newBlockers: CallBlockers[T] = blockers.map { case (b,fis) => substituter(b) -> fis.map(fi => fi.copy(args = fi.args.map(_.substitute(substituter, msubst)))) } @@ -451,10 +456,10 @@ class FunctionTemplate[T] private( val condVars: Map[Identifier, T], val exprVars: Map[Identifier, T], val condTree: Map[Identifier, Set[Identifier]], - val clauses: Seq[T], - val blockers: Map[T, Set[TemplateCallInfo[T]]], - val applications: Map[T, Set[App[T]]], - val functions: Set[(T, FunctionType, T)], + val clauses: Clauses[T], + val blockers: Calls[T], + val applications: Apps[T], + val functions: Functions[T], val lambdas: Seq[LambdaTemplate[T]], val matchers: Map[T, Set[Matcher[T]]], val quantifications: Seq[QuantificationTemplate[T]], @@ -465,7 +470,7 @@ class FunctionTemplate[T] private( override def toString : String = str } -class TemplateManager[T](protected[templates] val encoder: TemplateEncoder[T]) extends IncrementalState { +class TemplateManager[T](protected[unrolling] val encoder: TemplateEncoder[T]) extends IncrementalState { private val condImplies = new IncrementalMap[T, Set[T]].withDefaultValue(Set.empty) private val condImplied = new IncrementalMap[T, Set[T]].withDefaultValue(Set.empty) @@ -492,6 +497,7 @@ class TemplateManager[T](protected[templates] val encoder: TemplateEncoder[T]) e def blocker(b: T): Unit = condImplies += (b -> Set.empty) def isBlocker(b: T): Boolean = condImplies.isDefinedAt(b) || condImplied.isDefinedAt(b) def blockerParents(b: T): Set[T] = condImplied(b) + def blockerChildren(b: T): Set[T] = condImplies(b) def implies(b1: T, b2: T): Unit = implies(b1, Set(b2)) def implies(b1: T, b2s: Set[T]): Unit = { diff --git a/src/main/scala/leon/solvers/templates/UnrollingBank.scala b/src/main/scala/leon/solvers/unrolling/UnrollingBank.scala similarity index 89% rename from src/main/scala/leon/solvers/templates/UnrollingBank.scala rename to src/main/scala/leon/solvers/unrolling/UnrollingBank.scala index acee03b73ece82c3e64eeb2d8c9149929a052aac..2c38aca15800f15a70a80bd50ed5c9a2ddae9490 100644 --- a/src/main/scala/leon/solvers/templates/UnrollingBank.scala +++ b/src/main/scala/leon/solvers/unrolling/UnrollingBank.scala @@ -2,7 +2,7 @@ package leon package solvers -package templates +package unrolling import purescala.Common._ import purescala.Expressions._ @@ -100,6 +100,8 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat } } + def getFiniteRangeClauses: Seq[T] = manager.checkClauses + private def registerCallBlocker(gen: Int, id: T, fis: Set[TemplateCallInfo[T]]) { val notId = encoder.mkNot(id) @@ -159,13 +161,14 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat clause } - def getClauses(expr: Expr, bindings: Map[Expr, T]): Seq[T] = { + def getClauses(expr: Expr, bindings: Map[Identifier, T]): Seq[T] = { // OK, now this is subtle. This `getTemplate` will return // a template for a "fake" function. Now, this template will // define an activating boolean... - val template = templateGenerator.mkTemplate(expr) + val (template, mapping) = templateGenerator.mkTemplate(expr) + val reverse = mapping.map(p => p._2 -> p._1) - val trArgs = template.tfd.params.map(vd => Left(bindings(Variable(vd.id)))) + val trArgs = template.tfd.params.map(vd => Left(bindings(reverse(vd.id)))) // ...now this template defines clauses that are all guarded // by that activating boolean. If that activating boolean is @@ -174,7 +177,7 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat val blockClauses = freshAppBlocks(appBlocks.keys) - for((b, infos) <- callBlocks) { + for ((b, infos) <- callBlocks) { registerCallBlocker(nextGeneration(0), b, infos) } @@ -193,7 +196,7 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat clauses } - def nextGeneration(gen: Int) = gen + 3 + def nextGeneration(gen: Int) = gen + 5 def decreaseAllGenerations() = { for ((block, (gen, origGen, ast, infos)) <- callInfos) { @@ -206,24 +209,50 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat } } - def promoteBlocker(b: T) = { - if (callInfos contains b) { - val (_, origGen, notB, fis) = callInfos(b) - - callInfos += b -> (1, origGen, notB, fis) - } + def promoteBlocker(b: T, force: Boolean = false): Boolean = { + var seen: Set[T] = Set.empty + var promoted: Boolean = false + var blockers: Seq[Set[T]] = Seq(Set(b)) - if (blockerToApps contains b) { - val app = blockerToApps(b) - val (_, origGen, _, notB, infos) = appInfos(app) + do { + val (bs +: rest) = blockers + blockers = rest - appInfos += app -> (1, origGen, b, notB, infos) - } + val next = (for (b <- bs if !seen(b)) yield { + seen += b + + if (callInfos contains b) { + val (_, origGen, notB, fis) = callInfos(b) + + callInfos += b -> (1, origGen, notB, fis) + promoted = true + } + + if (blockerToApps contains b) { + val app = blockerToApps(b) + val (_, origGen, _, notB, infos) = appInfos(app) + + appInfos += app -> (1, origGen, b, notB, infos) + promoted = true + } + + if (force) { + templateGenerator.manager.blockerChildren(b) + } else { + Set.empty[T] + } + }).flatten + + if (next.nonEmpty) blockers :+= next + } while (!promoted && blockers.nonEmpty) + + promoted } def instantiateQuantifiers(force: Boolean = false): Seq[T] = { val (newExprs, callBlocks, appBlocks) = manager.instantiateIgnored(force) val blockExprs = freshAppBlocks(appBlocks.keys) + val gens = (callInfos.values.map(_._1) ++ appInfos.values.map(_._1)) val gen = if (gens.nonEmpty) gens.min else 0 @@ -367,12 +396,6 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat newClauses ++= newCls } - /* - for ((app @ (b, _), (gen, _, _, _, infos)) <- thisAppInfos if infos.isEmpty) { - registerAppBlocker(nextGeneration(gen), app, infos) - } - */ - reporter.debug(s" - ${newClauses.size} new clauses") //context.reporter.ifDebug { debug => // debug(s" - new clauses:") diff --git a/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala b/src/main/scala/leon/solvers/unrolling/UnrollingSolver.scala similarity index 63% rename from src/main/scala/leon/solvers/combinators/UnrollingSolver.scala rename to src/main/scala/leon/solvers/unrolling/UnrollingSolver.scala index 6594eb07e239876a5e046ebb3a99606f91225a93..852e125ed88db4327b18cfabc329d31e1366a2f6 100644 --- a/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala +++ b/src/main/scala/leon/solvers/unrolling/UnrollingSolver.scala @@ -2,7 +2,7 @@ package leon package solvers -package combinators +package unrolling import purescala.Common._ import purescala.Definitions._ @@ -15,7 +15,7 @@ import purescala.Types._ import purescala.TypeOps.bestRealType import utils._ -import templates._ +import theories._ import evaluators._ import Template._ @@ -55,9 +55,7 @@ trait AbstractUnrollingSolver[T] protected var definitiveModel : Model = Model.empty protected var definitiveCore : Set[Expr] = Set.empty - def check: Option[Boolean] = { - genericCheck(Set.empty) - } + def check: Option[Boolean] = genericCheck(Set.empty) def getModel: Model = if (foundDefinitiveAnswer && definitiveAnswer.getOrElse(false)) { definitiveModel @@ -71,14 +69,14 @@ trait AbstractUnrollingSolver[T] Set.empty } - private val freeVars = new IncrementalMap[Identifier, T]() private val constraints = new IncrementalSeq[Expr]() + private val freeVars = new IncrementalMap[Identifier, T]() protected var interrupted : Boolean = false protected val reporter = context.reporter - lazy val templateGenerator = new TemplateGenerator(templateEncoder, assumePreHolds) + lazy val templateGenerator = new TemplateGenerator(theoryEncoder, templateEncoder, assumePreHolds) lazy val unrollingBank = new UnrollingBank(context, templateGenerator) def push(): Unit = { @@ -110,11 +108,15 @@ trait AbstractUnrollingSolver[T] interrupted = false } - def assertCnstr(expression: Expr, bindings: Map[Identifier, T]): Unit = { + protected def declareVariable(id: Identifier): T + + def assertCnstr(expression: Expr): Unit = { constraints += expression - freeVars ++= bindings + val bindings = variablesOf(expression).map(id => id -> freeVars.cached(id) { + declareVariable(theoryEncoder.encode(id)) + }).toMap - val newClauses = unrollingBank.getClauses(expression, bindings.map { case (k, v) => Variable(k) -> v }) + val newClauses = unrollingBank.getClauses(expression, bindings) for (cl <- newClauses) { solverAssert(cl) } @@ -128,6 +130,8 @@ trait AbstractUnrollingSolver[T] } implicit val printable: T => Printable + + val theoryEncoder: TheoryEncoder val templateEncoder: TemplateEncoder[T] def solverAssert(cnstr: T): Unit @@ -166,8 +170,16 @@ trait AbstractUnrollingSolver[T] def solverUnsatCore: Option[Seq[T]] trait ModelWrapper { - def get(id: Identifier): Option[Expr] - def eval(elem: T, tpe: TypeTree): Option[Expr] + def modelEval(elem: T, tpe: TypeTree): Option[Expr] + + def eval(elem: T, tpe: TypeTree): Option[Expr] = modelEval(elem, theoryEncoder.encode(tpe)).map { + expr => theoryEncoder.decode(expr)(Map.empty) + } + + def get(id: Identifier): Option[Expr] = eval(freeVars(id), theoryEncoder.encode(id.getType)).filter { + case Variable(_) => false + case _ => true + } private[AbstractUnrollingSolver] def extract(b: T, m: Matcher[T]): Option[Seq[Expr]] = { val QuantificationTypeMatcher(fromTypes, _) = m.tpe @@ -215,185 +227,20 @@ trait AbstractUnrollingSolver[T] private def getPartialModel: PartialModel = { val wrapped = solverGetModel - - val typeInsts = templateGenerator.manager.typeInstantiations - val partialInsts = templateGenerator.manager.partialInstantiations - val lambdaInsts = templateGenerator.manager.lambdaInstantiations - - val typeDomains: Map[TypeTree, Set[Seq[Expr]]] = typeInsts.map { - case (tpe, domain) => tpe -> domain.flatMap { case (b, m) => wrapped.extract(b, m) }.toSet - } - - val funDomains: Map[Identifier, Set[Seq[Expr]]] = freeVars.toMap.map { case (id, idT) => - id -> partialInsts.get(idT).toSeq.flatten.flatMap { case (b, m) => wrapped.extract(b, m) }.toSet - } - - val lambdaDomains: Map[Lambda, Set[Seq[Expr]]] = lambdaInsts.map { - case (l, domain) => l -> domain.flatMap { case (b, m) => wrapped.extract(b, m) }.toSet - } - - val model = new Model(freeVars.toMap.map { case (id, _) => - val value = wrapped.get(id).getOrElse(simplestValue(id.getType)) - id -> (funDomains.get(id) match { - case Some(domain) => - val dflt = value match { - case FiniteLambda(_, dflt, _) => dflt - case Lambda(_, IfExpr(_, _, dflt)) => dflt - case _ => scala.sys.error("Can't extract default from " + value) - } - - FiniteLambda(domain.toSeq.map { es => - val optEv = evaluator.eval(application(value, es)).result - es -> optEv.getOrElse(scala.sys.error("Unexpectedly failed to evaluate " + application(value, es))) - }, dflt, id.getType.asInstanceOf[FunctionType]) - - case None => postMap { - case p @ FiniteLambda(mapping, dflt, tpe) => - Some(FiniteLambda(typeDomains.get(tpe) match { - case Some(domain) => domain.toSeq.map { es => - val optEv = evaluator.eval(application(value, es)).result - es -> optEv.getOrElse(scala.sys.error("Unexpectedly failed to evaluate " + application(value, es))) - } - case _ => Seq.empty - }, dflt, tpe)) - case _ => None - } (value) - }) - }) - - val domains = new Domains(lambdaDomains, typeDomains) - new PartialModel(model.toMap, domains) + val view = templateGenerator.manager.getModel(freeVars.toMap, evaluator, wrapped.get, wrapped.eval) + view.getPartialModel } private def getTotalModel: Model = { val wrapped = solverGetModel - - def checkForalls(quantified: Set[Identifier], body: Expr): Option[String] = { - val matchers = collect[(Expr, Seq[Expr])] { - case QuantificationMatcher(e, args) => Set(e -> args) - case _ => Set.empty - } (body) - - if (matchers.isEmpty) - return Some("No matchers found.") - - val matcherToQuants = matchers.foldLeft(Map.empty[Expr, Set[Identifier]]) { - case (acc, (m, args)) => acc + (m -> (acc.getOrElse(m, Set.empty) ++ args.flatMap { - case Variable(id) if quantified(id) => Set(id) - case _ => Set.empty[Identifier] - })) - } - - val bijectiveMappings = matcherToQuants.filter(_._2.nonEmpty).groupBy(_._2) - if (bijectiveMappings.size > 1) - return Some("Non-bijective mapping for symbol " + bijectiveMappings.head._2.head._1.asString) - - def quantifiedArg(e: Expr): Boolean = e match { - case Variable(id) => quantified(id) - case QuantificationMatcher(_, args) => args.forall(quantifiedArg) - case _ => false - } - - postTraversal(m => m match { - case QuantificationMatcher(_, args) => - val qArgs = args.filter(quantifiedArg) - - if (qArgs.nonEmpty && qArgs.size < args.size) - return Some("Mixed ground and quantified arguments in " + m.asString) - - case Operator(es, _) if es.collect { case Variable(id) if quantified(id) => id }.nonEmpty => - return Some("Invalid operation on quantifiers " + m.asString) - - case (_: Equals) | (_: And) | (_: Or) | (_: Implies) => // OK - - case Operator(es, _) if (es.flatMap(variablesOf).toSet & quantified).nonEmpty => - return Some("Unandled implications from operation " + m.asString) - - case _ => - }) (body) - - body match { - case Variable(id) if quantified(id) => - Some("Unexpected free quantifier " + id.asString) - case _ => None - } - } - - val issues: Iterable[(Seq[Identifier], Expr, String)] = for { - q <- templateGenerator.manager.quantifications.view - if wrapped.eval(q.holds, BooleanType) == Some(BooleanLiteral(true)) - msg <- checkForalls(q.quantifiers.map(_._1).toSet, q.body) - } yield (q.quantifiers.map(_._1), q.body, msg) - - if (issues.nonEmpty) { - val (quantifiers, body, msg) = issues.head - reporter.warning("Model soundness not guaranteed for \u2200" + - quantifiers.map(_.asString).mkString(",") + ". " + body.asString+" :\n => " + msg) - } - - val typeInsts = templateGenerator.manager.typeInstantiations - val partialInsts = templateGenerator.manager.partialInstantiations - - def extractCond(params: Seq[Identifier], args: Seq[(T, Expr)], structure: Map[T, Identifier]): Seq[Expr] = (params, args) match { - case (id +: rparams, (v, arg) +: rargs) => - if (templateGenerator.manager.isQuantifier(v)) { - structure.get(v) match { - case Some(pid) => Equals(Variable(id), Variable(pid)) +: extractCond(rparams, rargs, structure) - case None => extractCond(rparams, rargs, structure + (v -> id)) - } - } else { - Equals(Variable(id), arg) +: extractCond(rparams, rargs, structure) - } - case _ => Seq.empty - } - - new Model(freeVars.toMap.map { case (id, idT) => - val value = wrapped.get(id).getOrElse(simplestValue(id.getType)) - id -> (id.getType match { - case FunctionType(from, to) => - val params = from.map(tpe => FreshIdentifier("x", tpe, true)) - val domain = partialInsts.get(idT).orElse(typeInsts.get(bestRealType(id.getType))).toSeq.flatten - val conditionals = domain.flatMap { case (b, m) => - wrapped.extract(b, m).map { args => - val result = evaluator.eval(application(value, args)).result.getOrElse { - scala.sys.error("Unexpectedly failed to evaluate " + application(value, args)) - } - - val cond = if (m.args.exists(arg => templateGenerator.manager.isQuantifier(arg.encoded))) { - extractCond(params, m.args.map(_.encoded) zip args, Map.empty) - } else { - (params zip args).map(p => Equals(Variable(p._1), p._2)) - } - - cond -> result - } - } - - val filteredConds = conditionals - .foldLeft(Map.empty[Seq[Expr], Expr]) { case (mapping, (conds, result)) => - if (mapping.isDefinedAt(conds)) mapping else mapping + (conds -> result) - } - - if (filteredConds.isEmpty) { - // TODO: warning?? - value - } else { - val rest :+ ((_, dflt)) = filteredConds.toSeq.sortBy(_._1.size) - val body = rest.foldLeft(dflt) { case (elze, (conds, res)) => - if (conds.isEmpty) elze else IfExpr(andJoin(conds), res, elze) - } - - Lambda(params.map(ValDef(_)), body) - } - - case _ => value - }) - }) + val view = templateGenerator.manager.getModel(freeVars.toMap, evaluator, wrapped.get, wrapped.eval) + view.getTotalModel } def genericCheck(assumptions: Set[Expr]): Option[Boolean] = { foundDefinitiveAnswer = false + // TODO: theory encoder for assumptions!? val encoder = templateGenerator.encoder.encodeExpr(freeVars.toMap) _ val assumptionsSeq : Seq[Expr] = assumptions.toSeq val encodedAssumptions : Seq[T] = assumptionsSeq.map(encoder) @@ -407,7 +254,7 @@ trait AbstractUnrollingSolver[T] }).toSet } - while(!foundDefinitiveAnswer && !interrupted) { + while (!foundDefinitiveAnswer && !interrupted) { reporter.debug(" - Running search...") var quantify = false @@ -430,7 +277,7 @@ trait AbstractUnrollingSolver[T] } else if (partialModels) { (true, getPartialModel) } else { - val clauses = templateGenerator.manager.checkClauses + val clauses = unrollingBank.getFiniteRangeClauses if (clauses.isEmpty) { (true, extractModel(solverGetModel)) } else { @@ -473,8 +320,10 @@ trait AbstractUnrollingSolver[T] if (valid) { foundAnswer(Some(true), model) } else { - reporter.error("Something went wrong. The model should have been valid, yet we got this : ") - reporter.error(model.asString(context)) + reporter.error( + "Something went wrong. The model should have been valid, yet we got this: " + + model.asString(context) + + " for formula " + andJoin(assumptionsSeq ++ constraints).asString) foundAnswer(None, model) } } @@ -534,11 +383,24 @@ trait AbstractUnrollingSolver[T] } case Some(true) => - if (this.feelingLucky && !interrupted) { - // we might have been lucky :D - val model = extractModel(solverGetModel) - val valid = validateModel(model, assumptionsSeq, silenceErrors = true) - if (valid) foundAnswer(Some(true), model) + if (!interrupted) { + val model = solverGetModel + + if (this.feelingLucky) { + // we might have been lucky :D + val extracted = extractModel(model) + val valid = validateModel(extracted, assumptionsSeq, silenceErrors = true) + if (valid) foundAnswer(Some(true), extracted) + } + + if (!foundDefinitiveAnswer) { + val promote = templateGenerator.manager.getBlockersToPromote(model.eval) + if (promote.nonEmpty) { + unrollingBank.decreaseAllGenerations() + + for (b <- promote) unrollingBank.promoteBlocker(b, force = true) + } + } } case None => @@ -584,8 +446,12 @@ trait AbstractUnrollingSolver[T] } } -class UnrollingSolver(val context: LeonContext, val program: Program, underlying: Solver) - extends AbstractUnrollingSolver[Expr] { +class UnrollingSolver( + val context: LeonContext, + val program: Program, + underlying: Solver, + theories: TheoryEncoder = new NoEncoder +) extends AbstractUnrollingSolver[Expr] { override val name = "U:"+underlying.name @@ -596,10 +462,7 @@ class UnrollingSolver(val context: LeonContext, val program: Program, underlying val printable = (e: Expr) => e val templateEncoder = new TemplateEncoder[Expr] { - def encodeId(id: Identifier): Expr= { - Variable(id.freshen) - } - + def encodeId(id: Identifier): Expr= Variable(id.freshen) def encodeExpr(bindings: Map[Identifier, Expr])(e: Expr): Expr = { replaceFromIDs(bindings, e) } @@ -620,11 +483,11 @@ class UnrollingSolver(val context: LeonContext, val program: Program, underlying } } + val theoryEncoder = theories + val solver = underlying - def assertCnstr(expression: Expr): Unit = { - assertCnstr(expression, variablesOf(expression).map(id => id -> id.toVariable).toMap) - } + def declareVariable(id: Identifier): Variable = id.toVariable def solverAssert(cnstr: Expr): Unit = { solver.assertCnstr(cnstr) @@ -643,8 +506,7 @@ class UnrollingSolver(val context: LeonContext, val program: Program, underlying def solverGetModel: ModelWrapper = new ModelWrapper { val model = solver.getModel - def get(id: Identifier): Option[Expr] = model.get(id) - def eval(elem: Expr, tpe: TypeTree): Option[Expr] = evaluator.eval(elem, model).result + def modelEval(elem: Expr, tpe: TypeTree): Option[Expr] = evaluator.eval(elem, model).result override def toString = model.toMap.mkString("\n") } diff --git a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala index 01a86e2b91289487940f5875becdd30e05716b63..41a8f3722e8d30d9d65d5fb7efe64247de963afb 100644 --- a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala @@ -19,6 +19,8 @@ import purescala.Types._ case class UnsoundExtractionException(ast: Z3AST, msg: String) extends Exception("Can't extract " + ast + " : " + msg) +object AbstractZ3Solver + // This is just to factor out the things that are common in "classes that deal // with a Z3 instance" trait AbstractZ3Solver extends Solver { @@ -45,8 +47,19 @@ trait AbstractZ3Solver extends Solver { } } - protected[leon] val z3cfg : Z3Config - protected[leon] var z3 : Z3Context = null + // FIXME: (dirty?) hack to bypass z3lib bug. + // Uses the unique AbstractZ3Solver to ensure synchronization (no assumption on context). + protected[leon] val z3cfg : Z3Config = + AbstractZ3Solver.synchronized(new Z3Config( + "MODEL" -> true, + "TYPE_CHECK" -> true, + "WELL_SORTED_CHECK" -> true + )) + toggleWarningMessages(true) + + protected[leon] var z3 : Z3Context = null + + lazy protected val solver = z3.mkSolver() override def free(): Unit = { freed = true @@ -73,28 +86,21 @@ trait AbstractZ3Solver extends Solver { } } - def genericValueToDecl(gv: GenericValue): Z3FuncDecl = { - generics.cachedB(gv) { - z3.mkFreshFuncDecl(gv.tp.id.uniqueName+"#"+gv.id+"!val", Seq(), typeToSort(gv.tp)) - } - } - // ADT Manager protected val adtManager = new ADTManager(context) // Bijections between Leon Types/Functions/Ids to Z3 Sorts/Decls/ASTs protected val functions = new IncrementalBijection[TypedFunDef, Z3FuncDecl]() - protected val generics = new IncrementalBijection[GenericValue, Z3FuncDecl]() protected val lambdas = new IncrementalBijection[FunctionType, Z3FuncDecl]() protected val sorts = new IncrementalBijection[TypeTree, Z3Sort]() protected val variables = new IncrementalBijection[Expr, Z3AST]() - protected val constructors = new IncrementalBijection[TypeTree, Z3FuncDecl]() - protected val selectors = new IncrementalBijection[(TypeTree, Int), Z3FuncDecl]() - protected val testers = new IncrementalBijection[TypeTree, Z3FuncDecl]() + protected val constructors = new IncrementalBijection[TypeTree, Z3FuncDecl]() + protected val selectors = new IncrementalBijection[(TypeTree, Int), Z3FuncDecl]() + protected val testers = new IncrementalBijection[TypeTree, Z3FuncDecl]() var isInitialized = false - protected[leon] def initZ3() { + protected[leon] def initZ3(): Unit = { if (!isInitialized) { val timer = context.timers.solvers.z3.init.start() @@ -102,7 +108,6 @@ trait AbstractZ3Solver extends Solver { functions.clear() lambdas.clear() - generics.clear() sorts.clear() variables.clear() constructors.clear() @@ -117,11 +122,7 @@ trait AbstractZ3Solver extends Solver { } } - protected[leon] def restartZ3() { - isInitialized = false - - initZ3() - } + initZ3() def rootType(ct: TypeTree): TypeTree = ct match { case ct: ClassType => ct.root @@ -218,7 +219,7 @@ trait AbstractZ3Solver extends Solver { case Int32Type | BooleanType | IntegerType | RealType | CharType => sorts.toB(oldtt) - case tpe @ (_: ClassType | _: ArrayType | _: TupleType | UnitType) => + case tpe @ (_: ClassType | _: ArrayType | _: TupleType | _: TypeParameter | UnitType) => sorts.cachedB(tpe) { declareStructuralSort(tpe) } @@ -239,14 +240,6 @@ trait AbstractZ3Solver extends Solver { z3.mkArraySort(fromSort, toSort) } - case tt @ TypeParameter(id) => - sorts.cachedB(tt) { - val symbol = z3.mkFreshStringSymbol(id.name) - val newTPSort = z3.mkUninterpretedSort(symbol) - - newTPSort - } - case ft @ FunctionType(from, to) => sorts.cachedB(ft) { val symbol = z3.mkFreshStringSymbol(ft.toString) @@ -259,7 +252,7 @@ trait AbstractZ3Solver extends Solver { protected[leon] def toZ3Formula(expr: Expr, initialMap: Map[Identifier, Z3AST] = Map.empty): Z3AST = { - var z3Vars: Map[Identifier,Z3AST] = if(initialMap.nonEmpty) { + var z3Vars: Map[Identifier,Z3AST] = if (initialMap.nonEmpty) { initialMap } else { // FIXME TODO pleeeeeeeease make this cleaner. Ie. decide what set of @@ -531,9 +524,10 @@ trait AbstractZ3Solver extends Solver { z3.mkStore(m, rec(k), rec(CaseClass(library.someType(t), Seq(v)))) } - case gv @ GenericValue(tp, id) => - z3.mkApp(genericValueToDecl(gv)) + typeToSort(tp) + val constructor = constructors.toB(tp) + constructor(rec(InfiniteIntegerLiteral(id))) case other => unsupported(other) @@ -607,8 +601,6 @@ trait AbstractZ3Solver extends Solver { val tfd = functions.toA(decl) assert(tfd.params.size == argsSize) FunctionInvocation(tfd, args.zip(tfd.params).map{ case (a, p) => rec(a, p.getType) }) - } else if (generics containsB decl) { - generics.toA(decl) } else if (constructors containsB decl) { constructors.toA(decl) match { case cct: CaseClassType => @@ -640,6 +632,13 @@ trait AbstractZ3Solver extends Solver { case (s : IntLiteral, arr) => unsound(args(1), "invalid array type") case (size, _) => unsound(args(0), "invalid array size") } + + case tp @ TypeParameter(id) => + val InfiniteIntegerLiteral(n) = rec(args(0), IntegerType) + GenericValue(tp, n.toInt) + + case t => + unsupported(t, "Woot? structural type that is non-structural") } } else { tpe match { @@ -671,10 +670,6 @@ trait AbstractZ3Solver extends Solver { } } - case tp: TypeParameter => - val id = t.toString.split("!").last.toInt - GenericValue(tp, id) - case MapType(from, to) => rec(t, RawArrayType(from, library.optionType(to))) match { case r: RawArrayValue => diff --git a/src/main/scala/leon/solvers/z3/FairZ3Solver.scala b/src/main/scala/leon/solvers/z3/FairZ3Solver.scala index fa3352b1ef2b5810afd063e0eb28f503e9d4a57d..f8237a4885fdc459efe261bb4333afc36944b500 100644 --- a/src/main/scala/leon/solvers/z3/FairZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/FairZ3Solver.scala @@ -4,7 +4,6 @@ package leon package solvers package z3 -import utils._ import _root_.z3.scala._ import purescala.Common._ @@ -13,8 +12,9 @@ import purescala.Expressions._ import purescala.ExprOps._ import purescala.Types._ -import solvers.templates._ -import solvers.combinators._ +import unrolling._ +import theories._ +import utils._ class FairZ3Solver(val context: LeonContext, val program: Program) extends AbstractZ3Solver @@ -32,13 +32,9 @@ class FairZ3Solver(val context: LeonContext, val program: Program) override protected val reporter = context.reporter override def reset(): Unit = super[AbstractZ3Solver].reset() - // FIXME: Dirty hack to bypass z3lib bug. Assumes context is the same over all instances of FairZ3Solver - protected[leon] val z3cfg = context.synchronized { new Z3Config( - "MODEL" -> true, - "TYPE_CHECK" -> true, - "WELL_SORTED_CHECK" -> true - )} - toggleWarningMessages(true) + def declareVariable(id: Identifier): Z3AST = variables.cachedB(Variable(id)) { + templateEncoder.encodeId(id) + } def solverCheck[R](clauses: Seq[Z3AST])(block: Option[Boolean] => R): R = { solver.push() @@ -88,14 +84,7 @@ class FairZ3Solver(val context: LeonContext, val program: Program) val fullModel = leonModel ++ (functionsAsMap ++ constantFunctionsAsMap) */ - def get(id: Identifier): Option[Expr] = variables.getB(id.toVariable).flatMap { - z3ID => eval(z3ID, id.getType) match { - case Some(Variable(id)) => None - case e => e - } - } - - def eval(elem: Z3AST, tpe: TypeTree): Option[Expr] = tpe match { + def modelEval(elem: Z3AST, tpe: TypeTree): Option[Expr] = tpe match { case BooleanType => model.evalAs[Boolean](elem).map(BooleanLiteral) case Int32Type => model.evalAs[Int](elem).map(IntLiteral).orElse { model.eval(elem).flatMap(t => softFromZ3Formula(model, t, Int32Type)) @@ -114,6 +103,8 @@ class FairZ3Solver(val context: LeonContext, val program: Program) def asString(implicit ctx: LeonContext) = z3.toString } + val theoryEncoder = new StringEncoder + val templateEncoder = new TemplateEncoder[Z3AST] { def encodeId(id: Identifier): Z3AST = { idToFreshZ3Id(id) @@ -145,12 +136,8 @@ class FairZ3Solver(val context: LeonContext, val program: Program) } } - initZ3() - - val solver = z3.mkSolver() - private val incrementals: List[IncrementalState] = List( - errors, functions, generics, lambdas, sorts, variables, + errors, functions, lambdas, sorts, variables, constructors, selectors, testers ) @@ -182,13 +169,9 @@ class FairZ3Solver(val context: LeonContext, val program: Program) } } - def assertCnstr(expression: Expr): Unit = { + override def assertCnstr(expression: Expr): Unit = { try { - val bindings = variablesOf(expression).map(id => id -> variables.cachedB(Variable(id)) { - templateGenerator.encoder.encodeId(id) - }).toMap - - assertCnstr(expression, bindings) + super.assertCnstr(expression) } catch { case _: Unsupported => addError() diff --git a/src/main/scala/leon/solvers/z3/UninterpretedZ3Solver.scala b/src/main/scala/leon/solvers/z3/UninterpretedZ3Solver.scala index 6f40151cf396348822e2ab3b6e5cdde07a6e8b93..829c773438dd8a1509ab08560a730cf1ede86b18 100644 --- a/src/main/scala/leon/solvers/z3/UninterpretedZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/UninterpretedZ3Solver.scala @@ -29,18 +29,6 @@ class UninterpretedZ3Solver(val context : LeonContext, val program: Program) val name = "Z3-u" val description = "Uninterpreted Z3 Solver" - // this is fixed - protected[leon] val z3cfg = new Z3Config( - "MODEL" -> true, - "TYPE_CHECK" -> true, - "WELL_SORTED_CHECK" -> true - ) - toggleWarningMessages(true) - - initZ3() - - val solver = z3.mkSolver() - def push() { solver.push() freeVariables.push() diff --git a/src/main/scala/leon/solvers/z3/Z3StringConversion.scala b/src/main/scala/leon/solvers/z3/Z3StringConversion.scala deleted file mode 100644 index fa262caa6ede90a8c5196090364a80a0be9dc0fc..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/solvers/z3/Z3StringConversion.scala +++ /dev/null @@ -1,379 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package solvers -package z3 - -import purescala.Common._ -import purescala.Expressions._ -import purescala.Constructors._ -import purescala.Types._ -import purescala.Definitions._ -import leon.utils.Bijection -import leon.purescala.DefOps -import leon.purescala.TypeOps -import leon.purescala.Extractors.Operator -import leon.evaluators.EvaluationResults - -object StringEcoSystem { - private def withIdentifier[T](name: String, tpe: TypeTree = Untyped)(f: Identifier => T): T = { - val id = FreshIdentifier(name, tpe) - f(id) - } - private def withIdentifiers[T](name: String, tpe: TypeTree, name2: String, tpe2: TypeTree = Untyped)(f: (Identifier, Identifier) => T): T = { - withIdentifier(name, tpe)(id => withIdentifier(name2, tpe2)(id2 => f(id, id2))) - } - - val StringList = AbstractClassDef(FreshIdentifier("StringList"), Seq(), None) - val StringListTyped = StringList.typed - val StringCons = withIdentifiers("head", CharType, "tail", StringListTyped){ (head, tail) => - val d = CaseClassDef(FreshIdentifier("StringCons"), Seq(), Some(StringListTyped), false) - d.setFields(Seq(ValDef(head), ValDef(tail))) - d - } - StringList.registerChild(StringCons) - val StringConsTyped = StringCons.typed - val StringNil = CaseClassDef(FreshIdentifier("StringNil"), Seq(), Some(StringListTyped), false) - val StringNilTyped = StringNil.typed - StringList.registerChild(StringNil) - - val StringSize = withIdentifiers("l", StringListTyped, "StringSize"){ (lengthArg, id) => - val fd = new FunDef(id, Seq(), Seq(ValDef(lengthArg)), IntegerType) - fd.body = Some(withIdentifiers("h", CharType, "t", StringListTyped){ (h, t) => - MatchExpr(Variable(lengthArg), Seq( - MatchCase(CaseClassPattern(None, StringNilTyped, Seq()), None, InfiniteIntegerLiteral(BigInt(0))), - MatchCase(CaseClassPattern(None, StringConsTyped, Seq(WildcardPattern(Some(h)), WildcardPattern(Some(t)))), None, - Plus(InfiniteIntegerLiteral(BigInt(1)), FunctionInvocation(fd.typed, Seq(Variable(t))))) - )) - }) - fd - } - val StringListConcat = withIdentifiers("x", StringListTyped, "y", StringListTyped) { (x, y) => - val fd = new FunDef(FreshIdentifier("StringListConcat"), Seq(), Seq(ValDef(x), ValDef(y)), StringListTyped) - fd.body = Some( - withIdentifiers("h", CharType, "t", StringListTyped){ (h, t) => - MatchExpr(Variable(x), Seq( - MatchCase(CaseClassPattern(None, StringNilTyped, Seq()), None, Variable(y)), - MatchCase(CaseClassPattern(None, StringConsTyped, Seq(WildcardPattern(Some(h)), WildcardPattern(Some(t)))), None, - CaseClass(StringConsTyped, Seq(Variable(h), FunctionInvocation(fd.typed, Seq(Variable(t), Variable(y))))) - ))) - } - ) - fd - } - - val StringTake = withIdentifiers("tt", StringListTyped, "it", StringListTyped) { (tt, it) => - val fd = new FunDef(FreshIdentifier("StringTake"), Seq(), Seq(ValDef(tt), ValDef(it)), StringListTyped) - fd.body = Some{ - withIdentifiers("h", CharType, "t", StringListTyped) { (h, t) => - withIdentifier("i", IntegerType){ i => - MatchExpr(Tuple(Seq(Variable(tt), Variable(it))), Seq( - MatchCase(TuplePattern(None, Seq(CaseClassPattern(None, StringNilTyped, Seq()), WildcardPattern(None))), None, - InfiniteIntegerLiteral(BigInt(0))), - MatchCase(TuplePattern(None, Seq(CaseClassPattern(None, StringConsTyped, Seq(WildcardPattern(Some(h)), WildcardPattern(Some(t)))), WildcardPattern(Some(i)))), None, - IfExpr(LessThan(Variable(i), InfiniteIntegerLiteral(BigInt(0))), - CaseClass(StringNilTyped, Seq()), - CaseClass(StringConsTyped, Seq(Variable(h), - FunctionInvocation(fd.typed, Seq(Variable(t), Minus(Variable(i), InfiniteIntegerLiteral(BigInt(1))))))) - )))) - } - } - } - fd - } - - val StringDrop = withIdentifiers("td", StringListTyped, "id", IntegerType) { (td, id) => - val fd = new FunDef(FreshIdentifier("StringDrop"), Seq(), Seq(ValDef(td), ValDef(id)), StringListTyped) - fd.body = Some( - withIdentifiers("h", CharType, "t", StringListTyped) { (h, t) => - withIdentifier("i", IntegerType){ i => - MatchExpr(Tuple(Seq(Variable(td), Variable(id))), Seq( - MatchCase(TuplePattern(None, Seq(CaseClassPattern(None, StringNilTyped, Seq()), WildcardPattern(None))), None, - InfiniteIntegerLiteral(BigInt(0))), - MatchCase(TuplePattern(None, Seq(CaseClassPattern(None, StringConsTyped, Seq(WildcardPattern(Some(h)), WildcardPattern(Some(t)))), WildcardPattern(Some(i)))), None, - IfExpr(LessThan(Variable(i), InfiniteIntegerLiteral(BigInt(0))), - CaseClass(StringConsTyped, Seq(Variable(h), Variable(t))), - FunctionInvocation(fd.typed, Seq(Variable(t), Minus(Variable(i), InfiniteIntegerLiteral(BigInt(1))))) - )))) - }} - ) - fd - } - - val StringSlice = withIdentifier("s", StringListTyped) { s => withIdentifiers("from", IntegerType, "to", IntegerType) { (from, to) => - val fd = new FunDef(FreshIdentifier("StringSlice"), Seq(), Seq(ValDef(s), ValDef(from), ValDef(to)), StringListTyped) - fd.body = Some( - FunctionInvocation(StringTake.typed, - Seq(FunctionInvocation(StringDrop.typed, Seq(Variable(s), Variable(from))), - Minus(Variable(to), Variable(from))))) - fd - } } - - val classDefs = Seq(StringList, StringCons, StringNil) - val funDefs = Seq(StringSize, StringListConcat, StringTake, StringDrop, StringSlice) -} - -class Z3StringConversion(val p: Program) extends Z3StringConverters { - import StringEcoSystem._ - def getProgram = program_with_string_methods - - lazy val program_with_string_methods = { - val p2 = DefOps.addClassDefs(p, StringEcoSystem.classDefs, p.library.Nil.get) - DefOps.addFunDefs(p2, StringEcoSystem.funDefs, p2.library.escape.get) - } -} - -trait Z3StringConverters { - import StringEcoSystem._ - val mappedVariables = new Bijection[Identifier, Identifier]() - - val globalClassMap = new Bijection[ClassDef, ClassDef]() // To be added manually - - val globalFdMap = new Bijection[FunDef, FunDef]() - - val stringBijection = new Bijection[String, Expr]() - - def convertToString(e: Expr): String = - stringBijection.cachedA(e) { - e match { - case CaseClass(_, Seq(CharLiteral(c), l)) => c + convertToString(l) - case CaseClass(_, Seq()) => "" - } - } - def convertFromString(v: String): Expr = - stringBijection.cachedB(v) { - v.toList.foldRight(CaseClass(StringNilTyped, Seq())){ - case (char, l) => CaseClass(StringConsTyped, Seq(CharLiteral(char), l)) - } - } - - trait BidirectionalConverters { - def convertFunDef(fd: FunDef): FunDef - def hasIdConversion(id: Identifier): Boolean - def convertId(id: Identifier): Identifier - def convertClassDef(d: ClassDef): ClassDef - def isTypeToConvert(tpe: TypeTree): Boolean - def convertType(tpe: TypeTree): TypeTree - def convertPattern(pattern: Pattern): Pattern - def convertExpr(expr: Expr)(implicit bindings: Map[Identifier, Expr]): Expr - object TypeConverted { - def unapply(t: TypeTree): Option[TypeTree] = Some(t match { - case cct@CaseClassType(ccd, args) => CaseClassType(convertClassDef(ccd).asInstanceOf[CaseClassDef], args map convertType) - case act@AbstractClassType(acd, args) => AbstractClassType(convertClassDef(acd).asInstanceOf[AbstractClassDef], args map convertType) - case NAryType(es, builder) => - builder(es map convertType) - }) - } - object PatternConverted { - def unapply(e: Pattern): Option[Pattern] = Some(e match { - case InstanceOfPattern(binder, ct) => - InstanceOfPattern(binder.map(convertId), convertType(ct).asInstanceOf[ClassType]) - case WildcardPattern(binder) => - WildcardPattern(binder.map(convertId)) - case CaseClassPattern(binder, ct, subpatterns) => - CaseClassPattern(binder.map(convertId), convertType(ct).asInstanceOf[CaseClassType], subpatterns map convertPattern) - case TuplePattern(binder, subpatterns) => - TuplePattern(binder.map(convertId), subpatterns map convertPattern) - case UnapplyPattern(binder, TypedFunDef(fd, tpes), subpatterns) => - UnapplyPattern(binder.map(convertId), TypedFunDef(convertFunDef(fd), tpes map convertType), subpatterns map convertPattern) - case PatternExtractor(es, builder) => - builder(es map convertPattern) - }) - } - - object ExprConverted { - def unapply(e: Expr)(implicit bindings: Map[Identifier, Expr]): Option[Expr] = Some(e match { - case Variable(id) if bindings contains id => bindings(id).copiedFrom(e) - case Variable(id) if hasIdConversion(id) => Variable(convertId(id)).copiedFrom(e) - case Variable(id) => e - case pl @ FiniteLambda(mappings, default, tpe) => - FiniteLambda( - mappings.map(kv => (kv._1.map(argtpe => convertExpr(argtpe)), - convertExpr(kv._2))), - convertExpr(default), convertType(tpe).asInstanceOf[FunctionType]) - case Lambda(args, body) => - val new_bindings = scala.collection.mutable.ListBuffer[(Identifier, Identifier)]() - val new_args = for(arg <- args) yield { - val in = arg.getType - val new_id = convertId(arg.id) - if(new_id ne arg.id) { - new_bindings += (arg.id -> new_id) - ValDef(new_id) - } else arg - } - val res = Lambda(new_args, convertExpr(body)(bindings ++ new_bindings.map(t => (t._1, Variable(t._2))))).copiedFrom(e) - res - case Let(a, expr, body) if isTypeToConvert(a.getType) => - val new_a = convertId(a) - val new_bindings = bindings + (a -> Variable(new_a)) - val expr2 = convertExpr(expr)(new_bindings) - val body2 = convertExpr(body)(new_bindings) - Let(new_a, expr2, body2).copiedFrom(e) - case CaseClass(CaseClassType(ccd, tpes), args) => - CaseClass(CaseClassType(ccd, tpes map convertType), args map convertExpr).copiedFrom(e) - case CaseClassSelector(CaseClassType(ccd, tpes), caseClass, selector) => - CaseClassSelector(CaseClassType(ccd, tpes map convertType), convertExpr(caseClass), selector).copiedFrom(e) - case MethodInvocation(rec: Expr, cd: ClassDef, TypedFunDef(fd, tpes), args: Seq[Expr]) => - MethodInvocation(convertExpr(rec), cd, TypedFunDef(convertFunDef(fd), tpes map convertType), args map convertExpr).copiedFrom(e) - case FunctionInvocation(TypedFunDef(fd, tpes), args) => - FunctionInvocation(TypedFunDef(convertFunDef(fd), tpes map convertType), args map convertExpr).copiedFrom(e) - case This(ct: ClassType) => - This(convertType(ct).asInstanceOf[ClassType]).copiedFrom(e) - case IsInstanceOf(expr, ct) => - IsInstanceOf(convertExpr(expr), convertType(ct).asInstanceOf[ClassType]).copiedFrom(e) - case AsInstanceOf(expr, ct) => - AsInstanceOf(convertExpr(expr), convertType(ct).asInstanceOf[ClassType]).copiedFrom(e) - case Tuple(args) => - Tuple(for(arg <- args) yield convertExpr(arg)).copiedFrom(e) - case MatchExpr(scrutinee, cases) => - MatchExpr(convertExpr(scrutinee), for(MatchCase(pattern, guard, rhs) <- cases) yield { - MatchCase(convertPattern(pattern), guard.map(convertExpr), convertExpr(rhs)) - }) - case Operator(es, builder) => - val rec = convertExpr _ - val newEs = es.map(rec) - if ((newEs zip es).exists { case (bef, aft) => aft ne bef }) { - builder(newEs).copiedFrom(e) - } else { - e - } - case e => e - }) - } - - def convertModel(model: Model): Model = { - new Model(model.ids.map{i => - val id = convertId(i) - id -> convertExpr(model(i))(Map()) - }.toMap) - } - - def convertResult(result: EvaluationResults.Result[Expr]) = { - result match { - case EvaluationResults.Successful(e) => EvaluationResults.Successful(convertExpr(e)(Map())) - case result => result - } - } - } - - object Forward extends BidirectionalConverters { - /* The conversion between functions should already have taken place */ - def convertFunDef(fd: FunDef): FunDef = { - globalFdMap.getBorElse(fd, fd) - } - /* The conversion between classdefs should already have taken place */ - def convertClassDef(cd: ClassDef): ClassDef = { - globalClassMap.getBorElse(cd, cd) - } - def hasIdConversion(id: Identifier): Boolean = { - mappedVariables.containsA(id) - } - def convertId(id: Identifier): Identifier = { - mappedVariables.getB(id) match { - case Some(idB) => idB - case None => - if(isTypeToConvert(id.getType)) { - val new_id = FreshIdentifier(id.name, convertType(id.getType)) - mappedVariables += (id -> new_id) - new_id - } else id - } - } - def isTypeToConvert(tpe: TypeTree): Boolean = - TypeOps.exists(StringType == _)(tpe) - def convertType(tpe: TypeTree): TypeTree = tpe match { - case StringType => StringListTyped - case TypeConverted(t) => t - } - def convertPattern(e: Pattern): Pattern = e match { - case LiteralPattern(binder, StringLiteral(s)) => - s.foldRight(CaseClassPattern(None, StringNilTyped, Seq())) { - case (elem, pattern) => - CaseClassPattern(None, StringConsTyped, Seq(LiteralPattern(None, CharLiteral(elem)), pattern)) - } - case PatternConverted(e) => e - } - - /** Method which can use recursively StringConverted in its body in unapply positions */ - def convertExpr(e: Expr)(implicit bindings: Map[Identifier, Expr]): Expr = e match { - case Variable(id) if isTypeToConvert(id.getType) => Variable(convertId(id)).copiedFrom(e) - case StringLiteral(v) => - val stringEncoding = convertFromString(v) - convertExpr(stringEncoding).copiedFrom(e) - case StringLength(a) => - FunctionInvocation(StringSize.typed, Seq(convertExpr(a))).copiedFrom(e) - case StringConcat(a, b) => - FunctionInvocation(StringListConcat.typed, Seq(convertExpr(a), convertExpr(b))).copiedFrom(e) - case SubString(a, start, Plus(start2, length)) if start == start2 => - FunctionInvocation(StringTake.typed, - Seq(FunctionInvocation(StringDrop.typed, Seq(convertExpr(a), convertExpr(start))), convertExpr(length))).copiedFrom(e) - case SubString(a, start, end) => - FunctionInvocation(StringSlice.typed, Seq(convertExpr(a), convertExpr(start), convertExpr(end))).copiedFrom(e) - case MatchExpr(scrutinee, cases) => - MatchExpr(convertExpr(scrutinee), for(MatchCase(pattern, guard, rhs) <- cases) yield { - MatchCase(convertPattern(pattern), guard.map(convertExpr), convertExpr(rhs)) - }) - case ExprConverted(e) => e - } - } - - object Backward extends BidirectionalConverters { - def convertFunDef(fd: FunDef): FunDef = { - globalFdMap.getAorElse(fd, fd) - } - /* The conversion between classdefs should already have taken place */ - def convertClassDef(cd: ClassDef): ClassDef = { - globalClassMap.getAorElse(cd, cd) - } - def hasIdConversion(id: Identifier): Boolean = { - mappedVariables.containsB(id) - } - def convertId(id: Identifier): Identifier = { - mappedVariables.getA(id) match { - case Some(idA) => idA - case None => - if(isTypeToConvert(id.getType)) { - val old_type = convertType(id.getType) - val old_id = FreshIdentifier(id.name, old_type) - mappedVariables += (old_id -> id) - old_id - } else id - } - } - def convertIdToMapping(id: Identifier): (Identifier, Variable) = { - id -> Variable(convertId(id)) - } - def isTypeToConvert(tpe: TypeTree): Boolean = - TypeOps.exists(t => TypeOps.isSubtypeOf(t, StringListTyped))(tpe) - def convertType(tpe: TypeTree): TypeTree = tpe match { - case StringListTyped | StringConsTyped | StringNilTyped => StringType - case TypeConverted(t) => t - } - def convertPattern(e: Pattern): Pattern = e match { - case CaseClassPattern(b, StringNilTyped, Seq()) => - LiteralPattern(b.map(convertId), StringLiteral("")) - case CaseClassPattern(b, StringConsTyped, Seq(LiteralPattern(_, CharLiteral(elem)), subpattern)) => - convertPattern(subpattern) match { - case LiteralPattern(_, StringLiteral(s)) - => LiteralPattern(b.map(convertId), StringLiteral(elem + s)) - case e => LiteralPattern(None, StringLiteral("Failed to parse pattern back as string:" + e)) - } - case PatternConverted(e) => e - } - - def convertExpr(e: Expr)(implicit bindings: Map[Identifier, Expr]): Expr = - e match { - case cc@CaseClass(cct, args) if TypeOps.isSubtypeOf(cct, StringListTyped)=> - StringLiteral(convertToString(cc)) - case FunctionInvocation(StringSize, Seq(a)) => - StringLength(convertExpr(a)).copiedFrom(e) - case FunctionInvocation(StringListConcat, Seq(a, b)) => - StringConcat(convertExpr(a), convertExpr(b)).copiedFrom(e) - case FunctionInvocation(StringTake, - Seq(FunctionInvocation(StringDrop, Seq(a, start)), length)) => - val rstart = convertExpr(start) - SubString(convertExpr(a), rstart, plus(rstart, convertExpr(length))).copiedFrom(e) - case ExprConverted(e) => e - } - } -} diff --git a/src/main/scala/leon/solvers/z3/Z3UnrollingSolver.scala b/src/main/scala/leon/solvers/z3/Z3UnrollingSolver.scala new file mode 100644 index 0000000000000000000000000000000000000000..99c75199a7d96b64ce42ba9039dd648032e60888 --- /dev/null +++ b/src/main/scala/leon/solvers/z3/Z3UnrollingSolver.scala @@ -0,0 +1,13 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package solvers +package z3 + +import purescala.Definitions._ + +import unrolling._ +import theories._ + +class Z3UnrollingSolver(context: LeonContext, program: Program, underlying: Solver) + extends UnrollingSolver(context, program, underlying, theories = new StringEncoder) diff --git a/src/main/scala/leon/synthesis/ExamplesFinder.scala b/src/main/scala/leon/synthesis/ExamplesFinder.scala index 7079f905461f9912973fa7010e9fda809815538e..9308c54d1895440ea5c7abe760bae23e73c17a38 100644 --- a/src/main/scala/leon/synthesis/ExamplesFinder.scala +++ b/src/main/scala/leon/synthesis/ExamplesFinder.scala @@ -117,6 +117,7 @@ class ExamplesFinder(ctx0: LeonContext, program: Program) { } def generateForPC(ids: List[Identifier], pc: Expr, maxValid: Int = 400, maxEnumerated: Int = 1000): ExamplesBank = { + //println(program.definedClasses) val evaluator = new CodeGenEvaluator(ctx, program, CodeGenParams.default) val datagen = new GrammarDataGen(evaluator, ValueGrammar) diff --git a/src/main/scala/leon/synthesis/Synthesizer.scala b/src/main/scala/leon/synthesis/Synthesizer.scala index c81b045d09c1aeccf8e2381afe9d14cb9a11db33..23d79c1e75d7f5d562c69e2e1350d112bb600436 100644 --- a/src/main/scala/leon/synthesis/Synthesizer.scala +++ b/src/main/scala/leon/synthesis/Synthesizer.scala @@ -167,7 +167,7 @@ class Synthesizer(val context : LeonContext, val solutionExpr = sol.toSimplifiedExpr(context, program, ci.fd) - val (npr, fdMap) = replaceFunDefs(program)({ + val (npr, _, fdMap, _) = replaceFunDefs(program)({ case fd if fd eq ci.fd => val nfd = fd.duplicate() nfd.fullBody = replace(Map(ci.source -> solutionExpr), nfd.fullBody) diff --git a/src/main/scala/leon/synthesis/rules/CEGISLike.scala b/src/main/scala/leon/synthesis/rules/CEGISLike.scala index 7ac1efc043cdf890fbf5cbfd5a3d3a935a15e892..238d3198f355b7e0bb8a07c75fbcac938530c7a2 100644 --- a/src/main/scala/leon/synthesis/rules/CEGISLike.scala +++ b/src/main/scala/leon/synthesis/rules/CEGISLike.scala @@ -274,7 +274,7 @@ abstract class CEGISLike(name: String) extends Rule(name) { private val phiFd = new FunDef(FreshIdentifier("phiFd", alwaysShowUniqueID = true), Seq(), p.as.map(id => ValDef(id)), BooleanType) // The program with the body of the current function replaced by the current partial solution - private val (innerProgram, origFdMap) = { + private val (innerProgram, origIdMap, origFdMap, origCdMap) = { val outerSolution = { new PartialSolution(hctx.search.strat, true) @@ -313,7 +313,12 @@ abstract class CEGISLike(name: String) extends Rule(name) { case fd => Some(fd.duplicate()) } + } + private val outerToInner = new purescala.TreeTransformer { + override def transform(id: Identifier): Identifier = origIdMap.getOrElse(id, id) + override def transform(cd: ClassDef): ClassDef = origCdMap.getOrElse(cd, cd) + override def transform(fd: FunDef): FunDef = origFdMap.getOrElse(fd, fd) } /** @@ -322,9 +327,7 @@ abstract class CEGISLike(name: String) extends Rule(name) { * to the CEGIS-specific program, 'outer' refers to the actual program on * which we do synthesis. */ - private def outerExprToInnerExpr(e: Expr): Expr = { - replaceFunCalls(e, {fd => origFdMap.getOrElse(fd, fd) }) - } + private def outerExprToInnerExpr(e: Expr): Expr = outerToInner.transform(e)(Map.empty) private val innerPc = outerExprToInnerExpr(p.pc) private val innerPhi = outerExprToInnerExpr(p.phi) diff --git a/src/main/scala/leon/synthesis/rules/StringRender.scala b/src/main/scala/leon/synthesis/rules/StringRender.scala index ec681763eec062c7b6a29267f9ed209b4646460c..b868948676ceb51e9b92b09fa21ce8297087bc60 100644 --- a/src/main/scala/leon/synthesis/rules/StringRender.scala +++ b/src/main/scala/leon/synthesis/rules/StringRender.scala @@ -22,7 +22,6 @@ import solvers.string.StringSolver import programsets.DirectProgramSet import programsets.JoinProgramSet - /** A template generator for a given type tree. * Extend this class using a concrete type tree, * Then use the apply method to get a hole which can be a placeholder for holes in the template. @@ -210,12 +209,12 @@ case object StringRender extends Rule("StringRender") { def askQuestion(input: List[Identifier], r: RuleClosed)(implicit c: LeonContext, p: Program): List[disambiguation.Question[StringLiteral]] = { //if !s.contains(EDIT_ME) val qb = new disambiguation.QuestionBuilder(input, r.solutions, (seq: Seq[Expr], expr: Expr) => expr match { - case s@StringLiteral(slv) if !slv.contains(EDIT_ME) => Some(s) + case s @ StringLiteral(slv) if !slv.contains(EDIT_ME) => Some(s) case _ => None }) qb.result() } - + /** Converts the stream of solutions to a RuleApplication */ def solutionStreamToRuleApplication(p: Problem, solutions: Stream[(Seq[(FunDef, WithIds[Expr])], WithIds[Expr], Assignment)])(implicit program: Program): RuleApplication = { if(solutions.isEmpty) RuleFailed() else { @@ -361,8 +360,8 @@ case object StringRender extends Rule("StringRender") { def extractCaseVariants(cct: CaseClassType, ctx: StringSynthesisContext) : (Stream[WithIds[MatchCase]], StringSynthesisResult) = cct match { - case CaseClassType(ccd@CaseClassDef(id, tparams, parent, isCaseObject), tparams2) => - val typeMap = tparams.zip(tparams2).toMap + case CaseClassType(ccd: CaseClassDef, tparams2) => + val typeMap = ccd.tparams.zip(tparams2).toMap val fields = ccd.fields.map(vd => TypeOps.instantiateType(vd.id, typeMap) ) val pattern = CaseClassPattern(None, ccd.typed(tparams2), fields.map(k => WildcardPattern(Some(k)))) val (rhs, result) = createFunDefsTemplates(ctx.copy(currentCaseClassParent=Some(cct)), fields.map(Variable)) // Invoke functions for each of the fields. @@ -387,11 +386,11 @@ case object StringRender extends Rule("StringRender") { */ def constantPatternMatching(fd: FunDef, act: AbstractClassType): WithIds[MatchExpr] = { val cases = (ListBuffer[WithIds[MatchCase]]() /: act.knownCCDescendants) { - case (acc, cct@CaseClassType(ccd@CaseClassDef(id, tparams, parent, isCaseObject), tparams2)) => - val typeMap = tparams.zip(tparams2).toMap + case (acc, cct @ CaseClassType(ccd, tparams2)) => + val typeMap = ccd.tparams.zip(tparams2).toMap val fields = ccd.fields.map(vd => TypeOps.instantiateType(vd.id, typeMap) ) val pattern = CaseClassPattern(None, ccd.typed(tparams2), fields.map(k => WildcardPattern(Some(k)))) - val rhs = StringLiteral(id.asString) + val rhs = StringLiteral(ccd.id.asString) MatchCase(pattern, None, rhs) acc += ((MatchCase(pattern, None, rhs), Nil)) case (acc, e) => hctx.reporter.fatalError("Could not handle this class definition for string rendering " + e) @@ -458,17 +457,17 @@ case object StringRender extends Rule("StringRender") { val fd = createEmptyFunDef(ctx, dependentType) val ctx2 = preUpdateFunDefBody(dependentType, fd, ctx) // Inserts the FunDef in the assignments so that it can already be used. t.root match { - case act@AbstractClassType(acd@AbstractClassDef(id, tparams, parent), tps) => + case act @ AbstractClassType(acd, tps) => // Create a complete FunDef body with pattern matching val allKnownDescendantsAreCCAndHaveZeroArgs = act.knownCCDescendants.forall { - case CaseClassType(ccd@CaseClassDef(id, tparams, parent, isCaseObject), tparams2) => ccd.fields.isEmpty + case CaseClassType(ccd, tparams2) => ccd.fields.isEmpty case _ => false } //TODO: Test other templates not only with Wilcard patterns, but more cases options for non-recursive classes (e.g. Option, Boolean, Finite parameterless case classes.) val (ctx3, cases) = ((ctx2, ListBuffer[Stream[WithIds[MatchCase]]]()) /: act.knownCCDescendants) { - case ((ctx22, acc), cct@CaseClassType(ccd@CaseClassDef(id, tparams, parent, isCaseObject), tparams2)) => + case ((ctx22, acc), cct @ CaseClassType(ccd, tparams2)) => val (newCases, result) = extractCaseVariants(cct, ctx22) val ctx23 = ctx22.copy(result = result) (ctx23, acc += newCases) @@ -481,7 +480,7 @@ case object StringRender extends Rule("StringRender") { } else allMatchExprsEnd gatherInputs(ctx3.add(dependentType, fd, allMatchExprs), q, result += Stream((functionInvocation(fd, input::ctx.provided_functions.toList.map(Variable)), Nil))) - case cct@CaseClassType(ccd@CaseClassDef(id, tparams, parent, isCaseObject), tparams2) => + case cct @ CaseClassType(ccd, tparams2) => val (newCases, result3) = extractCaseVariants(cct, ctx2) val allMatchExprs = newCases.map(acase => mergeMatchCases(fd)(Seq(acase))) gatherInputs(ctx2.copy(result = result3).add(dependentType, fd, allMatchExprs), q, @@ -580,4 +579,4 @@ case object StringRender extends Rule("StringRender") { case _ => Nil } } -} \ No newline at end of file +} diff --git a/src/main/scala/leon/termination/ChainProcessor.scala b/src/main/scala/leon/termination/ChainProcessor.scala index c01fa9fb3afe65e5b798b31db3ebc5cf23e9e626..ef6fde2e21638a9f26117f5ef855f7d9d85ec57f 100644 --- a/src/main/scala/leon/termination/ChainProcessor.scala +++ b/src/main/scala/leon/termination/ChainProcessor.scala @@ -58,11 +58,11 @@ class ChainProcessor( if (structuralDecreasing || numericDecreasing) Some(problem.funDefs map Cleared) else { - val chainsUnlooping = chains.flatMap(c1 => chains.flatMap(c2 => c1 compose c2)).forall { - chain => !definitiveSATwithModel(andJoin(chain.loop())).isDefined + val maybeReentrant = chains.flatMap(c1 => chains.flatMap(c2 => c1 compose c2)).exists { + chain => maybeSAT(andJoin(chain.loop())) } - if (chainsUnlooping) + if (!maybeReentrant) Some(problem.funDefs map Cleared) else None diff --git a/src/main/scala/leon/termination/Processor.scala b/src/main/scala/leon/termination/Processor.scala index 1d0527127344510b803a21666f0ade2d92537dff..6f998b0edcfb80ed41110596c3c754e411a71cc6 100644 --- a/src/main/scala/leon/termination/Processor.scala +++ b/src/main/scala/leon/termination/Processor.scala @@ -34,7 +34,7 @@ trait Solvable extends Processor { val sizeUnit : UnitDef = UnitDef(FreshIdentifier("$size"),Seq(sizeModule)) val newProgram : Program = program.copy( units = sizeUnit :: program.units) - SolverFactory.getFromSettings(context, newProgram).withTimeout(10.seconds) + SolverFactory.getFromSettings(context, newProgram).withTimeout(1.seconds) } type Solution = (Option[Boolean], Map[Identifier, Expr]) @@ -52,23 +52,17 @@ trait Solvable extends Processor { res } - def maybeSAT(problem: Expr): Boolean = { - withoutPosts { - SimpleSolverAPI(solver).solveSAT(problem)._1 getOrElse true - } + def maybeSAT(problem: Expr): Boolean = withoutPosts { + SimpleSolverAPI(solver).solveSAT(problem)._1 getOrElse true } - def definitiveALL(problem: Expr): Boolean = { - withoutPosts { - SimpleSolverAPI(solver).solveSAT(Not(problem))._1.exists(!_) - } + def definitiveALL(problem: Expr): Boolean = withoutPosts { + SimpleSolverAPI(solver).solveSAT(Not(problem))._1.exists(!_) } - def definitiveSATwithModel(problem: Expr): Option[Model] = { - withoutPosts { - val (sat, model) = SimpleSolverAPI(solver).solveSAT(problem) - if (sat.isDefined && sat.get) Some(model) else None - } + def definitiveSATwithModel(problem: Expr): Option[Model] = withoutPosts { + val (sat, model) = SimpleSolverAPI(solver).solveSAT(problem) + if (sat.isDefined && sat.get) Some(model) else None } } diff --git a/src/main/scala/leon/termination/TerminationChecker.scala b/src/main/scala/leon/termination/TerminationChecker.scala index efe8bf4387f631799472b64905de9350bac1da64..266e13aa2bfae67b32042af54f77816d54ccd1f8 100644 --- a/src/main/scala/leon/termination/TerminationChecker.scala +++ b/src/main/scala/leon/termination/TerminationChecker.scala @@ -9,7 +9,7 @@ import purescala.DefOps._ abstract class TerminationChecker(val context: LeonContext, initProgram: Program) extends LeonComponent { val program = { - val (pgm, _) = replaceFunDefs(initProgram){ fd => Some(fd.duplicate()) } + val (pgm, _, _, _) = replaceFunDefs(initProgram){ fd => Some(fd.duplicate()) } pgm } diff --git a/src/main/scala/leon/transformations/IntToRealProgram.scala b/src/main/scala/leon/transformations/IntToRealProgram.scala index fa09f9dca145674224a067357192b41853631ad9..f5444dec5970169d6adca7375322e0e776c625f4 100644 --- a/src/main/scala/leon/transformations/IntToRealProgram.scala +++ b/src/main/scala/leon/transformations/IntToRealProgram.scala @@ -40,7 +40,7 @@ abstract class ProgramTypeTransformer { val absType = ccdef.parent.get Some(AbstractClassType(mapClass(absType.classDef), absType.tps)) } else None - val newclassDef = ccdef.copy(id = FreshIdentifier(ccdef.id.name, ccdef.id.getType, true), parent = newparent) + val newclassDef = ccdef.duplicate(id = FreshIdentifier(ccdef.id.name, ccdef.id.getType, true), parent = newparent) //important: register a child if a parent was newly created. if (newparent.isDefined) @@ -55,7 +55,7 @@ abstract class ProgramTypeTransformer { val absType = acdef.parent.get Some(AbstractClassType(mapClass(absType.classDef), absType.tps)) } else None - val newClassDef = acdef.copy(id = FreshIdentifier(acdef.id.name, acdef.id.getType, true), parent = newparent) + val newClassDef = acdef.duplicate(id = FreshIdentifier(acdef.id.name, acdef.id.getType, true), parent = newparent) defmap += (acdef -> newClassDef) newClassDef.asInstanceOf[T] } diff --git a/src/main/scala/leon/utils/Bijection.scala b/src/main/scala/leon/utils/Bijection.scala index fecf8a75603fb7f94791d11d184351c17e06508e..251ebde9525a0d176ced699b5651317960700939 100644 --- a/src/main/scala/leon/utils/Bijection.scala +++ b/src/main/scala/leon/utils/Bijection.scala @@ -2,14 +2,16 @@ package leon.utils +import scala.collection.mutable.{Map => MutableMap} + object Bijection { def apply[A, B](a: Iterable[(A, B)]): Bijection[A, B] = new Bijection[A, B] ++= a def apply[A, B](a: (A, B)*): Bijection[A, B] = apply(a.toSeq) } class Bijection[A, B] extends Iterable[(A, B)] { - protected var a2b = Map[A, B]() - protected var b2a = Map[B, A]() + protected val a2b = MutableMap[A, B]() + protected val b2a = MutableMap[B, A]() def iterator = a2b.iterator @@ -28,8 +30,8 @@ class Bijection[A, B] extends Iterable[(A, B)] { } def clear(): Unit = { - a2b = Map() - b2a = Map() + a2b.clear() + b2a.clear() } def getA(b: B) = b2a.get(b) @@ -72,4 +74,9 @@ class Bijection[A, B] extends Iterable[(A, B)] { def composeB[C](c: B => C): Bijection[A, C] = { new Bijection[A, C] ++= this.a2b.map(kv => kv._1 -> c(kv._2)) } + + def swap: Bijection[B, A] = new Bijection[B, A] { + override protected val a2b = Bijection.this.b2a + override protected val b2a = Bijection.this.a2b + } } diff --git a/src/main/scala/leon/utils/IncrementalBijection.scala b/src/main/scala/leon/utils/IncrementalBijection.scala index 6dab6d670aafb289b9d4f328a13bdb7a687d5063..411fcf31cf23d3753e4ba7dfad64429f8c8460fc 100644 --- a/src/main/scala/leon/utils/IncrementalBijection.scala +++ b/src/main/scala/leon/utils/IncrementalBijection.scala @@ -2,25 +2,24 @@ package leon.utils -class IncrementalBijection[A,B] extends Bijection[A,B] with IncrementalState { - private var a2bStack = List[Map[A,B]]() - private var b2aStack = List[Map[B,A]]() +import scala.collection.mutable.{Map => MutableMap, Stack} - private def recursiveGet[T,U](stack: List[Map[T,U]], t: T): Option[U] = stack match { - case t2u :: xs => t2u.get(t) orElse recursiveGet(xs, t) - case Nil => None - } +class IncrementalBijection[A,B] extends Bijection[A,B] with IncrementalState { + protected val a2bStack = Stack[MutableMap[A,B]]() + protected val b2aStack = Stack[MutableMap[B,A]]() override def getA(b: B) = b2a.get(b) match { case s @ Some(a) => s - case None => recursiveGet(b2aStack, b) + case None => b2aStack.view.flatMap(_.get(b)).headOption } override def getB(a: A) = a2b.get(a) match { case s @ Some(b) => s - case None => recursiveGet(a2bStack, a) + case None => a2bStack.view.flatMap(_.get(a)).headOption } + override def iterator = aToB.iterator + def aToB: Map[A,B] = { a2bStack.reverse.foldLeft(Map[A,B]()) { _ ++ _ } ++ a2b } @@ -37,22 +36,30 @@ class IncrementalBijection[A,B] extends Bijection[A,B] with IncrementalState { def reset() : Unit = { super.clear() - a2bStack = Nil - b2aStack = Nil + a2bStack.clear() + b2aStack.clear() } def push(): Unit = { - a2bStack = a2b :: a2bStack - b2aStack = b2a :: b2aStack - a2b = Map() - b2a = Map() + a2bStack.push(a2b.clone) + b2aStack.push(b2a.clone) + a2b.clear() + b2a.clear() } def pop(): Unit = { - a2b = a2bStack.head - b2a = b2aStack.head - a2bStack = a2bStack.tail - b2aStack = b2aStack.tail + a2b.clear() + a2b ++= a2bStack.head + b2a.clear() + b2a ++= b2aStack.head + a2bStack.pop() + b2aStack.pop() } + override def swap: IncrementalBijection[B, A] = new IncrementalBijection[B, A] { + override protected val a2b = IncrementalBijection.this.b2a + override protected val b2a = IncrementalBijection.this.a2b + override protected val a2bStack = IncrementalBijection.this.b2aStack + override protected val b2aStack = IncrementalBijection.this.a2bStack + } } diff --git a/src/main/scala/leon/utils/IncrementalMap.scala b/src/main/scala/leon/utils/IncrementalMap.scala index aeaf32e093ead4aba05b128d3f543e53fffb8f13..d1c3fe0344a7f44913b81efbd253ab06e84bb8c2 100644 --- a/src/main/scala/leon/utils/IncrementalMap.scala +++ b/src/main/scala/leon/utils/IncrementalMap.scala @@ -60,6 +60,12 @@ class IncrementalMap[A, B] private(dflt: Option[B]) def getOrElse[B1 >: B](k: A, e: => B1) = stack.head.getOrElse(k, e) def values = stack.head.values + def cached(k: A)(b: => B): B = getOrElse(k, { + val ev = b + this += k -> ev + ev + }) + def iterator = stack.head.iterator def +=(kv: (A, B)) = { stack.head += kv; this } def -=(k: A) = { stack.head -= k; this } diff --git a/src/main/scala/leon/xlang/AntiAliasingPhase.scala b/src/main/scala/leon/xlang/AntiAliasingPhase.scala index f1ea90043736d1ea59a5ae0d9747e15905509cd6..0cfc68d50381276e7c3ed5c28c2566086f8cb315 100644 --- a/src/main/scala/leon/xlang/AntiAliasingPhase.scala +++ b/src/main/scala/leon/xlang/AntiAliasingPhase.scala @@ -55,9 +55,7 @@ object AntiAliasingPhase extends TransformationPhase { updateBody(fd, effects, updatedFunctions, varsInScope)(ctx) } - val res = replaceFunDefs(pgm)(fd => updatedFunctions.get(fd), (fi, fd) => None) - //println(res._1) - res._1 + replaceDefsInProgram(pgm)(updatedFunctions) } /* diff --git a/src/test/resources/regression/verification/purescala/invalid/AbstractRefinementMap.scala b/src/test/resources/regression/verification/purescala/invalid/AbstractRefinementMap.scala new file mode 100644 index 0000000000000000000000000000000000000000..30acb0f91be22b7c53eac47ee5e1e4df05b9872e --- /dev/null +++ b/src/test/resources/regression/verification/purescala/invalid/AbstractRefinementMap.scala @@ -0,0 +1,22 @@ +import leon.annotation._ +import leon.collection._ +import leon.lang._ + +object AbstractRefinementMap { + + case class ~>[A,B](private val f: A => B, pre: A => Boolean, ens: B => Boolean) { + def apply(x: A): B = { + require(pre(x)) + f(x) + } ensuring(ens) + } + + def map[A, B](l: List[A], f: A ~> B): List[B] = { + require(forall((x:A) => l.contains(x) ==> f.pre(x))) + l match { + case Cons(x, xs) => Cons[B](f(x), map(xs, f)) + case Nil() => Nil[B]() + } + } ensuring { res => forall((x: B) => res.contains(x) ==> f.ens(x)) } +} + diff --git a/src/test/resources/regression/verification/purescala/invalid/AbstractRefinementMap2.scala b/src/test/resources/regression/verification/purescala/invalid/AbstractRefinementMap2.scala new file mode 100644 index 0000000000000000000000000000000000000000..f0cb1841376bf5b107da22f755d34bb33cc82fde --- /dev/null +++ b/src/test/resources/regression/verification/purescala/invalid/AbstractRefinementMap2.scala @@ -0,0 +1,24 @@ +import leon.annotation._ +import leon.collection._ +import leon.lang._ + +object AbstractRefinementMap2 { + + case class ~>[A,B](private val f: A => B, pre: A => Boolean, ens: B => Boolean) { + require(forall((x: A) => pre(x) ==> ens(f(x)))) + + def apply(x: A): B = { + require(pre(x)) + f(x) + } ensuring(ens) + } + + def map[A, B](l: List[A], f: A ~> B): List[B] = { + require(forall((x:A) => l.contains(x) ==> f.pre(x))) + l match { + case Cons(x, xs) => Cons[B](f(x), map(xs, f)) + case Nil() => Nil[B]() + } + } ensuring { res => forall((x: B) => /* res.contains(x) ==> */ f.ens(x)) } +} + diff --git a/src/test/resources/regression/verification/purescala/valid/AbstractRefinementMap.scala b/src/test/resources/regression/verification/purescala/valid/AbstractRefinementMap.scala new file mode 100644 index 0000000000000000000000000000000000000000..c42579dce65454c7111bfc1a3f8c1b5a61204d62 --- /dev/null +++ b/src/test/resources/regression/verification/purescala/valid/AbstractRefinementMap.scala @@ -0,0 +1,24 @@ +import leon.annotation._ +import leon.collection._ +import leon.lang._ + +object AbstractRefinementMap { + + case class ~>[A,B](private val f: A => B, pre: A => Boolean, ens: B => Boolean) { + require(forall((x: A) => pre(x) ==> ens(f(x)))) + + def apply(x: A): B = { + require(pre(x)) + f(x) + } ensuring(ens) + } + + def map[A, B](l: List[A], f: A ~> B): List[B] = { + require(forall((x:A) => l.contains(x) ==> f.pre(x))) + l match { + case Cons(x, xs) => Cons[B](f(x), map(xs, f)) + case Nil() => Nil[B]() + } + } ensuring { res => forall((x: B) => res.contains(x) ==> f.ens(x)) } +} + diff --git a/src/test/scala/leon/integration/solvers/GlobalVariablesSuite.scala b/src/test/scala/leon/integration/solvers/GlobalVariablesSuite.scala index 466a16dade7850e2604d111b232ffc5d349ba97d..628086b41c9df04e255aa91ec5d422e568d239a7 100644 --- a/src/test/scala/leon/integration/solvers/GlobalVariablesSuite.scala +++ b/src/test/scala/leon/integration/solvers/GlobalVariablesSuite.scala @@ -14,7 +14,7 @@ import leon.LeonContext import leon.solvers._ import leon.solvers.smtlib._ -import leon.solvers.combinators._ +import leon.solvers.unrolling._ import leon.solvers.z3._ class GlobalVariablesSuite extends LeonTestSuiteWithProgram with ExpressionsDSL { diff --git a/src/test/scala/leon/integration/solvers/QuantifierSolverSuite.scala b/src/test/scala/leon/integration/solvers/QuantifierSolverSuite.scala index 4ff5ccc71d159fd76ec1e48d99bce4be4dbfb125..0bc83e224c8632042d7dbf35b94971f9f23e49e3 100644 --- a/src/test/scala/leon/integration/solvers/QuantifierSolverSuite.scala +++ b/src/test/scala/leon/integration/solvers/QuantifierSolverSuite.scala @@ -13,24 +13,24 @@ import leon.LeonOption import leon.solvers._ import leon.solvers.smtlib._ -import leon.solvers.combinators._ +import leon.solvers.cvc4._ import leon.solvers.z3._ class QuantifierSolverSuite extends LeonTestSuiteWithProgram { val sources = List() - override val leonOpts = List("checkmodels") + override val leonOpts = List("--checkmodels") val getFactories: Seq[(String, (LeonContext, Program) => Solver)] = { (if (SolverFactory.hasNativeZ3) Seq( ("fairz3", (ctx: LeonContext, pgm: Program) => new FairZ3Solver(ctx, pgm)) ) else Nil) ++ (if (SolverFactory.hasZ3) Seq( - ("smt-z3", (ctx: LeonContext, pgm: Program) => new UnrollingSolver(ctx, pgm, new SMTLIBZ3Solver(ctx, pgm))) + ("smt-z3", (ctx: LeonContext, pgm: Program) => new Z3UnrollingSolver(ctx, pgm, new SMTLIBZ3Solver(ctx, pgm))) ) else Nil) ++ (if (SolverFactory.hasCVC4) Seq( - ("smt-cvc4", (ctx: LeonContext, pgm: Program) => new UnrollingSolver(ctx, pgm, new SMTLIBCVC4Solver(ctx, pgm))) + ("smt-cvc4", (ctx: LeonContext, pgm: Program) => new CVC4UnrollingSolver(ctx, pgm, new SMTLIBCVC4Solver(ctx, pgm))) ) else Nil) } @@ -126,6 +126,7 @@ class QuantifierSolverSuite extends LeonTestSuiteWithProgram { checkSolver(solver, expr, true) } + /* test(s"Satisfiable quantified formula $ename in $sname with partial models") { implicit fix => val (ctx, pgm) = fix val newCtx = ctx.copy(options = ctx.options.filter(_ != UnrollingProcedure.optPartialModels) :+ @@ -133,6 +134,7 @@ class QuantifierSolverSuite extends LeonTestSuiteWithProgram { val solver = sf(newCtx, pgm) checkSolver(solver, expr, true) } + */ } for ((sname, sf) <- getFactories; (ename, expr) <- unsatisfiable) { diff --git a/src/test/scala/leon/integration/solvers/SolversSuite.scala b/src/test/scala/leon/integration/solvers/SolversSuite.scala index 2a42bc6a3ae82a0f830afc3787d696d1792f3865..8b0d8026e1540a170df577b3952140e4ef8fa910 100644 --- a/src/test/scala/leon/integration/solvers/SolversSuite.scala +++ b/src/test/scala/leon/integration/solvers/SolversSuite.scala @@ -13,7 +13,6 @@ import leon.LeonContext import leon.solvers._ import leon.solvers.smtlib._ -import leon.solvers.combinators._ import leon.solvers.z3._ class SolversSuite extends LeonTestSuiteWithProgram { @@ -22,13 +21,13 @@ class SolversSuite extends LeonTestSuiteWithProgram { val getFactories: Seq[(String, (LeonContext, Program) => Solver)] = { (if (SolverFactory.hasNativeZ3) Seq( - ("fairz3", (ctx: LeonContext, pgm: Program) => new Z3StringFairZ3Solver(ctx, pgm)) + ("fairz3", (ctx: LeonContext, pgm: Program) => new FairZ3Solver(ctx, pgm)) ) else Nil) ++ (if (SolverFactory.hasZ3) Seq( - ("smt-z3", (ctx: LeonContext, pgm: Program) => new Z3StringUnrollingSolver(ctx, pgm, pgm => new SMTLIBZ3Solver(ctx, pgm))) + ("smt-z3", (ctx: LeonContext, pgm: Program) => new Z3UnrollingSolver(ctx, pgm, new SMTLIBZ3Solver(ctx, pgm))) ) else Nil) ++ (if (SolverFactory.hasCVC4) Seq( - ("smt-cvc4", (ctx: LeonContext, pgm: Program) => new Z3StringUnrollingSolver(ctx, pgm, pgm => new SMTLIBCVC4Solver(ctx, pgm))) + ("smt-cvc4", (ctx: LeonContext, pgm: Program) => new Z3UnrollingSolver(ctx, pgm, new SMTLIBCVC4Solver(ctx, pgm))) ) else Nil) } @@ -49,7 +48,7 @@ class SolversSuite extends LeonTestSuiteWithProgram { val vs = types.map(FreshIdentifier("v", _).toVariable) - // We need to make sure models are not co-finite + // We need to make sure models are not co-finite val cnstrs = vs.map(v => v.getType match { case UnitType => Equals(v, simplestValue(v.getType)) @@ -77,7 +76,7 @@ class SolversSuite extends LeonTestSuiteWithProgram { fail(s"Solver $solver - Model does not contain "+v.id.uniqueName+" of type "+v.getType) } } - case _ => + case res => fail(s"Solver $solver - Constraint "+cnstr.asString+" is unsat!? Solver was "+solver.getClass) } } finally { diff --git a/src/test/scala/leon/integration/solvers/StringRenderSuite.scala b/src/test/scala/leon/integration/solvers/StringRenderSuite.scala index c0c724c81580221a099e246332c7d3182b19d3f4..fe02ce2ef2a8d1eed9e0bbab4d25d695865287f9 100644 --- a/src/test/scala/leon/integration/solvers/StringRenderSuite.scala +++ b/src/test/scala/leon/integration/solvers/StringRenderSuite.scala @@ -115,7 +115,7 @@ class StringRenderSuite extends LeonTestSuiteWithProgram with Matchers with Scal val newProgram = DefOps.addFunDefs(synth.program, solutions.head.defs, synth.sctx.functionContext) val newFd = ci.fd.duplicate() newFd.body = Some(solutions.head.term) - val (newProgram2, _) = DefOps.replaceFunDefs(newProgram)({ fd => + val (newProgram2, _, _, _) = DefOps.replaceFunDefs(newProgram)({ fd => if(fd == ci.fd) { Some(newFd) } else None @@ -209,9 +209,10 @@ class StringRenderSuite extends LeonTestSuiteWithProgram with Matchers with Scal | def listEdgeToString(l: List[Edge]): String = ??? by example |} """.stripMargin.replaceByExample) + implicit val (ctx, program) = getFixture() - - val synthesisInfos = SourceInfo.extractFromProgram(ctx, program).map(si => si.fd.id.name -> si ).toMap + + val synthesisInfos = SourceInfo.extractFromProgram(ctx, program).map(si => si.fd.id.name -> si).toMap def synthesizeAndTest(functionName: String, tests: (Seq[Expr], String)*) { val (fd, program) = applyStringRenderOn(functionName) @@ -260,6 +261,7 @@ class StringRenderSuite extends LeonTestSuiteWithProgram with Matchers with Scal def apply(types: TypeTree*)(args: Expr*) = FunctionInvocation(fd.typed(types), args) } + // Mimics the file above, allows construction of expressions. case class Constructors(program: Program) { implicit val p = program @@ -401,4 +403,4 @@ class StringRenderSuite extends LeonTestSuiteWithProgram with Matchers with Scal customListToString(Dummy2)(listDummy2, lambdaDummy2ToString))) } } -} \ No newline at end of file +} diff --git a/src/test/scala/leon/integration/solvers/UnrollingSolverSuite.scala b/src/test/scala/leon/integration/solvers/UnrollingSolverSuite.scala index f286d9d655a8f2a4d3ae83cf1e1f6f4a1c430c56..413ee804ccb639fe479d36f918d8aee2fd6679f2 100644 --- a/src/test/scala/leon/integration/solvers/UnrollingSolverSuite.scala +++ b/src/test/scala/leon/integration/solvers/UnrollingSolverSuite.scala @@ -8,7 +8,6 @@ import leon.purescala.Types._ import leon.purescala.Common._ import leon.purescala.Definitions._ import leon.solvers.z3._ -import leon.solvers.combinators._ class UnrollingSolverSuite extends LeonSolverSuite { @@ -27,7 +26,7 @@ class UnrollingSolverSuite extends LeonSolverSuite { ) def getSolver(implicit ctx: LeonContext, pgm: Program) = { - new UnrollingSolver(ctx, pgm, new UninterpretedZ3Solver(ctx, pgm)) + new Z3UnrollingSolver(ctx, pgm, new UninterpretedZ3Solver(ctx, pgm)) } test("'true' should be valid") { implicit fix => diff --git a/src/test/scala/leon/unit/purescala/TypeOpsSuite.scala b/src/test/scala/leon/unit/purescala/TypeOpsSuite.scala index e24be6f2811da9e9a52bec29ef8786835e6f495e..c6158d66a71bb8796da2ff6aa52a8f9ab4544c6d 100644 --- a/src/test/scala/leon/unit/purescala/TypeOpsSuite.scala +++ b/src/test/scala/leon/unit/purescala/TypeOpsSuite.scala @@ -19,13 +19,13 @@ class TypeOpsSuite extends LeonTestSuite with helpers.WithLikelyEq with helpers. val tp2 = TypeParameter.fresh("A") val tp3 = TypeParameter.fresh("B") - val listD = AbstractClassDef(FreshIdentifier("List"), Seq(tpD), None) + val listD = new AbstractClassDef(FreshIdentifier("List"), Seq(tpD), None) val listT = listD.typed - val nilD = CaseClassDef(FreshIdentifier("Nil"), Seq(tpD), Some(listT), false) + val nilD = new CaseClassDef(FreshIdentifier("Nil"), Seq(tpD), Some(listT), false) val nilT = nilD.typed - val consD = CaseClassDef(FreshIdentifier("Cons"), Seq(tpD), Some(listT), false) + val consD = new CaseClassDef(FreshIdentifier("Cons"), Seq(tpD), Some(listT), false) val consT = consD.typed // Simple tests for fixed types