diff --git a/src/main/scala/leon/solvers/combinators/Z3StringCapableSolver.scala b/src/main/scala/leon/solvers/combinators/Z3StringCapableSolver.scala index 2b6984d00d989e6958da3c43930b9614436cffbf..c39c4fe09268c3e9873b5e2796e18451562bf6d9 100644 --- a/src/main/scala/leon/solvers/combinators/Z3StringCapableSolver.scala +++ b/src/main/scala/leon/solvers/combinators/Z3StringCapableSolver.scala @@ -23,7 +23,7 @@ import leon.utils.Bijection import leon.solvers.z3.StringEcoSystem object Z3StringCapableSolver { - def convert(p: Program, force: Boolean = false): (Program, Option[Z3StringConversion]) = { + def convert(p: Program): (Program, Option[Z3StringConversion]) = { val converter = new Z3StringConversion(p) import converter.Forward._ var globalFdMap = Map[FunDef, (Map[Identifier, Identifier], FunDef)]() @@ -45,7 +45,7 @@ object Z3StringCapableSolver { } else None ) }) - if(!hasStrings && !force) { + if(!hasStrings) { (p, None) } else { converter.globalFdMap ++= globalFdMap.view.map(kv => (kv._1, kv._2._2)) @@ -58,21 +58,16 @@ object Z3StringCapableSolver { } } -trait ForcedProgramConversion { self: Z3StringCapableSolver[_] => - override def convertProgram(p: Program): (Program, Option[Z3StringConversion]) = { - Z3StringCapableSolver.convert(p, true) - } -} - abstract class Z3StringCapableSolver[+TUnderlying <: Solver]( val context: LeonContext, val program: Program, val underlyingConstructor: (Program, Option[Z3StringConversion]) => TUnderlying) extends Solver { - def convertProgram(p: Program): (Program, Option[Z3StringConversion]) = Z3StringCapableSolver.convert(p) - protected val (new_program, someConverter) = convertProgram(program) + protected val (new_program, optConverter) = Z3StringCapableSolver.convert(program) + var someConverter = optConverter val underlying = underlyingConstructor(new_program, someConverter) + var solverInvokedWithStrings = false def getModel: leon.solvers.Model = { val model = underlying.getModel @@ -98,24 +93,40 @@ abstract class Z3StringCapableSolver[+TUnderlying <: Solver]( new PartialModel(original_ids.zip(original_exprs).toMap, new_domain) case _ => - new Model(original_ids.zip(original_exprs).toMap) - } + new Model(original_ids.zip(original_exprs).toMap) } } + } // Members declared in leon.utils.Interruptible def interrupt(): Unit = underlying.interrupt() def recoverInterrupt(): Unit = underlying.recoverInterrupt() + // Converts expression on the fly if needed, creating a string converter if needed. + def convertExprOnTheFly(expression: Expr, withConverter: Z3StringConversion => Expr): Expr = { + someConverter match { + case None => + if(solverInvokedWithStrings || exists(e => TypeOps.exists(StringType == _)(e.getType))(expression)) { // On the fly conversion + solverInvokedWithStrings = true + val c = new Z3StringConversion(program) + someConverter = Some(c) + withConverter(c) + } else expression + case Some(converter) => + withConverter(converter) + } + } + // Members declared in leon.solvers.Solver def assertCnstr(expression: Expr): Unit = { someConverter.map{converter => import converter.Forward._ val newExpression = convertExpr(expression)(Map()) underlying.assertCnstr(newExpression) - }.getOrElse(underlying.assertCnstr(expression)) + }.getOrElse{ + underlying.assertCnstr(convertExprOnTheFly(expression, _.Forward.convertExpr(expression)(Map()))) + } } - def getUnsatCore: Set[Expr] = { someConverter.map{converter => import converter.Backward._ @@ -150,7 +161,7 @@ class ConvertibleCodeGenEvaluator(context: LeonContext, originalProgram: Program override def compile(expression: Expr, args: Seq[Identifier]) : Option[solvers.Model=>EvaluationResult] = { import converter._ super.compile(Backward.convertExpr(expression)(Map()), args.map(Backward.convertId)) - .map(evaluator => (m: Model) => Forward.convertResult(evaluator(Backward.convertModel(m))) + .map(evaluator => (m: Model) => Forward.convertResult(evaluator(Backward.convertModel(m))) ) } } @@ -190,36 +201,36 @@ class Z3StringFairZ3Solver(context: LeonContext, program: Program) new FairZ3SolverWithBackwardEvaluator(context, prgm, program, someConverter)) with Z3StringEvaluatingSolver[FairZ3Solver] { - // Members declared in leon.solvers.z3.AbstractZ3Solver - protected[leon] val z3cfg: _root_.z3.scala.Z3Config = underlying.z3cfg - override def checkAssumptions(assumptions: Set[Expr]): Option[Boolean] = { - someConverter match { - case None => underlying.checkAssumptions(assumptions) - case Some(converter) => - underlying.checkAssumptions(assumptions map (e => converter.Forward.convertExpr(e)(Map()))) + // Members declared in leon.solvers.z3.AbstractZ3Solver + protected[leon] val z3cfg: _root_.z3.scala.Z3Config = underlying.z3cfg + override def checkAssumptions(assumptions: Set[Expr]): Option[Boolean] = { + someConverter match { + case None => underlying.checkAssumptions(assumptions.map(e => this.convertExprOnTheFly(e, _.Forward.convertExpr(e)(Map())))) + case Some(converter) => + underlying.checkAssumptions(assumptions map (e => converter.Forward.convertExpr(e)(Map()))) + } } - } } class Z3StringUnrollingSolver(context: LeonContext, program: Program, underlyingSolverConstructor: Program => Solver) extends Z3StringCapableSolver(context, program, (program: Program, converter: Option[Z3StringConversion]) => new UnrollingSolver(context, program, underlyingSolverConstructor(program))) - with Z3StringNaiveAssumptionSolver[UnrollingSolver] + with Z3StringNaiveAssumptionSolver[UnrollingSolver] with Z3StringEvaluatingSolver[UnrollingSolver] { - override def getUnsatCore = super[Z3StringNaiveAssumptionSolver].getUnsatCore + override def getUnsatCore = super[Z3StringNaiveAssumptionSolver].getUnsatCore } class Z3StringSMTLIBZ3QuantifiedSolver(context: LeonContext, program: Program) extends Z3StringCapableSolver(context, program, (program: Program, converter: Option[Z3StringConversion]) => new smtlib.SMTLIBZ3QuantifiedSolver(context, program)) { - override def checkAssumptions(assumptions: Set[Expr]): Option[Boolean] = { - someConverter match { - case None => underlying.checkAssumptions(assumptions) - case Some(converter) => - underlying.checkAssumptions(assumptions map (e => converter.Forward.convertExpr(e)(Map()))) + override def checkAssumptions(assumptions: Set[Expr]): Option[Boolean] = { + someConverter match { + case None => underlying.checkAssumptions(assumptions) + case Some(converter) => + underlying.checkAssumptions(assumptions map (e => converter.Forward.convertExpr(e)(Map()))) + } } - } } diff --git a/src/main/scala/leon/solvers/z3/Z3StringConversion.scala b/src/main/scala/leon/solvers/z3/Z3StringConversion.scala index 6d41b5fd8f45def478849c6c2d0623ed85985ffb..1c713fdf1ceb600cca902b38fd1dbd64f1754f4a 100644 --- a/src/main/scala/leon/solvers/z3/Z3StringConversion.scala +++ b/src/main/scala/leon/solvers/z3/Z3StringConversion.scala @@ -112,28 +112,24 @@ object StringEcoSystem { } class Z3StringConversion(val p: Program) extends Z3StringConverters { - val stringBijection = new Bijection[String, Expr]() - import StringEcoSystem._ - - lazy val listchar = StringList.typed - lazy val conschar = StringCons.typed - lazy val nilchar = StringNil.typed - - lazy val list_size = StringSize.typed - lazy val list_++ = StringListConcat.typed - lazy val list_take = StringTake.typed - lazy val list_drop = StringDrop.typed - lazy val list_slice = StringSlice.typed - def getProgram = program_with_string_methods lazy val program_with_string_methods = { val p2 = DefOps.addClassDefs(p, StringEcoSystem.classDefs, p.library.Nil.get) DefOps.addFunDefs(p2, StringEcoSystem.funDefs, p2.library.escape.get) } +} - def convertToString(e: Expr)(implicit p: Program): String = +trait Z3StringConverters { + import StringEcoSystem._ + val mappedVariables = new Bijection[Identifier, Identifier]() + + val globalFdMap = new Bijection[FunDef, FunDef]() + + val stringBijection = new Bijection[String, Expr]() + + def convertToString(e: Expr): String = stringBijection.cachedA(e) { e match { case CaseClass(_, Seq(CharLiteral(c), l)) => c + convertToString(l) @@ -142,17 +138,10 @@ class Z3StringConversion(val p: Program) extends Z3StringConverters { } def convertFromString(v: String): Expr = stringBijection.cachedB(v) { - v.toList.foldRight(CaseClass(nilchar, Seq())){ - case (char, l) => CaseClass(conschar, Seq(CharLiteral(char), l)) + v.toList.foldRight(CaseClass(StringNilTyped, Seq())){ + case (char, l) => CaseClass(StringConsTyped, Seq(CharLiteral(char), l)) } } -} - -trait Z3StringConverters { self: Z3StringConversion => - import StringEcoSystem._ - val mappedVariables = new Bijection[Identifier, Identifier]() - - val globalFdMap = new Bijection[FunDef, FunDef]() trait BidirectionalConverters { def convertFunDef(fd: FunDef): FunDef @@ -291,18 +280,17 @@ trait Z3StringConverters { self: Z3StringConversion => def convertExpr(e: Expr)(implicit bindings: Map[Identifier, Expr]): Expr = e match { case Variable(id) if isTypeToConvert(id.getType) => Variable(convertId(id)).copiedFrom(e) case StringLiteral(v) => - // No string support for z3 at this moment. val stringEncoding = convertFromString(v) convertExpr(stringEncoding).copiedFrom(e) case StringLength(a) => - FunctionInvocation(list_size, Seq(convertExpr(a))).copiedFrom(e) + FunctionInvocation(StringSize.typed, Seq(convertExpr(a))).copiedFrom(e) case StringConcat(a, b) => - FunctionInvocation(list_++, Seq(convertExpr(a), convertExpr(b))).copiedFrom(e) + FunctionInvocation(StringListConcat.typed, Seq(convertExpr(a), convertExpr(b))).copiedFrom(e) case SubString(a, start, Plus(start2, length)) if start == start2 => - FunctionInvocation(list_take, - Seq(FunctionInvocation(list_drop, Seq(convertExpr(a), convertExpr(start))), convertExpr(length))).copiedFrom(e) + FunctionInvocation(StringTake.typed, + Seq(FunctionInvocation(StringDrop.typed, Seq(convertExpr(a), convertExpr(start))), convertExpr(length))).copiedFrom(e) case SubString(a, start, end) => - FunctionInvocation(list_slice, Seq(convertExpr(a), convertExpr(start), convertExpr(end))).copiedFrom(e) + FunctionInvocation(StringSlice.typed, Seq(convertExpr(a), convertExpr(start), convertExpr(end))).copiedFrom(e) case MatchExpr(scrutinee, cases) => MatchExpr(convertExpr(scrutinee), for(MatchCase(pattern, guard, rhs) <- cases) yield { MatchCase(convertPattern(pattern), guard.map(convertExpr), convertExpr(rhs)) @@ -351,13 +339,11 @@ trait Z3StringConverters { self: Z3StringConversion => } case PatternConverted(e) => e } - - - + def convertExpr(e: Expr)(implicit bindings: Map[Identifier, Expr]): Expr = e match { case cc@CaseClass(cct, args) if TypeOps.isSubtypeOf(cct, StringListTyped)=> - StringLiteral(convertToString(cc)(self.p)) + StringLiteral(convertToString(cc)) case FunctionInvocation(StringSize, Seq(a)) => StringLength(convertExpr(a)).copiedFrom(e) case FunctionInvocation(StringListConcat, Seq(a, b)) => diff --git a/src/test/scala/leon/integration/solvers/SolversSuite.scala b/src/test/scala/leon/integration/solvers/SolversSuite.scala index c95fb39d33c86dac08d5d2465c99772fc88dc58f..d2af2030bded5ae60375180661107346164ca0c6 100644 --- a/src/test/scala/leon/integration/solvers/SolversSuite.scala +++ b/src/test/scala/leon/integration/solvers/SolversSuite.scala @@ -22,13 +22,13 @@ class SolversSuite extends LeonTestSuiteWithProgram { val getFactories: Seq[(String, (LeonContext, Program) => Solver)] = { (if (SolverFactory.hasNativeZ3) Seq( - ("fairz3", (ctx: LeonContext, pgm: Program) => new Z3StringFairZ3Solver(ctx, pgm) with ForcedProgramConversion ) + ("fairz3", (ctx: LeonContext, pgm: Program) => new Z3StringFairZ3Solver(ctx, pgm)) ) else Nil) ++ (if (SolverFactory.hasZ3) Seq( - ("smt-z3", (ctx: LeonContext, pgm: Program) => new Z3StringUnrollingSolver(ctx, pgm, pgm => new SMTLIBZ3Solver(ctx, pgm)) with ForcedProgramConversion ) + ("smt-z3", (ctx: LeonContext, pgm: Program) => new Z3StringUnrollingSolver(ctx, pgm, pgm => new SMTLIBZ3Solver(ctx, pgm))) ) else Nil) ++ (if (SolverFactory.hasCVC4) Seq( - ("smt-cvc4", (ctx: LeonContext, pgm: Program) => new Z3StringUnrollingSolver(ctx, pgm, pgm => new SMTLIBCVC4Solver(ctx, pgm)) with ForcedProgramConversion ) + ("smt-cvc4", (ctx: LeonContext, pgm: Program) => new Z3StringUnrollingSolver(ctx, pgm, pgm => new SMTLIBCVC4Solver(ctx, pgm))) ) else Nil) }