diff --git a/src/main/scala/leon/solvers/ADTManager.scala b/src/main/scala/leon/solvers/ADTManager.scala index 7987ca439bc1553372b860ce727560e5a624df6a..296025d189b09c2094a7e3dcec864823e6c5e239 100644 --- a/src/main/scala/leon/solvers/ADTManager.scala +++ b/src/main/scala/leon/solvers/ADTManager.scala @@ -4,10 +4,18 @@ package solvers import purescala.Types._ import purescala.Common._ -case class DataType(sym: Identifier, cases: Seq[Constructor]) -case class Constructor(sym: Identifier, tpe: TypeTree, fields: Seq[(Identifier, TypeTree)]) +case class DataType(sym: Identifier, cases: Seq[Constructor]) { + override def toString = { + "Datatype: "+sym.uniqueName+"\n"+cases.map(c => " - "+c.toString).mkString("\n") + } +} +case class Constructor(sym: Identifier, tpe: TypeTree, fields: Seq[(Identifier, TypeTree)]) { + override def toString = { + sym.uniqueName+" ["+tpe+"] "+fields.map(f => f._1.uniqueName+": "+f._2).mkString("(", ", ", ")") + } +} -class ADTManager { +class ADTManager(reporter: Reporter) { protected def freshId(id: Identifier): Identifier = freshId(id.name) protected def freshId(name: String): Identifier = FreshIdentifier(name) @@ -24,68 +32,93 @@ class ADTManager { } protected var defined = Set[TypeTree]() + protected var locked = Set[TypeTree]() + + protected var discovered = Map[TypeTree, DataType]() + + def defineADT(t: TypeTree): Either[Map[TypeTree, DataType], Set[TypeTree]] = { + discovered = Map() + locked = Set() - def defineADT(t: TypeTree): Map[TypeTree, DataType] = { - val adts = findDependencies(t) - for ((t, dt) <- adts) { - defined += t + findDependencies(t) + + val conflicts = discovered.keySet & locked + + if (conflicts(t)) { + // There is no way to solve this, the type we requested is in conflict + reporter.warning("Encountered ADT '"+t+"' that can't be defined.") + reporter.warning("It appears it has recursive references through non-structural types (such as arrays, maps, or sets).") + throw new IllegalArgumentException + } else { + // We might be able to define some despite conflicts + if (conflicts.isEmpty) { + for ((t, dt) <- discovered) { + defined += t + } + Left(discovered) + } else { + Right(conflicts) + } } - adts } - protected def findDependencies(t: TypeTree, dts: Map[TypeTree, DataType] = Map()): Map[TypeTree, DataType] = t match { + def forEachType(t: TypeTree)(f: TypeTree => Unit): Unit = t match { + case NAryType(tps, builder) => + f(t) + tps.foreach(forEachType(_)(f)) + } + + protected def findDependencies(t: TypeTree): Unit = t match { + case _: SetType | _: MapType => + forEachType(t) { tpe => + if (!defined(tpe)) { + locked += tpe + } + } + case ct: ClassType => val (root, sub) = getHierarchy(ct) - if (!(dts contains root) && !(defined contains root)) { + if (!(discovered contains root) && !(defined contains root)) { val sym = freshId(ct.id) val conss = sub.map { case cct => Constructor(freshId(cct.id), cct, cct.fields.map(vd => (freshId(vd.id), vd.getType))) } - var cdts = dts + (root -> DataType(sym, conss)) + discovered += (root -> DataType(sym, conss)) // look for dependencies for (ct <- root +: sub; f <- ct.fields) { - cdts ++= findDependencies(f.getType, cdts) + findDependencies(f.getType) } - - cdts - } else { - dts } case tt @ TupleType(bases) => - if (!(dts contains t) && !(defined contains t)) { + if (!(discovered contains t) && !(defined contains t)) { val sym = freshId("tuple"+bases.size) val c = Constructor(freshId(sym.name), tt, bases.zipWithIndex.map { case (tpe, i) => (freshId("_"+(i+1)), tpe) }) - var cdts = dts + (tt -> DataType(sym, Seq(c))) + discovered += (tt -> DataType(sym, Seq(c))) for (b <- bases) { - cdts ++= findDependencies(b, cdts) + findDependencies(b) } - cdts - } else { - dts } case UnitType => - if (!(dts contains t) && !(defined contains t)) { + if (!(discovered contains t) && !(defined contains t)) { val sym = freshId("Unit") - dts + (t -> DataType(sym, Seq(Constructor(freshId(sym.name), t, Nil)))) - } else { - dts + discovered += (t -> DataType(sym, Seq(Constructor(freshId(sym.name), t, Nil)))) } case at @ ArrayType(base) => - if (!(dts contains t) && !(defined contains t)) { + if (!(discovered contains t) && !(defined contains t)) { val sym = freshId("array") val c = Constructor(freshId(sym.name), at, List( @@ -93,14 +126,11 @@ class ADTManager { (freshId("content"), RawArrayType(Int32Type, base)) )) - val cdts = dts + (at -> DataType(sym, Seq(c))) + discovered += (at -> DataType(sym, Seq(c))) - findDependencies(base, cdts) - } else { - dts + findDependencies(base) } case _ => - dts } } diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala index 85f09790298b1fa951f89186efa556b8c5406c4f..87a147ee33fe794676c359beefc2e2e9ec83ab27 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala @@ -99,9 +99,9 @@ abstract class SMTLIBSolver(val context: LeonContext, QualifiedIdentifier(SMTIdentifier(s)) } - val adtManager = new ADTManager + protected val adtManager = new ADTManager(reporter) - val library = program.library + protected val library = program.library protected def id2sym(id: Identifier): SSymbol = SSymbol(id.name+"!"+id.globalId) @@ -251,11 +251,16 @@ abstract class SMTLIBSolver(val context: LeonContext, protected def declareStructuralSort(t: TypeTree): Sort = { // Populates the dependencies of the structural type to define. - val datatypes = adtManager.defineADT(t) - - declareDatatypes(datatypes) + adtManager.defineADT(t) match { + case Left(adts) => + declareDatatypes(adts) + sorts.toB(t) + + case Right(conflicts) => + conflicts.foreach { declareStructuralSort } + declareStructuralSort(t) + } - sorts.toB(t) } protected def declareVariable(id: Identifier): SSymbol = { diff --git a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala index 1ded84a6fcd702cb5c448650e236c6d942de5611..2e1f051357b092b172edd53258109d30eb69fc23 100644 --- a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala @@ -47,8 +47,6 @@ trait AbstractZ3Solver } } - class CantTranslateException(t: Z3AST) extends Exception("Can't translate from Z3 tree: " + t) - protected[leon] val z3cfg : Z3Config protected[leon] var z3 : Z3Context = null @@ -134,7 +132,7 @@ trait AbstractZ3Solver } // ADT Manager - protected[leon] val adtManager = new ADTManager + protected val adtManager = new ADTManager(reporter) // Bijections between Leon Types/Functions/Ids to Z3 Sorts/Decls/ASTs protected[leon] var functions = new Bijection[TypedFunDef, Z3FuncDecl] @@ -182,12 +180,22 @@ trait AbstractZ3Solver } def declareStructuralSort(t: TypeTree): Z3Sort = { - import Z3Context.{ADTSortReference, RecursiveType, RegularSort} - //println("///"*40) //println("Declaring for: "+ct) - val adts = adtManager.defineADT(t).toSeq + adtManager.defineADT(t) match { + case Left(adts) => + declareDatatypes(adts.toSeq) + sorts.toZ3(t) + + case Right(conflicts) => + conflicts.foreach { declareStructuralSort } + declareStructuralSort(t) + } + } + + def declareDatatypes(adts: Seq[(TypeTree, DataType)]): Unit = { + import Z3Context.{ADTSortReference, RecursiveType, RegularSort} val indexMap: Map[TypeTree, Int] = adts.map(_._1).zipWithIndex.toMap @@ -228,9 +236,6 @@ trait AbstractZ3Solver } } - //println("\\\\\\"*40) - - sorts.toZ3(t) } // Prepares some of the Z3 sorts, but *not* the tuple sorts; these are created on-demand. @@ -312,9 +317,7 @@ trait AbstractZ3Solver } } - protected[leon] def toZ3Formula(expr: Expr, initialMap: Map[Identifier,Z3AST] = Map.empty) : Option[Z3AST] = { - - class CantTranslateException extends Exception + protected[leon] def toZ3Formula(expr: Expr, initialMap: Map[Identifier, Z3AST] = Map.empty): Z3AST = { var z3Vars: Map[Identifier,Z3AST] = if(initialMap.nonEmpty) { initialMap @@ -568,16 +571,11 @@ trait AbstractZ3Solver case _ => { reporter.warning(ex.getPos, "Can't handle this in translation to Z3: " + ex) - throw new CantTranslateException + throw new IllegalArgumentException } } - try { - val res = Some(rec(expr)) - res - } catch { - case e: CantTranslateException => None - } + rec(expr) } protected def fromRawArray(r: Expr, tpe: TypeTree): Expr = r match { @@ -614,12 +612,12 @@ trait AbstractZ3Solver case Int32Type => IntLiteral(hexa.toInt) case CharType => CharLiteral(hexa.toInt.toChar) case _ => - println("Unexpected target type for BV value: " + tpe) - throw new CantTranslateException(t) + reporter.warning("Unexpected target type for BV value: " + tpe) + throw new IllegalArgumentException } case None => { - println("Z3NumeralIntAST with None: " + t) - throw new CantTranslateException(t) + reporter.warning("Z3NumeralIntAST with None: " + t) + throw new IllegalArgumentException } } } else { @@ -633,12 +631,12 @@ trait AbstractZ3Solver case Int32Type => IntLiteral(hexa.toInt) case CharType => CharLiteral(hexa.toInt.toChar) case _ => - println("Unexpected target type for BV value: " + tpe) - throw new CantTranslateException(t) + reporter.warning("Unexpected target type for BV value: " + tpe) + throw new IllegalArgumentException } case None => { - println("Z3NumeralIntAST with None: " + t) - throw new CantTranslateException(t) + reporter.warning("Z3NumeralIntAST with None: " + t) + throw new IllegalArgumentException } } } @@ -671,12 +669,12 @@ trait AbstractZ3Solver case (s : IntLiteral, RawArrayValue(_, elems, default)) => val entries = elems.map { case (IntLiteral(i), v) => i -> v - case _ => throw new CantTranslateException(t) + case _ => throw new IllegalArgumentException } finiteArray(entries, Some(s, default), to) case _ => - throw new CantTranslateException(t) + throw new IllegalArgumentException } } } else { @@ -690,7 +688,7 @@ trait AbstractZ3Solver } RawArrayValue(from, entries, default) - case None => throw new CantTranslateException(t) + case None => throw new IllegalArgumentException } case tp: TypeParameter => @@ -719,7 +717,7 @@ trait AbstractZ3Solver case FunctionType(fts, tt) => model.getArrayValue(t) match { - case None => throw new CantTranslateException(t) + case None => throw new IllegalArgumentException case Some((map, elseZ3Value)) => val leonElseValue = rec(elseZ3Value, tt) val leonMap = map.toSeq.map(p => rec(p._1, tupleTypeWrap(fts)) -> rec(p._2, tt)) @@ -728,7 +726,7 @@ trait AbstractZ3Solver case tpe @ SetType(dt) => model.getSetValue(t) match { - case None => throw new CantTranslateException(t) + case None => throw new IllegalArgumentException case Some(set) => val elems = set.map(e => rec(e, dt)) finiteSet(elems, dt) @@ -759,18 +757,17 @@ trait AbstractZ3Solver // case OpIDiv => Division(rargs(0), rargs(1)) // case OpMod => Modulo(rargs(0), rargs(1)) case other => - System.err.println("Don't know what to do with this declKind : " + other) - System.err.println("Expected type: " + tpe) - System.err.println("Tree: " + t) - System.err.println("The arguments are : " + args) - new Exception().printStackTrace - throw new CantTranslateException(t) + reporter.warning("Don't know what to do with this declKind : " + other) + reporter.warning("Expected type: " + tpe) + reporter.warning("Tree: " + t) + reporter.warning("The arguments are : " + args) + throw new IllegalArgumentException } } } case _ => - System.err.println("Can't handle "+t) - throw new CantTranslateException(t) + reporter.warning("Can't handle "+t) + throw new IllegalArgumentException } } rec(tree, tpe) @@ -780,7 +777,7 @@ trait AbstractZ3Solver try { Some(fromZ3Formula(model, tree, tpe)) } catch { - case e: CantTranslateException => None + case e: IllegalArgumentException => None } } diff --git a/src/main/scala/leon/solvers/z3/FairZ3Solver.scala b/src/main/scala/leon/solvers/z3/FairZ3Solver.scala index ede0429dd5725d42c0587a2c10ed46a2c53747a6..0439ef1ed4ee4f613cef32e5d1fea5e3da313976 100644 --- a/src/main/scala/leon/solvers/z3/FairZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/FairZ3Solver.scala @@ -4,6 +4,7 @@ package leon package solvers package z3 +import utils.IncrementalBijection import _root_.z3.scala._ import purescala.Common._ @@ -32,6 +33,10 @@ class FairZ3Solver(val context : LeonContext, val program: Program) val evalGroundApps = context.findOptionOrDefault(optEvalGround) val unrollUnsatCores = context.findOptionOrDefault(optUnrollCores) + protected val errors = new IncrementalBijection[Unit, Boolean]() + protected def hasError = errors.getB(()) contains true + protected def addError() = errors += () -> true + private val evaluator: Evaluator = if(useCodeGen) { // TODO If somehow we could not recompile each time we create a solver, @@ -122,9 +127,7 @@ class FairZ3Solver(val context : LeonContext, val program: Program) } def encodeExpr(bindings: Map[Identifier, Z3AST])(e: Expr): Z3AST = { - toZ3Formula(e, bindings).getOrElse { - reporter.fatalError("Failed to translate "+e+" to z3 ("+e.getClass+")") - } + toZ3Formula(e, bindings) } def substitute(substMap: Map[Z3AST, Z3AST]): Z3AST => Z3AST = { @@ -153,6 +156,7 @@ class FairZ3Solver(val context : LeonContext, val program: Program) val unrollingBank = new UnrollingBank(reporter, templateGenerator) def push() { + errors.push() solver.push() unrollingBank.push() varsInVC = Set[Identifier]() :: varsInVC @@ -160,6 +164,10 @@ class FairZ3Solver(val context : LeonContext, val program: Program) } def pop(lvl: Int = 1) { + for (i <- 1 until lvl) { + errors.pop() + } + solver.pop(lvl) unrollingBank.pop(lvl) varsInVC = varsInVC.drop(lvl) @@ -167,11 +175,19 @@ class FairZ3Solver(val context : LeonContext, val program: Program) } override def check: Option[Boolean] = { - fairCheck(Set()) + if (hasError) { + None + } else { + fairCheck(Set()) + } } override def checkAssumptions(assumptions: Set[Expr]): Option[Boolean] = { - fairCheck(assumptions) + if (hasError) { + None + } else { + fairCheck(assumptions) + } } var foundDefinitiveAnswer = false @@ -180,22 +196,27 @@ class FairZ3Solver(val context : LeonContext, val program: Program) var definitiveCore : Set[Expr] = Set.empty def assertCnstr(expression: Expr) { - val freeVars = variablesOf(expression) - varsInVC = (varsInVC.head ++ freeVars) :: varsInVC.tail - - // We make sure all free variables are registered as variables - freeVars.foreach { v => - variables.toZ3OrCompute(Variable(v)) { - templateGenerator.encoder.encodeId(v) + try { + val freeVars = variablesOf(expression) + varsInVC = (varsInVC.head ++ freeVars) :: varsInVC.tail + + // We make sure all free variables are registered as variables + freeVars.foreach { v => + variables.toZ3OrCompute(Variable(v)) { + templateGenerator.encoder.encodeId(v) + } } - } - frameExpressions = (expression :: frameExpressions.head) :: frameExpressions.tail + frameExpressions = (expression :: frameExpressions.head) :: frameExpressions.tail - val newClauses = unrollingBank.getClauses(expression, variables.leonToZ3) + val newClauses = unrollingBank.getClauses(expression, variables.leonToZ3) - for (cl <- newClauses) { - solver.assertCnstr(cl) + for (cl <- newClauses) { + solver.assertCnstr(cl) + } + } catch { + case _: IllegalArgumentException => + addError() } } @@ -220,7 +241,7 @@ class FairZ3Solver(val context : LeonContext, val program: Program) } // these are the optional sequence of assumption literals - val assumptionsAsZ3: Seq[Z3AST] = assumptions.flatMap(toZ3Formula(_)).toSeq + val assumptionsAsZ3: Seq[Z3AST] = assumptions.map(toZ3Formula(_)).toSeq val assumptionsAsZ3Set: Set[Z3AST] = assumptionsAsZ3.toSet def z3CoreToCore(core: Seq[Z3AST]): Set[Expr] = { diff --git a/src/main/scala/leon/solvers/z3/UninterpretedZ3Solver.scala b/src/main/scala/leon/solvers/z3/UninterpretedZ3Solver.scala index 08589a2e35cc909435aa865535c5c141806ebe8b..4b67702a953618e7da3b3ced528e1522896cc6a1 100644 --- a/src/main/scala/leon/solvers/z3/UninterpretedZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/UninterpretedZ3Solver.scala @@ -51,14 +51,14 @@ class UninterpretedZ3Solver(val context : LeonContext, val program: Program) private var freeVariables = Set[Identifier]() def assertCnstr(expression: Expr) { freeVariables ++= variablesOf(expression) - solver.assertCnstr(toZ3Formula(expression).getOrElse(scala.sys.error("Failed to compile to Z3: "+expression))) + solver.assertCnstr(toZ3Formula(expression)) } override def check: Option[Boolean] = solver.check() override def checkAssumptions(assumptions: Set[Expr]): Option[Boolean] = { freeVariables ++= assumptions.flatMap(variablesOf) - solver.checkAssumptions(assumptions.toSeq.map(toZ3Formula(_).get) : _*) + solver.checkAssumptions(assumptions.toSeq.map(toZ3Formula(_)) : _*) } def getModel = {