From 056f09f97e5ad80f217c32d9884db9209b3115f1 Mon Sep 17 00:00:00 2001 From: Nicolas Voirol <voirol.nicolas@gmail.com> Date: Tue, 22 Mar 2016 18:04:11 +0100 Subject: [PATCH] Added generic TheoryEncoders to Leon --- src/main/scala/leon/Main.scala | 2 +- .../scala/leon/codegen/runtime/Monitor.scala | 2 +- .../leon/evaluators/RecursiveEvaluator.scala | 2 +- .../leon/evaluators/StreamEvaluator.scala | 4 +- .../engine/UnfoldingTemplateSolver.scala | 2 +- .../leon/laziness/FreeVariableFactory.scala | 8 +- .../scala/leon/laziness/LazinessUtil.scala | 8 +- .../leon/laziness/LazyClosureConverter.scala | 2 +- .../leon/laziness/LazyClosureFactory.scala | 8 +- .../scala/leon/purescala/Definitions.scala | 6 +- .../scala/leon/purescala/Expressions.scala | 37 ++ .../scala/leon/purescala/Extractors.scala | 10 +- .../leon/purescala/TreeTransformer.scala | 225 +++++++++++ src/main/scala/leon/purescala/TypeOps.scala | 188 +-------- src/main/scala/leon/purescala/Types.scala | 1 + .../scala/leon/solvers/SolverFactory.scala | 16 +- .../combinators/Z3StringCapableSolver.scala | 271 ------------- .../solvers/cvc4/CVC4UnrollingSolver.scala | 13 + .../smtlib/SMTLIBQuantifiedTarget.scala | 3 +- .../leon/solvers/smtlib/SMTLIBSolver.scala | 18 +- .../leon/solvers/smtlib/SMTLIBTarget.scala | 1 - .../smtlib/SMTLIBZ3QuantifiedSolver.scala | 12 +- .../smtlib/SMTLIBZ3QuantifiedTarget.scala | 1 - .../leon/solvers/smtlib/SMTLIBZ3Solver.scala | 47 +-- .../leon/solvers/theories/BagEncoder.scala | 14 + .../leon/solvers/theories/StringEncoder.scala | 203 ++++++++++ .../leon/solvers/theories/TheoryEncoder.scala | 249 ++++++++++++ .../DatatypeManager.scala | 2 +- .../LambdaManager.scala | 6 +- .../QuantificationManager.scala | 2 +- .../TemplateEncoder.scala | 15 +- .../TemplateGenerator.scala | 24 +- .../TemplateInfo.scala | 2 +- .../TemplateManager.scala | 4 +- .../UnrollingBank.scala | 11 +- .../UnrollingSolver.scala | 61 +-- .../leon/solvers/z3/AbstractZ3Solver.scala | 2 +- .../scala/leon/solvers/z3/FairZ3Solver.scala | 28 +- .../leon/solvers/z3/Z3StringConversion.scala | 382 ------------------ .../leon/solvers/z3/Z3UnrollingSolver.scala | 13 + .../scala/leon/synthesis/ExamplesFinder.scala | 1 + src/main/scala/leon/utils/Bijection.scala | 15 +- .../leon/utils/IncrementalBijection.scala | 45 ++- .../scala/leon/utils/IncrementalMap.scala | 6 + .../solvers/GlobalVariablesSuite.scala | 2 +- .../solvers/QuantifierSolverSuite.scala | 6 +- .../integration/solvers/SolversSuite.scala | 11 +- .../solvers/StringRenderSuite.scala | 8 +- .../solvers/UnrollingSolverSuite.scala | 3 +- 49 files changed, 973 insertions(+), 1029 deletions(-) create mode 100644 src/main/scala/leon/purescala/TreeTransformer.scala delete mode 100644 src/main/scala/leon/solvers/combinators/Z3StringCapableSolver.scala create mode 100644 src/main/scala/leon/solvers/cvc4/CVC4UnrollingSolver.scala create mode 100644 src/main/scala/leon/solvers/theories/BagEncoder.scala create mode 100644 src/main/scala/leon/solvers/theories/StringEncoder.scala create mode 100644 src/main/scala/leon/solvers/theories/TheoryEncoder.scala rename src/main/scala/leon/solvers/{templates => unrolling}/DatatypeManager.scala (99%) rename src/main/scala/leon/solvers/{templates => unrolling}/LambdaManager.scala (98%) rename src/main/scala/leon/solvers/{templates => unrolling}/QuantificationManager.scala (99%) rename src/main/scala/leon/solvers/{templates => unrolling}/TemplateEncoder.scala (61%) rename src/main/scala/leon/solvers/{templates => unrolling}/TemplateGenerator.scala (96%) rename src/main/scala/leon/solvers/{templates => unrolling}/TemplateInfo.scala (98%) rename src/main/scala/leon/solvers/{templates => unrolling}/TemplateManager.scala (99%) rename src/main/scala/leon/solvers/{templates => unrolling}/UnrollingBank.scala (97%) rename src/main/scala/leon/solvers/{combinators => unrolling}/UnrollingSolver.scala (92%) delete mode 100644 src/main/scala/leon/solvers/z3/Z3StringConversion.scala create mode 100644 src/main/scala/leon/solvers/z3/Z3UnrollingSolver.scala diff --git a/src/main/scala/leon/Main.scala b/src/main/scala/leon/Main.scala index 393933277..efce1a2ff 100644 --- a/src/main/scala/leon/Main.scala +++ b/src/main/scala/leon/Main.scala @@ -36,7 +36,7 @@ object Main { // Add whatever you need here. lazy val allComponents : Set[LeonComponent] = allPhases.toSet ++ Set( - solvers.combinators.UnrollingProcedure, MainComponent, GlobalOptions, solvers.smtlib.SMTLIBCVC4Component, solvers.isabelle.Component + solvers.unrolling.UnrollingProcedure, MainComponent, GlobalOptions, solvers.smtlib.SMTLIBCVC4Component, solvers.isabelle.Component ) /* diff --git a/src/main/scala/leon/codegen/runtime/Monitor.scala b/src/main/scala/leon/codegen/runtime/Monitor.scala index 0861ce3bc..6ed6b2d67 100644 --- a/src/main/scala/leon/codegen/runtime/Monitor.scala +++ b/src/main/scala/leon/codegen/runtime/Monitor.scala @@ -20,7 +20,7 @@ import scala.collection.mutable.{HashMap => MutableMap, Set => MutableSet} import scala.concurrent.duration._ import solvers.SolverFactory -import solvers.combinators.UnrollingProcedure +import solvers.unrolling.UnrollingProcedure import synthesis._ diff --git a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala index ee1c689bc..5b8b9ccae 100644 --- a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala +++ b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala @@ -15,7 +15,7 @@ import purescala.Expressions._ import purescala.Definitions._ import purescala.DefOps import solvers.{PartialModel, Model, SolverFactory} -import solvers.combinators.UnrollingProcedure +import solvers.unrolling.UnrollingProcedure import scala.collection.mutable.{Map => MutableMap} import scala.concurrent.duration._ import org.apache.commons.lang3.StringEscapeUtils diff --git a/src/main/scala/leon/evaluators/StreamEvaluator.scala b/src/main/scala/leon/evaluators/StreamEvaluator.scala index 16d42e892..d53869f89 100644 --- a/src/main/scala/leon/evaluators/StreamEvaluator.scala +++ b/src/main/scala/leon/evaluators/StreamEvaluator.scala @@ -14,7 +14,7 @@ import purescala.Expressions._ import purescala.Quantification._ import leon.solvers.{SolverFactory, PartialModel} -import leon.solvers.combinators.UnrollingProcedure +import leon.solvers.unrolling.UnrollingProcedure import leon.utils.StreamUtils._ import scala.concurrent.duration._ @@ -166,7 +166,7 @@ class StreamEvaluator(ctx: LeonContext, prog: Program) val domainCnstr = orJoin(quorums.map { quorum => val quantifierDomains = quorum.flatMap { case (path, caller, args) => - val optMatcher = e(expr) match { + val optMatcher = e(caller) match { case Stream(l: Lambda) => Some(gctx.lambdas.getOrElse(l, l)) case Stream(ev) => Some(ev) case _ => None diff --git a/src/main/scala/leon/invariant/engine/UnfoldingTemplateSolver.scala b/src/main/scala/leon/invariant/engine/UnfoldingTemplateSolver.scala index 7634da75a..4cc4750d3 100644 --- a/src/main/scala/leon/invariant/engine/UnfoldingTemplateSolver.scala +++ b/src/main/scala/leon/invariant/engine/UnfoldingTemplateSolver.scala @@ -191,7 +191,7 @@ class UnfoldingTemplateSolver(ctx: InferenceContext, program: Program, rootFd: F } import leon.solvers._ - import leon.solvers.combinators.UnrollingSolver + import leon.solvers.unrolling.UnrollingSolver def solveUsingLeon(leonctx: LeonContext, p: Program, vc: VC) = { val solFactory = SolverFactory.uninterpreted(leonctx, program) val smtUnrollZ3 = new UnrollingSolver(ctx.leonContext, program, solFactory.getNewSolver()) with TimeoutSolver diff --git a/src/main/scala/leon/laziness/FreeVariableFactory.scala b/src/main/scala/leon/laziness/FreeVariableFactory.scala index 441da84ba..738822709 100644 --- a/src/main/scala/leon/laziness/FreeVariableFactory.scala +++ b/src/main/scala/leon/laziness/FreeVariableFactory.scala @@ -15,22 +15,22 @@ import purescala.Types._ */ object FreeVariableFactory { - val fvClass = AbstractClassDef(FreshIdentifier("FreeVar@"), Seq(), None) + val fvClass = new AbstractClassDef(FreshIdentifier("FreeVar@"), Seq(), None) val fvType = AbstractClassType(fvClass, Seq()) val varCase = { - val cdef = CaseClassDef(FreshIdentifier("Var@"), Seq(), Some(fvType), false) + val cdef = new CaseClassDef(FreshIdentifier("Var@"), Seq(), Some(fvType), false) cdef.setFields(Seq(ValDef(FreshIdentifier("fl", fvType)))) fvClass.registerChild(cdef) cdef } val nextCase = { - val cdef = CaseClassDef(FreshIdentifier("NextVar@"), Seq(), Some(fvType), false) + val cdef = new CaseClassDef(FreshIdentifier("NextVar@"), Seq(), Some(fvType), false) cdef.setFields(Seq(ValDef(FreshIdentifier("fl", fvType)))) fvClass.registerChild(cdef) cdef } val nilCase = { - val cdef = CaseClassDef(FreshIdentifier("NilVar@"), Seq(), Some(fvType), false) + val cdef = new CaseClassDef(FreshIdentifier("NilVar@"), Seq(), Some(fvType), false) fvClass.registerChild(cdef) cdef } diff --git a/src/main/scala/leon/laziness/LazinessUtil.scala b/src/main/scala/leon/laziness/LazinessUtil.scala index eb9e35c31..90064561d 100644 --- a/src/main/scala/leon/laziness/LazinessUtil.scala +++ b/src/main/scala/leon/laziness/LazinessUtil.scala @@ -134,14 +134,14 @@ object LazinessUtil { } def isLazyType(tpe: TypeTree): Boolean = tpe match { - case CaseClassType(CaseClassDef(cid, _, None, false), Seq(_)) => - cid.name == "Lazy" + case CaseClassType(ccd, Seq(_)) if !ccd.hasParent && !ccd.isCaseObject => + ccd.id.name == "Lazy" case _ => false } def isMemType(tpe: TypeTree): Boolean = tpe match { - case CaseClassType(CaseClassDef(cid, _, None, false), Seq(_)) => - cid.name == "Mem" + case CaseClassType(ccd, Seq(_)) if !ccd.hasParent && !ccd.isCaseObject => + ccd.id.name == "Mem" case _ => false } diff --git a/src/main/scala/leon/laziness/LazyClosureConverter.scala b/src/main/scala/leon/laziness/LazyClosureConverter.scala index 58fa16741..8e86581d8 100644 --- a/src/main/scala/leon/laziness/LazyClosureConverter.scala +++ b/src/main/scala/leon/laziness/LazyClosureConverter.scala @@ -742,7 +742,7 @@ class LazyClosureConverter(p: Program, ctx: LeonContext, } def transformCaseClasses = p.definedClasses.foreach { - case ccd @ CaseClassDef(id, tparamDefs, superClass, isCaseObj) if !ccd.flags.contains(Annotation("library", Seq())) && + case ccd: CaseClassDef if !ccd.flags.contains(Annotation("library", Seq())) && ccd.fields.exists(vd => isLazyType(vd.getType)) => val nfields = ccd.fields.map { fld => unwrapLazyType(fld.getType) match { diff --git a/src/main/scala/leon/laziness/LazyClosureFactory.scala b/src/main/scala/leon/laziness/LazyClosureFactory.scala index 1a69fa89d..ae7da4aa2 100644 --- a/src/main/scala/leon/laziness/LazyClosureFactory.scala +++ b/src/main/scala/leon/laziness/LazyClosureFactory.scala @@ -59,7 +59,7 @@ class LazyClosureFactory(p: Program) { ops.tail.forall(op => isMemoized(op) == isMemoized(ops.head)) } val absTParams = (1 to tpcount).map(i => TypeParameterDef(TypeParameter.fresh("T" + i))) - tpename -> AbstractClassDef(FreshIdentifier(typeNameToADTName(tpename), Untyped), + tpename -> new AbstractClassDef(FreshIdentifier(typeNameToADTName(tpename), Untyped), absTParams, None) }.toMap var opToAdt = Map[FunDef, CaseClassDef]() @@ -76,7 +76,7 @@ class LazyClosureFactory(p: Program) { assert(opfd.tparams.size == absTParamsDef.size) val absType = AbstractClassType(absClass, opfd.tparams.map(_.tp)) val classid = FreshIdentifier(opNameToCCName(opfd.id.name), Untyped) - val cdef = CaseClassDef(classid, opfd.tparams, Some(absType), isCaseObject = false) + val cdef = new CaseClassDef(classid, opfd.tparams, Some(absType), isCaseObject = false) val nfields = opfd.params.map { vd => val fldType = vd.getType unwrapLazyType(fldType) match { @@ -105,7 +105,7 @@ class LazyClosureFactory(p: Program) { case NAryType(tparams, tcons) => tcons(absTParams) } val eagerid = FreshIdentifier("Eager" + TypeUtil.typeNameWOParams(clresType)) - val eagerClosure = CaseClassDef(eagerid, absTParamsDef, + val eagerClosure = new CaseClassDef(eagerid, absTParamsDef, Some(AbstractClassType(absClass, absTParams)), isCaseObject = false) eagerClosure.setFields(Seq(ValDef(FreshIdentifier("a", clresType)))) absClass.registerChild(eagerClosure) @@ -166,7 +166,7 @@ class LazyClosureFactory(p: Program) { val fldType = SetType(AbstractClassType(absClass, tparams)) ValDef(FreshIdentifier(typeToFieldName(tn), fldType)) } - val ccd = CaseClassDef(FreshIdentifier("State@"), tparams map TypeParameterDef, None, false) + val ccd = new CaseClassDef(FreshIdentifier("State@"), tparams map TypeParameterDef, None, false) ccd.setFields(fields) ccd } diff --git a/src/main/scala/leon/purescala/Definitions.scala b/src/main/scala/leon/purescala/Definitions.scala index 3c26299ab..0103bd950 100644 --- a/src/main/scala/leon/purescala/Definitions.scala +++ b/src/main/scala/leon/purescala/Definitions.scala @@ -235,7 +235,7 @@ object Definitions { self => def subDefinitions = fields ++ methods ++ tparams - + val id: Identifier val tparams: Seq[TypeParameterDef] def fields: Seq[ValDef] @@ -364,7 +364,7 @@ object Definitions { val acd = new AbstractClassDef(id, tparams, parent) acd.addFlags(this.flags) if (!parent.exists(_.classDef.hasInvariant)) invariant.foreach(inv => acd.setInvariant(inv)) - parent.map(_.classDef.ancestors.map(_.registerChild(acd))) + parent.foreach(_.classDef.registerChild(acd)) acd.copiedFrom(this) } } @@ -418,7 +418,7 @@ object Definitions { cd.setFields(fields) cd.addFlags(this.flags) if (!parent.exists(_.classDef.hasInvariant)) invariant.foreach(inv => cd.setInvariant(inv)) - parent.map(_.classDef.ancestors.map(_.registerChild(cd))) + parent.foreach(_.classDef.registerChild(cd)) cd.copiedFrom(this) } } diff --git a/src/main/scala/leon/purescala/Expressions.scala b/src/main/scala/leon/purescala/Expressions.scala index 8a7294552..25d1da631 100644 --- a/src/main/scala/leon/purescala/Expressions.scala +++ b/src/main/scala/leon/purescala/Expressions.scala @@ -846,6 +846,43 @@ object Expressions { val getType = leastUpperBound(Seq(set1, set2).map(_.getType)).getOrElse(Untyped).unveilUntyped } + /* Bag operations */ + /** $encodingof `Bag[base](elements)` */ + case class FiniteBag(elements: Map[Expr, Int], base: TypeTree) extends Expr { + val getType = BagType(base).unveilUntyped + } + /** $encodingof `bag.get(element)` or `bag(element)` */ + case class MultiplicityInBag(element: Expr, bag: Expr) extends Expr { + val getType = IntegerType + } + /** $encodingof `bag.length` */ + /* + case class BagCardinality(bag: Expr) extends Expr { + val getType = IntegerType + } + */ + /** $encodingof `bag1.subsetOf(bag2)` */ + /* + case class SubbagOf(bag1: Expr, bag2: Expr) extends Expr { + val getType = BooleanType + } + */ + /** $encodingof `bag1.intersect(bag2)` */ + case class BagIntersection(bag1: Expr, bag2: Expr) extends Expr { + val getType = leastUpperBound(Seq(bag1, bag2).map(_.getType)).getOrElse(Untyped).unveilUntyped + } + /** $encodingof `bag1 ++ bag2` */ + case class BagUnion(bag1: Expr, bag2: Expr) extends Expr { + val getType = leastUpperBound(Seq(bag1, bag2).map(_.getType)).getOrElse(Untyped).unveilUntyped + } + /** $encodingof `bag1 -- bag2` */ + /* + case class SetDifference(bag1: Expr, bag2: Expr) extends Expr { + val getType = leastUpperBound(Seq(set1, set2).map(_.getType)).getOrElse(Untyped).unveilUntyped + } + */ + + // TODO: Add checks for these expressions too /* Map operations */ diff --git a/src/main/scala/leon/purescala/Extractors.scala b/src/main/scala/leon/purescala/Extractors.scala index c7d8a1a40..5c09c2d29 100644 --- a/src/main/scala/leon/purescala/Extractors.scala +++ b/src/main/scala/leon/purescala/Extractors.scala @@ -144,6 +144,12 @@ object Extractors { Some(Seq(t1, t2), (es: Seq[Expr]) => SetUnion(es(0), es(1))) case SetDifference(t1, t2) => Some(Seq(t1, t2), (es: Seq[Expr]) => SetDifference(es(0), es(1))) + case MultiplicityInBag(e1, e2) => + Some(Seq(e1, e2), (es: Seq[Expr]) => MultiplicityInBag(es(0), es(1))) + case BagIntersection(e1, e2) => + Some(Seq(e1, e2), (es: Seq[Expr]) => BagIntersection(es(0), es(1))) + case BagUnion(e1, e2) => + Some(Seq(e1, e2), (es: Seq[Expr]) => BagUnion(es(0), es(1))) case mg @ MapApply(t1, t2) => Some(Seq(t1, t2), (es: Seq[Expr]) => MapApply(es(0), es(1))) case MapUnion(t1, t2) => @@ -173,6 +179,9 @@ object Extractors { case SubString(t1, a, b) => Some((t1::a::b::Nil, es => SubString(es(0), es(1), es(2)))) case FiniteSet(els, base) => Some((els.toSeq, els => FiniteSet(els.toSet, base))) + case FiniteBag(els, base) => + val seq = els.toSeq + Some((seq.map(_._1), els => FiniteBag((els zip seq.map(_._2)).toMap, base))) case FiniteMap(args, f, t) => { val subArgs = args.flatMap { case (k, v) => Seq(k, v) }.toSeq val builder = (as: Seq[Expr]) => { @@ -381,5 +390,4 @@ object Extractors { } } } - } diff --git a/src/main/scala/leon/purescala/TreeTransformer.scala b/src/main/scala/leon/purescala/TreeTransformer.scala new file mode 100644 index 000000000..88b267670 --- /dev/null +++ b/src/main/scala/leon/purescala/TreeTransformer.scala @@ -0,0 +1,225 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package purescala + +import Common._ +import Definitions._ +import Expressions._ +import Extractors._ +import Types._ + +trait TreeTransformer { + def transform(id: Identifier): Identifier = id + def transform(cd: ClassDef): ClassDef = cd + def transform(fd: FunDef): FunDef = fd + + def transform(e: Expr)(implicit bindings: Map[Identifier, Identifier]): Expr = e match { + case Variable(id) if bindings contains id => Variable(bindings(id)).copiedFrom(e) + case Variable(id) => Variable(transform(id)).copiedFrom(e) + case FiniteLambda(mappings, default, tpe) => + FiniteLambda(mappings.map { case (ks, v) => (ks map transform, transform(v)) }, + transform(default), transform(tpe).asInstanceOf[FunctionType]).copiedFrom(e) + case Lambda(args, body) => + val newArgs = args.map(vd => ValDef(transform(vd.id))) + val newBindings = (args zip newArgs).filter(p => p._1 != p._2).map(p => p._1.id -> p._2.id) + Lambda(newArgs, transform(body)(bindings ++ newBindings)).copiedFrom(e) + case Forall(args, body) => + val newArgs = args.map(vd => ValDef(transform(vd.id))) + val newBindings = (args zip newArgs).filter(p => p._1 != p._2).map(p => p._1.id -> p._2.id) + Forall(newArgs, transform(body)(bindings ++ newBindings)).copiedFrom(e) + case Let(a, expr, body) => + val newA = transform(a) + Let(newA, transform(expr), transform(body)(bindings + (a -> newA))).copiedFrom(e) + case CaseClass(cct, args) => + CaseClass(transform(cct).asInstanceOf[CaseClassType], args map transform).copiedFrom(e) + case CaseClassSelector(cct, caseClass, selector) => + val newCct @ CaseClassType(ccd, _) = transform(cct) + val newSelector = ccd.fieldsIds(cct.classDef.fieldsIds.indexOf(selector)) + CaseClassSelector(newCct, transform(caseClass), newSelector).copiedFrom(e) + case FunctionInvocation(TypedFunDef(fd, tpes), args) => + FunctionInvocation(TypedFunDef(transform(fd), tpes map transform), args map transform).copiedFrom(e) + case MethodInvocation(rec, cd, TypedFunDef(fd, tpes), args) => + MethodInvocation(transform(rec), transform(cd), TypedFunDef(transform(fd), tpes map transform), args map transform).copiedFrom(e) + case This(ct) => + This(transform(ct).asInstanceOf[ClassType]).copiedFrom(e) + case IsInstanceOf(expr, ct) => + IsInstanceOf(transform(expr), transform(ct).asInstanceOf[ClassType]).copiedFrom(e) + case AsInstanceOf(expr, ct) => + AsInstanceOf(transform(expr), transform(ct).asInstanceOf[ClassType]).copiedFrom(e) + case MatchExpr(scrutinee, cases) => + MatchExpr(transform(scrutinee), for (cse @ MatchCase(pattern, guard, rhs) <- cases) yield { + val (newPattern, newBindings) = transform(pattern) + val allBindings = bindings ++ newBindings + MatchCase(newPattern, guard.map(g => transform(g)(allBindings)), transform(rhs)(allBindings)).copiedFrom(cse) + }).copiedFrom(e) + case FiniteSet(es, tpe) => + FiniteSet(es map transform, transform(tpe)).copiedFrom(e) + case FiniteBag(es, tpe) => + FiniteBag(es map { case (k, v) => transform(k) -> v }, transform(tpe)).copiedFrom(e) + case FiniteMap(pairs, from, to) => + FiniteMap(pairs map { case (k, v) => transform(k) -> transform(v) }, transform(from), transform(to)).copiedFrom(e) + case EmptyArray(tpe) => + EmptyArray(transform(tpe)).copiedFrom(e) + case Hole(tpe, alts) => + Hole(transform(tpe), alts map transform).copiedFrom(e) + case NoTree(tpe) => + NoTree(transform(tpe)).copiedFrom(e) + case Error(tpe, desc) => + Error(transform(tpe), desc).copiedFrom(e) + case Operator(es, builder) => + val newEs = es map transform + if ((newEs zip es).exists { case (bef, aft) => aft ne bef }) { + builder(newEs).copiedFrom(e) + } else { + e + } + case e => e + } + + def transform(pat: Pattern): (Pattern, Map[Identifier, Identifier]) = pat match { + case InstanceOfPattern(binder, ct) => + val newBinder = binder map transform + val newPat = InstanceOfPattern(newBinder, transform(ct).asInstanceOf[ClassType]).copiedFrom(pat) + (newPat, (binder zip newBinder).filter(p => p._1 != p._2).toMap) + case WildcardPattern(binder) => + val newBinder = binder map transform + val newPat = WildcardPattern(newBinder).copiedFrom(pat) + (newPat, (binder zip newBinder).filter(p => p._1 != p._2).toMap) + case CaseClassPattern(binder, ct, subs) => + val newBinder = binder map transform + val (newSubs, subBinders) = (subs map transform).unzip + val newPat = CaseClassPattern(newBinder, transform(ct).asInstanceOf[CaseClassType], newSubs).copiedFrom(pat) + (newPat, (binder zip newBinder).filter(p => p._1 != p._2).toMap ++ subBinders.flatten) + case TuplePattern(binder, subs) => + val newBinder = binder map transform + val (newSubs, subBinders) = (subs map transform).unzip + val newPat = TuplePattern(newBinder, newSubs).copiedFrom(pat) + (newPat, (binder zip newBinder).filter(p => p._1 != p._2).toMap ++ subBinders.flatten) + case UnapplyPattern(binder, TypedFunDef(fd, tpes), subs) => + val newBinder = binder map transform + val (newSubs, subBinders) = (subs map transform).unzip + val newPat = UnapplyPattern(newBinder, TypedFunDef(transform(fd), tpes map transform), newSubs).copiedFrom(pat) + (newPat, (binder zip newBinder).filter(p => p._1 != p._2).toMap ++ subBinders.flatten) + case PatternExtractor(subs, builder) => + val (newSubs, subBinders) = (subs map transform).unzip + (builder(newSubs).copiedFrom(pat), subBinders.flatten.toMap) + } + + def transform(tpe: TypeTree): TypeTree = tpe match { + case cct @ CaseClassType(ccd, args) => + CaseClassType(transform(ccd).asInstanceOf[CaseClassDef], args map transform).copiedFrom(tpe) + case act @ AbstractClassType(acd, args) => + AbstractClassType(transform(acd).asInstanceOf[AbstractClassDef], args map transform).copiedFrom(tpe) + case NAryType(ts, builder) => builder(ts map transform).copiedFrom(tpe) + } +} + +trait TreeTraverser { + def traverse(id: Identifier): Unit = () + def traverse(cd: ClassDef): Unit = () + def traverse(fd: FunDef): Unit = () + + def traverse(e: Expr): Unit = e match { + case Variable(id) => traverse(id) + case FiniteLambda(mappings, default, tpe) => + (default +: mappings.toSeq.flatMap(p => p._2 +: p._1)) foreach traverse + traverse(tpe) + case Lambda(args, body) => + args foreach (vd => traverse(vd.id)) + traverse(body) + case Forall(args, body) => + args foreach (vd => traverse(vd.id)) + traverse(body) + case Let(a, expr, body) => + traverse(a) + traverse(expr) + traverse(body) + case CaseClass(cct, args) => + traverse(cct) + args foreach traverse + case CaseClassSelector(cct, caseClass, selector) => + traverse(cct) + traverse(caseClass) + case FunctionInvocation(TypedFunDef(fd, tpes), args) => + traverse(fd) + tpes foreach traverse + args foreach traverse + case MethodInvocation(rec, cd, TypedFunDef(fd, tpes), args) => + traverse(rec) + traverse(cd) + traverse(fd) + tpes foreach traverse + args foreach traverse + case This(ct) => + traverse(ct) + case IsInstanceOf(expr, ct) => + traverse(expr) + traverse(ct) + case AsInstanceOf(expr, ct) => + traverse(expr) + traverse(ct) + case MatchExpr(scrutinee, cases) => + traverse(scrutinee) + for (cse @ MatchCase(pattern, guard, rhs) <- cases) { + traverse(pattern) + guard foreach traverse + traverse(rhs) + } + case FiniteSet(es, tpe) => + es foreach traverse + traverse(tpe) + case FiniteBag(es, tpe) => + es foreach { case (k, _) => traverse(k) } + traverse(tpe) + case FiniteMap(pairs, from, to) => + pairs foreach { case (k, v) => traverse(k); traverse(v) } + traverse(from) + traverse(to) + case EmptyArray(tpe) => + traverse(tpe) + case Hole(tpe, alts) => + traverse(tpe) + alts foreach traverse + case NoTree(tpe) => + traverse(tpe) + case Error(tpe, desc) => + traverse(tpe) + case Operator(es, builder) => + es foreach traverse + case e => + } + + def traverse(pat: Pattern): Unit = pat match { + case InstanceOfPattern(binder, ct) => + binder foreach traverse + traverse(ct) + case WildcardPattern(binder) => + binder foreach traverse + case CaseClassPattern(binder, ct, subs) => + binder foreach traverse + traverse(ct) + subs foreach traverse + case TuplePattern(binder, subs) => + binder foreach traverse + subs foreach traverse + case UnapplyPattern(binder, TypedFunDef(fd, tpes), subs) => + binder foreach traverse + traverse(fd) + tpes foreach traverse + subs foreach traverse + case PatternExtractor(subs, builder) => + subs foreach traverse + } + + def traverse(tpe: TypeTree): Unit = tpe match { + case cct @ CaseClassType(ccd, args) => + traverse(ccd) + args foreach traverse + case act @ AbstractClassType(acd, args) => + traverse(acd) + args foreach traverse + case NAryType(ts, builder) => + ts foreach traverse + } +} diff --git a/src/main/scala/leon/purescala/TypeOps.scala b/src/main/scala/leon/purescala/TypeOps.scala index 99da89dbc..1a33a0883 100644 --- a/src/main/scala/leon/purescala/TypeOps.scala +++ b/src/main/scala/leon/purescala/TypeOps.scala @@ -244,192 +244,12 @@ object TypeOps extends { val Deconstructor = NAryType } with SubTreeOps[TypeTree typeParamSubst(tps.map { case (tpd, tp) => tpd.tp -> tp }) _ } - def rec(idsMap: Map[Identifier, Identifier])(e: Expr): Expr = { - - // Simple rec without affecting map - val srec = rec(idsMap) _ - - def onMatchLike(e: Expr, cases : Seq[MatchCase]) = { - - val newTpe = tpeSub(e.getType) - - def mapsUnion(maps: Seq[Map[Identifier, Identifier]]): Map[Identifier, Identifier] = { - maps.flatten.toMap - } - - def trCase(c: MatchCase) = c match { - case SimpleCase(p, b) => - val (newP, newIds) = trPattern(p, newTpe) - SimpleCase(newP, rec(idsMap ++ newIds)(b)) - - case GuardedCase(p, g, b) => - val (newP, newIds) = trPattern(p, newTpe) - GuardedCase(newP, rec(idsMap ++ newIds)(g), rec(idsMap ++ newIds)(b)) - } - - def trPattern(p: Pattern, expType: TypeTree): (Pattern, Map[Identifier, Identifier]) = (p, expType) match { - case (InstanceOfPattern(ob, ct), _) => - val newCt = tpeSub(ct).asInstanceOf[ClassType] - val newOb = ob.map(id => freshId(id, newCt)) - - (InstanceOfPattern(newOb, newCt), (ob zip newOb).toMap) - - case (TuplePattern(ob, sps), tpt @ TupleType(stps)) => - val newOb = ob.map(id => freshId(id, tpt)) - - val (newSps, newMaps) = (sps zip stps).map { case (sp, stpe) => trPattern(sp, stpe) }.unzip - - (TuplePattern(newOb, newSps), (ob zip newOb).toMap ++ mapsUnion(newMaps)) - - case (CaseClassPattern(ob, cct, sps), _) => - val newCt = tpeSub(cct).asInstanceOf[CaseClassType] - - val newOb = ob.map(id => freshId(id, newCt)) - - val (newSps, newMaps) = (sps zip newCt.fieldsTypes).map { case (sp, stpe) => trPattern(sp, stpe) }.unzip - - (CaseClassPattern(newOb, newCt, newSps), (ob zip newOb).toMap ++ mapsUnion(newMaps)) - - case (up@UnapplyPattern(ob, fd, sps), tp) => - val newFd = if ((fd.tps map tpeSub) == fd.tps) fd else fd.fd.typed(fd.tps map tpeSub) - val newOb = ob.map(id => freshId(id,tp)) - val exType = tpeSub(up.someType.tps.head) - val exTypes = unwrapTupleType(exType, exType.isInstanceOf[TupleType]) - val (newSps, newMaps) = (sps zip exTypes).map { case (sp, stpe) => trPattern(sp, stpe) }.unzip - (UnapplyPattern(newOb, newFd, newSps), (ob zip newOb).toMap ++ mapsUnion(newMaps)) - - case (WildcardPattern(ob), expTpe) => - val newOb = ob.map(id => freshId(id, expTpe)) - - (WildcardPattern(newOb), (ob zip newOb).toMap) - - case (LiteralPattern(ob, lit), expType) => - val newOb = ob.map(id => freshId(id, expType)) - (LiteralPattern(newOb,lit), (ob zip newOb).toMap) - - case _ => - sys.error(s"woot!? $p:$expType") - } - - (srec(e), cases.map(trCase))//.copiedFrom(m) - } - - e match { - case fi @ FunctionInvocation(TypedFunDef(fd, tps), args) => - FunctionInvocation(TypedFunDef(fd, tps.map(tpeSub)), args.map(srec)).copiedFrom(fi) - - case mi @ MethodInvocation(r, cd, TypedFunDef(fd, tps), args) => - MethodInvocation(srec(r), cd, TypedFunDef(fd, tps.map(tpeSub)), args.map(srec)).copiedFrom(mi) - - case th @ This(ct) => - This(tpeSub(ct).asInstanceOf[ClassType]).copiedFrom(th) - - case cc @ CaseClass(ct, args) => - CaseClass(tpeSub(ct).asInstanceOf[CaseClassType], args.map(srec)).copiedFrom(cc) - - case cc @ CaseClassSelector(ct, e, sel) => - caseClassSelector(tpeSub(ct).asInstanceOf[CaseClassType], srec(e), sel).copiedFrom(cc) - - case cc @ IsInstanceOf(e, ct) => - IsInstanceOf(srec(e), tpeSub(ct).asInstanceOf[ClassType]).copiedFrom(cc) - - case cc @ AsInstanceOf(e, ct) => - AsInstanceOf(srec(e), tpeSub(ct).asInstanceOf[ClassType]).copiedFrom(cc) - - case l @ Let(id, value, body) => - val newId = freshId(id, tpeSub(id.getType)) - Let(newId, srec(value), rec(idsMap + (id -> newId))(body)).copiedFrom(l) - - case l @ LetDef(fds, bd) => - val fdsMapping = for(fd <- fds) yield { - val id = fd.id.freshen - val tparams = fd.tparams map { p => - TypeParameterDef(tpeSub(p.tp).asInstanceOf[TypeParameter]) - } - val returnType = tpeSub(fd.returnType) - val params = fd.params map (vd => vd.copy(id = freshId(vd.id, tpeSub(vd.getType)))) - val newFd = fd.duplicate(id, tparams, params, returnType) - val subCalls = ExprOps.preMap { - case fi @ FunctionInvocation(tfd, args) if tfd.fd == fd => - Some(FunctionInvocation(newFd.typed(tfd.tps), args).copiedFrom(fi)) - case _ => - None - } _ - (fd, newFd, subCalls) - } - // We group the subcalls functions all in once - val subCalls = fdsMapping.map(_._3).reduceLeft { _ andThen _ } - - // We apply all the functions mappings at once - val newFds = for((fd, newFd, _) <- fdsMapping) yield { - val fullBody = rec(idsMap ++ fd.paramIds.zip(newFd.paramIds))(subCalls(fd.fullBody)) - newFd.fullBody = fullBody - newFd - } - val newBd = srec(subCalls(bd)).copiedFrom(bd) - - letDef(newFds, newBd).copiedFrom(l) - - case l @ Lambda(args, body) => - val newArgs = args.map { arg => - val tpe = tpeSub(arg.getType) - arg.copy(id = freshId(arg.id, tpe)) - } - val mapping = args.map(_.id) zip newArgs.map(_.id) - Lambda(newArgs, rec(idsMap ++ mapping)(body)).copiedFrom(l) - - case f @ Forall(args, body) => - val newArgs = args.map { arg => - val tpe = tpeSub(arg.getType) - arg.copy(id = freshId(arg.id, tpe)) - } - val mapping = args.map(_.id) zip newArgs.map(_.id) - Forall(newArgs, rec(idsMap ++ mapping)(body)).copiedFrom(f) - - case p @ Passes(in, out, cases) => - val (newIn, newCases) = onMatchLike(in, cases) - passes(newIn, srec(out), newCases).copiedFrom(p) - - case m @ MatchExpr(e, cases) => - val (newE, newCases) = onMatchLike(e, cases) - matchExpr(newE, newCases).copiedFrom(m) - - case Error(tpe, desc) => - Error(tpeSub(tpe), desc).copiedFrom(e) - - case Hole(tpe, alts) => - Hole(tpeSub(tpe), alts.map(srec)).copiedFrom(e) - - case g @ GenericValue(tpar, id) => - tpeSub(tpar) match { - case newTpar : TypeParameter => - GenericValue(newTpar, id).copiedFrom(g) - case other => // FIXME any better ideas? - throw LeonFatalError(Some(s"Tried to substitute $tpar with $other within GenericValue $g")) - } - - case s @ FiniteSet(elems, tpe) => - FiniteSet(elems.map(srec), tpeSub(tpe)).copiedFrom(s) - - case m @ FiniteMap(elems, from, to) => - FiniteMap(elems.map{ case (k, v) => (srec(k), srec(v)) }, tpeSub(from), tpeSub(to)).copiedFrom(m) - - case f @ FiniteLambda(mapping, dflt, FunctionType(from, to)) => - FiniteLambda(mapping.map { case (ks, v) => ks.map(srec) -> srec(v) }, srec(dflt), - FunctionType(from.map(tpeSub), tpeSub(to))).copiedFrom(f) - - case v @ Variable(id) if idsMap contains id => - Variable(idsMap(id)).copiedFrom(v) - - case n @ Operator(es, builder) => - builder(es.map(srec)).copiedFrom(n) - - case _ => - e - } + val transformer = new TreeTransformer { + override def transform(id: Identifier): Identifier = freshId(id, transform(id.getType)) + override def transform(tpe: TypeTree): TypeTree = tpeSub(tpe) } - rec(ids)(e) + transformer.transform(e)(ids) } } } diff --git a/src/main/scala/leon/purescala/Types.scala b/src/main/scala/leon/purescala/Types.scala index 0ee936c81..e2838104f 100644 --- a/src/main/scala/leon/purescala/Types.scala +++ b/src/main/scala/leon/purescala/Types.scala @@ -80,6 +80,7 @@ object Types { } case class SetType(base: TypeTree) extends TypeTree + case class BagType(base: TypeTree) extends TypeTree case class MapType(from: TypeTree, to: TypeTree) extends TypeTree case class FunctionType(from: Seq[TypeTree], to: TypeTree) extends TypeTree case class ArrayType(base: TypeTree) extends TypeTree diff --git a/src/main/scala/leon/solvers/SolverFactory.scala b/src/main/scala/leon/solvers/SolverFactory.scala index 865651579..8b84cd05e 100644 --- a/src/main/scala/leon/solvers/SolverFactory.scala +++ b/src/main/scala/leon/solvers/SolverFactory.scala @@ -4,7 +4,9 @@ package leon package solvers import combinators._ +import unrolling._ import z3._ +import cvc4._ import smtlib._ import purescala.Definitions._ @@ -79,12 +81,10 @@ object SolverFactory { def getFromName(ctx: LeonContext, program: Program)(name: String): SolverFactory[TimeoutSolver] = name match { case "fairz3" => - // Previously: new FairZ3Solver(ctx, program) with TimeoutSolver - SolverFactory(() => new Z3StringFairZ3Solver(ctx, program) with TimeoutSolver) + SolverFactory(() => new FairZ3Solver(ctx, program) with TimeoutSolver) case "unrollz3" => - // Previously: new UnrollingSolver(ctx, program, new UninterpretedZ3Solver(ctx, program)) with TimeoutSolver - SolverFactory(() => new Z3StringUnrollingSolver(ctx, program, (program: Program) => new UninterpretedZ3Solver(ctx, program)) with TimeoutSolver) + SolverFactory(() => new Z3UnrollingSolver(ctx, program, new UninterpretedZ3Solver(ctx, program)) with TimeoutSolver) case "enum" => SolverFactory(() => new EnumerationSolver(ctx, program) with TimeoutSolver) @@ -93,15 +93,13 @@ object SolverFactory { SolverFactory(() => new GroundSolver(ctx, program) with TimeoutSolver) case "smt-z3" => - // Previously: new UnrollingSolver(ctx, program, new SMTLIBZ3Solver(ctx, program)) with TimeoutSolver - SolverFactory(() => new Z3StringUnrollingSolver(ctx, program, (program: Program) => new SMTLIBZ3Solver(ctx, program)) with TimeoutSolver) + SolverFactory(() => new Z3UnrollingSolver(ctx, program, new SMTLIBZ3Solver(ctx, program)) with TimeoutSolver) case "smt-z3-q" => - // Previously: new SMTLIBZ3QuantifiedSolver(ctx, program) with TimeoutSolver - SolverFactory(() => new Z3StringSMTLIBZ3QuantifiedSolver(ctx, program) with TimeoutSolver) + SolverFactory(() => new SMTLIBZ3QuantifiedSolver(ctx, program) with TimeoutSolver) case "smt-cvc4" => - SolverFactory(() => new UnrollingSolver(ctx, program, new SMTLIBCVC4Solver(ctx, program)) with TimeoutSolver) + SolverFactory(() => new CVC4UnrollingSolver(ctx, program, new SMTLIBCVC4Solver(ctx, program)) with TimeoutSolver) case "smt-cvc4-proof" => SolverFactory(() => new SMTLIBCVC4ProofSolver(ctx, program) with TimeoutSolver) diff --git a/src/main/scala/leon/solvers/combinators/Z3StringCapableSolver.scala b/src/main/scala/leon/solvers/combinators/Z3StringCapableSolver.scala deleted file mode 100644 index b94233f28..000000000 --- a/src/main/scala/leon/solvers/combinators/Z3StringCapableSolver.scala +++ /dev/null @@ -1,271 +0,0 @@ -/* Copyright 2009-2015 EPFL, Lausanne */ - -package leon -package solvers -package combinators - -import purescala.Common._ -import purescala.Definitions._ -import purescala.Quantification._ -import purescala.Constructors._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Types._ -import purescala.DefOps -import purescala.TypeOps -import purescala.Extractors._ -import utils._ -import templates._ -import evaluators._ -import Template._ -import leon.solvers.z3.Z3StringConversion -import leon.utils.Bijection -import leon.solvers.z3.StringEcoSystem - -object Z3StringCapableSolver { - def thatShouldBeConverted(t: TypeTree): Boolean = TypeOps.exists{ _== StringType }(t) - def thatShouldBeConverted(e: Expr): Boolean = exists(e => thatShouldBeConverted(e.getType))(e) - def thatShouldBeConverted(id: Identifier): Boolean = thatShouldBeConverted(id.getType) - def thatShouldBeConverted(vd: ValDef): Boolean = thatShouldBeConverted(vd.id) - def thatShouldBeConverted(fd: FunDef): Boolean = { - (fd.body exists thatShouldBeConverted)|| (fd.paramIds exists thatShouldBeConverted) - } - def thatShouldBeConverted(cd: ClassDef): Boolean = cd match { - case ccd:CaseClassDef => ccd.fields.exists(thatShouldBeConverted) - case _ => false - } - def thatShouldBeConverted(p: Program): Boolean = { - (p.definedFunctions exists thatShouldBeConverted) || - (p.definedClasses exists thatShouldBeConverted) - } - - def convert(p: Program): (Program, Option[Z3StringConversion]) = { - val converter = new Z3StringConversion(p) - import converter.Forward._ - var hasStrings = false - val program_with_strings = converter.getProgram - val (program_with_correct_classes, cdMap, idMap, fdMap) = if(program_with_strings.definedClasses.exists{ case c: CaseClassDef => c.fieldsIds.exists(id => TypeOps.exists{ _ == StringType}(id.getType)) case _ => false}) { - val res:(Program, Map[ClassDef, ClassDef], Map[Identifier, Identifier], Map[FunDef, FunDef]) = DefOps.replaceCaseClassDefs(program_with_strings)((cd: ClassDef) => { - cd match { - case acd:AbstractClassDef => None - case ccd:CaseClassDef => - if(ccd.fieldsIds.exists(id => TypeOps.exists(StringType == _)(id.getType))) { - Some((parent: Option[AbstractClassType]) => ccd.duplicate(convertId(ccd.id), ccd.tparams, ccd.fieldsIds.map(id => ValDef(convertId(id))), parent, ccd.isCaseObject)) - } else None - } - }) - converter.mappedVariables.clear() // We will compose them later, they have been stored in idMap - res - } else { - (program_with_strings, Map[ClassDef, ClassDef](), Map[Identifier, Identifier](), Map[FunDef, FunDef]()) - } - val fdMapInverse = fdMap.map(kv => kv._2 -> kv._1).toMap - val idMapInverse = idMap.map(kv => kv._2 -> kv._1).toMap - var globalFdMap = Map[FunDef, (Map[Identifier, Identifier], FunDef)]() - val (new_program, _) = DefOps.replaceFunDefs(program_with_correct_classes)((fd: FunDef) => { - globalFdMap.get(fd).map(_._2).orElse( - if(thatShouldBeConverted(fd)) { - val idMap = fd.params.zip(fd.params).map(origvd_vd => origvd_vd._1.id -> convertId(origvd_vd._2.id)).toMap - val newFdId = convertId(fd.id) - val newFd = fd.duplicate(newFdId, - fd.tparams, - fd.params.map(vd => ValDef(idMap(vd.id))), - convertType(fd.returnType)) - globalFdMap += fd -> ((idMap, newFd)) - hasStrings = hasStrings || (program_with_strings.library.escape.get != fd) - Some(newFd) - } else None - ) - }) - if(!hasStrings) { - (p, None) - } else { - converter.globalFdMap ++= globalFdMap.view.map(kv => (kv._1, kv._2._2)) - for((fd, (idMap, newFd)) <- globalFdMap) { - implicit val idVarMap = idMap.mapValues(id => Variable(id)) - newFd.fullBody = convertExpr(newFd.fullBody) - } - converter.mappedVariables.composeA(id => idMapInverse.getOrElse(id, id)) - converter.globalFdMap.composeA(fd => fdMapInverse.getOrElse(fd, fd)) - converter.globalClassMap ++= cdMap - (new_program, Some(converter)) - } - } -} - -abstract class Z3StringCapableSolver[+TUnderlying <: Solver]( - val context: LeonContext, - val program: Program, - val underlyingConstructor: (Program, Option[Z3StringConversion]) => TUnderlying) extends Solver { - - protected val (new_program, optConverter) = Z3StringCapableSolver.convert(program) - var someConverter = optConverter - - val underlying = underlyingConstructor(new_program, someConverter) - var solverInvokedWithStrings = false - - def getModel: leon.solvers.Model = { - val model = underlying.getModel - someConverter match { - case None => model - case Some(converter) => - val ids = model.ids.toSeq - val exprs = ids.map(model.apply) - import converter.Backward._ - val original_ids = ids.map(convertId) - val original_exprs = exprs.map{ case e => convertExpr(e)(Map()) } - - model match { - case hm: PartialModel => - val new_domain = new Domains( - hm.domains.lambdas.map(kv => - (convertExpr(kv._1)(Map()).asInstanceOf[Lambda], - kv._2.map(e => e.map(e => convertExpr(e)(Map()))))).toMap, - hm.domains.tpes.map(kv => - (convertType(kv._1), - kv._2.map(e => e.map(e => convertExpr(e)(Map()))))).toMap - ) - - new PartialModel(original_ids.zip(original_exprs).toMap, new_domain) - case _ => - new Model(original_ids.zip(original_exprs).toMap) - } - } - } - - // Members declared in leon.utils.Interruptible - def interrupt(): Unit = underlying.interrupt() - def recoverInterrupt(): Unit = underlying.recoverInterrupt() - - // Converts expression on the fly if needed, creating a string converter if needed. - def convertExprOnTheFly(expression: Expr, withConverter: Z3StringConversion => Expr): Expr = { - someConverter match { - case None => - if(solverInvokedWithStrings || exists(e => TypeOps.exists(StringType == _)(e.getType))(expression)) { // On the fly conversion - solverInvokedWithStrings = true - val c = new Z3StringConversion(program) - someConverter = Some(c) - withConverter(c) - } else expression - case Some(converter) => - withConverter(converter) - } - } - - // Members declared in leon.solvers.Solver - def assertCnstr(expression: Expr): Unit = { - someConverter.map{converter => - import converter.Forward._ - val newExpression = convertExpr(expression)(Map()) - underlying.assertCnstr(newExpression) - }.getOrElse{ - underlying.assertCnstr(convertExprOnTheFly(expression, _.Forward.convertExpr(expression)(Map()))) - } - } - def getUnsatCore: Set[Expr] = { - someConverter.map{converter => - import converter.Backward._ - underlying.getUnsatCore map (e => convertExpr(e)(Map())) - }.getOrElse(underlying.getUnsatCore) - } - - def check: Option[Boolean] = underlying.check - def free(): Unit = underlying.free() - def pop(): Unit = underlying.pop() - def push(): Unit = underlying.push() - def reset(): Unit = underlying.reset() - def name: String = underlying.name -} - -import z3._ - -trait Z3StringAbstractZ3Solver[TUnderlying <: Solver] extends AbstractZ3Solver { self: Z3StringCapableSolver[TUnderlying] => -} - -trait Z3StringNaiveAssumptionSolver[TUnderlying <: Solver] extends NaiveAssumptionSolver { self: Z3StringCapableSolver[TUnderlying] => -} - -trait Z3StringEvaluatingSolver[TUnderlying <: EvaluatingSolver] extends EvaluatingSolver{ self: Z3StringCapableSolver[TUnderlying] => - // Members declared in leon.solvers.EvaluatingSolver - val useCodeGen: Boolean = underlying.useCodeGen -} - -class ConvertibleCodeGenEvaluator(context: LeonContext, originalProgram: Program, val converter: Z3StringConversion) - extends CodeGenEvaluator(context, originalProgram) { - - override def compile(expression: Expr, args: Seq[Identifier]) : Option[solvers.Model=>EvaluationResult] = { - import converter._ - super.compile(Backward.convertExpr(expression)(Map()), args.map(Backward.convertId)) - .map(evaluator => (m: Model) => Forward.convertResult(evaluator(Backward.convertModel(m))) - ) - } -} - -class ConvertibleDefaultEvaluator(context: LeonContext, originalProgram: Program, val converter: Z3StringConversion) - extends DefaultEvaluator(context, originalProgram) { - - override def eval(ex: Expr, model: Model): EvaluationResults.Result[Expr] = { - import converter._ - Forward.convertResult(super.eval(Backward.convertExpr(ex)(Map()), Backward.convertModel(model))) - } -} - -class FairZ3SolverWithBackwardEvaluator(context: LeonContext, program: Program, - originalProgram: Program, someConverter: Option[Z3StringConversion]) extends FairZ3Solver(context, program) { - override lazy val evaluator: DeterministicEvaluator = { // We evaluate expressions using the original evaluator - someConverter match { - case Some(converter) => - if (useCodeGen) { - new ConvertibleCodeGenEvaluator(context, originalProgram, converter) - } else { - new ConvertibleDefaultEvaluator(context, originalProgram, converter) - } - case None => - if (useCodeGen) { - new CodeGenEvaluator(context, program) - } else { - new DefaultEvaluator(context, program) - } - } - } -} - -class Z3StringFairZ3Solver(context: LeonContext, program: Program) - extends Z3StringCapableSolver(context, program, - (prgm: Program, someConverter: Option[Z3StringConversion]) => - new FairZ3SolverWithBackwardEvaluator(context, prgm, program, someConverter)) - with Z3StringEvaluatingSolver[FairZ3Solver] { - - // Members declared in leon.solvers.z3.AbstractZ3Solver - protected[leon] val z3cfg: _root_.z3.scala.Z3Config = underlying.z3cfg - override def checkAssumptions(assumptions: Set[Expr]): Option[Boolean] = { - someConverter match { - case None => underlying.checkAssumptions(assumptions.map(e => convertExprOnTheFly(e, _.Forward.convertExpr(e)(Map())))) - case Some(converter) => - underlying.checkAssumptions(assumptions map (e => converter.Forward.convertExpr(e)(Map()))) - } - } -} - -class Z3StringUnrollingSolver(context: LeonContext, program: Program, underlyingSolverConstructor: Program => Solver) - extends Z3StringCapableSolver(context, program, (program: Program, converter: Option[Z3StringConversion]) => - new UnrollingSolver(context, program, underlyingSolverConstructor(program))) - with Z3StringNaiveAssumptionSolver[UnrollingSolver] - with Z3StringEvaluatingSolver[UnrollingSolver] { - - override def getUnsatCore = super[Z3StringNaiveAssumptionSolver].getUnsatCore -} - -class Z3StringSMTLIBZ3QuantifiedSolver(context: LeonContext, program: Program) - extends Z3StringCapableSolver(context, program, (program: Program, converter: Option[Z3StringConversion]) => - new smtlib.SMTLIBZ3QuantifiedSolver(context, program)) { - - override def checkAssumptions(assumptions: Set[Expr]): Option[Boolean] = { - someConverter match { - case None => underlying.checkAssumptions(assumptions) - case Some(converter) => - underlying.checkAssumptions(assumptions map (e => converter.Forward.convertExpr(e)(Map()))) - } - } -} - diff --git a/src/main/scala/leon/solvers/cvc4/CVC4UnrollingSolver.scala b/src/main/scala/leon/solvers/cvc4/CVC4UnrollingSolver.scala new file mode 100644 index 000000000..79498f97b --- /dev/null +++ b/src/main/scala/leon/solvers/cvc4/CVC4UnrollingSolver.scala @@ -0,0 +1,13 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package solvers +package cvc4 + +import purescala.Definitions._ + +import unrolling._ +import theories._ + +class CVC4UnrollingSolver(context: LeonContext, program: Program, underlying: Solver) + extends UnrollingSolver(context, program, underlying, theories = NoEncoder) diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBQuantifiedTarget.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBQuantifiedTarget.scala index 980d87fa8..76037a945 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBQuantifiedTarget.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBQuantifiedTarget.scala @@ -26,7 +26,7 @@ trait SMTLIBQuantifiedTarget extends SMTLIBTarget { protected def withInductiveHyp(cond: Expr): Expr = { val inductiveHyps = for { - fi@FunctionInvocation(tfd, args) <- functionCallsOf(cond).toSeq + fi @ FunctionInvocation(tfd, args) <- functionCallsOf(cond).toSeq } yield { val post = application( tfd.withParamSubst(args, tfd.postOrTrue), @@ -38,6 +38,5 @@ trait SMTLIBQuantifiedTarget extends SMTLIBTarget { // We want to check if the negation of the vc is sat under inductive hyp. // So we need to see if (indHyp /\ !vc) is satisfiable liftLets(matchToIfThenElse(andJoin(inductiveHyps :+ not(cond)))) - } } diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala index 72b475c5c..95f924a46 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala @@ -13,13 +13,18 @@ import _root_.smtlib.parser.Commands.{Assert => SMTAssert, FunDef => SMTFunDef, import _root_.smtlib.parser.Terms.{Identifier => SMTIdentifier, _} import _root_.smtlib.parser.CommandsResponses.{Error => ErrorResponse, _} -abstract class SMTLIBSolver(val context: LeonContext, val program: Program) +import theories._ +import utils._ + +abstract class SMTLIBSolver(val context: LeonContext, val program: Program, theories: TheoryEncoder = NoEncoder) extends Solver with SMTLIBTarget with NaiveAssumptionSolver { /* Solver name */ def targetName: String override def name: String = "smt-"+targetName + private val ids = new IncrementalBijection[Identifier, Identifier]() + override def dbg(msg: => Any) = { debugOut foreach { o => o.write(msg.toString) @@ -28,8 +33,10 @@ abstract class SMTLIBSolver(val context: LeonContext, val program: Program) } /* Public solver interface */ - def assertCnstr(expr: Expr): Unit = if(!hasError) { + def assertCnstr(raw: Expr): Unit = if (!hasError) { try { + val bindings = variablesOf(raw).map(id => id -> ids.cachedB(id)(theories.encode(id))).toMap + val expr = theories.encode(raw)(bindings) variablesOf(expr).foreach(declareVariable) val term = toSMT(expr)(Map()) @@ -85,7 +92,8 @@ abstract class SMTLIBSolver(val context: LeonContext, val program: Program) for (me <- smodel) me match { case DefineFun(SMTFunDef(s, args, kind, e)) if syms(s) => val id = variables.toA(s) - model += id -> fromSMT(e, id.getType)(Map(), modelFunDefs) + val value = fromSMT(e, id.getType)(Map(), modelFunDefs) + model += ids.getAorElse(id, id) -> theories.decode(value)(variablesOf(value).map(id => id -> ids.toA(id)).toMap) case _ => } @@ -101,9 +109,10 @@ abstract class SMTLIBSolver(val context: LeonContext, val program: Program) } } - override def getModel: Model = getModel( _ => true) + override def getModel: Model = getModel(_ => true) override def push(): Unit = { + ids.push() constructors.push() selectors.push() testers.push() @@ -117,6 +126,7 @@ abstract class SMTLIBSolver(val context: LeonContext, val program: Program) } override def pop(): Unit = { + ids.pop() constructors.pop() selectors.pop() testers.pop() diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala index 546e62603..7418526c9 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala @@ -294,7 +294,6 @@ trait SMTLIBTarget extends Interruptible { conflicts.foreach { declareStructuralSort } declareStructuralSort(t) } - } protected def declareVariable(id: Identifier): SSymbol = { diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3QuantifiedSolver.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3QuantifiedSolver.scala index 083e07f34..0be498f13 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3QuantifiedSolver.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3QuantifiedSolver.scala @@ -1,14 +1,18 @@ /* Copyright 2009-2016 EPFL, Lausanne */ package leon -package solvers.smtlib +package solvers +package smtlib import purescala.Definitions.Program +import theories._ + /** * This solver models function definitions as universally quantified formulas. * It is not meant as an underlying solver to UnrollingSolver, and does not handle HOFs. */ -class SMTLIBZ3QuantifiedSolver(context: LeonContext, program: Program) extends SMTLIBZ3Solver(context, program) - with SMTLIBQuantifiedSolver - with SMTLIBZ3QuantifiedTarget +class SMTLIBZ3QuantifiedSolver(context: LeonContext, program: Program) + extends SMTLIBZ3Solver(context, program) + with SMTLIBQuantifiedSolver + with SMTLIBZ3QuantifiedTarget diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3QuantifiedTarget.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3QuantifiedTarget.scala index fbd4b581c..f355392d0 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3QuantifiedTarget.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3QuantifiedTarget.scala @@ -65,6 +65,5 @@ trait SMTLIBZ3QuantifiedTarget extends SMTLIBZ3Target with SMTLIBQuantifiedTarge } functions.toB(tfd) - } } diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Solver.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Solver.scala index 58a88d1a6..e12b53fb5 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Solver.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Solver.scala @@ -13,49 +13,10 @@ import _root_.smtlib.parser.Commands.{FunDef => SMTFunDef, _} import _root_.smtlib.parser.CommandsResponses.GetModelResponseSuccess import _root_.smtlib.theories.Core.{Equals => _, _} -class SMTLIBZ3Solver(context: LeonContext, program: Program) extends SMTLIBSolver(context, program) with SMTLIBZ3Target { +import theories._ - def getProgram: Program = program - - // EK: We use get-model instead in order to extract models for arrays - override def getModel: Model = { - - val res = emit(GetModel()) - - val smodel: Seq[SExpr] = res match { - case GetModelResponseSuccess(model) => model - case _ => Nil - } - - var modelFunDefs = Map[SSymbol, DefineFun]() - - // First pass to gather functions (arrays defs) - for (me <- smodel) me match { - case me @ DefineFun(SMTFunDef(a, args, _, _)) if args.nonEmpty => - modelFunDefs += a -> me - case _ => - } - - var model = Map[Identifier, Expr]() - - for (me <- smodel) me match { - case DefineFun(SMTFunDef(s, args, kind, e)) => - if(args.isEmpty) { - variables.getA(s) match { - case Some(id) => - // EK: this is a little hack, we pass models for array functions as let-defs - try { - model += id -> fromSMT(e, id.getType)(Map(), modelFunDefs) - } catch { - case _ : Unsupported => - - } - case _ => // function, should be handled elsewhere - } - } - case _ => - } - new Model(model) - } +class SMTLIBZ3Solver(context: LeonContext, program: Program) + extends SMTLIBSolver(context, program, StringEncoder) with SMTLIBZ3Target { + def getProgram: Program = program } diff --git a/src/main/scala/leon/solvers/theories/BagEncoder.scala b/src/main/scala/leon/solvers/theories/BagEncoder.scala new file mode 100644 index 000000000..fb6674224 --- /dev/null +++ b/src/main/scala/leon/solvers/theories/BagEncoder.scala @@ -0,0 +1,14 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package solvers +package theories + +import purescala.Common._ +import purescala.Expressions._ +import purescala.Types._ + +object BagEncoder extends TheoryEncoder { + val encoder = new Encoder + val decoder = new Decoder +} diff --git a/src/main/scala/leon/solvers/theories/StringEncoder.scala b/src/main/scala/leon/solvers/theories/StringEncoder.scala new file mode 100644 index 000000000..f8b55303b --- /dev/null +++ b/src/main/scala/leon/solvers/theories/StringEncoder.scala @@ -0,0 +1,203 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package solvers +package theories + +import purescala.Common._ +import purescala.Expressions._ +import purescala.Constructors._ +import purescala.Types._ +import purescala.Definitions._ +import leon.utils.Bijection +import leon.purescala.DefOps +import leon.purescala.TypeOps +import leon.purescala.Extractors.Operator +import leon.evaluators.EvaluationResults + +object StringEcoSystem { + private def withIdentifier[T](name: String, tpe: TypeTree = Untyped)(f: Identifier => T): T = { + val id = FreshIdentifier(name, tpe) + f(id) + } + + private def withIdentifiers[T](name: String, tpe: TypeTree, name2: String, tpe2: TypeTree = Untyped)(f: (Identifier, Identifier) => T): T = { + withIdentifier(name, tpe)(id => withIdentifier(name2, tpe2)(id2 => f(id, id2))) + } + + val StringList = new AbstractClassDef(FreshIdentifier("StringList"), Seq(), None) + val StringListTyped = StringList.typed + val StringCons = withIdentifiers("head", CharType, "tail", StringListTyped){ (head, tail) => + val d = new CaseClassDef(FreshIdentifier("StringCons"), Seq(), Some(StringListTyped), false) + d.setFields(Seq(ValDef(head), ValDef(tail))) + d + } + + StringList.registerChild(StringCons) + val StringConsTyped = StringCons.typed + val StringNil = new CaseClassDef(FreshIdentifier("StringNil"), Seq(), Some(StringListTyped), false) + val StringNilTyped = StringNil.typed + StringList.registerChild(StringNil) + + val StringSize = withIdentifiers("l", StringListTyped, "StringSize"){ (lengthArg, id) => + val fd = new FunDef(id, Seq(), Seq(ValDef(lengthArg)), IntegerType) + fd.body = Some(withIdentifiers("h", CharType, "t", StringListTyped){ (h, t) => + MatchExpr(Variable(lengthArg), Seq( + MatchCase(CaseClassPattern(None, StringNilTyped, Seq()), None, InfiniteIntegerLiteral(BigInt(0))), + MatchCase(CaseClassPattern(None, StringConsTyped, Seq(WildcardPattern(Some(h)), WildcardPattern(Some(t)))), None, + Plus(InfiniteIntegerLiteral(BigInt(1)), FunctionInvocation(fd.typed, Seq(Variable(t))))) + )) + }) + fd + } + + val StringListConcat = withIdentifiers("x", StringListTyped, "y", StringListTyped) { (x, y) => + val fd = new FunDef(FreshIdentifier("StringListConcat"), Seq(), Seq(ValDef(x), ValDef(y)), StringListTyped) + fd.body = Some( + withIdentifiers("h", CharType, "t", StringListTyped){ (h, t) => + MatchExpr(Variable(x), Seq( + MatchCase(CaseClassPattern(None, StringNilTyped, Seq()), None, Variable(y)), + MatchCase(CaseClassPattern(None, StringConsTyped, Seq(WildcardPattern(Some(h)), WildcardPattern(Some(t)))), None, + CaseClass(StringConsTyped, Seq(Variable(h), FunctionInvocation(fd.typed, Seq(Variable(t), Variable(y))))) + ))) + } + ) + fd + } + + val StringTake = withIdentifiers("tt", StringListTyped, "it", StringListTyped) { (tt, it) => + val fd = new FunDef(FreshIdentifier("StringTake"), Seq(), Seq(ValDef(tt), ValDef(it)), StringListTyped) + fd.body = Some{ + withIdentifiers("h", CharType, "t", StringListTyped) { (h, t) => + withIdentifier("i", IntegerType){ i => + MatchExpr(Tuple(Seq(Variable(tt), Variable(it))), Seq( + MatchCase(TuplePattern(None, Seq(CaseClassPattern(None, StringNilTyped, Seq()), WildcardPattern(None))), None, + InfiniteIntegerLiteral(BigInt(0))), + MatchCase(TuplePattern(None, Seq(CaseClassPattern(None, StringConsTyped, Seq(WildcardPattern(Some(h)), WildcardPattern(Some(t)))), WildcardPattern(Some(i)))), None, + IfExpr(LessThan(Variable(i), InfiniteIntegerLiteral(BigInt(0))), + CaseClass(StringNilTyped, Seq()), + CaseClass(StringConsTyped, Seq(Variable(h), + FunctionInvocation(fd.typed, Seq(Variable(t), Minus(Variable(i), InfiniteIntegerLiteral(BigInt(1))))))) + )))) + } + } + } + fd + } + + val StringDrop = withIdentifiers("td", StringListTyped, "id", IntegerType) { (td, id) => + val fd = new FunDef(FreshIdentifier("StringDrop"), Seq(), Seq(ValDef(td), ValDef(id)), StringListTyped) + fd.body = Some( + withIdentifiers("h", CharType, "t", StringListTyped) { (h, t) => + withIdentifier("i", IntegerType){ i => + MatchExpr(Tuple(Seq(Variable(td), Variable(id))), Seq( + MatchCase(TuplePattern(None, Seq(CaseClassPattern(None, StringNilTyped, Seq()), WildcardPattern(None))), None, + InfiniteIntegerLiteral(BigInt(0))), + MatchCase(TuplePattern(None, Seq(CaseClassPattern(None, StringConsTyped, Seq(WildcardPattern(Some(h)), WildcardPattern(Some(t)))), WildcardPattern(Some(i)))), None, + IfExpr(LessThan(Variable(i), InfiniteIntegerLiteral(BigInt(0))), + CaseClass(StringConsTyped, Seq(Variable(h), Variable(t))), + FunctionInvocation(fd.typed, Seq(Variable(t), Minus(Variable(i), InfiniteIntegerLiteral(BigInt(1))))) + )))) + }} + ) + fd + } + + val StringSlice = withIdentifier("s", StringListTyped) { s => withIdentifiers("from", IntegerType, "to", IntegerType) { (from, to) => + val fd = new FunDef(FreshIdentifier("StringSlice"), Seq(), Seq(ValDef(s), ValDef(from), ValDef(to)), StringListTyped) + fd.body = Some( + FunctionInvocation(StringTake.typed, + Seq(FunctionInvocation(StringDrop.typed, Seq(Variable(s), Variable(from))), + Minus(Variable(to), Variable(from))))) + fd + } } + + val classDefs = Seq(StringList, StringCons, StringNil) + val funDefs = Seq(StringSize, StringListConcat, StringTake, StringDrop, StringSlice) +} + +object StringEncoder extends TheoryEncoder { + import StringEcoSystem._ + + private val stringBijection = new Bijection[String, Expr]() + + private def convertToString(e: Expr): String = stringBijection.cachedA(e)(e match { + case CaseClass(_, Seq(CharLiteral(c), l)) => c + convertToString(l) + case CaseClass(_, Seq()) => "" + }) + + private def convertFromString(v: String): Expr = stringBijection.cachedB(v) { + v.toList.foldRight(CaseClass(StringNilTyped, Seq())){ + case (char, l) => CaseClass(StringConsTyped, Seq(CharLiteral(char), l)) + } + } + + val encoder = new Encoder { + override def transform(e: Expr)(implicit binders: Map[Identifier, Identifier]): Expr = e match { + case StringLiteral(v) => + convertFromString(v) + case StringLength(a) => + FunctionInvocation(StringSize.typed, Seq(transform(a))).copiedFrom(e) + case StringConcat(a, b) => + FunctionInvocation(StringListConcat.typed, Seq(transform(a), transform(b))).copiedFrom(e) + case SubString(a, start, Plus(start2, length)) if start == start2 => + FunctionInvocation(StringTake.typed, Seq(FunctionInvocation(StringDrop.typed, Seq(transform(a), transform(start))), transform(length))).copiedFrom(e) + case SubString(a, start, end) => + FunctionInvocation(StringSlice.typed, Seq(transform(a), transform(start), transform(end))).copiedFrom(e) + case _ => super.transform(e) + } + + override def transform(tpe: TypeTree): TypeTree = tpe match { + case StringType => StringListTyped + case _ => super.transform(tpe) + } + + override def transform(pat: Pattern): (Pattern, Map[Identifier, Identifier]) = pat match { + case LiteralPattern(binder, StringLiteral(s)) => + val newBinder = binder map transform + val newPattern = s.foldRight(CaseClassPattern(None, StringNilTyped, Seq())) { + case (elem, pattern) => CaseClassPattern(None, StringConsTyped, Seq(LiteralPattern(None, CharLiteral(elem)), pattern)) + } + (newPattern.copy(binder = newBinder), (binder zip newBinder).filter(p => p._1 != p._2).toMap) + case _ => super.transform(pat) + } + } + + val decoder = new Decoder { + override def transform(e: Expr)(implicit binders: Map[Identifier, Identifier]): Expr = e match { + case cc @ CaseClass(cct, args) if TypeOps.isSubtypeOf(cct, StringListTyped)=> + StringLiteral(convertToString(cc)).copiedFrom(cc) + case FunctionInvocation(StringSize, Seq(a)) => + StringLength(transform(a)).copiedFrom(e) + case FunctionInvocation(StringListConcat, Seq(a, b)) => + StringConcat(transform(a), transform(b)).copiedFrom(e) + case FunctionInvocation(StringTake, Seq(FunctionInvocation(StringDrop, Seq(a, start)), length)) => + val rstart = transform(start) + SubString(transform(a), rstart, plus(rstart, transform(length))).copiedFrom(e) + case _ => super.transform(e) + } + + + override def transform(tpe: TypeTree): TypeTree = tpe match { + case StringListTyped | StringConsTyped | StringNilTyped => StringType + case _ => super.transform(tpe) + } + + override def transform(pat: Pattern): (Pattern, Map[Identifier, Identifier]) = pat match { + case CaseClassPattern(b, StringNilTyped, Seq()) => + val newBinder = b map transform + (LiteralPattern(newBinder , StringLiteral("")), (b zip newBinder).filter(p => p._1 != p._2).toMap) + + case CaseClassPattern(b, StringConsTyped, Seq(LiteralPattern(_, CharLiteral(elem)), sub)) => transform(sub) match { + case (LiteralPattern(_, StringLiteral(s)), binders) => + val newBinder = b map transform + (LiteralPattern(newBinder, StringLiteral(elem + s)), (b zip newBinder).filter(p => p._1 != p._2).toMap ++ binders) + case (e, binders) => + (LiteralPattern(None, StringLiteral("Failed to parse pattern back as string:" + e)), binders) + } + + case _ => super.transform(pat) + } + } +} + diff --git a/src/main/scala/leon/solvers/theories/TheoryEncoder.scala b/src/main/scala/leon/solvers/theories/TheoryEncoder.scala new file mode 100644 index 000000000..e0de93c43 --- /dev/null +++ b/src/main/scala/leon/solvers/theories/TheoryEncoder.scala @@ -0,0 +1,249 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package solvers +package theories + +import purescala.Common._ +import purescala.Expressions._ +import purescala.Definitions._ +import purescala.Extractors._ +import purescala.ExprOps._ +import purescala.Types._ + +import utils._ + +import scala.collection.mutable.{Map => MutableMap, Set => MutableSet} + +trait TheoryEncoder { self => + protected val encoder: Encoder + protected val decoder: Decoder + + private val idMap = new Bijection[Identifier, Identifier] + private val fdMap = new Bijection[FunDef , FunDef ] + private val cdMap = new Bijection[ClassDef , ClassDef ] + + def encode(id: Identifier): Identifier = encoder.transform(id) + def decode(id: Identifier): Identifier = decoder.transform(id) + + def encode(expr: Expr)(implicit bindings: Map[Identifier, Identifier]): Expr = encoder.transform(expr) + def decode(expr: Expr)(implicit bindings: Map[Identifier, Identifier]): Expr = decoder.transform(expr) + + def encode(tpe: TypeTree): TypeTree = encoder.transform(tpe) + def decode(tpe: TypeTree): TypeTree = decoder.transform(tpe) + + def encode(fd: FunDef): FunDef = encoder.transform(fd) + def decode(fd: FunDef): FunDef = decoder.transform(fd) + + protected trait Converter extends purescala.TreeTransformer { + private[TheoryEncoder] val idMap: Bijection[Identifier, Identifier] + private[TheoryEncoder] val fdMap: Bijection[FunDef , FunDef ] + private[TheoryEncoder] val cdMap: Bijection[ClassDef , ClassDef ] + + override def transform(id: Identifier): Identifier = idMap.cachedB(id) { + val ntpe = transform(id.getType) + if (ntpe == id.getType) id else id.duplicate(tpe = ntpe) + } + + override def transform(fd: FunDef): FunDef = fdMap.getBorElse(fd, if (tmpDefs(fd)) fd else { + transformDefs(fd) + fdMap.toB(fd) + }) + + override def transform(cd: ClassDef): ClassDef = cdMap.getBorElse(cd, if (tmpDefs(cd)) cd else { + transformDefs(cd) + cdMap.toB(cd) + }) + + private val deps: MutableMap[Definition, Set[Definition]] = MutableMap.empty + private val tmpDefs: MutableSet[Definition] = MutableSet.empty + + private class DependencyFinder(var current: Definition) extends purescala.TreeTraverser { + val deps: MutableMap[Definition, MutableSet[Definition]] = MutableMap.empty + deps(current) = MutableSet.empty + + private def withCurrent[T](d: Definition)(b: => T): T = { + if (!(deps contains d)) deps(d) = MutableSet.empty + val c = current + current = d + val res = b + current = c + res + } + + override def traverse(id: Identifier): Unit = traverse(id.getType) + + override def traverse(cd: ClassDef): Unit = if (!deps(current)(cd)) { + deps(current) += cd + if (!(Converter.this.deps contains cd) && !(deps contains cd)) { + for (cd <- cd.root.knownDescendants :+ cd) { + cd.invariant foreach (fd => withCurrent(cd)(traverse(fd))) + withCurrent(cd)(cd.fieldsIds foreach traverse) + cd.parent foreach { p => + deps(p.classDef) = deps.getOrElse(p.classDef, MutableSet.empty) + cd + deps(cd) = deps.getOrElse(cd, MutableSet.empty) + p.classDef + } + } + } + } + + override def traverse(fd: FunDef): Unit = if (!deps(current)(fd)) { + deps(current) += fd + if (!(Converter.this.deps contains fd) && !(deps contains fd)) withCurrent(fd) { + fd.params foreach (vd => traverse(vd.id)) + traverse(fd.returnType) + traverse(fd.fullBody) + } + } + + def dependencies: Set[Definition] = { + current match { + case fd: FunDef => traverse(fd) + case cd: ClassDef => traverse(cd) + case _ => + } + + for ((d, ds) <- deps) { + Converter.this.deps(d) = Converter.this.deps.getOrElse(d, Set.empty) ++ ds + } + + var changed = false + do { + for ((d, ds) <- Converter.this.deps.toSeq) { + val next = ds.flatMap(d => Converter.this.deps.getOrElse(d, Set.empty)) + if (!(next subsetOf ds)) { + Converter.this.deps(d) = next + changed = true + } + } + } while (changed) + + Converter.this.deps(current) + } + } + + private def dependencies(d: Definition): Set[Definition] = deps.getOrElse(d, { + new DependencyFinder(d).dependencies + }) + + private def transformDefs(base: Definition): Unit = { + val deps = dependencies(base) + val (cds, fds) = { + val (c, f) = deps.partition(_.isInstanceOf[ClassDef]) + (c.map(_.asInstanceOf[ClassDef]), f.map(_.asInstanceOf[FunDef])) + } + + tmpDefs ++= cds.filterNot(cdMap containsA _) ++ fds.filterNot(fdMap containsA _) + + var requireCache: Map[Definition, Boolean] = Map.empty + def required(d: Definition): Boolean = requireCache.getOrElse(d, { + val res = d match { + case fd: FunDef => + val newReturn = transform(fd.returnType) + lazy val newParams = fd.params.map(vd => ValDef(transform(vd.id))) + lazy val newBody = transform(fd.fullBody)((fd.params.map(_.id) zip newParams.map(_.id)).toMap) + newReturn != fd.returnType || newParams != fd.params || newBody != fd.fullBody + + case cd: ClassDef => + cd.fieldsIds.exists(id => transform(id.getType) != id.getType) || + cd.invariant.exists(required) + + case _ => scala.sys.error("Should never happen!?") + } + + requireCache += d -> res + res + }) + + val req = deps filter required + val allReq = req ++ (deps filter (d => (dependencies(d) & req).nonEmpty)) + val requiredCds = allReq collect { case cd: ClassDef => cd } + val requiredFds = allReq collect { case fd: FunDef => fd } + tmpDefs --= deps + + val nonReq = deps filterNot allReq + cdMap ++= nonReq collect { case cd: ClassDef => cd -> cd } + fdMap ++= nonReq collect { case fd: FunDef => fd -> fd } + + def trCd(cd: ClassDef): ClassDef = cdMap.cachedB(cd) { + val parent = cd.parent.map(act => act.copy(classDef = trCd(act.classDef).asInstanceOf[AbstractClassDef])) + cd match { + case acd: AbstractClassDef => acd.duplicate(id = transform(acd.id), parent = parent) + case ccd: CaseClassDef => ccd.duplicate(id = transform(ccd.id), parent = parent) + } + } + + for (cd <- requiredCds) trCd(cd) + for (fd <- requiredFds) { + val newReturn = transform(fd.returnType) + val newParams = fd.params map (vd => ValDef(transform(vd.id))) + fdMap += fd -> fd.duplicate(id = transform(fd.id), params = newParams, returnType = newReturn) + } + + for (ccd <- requiredCds collect { case ccd: CaseClassDef => ccd }) { + val newCcd = cdMap.toB(ccd).asInstanceOf[CaseClassDef] + newCcd.setFields(ccd.fields.map(vd => ValDef(transform(vd.id)))) + newCcd.invariant.foreach(fd => ccd.setInvariant(transform(fd))) + } + + for (fd <- requiredFds) { + val nfd = fdMap.toB(fd) + fd.fullBody = transform(fd.fullBody)((fd.params zip nfd.params).map(p => p._1.id -> p._2.id).toMap) + } + } + } + + protected class Encoder extends Converter { + private[TheoryEncoder] final val idMap: Bijection[Identifier, Identifier] = TheoryEncoder.this.idMap + private[TheoryEncoder] final val fdMap: Bijection[FunDef , FunDef ] = TheoryEncoder.this.fdMap + private[TheoryEncoder] final val cdMap: Bijection[ClassDef , ClassDef ] = TheoryEncoder.this.cdMap + } + + protected class Decoder extends Converter { + private[TheoryEncoder] final val idMap: Bijection[Identifier, Identifier] = TheoryEncoder.this.idMap.swap + private[TheoryEncoder] final val fdMap: Bijection[FunDef , FunDef ] = TheoryEncoder.this.fdMap.swap + private[TheoryEncoder] final val cdMap: Bijection[ClassDef , ClassDef ] = TheoryEncoder.this.cdMap.swap + } + + def >>(that: TheoryEncoder): TheoryEncoder = new TheoryEncoder { + val encoder = new Encoder { + override def transform(expr: Expr)(implicit bindings: Map[Identifier, Identifier]): Expr = { + val mapSeq = bindings.toSeq + val intermediate = mapSeq.map { case (id, _) => id.duplicate(tpe = self.encoder.transform(id.getType)) } + val e2 = self.encoder.transform(expr)((mapSeq zip intermediate).map { case ((id, _), id2) => id -> id2 }.toMap) + that.encoder.transform(e2)((intermediate zip mapSeq).map { case (id, (_, id2)) => id -> id2 }.toMap) + } + + override def transform(tpe: TypeTree): TypeTree = that.encoder.transform(self.encoder.transform(tpe)) + + override def transform(pat: Pattern): (Pattern, Map[Identifier, Identifier]) = { + val (pat2, bindings) = self.encoder.transform(pat) + val (pat3, bindings2) = that.encoder.transform(pat2) + (pat3, bindings2.map { case (id, id2) => id -> bindings2(id2) }) + } + } + + val decoder = new Decoder { + override def transform(expr: Expr)(implicit bindings: Map[Identifier, Identifier]): Expr = { + val mapSeq = bindings.toSeq + val intermediate = mapSeq.map { case (id, _) => id.duplicate(tpe = self.decoder.transform(id.getType)) } + val e2 = that.decoder.transform(expr)((mapSeq zip intermediate).map { case ((id, _), id2) => id -> id2 }.toMap) + self.decoder.transform(e2)((intermediate zip mapSeq).map { case (id, (_, id2)) => id -> id2 }.toMap) + } + + override def transform(tpe: TypeTree): TypeTree = self.decoder.transform(that.decoder.transform(tpe)) + + override def transform(pat: Pattern): (Pattern, Map[Identifier, Identifier]) = { + val (pat2, bindings) = that.decoder.transform(pat) + val (pat3, bindings2) = self.decoder.transform(pat2) + (pat3, bindings.map { case (id, id2) => id -> bindings2(id2) }) + } + } + } +} + +object NoEncoder extends TheoryEncoder { + val encoder = new Encoder + val decoder = new Decoder +} + diff --git a/src/main/scala/leon/solvers/templates/DatatypeManager.scala b/src/main/scala/leon/solvers/unrolling/DatatypeManager.scala similarity index 99% rename from src/main/scala/leon/solvers/templates/DatatypeManager.scala rename to src/main/scala/leon/solvers/unrolling/DatatypeManager.scala index dcfa67e83..275621828 100644 --- a/src/main/scala/leon/solvers/templates/DatatypeManager.scala +++ b/src/main/scala/leon/solvers/unrolling/DatatypeManager.scala @@ -2,7 +2,7 @@ package leon package solvers -package templates +package unrolling import purescala.Common._ import purescala.Definitions._ diff --git a/src/main/scala/leon/solvers/templates/LambdaManager.scala b/src/main/scala/leon/solvers/unrolling/LambdaManager.scala similarity index 98% rename from src/main/scala/leon/solvers/templates/LambdaManager.scala rename to src/main/scala/leon/solvers/unrolling/LambdaManager.scala index 1e715ce1d..f1ad010d8 100644 --- a/src/main/scala/leon/solvers/templates/LambdaManager.scala +++ b/src/main/scala/leon/solvers/unrolling/LambdaManager.scala @@ -2,7 +2,7 @@ package leon package solvers -package templates +package unrolling import purescala.Common._ import purescala.Definitions._ @@ -196,9 +196,9 @@ class LambdaTemplate[T] private ( } class LambdaManager[T](encoder: TemplateEncoder[T]) extends DatatypeManager(encoder) { - private[templates] lazy val trueT = encoder.encodeExpr(Map.empty)(BooleanLiteral(true)) + private[unrolling] lazy val trueT = encoder.encodeExpr(Map.empty)(BooleanLiteral(true)) - protected[templates] val byID = new IncrementalMap[T, LambdaTemplate[T]] + protected[unrolling] val byID = new IncrementalMap[T, LambdaTemplate[T]] protected val byType = new IncrementalMap[FunctionType, Map[(Expr, Seq[T]), LambdaTemplate[T]]].withDefaultValue(Map.empty) protected val applications = new IncrementalMap[FunctionType, Set[(T, App[T])]].withDefaultValue(Set.empty) protected val knownFree = new IncrementalMap[FunctionType, Set[T]].withDefaultValue(Set.empty) diff --git a/src/main/scala/leon/solvers/templates/QuantificationManager.scala b/src/main/scala/leon/solvers/unrolling/QuantificationManager.scala similarity index 99% rename from src/main/scala/leon/solvers/templates/QuantificationManager.scala rename to src/main/scala/leon/solvers/unrolling/QuantificationManager.scala index 4509e6069..84f7d0ae5 100644 --- a/src/main/scala/leon/solvers/templates/QuantificationManager.scala +++ b/src/main/scala/leon/solvers/unrolling/QuantificationManager.scala @@ -2,7 +2,7 @@ package leon package solvers -package templates +package unrolling import leon.utils._ import purescala.Common._ diff --git a/src/main/scala/leon/solvers/templates/TemplateEncoder.scala b/src/main/scala/leon/solvers/unrolling/TemplateEncoder.scala similarity index 61% rename from src/main/scala/leon/solvers/templates/TemplateEncoder.scala rename to src/main/scala/leon/solvers/unrolling/TemplateEncoder.scala index c2a2051b1..74488aa8e 100644 --- a/src/main/scala/leon/solvers/templates/TemplateEncoder.scala +++ b/src/main/scala/leon/solvers/unrolling/TemplateEncoder.scala @@ -2,10 +2,18 @@ package leon package solvers -package templates +package unrolling -import purescala.Common.Identifier -import purescala.Expressions.Expr +import purescala.Common._ +import purescala.Expressions._ +import purescala.Definitions._ +import purescala.Extractors._ +import purescala.ExprOps._ +import purescala.Types._ + +import utils._ + +import scala.collection.mutable.{Map => MutableMap, Set => MutableSet} trait TemplateEncoder[T] { def encodeId(id: Identifier): T @@ -21,3 +29,4 @@ trait TemplateEncoder[T] { def extractNot(v: T): Option[T] } + diff --git a/src/main/scala/leon/solvers/templates/TemplateGenerator.scala b/src/main/scala/leon/solvers/unrolling/TemplateGenerator.scala similarity index 96% rename from src/main/scala/leon/solvers/templates/TemplateGenerator.scala rename to src/main/scala/leon/solvers/unrolling/TemplateGenerator.scala index 9126f77c1..127c1004e 100644 --- a/src/main/scala/leon/solvers/templates/TemplateGenerator.scala +++ b/src/main/scala/leon/solvers/unrolling/TemplateGenerator.scala @@ -2,7 +2,7 @@ package leon package solvers -package templates +package unrolling import purescala.Common._ import purescala.Expressions._ @@ -14,13 +14,15 @@ import purescala.Definitions._ import purescala.Constructors._ import purescala.Quantification._ +import theories._ import utils.SeqUtils._ import Instantiation._ -class TemplateGenerator[T](val encoder: TemplateEncoder[T], +class TemplateGenerator[T](val theories: TheoryEncoder, + val encoder: TemplateEncoder[T], val assumePreHolds: Boolean) { private var cache = Map[TypedFunDef, FunctionTemplate[T]]() - private var cacheExpr = Map[Expr, FunctionTemplate[T]]() + private var cacheExpr = Map[Expr, (FunctionTemplate[T], Map[Identifier, Identifier])]() private type Clauses = ( Map[Identifier,T], @@ -45,20 +47,24 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], val manager = new QuantificationManager[T](encoder) - def mkTemplate(body: Expr): FunctionTemplate[T] = { - if (cacheExpr contains body) { - return cacheExpr(body) + def mkTemplate(raw: Expr): (FunctionTemplate[T], Map[Identifier, Identifier]) = { + if (cacheExpr contains raw) { + return cacheExpr(raw) } - val arguments = variablesOf(body).toSeq.map(ValDef(_)) + val mapping = variablesOf(raw).map(id => id -> theories.encode(id)).toMap + val body = theories.encode(raw)(mapping) + + val arguments = mapping.values.toSeq.map(ValDef(_)) val fakeFunDef = new FunDef(FreshIdentifier("fake", alwaysShowUniqueID = true), Nil, arguments, body.getType) fakeFunDef.precondition = Some(andJoin(arguments.map(vd => manager.typeUnroller(vd.toVariable)))) fakeFunDef.body = Some(body) val res = mkTemplate(fakeFunDef.typed, false) - cacheExpr += body -> res - res + val p = (res, mapping) + cacheExpr += raw -> p + p } def mkTemplate(tfd: TypedFunDef, isRealFunDef: Boolean = true): FunctionTemplate[T] = { diff --git a/src/main/scala/leon/solvers/templates/TemplateInfo.scala b/src/main/scala/leon/solvers/unrolling/TemplateInfo.scala similarity index 98% rename from src/main/scala/leon/solvers/templates/TemplateInfo.scala rename to src/main/scala/leon/solvers/unrolling/TemplateInfo.scala index 455704dc4..3dc60c89f 100644 --- a/src/main/scala/leon/solvers/templates/TemplateInfo.scala +++ b/src/main/scala/leon/solvers/unrolling/TemplateInfo.scala @@ -2,7 +2,7 @@ package leon package solvers -package templates +package unrolling import purescala.Definitions.TypedFunDef import Template.Arg diff --git a/src/main/scala/leon/solvers/templates/TemplateManager.scala b/src/main/scala/leon/solvers/unrolling/TemplateManager.scala similarity index 99% rename from src/main/scala/leon/solvers/templates/TemplateManager.scala rename to src/main/scala/leon/solvers/unrolling/TemplateManager.scala index 8c4752dbf..36bad0b58 100644 --- a/src/main/scala/leon/solvers/templates/TemplateManager.scala +++ b/src/main/scala/leon/solvers/unrolling/TemplateManager.scala @@ -2,7 +2,7 @@ package leon package solvers -package templates +package unrolling import purescala.Common._ import purescala.Definitions._ @@ -470,7 +470,7 @@ class FunctionTemplate[T] private( override def toString : String = str } -class TemplateManager[T](protected[templates] val encoder: TemplateEncoder[T]) extends IncrementalState { +class TemplateManager[T](protected[unrolling] val encoder: TemplateEncoder[T]) extends IncrementalState { private val condImplies = new IncrementalMap[T, Set[T]].withDefaultValue(Set.empty) private val condImplied = new IncrementalMap[T, Set[T]].withDefaultValue(Set.empty) diff --git a/src/main/scala/leon/solvers/templates/UnrollingBank.scala b/src/main/scala/leon/solvers/unrolling/UnrollingBank.scala similarity index 97% rename from src/main/scala/leon/solvers/templates/UnrollingBank.scala rename to src/main/scala/leon/solvers/unrolling/UnrollingBank.scala index 2383f65f6..feb618827 100644 --- a/src/main/scala/leon/solvers/templates/UnrollingBank.scala +++ b/src/main/scala/leon/solvers/unrolling/UnrollingBank.scala @@ -2,7 +2,7 @@ package leon package solvers -package templates +package unrolling import purescala.Common._ import purescala.Expressions._ @@ -161,13 +161,14 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat clause } - def getClauses(expr: Expr, bindings: Map[Expr, T]): Seq[T] = { + def getClauses(expr: Expr, bindings: Map[Identifier, T]): Seq[T] = { // OK, now this is subtle. This `getTemplate` will return // a template for a "fake" function. Now, this template will // define an activating boolean... - val template = templateGenerator.mkTemplate(expr) + val (template, mapping) = templateGenerator.mkTemplate(expr) + val reverse = mapping.map(p => p._2 -> p._1) - val trArgs = template.tfd.params.map(vd => Left(bindings(Variable(vd.id)))) + val trArgs = template.tfd.params.map(vd => Left(bindings(reverse(vd.id)))) // ...now this template defines clauses that are all guarded // by that activating boolean. If that activating boolean is @@ -211,7 +212,7 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat def promoteBlocker(b: T) = { if (callInfos contains b) { val (_, origGen, notB, fis) = callInfos(b) - + callInfos += b -> (1, origGen, notB, fis) } diff --git a/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala b/src/main/scala/leon/solvers/unrolling/UnrollingSolver.scala similarity index 92% rename from src/main/scala/leon/solvers/combinators/UnrollingSolver.scala rename to src/main/scala/leon/solvers/unrolling/UnrollingSolver.scala index 8a924762a..99fdbd048 100644 --- a/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala +++ b/src/main/scala/leon/solvers/unrolling/UnrollingSolver.scala @@ -2,7 +2,7 @@ package leon package solvers -package combinators +package unrolling import purescala.Common._ import purescala.Definitions._ @@ -15,7 +15,7 @@ import purescala.Types._ import purescala.TypeOps.bestRealType import utils._ -import templates._ +import theories._ import evaluators._ import Template._ @@ -55,9 +55,7 @@ trait AbstractUnrollingSolver[T] protected var definitiveModel : Model = Model.empty protected var definitiveCore : Set[Expr] = Set.empty - def check: Option[Boolean] = { - genericCheck(Set.empty) - } + def check: Option[Boolean] = genericCheck(Set.empty) def getModel: Model = if (foundDefinitiveAnswer && definitiveAnswer.getOrElse(false)) { definitiveModel @@ -71,14 +69,14 @@ trait AbstractUnrollingSolver[T] Set.empty } - private val freeVars = new IncrementalMap[Identifier, T]() private val constraints = new IncrementalSeq[Expr]() + private val freeVars = new IncrementalMap[Identifier, T]() protected var interrupted : Boolean = false protected val reporter = context.reporter - lazy val templateGenerator = new TemplateGenerator(templateEncoder, assumePreHolds) + lazy val templateGenerator = new TemplateGenerator(theoryEncoder, templateEncoder, assumePreHolds) lazy val unrollingBank = new UnrollingBank(context, templateGenerator) def push(): Unit = { @@ -110,11 +108,15 @@ trait AbstractUnrollingSolver[T] interrupted = false } - def assertCnstr(expression: Expr, bindings: Map[Identifier, T]): Unit = { + protected def declareVariable(id: Identifier): T + + def assertCnstr(expression: Expr): Unit = { constraints += expression - freeVars ++= bindings + val bindings = variablesOf(expression).map(id => id -> freeVars.cached(id) { + declareVariable(theoryEncoder.encode(id)) + }).toMap - val newClauses = unrollingBank.getClauses(expression, bindings.map { case (k, v) => Variable(k) -> v }) + val newClauses = unrollingBank.getClauses(expression, bindings) for (cl <- newClauses) { solverAssert(cl) } @@ -128,6 +130,8 @@ trait AbstractUnrollingSolver[T] } implicit val printable: T => Printable + + val theoryEncoder: TheoryEncoder val templateEncoder: TemplateEncoder[T] def solverAssert(cnstr: T): Unit @@ -166,8 +170,16 @@ trait AbstractUnrollingSolver[T] def solverUnsatCore: Option[Seq[T]] trait ModelWrapper { - def get(id: Identifier): Option[Expr] - def eval(elem: T, tpe: TypeTree): Option[Expr] + def modelEval(elem: T, tpe: TypeTree): Option[Expr] + + def eval(elem: T, tpe: TypeTree): Option[Expr] = modelEval(elem, theoryEncoder.encode(tpe)).map { + expr => theoryEncoder.decode(expr)(Map.empty) + } + + def get(id: Identifier): Option[Expr] = eval(freeVars(id), theoryEncoder.encode(id.getType)).filter { + case Variable(_) => false + case _ => true + } private[AbstractUnrollingSolver] def extract(b: T, m: Matcher[T]): Option[Seq[Expr]] = { val QuantificationTypeMatcher(fromTypes, _) = m.tpe @@ -228,6 +240,7 @@ trait AbstractUnrollingSolver[T] def genericCheck(assumptions: Set[Expr]): Option[Boolean] = { foundDefinitiveAnswer = false + // TODO: theory encoder for assumptions!? val encoder = templateGenerator.encoder.encodeExpr(freeVars.toMap) _ val assumptionsSeq : Seq[Expr] = assumptions.toSeq val encodedAssumptions : Seq[T] = assumptionsSeq.map(encoder) @@ -241,7 +254,7 @@ trait AbstractUnrollingSolver[T] }).toSet } - while(!foundDefinitiveAnswer && !interrupted) { + while (!foundDefinitiveAnswer && !interrupted) { reporter.debug(" - Running search...") var quantify = false @@ -432,8 +445,12 @@ trait AbstractUnrollingSolver[T] } } -class UnrollingSolver(val context: LeonContext, val program: Program, underlying: Solver) - extends AbstractUnrollingSolver[Expr] { +class UnrollingSolver( + val context: LeonContext, + val program: Program, + underlying: Solver, + theories: TheoryEncoder = NoEncoder +) extends AbstractUnrollingSolver[Expr] { override val name = "U:"+underlying.name @@ -444,10 +461,7 @@ class UnrollingSolver(val context: LeonContext, val program: Program, underlying val printable = (e: Expr) => e val templateEncoder = new TemplateEncoder[Expr] { - def encodeId(id: Identifier): Expr= { - Variable(id.freshen) - } - + def encodeId(id: Identifier): Expr= Variable(id.freshen) def encodeExpr(bindings: Map[Identifier, Expr])(e: Expr): Expr = { replaceFromIDs(bindings, e) } @@ -468,11 +482,11 @@ class UnrollingSolver(val context: LeonContext, val program: Program, underlying } } + val theoryEncoder = theories + val solver = underlying - def assertCnstr(expression: Expr): Unit = { - assertCnstr(expression, variablesOf(expression).map(id => id -> id.toVariable).toMap) - } + def declareVariable(id: Identifier): Variable = id.toVariable def solverAssert(cnstr: Expr): Unit = { solver.assertCnstr(cnstr) @@ -491,8 +505,7 @@ class UnrollingSolver(val context: LeonContext, val program: Program, underlying def solverGetModel: ModelWrapper = new ModelWrapper { val model = solver.getModel - def get(id: Identifier): Option[Expr] = model.get(id) - def eval(elem: Expr, tpe: TypeTree): Option[Expr] = evaluator.eval(elem, model).result + def modelEval(elem: Expr, tpe: TypeTree): Option[Expr] = evaluator.eval(elem, model).result override def toString = model.toMap.mkString("\n") } diff --git a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala index 4e24697b6..41a8f3722 100644 --- a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala @@ -252,7 +252,7 @@ trait AbstractZ3Solver extends Solver { protected[leon] def toZ3Formula(expr: Expr, initialMap: Map[Identifier, Z3AST] = Map.empty): Z3AST = { - var z3Vars: Map[Identifier,Z3AST] = if(initialMap.nonEmpty) { + var z3Vars: Map[Identifier,Z3AST] = if (initialMap.nonEmpty) { initialMap } else { // FIXME TODO pleeeeeeeease make this cleaner. Ie. decide what set of diff --git a/src/main/scala/leon/solvers/z3/FairZ3Solver.scala b/src/main/scala/leon/solvers/z3/FairZ3Solver.scala index 5de176ff2..780a9fc68 100644 --- a/src/main/scala/leon/solvers/z3/FairZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/FairZ3Solver.scala @@ -4,7 +4,6 @@ package leon package solvers package z3 -import utils._ import _root_.z3.scala._ import purescala.Common._ @@ -13,8 +12,8 @@ import purescala.Expressions._ import purescala.ExprOps._ import purescala.Types._ -import solvers.templates._ -import solvers.combinators._ +import unrolling._ +import theories._ class FairZ3Solver(val context: LeonContext, val program: Program) extends AbstractZ3Solver @@ -32,6 +31,10 @@ class FairZ3Solver(val context: LeonContext, val program: Program) override protected val reporter = context.reporter override def reset(): Unit = super[AbstractZ3Solver].reset() + def declareVariable(id: Identifier): Z3AST = variables.cachedB(Variable(id)) { + templateEncoder.encodeId(id) + } + def solverCheck[R](clauses: Seq[Z3AST])(block: Option[Boolean] => R): R = { solver.push() for (cls <- clauses) solver.assertCnstr(cls) @@ -80,14 +83,7 @@ class FairZ3Solver(val context: LeonContext, val program: Program) val fullModel = leonModel ++ (functionsAsMap ++ constantFunctionsAsMap) */ - def get(id: Identifier): Option[Expr] = variables.getB(id.toVariable).flatMap { - z3ID => eval(z3ID, id.getType) match { - case Some(Variable(id)) => None - case e => e - } - } - - def eval(elem: Z3AST, tpe: TypeTree): Option[Expr] = tpe match { + def modelEval(elem: Z3AST, tpe: TypeTree): Option[Expr] = tpe match { case BooleanType => model.evalAs[Boolean](elem).map(BooleanLiteral) case Int32Type => model.evalAs[Int](elem).map(IntLiteral).orElse { model.eval(elem).flatMap(t => softFromZ3Formula(model, t, Int32Type)) @@ -106,6 +102,8 @@ class FairZ3Solver(val context: LeonContext, val program: Program) def asString(implicit ctx: LeonContext) = z3.toString } + val theoryEncoder = StringEncoder + val templateEncoder = new TemplateEncoder[Z3AST] { def encodeId(id: Identifier): Z3AST = { idToFreshZ3Id(id) @@ -170,13 +168,9 @@ class FairZ3Solver(val context: LeonContext, val program: Program) } } - def assertCnstr(expression: Expr): Unit = { + override def assertCnstr(expression: Expr): Unit = { try { - val bindings = variablesOf(expression).map(id => id -> variables.cachedB(Variable(id)) { - templateGenerator.encoder.encodeId(id) - }).toMap - - assertCnstr(expression, bindings) + super.assertCnstr(expression) } catch { case _: Unsupported => addError() diff --git a/src/main/scala/leon/solvers/z3/Z3StringConversion.scala b/src/main/scala/leon/solvers/z3/Z3StringConversion.scala deleted file mode 100644 index c5fa2b0ba..000000000 --- a/src/main/scala/leon/solvers/z3/Z3StringConversion.scala +++ /dev/null @@ -1,382 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package solvers -package z3 - -import purescala.Common._ -import purescala.Expressions._ -import purescala.Constructors._ -import purescala.Types._ -import purescala.Definitions._ -import leon.utils.Bijection -import leon.purescala.DefOps -import leon.purescala.TypeOps -import leon.purescala.Extractors.Operator -import leon.evaluators.EvaluationResults - -object StringEcoSystem { - private def withIdentifier[T](name: String, tpe: TypeTree = Untyped)(f: Identifier => T): T = { - val id = FreshIdentifier(name, tpe) - f(id) - } - - private def withIdentifiers[T](name: String, tpe: TypeTree, name2: String, tpe2: TypeTree = Untyped)(f: (Identifier, Identifier) => T): T = { - withIdentifier(name, tpe)(id => withIdentifier(name2, tpe2)(id2 => f(id, id2))) - } - - val StringList = new AbstractClassDef(FreshIdentifier("StringList"), Seq(), None) - val StringListTyped = StringList.typed - val StringCons = withIdentifiers("head", CharType, "tail", StringListTyped){ (head, tail) => - val d = new CaseClassDef(FreshIdentifier("StringCons"), Seq(), Some(StringListTyped), false) - d.setFields(Seq(ValDef(head), ValDef(tail))) - d - } - - StringList.registerChild(StringCons) - val StringConsTyped = StringCons.typed - val StringNil = new CaseClassDef(FreshIdentifier("StringNil"), Seq(), Some(StringListTyped), false) - val StringNilTyped = StringNil.typed - StringList.registerChild(StringNil) - - val StringSize = withIdentifiers("l", StringListTyped, "StringSize"){ (lengthArg, id) => - val fd = new FunDef(id, Seq(), Seq(ValDef(lengthArg)), IntegerType) - fd.body = Some(withIdentifiers("h", CharType, "t", StringListTyped){ (h, t) => - MatchExpr(Variable(lengthArg), Seq( - MatchCase(CaseClassPattern(None, StringNilTyped, Seq()), None, InfiniteIntegerLiteral(BigInt(0))), - MatchCase(CaseClassPattern(None, StringConsTyped, Seq(WildcardPattern(Some(h)), WildcardPattern(Some(t)))), None, - Plus(InfiniteIntegerLiteral(BigInt(1)), FunctionInvocation(fd.typed, Seq(Variable(t))))) - )) - }) - fd - } - - val StringListConcat = withIdentifiers("x", StringListTyped, "y", StringListTyped) { (x, y) => - val fd = new FunDef(FreshIdentifier("StringListConcat"), Seq(), Seq(ValDef(x), ValDef(y)), StringListTyped) - fd.body = Some( - withIdentifiers("h", CharType, "t", StringListTyped){ (h, t) => - MatchExpr(Variable(x), Seq( - MatchCase(CaseClassPattern(None, StringNilTyped, Seq()), None, Variable(y)), - MatchCase(CaseClassPattern(None, StringConsTyped, Seq(WildcardPattern(Some(h)), WildcardPattern(Some(t)))), None, - CaseClass(StringConsTyped, Seq(Variable(h), FunctionInvocation(fd.typed, Seq(Variable(t), Variable(y))))) - ))) - } - ) - fd - } - - val StringTake = withIdentifiers("tt", StringListTyped, "it", StringListTyped) { (tt, it) => - val fd = new FunDef(FreshIdentifier("StringTake"), Seq(), Seq(ValDef(tt), ValDef(it)), StringListTyped) - fd.body = Some{ - withIdentifiers("h", CharType, "t", StringListTyped) { (h, t) => - withIdentifier("i", IntegerType){ i => - MatchExpr(Tuple(Seq(Variable(tt), Variable(it))), Seq( - MatchCase(TuplePattern(None, Seq(CaseClassPattern(None, StringNilTyped, Seq()), WildcardPattern(None))), None, - InfiniteIntegerLiteral(BigInt(0))), - MatchCase(TuplePattern(None, Seq(CaseClassPattern(None, StringConsTyped, Seq(WildcardPattern(Some(h)), WildcardPattern(Some(t)))), WildcardPattern(Some(i)))), None, - IfExpr(LessThan(Variable(i), InfiniteIntegerLiteral(BigInt(0))), - CaseClass(StringNilTyped, Seq()), - CaseClass(StringConsTyped, Seq(Variable(h), - FunctionInvocation(fd.typed, Seq(Variable(t), Minus(Variable(i), InfiniteIntegerLiteral(BigInt(1))))))) - )))) - } - } - } - fd - } - - val StringDrop = withIdentifiers("td", StringListTyped, "id", IntegerType) { (td, id) => - val fd = new FunDef(FreshIdentifier("StringDrop"), Seq(), Seq(ValDef(td), ValDef(id)), StringListTyped) - fd.body = Some( - withIdentifiers("h", CharType, "t", StringListTyped) { (h, t) => - withIdentifier("i", IntegerType){ i => - MatchExpr(Tuple(Seq(Variable(td), Variable(id))), Seq( - MatchCase(TuplePattern(None, Seq(CaseClassPattern(None, StringNilTyped, Seq()), WildcardPattern(None))), None, - InfiniteIntegerLiteral(BigInt(0))), - MatchCase(TuplePattern(None, Seq(CaseClassPattern(None, StringConsTyped, Seq(WildcardPattern(Some(h)), WildcardPattern(Some(t)))), WildcardPattern(Some(i)))), None, - IfExpr(LessThan(Variable(i), InfiniteIntegerLiteral(BigInt(0))), - CaseClass(StringConsTyped, Seq(Variable(h), Variable(t))), - FunctionInvocation(fd.typed, Seq(Variable(t), Minus(Variable(i), InfiniteIntegerLiteral(BigInt(1))))) - )))) - }} - ) - fd - } - - val StringSlice = withIdentifier("s", StringListTyped) { s => withIdentifiers("from", IntegerType, "to", IntegerType) { (from, to) => - val fd = new FunDef(FreshIdentifier("StringSlice"), Seq(), Seq(ValDef(s), ValDef(from), ValDef(to)), StringListTyped) - fd.body = Some( - FunctionInvocation(StringTake.typed, - Seq(FunctionInvocation(StringDrop.typed, Seq(Variable(s), Variable(from))), - Minus(Variable(to), Variable(from))))) - fd - } } - - val classDefs = Seq(StringList, StringCons, StringNil) - val funDefs = Seq(StringSize, StringListConcat, StringTake, StringDrop, StringSlice) -} - -class Z3StringConversion(val p: Program) extends Z3StringConverters { - import StringEcoSystem._ - def getProgram = program_with_string_methods - - lazy val program_with_string_methods = { - val p2 = DefOps.addClassDefs(p, StringEcoSystem.classDefs, p.library.Nil.get) - DefOps.addFunDefs(p2, StringEcoSystem.funDefs, p2.library.escape.get) - } -} - -trait Z3StringConverters { - import StringEcoSystem._ - val mappedVariables = new Bijection[Identifier, Identifier]() - - val globalClassMap = new Bijection[ClassDef, ClassDef]() // To be added manually - - val globalFdMap = new Bijection[FunDef, FunDef]() - - val stringBijection = new Bijection[String, Expr]() - - def convertToString(e: Expr): String = - stringBijection.cachedA(e) { - e match { - case CaseClass(_, Seq(CharLiteral(c), l)) => c + convertToString(l) - case CaseClass(_, Seq()) => "" - } - } - def convertFromString(v: String): Expr = - stringBijection.cachedB(v) { - v.toList.foldRight(CaseClass(StringNilTyped, Seq())){ - case (char, l) => CaseClass(StringConsTyped, Seq(CharLiteral(char), l)) - } - } - - trait BidirectionalConverters { - def convertFunDef(fd: FunDef): FunDef - def hasIdConversion(id: Identifier): Boolean - def convertId(id: Identifier): Identifier - def convertClassDef(d: ClassDef): ClassDef - def isTypeToConvert(tpe: TypeTree): Boolean - def convertType(tpe: TypeTree): TypeTree - def convertPattern(pattern: Pattern): Pattern - def convertExpr(expr: Expr)(implicit bindings: Map[Identifier, Expr]): Expr - object TypeConverted { - def unapply(t: TypeTree): Option[TypeTree] = Some(t match { - case cct@CaseClassType(ccd, args) => CaseClassType(convertClassDef(ccd).asInstanceOf[CaseClassDef], args map convertType) - case act@AbstractClassType(acd, args) => AbstractClassType(convertClassDef(acd).asInstanceOf[AbstractClassDef], args map convertType) - case NAryType(es, builder) => - builder(es map convertType) - }) - } - object PatternConverted { - def unapply(e: Pattern): Option[Pattern] = Some(e match { - case InstanceOfPattern(binder, ct) => - InstanceOfPattern(binder.map(convertId), convertType(ct).asInstanceOf[ClassType]) - case WildcardPattern(binder) => - WildcardPattern(binder.map(convertId)) - case CaseClassPattern(binder, ct, subpatterns) => - CaseClassPattern(binder.map(convertId), convertType(ct).asInstanceOf[CaseClassType], subpatterns map convertPattern) - case TuplePattern(binder, subpatterns) => - TuplePattern(binder.map(convertId), subpatterns map convertPattern) - case UnapplyPattern(binder, TypedFunDef(fd, tpes), subpatterns) => - UnapplyPattern(binder.map(convertId), TypedFunDef(convertFunDef(fd), tpes map convertType), subpatterns map convertPattern) - case PatternExtractor(es, builder) => - builder(es map convertPattern) - }) - } - - object ExprConverted { - def unapply(e: Expr)(implicit bindings: Map[Identifier, Expr]): Option[Expr] = Some(e match { - case Variable(id) if bindings contains id => bindings(id).copiedFrom(e) - case Variable(id) if hasIdConversion(id) => Variable(convertId(id)).copiedFrom(e) - case Variable(id) => e - case pl @ FiniteLambda(mappings, default, tpe) => - FiniteLambda( - mappings.map(kv => (kv._1.map(argtpe => convertExpr(argtpe)), - convertExpr(kv._2))), - convertExpr(default), convertType(tpe).asInstanceOf[FunctionType]) - case Lambda(args, body) => - val new_bindings = scala.collection.mutable.ListBuffer[(Identifier, Identifier)]() - val new_args = for(arg <- args) yield { - val in = arg.getType - val new_id = convertId(arg.id) - if(new_id ne arg.id) { - new_bindings += (arg.id -> new_id) - ValDef(new_id) - } else arg - } - val res = Lambda(new_args, convertExpr(body)(bindings ++ new_bindings.map(t => (t._1, Variable(t._2))))).copiedFrom(e) - res - case Let(a, expr, body) if isTypeToConvert(a.getType) => - val new_a = convertId(a) - val new_bindings = bindings + (a -> Variable(new_a)) - val expr2 = convertExpr(expr)(new_bindings) - val body2 = convertExpr(body)(new_bindings) - Let(new_a, expr2, body2).copiedFrom(e) - case CaseClass(CaseClassType(ccd, tpes), args) => - CaseClass(CaseClassType(ccd, tpes map convertType), args map convertExpr).copiedFrom(e) - case CaseClassSelector(CaseClassType(ccd, tpes), caseClass, selector) => - CaseClassSelector(CaseClassType(ccd, tpes map convertType), convertExpr(caseClass), selector).copiedFrom(e) - case MethodInvocation(rec: Expr, cd: ClassDef, TypedFunDef(fd, tpes), args: Seq[Expr]) => - MethodInvocation(convertExpr(rec), cd, TypedFunDef(convertFunDef(fd), tpes map convertType), args map convertExpr).copiedFrom(e) - case FunctionInvocation(TypedFunDef(fd, tpes), args) => - FunctionInvocation(TypedFunDef(convertFunDef(fd), tpes map convertType), args map convertExpr).copiedFrom(e) - case This(ct: ClassType) => - This(convertType(ct).asInstanceOf[ClassType]).copiedFrom(e) - case IsInstanceOf(expr, ct) => - IsInstanceOf(convertExpr(expr), convertType(ct).asInstanceOf[ClassType]).copiedFrom(e) - case AsInstanceOf(expr, ct) => - AsInstanceOf(convertExpr(expr), convertType(ct).asInstanceOf[ClassType]).copiedFrom(e) - case Tuple(args) => - Tuple(for(arg <- args) yield convertExpr(arg)).copiedFrom(e) - case MatchExpr(scrutinee, cases) => - MatchExpr(convertExpr(scrutinee), for(MatchCase(pattern, guard, rhs) <- cases) yield { - MatchCase(convertPattern(pattern), guard.map(convertExpr), convertExpr(rhs)) - }) - case Operator(es, builder) => - val rec = convertExpr _ - val newEs = es.map(rec) - if ((newEs zip es).exists { case (bef, aft) => aft ne bef }) { - builder(newEs).copiedFrom(e) - } else { - e - } - case e => e - }) - } - - def convertModel(model: Model): Model = { - new Model(model.ids.map{i => - val id = convertId(i) - id -> convertExpr(model(i))(Map()) - }.toMap) - } - - def convertResult(result: EvaluationResults.Result[Expr]) = { - result match { - case EvaluationResults.Successful(e) => EvaluationResults.Successful(convertExpr(e)(Map())) - case result => result - } - } - } - - object Forward extends BidirectionalConverters { - /* The conversion between functions should already have taken place */ - def convertFunDef(fd: FunDef): FunDef = { - globalFdMap.getBorElse(fd, fd) - } - /* The conversion between classdefs should already have taken place */ - def convertClassDef(cd: ClassDef): ClassDef = { - globalClassMap.getBorElse(cd, cd) - } - def hasIdConversion(id: Identifier): Boolean = { - mappedVariables.containsA(id) - } - def convertId(id: Identifier): Identifier = { - mappedVariables.getB(id) match { - case Some(idB) => idB - case None => - if(isTypeToConvert(id.getType)) { - val new_id = FreshIdentifier(id.name, convertType(id.getType)) - mappedVariables += (id -> new_id) - new_id - } else id - } - } - def isTypeToConvert(tpe: TypeTree): Boolean = - TypeOps.exists(StringType == _)(tpe) - def convertType(tpe: TypeTree): TypeTree = tpe match { - case StringType => StringListTyped - case TypeConverted(t) => t - } - def convertPattern(e: Pattern): Pattern = e match { - case LiteralPattern(binder, StringLiteral(s)) => - s.foldRight(CaseClassPattern(None, StringNilTyped, Seq())) { - case (elem, pattern) => - CaseClassPattern(None, StringConsTyped, Seq(LiteralPattern(None, CharLiteral(elem)), pattern)) - } - case PatternConverted(e) => e - } - - /** Method which can use recursively StringConverted in its body in unapply positions */ - def convertExpr(e: Expr)(implicit bindings: Map[Identifier, Expr]): Expr = e match { - case Variable(id) if isTypeToConvert(id.getType) => Variable(convertId(id)).copiedFrom(e) - case StringLiteral(v) => - val stringEncoding = convertFromString(v) - convertExpr(stringEncoding).copiedFrom(e) - case StringLength(a) => - FunctionInvocation(StringSize.typed, Seq(convertExpr(a))).copiedFrom(e) - case StringConcat(a, b) => - FunctionInvocation(StringListConcat.typed, Seq(convertExpr(a), convertExpr(b))).copiedFrom(e) - case SubString(a, start, Plus(start2, length)) if start == start2 => - FunctionInvocation(StringTake.typed, - Seq(FunctionInvocation(StringDrop.typed, Seq(convertExpr(a), convertExpr(start))), convertExpr(length))).copiedFrom(e) - case SubString(a, start, end) => - FunctionInvocation(StringSlice.typed, Seq(convertExpr(a), convertExpr(start), convertExpr(end))).copiedFrom(e) - case MatchExpr(scrutinee, cases) => - MatchExpr(convertExpr(scrutinee), for(MatchCase(pattern, guard, rhs) <- cases) yield { - MatchCase(convertPattern(pattern), guard.map(convertExpr), convertExpr(rhs)) - }) - case ExprConverted(e) => e - } - } - - object Backward extends BidirectionalConverters { - def convertFunDef(fd: FunDef): FunDef = { - globalFdMap.getAorElse(fd, fd) - } - /* The conversion between classdefs should already have taken place */ - def convertClassDef(cd: ClassDef): ClassDef = { - globalClassMap.getAorElse(cd, cd) - } - def hasIdConversion(id: Identifier): Boolean = { - mappedVariables.containsB(id) - } - def convertId(id: Identifier): Identifier = { - mappedVariables.getA(id) match { - case Some(idA) => idA - case None => - if(isTypeToConvert(id.getType)) { - val old_type = convertType(id.getType) - val old_id = FreshIdentifier(id.name, old_type) - mappedVariables += (old_id -> id) - old_id - } else id - } - } - def convertIdToMapping(id: Identifier): (Identifier, Variable) = { - id -> Variable(convertId(id)) - } - def isTypeToConvert(tpe: TypeTree): Boolean = - TypeOps.exists(t => TypeOps.isSubtypeOf(t, StringListTyped))(tpe) - def convertType(tpe: TypeTree): TypeTree = tpe match { - case StringListTyped | StringConsTyped | StringNilTyped => StringType - case TypeConverted(t) => t - } - def convertPattern(e: Pattern): Pattern = e match { - case CaseClassPattern(b, StringNilTyped, Seq()) => - LiteralPattern(b.map(convertId), StringLiteral("")) - case CaseClassPattern(b, StringConsTyped, Seq(LiteralPattern(_, CharLiteral(elem)), subpattern)) => - convertPattern(subpattern) match { - case LiteralPattern(_, StringLiteral(s)) - => LiteralPattern(b.map(convertId), StringLiteral(elem + s)) - case e => LiteralPattern(None, StringLiteral("Failed to parse pattern back as string:" + e)) - } - case PatternConverted(e) => e - } - - def convertExpr(e: Expr)(implicit bindings: Map[Identifier, Expr]): Expr = - e match { - case cc@CaseClass(cct, args) if TypeOps.isSubtypeOf(cct, StringListTyped)=> - StringLiteral(convertToString(cc)) - case FunctionInvocation(StringSize, Seq(a)) => - StringLength(convertExpr(a)).copiedFrom(e) - case FunctionInvocation(StringListConcat, Seq(a, b)) => - StringConcat(convertExpr(a), convertExpr(b)).copiedFrom(e) - case FunctionInvocation(StringTake, - Seq(FunctionInvocation(StringDrop, Seq(a, start)), length)) => - val rstart = convertExpr(start) - SubString(convertExpr(a), rstart, plus(rstart, convertExpr(length))).copiedFrom(e) - case ExprConverted(e) => e - } - } -} diff --git a/src/main/scala/leon/solvers/z3/Z3UnrollingSolver.scala b/src/main/scala/leon/solvers/z3/Z3UnrollingSolver.scala new file mode 100644 index 000000000..ec1428e39 --- /dev/null +++ b/src/main/scala/leon/solvers/z3/Z3UnrollingSolver.scala @@ -0,0 +1,13 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package solvers +package z3 + +import purescala.Definitions._ + +import unrolling._ +import theories._ + +class Z3UnrollingSolver(context: LeonContext, program: Program, underlying: Solver) + extends UnrollingSolver(context, program, underlying, theories = StringEncoder) diff --git a/src/main/scala/leon/synthesis/ExamplesFinder.scala b/src/main/scala/leon/synthesis/ExamplesFinder.scala index 7079f9054..9308c54d1 100644 --- a/src/main/scala/leon/synthesis/ExamplesFinder.scala +++ b/src/main/scala/leon/synthesis/ExamplesFinder.scala @@ -117,6 +117,7 @@ class ExamplesFinder(ctx0: LeonContext, program: Program) { } def generateForPC(ids: List[Identifier], pc: Expr, maxValid: Int = 400, maxEnumerated: Int = 1000): ExamplesBank = { + //println(program.definedClasses) val evaluator = new CodeGenEvaluator(ctx, program, CodeGenParams.default) val datagen = new GrammarDataGen(evaluator, ValueGrammar) diff --git a/src/main/scala/leon/utils/Bijection.scala b/src/main/scala/leon/utils/Bijection.scala index fecf8a756..251ebde95 100644 --- a/src/main/scala/leon/utils/Bijection.scala +++ b/src/main/scala/leon/utils/Bijection.scala @@ -2,14 +2,16 @@ package leon.utils +import scala.collection.mutable.{Map => MutableMap} + object Bijection { def apply[A, B](a: Iterable[(A, B)]): Bijection[A, B] = new Bijection[A, B] ++= a def apply[A, B](a: (A, B)*): Bijection[A, B] = apply(a.toSeq) } class Bijection[A, B] extends Iterable[(A, B)] { - protected var a2b = Map[A, B]() - protected var b2a = Map[B, A]() + protected val a2b = MutableMap[A, B]() + protected val b2a = MutableMap[B, A]() def iterator = a2b.iterator @@ -28,8 +30,8 @@ class Bijection[A, B] extends Iterable[(A, B)] { } def clear(): Unit = { - a2b = Map() - b2a = Map() + a2b.clear() + b2a.clear() } def getA(b: B) = b2a.get(b) @@ -72,4 +74,9 @@ class Bijection[A, B] extends Iterable[(A, B)] { def composeB[C](c: B => C): Bijection[A, C] = { new Bijection[A, C] ++= this.a2b.map(kv => kv._1 -> c(kv._2)) } + + def swap: Bijection[B, A] = new Bijection[B, A] { + override protected val a2b = Bijection.this.b2a + override protected val b2a = Bijection.this.a2b + } } diff --git a/src/main/scala/leon/utils/IncrementalBijection.scala b/src/main/scala/leon/utils/IncrementalBijection.scala index 6dab6d670..411fcf31c 100644 --- a/src/main/scala/leon/utils/IncrementalBijection.scala +++ b/src/main/scala/leon/utils/IncrementalBijection.scala @@ -2,25 +2,24 @@ package leon.utils -class IncrementalBijection[A,B] extends Bijection[A,B] with IncrementalState { - private var a2bStack = List[Map[A,B]]() - private var b2aStack = List[Map[B,A]]() +import scala.collection.mutable.{Map => MutableMap, Stack} - private def recursiveGet[T,U](stack: List[Map[T,U]], t: T): Option[U] = stack match { - case t2u :: xs => t2u.get(t) orElse recursiveGet(xs, t) - case Nil => None - } +class IncrementalBijection[A,B] extends Bijection[A,B] with IncrementalState { + protected val a2bStack = Stack[MutableMap[A,B]]() + protected val b2aStack = Stack[MutableMap[B,A]]() override def getA(b: B) = b2a.get(b) match { case s @ Some(a) => s - case None => recursiveGet(b2aStack, b) + case None => b2aStack.view.flatMap(_.get(b)).headOption } override def getB(a: A) = a2b.get(a) match { case s @ Some(b) => s - case None => recursiveGet(a2bStack, a) + case None => a2bStack.view.flatMap(_.get(a)).headOption } + override def iterator = aToB.iterator + def aToB: Map[A,B] = { a2bStack.reverse.foldLeft(Map[A,B]()) { _ ++ _ } ++ a2b } @@ -37,22 +36,30 @@ class IncrementalBijection[A,B] extends Bijection[A,B] with IncrementalState { def reset() : Unit = { super.clear() - a2bStack = Nil - b2aStack = Nil + a2bStack.clear() + b2aStack.clear() } def push(): Unit = { - a2bStack = a2b :: a2bStack - b2aStack = b2a :: b2aStack - a2b = Map() - b2a = Map() + a2bStack.push(a2b.clone) + b2aStack.push(b2a.clone) + a2b.clear() + b2a.clear() } def pop(): Unit = { - a2b = a2bStack.head - b2a = b2aStack.head - a2bStack = a2bStack.tail - b2aStack = b2aStack.tail + a2b.clear() + a2b ++= a2bStack.head + b2a.clear() + b2a ++= b2aStack.head + a2bStack.pop() + b2aStack.pop() } + override def swap: IncrementalBijection[B, A] = new IncrementalBijection[B, A] { + override protected val a2b = IncrementalBijection.this.b2a + override protected val b2a = IncrementalBijection.this.a2b + override protected val a2bStack = IncrementalBijection.this.b2aStack + override protected val b2aStack = IncrementalBijection.this.a2bStack + } } diff --git a/src/main/scala/leon/utils/IncrementalMap.scala b/src/main/scala/leon/utils/IncrementalMap.scala index aeaf32e09..d1c3fe034 100644 --- a/src/main/scala/leon/utils/IncrementalMap.scala +++ b/src/main/scala/leon/utils/IncrementalMap.scala @@ -60,6 +60,12 @@ class IncrementalMap[A, B] private(dflt: Option[B]) def getOrElse[B1 >: B](k: A, e: => B1) = stack.head.getOrElse(k, e) def values = stack.head.values + def cached(k: A)(b: => B): B = getOrElse(k, { + val ev = b + this += k -> ev + ev + }) + def iterator = stack.head.iterator def +=(kv: (A, B)) = { stack.head += kv; this } def -=(k: A) = { stack.head -= k; this } diff --git a/src/test/scala/leon/integration/solvers/GlobalVariablesSuite.scala b/src/test/scala/leon/integration/solvers/GlobalVariablesSuite.scala index 466a16dad..628086b41 100644 --- a/src/test/scala/leon/integration/solvers/GlobalVariablesSuite.scala +++ b/src/test/scala/leon/integration/solvers/GlobalVariablesSuite.scala @@ -14,7 +14,7 @@ import leon.LeonContext import leon.solvers._ import leon.solvers.smtlib._ -import leon.solvers.combinators._ +import leon.solvers.unrolling._ import leon.solvers.z3._ class GlobalVariablesSuite extends LeonTestSuiteWithProgram with ExpressionsDSL { diff --git a/src/test/scala/leon/integration/solvers/QuantifierSolverSuite.scala b/src/test/scala/leon/integration/solvers/QuantifierSolverSuite.scala index fa2260afc..0bc83e224 100644 --- a/src/test/scala/leon/integration/solvers/QuantifierSolverSuite.scala +++ b/src/test/scala/leon/integration/solvers/QuantifierSolverSuite.scala @@ -13,7 +13,7 @@ import leon.LeonOption import leon.solvers._ import leon.solvers.smtlib._ -import leon.solvers.combinators._ +import leon.solvers.cvc4._ import leon.solvers.z3._ class QuantifierSolverSuite extends LeonTestSuiteWithProgram { @@ -27,10 +27,10 @@ class QuantifierSolverSuite extends LeonTestSuiteWithProgram { ("fairz3", (ctx: LeonContext, pgm: Program) => new FairZ3Solver(ctx, pgm)) ) else Nil) ++ (if (SolverFactory.hasZ3) Seq( - ("smt-z3", (ctx: LeonContext, pgm: Program) => new UnrollingSolver(ctx, pgm, new SMTLIBZ3Solver(ctx, pgm))) + ("smt-z3", (ctx: LeonContext, pgm: Program) => new Z3UnrollingSolver(ctx, pgm, new SMTLIBZ3Solver(ctx, pgm))) ) else Nil) ++ (if (SolverFactory.hasCVC4) Seq( - ("smt-cvc4", (ctx: LeonContext, pgm: Program) => new UnrollingSolver(ctx, pgm, new SMTLIBCVC4Solver(ctx, pgm))) + ("smt-cvc4", (ctx: LeonContext, pgm: Program) => new CVC4UnrollingSolver(ctx, pgm, new SMTLIBCVC4Solver(ctx, pgm))) ) else Nil) } diff --git a/src/test/scala/leon/integration/solvers/SolversSuite.scala b/src/test/scala/leon/integration/solvers/SolversSuite.scala index 2a42bc6a3..8b0d8026e 100644 --- a/src/test/scala/leon/integration/solvers/SolversSuite.scala +++ b/src/test/scala/leon/integration/solvers/SolversSuite.scala @@ -13,7 +13,6 @@ import leon.LeonContext import leon.solvers._ import leon.solvers.smtlib._ -import leon.solvers.combinators._ import leon.solvers.z3._ class SolversSuite extends LeonTestSuiteWithProgram { @@ -22,13 +21,13 @@ class SolversSuite extends LeonTestSuiteWithProgram { val getFactories: Seq[(String, (LeonContext, Program) => Solver)] = { (if (SolverFactory.hasNativeZ3) Seq( - ("fairz3", (ctx: LeonContext, pgm: Program) => new Z3StringFairZ3Solver(ctx, pgm)) + ("fairz3", (ctx: LeonContext, pgm: Program) => new FairZ3Solver(ctx, pgm)) ) else Nil) ++ (if (SolverFactory.hasZ3) Seq( - ("smt-z3", (ctx: LeonContext, pgm: Program) => new Z3StringUnrollingSolver(ctx, pgm, pgm => new SMTLIBZ3Solver(ctx, pgm))) + ("smt-z3", (ctx: LeonContext, pgm: Program) => new Z3UnrollingSolver(ctx, pgm, new SMTLIBZ3Solver(ctx, pgm))) ) else Nil) ++ (if (SolverFactory.hasCVC4) Seq( - ("smt-cvc4", (ctx: LeonContext, pgm: Program) => new Z3StringUnrollingSolver(ctx, pgm, pgm => new SMTLIBCVC4Solver(ctx, pgm))) + ("smt-cvc4", (ctx: LeonContext, pgm: Program) => new Z3UnrollingSolver(ctx, pgm, new SMTLIBCVC4Solver(ctx, pgm))) ) else Nil) } @@ -49,7 +48,7 @@ class SolversSuite extends LeonTestSuiteWithProgram { val vs = types.map(FreshIdentifier("v", _).toVariable) - // We need to make sure models are not co-finite + // We need to make sure models are not co-finite val cnstrs = vs.map(v => v.getType match { case UnitType => Equals(v, simplestValue(v.getType)) @@ -77,7 +76,7 @@ class SolversSuite extends LeonTestSuiteWithProgram { fail(s"Solver $solver - Model does not contain "+v.id.uniqueName+" of type "+v.getType) } } - case _ => + case res => fail(s"Solver $solver - Constraint "+cnstr.asString+" is unsat!? Solver was "+solver.getClass) } } finally { diff --git a/src/test/scala/leon/integration/solvers/StringRenderSuite.scala b/src/test/scala/leon/integration/solvers/StringRenderSuite.scala index c0c724c81..c8a3000dc 100644 --- a/src/test/scala/leon/integration/solvers/StringRenderSuite.scala +++ b/src/test/scala/leon/integration/solvers/StringRenderSuite.scala @@ -209,9 +209,10 @@ class StringRenderSuite extends LeonTestSuiteWithProgram with Matchers with Scal | def listEdgeToString(l: List[Edge]): String = ??? by example |} """.stripMargin.replaceByExample) + implicit val (ctx, program) = getFixture() - - val synthesisInfos = SourceInfo.extractFromProgram(ctx, program).map(si => si.fd.id.name -> si ).toMap + + val synthesisInfos = SourceInfo.extractFromProgram(ctx, program).map(si => si.fd.id.name -> si).toMap def synthesizeAndTest(functionName: String, tests: (Seq[Expr], String)*) { val (fd, program) = applyStringRenderOn(functionName) @@ -260,6 +261,7 @@ class StringRenderSuite extends LeonTestSuiteWithProgram with Matchers with Scal def apply(types: TypeTree*)(args: Expr*) = FunctionInvocation(fd.typed(types), args) } + // Mimics the file above, allows construction of expressions. case class Constructors(program: Program) { implicit val p = program @@ -401,4 +403,4 @@ class StringRenderSuite extends LeonTestSuiteWithProgram with Matchers with Scal customListToString(Dummy2)(listDummy2, lambdaDummy2ToString))) } } -} \ No newline at end of file +} diff --git a/src/test/scala/leon/integration/solvers/UnrollingSolverSuite.scala b/src/test/scala/leon/integration/solvers/UnrollingSolverSuite.scala index f286d9d65..413ee804c 100644 --- a/src/test/scala/leon/integration/solvers/UnrollingSolverSuite.scala +++ b/src/test/scala/leon/integration/solvers/UnrollingSolverSuite.scala @@ -8,7 +8,6 @@ import leon.purescala.Types._ import leon.purescala.Common._ import leon.purescala.Definitions._ import leon.solvers.z3._ -import leon.solvers.combinators._ class UnrollingSolverSuite extends LeonSolverSuite { @@ -27,7 +26,7 @@ class UnrollingSolverSuite extends LeonSolverSuite { ) def getSolver(implicit ctx: LeonContext, pgm: Program) = { - new UnrollingSolver(ctx, pgm, new UninterpretedZ3Solver(ctx, pgm)) + new Z3UnrollingSolver(ctx, pgm, new UninterpretedZ3Solver(ctx, pgm)) } test("'true' should be valid") { implicit fix => -- GitLab