diff --git a/src/main/scala/leon/purescala/DefOps.scala b/src/main/scala/leon/purescala/DefOps.scala index 8cae0aa72c7a6d7ac18f3af251eeccb4056a6441..b5e5f1921807482229a2756a67749bc21f9443fc 100644 --- a/src/main/scala/leon/purescala/DefOps.scala +++ b/src/main/scala/leon/purescala/DefOps.scala @@ -275,7 +275,15 @@ object DefOps { None } - /** Returns the new program with a map from the old functions to the new functions */ + /** Clones the given program by replacing some functions by other functions. + * + * @param p The original program + * @param fdMapF Given f, returns Some(g) if f should be replaced by g, and None if f should be kept. + * May be called once each time a function appears (definition and invocation), + * so make sure to output the same if the argument is the same. + * @param fiMapF Given a previous function invocation and its new function definition, returns the expression to use. + * By default it is the function invocation using the new function definition. + * @return the new program with a map from the old functions to the new functions */ def replaceFunDefs(p: Program)(fdMapF: FunDef => Option[FunDef], fiMapF: (FunctionInvocation, FunDef) => Option[Expr] = defaultFiMap) = { @@ -297,7 +305,6 @@ object DefOps { df match { case f : FunDef => val newF = fdMap(f) - newF.fullBody = replaceFunCalls(newF.fullBody, fdMap, fiMapF) newF case d => d @@ -307,7 +314,11 @@ object DefOps { } ) }) - + for(fd <- newP.definedFunctions) { + if(ExprOps.exists{ case FunctionInvocation(TypedFunDef(fd, targs), fargs) => fdMapCache(fd) != None case _ => false }(fd.fullBody)) { + fd.fullBody = replaceFunCalls(fd.fullBody, fdMap, fiMapF) + } + } (newP, fdMapCache.collect{ case (ofd, Some(nfd)) => ofd -> nfd }) } @@ -320,32 +331,38 @@ object DefOps { }(e) } - def addFunDefs(p: Program, fds: Traversable[FunDef], after: FunDef): Program = { + def addDefs(p: Program, cds: Traversable[Definition], after: Definition): Program = { var found = false val res = p.copy(units = for (u <- p.units) yield { u.copy( - defs = u.defs.map { + defs = u.defs.flatMap { case m: ModuleDef => val newdefs = for (df <- m.defs) yield { df match { case `after` => found = true - after +: fds.toSeq - case d => - Seq(d) + after +: cds.toSeq + case d => Seq(d) } } - m.copy(defs = newdefs.flatten) - case d => d + Seq(m.copy(defs = newdefs.flatten)) + case `after` => + found = true + after +: cds.toSeq + case d => Seq(d) } ) }) if (!found) { - println("addFunDefs could not find anchor function!") + println("addDefs could not find anchor definition!") } res } + + def addFunDefs(p: Program, fds: Traversable[FunDef], after: FunDef): Program = addDefs(p, fds, after) + + def addClassDefs(p: Program, fds: Traversable[ClassDef], after: ClassDef): Program = addDefs(p, fds, after) // @Note: This function does not filter functions in classdefs def filterFunDefs(p: Program, fdF: FunDef => Boolean): Program = { diff --git a/src/main/scala/leon/purescala/Expressions.scala b/src/main/scala/leon/purescala/Expressions.scala index bb70a6676923db89648df770318f8c04129bad3e..69dbc4429aac6d31444cd024944d03096c7b6cc8 100644 --- a/src/main/scala/leon/purescala/Expressions.scala +++ b/src/main/scala/leon/purescala/Expressions.scala @@ -364,6 +364,22 @@ object Expressions { someValue.id ) } + + object PatternExtractor extends SubTreeOps.Extractor[Pattern] { + def unapply(e: Pattern): Option[(Seq[Pattern], (Seq[Pattern]) => Pattern)] = e match { + case (_: InstanceOfPattern) | (_: WildcardPattern) | (_: LiteralPattern[_]) => + Some(Seq(), es => e) + case CaseClassPattern(binder, ct, subpatterns) => + Some(subpatterns, es => CaseClassPattern(binder, ct, es)) + case TuplePattern(binder, subpatterns) => + Some(subpatterns, es => TuplePattern(binder, es)) + case UnapplyPattern(binder, unapplyFun, subpatterns) => + Some(subpatterns, es => UnapplyPattern(binder, unapplyFun, es)) + case _ => None + } + } + + object PatternOps extends { val Deconstructor = PatternExtractor } with SubTreeOps[Pattern] /** Symbolic I/O examples as a match/case. * $encodingof `out == (in match { cases; case _ => out })` diff --git a/src/main/scala/leon/solvers/combinators/Z3StringCapableSolver.scala b/src/main/scala/leon/solvers/combinators/Z3StringCapableSolver.scala index c57958bc726b9384b49ace471e44c2bc595d5358..9648c17853d84d034db21be1ea2d5f347a20ea45 100644 --- a/src/main/scala/leon/solvers/combinators/Z3StringCapableSolver.scala +++ b/src/main/scala/leon/solvers/combinators/Z3StringCapableSolver.scala @@ -19,58 +19,39 @@ import z3.FairZ3Component.{optFeelingLucky, optUseCodeGen, optAssumePre, optNoCh import templates._ import evaluators._ import Template._ -import leon.solvers.z3.Z3StringConversionReverse +import leon.solvers.z3.Z3StringConversion +import leon.utils.Bijection +import leon.solvers.z3.StringEcoSystem object Z3StringCapableSolver { - def convert(p: Program): ((Program, Map[FunDef, FunDef]), Z3StringConversionReverse, Map[Identifier, Identifier]) = { - val converter = new Z3StringConversionReverse { - def getProgram = p - - def convertId(id: Identifier): (Identifier, Variable) = { - id -> Variable(FreshIdentifier(id.name, convertType(id.getType))) - } - def convertToTarget(e: Expr)(implicit bindings: Map[Identifier, Expr]): Expr = e match { - case Variable(id) if bindings contains id => bindings(id) - case Let(a, expr, body) if TypeOps.exists( _ == StringType)(a.getType) => - val new_a_bid = convertId(a) - val new_bindings = bindings + new_a_bid - val expr2 = convertToTarget(expr)(new_bindings) - val body2 = convertToTarget(expr)(new_bindings) - Let(new_a_bid._1, expr2, body2) - case StringConverted(p) => p - case Operator(es, builder) => - val rec = convertToTarget _ - val newEs = es.map(rec) - if ((newEs zip es).exists { case (bef, aft) => aft ne bef }) { - builder(newEs).copiedFrom(e) - } else { - e - } - } - def targetApplication(fd: TypedFunDef, args: Seq[Expr])(implicit bindings: Map[Identifier, Expr]): Expr = { - FunctionInvocation(fd, args) - } - } + def convert(p: Program): ((Program, Map[FunDef, FunDef]), Z3StringConversion, Map[Identifier, Identifier]) = { + val converter = new Z3StringConversion(p) import converter._ + import converter.Forward._ var globalIdMap = Map[Identifier, Identifier]() - (DefOps.replaceFunDefs(p)((fd: FunDef) => { - if( fd.body.map(exists(e => TypeOps.exists{ _== StringType }(e.getType))).getOrElse(false) || - fd.paramIds.exists(id => TypeOps.exists(_ == StringType)(id.getType))) { - - val idMap = fd.params.map(vd => vd.id -> FreshIdentifier(vd.id.name, convertType(vd.id.getType))).toMap - globalIdMap ++= idMap - implicit val idVarMap = idMap.mapValues(id => Variable(id)) - - val newFd = fd.duplicate(FreshIdentifier(fd.id.name, convertType(fd.id.getType)), - fd.tparams, - fd.params.map(vd => ValDef(idMap(vd.id))), - convertType(fd.returnType)) - fd.body foreach { body => - newFd.body = Some(convertToTarget(body)) - } - Some(newFd) - } else None - }), converter, globalIdMap) + var globalFdMap = Map[FunDef, (Map[Identifier, Identifier], FunDef)]() + val (new_program, fdMap) = DefOps.replaceFunDefs(converter.getProgram)((fd: FunDef) => { + globalFdMap.get(fd).map(_._2).orElse( + if( fd.body.map(exists(e => TypeOps.exists{ _== StringType }(e.getType))).getOrElse(false) || + fd.paramIds.exists(id => TypeOps.exists(_ == StringType)(id.getType))) { + val idMap = fd.params.map(vd => vd.id -> convertId(vd.id)).toMap + globalIdMap ++= idMap + val newFdId = convertId(fd.id) + val newFd = fd.duplicate(newFdId, + fd.tparams, + fd.params.map(vd => ValDef(idMap(vd.id))), + convertType(fd.returnType)) + globalFdMap += fd -> ((idMap, newFd)) + Some(newFd) + } else None + ) + }) + converter.globalFdMap ++= globalFdMap.view.map(kv => (kv._1, kv._2._2)) + for((fd, (idMap, newFd)) <- globalFdMap) { + implicit val idVarMap = idMap.mapValues(id => Variable(id)) + newFd.fullBody = convertExpr(newFd.fullBody) + } + ((new_program, fdMap), converter, globalIdMap) } } @@ -94,16 +75,30 @@ class Z3StringCapableSolver(val context: LeonContext, val program: Program, f: P // Members declared in leon.solvers.QuantificationSolver def getModel: leon.solvers.HenkinModel = { val model = underlying.getModel - val ids = model.ids.toSet + val ids = model.ids.toSeq val exprs = ids.map(model.apply) val original_ids = ids.map(idMapReverse) // Should exist. - val original_exprs = exprs.map(e => converter.StringConversion.reverse(e)) - new HenkinModel(original_ids.zip(original_exprs).toMap, model.doms) // TODO: Convert the domains as well + import converter.Backward._ + val original_exprs = exprs.zip(original_ids).map{ case (e, id) => convertExpr(e)(Map()) } + + val new_domain = new HenkinDomains( + model.doms.lambdas.map(kv => + (convertExpr(kv._1)(Map()).asInstanceOf[Lambda], + kv._2.map(e => e.map(e => convertExpr(e)(Map()))))).toMap, + model.doms.tpes.map(kv => + (convertType(kv._1), + kv._2.map(e => e.map(e => convertExpr(e)(Map()))))).toMap + ) + + new HenkinModel(original_ids.zip(original_exprs).toMap, new_domain) } // Members declared in leon.solvers.Solver def assertCnstr(expression: leon.purescala.Expressions.Expr): Unit = { - underlying.assertCnstr(converter.convertToTarget(expression)(Map())) + val expression2 = DefOps.replaceFunCalls(expression, mappings.withDefault { x => x }.apply _) + import converter.Forward._ + val newExpression = convertExpr(expression2)(idMap.mapValues(Variable)) + underlying.assertCnstr(newExpression) } def check: Option[Boolean] = underlying.check def free(): Unit = underlying.free() diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala index 47017bcf1471770f1b1e9fb574a81bed34f4c515..d05b45c35e2e3847be4e1a6650c779de91da91d9 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala @@ -104,6 +104,7 @@ trait SMTLIBTarget extends Interruptible { interpreter.eval(cmd) match { case err @ ErrorResponse(msg) if !hasError && !interrupted && !rawOut => reporter.warning(s"Unexpected error from $targetName solver: $msg") + //println(Thread.currentThread().getStackTrace.map(_.toString).take(10).mkString("\n")) // Store that there was an error. Now all following check() // invocations will return None addError() diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala index 1731b94ae5f4f91b87248deddc50db5339915552..3d4a06a838a5057a693d85754bb5113e1ce7d0ae 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala @@ -8,15 +8,15 @@ import purescala.Common._ import purescala.Expressions._ import purescala.Constructors._ import purescala.Types._ -import purescala.Definitions._ + import _root_.smtlib.parser.Terms.{Identifier => SMTIdentifier, _} import _root_.smtlib.parser.Commands.{FunDef => SMTFunDef, _} import _root_.smtlib.interpreters.Z3Interpreter import _root_.smtlib.theories.Core.{Equals => SMTEquals, _} import _root_.smtlib.theories.ArraysEx -import leon.solvers.z3.Z3StringConversion -trait SMTLIBZ3Target extends SMTLIBTarget with Z3StringConversion[Term] { +trait SMTLIBZ3Target extends SMTLIBTarget { + def targetName = "z3" def interpreterOps(ctx: LeonContext) = { @@ -40,11 +40,11 @@ trait SMTLIBZ3Target extends SMTLIBTarget with Z3StringConversion[Term] { override protected def declareSort(t: TypeTree): Sort = { val tpe = normalizeType(t) sorts.cachedB(tpe) { - convertType(tpe) match { + tpe match { case SetType(base) => super.declareSort(BooleanType) declareSetSort(base) - case t => + case _ => super.declareSort(t) } } @@ -69,13 +69,9 @@ trait SMTLIBZ3Target extends SMTLIBTarget with Z3StringConversion[Term] { Sort(SMTIdentifier(setSort.get), Seq(declareSort(of))) } - override protected def fromSMT(t: Term, expected_otpe: Option[TypeTree] = None) + override protected def fromSMT(t: Term, otpe: Option[TypeTree] = None) (implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): Expr = { - val otpe = expected_otpe match { - case Some(StringType) => Some(listchar) - case _ => expected_otpe - } - val res = (t, otpe) match { + (t, otpe) match { case (SimpleSymbol(s), Some(tp: TypeParameter)) => val n = s.name.split("!").toList.last GenericValue(tp, n.toInt) @@ -100,16 +96,6 @@ trait SMTLIBZ3Target extends SMTLIBTarget with Z3StringConversion[Term] { case _ => super.fromSMT(t, otpe) } - expected_otpe match { - case Some(StringType) => - StringLiteral(convertToString(res)(program)) - case _ => res - } - } - - def convertToTarget(e: Expr)(implicit bindings: Map[Identifier, Term]): Term = toSMT(e) - def targetApplication(tfd: TypedFunDef, args: Seq[Term])(implicit bindings: Map[Identifier, Term]): Term = { - FunctionApplication(declareFunction(tfd), args) } override protected def toSMT(e: Expr)(implicit bindings: Map[Identifier, Term]): Term = e match { @@ -146,7 +132,6 @@ trait SMTLIBZ3Target extends SMTLIBTarget with Z3StringConversion[Term] { case SetIntersection(l, r) => ArrayMap(SSymbol("and"), toSMT(l), toSMT(r)) - case StringConverted(result) => result case _ => super.toSMT(e) } diff --git a/src/main/scala/leon/solvers/templates/TemplateGenerator.scala b/src/main/scala/leon/solvers/templates/TemplateGenerator.scala index b16083bad85958f5f2f906368a1938cb91dd25cd..59be52775406a78c22a55773dfd161db003b21bb 100644 --- a/src/main/scala/leon/solvers/templates/TemplateGenerator.scala +++ b/src/main/scala/leon/solvers/templates/TemplateGenerator.scala @@ -207,7 +207,11 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], // id => expr && ... && expr var guardedExprs = Map[Identifier, Seq[Expr]]() def storeGuarded(guardVar : Identifier, expr : Expr) : Unit = { - assert(expr.getType == BooleanType, expr.asString(Program.empty)(LeonContext.empty) + " is not of type Boolean") + assert(expr.getType == BooleanType, expr.asString(Program.empty)(LeonContext.empty) + " is not of type Boolean." + ( + purescala.ExprOps.fold[String]{ (e, se) => + s"$e is of type ${e.getType}" + se.map(child => "\n " + "\n".r.replaceAllIn(child, "\n ")).mkString + }(expr) + )) val prev = guardedExprs.getOrElse(guardVar, Nil) diff --git a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala index ac1e8855a4a53ed5ef2f66c34d4b015d9a26ba3f..f0b7d5f91a9f53f32dd6eb9590c988618f0536d4 100644 --- a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala @@ -264,322 +264,311 @@ trait AbstractZ3Solver extends Solver { case other => throw SolverUnsupportedError(other, this) } - - protected[leon] def toZ3Formula(expr: Expr, initialMap: Map[Identifier, Z3AST] = Map.empty): Z3AST = { - implicit var z3Vars: Map[Identifier,Z3AST] = if(initialMap.nonEmpty) { + var z3Vars: Map[Identifier,Z3AST] = if(initialMap.nonEmpty) { initialMap } else { // FIXME TODO pleeeeeeeease make this cleaner. Ie. decide what set of // variable has to remain in a map etc. variables.aToB.collect{ case (Variable(id), p2) => id -> p2 } } - new Z3StringConversion[Z3AST] { - def getProgram = AbstractZ3Solver.this.program - def convertToTarget(e: Expr)(implicit bindings: Map[Identifier, Z3AST]): Z3AST = { - rec(e) - } - def targetApplication(tfd: TypedFunDef, args: Seq[Z3AST])(implicit bindings: Map[Identifier, Z3AST]): Z3AST = { - z3.mkApp(functionDefToDecl(tfd), args: _*) + + def rec(ex: Expr): Z3AST = ex match { + + // TODO: Leave that as a specialization? + case LetTuple(ids, e, b) => { + z3Vars = z3Vars ++ ids.zipWithIndex.map { case (id, ix) => + val entry = id -> rec(tupleSelect(e, ix + 1, ids.size)) + entry } - def rec(ex: Expr): Z3AST = ex match { - - // TODO: Leave that as a specialization? - case LetTuple(ids, e, b) => { - z3Vars = z3Vars ++ ids.zipWithIndex.map { case (id, ix) => - val entry = id -> rec(tupleSelect(e, ix + 1, ids.size)) - entry - } - val rb = rec(b) - z3Vars = z3Vars -- ids - rb - } - - case p @ Passes(_, _, _) => - rec(p.asConstraint) - - case me @ MatchExpr(s, cs) => - rec(matchToIfThenElse(me)) - - case Let(i, e, b) => { - val re = rec(e) - z3Vars = z3Vars + (i -> re) - val rb = rec(b) - z3Vars = z3Vars - i - rb - } - - case Waypoint(_, e, _) => rec(e) - case a @ Assert(cond, err, body) => - rec(IfExpr(cond, body, Error(a.getType, err.getOrElse("Assertion failed")).setPos(a.getPos)).setPos(a.getPos)) - - case e @ Error(tpe, _) => { - val newAST = z3.mkFreshConst("errorValue", typeToSort(tpe)) - // Might introduce dupplicates (e), but no worries here - variables += (e -> newAST) - newAST - } - case v @ Variable(id) => z3Vars.get(id) match { - case Some(ast) => + val rb = rec(b) + z3Vars = z3Vars -- ids + rb + } + + case p @ Passes(_, _, _) => + rec(p.asConstraint) + + case me @ MatchExpr(s, cs) => + rec(matchToIfThenElse(me)) + + case Let(i, e, b) => { + val re = rec(e) + z3Vars = z3Vars + (i -> re) + val rb = rec(b) + z3Vars = z3Vars - i + rb + } + + case Waypoint(_, e, _) => rec(e) + case a @ Assert(cond, err, body) => + rec(IfExpr(cond, body, Error(a.getType, err.getOrElse("Assertion failed")).setPos(a.getPos)).setPos(a.getPos)) + + case e @ Error(tpe, _) => { + val newAST = z3.mkFreshConst("errorValue", typeToSort(tpe)) + // Might introduce dupplicates (e), but no worries here + variables += (e -> newAST) + newAST + } + case v @ Variable(id) => z3Vars.get(id) match { + case Some(ast) => + ast + case None => { + variables.getB(v) match { + case Some(ast) => ast - case None => { - variables.getB(v) match { - case Some(ast) => - ast - - case None => - val newAST = z3.mkFreshConst(id.uniqueName, typeToSort(v.getType)) - z3Vars = z3Vars + (id -> newAST) - variables += (v -> newAST) - newAST - } - } - } - - case ite @ IfExpr(c, t, e) => z3.mkITE(rec(c), rec(t), rec(e)) - case And(exs) => z3.mkAnd(exs.map(rec): _*) - case Or(exs) => z3.mkOr(exs.map(rec): _*) - case Implies(l, r) => z3.mkImplies(rec(l), rec(r)) - case Not(Equals(l, r)) => z3.mkDistinct(rec(l), rec(r)) - case Not(e) => z3.mkNot(rec(e)) - case IntLiteral(v) => z3.mkInt(v, typeToSort(Int32Type)) - case InfiniteIntegerLiteral(v) => z3.mkNumeral(v.toString, typeToSort(IntegerType)) - case FractionalLiteral(n, d) => z3.mkNumeral(s"$n / $d", typeToSort(RealType)) - case CharLiteral(c) => z3.mkInt(c, typeToSort(CharType)) - case BooleanLiteral(v) => if (v) z3.mkTrue() else z3.mkFalse() - case Equals(l, r) => z3.mkEq(rec( l ), rec( r ) ) - case Plus(l, r) => z3.mkAdd(rec(l), rec(r)) - case Minus(l, r) => z3.mkSub(rec(l), rec(r)) - case Times(l, r) => z3.mkMul(rec(l), rec(r)) - case Division(l, r) => { - val rl = rec(l) - val rr = rec(r) - z3.mkITE( - z3.mkGE(rl, z3.mkNumeral("0", typeToSort(IntegerType))), - z3.mkDiv(rl, rr), - z3.mkUnaryMinus(z3.mkDiv(z3.mkUnaryMinus(rl), rr)) - ) - } - case Remainder(l, r) => { - val q = rec(Division(l, r)) - z3.mkSub(rec(l), z3.mkMul(rec(r), q)) - } - case Modulo(l, r) => { - z3.mkMod(rec(l), rec(r)) - } - case UMinus(e) => z3.mkUnaryMinus(rec(e)) - - case RealPlus(l, r) => z3.mkAdd(rec(l), rec(r)) - case RealMinus(l, r) => z3.mkSub(rec(l), rec(r)) - case RealTimes(l, r) => z3.mkMul(rec(l), rec(r)) - case RealDivision(l, r) => z3.mkDiv(rec(l), rec(r)) - case RealUMinus(e) => z3.mkUnaryMinus(rec(e)) - - case BVPlus(l, r) => z3.mkBVAdd(rec(l), rec(r)) - case BVMinus(l, r) => z3.mkBVSub(rec(l), rec(r)) - case BVTimes(l, r) => z3.mkBVMul(rec(l), rec(r)) - case BVDivision(l, r) => z3.mkBVSdiv(rec(l), rec(r)) - case BVRemainder(l, r) => z3.mkBVSrem(rec(l), rec(r)) - case BVUMinus(e) => z3.mkBVNeg(rec(e)) - case BVNot(e) => z3.mkBVNot(rec(e)) - case BVAnd(l, r) => z3.mkBVAnd(rec(l), rec(r)) - case BVOr(l, r) => z3.mkBVOr(rec(l), rec(r)) - case BVXOr(l, r) => z3.mkBVXor(rec(l), rec(r)) - case BVShiftLeft(l, r) => z3.mkBVShl(rec(l), rec(r)) - case BVAShiftRight(l, r) => z3.mkBVAshr(rec(l), rec(r)) - case BVLShiftRight(l, r) => z3.mkBVLshr(rec(l), rec(r)) - case LessThan(l, r) => l.getType match { - case IntegerType => z3.mkLT(rec(l), rec(r)) - case RealType => z3.mkLT(rec(l), rec(r)) - case Int32Type => z3.mkBVSlt(rec(l), rec(r)) - case CharType => z3.mkBVSlt(rec(l), rec(r)) - } - case LessEquals(l, r) => l.getType match { - case IntegerType => z3.mkLE(rec(l), rec(r)) - case RealType => z3.mkLE(rec(l), rec(r)) - case Int32Type => z3.mkBVSle(rec(l), rec(r)) - case CharType => z3.mkBVSle(rec(l), rec(r)) - //case _ => throw new IllegalStateException(s"l: $l, Left type: ${l.getType} Expr: $ex") - } - case GreaterThan(l, r) => l.getType match { - case IntegerType => z3.mkGT(rec(l), rec(r)) - case RealType => z3.mkGT(rec(l), rec(r)) - case Int32Type => z3.mkBVSgt(rec(l), rec(r)) - case CharType => z3.mkBVSgt(rec(l), rec(r)) - } - case GreaterEquals(l, r) => l.getType match { - case IntegerType => z3.mkGE(rec(l), rec(r)) - case RealType => z3.mkGE(rec(l), rec(r)) - case Int32Type => z3.mkBVSge(rec(l), rec(r)) - case CharType => z3.mkBVSge(rec(l), rec(r)) - } - - case StringConverted(result) => - result - - case u : UnitLiteral => - val tpe = normalizeType(u.getType) - typeToSort(tpe) - val constructor = constructors.toB(tpe) - constructor() - - case t @ Tuple(es) => - val tpe = normalizeType(t.getType) - typeToSort(tpe) - val constructor = constructors.toB(tpe) - constructor(es.map(rec): _*) - - case ts @ TupleSelect(t, i) => - val tpe = normalizeType(t.getType) - typeToSort(tpe) - val selector = selectors.toB((tpe, i-1)) - selector(rec(t)) - - case c @ CaseClass(ct, args) => - typeToSort(ct) // Making sure the sort is defined - val constructor = constructors.toB(ct) - constructor(args.map(rec): _*) - - case c @ CaseClassSelector(cct, cc, sel) => - typeToSort(cct) // Making sure the sort is defined - val selector = selectors.toB(cct, c.selectorIndex) - selector(rec(cc)) - - case AsInstanceOf(expr, ct) => - rec(expr) - - case IsInstanceOf(e, act: AbstractClassType) => - act.knownCCDescendants match { - case Seq(cct) => - rec(IsInstanceOf(e, cct)) - case more => - val i = FreshIdentifier("e", act, alwaysShowUniqueID = true) - rec(Let(i, e, orJoin(more map(IsInstanceOf(Variable(i), _))))) - } - - case IsInstanceOf(e, cct: CaseClassType) => - typeToSort(cct) // Making sure the sort is defined - val tester = testers.toB(cct) - tester(rec(e)) - - case al @ ArraySelect(a, i) => - val tpe = normalizeType(a.getType) - - val sa = rec(a) - val content = selectors.toB((tpe, 1))(sa) - - z3.mkSelect(content, rec(i)) - - case al @ ArrayUpdated(a, i, e) => - val tpe = normalizeType(a.getType) - - val sa = rec(a) - val ssize = selectors.toB((tpe, 0))(sa) - val scontent = selectors.toB((tpe, 1))(sa) - - val newcontent = z3.mkStore(scontent, rec(i), rec(e)) - - val constructor = constructors.toB(tpe) - - constructor(ssize, newcontent) - - case al @ ArrayLength(a) => - val tpe = normalizeType(a.getType) - val sa = rec(a) - selectors.toB((tpe, 0))(sa) - - case arr @ FiniteArray(elems, oDefault, length) => - val at @ ArrayType(base) = normalizeType(arr.getType) - typeToSort(at) - - val default = oDefault.getOrElse(simplestValue(base)) - - val ar = rec(RawArrayValue(Int32Type, elems.map { - case (i, e) => IntLiteral(i) -> e - }, default)) - - constructors.toB(at)(rec(length), ar) - - case f @ FunctionInvocation(tfd, args) => - z3.mkApp(functionDefToDecl(tfd), args.map(rec): _*) - - case fa @ Application(caller, args) => - val ft @ FunctionType(froms, to) = normalizeType(caller.getType) - val funDecl = lambdas.cachedB(ft) { - val sortSeq = (ft +: froms).map(tpe => typeToSort(tpe)) - val returnSort = typeToSort(to) - - val name = FreshIdentifier("dynLambda").uniqueName - z3.mkFreshFuncDecl(name, sortSeq, returnSort) - } - z3.mkApp(funDecl, (caller +: args).map(rec): _*) - - case ElementOfSet(e, s) => z3.mkSetMember(rec(e), rec(s)) - case SubsetOf(s1, s2) => z3.mkSetSubset(rec(s1), rec(s2)) - case SetIntersection(s1, s2) => z3.mkSetIntersect(rec(s1), rec(s2)) - case SetUnion(s1, s2) => z3.mkSetUnion(rec(s1), rec(s2)) - case SetDifference(s1, s2) => z3.mkSetDifference(rec(s1), rec(s2)) - case f @ FiniteSet(elems, base) => elems.foldLeft(z3.mkEmptySet(typeToSort(base)))((ast, el) => z3.mkSetAdd(ast, rec(el))) - - case RawArrayValue(keyTpe, elems, default) => - val ar = z3.mkConstArray(typeToSort(keyTpe), rec(default)) - - elems.foldLeft(ar) { - case (array, (k, v)) => z3.mkStore(array, rec(k), rec(v)) - } - - /** - * ===== Map operations ===== - */ - case m @ FiniteMap(elems, from, to) => - val MapType(_, t) = normalizeType(m.getType) - - rec(RawArrayValue(from, elems.map{ - case (k, v) => (k, CaseClass(library.someType(t), Seq(v))) - }.toMap, CaseClass(library.noneType(t), Seq()))) - - case MapApply(m, k) => - val mt @ MapType(_, t) = normalizeType(m.getType) - typeToSort(mt) - - val el = z3.mkSelect(rec(m), rec(k)) - - // Really ?!? We don't check that it is actually != None? - selectors.toB(library.someType(t), 0)(el) - - case MapIsDefinedAt(m, k) => - val mt @ MapType(_, t) = normalizeType(m.getType) - typeToSort(mt) - - val el = z3.mkSelect(rec(m), rec(k)) - - testers.toB(library.someType(t))(el) - - case MapUnion(m1, FiniteMap(elems, _, _)) => - val mt @ MapType(_, t) = normalizeType(m1.getType) - typeToSort(mt) - - elems.foldLeft(rec(m1)) { case (m, (k,v)) => - z3.mkStore(m, rec(k), rec(CaseClass(library.someType(t), Seq(v)))) - } - - - case gv @ GenericValue(tp, id) => - z3.mkApp(genericValueToDecl(gv)) - - case other => - unsupported(other) + + case None => + val newAST = z3.mkFreshConst(id.uniqueName, typeToSort(v.getType)) + z3Vars = z3Vars + (id -> newAST) + variables += (v -> newAST) + newAST } - }.rec(expr) + } + } + + case ite @ IfExpr(c, t, e) => z3.mkITE(rec(c), rec(t), rec(e)) + case And(exs) => z3.mkAnd(exs.map(rec): _*) + case Or(exs) => z3.mkOr(exs.map(rec): _*) + case Implies(l, r) => z3.mkImplies(rec(l), rec(r)) + case Not(Equals(l, r)) => z3.mkDistinct(rec(l), rec(r)) + case Not(e) => z3.mkNot(rec(e)) + case IntLiteral(v) => z3.mkInt(v, typeToSort(Int32Type)) + case InfiniteIntegerLiteral(v) => z3.mkNumeral(v.toString, typeToSort(IntegerType)) + case FractionalLiteral(n, d) => z3.mkNumeral(s"$n / $d", typeToSort(RealType)) + case CharLiteral(c) => z3.mkInt(c, typeToSort(CharType)) + case BooleanLiteral(v) => if (v) z3.mkTrue() else z3.mkFalse() + case Equals(l, r) => z3.mkEq(rec( l ), rec( r ) ) + case Plus(l, r) => z3.mkAdd(rec(l), rec(r)) + case Minus(l, r) => z3.mkSub(rec(l), rec(r)) + case Times(l, r) => z3.mkMul(rec(l), rec(r)) + case Division(l, r) => { + val rl = rec(l) + val rr = rec(r) + z3.mkITE( + z3.mkGE(rl, z3.mkNumeral("0", typeToSort(IntegerType))), + z3.mkDiv(rl, rr), + z3.mkUnaryMinus(z3.mkDiv(z3.mkUnaryMinus(rl), rr)) + ) + } + case Remainder(l, r) => { + val q = rec(Division(l, r)) + z3.mkSub(rec(l), z3.mkMul(rec(r), q)) + } + case Modulo(l, r) => { + z3.mkMod(rec(l), rec(r)) + } + case UMinus(e) => z3.mkUnaryMinus(rec(e)) + + case RealPlus(l, r) => z3.mkAdd(rec(l), rec(r)) + case RealMinus(l, r) => z3.mkSub(rec(l), rec(r)) + case RealTimes(l, r) => z3.mkMul(rec(l), rec(r)) + case RealDivision(l, r) => z3.mkDiv(rec(l), rec(r)) + case RealUMinus(e) => z3.mkUnaryMinus(rec(e)) + + case BVPlus(l, r) => z3.mkBVAdd(rec(l), rec(r)) + case BVMinus(l, r) => z3.mkBVSub(rec(l), rec(r)) + case BVTimes(l, r) => z3.mkBVMul(rec(l), rec(r)) + case BVDivision(l, r) => z3.mkBVSdiv(rec(l), rec(r)) + case BVRemainder(l, r) => z3.mkBVSrem(rec(l), rec(r)) + case BVUMinus(e) => z3.mkBVNeg(rec(e)) + case BVNot(e) => z3.mkBVNot(rec(e)) + case BVAnd(l, r) => z3.mkBVAnd(rec(l), rec(r)) + case BVOr(l, r) => z3.mkBVOr(rec(l), rec(r)) + case BVXOr(l, r) => z3.mkBVXor(rec(l), rec(r)) + case BVShiftLeft(l, r) => z3.mkBVShl(rec(l), rec(r)) + case BVAShiftRight(l, r) => z3.mkBVAshr(rec(l), rec(r)) + case BVLShiftRight(l, r) => z3.mkBVLshr(rec(l), rec(r)) + case LessThan(l, r) => l.getType match { + case IntegerType => z3.mkLT(rec(l), rec(r)) + case RealType => z3.mkLT(rec(l), rec(r)) + case Int32Type => z3.mkBVSlt(rec(l), rec(r)) + case CharType => z3.mkBVSlt(rec(l), rec(r)) + } + case LessEquals(l, r) => l.getType match { + case IntegerType => z3.mkLE(rec(l), rec(r)) + case RealType => z3.mkLE(rec(l), rec(r)) + case Int32Type => z3.mkBVSle(rec(l), rec(r)) + case CharType => z3.mkBVSle(rec(l), rec(r)) + //case _ => throw new IllegalStateException(s"l: $l, Left type: ${l.getType} Expr: $ex") + } + case GreaterThan(l, r) => l.getType match { + case IntegerType => z3.mkGT(rec(l), rec(r)) + case RealType => z3.mkGT(rec(l), rec(r)) + case Int32Type => z3.mkBVSgt(rec(l), rec(r)) + case CharType => z3.mkBVSgt(rec(l), rec(r)) + } + case GreaterEquals(l, r) => l.getType match { + case IntegerType => z3.mkGE(rec(l), rec(r)) + case RealType => z3.mkGE(rec(l), rec(r)) + case Int32Type => z3.mkBVSge(rec(l), rec(r)) + case CharType => z3.mkBVSge(rec(l), rec(r)) + } + + case u : UnitLiteral => + val tpe = normalizeType(u.getType) + typeToSort(tpe) + val constructor = constructors.toB(tpe) + constructor() + + case t @ Tuple(es) => + val tpe = normalizeType(t.getType) + typeToSort(tpe) + val constructor = constructors.toB(tpe) + constructor(es.map(rec): _*) + + case ts @ TupleSelect(t, i) => + val tpe = normalizeType(t.getType) + typeToSort(tpe) + val selector = selectors.toB((tpe, i-1)) + selector(rec(t)) + + case c @ CaseClass(ct, args) => + typeToSort(ct) // Making sure the sort is defined + val constructor = constructors.toB(ct) + constructor(args.map(rec): _*) + + case c @ CaseClassSelector(cct, cc, sel) => + typeToSort(cct) // Making sure the sort is defined + val selector = selectors.toB(cct, c.selectorIndex) + selector(rec(cc)) + + case AsInstanceOf(expr, ct) => + rec(expr) + + case IsInstanceOf(e, act: AbstractClassType) => + act.knownCCDescendants match { + case Seq(cct) => + rec(IsInstanceOf(e, cct)) + case more => + val i = FreshIdentifier("e", act, alwaysShowUniqueID = true) + rec(Let(i, e, orJoin(more map(IsInstanceOf(Variable(i), _))))) + } + + case IsInstanceOf(e, cct: CaseClassType) => + typeToSort(cct) // Making sure the sort is defined + val tester = testers.toB(cct) + tester(rec(e)) + + case al @ ArraySelect(a, i) => + val tpe = normalizeType(a.getType) + + val sa = rec(a) + val content = selectors.toB((tpe, 1))(sa) + + z3.mkSelect(content, rec(i)) + + case al @ ArrayUpdated(a, i, e) => + val tpe = normalizeType(a.getType) + + val sa = rec(a) + val ssize = selectors.toB((tpe, 0))(sa) + val scontent = selectors.toB((tpe, 1))(sa) + + val newcontent = z3.mkStore(scontent, rec(i), rec(e)) + + val constructor = constructors.toB(tpe) + + constructor(ssize, newcontent) + + case al @ ArrayLength(a) => + val tpe = normalizeType(a.getType) + val sa = rec(a) + selectors.toB((tpe, 0))(sa) + + case arr @ FiniteArray(elems, oDefault, length) => + val at @ ArrayType(base) = normalizeType(arr.getType) + typeToSort(at) + + val default = oDefault.getOrElse(simplestValue(base)) + + val ar = rec(RawArrayValue(Int32Type, elems.map { + case (i, e) => IntLiteral(i) -> e + }, default)) + + constructors.toB(at)(rec(length), ar) + + case f @ FunctionInvocation(tfd, args) => + z3.mkApp(functionDefToDecl(tfd), args.map(rec): _*) + + case fa @ Application(caller, args) => + val ft @ FunctionType(froms, to) = normalizeType(caller.getType) + val funDecl = lambdas.cachedB(ft) { + val sortSeq = (ft +: froms).map(tpe => typeToSort(tpe)) + val returnSort = typeToSort(to) + + val name = FreshIdentifier("dynLambda").uniqueName + z3.mkFreshFuncDecl(name, sortSeq, returnSort) + } + z3.mkApp(funDecl, (caller +: args).map(rec): _*) + + case ElementOfSet(e, s) => z3.mkSetMember(rec(e), rec(s)) + case SubsetOf(s1, s2) => z3.mkSetSubset(rec(s1), rec(s2)) + case SetIntersection(s1, s2) => z3.mkSetIntersect(rec(s1), rec(s2)) + case SetUnion(s1, s2) => z3.mkSetUnion(rec(s1), rec(s2)) + case SetDifference(s1, s2) => z3.mkSetDifference(rec(s1), rec(s2)) + case f @ FiniteSet(elems, base) => elems.foldLeft(z3.mkEmptySet(typeToSort(base)))((ast, el) => z3.mkSetAdd(ast, rec(el))) + + case RawArrayValue(keyTpe, elems, default) => + val ar = z3.mkConstArray(typeToSort(keyTpe), rec(default)) + + elems.foldLeft(ar) { + case (array, (k, v)) => z3.mkStore(array, rec(k), rec(v)) + } + + /** + * ===== Map operations ===== + */ + case m @ FiniteMap(elems, from, to) => + val MapType(_, t) = normalizeType(m.getType) + + rec(RawArrayValue(from, elems.map{ + case (k, v) => (k, CaseClass(library.someType(t), Seq(v))) + }.toMap, CaseClass(library.noneType(t), Seq()))) + + case MapApply(m, k) => + val mt @ MapType(_, t) = normalizeType(m.getType) + typeToSort(mt) + + val el = z3.mkSelect(rec(m), rec(k)) + + // Really ?!? We don't check that it is actually != None? + selectors.toB(library.someType(t), 0)(el) + + case MapIsDefinedAt(m, k) => + val mt @ MapType(_, t) = normalizeType(m.getType) + typeToSort(mt) + + val el = z3.mkSelect(rec(m), rec(k)) + + testers.toB(library.someType(t))(el) + + case MapUnion(m1, FiniteMap(elems, _, _)) => + val mt @ MapType(_, t) = normalizeType(m1.getType) + typeToSort(mt) + + elems.foldLeft(rec(m1)) { case (m, (k,v)) => + z3.mkStore(m, rec(k), rec(CaseClass(library.someType(t), Seq(v)))) + } + + + case gv @ GenericValue(tp, id) => + z3.mkApp(genericValueToDecl(gv)) + + case other => + unsupported(other) + } + + rec(expr) } protected[leon] def fromZ3Formula(model: Z3Model, tree: Z3AST, tpe: TypeTree): Expr = { - def rec(t: Z3AST, expected_tpe: TypeTree): Expr = { + + def rec(t: Z3AST, tpe: TypeTree): Expr = { val kind = z3.getASTKind(t) - val tpe = Z3StringTypeConversion.convert(expected_tpe)(program) - val res = kind match { + kind match { case Z3NumeralIntAST(Some(v)) => val leading = t.toString.substring(0, 2 min t.toString.length) if(leading == "#x") { @@ -769,11 +758,6 @@ trait AbstractZ3Solver extends Solver { } case _ => unsound(t, "unexpected AST") } - expected_tpe match { - case StringType => - StringLiteral(Z3StringTypeConversion.convertToString(res)(program)) - case _ => res - } } rec(tree, normalizeType(tpe)) @@ -790,8 +774,7 @@ trait AbstractZ3Solver extends Solver { } def idToFreshZ3Id(id: Identifier): Z3AST = { - val correctType = Z3StringTypeConversion.convert(id.getType)(program) - z3.mkFreshConst(id.uniqueName, typeToSort(correctType)) + z3.mkFreshConst(id.uniqueName, typeToSort(id.getType)) } def reset() = { diff --git a/src/main/scala/leon/solvers/z3/Z3StringConversion.scala b/src/main/scala/leon/solvers/z3/Z3StringConversion.scala index 68e966ad2a2e38582d633ee4fac1a81037d831b4..5682fcf65219a7ac2677f4c9af8b87a2faf0e2da 100644 --- a/src/main/scala/leon/solvers/z3/Z3StringConversion.scala +++ b/src/main/scala/leon/solvers/z3/Z3StringConversion.scala @@ -7,57 +7,131 @@ import purescala.Expressions._ import purescala.Constructors._ import purescala.Types._ import purescala.Definitions._ -import _root_.smtlib.parser.Terms.{Identifier => SMTIdentifier, _} -import _root_.smtlib.parser.Commands.{FunDef => SMTFunDef, _} -import _root_.smtlib.interpreters.Z3Interpreter -import _root_.smtlib.theories.Core.{Equals => SMTEquals, _} -import _root_.smtlib.theories.ArraysEx import leon.utils.Bijection +import leon.purescala.DefOps +import leon.purescala.TypeOps +import leon.purescala.Extractors.Operator -object Z3StringTypeConversion { - def convert(t: TypeTree)(implicit p: Program) = new Z3StringTypeConversion { def getProgram = p }.convertType(t) - def convertToString(e: Expr)(implicit p: Program) = new Z3StringTypeConversion{ def getProgram = p }.convertToString(e) -} - -trait Z3StringTypeConversion { - val stringBijection = new Bijection[String, Expr]() +object StringEcoSystem { + private def withIdentifier[T](name: String, tpe: TypeTree = Untyped)(f: Identifier => T): T = { + val id = FreshIdentifier(name, tpe) + f(id) + } + private def withIdentifiers[T](name: String, tpe: TypeTree, name2: String, tpe2: TypeTree = Untyped)(f: (Identifier, Identifier) => T): T = { + withIdentifier(name, tpe)(id => withIdentifier(name2, tpe2)(id2 => f(id, id2))) + } - lazy val conschar = program.lookupCaseClass("leon.collection.Cons") match { - case Some(cc) => cc.typed(Seq(CharType)) - case _ => throw new Exception("Could not find Cons in Z3 solver") + val StringList = AbstractClassDef(FreshIdentifier("StringList"), Seq(), None) + val StringListTyped = StringList.typed + val StringCons = withIdentifiers("head", CharType, "tail", StringListTyped){ (head, tail) => + val d = CaseClassDef(FreshIdentifier("StringCons"), Seq(), Some(StringListTyped), false) + d.setFields(Seq(ValDef(head), ValDef(tail))) + d } - lazy val nilchar = program.lookupCaseClass("leon.collection.Nil") match { - case Some(cc) => cc.typed(Seq(CharType)) - case _ => throw new Exception("Could not find Nil in Z3 solver") + StringList.registerChild(StringCons) + val StringConsTyped = StringCons.typed + val StringNil = CaseClassDef(FreshIdentifier("StringNil"), Seq(), Some(StringListTyped), false) + val StringNilTyped = StringNil.typed + StringList.registerChild(StringNil) + + val StringSize = withIdentifiers("l", StringListTyped, "StringSize"){ (lengthArg, id) => + val fd = new FunDef(id, Seq(), Seq(ValDef(lengthArg)), IntegerType) + fd.body = Some(withIdentifiers("h", CharType, "t", StringListTyped){ (h, t) => + MatchExpr(Variable(lengthArg), Seq( + MatchCase(CaseClassPattern(None, StringNilTyped, Seq()), None, InfiniteIntegerLiteral(BigInt(0))), + MatchCase(CaseClassPattern(None, StringConsTyped, Seq(WildcardPattern(Some(h)), WildcardPattern(Some(t)))), None, + Plus(InfiniteIntegerLiteral(BigInt(1)), FunctionInvocation(fd.typed, Seq(Variable(t))))) + )) + }) + fd } - lazy val listchar = program.lookupAbstractClass("leon.collection.List") match { - case Some(cc) => cc.typed(Seq(CharType)) - case _ => throw new Exception("Could not find List in Z3 solver") + val StringListConcat = withIdentifiers("x", StringListTyped, "y", StringListTyped) { (x, y) => + val fd = new FunDef(FreshIdentifier("StringListConcat"), Seq(), Seq(ValDef(x), ValDef(y)), StringListTyped) + fd.body = Some( + withIdentifiers("h", CharType, "t", StringListTyped){ (h, t) => + MatchExpr(Variable(x), Seq( + MatchCase(CaseClassPattern(None, StringNilTyped, Seq()), None, Variable(y)), + MatchCase(CaseClassPattern(None, StringConsTyped, Seq(WildcardPattern(Some(h)), WildcardPattern(Some(t)))), None, + CaseClass(StringConsTyped, Seq(Variable(h), FunctionInvocation(fd.typed, Seq(Variable(t), Variable(y))))) + ))) + } + ) + fd } - def lookupFunDef(s: String): FunDef = program.lookupFunDef(s) match { - case Some(fd) => fd - case _ => throw new Exception("Could not find function "+s+" in program") + + val StringTake = withIdentifiers("tt", StringListTyped, "it", StringListTyped) { (tt, it) => + val fd = new FunDef(FreshIdentifier("StringTake"), Seq(), Seq(ValDef(tt), ValDef(it)), StringListTyped) + fd.body = Some{ + withIdentifiers("h", CharType, "t", StringListTyped) { (h, t) => + withIdentifier("i", IntegerType){ i => + MatchExpr(Tuple(Seq(Variable(tt), Variable(it))), Seq( + MatchCase(TuplePattern(None, Seq(CaseClassPattern(None, StringNilTyped, Seq()), WildcardPattern(None))), None, + InfiniteIntegerLiteral(BigInt(0))), + MatchCase(TuplePattern(None, Seq(CaseClassPattern(None, StringConsTyped, Seq(WildcardPattern(Some(h)), WildcardPattern(Some(t)))), WildcardPattern(Some(i)))), None, + IfExpr(LessThan(Variable(i), InfiniteIntegerLiteral(BigInt(0))), + CaseClass(StringNilTyped, Seq()), + CaseClass(StringConsTyped, Seq(Variable(h), + FunctionInvocation(fd.typed, Seq(Variable(t), Minus(Variable(i), InfiniteIntegerLiteral(BigInt(1))))))) + )))) + } + } + } + fd + } + + val StringDrop = withIdentifiers("td", StringListTyped, "id", IntegerType) { (td, id) => + val fd = new FunDef(FreshIdentifier("StringDrop"), Seq(), Seq(ValDef(td), ValDef(id)), StringListTyped) + fd.body = Some( + withIdentifiers("h", CharType, "t", StringListTyped) { (h, t) => + withIdentifier("i", IntegerType){ i => + MatchExpr(Tuple(Seq(Variable(td), Variable(id))), Seq( + MatchCase(TuplePattern(None, Seq(CaseClassPattern(None, StringNilTyped, Seq()), WildcardPattern(None))), None, + InfiniteIntegerLiteral(BigInt(0))), + MatchCase(TuplePattern(None, Seq(CaseClassPattern(None, StringConsTyped, Seq(WildcardPattern(Some(h)), WildcardPattern(Some(t)))), WildcardPattern(Some(i)))), None, + IfExpr(LessThan(Variable(i), InfiniteIntegerLiteral(BigInt(0))), + CaseClass(StringConsTyped, Seq(Variable(h), Variable(t))), + FunctionInvocation(fd.typed, Seq(Variable(t), Minus(Variable(i), InfiniteIntegerLiteral(BigInt(1))))) + )))) + }} + ) + fd } - lazy val list_size = lookupFunDef("leon.collection.List.size").typed(Seq(CharType)) - lazy val list_++ = lookupFunDef("leon.collection.List.++").typed(Seq(CharType)) - lazy val list_take = lookupFunDef("leon.collection.List.take").typed(Seq(CharType)) - lazy val list_drop = lookupFunDef("leon.collection.List.drop").typed(Seq(CharType)) - lazy val list_slice = lookupFunDef("leon.collection.List.slice").typed(Seq(CharType)) - private lazy val program = getProgram + val StringSlice = withIdentifier("s", StringListTyped) { s => withIdentifiers("from", IntegerType, "to", IntegerType) { (from, to) => + val fd = new FunDef(FreshIdentifier("StringSlice"), Seq(), Seq(ValDef(s), ValDef(from), ValDef(to)), StringListTyped) + fd.body = Some( + FunctionInvocation(StringTake.typed, + Seq(FunctionInvocation(StringDrop.typed, Seq(Variable(s), Variable(from))), + Minus(Variable(to), Variable(from))))) + fd + } } - def getProgram: Program + val classDefs = Seq(StringList, StringCons, StringNil) + val funDefs = Seq(StringSize, StringListConcat, StringTake, StringDrop, StringSlice) +} + +class Z3StringConversion(val p: Program) extends Z3StringConverters { + val stringBijection = new Bijection[String, Expr]() + + import StringEcoSystem._ + + lazy val listchar = StringList.typed + lazy val conschar = StringCons.typed + lazy val nilchar = StringNil.typed + + lazy val list_size = StringSize.typed + lazy val list_++ = StringListConcat.typed + lazy val list_take = StringTake.typed + lazy val list_drop = StringDrop.typed + lazy val list_slice = StringSlice.typed - def convertType(t: TypeTree): TypeTree = t match { - case StringType => listchar - case NAryType(subtypes, builder) => - builder(subtypes.map(convertType)) - } - def convertTypeBack(expected_type: TypeTree)(t: TypeTree): TypeTree = (expected_type, t) match { - case (StringType, `listchar`) => StringType - case (NAryType(ex, builder), NAryType(cur, builder2)) => - builder2(ex.zip(cur).map(ex_cur => convertTypeBack(ex_cur._1)(ex_cur._2))) + def getProgram = program_with_string_methods + + lazy val program_with_string_methods = { + val p2 = DefOps.addClassDefs(p, StringEcoSystem.classDefs, p.library.Nil.get) + DefOps.addFunDefs(p2, StringEcoSystem.funDefs, p2.library.escape.get) } + def convertToString(e: Expr)(implicit p: Program): String = stringBijection.cachedA(e) { e match { @@ -73,59 +147,206 @@ trait Z3StringTypeConversion { } } -trait Z3StringConversion[TargetType] extends Z3StringTypeConversion { - /** Method which can use recursively StringConverted in its body in unapply positions */ - def convertToTarget(e: Expr)(implicit bindings: Map[Identifier, TargetType]): TargetType - /** How the application (or function invocation) of a given fundef is performed in the target type. */ - def targetApplication(fd: TypedFunDef, args: Seq[TargetType])(implicit bindings: Map[Identifier, TargetType]): TargetType +trait Z3StringConverters { self: Z3StringConversion => + import StringEcoSystem._ + val mappedVariables = new Bijection[Identifier, Identifier]() + + val globalFdMap = new Bijection[FunDef, FunDef]() + + trait BidirectionalConverters { + def convertFunDef(fd: FunDef): FunDef + def hasIdConversion(id: Identifier): Boolean + def convertId(id: Identifier): Identifier + def isTypeToConvert(tpe: TypeTree): Boolean + def convertType(tpe: TypeTree): TypeTree + def convertPattern(pattern: Pattern): Pattern + def convertExpr(expr: Expr)(implicit bindings: Map[Identifier, Expr]): Expr + + object PatternConverted { + def unapply(e: Pattern): Option[Pattern] = Some(e match { + case InstanceOfPattern(binder, ct) => + InstanceOfPattern(binder.map(convertId), convertType(ct).asInstanceOf[ClassType]) + case WildcardPattern(binder) => + WildcardPattern(binder.map(convertId)) + case CaseClassPattern(binder, ct, subpatterns) => + CaseClassPattern(binder.map(convertId), convertType(ct).asInstanceOf[CaseClassType], subpatterns map convertPattern) + case TuplePattern(binder, subpatterns) => + TuplePattern(binder.map(convertId), subpatterns map convertPattern) + case UnapplyPattern(binder, TypedFunDef(fd, tpes), subpatterns) => + UnapplyPattern(binder.map(convertId), TypedFunDef(convertFunDef(fd), tpes map convertType), subpatterns map convertPattern) + case PatternExtractor(es, builder) => + builder(es map convertPattern) + }) + } + + object ExprConverted { + def unapply(e: Expr)(implicit bindings: Map[Identifier, Expr]): Option[Expr] = Some(e match { + case Variable(id) if bindings contains id => bindings(id).copiedFrom(e) + case Variable(id) if hasIdConversion(id) => Variable(convertId(id)).copiedFrom(e) + case Variable(id) => e + case pl@PartialLambda(mappings, default, tpe) => + PartialLambda( + mappings.map(kv => (kv._1.map(argtpe => convertExpr(argtpe)), + convertExpr(kv._2))), + default.map(d => convertExpr(d)), convertType(tpe).asInstanceOf[FunctionType]) + case Lambda(args, body) => + val new_bindings = scala.collection.mutable.ListBuffer[(Identifier, Identifier)]() + for(arg <- args) { + val in = arg.getType + if(convertType(in) ne in) { + new_bindings += (arg.id -> convertId(arg.id)) + } + } + val new_args = new_bindings.map(x => ValDef(x._2)) + Lambda(new_args, + convertExpr(body)(bindings ++ new_bindings.map(t => (t._1, Variable(t._2))))).copiedFrom(e) + case Let(a, expr, body) if isTypeToConvert(a.getType) => + val new_a = convertId(a) + val new_bindings = bindings + (a -> Variable(new_a)) + val expr2 = convertExpr(expr)(new_bindings) + val body2 = convertExpr(body)(new_bindings) + Let(new_a, expr2, body2).copiedFrom(e) + case CaseClass(CaseClassType(ccd, tpes), args) => + CaseClass(CaseClassType(ccd, tpes map convertType), args map convertExpr).copiedFrom(e) + case CaseClassSelector(CaseClassType(ccd, tpes), caseClass, selector) => + CaseClassSelector(CaseClassType(ccd, tpes map convertType), convertExpr(caseClass), selector).copiedFrom(e) + case MethodInvocation(rec: Expr, cd: ClassDef, TypedFunDef(fd, tpes), args: Seq[Expr]) => + MethodInvocation(convertExpr(rec), cd, TypedFunDef(convertFunDef(fd), tpes map convertType), args map convertExpr).copiedFrom(e) + case FunctionInvocation(TypedFunDef(fd, tpes), args) => + FunctionInvocation(TypedFunDef(convertFunDef(fd), tpes map convertType), args map convertExpr).copiedFrom(e) + case This(ct: ClassType) => + This(convertType(ct).asInstanceOf[ClassType]).copiedFrom(e) + case IsInstanceOf(expr, ct) => + IsInstanceOf(convertExpr(expr), convertType(ct).asInstanceOf[ClassType]).copiedFrom(e) + case AsInstanceOf(expr, ct) => + AsInstanceOf(convertExpr(expr), convertType(ct).asInstanceOf[ClassType]).copiedFrom(e) + case Tuple(args) => + Tuple(for(arg <- args) yield convertExpr(arg)).copiedFrom(e) + case MatchExpr(scrutinee, cases) => + MatchExpr(convertExpr(scrutinee), for(MatchCase(pattern, guard, rhs) <- cases) yield { + MatchCase(convertPattern(pattern), guard.map(convertExpr), convertExpr(rhs)) + }) + case Operator(es, builder) => + val rec = convertExpr _ + val newEs = es.map(rec) + if ((newEs zip es).exists { case (bef, aft) => aft ne bef }) { + builder(newEs).copiedFrom(e) + } else { + e + } + case e => e + }) + } + } - object StringConverted { - def unapply(e: Expr)(implicit replacement: Map[Identifier, TargetType]): Option[TargetType] = e match { + object Forward extends BidirectionalConverters { + /* The conversion between functions should already have taken place */ + def convertFunDef(fd: FunDef): FunDef = { + globalFdMap.getBorElse(fd, fd) + } + def hasIdConversion(id: Identifier): Boolean = { + mappedVariables.containsA(id) + } + def convertId(id: Identifier): Identifier = { + mappedVariables.getB(id) match { + case Some(idB) => idB + case None => + val new_id = FreshIdentifier(id.name, convertType(id.getType)) + mappedVariables += (id -> new_id) + new_id + } + } + def isTypeToConvert(tpe: TypeTree): Boolean = + TypeOps.exists(StringType == _)(tpe) + def convertType(tpe: TypeTree): TypeTree = + TypeOps.preMap{ case StringType => Some(StringList.typed) case e => None}(tpe) + def convertPattern(e: Pattern): Pattern = e match { + case LiteralPattern(binder, StringLiteral(s)) => + s.foldRight(CaseClassPattern(None, StringNilTyped, Seq())) { + case (elem, pattern) => + CaseClassPattern(None, StringConsTyped, Seq(LiteralPattern(None, CharLiteral(elem)), pattern)) + } + case PatternConverted(e) => e + } + + /** Method which can use recursively StringConverted in its body in unapply positions */ + def convertExpr(e: Expr)(implicit bindings: Map[Identifier, Expr]): Expr = e match { + case Variable(id) if isTypeToConvert(id.getType) => Variable(convertId(id)).copiedFrom(e) case StringLiteral(v) => // No string support for z3 at this moment. val stringEncoding = convertFromString(v) - Some(convertToTarget(stringEncoding)) + convertExpr(stringEncoding).copiedFrom(e) case StringLength(a) => - Some(targetApplication(list_size, Seq(convertToTarget(a)))) + FunctionInvocation(list_size, Seq(convertExpr(a))).copiedFrom(e) case StringConcat(a, b) => - Some(targetApplication(list_++, Seq(convertToTarget(a), convertToTarget(b)))) + FunctionInvocation(list_++, Seq(convertExpr(a), convertExpr(b))).copiedFrom(e) case SubString(a, start, Plus(start2, length)) if start == start2 => - Some(targetApplication(list_take, - Seq(targetApplication(list_drop, Seq(convertToTarget(a), convertToTarget(start))), convertToTarget(length)))) + FunctionInvocation(list_take, + Seq(FunctionInvocation(list_drop, Seq(convertExpr(a), convertExpr(start))), convertExpr(length))).copiedFrom(e) case SubString(a, start, end) => - Some(targetApplication(list_slice, Seq(convertToTarget(a), convertToTarget(start), convertToTarget(end)))) - case _ => None + FunctionInvocation(list_slice, Seq(convertExpr(a), convertExpr(start), convertExpr(end))).copiedFrom(e) + case MatchExpr(scrutinee, cases) => + MatchExpr(convertExpr(scrutinee), for(MatchCase(pattern, guard, rhs) <- cases) yield { + MatchCase(convertPattern(pattern), guard.map(convertExpr), convertExpr(rhs)) + }) + case ExprConverted(e) => e } - - def apply(t: TypeTree): TypeTree = convertType(t) } -} - -trait Z3StringConversionReverse extends Z3StringConversion[Expr] { - - object StringConversion { - def reverse(e: Expr): Expr = unapply(e).getOrElse(e) - def unapply(e: Expr): Option[Expr] = e match { - case CaseClass(`conschar`, Seq(CharLiteral(c), l)) => - reverse(l) match { - case StringLiteral(s) => Some(StringLiteral(c + s)) - case _ => None - } - case CaseClass(`nilchar`, Seq()) => - Some(StringLiteral("")) - case FunctionInvocation(`list_size`, Seq(a)) => - Some(StringLength(reverse(a))) - case FunctionInvocation(`list_++`, Seq(a, b)) => - Some(StringConcat(reverse(a), reverse(b))) - case FunctionInvocation(`list_take`, - Seq(FunctionInvocation(`list_drop`, Seq(a, start)), length)) => - val rstart = reverse(start) - Some(SubString(reverse(a), rstart, Plus(rstart, reverse(length)))) - case purescala.Extractors.Operator(es, builder) => - Some(builder(es.map(reverse _))) - case _ => None + + object Backward extends BidirectionalConverters { + def convertFunDef(fd: FunDef): FunDef = { + globalFdMap.getAorElse(fd, fd) + } + def hasIdConversion(id: Identifier): Boolean = { + mappedVariables.containsB(id) + } + def convertId(id: Identifier): Identifier = { + mappedVariables.getA(id) match { + case Some(idA) => idA + case None => + val old_type = convertType(id.getType) + val old_id = FreshIdentifier(id.name, old_type) + mappedVariables += (old_id -> id) + old_id + } + } + def convertIdToMapping(id: Identifier): (Identifier, Variable) = { + id -> Variable(convertId(id)) + } + def isTypeToConvert(tpe: TypeTree): Boolean = + TypeOps.exists(t => TypeOps.isSubtypeOf(t, StringListTyped))(tpe) + def convertType(tpe: TypeTree): TypeTree = { + TypeOps.preMap{ + case StringList | StringCons | StringNil => Some(StringType) + case e => None}(tpe) } + def convertPattern(e: Pattern): Pattern = e match { + case CaseClassPattern(b, StringNilTyped, Seq()) => + LiteralPattern(b.map(convertId), StringLiteral("")) + case CaseClassPattern(b, StringConsTyped, Seq(LiteralPattern(_, CharLiteral(elem)), subpattern)) => + convertPattern(subpattern) match { + case LiteralPattern(_, StringLiteral(s)) + => LiteralPattern(b.map(convertId), StringLiteral(elem + s)) + case e => LiteralPattern(None, StringLiteral("Failed to parse pattern back as string:" + e)) + } + case PatternConverted(e) => e } + - + + def convertExpr(e: Expr)(implicit bindings: Map[Identifier, Expr]): Expr = + e match { + case cc@CaseClass(cct, args) if TypeOps.isSubtypeOf(cct, StringListTyped)=> + StringLiteral(convertToString(cc)(self.p)) + case FunctionInvocation(StringSize, Seq(a)) => + StringLength(convertExpr(a)).copiedFrom(e) + case FunctionInvocation(StringListConcat, Seq(a, b)) => + StringConcat(convertExpr(a), convertExpr(b)).copiedFrom(e) + case FunctionInvocation(StringTake, + Seq(FunctionInvocation(StringDrop, Seq(a, start)), length)) => + val rstart = convertExpr(start) + SubString(convertExpr(a), rstart, plus(rstart, convertExpr(length))).copiedFrom(e) + case ExprConverted(e) => e + } + } } \ No newline at end of file diff --git a/src/main/scala/leon/utils/Bijection.scala b/src/main/scala/leon/utils/Bijection.scala index 57a62b665c797b10fab2d099fabd3a722f6e7d27..3680930639a2cfba46490d4a21bab7772d7fd0c8 100644 --- a/src/main/scala/leon/utils/Bijection.scala +++ b/src/main/scala/leon/utils/Bijection.scala @@ -11,8 +11,13 @@ class Bijection[A, B] { b2a += b -> a } - def +=(t: (A,B)): Unit = { - this += (t._1, t._2) + def +=(t: (A,B)): this.type = { + +=(t._1, t._2) + this + } + + def ++=(t: Iterable[(A,B)]) = { + (this /: t){ case (b, elem) => b += elem } } def clear(): Unit = { @@ -22,6 +27,9 @@ class Bijection[A, B] { def getA(b: B) = b2a.get(b) def getB(a: A) = a2b.get(a) + + def getAorElse(b: B, orElse: =>A) = b2a.getOrElse(b, orElse) + def getBorElse(a: A, orElse: =>B) = a2b.getOrElse(a, orElse) def toA(b: B) = getA(b).get def toB(a: A) = getB(a).get