diff --git a/src/main/scala/inox/solvers/ADTManager.scala b/src/main/scala/inox/solvers/ADTManager.scala index 8ede73df1d14eaa0b10e2f71771d3c5bea3d077d..25f4021d032847c9808fd3f18b62512b9526258f 100644 --- a/src/main/scala/inox/solvers/ADTManager.scala +++ b/src/main/scala/inox/solvers/ADTManager.scala @@ -1,146 +1,149 @@ /* Copyright 2009-2016 EPFL, Lausanne */ -package leon +package inox package solvers -import purescala.Types._ -import purescala.Common._ +trait ADTManagers { + val program: Program + import program._ + import trees._ -case class DataType(sym: Identifier, cases: Seq[Constructor]) extends Printable { - def asString(implicit ctx: LeonContext) = { - "Datatype: "+sym.asString+"\n"+cases.map(c => " - "+c.asString(ctx)).mkString("\n") - } -} -case class Constructor(sym: Identifier, tpe: TypeTree, fields: Seq[(Identifier, TypeTree)]) extends Printable { - def asString(implicit ctx: LeonContext) = { - sym.asString(ctx)+" ["+tpe.asString(ctx)+"] "+fields.map(f => f._1.asString(ctx)+": "+f._2.asString(ctx)).mkString("(", ", ", ")") + case class DataType(sym: Identifier, cases: Seq[Constructor]) extends Printable { + def asString(implicit opts: PrinterOptions) = { + "Datatype: " + sym.asString(opts) + "\n" + cases.map(c => " - " + c.asString(opts)).mkString("\n") + } } -} - -class ADTManager(ctx: LeonContext) { - val reporter = ctx.reporter - - protected def freshId(id: Identifier): Identifier = freshId(id.name) - protected def freshId(name: String): Identifier = FreshIdentifier(name) - - protected def getHierarchy(ct: ClassType): (ClassType, Seq[CaseClassType]) = ct.parent match { - case Some(p) => - getHierarchy(p) - case None => (ct, ct match { - case act: AbstractClassType => - act.knownCCDescendants - case cct: CaseClassType => - List(cct) - }) + + case class Constructor(sym: Identifier, tpe: Type, fields: Seq[(Identifier, Type)]) extends Printable { + def asString(implicit opts: PrinterOptions) = { + sym.asString(opts) + " [" + tpe.asString(opts) + "] " + fields.map(f => f._1.asString(opts) + ": " + f._2.asString(opts)).mkString("(", ", ", ")") + } } - protected var defined = Set[TypeTree]() - protected var locked = Set[TypeTree]() + class ADTManager(ctx: InoxContext) { + val reporter = ctx.reporter - protected var discovered = Map[TypeTree, DataType]() + protected def freshId(id: Identifier): Identifier = freshId(id.name) - def defineADT(t: TypeTree): Either[Map[TypeTree, DataType], Set[TypeTree]] = { - discovered = Map() - locked = Set() + protected def freshId(name: String): Identifier = FreshIdentifier(name) - findDependencies(t) + protected def getHierarchy(ct: TypedClassDef): (TypedClassDef, Seq[TypedCaseClassDef]) = ct match { + case acd: TypedAbstractClassDef => + (acd, acd.descendants) + case ccd: TypedCaseClassDef => + (ccd, List(ccd)) - val conflicts = discovered.keySet & locked + } - if (conflicts(t)) { - // There is no way to solve this, the type we requested is in conflict - val str = "Encountered ADT that can't be defined.\n" + - "It appears it has recursive references through non-structural types (such as arrays, maps, or sets)." - val err = new Unsupported(t, str)(ctx) - reporter.warning(err.getMessage) - throw err - } else { - // We might be able to define some despite conflicts - if (conflicts.isEmpty) { - for ((t, dt) <- discovered) { - defined += t - } - Left(discovered) + protected var defined = Set[Type]() + protected var locked = Set[Type]() + + protected var discovered = Map[Type, DataType]() + + def defineADT(t: Type): Either[Map[Type, DataType], Set[Type]] = { + discovered = Map() + locked = Set() + + findDependencies(t) + + val conflicts = discovered.keySet & locked + + if (conflicts(t)) { + // There is no way to solve this, the type we requested is in conflict + val str = "Encountered ADT that can't be defined.\n" + + "It appears it has recursive references through non-structural types (such as arrays, maps, or sets)." + val err = new Unsupported(t, str)(ctx) + reporter.warning(err.getMessage) + throw err } else { - Right(conflicts) + // We might be able to define some despite conflicts + if (conflicts.isEmpty) { + for ((t, dt) <- discovered) { + defined += t + } + Left(discovered) + } else { + Right(conflicts) + } } } - } - def forEachType(t: TypeTree, alreadyDone: Set[TypeTree] = Set())(f: TypeTree => Unit): Unit = if(!alreadyDone(t)) { - t match { - case NAryType(tps, builder) => - f(t) - val alreadyDone2 = alreadyDone + t - // note: each of the tps could be abstract classes in which case we need to - // lock their dependencies, transitively. - tps.foreach { - case ct: ClassType => - val (root, sub) = getHierarchy(ct) - (root +: sub).flatMap(_.fields.map(_.getType)).foreach(subt => forEachType(subt, alreadyDone2)(f)) - case othert => - forEachType(othert, alreadyDone2)(f) + def forEachType(t: Type, alreadyDone: Set[Type] = Set())(f: Type => Unit): Unit = if (!alreadyDone(t)) { + t match { + case NAryType(tps, builder) => + f(t) + val alreadyDone2 = alreadyDone + t + // note: each of the tps could be abstract classes in which case we need to + // lock their dependencies, transitively. + tps.foreach { + case ct: ClassType => + val (root, sub) = getHierarchy(ct.lookupClass.get) + sub.flatMap(_.fields.map(_.getType)).foreach(subt => forEachType(subt, alreadyDone2)(f)) + case othert => + forEachType(othert, alreadyDone2)(f) + } } } - } - protected def findDependencies(t: TypeTree): Unit = t match { - case _: SetType | _: MapType => - forEachType(t) { tpe => - if (!defined(tpe)) { - locked += tpe + protected def findDependencies(t: Type): Unit = t match { + case _: SetType | _: MapType => + forEachType(t) { tpe => + if (!defined(tpe)) { + locked += tpe + } } - } - case ct: ClassType => - val (root, sub) = getHierarchy(ct) + case ct: ClassType => + val (root, sub) = getHierarchy(ct.lookupClass.get) - if (!(discovered contains root) && !(defined contains root)) { - val sym = freshId(root.id) + if (!(discovered contains root.toType) && !(defined contains root.toType)) { + val sym = freshId(root.id) - val conss = sub.map { case cct => - Constructor(freshId(cct.id), cct, cct.fields.map(vd => (freshId(vd.id), vd.getType))) - } + val conss = sub.map { case cct => + Constructor(freshId(cct.id), cct.toType, cct.fields.map(vd => (freshId(vd.id), vd.getType))) + } - discovered += (root -> DataType(sym, conss)) + discovered += (root.toType -> DataType(sym, conss)) - // look for dependencies - for (ct <- root +: sub; f <- ct.fields) { - findDependencies(f.getType) + // look for dependencies + for (ct <- sub; f <- ct.fields) { + findDependencies(f.getType) + } } - } - case tt @ TupleType(bases) => - if (!(discovered contains t) && !(defined contains t)) { - val sym = freshId("tuple"+bases.size) + case tt@TupleType(bases) => + if (!(discovered contains t) && !(defined contains t)) { + val sym = freshId("tuple" + bases.size) - val c = Constructor(freshId(sym.name), tt, bases.zipWithIndex.map { - case (tpe, i) => (freshId("_"+(i+1)), tpe) - }) + val c = Constructor(freshId(sym.name), tt, bases.zipWithIndex.map { + case (tpe, i) => (freshId("_" + (i + 1)), tpe) + }) - discovered += (tt -> DataType(sym, Seq(c))) + discovered += (tt -> DataType(sym, Seq(c))) - for (b <- bases) { - findDependencies(b) + for (b <- bases) { + findDependencies(b) + } } - } - case UnitType => - if (!(discovered contains t) && !(defined contains t)) { - discovered += (t -> DataType(freshId("Unit"), Seq(Constructor(freshId("Unit"), t, Nil)))) - } + case UnitType => + if (!(discovered contains t) && !(defined contains t)) { + discovered += (t -> DataType(freshId("Unit"), Seq(Constructor(freshId("Unit"), t, Nil)))) + } - case tp @ TypeParameter(id) => - if (!(discovered contains t) && !(defined contains t)) { - val sym = freshId(id.name) + case tp@TypeParameter(id) => + if (!(discovered contains t) && !(defined contains t)) { + val sym = freshId(id.name) - val c = Constructor(freshId(sym.name), tp, List( - (freshId("val"), IntegerType) - )) + val c = Constructor(freshId(sym.name), tp, List( + (freshId("val"), IntegerType) + )) - discovered += (tp -> DataType(sym, Seq(c))) - } + discovered += (tp -> DataType(sym, Seq(c))) + } - case _ => + case _ => + } } -} + +} \ No newline at end of file diff --git a/src/main/scala/inox/solvers/smtlib/SMTLIBSolver.scala b/src/main/scala/inox/solvers/smtlib/SMTLIBSolver.scala index c75b140a228f2c8bb2fa280e1ba8b8fb49327c24..da60b767d133fec5f2387f5272f18775c0dcb643 100644 --- a/src/main/scala/inox/solvers/smtlib/SMTLIBSolver.scala +++ b/src/main/scala/inox/solvers/smtlib/SMTLIBSolver.scala @@ -1,14 +1,9 @@ /* Copyright 2009-2016 EPFL, Lausanne */ -package leon +package inox package solvers package smtlib -import purescala.Common._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Definitions._ - import _root_.smtlib.parser.Commands.{Assert => SMTAssert, FunDef => SMTFunDef, _} import _root_.smtlib.parser.Terms.{Identifier => _, _} import _root_.smtlib.parser.CommandsResponses.{Error => ErrorResponse, _} @@ -16,8 +11,14 @@ import _root_.smtlib.parser.CommandsResponses.{Error => ErrorResponse, _} import theories._ import utils._ -abstract class SMTLIBSolver(val sctx: SolverContext, val program: Program, theories: TheoryEncoder) - extends Solver with SMTLIBTarget with NaiveAssumptionSolver { +trait SMTLIBSolver extends Solver with SMTLIBTarget { + val theories: TheoryEncoder { val trees: program.trees.type } + + import program._ + import trees._ + import symbols._ + import exprOps.variablesOf + import SolverResponses._ /* Solver name */ def targetName: String @@ -35,7 +36,7 @@ abstract class SMTLIBSolver(val sctx: SolverContext, val program: Program, theor /* Public solver interface */ def assertCnstr(raw: Expr): Unit = if (!hasError) { try { - val bindings = variablesOf(raw).map(id => id -> ids.cachedB(id)(theories.encode(id))).toMap + val bindings = variablesOf(raw).map(v => v -> ids.cachedB(v.id)(theories.encode(v.id))).toMap val expr = theories.encode(raw)(bindings) variablesOf(expr).foreach(declareVariable) @@ -53,13 +54,13 @@ abstract class SMTLIBSolver(val sctx: SolverContext, val program: Program, theor emit(Reset(), rawOut = true) match { case ErrorResponse(msg) => reporter.warning(s"Failed to reset $name: $msg") - throw new CantResetException(this) + throw new Exception() //CantResetException(this) case _ => } } - override def check: Option[Boolean] = { - if (hasError) None + override def check[R <: SolverResponse[Map[ValDef, Expr], Set[Expr]]](config: Configuration { type Response = R }): R = { + if (hasError) Unknown else emit(CheckSat()) match { case CheckSatStatus(SatStatus) => Some(true) case CheckSatStatus(UnsatStatus) => Some(false) diff --git a/src/main/scala/inox/solvers/smtlib/SMTLIBTarget.scala b/src/main/scala/inox/solvers/smtlib/SMTLIBTarget.scala index 445960ffaaa0b40562d09d90c2bcd84c3a8dcd46..a7f84c0867e671cf9e0a5ae8cd44bc03170ad9fc 100644 --- a/src/main/scala/inox/solvers/smtlib/SMTLIBTarget.scala +++ b/src/main/scala/inox/solvers/smtlib/SMTLIBTarget.scala @@ -1,20 +1,12 @@ /* Copyright 2009-2016 EPFL, Lausanne */ -package leon +package inox package solvers package smtlib +import inox.utils.Interruptible import utils._ -import purescala.Common._ -import purescala.Expressions._ -import purescala.Extractors._ -import purescala.ExprOps._ -import purescala.Types._ -import purescala.TypeOps._ -import purescala.Constructors._ -import purescala.Definitions._ - import _root_.smtlib.common._ import _root_.smtlib.printer.{ RecursivePrinter => SMTPrinter } import _root_.smtlib.parser.Commands.{ @@ -35,26 +27,27 @@ import _root_.smtlib.theories.{Constructors => SmtLibConstructors, _} import _root_.smtlib.theories.experimental._ import _root_.smtlib.interpreters.ProcessInterpreter -trait SMTLIBTarget extends Interruptible { - val context: LeonContext +trait SMTLIBTarget extends Interruptible with ADTManagers { val program: Program - + import program._ + import trees._ + import symbols._ def targetName: String implicit val debugSection: DebugSection - protected def interpreterOps(ctx: LeonContext): Seq[String] + protected def interpreterOps(ctx: InoxContext): Seq[String] - protected def getNewInterpreter(ctx: LeonContext): ProcessInterpreter + protected def getNewInterpreter(ctx: InoxContext): ProcessInterpreter protected def unsupported(t: Tree, str: String): Nothing - protected lazy val interpreter = getNewInterpreter(context) + protected lazy val interpreter = getNewInterpreter(ctx) /* Interruptible interface */ private var interrupted = false - context.interruptManager.registerForInterrupts(this) + ctx.interruptManager.registerForInterrupts(this) override def interrupt(): Unit = { interrupted = true @@ -66,14 +59,14 @@ trait SMTLIBTarget extends Interruptible { def free() = { interpreter.free() - context.interruptManager.unregisterForInterrupts(this) + ctx.interruptManager.unregisterForInterrupts(this) debugOut foreach { _.close } } /* Printing VCs */ protected lazy val debugOut: Option[java.io.FileWriter] = { - if (context.reporter.isDebugEnabled) { - val file = context.files.headOption.map(_.getName).getOrElse("NA") + if (ctx.reporter.isDebugEnabled) { + val file = ""//ctx.files.headOption.map(_.getName).getOrElse("NA") val n = DebugFileNumbers.next(targetName + file) val fileName = s"smt-sessions/$targetName-$file-$n.smt2" @@ -81,11 +74,11 @@ trait SMTLIBTarget extends Interruptible { val javaFile = new java.io.File(fileName) javaFile.getParentFile.mkdirs() - context.reporter.debug(s"Outputting smt session into $fileName") + ctx.reporter.debug(s"Outputting smt session into $fileName") val fw = new java.io.FileWriter(javaFile, false) - fw.write("; Options: " + interpreterOps(context).mkString(" ") + "\n") + fw.write("; Options: " + interpreterOps(ctx).mkString(" ") + "\n") Some(fw) } else { @@ -102,7 +95,7 @@ trait SMTLIBTarget extends Interruptible { } interpreter.eval(cmd) match { case err @ ErrorResponse(msg) if !hasError && !interrupted && !rawOut => - context.reporter.warning(s"Unexpected error from $targetName solver: $msg") + ctx.reporter.warning(s"Unexpected error from $targetName solver: $msg") //println(Thread.currentThread().getStackTrace.map(_.toString).take(10).mkString("\n")) // Store that there was an error. Now all following check() // invocations will return None @@ -116,7 +109,7 @@ trait SMTLIBTarget extends Interruptible { def parseSuccess() = { val res = interpreter.parser.parseGenResponse if (res != Success) { - context.reporter.warning("Unnexpected result from " + targetName + ": " + res + " expected success") + ctx.reporter.warning("Unnexpected result from " + targetName + ": " + res + " expected success") } } @@ -136,9 +129,7 @@ trait SMTLIBTarget extends Interruptible { import scala.language.implicitConversions protected implicit def symbolToQualifiedId(s: SSymbol): QualifiedIdentifier = SimpleSymbol(s) - protected val adtManager = new ADTManager(context) - - protected val library = program.library + protected val adtManager = new ADTManager(ctx) protected def id2sym(id: Identifier): SSymbol = { SSymbol(id.uniqueNameDelimited("!").replace("|", "$pipe").replace("\\", "$backslash")) @@ -148,11 +139,11 @@ trait SMTLIBTarget extends Interruptible { protected def freshSym(name: String): SSymbol = id2sym(FreshIdentifier(name)) /* Metadata for CC, and variables */ - protected val constructors = new IncrementalBijection[TypeTree, SSymbol]() - protected val selectors = new IncrementalBijection[(TypeTree, Int), SSymbol]() - protected val testers = new IncrementalBijection[TypeTree, SSymbol]() + protected val constructors = new IncrementalBijection[Type, SSymbol]() + protected val selectors = new IncrementalBijection[(Type, Int), SSymbol]() + protected val testers = new IncrementalBijection[Type, SSymbol]() protected val variables = new IncrementalBijection[Identifier, SSymbol]() - protected val sorts = new IncrementalBijection[TypeTree, Sort]() + protected val sorts = new IncrementalBijection[Type, Sort]() protected val functions = new IncrementalBijection[TypedFunDef, SSymbol]() protected val lambdas = new IncrementalBijection[FunctionType, SSymbol]() protected val errors = new IncrementalBijection[Unit, Boolean]() @@ -161,76 +152,38 @@ trait SMTLIBTarget extends Interruptible { /* Helper functions */ - protected def normalizeType(t: TypeTree): TypeTree = t match { - case ct: ClassType => ct.root + protected def normalizeType(t: Type): Type = t match { + case ct: ClassType => ct.lookupClass.get.root.toType case tt: TupleType => tupleTypeWrap(tt.bases.map(normalizeType)) case _ => t } - protected def quantifiedTerm( - quantifier: (SortedVar, Seq[SortedVar], Term) => Term, - vars: Seq[Identifier], - body: Expr)( - implicit bindings: Map[Identifier, Term]): Term = { + protected def quantifiedTerm(quantifier: (SortedVar, Seq[SortedVar], Term) => Term, + vars: Seq[ValDef], + body: Expr) + (implicit bindings: Map[Identifier, Term]) + : Term = { if (vars.isEmpty) toSMT(body) if (vars.isEmpty) toSMT(body)(Map()) else { - val sortedVars = vars map { id => - SortedVar(id2sym(id), declareSort(id.getType)) + val sortedVars = vars map { vd => + SortedVar(id2sym(vd.id), declareSort(vd.getType)) } quantifier( sortedVars.head, sortedVars.tail, - toSMT(body)(bindings ++ vars.map { id => id -> (id2sym(id): Term) }.toMap)) + toSMT(body)(bindings ++ vars.map { vd => vd.id -> (id2sym(vd.id): Term) }.toMap)) } } // Returns a quantified term where all free variables in the body have been quantified - protected def quantifiedTerm(quantifier: (SortedVar, Seq[SortedVar], Term) => Term, body: Expr)( - implicit bindings: Map[Identifier, Term]): Term = - quantifiedTerm(quantifier, variablesOf(body).toSeq, body) - - protected def fromRawArray(r: RawArrayValue, tpe: TypeTree): Expr = tpe match { - case SetType(base) => - if (r.default != BooleanLiteral(false)) { - unsupported(r, "Solver returned a co-finite set which is not supported.") - } - require(r.keyTpe == base, s"Type error in solver model, expected $base, found ${r.keyTpe}") - FiniteSet(r.elems.keySet, base) - - case BagType(base) => - if (r.default != InfiniteIntegerLiteral(0)) { - unsupported(r, "Solver returned an infinite bag which is not supported.") - } - require(r.keyTpe == base, s"Type error in solver model, expected $base, found ${r.keyTpe}") - FiniteBag(r.elems, base) - - case RawArrayType(from, to) => - r - - case ft @ FunctionType(from, to) => - val elems = r.elems.toSeq.map { case (k, v) => unwrapTuple(k, from.size) -> v } - FiniteLambda(elems, r.default, ft) - - case MapType(from, to) => - // We expect a RawArrayValue with keys in from and values in Option[to], - // with default value == None - if (r.default.getType != library.noneType(to)) { - unsupported(r, "Solver returned a co-finite map which is not supported.") - } - require(r.keyTpe == from, s"Type error in solver model, expected $from, found ${r.keyTpe}") - - val elems = r.elems.flatMap { - case (k, CaseClass(leonSome, Seq(x))) => Some(k -> x) - case (k, _) => None - }.toMap - FiniteMap(elems, from, to) - - case other => - unsupported(other, "Unable to extract from raw array for " + tpe) + protected def quantifiedTerm(quantifier: (SortedVar, Seq[SortedVar], Term) => Term, body: Expr) + (implicit bindings: Map[Identifier, Term]) + : Term = { + quantifiedTerm(quantifier, exprOps.variablesOf(body).toSeq.map(_.toVal), body) } - protected def declareSort(t: TypeTree): Sort = { + protected def declareSort(t: Type): Sort = { val tpe = normalizeType(t) sorts.cachedB(tpe) { tpe match { @@ -240,16 +193,13 @@ trait SMTLIBTarget extends Interruptible { case Int32Type => FixedSizeBitVectors.BitVectorSort(32) case CharType => FixedSizeBitVectors.BitVectorSort(32) - case RawArrayType(from, to) => - Sort(SMTIdentifier(SSymbol("Array")), Seq(declareSort(from), declareSort(to))) - case MapType(from, to) => - declareSort(RawArrayType(from, library.optionType(to))) + Sort(SMTIdentifier(SSymbol("Array")), Seq(declareSort(from), declareSort(to))) case FunctionType(from, to) => Ints.IntSort() - case _: ClassType | _: TupleType | _: ArrayType | _: TypeParameter | UnitType => + case _: ClassType | _: TupleType | _: TypeParameter | UnitType => declareStructuralSort(tpe) case other => @@ -258,7 +208,7 @@ trait SMTLIBTarget extends Interruptible { } } - protected def declareDatatypes(datatypes: Map[TypeTree, DataType]): Unit = { + protected def declareDatatypes(datatypes: Map[Type, DataType]): Unit = { // We pre-declare ADTs for ((tpe, DataType(sym, _)) <- datatypes) { sorts += tpe -> Sort(SMTIdentifier(id2sym(sym))) @@ -288,7 +238,7 @@ trait SMTLIBTarget extends Interruptible { } - protected def declareStructuralSort(t: TypeTree): Sort = { + protected def declareStructuralSort(t: Type): Sort = { // Populates the dependencies of the structural type to define. adtManager.defineADT(t) match { case Left(adts) => @@ -301,10 +251,11 @@ trait SMTLIBTarget extends Interruptible { } } - protected def declareVariable(id: Identifier): SSymbol = { + protected def declareVariable(vd: ValDef): SSymbol = declareVariable(vd.id, vd.getType) + protected def declareVariable(id: Identifier, tp: Type) = { variables.cachedB(id) { val s = id2sym(id) - val cmd = DeclareFun(s, List(), declareSort(id.getType)) + val cmd = DeclareFun(s, List(), declareSort(tp)) emit(cmd) s } @@ -352,56 +303,54 @@ trait SMTLIBTarget extends Interruptible { } } - protected def toSMT(t: TypeTree): SExpr = { + protected def toSMT(t: Type): SExpr = { sortToSMT(declareSort(t)) } protected def toSMT(e: Expr)(implicit bindings: Map[Identifier, Term]): Term = { e match { - case Variable(id) => + case Variable(id, tp) => declareSort(e.getType) bindings.getOrElse(id, variables.toB(id)) case UnitLiteral() => declareSort(UnitType) + declareVariable(FreshIdentifier("Unit"), UnitType) - declareVariable(FreshIdentifier("Unit", UnitType)) - - case InfiniteIntegerLiteral(i) => if (i >= 0) Ints.NumeralLit(i) else Ints.Neg(Ints.NumeralLit(-i)) - case IntLiteral(i) => FixedSizeBitVectors.BitVectorLit(Hexadecimal.fromInt(i)) - case FractionalLiteral(n, d) => Reals.Div(Reals.NumeralLit(n), Reals.NumeralLit(d)) - case CharLiteral(c) => FixedSizeBitVectors.BitVectorLit(Hexadecimal.fromInt(c.toInt)) - case BooleanLiteral(v) => Core.BoolConst(v) + case IntegerLiteral(i) => if (i >= 0) Ints.NumeralLit(i) else Ints.Neg(Ints.NumeralLit(-i)) + case IntLiteral(i) => FixedSizeBitVectors.BitVectorLit(Hexadecimal.fromInt(i)) + case FractionLiteral(n, d) => Reals.Div(Reals.NumeralLit(n), Reals.NumeralLit(d)) + case CharLiteral(c) => FixedSizeBitVectors.BitVectorLit(Hexadecimal.fromInt(c.toInt)) + case BooleanLiteral(v) => Core.BoolConst(v) case Let(b, d, e) => - val id = id2sym(b) + val id = id2sym(b.id) val value = toSMT(d) - val newBody = toSMT(e)(bindings + (b -> id)) + val newBody = toSMT(e)(bindings + (b.id -> id)) SMTLet( VarBinding(id, value), Seq(), newBody) - case er @ Error(tpe, _) => - declareVariable(FreshIdentifier("error_value", tpe)) - - case s @ CaseClassSelector(cct, e, id) => - declareSort(cct) - val selector = selectors.toB((cct, s.selectorIndex)) + case s @ CaseClassSelector(e, id) => + declareSort(e.getType) + val selector = selectors.toB((e.getType, s.classIndex.get._2)) FunctionApplication(selector, Seq(toSMT(e))) case AsInstanceOf(expr, cct) => toSMT(expr) - case IsInstanceOf(e, cct) => + case io@IsInstanceOf(e, cct) => declareSort(cct) - val cases = cct match { - case act: AbstractClassType => - act.knownCCDescendants - case cct: CaseClassType => + val cases = cct.lookupClass match { + case Some(act: TypedAbstractClassDef) => + act.descendants + case Some(cct: TypedCaseClassDef) => Seq(cct) + case None => + unsupported(io, "isInstanceOf on non-class") } - val oneOf = cases map testers.toB + val oneOf = cases map (_.toType) map testers.toB oneOf match { case Seq(tester) => FunctionApplication(tester, Seq(toSMT(e))) @@ -432,11 +381,11 @@ trait SMTLIBTarget extends Interruptible { val selector = selectors.toB((tpe, i - 1)) FunctionApplication(selector, Seq(toSMT(t))) - case al @ RawArraySelect(a, i) => + case al @ MapApply(a, i) => ArraysEx.Select(toSMT(a), toSMT(i)) - case al @ RawArrayUpdated(a, i, e) => - ArraysEx.Store(toSMT(a), toSMT(i), toSMT(e)) - case ra @ RawArrayValue(keyTpe, elems, default) => + case al @ MapUpdated(map, k, v) => + ArraysEx.Store(toSMT(map), toSMT(k), toSMT(v)) + case ra @ FiniteMap(elems, default, keyTpe) => val s = declareSort(ra.getType) var res: Term = FunctionApplication( @@ -448,56 +397,10 @@ trait SMTLIBTarget extends Interruptible { res - /** - * ===== Map operations ===== - */ - case m @ FiniteMap(elems, _, _) => - val mt @ MapType(from, to) = m.getType - declareSort(mt) - - toSMT(RawArrayValue(from, elems.map { - case (k, v) => k -> CaseClass(library.someType(to), Seq(v)) - }.toMap, CaseClass(library.noneType(to), Seq()))) - - case MapApply(m, k) => - val mt @ MapType(_, to) = m.getType - declareSort(mt) - // m(k) becomes - // (Some-value (select m k)) - FunctionApplication( - selectors.toB((library.someType(to), 0)), - Seq(ArraysEx.Select(toSMT(m), toSMT(k)))) - - case MapIsDefinedAt(m, k) => - val mt @ MapType(_, to) = m.getType - declareSort(mt) - // m.isDefinedAt(k) becomes - // (is-Some (select m k)) - FunctionApplication( - testers.toB(library.someType(to)), - Seq(ArraysEx.Select(toSMT(m), toSMT(k)))) - - case MapUnion(m1, FiniteMap(elems, _, _)) => - val MapType(_, t) = m1.getType - - elems.foldLeft(toSMT(m1)) { - case (m, (k, v)) => - ArraysEx.Store(m, toSMT(k), toSMT(CaseClass(library.someType(t), Seq(v)))) - } - - case p: Passes => - toSMT(matchToIfThenElse(p.asConstraint)) - - case m: MatchExpr => - toSMT(matchToIfThenElse(m)) - case gv @ GenericValue(tpe, n) => declareSort(tpe) val constructor = constructors.toB(tpe) - FunctionApplication(constructor, Seq(toSMT(InfiniteIntegerLiteral(n)))) - - case synthesis.utils.MutableExpr(ex) => - toSMT(ex) + FunctionApplication(constructor, Seq(toSMT(IntegerLiteral(n)))) /** * ===== Everything else ===== @@ -506,33 +409,54 @@ trait SMTLIBTarget extends Interruptible { val dyn = declareLambda(caller.getType.asInstanceOf[FunctionType]) FunctionApplication(dyn, (caller +: args).map(toSMT)) - case Not(u) => Core.Not(toSMT(u)) - case UMinus(u) => Ints.Neg(toSMT(u)) - case BVUMinus(u) => FixedSizeBitVectors.Neg(toSMT(u)) - case BVNot(u) => FixedSizeBitVectors.Not(toSMT(u)) - case Assert(a, _, b) => toSMT(IfExpr(a, b, Error(b.getType, "assertion failed"))) + case Not(u) => u.getType match { + case BooleanType => Core.Not(toSMT(u)) + case Int32Type => FixedSizeBitVectors.Not(toSMT(u)) + } + case UMinus(u) => u.getType match { + case IntegerType => Ints.Neg(toSMT(u)) + case Int32Type => FixedSizeBitVectors.Neg(toSMT(u)) + } case Equals(a, b) => Core.Equals(toSMT(a), toSMT(b)) case Implies(a, b) => Core.Implies(toSMT(a), toSMT(b)) - case Plus(a, b) => Ints.Add(toSMT(a), toSMT(b)) - case Minus(a, b) => Ints.Sub(toSMT(a), toSMT(b)) - case Times(a, b) => Ints.Mul(toSMT(a), toSMT(b)) - case Division(a, b) => { + case Plus(a, b) => a.getType match { + case Int32Type => FixedSizeBitVectors.Add(toSMT(a), toSMT(b)) + case IntegerType => Ints.Add(toSMT(a), toSMT(b)) + case RealType => Reals.Add(toSMT(a), toSMT(b)) + } + case Minus(a, b) => a.getType match { + case Int32Type => FixedSizeBitVectors.Sub(toSMT(a), toSMT(b)) + case IntegerType => Ints.Sub(toSMT(a), toSMT(b)) + case RealType => Reals.Sub(toSMT(a), toSMT(b)) + } + case Times(a, b) => a.getType match { + case Int32Type => FixedSizeBitVectors.Mul(toSMT(a), toSMT(b)) + case IntegerType => Ints.Mul(toSMT(a), toSMT(b)) + case RealType => Reals.Mul(toSMT(a), toSMT(b)) + } + + //FIXME + case Division(a, b) => val ar = toSMT(a) val br = toSMT(b) - Core.ITE( Ints.GreaterEquals(ar, Ints.NumeralLit(0)), Ints.Div(ar, br), - Ints.Neg(Ints.Div(Ints.Neg(ar), br))) - } - case Remainder(a, b) => { + Ints.Neg(Ints.Div(Ints.Neg(ar), br)) + ) + + //case BVDivision(a, b) => FixedSizeBitVectors.SDiv(toSMT(a), toSMT(b)) + //case BVRemainder(a, b) => FixedSizeBitVectors.SRem(toSMT(a), toSMT(b)) + //case RealDivision(a, b) => Reals.Div(toSMT(a), toSMT(b)) + + case Remainder(a, b) => val q = toSMT(Division(a, b)) Ints.Sub(toSMT(a), Ints.Mul(toSMT(b), q)) - } - case Modulo(a, b) => { + case Modulo(a, b) => Ints.Mod(toSMT(a), toSMT(b)) - } + // End FIXME + case LessThan(a, b) => a.getType match { case Int32Type => FixedSizeBitVectors.SLessThan(toSMT(a), toSMT(b)) case IntegerType => Ints.LessThan(toSMT(a), toSMT(b)) @@ -557,11 +481,7 @@ trait SMTLIBTarget extends Interruptible { case RealType => Reals.GreaterEquals(toSMT(a), toSMT(b)) case CharType => FixedSizeBitVectors.SGreaterEquals(toSMT(a), toSMT(b)) } - case BVPlus(a, b) => FixedSizeBitVectors.Add(toSMT(a), toSMT(b)) - case BVMinus(a, b) => FixedSizeBitVectors.Sub(toSMT(a), toSMT(b)) - case BVTimes(a, b) => FixedSizeBitVectors.Mul(toSMT(a), toSMT(b)) - case BVDivision(a, b) => FixedSizeBitVectors.SDiv(toSMT(a), toSMT(b)) - case BVRemainder(a, b) => FixedSizeBitVectors.SRem(toSMT(a), toSMT(b)) + case BVAnd(a, b) => FixedSizeBitVectors.And(toSMT(a), toSMT(b)) case BVOr(a, b) => FixedSizeBitVectors.Or(toSMT(a), toSMT(b)) case BVXOr(a, b) => FixedSizeBitVectors.XOr(toSMT(a), toSMT(b)) @@ -569,29 +489,24 @@ trait SMTLIBTarget extends Interruptible { case BVAShiftRight(a, b) => FixedSizeBitVectors.AShiftRight(toSMT(a), toSMT(b)) case BVLShiftRight(a, b) => FixedSizeBitVectors.LShiftRight(toSMT(a), toSMT(b)) - case RealPlus(a, b) => Reals.Add(toSMT(a), toSMT(b)) - case RealMinus(a, b) => Reals.Sub(toSMT(a), toSMT(b)) - case RealTimes(a, b) => Reals.Mul(toSMT(a), toSMT(b)) - case RealDivision(a, b) => Reals.Div(toSMT(a), toSMT(b)) + case And(sub) => SmtLibConstructors.and(sub.map(toSMT)) case Or(sub) => SmtLibConstructors.or(sub.map(toSMT)) case IfExpr(cond, thenn, elze) => Core.ITE(toSMT(cond), toSMT(thenn), toSMT(elze)) - case f @ FunctionInvocation(_, sub) => - if (sub.isEmpty) declareFunction(f.tfd) else { - FunctionApplication( - declareFunction(f.tfd), - sub.map(toSMT)) - } + case FunctionInvocation(id, tps, sub) => + val fun = declareFunction(symbols.getFunction(id).typed(tps)) + if (sub.isEmpty) fun + else FunctionApplication(fun, sub.map(toSMT)) case Forall(vs, bd) => - quantifiedTerm(SMTForall, vs map { _.id }, bd)(Map()) + quantifiedTerm(SMTForall, vs, bd)(Map()) case o => unsupported(o, "") } } /* Translate an SMTLIB term back to a Leon Expr */ - protected def fromSMT(t: Term, otpe: Option[TypeTree] = None)(implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): Expr = { + protected def fromSMT(t: Term, otpe: Option[Type] = None)(implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): Expr = { object EQ { def unapply(t: Term): Option[(Term, Term)] = t match { @@ -629,8 +544,8 @@ trait SMTLIBTarget extends Interruptible { case Some(dynLambda) => letDefs.get(dynLambda) match { case None => simplestValue(ft) case Some(DefineFun(SMTFunDef(a, SortedVar(dispatcher, dkind) +: args, rkind, body))) => - val lambdaArgs = from.map(tpe => FreshIdentifier("x", tpe, true)) - val argsMap: Map[Term, Identifier] = (args.map(sv => symbolToQualifiedId(sv.name)) zip lambdaArgs).toMap + val lambdaArgs = from.map(tpe => ValDef(FreshIdentifier("x", true), tpe)) + val argsMap: Map[Term, ValDef] = (args.map(sv => symbolToQualifiedId(sv.name)) zip lambdaArgs).toMap val d = symbolToQualifiedId(dispatcher) def dispatch(t: Term): Term = t match { @@ -671,7 +586,7 @@ trait SMTLIBTarget extends Interruptible { } val body = recCases(t) - Lambda(lambdaArgs.map(ValDef(_)), body) + Lambda(lambdaArgs, body) } extract(dispatch(body)) @@ -699,16 +614,16 @@ trait SMTLIBTarget extends Interruptible { case (SDecimal(d), Some(RealType)) => // converting bigdecimal to a fraction if (d == BigDecimal(0)) - FractionalLiteral(0, 1) + FractionLiteral(0, 1) else { d.toBigIntExact() match { case Some(num) => - FractionalLiteral(num, 1) + FractionLiteral(num, 1) case _ => val scale = d.scale val num = BigInt(d.bigDecimal.scaleByPowerOfTen(scale).toBigInteger()) val denom = BigInt(new java.math.BigDecimal(1).scaleByPowerOfTen(scale).toBigInteger()) - FractionalLiteral(num, denom) + FractionLiteral(num, denom) } } @@ -716,17 +631,18 @@ trait SMTLIBTarget extends Interruptible { extractLambda(n, ft) case (SNumeral(n), Some(RealType)) => - FractionalLiteral(n, 1) + FractionLiteral(n, 1) case (FunctionApplication(SimpleSymbol(SSymbol("ite")), Seq(cond, thenn, elze)), t) => IfExpr( fromSMT(cond, Some(BooleanType)), fromSMT(thenn, t), - fromSMT(elze, t)) + fromSMT(elze, t) + ) // Best-effort case case (SNumeral(n), _) => - InfiniteIntegerLiteral(n) + IntegerLiteral(n) // EK: Since we have no type information, we cannot do type-directed // extraction of defs, instead, we expand them in smt-world @@ -739,54 +655,36 @@ trait SMTLIBTarget extends Interruptible { case (SimpleSymbol(s), _) if constructors.containsB(s) => constructors.toA(s) match { - case cct: CaseClassType => - CaseClass(cct, Nil) + case ct: ClassType => + CaseClass(ct, Nil) case t => unsupported(t, "woot? for a single constructor for non-case-object") } case (FunctionApplication(SimpleSymbol(s), List(e)), _) if testers.containsB(s) => testers.toA(s) match { - case cct: CaseClassType => + case cct: ClassType => IsInstanceOf(fromSMT(e, cct), cct) } case (FunctionApplication(SimpleSymbol(s), List(e)), _) if selectors.containsB(s) => selectors.toA(s) match { - case (cct: CaseClassType, i) => - CaseClassSelector(cct, fromSMT(e, cct), cct.fields(i).id) + case (ct: ClassType, i) => + CaseClassSelector(fromSMT(e, ct), ct.lookupClass.get.toCase.fields(i).id) } case (FunctionApplication(SimpleSymbol(s), args), _) if constructors.containsB(s) => constructors.toA(s) match { - case cct: CaseClassType => - val rargs = args.zip(cct.fields.map(_.getType)).map(fromSMT) - CaseClass(cct, rargs) + case ct: ClassType => + val rargs = args.zip(ct.lookupClass.get.toCase.fields.map(_.getType)).map(fromSMT) + CaseClass(ct, rargs) case tt: TupleType => val rargs = args.zip(tt.bases).map(fromSMT) tupleWrap(rargs) - case at @ ArrayType(baseType) => - val IntLiteral(size) = fromSMT(args(0), Int32Type) - val RawArrayValue(_, elems, default) = fromSMT(args(1), RawArrayType(Int32Type, baseType)) - - if (size < 0) { - unsupported(at, "Cannot build an array of negative length: " + size) - } else if (size > 10) { - val definedElements = elems.collect { - case (IntLiteral(i), value) => (i, value) - } - finiteArray(definedElements, Some(default, IntLiteral(size)), baseType) - - } else { - val entries = for (i <- 0 to size - 1) yield elems.getOrElse(IntLiteral(i), default) - - finiteArray(entries, None, baseType) - } - case tp @ TypeParameter(id) => - val InfiniteIntegerLiteral(n) = fromSMT(args(0), IntegerType) + val IntegerLiteral(n) = fromSMT(args(0), IntegerType) GenericValue(tp, n.toInt) case t => @@ -808,13 +706,13 @@ trait SMTLIBTarget extends Interruptible { LessThan(fromSMT(a, IntegerType), fromSMT(b, IntegerType)) case ("+", args) => - args.map(fromSMT(_, otpe)).reduceLeft(plus _) + args.map(fromSMT(_, otpe)).reduceLeft(plus) case ("-", List(a)) if otpe == Some(RealType) => val aexpr = fromSMT(a, otpe) aexpr match { - case FractionalLiteral(na, da) => - FractionalLiteral(-na, da) + case FractionLiteral(na, da) => + FractionLiteral(-na, da) case _ => UMinus(aexpr) } @@ -822,8 +720,8 @@ trait SMTLIBTarget extends Interruptible { case ("-", List(a)) => val aexpr = fromSMT(a, otpe) aexpr match { - case InfiniteIntegerLiteral(v) => - InfiniteIntegerLiteral(-v) + case IntegerLiteral(v) => + IntegerLiteral(-v) case _ => UMinus(aexpr) } @@ -832,14 +730,14 @@ trait SMTLIBTarget extends Interruptible { Minus(fromSMT(a, otpe), fromSMT(b, otpe)) case ("*", args) => - args.map(fromSMT(_, otpe)).reduceLeft(times _) + args.map(fromSMT(_, otpe)).reduceLeft(times) case ("/", List(a, b)) if otpe == Some(RealType) => val aexpr = fromSMT(a, otpe) val bexpr = fromSMT(b, otpe) (aexpr, bexpr) match { - case (FractionalLiteral(na, da), FractionalLiteral(nb, db)) if da == 1 && db == 1 => - FractionalLiteral(na, nb) + case (FractionLiteral(na, da), FractionLiteral(nb, db)) if da == 1 && db == 1 => + FractionLiteral(na, nb) case _ => Division(aexpr, bexpr) } @@ -864,7 +762,7 @@ trait SMTLIBTarget extends Interruptible { Equals(ra, fromSMT(b, ra.getType)) case _ => - context.reporter.fatalError("Function " + app + " not handled in fromSMT: " + s) + ctx.reporter.fatalError("Function " + app + " not handled in fromSMT: " + s) } case (Core.True(), Some(BooleanType)) => BooleanLiteral(true) @@ -874,21 +772,21 @@ trait SMTLIBTarget extends Interruptible { fromSMT(lets(s), otpe) case (SimpleSymbol(s), otpe) => - variables.getA(s).map(_.toVariable).getOrElse { - throw new Exception() + variables.getA(s).map(Variable(_, otpe.get)).getOrElse { + ctx.reporter.fatalError("Could not find variable from SMT") } case _ => - context.reporter.fatalError(s"Unhandled case in fromSMT: $t : ${otpe.map(_.asString(context)).getOrElse("?")} (${t.getClass})") + ctx.reporter.fatalError(s"Unhandled case in fromSMT: $t : ${otpe.map(_.asString).getOrElse("?")} (${t.getClass})") } } - final protected def fromSMT(pair: (Term, TypeTree))(implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): Expr = { + final protected def fromSMT(pair: (Term, Type))(implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): Expr = { fromSMT(pair._1, Some(pair._2)) } - final protected def fromSMT(s: Term, tpe: TypeTree)(implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): Expr = { + final protected def fromSMT(s: Term, tpe: Type)(implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): Expr = { fromSMT(s, Some(tpe)) } }