From 1949076dc894ff30dd53918d90051405e078e766 Mon Sep 17 00:00:00 2001 From: Etienne Kneuss <ekneuss@gmail.com> Date: Wed, 3 Jun 2015 00:10:06 +0200 Subject: [PATCH] Prevent user in case of an invalid usage of ADTs - Detect definitions of ADTs that wil lbe problematic for solvers to handle. e.g. case class Ls(elems: List[Cons[T]]) - Use IncrementalBijections in native z3 solvers --- .../leon/purescala/CheckADTFieldsTypes.scala | 30 ++++++++++++ src/main/scala/leon/solvers/ADTManager.scala | 13 +++-- .../leon/solvers/smtlib/SMTLIBSolver.scala | 2 +- .../leon/solvers/z3/AbstractZ3Solver.scala | 47 ++++++++++--------- .../scala/leon/solvers/z3/FairZ3Solver.scala | 12 ++--- .../solvers/z3/Z3ModelReconstruction.scala | 2 +- .../leon/utils/IncrementalBijection.scala | 8 ++++ .../scala/leon/utils/PreprocessingPhase.scala | 3 +- 8 files changed, 81 insertions(+), 36 deletions(-) create mode 100644 src/main/scala/leon/purescala/CheckADTFieldsTypes.scala diff --git a/src/main/scala/leon/purescala/CheckADTFieldsTypes.scala b/src/main/scala/leon/purescala/CheckADTFieldsTypes.scala new file mode 100644 index 000000000..9de511180 --- /dev/null +++ b/src/main/scala/leon/purescala/CheckADTFieldsTypes.scala @@ -0,0 +1,30 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package purescala + +import Common._ +import Definitions._ +import Types._ +import TypeOps._ +import Expressions._ + +object CheckADTFieldsTypes extends UnitPhase[Program] { + + val name = "ADT Fields" + val description = "Check that fields of ADTs are hierarchy roots" + + def apply(ctx: LeonContext, program: Program) = { + program.definedClasses.foreach { + case ccd: CaseClassDef => + for(vd <- ccd.fields) { + val tpe = vd.getType + if (bestRealType(tpe) != tpe) { + ctx.reporter.warning("Definition of "+ccd.id+" has a field of a sub-type ("+vd+"): this type is not supported as-is by solvers and will be up-casted. This may cause issues such as crashes.") + } + } + case _ => + } + } + +} diff --git a/src/main/scala/leon/solvers/ADTManager.scala b/src/main/scala/leon/solvers/ADTManager.scala index 296025d18..e28255267 100644 --- a/src/main/scala/leon/solvers/ADTManager.scala +++ b/src/main/scala/leon/solvers/ADTManager.scala @@ -2,20 +2,23 @@ package leon package solvers import purescala.Types._ +import purescala.TypeOps._ import purescala.Common._ case class DataType(sym: Identifier, cases: Seq[Constructor]) { - override def toString = { - "Datatype: "+sym.uniqueName+"\n"+cases.map(c => " - "+c.toString).mkString("\n") + def asString(implicit ctx: LeonContext) = { + "Datatype: "+sym.asString+"\n"+cases.map(c => " - "+c.asString(ctx)).mkString("\n") } } case class Constructor(sym: Identifier, tpe: TypeTree, fields: Seq[(Identifier, TypeTree)]) { - override def toString = { - sym.uniqueName+" ["+tpe+"] "+fields.map(f => f._1.uniqueName+": "+f._2).mkString("(", ", ", ")") + def asString(implicit ctx: LeonContext) = { + sym.asString(ctx)+" ["+tpe.asString(ctx)+"] "+fields.map(f => f._1.asString(ctx)+": "+f._2.asString(ctx)).mkString("(", ", ", ")") } } -class ADTManager(reporter: Reporter) { +class ADTManager(ctx: LeonContext) { + val reporter = ctx.reporter + protected def freshId(id: Identifier): Identifier = freshId(id.name) protected def freshId(name: String): Identifier = FreshIdentifier(name) diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala index 87a147ee3..61b39b0ce 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala @@ -99,7 +99,7 @@ abstract class SMTLIBSolver(val context: LeonContext, QualifiedIdentifier(SMTIdentifier(s)) } - protected val adtManager = new ADTManager(reporter) + protected val adtManager = new ADTManager(context) protected val library = program.library diff --git a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala index 2e1f05135..fe7b0fec7 100644 --- a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala @@ -72,7 +72,7 @@ trait AbstractZ3Solver } def functionDefToDecl(tfd: TypedFunDef): Z3FuncDecl = { - functions.toZ3OrCompute(tfd) { + functions.cachedB(tfd) { val sortSeq = tfd.params.map(vd => typeToSort(vd.getType)) val returnSort = typeToSort(tfd.returnType) @@ -81,7 +81,7 @@ trait AbstractZ3Solver } def genericValueToDecl(gv: GenericValue): Z3FuncDecl = { - generics.toZ3OrCompute(gv) { + generics.cachedB(gv) { z3.mkFreshFuncDecl(gv.tp.toString+"#"+gv.id+"!val", Seq(), typeToSort(gv.tp)) } } @@ -132,13 +132,13 @@ trait AbstractZ3Solver } // ADT Manager - protected val adtManager = new ADTManager(reporter) + protected val adtManager = new ADTManager(context) // Bijections between Leon Types/Functions/Ids to Z3 Sorts/Decls/ASTs - protected[leon] var functions = new Bijection[TypedFunDef, Z3FuncDecl] - protected[leon] var generics = new Bijection[GenericValue, Z3FuncDecl] - protected[leon] var sorts = new Bijection[TypeTree, Z3Sort] - protected[leon] var variables = new Bijection[Expr, Z3AST] + protected val functions = new IncrementalBijection[TypedFunDef, Z3FuncDecl]() + protected val generics = new IncrementalBijection[GenericValue, Z3FuncDecl]() + protected val sorts = new IncrementalBijection[TypeTree, Z3Sort]() + protected val variables = new IncrementalBijection[Expr, Z3AST]() protected val constructors = new IncrementalBijection[TypeTree, Z3FuncDecl]() protected val selectors = new IncrementalBijection[(TypeTree, Int), Z3FuncDecl]() @@ -155,6 +155,9 @@ trait AbstractZ3Solver generics.clear() sorts.clear() variables.clear() + constructors.clear() + selectors.clear() + testers.clear() prepareSorts() @@ -186,7 +189,7 @@ trait AbstractZ3Solver adtManager.defineADT(t) match { case Left(adts) => declareDatatypes(adts.toSeq) - sorts.toZ3(t) + sorts.toB(t) case Right(conflicts) => conflicts.foreach { declareStructuralSort } @@ -269,16 +272,16 @@ trait AbstractZ3Solver // assumes prepareSorts has been called.... protected[leon] def typeToSort(oldtt: TypeTree): Z3Sort = normalizeType(oldtt) match { case Int32Type | BooleanType | IntegerType | CharType => - sorts.toZ3(oldtt) + sorts.toB(oldtt) case tpe @ (_: ClassType | _: ArrayType | _: TupleType | UnitType) => - sorts.toZ3OrCompute(tpe) { + sorts.cachedB(tpe) { declareStructuralSort(tpe) } case tt @ SetType(base) => - sorts.toZ3OrCompute(tt) { + sorts.cachedB(tt) { z3.mkSetSort(typeToSort(base)) } @@ -286,7 +289,7 @@ trait AbstractZ3Solver typeToSort(RawArrayType(fromType, library.optionType(toType))) case rat @ RawArrayType(from, to) => - sorts.toZ3OrCompute(rat) { + sorts.cachedB(rat) { val fromSort = typeToSort(from) val toSort = typeToSort(to) @@ -294,7 +297,7 @@ trait AbstractZ3Solver } case tt @ TypeParameter(id) => - sorts.toZ3OrCompute(tt) { + sorts.cachedB(tt) { val symbol = z3.mkFreshStringSymbol(id.name) val newTPSort = z3.mkUninterpretedSort(symbol) @@ -302,7 +305,7 @@ trait AbstractZ3Solver } case ft @ FunctionType(from, to) => - sorts.toZ3OrCompute(ft) { + sorts.cachedB(ft) { val fromSort = typeToSort(tupleTypeWrap(from)) val toSort = typeToSort(to) @@ -310,7 +313,7 @@ trait AbstractZ3Solver } case other => - sorts.toZ3OrCompute(other) { + sorts.cachedB(other) { reporter.warning(other.getPos, "Resorting to uninterpreted type for : " + other) val symbol = z3.mkIntSymbol(FreshIdentifier("unint").globalId) z3.mkUninterpretedSort(symbol) @@ -324,7 +327,7 @@ trait AbstractZ3Solver } else { // FIXME TODO pleeeeeeeease make this cleaner. Ie. decide what set of // variable has to remain in a map etc. - variables.leonToZ3.collect{ case (Variable(id), p2) => id -> p2 } + variables.aToB.collect{ case (Variable(id), p2) => id -> p2 } } def rec(ex: Expr): Z3AST = ex match { @@ -642,14 +645,14 @@ trait AbstractZ3Solver } case Z3AppAST(decl, args) => val argsSize = args.size - if(argsSize == 0 && (variables containsZ3 t)) { - variables.toLeon(t) - } else if(functions containsZ3 decl) { - val tfd = functions.toLeon(decl) + if(argsSize == 0 && (variables containsB t)) { + variables.toA(t) + } else if(functions containsB decl) { + val tfd = functions.toA(decl) assert(tfd.params.size == argsSize) FunctionInvocation(tfd, args.zip(tfd.params).map{ case (a, p) => rec(a, p.getType) }) - } else if (generics containsZ3 decl) { - generics.toLeon(decl) + } else if (generics containsB decl) { + generics.toA(decl) } else if (constructors containsB decl) { constructors.toA(decl) match { case cct: CaseClassType => diff --git a/src/main/scala/leon/solvers/z3/FairZ3Solver.scala b/src/main/scala/leon/solvers/z3/FairZ3Solver.scala index 0439ef1ed..8a845754b 100644 --- a/src/main/scala/leon/solvers/z3/FairZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/FairZ3Solver.scala @@ -66,8 +66,8 @@ class FairZ3Solver(val context : LeonContext, val program: Program) val functionsModel: Map[Z3FuncDecl, (Seq[(Seq[Z3AST], Z3AST)], Z3AST)] = model.getModelFuncInterpretations.map(i => (i._1, (i._2, i._3))).toMap val functionsAsMap: Map[Identifier, Expr] = functionsModel.flatMap(p => { - if (functions containsZ3 p._1) { - val tfd = functions.toLeon(p._1) + if (functions containsB p._1) { + val tfd = functions.toA(p._1) if (!tfd.hasImplementation) { val (cses, default) = p._2 val ite = cses.foldLeft(fromZ3Formula(model, default, tfd.returnType))((expr, q) => IfExpr( @@ -82,8 +82,8 @@ class FairZ3Solver(val context : LeonContext, val program: Program) }) val constantFunctionsAsMap: Map[Identifier, Expr] = model.getModelConstantInterpretations.flatMap(p => { - if(functions containsZ3 p._1) { - val tfd = functions.toLeon(p._1) + if(functions containsB p._1) { + val tfd = functions.toA(p._1) if(!tfd.hasImplementation) { Seq((tfd.id, fromZ3Formula(model, p._2, tfd.returnType))) } else Seq() @@ -202,14 +202,14 @@ class FairZ3Solver(val context : LeonContext, val program: Program) // We make sure all free variables are registered as variables freeVars.foreach { v => - variables.toZ3OrCompute(Variable(v)) { + variables.cachedB(Variable(v)) { templateGenerator.encoder.encodeId(v) } } frameExpressions = (expression :: frameExpressions.head) :: frameExpressions.tail - val newClauses = unrollingBank.getClauses(expression, variables.leonToZ3) + val newClauses = unrollingBank.getClauses(expression, variables.aToB) for (cl <- newClauses) { solver.assertCnstr(cl) diff --git a/src/main/scala/leon/solvers/z3/Z3ModelReconstruction.scala b/src/main/scala/leon/solvers/z3/Z3ModelReconstruction.scala index 2abe02272..62f43d736 100644 --- a/src/main/scala/leon/solvers/z3/Z3ModelReconstruction.scala +++ b/src/main/scala/leon/solvers/z3/Z3ModelReconstruction.scala @@ -19,7 +19,7 @@ trait Z3ModelReconstruction { def modelValue(model: Z3Model, id: Identifier, tpe: TypeTree = null) : Option[Expr] = { val expectedType = if(tpe == null) id.getType else tpe - variables.getZ3(id.toVariable).flatMap { z3ID => + variables.getB(id.toVariable).flatMap { z3ID => expectedType match { case BooleanType => model.evalAs[Boolean](z3ID).map(BooleanLiteral) case Int32Type => diff --git a/src/main/scala/leon/utils/IncrementalBijection.scala b/src/main/scala/leon/utils/IncrementalBijection.scala index 6ba7be81d..99d0d70cd 100644 --- a/src/main/scala/leon/utils/IncrementalBijection.scala +++ b/src/main/scala/leon/utils/IncrementalBijection.scala @@ -27,6 +27,14 @@ class IncrementalBijection[A,B] extends Bijection[A,B] { case None => recursiveGet(a2bStack, a) } + def aToB: Map[A,B] = { + a2bStack.reverse.foldLeft(Map[A,B]()) { _ ++ _ } ++ a2b + } + + def bToA: Map[B,A] = { + b2aStack.reverse.foldLeft(Map[B,A]()) { _ ++ _ } ++ b2a + } + override def containsA(a: A) = getB(a).isDefined override def containsB(b: B) = getA(b).isDefined diff --git a/src/main/scala/leon/utils/PreprocessingPhase.scala b/src/main/scala/leon/utils/PreprocessingPhase.scala index 0f51eb6ed..7802d4509 100644 --- a/src/main/scala/leon/utils/PreprocessingPhase.scala +++ b/src/main/scala/leon/utils/PreprocessingPhase.scala @@ -6,7 +6,7 @@ package utils import purescala.Definitions.Program import purescala.ScalaPrinter -import purescala.{MethodLifting, CompleteAbstractDefinitions} +import purescala.{MethodLifting, CompleteAbstractDefinitions,CheckADTFieldsTypes} import synthesis.{ConvertWithOracle, ConvertHoles} import verification.InjectAsserts @@ -24,6 +24,7 @@ object PreprocessingPhase extends TransformationPhase { ConvertWithOracle andThen ConvertHoles andThen CompleteAbstractDefinitions andThen + CheckADTFieldsTypes andThen InjectAsserts -- GitLab