diff --git a/src/main/scala/inox/ast/Extractors.scala b/src/main/scala/inox/ast/Extractors.scala index 2f9ae2bca4e87cf0ec681350ad6fc119850480b4..e59547df2838296802ae3935f48ff4b6834f62e0 100644 --- a/src/main/scala/inox/ast/Extractors.scala +++ b/src/main/scala/inox/ast/Extractors.scala @@ -26,6 +26,8 @@ trait Extractors { self: Trees => /* Unary operators */ case Not(t) => Some((Seq(t), (es: Seq[Expr]) => Not(es.head))) + case BVNot(t) => + Some((Seq(t), (es: Seq[Expr]) => BVNot(es.head))) case UMinus(t) => Some((Seq(t), (es: Seq[Expr]) => UMinus(es.head))) case StringLength(t) => diff --git a/src/main/scala/inox/solvers/Solver.scala b/src/main/scala/inox/solvers/Solver.scala index a353f24596fc5042113775daeedcc043fa9c0a6c..97d46095a18fa8b6d5197caf1a004527279081f0 100644 --- a/src/main/scala/inox/solvers/Solver.scala +++ b/src/main/scala/inox/solvers/Solver.scala @@ -21,99 +21,63 @@ case object DebugSectionSolver extends DebugSection("solver") object optCheckModels extends InoxFlagOptionDef("checkmodels", "Double-check counter-examples with evaluator", false) object optSilentErrors extends InoxFlagOptionDef("silenterrors", "Fail silently into UNKNOWN when encountering an error", false) -trait Solver extends Interruptible { +trait AbstractSolver extends Interruptible { def name: String val program: Program val options: SolverOptions - import program._ - import program.trees._ - - import SolverResponses._ - - sealed trait Configuration { - type Response <: SolverResponse[Map[ValDef, Expr], Set[Expr]] - - def max(that: Configuration): Configuration = (this, that) match { - case (All , _ ) => All - case (_ , All ) => All - case (Model, Cores) => All - case (Cores, Model) => All - case (Model, _ ) => Model - case (_ , Model) => Model - case (Cores, _ ) => Cores - case (_ , Cores) => Cores - case _ => Simple - } - - def min(that: Configuration): Configuration = (this, that) match { - case (o1, o2) if o1 == o2 => o1 - case (Simple, _) => Simple - case (_, Simple) => Simple - case (Model, Cores) => Simple - case (Cores, Model) => Simple - case (All, o) => o - case (o, All) => o - } - - def in(solver: Solver): solver.Configuration = this match { - case Simple => solver.Simple - case Model => solver.Model - case Cores => solver.Cores - case All => solver.All - } - - def cast(resp: SolverResponse[Map[ValDef, Expr], Set[Expr]]): Response = ((this, resp) match { - case (_ , Unknown) => Unknown - case (Simple | Cores, Sat) => Sat - case (Model | All , s @ SatWithModel(_)) => s - case (Simple | Model, Unsat) => Unsat - case (Cores | All , u @ UnsatWithCores(_)) => u - case _ => throw FatalError("Unexpected response " + resp + " for configuration " + this) - }).asInstanceOf[Response] - } + type Trees + type Model + type Cores - object Configuration { - def apply(model: Boolean = false, cores: Boolean = false): Configuration = - if (model && cores) All - else if (model) Model - else if (cores) Cores - else Simple - } - - case object Simple extends Configuration { type Response = SimpleResponse } - case object Model extends Configuration { type Response = ResponseWithModel[Map[ValDef, Expr]] } - case object Cores extends Configuration { type Response = ResponseWithCores[Set[Expr]] } - case object All extends Configuration { type Response = ResponseWithModelAndCores[Map[ValDef, Expr], Set[Expr]] } - - object SolverUnsupportedError { - def msg(t: Tree, reason: Option[String]) = { - s"(of ${t.getClass}) is unsupported by solver ${name}" + reason.map(":\n " + _ ).getOrElse("") - } - } - - case class SolverUnsupportedError(t: Tree, reason: Option[String] = None) - extends Unsupported(t, SolverUnsupportedError.msg(t,reason)) + type Configuration = SolverResponses.Configuration[Model, Cores] + val Simple = SolverResponses.Simple + val Model = SolverResponses.Model[Model]() + val Cores = SolverResponses.Cores[Cores]() + val All = SolverResponses.All[Model, Cores]() lazy val reporter = program.ctx.reporter // This is ugly, but helpful for smtlib solvers def dbg(msg: => Any) {} - def assertCnstr(expression: Expr): Unit + def assertCnstr(expression: Trees): Unit def check[R](config: Configuration { type Response <: R }): R - def checkAssumptions[R](config: Configuration { type Response <: R })(assumptions: Set[Expr]): R + def checkAssumptions[R](config: Configuration { type Response <: R })(assumptions: Set[Trees]): R def getResultSolver: Option[Solver] = Some(this) - def free() + def free(): Unit - def reset() + def reset(): Unit def push(): Unit def pop(): Unit + implicit val debugSection = DebugSectionSolver + + private[solvers] def debugS(msg: String) = { + reporter.debug("["+name+"] "+msg) + } +} + +trait Solver extends AbstractSolver { + import program.trees._ + + type Trees = Expr + type Model = Map[ValDef, Expr] + type Cores = Set[Expr] + + object SolverUnsupportedError { + def msg(t: Tree, reason: Option[String]) = { + s"(of ${t.getClass}) is unsupported by solver ${name}" + reason.map(":\n " + _ ).getOrElse("") + } + } + + case class SolverUnsupportedError(t: Tree, reason: Option[String] = None) + extends Unsupported(t, SolverUnsupportedError.msg(t,reason)) + protected def unsupported(t: Tree): Nothing = { val err = SolverUnsupportedError(t, None) reporter.warning(err.getMessage) @@ -125,10 +89,4 @@ trait Solver extends Interruptible { reporter.warning(err.getMessage) throw err } - - implicit val debugSection = DebugSectionSolver - - private[solvers] def debugS(msg: String) = { - reporter.debug("["+name+"] "+msg) - } } diff --git a/src/main/scala/inox/solvers/SolverResponses.scala b/src/main/scala/inox/solvers/SolverResponses.scala index 0f5f7f87faa89e8d12aa75db8b30b776fcede277..f4572e78d4503674e9d1bda0b2aba41d7e6619b1 100644 --- a/src/main/scala/inox/solvers/SolverResponses.scala +++ b/src/main/scala/inox/solvers/SolverResponses.scala @@ -42,5 +42,64 @@ object SolverResponses { case Unknown => None } } + + sealed trait Configuration[+Model, +Cores] { + type Response <: SolverResponse[Model, Cores] + + def max[M >: Model,C >: Cores](that: Configuration[M, C]): Configuration[M, C] = (this, that) match { + case (All() , _ ) => All() + case (_ , All() ) => All() + case (Model(), Cores()) => All() + case (Cores(), Model()) => All() + case (Model(), _ ) => Model() + case (_ , Model()) => Model() + case (Cores(), _ ) => Cores() + case (_ , Cores()) => Cores() + case _ => Simple + } + + def min[M >: Model, C >: Cores](that: Configuration[M, C]): Configuration[M, C] = (this, that) match { + case (o1, o2) if o1 == o2 => o1 + case (Simple, _) => Simple + case (_, Simple) => Simple + case (Model(), Cores()) => Simple + case (Cores(), Model()) => Simple + case (All(), o) => o + case (o, All()) => o + } + + def cast[M <: Model, C <: Cores](resp: SolverResponse[M, C]): Response = ((this, resp) match { + case (_, Unknown) => Unknown + case (Simple | Cores(), Sat) => Sat + case (Model() | All() , s @ SatWithModel(_)) => s + case (Simple | Model(), Unsat) => Unsat + case (Cores() | All() , u @ UnsatWithCores(_)) => u + case _ => throw FatalError("Unexpected response " + resp + " for configuration " + this) + }).asInstanceOf[Response] + } + + object Configuration { + def apply[M,C](model: Boolean = false, cores: Boolean = false): Configuration[M,C] = + if (model && cores) All() + else if (model) Model() + else if (cores) Cores() + else Simple + } + + case object Simple extends Configuration[Nothing,Nothing] { + type Response = SimpleResponse + } + + case class Model[Model]() extends Configuration[Model,Nothing] { + type Response = ResponseWithModel[Model] + } + + case class Cores[Cores]() extends Configuration[Nothing,Cores] { + type Response = ResponseWithCores[Cores] + } + + case class All[Model,Cores]() extends Configuration[Model,Cores] { + type Response = ResponseWithModelAndCores[Model, Cores] + } } diff --git a/src/main/scala/inox/solvers/unrolling/UnrollingSolver.scala b/src/main/scala/inox/solvers/unrolling/UnrollingSolver.scala index 2b3f4079438ae65ca5369418db0836205cd9f890..a5cc2f2d2676bb6f692a1ec270bde838c2902785 100644 --- a/src/main/scala/inox/solvers/unrolling/UnrollingSolver.scala +++ b/src/main/scala/inox/solvers/unrolling/UnrollingSolver.scala @@ -21,20 +21,27 @@ trait AbstractUnrollingSolver import program.trees._ import program.symbols._ - val theories: TheoryEncoder { val trees: program.trees.type } + protected type Encoded + protected implicit val printable: Encoded => Printable - type Encoded - implicit val printable: Encoded => Printable + protected val theories: TheoryEncoder { val trees: program.trees.type } - val templates: Templates { + protected val templates: Templates { val program: theories.targetProgram.type type Encoded = AbstractUnrollingSolver.this.Encoded } - val evaluator: DeterministicEvaluator with SolvingEvaluator { + protected val evaluator: DeterministicEvaluator with SolvingEvaluator { val program: AbstractUnrollingSolver.this.program.type } + protected val underlying: AbstractSolver { + val program: AbstractUnrollingSolver.this.program.type + type Trees = Encoded + type Model = ModelWrapper + type Cores = Set[Encoded] + } + val unfoldFactor = options.findOptionOrDefault(optUnrollFactor) val feelingLucky = options.findOptionOrDefault(optFeelingLucky) val checkModels = options.findOptionOrDefault(optCheckModels) @@ -88,7 +95,7 @@ trait AbstractUnrollingSolver val newClauses = templates.instantiateExpr(expression, bindings) for (cl <- newClauses) { - solverAssert(cl) + underlying.assertCnstr(cl) } } diff --git a/src/main/scala/inox/solvers/z3/AbstractZ3Solver.scala b/src/main/scala/inox/solvers/z3/AbstractZ3Solver.scala index ce48599a8968545461340508caff4507ec230233..c5c19a2495d9bf82ee6d7b872f5932c845c6d6aa 100644 --- a/src/main/scala/inox/solvers/z3/AbstractZ3Solver.scala +++ b/src/main/scala/inox/solvers/z3/AbstractZ3Solver.scala @@ -1,33 +1,29 @@ /* Copyright 2009-2016 EPFL, Lausanne */ -package leon +package inox package solvers.z3 -import leon.utils._ +import utils._ import z3.scala.{Z3Solver => ScalaZ3Solver, _} import solvers._ -import purescala.Common._ -import purescala.Definitions._ -import purescala.Constructors._ -import purescala.Extractors._ -import purescala.Expressions._ -import purescala.TypeOps._ -import purescala.ExprOps._ -import purescala.Types._ case class UnsoundExtractionException(ast: Z3AST, msg: String) extends Exception("Can't extract " + ast + " : " + msg) // This is just to factor out the things that are common in "classes that deal // with a Z3 instance" -trait AbstractZ3Solver extends Z3Solver { - - val program : Program +trait AbstractZ3Solver extends AbstractSolver { + context.interruptManager.registerForInterrupts(this) - val library = program.library + import program._ + import program.trees._ + import program.symbols._ + import program.symbols.typeOps.bestRealType - context.interruptManager.registerForInterrupts(this) + type Trees = Z3AST + type Model = Z3Model + type Cores = Set[Z3AST] private[this] var freed = false val traceE = new Exception() @@ -88,11 +84,11 @@ trait AbstractZ3Solver extends Z3Solver { protected val lambdas = new IncrementalBijection[FunctionType, Z3FuncDecl]() protected val variables = new IncrementalBijection[Expr, Z3AST]() - protected val constructors = new IncrementalBijection[TypeTree, Z3FuncDecl]() - protected val selectors = new IncrementalBijection[(TypeTree, Int), Z3FuncDecl]() - protected val testers = new IncrementalBijection[TypeTree, Z3FuncDecl]() + protected val constructors = new IncrementalBijection[Type, Z3FuncDecl]() + protected val selectors = new IncrementalBijection[(Type, Int), Z3FuncDecl]() + protected val testers = new IncrementalBijection[Type, Z3FuncDecl]() - protected val sorts = new IncrementalMap[TypeTree, Z3Sort]() + protected val sorts = new IncrementalMap[Type, Z3Sort]() var isInitialized = false protected[leon] def initZ3(): Unit = { @@ -115,19 +111,19 @@ trait AbstractZ3Solver extends Z3Solver { initZ3() - def rootType(ct: TypeTree): TypeTree = ct match { + def rootType(ct: Type): Type = ct match { case ct: ClassType => ct.root case t => t } - def declareStructuralSort(t: TypeTree): Z3Sort = { + def declareStructuralSort(t: Type): Z3Sort = { //println("///"*40) //println("Declaring for: "+t) adtManager.defineADT(t) match { case Left(adts) => declareDatatypes(adts.toSeq) - sorts(normalizeType(t)) + sorts(bestRealType(t)) case Right(conflicts) => conflicts.foreach { declareStructuralSort } @@ -135,12 +131,12 @@ trait AbstractZ3Solver extends Z3Solver { } } - def declareDatatypes(adts: Seq[(TypeTree, DataType)]): Unit = { + def declareDatatypes(adts: Seq[(Type, DataType)]): Unit = { import Z3Context.{ADTSortReference, RecursiveType, RegularSort} - val indexMap: Map[TypeTree, Int] = adts.map(_._1).zipWithIndex.toMap + val indexMap: Map[Type, Int] = adts.map(_._1).zipWithIndex.toMap - def typeToSortRef(tt: TypeTree): ADTSortReference = { + def typeToSortRef(tt: Type): ADTSortReference = { val tpe = rootType(tt) if (indexMap contains tpe) { @@ -201,12 +197,8 @@ trait AbstractZ3Solver extends Z3Solver { selectors.clear() } - def normalizeType(t: TypeTree): TypeTree = { - bestRealType(t) - } - // assumes prepareSorts has been called.... - protected[leon] def typeToSort(oldtt: TypeTree): Z3Sort = normalizeType(oldtt) match { + protected[leon] def typeToSort(oldtt: Type): Z3Sort = bestRealType(oldtt) match { case Int32Type | BooleanType | IntegerType | RealType | CharType => sorts(oldtt) @@ -220,14 +212,11 @@ trait AbstractZ3Solver extends Z3Solver { z3.mkSetSort(typeToSort(base)) } - case tt @ MapType(fromType, toType) => - typeToSort(RawArrayType(fromType, library.optionType(toType))) - case tt @ BagType(base) => - typeToSort(RawArrayType(base, IntegerType)) + typeToSort(MapType(base, IntegerType)) - case rat @ RawArrayType(from, to) => - sorts.cached(rat) { + case tt @ MapType(from, to) => + sorts.cached(tt) { val fromSort = typeToSort(from) val toSort = typeToSort(to) @@ -244,57 +233,20 @@ trait AbstractZ3Solver extends Z3Solver { unsupported(other) } - protected[leon] def toZ3Formula(expr: Expr, initialMap: Map[Identifier, Z3AST] = Map.empty): Z3AST = { + protected[leon] def toZ3Formula(expr: Expr, bindings: Map[Variable, Z3AST] = Map.empty): Z3AST = { - var z3Vars: Map[Identifier,Z3AST] = if (initialMap.nonEmpty) { - initialMap - } else { - // FIXME TODO pleeeeeeeease make this cleaner. Ie. decide what set of - // variable has to remain in a map etc. - variables.aToB.collect{ case (Variable(id), p2) => id -> p2 } - } - - def rec(ex: Expr): Z3AST = ex match { - - // TODO: Leave that as a specialization? - case LetTuple(ids, e, b) => - z3Vars = z3Vars ++ ids.zipWithIndex.map { case (id, ix) => - val entry = id -> rec(tupleSelect(e, ix + 1, ids.size)) - entry - } - val rb = rec(b) - z3Vars = z3Vars -- ids - rb + def rec(ex: Expr)(implicit bindings: Map[Variable, Z3AST]): Z3AST = ex match { - case p @ Passes(_, _, _) => - rec(p.asConstraint) + case Let(vd, e, b) => + val re = rec(e) + rec(b)(bindings + (vd.toVariable -> re)) - case me @ MatchExpr(s, cs) => - rec(matchToIfThenElse(me)) + case a @ Assume(cond, body) => + val (rc, rb) = (rec(cond), rec(body)) + z3.mkITE(rc, rb, z3.mkFreshConst("fail", typeToSort(body.getType))) - case Let(i, e, b) => - val re = rec(e) - z3Vars = z3Vars + (i -> re) - val rb = rec(b) - z3Vars = z3Vars - i - rb - - case a @ Assert(cond, err, body) => - rec(IfExpr(cond, body, Error(a.getType, err.getOrElse("Assertion failed")).setPos(a.getPos)).setPos(a.getPos)) - - case e @ Error(tpe, _) => - val newAST = z3.mkFreshConst("errorValue", typeToSort(tpe)) - // Might introduce dupplicates (e), but no worries here - variables += (e -> newAST) - newAST - - case v @ Variable(id) => z3Vars.getOrElse(id, - variables.getB(v).getOrElse { - val newAST = z3.mkFreshConst(id.uniqueName, typeToSort(v.getType)) - z3Vars = z3Vars + (id -> newAST) - variables += (v -> newAST) - newAST - } + case v: Variable => bindings.getOrElse(v, + variables.cachedB(v)(z3.mkFreshConst(v.id.uniqueName, typeToSort(v.getType))) ) case ite @ IfExpr(c, t, e) => z3.mkITE(rec(c), rec(t), rec(e)) @@ -309,36 +261,42 @@ trait AbstractZ3Solver extends Z3Solver { case CharLiteral(c) => z3.mkInt(c, typeToSort(CharType)) case BooleanLiteral(v) => if (v) z3.mkTrue() else z3.mkFalse() case Equals(l, r) => z3.mkEq(rec( l ), rec( r ) ) - case Plus(l, r) => z3.mkAdd(rec(l), rec(r)) - case Minus(l, r) => z3.mkSub(rec(l), rec(r)) - case Times(l, r) => z3.mkMul(rec(l), rec(r)) + case Plus(l, r) => l.getType match { + case BVType(_) => z3.mkBVAdd(rec(l), rec(r)) + case _ => z3.mkAdd(rec(l), rec(r)) + } + case Minus(l, r) => l.getType match { + case BVType(_) => z3.mkBVSub(rec(l), rec(r)) + case _ => z3.mkSub(rec(l), rec(r)) + } + case Times(l, r) => l.getType match { + case BVType(_) => z3.mkBVMul(rec(l), rec(r)) + case _ => z3.mkMul(rec(l), rec(r)) + } case Division(l, r) => - val rl = rec(l) - val rr = rec(r) - z3.mkITE( - z3.mkGE(rl, z3.mkNumeral("0", typeToSort(IntegerType))), - z3.mkDiv(rl, rr), - z3.mkUnaryMinus(z3.mkDiv(z3.mkUnaryMinus(rl), rr)) - ) - case Remainder(l, r) => - val q = rec(Division(l, r)) - z3.mkSub(rec(l), z3.mkMul(rec(r), q)) - case Modulo(l, r) => - z3.mkMod(rec(l), rec(r)) - case UMinus(e) => z3.mkUnaryMinus(rec(e)) - - case RealPlus(l, r) => z3.mkAdd(rec(l), rec(r)) - case RealMinus(l, r) => z3.mkSub(rec(l), rec(r)) - case RealTimes(l, r) => z3.mkMul(rec(l), rec(r)) - case RealDivision(l, r) => z3.mkDiv(rec(l), rec(r)) - case RealUMinus(e) => z3.mkUnaryMinus(rec(e)) - - case BVPlus(l, r) => z3.mkBVAdd(rec(l), rec(r)) - case BVMinus(l, r) => z3.mkBVSub(rec(l), rec(r)) - case BVTimes(l, r) => z3.mkBVMul(rec(l), rec(r)) - case BVDivision(l, r) => z3.mkBVSdiv(rec(l), rec(r)) - case BVRemainder(l, r) => z3.mkBVSrem(rec(l), rec(r)) - case BVUMinus(e) => z3.mkBVNeg(rec(e)) + val (rl, rr) = (rec(l), rec(r)) + l.getType match { + case IntegerType => + z3.mkITE( + z3.mkGE(rl, z3.mkNumeral("0", typeToSort(IntegerType))), + z3.mkDiv(rl, rr), + z3.mkUnaryMinus(z3.mkDiv(z3.mkUnaryMinus(rl), rr)) + ) + case BVType(_) => z3.mkBVSdiv(rl, rr) + case _ => z3.mkDiv(rl, rr) + } + case Remainder(l, r) => l.getType match { + case BVType(_) => z3.mkBVSrem(rec(l), rec(r)) + case _ => + val q = rec(Division(l, r)) + z3.mkSub(rec(l), z3.mkMul(rec(r), q)) + } + case Modulo(l, r) => z3.mkMod(rec(l), rec(r)) + case UMinus(e) => e.getType match { + case BVType(_) => z3.mkBVNeg(rec(e)) + case _ => z3.mkUnaryMinus(rec(e)) + } + case BVNot(e) => z3.mkBVNot(rec(e)) case BVAnd(l, r) => z3.mkBVAnd(rec(l), rec(r)) case BVOr(l, r) => z3.mkBVOr(rec(l), rec(r)) @@ -357,7 +315,6 @@ trait AbstractZ3Solver extends Z3Solver { case RealType => z3.mkLE(rec(l), rec(r)) case Int32Type => z3.mkBVSle(rec(l), rec(r)) case CharType => z3.mkBVSle(rec(l), rec(r)) - //case _ => throw new IllegalStateException(s"l: $l, Left type: ${l.getType} Expr: $ex") } case GreaterThan(l, r) => l.getType match { case IntegerType => z3.mkGT(rec(l), rec(r)) @@ -373,19 +330,19 @@ trait AbstractZ3Solver extends Z3Solver { } case u : UnitLiteral => - val tpe = normalizeType(u.getType) + val tpe = bestRealType(u.getType) typeToSort(tpe) val constructor = constructors.toB(tpe) constructor() case t @ Tuple(es) => - val tpe = normalizeType(t.getType) + val tpe = bestRealType(t.getType) typeToSort(tpe) val constructor = constructors.toB(tpe) constructor(es.map(rec): _*) case ts @ TupleSelect(t, i) => - val tpe = normalizeType(t.getType) + val tpe = bestRealType(t.getType) typeToSort(tpe) val selector = selectors.toB((tpe, i-1)) selector(rec(t)) @@ -395,33 +352,35 @@ trait AbstractZ3Solver extends Z3Solver { val constructor = constructors.toB(ct) constructor(args.map(rec): _*) - case c @ CaseClassSelector(cct, cc, sel) => + case c @ CaseClassSelector(cc, sel) => + val cct = cc.getType typeToSort(cct) // Making sure the sort is defined - val selector = selectors.toB(cct, c.selectorIndex) + val selector = selectors.toB(cct, sel) selector(rec(cc)) case AsInstanceOf(expr, ct) => rec(expr) - case IsInstanceOf(e, act: AbstractClassType) => - act.knownCCDescendants match { - case Seq(cct) => - rec(IsInstanceOf(e, cct)) - case more => - val i = FreshIdentifier("e", act, alwaysShowUniqueID = true) - rec(Let(i, e, orJoin(more map(IsInstanceOf(Variable(i), _))))) - } - - case IsInstanceOf(e, cct: CaseClassType) => - typeToSort(cct) // Making sure the sort is defined - val tester = testers.toB(cct) - tester(rec(e)) + case IsInstanceOf(e, ct) => ct.tcd match { + case tacd: TypedAbstractClassDef => + tacd.descendants match { + case Seq(tccd) => + rec(IsInstanceOf(e, tccd.toType)) + case more => + val v = Variable(FreshIdentifier("e", true), ct) + rec(Let(v.toVal, e, orJoin(more map (IsInstanceOf(v, _))))) + } + case tccd: TypedCaseClassDef => + typeToSort(ct) + val tester = tester.toB(ct) + tester(rec(e)) + } - case f @ FunctionInvocation(tfd, args) => - z3.mkApp(functionDefToDecl(tfd), args.map(rec): _*) + case f @ FunctionInvocation(id, tps, args) => + z3.mkApp(functionDefToDecl(getFunction(id, tps)), args.map(rec): _*) case fa @ Application(caller, args) => - val ft @ FunctionType(froms, to) = normalizeType(caller.getType) + val ft @ FunctionType(froms, to) = bestRealType(caller.getType) val funDecl = lambdas.cachedB(ft) { val sortSeq = (ft +: froms).map(tpe => typeToSort(tpe)) val returnSort = typeToSort(to) @@ -440,12 +399,11 @@ trait AbstractZ3Solver extends Z3Solver { case fb @ FiniteBag(elems, base) => typeToSort(fb.getType) - rec(RawArrayValue(base, elems, InfiniteIntegerLiteral(0))) + rec(FiniteMap(base, elems, IntegerLiteral(0))) case BagAdd(b, e) => - val bag = rec(b) - val elem = rec(e) - z3.mkStore(bag, elem, z3.mkAdd(z3.mkSelect(bag, elem), rec(InfiniteIntegerLiteral(1)))) + val (bag, elem) = (rec(b), rec(e)) + z3.mkStore(bag, elem, z3.mkAdd(z3.mkSelect(bag, elem), rec(IntegerLiteral(1)))) case MultiplicityInBag(e, b) => z3.mkSelect(rec(b), rec(e)) @@ -467,8 +425,9 @@ trait AbstractZ3Solver extends Z3Solver { val withNeg = z3.mkArrayMap(minus, rec(b1), rec(b2)) z3.mkArrayMap(div, z3.mkArrayMap(plus, withNeg, z3.mkArrayMap(abs, withNeg)), all2) - case al @ RawArraySelect(a, i) => + case al @ MapApply(a, i) => z3.mkSelect(rec(a), rec(i)) + case al @ RawArrayUpdated(a, i, e) => z3.mkStore(rec(a), rec(i), rec(e)) case RawArrayValue(keyTpe, elems, default) => @@ -482,14 +441,14 @@ trait AbstractZ3Solver extends Z3Solver { * ===== Map operations ===== */ case m @ FiniteMap(elems, from, to) => - val MapType(_, t) = normalizeType(m.getType) + val MapType(_, t) = bestRealType(m.getType) rec(RawArrayValue(from, elems.map{ case (k, v) => (k, CaseClass(library.someType(t), Seq(v))) }, CaseClass(library.noneType(t), Seq()))) case MapApply(m, k) => - val mt @ MapType(_, t) = normalizeType(m.getType) + val mt @ MapType(_, t) = bestRealType(m.getType) typeToSort(mt) val el = z3.mkSelect(rec(m), rec(k)) @@ -498,7 +457,7 @@ trait AbstractZ3Solver extends Z3Solver { selectors.toB(library.someType(t), 0)(el) case MapIsDefinedAt(m, k) => - val mt @ MapType(_, t) = normalizeType(m.getType) + val mt @ MapType(_, t) = bestRealType(m.getType) typeToSort(mt) val el = z3.mkSelect(rec(m), rec(k)) @@ -506,7 +465,7 @@ trait AbstractZ3Solver extends Z3Solver { testers.toB(library.someType(t))(el) case MapUnion(m1, FiniteMap(elems, _, _)) => - val mt @ MapType(_, t) = normalizeType(m1.getType) + val mt @ MapType(_, t) = bestRealType(m1.getType) typeToSort(mt) elems.foldLeft(rec(m1)) { case (m, (k,v)) => @@ -528,9 +487,9 @@ trait AbstractZ3Solver extends Z3Solver { rec(expr) } - protected[leon] def fromZ3Formula(model: Z3Model, tree: Z3AST, tpe: TypeTree): Expr = { + protected[leon] def fromZ3Formula(model: Z3Model, tree: Z3AST, tpe: Type): Expr = { - def rec(t: Z3AST, tpe: TypeTree): Expr = { + def rec(t: Z3AST, tpe: Type): Expr = { val kind = z3.getASTKind(t) kind match { case Z3NumeralIntAST(Some(v)) => @@ -734,10 +693,10 @@ trait AbstractZ3Solver extends Z3Solver { } } - rec(tree, normalizeType(tpe)) + rec(tree, bestRealType(tpe)) } - protected[leon] def softFromZ3Formula(model: Z3Model, tree: Z3AST, tpe: TypeTree) : Option[Expr] = { + protected[leon] def softFromZ3Formula(model: Z3Model, tree: Z3AST, tpe: Type) : Option[Expr] = { try { Some(fromZ3Formula(model, tree, tpe)) } catch { @@ -747,8 +706,8 @@ trait AbstractZ3Solver extends Z3Solver { } } - def idToFreshZ3Id(id: Identifier): Z3AST = { - z3.mkFreshConst(id.uniqueName, typeToSort(id.getType)) + def symbolToFreshZ3Symbol(v: Variable): Z3AST = { + z3.mkFreshConst(v.id.uniqueName, typeToSort(v.getType)) } def reset(): Unit = { diff --git a/src/main/scala/inox/solvers/z3/FairZ3Solver.scala b/src/main/scala/inox/solvers/z3/FairZ3Solver.scala index c51dc6033d7532b3ed95fe98eb587582507bda8d..5223f915a7a3287e811f74f5556b24ddb8c12e99 100644 --- a/src/main/scala/inox/solvers/z3/FairZ3Solver.scala +++ b/src/main/scala/inox/solvers/z3/FairZ3Solver.scala @@ -1,41 +1,83 @@ /* Copyright 2009-2016 EPFL, Lausanne */ -package leon +package inox package solvers package z3 import _root_.z3.scala._ -import purescala.Common._ -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Types._ - import unrolling._ import theories._ import utils._ -class FairZ3Solver(val sctx: SolverContext, val program: Program) +trait FairZ3Solver extends AbstractZ3Solver - with AbstractUnrollingSolver[Z3AST] { - - enclosing => + with AbstractUnrollingSolver { - protected val errors = new IncrementalBijection[Unit, Boolean]() - protected def hasError = errors.getB(()) contains true - protected def addError() = errors += () -> true + import program._ + import program.trees._ + import program.symbols._ override val name = "Z3-f" override val description = "Fair Z3 Solver" + type Encoded = Z3AST + val printable = (z3: Z3AST) => new Printable { + def asString(implicit ctx: LeonContext) = z3.toString + } + + object theories extends { + val trees: program.trees.type = program.trees + } with StringEncoder + + object templates extends { + val program: FairZ3Solver.this.program.type = FairZ3Solver.this.program + type Encoded = FairZ3Solver.this.Encoded + } with Templates { + + def encodeSymbol(v: Variable): Z3AST = symbolToFreshZ3Symbol(v) + + def encodeExpr(bindings: Map[Variable, Z3AST])(e: Expr): Z3AST = { + toZ3Formula(e, bindings) + } + + def substitute(substMap: Map[Z3AST, Z3AST]): Z3AST => Z3AST = { + val (from, to) = substMap.unzip + val (fromArray, toArray) = (from.toArray, to.toArray) + + (c: Z3AST) => z3.substitute(c, fromArray, toArray) + } + + def mkNot(e: Z3AST) = z3.mkNot(e) + def mkOr(es: Z3AST*) = z3.mkOr(es : _*) + def mkAnd(es: Z3AST*) = z3.mkAnd(es : _*) + def mkEquals(l: Z3AST, r: Z3AST) = z3.mkEq(l, r) + def mkImplies(l: Z3AST, r: Z3AST) = z3.mkImplies(l, r) + + def extractNot(l: Z3AST): Option[Z3AST] = z3.getASTKind(l) match { + case Z3AppAST(decl, args) => z3.getDeclKind(decl) match { + case OpNot => Some(args.head) + case _ => None + } + case ast => None + } + } + override def reset(): Unit = super[AbstractZ3Solver].reset() - def declareVariable(id: Identifier): Z3AST = variables.cachedB(Variable(id)) { - templateEncoder.encodeId(id) + def declareVariable(v: Variable): Z3AST = variables.cachedB(v) { + templates.encodeSymbol(v) + } + + def solverAssert(cnstr: Z3AST): Unit = { + val timer = context.timers.solvers.z3.assert.start() + solver.assertCnstr(cnstr) + timer.stop() } - def solverCheck[R](clauses: Seq[Z3AST])(block: Option[Boolean] => R): R = { + def solverCheck[R](config: Configuration) + (clauses: Seq[Z3AST]) + (block: Response => R): R = { solver.push() for (cls <- clauses) solver.assertCnstr(cls) val res = solver.check @@ -103,43 +145,6 @@ class FairZ3Solver(val sctx: SolverContext, val program: Program) override def toString = model.toString } - val printable = (z3: Z3AST) => new Printable { - def asString(implicit ctx: LeonContext) = z3.toString - } - - val theoryEncoder = new StringEncoder(context, program) >> new BagEncoder(context, program) >> new ArrayEncoder(context, program) - - val templateEncoder = new TemplateEncoder[Z3AST] { - def encodeId(id: Identifier): Z3AST = { - idToFreshZ3Id(id) - } - - def encodeExpr(bindings: Map[Identifier, Z3AST])(e: Expr): Z3AST = { - toZ3Formula(e, bindings) - } - - def substitute(substMap: Map[Z3AST, Z3AST]): Z3AST => Z3AST = { - val (from, to) = substMap.unzip - val (fromArray, toArray) = (from.toArray, to.toArray) - - (c: Z3AST) => z3.substitute(c, fromArray, toArray) - } - - def mkNot(e: Z3AST) = z3.mkNot(e) - def mkOr(es: Z3AST*) = z3.mkOr(es : _*) - def mkAnd(es: Z3AST*) = z3.mkAnd(es : _*) - def mkEquals(l: Z3AST, r: Z3AST) = z3.mkEq(l, r) - def mkImplies(l: Z3AST, r: Z3AST) = z3.mkImplies(l, r) - - def extractNot(l: Z3AST): Option[Z3AST] = z3.getASTKind(l) match { - case Z3AppAST(decl, args) => z3.getDeclKind(decl) match { - case OpNot => Some(args.head) - case _ => None - } - case ast => None - } - } - private val incrementals: List[IncrementalState] = List( errors, functions, lambdas, sorts, variables, constructors, selectors, testers @@ -173,21 +178,6 @@ class FairZ3Solver(val sctx: SolverContext, val program: Program) } } - override def assertCnstr(expression: Expr): Unit = { - try { - super.assertCnstr(expression) - } catch { - case u: Unsupported => - addError() - } - } - - def solverAssert(cnstr: Z3AST): Unit = { - val timer = context.timers.solvers.z3.assert.start() - solver.assertCnstr(cnstr) - timer.stop() - } - def solverUnsatCore = Some(solver.getUnsatCore) override def foundAnswer(res: Option[Boolean], model: Model = Model.empty, core: Set[Expr] = Set.empty) = { diff --git a/src/main/scala/inox/solvers/z3/Z3Solver.scala b/src/main/scala/inox/solvers/z3/Z3Solver.scala index 24feda9fff650925419fb14450f0e198ad9a29ae..a3648904ab857014fc1d483064ad4f5d0fb6f4a0 100644 --- a/src/main/scala/inox/solvers/z3/Z3Solver.scala +++ b/src/main/scala/inox/solvers/z3/Z3Solver.scala @@ -1,6 +1,6 @@ /* Copyright 2009-2016 EPFL, Lausanne */ -package leon +package inox package solvers package z3 diff --git a/src/main/scala/inox/solvers/z3/Z3UnrollingSolver.scala b/src/main/scala/inox/solvers/z3/Z3UnrollingSolver.scala index 3dd919f99784438a65dc9dfc851f1e961797f153..454cc37dc2de78e38b038f835a2ec38f24c8d539 100644 --- a/src/main/scala/inox/solvers/z3/Z3UnrollingSolver.scala +++ b/src/main/scala/inox/solvers/z3/Z3UnrollingSolver.scala @@ -10,5 +10,5 @@ import unrolling._ import theories._ class Z3UnrollingSolver(context: SolverContext, program: Program, underlying: Z3Solver) - extends UnrollingSolver(context, program, underlying, new StringEncoder(context.context, program) >> new ArrayEncoder(context.context, program)) + extends UnrollingSolver(context, program, underlying, new StringEncoder(context.context, program)) with Z3Solver