From 6577e312968f2b610461b2e0b3a5874c779bfb6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mika=C3=ABl=20Mayer?= <mikael.mayer@epfl.ch> Date: Wed, 20 Jan 2016 20:00:37 +0100 Subject: [PATCH] Resolving issues with Z3 and string conversion in multiple solvers. corrected a web benchmark. --- .../leon/solvers/smtlib/SMTLIBZ3Solver.scala | 2 + .../leon/solvers/smtlib/SMTLIBZ3Target.scala | 89 +-- .../leon/solvers/z3/AbstractZ3Solver.scala | 597 +++++++++--------- .../leon/solvers/z3/Z3StringConversion.scala | 94 +++ .../web/synthesis/24_String_DoubleList.scala | 2 +- 5 files changed, 425 insertions(+), 359 deletions(-) create mode 100644 src/main/scala/leon/solvers/z3/Z3StringConversion.scala diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Solver.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Solver.scala index 86cf40f47..32adcd50f 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Solver.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Solver.scala @@ -15,6 +15,8 @@ import _root_.smtlib.theories.Core.{Equals => _, _} class SMTLIBZ3Solver(context: LeonContext, program: Program) extends SMTLIBSolver(context, program) with SMTLIBZ3Target { + def getProgram: Program = program + // EK: We use get-model instead in order to extract models for arrays override def getModel: Model = { diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala index b15bd9102..1731b94ae 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala @@ -14,10 +14,9 @@ import _root_.smtlib.parser.Commands.{FunDef => SMTFunDef, _} import _root_.smtlib.interpreters.Z3Interpreter import _root_.smtlib.theories.Core.{Equals => SMTEquals, _} import _root_.smtlib.theories.ArraysEx -import utils.Bijection - -trait SMTLIBZ3Target extends SMTLIBTarget { +import leon.solvers.z3.Z3StringConversion +trait SMTLIBZ3Target extends SMTLIBTarget with Z3StringConversion[Term] { def targetName = "z3" def interpreterOps(ctx: LeonContext) = { @@ -41,11 +40,11 @@ trait SMTLIBZ3Target extends SMTLIBTarget { override protected def declareSort(t: TypeTree): Sort = { val tpe = normalizeType(t) sorts.cachedB(tpe) { - tpe match { + convertType(tpe) match { case SetType(base) => super.declareSort(BooleanType) declareSetSort(base) - case _ => + case t => super.declareSort(t) } } @@ -70,34 +69,13 @@ trait SMTLIBZ3Target extends SMTLIBTarget { Sort(SMTIdentifier(setSort.get), Seq(declareSort(of))) } - val stringBijection = new Bijection[String, CaseClass]() - - lazy val conschar = program.lookupCaseClass("leon.collection.Cons") match { - case Some(cc) => cc.typed(Seq(CharType)) - case _ => throw new Exception("Could not find Cons in Z3 solver") - } - lazy val nilchar = program.lookupCaseClass("leon.collection.Nil") match { - case Some(cc) => cc.typed(Seq(CharType)) - case _ => throw new Exception("Could not find Nil in Z3 solver") - } - lazy val listchar = program.lookupAbstractClass("leon.collection.List") match { - case Some(cc) => cc.typed(Seq(CharType)) - case _ => throw new Exception("Could not find List in Z3 solver") - } - def lookupFunDef(s: String): FunDef = program.lookupFunDef(s) match { - case Some(fd) => fd - case _ => throw new Exception("Could not find function "+s+" in program") - } - lazy val list_size = lookupFunDef("leon.collection.List.size").typed(Seq(CharType)) - lazy val list_++ = lookupFunDef("leon.collection.List.++").typed(Seq(CharType)) - lazy val list_take = lookupFunDef("leon.collection.List.take").typed(Seq(CharType)) - lazy val list_drop = lookupFunDef("leon.collection.List.drop").typed(Seq(CharType)) - lazy val list_slice = lookupFunDef("leon.collection.List.slice").typed(Seq(CharType)) - - - override protected def fromSMT(t: Term, otpe: Option[TypeTree] = None) + override protected def fromSMT(t: Term, expected_otpe: Option[TypeTree] = None) (implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): Expr = { - (t, otpe) match { + val otpe = expected_otpe match { + case Some(StringType) => Some(listchar) + case _ => expected_otpe + } + val res = (t, otpe) match { case (SimpleSymbol(s), Some(tp: TypeParameter)) => val n = s.name.split("!").toList.last GenericValue(tp, n.toInt) @@ -119,28 +97,19 @@ trait SMTLIBZ3Target extends SMTLIBTarget { fromRawArray(RawArrayValue(ktpe, Map(), fromSMT(defV, vtpe)), tpe) - case (SimpleSymbol(s), Some(StringType)) if constructors.containsB(s) => - constructors.toA(s) match { - case cct: CaseClassType if cct == nilchar => - StringLiteral("") - case t => - unsupported(t, "woot? for a single constructor for non-case-object") - } - case (FunctionApplication(SimpleSymbol(s), args), Some(StringType)) if constructors.containsB(s) => - constructors.toA(s) match { - case cct: CaseClassType if cct == conschar => - val rargs = args.zip(cct.fields.map(_.getType)).map(fromSMT) - val s = ("" /: rargs) { - case (acc, c@CharLiteral(s)) => acc + s - case _ => unsupported(cct, "Cannot extract string out of list of any") - } - StringLiteral(s) - case t => unsupported(t, "Cannot extract string") - } - case _ => super.fromSMT(t, otpe) } + expected_otpe match { + case Some(StringType) => + StringLiteral(convertToString(res)(program)) + case _ => res + } + } + + def convertToTarget(e: Expr)(implicit bindings: Map[Identifier, Term]): Term = toSMT(e) + def targetApplication(tfd: TypedFunDef, args: Seq[Term])(implicit bindings: Map[Identifier, Term]): Term = { + FunctionApplication(declareFunction(tfd), args) } override protected def toSMT(e: Expr)(implicit bindings: Map[Identifier, Term]): Term = e match { @@ -177,23 +146,7 @@ trait SMTLIBZ3Target extends SMTLIBTarget { case SetIntersection(l, r) => ArrayMap(SSymbol("and"), toSMT(l), toSMT(r)) - case StringLiteral(v) => - // No string support for z3 at this moment. - val stringEncoding = stringBijection.cachedB(v) { - v.toList.foldRight(CaseClass(nilchar, Seq())){ - case (char, l) => CaseClass(conschar, Seq(CharLiteral(char), l)) - } - } - toSMT(stringEncoding) - case StringLength(a) => - FunctionApplication(declareFunction(list_size), Seq(toSMT(a))) - case StringConcat(a, b) => - FunctionApplication(declareFunction(list_++), Seq(toSMT(a), toSMT(b))) - case SubString(a, start, Plus(start2, length)) if start == start2 => - FunctionApplication(declareFunction(list_take), - Seq(FunctionApplication(declareFunction(list_drop), Seq(toSMT(a), toSMT(start))), toSMT(length))) - case SubString(a, start, end) => - FunctionApplication(declareFunction(list_slice), Seq(toSMT(a), toSMT(start), toSMT(end))) + case StringConverted(result) => result case _ => super.toSMT(e) } diff --git a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala index f0b7d5f91..ac1e8855a 100644 --- a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala @@ -264,311 +264,322 @@ trait AbstractZ3Solver extends Solver { case other => throw SolverUnsupportedError(other, this) } + + protected[leon] def toZ3Formula(expr: Expr, initialMap: Map[Identifier, Z3AST] = Map.empty): Z3AST = { - var z3Vars: Map[Identifier,Z3AST] = if(initialMap.nonEmpty) { + implicit var z3Vars: Map[Identifier,Z3AST] = if(initialMap.nonEmpty) { initialMap } else { // FIXME TODO pleeeeeeeease make this cleaner. Ie. decide what set of // variable has to remain in a map etc. variables.aToB.collect{ case (Variable(id), p2) => id -> p2 } } - - def rec(ex: Expr): Z3AST = ex match { - - // TODO: Leave that as a specialization? - case LetTuple(ids, e, b) => { - z3Vars = z3Vars ++ ids.zipWithIndex.map { case (id, ix) => - val entry = id -> rec(tupleSelect(e, ix + 1, ids.size)) - entry + new Z3StringConversion[Z3AST] { + def getProgram = AbstractZ3Solver.this.program + def convertToTarget(e: Expr)(implicit bindings: Map[Identifier, Z3AST]): Z3AST = { + rec(e) } - val rb = rec(b) - z3Vars = z3Vars -- ids - rb - } - - case p @ Passes(_, _, _) => - rec(p.asConstraint) - - case me @ MatchExpr(s, cs) => - rec(matchToIfThenElse(me)) - - case Let(i, e, b) => { - val re = rec(e) - z3Vars = z3Vars + (i -> re) - val rb = rec(b) - z3Vars = z3Vars - i - rb - } - - case Waypoint(_, e, _) => rec(e) - case a @ Assert(cond, err, body) => - rec(IfExpr(cond, body, Error(a.getType, err.getOrElse("Assertion failed")).setPos(a.getPos)).setPos(a.getPos)) - - case e @ Error(tpe, _) => { - val newAST = z3.mkFreshConst("errorValue", typeToSort(tpe)) - // Might introduce dupplicates (e), but no worries here - variables += (e -> newAST) - newAST - } - case v @ Variable(id) => z3Vars.get(id) match { - case Some(ast) => - ast - case None => { - variables.getB(v) match { - case Some(ast) => - ast - - case None => - val newAST = z3.mkFreshConst(id.uniqueName, typeToSort(v.getType)) - z3Vars = z3Vars + (id -> newAST) - variables += (v -> newAST) - newAST - } - } - } - - case ite @ IfExpr(c, t, e) => z3.mkITE(rec(c), rec(t), rec(e)) - case And(exs) => z3.mkAnd(exs.map(rec): _*) - case Or(exs) => z3.mkOr(exs.map(rec): _*) - case Implies(l, r) => z3.mkImplies(rec(l), rec(r)) - case Not(Equals(l, r)) => z3.mkDistinct(rec(l), rec(r)) - case Not(e) => z3.mkNot(rec(e)) - case IntLiteral(v) => z3.mkInt(v, typeToSort(Int32Type)) - case InfiniteIntegerLiteral(v) => z3.mkNumeral(v.toString, typeToSort(IntegerType)) - case FractionalLiteral(n, d) => z3.mkNumeral(s"$n / $d", typeToSort(RealType)) - case CharLiteral(c) => z3.mkInt(c, typeToSort(CharType)) - case BooleanLiteral(v) => if (v) z3.mkTrue() else z3.mkFalse() - case Equals(l, r) => z3.mkEq(rec( l ), rec( r ) ) - case Plus(l, r) => z3.mkAdd(rec(l), rec(r)) - case Minus(l, r) => z3.mkSub(rec(l), rec(r)) - case Times(l, r) => z3.mkMul(rec(l), rec(r)) - case Division(l, r) => { - val rl = rec(l) - val rr = rec(r) - z3.mkITE( - z3.mkGE(rl, z3.mkNumeral("0", typeToSort(IntegerType))), - z3.mkDiv(rl, rr), - z3.mkUnaryMinus(z3.mkDiv(z3.mkUnaryMinus(rl), rr)) - ) - } - case Remainder(l, r) => { - val q = rec(Division(l, r)) - z3.mkSub(rec(l), z3.mkMul(rec(r), q)) - } - case Modulo(l, r) => { - z3.mkMod(rec(l), rec(r)) - } - case UMinus(e) => z3.mkUnaryMinus(rec(e)) - - case RealPlus(l, r) => z3.mkAdd(rec(l), rec(r)) - case RealMinus(l, r) => z3.mkSub(rec(l), rec(r)) - case RealTimes(l, r) => z3.mkMul(rec(l), rec(r)) - case RealDivision(l, r) => z3.mkDiv(rec(l), rec(r)) - case RealUMinus(e) => z3.mkUnaryMinus(rec(e)) - - case BVPlus(l, r) => z3.mkBVAdd(rec(l), rec(r)) - case BVMinus(l, r) => z3.mkBVSub(rec(l), rec(r)) - case BVTimes(l, r) => z3.mkBVMul(rec(l), rec(r)) - case BVDivision(l, r) => z3.mkBVSdiv(rec(l), rec(r)) - case BVRemainder(l, r) => z3.mkBVSrem(rec(l), rec(r)) - case BVUMinus(e) => z3.mkBVNeg(rec(e)) - case BVNot(e) => z3.mkBVNot(rec(e)) - case BVAnd(l, r) => z3.mkBVAnd(rec(l), rec(r)) - case BVOr(l, r) => z3.mkBVOr(rec(l), rec(r)) - case BVXOr(l, r) => z3.mkBVXor(rec(l), rec(r)) - case BVShiftLeft(l, r) => z3.mkBVShl(rec(l), rec(r)) - case BVAShiftRight(l, r) => z3.mkBVAshr(rec(l), rec(r)) - case BVLShiftRight(l, r) => z3.mkBVLshr(rec(l), rec(r)) - case LessThan(l, r) => l.getType match { - case IntegerType => z3.mkLT(rec(l), rec(r)) - case RealType => z3.mkLT(rec(l), rec(r)) - case Int32Type => z3.mkBVSlt(rec(l), rec(r)) - case CharType => z3.mkBVSlt(rec(l), rec(r)) - } - case LessEquals(l, r) => l.getType match { - case IntegerType => z3.mkLE(rec(l), rec(r)) - case RealType => z3.mkLE(rec(l), rec(r)) - case Int32Type => z3.mkBVSle(rec(l), rec(r)) - case CharType => z3.mkBVSle(rec(l), rec(r)) - //case _ => throw new IllegalStateException(s"l: $l, Left type: ${l.getType} Expr: $ex") - } - case GreaterThan(l, r) => l.getType match { - case IntegerType => z3.mkGT(rec(l), rec(r)) - case RealType => z3.mkGT(rec(l), rec(r)) - case Int32Type => z3.mkBVSgt(rec(l), rec(r)) - case CharType => z3.mkBVSgt(rec(l), rec(r)) - } - case GreaterEquals(l, r) => l.getType match { - case IntegerType => z3.mkGE(rec(l), rec(r)) - case RealType => z3.mkGE(rec(l), rec(r)) - case Int32Type => z3.mkBVSge(rec(l), rec(r)) - case CharType => z3.mkBVSge(rec(l), rec(r)) - } - - case u : UnitLiteral => - val tpe = normalizeType(u.getType) - typeToSort(tpe) - val constructor = constructors.toB(tpe) - constructor() - - case t @ Tuple(es) => - val tpe = normalizeType(t.getType) - typeToSort(tpe) - val constructor = constructors.toB(tpe) - constructor(es.map(rec): _*) - - case ts @ TupleSelect(t, i) => - val tpe = normalizeType(t.getType) - typeToSort(tpe) - val selector = selectors.toB((tpe, i-1)) - selector(rec(t)) - - case c @ CaseClass(ct, args) => - typeToSort(ct) // Making sure the sort is defined - val constructor = constructors.toB(ct) - constructor(args.map(rec): _*) - - case c @ CaseClassSelector(cct, cc, sel) => - typeToSort(cct) // Making sure the sort is defined - val selector = selectors.toB(cct, c.selectorIndex) - selector(rec(cc)) - - case AsInstanceOf(expr, ct) => - rec(expr) - - case IsInstanceOf(e, act: AbstractClassType) => - act.knownCCDescendants match { - case Seq(cct) => - rec(IsInstanceOf(e, cct)) - case more => - val i = FreshIdentifier("e", act, alwaysShowUniqueID = true) - rec(Let(i, e, orJoin(more map(IsInstanceOf(Variable(i), _))))) + def targetApplication(tfd: TypedFunDef, args: Seq[Z3AST])(implicit bindings: Map[Identifier, Z3AST]): Z3AST = { + z3.mkApp(functionDefToDecl(tfd), args: _*) } - - case IsInstanceOf(e, cct: CaseClassType) => - typeToSort(cct) // Making sure the sort is defined - val tester = testers.toB(cct) - tester(rec(e)) - - case al @ ArraySelect(a, i) => - val tpe = normalizeType(a.getType) - - val sa = rec(a) - val content = selectors.toB((tpe, 1))(sa) - - z3.mkSelect(content, rec(i)) - - case al @ ArrayUpdated(a, i, e) => - val tpe = normalizeType(a.getType) - - val sa = rec(a) - val ssize = selectors.toB((tpe, 0))(sa) - val scontent = selectors.toB((tpe, 1))(sa) - - val newcontent = z3.mkStore(scontent, rec(i), rec(e)) - - val constructor = constructors.toB(tpe) - - constructor(ssize, newcontent) - - case al @ ArrayLength(a) => - val tpe = normalizeType(a.getType) - val sa = rec(a) - selectors.toB((tpe, 0))(sa) - - case arr @ FiniteArray(elems, oDefault, length) => - val at @ ArrayType(base) = normalizeType(arr.getType) - typeToSort(at) - - val default = oDefault.getOrElse(simplestValue(base)) - - val ar = rec(RawArrayValue(Int32Type, elems.map { - case (i, e) => IntLiteral(i) -> e - }, default)) - - constructors.toB(at)(rec(length), ar) - - case f @ FunctionInvocation(tfd, args) => - z3.mkApp(functionDefToDecl(tfd), args.map(rec): _*) - - case fa @ Application(caller, args) => - val ft @ FunctionType(froms, to) = normalizeType(caller.getType) - val funDecl = lambdas.cachedB(ft) { - val sortSeq = (ft +: froms).map(tpe => typeToSort(tpe)) - val returnSort = typeToSort(to) - - val name = FreshIdentifier("dynLambda").uniqueName - z3.mkFreshFuncDecl(name, sortSeq, returnSort) - } - z3.mkApp(funDecl, (caller +: args).map(rec): _*) - - case ElementOfSet(e, s) => z3.mkSetMember(rec(e), rec(s)) - case SubsetOf(s1, s2) => z3.mkSetSubset(rec(s1), rec(s2)) - case SetIntersection(s1, s2) => z3.mkSetIntersect(rec(s1), rec(s2)) - case SetUnion(s1, s2) => z3.mkSetUnion(rec(s1), rec(s2)) - case SetDifference(s1, s2) => z3.mkSetDifference(rec(s1), rec(s2)) - case f @ FiniteSet(elems, base) => elems.foldLeft(z3.mkEmptySet(typeToSort(base)))((ast, el) => z3.mkSetAdd(ast, rec(el))) - - case RawArrayValue(keyTpe, elems, default) => - val ar = z3.mkConstArray(typeToSort(keyTpe), rec(default)) - - elems.foldLeft(ar) { - case (array, (k, v)) => z3.mkStore(array, rec(k), rec(v)) - } - - /** - * ===== Map operations ===== - */ - case m @ FiniteMap(elems, from, to) => - val MapType(_, t) = normalizeType(m.getType) - - rec(RawArrayValue(from, elems.map{ - case (k, v) => (k, CaseClass(library.someType(t), Seq(v))) - }.toMap, CaseClass(library.noneType(t), Seq()))) - - case MapApply(m, k) => - val mt @ MapType(_, t) = normalizeType(m.getType) - typeToSort(mt) - - val el = z3.mkSelect(rec(m), rec(k)) - - // Really ?!? We don't check that it is actually != None? - selectors.toB(library.someType(t), 0)(el) - - case MapIsDefinedAt(m, k) => - val mt @ MapType(_, t) = normalizeType(m.getType) - typeToSort(mt) - - val el = z3.mkSelect(rec(m), rec(k)) - - testers.toB(library.someType(t))(el) - - case MapUnion(m1, FiniteMap(elems, _, _)) => - val mt @ MapType(_, t) = normalizeType(m1.getType) - typeToSort(mt) - - elems.foldLeft(rec(m1)) { case (m, (k,v)) => - z3.mkStore(m, rec(k), rec(CaseClass(library.someType(t), Seq(v)))) + def rec(ex: Expr): Z3AST = ex match { + + // TODO: Leave that as a specialization? + case LetTuple(ids, e, b) => { + z3Vars = z3Vars ++ ids.zipWithIndex.map { case (id, ix) => + val entry = id -> rec(tupleSelect(e, ix + 1, ids.size)) + entry + } + val rb = rec(b) + z3Vars = z3Vars -- ids + rb + } + + case p @ Passes(_, _, _) => + rec(p.asConstraint) + + case me @ MatchExpr(s, cs) => + rec(matchToIfThenElse(me)) + + case Let(i, e, b) => { + val re = rec(e) + z3Vars = z3Vars + (i -> re) + val rb = rec(b) + z3Vars = z3Vars - i + rb + } + + case Waypoint(_, e, _) => rec(e) + case a @ Assert(cond, err, body) => + rec(IfExpr(cond, body, Error(a.getType, err.getOrElse("Assertion failed")).setPos(a.getPos)).setPos(a.getPos)) + + case e @ Error(tpe, _) => { + val newAST = z3.mkFreshConst("errorValue", typeToSort(tpe)) + // Might introduce dupplicates (e), but no worries here + variables += (e -> newAST) + newAST + } + case v @ Variable(id) => z3Vars.get(id) match { + case Some(ast) => + ast + case None => { + variables.getB(v) match { + case Some(ast) => + ast + + case None => + val newAST = z3.mkFreshConst(id.uniqueName, typeToSort(v.getType)) + z3Vars = z3Vars + (id -> newAST) + variables += (v -> newAST) + newAST + } + } + } + + case ite @ IfExpr(c, t, e) => z3.mkITE(rec(c), rec(t), rec(e)) + case And(exs) => z3.mkAnd(exs.map(rec): _*) + case Or(exs) => z3.mkOr(exs.map(rec): _*) + case Implies(l, r) => z3.mkImplies(rec(l), rec(r)) + case Not(Equals(l, r)) => z3.mkDistinct(rec(l), rec(r)) + case Not(e) => z3.mkNot(rec(e)) + case IntLiteral(v) => z3.mkInt(v, typeToSort(Int32Type)) + case InfiniteIntegerLiteral(v) => z3.mkNumeral(v.toString, typeToSort(IntegerType)) + case FractionalLiteral(n, d) => z3.mkNumeral(s"$n / $d", typeToSort(RealType)) + case CharLiteral(c) => z3.mkInt(c, typeToSort(CharType)) + case BooleanLiteral(v) => if (v) z3.mkTrue() else z3.mkFalse() + case Equals(l, r) => z3.mkEq(rec( l ), rec( r ) ) + case Plus(l, r) => z3.mkAdd(rec(l), rec(r)) + case Minus(l, r) => z3.mkSub(rec(l), rec(r)) + case Times(l, r) => z3.mkMul(rec(l), rec(r)) + case Division(l, r) => { + val rl = rec(l) + val rr = rec(r) + z3.mkITE( + z3.mkGE(rl, z3.mkNumeral("0", typeToSort(IntegerType))), + z3.mkDiv(rl, rr), + z3.mkUnaryMinus(z3.mkDiv(z3.mkUnaryMinus(rl), rr)) + ) + } + case Remainder(l, r) => { + val q = rec(Division(l, r)) + z3.mkSub(rec(l), z3.mkMul(rec(r), q)) + } + case Modulo(l, r) => { + z3.mkMod(rec(l), rec(r)) + } + case UMinus(e) => z3.mkUnaryMinus(rec(e)) + + case RealPlus(l, r) => z3.mkAdd(rec(l), rec(r)) + case RealMinus(l, r) => z3.mkSub(rec(l), rec(r)) + case RealTimes(l, r) => z3.mkMul(rec(l), rec(r)) + case RealDivision(l, r) => z3.mkDiv(rec(l), rec(r)) + case RealUMinus(e) => z3.mkUnaryMinus(rec(e)) + + case BVPlus(l, r) => z3.mkBVAdd(rec(l), rec(r)) + case BVMinus(l, r) => z3.mkBVSub(rec(l), rec(r)) + case BVTimes(l, r) => z3.mkBVMul(rec(l), rec(r)) + case BVDivision(l, r) => z3.mkBVSdiv(rec(l), rec(r)) + case BVRemainder(l, r) => z3.mkBVSrem(rec(l), rec(r)) + case BVUMinus(e) => z3.mkBVNeg(rec(e)) + case BVNot(e) => z3.mkBVNot(rec(e)) + case BVAnd(l, r) => z3.mkBVAnd(rec(l), rec(r)) + case BVOr(l, r) => z3.mkBVOr(rec(l), rec(r)) + case BVXOr(l, r) => z3.mkBVXor(rec(l), rec(r)) + case BVShiftLeft(l, r) => z3.mkBVShl(rec(l), rec(r)) + case BVAShiftRight(l, r) => z3.mkBVAshr(rec(l), rec(r)) + case BVLShiftRight(l, r) => z3.mkBVLshr(rec(l), rec(r)) + case LessThan(l, r) => l.getType match { + case IntegerType => z3.mkLT(rec(l), rec(r)) + case RealType => z3.mkLT(rec(l), rec(r)) + case Int32Type => z3.mkBVSlt(rec(l), rec(r)) + case CharType => z3.mkBVSlt(rec(l), rec(r)) + } + case LessEquals(l, r) => l.getType match { + case IntegerType => z3.mkLE(rec(l), rec(r)) + case RealType => z3.mkLE(rec(l), rec(r)) + case Int32Type => z3.mkBVSle(rec(l), rec(r)) + case CharType => z3.mkBVSle(rec(l), rec(r)) + //case _ => throw new IllegalStateException(s"l: $l, Left type: ${l.getType} Expr: $ex") + } + case GreaterThan(l, r) => l.getType match { + case IntegerType => z3.mkGT(rec(l), rec(r)) + case RealType => z3.mkGT(rec(l), rec(r)) + case Int32Type => z3.mkBVSgt(rec(l), rec(r)) + case CharType => z3.mkBVSgt(rec(l), rec(r)) + } + case GreaterEquals(l, r) => l.getType match { + case IntegerType => z3.mkGE(rec(l), rec(r)) + case RealType => z3.mkGE(rec(l), rec(r)) + case Int32Type => z3.mkBVSge(rec(l), rec(r)) + case CharType => z3.mkBVSge(rec(l), rec(r)) + } + + case StringConverted(result) => + result + + case u : UnitLiteral => + val tpe = normalizeType(u.getType) + typeToSort(tpe) + val constructor = constructors.toB(tpe) + constructor() + + case t @ Tuple(es) => + val tpe = normalizeType(t.getType) + typeToSort(tpe) + val constructor = constructors.toB(tpe) + constructor(es.map(rec): _*) + + case ts @ TupleSelect(t, i) => + val tpe = normalizeType(t.getType) + typeToSort(tpe) + val selector = selectors.toB((tpe, i-1)) + selector(rec(t)) + + case c @ CaseClass(ct, args) => + typeToSort(ct) // Making sure the sort is defined + val constructor = constructors.toB(ct) + constructor(args.map(rec): _*) + + case c @ CaseClassSelector(cct, cc, sel) => + typeToSort(cct) // Making sure the sort is defined + val selector = selectors.toB(cct, c.selectorIndex) + selector(rec(cc)) + + case AsInstanceOf(expr, ct) => + rec(expr) + + case IsInstanceOf(e, act: AbstractClassType) => + act.knownCCDescendants match { + case Seq(cct) => + rec(IsInstanceOf(e, cct)) + case more => + val i = FreshIdentifier("e", act, alwaysShowUniqueID = true) + rec(Let(i, e, orJoin(more map(IsInstanceOf(Variable(i), _))))) + } + + case IsInstanceOf(e, cct: CaseClassType) => + typeToSort(cct) // Making sure the sort is defined + val tester = testers.toB(cct) + tester(rec(e)) + + case al @ ArraySelect(a, i) => + val tpe = normalizeType(a.getType) + + val sa = rec(a) + val content = selectors.toB((tpe, 1))(sa) + + z3.mkSelect(content, rec(i)) + + case al @ ArrayUpdated(a, i, e) => + val tpe = normalizeType(a.getType) + + val sa = rec(a) + val ssize = selectors.toB((tpe, 0))(sa) + val scontent = selectors.toB((tpe, 1))(sa) + + val newcontent = z3.mkStore(scontent, rec(i), rec(e)) + + val constructor = constructors.toB(tpe) + + constructor(ssize, newcontent) + + case al @ ArrayLength(a) => + val tpe = normalizeType(a.getType) + val sa = rec(a) + selectors.toB((tpe, 0))(sa) + + case arr @ FiniteArray(elems, oDefault, length) => + val at @ ArrayType(base) = normalizeType(arr.getType) + typeToSort(at) + + val default = oDefault.getOrElse(simplestValue(base)) + + val ar = rec(RawArrayValue(Int32Type, elems.map { + case (i, e) => IntLiteral(i) -> e + }, default)) + + constructors.toB(at)(rec(length), ar) + + case f @ FunctionInvocation(tfd, args) => + z3.mkApp(functionDefToDecl(tfd), args.map(rec): _*) + + case fa @ Application(caller, args) => + val ft @ FunctionType(froms, to) = normalizeType(caller.getType) + val funDecl = lambdas.cachedB(ft) { + val sortSeq = (ft +: froms).map(tpe => typeToSort(tpe)) + val returnSort = typeToSort(to) + + val name = FreshIdentifier("dynLambda").uniqueName + z3.mkFreshFuncDecl(name, sortSeq, returnSort) + } + z3.mkApp(funDecl, (caller +: args).map(rec): _*) + + case ElementOfSet(e, s) => z3.mkSetMember(rec(e), rec(s)) + case SubsetOf(s1, s2) => z3.mkSetSubset(rec(s1), rec(s2)) + case SetIntersection(s1, s2) => z3.mkSetIntersect(rec(s1), rec(s2)) + case SetUnion(s1, s2) => z3.mkSetUnion(rec(s1), rec(s2)) + case SetDifference(s1, s2) => z3.mkSetDifference(rec(s1), rec(s2)) + case f @ FiniteSet(elems, base) => elems.foldLeft(z3.mkEmptySet(typeToSort(base)))((ast, el) => z3.mkSetAdd(ast, rec(el))) + + case RawArrayValue(keyTpe, elems, default) => + val ar = z3.mkConstArray(typeToSort(keyTpe), rec(default)) + + elems.foldLeft(ar) { + case (array, (k, v)) => z3.mkStore(array, rec(k), rec(v)) + } + + /** + * ===== Map operations ===== + */ + case m @ FiniteMap(elems, from, to) => + val MapType(_, t) = normalizeType(m.getType) + + rec(RawArrayValue(from, elems.map{ + case (k, v) => (k, CaseClass(library.someType(t), Seq(v))) + }.toMap, CaseClass(library.noneType(t), Seq()))) + + case MapApply(m, k) => + val mt @ MapType(_, t) = normalizeType(m.getType) + typeToSort(mt) + + val el = z3.mkSelect(rec(m), rec(k)) + + // Really ?!? We don't check that it is actually != None? + selectors.toB(library.someType(t), 0)(el) + + case MapIsDefinedAt(m, k) => + val mt @ MapType(_, t) = normalizeType(m.getType) + typeToSort(mt) + + val el = z3.mkSelect(rec(m), rec(k)) + + testers.toB(library.someType(t))(el) + + case MapUnion(m1, FiniteMap(elems, _, _)) => + val mt @ MapType(_, t) = normalizeType(m1.getType) + typeToSort(mt) + + elems.foldLeft(rec(m1)) { case (m, (k,v)) => + z3.mkStore(m, rec(k), rec(CaseClass(library.someType(t), Seq(v)))) + } + + + case gv @ GenericValue(tp, id) => + z3.mkApp(genericValueToDecl(gv)) + + case other => + unsupported(other) } - - - case gv @ GenericValue(tp, id) => - z3.mkApp(genericValueToDecl(gv)) - - case other => - unsupported(other) - } - - rec(expr) + }.rec(expr) } protected[leon] def fromZ3Formula(model: Z3Model, tree: Z3AST, tpe: TypeTree): Expr = { - - def rec(t: Z3AST, tpe: TypeTree): Expr = { + def rec(t: Z3AST, expected_tpe: TypeTree): Expr = { val kind = z3.getASTKind(t) - kind match { + val tpe = Z3StringTypeConversion.convert(expected_tpe)(program) + val res = kind match { case Z3NumeralIntAST(Some(v)) => val leading = t.toString.substring(0, 2 min t.toString.length) if(leading == "#x") { @@ -758,6 +769,11 @@ trait AbstractZ3Solver extends Solver { } case _ => unsound(t, "unexpected AST") } + expected_tpe match { + case StringType => + StringLiteral(Z3StringTypeConversion.convertToString(res)(program)) + case _ => res + } } rec(tree, normalizeType(tpe)) @@ -774,7 +790,8 @@ trait AbstractZ3Solver extends Solver { } def idToFreshZ3Id(id: Identifier): Z3AST = { - z3.mkFreshConst(id.uniqueName, typeToSort(id.getType)) + val correctType = Z3StringTypeConversion.convert(id.getType)(program) + z3.mkFreshConst(id.uniqueName, typeToSort(correctType)) } def reset() = { diff --git a/src/main/scala/leon/solvers/z3/Z3StringConversion.scala b/src/main/scala/leon/solvers/z3/Z3StringConversion.scala new file mode 100644 index 000000000..3daf1ad49 --- /dev/null +++ b/src/main/scala/leon/solvers/z3/Z3StringConversion.scala @@ -0,0 +1,94 @@ +package leon +package solvers +package z3 + +import purescala.Common._ +import purescala.Expressions._ +import purescala.Constructors._ +import purescala.Types._ +import purescala.Definitions._ +import _root_.smtlib.parser.Terms.{Identifier => SMTIdentifier, _} +import _root_.smtlib.parser.Commands.{FunDef => SMTFunDef, _} +import _root_.smtlib.interpreters.Z3Interpreter +import _root_.smtlib.theories.Core.{Equals => SMTEquals, _} +import _root_.smtlib.theories.ArraysEx +import leon.utils.Bijection + +object Z3StringTypeConversion { + def convert(t: TypeTree)(implicit p: Program) = new Z3StringTypeConversion { def getProgram = p }.convertType(t) + def convertToString(e: Expr)(implicit p: Program) = new Z3StringTypeConversion{ def getProgram = p }.convertToString(e) +} + +trait Z3StringTypeConversion { + val stringBijection = new Bijection[String, Expr]() + + lazy val conschar = program.lookupCaseClass("leon.collection.Cons") match { + case Some(cc) => cc.typed(Seq(CharType)) + case _ => throw new Exception("Could not find Cons in Z3 solver") + } + lazy val nilchar = program.lookupCaseClass("leon.collection.Nil") match { + case Some(cc) => cc.typed(Seq(CharType)) + case _ => throw new Exception("Could not find Nil in Z3 solver") + } + lazy val listchar = program.lookupAbstractClass("leon.collection.List") match { + case Some(cc) => cc.typed(Seq(CharType)) + case _ => throw new Exception("Could not find List in Z3 solver") + } + def lookupFunDef(s: String): FunDef = program.lookupFunDef(s) match { + case Some(fd) => fd + case _ => throw new Exception("Could not find function "+s+" in program") + } + lazy val list_size = lookupFunDef("leon.collection.List.size").typed(Seq(CharType)) + lazy val list_++ = lookupFunDef("leon.collection.List.++").typed(Seq(CharType)) + lazy val list_take = lookupFunDef("leon.collection.List.take").typed(Seq(CharType)) + lazy val list_drop = lookupFunDef("leon.collection.List.drop").typed(Seq(CharType)) + lazy val list_slice = lookupFunDef("leon.collection.List.slice").typed(Seq(CharType)) + + private lazy val program = getProgram + + def getProgram: Program + + def convertType(t: TypeTree): TypeTree = t match { + case StringType => listchar + case _ => t + } + def convertToString(e: Expr)(implicit p: Program): String = + stringBijection.cachedA(e) { + e match { + case CaseClass(_, Seq(CharLiteral(c), l)) => c + convertToString(l) + case CaseClass(_, Seq()) => "" + } + } + def convertFromString(v: String) = + stringBijection.cachedB(v) { + v.toList.foldRight(CaseClass(nilchar, Seq())){ + case (char, l) => CaseClass(conschar, Seq(CharLiteral(char), l)) + } + } +} + +trait Z3StringConversion[TargetType] extends Z3StringTypeConversion { + def convertToTarget(e: Expr)(implicit bindings: Map[Identifier, TargetType]): TargetType + def targetApplication(fd: TypedFunDef, args: Seq[TargetType])(implicit bindings: Map[Identifier, TargetType]): TargetType + + object StringConverted { + def unapply(e: Expr)(implicit bindings: Map[Identifier, TargetType]): Option[TargetType] = e match { + case StringLiteral(v) => + // No string support for z3 at this moment. + val stringEncoding = convertFromString(v) + Some(convertToTarget(stringEncoding)) + case StringLength(a) => + Some(targetApplication(list_size, Seq(convertToTarget(a)))) + case StringConcat(a, b) => + Some(targetApplication(list_++, Seq(convertToTarget(a), convertToTarget(b)))) + case SubString(a, start, Plus(start2, length)) if start == start2 => + Some(targetApplication(list_take, + Seq(targetApplication(list_drop, Seq(convertToTarget(a), convertToTarget(start))), convertToTarget(length)))) + case SubString(a, start, end) => + Some(targetApplication(list_slice, Seq(convertToTarget(a), convertToTarget(start), convertToTarget(end)))) + case _ => None + } + + def apply(t: TypeTree): TypeTree = convertType(t) + } +} \ No newline at end of file diff --git a/testcases/web/synthesis/24_String_DoubleList.scala b/testcases/web/synthesis/24_String_DoubleList.scala index 0d89d3f23..652f5e4b8 100644 --- a/testcases/web/synthesis/24_String_DoubleList.scala +++ b/testcases/web/synthesis/24_String_DoubleList.scala @@ -26,7 +26,7 @@ object DoubleListRender { (res : String) => (a, res) passes { case N() => "[]" - case B(NN()) => + case B(NN(), N()) => "[()]" } } -- GitLab