From bc7d332a586afa81076beb912c3301ab37ce0d9e Mon Sep 17 00:00:00 2001 From: Nicolas Voirol <voirol.nicolas@gmail.com> Date: Mon, 10 Oct 2016 22:39:48 +0200 Subject: [PATCH] Lots of refactoring to enable encoding solvers --- src/it/scala/inox/InoxTestSuite.scala | 12 +- .../scala/inox/solvers/SolvingTestSuite.scala | 12 +- .../AssociativeQuantifiersSuite.scala | 12 +- .../inox/{InoxContext.scala => Context.scala} | 12 +- .../inox/{InoxOptions.scala => Options.scala} | 58 +-- src/main/scala/inox/Program.scala | 12 +- src/main/scala/inox/ast/Definitions.scala | 34 +- src/main/scala/inox/ast/Extractors.scala | 65 +-- src/main/scala/inox/ast/Paths.scala | 2 +- src/main/scala/inox/ast/Printers.scala | 10 +- src/main/scala/inox/ast/SymbolOps.scala | 4 +- src/main/scala/inox/ast/TreeOps.scala | 383 ++++++++++++++---- src/main/scala/inox/ast/Trees.scala | 4 +- .../scala/inox/evaluators/Evaluator.scala | 4 +- .../inox/evaluators/RecursiveEvaluator.scala | 4 +- .../inox/evaluators/SolvingEvaluator.scala | 14 +- .../inox/grammars/ExpressionGrammars.scala | 2 +- src/main/scala/inox/package.scala | 4 +- src/main/scala/inox/solvers/Solver.scala | 6 +- .../scala/inox/solvers/SolverFactory.scala | 30 +- .../solvers/combinators/EncodingSolver.scala | 57 +++ .../inox/solvers/smtlib/CVC4Solver.scala | 15 +- .../inox/solvers/theories/BagEncoder.scala | 4 +- .../inox/solvers/theories/StringEncoder.scala | 4 +- .../inox/solvers/theories/TheoryEncoder.scala | 66 +-- .../scala/inox/solvers/theories/package.scala | 4 +- .../solvers/unrolling/UnrollingSolver.scala | 95 +++-- .../inox/solvers/z3/NativeZ3Solver.scala | 21 +- src/test/scala/inox/ast/TreeTestsSuite.scala | 2 +- .../inox/evaluators/EvaluatorSuite.scala | 4 +- 30 files changed, 612 insertions(+), 344 deletions(-) rename src/main/scala/inox/{InoxContext.scala => Context.scala} (67%) rename src/main/scala/inox/{InoxOptions.scala => Options.scala} (73%) create mode 100644 src/main/scala/inox/solvers/combinators/EncodingSolver.scala diff --git a/src/it/scala/inox/InoxTestSuite.scala b/src/it/scala/inox/InoxTestSuite.scala index 89b1d0649..6ee309d99 100644 --- a/src/it/scala/inox/InoxTestSuite.scala +++ b/src/it/scala/inox/InoxTestSuite.scala @@ -9,19 +9,19 @@ import utils._ trait InoxTestSuite extends FunSuite with Matchers with Timeouts { - val configurations: Seq[Seq[InoxOption[Any]]] = Seq(Seq.empty) + val configurations: Seq[Seq[OptionValue[_]]] = Seq(Seq.empty) - private def optionsString(options: InoxOptions): String = { + private def optionsString(options: Options): String = { "solver=" + options.findOptionOrDefault(InoxOptions.optSelectedSolvers).head + " " + "feelinglucky=" + options.findOptionOrDefault(solvers.unrolling.optFeelingLucky) + " " + "checkmodels=" + options.findOptionOrDefault(solvers.optCheckModels) + " " + "unrollassumptions=" + options.findOptionOrDefault(solvers.unrolling.optUnrollAssumptions) } - protected def test(name: String, tags: Tag*)(body: InoxContext => Unit): Unit = { + protected def test(name: String, tags: Tag*)(body: Context => Unit): Unit = { for (config <- configurations) { val reporter = new TestSilentReporter - val ctx = InoxContext(reporter, new InterruptManager(reporter), InoxOptions(config)) + val ctx = Context(reporter, new InterruptManager(reporter), Options(config)) try { super.test(name + " " + optionsString(ctx.options))(body(ctx)) } catch { @@ -32,9 +32,9 @@ trait InoxTestSuite extends FunSuite with Matchers with Timeouts { } } - protected def ignore(name: String, tags: Tag*)(body: InoxContext => Unit): Unit = { + protected def ignore(name: String, tags: Tag*)(body: Context => Unit): Unit = { for (config <- configurations) { - super.ignore(name + " " + optionsString(InoxOptions(config)))(()) + super.ignore(name + " " + optionsString(Options(config)))(()) } } } diff --git a/src/it/scala/inox/solvers/SolvingTestSuite.scala b/src/it/scala/inox/solvers/SolvingTestSuite.scala index 63bc04664..7129c9fb7 100644 --- a/src/it/scala/inox/solvers/SolvingTestSuite.scala +++ b/src/it/scala/inox/solvers/SolvingTestSuite.scala @@ -11,11 +11,11 @@ trait SolvingTestSuite extends InoxTestSuite { feelingLucky <- Seq(false, true) unrollAssumptions <- Seq(false, true) } yield Seq( - InoxOption(InoxOptions.optSelectedSolvers)(Set(solverName)), - InoxOption(optCheckModels)(checkModels), - InoxOption(unrolling.optFeelingLucky)(feelingLucky), - InoxOption(unrolling.optUnrollAssumptions)(unrollAssumptions), - InoxOption(InoxOptions.optTimeout)(300), - InoxOption(ast.optPrintUniqueIds)(true) + InoxOptions.optSelectedSolvers(Set(solverName)), + optCheckModels(checkModels), + unrolling.optFeelingLucky(feelingLucky), + unrolling.optUnrollAssumptions(unrollAssumptions), + InoxOptions.optTimeout(300), + ast.optPrintUniqueIds(true) ) } diff --git a/src/it/scala/inox/solvers/unrolling/AssociativeQuantifiersSuite.scala b/src/it/scala/inox/solvers/unrolling/AssociativeQuantifiersSuite.scala index 3693f2a10..28be5f5ef 100644 --- a/src/it/scala/inox/solvers/unrolling/AssociativeQuantifiersSuite.scala +++ b/src/it/scala/inox/solvers/unrolling/AssociativeQuantifiersSuite.scala @@ -16,12 +16,12 @@ class AssociativeQuantifiersSuite extends InoxTestSuite { ("nativez3", false, false, true ), ("smt-cvc4", false, false, true ) ).map { case (solverName, checkModels, feelingLucky, unrollAssumptions) => Seq( - InoxOption(InoxOptions.optSelectedSolvers)(Set(solverName)), - InoxOption(optCheckModels)(checkModels), - InoxOption(optFeelingLucky)(feelingLucky), - InoxOption(optUnrollAssumptions)(unrollAssumptions), - InoxOption(InoxOptions.optTimeout)(300), - InoxOption(ast.optPrintUniqueIds)(true) + InoxOptions.optSelectedSolvers(Set(solverName)), + optCheckModels(checkModels), + optFeelingLucky(feelingLucky), + optUnrollAssumptions(unrollAssumptions), + InoxOptions.optTimeout(300), + ast.optPrintUniqueIds(true) )} val isAssociativeID = FreshIdentifier("isAssociative") diff --git a/src/main/scala/inox/InoxContext.scala b/src/main/scala/inox/Context.scala similarity index 67% rename from src/main/scala/inox/InoxContext.scala rename to src/main/scala/inox/Context.scala index 42086b486..1cf9d8dab 100644 --- a/src/main/scala/inox/InoxContext.scala +++ b/src/main/scala/inox/Context.scala @@ -8,24 +8,24 @@ import inox.utils._ * Contexts are immutable, and so should all their fields (with the possible * exception of the reporter). */ -case class InoxContext( +case class Context( reporter: Reporter, interruptManager: InterruptManager, - options: InoxOptions = InoxOptions(Seq()), + options: Options = Options(Seq()), timers: TimerStorage = new TimerStorage) -object InoxContext { +object Context { def empty = { val reporter = new DefaultReporter(Set()) - InoxContext(reporter, new InterruptManager(reporter)) + Context(reporter, new InterruptManager(reporter)) } def printNames = { val reporter = new DefaultReporter(Set()) - InoxContext( + Context( reporter, new InterruptManager(reporter), - options = InoxOptions.empty + InoxOption[Set[DebugSection]](InoxOptions.optDebug)(Set(ast.DebugSectionTrees)) + options = Options.empty + InoxOptions.optDebug(Set(ast.DebugSectionTrees: DebugSection)) ) } } diff --git a/src/main/scala/inox/InoxOptions.scala b/src/main/scala/inox/Options.scala similarity index 73% rename from src/main/scala/inox/InoxOptions.scala rename to src/main/scala/inox/Options.scala index dee27c901..e19fdff3b 100644 --- a/src/main/scala/inox/InoxOptions.scala +++ b/src/main/scala/inox/Options.scala @@ -7,7 +7,7 @@ import OptionParsers._ import scala.util.Try import scala.reflect.ClassTag -abstract class InoxOptionDef[+A] { +abstract class OptionDef[A] { val name: String val description: String def default: A @@ -32,49 +32,51 @@ abstract class InoxOptionDef[+A] { ) } - def parse(s: String)(implicit reporter: Reporter): InoxOption[A] = - InoxOption(this)(parseValue(s)) + def parse(s: String)(implicit reporter: Reporter): OptionValue[A] = + OptionValue(this)(parseValue(s)) - def withDefaultValue: InoxOption[A] = - InoxOption(this)(default) + def withDefaultValue: OptionValue[A] = + OptionValue(this)(default) + + def apply(value: A): OptionValue[A] = OptionValue(this)(value) // @mk: FIXME: Is this cool? override def equals(other: Any) = other match { - case that: InoxOptionDef[_] => this.name == that.name + case that: OptionDef[_] => this.name == that.name case _ => false } override def hashCode = name.hashCode } -case class InoxFlagOptionDef(name: String, description: String, default: Boolean) extends InoxOptionDef[Boolean] { +case class FlagOptionDef(name: String, description: String, default: Boolean) extends OptionDef[Boolean] { val parser = booleanParser val usageRhs = "" } -case class InoxStringOptionDef(name: String, description: String, default: String, usageRhs: String) extends InoxOptionDef[String] { +case class StringOptionDef(name: String, description: String, default: String, usageRhs: String) extends OptionDef[String] { val parser = stringParser } -case class InoxLongOptionDef(name: String, description: String, default: Long, usageRhs: String) extends InoxOptionDef[Long] { +case class LongOptionDef(name: String, description: String, default: Long, usageRhs: String) extends OptionDef[Long] { val parser = longParser } -class InoxOption[+A] private (val optionDef: InoxOptionDef[A], val value: A) { +class OptionValue[A] private (val optionDef: OptionDef[A], val value: A) { override def toString = s"--${optionDef.name}=$value" override def equals(other: Any) = other match { - case InoxOption(optionDef, value) => + case OptionValue(optionDef, value) => optionDef.name == this.optionDef.name && value == this.value case _ => false } override def hashCode = optionDef.hashCode + value.hashCode } -object InoxOption { - def apply[A](optionDef: InoxOptionDef[A])(value: A) = { - new InoxOption(optionDef, value) +object OptionValue { + def apply[A](optionDef: OptionDef[A])(value: A) = { + new OptionValue(optionDef, value) } - def unapply[A](opt: InoxOption[A]) = Some((opt.optionDef, opt.value)) + def unapply[A](opt: OptionValue[A]) = Some((opt.optionDef, opt.value)) } object OptionParsers { @@ -154,32 +156,34 @@ object OptionsHelpers { } } -case class InoxOptions(options: Seq[InoxOption[Any]]) { +case class Options(options: Seq[OptionValue[_]]) { - def findOption[A: ClassTag](optDef: InoxOptionDef[A]): Option[A] = options.collectFirst { - case InoxOption(`optDef`, value: A) => value + def findOption[A: ClassTag](optDef: OptionDef[A]): Option[A] = options.collectFirst { + case OptionValue(`optDef`, value: A) => value } - def findOptionOrDefault[A: ClassTag](optDef: InoxOptionDef[A]): A = findOption(optDef).getOrElse(optDef.default) + def findOptionOrDefault[A: ClassTag](optDef: OptionDef[A]): A = findOption(optDef).getOrElse(optDef.default) - def +(newOpt: InoxOption[Any]): InoxOptions = InoxOptions( + def +(newOpt: OptionValue[_]): Options = Options( options.filter(_.optionDef != newOpt.optionDef) :+ newOpt ) - def ++(newOpts: Seq[InoxOption[Any]]): InoxOptions = InoxOptions { + def ++(newOpts: Seq[OptionValue[_]]): Options = Options { val defs = newOpts.map(_.optionDef).toSet options.filter(opt => !defs(opt.optionDef)) ++ newOpts } @inline - def ++(that: InoxOptions): InoxOptions = this ++ that.options + def ++(that: Options): Options = this ++ that.options } -object InoxOptions { +object Options { + def empty: Options = Options(Seq()) +} - def empty: InoxOptions = InoxOptions(Seq()) +object InoxOptions { - val optSelectedSolvers = new InoxOptionDef[Set[String]] { + val optSelectedSolvers = new OptionDef[Set[String]] { val name = "solvers" val description = "Use solvers s1, s2,...\n" + solvers.SolverFactory.solversPretty @@ -188,7 +192,7 @@ object InoxOptions { val usageRhs = "s1,s2,..." } - val optDebug = new InoxOptionDef[Set[DebugSection]] { + val optDebug = new OptionDef[Set[DebugSection]] { import OptionParsers._ val name = "debug" val description = { @@ -218,7 +222,7 @@ object InoxOptions { } } - val optTimeout = InoxLongOptionDef( + val optTimeout = LongOptionDef( "timeout", "Set a timeout for attempting to prove a verification condition/ repair a function (in sec.)", 0L, diff --git a/src/main/scala/inox/Program.scala b/src/main/scala/inox/Program.scala index 7dfbf25ee..374464d05 100644 --- a/src/main/scala/inox/Program.scala +++ b/src/main/scala/inox/Program.scala @@ -19,17 +19,25 @@ import ast._ trait Program { val trees: Trees implicit val symbols: trees.Symbols - implicit val ctx: InoxContext + implicit val ctx: Context implicit def implicitProgram: this.type = this implicit def printerOpts: trees.PrinterOptions = trees.PrinterOptions.fromSymbols(symbols, ctx) - def transform(t: trees.TreeTransformer): Program { val trees: Program.this.trees.type } = new Program { + def transform(t: trees.SelfTransformer): Program { val trees: Program.this.trees.type } = new Program { val trees: Program.this.trees.type = Program.this.trees val symbols = Program.this.symbols.transform(t) val ctx = Program.this.ctx } + def transform(t: SymbolTransformer { + val transformer: TreeTransformer { val s: trees.type } + }): Program { val trees: t.t.type } = new Program { + val trees: t.t.type = t.t + val symbols = t.transform(Program.this.symbols) + val ctx = Program.this.ctx + } + def withFunctions(functions: Seq[trees.FunDef]): Program { val trees: Program.this.trees.type } = new Program { val trees: Program.this.trees.type = Program.this.trees val symbols = Program.this.symbols.withFunctions(functions) diff --git a/src/main/scala/inox/ast/Definitions.scala b/src/main/scala/inox/ast/Definitions.scala index 75090875c..9fbac46da 100644 --- a/src/main/scala/inox/ast/Definitions.scala +++ b/src/main/scala/inox/ast/Definitions.scala @@ -129,32 +129,16 @@ trait Definitions { self: Trees => def getFunction(id: Identifier): FunDef = lookupFunction(id).getOrElse(throw FunctionLookupException(id)) def getFunction(id: Identifier, tps: Seq[Type]): TypedFunDef = lookupFunction(id, tps).getOrElse(throw FunctionLookupException(id)) - override def toString: String = asString(PrinterOptions.fromSymbols(this, InoxContext.printNames)) + override def toString: String = asString(PrinterOptions.fromSymbols(this, Context.printNames)) override def asString(implicit opts: PrinterOptions): String = { adts.map(p => prettyPrint(p._2, opts)).mkString("\n\n") + "\n\n-----------\n\n" + functions.map(p => prettyPrint(p._2, opts)).mkString("\n\n") } - def transform(t: TreeTransformer): Symbols = NoSymbols.withFunctions { - functions.values.toSeq.map(fd => new FunDef( - fd.id, - fd.tparams, // type parameters can't be transformed! - fd.params.map(vd => t.transform(vd)), - t.transform(fd.returnType), - t.transform(fd.fullBody), - fd.flags)) - }.withADTs { - adts.values.toSeq.map { - case sort: ADTSort => sort - case cons: ADTConstructor => new ADTConstructor( - cons.id, - cons.tparams, - cons.sort, - cons.fields.map(t.transform), - cons.flags) - } - } + def transform(trans: SelfTransformer): Symbols = new SymbolTransformer { + val transformer: trans.type = trans + }.transform(this) override def equals(that: Any): Boolean = that match { case sym: AbstractSymbols => functions == sym.functions && adts == sym.adts @@ -172,7 +156,11 @@ trait Definitions { self: Trees => val id = tp.id } - /** Represents source code annotations and some other meaningful flags. */ + /** Represents source code annotations and some other meaningful flags. + * + * In order to enable transformations on [[Flag]] instances, there is an + * implicit contract on [[args]] such that for each argument, either + * {{{arg: Expr | Type}}}, or there exists no tree instance within arg. */ abstract class Flag(name: String, args: Seq[Any]) extends Printable { def asString(implicit opts: PrinterOptions): String = name + (if (args.isEmpty) "" else { args.map(arg => self.asString(arg)(opts)).mkString("(", ", ", ")") @@ -182,7 +170,9 @@ trait Definitions { self: Trees => /** Denotes that this adt is refined by invariant ''id'' */ case class HasADTInvariant(id: Identifier) extends Flag("invariant", Seq(id)) - // Compiler annotations given in the source code as @annot + /** Compiler annotations given in the source code as @annot. + * + * @see [[Flag]] for some notes on the actual type of [[args]]. */ case class Annotation(val name: String, val args: Seq[Any]) extends Flag(name, args) def extractFlag(name: String, args: Seq[Any]): Flag = (name, args) match { diff --git a/src/main/scala/inox/ast/Extractors.scala b/src/main/scala/inox/ast/Extractors.scala index cda2998ec..cecd9b426 100644 --- a/src/main/scala/inox/ast/Extractors.scala +++ b/src/main/scala/inox/ast/Extractors.scala @@ -210,59 +210,20 @@ trait TreeDeconstructor { case s.StringType => (Seq(), _ => t.StringType) } - def translate(v: s.Variable): t.Variable = { - val newTpe = translate(v.tpe) - if (v.tpe ne newTpe) { - t.Variable(v.id, newTpe).copiedFrom(v) - } else { - v.asInstanceOf[t.Variable] - } - } - - def translate(e: s.Expr): t.Expr = { - val (vs, es, tps, builder) = deconstruct(e) - - var changed = false - val newVs = for (v <- vs) yield { - val newV = translate(v) - if (v ne newV) changed = true - newV - } - - val newEs = for (e <- es) yield { - val newE = translate(e) - if (e ne newE) changed = true - newE - } - - val newTps = for (tp <- tps) yield { - val newTp = translate(tp) - if (tp ne newTp) changed = true - newTp - } - - if (changed || (s ne t)) { - builder(newVs, newEs, newTps).copiedFrom(e) - } else { - e.asInstanceOf[t.Expr] - } - } + def deconstruct(f: s.Flag): (Seq[s.Expr], Seq[s.Type], (Seq[t.Expr], Seq[t.Type]) => t.Flag) = f match { + case s.HasADTInvariant(id) => + (Seq(), Seq(), (_, _) => t.HasADTInvariant(id)) + case s.Annotation(name, args) => + val withIndex = args.zipWithIndex + val (exprs, exprIndexes) = withIndex.collect { case (e: s.Expr, i) => e -> i }.unzip + val (types, typeIndexes) = withIndex.collect { case (tp: s.Type, i) => tp -> i }.unzip - def translate(tp: s.Type): t.Type = { - val (tps, builder) = deconstruct(tp) - - var changed = false - val newTps = for (tp <- tps) yield { - val newTp = translate(tp) - if (tp ne newTp) changed = true - newTp - } - - if (changed || (s ne t)) { - builder(newTps).copiedFrom(tp) - } else { - tp.asInstanceOf[t.Type] - } + // we use the implicit contract on Flags here that states that a flag is either + // an instance of Expr | Type, or it has nothing to do with a tree + val rest = withIndex.filterNot(_._1.isInstanceOf[s.Tree]) + (exprs, types, (es, tps) => t.Annotation(name, + ((es zip exprIndexes) ++ (tps zip typeIndexes) ++ rest).sortBy(_._2).map(_._1) + )) } } diff --git a/src/main/scala/inox/ast/Paths.scala b/src/main/scala/inox/ast/Paths.scala index 88e00a617..9c78d2677 100644 --- a/src/main/scala/inox/ast/Paths.scala +++ b/src/main/scala/inox/ast/Paths.scala @@ -224,7 +224,7 @@ trait Paths { self: TypeOps with Constructors => override def hashCode: Int = elements.hashCode - override def toString = asString(PrinterOptions.fromContext(InoxContext.printNames)) + override def toString = asString(PrinterOptions.fromContext(Context.printNames)) def asString(implicit opts: PrinterOptions): String = fullClause.asString } } diff --git a/src/main/scala/inox/ast/Printers.scala b/src/main/scala/inox/ast/Printers.scala index c2385a931..998527990 100644 --- a/src/main/scala/inox/ast/Printers.scala +++ b/src/main/scala/inox/ast/Printers.scala @@ -7,9 +7,9 @@ import utils._ import org.apache.commons.lang3.StringEscapeUtils import scala.language.implicitConversions -object optPrintPositions extends InoxFlagOptionDef("printpositions", "Attach positions to trees when printing", false) -object optPrintUniqueIds extends InoxFlagOptionDef("printids", "Always print unique ids", false) -object optPrintTypes extends InoxFlagOptionDef("printtypes", "Attach types to trees when printing", false) +object optPrintPositions extends FlagOptionDef("printpositions", "Attach positions to trees when printing", false) +object optPrintUniqueIds extends FlagOptionDef("printids", "Always print unique ids", false) +object optPrintTypes extends FlagOptionDef("printtypes", "Attach types to trees when printing", false) trait Printers { self: Trees => @@ -36,7 +36,7 @@ trait Printers { } object PrinterOptions { - def fromContext(ctx: InoxContext): PrinterOptions = { + def fromContext(ctx: Context): PrinterOptions = { PrinterOptions( baseIndent = 0, printPositions = ctx.options.findOptionOrDefault(optPrintPositions), @@ -46,7 +46,7 @@ trait Printers { ) } - def fromSymbols(s: Symbols, ctx: InoxContext): PrinterOptions = { + def fromSymbols(s: Symbols, ctx: Context): PrinterOptions = { fromContext(ctx).copy(symbols = Some(s)) } } diff --git a/src/main/scala/inox/ast/SymbolOps.scala b/src/main/scala/inox/ast/SymbolOps.scala index 35c1c8416..8cbbaf280 100644 --- a/src/main/scala/inox/ast/SymbolOps.scala +++ b/src/main/scala/inox/ast/SymbolOps.scala @@ -108,7 +108,7 @@ trait SymbolOps { self: TypeOps => (Seq[ValDef], Expr, Map[Variable, Expr]) = synchronized { val vars = args.map(_.toVariable).toSet - class Normalizer extends TreeTransformer { + class Normalizer extends SelfTreeTransformer { var subst: Map[Variable, Expr] = Map.empty var varSubst: Map[Identifier, Identifier] = Map.empty var remainingIds: Map[Type, List[Identifier]] = typedIds.toMap @@ -660,7 +660,7 @@ trait SymbolOps { self: TypeOps => } // Helpers for instantiateType - class TypeInstantiator(tps: Map[TypeParameter, Type]) extends TreeTransformer { + class TypeInstantiator(tps: Map[TypeParameter, Type]) extends SelfTreeTransformer { override def transform(tpe: Type): Type = tpe match { case tp: TypeParameter => tps.getOrElse(tp, super.transform(tpe)) case _ => super.transform(tpe) diff --git a/src/main/scala/inox/ast/TreeOps.scala b/src/main/scala/inox/ast/TreeOps.scala index 78ccb8c66..a5cc023e1 100644 --- a/src/main/scala/inox/ast/TreeOps.scala +++ b/src/main/scala/inox/ast/TreeOps.scala @@ -4,81 +4,33 @@ package ast trait TreeOps { self: Trees => - trait TreeTransformer { - def transform(id: Identifier, tpe: Type): (Identifier, Type) = (id, transform(tpe)) - - def transform(v: Variable): Variable = { - val (id, tpe) = transform(v.id, v.tpe) - if ((id ne v.id) || (tpe ne v.tpe)) { - Variable(id, tpe).copiedFrom(v) - } else { - v - } - } - - def transform(vd: ValDef): ValDef = { - val (id, es, Seq(tpe), builder) = deconstructor.deconstruct(vd) - val (newId, newTpe) = transform(id, tpe) - - var changed = false - val newEs = for (e <- es) yield { - val newE = transform(e) - if (e ne newE) changed = true - newE - } - - if ((id ne newId) || (tpe ne newTpe) || changed) { - builder(newId, newEs, Seq(newTpe)).copiedFrom(vd).asInstanceOf[ValDef] - } else { - vd - } - } - - def transform(e: Expr): Expr = { - val (vs, es, tps, builder) = deconstructor.deconstruct(e) - - var changed = false - val newVs = for (v <- vs) yield { - val newV = transform(v) - if (v ne newV) changed = true - newV - } - - val newEs = for (e <- es) yield { - val newE = transform(e) - if (e ne newE) changed = true - newE - } - - val newTps = for (tp <- tps) yield { - val newTp = transform(tp) - if (tp ne newTp) changed = true - newTp - } + type SelfTransformer = TreeTransformer { + val s: self.type + val t: self.type + } - if (changed) { - builder(newVs, newEs, newTps).copiedFrom(e) - } else { - e - } - } + trait SelfTreeTransformer extends TreeTransformer { + val s: self.type = self + val t: self.type = self - def transform(t: Type): Type = { - val (tps, builder) = deconstructor.deconstruct(t) + lazy val deconstructor: TreeDeconstructor { + val s: self.type + val t: self.type + } = self.deconstructor + } - var changed = false - val newTps = for (tp <- tps) yield { - val newTp = transform(tp) - if (tp ne newTp) changed = true - newTp - } + lazy val TreeIdentity = new SelfTreeTransformer { + override def transform(id: Identifier, tpe: s.Type): (Identifier, t.Type) = (id, tpe) + override def transform(v: s.Variable): t.Variable = v + override def transform(vd: s.ValDef): t.ValDef = vd + override def transform(e: s.Expr): t.Expr = e + override def transform(tpe: s.Type): t.Type = tpe + override def transform(flag: s.Flag): t.Flag = flag + } - if (changed) { - builder(newTps).copiedFrom(t) - } else { - t - } - } + lazy val SymbolIdentity = new SymbolTransformer { + val transformer = TreeIdentity + override def transform(syms: s.Symbols): t.Symbols = syms } trait TreeTraverser { @@ -99,3 +51,292 @@ trait TreeOps { self: Trees => } } } + +trait TreeTransformer { + val s: Trees + val t: Trees + + val deconstructor: TreeDeconstructor { + val s: TreeTransformer.this.s.type + val t: TreeTransformer.this.t.type + } + + def transform(id: Identifier, tpe: s.Type): (Identifier, t.Type) = (id, transform(tpe)) + + def transform(v: s.Variable): t.Variable = { + val (id, tpe) = transform(v.id, v.tpe) + if ((id ne v.id) || (tpe ne v.tpe) || (s ne t)) { + t.Variable(id, tpe).copiedFrom(v) + } else { + v.asInstanceOf[t.Variable] + } + } + + def transform(vd: s.ValDef): t.ValDef = { + val (id, es, Seq(tpe), builder) = deconstructor.deconstruct(vd) + val (newId, newTpe) = transform(id, tpe) + + var changed = false + val newEs = for (e <- es) yield { + val newE = transform(e) + if (e ne newE) changed = true + newE + } + + if ((id ne newId) || (tpe ne newTpe) || changed || (s ne t)) { + builder(newId, newEs, Seq(newTpe)).copiedFrom(vd).asInstanceOf[t.ValDef] + } else { + vd.asInstanceOf[t.ValDef] + } + } + + def transform(e: s.Expr): t.Expr = { + val (vs, es, tps, builder) = deconstructor.deconstruct(e) + + var changed = false + val newVs = for (v <- vs) yield { + val newV = transform(v) + if (v ne newV) changed = true + newV + } + + val newEs = for (e <- es) yield { + val newE = transform(e) + if (e ne newE) changed = true + newE + } + + val newTps = for (tp <- tps) yield { + val newTp = transform(tp) + if (tp ne newTp) changed = true + newTp + } + + if (changed || (s ne t)) { + builder(newVs, newEs, newTps).copiedFrom(e) + } else { + e.asInstanceOf[t.Expr] + } + } + + def transform(tpe: s.Type): t.Type = { + val (tps, builder) = deconstructor.deconstruct(tpe) + + var changed = false + val newTps = for (tp <- tps) yield { + val newTp = transform(tp) + if (tp ne newTp) changed = true + newTp + } + + if (changed || (s ne t)) { + builder(newTps).copiedFrom(tpe) + } else { + tpe.asInstanceOf[t.Type] + } + } + + def transform(flag: s.Flag): t.Flag = { + val (es, tps, builder) = deconstructor.deconstruct(flag) + + var changed = false + val newEs = for (e <- es) yield { + val newE = transform(e) + if (e ne newE) changed = true + newE + } + + val newTps = for (tp <- tps) yield { + val newTp = transform(tp) + if (tp ne newTp) changed = true + newTp + } + + if (changed || (s ne t)) { + builder(newEs, newTps) + } else { + flag.asInstanceOf[t.Flag] + } + } + + protected trait TreeTransformerComposition extends TreeTransformer { + protected val t1: TreeTransformer + protected val t2: TreeTransformer { val s: t1.t.type } + + lazy val s: t1.s.type = t1.s + lazy val t: t2.t.type = t2.t + + override def transform(id: Identifier, tpe: s.Type): (Identifier, t.Type) = { + val (id1, tp1) = t1.transform(id, tpe) + t2.transform(id1, tp1) + } + + override def transform(v: s.Variable): t.Variable = t2.transform(t1.transform(v)) + override def transform(vd: s.ValDef): t.ValDef = t2.transform(t1.transform(vd)) + override def transform(e: s.Expr): t.Expr = t2.transform(t1.transform(e)) + override def transform(tpe: s.Type): t.Type = t2.transform(t1.transform(tpe)) + override def transform(flag: s.Flag): t.Flag = t2.transform(t1.transform(flag)) + } + + def compose(that: TreeTransformer { val t: TreeTransformer.this.s.type }): TreeTransformer { + val s: that.s.type + val t: TreeTransformer.this.t.type + } = { + // the scala type checker doesn't realize that this relation must hold here + that andThen this.asInstanceOf[TreeTransformer { + val s: that.t.type + val t: TreeTransformer.this.t.type + }] + } + + def andThen(that: TreeTransformer { val s: TreeTransformer.this.t.type }): TreeTransformer { + val s: TreeTransformer.this.s.type + val t: that.t.type + } = new TreeTransformerComposition { + val t1: TreeTransformer.this.type = TreeTransformer.this + val t2: that.type = that + + lazy val deconstructor: TreeDeconstructor { + val s: TreeTransformer.this.s.type + val t: that.t.type + } = new TreeDeconstructor { + protected val s: TreeTransformer.this.s.type = TreeTransformer.this.s + protected val t: that.t.type = that.t + } + } +} + +/** Symbol table transformer */ +trait SymbolTransformer { + val transformer: TreeTransformer + lazy val s: transformer.s.type = transformer.s + lazy val t: transformer.t.type = transformer.t + + def transform(id: Identifier, tpe: s.Type): (Identifier, t.Type) = transformer.transform(id, tpe) + def transform(v: s.Variable): t.Variable = transformer.transform(v) + def transform(vd: s.ValDef): t.ValDef = transformer.transform(vd) + def transform(e: s.Expr): t.Expr = transformer.transform(e) + def transform(tpe: s.Type): t.Type = transformer.transform(tpe) + def transform(flag: s.Flag): t.Flag = transformer.transform(flag) + + /* Type parameters can't be modified by transformed but they need to be + * translated into the new tree definitions given by `t`. */ + protected def transformTypeParams(tparams: Seq[s.TypeParameterDef]): Seq[t.TypeParameterDef] = { + if (s eq t) tparams.asInstanceOf[Seq[t.TypeParameterDef]] + else tparams.map(tdef => t.TypeParameterDef(t.TypeParameter(tdef.id))) + } + + def transform(syms: s.Symbols): t.Symbols = t.NoSymbols.withFunctions { + syms.functions.values.toSeq.map(fd => new t.FunDef( + fd.id, + transformTypeParams(fd.tparams), + fd.params.map(vd => transformer.transform(vd)), + transformer.transform(fd.returnType), + transformer.transform(fd.fullBody), + fd.flags.map(f => transformer.transform(f)))) + }.withADTs { + syms.adts.values.toSeq.map { + case sort: s.ADTSort if (s eq t) => sort.asInstanceOf[t.ADTSort] + case sort: s.ADTSort => new t.ADTSort( + sort.id, + transformTypeParams(sort.tparams), + sort.cons, + sort.flags.map(f => transformer.transform(f))) + case cons: s.ADTConstructor => new t.ADTConstructor( + cons.id, + transformTypeParams(cons.tparams), + cons.sort, + cons.fields.map(vd => transformer.transform(vd)), + cons.flags.map(f => transformer.transform(f))) + } + } + + def compose(that: SymbolTransformer { + val transformer: TreeTransformer { val t: SymbolTransformer.this.s.type } + }): SymbolTransformer { + val transformer: TreeTransformer { + val s: that.s.type + val t: SymbolTransformer.this.t.type + } + } = new SymbolTransformer { + val transformer = SymbolTransformer.this.transformer compose that.transformer + override def transform(syms: s.Symbols): t.Symbols = SymbolTransformer.this.transform(that.transform(syms)) + } + + def andThen(that: SymbolTransformer { + val transformer: TreeTransformer { val s: SymbolTransformer.this.t.type } + }): SymbolTransformer { + val transformer: TreeTransformer { + val s: SymbolTransformer.this.s.type + val t: that.t.type + } + } = { + // the scala compiler doesn't realize that this relation must hold here + that compose this.asInstanceOf[SymbolTransformer { + val transformer: TreeTransformer { + val s: SymbolTransformer.this.s.type + val t: that.s.type + } + }] + } +} + +trait TreeBijection { + val s: Trees + val t: Trees + + val encoder: SymbolTransformer { val transformer: TreeTransformer { + val s: TreeBijection.this.s.type + val t: TreeBijection.this.t.type + }} + + val decoder: SymbolTransformer { val transformer: TreeTransformer { + val s: TreeBijection.this.t.type + val t: TreeBijection.this.s.type + }} + + def encode(vd: s.ValDef): t.ValDef = encoder.transform(vd) + def decode(vd: t.ValDef): s.ValDef = decoder.transform(vd) + + def encode(v: s.Variable): t.Variable = encoder.transform(v) + def decode(v: t.Variable): s.Variable = decoder.transform(v) + + def encode(e: s.Expr): t.Expr = encoder.transform(e) + def decode(e: t.Expr): s.Expr = decoder.transform(e) + + def encode(tpe: s.Type): t.Type = encoder.transform(tpe) + def decode(tpe: t.Type): s.Type = decoder.transform(tpe) + + def inverse: TreeBijection { + val s: TreeBijection.this.t.type + val t: TreeBijection.this.s.type + } = new TreeBijection { + val s: TreeBijection.this.t.type = TreeBijection.this.t + val t: TreeBijection.this.s.type = TreeBijection.this.s + + val encoder = TreeBijection.this.decoder + val decoder = TreeBijection.this.encoder + } + + def compose(that: TreeBijection { val t: TreeBijection.this.s.type }): TreeBijection { + val s: that.s.type + val t: TreeBijection.this.t.type + } = new TreeBijection { + val s: that.s.type = that.s + val t: TreeBijection.this.t.type = TreeBijection.this.t + + val encoder = TreeBijection.this.encoder compose that.encoder + val decoder = that.decoder compose TreeBijection.this.decoder + } + + def andThen(that: TreeBijection { val s: TreeBijection.this.t.type }): TreeBijection { + val s: TreeBijection.this.s.type + val t: that.t.type + } = new TreeBijection { + val s: TreeBijection.this.s.type = TreeBijection.this.s + val t: that.t.type = that.t + + val encoder = TreeBijection.this.encoder andThen that.encoder + val decoder = that.decoder andThen TreeBijection.this.decoder + } +} diff --git a/src/main/scala/inox/ast/Trees.scala b/src/main/scala/inox/ast/Trees.scala index ac08971ac..57a9bdc17 100644 --- a/src/main/scala/inox/ast/Trees.scala +++ b/src/main/scala/inox/ast/Trees.scala @@ -19,7 +19,7 @@ trait Trees type Identifier = ast.Identifier val FreshIdentifier = ast.FreshIdentifier - class Unsupported(t: Tree, msg: String)(implicit ctx: InoxContext) + class Unsupported(t: Tree, msg: String)(implicit ctx: Context) extends Exception(s"${t.asString(PrinterOptions.fromContext(ctx))}@${t.getPos} $msg") abstract class Tree extends utils.Positioned with Serializable { @@ -29,7 +29,7 @@ trait Trees def asString(implicit opts: PrinterOptions): String = prettyPrint(this, opts) - override def toString = asString(PrinterOptions.fromContext(InoxContext.printNames)) + override def toString = asString(PrinterOptions.fromContext(Context.printNames)) } val exprOps: ExprOps { val trees: Trees.this.type } = new { diff --git a/src/main/scala/inox/evaluators/Evaluator.scala b/src/main/scala/inox/evaluators/Evaluator.scala index eea69bb55..4283f2bfb 100644 --- a/src/main/scala/inox/evaluators/Evaluator.scala +++ b/src/main/scala/inox/evaluators/Evaluator.scala @@ -9,12 +9,12 @@ object EvaluatorOptions { ) } -object optIgnoreContracts extends InoxFlagOptionDef( +object optIgnoreContracts extends FlagOptionDef( "ignorecontracts", "Don't fail on invalid contracts during evaluation", false) trait Evaluator { val program: Program - val options: InoxOptions + val options: Options import program.trees._ /** The type of value that this [[Evaluator]] calculates diff --git a/src/main/scala/inox/evaluators/RecursiveEvaluator.scala b/src/main/scala/inox/evaluators/RecursiveEvaluator.scala index 977158077..d3b0e2da2 100644 --- a/src/main/scala/inox/evaluators/RecursiveEvaluator.scala +++ b/src/main/scala/inox/evaluators/RecursiveEvaluator.scala @@ -510,13 +510,13 @@ trait RecursiveEvaluator } object RecursiveEvaluator { - def apply(p: InoxProgram, opts: InoxOptions): RecursiveEvaluator { val program: p.type } = { + def apply(p: InoxProgram, opts: Options): RecursiveEvaluator { val program: p.type } = { new { val program: p.type = p } with RecursiveEvaluator with HasDefaultGlobalContext with HasDefaultRecContext { val options = opts val maxSteps = 50000 - def getSolver(moreOpts: InoxOption[Any]*) = solvers.SolverFactory(p, opts ++ moreOpts) + def getSolver(moreOpts: OptionValue[_]*) = solvers.SolverFactory(p, opts ++ moreOpts) } } diff --git a/src/main/scala/inox/evaluators/SolvingEvaluator.scala b/src/main/scala/inox/evaluators/SolvingEvaluator.scala index 18a7e5c07..dd0e9fa5a 100644 --- a/src/main/scala/inox/evaluators/SolvingEvaluator.scala +++ b/src/main/scala/inox/evaluators/SolvingEvaluator.scala @@ -12,7 +12,7 @@ trait SolvingEvaluator extends Evaluator { import program.trees._ import program.symbols._ - private object optForallCache extends InoxOptionDef[MutableMap[program.trees.Forall, Boolean]] { + private object optForallCache extends OptionDef[MutableMap[program.trees.Forall, Boolean]] { val parser = { (_: String) => throw FatalError("Unparsable option \"bankOption\"") } val name = "bank-option" val description = "Evaluation bank shared between solver and evaluator" @@ -20,7 +20,7 @@ trait SolvingEvaluator extends Evaluator { def default = MutableMap.empty } - def getSolver(opts: InoxOption[Any]*): SolverFactory { val program: SolvingEvaluator.this.program.type } + def getSolver(opts: OptionValue[_]*): SolverFactory { val program: SolvingEvaluator.this.program.type } private val chooseCache: MutableMap[Choose, Expr] = MutableMap.empty private val forallCache: MutableMap[Forall, Expr] = MutableMap.empty @@ -29,7 +29,7 @@ trait SolvingEvaluator extends Evaluator { val timer = ctx.timers.evaluators.specs.start() val sf = getSolver(options.options.collect { - case o @ InoxOption(opt, _) if opt == optForallCache => o + case o @ OptionValue(opt, _) if opt == optForallCache => o } : _*) import SolverResponses._ @@ -54,10 +54,10 @@ trait SolvingEvaluator extends Evaluator { val timer = ctx.timers.evaluators.forall.start() val sf = getSolver( - InoxOption(optSilentErrors)(true), - InoxOption(optCheckModels)(false), // model is checked manually!! (see below) - InoxOption(unrolling.optFeelingLucky)(false), - InoxOption(optForallCache)(cache) + optSilentErrors(true), + optCheckModels(false), // model is checked manually!! (see below) + unrolling.optFeelingLucky(false), + optForallCache(cache) ) import SolverResponses._ diff --git a/src/main/scala/inox/grammars/ExpressionGrammars.scala b/src/main/scala/inox/grammars/ExpressionGrammars.scala index 36c44cd61..0bc1bf875 100644 --- a/src/main/scala/inox/grammars/ExpressionGrammars.scala +++ b/src/main/scala/inox/grammars/ExpressionGrammars.scala @@ -33,7 +33,7 @@ trait ExpressionGrammars { self: GrammarsUniverse => */ def computeProductions(lab: Label): Seq[ProductionRule[Label, Expr]] - protected def applyAspects(lab: Label, ps: Seq[ProductionRule[Label, Expr]])(implicit ctx: InoxContext) = { + protected def applyAspects(lab: Label, ps: Seq[ProductionRule[Label, Expr]]) = { lab.aspects.foldLeft(ps) { case (ps, a) => a.applyTo(lab, ps) } diff --git a/src/main/scala/inox/package.scala b/src/main/scala/inox/package.scala index 755de6a11..bf2325f00 100644 --- a/src/main/scala/inox/package.scala +++ b/src/main/scala/inox/package.scala @@ -24,7 +24,7 @@ package object inox { type InoxProgram = Program { val trees: inox.trees.type } object InoxProgram { - def apply(ictx: InoxContext, + def apply(ictx: Context, functions: Seq[inox.trees.FunDef], adts: Seq[inox.trees.ADTDefinition]): InoxProgram = new Program { val trees = inox.trees @@ -34,7 +34,7 @@ package object inox { adts.map(cd => cd.id -> cd).toMap) } - def apply(ictx: InoxContext, sym: inox.trees.Symbols): InoxProgram = new Program { + def apply(ictx: Context, sym: inox.trees.Symbols): InoxProgram = new Program { val trees = inox.trees val ctx = ictx val symbols = sym diff --git a/src/main/scala/inox/solvers/Solver.scala b/src/main/scala/inox/solvers/Solver.scala index 34d317c62..ca04942d7 100644 --- a/src/main/scala/inox/solvers/Solver.scala +++ b/src/main/scala/inox/solvers/Solver.scala @@ -16,10 +16,10 @@ object SolverOptions { ) } -object optCheckModels extends InoxFlagOptionDef( +object optCheckModels extends FlagOptionDef( "checkmodels", "Double-check counter-examples with evaluator", false) -object optSilentErrors extends InoxFlagOptionDef( +object optSilentErrors extends FlagOptionDef( "silenterrors", "Fail silently into UNKNOWN when encountering an error", false) case object DebugSectionSolver extends DebugSection("solver") @@ -27,7 +27,7 @@ case object DebugSectionSolver extends DebugSection("solver") trait AbstractSolver extends Interruptible { def name: String val program: Program - val options: InoxOptions + val options: Options import program._ import program.trees._ diff --git a/src/main/scala/inox/solvers/SolverFactory.scala b/src/main/scala/inox/solvers/SolverFactory.scala index 6f0b46875..9ce18f966 100644 --- a/src/main/scala/inox/solvers/SolverFactory.scala +++ b/src/main/scala/inox/solvers/SolverFactory.scala @@ -36,7 +36,7 @@ object SolverFactory { import evaluators._ import combinators._ - private val solverNames = Map( + val solverNames = Map( "nativez3" -> "Native Z3 with z3-templates for unrolling", "unrollz3" -> "Native Z3 with inox-templates for unrolling", "smt-cvc4" -> "CVC4 through SMT-LIB", @@ -45,12 +45,14 @@ object SolverFactory { ) def getFromName(name: String) - (p: InoxProgram, opts: InoxOptions) - (ev: DeterministicEvaluator with SolvingEvaluator { val program: p.type }): + (p: Program, opts: Options) + (ev: DeterministicEvaluator with SolvingEvaluator { val program: p.type }, + enc: ProgramEncoder { val sourceProgram: p.type; val t: inox.trees.type }): SolverFactory { val program: p.type; type S <: TimeoutSolver } = name match { case "nativez3" => create(p)(name, () => new { val program: p.type = p val options = opts + val encoder = enc } with z3.NativeZ3Solver with TimeoutSolver { val evaluator = ev }) @@ -58,11 +60,12 @@ object SolverFactory { case "unrollz3" => create(p)(name, () => new { val program: p.type = p val options = opts + val encoder = enc } with unrolling.UnrollingSolver with theories.Z3Theories with TimeoutSolver { val evaluator = ev object underlying extends { - val program: theories.targetProgram.type = theories.targetProgram + val program: targetProgram.type = targetProgram val options = opts } with z3.UninterpretedZ3Solver }) @@ -70,11 +73,12 @@ object SolverFactory { case "smt-cvc4" => create(p)(name, () => new { val program: p.type = p val options = opts + val encoder = enc } with unrolling.UnrollingSolver with theories.CVC4Theories with TimeoutSolver { val evaluator = ev object underlying extends { - val program: theories.targetProgram.type = theories.targetProgram + val program: targetProgram.type = targetProgram val options = opts } with smtlib.CVC4Solver }) @@ -82,11 +86,12 @@ object SolverFactory { case "smt-z3" => create(p)(name, () => new { val program: p.type = p val options = opts + val encoder = enc } with unrolling.UnrollingSolver with theories.Z3Theories with TimeoutSolver { val evaluator = ev object underlying extends { - val program: theories.targetProgram.type = theories.targetProgram + val program: targetProgram.type = targetProgram val options = opts } with smtlib.Z3Solver }) @@ -108,12 +113,19 @@ object SolverFactory { case (name, desc) => f"\n $name%-14s : $desc" } - def apply(p: InoxProgram, opts: InoxOptions): SolverFactory { val program: p.type; type S <: TimeoutSolver } = + val solvers: Set[String] = solverNames.map(_._1).toSet + + def apply(name: String, p: InoxProgram, opts: Options): + SolverFactory { val program: p.type; type S <: TimeoutSolver } = { + getFromName(name)(p, opts)(RecursiveEvaluator(p, opts), ProgramEncoder.empty(p)) + } + + def apply(p: InoxProgram, opts: Options): SolverFactory { val program: p.type; type S <: TimeoutSolver } = p.ctx.options.findOptionOrDefault(InoxOptions.optSelectedSolvers).toSeq match { case Seq() => throw FatalError("No selected solver") - case Seq(single) => getFromName(single)(p, opts)(RecursiveEvaluator(p, opts)) + case Seq(single) => apply(single, p, opts) case multiple => PortfolioSolverFactory(p) { - multiple.map(name => getFromName(name)(p, opts)(RecursiveEvaluator(p, opts))) + multiple.map(name => apply(name, p, opts)) } } diff --git a/src/main/scala/inox/solvers/combinators/EncodingSolver.scala b/src/main/scala/inox/solvers/combinators/EncodingSolver.scala new file mode 100644 index 000000000..f32d2bd6d --- /dev/null +++ b/src/main/scala/inox/solvers/combinators/EncodingSolver.scala @@ -0,0 +1,57 @@ +/* Copyright 2009-2016 EPFL, Lausanne */ + +package inox +package solvers +package combinators + +import ast._ + +trait ProgramEncoder extends TreeBijection { + val sourceProgram: Program + lazy val s: sourceProgram.trees.type = sourceProgram.trees + lazy val targetProgram: Program { val trees: t.type } = sourceProgram.transform(encoder) + + /* @nv XXX: ideally, we would want to replace `>>` by `override def andThen`, however this + * seems to break the scala compiler for some weird reason... */ + def >>(that: TreeBijection { val s: ProgramEncoder.this.t.type }): ProgramEncoder { + val sourceProgram: ProgramEncoder.this.sourceProgram.type + val t: that.t.type + } = new ProgramEncoder { + val sourceProgram: ProgramEncoder.this.sourceProgram.type = ProgramEncoder.this.sourceProgram + val t: that.t.type = that.t + + val encoder = ProgramEncoder.this.encoder andThen that.encoder + val decoder = that.decoder andThen ProgramEncoder.this.decoder + } +} + +object ProgramEncoder { + def empty(p: Program): ProgramEncoder { + val sourceProgram: p.type + val t: p.trees.type + } = new ProgramEncoder { + val sourceProgram: p.type = p + val t: p.trees.type = p.trees + + val encoder = p.trees.SymbolIdentity + val decoder = p.trees.SymbolIdentity + } +} + +trait EncodingSolver extends Solver { + import program.trees._ + + protected val programEncoder: ProgramEncoder { val sourceProgram: program.type } + + protected def encode(vd: ValDef): programEncoder.t.ValDef = programEncoder.encode(vd) + protected def decode(vd: programEncoder.t.ValDef): ValDef = programEncoder.decode(vd) + + protected def encode(v: Variable): programEncoder.t.Variable = programEncoder.encode(v) + protected def decode(v: programEncoder.t.Variable): Variable = programEncoder.decode(v) + + protected def encode(e: Expr): programEncoder.t.Expr = programEncoder.encode(e) + protected def decode(e: programEncoder.t.Expr): Expr = programEncoder.decode(e) + + protected def encode(tpe: Type): programEncoder.t.Type = programEncoder.encode(tpe) + protected def decode(tpe: programEncoder.t.Type): Type = programEncoder.decode(tpe) +} diff --git a/src/main/scala/inox/solvers/smtlib/CVC4Solver.scala b/src/main/scala/inox/solvers/smtlib/CVC4Solver.scala index 738e344fa..428431877 100644 --- a/src/main/scala/inox/solvers/smtlib/CVC4Solver.scala +++ b/src/main/scala/inox/solvers/smtlib/CVC4Solver.scala @@ -6,6 +6,14 @@ package smtlib import inox.OptionParsers._ +object optCVC4Options extends OptionDef[Set[String]] { + val name = "solver:cvc4" + val description = "Pass extra arguments to CVC4" + val default = Set[String]() + val parser = setParser(stringParser) + val usageRhs = "<cvc4-opt>" +} + trait CVC4Solver extends SMTLIBSolver with CVC4Target { import program.trees._ import SolverResponses._ @@ -36,10 +44,3 @@ trait CVC4Solver extends SMTLIBSolver with CVC4Target { } } -object optCVC4Options extends InoxOptionDef[Set[String]] { - val name = "solver:cvc4" - val description = "Pass extra arguments to CVC4" - val default = Set[String]() - val parser = setParser(stringParser) - val usageRhs = "<cvc4-opt>" -} diff --git a/src/main/scala/inox/solvers/theories/BagEncoder.scala b/src/main/scala/inox/solvers/theories/BagEncoder.scala index 6a026a1e3..b91f3e52d 100644 --- a/src/main/scala/inox/solvers/theories/BagEncoder.scala +++ b/src/main/scala/inox/solvers/theories/BagEncoder.scala @@ -74,7 +74,7 @@ trait BagEncoder extends TheoryEncoder { override val newFunctions = Seq(Get, Add, Union, Difference, Intersect, BagEquals) override val newADTs = Seq(bagADT) - val encoder = new TreeTransformer { + val treeEncoder = new SelfTreeTransformer { import sourceProgram._ override def transform(e: Expr): Expr = e match { @@ -117,7 +117,7 @@ trait BagEncoder extends TheoryEncoder { } } - val decoder = new TreeTransformer { + val treeDecoder = new SelfTreeTransformer { import targetProgram._ override def transform(e: Expr): Expr = e match { diff --git a/src/main/scala/inox/solvers/theories/StringEncoder.scala b/src/main/scala/inox/solvers/theories/StringEncoder.scala index ddf7c4287..b37b9aaef 100644 --- a/src/main/scala/inox/solvers/theories/StringEncoder.scala +++ b/src/main/scala/inox/solvers/theories/StringEncoder.scala @@ -95,7 +95,7 @@ trait StringEncoder extends TheoryEncoder { v.toList.foldRight(StringNil()){ case (char, l) => StringCons(E(char), l) } } - val encoder = new TreeTransformer { + val treeEncoder = new SelfTreeTransformer { override def transform(e: Expr): Expr = e match { case StringLiteral(v) => convertFromString(v) case StringLength(a) => Size(transform(a)).copiedFrom(e) @@ -113,7 +113,7 @@ trait StringEncoder extends TheoryEncoder { } } - val decoder = new TreeTransformer { + val treeDecoder = new SelfTreeTransformer { override def transform(e: Expr): Expr = e match { case cc @ ADT(adt, args) if adt == StringNil || adt == StringCons => StringLiteral(convertToString(cc)).copiedFrom(cc) diff --git a/src/main/scala/inox/solvers/theories/TheoryEncoder.scala b/src/main/scala/inox/solvers/theories/TheoryEncoder.scala index 816291d41..95ece9fb3 100644 --- a/src/main/scala/inox/solvers/theories/TheoryEncoder.scala +++ b/src/main/scala/inox/solvers/theories/TheoryEncoder.scala @@ -6,75 +6,47 @@ package theories import utils._ -trait TheoryEncoder { +trait TheoryEncoder extends ast.TreeBijection { val sourceProgram: Program lazy val trees: sourceProgram.trees.type = sourceProgram.trees + + lazy val s: trees.type = trees + lazy val t: trees.type = trees + lazy val targetProgram: Program { val trees: TheoryEncoder.this.trees.type } = { sourceProgram.transform(encoder).withFunctions(newFunctions).withADTs(newADTs) } import trees._ - protected val encoder: TreeTransformer - protected val decoder: TreeTransformer - - val newFunctions: Seq[FunDef] = Seq.empty - val newADTs: Seq[ADTDefinition] = Seq.empty + protected val treeEncoder: SelfTransformer + protected val treeDecoder: SelfTransformer - def encode(v: Variable): Variable = encoder.transform(v) - def decode(v: Variable): Variable = decoder.transform(v) + lazy val encoder = new ast.SymbolTransformer { + val transformer: treeEncoder.type = treeEncoder + } - def encode(expr: Expr): Expr = encoder.transform(expr) - def decode(expr: Expr): Expr = decoder.transform(expr) + lazy val decoder = new ast.SymbolTransformer { + val transformer: treeEncoder.type = treeEncoder + } - def encode(tpe: Type): Type = encoder.transform(tpe) - def decode(tpe: Type): Type = decoder.transform(tpe) + val newFunctions: Seq[FunDef] = Seq.empty + val newADTs: Seq[ADTDefinition] = Seq.empty def >>(that: TheoryEncoder { val sourceProgram: TheoryEncoder.this.targetProgram.type }): TheoryEncoder { val sourceProgram: TheoryEncoder.this.sourceProgram.type } = new TheoryEncoder { val sourceProgram: TheoryEncoder.this.sourceProgram.type = TheoryEncoder.this.sourceProgram - val encoder = new TreeTransformer { - override def transform(id: Identifier, tpe: Type): (Identifier, Type) = { - val (id1, tpe1) = TheoryEncoder.this.encoder.transform(id, tpe) - that.encoder.transform(id1, tpe1) - } - - override def transform(expr: Expr): Expr = - that.encoder.transform(TheoryEncoder.this.encoder.transform(expr)) - - override def transform(tpe: Type): Type = - that.encoder.transform(TheoryEncoder.this.encoder.transform(tpe)) - } - - val decoder = new TreeTransformer { - override def transform(id: Identifier, tpe: Type): (Identifier, Type) = { - val (id1, tpe1) = that.decoder.transform(id, tpe) - TheoryEncoder.this.decoder.transform(id1, tpe1) - } - - override def transform(expr: Expr): Expr = - TheoryEncoder.this.decoder.transform(that.decoder.transform(expr)) - - override def transform(tpe: Type): Type = - TheoryEncoder.this.decoder.transform(that.decoder.transform(tpe)) - } + protected val treeEncoder: SelfTransformer = TheoryEncoder.this.treeEncoder andThen that.treeEncoder + protected val treeDecoder: SelfTransformer = that.treeDecoder andThen TheoryEncoder.this.treeDecoder } } trait NoEncoder extends TheoryEncoder { import trees._ - private object NoTransformer extends TreeTransformer { - override def transform(id: Identifier, tpe: Type): (Identifier, Type) = (id, tpe) - override def transform(v: Variable): Variable = v - override def transform(vd: ValDef): ValDef = vd - override def transform(expr: Expr): Expr = expr - override def transform(tpe: Type): Type = tpe - } - - val encoder: TreeTransformer = NoTransformer - val decoder: TreeTransformer = NoTransformer + protected val treeEncoder: SelfTransformer = TreeIdentity + protected val treeDecoder: SelfTransformer = TreeIdentity } diff --git a/src/main/scala/inox/solvers/theories/package.scala b/src/main/scala/inox/solvers/theories/package.scala index b2bd02c4b..617dde8e0 100644 --- a/src/main/scala/inox/solvers/theories/package.scala +++ b/src/main/scala/inox/solvers/theories/package.scala @@ -7,13 +7,13 @@ package object theories { trait Z3Theories { self: unrolling.AbstractUnrollingSolver => object theories extends { - val sourceProgram: self.program.type = self.program + val sourceProgram: self.encoder.targetProgram.type = self.encoder.targetProgram } with StringEncoder } trait CVC4Theories { self: unrolling.AbstractUnrollingSolver => object theories extends { - val sourceProgram: self.program.type = self.program + val sourceProgram: self.encoder.targetProgram.type = self.encoder.targetProgram } with BagEncoder } } diff --git a/src/main/scala/inox/solvers/unrolling/UnrollingSolver.scala b/src/main/scala/inox/solvers/unrolling/UnrollingSolver.scala index 6dcf57485..57daddc41 100644 --- a/src/main/scala/inox/solvers/unrolling/UnrollingSolver.scala +++ b/src/main/scala/inox/solvers/unrolling/UnrollingSolver.scala @@ -8,33 +8,43 @@ import utils._ import theories._ import evaluators._ +import combinators._ import scala.collection.mutable.{Map => MutableMap} -object optUnrollFactor extends InoxLongOptionDef( +object optUnrollFactor extends LongOptionDef( "unrollfactor", "Number of unfoldings to perform in each unfold step", default = 1, "<PosInt>") -object optFeelingLucky extends InoxFlagOptionDef( +object optFeelingLucky extends FlagOptionDef( "feelinglucky", "Use evaluator to find counter-examples early", false) -object optUnrollAssumptions extends InoxFlagOptionDef( +object optUnrollAssumptions extends FlagOptionDef( "unrollassumptions", "Use unsat-assumptions to drive unfolding while remaining fair", false) -trait AbstractUnrollingSolver - extends Solver { +trait AbstractUnrollingSolver extends Solver with EncodingSolver { import program._ import program.trees._ import program.symbols._ - import SolverResponses._ protected type Encoded - protected val theories: TheoryEncoder { val sourceProgram: program.type } + protected val encoder: ProgramEncoder { val sourceProgram: program.type } + + protected val theories: TheoryEncoder { val sourceProgram: AbstractUnrollingSolver.this.encoder.targetProgram.type } + + protected lazy val programEncoder = encoder >> theories + + protected lazy val s: programEncoder.s.type = programEncoder.s + protected lazy val t: programEncoder.t.type = programEncoder.t + protected lazy val targetProgram: programEncoder.targetProgram.type = programEncoder.targetProgram + + protected def encode(tpe: FunctionType): t.FunctionType = + programEncoder.encode(tpe).asInstanceOf[t.FunctionType] protected val templates: Templates { - val program: theories.targetProgram.type + val program: targetProgram.type type Encoded = AbstractUnrollingSolver.this.Encoded } @@ -43,7 +53,7 @@ trait AbstractUnrollingSolver } protected val underlying: AbstractSolver { - val program: AbstractUnrollingSolver.this.theories.targetProgram.type + val program: targetProgram.type type Trees = Encoded } @@ -89,15 +99,15 @@ trait AbstractUnrollingSolver interrupted = false } - protected def declareVariable(v: Variable): Encoded + protected def declareVariable(v: t.Variable): Encoded def assertCnstr(expression: Expr): Unit = { constraints += expression val bindings = exprOps.variablesOf(expression).map(v => v -> freeVars.cached(v) { - declareVariable(theories.encode(v)) + declareVariable(encode(v)) }).toMap - val newClauses = templates.instantiateExpr(expression, bindings) + val newClauses = templates.instantiateExpr(encode(expression), bindings.map(p => encode(p._1) -> p._2)) for (cl <- newClauses) { underlying.assertCnstr(cl) } @@ -106,17 +116,17 @@ trait AbstractUnrollingSolver protected def wrapModel(model: underlying.Model): ModelWrapper trait ModelWrapper { - protected def modelEval(elem: Encoded, tpe: Type): Option[Expr] + protected def modelEval(elem: Encoded, tpe: t.Type): Option[t.Expr] - def eval(elem: Encoded, tpe: Type): Option[Expr] = modelEval(elem, theories.encode(tpe)).flatMap { + def eval(elem: Encoded, tpe: s.Type): Option[Expr] = modelEval(elem, encode(tpe)).flatMap { expr => try { - Some(theories.decode(expr)) + Some(decode(expr)) } catch { case u: Unsupported => None } } - def get(v: Variable): Option[Expr] = eval(freeVars(v), v.getType).filter { + def get(v: Variable): Option[Expr] = eval(freeVars(v), v.tpe).filter { case v: Variable => false case _ => true } @@ -196,7 +206,7 @@ trait AbstractUnrollingSolver val id = Variable(FreshIdentifier("v"), tpe) val (functions, recons) = functionsOf(value, id) recons(functions.map { case (f, selector) => - val encoded = templates.mkEncoder(Map(id -> v))(selector) + val encoded = templates.mkEncoder(Map(encode(id) -> v))(encode(selector)) val tpe = bestRealType(f.getType).asInstanceOf[FunctionType] extractFunction(encoded, tpe) }) @@ -239,18 +249,21 @@ trait AbstractUnrollingSolver def extractFunction(f: Encoded, tpe: FunctionType): Expr = { def extractLambda(f: Encoded, tpe: FunctionType): Option[Lambda] = { - val optEqTemplate = templates.getLambdaTemplates(tpe).find { tmpl => + val optEqTemplate = templates.getLambdaTemplates(encode(tpe)).find { tmpl => wrapped.eval(tmpl.start, BooleanType) == Some(BooleanLiteral(true)) && wrapped.eval(templates.mkEquals(tmpl.ids._2, f), BooleanType) == Some(BooleanLiteral(true)) } optEqTemplate.map { tmpl => - val localsSubst = tmpl.structure.locals.map(p => p._1 -> wrapped.eval(p._2, p._1.tpe).getOrElse { - scala.sys.error("Unexpectedly failed to extract " + templates.asString(p._2) + - " with expected type " + p._1.tpe.asString) - }).toMap + val localsSubst = tmpl.structure.locals.map { case (v, ev) => + val dv = decode(v) + dv -> wrapped.eval(ev, dv.tpe).getOrElse { + scala.sys.error("Unexpectedly failed to extract " + templates.asString(ev) + + " with expected type " + dv.tpe.asString) + } + }.toMap - exprOps.replaceFromSymbols(localsSubst, tmpl.structure.lambda).asInstanceOf[Lambda] + exprOps.replaceFromSymbols(localsSubst, decode(tmpl.structure.lambda)).asInstanceOf[Lambda] } } @@ -266,9 +279,9 @@ trait AbstractUnrollingSolver case ft: FunctionType => val nextParams = params.tail val nextArguments = arguments.map(_.tail) - extract(templates.mkApp(caller, tpe, Seq.empty), ft, nextParams, nextArguments, dflt) + extract(templates.mkApp(caller, encode(tpe), Seq.empty), ft, nextParams, nextArguments, dflt) case _ => - (extractValue(templates.mkApp(caller, tpe, Seq.empty), tpe.to), false) + (extractValue(templates.mkApp(caller, encode(tpe), Seq.empty), tpe.to), false) } (Lambda(Seq.empty, result), real) @@ -279,7 +292,7 @@ trait AbstractUnrollingSolver case (currCond, arguments) => tpe.to match { case ft: FunctionType => val (currArgs, restArgs) = (arguments.head.head._1, arguments.map(_.tail)) - val newCaller = templates.mkApp(caller, tpe, currArgs) + val newCaller = templates.mkApp(caller, encode(tpe), currArgs) val (res, real) = extract(newCaller, ft, params.tail, restArgs, dflt) val mappings: Seq[(Expr, Expr)] = if (real) { Seq(BooleanLiteral(true) -> res) @@ -291,7 +304,7 @@ trait AbstractUnrollingSolver case _ => val currArgs = arguments.head.head._1 - val res = extractValue(templates.mkApp(caller, tpe, currArgs), tpe.to) + val res = extractValue(templates.mkApp(caller, encode(tpe), currArgs), tpe.to) Seq(currCond -> res) } } @@ -329,7 +342,7 @@ trait AbstractUnrollingSolver rec(tpe) } - val arguments = templates.getGroundInstantiations(f, tpe).flatMap { case (b, eArgs) => + val arguments = templates.getGroundInstantiations(f, encode(tpe)).flatMap { case (b, eArgs) => wrapped.eval(b, BooleanType).filter(_ == BooleanLiteral(true)).map(_ => eArgs) }.distinct @@ -367,7 +380,7 @@ trait AbstractUnrollingSolver } val default = extractValue(unflatten(flatArguments.last._1).foldLeft(f -> (tpe: Type)) { - case ((f, tpe: FunctionType), args) => (templates.mkApp(f, tpe, args), tpe.to) + case ((f, tpe: FunctionType), args) => (templates.mkApp(f, encode(tpe), args), tpe.to) }._1, tpe) extract(f, tpe, params, allArguments, default)._1 @@ -382,7 +395,7 @@ trait AbstractUnrollingSolver val assumptionsSeq : Seq[Expr] = assumptions.toSeq val encodedAssumptions : Seq[Encoded] = assumptionsSeq.map { expr => val vars = exprOps.variablesOf(expr) - templates.mkEncoder(vars.map(v => theories.encode(v) -> freeVars(v)).toMap)(expr) + templates.mkEncoder(vars.map(v => encode(v) -> freeVars(v)).toMap)(encode(expr)) } val encodedToAssumptions : Map[Encoded, Expr] = (encodedAssumptions zip assumptionsSeq).toMap @@ -585,11 +598,11 @@ trait UnrollingSolver extends AbstractUnrollingSolver { import program.trees._ import program.symbols._ - type Encoded = Expr + type Encoded = t.Expr val underlying: Solver { - val program: theories.targetProgram.type - type Trees = Expr - type Model = Map[ValDef, Expr] + val program: targetProgram.type + type Trees = t.Expr + type Model = Map[t.ValDef, t.Expr] } override val name = "U:"+underlying.name @@ -599,8 +612,12 @@ trait UnrollingSolver extends AbstractUnrollingSolver { } object templates extends { - val program: theories.targetProgram.type = theories.targetProgram + val program: targetProgram.type = targetProgram } with Templates { + import program._ + import program.trees._ + import program.symbols._ + type Encoded = Expr def asString(expr: Expr): String = expr.asString @@ -618,11 +635,13 @@ trait UnrollingSolver extends AbstractUnrollingSolver { def mkImplies(l: Expr, r: Expr) = implies(l, r) } - protected def declareVariable(v: Variable): Variable = v - protected def wrapModel(model: Map[ValDef, Expr]): super.ModelWrapper = ModelWrapper(model) + protected def declareVariable(v: t.Variable): t.Variable = v + protected def wrapModel(model: Map[t.ValDef, t.Expr]): super.ModelWrapper = + ModelWrapper(model.map(p => decode(p._1) -> decode(p._2))) private case class ModelWrapper(model: Map[ValDef, Expr]) extends super.ModelWrapper { - def modelEval(elem: Expr, tpe: Type): Option[Expr] = evaluator.eval(elem, model).result + def modelEval(elem: t.Expr, tpe: t.Type): Option[t.Expr] = + evaluator.eval(decode(elem), model).result.map(encode) override def toString = model.mkString("\n") } diff --git a/src/main/scala/inox/solvers/z3/NativeZ3Solver.scala b/src/main/scala/inox/solvers/z3/NativeZ3Solver.scala index 7d1eeccc7..2b62d3683 100644 --- a/src/main/scala/inox/solvers/z3/NativeZ3Solver.scala +++ b/src/main/scala/inox/solvers/z3/NativeZ3Solver.scala @@ -21,19 +21,22 @@ trait NativeZ3Solver type Encoded = Z3AST object theories extends { - val sourceProgram: program.type = program + val sourceProgram: NativeZ3Solver.this.encoder.targetProgram.type = + NativeZ3Solver.this.encoder.targetProgram } with StringEncoder protected object underlying extends { - val program: theories.targetProgram.type = theories.targetProgram + val program: targetProgram.type = targetProgram val options = NativeZ3Solver.this.options } with AbstractZ3Solver private lazy val z3 = underlying.z3 object templates extends { - val program: theories.targetProgram.type = theories.targetProgram + val program: targetProgram.type = targetProgram } with Templates { + import program.trees._ + type Encoded = NativeZ3Solver.this.Encoded def asString(ast: Z3AST): String = ast.toString @@ -65,19 +68,19 @@ trait NativeZ3Solver def free(): Unit = underlying.free() - protected def declareVariable(v: Variable): Z3AST = underlying.declareVariable(v) + protected def declareVariable(v: t.Variable): Z3AST = underlying.declareVariable(v) protected def wrapModel(model: Z3Model): super.ModelWrapper = ModelWrapper(model) private case class ModelWrapper(model: Z3Model) extends super.ModelWrapper { - def modelEval(elem: Z3AST, tpe: Type): Option[Expr] = { + def modelEval(elem: Z3AST, tpe: t.Type): Option[t.Expr] = { val timer = ctx.timers.solvers.z3.eval.start() val res = tpe match { - case BooleanType => model.evalAs[Boolean](elem).map(BooleanLiteral) - case Int32Type => model.evalAs[Int](elem).map(IntLiteral(_)).orElse { - model.eval(elem).flatMap(t => underlying.softFromZ3Formula(model, t, Int32Type)) + case t.BooleanType => model.evalAs[Boolean](elem).map(t.BooleanLiteral) + case t.Int32Type => model.evalAs[Int](elem).map(t.IntLiteral(_)).orElse { + model.eval(elem).flatMap(term => underlying.softFromZ3Formula(model, term, t.Int32Type)) } - case IntegerType => model.evalAs[Int](elem).map(IntegerLiteral(_)) + case t.IntegerType => model.evalAs[Int](elem).map(t.IntegerLiteral(_)) case other => model.eval(elem) match { case None => None case Some(t) => underlying.softFromZ3Formula(model, t, other) diff --git a/src/test/scala/inox/ast/TreeTestsSuite.scala b/src/test/scala/inox/ast/TreeTestsSuite.scala index 7d2603361..9ba994d21 100644 --- a/src/test/scala/inox/ast/TreeTestsSuite.scala +++ b/src/test/scala/inox/ast/TreeTestsSuite.scala @@ -13,7 +13,7 @@ class TreeTestsSuite extends FunSuite { //TODO dont like the fact that we need to create an empty program // to get access to the and/or constructors - val pgm = InoxProgram(InoxContext.empty, Seq(), Seq()) + val pgm = InoxProgram(Context.empty, Seq(), Seq()) import pgm.symbols._ val x = Variable(FreshIdentifier("x"), BooleanType) diff --git a/src/test/scala/inox/evaluators/EvaluatorSuite.scala b/src/test/scala/inox/evaluators/EvaluatorSuite.scala index c835e9dbc..66ed0ef04 100644 --- a/src/test/scala/inox/evaluators/EvaluatorSuite.scala +++ b/src/test/scala/inox/evaluators/EvaluatorSuite.scala @@ -8,10 +8,10 @@ import org.scalatest._ class EvaluatorSuite extends FunSuite { import inox.trees._ - val ctx = InoxContext.empty + val ctx = Context.empty val symbols = new Symbols(Map.empty, Map.empty) - def evaluator(ctx: InoxContext): DeterministicEvaluator { val program: InoxProgram } = { + def evaluator(ctx: Context): DeterministicEvaluator { val program: InoxProgram } = { val program = InoxProgram(ctx, symbols) RecursiveEvaluator.default(program) } -- GitLab