From 9574980de6d773f22839d1d4562dc07d670c3330 Mon Sep 17 00:00:00 2001 From: Nicolas Voirol <voirol.nicolas@gmail.com> Date: Sat, 16 Jul 2016 16:53:12 +0200 Subject: [PATCH] Progress on inox trees --- src/main/scala/inox/GlobalOptions.scala | 83 - .../{LeonContext.scala => InoxContext.scala} | 16 +- .../{LeonOption.scala => InoxOptions.scala} | 85 +- src/main/scala/inox/LeonComponent.scala | 13 - src/main/scala/inox/LeonExceptions.scala | 14 - src/main/scala/inox/Printable.scala | 6 +- src/main/scala/inox/Program.scala | 9 + src/main/scala/inox/Reporter.scala | 9 +- .../scala/inox/{trees => ast}/CallGraph.scala | 18 +- .../inox/{trees => ast}/Constructors.scala | 75 +- .../inox/{trees => ast}/Definitions.scala | 258 +- src/main/scala/inox/ast/ExprOps.scala | 163 ++ .../inox/{trees => ast}/Expressions.scala | 346 +-- .../inox/{trees => ast}/Extractors.scala | 64 +- .../inox/{trees => ast}/GenTreeOps.scala | 10 +- .../scala/inox/{trees => ast}/Paths.scala | 61 +- .../inox/{trees => ast}/PrinterOptions.scala | 0 .../scala/inox/{trees => ast}/Printers.scala | 392 +-- src/main/scala/inox/ast/SymbolOps.scala | 1050 ++++++++ src/main/scala/inox/ast/TreeOps.scala | 200 ++ .../scala/inox/{trees => ast}/Trees.scala | 51 +- .../scala/inox/{trees => ast}/TypeOps.scala | 204 +- .../scala/inox/{trees => ast}/Types.scala | 32 +- .../scala/inox/{trees => ast}/package.scala | 4 +- src/main/scala/inox/package.scala | 4 +- src/main/scala/inox/trees/ExprOps.scala | 2337 ----------------- src/main/scala/inox/utils/Benchmarks.scala | 72 - src/main/scala/inox/utils/DebugSections.scala | 49 - .../scala/inox/utils/FileOutputPhase.scala | 37 - src/main/scala/inox/utils/FilesWatcher.scala | 69 - src/main/scala/inox/utils/GraphOps.scala | 2 +- src/main/scala/inox/utils/GraphPrinters.scala | 2 +- src/main/scala/inox/utils/Graphs.scala | 2 +- src/main/scala/inox/utils/InliningPhase.scala | 48 - .../scala/inox/utils/InterruptManager.scala | 2 +- src/main/scala/inox/utils/Interruptible.scala | 2 +- src/main/scala/inox/utils/Library.scala | 49 - src/main/scala/inox/utils/Positions.scala | 2 +- src/main/scala/inox/utils/StreamUtils.scala | 6 +- .../inox/utils/TemporaryInputPhase.scala | 28 - src/main/scala/inox/utils/Timer.scala | 2 +- src/main/scala/inox/utils/TypingPhase.scala | 100 - src/main/scala/inox/utils/UniqueCounter.scala | 2 +- .../scala/inox/utils/UnitElimination.scala | 155 -- src/main/scala/inox/utils/package.scala | 4 +- 45 files changed, 2149 insertions(+), 3988 deletions(-) delete mode 100644 src/main/scala/inox/GlobalOptions.scala rename src/main/scala/inox/{LeonContext.scala => InoxContext.scala} (84%) rename src/main/scala/inox/{LeonOption.scala => InoxOptions.scala} (59%) delete mode 100644 src/main/scala/inox/LeonComponent.scala delete mode 100644 src/main/scala/inox/LeonExceptions.scala create mode 100644 src/main/scala/inox/Program.scala rename src/main/scala/inox/{trees => ast}/CallGraph.scala (83%) rename src/main/scala/inox/{trees => ast}/Constructors.scala (82%) rename src/main/scala/inox/{trees => ast}/Definitions.scala (61%) create mode 100644 src/main/scala/inox/ast/ExprOps.scala rename src/main/scala/inox/{trees => ast}/Expressions.scala (66%) rename src/main/scala/inox/{trees => ast}/Extractors.scala (84%) rename src/main/scala/inox/{trees => ast}/GenTreeOps.scala (98%) rename src/main/scala/inox/{trees => ast}/Paths.scala (79%) rename src/main/scala/inox/{trees => ast}/PrinterOptions.scala (100%) rename src/main/scala/inox/{trees => ast}/Printers.scala (53%) create mode 100644 src/main/scala/inox/ast/SymbolOps.scala create mode 100644 src/main/scala/inox/ast/TreeOps.scala rename src/main/scala/inox/{trees => ast}/Trees.scala (68%) rename src/main/scala/inox/{trees => ast}/TypeOps.scala (53%) rename src/main/scala/inox/{trees => ast}/Types.scala (78%) rename src/main/scala/inox/{trees => ast}/package.scala (97%) delete mode 100644 src/main/scala/inox/trees/ExprOps.scala delete mode 100644 src/main/scala/inox/utils/Benchmarks.scala delete mode 100644 src/main/scala/inox/utils/DebugSections.scala delete mode 100644 src/main/scala/inox/utils/FileOutputPhase.scala delete mode 100644 src/main/scala/inox/utils/FilesWatcher.scala delete mode 100644 src/main/scala/inox/utils/InliningPhase.scala delete mode 100644 src/main/scala/inox/utils/Library.scala delete mode 100644 src/main/scala/inox/utils/TemporaryInputPhase.scala delete mode 100644 src/main/scala/inox/utils/TypingPhase.scala delete mode 100644 src/main/scala/inox/utils/UnitElimination.scala diff --git a/src/main/scala/inox/GlobalOptions.scala b/src/main/scala/inox/GlobalOptions.scala deleted file mode 100644 index 37eeeef3f..000000000 --- a/src/main/scala/inox/GlobalOptions.scala +++ /dev/null @@ -1,83 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon - -import leon.utils.{DebugSections, DebugSection} -import OptionParsers._ - -/** This object contains options that are shared among different modules of Leon. - * - * Options that determine the pipeline of Leon are not stored here, - * but in [[Main.MainComponent]] instead. - */ -object GlobalOptions extends LeonComponent { - - val name = "sharedOptions" - val description = "Options shared by multiple components of Leon" - - val optStrictPhases = LeonFlagOptionDef("strict", "Terminate after each phase if there is an error", true) - - val optBenchmark = LeonFlagOptionDef("benchmark", "Dump benchmarking information in a data file", false) - - val optWatch = LeonFlagOptionDef("watch", "Rerun pipeline when file changes", false) - - val optSilent = LeonFlagOptionDef("silent", "Do not display progress messages or results to the console", false) - - val optFunctions = new LeonOptionDef[Seq[String]] { - val name = "functions" - val description = "Only consider functions f1, f2, ..." - val default = Seq[String]() - val parser = seqParser(stringParser) - val usageRhs = "f1,f2,..." - } - - val optSelectedSolvers = new LeonOptionDef[Set[String]] { - val name = "solvers" - val description = "Use solvers s1, s2,...\n" + solvers.SolverFactory.availableSolversPretty - val default = Set("fairz3") - val parser = setParser(stringParser) - val usageRhs = "s1,s2,..." - } - - val optDebug = new LeonOptionDef[Set[DebugSection]] { - import OptionParsers._ - val name = "debug" - val description = { - val sects = DebugSections.all.toSeq.map(_.name).sorted - val (first, second) = sects.splitAt(sects.length/2 + 1) - "Enable detailed messages per component.\nAvailable:\n" + - " " + first.mkString(", ") + ",\n" + - " " + second.mkString(", ") - } - val default = Set[DebugSection]() - val usageRhs = "d1,d2,..." - private val debugParser: OptionParser[Set[DebugSection]] = s => { - if (s == "all") { - Some(DebugSections.all) - } else { - DebugSections.all.find(_.name == s).map(Set(_)) - } - } - val parser: String => Option[Set[DebugSection]] = { - setParser[Set[DebugSection]](debugParser)(_).map(_.flatten) - } - } - - val optTimeout = LeonLongOptionDef( - "timeout", - "Set a timeout for attempting to prove a verification condition/ repair a function (in sec.)", - 0L, - "t" - ) - - override val definedOptions: Set[LeonOptionDef[Any]] = Set( - optStrictPhases, - optBenchmark, - optFunctions, - optSelectedSolvers, - optDebug, - optWatch, - optTimeout, - optSilent - ) -} diff --git a/src/main/scala/inox/LeonContext.scala b/src/main/scala/inox/InoxContext.scala similarity index 84% rename from src/main/scala/inox/LeonContext.scala rename to src/main/scala/inox/InoxContext.scala index 8e0a414f1..0f6a3ed7d 100644 --- a/src/main/scala/inox/LeonContext.scala +++ b/src/main/scala/inox/InoxContext.scala @@ -1,10 +1,8 @@ /* Copyright 2009-2016 EPFL, Lausanne */ -package leon +package inox -import leon.utils._ - -import java.io.File +import inox.utils._ import scala.reflect.ClassTag @@ -12,14 +10,12 @@ import scala.reflect.ClassTag * LeonContexts are immutable, and so should all their fields (with the possible * exception of the reporter). */ -case class LeonContext( +case class Context( reporter: Reporter, interruptManager: InterruptManager, options: Seq[LeonOption[Any]] = Seq(), - files: Seq[File] = Seq(), - classDir: Option[File] = None, - timers: TimerStorage = new TimerStorage -) { + timers: TimerStorage = new TimerStorage, + bank: evaluators.EvaluationBank) { def findOption[A: ClassTag](optDef: LeonOptionDef[A]): Option[A] = options.collectFirst { case LeonOption(`optDef`, value:A) => value @@ -31,7 +27,7 @@ case class LeonContext( def toSctx = solvers.SolverContext(this, new evaluators.EvaluationBank) } -object LeonContext { +object Context { def empty = { val reporter = new DefaultReporter(Set()) LeonContext(reporter, new InterruptManager(reporter)) diff --git a/src/main/scala/inox/LeonOption.scala b/src/main/scala/inox/InoxOptions.scala similarity index 59% rename from src/main/scala/inox/LeonOption.scala rename to src/main/scala/inox/InoxOptions.scala index d1f93981f..32fdc8c72 100644 --- a/src/main/scala/inox/LeonOption.scala +++ b/src/main/scala/inox/InoxOptions.scala @@ -1,15 +1,12 @@ /* Copyright 2009-2016 EPFL, Lausanne */ -package leon +package inox import OptionParsers._ -import purescala.Definitions._ -import purescala.DefOps.fullName - import scala.util.Try -abstract class LeonOptionDef[+A] { +abstract class InoxOptionDef[+A] { val name: String val description: String val default: A @@ -32,49 +29,48 @@ abstract class LeonOptionDef[+A] { ) } - def parse(s: String)(implicit reporter: Reporter): LeonOption[A] = - LeonOption(this)(parseValue(s)) + def parse(s: String)(implicit reporter: Reporter): InoxOption[A] = + InoxOption(this)(parseValue(s)) - def withDefaultValue: LeonOption[A] = - LeonOption(this)(default) + def withDefaultValue: InoxOption[A] = + InoxOption(this)(default) // @mk: FIXME: Is this cool? override def equals(other: Any) = other match { - case that: LeonOptionDef[_] => this.name == that.name + case that: InoxOptionDef[_] => this.name == that.name case _ => false } override def hashCode = name.hashCode } -case class LeonFlagOptionDef(name: String, description: String, default: Boolean) extends LeonOptionDef[Boolean] { +case class InoxFlagOptionDef(name: String, description: String, default: Boolean) extends InoxOptionDef[Boolean] { val parser = booleanParser val usageRhs = "" } -case class LeonStringOptionDef(name: String, description: String, default: String, usageRhs: String) extends LeonOptionDef[String] { +case class InoxStringOptionDef(name: String, description: String, default: String, usageRhs: String) extends InoxOptionDef[String] { val parser = stringParser } -case class LeonLongOptionDef(name: String, description: String, default: Long, usageRhs: String) extends LeonOptionDef[Long] { +case class InoxLongOptionDef(name: String, description: String, default: Long, usageRhs: String) extends InoxOptionDef[Long] { val parser = longParser } - -class LeonOption[+A] private (val optionDef: LeonOptionDef[A], val value: A) { +class InoxOption[+A] private (val optionDef: InoxOptionDef[A], val value: A) { override def toString = s"--${optionDef.name}=$value" override def equals(other: Any) = other match { - case LeonOption(optionDef, value) => + case InoxOption(optionDef, value) => optionDef.name == this.optionDef.name && value == this.value case _ => false } override def hashCode = optionDef.hashCode + value.hashCode } -object LeonOption { - def apply[A](optionDef: LeonOptionDef[A])(value: A) = { - new LeonOption(optionDef, value) +object InoxOption { + def apply[A](optionDef: InoxOptionDef[A])(value: A) = { + new InoxOption(optionDef, value) } - def unapply[A](opt: LeonOption[A]) = Some((opt.optionDef, opt.value)) + def unapply[A](opt: InoxOption[A]) = Some((opt.optionDef, opt.value)) } object OptionParsers { @@ -98,13 +94,12 @@ object OptionParsers { ) foo } + def setParser[A](base: OptionParser[A]): OptionParser[Set[A]] = { seqParser(base)(_).map(_.toSet) } - } - object OptionsHelpers { private val matcher = s"--(.*)=(.*)".r @@ -142,10 +137,6 @@ object OptionsHelpers { (name: String) => regexPatterns.exists(p => p.matcher(name).matches()) } - def fdMatcher(pgm: Program)(patterns: Traversable[String]): FunDef => Boolean = { - { (fd: FunDef) => fullName(fd)(pgm) } andThen matcher(patterns) - } - def filterInclusive[T](included: Option[T => Boolean], excluded: Option[T => Boolean]): T => Boolean = { included match { case Some(i) => @@ -158,3 +149,45 @@ object OptionsHelpers { } } } + +object InoxOptions { + + val optSelectedSolvers = new LeonOptionDef[Set[String]] { + val name = "solvers" + val description = "Use solvers s1, s2,...\n" + solvers.SolverFactory.availableSolversPretty + val default = Set("fairz3") + val parser = setParser(stringParser) + val usageRhs = "s1,s2,..." + } + + val optDebug = new LeonOptionDef[Set[DebugSection]] { + import OptionParsers._ + val name = "debug" + val description = { + val sects = DebugSections.all.toSeq.map(_.name).sorted + val (first, second) = sects.splitAt(sects.length/2 + 1) + "Enable detailed messages per component.\nAvailable:\n" + + " " + first.mkString(", ") + ",\n" + + " " + second.mkString(", ") + } + val default = Set[DebugSection]() + val usageRhs = "d1,d2,..." + private val debugParser: OptionParser[Set[DebugSection]] = s => { + if (s == "all") { + Some(DebugSections.all) + } else { + DebugSections.all.find(_.name == s).map(Set(_)) + } + } + val parser: String => Option[Set[DebugSection]] = { + setParser[Set[DebugSection]](debugParser)(_).map(_.flatten) + } + } + + val optTimeout = LeonLongOptionDef( + "timeout", + "Set a timeout for attempting to prove a verification condition/ repair a function (in sec.)", + 0L, + "t" + ) +} diff --git a/src/main/scala/inox/LeonComponent.scala b/src/main/scala/inox/LeonComponent.scala deleted file mode 100644 index 7aeaacb10..000000000 --- a/src/main/scala/inox/LeonComponent.scala +++ /dev/null @@ -1,13 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon - -/** A common trait for everything that is important enough to be named, - * and that defines command line options. And important category are - * [[LeonPhase]]s. */ -trait LeonComponent { - val name : String - val description : String - - val definedOptions : Set[LeonOptionDef[Any]] = Set() -} diff --git a/src/main/scala/inox/LeonExceptions.scala b/src/main/scala/inox/LeonExceptions.scala deleted file mode 100644 index ad526d782..000000000 --- a/src/main/scala/inox/LeonExceptions.scala +++ /dev/null @@ -1,14 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon - -import purescala.Common.Tree - -case class LeonFatalError(msg: Option[String]) extends Exception(msg.getOrElse("")) - -object LeonFatalError { - def apply(msg: String) = new LeonFatalError(Some(msg)) -} - -class Unsupported(t: Tree, msg: String)(implicit ctx: LeonContext) - extends Exception(s"${t.asString}@${t.getPos} $msg") \ No newline at end of file diff --git a/src/main/scala/inox/Printable.scala b/src/main/scala/inox/Printable.scala index e55272e57..d92b53d80 100644 --- a/src/main/scala/inox/Printable.scala +++ b/src/main/scala/inox/Printable.scala @@ -1,8 +1,8 @@ /* Copyright 2009-2016 EPFL, Lausanne */ -package leon +package inox -/** A trait for objects that can be pretty-printed given a [[leon.LeonContext]] */ +/** A trait for objects that can be pretty-printed given a [[inox.Context]] */ trait Printable { - def asString(implicit ctx: LeonContext): String + def asString(implicit ctx: Context): String } diff --git a/src/main/scala/inox/Program.scala b/src/main/scala/inox/Program.scala new file mode 100644 index 000000000..0e0649f50 --- /dev/null +++ b/src/main/scala/inox/Program.scala @@ -0,0 +1,9 @@ +/* Copyright 2009-2016 EPFL, Lausanne */ + +package inox + +import ast._ + +class Program(val trees: Trees)(val symbols: trees.Symbols, val ctx: Context) { + +} diff --git a/src/main/scala/inox/Reporter.scala b/src/main/scala/inox/Reporter.scala index ff376b905..8e6e2c09c 100644 --- a/src/main/scala/inox/Reporter.scala +++ b/src/main/scala/inox/Reporter.scala @@ -1,9 +1,12 @@ /* Copyright 2009-2016 EPFL, Lausanne */ -package leon +package inox import utils._ +abstract class DebugSection(val name: String, val mask: Int) +case object DebugSectionSolver extends DebugSection("solver", 1 << 0) + abstract class Reporter(val debugSections: Set[DebugSection]) { abstract class Severity @@ -34,9 +37,7 @@ abstract class Reporter(val debugSections: Set[DebugSection]) { def emit(msg: Message): Unit - def onFatal(): Nothing = { - throw LeonFatalError(None) - } + def onFatal(): Nothing = throw FatalError("") def onCompilerProgress(current: Int, total: Int) = {} diff --git a/src/main/scala/inox/trees/CallGraph.scala b/src/main/scala/inox/ast/CallGraph.scala similarity index 83% rename from src/main/scala/inox/trees/CallGraph.scala rename to src/main/scala/inox/ast/CallGraph.scala index 0c5d1b2b0..fd00dd082 100644 --- a/src/main/scala/inox/trees/CallGraph.scala +++ b/src/main/scala/inox/ast/CallGraph.scala @@ -1,24 +1,24 @@ /* Copyright 2009-2016 EPFL, Lausanne */ package inox -package trees - -import Definitions._ -import Expressions._ -import ExprOps._ +package ast import utils.Graphs._ -class CallGraph(p: Program) { +trait CallGraph { + val trees: Trees + import trees._ + import trees.exprOps._ + val symbols: Symbols private def collectCallsInPats(fd: FunDef)(p: Pattern): Set[(FunDef, FunDef)] = (p match { - case u: UnapplyPattern => Set((fd, u.unapplyFun.fd)) + case u: UnapplyPattern => Set((fd, symbols.getFunction(u.id))) case _ => Set() }) ++ p.subPatterns.flatMap(collectCallsInPats(fd)) private def collectCalls(fd: FunDef)(e: Expr): Set[(FunDef, FunDef)] = e match { - case f @ FunctionInvocation(f2, _) => Set((fd, f2.fd)) + case f @ FunctionInvocation(id, tps, _) => Set((fd, symbols.getFunction(id))) case MatchExpr(_, cases) => cases.toSet.flatMap((mc: MatchCase) => collectCallsInPats(fd)(mc.pattern)) case _ => Set() } @@ -26,7 +26,7 @@ class CallGraph(p: Program) { lazy val graph: DiGraph[FunDef, SimpleEdge[FunDef]] = { var g = DiGraph[FunDef, SimpleEdge[FunDef]]() - for (fd <- p.definedFunctions; c <- collect(collectCalls(fd))(fd.fullBody)) { + for ((_, fd) <- symbols.functions; c <- collect(collectCalls(fd))(fd.fullBody)) { g += SimpleEdge(c._1, c._2) } diff --git a/src/main/scala/inox/trees/Constructors.scala b/src/main/scala/inox/ast/Constructors.scala similarity index 82% rename from src/main/scala/inox/trees/Constructors.scala rename to src/main/scala/inox/ast/Constructors.scala index 85ec99a73..d351540c2 100644 --- a/src/main/scala/inox/trees/Constructors.scala +++ b/src/main/scala/inox/ast/Constructors.scala @@ -1,7 +1,7 @@ /* Copyright 2009-2016 EPFL, Lausanne */ package inox -package trees +package ast /** Provides constructors for [[purescala.Expressions]]. * @@ -9,8 +9,12 @@ package trees * potentially use a different expression node if one is more suited. * @define encodingof Encoding of * */ -trait Constructors { self: ExprOps => +trait Constructors { + val trees: Trees import trees._ + import trees.exprOps._ + implicit val symbols: Symbols + import symbols._ /** If `isTuple`: * `tupleSelect(tupleWrap(Seq(Tuple(x,y))),1) -> x` @@ -37,16 +41,16 @@ trait Constructors { self: ExprOps => /** $encodingof ``val id = e; bd``, and returns `bd` if the identifier is not bound in `bd`. * @see [[purescala.Expressions.Let]] */ - def let(id: Identifier, e: Expr, bd: Expr) = { - if (exprOps.variablesOf(bd) contains id) - Let(id, e, bd) + def let(vd: ValDef, e: Expr, bd: Expr) = { + if (exprOps.variablesOf(bd) contains vd.toVariable) + Let(vd, e, bd) else bd } /** $encodingof ``val (...binders...) = value; body`` which is translated to ``value match { case (...binders...) => body }``, and returns `body` if the identifiers are not bound in `body`. * @see [[purescala.Expressions.Let]] */ - def letTuple(binders: Seq[Identifier], value: Expr, body: Expr) = binders match { + def letTuple(binders: Seq[ValDef], value: Expr, body: Expr) = binders match { case Nil => body case x :: Nil => @@ -60,7 +64,7 @@ trait Constructors { self: ExprOps => s"In letTuple: '$value' is being assigned as a tuple of arity ${xs.size}; yet its type is '${value.getType}' (body is '$body')" ) - Extractors.LetPattern(TuplePattern(None,binders map { b => WildcardPattern(Some(b)) }), value, body) + LetPattern(TuplePattern(None,binders map { b => WildcardPattern(Some(b)) }), value, body) } /** Wraps the sequence of expressions as a tuple. If the sequence contains a single expression, it is returned instead. @@ -87,7 +91,7 @@ trait Constructors { self: ExprOps => * If the sequence is empty, the [[purescala.Types.UnitType UnitType]] is returned. * @see [[purescala.Types.TupleType]] */ - def tupleTypeWrap(tps: Seq[TypeTree]) = tps match { + def tupleTypeWrap(tps: Seq[Type]) = tps match { case Seq() => UnitType case Seq(elem) => elem case more => TupleType(more) @@ -103,20 +107,20 @@ trait Constructors { self: ExprOps => val formalType = tupleTypeWrap(fd.params map { _.getType }) val actualType = tupleTypeWrap(args map { _.getType }) - canBeSupertypeOf(formalType, actualType) match { + symbols.canBeSupertypeOf(formalType, actualType) match { case Some(tmap) => - FunctionInvocation(fd.typed(fd.tparams map { tpd => tmap.getOrElse(tpd.tp, tpd.tp) }), args) - case None => throw LeonFatalError(s"$args:$actualType cannot be a subtype of $formalType!") + FunctionInvocation(fd.id, fd.tparams map { tpd => tmap.getOrElse(tpd.tp, tpd.tp) }, args) + case None => throw FatalError(s"$args:$actualType cannot be a subtype of $formalType!") } } /** Simplifies the provided case class selector. * @see [[purescala.Expressions.CaseClassSelector]] */ - def caseClassSelector(classType: CaseClassType, caseClass: Expr, selector: Identifier): Expr = { + def caseClassSelector(classType: ClassType, caseClass: Expr, selector: Identifier): Expr = { caseClass match { - case CaseClass(ct, fields) if ct.classDef == classType.classDef && !ct.classDef.hasInvariant => - fields(ct.classDef.selectorID2Index(selector)) + case CaseClass(ct, fields) if ct == classType && !ct.tcd.hasInvariant => + fields(ct.tcd.cd.asInstanceOf[CaseClassDef].selectorID2Index(selector)) case _ => CaseClassSelector(classType, caseClass, selector) } @@ -126,11 +130,11 @@ trait Constructors { self: ExprOps => * @see [[purescala.Expressions.CaseClassPattern MatchExpr]] * @see [[purescala.Expressions.CaseClassPattern CaseClassPattern]] */ - private def filterCases(scrutType: TypeTree, resType: Option[TypeTree], cases: Seq[MatchCase]): Seq[MatchCase] = { + private def filterCases(scrutType: Type, resType: Option[Type], cases: Seq[MatchCase]): Seq[MatchCase] = { val casesFiltered = scrutType match { - case c: CaseClassType => + case c: ClassType if !c.tcd.cd.isAbstract => cases.filter(_.pattern match { - case CaseClassPattern(_, cct, _) if cct.classDef != c.classDef => false + case CaseClassPattern(_, cct, _) if cct.id != c.id => false case _ => true }) @@ -140,7 +144,7 @@ trait Constructors { self: ExprOps => resType match { case Some(tpe) => - casesFiltered.filter(c => typesCompatible(c.rhs.getType, tpe)) + casesFiltered.filter(c => symbols.typesCompatible(c.rhs.getType, tpe)) case None => casesFiltered } @@ -149,7 +153,7 @@ trait Constructors { self: ExprOps => /** $encodingof `... match { ... }` but simplified if possible. Simplifies to [[Error]] if no case can match the scrutined expression. * @see [[purescala.Expressions.MatchExpr MatchExpr]] */ - def matchExpr(scrutinee : Expr, cases : Seq[MatchCase]) : Expr ={ + def matchExpr(scrutinee: Expr, cases: Seq[MatchCase]): Expr = { val filtered = filterCases(scrutinee.getType, None, cases) if (filtered.nonEmpty) MatchExpr(scrutinee, filtered) @@ -268,20 +272,20 @@ trait Constructors { self: ExprOps => case Lambda(formalArgs, body) => assert(realArgs.size == formalArgs.size, "Invoking lambda with incorrect number of arguments") - var defs: Seq[(Identifier, Expr)] = Seq() + var defs: Seq[(ValDef, Expr)] = Seq() val subst = formalArgs.zip(realArgs).map { - case (ValDef(from), to:Variable) => - from -> to - case (ValDef(from), e) => - val fresh = from.freshen + case (vd, to:Variable) => + vd -> to + case (vd, e) => + val fresh = vd.freshen defs :+= (fresh -> e) - from -> Variable(fresh) + vd -> fresh.toVariable }.toMap - val (ids, bds) = defs.unzip + val (vds, bds) = defs.unzip - letTuple(ids, tupleWrap(bds), replaceFromIDs(subst, body)) + letTuple(vds, tupleWrap(bds), exprOps.replaceFromSymbols(subst, body)) case _ => Application(fn, realArgs) @@ -341,7 +345,7 @@ trait Constructors { self: ExprOps => /** $encodingof expr.asInstanceOf[tpe], returns `expr` it it already is of type `tpe`. */ def asInstOf(expr: Expr, tpe: ClassType) = { - if (isSubtypeOf(expr.getType, tpe)) { + if (symbols.isSubtypeOf(expr.getType, tpe)) { expr } else { AsInstanceOf(expr, tpe) @@ -349,7 +353,7 @@ trait Constructors { self: ExprOps => } def isInstOf(expr: Expr, tpe: ClassType) = { - if (isSubtypeOf(expr.getType, tpe)) { + if (symbols.isSubtypeOf(expr.getType, tpe)) { BooleanLiteral(true) } else { IsInstanceOf(expr, tpe) @@ -362,6 +366,19 @@ trait Constructors { self: ExprOps => case _ => Require(pred, body) } + def tupleWrapArg(fun: Expr) = fun.getType match { + case FunctionType(args, res) if args.size > 1 => + val newArgs = fun match { + case Lambda(args, _) => args + case _ => args map (tpe => ValDef(FreshIdentifier("x", alwaysShowUniqueID = true), tpe)) + } + val res = ValDef(FreshIdentifier("res", alwaysShowUniqueID = true), TupleType(args)) + val patt = TuplePattern(None, newArgs map (arg => WildcardPattern(Some(arg)))) + Lambda(Seq(res), MatchExpr(res.toVariable, Seq(SimpleCase(patt, application(fun, newArgs map (_.toVariable)))))) + case _ => + fun + } + def ensur(e: Expr, pred: Expr) = { Ensuring(e, tupleWrapArg(pred)) } diff --git a/src/main/scala/inox/trees/Definitions.scala b/src/main/scala/inox/ast/Definitions.scala similarity index 61% rename from src/main/scala/inox/trees/Definitions.scala rename to src/main/scala/inox/ast/Definitions.scala index 7d4c8efd6..0274053bc 100644 --- a/src/main/scala/inox/trees/Definitions.scala +++ b/src/main/scala/inox/ast/Definitions.scala @@ -1,7 +1,7 @@ /* Copyright 2009-2016 EPFL, Lausanne */ -package leon -package purescala +package inox +package ast import scala.collection.mutable.{Map => MutableMap} @@ -18,45 +18,72 @@ trait Definitions { self: Trees => override def hashCode = id.hashCode } - abstract class LookupException(id: Identifier, what: String) extends Exception("Lookup failed for " + what + " with symbol " + id) + abstract class LookupException(id: Identifier, what: String) + extends Exception("Lookup failed for " + what + " with symbol " + id) case class FunctionLookupException(id: Identifier) extends LookupException(id, "function") case class ClassLookupException(id: Identifier) extends LookupException(id, "class") - /** - * A ValDef declares a formal parameter (with symbol [[id]]) to be of a certain type. - */ - case class ValDef(id: Identifier, tpe: Type) extends Definition with Typed { - def getType(implicit p: Program): Type = tpe + case class NotWellFormedException(id: Identifier, s: Symbols) + extends Exception(s"$id not well formed in ${s.asString}") + + /** Common super-type for [[ValDef]] and [[Expressions.Variable]]. + * + * Both types share much in common and being able to reason about them + * in a uniform manner can be useful in certain cases. + */ + private[ast] trait VariableSymbol extends Typed { + val id: Identifier + val tpe: Type + + def getType(implicit s: Symbols): Type = tpe + + override def equals(that: Any): Boolean = that match { + case vs: VariableSymbol => id == vs.id && tpe == vs.tpe + case _ => false + } + def hashCode: Int = 61 * id.hashCode + tpe.hashCode + } + + /** + * A ValDef declares a formal parameter (with symbol [[id]]) to be of a certain type. + */ + case class ValDef(id: Identifier, tpe: Type) extends Definition with VariableSymbol { /** Transform this [[ValDef]] into a [[Expressions.Variable Variable]] */ def toVariable: Variable = Variable(id, tpe) + def freshen: ValDef = ValDef(id.freshen, tpe).copiedFrom(this) } /** A wrapper for a program. For now a program is simply a single object. */ - case class Program(classes: Map[Identifier, ClassDef], functions: Map[Identifier, FunDef]) extends Tree { - lazy val callGraph = new CallGraph(this) - - private val typedClassCache: MutableMap[(Identifier, Seq[Type]), TypedClassDef] = MutableMap.empty + case class Symbols(classes: Map[Identifier, ClassDef], functions: Map[Identifier, FunDef]) + extends Tree + with TypeOps + with SymbolOps + with CallGraph + with Constructors + with Paths { + + val trees: self.type = self + val symbols: this.type = this + private implicit def s: Symbols = symbols + + private val typedClassCache: MutableMap[(Identifier, Seq[Type]), Option[TypedClassDef]] = MutableMap.empty def lookupClass(id: Identifier): Option[ClassDef] = classes.get(id) def lookupClass(id: Identifier, tps: Seq[Type]): Option[TypedClassDef] = - typedClassCache.getOrElseUpdated(id -> tps, lookupClass(id).typed(tps)) + typedClassCache.getOrElseUpdate(id -> tps, lookupClass(id).map(_.typed(tps))) def getClass(id: Identifier): ClassDef = lookupClass(id).getOrElse(throw ClassLookupException(id)) def getClass(id: Identifier, tps: Seq[Type]): TypedClassDef = lookupClass(id, tps).getOrElse(throw ClassLookupException(id)) - private val typedFunctionCache: MutableMap[(Identifier, Seq[Type]), TypedFunDef] = MutableMap.empty + private val typedFunctionCache: MutableMap[(Identifier, Seq[Type]), Option[TypedFunDef]] = MutableMap.empty def lookupFunction(id: Identifier): Option[FunDef] = functions.get(id) def lookupFunction(id: Identifier, tps: Seq[Type]): Option[TypedFunDef] = - typedFunctionCache.getOrElseUpdated(id -> tps, lookupFunction(id).typed(tps)) + typedFunctionCache.getOrElseUpdate(id -> tps, lookupFunction(id).map(_.typed(tps)(this))) def getFunction(id: Identifier): FunDef = lookupFunction(id).getOrElse(throw FunctionLookupException(id)) def getFunction(id: Identifier, tps: Seq[Type]): TypedFunDef = lookupFunction(id, tps).getOrElse(throw FunctionLookupException(id)) } - object Program { - lazy val empty: Program = Program(Nil) - } - case class TypeParameterDef(tp: TypeParameter) extends Definition { def freshen = TypeParameterDef(tp.freshen) val id = tp.id @@ -90,71 +117,69 @@ trait Definitions { self: Trees => sealed trait ClassDef extends Definition { val id: Identifier val tparams: Seq[TypeParameterDef] - val fields: Seq[ValDef] val flags: Set[ClassFlag] - val parent: Option[Identifier] - val children: Seq[Identifier] + def annotations: Set[String] = extAnnotations.keySet + def extAnnotations: Map[String, Seq[Option[Any]]] = flags.collect { case Annotation(s, args) => s -> args }.toMap - def hasParent = parent.isDefined + def root(implicit s: Symbols): ClassDef + def invariant(implicit s: Symbols): Option[FunDef] = { + val rt = root + if (rt ne this) rt.invariant + else flags.collect { case HasADTInvariant(id) => id }.headOption.map(s.getFunction) + } - def invariant(implicit p: Program): Option[FunDef] = - flags.collect { case HasADTInvariant(id) => id }.map(p.getFunction) + def hasInvariant(implicit s: Symbols): Boolean = invariant.isDefined - def annotations: Set[String] = extAnnotations.keySet - def extAnnotations: Map[String, Seq[Option[Any]]] = flags.collect { case Annotation(s, args) => s -> args }.toMap + val isAbstract: Boolean + + def typeArgs = tparams map (_.tp) - def ancestors(implicit p: Program): Seq[ClassDef] = parent - .map(p.getClass).toSeq - .flatMap(parentCls => parentCls +: parentCls.ancestors) + def typed(tps: Seq[Type])(implicit s: Symbols): TypedClassDef + def typed(implicit s: Symbols): TypedClassDef + } - def root(implicit p: Program) = ancestors.lastOption.getOrElse(this) + /** Abstract classes. */ + class AbstractClassDef(val id: Identifier, + val tparams: Seq[TypeParameterDef], + val children: Seq[Identifier], + val flags: Set[ClassFlag]) extends ClassDef { + val isAbstract = true - def descendants(implicit p: Program): Seq[ClassDef] = children - .map(p.getClass) - .flatMap(cd => cd +: cd.descendants) + def descendants(implicit s: Symbols): Seq[ClassDef] = children + .map(id => s.getClass(id) match { + case ccd: CaseClassDef => ccd + case _ => throw NotWellFormedException(id, s) + }) - def ccDescendants(implicit p: Program): Seq[CaseClassDef] = + def ccDescendants(implicit s: Symbols): Seq[CaseClassDef] = descendants collect { case ccd: CaseClassDef => ccd } - def isInductive(implicit p: Program): Boolean = { + def isInductive(implicit s: Symbols): Boolean = { def induct(tpe: Type, seen: Set[ClassDef]): Boolean = tpe match { case ct: ClassType => val tcd = ct.lookupClass.getOrElse(throw ClassLookupException(ct.id)) val root = tcd.cd.root - seen(root) || tcd.fields.forall(vd => induct(vd.getType, seen + root)) + seen(root) || (tcd match { + case tccd: TypedCaseClassDef => + tccd.fields.exists(vd => induct(vd.getType, seen + root)) + case _ => false + }) case TupleType(tpes) => - tpes.forall(tpe => induct(tpe, seen)) - case _ => true + tpes.exists(tpe => induct(tpe, seen)) + case _ => false } if (this == root && !this.isAbstract) false - else if (this != root) root.isInductive - else ccDescendants.forall { ccd => - ccd.fields.forall(vd => induct(vd.getType, Set(root))) + else ccDescendants.exists { ccd => + ccd.fields.exists(vd => induct(vd.getType, Set(root))) } } - val isAbstract: Boolean - - def typeArgs = tparams map (_.tp) + def root(implicit s: Symbols): ClassDef = this - def typed(tps: Seq[Type]): TypedClassDef - def typed: TypedClassDef - } - - /** Abstract classes. */ - class AbstractClassDef(val id: Identifier, - val tparams: Seq[TypeParameterDef], - val parent: Option[Identifier], - val children: Seq[Identifier], - val flags: Set[Flag]) extends ClassDef { - - val fields = Nil - val isAbstract = true - - def typed: TypedAbstractClassDef = typed(tparams.map(_.tp)) - def typed(tps: Seq[Type]): TypedAbstractClassDef = { + def typed(implicit s: Symbols): TypedAbstractClassDef = typed(tparams.map(_.tp)) + def typed(tps: Seq[Type])(implicit s: Symbols): TypedAbstractClassDef = { require(tps.length == tparams.length) TypedAbstractClassDef(this, tps) } @@ -165,9 +190,8 @@ trait Definitions { self: Trees => val tparams: Seq[TypeParameterDef], val parent: Option[Identifier], val fields: Seq[ValDef], - val flags: Set[Flag]) extends ClassDef { + val flags: Set[ClassFlag]) extends ClassDef { - val children = Nil val isAbstract = false def selectorID2Index(id: Identifier) : Int = { @@ -181,8 +205,10 @@ trait Definitions { self: Trees => } else index } - def typed: TypedCaseClassDef = typed(tparams.map(_.tp)) - def typed(tps: Seq[Type]): TypedCaseClassDef = { + def root(implicit s: Symbols): ClassDef = parent.map(id => s.getClass(id).root).getOrElse(this) + + def typed(implicit s: Symbols): TypedCaseClassDef = typed(tparams.map(_.tp)) + def typed(tps: Seq[Type])(implicit s: Symbols): TypedCaseClassDef = { require(tps.length == tparams.length) TypedCaseClassDef(this, tps) } @@ -191,33 +217,37 @@ trait Definitions { self: Trees => sealed abstract class TypedClassDef extends Tree { val cd: ClassDef val tps: Seq[Type] - implicit val program: Program + implicit val symbols: Symbols val id: Identifier = cd.id + + lazy val root: TypedClassDef = cd.root.typed(tps) + lazy val invariant: Option[TypedFunDef] = cd.invariant.map(_.typed(tps)) + lazy val hasInvariant: Boolean = invariant.isDefined + + def toType = ClassType(cd.id, tps) + } + + case class TypedAbstractClassDef(cd: AbstractClassDef, tps: Seq[Type])(implicit symbols: Symbols) extends TypedClassDef { + def descendants: Seq[TypedClassDef] = cd.descendants.map(_.typed(tps)) + def ccDescendants: Seq[TypedCaseClassDef] = cd.ccDescendants.map(_.typed(tps)) + } + + case class TypedCaseClassDef(cd: CaseClassDef, tps: Seq[Type])(implicit symbols: Symbols) extends TypedClassDef { lazy val fields: Seq[ValDef] = { val tmap = (cd.typeArgs zip tps).toMap if (tmap.isEmpty) cd.fields - else classDef.fields.map(vd => vd.copy(tpe = instantiateType(vd.getType, tmap))) + else cd.fields.map(vd => vd.copy(tpe = symbols.instantiateType(vd.getType, tmap))) } - lazy val parent: Option[TypedAbstractClassDef] = cd.parent.map(id => p.getClass(id) match { + lazy val fieldsTypes = fields.map(_.tpe) + + lazy val parent: Option[TypedAbstractClassDef] = cd.parent.map(id => symbols.getClass(id) match { case acd: AbstractClassDef => TypedAbstractClassDef(acd, tps) - case _ => scala.sys.error("Expected parent to be an AbstractClassDef") + case _ => throw NotWellFormedException(id, symbols) }) - - lazy val invariant: Option[TypedFunDef] = cd.invariant.map { fd => - TypedFunDef(fd, tps) - } - - lazy val root = parent.map(_.root).getOrElse(this) - - def descendants: Seq[TypedClassDef] = cd.descendants.map(_.typed(tps)) - def ccDescendants: Seq[TypedCaseClassDef] = cd.ccDescendants.map(_.typed(tps)) } - case class TypedAbstractClassDef(cd: AbstractClassDef, tps: Seq[Type])(implicit program: Program) extends TypedClassDef - case class TypedCaseClassDef(cd: AbstractClassDef, tps: Seq[Type])(implicit program: Program) extends TypedClassDef - /** Function/method definition. * @@ -238,18 +268,18 @@ trait Definitions { self: Trees => val params: Seq[ValDef], val returnType: Type, val fullBody: Expr, - val flags: Set[Flag] + val flags: Set[FunctionFlag] ) extends Definition { /* Body manipulation */ - lazy val body: Option[Expr] = withoutSpec(fullBody) - lazy val precondition = preconditionOf(fullBody) + lazy val body: Option[Expr] = exprOps.withoutSpec(fullBody) + lazy val precondition = exprOps.preconditionOf(fullBody) lazy val precOrTrue = precondition getOrElse BooleanLiteral(true) - lazy val postcondition = postconditionOf(fullBody) + lazy val postcondition = exprOps.postconditionOf(fullBody) lazy val postOrTrue = postcondition getOrElse { - val arg = ValDef(FreshIdentifier("res", returnType, alwaysShowUniqueID = true)) + val arg = ValDef(FreshIdentifier("res", alwaysShowUniqueID = true), returnType) Lambda(Seq(arg), BooleanLiteral(true)) } @@ -262,36 +292,28 @@ trait Definitions { self: Trees => case Annotation(s, args) => s -> args }.toMap - def canBeLazyField = flags.contains(IsField(true)) && params.isEmpty && tparams.isEmpty - def canBeStrictField = flags.contains(IsField(false)) && params.isEmpty && tparams.isEmpty - def canBeField = canBeLazyField || canBeStrictField - def isRealFunction = !canBeField - def isInvariant = flags contains IsADTInvariant - /* Wrapping in TypedFunDef */ - def typed(tps: Seq[Type]): TypedFunDef = { + def typed(tps: Seq[Type])(implicit s: Symbols): TypedFunDef = { assert(tps.size == tparams.size) TypedFunDef(this, tps) } - def typed: TypedFunDef = typed(tparams.map(_.tp)) + def typed(implicit s: Symbols): TypedFunDef = typed(tparams.map(_.tp)) /* Auxiliary methods */ - def isRecursive(implicit p: Program) = p.callGraph.transitiveCallees(this) contains this - - def paramIds = params map { _.id } + def isRecursive(implicit s: Symbols) = s.transitiveCallees(this) contains this def typeArgs = tparams map (_.tp) - def applied(args: Seq[Expr]): FunctionInvocation = Constructors.functionInvocation(this, args) - def applied = FunctionInvocation(this.typed, this.paramIds map Variable) + def applied(args: Seq[Expr])(implicit s: Symbols): FunctionInvocation = s.functionInvocation(this, args) + def applied = FunctionInvocation(id, typeArgs, params map (_.toVariable)) } // Wrapper for typing function according to valuations for type parameters - case class TypedFunDef(fd: FunDef, tps: Seq[Type])(implicit program: Program) extends Tree { + case class TypedFunDef(fd: FunDef, tps: Seq[Type])(implicit symbols: Symbols) extends Tree { val id = fd.id def signature = { @@ -306,17 +328,17 @@ trait Definitions { self: Trees => (fd.typeArgs zip tps).toMap.filter(tt => tt._1 != tt._2) } - def translated(t: Type): Type = instantiateType(t, typesMap) + def translated(t: Type): Type = symbols.instantiateType(t, typesMap) - def translated(e: Expr): Expr = instantiateType(e, typesMap, paramsMap) + def translated(e: Expr): Expr = symbols.instantiateType(e, typesMap) /** A mapping from this [[TypedFunDef]]'s formal parameters to real arguments * * @param realArgs The arguments to which the formal argumentas are mapped - * */ + */ def paramSubst(realArgs: Seq[Expr]) = { require(realArgs.size == params.size) - (paramIds zip realArgs).toMap + (params zip realArgs).toMap } /** Substitute this [[TypedFunDef]]'s formal parameters with real arguments in some expression @@ -325,33 +347,21 @@ trait Definitions { self: Trees => * @param e The expression in which the substitution will take place */ def withParamSubst(realArgs: Seq[Expr], e: Expr) = { - replaceFromIDs(paramSubst(realArgs), e) + exprOps.replaceFromSymbols(paramSubst(realArgs), e) } def applied(realArgs: Seq[Expr]): FunctionInvocation = { - FunctionInvocation(fd, tps, realArgs) + FunctionInvocation(id, tps, realArgs) } - def applied: FunctionInvocation = - applied(params map { _.toVariable }) + def applied: FunctionInvocation = applied(params map { _.toVariable }) - /** - * Params will return ValDefs instantiated with the correct types - * For such a ValDef(id,tp) it may hold that (id.getType != tp) - */ - lazy val (params: Seq[ValDef], paramsMap: Map[Identifier, Identifier]) = { + /** Params will contain ValDefs instantiated with the correct types */ + lazy val params: Seq[ValDef] = { if (typesMap.isEmpty) { - (fd.params, Map()) + fd.params } else { - val newParams = fd.params.map { vd => - val newTpe = translated(vd.getType) - val newId = FreshIdentifier(vd.id.name, newTpe, true).copiedFrom(vd.id) - vd.copy(id = newId).setPos(vd) - } - - val paramsMap: Map[Identifier, Identifier] = (fd.params zip newParams).map { case (vd1, vd2) => vd1.id -> vd2.id }.toMap - - (newParams, paramsMap) + fd.params.map(vd => vd.copy(tpe = translated(vd.getType))) } } @@ -359,8 +369,6 @@ trait Definitions { self: Trees => lazy val returnType: Type = translated(fd.returnType) - lazy val paramIds = params map { _.id } - lazy val fullBody = translated(fd.fullBody) lazy val body = fd.body map translated lazy val precondition = fd.precondition map translated diff --git a/src/main/scala/inox/ast/ExprOps.scala b/src/main/scala/inox/ast/ExprOps.scala new file mode 100644 index 000000000..81555ad82 --- /dev/null +++ b/src/main/scala/inox/ast/ExprOps.scala @@ -0,0 +1,163 @@ +/* Copyright 2009-2016 EPFL, Lausanne */ + +package inox +package ast + +import utils._ + +/** Provides functions to manipulate [[purescala.Expressions]]. + * + * This object provides a few generic operations on Leon expressions, + * as well as some common operations. + * + * The generic operations lets you apply operations on a whole tree + * expression. You can look at: + * - [[GenTreeOps.fold foldRight]] + * - [[GenTreeOps.preTraversal preTraversal]] + * - [[GenTreeOps.postTraversal postTraversal]] + * - [[GenTreeOps.preMap preMap]] + * - [[GenTreeOps.postMap postMap]] + * - [[GenTreeOps.genericTransform genericTransform]] + * + * These operations usually take a higher order function that gets applied to the + * expression tree in some strategy. They provide an expressive way to build complex + * operations on Leon expressions. + * + */ +trait ExprOps extends GenTreeOps { + val trees: Trees + import trees._ + + type SubTree = Expr + + val Deconstructor = Operator + + /** Replaces bottom-up variables by looking up for them in a map */ + def replaceFromSymbols(substs: Map[Variable, Expr], expr: Expr): Expr = postMap { + case v: Variable => substs.get(v) + case _ => None + } (expr) + + /** Replaces bottom-up variables by looking them up in a map from [[ValDef]] to expressions */ + def replaceFromSymbols(substs: Map[ValDef, Expr], expr: Expr): Expr = postMap { + case v: Variable => substs.get(v.toVal) + case _ => None + } (expr) + + /** Returns the set of free variables in an expression */ + def variablesOf(expr: Expr): Set[Variable] = { + fold[Set[Variable]] { + case (e, subs) => + val subvs = subs.flatten.toSet + e match { + case v: Variable => subvs + v + case Let(vd, _, _) => subvs - vd.toVariable + case MatchExpr(_, cses) => subvs -- cses.flatMap(_.pattern.binders).map(_.toVariable) + case Lambda(args, _) => subvs -- args.map(_.toVariable) + case Forall(args, _) => subvs -- args.map(_.toVariable) + case _ => subvs + } + }(expr) + } + + /** Returns true if the expression contains a function call */ + def containsFunctionCalls(expr: Expr): Boolean = { + exists{ + case _: FunctionInvocation => true + case _ => false + }(expr) + } + + /** Returns all Function calls found in the expression */ + def functionCallsOf(expr: Expr): Set[FunctionInvocation] = { + collect[FunctionInvocation] { + case f: FunctionInvocation => Set(f) + case _ => Set() + }(expr) + } + + /** Returns '''true''' if the formula is Ground, + * which means that it does not contain any variables + * ([[purescala.ExprOps#variablesOf]] e is empty) + */ + def isGround(e: Expr): Boolean = variablesOf(e).isEmpty + + /** Returns '''true''' if the formula is simple, + * which means that it requires no special encoding for an + * unrolling solver. See implementation for what this means exactly. + */ + def isSimple(e: Expr): Boolean = !exists { + case (_: Assert) | (_: Ensuring) | + (_: Forall) | (_: Lambda) | + (_: FunctionInvocation) | (_: Application) => true + case _ => false + } (e) + + /* Checks if a given expression is 'real' and does not contain generic + * values. */ + def isRealExpr(v: Expr): Boolean = { + !exists { + case gv: GenericValue => true + case _ => false + }(v) + } + + override def formulaSize(e: Expr): Int = e match { + case ml: MatchExpr => + super.formulaSize(e) + ml.cases.map(cs => patternOps.formulaSize(cs.pattern)).sum + case _ => + super.formulaSize(e) + } + + /** Returns if this expression behaves as a purely functional construct, + * i.e. always returns the same value (for the same environment) and has no side-effects + */ + def isPurelyFunctional(e: Expr): Boolean = exists { + case _ : Error => false + case _ => true + }(e) + + /** Extracts the body without its specification + * + * [[Expressions.Expr]] trees contain its specifications as part of certain nodes. + * This function helps extracting only the body part of an expression + * + * @return An option type with the resulting expression if not [[Expressions.NoTree]] + * @see [[Expressions.Ensuring]] + * @see [[Expressions.Require]] + */ + def withoutSpec(expr: Expr): Option[Expr] = expr match { + case Let(i, e, b) => withoutSpec(b).map(Let(i, e, _)) + case Require(pre, b) => Option(b).filterNot(_.isInstanceOf[NoTree]) + case Ensuring(Require(pre, b), post) => Option(b).filterNot(_.isInstanceOf[NoTree]) + case Ensuring(b, post) => Option(b).filterNot(_.isInstanceOf[NoTree]) + case b => Option(b).filterNot(_.isInstanceOf[NoTree]) + } + + /** Returns the precondition of an expression wrapped in Option */ + def preconditionOf(expr: Expr): Option[Expr] = expr match { + case Let(i, e, b) => preconditionOf(b).map(Let(i, e, _).copiedFrom(expr)) + case Require(pre, _) => Some(pre) + case Ensuring(Require(pre, _), _) => Some(pre) + case b => None + } + + /** Returns the postcondition of an expression wrapped in Option */ + def postconditionOf(expr: Expr): Option[Expr] = expr match { + case Let(i, e, b) => postconditionOf(b).map(Let(i, e, _).copiedFrom(expr)) + case Ensuring(_, post) => Some(post) + case _ => None + } + + /** Returns a tuple of precondition, the raw body and the postcondition of an expression */ + def breakDownSpecs(e: Expr) = (preconditionOf(e), withoutSpec(e), postconditionOf(e)) + + def preTraversalWithParent(f: (Expr, Option[Tree]) => Unit, initParent: Option[Tree] = None)(e: Expr): Unit = { + val rec = preTraversalWithParent(f, Some(e)) _ + + f(e, initParent) + + val Deconstructor(es, _) = e + es foreach rec + } +} diff --git a/src/main/scala/inox/trees/Expressions.scala b/src/main/scala/inox/ast/Expressions.scala similarity index 66% rename from src/main/scala/inox/trees/Expressions.scala rename to src/main/scala/inox/ast/Expressions.scala index 1215fd7c0..1bfd1b342 100644 --- a/src/main/scala/inox/trees/Expressions.scala +++ b/src/main/scala/inox/ast/Expressions.scala @@ -1,14 +1,7 @@ /* Copyright 2009-2016 EPFL, Lausanne */ -package leon.purescala - -import Common._ -import Types._ -import TypeOps._ -import Definitions._ -import Extractors._ -import Constructors._ -import ExprOps.replaceFromIDs +package inox +package ast /** Expression definitions for Pure Scala. * @@ -19,7 +12,7 @@ import ExprOps.replaceFromIDs * case classes, with no behaviour. In particular, they do not perform smart * rewriting. What you build is what you get. For example, * {{{ - * And(BooleanLiteral(true), Variable(id)) != Variable(id) + * And(BooleanLiteral(true), Variable(id, BooleanType)) != Variable(id, BooleanType) * }}} * because the ``And`` constructor will simply build a tree without checking for * optimization opportunities. Unless you need exact control on the structure @@ -32,8 +25,8 @@ import ExprOps.replaceFromIDs */ trait Expressions { self: Trees => - private def checkParamTypes(real: Seq[Type], formal: Seq[Type], result: Type): Type = { - if (real zip formal forall { case (real, formal) => isSubtypeOf(real, formal)} ) { + private def checkParamTypes(real: Seq[Type], formal: Seq[Type], result: Type)(implicit s: Symbols): Type = { + if (real zip formal forall { case (real, formal) => s.isSubtypeOf(real, formal)} ) { result.unveilUntyped } else { //println(s"Failed to type as $result") @@ -52,6 +45,12 @@ trait Expressions { self: Trees => } + /** Stands for an undefined Expr, similar to `???` or `null` */ + case class NoTree(tpe: Type) extends Expr with Terminal { + val getType = tpe + } + + /* Specifications */ /** Computational errors (unmatched case, taking min of an empty set, @@ -63,7 +62,44 @@ trait Expressions { self: Trees => * @param description The description of the error */ case class Error(tpe: Type, description: String) extends Expr with Terminal { - def getType(implicit p: Program): Type = tpe + def getType(implicit s: Symbols): Type = tpe + } + + /** Precondition of an [[Expressions.Expr]]. Corresponds to the Leon keyword *require* + * + * @param pred The precondition formula inside ``require(...)`` + * @param body The body following the ``require(...)`` + */ + case class Require(pred: Expr, body: Expr) extends Expr with CachingTyped { + protected def computeType(implicit s: Symbols): Type = { + if (pred.getType == BooleanType) body.getType + else Untyped + } + } + + /** Postcondition of an [[Expressions.Expr]]. Corresponds to the Leon keyword *ensuring* + * + * @param body The body of the expression. It can contain at most one [[Expressions.Require]] sub-expression. + * @param pred The predicate to satisfy. It should be a function whose argument's type can handle the type of the body + */ + case class Ensuring(body: Expr, pred: Expr) extends Expr with CachingTyped { + require(pred.isInstanceOf[Lambda]) + + protected def computeType(implicit s: Symbols) = pred.getType match { + case FunctionType(Seq(bodyType), BooleanType) if s.isSubtypeOf(body.getType, bodyType) => + body.getType + case _ => + Untyped + } + + /** Converts this ensuring clause to the body followed by an assert statement */ + def toAssert(implicit s: Symbols): Expr = { + val res = ValDef(FreshIdentifier("res", true), getType) + Let(res, body, Assert( + s.application(pred, Seq(res.toVariable)), + Some("Postcondition failed @" + this.getPos), res.toVariable + )) + } } /** Local assertions with customizable error message @@ -73,7 +109,7 @@ trait Expressions { self: Trees => * @param body The expression following `assert(..., ...)` */ case class Assert(pred: Expr, error: Option[String], body: Expr) extends Expr with CachingTyped { - protected def computeType(implicit p: Program): Type = { + protected def computeType(implicit s: Symbols): Type = { if (pred.getType == BooleanType) body.getType else Untyped } @@ -83,23 +119,22 @@ trait Expressions { self: Trees => /** Variable * @param id The identifier of this variable */ - case class Variable(id: Identifier) extends Expr with Terminal with CachingTyped { - protected def computeType(implicit p: Program): Type = id.getType + case class Variable(id: Identifier, tpe: Type) extends Expr with Terminal with VariableSymbol { + /** Transforms this [[Variable]] into a [[Definitions.ValDef ValDef]] */ + def toVal = ValDef(id, tpe) } /** $encodingof `val ... = ...; ...` * - * @param binder The identifier used in body, defined just after '''val''' + * @param vd The ValDef used in body, defined just after '''val''' * @param value The value assigned to the identifier, after the '''=''' sign * @param body The expression following the ``val ... = ... ;`` construct * @see [[purescala.Constructors#let purescala's constructor let]] */ - case class Let(binder: Identifier, value: Expr, body: Expr) extends Expr with CachingTyped { - protected def computeType(implicit p: Program): Type = { - // We can't demand anything sticter here, because some binders are - // typed context-wise - if (typesCompatible(value.getType, binder.getType)) + case class Let(vd: ValDef, value: Expr, body: Expr) extends Expr with CachingTyped { + protected def computeType(implicit s: Symbols): Type = { + if (s.isSubtypeOf(value.getType, vd.tpe)) body.getType else { Untyped @@ -111,9 +146,9 @@ trait Expressions { self: Trees => /** $encodingof `callee(args...)`, where [[callee]] is an expression of a function type (not a method) */ case class Application(callee: Expr, args: Seq[Expr]) extends Expr with CachingTyped { - protected def computeType(implicit p: Program): Type = callee.getType match { + protected def computeType(implicit s: Symbols): Type = callee.getType match { case FunctionType(from, to) => - checkParamTypes(args, from, to) + checkParamTypes(args.map(_.getType), from, to) case _ => Untyped } @@ -121,30 +156,30 @@ trait Expressions { self: Trees => /** $encodingof `(args) => body` */ case class Lambda(args: Seq[ValDef], body: Expr) extends Expr with CachingTyped { - protected def computeType(implicit p: Program): Type = + protected def computeType(implicit s: Symbols): Type = FunctionType(args.map(_.getType), body.getType).unveilUntyped def paramSubst(realArgs: Seq[Expr]) = { require(realArgs.size == args.size) - (args map { _.id } zip realArgs).toMap + (args zip realArgs).toMap } def withParamSubst(realArgs: Seq[Expr], e: Expr) = { - replaceFromIDs(paramSubst(realArgs), e) + exprOps.replaceFromSymbols(paramSubst(realArgs), e) } } /* Universal Quantification */ case class Forall(args: Seq[ValDef], body: Expr) extends Expr with CachingTyped { - protected def computeType(implicit p: Program): Type = body.getType + protected def computeType(implicit s: Symbols): Type = body.getType } /* Control flow */ /** $encodingof `function(...)` (function invocation) */ case class FunctionInvocation(id: Identifier, tps: Seq[Type], args: Seq[Expr]) extends Expr with CachingTyped { - protected def computeType(implicit p: Program): Type = p.lookupFunction(id) match { + protected def computeType(implicit s: Symbols): Type = s.lookupFunction(id) match { case Some(fd) => val tfd = fd.typed(tps) require(args.size == tfd.params.size) @@ -155,8 +190,8 @@ trait Expressions { self: Trees => /** $encodingof `if(...) ... else ...` */ case class IfExpr(cond: Expr, thenn: Expr, elze: Expr) extends Expr with CachingTyped { - protected def computeType(implicit p: Program): Type = - leastUpperBound(thenn.getType, elze.getType).getOrElse(Untyped).unveilUntyped + protected def computeType(implicit s: Symbols): Type = + s.leastUpperBound(thenn.getType, elze.getType).getOrElse(Untyped).unveilUntyped } /** $encodingof `... match { ... }` @@ -169,8 +204,8 @@ trait Expressions { self: Trees => */ case class MatchExpr(scrutinee: Expr, cases: Seq[MatchCase]) extends Expr with CachingTyped { require(cases.nonEmpty) - protected def computeType(implicit p: Program): Type = - leastUpperBound(cases.map(_.rhs.getType)).getOrElse(Untyped).unveilUntyped + protected def computeType(implicit s: Symbols): Type = + s.leastUpperBound(cases.map(_.rhs.getType)).getOrElse(Untyped).unveilUntyped } /** $encodingof `case pattern [if optGuard] => rhs` @@ -190,12 +225,12 @@ trait Expressions { self: Trees => */ sealed abstract class Pattern extends Tree { val subPatterns: Seq[Pattern] - val binder: Option[Identifier] + val binder: Option[ValDef] private def subBinders = subPatterns.flatMap(_.binders).toSet - def binders: Set[Identifier] = subBinders ++ binder.toSet + def binders: Set[ValDef] = subBinders ++ binder.toSet - def withBinder(b : Identifier) = { this match { + def withBinder(b: ValDef) = { this match { case Pattern(None, subs, builder) => builder(Some(b), subs) case other => other }}.copiedFrom(this) @@ -205,12 +240,12 @@ trait Expressions { self: Trees => * * If [[binder]] is empty, consider a wildcard `_` in its place. */ - case class InstanceOfPattern(binder: Option[Identifier], ct: ClassType) extends Pattern { + case class InstanceOfPattern(binder: Option[ValDef], ct: ClassType) extends Pattern { val subPatterns = Seq() } /** Pattern encoding `case _ => `, or `case binder => ` if identifier [[binder]] is present */ - case class WildcardPattern(binder: Option[Identifier]) extends Pattern { // c @ _ + case class WildcardPattern(binder: Option[ValDef]) extends Pattern { // c @ _ val subPatterns = Seq() } @@ -218,28 +253,33 @@ trait Expressions { self: Trees => * * If [[binder]] is empty, consider a wildcard `_` in its place. */ - case class CaseClassPattern(binder: Option[Identifier], ct: CaseClassType, subPatterns: Seq[Pattern]) extends Pattern + case class CaseClassPattern(binder: Option[ValDef], ct: ClassType, subPatterns: Seq[Pattern]) extends Pattern /** Pattern encoding tuple pattern `case binder @ (subPatterns...) =>` * * If [[binder]] is empty, consider a wildcard `_` in its place. */ - case class TuplePattern(binder: Option[Identifier], subPatterns: Seq[Pattern]) extends Pattern + case class TuplePattern(binder: Option[ValDef], subPatterns: Seq[Pattern]) extends Pattern /** Pattern encoding like `case binder @ 0 => ...` or `case binder @ "Foo" => ...` * * If [[binder]] is empty, consider a wildcard `_` in its place. */ - case class LiteralPattern[+T](binder: Option[Identifier], lit : Literal[T]) extends Pattern { + case class LiteralPattern[+T](binder: Option[ValDef], lit: Literal[T]) extends Pattern { val subPatterns = Seq() } /** A custom pattern defined through an object's `unapply` function */ - case class UnapplyPattern(binder: Option[Identifier], unapplyFun: TypedFunDef, subPatterns: Seq[Pattern]) extends Pattern { + case class UnapplyPattern(binder: Option[ValDef], id: Identifier, tps: Seq[Type], subPatterns: Seq[Pattern]) extends Pattern { // Hacky, but ok - lazy val optionType = unapplyFun.returnType.asInstanceOf[AbstractClassType] - lazy val Seq(noneType, someType) = optionType.knownCCDescendants.sortBy(_.fields.size) - lazy val someValue = someType.classDef.fields.head + def optionType(implicit s: Symbols) = s.getFunction(id, tps).returnType.asInstanceOf[ClassType] + def someType(implicit s: Symbols): ClassType = { + val optionChildren = optionType.tcd.asInstanceOf[TypedAbstractClassDef].ccDescendants.sortBy(_.fields.size) + val someTcd = optionChildren(1) + ClassType(someTcd.id, someTcd.tps) + } + + def someValue(implicit s: Symbols): ValDef = someType.tcd.asInstanceOf[TypedCaseClassDef].fields.head /** Construct a pattern matching against unapply(scrut) (as an if-expression) * @@ -247,23 +287,23 @@ trait Expressions { self: Trees => * @param noneCase The expression that will happen if unapply(scrut) is None * @param someCase How unapply(scrut).get will be handled in case it exists */ - def patternMatch(scrut: Expr, noneCase: Expr, someCase: Expr => Expr): Expr = { + def patternMatch(scrut: Expr, noneCase: Expr, someCase: Expr => Expr)(implicit s: Symbols): Expr = { // We use this hand-coded if-then-else because we don't want to generate // match exhaustiveness checks in the program - val binder = FreshIdentifier("unap", optionType, true) + val vd = ValDef(FreshIdentifier("unap", true), optionType) Let( - binder, - FunctionInvocation(unapplyFun, Seq(scrut)), + vd, + FunctionInvocation(id, tps, Seq(scrut)), IfExpr( - IsInstanceOf(Variable(binder), someType), - someCase(CaseClassSelector(someType, Variable(binder), someValue.id)), + IsInstanceOf(vd.toVariable, someType), + someCase(CaseClassSelector(someType, vd.toVariable, someValue.id)), noneCase ) ) } /** Inlined .get method */ - def get(scrut: Expr) = patternMatch( + def get(scrut: Expr)(implicit s: Symbols) = patternMatch( scrut, Error(optionType.tps.head, "None.get"), e => e @@ -272,31 +312,37 @@ trait Expressions { self: Trees => /** Selects Some.v field without type-checking. * Use in a context where scrut.isDefined returns true. */ - def getUnsafe(scrut: Expr) = CaseClassSelector( + def getUnsafe(scrut: Expr)(implicit s: Symbols) = CaseClassSelector( someType, - FunctionInvocation(unapplyFun, Seq(scrut)), + FunctionInvocation(id, tps, Seq(scrut)), someValue.id ) - def isSome(scrut: Expr) = IsInstanceOf(FunctionInvocation(unapplyFun, Seq(scrut)), someType) + def isSome(scrut: Expr)(implicit s: Symbols) = + IsInstanceOf(FunctionInvocation(id, tps, Seq(scrut)), someType) } // Extracts without taking care of the binder. (contrary to Extractos.Pattern) - object PatternExtractor extends TreeExtractor[Pattern] { + object PatternExtractor extends TreeExtractor { + val trees: self.type = self + type SubTree = Pattern + def unapply(e: Pattern): Option[(Seq[Pattern], (Seq[Pattern]) => Pattern)] = e match { case (_: InstanceOfPattern) | (_: WildcardPattern) | (_: LiteralPattern[_]) => Some(Seq(), es => e) - case CaseClassPattern(binder, ct, subpatterns) => - Some(subpatterns, es => CaseClassPattern(binder, ct, es)) - case TuplePattern(binder, subpatterns) => - Some(subpatterns, es => TuplePattern(binder, es)) - case UnapplyPattern(binder, unapplyFun, subpatterns) => - Some(subpatterns, es => UnapplyPattern(binder, unapplyFun, es)) + case CaseClassPattern(vd, ct, subpatterns) => + Some(subpatterns, es => CaseClassPattern(vd, ct, es)) + case TuplePattern(vd, subpatterns) => + Some(subpatterns, es => TuplePattern(vd, es)) + case UnapplyPattern(vd, id, tps, subpatterns) => + Some(subpatterns, es => UnapplyPattern(vd, id, tps, es)) case _ => None } } - object PatternOps extends GenTreeOps[Pattern] { + object patternOps extends GenTreeOps { + val trees: self.type = self + type SubTree = Pattern val Deconstructor = PatternExtractor } @@ -308,44 +354,44 @@ trait Expressions { self: Trees => /** $encodingof a character literal */ case class CharLiteral(value: Char) extends Literal[Char] { - def getType(implicit p: Program): Type = CharType + def getType(implicit s: Symbols): Type = CharType } /** $encodingof a 32-bit integer literal */ case class IntLiteral(value: Int) extends Literal[Int] { - def getType(implicit p: Program): Type = Int32Type + def getType(implicit s: Symbols): Type = Int32Type } /** $encodingof a n-bit bitvector literal */ case class BVLiteral(value: BigInt, size: Int) extends Literal[BigInt] { - def getType(implicit p: Program): Type = BVType(size) + def getType(implicit s: Symbols): Type = BVType(size) } /** $encodingof an infinite precision integer literal */ case class IntegerLiteral(value: BigInt) extends Literal[BigInt] { - def getType(implicit p: Program): Type = IntegerType + def getType(implicit s: Symbols): Type = IntegerType } /** $encodingof a fraction literal */ case class FractionalLiteral(numerator: BigInt, denominator: BigInt) extends Literal[(BigInt, BigInt)] { val value = (numerator, denominator) - def getType(implicit p: Program): Type = RealType + def getType(implicit s: Symbols): Type = RealType } /** $encodingof a boolean literal '''true''' or '''false''' */ case class BooleanLiteral(value: Boolean) extends Literal[Boolean] { - def getType(implicit p: Program): Type = BooleanType + def getType(implicit s: Symbols): Type = BooleanType } /** $encodingof the unit literal `()` */ - case object UnitLiteral extends Literal[Unit] { + case class UnitLiteral() extends Literal[Unit] { val value = () - def getType(implicit p: Program): Type = UnitType + def getType(implicit s: Symbols): Type = UnitType } /** $encodingof a string literal */ case class StringLiteral(value: String) extends Literal[String] { - def getType(implicit p: Program): Type = StringType + def getType(implicit s: Symbols): Type = StringType } @@ -353,7 +399,7 @@ trait Expressions { self: Trees => * This is useful e.g. to present counterexamples of generic types. */ case class GenericValue(tp: TypeParameter, id: Int) extends Expr with Terminal { - def getType(implicit p: Program): Type = tp + def getType(implicit s: Symbols): Type = tp } @@ -362,17 +408,17 @@ trait Expressions { self: Trees => * @param ct The case class name and inherited attributes * @param args The arguments of the case class */ - case class CaseClass(ct: CaseClassType, args: Seq[Expr]) extends Expr with CachingTyped { - protected def computeType(implicit p: Program): Type = ct.lookupClass match { - case Some(tcd) => checkParamTypes(args.map(_.getType), tcd.fieldsTypes, ct) + case class CaseClass(ct: ClassType, args: Seq[Expr]) extends Expr with CachingTyped { + protected def computeType(implicit s: Symbols): Type = ct.lookupClass match { + case Some(tcd: TypedCaseClassDef) => checkParamTypes(args.map(_.getType), tcd.fieldsTypes, ct) case _ => Untyped } } /** $encodingof `.isInstanceOf[...]` */ case class IsInstanceOf(expr: Expr, classType: ClassType) extends Expr with CachingTyped { - protected def computeType(implicit p: Program): Type = - if (isSubtypeOf(expr.getType, classType)) BooleanType else Untyped + protected def computeType(implicit s: Symbols): Type = + if (s.isSubtypeOf(expr.getType, classType)) BooleanType else Untyped } /** $encodingof `expr.asInstanceOf[tpe]` @@ -381,8 +427,8 @@ trait Expressions { self: Trees => * if bodies. */ case class AsInstanceOf(expr: Expr, tpe: ClassType) extends Expr with CachingTyped { - protected def computeType(implicit p: Program): Type = - if (typesCompatible(tpe, expr.getType)) tpe else Untyped + protected def computeType(implicit s: Symbols): Type = + if (s.typesCompatible(tpe, expr.getType)) tpe else Untyped } /** $encodingof `value.selector` where value is of a case class type @@ -390,10 +436,10 @@ trait Expressions { self: Trees => * If you are not sure about the requirement you should use * [[purescala.Constructors#caseClassSelector purescala's constructor caseClassSelector]] */ - case class CaseClassSelector(classType: CaseClassType, caseClass: Expr, selector: Identifier) extends Expr with CachingTyped { - protected def computeType(implicit p: Program): Type = classType.lookupClass match { + case class CaseClassSelector(classType: ClassType, caseClass: Expr, selector: Identifier) extends Expr with CachingTyped { + protected def computeType(implicit s: Symbols): Type = classType.lookupClass match { case Some(tcd: TypedCaseClassDef) => - val index = tcd.selectorID2Index(selector) + val index = tcd.cd.selectorID2Index(selector) if (classType == caseClass.getType) { tcd.fieldsTypes(index) } else { @@ -405,8 +451,8 @@ trait Expressions { self: Trees => /** $encodingof `... == ...` */ case class Equals(lhs: Expr, rhs: Expr) extends Expr with CachingTyped { - protected def computeType(implicit p: Program): Type = { - if (typesCompatible(lhs.getType, rhs.getType)) BooleanType + protected def computeType(implicit s: Symbols): Type = { + if (s.typesCompatible(lhs.getType, rhs.getType)) BooleanType else { //println(s"Incompatible argument types: arguments: ($lhs, $rhs) types: ${lhs.getType}, ${rhs.getType}") Untyped @@ -425,9 +471,9 @@ trait Expressions { self: Trees => */ case class And(exprs: Seq[Expr]) extends Expr with CachingTyped { require(exprs.size >= 2) - protected def computeType(implicit p: Program): Type = { + protected def computeType(implicit s: Symbols): Type = { if (exprs forall (_.getType == BooleanType)) BooleanType - else checkBVCompatible(exprs.map(_.getType) : _*) + else bitVectorType(exprs.head.getType, exprs.tail.map(_.getType) : _*) } } @@ -443,9 +489,9 @@ trait Expressions { self: Trees => */ case class Or(exprs: Seq[Expr]) extends Expr { require(exprs.size >= 2) - protected def computeType(implicit p: Program): Type = { + protected def computeType(implicit s: Symbols): Type = { if (exprs forall (_.getType == BooleanType)) BooleanType - else checkBVCompatible(exprs.map(_.getType) : _*) + else bitVectorType(exprs.head.getType, exprs.tail.map(_.getType) : _*) } } @@ -461,7 +507,7 @@ trait Expressions { self: Trees => * @see [[leon.purescala.Constructors.implies]] */ case class Implies(lhs: Expr, rhs: Expr) extends Expr { - protected def computeType(implicit p: Program): Type = { + protected def computeType(implicit s: Symbols): Type = { if(lhs.getType == BooleanType && rhs.getType == BooleanType) BooleanType else Untyped } @@ -472,7 +518,7 @@ trait Expressions { self: Trees => * @see [[leon.purescala.Constructors.not]] */ case class Not(expr: Expr) extends Expr { - protected def computeType(implicit p: Program): Type = { + protected def computeType(implicit s: Symbols): Type = { if (expr.getType == BooleanType) BooleanType else bitVectorType(expr.getType) } @@ -483,7 +529,7 @@ trait Expressions { self: Trees => abstract class ConverterToString(fromType: Type, toType: Type) extends Expr with CachingTyped { val expr: Expr - protected def computeType(implicit p: Program): Type = + protected def computeType(implicit s: Symbols): Type = if (expr.getType == fromType) toType else Untyped } @@ -504,7 +550,7 @@ trait Expressions { self: Trees => /** $encodingof `lhs + rhs` for strings */ case class StringConcat(lhs: Expr, rhs: Expr) extends Expr with CachingTyped { - protected def computeType(implicit p: Program): Type = { + protected def computeType(implicit s: Symbols): Type = { if (lhs.getType == StringType && rhs.getType == StringType) StringType else Untyped } @@ -512,7 +558,7 @@ trait Expressions { self: Trees => /** $encodingof `lhs.subString(start, end)` for strings */ case class SubString(expr: Expr, start: Expr, end: Expr) extends Expr with CachingTyped { - protected def computeType(implicit p: Program): Type = { + protected def computeType(implicit s: Symbols): Type = { val ext = expr.getType val st = start.getType val et = end.getType @@ -523,7 +569,7 @@ trait Expressions { self: Trees => /** $encodingof `lhs.subString(start, end)` for strings */ case class BigSubString(expr: Expr, start: Expr, end: Expr) extends Expr with CachingTyped { - protected def computeType(implicit p: Program): Type = { + protected def computeType(implicit s: Symbols): Type = { val ext = expr.getType val st = start.getType val et = end.getType @@ -534,7 +580,7 @@ trait Expressions { self: Trees => /** $encodingof `lhs.length` for strings */ case class StringLength(expr: Expr) extends Expr with CachingTyped { - protected def computeType(implicit p: Program): Type = { + protected def computeType(implicit s: Symbols): Type = { if (expr.getType == StringType) Int32Type else Untyped } @@ -542,7 +588,7 @@ trait Expressions { self: Trees => /** $encodingof `lhs.length` for strings */ case class StringBigLength(expr: Expr) extends Expr with CachingTyped { - protected def computeType(implicit p: Program): Type = { + protected def computeType(implicit s: Symbols): Type = { if (expr.getType == StringType) IntegerType else Untyped } @@ -551,46 +597,46 @@ trait Expressions { self: Trees => /* General arithmetic */ - def numericType(tpe: TypeTree, tpes: TypeTree*): TypeTree = { - lazy val intType = integerType(tpe, tpes) - lazy val bvType = bitVectorType(tpe, tpes) - lazy val realType = realType(tpe, tpes) - if (intType.isTyped) intType else if (bvType.isTyped) bvType else realType + def numericType(tpe: Type, tpes: Type*)(implicit s: Symbols): Type = { + lazy val intType = integerType(tpe, tpes : _*) + lazy val bvType = bitVectorType(tpe, tpes : _*) + lazy val rlType = realType(tpe, tpes : _*) + if (intType.isTyped) intType else if (bvType.isTyped) bvType else rlType } - def integerType(tpe: TypeTree, tpes: TypeTree*): TypeTree = tpe match { - case IntegerType if typesCompatible(tpe, tpes : _*) => tpe + def integerType(tpe: Type, tpes: Type*)(implicit s: Symbols): Type = tpe match { + case IntegerType if s.typesCompatible(tpe, tpes : _*) => tpe case _ => Untyped } - def bitVectorType(tpe: TypeTree, tpes: TypeTree*): TypeTree = tpe match { - case _: BVType if typesCompatible(tpe, tpes: _*) => tpe + def bitVectorType(tpe: Type, tpes: Type*)(implicit s: Symbols): Type = tpe match { + case _: BVType if s.typesCompatible(tpe, tpes: _*) => tpe case _ => Untyped } - def realType(tpe: TypeTree, tpes: TypeTree*): TypeTree = tpe match { - case RealType if typesCompatible(tpe, tpes : _*) => tpe + def realType(tpe: Type, tpes: Type*)(implicit s: Symbols): Type = tpe match { + case RealType if s.typesCompatible(tpe, tpes : _*) => tpe case _ => Untyped } /** $encodingof `... + ...` for BigInts */ case class Plus(lhs: Expr, rhs: Expr) extends Expr with CachingTyped { - protected def computeType(implicit p: Program): Type = numericType(lhs.getType, rhs.getType) + protected def computeType(implicit s: Symbols): Type = numericType(lhs.getType, rhs.getType) } /** $encodingof `... - ...` */ case class Minus(lhs: Expr, rhs: Expr) extends Expr with CachingTyped { - protected def computeType(implicit p: Program): Type = numericType(lhs.getType, rhs.getType) + protected def computeType(implicit s: Symbols): Type = numericType(lhs.getType, rhs.getType) } /** $encodingof `- ... for BigInts`*/ case class UMinus(expr: Expr) extends Expr with CachingTyped { - protected def computeType(implicit p: Program): Type = numericType(expr.getType) + protected def computeType(implicit s: Symbols): Type = numericType(expr.getType) } /** $encodingof `... * ...` */ case class Times(lhs: Expr, rhs: Expr) extends Expr with CachingTyped { - protected def computeType(implicit p: Program): Type = numericType(lhs.getType, rhs.getType) + protected def computeType(implicit s: Symbols): Type = numericType(lhs.getType, rhs.getType) } /** $encodingof `... / ...` @@ -605,7 +651,7 @@ trait Expressions { self: Trees => * Division(x, y) * y + Remainder(x, y) == x */ case class Division(lhs: Expr, rhs: Expr) extends Expr with CachingTyped { - protected def computeType(implicit p: Program): Type = numericType(lhs.getType, rhs.getType) + protected def computeType(implicit s: Symbols): Type = numericType(lhs.getType, rhs.getType) } /** $encodingof `... % ...` (can return negative numbers) @@ -613,7 +659,7 @@ trait Expressions { self: Trees => * @see [[Expressions.Division]] */ case class Remainder(lhs: Expr, rhs: Expr) extends Expr with CachingTyped { - protected def computeType(implicit p: Program): Type = integerType(lhs.getType, rhs.getType) match { + protected def computeType(implicit s: Symbols): Type = integerType(lhs.getType, rhs.getType) match { case Untyped => bitVectorType(lhs.getType, rhs.getType) case tpe => tpe } @@ -624,7 +670,7 @@ trait Expressions { self: Trees => * @see [[Expressions.Division]] */ case class Modulo(lhs: Expr, rhs: Expr) extends Expr with CachingTyped { - protected def computeType(implicit p: Program): Type = integerType(lhs.getType, rhs.getType) match { + protected def computeType(implicit s: Symbols): Type = integerType(lhs.getType, rhs.getType) match { case Untyped => bitVectorType(lhs.getType, rhs.getType) case tpe => tpe } @@ -632,25 +678,25 @@ trait Expressions { self: Trees => /** $encodingof `... < ...`*/ case class LessThan(lhs: Expr, rhs: Expr) extends Expr with CachingTyped { - protected def computeType(implicit p: Program): Type = + protected def computeType(implicit s: Symbols): Type = if (numericType(lhs.getType, rhs.getType) != Untyped) BooleanType else Untyped } /** $encodingof `... > ...`*/ case class GreaterThan(lhs: Expr, rhs: Expr) extends Expr { - protected def computeType(implicit p: Program): Type = + protected def computeType(implicit s: Symbols): Type = if (numericType(lhs.getType, rhs.getType) != Untyped) BooleanType else Untyped } /** $encodingof `... <= ...`*/ case class LessEquals(lhs: Expr, rhs: Expr) extends Expr { - protected def computeType(implicit p: Program): Type = + protected def computeType(implicit s: Symbols): Type = if (numericType(lhs.getType, rhs.getType) != Untyped) BooleanType else Untyped } /** $encodingof `... >= ...`*/ case class GreaterEquals(lhs: Expr, rhs: Expr) extends Expr { - protected def computeType(implicit p: Program): Type = + protected def computeType(implicit s: Symbols): Type = if (numericType(lhs.getType, rhs.getType) != Untyped) BooleanType else Untyped } @@ -659,22 +705,22 @@ trait Expressions { self: Trees => /** $encodingof `... ^ ...` $noteBitvector */ case class BVXOr(lhs: Expr, rhs: Expr) extends Expr with CachingTyped { - protected def computeType(implicit p: Program): Type = bitVectorType(lhs.getType, rhs.getType) + protected def computeType(implicit s: Symbols): Type = bitVectorType(lhs.getType, rhs.getType) } /** $encodingof `... << ...` $noteBitvector */ case class BVShiftLeft(lhs: Expr, rhs: Expr) extends Expr with CachingTyped { - protected def computeType(implicit p: Program): Type = bitVectorType(lhs.getType, rhs.getType) + protected def computeType(implicit s: Symbols): Type = bitVectorType(lhs.getType, rhs.getType) } /** $encodingof `... >> ...` $noteBitvector (arithmetic shift, sign-preserving) */ case class BVAShiftRight(lhs: Expr, rhs: Expr) extends Expr with CachingTyped { - protected def computeType(implicit p: Program): Type = bitVectorType(lhs.getType, rhs.getType) + protected def computeType(implicit s: Symbols): Type = bitVectorType(lhs.getType, rhs.getType) } /** $encodingof `... >>> ...` $noteBitvector (logical shift) */ case class BVLShiftRight(lhs: Expr, rhs: Expr) extends Expr with CachingTyped { - protected def computeType(implicit p: Program): Type = bitVectorType(lhs.getType, rhs.getType) + protected def computeType(implicit s: Symbols): Type = bitVectorType(lhs.getType, rhs.getType) } @@ -690,7 +736,7 @@ trait Expressions { self: Trees => */ case class Tuple(exprs: Seq[Expr]) extends Expr with CachingTyped { require(exprs.size >= 2) - protected def computeType(implicit p: Program): Type = TupleType(exprs.map(_.getType)).unveilUntyped + protected def computeType(implicit s: Symbols): Type = TupleType(exprs.map(_.getType)).unveilUntyped } /** $encodingof `(tuple)._i` @@ -702,7 +748,7 @@ trait Expressions { self: Trees => case class TupleSelect(tuple: Expr, index: Int) extends Expr with CachingTyped { require(index >= 1) - protected def computeType(implicit p: Program): Type = tuple.getType match { + protected def computeType(implicit s: Symbols): Type = tuple.getType match { case tp @ TupleType(ts) => require(index <= ts.size, s"Got index $index for '$tuple' of type '$tp") ts(index - 1) @@ -716,12 +762,12 @@ trait Expressions { self: Trees => /** $encodingof `Set[base](elements)` */ case class FiniteSet(elements: Seq[Expr], base: Type) extends Expr { private lazy val tpe = SetType(base).unveilUntyped - def getType(implicit p: Program): Type = tpe + def getType(implicit s: Symbols): Type = tpe } /** $encodingof `set + elem` */ case class SetAdd(set: Expr, elem: Expr) extends Expr with CachingTyped { - protected def computeType(implicit p: Program): Type = { + protected def computeType(implicit s: Symbols): Type = { val base = set.getType match { case SetType(base) => base case _ => Untyped @@ -732,7 +778,7 @@ trait Expressions { self: Trees => /** $encodingof `set.contains(element)` or `set(element)` */ case class ElementOfSet(element: Expr, set: Expr) extends Expr with CachingTyped { - protected def computeType(implicit p: Program) = checkParamTypes(Seq(element.getType), Seq(set.getType match { + protected def computeType(implicit s: Symbols) = checkParamTypes(Seq(element.getType), Seq(set.getType match { case SetType(base) => base case _ => Untyped }), BooleanType) @@ -740,7 +786,7 @@ trait Expressions { self: Trees => /** $encodingof `set.length` */ case class SetCardinality(set: Expr) extends Expr with CachingTyped { - protected def computeType(implicit p: Program): Type = set.getType match { + protected def computeType(implicit s: Symbols): Type = set.getType match { case SetType(_) => IntegerType case _ => Untyped } @@ -748,7 +794,7 @@ trait Expressions { self: Trees => /** $encodingof `set.subsetOf(set2)` */ case class SubsetOf(set1: Expr, set2: Expr) extends Expr with CachingTyped { - protected def computeType(implicit p: Program): Type = (set1.getType, set2.getType) match { + protected def computeType(implicit s: Symbols): Type = (set1.getType, set2.getType) match { case (SetType(b1), SetType(b2)) if b1 == b2 => BooleanType case _ => Untyped } @@ -756,34 +802,34 @@ trait Expressions { self: Trees => /** $encodingof `set & set2` */ case class SetIntersection(set1: Expr, set2: Expr) extends Expr with CachingTyped { - protected def computeType(implicit p: Program): Type = - leastUpperBound(Seq(set1, set2).map(_.getType)).getOrElse(Untyped).unveilUntyped + protected def computeType(implicit s: Symbols): Type = + s.leastUpperBound(Seq(set1, set2).map(_.getType)).getOrElse(Untyped).unveilUntyped } /** $encodingof `set ++ set2` */ case class SetUnion(set1: Expr, set2: Expr) extends Expr with CachingTyped { - protected def computeType(implicit p: Program): Type = - leastUpperBound(Seq(set1, set2).map(_.getType)).getOrElse(Untyped).unveilUntyped + protected def computeType(implicit s: Symbols): Type = + s.leastUpperBound(Seq(set1, set2).map(_.getType)).getOrElse(Untyped).unveilUntyped } /** $encodingof `set -- set2` */ case class SetDifference(set1: Expr, set2: Expr) extends Expr with CachingTyped { - protected def computeType(implicit p: Program): Type = - leastUpperBound(Seq(set1, set2).map(_.getType)).getOrElse(Untyped).unveilUntyped + protected def computeType(implicit s: Symbols): Type = + s.leastUpperBound(Seq(set1, set2).map(_.getType)).getOrElse(Untyped).unveilUntyped } /* Bag operations */ /** $encodingof `Bag[base](elements)` */ - case class FiniteBag(elements: Seq[(Expr, Expr)], base: TypeTree) extends Expr { + case class FiniteBag(elements: Seq[(Expr, Expr)], base: Type) extends Expr { lazy val tpe = BagType(base).unveilUntyped - def getType(implicit p: Program): Type = tpe + def getType(implicit s: Symbols): Type = tpe } /** $encodingof `bag + elem` */ case class BagAdd(bag: Expr, elem: Expr) extends Expr with CachingTyped { - protected def computeType(implicit p: Program): Type = { + protected def computeType(implicit s: Symbols): Type = { val base = bag.getType match { case BagType(base) => base case _ => Untyped @@ -794,7 +840,7 @@ trait Expressions { self: Trees => /** $encodingof `bag.get(element)` or `bag(element)` */ case class MultiplicityInBag(element: Expr, bag: Expr) extends Expr with CachingTyped { - protected def computeType(implicit p: Program): Type = checkParamTypes(Seq(element.getType), Seq(bag.getType match { + protected def computeType(implicit s: Symbols): Type = checkParamTypes(Seq(element.getType), Seq(bag.getType match { case BagType(base) => base case _ => Untyped }), IntegerType) @@ -802,34 +848,34 @@ trait Expressions { self: Trees => /** $encodingof `bag1 & bag2` */ case class BagIntersection(bag1: Expr, bag2: Expr) extends Expr with CachingTyped { - protected def computeType(implicit p: Program): Type = - leastUpperBound(Seq(bag1, bag2).map(_.getType)).getOrElse(Untyped).unveilUntyped + protected def computeType(implicit s: Symbols): Type = + s.leastUpperBound(Seq(bag1, bag2).map(_.getType)).getOrElse(Untyped).unveilUntyped } /** $encodingof `bag1 ++ bag2` */ case class BagUnion(bag1: Expr, bag2: Expr) extends Expr with CachingTyped { - protected def computeType(implicit p: Program): Type = - leastUpperBound(Seq(bag1, bag2).map(_.getType)).getOrElse(Untyped).unveilUntyped + protected def computeType(implicit s: Symbols): Type = + s.leastUpperBound(Seq(bag1, bag2).map(_.getType)).getOrElse(Untyped).unveilUntyped } /** $encodingof `bag1 -- bag2` */ case class BagDifference(bag1: Expr, bag2: Expr) extends Expr with CachingTyped { - protected def computeType(implicit p: Program): Type = - leastUpperBound(Seq(bag1, bag2).map(_.getType)).getOrElse(Untyped).unveilUntyped + protected def computeType(implicit s: Symbols): Type = + s.leastUpperBound(Seq(bag1, bag2).map(_.getType)).getOrElse(Untyped).unveilUntyped } /* Total map operations */ /** $encodingof `Map[keyType, valueType](key1 -> value1, key2 -> value2 ...)` */ - case class FiniteMap(pairs: Seq[(Expr, Expr)], default: Expr, keyType: TypeTree) extends Expr { - lazy val tpe = MapType(keyType, default.getType).unveilUntyped - def getType(implicit p: Program): Type = tpe + case class FiniteMap(pairs: Seq[(Expr, Expr)], default: Expr, keyType: Type) extends Expr with CachingTyped { + protected def computeType(implicit s: Symbols): Type = + MapType(keyType, default.getType).unveilUntyped } /** $encodingof `map.apply(key)` (or `map(key)`)*/ case class MapApply(map: Expr, key: Expr) extends Expr with CachingTyped { - protected def computeType(implicit p: Program): Type = map.getType match { + protected def computeType(implicit s: Symbols): Type = map.getType match { case MapType(from, to) => checkParamTypes(Seq(key.getType), Seq(from), to) case _ => Untyped } diff --git a/src/main/scala/inox/trees/Extractors.scala b/src/main/scala/inox/ast/Extractors.scala similarity index 84% rename from src/main/scala/inox/trees/Extractors.scala rename to src/main/scala/inox/ast/Extractors.scala index 19998d88b..a02aec449 100644 --- a/src/main/scala/inox/trees/Extractors.scala +++ b/src/main/scala/inox/ast/Extractors.scala @@ -1,9 +1,9 @@ /* Copyright 2009-2016 EPFL, Lausanne */ package inox -package trees +package ast -trait Extractors { self: Expressions => +trait Extractors { self: Trees => /** Operator Extractor to extract any Expression in a consistent way. * @@ -19,7 +19,7 @@ trait Extractors { self: Expressions => * function that would simply apply the corresponding constructor for each node. */ object Operator extends TreeExtractor { - val trees: Extractors.this.trees = Extractors.this.trees + val trees: Extractors.this.type = Extractors.this type SubTree = trees.Expr def unapply(expr: Expr): Option[(Seq[Expr], (Seq[Expr]) => Expr)] = expr match { @@ -134,14 +134,14 @@ trait Extractors { self: Expressions => case SubString(t1, a, b) => Some((t1::a::b::Nil, es => SubString(es(0), es(1), es(2)))) case BigSubString(t1, a, b) => Some((t1::a::b::Nil, es => BigSubString(es(0), es(1), es(2)))) case FiniteSet(els, base) => - Some((els.toSeq, els => FiniteSet(els.toSet, base))) + Some((els, els => FiniteSet(els, base))) case FiniteBag(els, base) => val subArgs = els.flatMap { case (k, v) => Seq(k, v) }.toSeq val builder = (as: Seq[Expr]) => { - def rec(kvs: Seq[Expr]): Map[Expr, Expr] = kvs match { + def rec(kvs: Seq[Expr]): Seq[(Expr, Expr)] = kvs match { case Seq(k, v, t @ _*) => - Map(k -> v) ++ rec(t) - case Seq() => Map() + Seq(k -> v) ++ rec(t) + case Seq() => Seq() case _ => sys.error("odd number of key/value expressions") } FiniteBag(rec(as), base) @@ -150,10 +150,10 @@ trait Extractors { self: Expressions => case FiniteMap(args, f, t) => { val subArgs = args.flatMap { case (k, v) => Seq(k, v) }.toSeq val builder = (as: Seq[Expr]) => { - def rec(kvs: Seq[Expr]): Map[Expr, Expr] = kvs match { + def rec(kvs: Seq[Expr]): Seq[(Expr, Expr)] = kvs match { case Seq(k, v, t @ _*) => - Map(k -> v) ++ rec(t) - case Seq() => Map() + Seq(k -> v) ++ rec(t) + case Seq() => Seq() case _ => sys.error("odd number of key/value expressions") } FiniteMap(rec(as), f, t) @@ -198,7 +198,7 @@ trait Extractors { self: Expressions => object TopLevelOrs { // expr1 OR (expr2 OR (expr3 OR ..)) => List(expr1, expr2, expr3) def unapply(e: Expr): Option[Seq[Expr]] = e match { - case Let(i, e, TopLevelOrs(bs)) => Some(bs map (let(i,e,_))) + case Let(i, e, TopLevelOrs(bs)) => Some(bs map (Let(i,e,_))) case Or(exprs) => Some(exprs.flatMap(unapply).flatten) case e => @@ -208,7 +208,7 @@ trait Extractors { self: Expressions => object TopLevelAnds { // expr1 AND (expr2 AND (expr3 AND ..)) => List(expr1, expr2, expr3) def unapply(e: Expr): Option[Seq[Expr]] = e match { - case Let(i, e, TopLevelAnds(bs)) => Some(bs map (let(i,e,_))) + case Let(i, e, TopLevelAnds(bs)) => Some(bs map (Let(i,e,_))) case And(exprs) => Some(exprs.flatMap(unapply).flatten) case e => @@ -217,7 +217,7 @@ trait Extractors { self: Expressions => } object IsTyped { - def unapply[T <: Typed](e: T): Option[(T, Type)] = Some((e, e.getType)) + def unapply[T <: Typed](e: T)(implicit p: Symbols): Option[(T, Type)] = Some((e, e.getType)) } object WithStringconverter { @@ -232,7 +232,7 @@ trait Extractors { self: Expressions => } object SimpleCase { - def apply(p : Pattern, rhs : Expr) = MatchCase(p, None, rhs) + def apply(p: Pattern, rhs: Expr) = MatchCase(p, None, rhs) def unapply(c : MatchCase) = c match { case MatchCase(p, None, rhs) => Some((p, rhs)) case _ => None @@ -240,7 +240,7 @@ trait Extractors { self: Expressions => } object GuardedCase { - def apply(p : Pattern, g: Expr, rhs : Expr) = MatchCase(p, Some(g), rhs) + def apply(p: Pattern, g: Expr, rhs: Expr) = MatchCase(p, Some(g), rhs) def unapply(c : MatchCase) = c match { case MatchCase(p, Some(g), rhs) => Some((p, g, rhs)) case _ => None @@ -248,28 +248,28 @@ trait Extractors { self: Expressions => } object Pattern { - def unapply(p : Pattern) : Option[( - Option[Identifier], + def unapply(p: Pattern) : Option[( + Option[ValDef], Seq[Pattern], - (Option[Identifier], Seq[Pattern]) => Pattern + (Option[ValDef], Seq[Pattern]) => Pattern )] = Option(p) map { - case InstanceOfPattern(b, ct) => (b, Seq(), (b, _) => InstanceOfPattern(b,ct)) - case WildcardPattern(b) => (b, Seq(), (b, _) => WildcardPattern(b)) - case CaseClassPattern(b, ct, subs) => (b, subs, (b, sp) => CaseClassPattern(b, ct, sp)) - case TuplePattern(b,subs) => (b, subs, (b, sp) => TuplePattern(b, sp)) - case LiteralPattern(b, l) => (b, Seq(), (b, _) => LiteralPattern(b, l)) - case UnapplyPattern(b, fd, subs) => (b, subs, (b, sp) => UnapplyPattern(b, fd, sp)) + case InstanceOfPattern(b, ct) => (b, Seq(), (b, _) => InstanceOfPattern(b,ct)) + case WildcardPattern(b) => (b, Seq(), (b, _) => WildcardPattern(b)) + case CaseClassPattern(b, ct, subs) => (b, subs, (b, sp) => CaseClassPattern(b, ct, sp)) + case TuplePattern(b,subs) => (b, subs, (b, sp) => TuplePattern(b, sp)) + case LiteralPattern(b, l) => (b, Seq(), (b, _) => LiteralPattern(b, l)) + case UnapplyPattern(b, id, tps, subs) => (b, subs, (b, sp) => UnapplyPattern(b, id, tps, sp)) } } - def unwrapTuple(e: Expr, isTuple: Boolean): Seq[Expr] = e.getType match { + def unwrapTuple(e: Expr, isTuple: Boolean)(implicit s: Symbols): Seq[Expr] = e.getType match { case TupleType(subs) if isTuple => - for (ind <- 1 to subs.size) yield { tupleSelect(e, ind, isTuple) } + for (ind <- 1 to subs.size) yield { s.tupleSelect(e, ind, isTuple) } case _ if !isTuple => Seq(e) case tp => sys.error(s"Calling unwrapTuple on non-tuple $e of type $tp") } - def unwrapTuple(e: Expr, expectedSize: Int): Seq[Expr] = unwrapTuple(e, expectedSize > 1) + def unwrapTuple(e: Expr, expectedSize: Int)(implicit p: Symbols): Seq[Expr] = unwrapTuple(e, expectedSize > 1) def unwrapTupleType(tp: Type, isTuple: Boolean): Seq[Type] = tp match { case TupleType(subs) if isTuple => subs @@ -290,25 +290,25 @@ trait Extractors { self: Expressions => unwrapTuplePattern(p, expectedSize > 1) object LetPattern { - def apply(patt : Pattern, value: Expr, body: Expr) : Expr = { + def apply(patt: Pattern, value: Expr, body: Expr) : Expr = { patt match { case WildcardPattern(Some(binder)) => Let(binder, value, body) case _ => MatchExpr(value, List(SimpleCase(patt, body))) } } - def unapply(me : MatchExpr) : Option[(Pattern, Expr, Expr)] = { + def unapply(me: MatchExpr) : Option[(Pattern, Expr, Expr)] = { Option(me) collect { - case MatchExpr(scrut, List(SimpleCase(pattern, body))) if !aliased(pattern.binders, ExprOps.variablesOf(scrut)) => + case MatchExpr(scrut, List(SimpleCase(pattern, body))) if !aliased(pattern.binders, exprOps.variablesOf(scrut)) => ( pattern, scrut, body ) } } } object LetTuple { - def unapply(me : MatchExpr) : Option[(Seq[Identifier], Expr, Expr)] = { + def unapply(me: MatchExpr) : Option[(Seq[ValDef], Expr, Expr)] = { Option(me) collect { - case LetPattern(TuplePattern(None,subPatts), value, body) if + case LetPattern(TuplePattern(None, subPatts), value, body) if subPatts forall { case WildcardPattern(Some(_)) => true; case _ => false } => (subPatts map { _.binder.get }, value, body ) } diff --git a/src/main/scala/inox/trees/GenTreeOps.scala b/src/main/scala/inox/ast/GenTreeOps.scala similarity index 98% rename from src/main/scala/inox/trees/GenTreeOps.scala rename to src/main/scala/inox/ast/GenTreeOps.scala index 35cde2198..3217b76b4 100644 --- a/src/main/scala/inox/trees/GenTreeOps.scala +++ b/src/main/scala/inox/ast/GenTreeOps.scala @@ -1,7 +1,9 @@ /* Copyright 2009-2015 EPFL, Lausanne */ package inox -package trees +package ast + +import utils._ /** A type that pattern matches agains a type of [[Tree]] and extracts it subtrees, * and a builder that reconstructs a tree of the same type from subtrees. @@ -28,7 +30,7 @@ trait GenTreeOps { /** An extractor for [[SubTree]]*/ val Deconstructor: TreeExtractor { - val trees: GenTreeOps.this.trees + val trees: GenTreeOps.this.trees.type type SubTree <: GenTreeOps.this.SubTree } @@ -284,8 +286,8 @@ trait GenTreeOps { rec(expr, init) } - protected def noCombiner(e: SubTree, subCs: Seq[Unit]) = () - protected def noTransformer[C](e: SubTree, c: C) = (e, c) + def noCombiner(e: SubTree, subCs: Seq[Unit]) = () + def noTransformer[C](e: SubTree, c: C) = (e, c) /** A [[genericTransform]] with the trivial combiner that returns () */ def simpleTransform(pre: SubTree => SubTree, post: SubTree => SubTree)(tree: SubTree) = { diff --git a/src/main/scala/inox/trees/Paths.scala b/src/main/scala/inox/ast/Paths.scala similarity index 79% rename from src/main/scala/inox/trees/Paths.scala rename to src/main/scala/inox/ast/Paths.scala index a5c0d7633..0531ff055 100644 --- a/src/main/scala/inox/trees/Paths.scala +++ b/src/main/scala/inox/ast/Paths.scala @@ -1,13 +1,13 @@ /* Copyright 2009-2016 EPFL, Lausanne */ package inox -package trees +package ast -trait Paths { self: ExprOps => +trait Paths { self: TypeOps with Constructors => import trees._ object Path { - final type Element = Either[(Identifier, Expr), Expr] + final type Element = Either[(ValDef, Expr), Expr] def empty: Path = new Path(Seq.empty) @@ -29,19 +29,19 @@ trait Paths { self: ExprOps => * not defined, whereas an encoding of let-bindings with equalities * could introduce non-sensical equations. */ - class Path private[purescala]( - private[purescala] val elements: Seq[Path.Element]) extends Printable { + class Path private(private[ast] val elements: Seq[Path.Element]) + extends inox.Printable { import Path.Element - + /** Add a binding to this [[Path]] */ - def withBinding(p: (Identifier, Expr)) = { + def withBinding(p: (ValDef, Expr)) = { def exprOf(e: Element) = e match { case Right(e) => e; case Left((_, e)) => e } - val (before, after) = elements span (el => !variablesOf(exprOf(el)).contains(p._1)) + val (before, after) = elements span (el => !exprOps.variablesOf(exprOf(el)).contains(p._1.toVariable)) new Path(before ++ Seq(Left(p)) ++ after) } - def withBindings(ps: Iterable[(Identifier, Expr)]) = { + def withBindings(ps: Iterable[(ValDef, Expr)]) = { ps.foldLeft(this)( _ withBinding _ ) } @@ -56,7 +56,7 @@ trait Paths { self: ExprOps => /** Remove bound variables from this [[Path]] * @param ids the bound variables to remove */ - def --(ids: Set[Identifier]) = new Path(elements.filterNot(_.left.exists(p => ids(p._1)))) + def --(ids: Set[Identifier]) = new Path(elements.filterNot(_.left.exists(p => ids(p._1.id)))) /** Appends `that` path at the end of `this` */ def merge(that: Path): Path = new Path(elements ++ that.elements) @@ -103,7 +103,7 @@ trait Paths { self: ExprOps => */ def negate: Path = { val (outers, rest) = elements.span(_.isLeft) - new Path(outers :+ Right(not(fold[Expr](BooleanLiteral(true), let, Constructors.and(_, _))(rest)))) + new Path(outers :+ Right(not(fold[Expr](BooleanLiteral(true), let, self.and(_, _))(rest)))) } /** Returns a new path which depends ONLY on provided ids. @@ -119,35 +119,35 @@ trait Paths { self: ExprOps => * @see [[leon.purescala.FunctionClosure.close]] for an example usecase. */ def filterByIds(ids: Set[Identifier]): Path = { - def containsIds(ids: Set[Identifier])(e: Expr): Boolean = exists{ - case Variable(id) => ids.contains(id) + def containsIds(ids: Set[Identifier])(e: Expr): Boolean = exprOps.exists { + case Variable(id, _) => ids.contains(id) case _ => false }(e) val newElements = elements.filter{ - case Left((id, e)) => ids.contains(id) || containsIds(ids)(e) + case Left((vd, e)) => ids.contains(vd.id) || containsIds(ids)(e) case Right(e) => containsIds(ids)(e) } new Path(newElements) } /** Free variables within the path */ - lazy val variables: Set[Identifier] = fold[Set[Identifier]](Set.empty, - (id, e, res) => res - id ++ variablesOf(e), (e, res) => res ++ variablesOf(e) + lazy val variables: Set[Variable] = fold[Set[Variable]](Set.empty, + (vd, e, res) => res - vd.toVariable ++ exprOps.variablesOf(e), (e, res) => res ++ exprOps.variablesOf(e) )(elements) - lazy val bindings: Seq[(Identifier, Expr)] = elements.collect { case Left(p) => p } + lazy val bindings: Seq[(ValDef, Expr)] = elements.collect { case Left(p) => p } lazy val boundIds = bindings map (_._1) lazy val conditions: Seq[Expr] = elements.collect { case Right(e) => e } - def isBound(id: Identifier): Boolean = bindings.exists(p => p._1 == id) + def isBound(id: Identifier): Boolean = bindings.exists(p => p._1.id == id) /** Fold the path elements * * This function takes two combiner functions, one for let bindings and one * for proposition expressions. */ - private def fold[T](base: T, combineLet: (Identifier, Expr, T) => T, combineCond: (Expr, T) => T) + private def fold[T](base: T, combineLet: (ValDef, Expr, T) => T, combineCond: (Expr, T) => T) (elems: Seq[Element]): T = elems.foldRight(base) { case (Left((id, e)), res) => combineLet(id, e, res) case (Right(e), res) => combineCond(e, res) @@ -168,10 +168,10 @@ trait Paths { self: ExprOps => } /** Folds the path into a conjunct with the expression `base` */ - def and(base: Expr) = distributiveClause(base, Constructors.and(_, _)) + def and(base: Expr) = distributiveClause(base, self.and(_, _)) /** Fold the path into an implication of `base`, namely `path ==> base` */ - def implies(base: Expr) = distributiveClause(base, Constructors.implies) + def implies(base: Expr) = distributiveClause(base, self.implies) /** Folds the path into a `require` wrapping the expression `body` * @@ -182,17 +182,16 @@ trait Paths { self: ExprOps => */ def specs(body: Expr, pre: Expr = BooleanLiteral(true), post: Expr = NoTree(BooleanType)) = { val (outers, rest) = elements.span(_.isLeft) - val cond = fold[Expr](BooleanLiteral(true), let, Constructors.and(_, _))(rest) + val cond = fold[Expr](BooleanLiteral(true), let, self.and(_, _))(rest) def wrap(e: Expr): Expr = { - val bindings = rest.collect { case Left((id, e)) => id -> e } - val idSubst = bindings.map(p => p._1 -> p._1.freshen).toMap - val substMap = idSubst.mapValues(_.toVariable) - val replace = replaceFromIDs(substMap, _: Expr) - bindings.foldRight(replace(e)) { case ((id, e), b) => let(idSubst(id), replace(e), b) } + val bindings = rest.collect { case Left((vd, e)) => vd -> e } + val vdSubst = bindings.map(p => p._1 -> p._1.freshen).toMap + val replace = exprOps.replaceFromSymbols(vdSubst.mapValues(_.toVariable), _: Expr) + bindings.foldRight(replace(e)) { case ((vd, e), b) => let(vdSubst(vd), replace(e), b) } } - val req = Require(Constructors.and(cond, wrap(pre)), wrap(body)) + val req = Require(self.and(cond, wrap(pre)), wrap(body)) val full = post match { case l @ Lambda(args, body) => Ensuring(req, Lambda(args, wrap(body)).copiedFrom(l)) case _ => req @@ -225,8 +224,8 @@ trait Paths { self: ExprOps => override def hashCode: Int = elements.hashCode - override def toString = asString(LeonContext.printNames) - def asString(implicit ctx: LeonContext): String = fullClause.asString - def asString(pgm: Program)(implicit ctx: LeonContext): String = fullClause.asString(pgm) + override def toString = asString(Context.printNames) + def asString(implicit ctx: Context): String = fullClause.asString + def asString(pgm: Program)(implicit ctx: Context): String = fullClause.asString(pgm) } } diff --git a/src/main/scala/inox/trees/PrinterOptions.scala b/src/main/scala/inox/ast/PrinterOptions.scala similarity index 100% rename from src/main/scala/inox/trees/PrinterOptions.scala rename to src/main/scala/inox/ast/PrinterOptions.scala diff --git a/src/main/scala/inox/trees/Printers.scala b/src/main/scala/inox/ast/Printers.scala similarity index 53% rename from src/main/scala/inox/trees/Printers.scala rename to src/main/scala/inox/ast/Printers.scala index 460b1bdb2..6b3e243a3 100644 --- a/src/main/scala/inox/trees/Printers.scala +++ b/src/main/scala/inox/ast/Printers.scala @@ -1,8 +1,9 @@ /* Copyright 2009-2016 EPFL, Lausanne */ package inox -package trees +package ast +import utils._ import org.apache.commons.lang3.StringEscapeUtils trait Printers { self: Trees => @@ -34,19 +35,6 @@ trait Printers { self: Trees => } } - protected def getScope(implicit ctx: PrinterContext) = - ctx.parents.collectFirst { case (d: Definition) if !d.isInstanceOf[ValDef] => d } - - protected def printNameWithPath(df: Definition)(implicit ctx: PrinterContext) { - (opgm, getScope) match { - case (Some(pgm), Some(scope)) => - sb.append(fullNameFrom(df, scope, opts.printUniqueIds)(pgm)) - - case _ => - p"${df.id}" - } - } - private val dbquote = "\"" def pp(tree: Tree)(implicit ctx: PrinterContext): Unit = { @@ -75,25 +63,19 @@ trait Printers { self: Trees => } p"$name" - case Variable(id) => + case Variable(id, _) => p"$id" - case Let(id, expr, SubString(Variable(id2), start, StringLength(Variable(id3)))) if id == id2 && id2 == id3 => + case Let(vd, expr, SubString(v2: Variable, start, StringLength(v3: Variable))) if vd == v2 && v2 == v3 => p"$expr.substring($start)" - - case Let(id, expr, BigSubString(Variable(id2), start, StringLength(Variable(id3)))) if id == id2 && id2 == id3 => + + case Let(vd, expr, BigSubString(v2: Variable, start, StringLength(v3: Variable))) if vd == v2 && v2 == v3 => p"$expr.bigSubstring($start)" case Let(b,d,e) => p"""|val $b = $d |$e""" - case LetDef(a::q,body) => - p"""|$a - |${letDef(q, body)}""" - case LetDef(Nil,body) => - p"""$body""" - case Require(pre, body) => p"""|require($pre) |$body""" @@ -113,69 +95,17 @@ trait Printers { self: Trees => | $post |}""" - case p @ Passes(in, out, tests) => - tests match { - case Seq(MatchCase(_, Some(BooleanLiteral(false)), NoTree(_))) => - p"""|byExample($in, $out)""" - case _ => - optP { - p"""|($in, $out) passes { - | ${nary(tests, "\n")} - |}""" - } - } - - - case c @ WithOracle(vars, pred) => - p"""|withOracle { (${typed(vars)}) => - | $pred - |}""" - - case h @ Hole(tpe, es) => - if (es.isEmpty) { - val hole = (for{scope <- getScope - program <- opgm } - yield simplifyPath("leon" :: "lang" :: "synthesis" :: "???" :: Nil, scope, false)(program)) - .getOrElse("leon.lang.synthesis.???") - p"$hole[$tpe]" - } else { - p"?($es)" - } - case Forall(args, e) => - p"\u2200${typed(args.map(_.id))}. $e" + p"\u2200${typed(args)}. $e" case e @ CaseClass(cct, args) => - opgm.flatMap { pgm => isListLiteral(e)(pgm) } match { - case Some((tpe, elems)) => - val chars = elems.collect{ case CharLiteral(ch) => ch } - if (chars.length == elems.length && tpe == CharType) { - // String literal - val str = chars mkString "" - val q = '"' - p"$q$str$q" - } else { - val lclass = AbstractClassType(opgm.get.library.List.get, cct.tps) - - p"$lclass($elems)" - } - - case None => - if (cct.classDef.isCaseObject) { - p"$cct" - } else { - p"$cct($args)" - } - } + p"$cct($args)" case And(exprs) => optP { p"${nary(exprs, " && ")}" } case Or(exprs) => optP { p"${nary(exprs, "| || ")}" } // Ugliness award! The first | is there to shield from stripMargin() case Not(Equals(l, r)) => optP { p"$l \u2260 $r" } case Implies(l,r) => optP { p"$l ==> $r" } - case BVNot(expr) => p"~$expr" case UMinus(expr) => p"-$expr" - case BVUMinus(expr) => p"-$expr" - case RealUMinus(expr) => p"-$expr" case Equals(l,r) => optP { p"$l == $r" } @@ -192,7 +122,7 @@ trait Printers { self: Trees => case StringBigLength(expr) => p"$expr.bigLength" case IntLiteral(v) => p"$v" - case InfiniteIntegerLiteral(v) => p"$v" + case IntegerLiteral(v) => p"$v" case FractionalLiteral(n, d) => if (d == 1) p"$n" else p"$n/$d" @@ -210,98 +140,26 @@ trait Printers { self: Trees => case Tuple(exprs) => p"($exprs)" case TupleSelect(t, i) => p"$t._$i" case NoTree(tpe) => p"<empty tree>[$tpe]" - case Choose(pred) => - val choose = (for{scope <- getScope - program <- opgm } - yield simplifyPath("leon" :: "lang" :: "synthesis" :: "choose" :: Nil, scope, false)(program)) - .getOrElse("leon.lang.synthesis.choose") - p"$choose($pred)" case e @ Error(tpe, err) => p"""error[$tpe]("$err")""" case AsInstanceOf(e, ct) => p"""$e.asInstanceOf[$ct]""" - case IsInstanceOf(e, cct) => - if (cct.classDef.isCaseObject) { - p"($e == $cct)" - } else { - p"$e.isInstanceOf[$cct]" - } + case IsInstanceOf(e, cct) => p"$e.isInstanceOf[$cct]" case CaseClassSelector(_, e, id) => p"$e.$id" - case MethodInvocation(rec, _, tfd, args) => - p"$rec.${tfd.id}${nary(tfd.tps, ", ", "[", "]")}" - // No () for fields - if (tfd.fd.isRealFunction) { - // The non-present arguments are synthetic function invocations - val presentArgs = args filter { - case MethodInvocation(_, _, tfd, _) if tfd.fd.isSynthetic => false - case FunctionInvocation(tfd, _) if tfd.fd.isSynthetic => false - case other => true - } - - val requireParens = presentArgs.nonEmpty || args.nonEmpty - if (requireParens) { - p"($presentArgs)" - } - } - - case BinaryMethodCall(a, op, b) => - optP { p"$a $op $b" } - - case FcallMethodInvocation(rec, fd, id, tps, args) => - - p"$rec.$id${nary(tps, ", ", "[", "]")}" - - if (fd.isRealFunction) { - // The non-present arguments are synthetic function invocations - val presentArgs = args filter { - case MethodInvocation(_, _, tfd, _) if tfd.fd.isSynthetic => false - case FunctionInvocation(tfd, _) if tfd.fd.isSynthetic => false - case other => true - } - - val requireParens = presentArgs.nonEmpty || args.nonEmpty - if (requireParens) { - p"($presentArgs)" - } - } - - case FunctionInvocation(TypedFunDef(fd, tps), args) => - printNameWithPath(fd) - p"${nary(tps, ", ", "[", "]")}" - - if (fd.isRealFunction) { - // The non-present arguments are synthetic function invocations - val presentArgs = args filter { - case MethodInvocation(_, _, tfd, _) if tfd.fd.isSynthetic => false - case FunctionInvocation(tfd, _) if tfd.fd.isSynthetic => false - case other => true - } - val requireParens = presentArgs.nonEmpty || args.nonEmpty - if (requireParens) { - p"($presentArgs)" - } + case FunctionInvocation(id, tps, args) => + p"${id}${nary(tps, ", ", "[", "]")}" + if (args.nonEmpty) { + p"($args)" } case Application(caller, args) => p"$caller($args)" - case Lambda(Seq(ValDef(id)), FunctionInvocation(TypedFunDef(fd, Seq()), Seq(Variable(idArg)))) if id == idArg => - printNameWithPath(fd) - + case Lambda(Seq(vd), FunctionInvocation(id, Seq(), Seq(arg))) if vd == arg => + p"${id}" + case Lambda(args, body) => optP { p"($args) => $body" } - case FiniteLambda(mapping, dflt, _) => - optP { - def pm(p: (Seq[Expr], Expr)): PrinterHelpers.Printable = - (pctx: PrinterContext) => p"${purescala.Constructors.tupleWrap(p._1)} => ${p._2}"(pctx) - - if (mapping.isEmpty) { - p"{ * => ${dflt} }" - } else { - p"{ ${nary(mapping map pm)}, * => ${dflt} }" - } - } - case Plus(l,r) => optP { p"$l + $r" } case Minus(l,r) => optP { p"$l - $r" } case Times(l,r) => optP { p"$l * $r" } @@ -312,24 +170,13 @@ trait Printers { self: Trees => case GreaterThan(l,r) => optP { p"$l > $r" } case LessEquals(l,r) => optP { p"$l <= $r" } case GreaterEquals(l,r) => optP { p"$l >= $r" } - case BVPlus(l,r) => optP { p"$l + $r" } - case BVMinus(l,r) => optP { p"$l - $r" } - case BVTimes(l,r) => optP { p"$l * $r" } - case BVDivision(l,r) => optP { p"$l / $r" } - case BVRemainder(l,r) => optP { p"$l % $r" } - case BVAnd(l,r) => optP { p"$l & $r" } - case BVOr(l,r) => optP { p"$l | $r" } case BVXOr(l,r) => optP { p"$l ^ $r" } case BVShiftLeft(l,r) => optP { p"$l << $r" } case BVAShiftRight(l,r) => optP { p"$l >> $r" } case BVLShiftRight(l,r) => optP { p"$l >>> $r" } - case RealPlus(l,r) => optP { p"$l + $r" } - case RealMinus(l,r) => optP { p"$l - $r" } - case RealTimes(l,r) => optP { p"$l * $r" } - case RealDivision(l,r) => optP { p"$l / $r" } - case fs @ FiniteSet(rs, _) => p"{${rs.toSeq}}" - case fs @ FiniteBag(rs, _) => p"{$rs}" - case fm @ FiniteMap(rs, _, _) => p"{${rs.toSeq}}" + case fs @ FiniteSet(rs, _) => p"{${rs.distinct}}" + case fs @ FiniteBag(rs, _) => p"{${rs.toMap.toSeq}}" + case fm @ FiniteMap(rs, _, _) => p"{${rs.toMap.toSeq}}" case Not(ElementOfSet(e,s)) => p"$e \u2209 $s" case ElementOfSet(e,s) => p"$e \u2208 $s" case SubsetOf(l,r) => p"$l \u2286 $r" @@ -337,7 +184,6 @@ trait Printers { self: Trees => case SetAdd(s,e) => p"$s \u222A {$e}" case SetUnion(l,r) => p"$l \u222A $r" case BagUnion(l,r) => p"$l \u222A $r" - case MapUnion(l,r) => p"$l \u222A $r" case SetDifference(l,r) => p"$l \\ $r" case BagDifference(l,r) => p"$l \\ $r" case SetIntersection(l,r) => p"$l \u2229 $r" @@ -346,47 +192,12 @@ trait Printers { self: Trees => case BagAdd(b,e) => p"$b + $e" case MultiplicityInBag(e, b) => p"$b($e)" case MapApply(m,k) => p"$m($k)" - case MapIsDefinedAt(m,k) => p"$m.isDefinedAt($k)" - case ArrayLength(a) => p"$a.length" - case ArraySelect(a, i) => p"$a($i)" - case ArrayUpdated(a, i, v) => p"$a.updated($i, $v)" - case a@FiniteArray(es, d, s) => { - val ArrayType(underlying) = a.getType - val default = d.getOrElse(simplestValue(underlying)) - def ppBigArray(): Unit = { - if(es.isEmpty) { - p"Array($default, $default, $default, ..., $default) (of size $s)" - } else { - p"Array(_) (of size $s)" - } - } - s match { - case IntLiteral(length) => { - if(es.size == length) { - val orderedElements = es.toSeq.sortWith((e1, e2) => e1._1 < e2._1).map(el => el._2) - p"Array($orderedElements)" - } else if(length < 10) { - val elems = (0 until length).map(i => - es.find(el => el._1 == i).map(el => el._2).getOrElse(d.get) - ) - p"Array($elems)" - } else { - ppBigArray() - } - } - case _ => ppBigArray() - } - } case Not(expr) => p"\u00AC$expr" - case vd @ ValDef(id) => - if(vd.isVar) - p"var " - p"$id : ${vd.getType}" - vd.defaultValue.foreach { fd => p" = ${fd.body.get}" } + case vd @ ValDef(id, tpe) => + p"$id : ${tpe}" - case This(_) => p"this" case (tfd: TypedFunDef) => p"typed def ${tfd.id}[${tfd.tps}]" case TypeParameterDef(tp) => p"$tp" case TypeParameter(id) => p"$id" @@ -428,40 +239,24 @@ trait Printers { self: Trees => case WildcardPattern(None) => p"_" case WildcardPattern(Some(id)) => p"$id" - case CaseClassPattern(ob, cct, subps) => + case CaseClassPattern(ob, ct, subps) => ob.foreach { b => p"$b @ " } // Print only the classDef because we don't want type parameters in patterns - printNameWithPath(cct.classDef) - if (!cct.classDef.isCaseObject) p"($subps)" + p"${ct.id}" + p"($subps)" case InstanceOfPattern(ob, cct) => - if (cct.classDef.isCaseObject) { - ob.foreach { b => p"$b @ " } - } else { - ob.foreach { b => p"$b : " } - } - // It's ok to print the whole type because there are no type parameters for case objects + ob.foreach { b => p"$b : " } + // It's ok to print the whole type although scalac will complain about erasure p"$cct" case TuplePattern(ob, subps) => ob.foreach { b => p"$b @ " } p"($subps)" - case UnapplyPattern(ob, tfd, subps) => + case UnapplyPattern(ob, id, tps, subps) => ob.foreach { b => p"$b @ " } - - // @mk: I admit this is pretty ugly - (for { - p <- opgm - mod <- p.modules.find( _.definedFunctions contains tfd.fd ) - } yield mod) match { - case Some(obj) => - printNameWithPath(obj) - case None => - p"<unknown object>" - } - - p"(${nary(subps)})" + p"$id(${nary(subps)})" case LiteralPattern(ob, lit) => ob foreach { b => p"$b @ " } @@ -476,77 +271,29 @@ trait Printers { self: Trees => case CharType => p"Char" case BooleanType => p"Boolean" case StringType => p"String" - case ArrayType(bt) => p"Array[$bt]" case SetType(bt) => p"Set[$bt]" case BagType(bt) => p"Bag[$bt]" case MapType(ft,tt) => p"Map[$ft, $tt]" case TupleType(tpes) => p"($tpes)" case FunctionType(fts, tt) => p"($fts) => $tt" case c: ClassType => - printNameWithPath(c.classDef) - p"${nary(c.tps, ", ", "[", "]")}" + p"${c.id}${nary(c.tps, ", ", "[", "]")}" // Definitions case Program(units) => p"""${nary(units filter { /*opts.printUniqueIds ||*/ _.isMainUnit }, "\n\n")}""" - case UnitDef(id,pack, imports, defs,_) => - if (pack.nonEmpty){ - p"""|package ${pack mkString "."} - |""" - } - p"""|${nary(imports,"\n")} - | - |${nary(defs,"\n\n")} - |""" - - case Import(path, isWild) => - if (isWild) { - p"import ${nary(path,".")}._" - } else { - p"import ${nary(path,".")}" - } - - case ModuleDef(id, defs, _) => - p"""|object $id { - | ${nary(defs, "\n\n")} - |}""" - - case acd : AbstractClassDef => + case acd: AbstractClassDef => p"abstract class ${acd.id}${nary(acd.tparams, ", ", "[", "]")}" - acd.parent.foreach{ par => - p" extends ${par.id}" - } - - if (acd.methods.nonEmpty) { - p"""| { - | ${nary(acd.methods, "\n\n")} - |}""" - } - case ccd : CaseClassDef => - if (ccd.isCaseObject) { - p"case object ${ccd.id}" - } else { - p"case class ${ccd.id}" - } - + p"case class ${ccd.id}" p"${nary(ccd.tparams, ", ", "[", "]")}" - - if (!ccd.isCaseObject) { - p"(${ccd.fields})" - } + p"(${ccd.fields})" ccd.parent.foreach { par => // Remember child and parents tparams are simple bijection - p" extends ${par.id}${nary(ccd.tparams, ", ", "[", "]")}" - } - - if (ccd.methods.nonEmpty) { - p"""| { - | ${nary(ccd.methods, "\n\n") } - |}""" + p" extends ${par}${nary(ccd.tparams, ", ", "[", "]")}" } case fd: FunDef => @@ -555,12 +302,9 @@ trait Printers { self: Trees => |""" } - if (fd.canBeStrictField) { - p"val ${fd.id} : " - } else if (fd.canBeLazyField) { - p"lazy val ${fd.id} : " - } else { - p"def ${fd.id}${nary(fd.tparams, ", ", "[", "]")}(${fd.params}): " + p"def ${fd.id}${nary(fd.tparams, ", ", "[", "]")}" + if (fd.params.nonEmpty) { + p"(${fd.params}): " } p"${fd.returnType} = ${fd.fullBody}" @@ -597,41 +341,8 @@ trait Printers { self: Trees => } } - protected object FcallMethodInvocation { - def unapply(fi: FunctionInvocation): Option[(Expr, FunDef, Identifier, Seq[TypeTree], Seq[Expr])] = { - val FunctionInvocation(tfd, args) = fi - tfd.fd.methodOwner.map { cd => - val (rec, rargs) = (args.head, args.tail) - - val fid = tfd.fd.id - - val realtps = tfd.tps.drop(cd.tparams.size) - - (rec, tfd.fd, fid, realtps, rargs) - } - } - } - - protected object BinaryMethodCall { - val makeBinary = Set("+", "-", "*", "::", "++", "--", "&&", "||", "/") - - def unapply(fi: FunctionInvocation): Option[(Expr, String, Expr)] = fi match { - case FcallMethodInvocation(rec, _, id, Nil, List(a)) => - val name = id.name - if (makeBinary contains name) { - if(name == "::") - Some((a, name, rec)) - else - Some((rec, name, a)) - } else { - None - } - case _ => None - } - } - protected def isSimpleExpr(e: Expr): Boolean = e match { - case _: LetDef | _: Let | LetPattern(_, _, _) | _: Assert | _: Require => false + case _: Let | LetPattern(_, _, _) | _: Assert | _: Require => false case p: PrettyPrintable => p.isSimpleExpr case _ => true } @@ -639,8 +350,6 @@ trait Printers { self: Trees => protected def noBracesSub(e: Expr): Seq[Expr] = e match { case Assert(_, _, bd) => Seq(bd) case Let(_, _, bd) => Seq(bd) - case xlang.Expressions.LetVar(_, _, bd) => Seq(bd) - case LetDef(_, bd) => Seq(bd) case LetPattern(_, _, bd) => Seq(bd) case Require(_, bd) => Seq(bd) case IfExpr(_, t, e) => Seq(t, e) // if-else always has braces anyway @@ -652,11 +361,6 @@ trait Printers { self: Trees => case (e: Expr, _) if isSimpleExpr(e) => false case (e: Expr, Some(within: Expr)) if noBracesSub(within) contains e => false case (_: Expr, Some(_: MatchCase)) => false - case (_: LetDef, Some(_: LetDef)) => false - case (_: Expr, Some(_: xlang.Expressions.Block)) => false - case (_: xlang.Expressions.Block, Some(_: xlang.Expressions.While)) => false - case (_: xlang.Expressions.Block, Some(_: FunDef)) => false - case (_: xlang.Expressions.Block, Some(_: LetDef)) => false case (e: Expr, Some(_)) => true case _ => false } @@ -665,12 +369,12 @@ trait Printers { self: Trees => case (pa: PrettyPrintable) => pa.printPrecedence case (_: ElementOfSet) => 0 case (_: Modulo) => 1 - case (_: Or | BinaryMethodCall(_, "||", _)) => 2 - case (_: And | BinaryMethodCall(_, "&&", _)) => 3 + case (_: Or) => 2 + case (_: And) => 3 case (_: GreaterThan | _: GreaterEquals | _: LessEquals | _: LessThan | _: Implies) => 4 case (_: Equals | _: Not) => 5 - case (_: Plus | _: BVPlus | _: Minus | _: BVMinus | _: SetUnion| _: SetDifference | BinaryMethodCall(_, "+" | "-", _)) => 7 - case (_: Times | _: BVTimes | _: Division | _: BVDivision | _: Remainder | _: BVRemainder | BinaryMethodCall(_, "*" | "/", _)) => 8 + case (_: Plus | _: Minus | _: SetUnion| _: SetDifference) => 7 + case (_: Times | _: Division | _: Remainder) => 8 case _ => 9 } @@ -679,12 +383,10 @@ trait Printers { self: Trees => case (_, None) => false case (_, Some( _: Ensuring | _: Assert | _: Require | _: Definition | _: MatchExpr | _: MatchCase | - _: Let | _: LetDef | _: IfExpr | _ : CaseClass | _ : Lambda | _ : Choose | _ : Tuple + _: Let | _: IfExpr | _ : CaseClass | _ : Lambda | _ : Tuple )) => false case (_:Pattern, _) => false case (ex: StringConcat, Some(_: StringConcat)) => false - case (b1 @ BinaryMethodCall(_, _, _), Some(b2 @ BinaryMethodCall(_, _, _))) if precedence(b1) > precedence(b2) => false - case (BinaryMethodCall(_, _, _), Some(_: FunctionInvocation)) => true case (_, Some(_: FunctionInvocation)) => false case (ie: IfExpr, _) => true case (me: MatchExpr, _ ) => true @@ -693,11 +395,11 @@ trait Printers { self: Trees => } } - implicit class Printable(val f: PrinterContext => Any) extends AnyVal { + implicit class Printable(val f: PrinterContext => Any) { def print(ctx: PrinterContext) = f(ctx) } - implicit class PrintingHelper(val sc: StringContext) extends AnyVal { + implicit class PrintingHelper(val sc: StringContext) { def p(args: Any*)(implicit ctx: PrinterContext): Unit = { val printer = ctx.printer @@ -814,12 +516,12 @@ trait Printers { self: Trees => printer.toString } - def apply(tree: Tree, ctx: LeonContext): String = { + def apply(tree: Tree, ctx: Context): String = { val opts = PrinterOptions.fromContext(ctx) apply(tree, opts, None) } - def apply(tree: Tree, ctx: LeonContext, pgm: Program): String = { + def apply(tree: Tree, ctx: Context, pgm: Program): String = { val opts = PrinterOptions.fromContext(ctx) apply(tree, opts, Some(pgm)) } diff --git a/src/main/scala/inox/ast/SymbolOps.scala b/src/main/scala/inox/ast/SymbolOps.scala new file mode 100644 index 000000000..0595a6adf --- /dev/null +++ b/src/main/scala/inox/ast/SymbolOps.scala @@ -0,0 +1,1050 @@ +/* Copyright 2009-2016 EPFL, Lausanne */ + +package inox +package ast + +import utils._ + +/** Provides functions to manipulate [[purescala.Expressions]]. + * + * This object provides a few generic operations on Leon expressions, + * as well as some common operations. + * + * The generic operations lets you apply operations on a whole tree + * expression. You can look at: + * - [[GenTreeOps.fold foldRight]] + * - [[GenTreeOps.preTraversal preTraversal]] + * - [[GenTreeOps.postTraversal postTraversal]] + * - [[GenTreeOps.preMap preMap]] + * - [[GenTreeOps.postMap postMap]] + * - [[GenTreeOps.genericTransform genericTransform]] + * + * These operations usually take a higher order function that gets applied to the + * expression tree in some strategy. They provide an expressive way to build complex + * operations on Leon expressions. + * + */ +trait SymbolOps extends TreeOps { + import trees._ + import trees.exprOps._ + implicit val symbols: Symbols + import symbols._ + + /** Computes the negation of a boolean formula, with some simplifications. */ + def negate(expr: Expr) : Expr = { + (expr match { + case Let(i,b,e) => Let(i,b,negate(e)) + case Not(e) => e + case Implies(e1,e2) => and(e1, negate(e2)) + case Or(exs) => and(exs map negate: _*) + case And(exs) => or(exs map negate: _*) + case LessThan(e1,e2) => GreaterEquals(e1,e2) + case LessEquals(e1,e2) => GreaterThan(e1,e2) + case GreaterThan(e1,e2) => LessEquals(e1,e2) + case GreaterEquals(e1,e2) => LessThan(e1,e2) + case IfExpr(c,e1,e2) => IfExpr(c, negate(e1), negate(e2)) + case BooleanLiteral(b) => BooleanLiteral(!b) + case e => Not(e) + }).setPos(expr) + } + + /** Replace each node by its constructor + * + * Remap the expression by calling the corresponding constructor + * for each node of the expression. The constructor will perfom + * some local simplifications, resulting in a simplified expression. + */ + def simplifyByConstructors(expr: Expr): Expr = { + def step(e: Expr): Option[Expr] = e match { + case Not(t) => Some(not(t)) + case UMinus(t) => Some(uminus(t)) + case CaseClassSelector(cd, e, sel) => Some(caseClassSelector(cd, e, sel)) + case AsInstanceOf(e, ct) => Some(asInstOf(e, ct)) + case Equals(t1, t2) => Some(equality(t1, t2)) + case Implies(t1, t2) => Some(implies(t1, t2)) + case Plus(t1, t2) => Some(plus(t1, t2)) + case Minus(t1, t2) => Some(minus(t1, t2)) + case Times(t1, t2) => Some(times(t1, t2)) + case And(args) => Some(andJoin(args)) + case Or(args) => Some(orJoin(args)) + case Tuple(args) => Some(tupleWrap(args)) + case MatchExpr(scrut, cases) => Some(matchExpr(scrut, cases)) + case _ => None + } + postMap(step)(expr) + } + + /** Normalizes the expression expr */ + def normalizeExpression(expr: Expr): Expr = { + def rec(e: Expr): Option[Expr] = e match { + case TupleSelect(Let(id, v, b), ts) => + Some(Let(id, v, tupleSelect(b, ts, true))) + + case TupleSelect(LetTuple(ids, v, b), ts) => + Some(letTuple(ids, v, tupleSelect(b, ts, true))) + + case CaseClassSelector(cct, cc: CaseClass, id) => + Some(caseClassSelector(cct, cc, id).copiedFrom(e)) + + case IfExpr(c, thenn, elze) if (thenn == elze) && isPurelyFunctional(c) => + Some(thenn) + + case IfExpr(c, BooleanLiteral(true), BooleanLiteral(false)) => + Some(c) + + case IfExpr(Not(c), thenn, elze) => + Some(IfExpr(c, elze, thenn).copiedFrom(e)) + + case IfExpr(c, BooleanLiteral(false), BooleanLiteral(true)) => + Some(Not(c).copiedFrom(e)) + + case FunctionInvocation(id, tps, List(IfExpr(c, thenn, elze))) => + Some(IfExpr(c, FunctionInvocation(id, tps, List(thenn)), FunctionInvocation(id, tps, List(elze))).copiedFrom(e)) + + case _ => + None + } + + fixpoint(postMap(rec))(expr) + } + + private val typedIds: scala.collection.mutable.Map[Type, List[Identifier]] = + scala.collection.mutable.Map.empty.withDefaultValue(List.empty) + + /** Normalizes identifiers in an expression to enable some notion of structural + * equality between expressions on which usual equality doesn't make sense + * (i.e. closures). + * + * This function relies on the static map `typedIds` to ensure identical + * structures and must therefore be synchronized. + * + * The optional argument [[onlySimple]] determines whether non-simple expressions + * (see [[isSimple]]) should be normalized into a dependency or recursed into + * (when they don't depend on [[args]]). This distinction is used in the + * unrolling solver to provide geenral equality checks between functions even when + * they have complex closures. + */ + def normalizeStructure(args: Seq[ValDef], expr: Expr, onlySimple: Boolean = true): + (Seq[ValDef], Expr, Map[Variable, Expr]) = synchronized { + val vars = args.map(_.toVariable).toSet + + class Normalizer extends TreeTransformer { + var subst: Map[Variable, Expr] = Map.empty + var remainingIds: Map[Type, List[Identifier]] = typedIds.toMap + + def getId(e: Expr): Identifier = { + val tpe = bestRealType(e.getType) + val newId = remainingIds.get(tpe) match { + case Some(x :: xs) => + remainingIds += tpe -> xs + x + case _ => + val x = FreshIdentifier("x", true) + typedIds(tpe) = typedIds(tpe) :+ x + x + } + subst += Variable(newId, tpe) -> e + newId + } + + override def transform(id: Identifier, tpe: Type): (Identifier, Type) = subst.get(Variable(id, tpe)) match { + case Some(Variable(newId, tpe)) => (newId, tpe) + case Some(_) => scala.sys.error("Should never happen!") + case None => (getId(Variable(id, tpe)), tpe) + } + + override def transform(e: Expr): Expr = e match { + case expr if (isSimple(expr) || !onlySimple) && (variablesOf(expr) & vars).isEmpty => + Variable(getId(expr), expr.getType) + case f: Forall => + val (args, body, newSubst) = normalizeStructure(f.args, f.body, onlySimple) + subst ++= newSubst + Forall(args, body) + case l: Lambda => + val (args, body, newSubst) = normalizeStructure(l.args, l.body, onlySimple) + subst ++= newSubst + Lambda(args, body) + case _ => super.transform(e) + } + } + + val n = new Normalizer + // this registers the argument images into n.subst + val bindings = args map n.transform + val normalized = n.transform(matchToIfThenElse(expr)) + + val freeVars = variablesOf(normalized) -- bindings.map(_.toVariable) + val bodySubst = n.subst.filter(p => freeVars(p._1)) + + (bindings, normalized, bodySubst) + } + + def normalizeStructure(lambda: Lambda): (Lambda, Map[Variable, Expr]) = { + val (args, body, subst) = normalizeStructure(lambda.args, lambda.body, onlySimple = false) + (Lambda(args, body), subst) + } + + def normalizeStructure(forall: Forall): (Forall, Map[Variable, Expr]) = { + val (args, body, subst) = normalizeStructure(forall.args, forall.body) + (Forall(args, body), subst) + } + + /** Fully expands all let expressions. */ + def expandLets(expr: Expr): Expr = { + def rec(ex: Expr, s: Map[Variable,Expr]) : Expr = ex match { + case v: Variable if s.isDefinedAt(v) => rec(s(v), s) + case l @ Let(i,e,b) => rec(b, s + (i.toVariable -> rec(e, s))) + case i @ IfExpr(t1,t2,t3) => IfExpr(rec(t1, s),rec(t2, s),rec(t3, s)).copiedFrom(i) + case m @ MatchExpr(scrut, cses) => matchExpr(rec(scrut, s), cses.map(inCase(_, s))).copiedFrom(m) + case n @ Deconstructor(args, recons) => + var change = false + val rargs = args.map(a => { + val ra = rec(a, s) + if(ra != a) { + change = true + ra + } else { + a + } + }) + if(change) + recons(rargs).copiedFrom(n) + else + n + case unhandled => scala.sys.error("Unhandled case in expandLets: " + unhandled) + } + + def inCase(cse: MatchCase, s: Map[Variable,Expr]) : MatchCase = { + import cse._ + MatchCase(pattern, optGuard map { rec(_, s) }, rec(rhs,s)) + } + + rec(expr, Map.empty) + } + + /** Lifts lets to top level. + * + * Does not push any used variable out of scope. + * Assumes no match expressions (i.e. matchToIfThenElse has been called on e) + */ + def liftLets(e: Expr): Expr = { + + type C = Seq[(ValDef, Expr)] + + def combiner(e: Expr, defs: Seq[C]): C = (e, defs) match { + case (Let(v, ex, b), Seq(inDef, inBody)) => + inDef ++ ((v, ex) +: inBody) + case _ => + defs.flatten + } + + def noLet(e: Expr, defs: C) = e match { + case Let(_, _, b) => (b, defs) + case _ => (e, defs) + } + + val (bd, defs) = genericTransform[C](noTransformer, noLet, combiner)(Seq())(e) + + defs.foldRight(bd){ case ((vd, e), body) => Let(vd, e, body) } + } + + /** Recursively transforms a pattern on a boolean formula expressing the conditions for the input expression, possibly including name binders + * + * For example, the following pattern on the input `i` + * {{{ + * case m @ MyCaseClass(t: B, (_, 7)) => + * }}} + * will yield the following condition before simplification (to give some flavour) + * + * {{{and(IsInstanceOf(MyCaseClass, i), and(Equals(m, i), InstanceOfClass(B, i.t), equals(i.k.arity, 2), equals(i.k._2, 7))) }}} + * + * Pretty-printed, this would be: + * {{{ + * i.instanceOf[MyCaseClass] && m == i && i.t.instanceOf[B] && i.k.instanceOf[Tuple2] && i.k._2 == 7 + * }}} + * + * @see [[purescala.Expressions.Pattern]] + */ + def conditionForPattern(in: Expr, pattern: Pattern, includeBinders: Boolean = false): Path = { + def bind(ob: Option[ValDef], to: Expr): Path = { + if (!includeBinders) { + Path.empty + } else { + ob.map(v => Path.empty withBinding (v -> to)).getOrElse(Path.empty) + } + } + + def rec(in: Expr, pattern: Pattern): Path = { + pattern match { + case WildcardPattern(ob) => + bind(ob, in) + + case InstanceOfPattern(ob, ct) => + val tcd = ct.tcd + if (tcd.root == tcd) { + bind(ob, in) + } else { + Path(IsInstanceOf(in, ct)) merge bind(ob, in) + } + + case CaseClassPattern(ob, cct, subps) => + assert(cct.tcd.fields.size == subps.size) + val pairs = cct.tcd.fields.map(_.id).toList zip subps.toList + val subTests = pairs.map(p => rec(caseClassSelector(cct, in, p._1), p._2)) + Path(IsInstanceOf(in, cct)) merge bind(ob, in) merge subTests + + case TuplePattern(ob, subps) => + val TupleType(tpes) = in.getType + assert(tpes.size == subps.size) + val subTests = subps.zipWithIndex.map { + case (p, i) => rec(tupleSelect(in, i+1, subps.size), p) + } + bind(ob, in) merge subTests + + case up @ UnapplyPattern(ob, id, tps, subps) => + val subs = unwrapTuple(up.get(in), subps.size).zip(subps) map (rec _).tupled + bind(ob, in) withCond up.isSome(in) merge subs + + case LiteralPattern(ob, lit) => + Path(Equals(in, lit)) merge bind(ob, in) + } + } + + rec(in, pattern) + } + + /** Converts the pattern applied to an input to a map between identifiers and expressions */ + def mapForPattern(in: Expr, pattern: Pattern): Map[Variable,Expr] = { + def bindIn(ov: Option[ValDef], cast: Option[ClassType] = None): Map[Variable, Expr] = ov match { + case None => Map() + case Some(v) => Map(v.toVariable -> cast.map(asInstOf(in, _)).getOrElse(in)) + } + + pattern match { + case CaseClassPattern(b, ct, subps) => + val tcd = ct.tcd + assert(tcd.fields.size == subps.size) + val pairs = tcd.fields.map(_.id).toList zip subps.toList + val subMaps = pairs.map(p => mapForPattern(caseClassSelector(ct, asInstOf(in, ct), p._1), p._2)) + val together = subMaps.flatten.toMap + bindIn(b, Some(ct)) ++ together + + case TuplePattern(b, subps) => + val TupleType(tpes) = in.getType + assert(tpes.size == subps.size) + + val maps = subps.zipWithIndex.map{case (p, i) => mapForPattern(tupleSelect(in, i+1, subps.size), p)} + val map = maps.flatten.toMap + bindIn(b) ++ map + + case up @ UnapplyPattern(b, _, _, subps) => + bindIn(b) ++ unwrapTuple(up.getUnsafe(in), subps.size).zip(subps).flatMap { + case (e, p) => mapForPattern(e, p) + }.toMap + + case InstanceOfPattern(b, ct) => + bindIn(b, Some(ct)) + + case other => + bindIn(other.binder) + } + } + + /** Rewrites all pattern-matching expressions into if-then-else expressions + * Introduces additional error conditions. Does not introduce additional variables. + */ + def matchToIfThenElse(expr: Expr): Expr = { + + def rewritePM(e: Expr): Option[Expr] = e match { + case m @ MatchExpr(scrut, cases) => + // println("Rewriting the following PM: " + e) + + val condsAndRhs = for (cse <- cases) yield { + val map = mapForPattern(scrut, cse.pattern) + val patCond = conditionForPattern(scrut, cse.pattern, includeBinders = false) + val realCond = cse.optGuard match { + case Some(g) => patCond withCond replaceFromSymbols(map, g) + case None => patCond + } + val newRhs = replaceFromSymbols(map, cse.rhs) + (realCond.toClause, newRhs, cse) + } + + val bigIte = condsAndRhs.foldRight[Expr](Error(m.getType, "Match is non-exhaustive").copiedFrom(m))((p1, ex) => { + if(p1._1 == BooleanLiteral(true)) { + p1._2 + } else { + IfExpr(p1._1, p1._2, ex).copiedFrom(p1._3) + } + }) + + Some(bigIte) + + case _ => None + } + + preMap(rewritePM)(expr) + } + + /** For each case in the [[purescala.Expressions.MatchExpr MatchExpr]], concatenates the path condition with the newly induced conditions. + * + * Each case holds the conditions on other previous cases as negative. + * + * @see [[purescala.ExprOps#conditionForPattern conditionForPattern]] + * @see [[purescala.ExprOps#mapForPattern mapForPattern]] + */ + def matchExprCaseConditions(m: MatchExpr, path: Path): Seq[Path] = { + val MatchExpr(scrut, cases) = m + var pcSoFar = path + + for (c <- cases) yield { + val g = c.optGuard getOrElse BooleanLiteral(true) + val cond = conditionForPattern(scrut, c.pattern, includeBinders = true) + val localCond = pcSoFar merge (cond withCond g) + + // These contain no binders defined in this MatchCase + val condSafe = conditionForPattern(scrut, c.pattern) + val gSafe = replaceFromSymbols(mapForPattern(scrut, c.pattern), g) + pcSoFar = pcSoFar merge (condSafe withCond gSafe).negate + + localCond + } + } + + /** Condition to pass this match case, expressed w.r.t scrut only */ + def matchCaseCondition(scrut: Expr, c: MatchCase): Path = { + + val patternC = conditionForPattern(scrut, c.pattern, includeBinders = false) + + c.optGuard match { + case Some(g) => + // guard might refer to binders + val map = mapForPattern(scrut, c.pattern) + patternC withCond replaceFromSymbols(map, g) + + case None => + patternC + } + } + + private def hasInstance(tcd: TypedClassDef): Boolean = { + val ancestors = tcd.ancestors.toSet + + def isRecursive(tpe: Type, seen: Set[TypedClassDef]): Boolean = tpe match { + case ct: ClassType => + val ctcd = ct.tcd + if (seen(ctcd)) { + false + } else if (ancestors(ctcd)) { + true + } else { + ctcd.fieldsTypes.exists(isRecursive(_, seen + ctcd)) + } + case _ => false + } + + tcd match { + case tacd: TypedAbstractClassDef => + tacd.ccDescendants.filterNot(tccd => isRecursive(tccd.toType, Set.empty)).nonEmpty + case tccd: TypedCaseClassDef => + !isRecursive(tccd.toType, Set.empty) + } + } + + /** Returns simplest value of a given type */ + def simplestValue(tpe: Type): Expr = tpe match { + case StringType => StringLiteral("") + case Int32Type => IntLiteral(0) + case RealType => FractionalLiteral(0, 1) + case IntegerType => IntegerLiteral(0) + case CharType => CharLiteral('a') + case BooleanType => BooleanLiteral(false) + case UnitType => UnitLiteral() + case SetType(baseType) => FiniteSet(Seq(), baseType) + case BagType(baseType) => FiniteBag(Seq(), baseType) + case MapType(fromType, toType) => FiniteMap(Seq(), simplestValue(toType), fromType) + case TupleType(tpes) => Tuple(tpes.map(simplestValue)) + + case ct @ ClassType(id, tps) => + val tcd = ct.lookupClass.getOrElse(throw ClassLookupException(id)) + if (!hasInstance(tcd)) scala.sys.error(ct +" does not seem to be well-founded") + + val tccd @ TypedCaseClassDef(cd, tps) = tcd match { + case tacd: TypedAbstractClassDef => + tacd.ccDescendants.filter(hasInstance(_)).sortBy(_.fields.size).head + case tccd: TypedCaseClassDef => tccd + } + + CaseClass(ClassType(cd.id, tps), tccd.fieldsTypes.map(simplestValue)) + + case tp: TypeParameter => + GenericValue(tp, 0) + + case ft @ FunctionType(from, to) => + Lambda(from.map(tpe => ValDef(FreshIdentifier("x", true), tpe)), simplestValue(to)) + + case _ => scala.sys.error("I can't choose simplest value for type " + tpe) + } + + def valuesOf(tp: Type): Stream[Expr] = { + import utils.StreamUtils._ + tp match { + case BooleanType => + Stream(BooleanLiteral(false), BooleanLiteral(true)) + case BVType(size) => + val count = BigInt(2).pow(size - 1) + def rec(i: BigInt): Stream[BigInt] = + if (i <= count) Stream.cons(i, Stream.cons(-i - 1, rec(i + 1))) + else Stream.empty + rec(0) map (BVLiteral(_, size)) + case IntegerType => + Stream.iterate(BigInt(0)) { prev => + if (prev > 0) -prev else -prev + 1 + } map IntegerLiteral + case UnitType => + Stream(UnitLiteral()) + case tp: TypeParameter => + Stream.from(0) map (GenericValue(tp, _)) + case TupleType(stps) => + cartesianProduct(stps map (tp => valuesOf(tp))) map Tuple + case SetType(base) => + def elems = valuesOf(base) + elems.scanLeft(Stream(FiniteSet(Seq(), base): Expr)){ (prev, curr) => + prev flatMap { case fs @ FiniteSet(elems, tp) => Stream(fs, FiniteSet(elems :+ curr, tp)) } + }.flatten + case BagType(base) => + def elems = valuesOf(base) + def counts = Stream.iterate(BigInt(1))(prev => prev + 1) map IntegerLiteral + val pairs = interleave(elems.map(e => counts.map(c => e -> c))) + pairs.scanLeft(Stream(FiniteBag(Seq(), base): Expr)) { (prev, curr) => + prev flatMap { case fs @ FiniteBag(elems, tp) => Stream(fs, FiniteBag(elems :+ curr, tp)) } + }.flatten + case MapType(from, to) => + def elems = cartesianProduct(valuesOf(from), valuesOf(to)) + val seqs = elems.scanLeft(Stream(Seq[(Expr, Expr)]())) { (prev, curr) => + prev flatMap { case seq => Stream(seq, seq :+ curr) } + }.flatten + cartesianProduct(seqs, valuesOf(to)) map { case (values, default) => FiniteMap(values, default, from) } + case ct: ClassType => ct.lookupClass match { + case Some(tccd: TypedCaseClassDef) => + cartesianProduct(tccd.fieldsTypes map valuesOf) map (CaseClass(ct, _)) + case Some(accd: TypedAbstractClassDef) => + interleave(accd.ccDescendants.map(tccd => valuesOf(tccd.toType))) + case None => throw ClassLookupException(ct.id) + } + } + } + + + /** Hoists all IfExpr at top level. + * + * Guarantees that all IfExpr will be at the top level and as soon as you + * encounter a non-IfExpr, then no more IfExpr can be found in the + * sub-expressions + * + * Assumes no match expressions + */ + def hoistIte(expr: Expr): Expr = { + def transform(expr: Expr): Option[Expr] = expr match { + case IfExpr(c, t, e) => None + + case nop@Deconstructor(ts, op) => { + val iteIndex = ts.indexWhere{ case IfExpr(_, _, _) => true case _ => false } + if(iteIndex == -1) None else { + val (beforeIte, startIte) = ts.splitAt(iteIndex) + val afterIte = startIte.tail + val IfExpr(c, t, e) = startIte.head + Some(IfExpr(c, + op(beforeIte ++ Seq(t) ++ afterIte).copiedFrom(nop), + op(beforeIte ++ Seq(e) ++ afterIte).copiedFrom(nop) + )) + } + } + case _ => None + } + + postMap(transform, applyRec = true)(expr) + } + + def collectWithPC[T](f: PartialFunction[Expr, T])(expr: Expr): Seq[(T, Path)] = { + + def rec(expr: Expr, path: Path): Seq[(T, Path)] = { + val seq = if (f.isDefinedAt(expr)) { + Seq(f(expr) -> path) + } else { + Seq.empty[(T, Path)] + } + + val rseq = expr match { + case Let(i, v, b) => + rec(v, path) ++ + rec(b, path withBinding (i -> v)) + + case Ensuring(Require(pre, body), Lambda(Seq(arg), post)) => + rec(pre, path) ++ + rec(body, path withCond pre) ++ + rec(post, path withCond pre withBinding (arg -> body)) + + case Ensuring(body, Lambda(Seq(arg), post)) => + rec(body, path) ++ + rec(post, path withBinding (arg -> body)) + + case Require(pre, body) => + rec(pre, path) ++ + rec(body, path withCond pre) + + case Assert(pred, err, body) => + rec(pred, path) ++ + rec(body, path withCond pred) + + case MatchExpr(scrut, cases) => + val rs = rec(scrut, path) + var soFar = path + + rs ++ cases.flatMap { c => + val patternPathPos = conditionForPattern(scrut, c.pattern, includeBinders = true) + val patternPathNeg = conditionForPattern(scrut, c.pattern, includeBinders = false) + val map = mapForPattern(scrut, c.pattern) + val guardOrTrue = c.optGuard.getOrElse(BooleanLiteral(true)) + val guardMapped = replaceFromSymbols(map, guardOrTrue) + + val rc = rec((patternPathPos withCond guardOrTrue).fullClause, soFar) + val subPath = soFar merge (patternPathPos withCond guardOrTrue) + val rrhs = rec(c.rhs, subPath) + + soFar = soFar merge (patternPathNeg withCond guardMapped).negate + rc ++ rrhs + } + + case IfExpr(cond, thenn, elze) => + rec(cond, path) ++ + rec(thenn, path withCond cond) ++ + rec(elze, path withCond Not(cond)) + + case And(es) => + var soFar = path + es.flatMap { e => + val re = rec(e, soFar) + soFar = soFar withCond e + re + } + + case Or(es) => + var soFar = path + es.flatMap { e => + val re = rec(e, soFar) + soFar = soFar withCond Not(e) + re + } + + case Implies(lhs, rhs) => + rec(lhs, path) ++ + rec(rhs, path withCond lhs) + + case Operator(es, _) => + es.flatMap(rec(_, path)) + + case _ => sys.error("Expression " + expr + "["+expr.getClass+"] is not extractable") + } + + seq ++ rseq + } + + rec(expr, Path.empty) + } + + /** Returns the value for an identifier given a model. */ + def valuateWithModel(model: Model)(vd: ValDef): Expr = { + model.getOrElse(vd, simplestValue(vd.getType)) + } + + /** Substitute (free) variables in an expression with values form a model. + * + * Complete with simplest values in case of incomplete model. + */ + def valuateWithModelIn(expr: Expr, vars: Set[ValDef], model: Model): Expr = { + val valuator = valuateWithModel(model) _ + replace(vars.map(vd => vd.toVariable -> valuator(vd)).toMap, expr) + } + + /** Simple, local optimization on string */ + def simplifyString(expr: Expr): Expr = { + def simplify0(expr: Expr): Expr = (expr match { + case StringConcat(StringLiteral(""), b) => b + case StringConcat(b, StringLiteral("")) => b + case StringConcat(StringLiteral(a), StringLiteral(b)) => StringLiteral(a + b) + case StringLength(StringLiteral(a)) => IntLiteral(a.length) + case StringBigLength(StringLiteral(a)) => IntegerLiteral(a.length) + case SubString(StringLiteral(a), IntLiteral(start), IntLiteral(end)) => StringLiteral(a.substring(start.toInt, end.toInt)) + case BigSubString(StringLiteral(a), IntegerLiteral(start), IntegerLiteral(end)) => StringLiteral(a.substring(start.toInt, end.toInt)) + case _ => expr + }).copiedFrom(expr) + simplify0(expr) + fixpoint(simplePostTransform(simplify0))(expr) + } + + /** Simple, local simplification on arithmetic + * + * You should not assume anything smarter than some constant folding and + * simple cancellation. To avoid infinite cycle we only apply simplification + * that reduce the size of the tree. The only guarantee from this function is + * to not augment the size of the expression and to be sound. + */ + def simplifyArithmetic(expr: Expr): Expr = { + def simplify0(expr: Expr): Expr = (expr match { + case Plus(IntegerLiteral(i1), IntegerLiteral(i2)) => IntegerLiteral(i1 + i2) + case Plus(IntegerLiteral(zero), e) if zero == BigInt(0) => e + case Plus(e, IntegerLiteral(zero)) if zero == BigInt(0) => e + case Plus(e1, UMinus(e2)) => Minus(e1, e2) + case Plus(Plus(e, IntegerLiteral(i1)), IntegerLiteral(i2)) => Plus(e, IntegerLiteral(i1+i2)) + case Plus(Plus(IntegerLiteral(i1), e), IntegerLiteral(i2)) => Plus(IntegerLiteral(i1+i2), e) + + case Minus(e, IntegerLiteral(zero)) if zero == BigInt(0) => e + case Minus(IntegerLiteral(zero), e) if zero == BigInt(0) => UMinus(e) + case Minus(IntegerLiteral(i1), IntegerLiteral(i2)) => IntegerLiteral(i1 - i2) + case Minus(e1, UMinus(e2)) => Plus(e1, e2) + case Minus(e1, Minus(UMinus(e2), e3)) => Plus(e1, Plus(e2, e3)) + + case UMinus(IntegerLiteral(x)) => IntegerLiteral(-x) + case UMinus(UMinus(x)) => x + case UMinus(Plus(UMinus(e1), e2)) => Plus(e1, UMinus(e2)) + case UMinus(Minus(e1, e2)) => Minus(e2, e1) + + case Times(IntegerLiteral(i1), IntegerLiteral(i2)) => IntegerLiteral(i1 * i2) + case Times(IntegerLiteral(one), e) if one == BigInt(1) => e + case Times(IntegerLiteral(mone), e) if mone == BigInt(-1) => UMinus(e) + case Times(e, IntegerLiteral(one)) if one == BigInt(1) => e + case Times(IntegerLiteral(zero), _) if zero == BigInt(0) => IntegerLiteral(0) + case Times(_, IntegerLiteral(zero)) if zero == BigInt(0) => IntegerLiteral(0) + case Times(IntegerLiteral(i1), Times(IntegerLiteral(i2), t)) => Times(IntegerLiteral(i1*i2), t) + case Times(IntegerLiteral(i1), Times(t, IntegerLiteral(i2))) => Times(IntegerLiteral(i1*i2), t) + case Times(IntegerLiteral(i), UMinus(e)) => Times(IntegerLiteral(-i), e) + case Times(UMinus(e), IntegerLiteral(i)) => Times(e, IntegerLiteral(-i)) + case Times(IntegerLiteral(i1), Division(e, IntegerLiteral(i2))) if i2 != BigInt(0) && i1 % i2 == BigInt(0) => Times(IntegerLiteral(i1/i2), e) + + case Division(IntegerLiteral(i1), IntegerLiteral(i2)) if i2 != BigInt(0) => IntegerLiteral(i1 / i2) + case Division(e, IntegerLiteral(one)) if one == BigInt(1) => e + + //here we put more expensive rules + //btw, I know those are not the most general rules, but they lead to good optimizations :) + case Plus(UMinus(Plus(e1, e2)), e3) if e1 == e3 => UMinus(e2) + case Plus(UMinus(Plus(e1, e2)), e3) if e2 == e3 => UMinus(e1) + case Minus(e1, e2) if e1 == e2 => IntegerLiteral(0) + case Minus(Plus(e1, e2), Plus(e3, e4)) if e1 == e4 && e2 == e3 => IntegerLiteral(0) + case Minus(Plus(e1, e2), Plus(Plus(e3, e4), e5)) if e1 == e4 && e2 == e3 => UMinus(e5) + + case StringConcat(StringLiteral(""), a) => a + case StringConcat(a, StringLiteral("")) => a + case StringConcat(StringLiteral(a), StringLiteral(b)) => StringLiteral(a+b) + case StringConcat(StringLiteral(a), StringConcat(StringLiteral(b), c)) => StringConcat(StringLiteral(a+b), c) + case StringConcat(StringConcat(c, StringLiteral(a)), StringLiteral(b)) => StringConcat(c, StringLiteral(a+b)) + case StringConcat(a, StringConcat(b, c)) => StringConcat(StringConcat(a, b), c) + //default + case e => e + }).copiedFrom(expr) + + fixpoint(simplePostTransform(simplify0))(expr) + } + + /** + * Some helper methods for FractionalLiterals + */ + def normalizeFraction(fl: FractionalLiteral) = { + val FractionalLiteral(num, denom) = fl + val modNum = if (num < 0) -num else num + val modDenom = if (denom < 0) -denom else denom + val divisor = modNum.gcd(modDenom) + val simpNum = num / divisor + val simpDenom = denom / divisor + if (simpDenom < 0) + FractionalLiteral(-simpNum, -simpDenom) + else + FractionalLiteral(simpNum, simpDenom) + } + + val realzero = FractionalLiteral(0, 1) + def floor(fl: FractionalLiteral): FractionalLiteral = { + val FractionalLiteral(n, d) = normalizeFraction(fl) + if (d == 0) throw new IllegalStateException("denominator zero") + if (n == 0) realzero + else if (n > 0) { + //perform integer division + FractionalLiteral(n / d, 1) + } else { + //here the number is negative + if (n % d == 0) + FractionalLiteral(n / d, 1) + else { + //perform integer division and subtract 1 + FractionalLiteral(n / d - 1, 1) + } + } + } + + /* ================= + * Body manipulation + * ================= + */ + + /** Returns whether a particular [[Expressions.Expr]] contains specification + * constructs, namely [[Expressions.Require]] and [[Expressions.Ensuring]]. + */ + def hasSpec(e: Expr): Boolean = exists { + case Require(_, _) => true + case Ensuring(_, _) => true + case Let(i, e, b) => hasSpec(b) + case _ => false + } (e) + + /** Merges the given [[Path]] into the provided [[Expressions.Expr]]. + * + * This method expects to run on a [[Definitions.FunDef.fullBody]] and merges into + * existing pre- and postconditions. + * + * @param expr The current body + * @param path The path that should be wrapped around the given body + * @see [[Expressions.Ensuring]] + * @see [[Expressions.Require]] + */ + def withPath(expr: Expr, path: Path): Expr = expr match { + case Let(i, e, b) => withPath(b, path withBinding (i -> e)) + case Require(pre, b) => path specs (b, pre) + case Ensuring(Require(pre, b), post) => path specs (b, pre, post) + case Ensuring(b, post) => path specs (b, post = post) + case b => path specs b + } + + /** Replaces the precondition of an existing [[Expressions.Expr]] with a new one. + * + * If no precondition is provided, removes any existing precondition. + * Else, wraps the expression with a [[Expressions.Require]] clause referring to the new precondition. + * + * @param expr The current expression + * @param pred An optional precondition. Setting it to None removes any precondition. + * @see [[Expressions.Ensuring]] + * @see [[Expressions.Require]] + */ + def withPrecondition(expr: Expr, pred: Option[Expr]): Expr = (pred, expr) match { + case (Some(newPre), Require(pre, b)) => req(newPre, b) + case (Some(newPre), Ensuring(Require(pre, b), p)) => Ensuring(req(newPre, b), p) + case (Some(newPre), Ensuring(b, p)) => Ensuring(req(newPre, b), p) + case (Some(newPre), Let(i, e, b)) if hasSpec(b) => Let(i, e, withPrecondition(b, pred)) + case (Some(newPre), b) => req(newPre, b) + case (None, Require(pre, b)) => b + case (None, Ensuring(Require(pre, b), p)) => Ensuring(b, p) + case (None, Let(i, e, b)) if hasSpec(b) => Let(i, e, withPrecondition(b, pred)) + case (None, b) => b + } + + /** Replaces the postcondition of an existing [[Expressions.Expr]] with a new one. + * + * If no postcondition is provided, removes any existing postcondition. + * Else, wraps the expression with a [[Expressions.Ensuring]] clause referring to the new postcondition. + * + * @param expr The current expression + * @param oie An optional postcondition. Setting it to None removes any postcondition. + * @see [[Expressions.Ensuring]] + * @see [[Expressions.Require]] + */ + def withPostcondition(expr: Expr, oie: Option[Expr]): Expr = (oie, expr) match { + case (Some(npost), Ensuring(b, post)) => ensur(b, npost) + case (Some(npost), Let(i, e, b)) if hasSpec(b) => Let(i, e, withPostcondition(b, oie)) + case (Some(npost), b) => ensur(b, npost) + case (None, Ensuring(b, p)) => b + case (None, Let(i, e, b)) if hasSpec(b) => Let(i, e, withPostcondition(b, oie)) + case (None, b) => b + } + + /** Adds a body to a specification + * + * @param expr The specification expression [[Expressions.Ensuring]] or [[Expressions.Require]]. If none of these, the argument is discarded. + * @param body An option of [[Expressions.Expr]] possibly containing an expression body. + * @return The post/pre condition with the body. If no body is provided, returns [[Expressions.NoTree]] + * @see [[Expressions.Ensuring]] + * @see [[Expressions.Require]] + */ + def withBody(expr: Expr, body: Option[Expr]): Expr = expr match { + case Let(i, e, b) if hasSpec(b) => Let(i, e, withBody(b, body)) + case Require(pre, _) => Require(pre, body.getOrElse(NoTree(expr.getType))) + case Ensuring(Require(pre, _), post) => Ensuring(Require(pre, body.getOrElse(NoTree(expr.getType))), post) + case Ensuring(_, post) => Ensuring(body.getOrElse(NoTree(expr.getType)), post) + case _ => body.getOrElse(NoTree(expr.getType)) + } + + object InvocationExtractor { + private def flatInvocation(expr: Expr): Option[(Identifier, Seq[Type], Seq[Expr])] = expr match { + case fi @ FunctionInvocation(id, tps, args) => Some((id, tps, args)) + case Application(caller, args) => flatInvocation(caller) match { + case Some((id, tps, prevArgs)) => Some((id, tps, prevArgs ++ args)) + case None => None + } + case _ => None + } + + def unapply(expr: Expr): Option[(Identifier, Seq[Type], Seq[Expr])] = expr match { + case IsTyped(f: FunctionInvocation, ft: FunctionType) => None + case IsTyped(f: Application, ft: FunctionType) => None + case FunctionInvocation(id, tps, args) => Some((id, tps, args)) + case f: Application => flatInvocation(f) + case _ => None + } + } + + def firstOrderCallsOf(expr: Expr): Set[(Identifier, Seq[Type], Seq[Expr])] = + collect { e => InvocationExtractor.unapply(e).toSet[(Identifier, Seq[Type], Seq[Expr])] }(expr) + + object ApplicationExtractor { + private def flatApplication(expr: Expr): Option[(Expr, Seq[Expr])] = expr match { + case Application(fi: FunctionInvocation, _) => None + case Application(caller: Application, args) => flatApplication(caller) match { + case Some((c, prevArgs)) => Some((c, prevArgs ++ args)) + case None => None + } + case Application(caller, args) => Some((caller, args)) + case _ => None + } + + def unapply(expr: Expr): Option[(Expr, Seq[Expr])] = expr match { + case IsTyped(f: Application, ft: FunctionType) => None + case f: Application => flatApplication(f) + case _ => None + } + } + + def firstOrderAppsOf(expr: Expr): Set[(Expr, Seq[Expr])] = + collect[(Expr, Seq[Expr])] { + case ApplicationExtractor(caller, args) => Set(caller -> args) + case _ => Set.empty + } (expr) + + def simplifyHOFunctions(expr: Expr): Expr = { + + def liftToLambdas(expr: Expr) = { + def apply(expr: Expr, args: Seq[Expr]): Expr = expr match { + case IfExpr(cond, thenn, elze) => + IfExpr(cond, apply(thenn, args), apply(elze, args)) + case Let(i, e, b) => + Let(i, e, apply(b, args)) + case LetTuple(is, es, b) => + letTuple(is, es, apply(b, args)) + //case l @ Lambda(params, body) => + // l.withParamSubst(args, body) + case _ => Application(expr, args) + } + + def lift(expr: Expr): Expr = expr.getType match { + case FunctionType(from, to) => expr match { + case _ : Lambda => expr + case _ : Variable => expr + case e => + val args = from.map(tpe => ValDef(FreshIdentifier("x", true), tpe)) + val application = apply(expr, args.map(_.toVariable)) + Lambda(args, lift(application)) + } + case _ => expr + } + + def extract(expr: Expr, build: Boolean) = if (build) lift(expr) else expr + + def rec(expr: Expr, build: Boolean): Expr = expr match { + case Application(caller, args) => + val newArgs = args.map(rec(_, true)) + val newCaller = rec(caller, false) + extract(Application(newCaller, newArgs), build) + case FunctionInvocation(id, tps, args) => + val newArgs = args.map(rec(_, true)) + extract(FunctionInvocation(id, tps, newArgs), build) + case l @ Lambda(args, body) => + val newBody = rec(body, true) + extract(Lambda(args, newBody), build) + case Deconstructor(es, recons) => recons(es.map(rec(_, build))) + } + + rec(lift(expr), true) + } + + liftToLambdas( + matchToIfThenElse( + expr + ) + ) + } + + // Use this only to debug isValueOfType + private implicit class BooleanAdder(b: Boolean) { + @inline def <(msg: => String) = {/*if(!b) println(msg); */b} + } + + /** Returns true if expr is a value of type t */ + def isValueOfType(e: Expr, t: Type): Boolean = { + def unWrapSome(s: Expr) = s match { + case CaseClass(_, Seq(a)) => a + case _ => s + } + (e, t) match { + case (StringLiteral(_), StringType) => true + case (IntLiteral(_), Int32Type) => true + case (IntegerLiteral(_), IntegerType) => true + case (CharLiteral(_), CharType) => true + case (FractionalLiteral(_, _), RealType) => true + case (BooleanLiteral(_), BooleanType) => true + case (UnitLiteral(), UnitType) => true + case (GenericValue(t, _), tp) => t == tp + case (Tuple(elems), TupleType(bases)) => + elems zip bases forall (eb => isValueOfType(eb._1, eb._2)) + case (FiniteSet(elems, tbase), SetType(base)) => + tbase == base && + (elems forall isValue) + case (FiniteBag(elements, fbtpe), BagType(tpe)) => + fbtpe == tpe && + elements.forall{ case (key, value) => isValueOfType(key, tpe) && isValueOfType(value, IntegerType) } + case (FiniteMap(elems, tk, tv), MapType(from, to)) => + (tk == from) < s"$tk not equal to $from" && (tv == to) < s"$tv not equal to $to" && + (elems forall (kv => isValueOfType(kv._1, from) < s"${kv._1} not a value of type ${from}" && isValueOfType(unWrapSome(kv._2), to) < s"${unWrapSome(kv._2)} not a value of type ${to}" )) + case (CaseClass(ct, args), ct2: ClassType) => + isSubtypeOf(ct, ct2) < s"$ct not a subtype of $ct2" && + ((args zip ct.tcd.fieldsTypes) forall (argstyped => isValueOfType(argstyped._1, argstyped._2) < s"${argstyped._1} not a value of type ${argstyped._2}" )) + case (Lambda(valdefs, body), FunctionType(ins, out)) => + variablesOf(e).isEmpty && + (valdefs zip ins forall (vdin => isSubtypeOf(vdin._2, vdin._1.getType) < s"${vdin._2} is not a subtype of ${vdin._1.getType}")) && + (isSubtypeOf(body.getType, out)) < s"${body.getType} is not a subtype of ${out}" + case _ => false + } + } + + /** Returns true if expr is a value. Stronger than isGround */ + def isValue(e: Expr) = isValueOfType(e, e.getType) + + /** Returns a nested string explaining why this expression is typed the way it is.*/ + def explainTyping(e: Expr): String = { + fold[String]{ (e, se) => + e match { + case FunctionInvocation(id, tps, args) => + val tfd = getFunction(id, tps) + s"$e is of type ${e.getType}" + se.map(child => "\n " + "\n".r.replaceAllIn(child, "\n ")).mkString + s" because ${tfd.fd.id.name} was instantiated with ${tfd.fd.tparams.zip(args).map(k => k._1 +":="+k._2).mkString(",")} with type ${tfd.fd.params.map(_.getType).mkString(",")} => ${tfd.fd.returnType}" + case e => + s"$e is of type ${e.getType}" + se.map(child => "\n " + "\n".r.replaceAllIn(child, "\n ")).mkString + } + }(e) + } + + def typeParamsOf(expr: Expr): Set[TypeParameter] = { + collect(e => typeParamsOf(e.getType))(expr) + } + + // Helpers for instantiateType + class TypeInstantiator(tps: Map[TypeParameter, Type]) extends TreeTransformer { + override def transform(tpe: Type): Type = tpe match { + case tp: TypeParameter => tps.getOrElse(tp, tpe) + case _ => tpe + } + } + + def instantiateType(e: Expr, tps: Map[TypeParameter, Type]): Expr = { + if (tps.isEmpty) { + e + } else { + new TypeInstantiator(tps).transform(e) + } + } +} diff --git a/src/main/scala/inox/ast/TreeOps.scala b/src/main/scala/inox/ast/TreeOps.scala new file mode 100644 index 000000000..cbd7b36ad --- /dev/null +++ b/src/main/scala/inox/ast/TreeOps.scala @@ -0,0 +1,200 @@ + +package inox +package ast + +trait TreeOps { + val trees: Trees + import trees._ + + trait TreeTransformer { + def transform(id: Identifier, tpe: Type): (Identifier, Type) = (id, transform(tpe)) + + def transform(v: Variable): Variable = { + val (id, tpe) = transform(v.id, v.tpe) + Variable(id, tpe).copiedFrom(v) + } + + def transform(vd: ValDef): ValDef = { + val (id, tpe) = transform(vd.id, vd.tpe) + ValDef(id, tpe).copiedFrom(vd) + } + + def transform(e: Expr): Expr = e match { + case v: Variable => transform(v) + + case Lambda(args, body) => + Lambda(args map transform, transform(body)).copiedFrom(e) + + case Forall(args, body) => + Forall(args map transform, transform(body)).copiedFrom(e) + + case Let(vd, expr, body) => + Let(transform(vd), transform(expr), transform(body)).copiedFrom(e) + + case CaseClass(cct, args) => + CaseClass(transform(cct).asInstanceOf[ClassType], args map transform).copiedFrom(e) + + case CaseClassSelector(cct, caseClass, selector) => + CaseClassSelector(transform(cct).asInstanceOf[ClassType], transform(caseClass), selector).copiedFrom(e) + + case FunctionInvocation(id, tps, args) => + FunctionInvocation(id, tps map transform, args map transform).copiedFrom(e) + + case IsInstanceOf(expr, ct) => + IsInstanceOf(transform(expr), transform(ct).asInstanceOf[ClassType]).copiedFrom(e) + + case AsInstanceOf(expr, ct) => + AsInstanceOf(transform(expr), transform(ct).asInstanceOf[ClassType]).copiedFrom(e) + + case MatchExpr(scrutinee, cases) => + MatchExpr(transform(scrutinee), for (cse @ MatchCase(pattern, guard, rhs) <- cases) yield { + MatchCase(transform(pattern), guard.map(transform), transform(rhs)).copiedFrom(cse) + }).copiedFrom(e) + + case FiniteSet(es, tpe) => + FiniteSet(es map transform, transform(tpe)).copiedFrom(e) + + case FiniteBag(es, tpe) => + FiniteBag(es map { case (k, v) => transform(k) -> v }, transform(tpe)).copiedFrom(e) + + case FiniteMap(pairs, from, to) => + FiniteMap(pairs map { case (k, v) => transform(k) -> transform(v) }, transform(from), transform(to)).copiedFrom(e) + + case NoTree(tpe) => + NoTree(transform(tpe)).copiedFrom(e) + + case Error(tpe, desc) => + Error(transform(tpe), desc).copiedFrom(e) + + case Operator(es, builder) => + val newEs = es map transform + if ((newEs zip es).exists { case (bef, aft) => aft ne bef }) { + builder(newEs).copiedFrom(e) + } else { + e + } + + case e => e + } + + def transform(pat: Pattern): Pattern = pat match { + case InstanceOfPattern(binder, ct) => + InstanceOfPattern(binder map transform, transform(ct).asInstanceOf[ClassType]).copiedFrom(pat) + + case CaseClassPattern(binder, ct, subs) => + CaseClassPattern(binder map transform, transform(ct).asInstanceOf[ClassType], subs map transform).copiedFrom(pat) + + case TuplePattern(binder, subs) => + TuplePattern(binder map transform, subs map transform).copiedFrom(pat) + + case UnapplyPattern(binder, id, tps, subs) => + UnapplyPattern(binder map transform, id, tps map transform, subs map transform).copiedFrom(pat) + + case PatternExtractor(subs, builder) => + builder(subs map transform).copiedFrom(pat) + + case p => p + } + + def transform(tpe: Type): Type = tpe match { + case NAryType(ts, builder) => builder(ts map transform).copiedFrom(tpe) + } + } + + trait TreeTraverser { + def traverse(e: Expr): Unit = e match { + case Variable(_, tpe) => + traverse(tpe) + + case Lambda(args, body) => + args foreach (vd => traverse(vd.tpe)) + traverse(body) + + case Forall(args, body) => + args foreach (vd => traverse(vd.tpe)) + traverse(body) + + case Let(a, expr, body) => + traverse(expr) + traverse(body) + + case CaseClass(cct, args) => + traverse(cct) + args foreach traverse + + case CaseClassSelector(cct, caseClass, selector) => + traverse(cct) + traverse(caseClass) + + case FunctionInvocation(id, tps, args) => + tps foreach traverse + args foreach traverse + + case IsInstanceOf(expr, ct) => + traverse(expr) + traverse(ct) + + case AsInstanceOf(expr, ct) => + traverse(expr) + traverse(ct) + + case MatchExpr(scrutinee, cases) => + traverse(scrutinee) + for (cse @ MatchCase(pattern, guard, rhs) <- cases) { + traverse(pattern) + guard foreach traverse + traverse(rhs) + } + + case FiniteSet(es, tpe) => + es foreach traverse + traverse(tpe) + + case FiniteBag(es, tpe) => + es foreach { case (k, _) => traverse(k) } + traverse(tpe) + + case FiniteMap(pairs, from, to) => + pairs foreach { case (k, v) => traverse(k); traverse(v) } + traverse(from) + traverse(to) + + case NoTree(tpe) => + traverse(tpe) + + case Error(tpe, desc) => + traverse(tpe) + + case Operator(es, builder) => + es foreach traverse + + case e => + } + + def traverse(pat: Pattern): Unit = pat match { + case InstanceOfPattern(binder, ct) => + traverse(ct) + + case CaseClassPattern(binder, ct, subs) => + traverse(ct) + subs foreach traverse + + case TuplePattern(binder, subs) => + subs foreach traverse + + case UnapplyPattern(binder, id, tps, subs) => + tps foreach traverse + subs foreach traverse + + case PatternExtractor(subs, builder) => + subs foreach traverse + + case pat => + } + + def traverse(tpe: Type): Unit = tpe match { + case NAryType(ts, builder) => + ts foreach traverse + } + } +} diff --git a/src/main/scala/inox/trees/Trees.scala b/src/main/scala/inox/ast/Trees.scala similarity index 68% rename from src/main/scala/inox/trees/Trees.scala rename to src/main/scala/inox/ast/Trees.scala index 5b7874b18..b02c30a74 100644 --- a/src/main/scala/inox/trees/Trees.scala +++ b/src/main/scala/inox/ast/Trees.scala @@ -1,19 +1,17 @@ /* Copyright 2009-2016 EPFL, Lausanne */ package inox -package trees +package ast -trait Trees extends Expressions with Extractors with Types with Definitions { +import utils._ +import scala.language.implicitConversions - object exprOps extends { - val trees: Trees.this.type = Trees.this - } with ExprOps with Constructors +trait Trees extends Expressions with Extractors with Types with Definitions with Printers { - object typeOps extends { - val trees: Trees.this.type = Trees.this - } with TypeOps + class Unsupported(t: Tree, msg: String)(implicit ctx: Context) + extends Exception(s"${t.asString(ctx)}@${t.getPos} $msg") - abstract class Tree extends Positioned with Serializable with Printable { + abstract class Tree extends utils.Positioned with Serializable with inox.Printable { def copiedFrom(o: Tree): this.type = { setPos(o) this @@ -21,22 +19,26 @@ trait Trees extends Expressions with Extractors with Types with Definitions { // @EK: toString is considered harmful for non-internal things. Use asString(ctx) instead. - def asString(implicit pgm: Program, ctx: Context): String = { - ScalaPrinter(this, ctx, pgm) + def asString(implicit symbols: Symbols, ctx: Context): String = { + ScalaPrinter(this, ctx, symbols) } - override def toString = asString(LeonContext.printNames) + override def toString = asString(Context.printNames) } + object exprOps extends { + val trees: Trees.this.type = Trees.this + } with ExprOps + /** Represents a unique symbol in Inox. * * The name is stored in the decoded (source code) form rather than encoded (JVM) form. * The type may be left blank (Untyped) for Identifiers that are not variables. */ - class Identifier private[Common]( + class Identifier private[Trees]( val name: String, val globalId: Int, - private[Common] val id: Int, + private[Trees] val id: Int, private val alwaysShowUniqueID: Boolean = false ) extends Tree with Ordered[Identifier] { @@ -56,13 +58,7 @@ trait Trees extends Expressions with Extractors with Types with Definitions { def uniqueName: String = uniqueNameDelimited("$") - def toVariable: Variable = Variable(this) - - def freshen: Identifier = FreshIdentifier(name, tpe, alwaysShowUniqueID).copiedFrom(this) - - def duplicate(name: String = name, tpe: TypeTree = tpe, alwaysShowUniqueID: Boolean = alwaysShowUniqueID) = { - FreshIdentifier(name, tpe, alwaysShowUniqueID) - } + def freshen: Identifier = FreshIdentifier(name, alwaysShowUniqueID).copiedFrom(this) override def compare(that: Identifier): Int = { val ord = implicitly[Ordering[(String, Int, Int)]] @@ -87,9 +83,9 @@ trait Trees extends Expressions with Extractors with Types with Definitions { * @param tpe The type of the identifier * @param alwaysShowUniqueID If the unique ID should always be shown */ - def apply(name: String, tpe: TypeTree = Untyped, alwaysShowUniqueID: Boolean = false) : Identifier = { + def apply(name: String, alwaysShowUniqueID: Boolean = false) : Identifier = { val decoded = decode(name) - new Identifier(decoded, uniqueCounter.nextGlobal, uniqueCounter.next(decoded), tpe, alwaysShowUniqueID) + new Identifier(decoded, uniqueCounter.nextGlobal, uniqueCounter.next(decoded), alwaysShowUniqueID) } /** Builds a fresh identifier, whose ID is always shown @@ -98,10 +94,8 @@ trait Trees extends Expressions with Extractors with Types with Definitions { * @param forceId The forced ID of the identifier * @param tpe The type of the identifier */ - object forceId { - def apply(name: String, forceId: Int, tpe: TypeTree, alwaysShowUniqueID: Boolean = false): Identifier = - new Identifier(decode(name), uniqueCounter.nextGlobal, forceId, tpe, alwaysShowUniqueID) - } + def forceId(name: String, forceId: Int, alwaysShowUniqueID: Boolean = false): Identifier = + new Identifier(decode(name), uniqueCounter.nextGlobal, forceId, alwaysShowUniqueID) } def aliased(id1: Identifier, id2: Identifier) = { @@ -115,4 +109,7 @@ trait Trees extends Expressions with Extractors with Types with Definitions { (s1 & s2).nonEmpty } + def aliased[T1 <: VariableSymbol,T2 <: VariableSymbol](vs1: Set[T1], vs2: Set[T2]): Boolean = { + aliased(vs1.map(_.id), vs2.map(_.id)) + } } diff --git a/src/main/scala/inox/trees/TypeOps.scala b/src/main/scala/inox/ast/TypeOps.scala similarity index 53% rename from src/main/scala/inox/trees/TypeOps.scala rename to src/main/scala/inox/ast/TypeOps.scala index 0fd4bf588..5f294f6c7 100644 --- a/src/main/scala/inox/trees/TypeOps.scala +++ b/src/main/scala/inox/ast/TypeOps.scala @@ -1,27 +1,35 @@ /* Copyright 2009-2016 EPFL, Lausanne */ package inox -package trees +package ast -trait TypeOps extends GenTreeOps { +trait TypeOps { val trees: Trees import trees._ + implicit val symbols: Symbols - type SubTree = Type + object typeOps extends GenTreeOps { + val trees: TypeOps.this.trees.type = TypeOps.this.trees + import trees._ - val Deconstructor = NAryType + type SubTree = Type + val Deconstructor = NAryType + } + + class TypeErrorException(msg: String) extends Exception(msg) - def typeParamsOf(expr: Expr): Set[TypeParameter] = { - exprOps.collect(e => typeParamsOf(e.getType))(expr) + object TypeErrorException { + def apply(obj: Expr, tpes: Seq[Type]): TypeErrorException = + new TypeErrorException(s"Type error: $obj, expected ${tpes.mkString(" or ")}, found ${obj.getType}") + def apply(obj: Expr, tpe: Type): TypeErrorException = apply(obj, Seq(tpe)) } - def typeParamsOf(t: TypeTree): Set[TypeParameter] = t match { + def typeParamsOf(t: Type): Set[TypeParameter] = t match { case tp: TypeParameter => Set(tp) case NAryType(subs, _) => subs.flatMap(typeParamsOf).toSet } - /** Generic type bounds between two types. Serves as a base for a set of subtyping/unification functions. * It will allow subtyping between classes (but type parameters are invariant). * It will also allow a set of free parameters to be unified if needed. @@ -35,10 +43,10 @@ trait TypeOps extends GenTreeOps { * Result is empty if types are incompatible. * @see [[leastUpperBound]], [[greatestLowerBound]], [[isSubtypeOf]], [[typesCompatible]], [[unify]] */ - def typeBound(t1: TypeTree, t2: TypeTree, isLub: Boolean, allowSub: Boolean) - (implicit freeParams: Seq[TypeParameter]): Option[(TypeTree, Map[TypeParameter, TypeTree])] = { + def typeBound(t1: Type, t2: Type, isLub: Boolean, allowSub: Boolean) + (implicit freeParams: Seq[TypeParameter]): Option[(Type, Map[TypeParameter, Type])] = { - def flatten(res: Seq[Option[(TypeTree, Map[TypeParameter, TypeTree])]]): Option[(Seq[TypeTree], Map[TypeParameter, TypeTree])] = { + def flatten(res: Seq[Option[(Type, Map[TypeParameter, Type])]]): Option[(Seq[Type], Map[TypeParameter, Type])] = { val (tps, subst) = res.map(_.getOrElse(return None)).unzip val flat = subst.flatMap(_.toSeq).groupBy(_._1) Some((tps, flat.mapValues { vs => @@ -66,11 +74,11 @@ trait TypeOps extends GenTreeOps { None case (ct1: ClassType, ct2: ClassType) => - val cd1 = ct1.classDef - val cd2 = ct2.classDef + val cd1 = ct1.tcd.cd + val cd2 = ct2.tcd.cd val bound: Option[ClassDef] = if (allowSub) { - val an1 = cd1 +: cd1.ancestors - val an2 = cd2 +: cd2.ancestors + val an1 = Seq(cd1, cd1.root) + val an2 = Seq(cd2, cd2.root) if (isLub) { (an1.reverse zip an2.reverse) .takeWhile(((_: ClassDef) == (_: ClassDef)).tupled) @@ -90,7 +98,7 @@ trait TypeOps extends GenTreeOps { // Class types are invariant! typeBound(tp1, tp2, isLub, allowSub = false) }) - } yield (cd.typed(subs), map) + } yield (cd.typed(subs).toType, map) case (FunctionType(from1, to1), FunctionType(from2, to2)) => if (from1.size != from2.size) None @@ -105,7 +113,7 @@ trait TypeOps extends GenTreeOps { } } - case Same(t1, t2) => + case typeOps.Same(t1, t2) => // Only tuples are covariant def allowVariance = t1 match { case _ : TupleType => true @@ -124,34 +132,45 @@ trait TypeOps extends GenTreeOps { } } - def unify(tp1: TypeTree, tp2: TypeTree, freeParams: Seq[TypeParameter]) = + def instantiateType(tpe: Type, tps: Map[TypeParameter, Type]): Type = { + if (tps.isEmpty) { + tpe + } else { + typeOps.postMap { + case tp: TypeParameter => tps.get(tp) + case _ => None + } (tpe) + } + } + + def unify(tp1: Type, tp2: Type, freeParams: Seq[TypeParameter]) = typeBound(tp1, tp2, isLub = true, allowSub = false)(freeParams).map(_._2) /** Will try to instantiate subT and superT so that subT <: superT * * @return Mapping of instantiations */ - private def subtypingInstantiation(subT: TypeTree, superT: TypeTree, free: Seq[TypeParameter]) = + private def subtypingInstantiation(subT: Type, superT: Type, free: Seq[TypeParameter]) = typeBound(subT, superT, isLub = true, allowSub = true)(free) collect { case (tp, map) if instantiateType(superT, map) == tp => map } - def canBeSubtypeOf(subT: TypeTree, superT: TypeTree) = { + def canBeSubtypeOf(subT: Type, superT: Type) = { subtypingInstantiation(subT, superT, (typeParamsOf(subT) -- typeParamsOf(superT)).toSeq) } - def canBeSupertypeOf(superT: TypeTree, subT: TypeTree) = { + def canBeSupertypeOf(superT: Type, subT: Type) = { subtypingInstantiation(subT, superT, (typeParamsOf(superT) -- typeParamsOf(subT)).toSeq) } - def leastUpperBound(tp1: TypeTree, tp2: TypeTree): Option[TypeTree] = + def leastUpperBound(tp1: Type, tp2: Type): Option[Type] = typeBound(tp1, tp2, isLub = true, allowSub = true)(Seq()).map(_._1) - def greatestLowerBound(tp1: TypeTree, tp2: TypeTree): Option[TypeTree] = + def greatestLowerBound(tp1: Type, tp2: Type): Option[Type] = typeBound(tp1, tp2, isLub = false, allowSub = true)(Seq()).map(_._1) - def leastUpperBound(ts: Seq[TypeTree]): Option[TypeTree] = { - def olub(ot1: Option[TypeTree], t2: Option[TypeTree]): Option[TypeTree] = ot1 match { + def leastUpperBound(ts: Seq[Type]): Option[Type] = { + def olub(ot1: Option[Type], t2: Option[Type]): Option[Type] = ot1 match { case Some(t1) => leastUpperBound(t1, t2.get) case None => None } @@ -163,15 +182,15 @@ trait TypeOps extends GenTreeOps { } } - def isSubtypeOf(t1: TypeTree, t2: TypeTree): Boolean = { + def isSubtypeOf(t1: Type, t2: Type): Boolean = { leastUpperBound(t1, t2) == Some(t2) } - def typesCompatible(t1: TypeTree, t2s: TypeTree*) = { + def typesCompatible(t1: Type, t2s: Type*) = { leastUpperBound(t1 +: t2s).isDefined } - def typeCheck(obj: Expr, exps: TypeTree*) { + def typeCheck(obj: Expr, exps: Type*) { val res = exps.exists(e => isSubtypeOf(obj.getType, e)) if (!res) { @@ -179,97 +198,68 @@ trait TypeOps extends GenTreeOps { } } - def bestRealType(t: TypeTree) : TypeTree = t match { - case (c: ClassType) => c.root + def bestRealType(t: Type): Type = t match { + case (c: ClassType) => c.tcd.root.toType case NAryType(tps, builder) => builder(tps.map(bestRealType)) } - def isParametricType(tpe: TypeTree): Boolean = tpe match { + def isParametricType(tpe: Type): Boolean = tpe match { case (tp: TypeParameter) => true case NAryType(tps, builder) => tps.exists(isParametricType) } - // Helpers for instantiateType - private def typeParamSubst(map: Map[TypeParameter, TypeTree])(tpe: TypeTree): TypeTree = tpe match { - case (tp: TypeParameter) => map.getOrElse(tp, tp) - case NAryType(tps, builder) => builder(tps.map(typeParamSubst(map))) - } - - private def freshId(id: Identifier, newTpe: TypeTree) = { - if (id.getType != newTpe) { - FreshIdentifier(id.name, newTpe).copiedFrom(id) - } else { - id - } - } - - def instantiateType(id: Identifier, tps: Map[TypeParameter, TypeTree]): Identifier = { - freshId(id, typeParamSubst(tps)(id.getType)) - } - - def instantiateType(tpe: TypeTree, tps: Map[TypeParameter, TypeTree]): TypeTree = { - if (tps.isEmpty) { - tpe - } else { - typeParamSubst(tps)(tpe) - } - } - - def instantiateType(e: Expr, tps: Map[TypeParameter, TypeTree], ids: Map[Identifier, Identifier]): Expr = { - if (tps.isEmpty && ids.isEmpty) { - e - } else { - val tpeSub = if (tps.isEmpty) { - { (tpe: TypeTree) => tpe } + def typeCardinality(tp: Type): Option[Int] = { + def cards(tps: Seq[Type]): Option[Seq[Int]] = { + val cardinalities = tps.map(typeCardinality).flatten + if (cardinalities.size == tps.size) { + Some(cardinalities) } else { - typeParamSubst(tps) _ - } - - val transformer = new TreeTransformer { - override def transform(id: Identifier): Identifier = freshId(id, transform(id.getType)) - override def transform(tpe: TypeTree): TypeTree = tpeSub(tpe) + None } - - transformer.transform(e)(ids) } - } - def typeCardinality(tp: TypeTree): Option[Int] = tp match { - case Untyped => Some(0) - case BooleanType => Some(2) - case UnitType => Some(1) - case TupleType(tps) => - Some(tps.map(typeCardinality).map(_.getOrElse(return None)).product) - case SetType(base) => - typeCardinality(base).map(b => Math.pow(2, b).toInt) - case FunctionType(from, to) => - val t = typeCardinality(to).getOrElse(return None) - val f = from.map(typeCardinality).map(_.getOrElse(return None)).product - Some(Math.pow(t, f).toInt) - case MapType(from, to) => - for { - t <- typeCardinality(to) - f <- typeCardinality(from) - } yield { - Math.pow(t + 1, f).toInt + tp match { + case Untyped => Some(0) + case BooleanType => Some(2) + case UnitType => Some(1) + case TupleType(tps) => cards(tps).map(_.product) + case SetType(base) => + typeCardinality(base).map(b => Math.pow(2, b).toInt) + case FunctionType(from, to) => + for { + t <- typeCardinality(to) + f <- cards(from).map(_.product) + } yield Math.pow(t, f).toInt + case MapType(from, to) => + for { + t <- typeCardinality(to) + f <- typeCardinality(from) + } yield Math.pow(t + 1, f).toInt + case ct: ClassType => ct.tcd match { + case tccd: TypedCaseClassDef => + cards(tccd.fieldsTypes).map(_.product) + + case accd: TypedAbstractClassDef => + val possibleChildTypes = utils.fixpoint((tpes: Set[Type]) => { + tpes.flatMap(tpe => + Set(tpe) ++ (tpe match { + case ct: ClassType => ct.tcd match { + case tccd: TypedCaseClassDef => tccd.fieldsTypes + case tacd: TypedAbstractClassDef => (Set(tacd) ++ tacd.ccDescendants).map(_.toType) + } + case _ => Set.empty + }) + ) + })(accd.ccDescendants.map(_.toType).toSet) + + if (possibleChildTypes(accd.toType)) { + None + } else { + cards(accd.ccDescendants.map(_.toType)).map(_.sum) + } } - case cct: CaseClassType => - Some(cct.fieldsTypes.map { tpe => - typeCardinality(tpe).getOrElse(return None) - }.product) - case act: AbstractClassType => - val possibleChildTypes = leon.utils.fixpoint((tpes: Set[TypeTree]) => { - tpes.flatMap(tpe => - Set(tpe) ++ (tpe match { - case cct: CaseClassType => cct.fieldsTypes - case act: AbstractClassType => Set(act) ++ act.knownCCDescendants - case _ => Set.empty - }) - ) - })(act.knownCCDescendants.toSet) - if(possibleChildTypes(act)) return None - Some(act.knownCCDescendants.map(typeCardinality).map(_.getOrElse(return None)).sum) - case _ => None + case _ => None + } } } diff --git a/src/main/scala/inox/trees/Types.scala b/src/main/scala/inox/ast/Types.scala similarity index 78% rename from src/main/scala/inox/trees/Types.scala rename to src/main/scala/inox/ast/Types.scala index fe9d6a0bb..706c217b5 100644 --- a/src/main/scala/inox/trees/Types.scala +++ b/src/main/scala/inox/ast/Types.scala @@ -1,32 +1,32 @@ /* Copyright 2009-2016 EPFL, Lausanne */ package inox -package trees +package ast trait Types { self: Trees => - trait Typed extends Printable { - def getType(implicit p: Program): Type - def isTyped(implicit p: Program): Boolean = getType != Untyped + trait Typed extends utils.Printable { + def getType(implicit s: Symbols): Type + def isTyped(implicit s: Symbols): Boolean = getType != Untyped } - private[trees] trait CachingTyped extends Typed { - private var lastProgram: Program = null + private[ast] trait CachingTyped extends Typed { + private var lastSymbols: Symbols = null private var lastType: Type = null - final def getType(implicit p: Program): Type = - if (p eq lastProgram) lastType else { + final def getType(implicit s: Symbols): Type = + if (s eq lastSymbols) lastType else { val tpe = computeType - lastProgram = p + lastSymbols = s lastType = tpe tpe } - protected def computeType(implicit p: Program): Type + protected def computeType(implicit s: Symbols): Type } abstract class Type extends Tree with Typed { - def getType(implicit p: Program): Type = this + def getType(implicit s: Symbols): Type = this // Checks whether the subtypes of this type contain Untyped, // and if so sets this to Untyped. @@ -47,10 +47,12 @@ trait Types { self: Trees => case object StringType extends Type case class BVType(size: Int) extends Type - case object Int32Type extends BVType(32) + object Int32Type extends BVType(32) { + override def toString = "Int32Type" + } class TypeParameter private (name: String) extends Type { - val id = FreshIdentifier(name, this) + val id = FreshIdentifier(name) def freshen = new TypeParameter(name) override def equals(that: Any) = that match { @@ -81,12 +83,12 @@ trait Types { self: Trees => case class FunctionType(from: Seq[Type], to: Type) extends Type case class ClassType(id: Identifier, tps: Seq[Type]) extends Type { - def lookupClass(implicit p: Program): Option[ClassDef] = p.lookupClass(id, tps) + def lookupClass(implicit s: Symbols): Option[TypedClassDef] = p.lookupClass(id, tps) + def tcd(implicit s: Symbols): TypedClassDef = s.getClass(id, tps) } object NAryType extends TreeExtractor { val trees: Types.this.type = Types.this - import trees._ type SubTree = Type diff --git a/src/main/scala/inox/trees/package.scala b/src/main/scala/inox/ast/package.scala similarity index 97% rename from src/main/scala/inox/trees/package.scala rename to src/main/scala/inox/ast/package.scala index db1812595..b81058118 100644 --- a/src/main/scala/inox/trees/package.scala +++ b/src/main/scala/inox/ast/package.scala @@ -19,6 +19,4 @@ package inox * a [[leon.purescala.ScalaPrinter]] that outputs a valid Scala program from a Leon * representation. */ -package object trees { - -} +package object ast {} diff --git a/src/main/scala/inox/package.scala b/src/main/scala/inox/package.scala index 69f9be0b3..368316aa0 100644 --- a/src/main/scala/inox/package.scala +++ b/src/main/scala/inox/package.scala @@ -4,8 +4,10 @@ * * Provides the basic types and definitions for the Leon system. */ -package object leon { +package object inox { implicit class BooleanToOption(cond: Boolean) { def option[A](v: => A) = if (cond) Some(v) else None } + + case class FatalError(msg: String) extends Exception(msg) } diff --git a/src/main/scala/inox/trees/ExprOps.scala b/src/main/scala/inox/trees/ExprOps.scala deleted file mode 100644 index 50c9de76f..000000000 --- a/src/main/scala/inox/trees/ExprOps.scala +++ /dev/null @@ -1,2337 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package purescala - -/** Provides functions to manipulate [[purescala.Expressions]]. - * - * This object provides a few generic operations on Leon expressions, - * as well as some common operations. - * - * The generic operations lets you apply operations on a whole tree - * expression. You can look at: - * - [[GenTreeOps.fold foldRight]] - * - [[GenTreeOps.preTraversal preTraversal]] - * - [[GenTreeOps.postTraversal postTraversal]] - * - [[GenTreeOps.preMap preMap]] - * - [[GenTreeOps.postMap postMap]] - * - [[GenTreeOps.genericTransform genericTransform]] - * - * These operations usually take a higher order function that gets applied to the - * expression tree in some strategy. They provide an expressive way to build complex - * operations on Leon expressions. - * - */ -trait ExprOps extends GenTreeOps[Expr] with Constructors with Paths { - val trees: Trees - import trees._ - - val Deconstructor = Operator - - /** Replaces bottom-up sub-identifiers by looking up for them in a map */ - def replaceFromIDs(substs: Map[Identifier, Expr], expr: Expr) : Expr = { - postMap({ - case Variable(i) => substs.get(i) - case _ => None - })(expr) - } - - def preTransformWithBinders(f: (Expr, Set[Identifier]) => Expr, initBinders: Set[Identifier] = Set())(e: Expr) = { - import xlang.Expressions.LetVar - def rec(binders: Set[Identifier], e: Expr): Expr = f(e, binders) match { - case ld@LetDef(fds, bd) => - fds.foreach(fd => { - fd.fullBody = rec(binders ++ fd.paramIds, fd.fullBody) - }) - LetDef(fds, rec(binders, bd)).copiedFrom(ld) - case l@Let(i, v, b) => - Let(i, rec(binders + i, v), rec(binders + i, b)).copiedFrom(l) - case lv@LetVar(i, v, b) => - LetVar(i, rec(binders + i, v), rec(binders + i, b)).copiedFrom(lv) - case m@MatchExpr(scrut, cses) => - MatchExpr(rec(binders, scrut), cses map { case mc@MatchCase(pat, og, rhs) => - val newBs = binders ++ pat.binders - MatchCase(pat, og map (rec(newBs, _)), rec(newBs, rhs)).copiedFrom(mc) - }).copiedFrom(m) - case p@Passes(in, out, cses) => - Passes(rec(binders, in), rec(binders, out), cses map { case mc@MatchCase(pat, og, rhs) => - val newBs = binders ++ pat.binders - MatchCase(pat, og map (rec(newBs, _)), rec(newBs, rhs)).copiedFrom(mc) - }).copiedFrom(p) - case l@Lambda(args, bd) => - Lambda(args, rec(binders ++ args.map(_.id), bd)).copiedFrom(l) - case f@Forall(args, bd) => - Forall(args, rec(binders ++ args.map(_.id), bd)).copiedFrom(f) - case d@Deconstructor(subs, builder) => - builder(subs map (rec(binders, _))).copiedFrom(d) - } - - rec(initBinders, e) - } - - /** Returns the set of free variables in an expression */ - def variablesOf(expr: Expr): Set[Identifier] = { - import leon.xlang.Expressions._ - fold[Set[Identifier]] { - case (e, subs) => - val subvs = subs.flatten.toSet - e match { - case Variable(i) => subvs + i - case Old(i) => subvs + i - case LetDef(fds, _) => subvs -- fds.flatMap(_.params.map(_.id)) - case Let(i, _, _) => subvs - i - case LetVar(i, _, _) => subvs - i - case MatchExpr(_, cses) => subvs -- cses.flatMap(_.pattern.binders) - case Passes(_, _, cses) => subvs -- cses.flatMap(_.pattern.binders) - case Lambda(args, _) => subvs -- args.map(_.id) - case Forall(args, _) => subvs -- args.map(_.id) - case _ => subvs - } - }(expr) - } - - /** Returns true if the expression contains a function call */ - def containsFunctionCalls(expr: Expr): Boolean = { - exists{ - case _: FunctionInvocation => true - case _ => false - }(expr) - } - - /** Returns all Function calls found in the expression */ - def functionCallsOf(expr: Expr): Set[FunctionInvocation] = { - collect[FunctionInvocation] { - case f: FunctionInvocation => Set(f) - case _ => Set() - }(expr) - } - - def nestedFunDefsOf(expr: Expr): Set[FunDef] = { - collect[FunDef] { - case LetDef(fds, _) => fds.toSet - case _ => Set() - }(expr) - } - - /** Returns functions in directly nested LetDefs */ - def directlyNestedFunDefs(e: Expr): Set[FunDef] = { - fold[Set[FunDef]]{ - case (LetDef(fds,_), fromFdsFromBd) => fromFdsFromBd.last ++ fds - case (_, subs) => subs.flatten.toSet - }(e) - } - - /** Computes the negation of a boolean formula, with some simplifications. */ - def negate(expr: Expr) : Expr = { - //require(expr.getType == BooleanType) - (expr match { - case Let(i,b,e) => Let(i,b,negate(e)) - case Not(e) => e - case Implies(e1,e2) => and(e1, negate(e2)) - case Or(exs) => and(exs map negate: _*) - case And(exs) => or(exs map negate: _*) - case LessThan(e1,e2) => GreaterEquals(e1,e2) - case LessEquals(e1,e2) => GreaterThan(e1,e2) - case GreaterThan(e1,e2) => LessEquals(e1,e2) - case GreaterEquals(e1,e2) => LessThan(e1,e2) - case IfExpr(c,e1,e2) => IfExpr(c, negate(e1), negate(e2)) - case BooleanLiteral(b) => BooleanLiteral(!b) - case e => Not(e) - }).setPos(expr) - } - - def replacePatternBinders(pat: Pattern, subst: Map[Identifier, Identifier]): Pattern = { - def rec(p: Pattern): Pattern = p match { - case InstanceOfPattern(ob, ctd) => InstanceOfPattern(ob map subst, ctd) - case WildcardPattern(ob) => WildcardPattern(ob map subst) - case TuplePattern(ob, sps) => TuplePattern(ob map subst, sps map rec) - case CaseClassPattern(ob, ccd, sps) => CaseClassPattern(ob map subst, ccd, sps map rec) - case UnapplyPattern(ob, obj, sps) => UnapplyPattern(ob map subst, obj, sps map rec) - case LiteralPattern(ob, lit) => LiteralPattern(ob map subst, lit) - } - - rec(pat) - } - - - /** Replace each node by its constructor - * - * Remap the expression by calling the corresponding constructor - * for each node of the expression. The constructor will perfom - * some local simplifications, resulting in a simplified expression. - */ - def simplifyByConstructors(expr: Expr): Expr = { - def step(e: Expr): Option[Expr] = e match { - case Not(t) => Some(not(t)) - case UMinus(t) => Some(uminus(t)) - case BVUMinus(t) => Some(uminus(t)) - case RealUMinus(t) => Some(uminus(t)) - case CaseClassSelector(cd, e, sel) => Some(caseClassSelector(cd, e, sel)) - case AsInstanceOf(e, ct) => Some(asInstOf(e, ct)) - case Equals(t1, t2) => Some(equality(t1, t2)) - case Implies(t1, t2) => Some(implies(t1, t2)) - case Plus(t1, t2) => Some(plus(t1, t2)) - case Minus(t1, t2) => Some(minus(t1, t2)) - case Times(t1, t2) => Some(times(t1, t2)) - case BVPlus(t1, t2) => Some(plus(t1, t2)) - case BVMinus(t1, t2) => Some(minus(t1, t2)) - case BVTimes(t1, t2) => Some(times(t1, t2)) - case RealPlus(t1, t2) => Some(plus(t1, t2)) - case RealMinus(t1, t2) => Some(minus(t1, t2)) - case RealTimes(t1, t2) => Some(times(t1, t2)) - case And(args) => Some(andJoin(args)) - case Or(args) => Some(orJoin(args)) - case Tuple(args) => Some(tupleWrap(args)) - case MatchExpr(scrut, cases) => Some(matchExpr(scrut, cases)) - case Passes(in, out, cases) => Some(passes(in, out, cases)) - case _ => None - } - postMap(step)(expr) - } - - /** ATTENTION: Unused, and untested - * rewrites pattern-matching expressions to use fresh variables for the binders - */ - def freshenLocals(expr: Expr) : Expr = { - def freshenCase(cse: MatchCase) : MatchCase = { - val allBinders: Set[Identifier] = cse.pattern.binders - val subMap: Map[Identifier,Identifier] = - Map(allBinders.map(i => (i, FreshIdentifier(i.name, i.getType, true))).toSeq : _*) - val subVarMap: Map[Expr,Expr] = subMap.map(kv => Variable(kv._1) -> Variable(kv._2)) - - MatchCase( - replacePatternBinders(cse.pattern, subMap), - cse.optGuard map { replace(subVarMap, _)}, - replace(subVarMap,cse.rhs) - ) - } - - postMap{ - case m @ MatchExpr(s, cses) => - Some(matchExpr(s, cses.map(freshenCase)).copiedFrom(m)) - - case p @ Passes(in, out, cses) => - Some(Passes(in, out, cses.map(freshenCase)).copiedFrom(p)) - - case l @ Let(i,e,b) => - val newID = FreshIdentifier(i.name, i.getType, alwaysShowUniqueID = true).copiedFrom(i) - Some(Let(newID, e, replaceFromIDs(Map(i -> Variable(newID)), b)).copiedFrom(l)) - - case _ => None - }(expr) - } - - /** Applies the function to the I/O constraint and simplifies the resulting constraint */ - def applyAsMatches(p : Passes, f : Expr => Expr) = { - f(p.asConstraint) match { - case Equals(newOut, MatchExpr(newIn, newCases)) => - val filtered = newCases flatMap { - case MatchCase(p, g, `newOut`) => None - case other => Some(other) - } - Passes(newIn, newOut, filtered) - case other => - other - } - } - - /** Normalizes the expression expr */ - def normalizeExpression(expr: Expr) : Expr = { - def rec(e: Expr): Option[Expr] = e match { - case TupleSelect(Let(id, v, b), ts) => - Some(Let(id, v, tupleSelect(b, ts, true))) - - case TupleSelect(LetTuple(ids, v, b), ts) => - Some(letTuple(ids, v, tupleSelect(b, ts, true))) - - case CaseClassSelector(cct, cc: CaseClass, id) => - Some(caseClassSelector(cct, cc, id).copiedFrom(e)) - - case IfExpr(c, thenn, elze) if (thenn == elze) && isPurelyFunctional(c) => - Some(thenn) - - case IfExpr(c, BooleanLiteral(true), BooleanLiteral(false)) => - Some(c) - - case IfExpr(Not(c), thenn, elze) => - Some(IfExpr(c, elze, thenn).copiedFrom(e)) - - case IfExpr(c, BooleanLiteral(false), BooleanLiteral(true)) => - Some(Not(c).copiedFrom(e)) - - case FunctionInvocation(tfd, List(IfExpr(c, thenn, elze))) => - Some(IfExpr(c, FunctionInvocation(tfd, List(thenn)), FunctionInvocation(tfd, List(elze))).copiedFrom(e)) - - case _ => - None - } - - fixpoint(postMap(rec))(expr) - } - - private val typedIds: scala.collection.mutable.Map[TypeTree, List[Identifier]] = - scala.collection.mutable.Map.empty.withDefaultValue(List.empty) - - /** Normalizes identifiers in an expression to enable some notion of structural - * equality between expressions on which usual equality doesn't make sense - * (i.e. closures). - * - * This function relies on the static map `typedIds` to ensure identical - * structures and must therefore be synchronized. - * - * The optional argument [[onlySimple]] determines whether non-simple expressions - * (see [[isSimple]]) should be normalized into a dependency or recursed into - * (when they don't depend on [[args]]). This distinction is used in the - * unrolling solver to provide geenral equality checks between functions even when - * they have complex closures. - */ - def normalizeStructure(args: Seq[Identifier], expr: Expr, onlySimple: Boolean = true): (Seq[Identifier], Expr, Map[Identifier, Expr]) = synchronized { - val vars = args.toSet - - class Normalizer extends TreeTransformer { - var subst: Map[Identifier, Expr] = Map.empty - var remainingIds: Map[TypeTree, List[Identifier]] = typedIds.toMap - - def getId(e: Expr): Identifier = { - val tpe = TypeOps.bestRealType(e.getType) - val newId = remainingIds.get(tpe) match { - case Some(x :: xs) => - remainingIds += tpe -> xs - x - case _ => - val x = FreshIdentifier("x", tpe, true) - typedIds(tpe) = typedIds(tpe) :+ x - x - } - subst += newId -> e - newId - } - - override def transform(id: Identifier): Identifier = subst.get(id) match { - case Some(Variable(newId)) => newId - case Some(_) => scala.sys.error("Should never happen!") - case None => getId(id.toVariable) - } - - override def transform(e: Expr)(implicit bindings: Map[Identifier, Identifier]): Expr = e match { - case expr if (isSimple(expr) || !onlySimple) && (variablesOf(expr) & vars).isEmpty => getId(expr).toVariable - case f: Forall => - val (args, body, newSubst) = normalizeStructure(f.args.map(_.id), f.body, onlySimple) - subst ++= newSubst - Forall(args.map(ValDef(_)), body) - case l: Lambda => - val (args, body, newSubst) = normalizeStructure(l.args.map(_.id), l.body, onlySimple) - subst ++= newSubst - Lambda(args.map(ValDef(_)), body) - case _ => super.transform(e) - } - } - - val n = new Normalizer - val bindings = args.map(id => id -> n.getId(id.toVariable)).toMap - val normalized = n.transform(matchToIfThenElse(expr))(bindings) - - val argsImgSet = bindings.map(_._2).toSet - val bodySubst = n.subst.filter(p => !argsImgSet(p._1)) - - (args.map(bindings), normalized, bodySubst) - } - - def normalizeStructure(lambda: Lambda): (Lambda, Map[Identifier, Expr]) = { - val (args, body, subst) = normalizeStructure(lambda.args.map(_.id), lambda.body, onlySimple = false) - (Lambda(args.map(ValDef(_)), body), subst) - } - - def normalizeStructure(forall: Forall): (Forall, Map[Identifier, Expr]) = { - val (args, body, subst) = normalizeStructure(forall.args.map(_.id), forall.body) - (Forall(args.map(ValDef(_)), body), subst) - } - - /** Returns '''true''' if the formula is Ground, - * which means that it does not contain any variable ([[purescala.ExprOps#variablesOf]] e is empty) - * and [[purescala.ExprOps#isDeterministic isDeterministic]] - */ - def isGround(e: Expr): Boolean = { - variablesOf(e).isEmpty && isDeterministic(e) - } - - /** Returns '''true''' if the formula is simple, - * which means that it requires no special encoding for an - * unrolling solver. See implementation for what this means exactly. - */ - def isSimple(e: Expr): Boolean = !exists { - case (_: Choose) | (_: Hole) | - (_: Assert) | (_: Ensuring) | - (_: Forall) | (_: Lambda) | (_: FiniteLambda) | - (_: FunctionInvocation) | (_: Application) => true - case _ => false - } (e) - - /** Returns a function which can simplify all ground expressions which appear in a program context. - */ - def evalGround(ctx: LeonContext, program: Program): Expr => Expr = { - import evaluators._ - - val eval = new DefaultEvaluator(ctx, program) - - def rec(e: Expr): Option[Expr] = e match { - case l: Terminal => None - case e if isGround(e) => eval.eval(e).result // returns None if eval fails - case _ => None - } - - preMap(rec) - } - - /** Simplifies let expressions - * - * - removes lets when expression never occurs - * - simplifies when expressions occurs exactly once - * - expands when expression is just a variable. - * - * @note the code is simple but far from optimal (many traversals...) - */ - def simplifyLets(expr: Expr) : Expr = { - - def freeComputable(e: Expr) = e match { - case TupleSelect(Variable(_), _) => true - case CaseClassSelector(_, Variable(_), _) => true - case FiniteSet(els, _) => els.isEmpty - case FiniteMap(els, _, _) => els.isEmpty - case _: Terminal => true - case _ => false - } - - def inlineLetDefs(fds: Seq[FunDef], body: Expr, toInline: Set[FunDef]): Expr = { - def inline(e: Expr) = leon.utils.fixpoint( - ExprOps.preMap{ - case FunctionInvocation(TypedFunDef(f, targs), args) if toInline(f) => - val substs = f.paramIds.zip(args).toMap - Some(replaceFromIDs(substs, f.fullBody)) - case _ => None - } , 64)(e) - - val inlined = fds.filter(x => !toInline(x)).map{fd => - val newFd = fd.duplicate() - newFd.fullBody = inline(fd.fullBody) - fd -> newFd - } - val inlined_body = inline(body) - val inlineMap = inlined.toMap - def useUpdatedFunctions(e: Expr): Expr = ExprOps.preMap{ - case fi@FunctionInvocation(tfd@TypedFunDef(f, targs), args) => - inlineMap.get(f).map{ newF => - FunctionInvocation(TypedFunDef(newF, targs).copiedFrom(tfd), args).copiedFrom(fi) - } - case _ => None - }(e) - val updatedCalls = inlined.map{ case (f, newF) => - newF.fullBody = useUpdatedFunctions(newF.fullBody) - newF - } - val updatedBody = useUpdatedFunctions(inlined_body) - if(updatedCalls.isEmpty) { - updatedBody - } else - LetDef(updatedCalls, updatedBody) - } - - def simplerLet(t: Expr): Option[Expr] = t match { - case LetDef(fds, body) => // Inline simple functions called only once, or calling another function. - - def collectCalls(e: Expr): Set[(FunDef, Int)] = { - var i = 1 - ExprOps.collect[(FunDef, Int)]{ - case FunctionInvocation(TypedFunDef(fd, _), _) => Set((fd, { i += 1; i})) - case _ => Set()}(e) - } - val calledGraph = ((for{ - fd <- fds - (callee, id) <- collectCalls(fd.fullBody).toSeq - if fds.contains(callee) - } yield ((fd: Tree) -> callee)) ++ collectCalls(body).toSeq.map((body: Tree) -> _._1)).groupBy(_._2).mapValues(_.size) - - val toInline = fds.filter{ fd => fd.fullBody match { - case Int32ToString(Variable(id)) if fd.paramIds.headOption == Some(id) => true - case BooleanToString(Variable(id)) if fd.paramIds.headOption == Some(id) => true - case IntegerToString(Variable(id)) if fd.paramIds.headOption == Some(id) => true - case Variable(id) if fd.paramIds.headOption == Some(id) => true - case _ if calledGraph.getOrElse(fd, 0) <= 1 => true - case FunctionInvocation(TypedFunDef(f, _), _) if calledGraph.getOrElse(f, 0) > 1 => true - case _ => false - }} - - if(toInline.length > 0) { - Some(inlineLetDefs(fds, body, toInline.toSet)) - } else None - - /* Untangle */ - case Let(i1, Let(i2, e2, b2), b1) => - Some(Let(i2, e2, Let(i1, b2, b1))) - - case Let(i1, LetTuple(is2, e2, b2), b1) => - Some(letTuple(is2, e2, Let(i1, b2, b1))) - - case LetTuple(ids1, Let(id2, e2, b2), b1) => - Some(Let(id2, e2, letTuple(ids1, b2, b1))) - - case LetTuple(ids1, LetTuple(ids2, e2, b2), b1) => - Some(letTuple(ids2, e2, letTuple(ids1, b2, b1))) - - // Untuple - case Let(id, Tuple(es), b) => - val ids = es.zipWithIndex.map { case (e, ind) => - FreshIdentifier(id + (ind + 1).toString, e.getType, true) - } - val theMap: Map[Expr, Expr] = es.zip(ids).zipWithIndex.map { - case ((e, subId), ind) => TupleSelect(Variable(id), ind + 1) -> Variable(subId) - }.toMap - - val replaced0 = replace(theMap, b) - val replaced = replace(Map(Variable(id) -> Tuple(ids map Variable)), replaced0) - - Some(letTuple(ids, Tuple(es), replaced)) - - case Let(i, e, b) if freeComputable(e) && isPurelyFunctional(e) => - // computation is very quick and code easy to read, always inline - Some(replaceFromIDs(Map(i -> e), b)) - - case Let(i,e,b) if isPurelyFunctional(e) => - // computation may be slow, or code complicated to read, inline at most once - val occurrences = count { - case Variable(`i`) => 1 - case _ => 0 - }(b) - - if(occurrences == 0) { - Some(b) - } else if(occurrences == 1) { - Some(replaceFromIDs(Map(i -> e), b)) - } else { - None - } - - case LetTuple(ids, Tuple(elems), body) => - Some(ids.zip(elems).foldRight(body) { case ((id, elem), bd) => Let(id, elem, bd) }) - - /*case LetPattern(patt, e0, body) if isPurelyFunctional(e0) => - // Will turn the match-expression with a single case into a list of lets. - // @mk it is not clear at all that we want this - - // Just extra safety... - val e = (e0.getType, patt) match { - case (_:AbstractClassType, CaseClassPattern(_, cct, _)) => - asInstOf(e0, cct) - case (at: AbstractClassType, InstanceOfPattern(_, ct)) if at != ct => - asInstOf(e0, ct) - case _ => - e0 - } - - // Sort lets in dependency order - val lets = mapForPattern(e, patt).toSeq.sortWith { - case ((id1, e1), (id2, e2)) => exists{ _ == Variable(id1) }(e2) - } - - Some(lets.foldRight(body) { - case ((id, e), bd) => Let(id, e, bd) - })*/ - - case MatchExpr(scrut, cases) => - // Merge match within match - var changed = false - val newCases = cases map { - case MatchCase(patt, g, LetPattern(innerPatt, Variable(id), body)) if patt.binders contains id => - changed = true - val newPatt = PatternOps.preMap { - case WildcardPattern(Some(`id`)) => Some(innerPatt.withBinder(id)) - case _ => None - }(patt) - MatchCase(newPatt, g, body) - case other => - other - } - if(changed) Some(MatchExpr(scrut, newCases)) else None - - case _ => None - } - - postMap(simplerLet, applyRec = true)(expr) - } - - /** Fully expands all let expressions. */ - def expandLets(expr: Expr) : Expr = { - def rec(ex: Expr, s: Map[Identifier,Expr]) : Expr = ex match { - case v @ Variable(id) if s.isDefinedAt(id) => rec(s(id), s) - case l @ Let(i,e,b) => rec(b, s + (i -> rec(e, s))) - case i @ IfExpr(t1,t2,t3) => IfExpr(rec(t1, s),rec(t2, s),rec(t3, s)) - case m @ MatchExpr(scrut, cses) => matchExpr(rec(scrut, s), cses.map(inCase(_, s))).setPos(m) - case p @ Passes(in, out, cses) => Passes(rec(in, s), rec(out,s), cses.map(inCase(_, s))).setPos(p) - case n @ Deconstructor(args, recons) => - var change = false - val rargs = args.map(a => { - val ra = rec(a, s) - if(ra != a) { - change = true - ra - } else { - a - } - }) - if(change) - recons(rargs) - else - n - case unhandled => throw LeonFatalError("Unhandled case in expandLets: " + unhandled) - } - - def inCase(cse: MatchCase, s: Map[Identifier,Expr]) : MatchCase = { - import cse._ - MatchCase(pattern, optGuard map { rec(_, s) }, rec(rhs,s)) - } - - rec(expr, Map.empty) - } - - /** Lifts lets to top level. - * - * Does not push any used variable out of scope. - * Assumes no match expressions (i.e. matchToIfThenElse has been called on e) - */ - def liftLets(e: Expr): Expr = { - - type C = Seq[(Identifier, Expr)] - - def combiner(e: Expr, defs: Seq[C]): C = (e, defs) match { - case (Let(i, ex, b), Seq(inDef, inBody)) => - inDef ++ ((i, ex) +: inBody) - case _ => - defs.flatten - } - - def noLet(e: Expr, defs: C) = e match { - case Let(_, _, b) => (b, defs) - case _ => (e, defs) - } - - val (bd, defs) = genericTransform[C](noTransformer, noLet, combiner)(Seq())(e) - - defs.foldRight(bd){ case ((id, e), body) => Let(id, e, body) } - } - - /** Generates substitutions necessary to transform scrutinee to equivalent - * specialized cases - * - * {{{ - * e match { - * case CaseClass((a, 42), c) => expr - * } - * }}} - * will return, for the first pattern: - * {{{ - * Map( - * e -> CaseClass(t, c), - * t -> (a, b2), - * b2 -> 42, - * ) - * }}} - * - * @note UNUSED, is not maintained - */ - def patternSubstitutions(in: Expr, pattern: Pattern): Seq[(Expr, Expr)] ={ - def rec(in: Expr, pattern: Pattern): Seq[(Expr, Expr)] = pattern match { - case InstanceOfPattern(ob, cct: CaseClassType) => - val pt = CaseClassPattern(ob, cct, cct.fields.map { f => - WildcardPattern(Some(FreshIdentifier(f.id.name, f.getType))) - }) - rec(in, pt) - - case TuplePattern(_, subps) => - val TupleType(subts) = in.getType - val subExprs = (subps zip subts).zipWithIndex map { - case ((p, t), index) => p.binder.map(_.toVariable).getOrElse(tupleSelect(in, index+1, subps.size)) - } - - // Special case to get rid of (a,b) match { case (c,d) => .. } - val subst0 = in match { - case Tuple(ts) => - ts zip subExprs - case _ => - Seq(in -> tupleWrap(subExprs)) - } - - subst0 ++ ((subExprs zip subps) flatMap { - case (e, p) => recBinder(e, p) - }) - - case CaseClassPattern(_, cct, subps) => - val subExprs = (subps zip cct.classDef.fields) map { - case (p, f) => p.binder.map(_.toVariable).getOrElse(caseClassSelector(cct, in, f.id)) - } - - // Special case to get rid of Cons(a,b) match { case Cons(c,d) => .. } - val subst0 = in match { - case CaseClass(`cct`, args) => - args zip subExprs - case _ => - Seq(in -> CaseClass(cct, subExprs)) - } - - subst0 ++ ((subExprs zip subps) flatMap { - case (e, p) => recBinder(e, p) - }) - - case LiteralPattern(_, v) => - Seq(in -> v) - - case _ => - Seq() - } - - def recBinder(in: Expr, pattern: Pattern): Seq[(Expr, Expr)] = { - (pattern, pattern.binder) match { - case (_: WildcardPattern, Some(b)) => - Seq(in -> b.toVariable) - case (p, Some(b)) => - val bv = b.toVariable - Seq(in -> bv) ++ rec(bv, pattern) - case _ => - rec(in, pattern) - } - } - - recBinder(in, pattern).filter{ case (a, b) => a != b } - } - - /** Recursively transforms a pattern on a boolean formula expressing the conditions for the input expression, possibly including name binders - * - * For example, the following pattern on the input `i` - * {{{ - * case m @ MyCaseClass(t: B, (_, 7)) => - * }}} - * will yield the following condition before simplification (to give some flavour) - * - * {{{and(IsInstanceOf(MyCaseClass, i), and(Equals(m, i), InstanceOfClass(B, i.t), equals(i.k.arity, 2), equals(i.k._2, 7))) }}} - * - * Pretty-printed, this would be: - * {{{ - * i.instanceOf[MyCaseClass] && m == i && i.t.instanceOf[B] && i.k.instanceOf[Tuple2] && i.k._2 == 7 - * }}} - * - * @see [[purescala.Expressions.Pattern]] - */ - def conditionForPattern(in: Expr, pattern: Pattern, includeBinders: Boolean = false): Path = { - def bind(ob: Option[Identifier], to: Expr): Path = { - if (!includeBinders) { - Path.empty - } else { - ob.map(id => Path.empty withBinding (id -> to)).getOrElse(Path.empty) - } - } - - def rec(in: Expr, pattern: Pattern): Path = { - pattern match { - case WildcardPattern(ob) => - bind(ob, in) - - case InstanceOfPattern(ob, ct) => - if (ct.parent.isEmpty) { - bind(ob, in) - } else { - Path(IsInstanceOf(in, ct)) merge bind(ob, in) - } - - case CaseClassPattern(ob, cct, subps) => - assert(cct.classDef.fields.size == subps.size) - val pairs = cct.classDef.fields.map(_.id).toList zip subps.toList - val subTests = pairs.map(p => rec(caseClassSelector(cct, in, p._1), p._2)) - Path(IsInstanceOf(in, cct)) merge bind(ob, in) merge subTests - - case TuplePattern(ob, subps) => - val TupleType(tpes) = in.getType - assert(tpes.size == subps.size) - val subTests = subps.zipWithIndex.map { - case (p, i) => rec(tupleSelect(in, i+1, subps.size), p) - } - bind(ob, in) merge subTests - - case up @ UnapplyPattern(ob, fd, subps) => - val subs = unwrapTuple(up.get(in), subps.size).zip(subps) map (rec _).tupled - bind(ob, in) withCond up.isSome(in) merge subs - - case LiteralPattern(ob, lit) => - Path(Equals(in, lit)) merge bind(ob, in) - } - } - - rec(in, pattern) - } - - /** Converts the pattern applied to an input to a map between identifiers and expressions */ - def mapForPattern(in: Expr, pattern: Pattern) : Map[Identifier,Expr] = { - def bindIn(id: Option[Identifier], cast: Option[ClassType] = None): Map[Identifier,Expr] = id match { - case None => Map() - case Some(id) => Map(id -> cast.map(asInstOf(in, _)).getOrElse(in)) - } - pattern match { - case CaseClassPattern(b, cct, subps) => - assert(cct.fields.size == subps.size) - val pairs = cct.classDef.fields.map(_.id).toList zip subps.toList - val subMaps = pairs.map(p => mapForPattern(caseClassSelector(cct, asInstOf(in, cct), p._1), p._2)) - val together = subMaps.flatten.toMap - bindIn(b, Some(cct)) ++ together - - case TuplePattern(b, subps) => - val TupleType(tpes) = in.getType - assert(tpes.size == subps.size) - - val maps = subps.zipWithIndex.map{case (p, i) => mapForPattern(tupleSelect(in, i+1, subps.size), p)} - val map = maps.flatten.toMap - bindIn(b) ++ map - - case up@UnapplyPattern(b, _, subps) => - bindIn(b) ++ unwrapTuple(up.getUnsafe(in), subps.size).zip(subps).flatMap { - case (e, p) => mapForPattern(e, p) - }.toMap - - case InstanceOfPattern(b, ct) => - bindIn(b, Some(ct)) - - case other => - bindIn(other.binder) - } - } - - /** Rewrites all pattern-matching expressions into if-then-else expressions - * Introduces additional error conditions. Does not introduce additional variables. - */ - def matchToIfThenElse(expr: Expr): Expr = { - - def rewritePM(e: Expr): Option[Expr] = e match { - case m @ MatchExpr(scrut, cases) => - // println("Rewriting the following PM: " + e) - - val condsAndRhs = for (cse <- cases) yield { - val map = mapForPattern(scrut, cse.pattern) - val patCond = conditionForPattern(scrut, cse.pattern, includeBinders = false) - val realCond = cse.optGuard match { - case Some(g) => patCond withCond replaceFromIDs(map, g) - case None => patCond - } - val newRhs = replaceFromIDs(map, cse.rhs) - (realCond.toClause, newRhs, cse) - } - - val bigIte = condsAndRhs.foldRight[Expr](Error(m.getType, "Match is non-exhaustive").copiedFrom(m))((p1, ex) => { - if(p1._1 == BooleanLiteral(true)) { - p1._2 - } else { - IfExpr(p1._1, p1._2, ex).copiedFrom(p1._3) - } - }) - - Some(bigIte) - - case p: Passes => - // This introduces a MatchExpr - Some(p.asConstraint) - - case _ => None - } - - preMap(rewritePM)(expr) - } - - /** For each case in the [[purescala.Expressions.MatchExpr MatchExpr]], concatenates the path condition with the newly induced conditions. - * - * Each case holds the conditions on other previous cases as negative. - * - * @see [[purescala.ExprOps#conditionForPattern conditionForPattern]] - * @see [[purescala.ExprOps#mapForPattern mapForPattern]] - */ - def matchExprCaseConditions(m: MatchExpr, path: Path) : Seq[Path] = { - val MatchExpr(scrut, cases) = m - var pcSoFar = path - - for (c <- cases) yield { - val g = c.optGuard getOrElse BooleanLiteral(true) - val cond = conditionForPattern(scrut, c.pattern, includeBinders = true) - val localCond = pcSoFar merge (cond withCond g) - - // These contain no binders defined in this MatchCase - val condSafe = conditionForPattern(scrut, c.pattern) - val gSafe = replaceFromIDs(mapForPattern(scrut, c.pattern), g) - pcSoFar = pcSoFar merge (condSafe withCond gSafe).negate - - localCond - } - } - - /** Condition to pass this match case, expressed w.r.t scrut only */ - def matchCaseCondition(scrut: Expr, c: MatchCase): Path = { - - val patternC = conditionForPattern(scrut, c.pattern, includeBinders = false) - - c.optGuard match { - case Some(g) => - // guard might refer to binders - val map = mapForPattern(scrut, c.pattern) - patternC withCond replaceFromIDs(map, g) - - case None => - patternC - } - } - - /** Returns the path conditions for each of the case passes. - * - * Each case holds the conditions on other previous cases as negative. - */ - def passesPathConditions(p: Passes, pathCond: Path) : Seq[Path] = { - matchExprCaseConditions(MatchExpr(p.in, p.cases), pathCond) - } - - /** - * Returns a pattern from an expression, and a guard if any. - */ - def expressionToPattern(e : Expr) : (Pattern, Expr) = { - var guard : Expr = BooleanLiteral(true) - def rec(e : Expr) : Pattern = e match { - case CaseClass(cct, fields) => CaseClassPattern(None, cct, fields map rec) - case Tuple(subs) => TuplePattern(None, subs map rec) - case l : Literal[_] => LiteralPattern(None, l) - case Variable(i) => WildcardPattern(Some(i)) - case other => - val id = FreshIdentifier("other", other.getType, true) - guard = and(guard, Equals(Variable(id), other)) - WildcardPattern(Some(id)) - } - (rec(e), guard) - } - - /** - * Takes a pattern and returns an expression that corresponds to it. - * Also returns a sequence of `Identifier -> Expr` pairs which - * represent the bindings for intermediate binders (from outermost to innermost) - */ - def patternToExpression(p: Pattern, expectedType: TypeTree): (Expr, Seq[(Identifier, Expr)]) = { - def fresh(tp : TypeTree) = FreshIdentifier("binder", tp, true) - var ieMap = Seq[(Identifier, Expr)]() - def addBinding(b : Option[Identifier], e : Expr) = b foreach { ieMap +:= (_, e) } - def rec(p : Pattern, expectedType : TypeTree) : Expr = p match { - case WildcardPattern(b) => - Variable(b getOrElse fresh(expectedType)) - case LiteralPattern(b, lit) => - addBinding(b,lit) - lit - case InstanceOfPattern(b, ct) => ct match { - case act: AbstractClassType => - // @mk: This seems dubious, in the sense that it just binds the expression - // of the AbstractClassType to an id instead of going case-wise. - // I think this is sufficient for the use of this function though: - // it is only used to generate examples so it is followed by a type-aware enumerator. - val e = Variable(fresh(act)) - addBinding(b, e) - e - - case cct: CaseClassType => - val fields = cct.fields map { f => Variable(fresh(f.getType)) } - val e = CaseClass(cct, fields) - addBinding(b, e) - e - } - case TuplePattern(b, subs) => - val TupleType(subTypes) = expectedType - val e = Tuple(subs zip subTypes map { - case (sub, subType) => rec(sub, subType) - }) - addBinding(b, e) - e - case CaseClassPattern(b, cct, subs) => - val e = CaseClass(cct, subs zip cct.fieldsTypes map { case (sub,tp) => rec(sub,tp) }) - addBinding(b, e) - e - case up@UnapplyPattern(b, fd, subs) => - // TODO: Support this - NoTree(expectedType) - } - - (rec(p, expectedType), ieMap) - - } - - - /** Rewrites all map accesses with additional error conditions. */ - def mapGetWithChecks(expr: Expr): Expr = { - postMap({ - case mg @ MapApply(m,k) => - val ida = MapIsDefinedAt(m, k) - Some(IfExpr(ida, mg, Error(mg.getType, "Key not found for map access").copiedFrom(mg)).copiedFrom(mg)) - - case _=> - None - })(expr) - } - - /** Returns simplest value of a given type */ - def simplestValue(tpe: TypeTree) : Expr = tpe match { - case StringType => StringLiteral("") - case Int32Type => IntLiteral(0) - case RealType => FractionalLiteral(0, 1) - case IntegerType => InfiniteIntegerLiteral(0) - case CharType => CharLiteral('a') - case BooleanType => BooleanLiteral(false) - case UnitType => UnitLiteral() - case SetType(baseType) => FiniteSet(Set(), baseType) - case BagType(baseType) => FiniteBag(Map(), baseType) - case MapType(fromType, toType) => FiniteMap(Map(), fromType, toType) - case TupleType(tpes) => Tuple(tpes.map(simplestValue)) - case ArrayType(tpe) => EmptyArray(tpe) - - case act @ AbstractClassType(acd, tpe) => - val ccDesc = act.knownCCDescendants - - def isRecursive(cct: CaseClassType): Boolean = { - cct.fieldsTypes.exists{ - case AbstractClassType(fieldAcd, _) => acd.root == fieldAcd.root - case CaseClassType(fieldCcd, _) => acd.root == fieldCcd.root - case _ => false - } - } - - val nonRecChildren = ccDesc.filterNot(isRecursive).sortBy(_.fields.size) - - nonRecChildren.headOption match { - case Some(cct) => - simplestValue(cct) - - case None => - throw LeonFatalError(act +" does not seem to be well-founded") - } - - case cct: CaseClassType => - CaseClass(cct, cct.fieldsTypes.map(t => simplestValue(t))) - - case tp: TypeParameter => - GenericValue(tp, 0) - - case ft @ FunctionType(from, to) => - FiniteLambda(Seq.empty, simplestValue(to), ft) - - case _ => throw LeonFatalError("I can't choose simplest value for type " + tpe) - } - - /* Checks if a given expression is 'real' and does not contain generic - * values. */ - def isRealExpr(v: Expr): Boolean = { - !exists { - case gv: GenericValue => true - case _ => false - }(v) - } - - def valuesOf(tp: TypeTree): Stream[Expr] = { - import utils.StreamUtils._ - tp match { - case BooleanType => - Stream(BooleanLiteral(false), BooleanLiteral(true)) - case Int32Type => - Stream.iterate(0) { prev => - if (prev > 0) -prev else -prev + 1 - } map IntLiteral - case IntegerType => - Stream.iterate(BigInt(0)) { prev => - if (prev > 0) -prev else -prev + 1 - } map InfiniteIntegerLiteral - case UnitType => - Stream(UnitLiteral()) - case tp: TypeParameter => - Stream.from(0) map (GenericValue(tp, _)) - case TupleType(stps) => - cartesianProduct(stps map (tp => valuesOf(tp))) map Tuple - case SetType(base) => - def elems = valuesOf(base) - elems.scanLeft(Stream(FiniteSet(Set(), base): Expr)){ (prev, curr) => - prev flatMap { - case fs@FiniteSet(elems, tp) => - Stream(fs, FiniteSet(elems + curr, tp)) - } - }.flatten // FIXME Need cp οr is this fine? - case cct: CaseClassType => - cartesianProduct(cct.fieldsTypes map valuesOf) map (CaseClass(cct, _)) - case act: AbstractClassType => - interleave(act.knownCCDescendants.map(cct => valuesOf(cct))) - } - } - - - /** Hoists all IfExpr at top level. - * - * Guarantees that all IfExpr will be at the top level and as soon as you - * encounter a non-IfExpr, then no more IfExpr can be found in the - * sub-expressions - * - * Assumes no match expressions - */ - def hoistIte(expr: Expr): Expr = { - def transform(expr: Expr): Option[Expr] = expr match { - case IfExpr(c, t, e) => None - - case nop@Deconstructor(ts, op) => { - val iteIndex = ts.indexWhere{ case IfExpr(_, _, _) => true case _ => false } - if(iteIndex == -1) None else { - val (beforeIte, startIte) = ts.splitAt(iteIndex) - val afterIte = startIte.tail - val IfExpr(c, t, e) = startIte.head - Some(IfExpr(c, - op(beforeIte ++ Seq(t) ++ afterIte).copiedFrom(nop), - op(beforeIte ++ Seq(e) ++ afterIte).copiedFrom(nop) - )) - } - } - case _ => None - } - - postMap(transform, applyRec = true)(expr) - } - - def collectWithPC[T](f: PartialFunction[Expr, T])(expr: Expr): Seq[(T, Path)] = { - - def rec(expr: Expr, path: Path): Seq[(T, Path)] = { - val seq = if (f.isDefinedAt(expr)) { - Seq(f(expr) -> path) - } else { - Seq.empty[(T, Path)] - } - - val rseq = expr match { - case Let(i, v, b) => - rec(v, path) ++ - rec(b, path withBinding (i -> v)) - - case Ensuring(Require(pre, body), Lambda(Seq(arg), post)) => - rec(pre, path) ++ - rec(body, path withCond pre) ++ - rec(post, path withCond pre withBinding (arg.toVariable -> body)) - - case Ensuring(body, Lambda(Seq(arg), post)) => - rec(body, path) ++ - rec(post, path withBinding (arg.toVariable -> body)) - - case Require(pre, body) => - rec(pre, path) ++ - rec(body, path withCond pre) - - case Assert(pred, err, body) => - rec(pred, path) ++ - rec(body, path withCond pred) - - case MatchExpr(scrut, cases) => - val rs = rec(scrut, path) - var soFar = path - - rs ++ cases.flatMap { c => - val patternPathPos = conditionForPattern(scrut, c.pattern, includeBinders = true) - val patternPathNeg = conditionForPattern(scrut, c.pattern, includeBinders = false) - val map = mapForPattern(scrut, c.pattern) - val guardOrTrue = c.optGuard.getOrElse(BooleanLiteral(true)) - val guardMapped = replaceFromIDs(map, guardOrTrue) - - val rc = rec((patternPathPos withCond guardOrTrue).fullClause, soFar) - val subPath = soFar merge (patternPathPos withCond guardOrTrue) - val rrhs = rec(c.rhs, subPath) - - soFar = soFar merge (patternPathNeg withCond guardMapped).negate - rc ++ rrhs - } - - case IfExpr(cond, thenn, elze) => - rec(cond, path) ++ - rec(thenn, path withCond cond) ++ - rec(elze, path withCond Not(cond)) - - case And(es) => - var soFar = path - es.flatMap { e => - val re = rec(e, soFar) - soFar = soFar withCond e - re - } - - case Or(es) => - var soFar = path - es.flatMap { e => - val re = rec(e, soFar) - soFar = soFar withCond Not(e) - re - } - - case Implies(lhs, rhs) => - rec(lhs, path) ++ - rec(rhs, path withCond lhs) - - case Operator(es, _) => - es.flatMap(rec(_, path)) - - case _ => sys.error("Expression " + e + "["+e.getClass+"] is not extractable") - } - - seq ++ rseq - } - - rec(expr, Path.empty) - } - - override def formulaSize(e: Expr): Int = e match { - case ml: MatchExpr => - super.formulaSize(e) + ml.cases.map(cs => PatternOps.formulaSize(cs.pattern)).sum - case _ => - super.formulaSize(e) - } - - /** Returns true if the expression is deterministic / - * does not contain any [[purescala.Expressions.Choose Choose]] - * or [[purescala.Expressions.Hole Hole]] or [[purescala.Expressions.WithOracle WithOracle]] - */ - def isDeterministic(e: Expr): Boolean = { - exists { - case _ : Choose | _: Hole | _: WithOracle => false - case _ => true - }(e) - } - - /** Returns if this expression behaves as a purely functional construct, - * i.e. always returns the same value (for the same environment) and has no side-effects - */ - def isPurelyFunctional(e: Expr): Boolean = { - exists { - case _ : Error | _ : Choose | _: Hole | _: WithOracle => false - case _ => true - }(e) - } - - /** Returns the value for an identifier given a model. */ - def valuateWithModel(model: Model)(id: Identifier): Expr = { - model.getOrElse(id, simplestValue(id.getType)) - } - - /** Substitute (free) variables in an expression with values form a model. - * - * Complete with simplest values in case of incomplete model. - */ - def valuateWithModelIn(expr: Expr, vars: Set[Identifier], model: Model): Expr = { - val valuator = valuateWithModel(model) _ - replace(vars.map(id => Variable(id) -> valuator(id)).toMap, expr) - } - - /** Simple, local optimization on string */ - def simplifyString(expr: Expr): Expr = { - def simplify0(expr: Expr): Expr = (expr match { - case StringConcat(StringLiteral(""), b) => b - case StringConcat(b, StringLiteral("")) => b - case StringConcat(StringLiteral(a), StringLiteral(b)) => StringLiteral(a + b) - case StringLength(StringLiteral(a)) => IntLiteral(a.length) - case StringBigLength(StringLiteral(a)) => InfiniteIntegerLiteral(a.length) - case SubString(StringLiteral(a), IntLiteral(start), IntLiteral(end)) => StringLiteral(a.substring(start.toInt, end.toInt)) - case BigSubString(StringLiteral(a), InfiniteIntegerLiteral(start), InfiniteIntegerLiteral(end)) => StringLiteral(a.substring(start.toInt, end.toInt)) - case _ => expr - }).copiedFrom(expr) - simplify0(expr) - fixpoint(simplePostTransform(simplify0))(expr) - } - - /** Simple, local simplification on arithmetic - * - * You should not assume anything smarter than some constant folding and - * simple cancellation. To avoid infinite cycle we only apply simplification - * that reduce the size of the tree. The only guarantee from this function is - * to not augment the size of the expression and to be sound. - */ - def simplifyArithmetic(expr: Expr): Expr = { - def simplify0(expr: Expr): Expr = (expr match { - case Plus(InfiniteIntegerLiteral(i1), InfiniteIntegerLiteral(i2)) => InfiniteIntegerLiteral(i1 + i2) - case Plus(InfiniteIntegerLiteral(zero), e) if zero == BigInt(0) => e - case Plus(e, InfiniteIntegerLiteral(zero)) if zero == BigInt(0) => e - case Plus(e1, UMinus(e2)) => Minus(e1, e2) - case Plus(Plus(e, InfiniteIntegerLiteral(i1)), InfiniteIntegerLiteral(i2)) => Plus(e, InfiniteIntegerLiteral(i1+i2)) - case Plus(Plus(InfiniteIntegerLiteral(i1), e), InfiniteIntegerLiteral(i2)) => Plus(InfiniteIntegerLiteral(i1+i2), e) - - case Minus(e, InfiniteIntegerLiteral(zero)) if zero == BigInt(0) => e - case Minus(InfiniteIntegerLiteral(zero), e) if zero == BigInt(0) => UMinus(e) - case Minus(InfiniteIntegerLiteral(i1), InfiniteIntegerLiteral(i2)) => InfiniteIntegerLiteral(i1 - i2) - case Minus(e1, UMinus(e2)) => Plus(e1, e2) - case Minus(e1, Minus(UMinus(e2), e3)) => Plus(e1, Plus(e2, e3)) - - case UMinus(InfiniteIntegerLiteral(x)) => InfiniteIntegerLiteral(-x) - case UMinus(UMinus(x)) => x - case UMinus(Plus(UMinus(e1), e2)) => Plus(e1, UMinus(e2)) - case UMinus(Minus(e1, e2)) => Minus(e2, e1) - - case Times(InfiniteIntegerLiteral(i1), InfiniteIntegerLiteral(i2)) => InfiniteIntegerLiteral(i1 * i2) - case Times(InfiniteIntegerLiteral(one), e) if one == BigInt(1) => e - case Times(InfiniteIntegerLiteral(mone), e) if mone == BigInt(-1) => UMinus(e) - case Times(e, InfiniteIntegerLiteral(one)) if one == BigInt(1) => e - case Times(InfiniteIntegerLiteral(zero), _) if zero == BigInt(0) => InfiniteIntegerLiteral(0) - case Times(_, InfiniteIntegerLiteral(zero)) if zero == BigInt(0) => InfiniteIntegerLiteral(0) - case Times(InfiniteIntegerLiteral(i1), Times(InfiniteIntegerLiteral(i2), t)) => Times(InfiniteIntegerLiteral(i1*i2), t) - case Times(InfiniteIntegerLiteral(i1), Times(t, InfiniteIntegerLiteral(i2))) => Times(InfiniteIntegerLiteral(i1*i2), t) - case Times(InfiniteIntegerLiteral(i), UMinus(e)) => Times(InfiniteIntegerLiteral(-i), e) - case Times(UMinus(e), InfiniteIntegerLiteral(i)) => Times(e, InfiniteIntegerLiteral(-i)) - case Times(InfiniteIntegerLiteral(i1), Division(e, InfiniteIntegerLiteral(i2))) if i2 != BigInt(0) && i1 % i2 == BigInt(0) => Times(InfiniteIntegerLiteral(i1/i2), e) - - case Division(InfiniteIntegerLiteral(i1), InfiniteIntegerLiteral(i2)) if i2 != BigInt(0) => InfiniteIntegerLiteral(i1 / i2) - case Division(e, InfiniteIntegerLiteral(one)) if one == BigInt(1) => e - - //here we put more expensive rules - //btw, I know those are not the most general rules, but they lead to good optimizations :) - case Plus(UMinus(Plus(e1, e2)), e3) if e1 == e3 => UMinus(e2) - case Plus(UMinus(Plus(e1, e2)), e3) if e2 == e3 => UMinus(e1) - case Minus(e1, e2) if e1 == e2 => InfiniteIntegerLiteral(0) - case Minus(Plus(e1, e2), Plus(e3, e4)) if e1 == e4 && e2 == e3 => InfiniteIntegerLiteral(0) - case Minus(Plus(e1, e2), Plus(Plus(e3, e4), e5)) if e1 == e4 && e2 == e3 => UMinus(e5) - - case StringConcat(StringLiteral(""), a) => a - case StringConcat(a, StringLiteral("")) => a - case StringConcat(StringLiteral(a), StringLiteral(b)) => StringLiteral(a+b) - case StringConcat(StringLiteral(a), StringConcat(StringLiteral(b), c)) => StringConcat(StringLiteral(a+b), c) - case StringConcat(StringConcat(c, StringLiteral(a)), StringLiteral(b)) => StringConcat(c, StringLiteral(a+b)) - case StringConcat(a, StringConcat(b, c)) => StringConcat(StringConcat(a, b), c) - //default - case e => e - }).copiedFrom(expr) - - fixpoint(simplePostTransform(simplify0))(expr) - } - - /** - * Some helper methods for FractionalLiterals - */ - def normalizeFraction(fl: FractionalLiteral) = { - val FractionalLiteral(num, denom) = fl - val modNum = if (num < 0) -num else num - val modDenom = if (denom < 0) -denom else denom - val divisor = modNum.gcd(modDenom) - val simpNum = num / divisor - val simpDenom = denom / divisor - if (simpDenom < 0) - FractionalLiteral(-simpNum, -simpDenom) - else - FractionalLiteral(simpNum, simpDenom) - } - - val realzero = FractionalLiteral(0, 1) - def floor(fl: FractionalLiteral): FractionalLiteral = { - val FractionalLiteral(n, d) = normalizeFraction(fl) - if (d == 0) throw new IllegalStateException("denominator zero") - if (n == 0) realzero - else if (n > 0) { - //perform integer division - FractionalLiteral(n / d, 1) - } else { - //here the number is negative - if (n % d == 0) - FractionalLiteral(n / d, 1) - else { - //perform integer division and subtract 1 - FractionalLiteral(n / d - 1, 1) - } - } - } - - /** Checks whether a predicate is inductive on a certain identifier. - * - * isInductive(foo(a, b), a) where a: List will check whether - * foo(Nil, b) and - * foo(t, b) => foo(Cons(h,t), b) - */ - def isInductiveOn(sf: SolverFactory[Solver])(path: Path, on: Identifier): Boolean = on match { - case IsTyped(origId, AbstractClassType(cd, tps)) => - - val toCheck = cd.knownDescendants.collect { - case ccd: CaseClassDef => - val cct = CaseClassType(ccd, tps) - - val isType = IsInstanceOf(Variable(on), cct) - - val recSelectors = (cct.classDef.fields zip cct.fieldsTypes).collect { - case (vd, tpe) if tpe == on.getType => vd.id - } - - if (recSelectors.isEmpty) { - Seq() - } else { - val v = Variable(on) - - recSelectors.map { s => - and(path and isType, not(replace(Map(v -> caseClassSelector(cct, v, s)), path.toClause))) - } - } - }.flatten - - val solver = SimpleSolverAPI(sf) - - toCheck.forall { cond => - solver.solveSAT(cond)._1 match { - case Some(false) => - true - case Some(true) => - false - case None => - // Should we be optimistic here? - false - } - } - case _ => - false - } - - type Apriori = Map[Identifier, Identifier] - - /** Checks whether two expressions can be homomorphic and returns the corresponding mapping */ - def canBeHomomorphic(t1: Expr, t2: Expr): Option[Map[Identifier, Identifier]] = { - val freeT1Variables = ExprOps.variablesOf(t1) - val freeT2Variables = ExprOps.variablesOf(t2) - - def mergeContexts( - a: Option[Apriori], - b: Apriori => Option[Apriori]): - Option[Apriori] = a.flatMap(b) - - implicit class AugmentedContext(c: Option[Apriori]) { - def &&(other: Apriori => Option[Apriori]): Option[Apriori] = mergeContexts(c, other) - def --(other: Seq[Identifier]) = - c.map(_ -- other) - } - implicit class AugmentedBoolean(c: Boolean) { - def &&(other: => Option[Apriori]) = if(c) other else None - } - implicit class AugmentedFilter(c: Apriori => Option[Apriori]) { - def &&(other: Apriori => Option[Apriori]): - Apriori => Option[Apriori] - = (m: Apriori) => c(m).flatMap(mp => other(mp)) - } - implicit class AugmentedSeq[T](c: Seq[T]) { - def mergeall(p: T => Apriori => Option[Apriori])(apriori: Apriori) = - (Option(apriori) /: c) { - case (s, c) => s.flatMap(apriori => p(c)(apriori)) - } - } - implicit def noneToContextTaker(c: None.type) = { - (m: Apriori) => None - } - - - def idHomo(i1: Identifier, i2: Identifier)(apriori: Apriori): Option[Apriori] = { - if(!(freeT1Variables(i1) || freeT2Variables(i2)) || i1 == i2 || apriori.get(i1) == Some(i2)) Some(Map(i1 -> i2)) else None - } - def idOptionHomo(i1: Option[Identifier], i2: Option[Identifier])(apriori: Apriori): Option[Apriori] = { - (i1.size == i2.size) && (i1 zip i2).headOption.flatMap(i => idHomo(i._1, i._2)(apriori)) - } - - def fdHomo(fd1: FunDef, fd2: FunDef)(apriori: Apriori): Option[Apriori] = { - if(fd1.params.size == fd2.params.size) { - val newMap = Map(( - (fd1.id -> fd2.id) +: - (fd1.paramIds zip fd2.paramIds)): _*) - Option(newMap) && isHomo(fd1.fullBody, fd2.fullBody) - } else None - } - - def isHomo(t1: Expr, t2: Expr)(apriori: Apriori): Option[Apriori] = { - def casesMatch(cs1 : Seq[MatchCase], cs2 : Seq[MatchCase])(apriori: Apriori) : Option[Apriori] = { - def patternHomo(p1: Pattern, p2: Pattern)(apriori: Apriori): Option[Apriori] = (p1, p2) match { - case (InstanceOfPattern(ob1, cd1), InstanceOfPattern(ob2, cd2)) => - cd1 == cd2 && idOptionHomo(ob1, ob2)(apriori) - - case (WildcardPattern(ob1), WildcardPattern(ob2)) => - idOptionHomo(ob1, ob2)(apriori) - - case (CaseClassPattern(ob1, ccd1, subs1), CaseClassPattern(ob2, ccd2, subs2)) => - val m = idOptionHomo(ob1, ob2)(apriori) - - (ccd1 == ccd2 && subs1.size == subs2.size) && m && - ((subs1 zip subs2) mergeall { case (p1, p2) => patternHomo(p1, p2) }) - - case (UnapplyPattern(ob1, TypedFunDef(fd1, ts1), subs1), UnapplyPattern(ob2, TypedFunDef(fd2, ts2), subs2)) => - val m = idOptionHomo(ob1, ob2)(apriori) - - (subs1.size == subs2.size && ts1 == ts2) && m && fdHomo(fd1, fd2) && ( - (subs1 zip subs2) mergeall { case (p1, p2) => patternHomo(p1, p2) }) - - case (TuplePattern(ob1, subs1), TuplePattern(ob2, subs2)) => - val m = idOptionHomo(ob1, ob2)(apriori) - - (ob1.size == ob2.size && subs1.size == subs2.size) && m && ( - (subs1 zip subs2) mergeall { case (p1, p2) => patternHomo(p1, p2) }) - - case (LiteralPattern(ob1, lit1), LiteralPattern(ob2,lit2)) => - lit1 == lit2 && idOptionHomo(ob1, ob2)(apriori) - - case _ => - None - } - - (cs1 zip cs2).mergeall { - case (MatchCase(p1, g1, e1), MatchCase(p2, g2, e2)) => - val h = patternHomo(p1, p2) _ - val g: Apriori => Option[Apriori] = (g1, g2) match { - case (Some(g1), Some(g2)) => isHomo(g1, g2)(_) - case (None, None) => (m: Apriori) => Some(m) - case _ => None - } - val e = isHomo(e1, e2) _ - - h && g && e - }(apriori) - } - - val res: Option[Apriori] = (t1, t2) match { - case (Variable(i1), Variable(i2)) => - idHomo(i1, i2)(apriori) - - case (Let(id1, v1, e1), Let(id2, v2, e2)) => - - isHomo(v1, v2)(apriori + (id1 -> id2)) && - isHomo(e1, e2) - - case (Hole(_, _), Hole(_, _)) => - None - - case (LetDef(fds1, e1), LetDef(fds2, e2)) => - fds1.size == fds2.size && - { - val zipped = fds1.zip(fds2) - (zipped mergeall (fds => fdHomo(fds._1, fds._2)))(apriori) && - isHomo(e1, e2) - } - - case (MatchExpr(s1, cs1), MatchExpr(s2, cs2)) => - cs1.size == cs2.size && casesMatch(cs1,cs2)(apriori) && isHomo(s1, s2) - - case (Passes(in1, out1, cs1), Passes(in2, out2, cs2)) => - (cs1.size == cs2.size && casesMatch(cs1,cs2)(apriori)) && isHomo(in1,in2) && isHomo(out1,out2) - - case (FunctionInvocation(tfd1, args1), FunctionInvocation(tfd2, args2)) => - (if(tfd1 == tfd2) Some(apriori) else (apriori.get(tfd1.fd.id) match { - case None => - isHomo(tfd1.fd.fullBody, tfd2.fd.fullBody)(apriori + (tfd1.fd.id -> tfd2.fd.id)) - case Some(fdid2) => - if(fdid2 == tfd2.fd.id) Some(apriori) else None - })) && - tfd1.tps.zip(tfd2.tps).mergeall{ - case (t1, t2) => if(t1 == t2) - (m: Apriori) => Option(m) - else (m: Apriori) => None} && - (args1 zip args2).mergeall{ case (a1, a2) => isHomo(a1, a2) } - - case (Lambda(defs, body), Lambda(defs2, body2)) => - // We remove variables introduced by lambdas. - ((defs zip defs2).mergeall{ case (ValDef(a1), ValDef(a2)) => - (m: Apriori) => - Some(m + (a1 -> a2)) }(apriori) - && isHomo(body, body2) - ) -- (defs.map(_.id)) - - case (v1, v2) if isValue(v1) && isValue(v2) => - v1 == v2 && Some(apriori) - - case Same(Operator(es1, _), Operator(es2, _)) => - (es1.size == es2.size) && - (es1 zip es2).mergeall{ case (e1, e2) => isHomo(e1, e2) }(apriori) - - case _ => - None - } - - res - } - - isHomo(t1,t2)(Map()) - } // ensuring (res => res.isEmpty || isHomomorphic(t1, t2)(res.get)) - - /** Checks whether two trees are homomoprhic modulo an identifier map. - * - * Used for transformation tests. - */ - def isHomomorphic(t1: Expr, t2: Expr)(implicit map: Map[Identifier, Identifier]): Boolean = { - object Same { - def unapply(tt: (Expr, Expr)): Option[(Expr, Expr)] = { - if (tt._1.getClass == tt._2.getClass) { - Some(tt) - } else { - None - } - } - } - - def idHomo(i1: Identifier, i2: Identifier)(implicit map: Map[Identifier, Identifier]) = { - i1 == i2 || map.get(i1).contains(i2) - } - - def fdHomo(fd1: FunDef, fd2: FunDef)(implicit map: Map[Identifier, Identifier]) = { - (fd1.params.size == fd2.params.size) && { - val newMap = map + - (fd1.id -> fd2.id) ++ - (fd1.paramIds zip fd2.paramIds) - isHomo(fd1.fullBody, fd2.fullBody)(newMap) - } - } - - def isHomo(t1: Expr, t2: Expr)(implicit map: Map[Identifier,Identifier]): Boolean = { - - def casesMatch(cs1 : Seq[MatchCase], cs2 : Seq[MatchCase]) : Boolean = { - def patternHomo(p1: Pattern, p2: Pattern): (Boolean, Map[Identifier, Identifier]) = (p1, p2) match { - case (InstanceOfPattern(ob1, cd1), InstanceOfPattern(ob2, cd2)) => - (ob1.size == ob2.size && cd1 == cd2, Map((ob1 zip ob2).toSeq : _*)) - - case (WildcardPattern(ob1), WildcardPattern(ob2)) => - (ob1.size == ob2.size, Map((ob1 zip ob2).toSeq : _*)) - - case (CaseClassPattern(ob1, ccd1, subs1), CaseClassPattern(ob2, ccd2, subs2)) => - val m = Map[Identifier, Identifier]() ++ (ob1 zip ob2) - - if (ob1.size == ob2.size && ccd1 == ccd2 && subs1.size == subs2.size) { - (subs1 zip subs2).map { case (p1, p2) => patternHomo(p1, p2) }.foldLeft((true, m)) { - case ((b1, m1), (b2,m2)) => (b1 && b2, m1 ++ m2) - } - } else { - (false, Map()) - } - - case (UnapplyPattern(ob1, fd1, subs1), UnapplyPattern(ob2, fd2, subs2)) => - val m = Map[Identifier, Identifier]() ++ (ob1 zip ob2) - - if (ob1.size == ob2.size && fd1 == fd2 && subs1.size == subs2.size) { - (subs1 zip subs2).map { case (p1, p2) => patternHomo(p1, p2) }.foldLeft((true, m)) { - case ((b1, m1), (b2,m2)) => (b1 && b2, m1 ++ m2) - } - } else { - (false, Map()) - } - - case (TuplePattern(ob1, subs1), TuplePattern(ob2, subs2)) => - val m = Map[Identifier, Identifier]() ++ (ob1 zip ob2) - - if (ob1.size == ob2.size && subs1.size == subs2.size) { - (subs1 zip subs2).map { case (p1, p2) => patternHomo(p1, p2) }.foldLeft((true, m)) { - case ((b1, m1), (b2,m2)) => (b1 && b2, m1 ++ m2) - } - } else { - (false, Map()) - } - - case (LiteralPattern(ob1, lit1), LiteralPattern(ob2,lit2)) => - (ob1.size == ob2.size && lit1 == lit2, (ob1 zip ob2).toMap) - - case _ => - (false, Map()) - } - - (cs1 zip cs2).forall { - case (MatchCase(p1, g1, e1), MatchCase(p2, g2, e2)) => - val (h, nm) = patternHomo(p1, p2) - val g = (g1, g2) match { - case (Some(g1), Some(g2)) => isHomo(g1,g2)(map ++ nm) - case (None, None) => true - case _ => false - } - val e = isHomo(e1, e2)(map ++ nm) - - g && e && h - } - - } - - val res = (t1, t2) match { - case (Variable(i1), Variable(i2)) => - idHomo(i1, i2) - - case (Let(id1, v1, e1), Let(id2, v2, e2)) => - isHomo(v1, v2) && - isHomo(e1, e2)(map + (id1 -> id2)) - - case (LetDef(fds1, e1), LetDef(fds2, e2)) => - fds1.size == fds2.size && - { - val zipped = fds1.zip(fds2) - zipped.forall( fds => - fdHomo(fds._1, fds._2) - ) && - isHomo(e1, e2)(map ++ zipped.map(fds => fds._1.id -> fds._2.id)) - } - - case (MatchExpr(s1, cs1), MatchExpr(s2, cs2)) => - cs1.size == cs2.size && isHomo(s1, s2) && casesMatch(cs1,cs2) - - case (Passes(in1, out1, cs1), Passes(in2, out2, cs2)) => - cs1.size == cs2.size && isHomo(in1,in2) && isHomo(out1,out2) && casesMatch(cs1,cs2) - - case (FunctionInvocation(tfd1, args1), FunctionInvocation(tfd2, args2)) => - // TODO: Check type params - fdHomo(tfd1.fd, tfd2.fd) && - (args1 zip args2).forall{ case (a1, a2) => isHomo(a1, a2) } - - case Same(Deconstructor(es1, _), Deconstructor(es2, _)) => - (es1.size == es2.size) && - (es1 zip es2).forall{ case (e1, e2) => isHomo(e1, e2) } - - case _ => - false - } - - res - } - - isHomo(t1,t2) - } - - /** Checks whether the match cases cover all possible inputs. - * - * Used when reconstructing pattern matching from ITE. - * - * e.g. The following: - * {{{ - * list match { - * case Cons(_, Cons(_, a)) => - * - * case Cons(_, Nil) => - * - * case Nil => - * - * } - * }}} - * is exaustive. - * - * @note Unused and unmaintained - */ - def isMatchExhaustive(m: MatchExpr): Boolean = { - - /* - * Takes the matrix of the cases per position/types: - * e.g. - * e match { // Where e: (T1, T2, T3) - * case (P1, P2, P3) => - * case (P4, P5, P6) => - * - * becomes checked as: - * Seq( (T1, Seq(P1, P4)), (T2, Seq(P2, P5)), (T3, Seq(p3, p6))) - * - * We then check that P1+P4 covers every T1, etc.. - * - * TODO: We ignore type parameters here, we might want to make sure it's - * valid. What's Leon's semantics w.r.t. erasure? - */ - def areExhaustive(pss: Seq[(TypeTree, Seq[Pattern])]): Boolean = pss.forall { case (tpe, ps) => - - tpe match { - case TupleType(tpes) => - val subs = ps.collect { - case TuplePattern(_, bs) => - bs - } - - areExhaustive(tpes zip subs.transpose) - - case _: ClassType => - - def typesOf(tpe: TypeTree): Set[CaseClassDef] = tpe match { - case AbstractClassType(ctp, _) => - ctp.knownDescendants.collect { case c: CaseClassDef => c }.toSet - - case CaseClassType(ctd, _) => - Set(ctd) - - case _ => - Set() - } - - var subChecks = typesOf(tpe).map(_ -> Seq[Seq[Pattern]]()).toMap - - for (p <- ps) p match { - case w: WildcardPattern => - // (a) Wildcard covers everything, no type left to check - subChecks = Map.empty - - case InstanceOfPattern(_, cct) => - // (a: B) covers all Bs - subChecks --= typesOf(cct) - - case CaseClassPattern(_, cct, subs) => - val ccd = cct.classDef - // We record the patterns per types, if they still need to be checked - if (subChecks contains ccd) { - subChecks += (ccd -> (subChecks(ccd) :+ subs)) - } - - case _ => - sys.error("Unexpected case: "+p) - } - - subChecks.forall { case (ccd, subs) => - val tpes = ccd.fields.map(_.getType) - - if (subs.isEmpty) { - false - } else { - areExhaustive(tpes zip subs.transpose) - } - } - - case BooleanType => - // make sure ps contains either - // - Wildcard or - // - both true and false - (ps exists { _.isInstanceOf[WildcardPattern] }) || { - var found = Set[Boolean]() - ps foreach { - case LiteralPattern(_, BooleanLiteral(b)) => found += b - case _ => () - } - (found contains true) && (found contains false) - } - - case UnitType => - // Anything matches () - ps.nonEmpty - - case StringType => - // Can't possibly pattern match against all Strings one by one - ps exists (_.isInstanceOf[WildcardPattern]) - - case Int32Type => - // Can't possibly pattern match against all Ints one by one - ps exists (_.isInstanceOf[WildcardPattern]) - - case _ => - true - } - } - - val patterns = m.cases.map { - case SimpleCase(p, _) => - p - case GuardedCase(p, _, _) => - return false - } - - areExhaustive(Seq((m.scrutinee.getType, patterns))) - } - - /** Flattens a function that contains a LetDef with a direct call to it - * - * Used for merging synthesis results. - * - * {{{ - * def foo(a, b) { - * def bar(c, d) { - * if (..) { bar(c, d) } else { .. } - * } - * bar(b, a) - * } - * }}} - * becomes - * {{{ - * def foo(a, b) { - * if (..) { foo(b, a) } else { .. } - * } - * }}} - */ - def flattenFunctions(fdOuter: FunDef, ctx: LeonContext, p: Program): FunDef = { - fdOuter.body match { - case Some(LetDef(fdsInner, FunctionInvocation(tfdInner2, args))) if fdsInner.size == 1 && fdsInner.head == tfdInner2.fd => - val fdInner = fdsInner.head - val argsDef = fdOuter.paramIds - val argsCall = args.collect { case Variable(id) => id } - - if (argsDef.toSet == argsCall.toSet) { - val defMap = argsDef.zipWithIndex.toMap - val rewriteMap = argsCall.map(defMap) - - val innerIdsToOuterIds = (fdInner.paramIds zip argsCall).toMap - - def pre(e: Expr) = e match { - case FunctionInvocation(tfd, args) if tfd.fd == fdInner => - val newArgs = (args zip rewriteMap).sortBy(_._2) - FunctionInvocation(fdOuter.typed(tfd.tps), newArgs.map(_._1)) - case Variable(id) => - Variable(innerIdsToOuterIds.getOrElse(id, id)) - case _ => - e - } - - def mergePre(outer: Option[Expr], inner: Option[Expr]): Option[Expr] = (outer, inner) match { - case (None, Some(ie)) => - Some(simplePreTransform(pre)(ie)) - case (Some(oe), None) => - Some(oe) - case (None, None) => - None - case (Some(oe), Some(ie)) => - Some(and(oe, simplePreTransform(pre)(ie))) - } - - def mergePost(outer: Option[Expr], inner: Option[Expr]): Option[Expr] = (outer, inner) match { - case (None, Some(ie)) => - Some(simplePreTransform(pre)(ie)) - case (Some(oe), None) => - Some(oe) - case (None, None) => - None - case (Some(oe), Some(ie)) => - val res = FreshIdentifier("res", fdOuter.returnType, true) - Some(Lambda(Seq(ValDef(res)), and( - application(oe, Seq(Variable(res))), - application(simplePreTransform(pre)(ie), Seq(Variable(res))) - ))) - } - - val newFd = fdOuter.duplicate() - - val simp = Simplifiers.bestEffort(ctx, p)((_: Expr)) - - newFd.body = fdInner.body.map(b => simplePreTransform(pre)(b)) - newFd.precondition = mergePre(fdOuter.precondition, fdInner.precondition).map(simp) - newFd.postcondition = mergePost(fdOuter.postcondition, fdInner.postcondition).map(simp) - - newFd - } else { - fdOuter - } - case _ => - fdOuter - } - } - - def expandAndSimplifyArithmetic(expr: Expr): Expr = { - val expr0 = try { - val freeVars: Array[Identifier] = variablesOf(expr).toArray - val coefs: Array[Expr] = TreeNormalizations.linearArithmeticForm(expr, freeVars) - coefs.toList.zip(InfiniteIntegerLiteral(1) :: freeVars.toList.map(Variable)).foldLeft[Expr](InfiniteIntegerLiteral(0))((acc, t) => { - if(t._1 == InfiniteIntegerLiteral(0)) acc else Plus(acc, Times(t._1, t._2)) - }) - } catch { - case _: Throwable => - expr - } - simplifyArithmetic(expr0) - } - - /* ================= - * Body manipulation - * ================= - */ - - /** Returns whether a particular [[Expressions.Expr]] contains specification - * constructs, namely [[Expressions.Require]] and [[Expressions.Ensuring]]. - */ - def hasSpec(e: Expr): Boolean = exists { - case Require(_, _) => true - case Ensuring(_, _) => true - case Let(i, e, b) => hasSpec(b) - case _ => false - } (e) - - /** Merges the given [[Path]] into the provided [[Expressions.Expr]]. - * - * This method expects to run on a [[Definitions.FunDef.fullBody]] and merges into - * existing pre- and postconditions. - * - * @param expr The current body - * @param path The path that should be wrapped around the given body - * @see [[Expressions.Ensuring]] - * @see [[Expressions.Require]] - */ - def withPath(expr: Expr, path: Path): Expr = expr match { - case Let(i, e, b) => withPath(b, path withBinding (i -> e)) - case Require(pre, b) => path specs (b, pre) - case Ensuring(Require(pre, b), post) => path specs (b, pre, post) - case Ensuring(b, post) => path specs (b, post = post) - case b => path specs b - } - - /** Replaces the precondition of an existing [[Expressions.Expr]] with a new one. - * - * If no precondition is provided, removes any existing precondition. - * Else, wraps the expression with a [[Expressions.Require]] clause referring to the new precondition. - * - * @param expr The current expression - * @param pred An optional precondition. Setting it to None removes any precondition. - * @see [[Expressions.Ensuring]] - * @see [[Expressions.Require]] - */ - def withPrecondition(expr: Expr, pred: Option[Expr]): Expr = (pred, expr) match { - case (Some(newPre), Require(pre, b)) => req(newPre, b) - case (Some(newPre), Ensuring(Require(pre, b), p)) => Ensuring(req(newPre, b), p) - case (Some(newPre), Ensuring(b, p)) => Ensuring(req(newPre, b), p) - case (Some(newPre), Let(i, e, b)) if hasSpec(b) => Let(i, e, withPrecondition(b, pred)) - case (Some(newPre), b) => req(newPre, b) - case (None, Require(pre, b)) => b - case (None, Ensuring(Require(pre, b), p)) => Ensuring(b, p) - case (None, Let(i, e, b)) if hasSpec(b) => Let(i, e, withPrecondition(b, pred)) - case (None, b) => b - } - - /** Replaces the postcondition of an existing [[Expressions.Expr]] with a new one. - * - * If no postcondition is provided, removes any existing postcondition. - * Else, wraps the expression with a [[Expressions.Ensuring]] clause referring to the new postcondition. - * - * @param expr The current expression - * @param oie An optional postcondition. Setting it to None removes any postcondition. - * @see [[Expressions.Ensuring]] - * @see [[Expressions.Require]] - */ - def withPostcondition(expr: Expr, oie: Option[Expr]): Expr = (oie, expr) match { - case (Some(npost), Ensuring(b, post)) => ensur(b, npost) - case (Some(npost), Let(i, e, b)) if hasSpec(b) => Let(i, e, withPostcondition(b, oie)) - case (Some(npost), b) => ensur(b, npost) - case (None, Ensuring(b, p)) => b - case (None, Let(i, e, b)) if hasSpec(b) => Let(i, e, withPostcondition(b, oie)) - case (None, b) => b - } - - /** Adds a body to a specification - * - * @param expr The specification expression [[Expressions.Ensuring]] or [[Expressions.Require]]. If none of these, the argument is discarded. - * @param body An option of [[Expressions.Expr]] possibly containing an expression body. - * @return The post/pre condition with the body. If no body is provided, returns [[Expressions.NoTree]] - * @see [[Expressions.Ensuring]] - * @see [[Expressions.Require]] - */ - def withBody(expr: Expr, body: Option[Expr]): Expr = expr match { - case Let(i, e, b) if hasSpec(b) => Let(i, e, withBody(b, body)) - case Require(pre, _) => Require(pre, body.getOrElse(NoTree(expr.getType))) - case Ensuring(Require(pre, _), post) => Ensuring(Require(pre, body.getOrElse(NoTree(expr.getType))), post) - case Ensuring(_, post) => Ensuring(body.getOrElse(NoTree(expr.getType)), post) - case _ => body.getOrElse(NoTree(expr.getType)) - } - - /** Extracts the body without its specification - * - * [[Expressions.Expr]] trees contain its specifications as part of certain nodes. - * This function helps extracting only the body part of an expression - * - * @return An option type with the resulting expression if not [[Expressions.NoTree]] - * @see [[Expressions.Ensuring]] - * @see [[Expressions.Require]] - */ - def withoutSpec(expr: Expr): Option[Expr] = expr match { - case Let(i, e, b) => withoutSpec(b).map(Let(i, e, _)) - case Require(pre, b) => Option(b).filterNot(_.isInstanceOf[NoTree]) - case Ensuring(Require(pre, b), post) => Option(b).filterNot(_.isInstanceOf[NoTree]) - case Ensuring(b, post) => Option(b).filterNot(_.isInstanceOf[NoTree]) - case b => Option(b).filterNot(_.isInstanceOf[NoTree]) - } - - /** Returns the precondition of an expression wrapped in Option */ - def preconditionOf(expr: Expr): Option[Expr] = expr match { - case Let(i, e, b) => preconditionOf(b).map(Let(i, e, _).copiedFrom(expr)) - case Require(pre, _) => Some(pre) - case Ensuring(Require(pre, _), _) => Some(pre) - case b => None - } - - /** Returns the postcondition of an expression wrapped in Option */ - def postconditionOf(expr: Expr): Option[Expr] = expr match { - case Let(i, e, b) => postconditionOf(b).map(Let(i, e, _).copiedFrom(expr)) - case Ensuring(_, post) => Some(post) - case _ => None - } - - /** Returns a tuple of precondition, the raw body and the postcondition of an expression */ - def breakDownSpecs(e : Expr) = (preconditionOf(e), withoutSpec(e), postconditionOf(e)) - - def preTraversalWithParent(f: (Expr, Option[Tree]) => Unit, initParent: Option[Tree] = None)(e: Expr): Unit = { - val rec = preTraversalWithParent(f, Some(e)) _ - - f(e, initParent) - - val Deconstructor(es, _) = e - es foreach rec - } - - object InvocationExtractor { - private def flatInvocation(expr: Expr): Option[(TypedFunDef, Seq[Expr])] = expr match { - case fi @ FunctionInvocation(tfd, args) => Some((tfd, args)) - case Application(caller, args) => flatInvocation(caller) match { - case Some((tfd, prevArgs)) => Some((tfd, prevArgs ++ args)) - case None => None - } - case _ => None - } - - def unapply(expr: Expr): Option[(TypedFunDef, Seq[Expr])] = expr match { - case IsTyped(f: FunctionInvocation, ft: FunctionType) => None - case IsTyped(f: Application, ft: FunctionType) => None - case FunctionInvocation(tfd, args) => Some(tfd -> args) - case f: Application => flatInvocation(f) - case _ => None - } - } - - def firstOrderCallsOf(expr: Expr): Set[(TypedFunDef, Seq[Expr])] = - collect[(TypedFunDef, Seq[Expr])] { - case InvocationExtractor(tfd, args) => Set(tfd -> args) - case _ => Set.empty - }(expr) - - object ApplicationExtractor { - private def flatApplication(expr: Expr): Option[(Expr, Seq[Expr])] = expr match { - case Application(fi: FunctionInvocation, _) => None - case Application(caller: Application, args) => flatApplication(caller) match { - case Some((c, prevArgs)) => Some((c, prevArgs ++ args)) - case None => None - } - case Application(caller, args) => Some((caller, args)) - case _ => None - } - - def unapply(expr: Expr): Option[(Expr, Seq[Expr])] = expr match { - case IsTyped(f: Application, ft: FunctionType) => None - case f: Application => flatApplication(f) - case _ => None - } - } - - def firstOrderAppsOf(expr: Expr): Set[(Expr, Seq[Expr])] = - collect[(Expr, Seq[Expr])] { - case ApplicationExtractor(caller, args) => Set(caller -> args) - case _ => Set.empty - } (expr) - - def simplifyHOFunctions(expr: Expr) : Expr = { - - def liftToLambdas(expr: Expr) = { - def apply(expr: Expr, args: Seq[Expr]): Expr = expr match { - case IfExpr(cond, thenn, elze) => - IfExpr(cond, apply(thenn, args), apply(elze, args)) - case Let(i, e, b) => - Let(i, e, apply(b, args)) - case LetTuple(is, es, b) => - letTuple(is, es, apply(b, args)) - //case l @ Lambda(params, body) => - // l.withParamSubst(args, body) - case _ => Application(expr, args) - } - - def lift(expr: Expr): Expr = expr.getType match { - case FunctionType(from, to) => expr match { - case _ : Lambda => expr - case _ : Variable => expr - case e => - val args = from.map(tpe => ValDef(FreshIdentifier("x", tpe, true))) - val application = apply(expr, args.map(_.toVariable)) - Lambda(args, lift(application)) - } - case _ => expr - } - - def extract(expr: Expr, build: Boolean) = if (build) lift(expr) else expr - - def rec(expr: Expr, build: Boolean): Expr = expr match { - case Application(caller, args) => - val newArgs = args.map(rec(_, true)) - val newCaller = rec(caller, false) - extract(Application(newCaller, newArgs), build) - case FunctionInvocation(fd, args) => - val newArgs = args.map(rec(_, true)) - extract(FunctionInvocation(fd, newArgs), build) - case l @ Lambda(args, body) => - val newBody = rec(body, true) - extract(Lambda(args, newBody), build) - case Deconstructor(es, recons) => recons(es.map(rec(_, build))) - } - - rec(lift(expr), true) - } - - liftToLambdas( - matchToIfThenElse( - expr - ) - ) - } - - /** lift closures introduced by synthesis. - * - * Closures already define all - * the necessary information as arguments, no need to close them. - */ - def liftClosures(e: Expr): (Set[FunDef], Expr) = { - var fds: Map[FunDef, FunDef] = Map() - - val res1 = preMap({ - case LetDef(lfds, b) => - val nfds = lfds.map(fd => fd -> fd.duplicate()) - - fds ++= nfds - - Some(letDef(nfds.map(_._2), b)) - - case FunctionInvocation(tfd, args) => - if (fds contains tfd.fd) { - Some(FunctionInvocation(fds(tfd.fd).typed(tfd.tps), args)) - } else { - None - } - - case _ => - None - })(e) - - // we now remove LetDefs - val res2 = preMap({ - case LetDef(fds, b) => - Some(b) - case _ => - None - }, applyRec = true)(res1) - - (fds.values.toSet, res2) - } - - def isListLiteral(e: Expr)(implicit pgm: Program): Option[(TypeTree, List[Expr])] = e match { - case CaseClass(CaseClassType(classDef, Seq(tp)), Nil) => - for { - leonNil <- pgm.library.Nil - if classDef == leonNil - } yield { - (tp, Nil) - } - case CaseClass(CaseClassType(classDef, Seq(tp)), Seq(hd, tl)) => - for { - leonCons <- pgm.library.Cons - if classDef == leonCons - (_, tlElems) <- isListLiteral(tl) - } yield { - (tp, hd :: tlElems) - } - case _ => - None - } - - - /** Collects from within an expression all conditions under which the evaluation of the expression - * will not fail (e.g. by violating a function precondition or evaluating to an error). - * - * Collection of preconditions of function invocations can be disabled - * (mainly for [[leon.verification.Tactic]]). - * - * @param e The expression for which correctness conditions are calculated. - * @param collectFIs Whether we also want to collect preconditions for function invocations - * @return A sequence of pairs (expression, condition) - */ - def collectCorrectnessConditions(e: Expr, collectFIs: Boolean = false): Seq[(Expr, Expr)] = { - val conds = collectWithPC { - - case m @ MatchExpr(scrut, cases) => - (m, orJoin(cases map (matchCaseCondition(scrut, _).toClause))) - - case e @ Error(_, _) => - (e, BooleanLiteral(false)) - - case a @ Assert(cond, _, _) => - (a, cond) - - /*case e @ Ensuring(body, post) => - (e, application(post, Seq(body))) - - case r @ Require(pred, e) => - (r, pred)*/ - - case fi @ FunctionInvocation(tfd, args) if tfd.hasPrecondition && collectFIs => - (fi, tfd.withParamSubst(args, tfd.precondition.get)) - }(e) - - conds map { - case ((e, cond), path) => - (e, path implies cond) - } - } - - - def simpleCorrectnessCond(e: Expr, path: Path, sf: SolverFactory[Solver]): Expr = { - simplifyPaths(sf, path)( - andJoin( collectCorrectnessConditions(e) map { _._2 } ) - ) - } - - def tupleWrapArg(fun: Expr) = fun.getType match { - case FunctionType(args, res) if args.size > 1 => - val newArgs = fun match { - case Lambda(args, _) => args map (_.id) - case _ => args map (arg => FreshIdentifier("x", arg.getType, alwaysShowUniqueID = true)) - } - val res = FreshIdentifier("res", TupleType(args map (_.getType)), alwaysShowUniqueID = true) - val patt = TuplePattern(None, newArgs map (arg => WildcardPattern(Some(arg)))) - Lambda(Seq(ValDef(res)), MatchExpr(res.toVariable, Seq(SimpleCase(patt, application(fun, newArgs map (_.toVariable)))))) - case _ => - fun - } - - // Use this only to debug isValueOfType - private implicit class BooleanAdder(b: Boolean) { - @inline def <(msg: String) = {/*if(!b) println(msg); */b} - } - - /** Returns true if expr is a value of type t */ - def isValueOfType(e: Expr, t: TypeTree): Boolean = { - def unWrapSome(s: Expr) = s match { - case CaseClass(_, Seq(a)) => a - case _ => s - } - (e, t) match { - case (StringLiteral(_), StringType) => true - case (IntLiteral(_), Int32Type) => true - case (InfiniteIntegerLiteral(_), IntegerType) => true - case (CharLiteral(_), CharType) => true - case (FractionalLiteral(_, _), RealType) => true - case (BooleanLiteral(_), BooleanType) => true - case (UnitLiteral(), UnitType) => true - case (GenericValue(t, _), tp) => t == tp - case (Tuple(elems), TupleType(bases)) => - elems zip bases forall (eb => isValueOfType(eb._1, eb._2)) - case (FiniteSet(elems, tbase), SetType(base)) => - tbase == base && - (elems forall isValue) - case (FiniteMap(elems, tk, tv), MapType(from, to)) => - (tk == from) < s"$tk not equal to $from" && (tv == to) < s"$tv not equal to $to" && - (elems forall (kv => isValueOfType(kv._1, from) < s"${kv._1} not a value of type ${from}" && isValueOfType(unWrapSome(kv._2), to) < s"${unWrapSome(kv._2)} not a value of type ${to}" )) - case (NonemptyArray(elems, defaultValues), ArrayType(base)) => - elems.values forall (x => isValueOfType(x, base)) - case (EmptyArray(tpe), ArrayType(base)) => - tpe == base - case (CaseClass(ct, args), ct2@AbstractClassType(classDef, tps)) => - TypeOps.isSubtypeOf(ct, ct2) < s"$ct not a subtype of $ct2" && - ((args zip ct.fieldsTypes) forall (argstyped => isValueOfType(argstyped._1, argstyped._2) < s"${argstyped._1} not a value of type ${argstyped._2}" )) - case (CaseClass(ct, args), ct2@CaseClassType(classDef, tps)) => - (ct == ct2) < s"$ct not equal to $ct2" && - ((args zip ct.fieldsTypes) forall (argstyped => isValueOfType(argstyped._1, argstyped._2))) - case (FiniteLambda(mapping, default, tpe), exTpe@FunctionType(ins, out)) => - variablesOf(e).isEmpty && - tpe == exTpe - case (Lambda(valdefs, body), FunctionType(ins, out)) => - variablesOf(e).isEmpty && - (valdefs zip ins forall (vdin => TypeOps.isSubtypeOf(vdin._2, vdin._1.getType) < s"${vdin._2} is not a subtype of ${vdin._1.getType}")) && - (TypeOps.isSubtypeOf(body.getType, out)) < s"${body.getType} is not a subtype of ${out}" - case (FiniteBag(elements, fbtpe), BagType(tpe)) => - fbtpe == tpe && elements.forall{ case (key, value) => isValueOfType(key, tpe) && isValueOfType(value, IntegerType) } - case _ => false - } - } - - /** Returns true if expr is a value. Stronger than isGround */ - val isValue = (e: Expr) => isValueOfType(e, e.getType) - - /** Returns a nested string explaining why this expression is typed the way it is.*/ - def explainTyping(e: Expr): String = { - leon.purescala.ExprOps.fold[String]{ (e, se) => - e match { - case FunctionInvocation(tfd, args) => - s"$e is of type ${e.getType}" + se.map(child => "\n " + "\n".r.replaceAllIn(child, "\n ")).mkString + s" because ${tfd.fd.id.name} was instantiated with ${tfd.fd.tparams.zip(args).map(k => k._1 +":="+k._2).mkString(",")} with type ${tfd.fd.params.map(_.getType).mkString(",")} => ${tfd.fd.returnType}" - case e => - s"$e is of type ${e.getType}" + se.map(child => "\n " + "\n".r.replaceAllIn(child, "\n ")).mkString - } - }(e) - } -} diff --git a/src/main/scala/inox/utils/Benchmarks.scala b/src/main/scala/inox/utils/Benchmarks.scala deleted file mode 100644 index 92c0e398c..000000000 --- a/src/main/scala/inox/utils/Benchmarks.scala +++ /dev/null @@ -1,72 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package utils - -import java.io.{File, PrintWriter} -import scala.io.Source - -import com.fasterxml.jackson.databind.ObjectMapper -import com.fasterxml.jackson.module.scala.experimental.ScalaObjectMapper -import com.fasterxml.jackson.module.scala.DefaultScalaModule - -import java.text._ -import java.util.Date - -class BenchmarksHistory(file: File) { - - def this(name: String) = { - this(new File(name)) - } - - val mapper = new ObjectMapper() with ScalaObjectMapper - mapper.registerModule(DefaultScalaModule) - - private[this] var entries: List[BenchmarkEntry] = { - if (file.exists) { - val str = Source.fromFile(file).mkString - - mapper.readValue[List[Map[String, Any]]](str).map(BenchmarkEntry(_)) - } else { - Nil - } - } - - def write(): Unit = { - val json = mapper.writeValueAsString(entries.map(_.fields)) - val pw = new PrintWriter(file) - try { - pw.write(json) - } finally { - pw.close() - } - } - - def +=(be: BenchmarkEntry) { - entries :+= be - } - -} - -case class BenchmarkEntry(fields: Map[String, Any]) { - def +(s: String, v: Any) = { - copy(fields + (s -> v)) - } - - def ++(es: Map[String, Any]) = { - copy(fields ++ es) - } -} - -object BenchmarkEntry { - def fromContext(ctx: LeonContext) = { - val date = new Date() - - BenchmarkEntry(Map( - "files" -> ctx.files.map(_.getAbsolutePath).mkString(" "), - "options" -> ctx.options.mkString(" "), - "date" -> new SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(date), - "ts" -> (System.currentTimeMillis() / 1000L) - )) - } -} diff --git a/src/main/scala/inox/utils/DebugSections.scala b/src/main/scala/inox/utils/DebugSections.scala deleted file mode 100644 index 7e454e7ea..000000000 --- a/src/main/scala/inox/utils/DebugSections.scala +++ /dev/null @@ -1,49 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package utils - -import scala.annotation.implicitNotFound - -@implicitNotFound("No implicit debug section found in scope. You need define an implicit DebugSection to use debug/ifDebug") -sealed abstract class DebugSection(val name: String, val mask: Int) - -case object DebugSectionSolver extends DebugSection("solver", 1 << 0) -case object DebugSectionSynthesis extends DebugSection("synthesis", 1 << 1) -case object DebugSectionTimers extends DebugSection("timers", 1 << 2) -case object DebugSectionOptions extends DebugSection("options", 1 << 3) -case object DebugSectionVerification extends DebugSection("verification", 1 << 4) -case object DebugSectionTermination extends DebugSection("termination", 1 << 5) -case object DebugSectionTrees extends DebugSection("trees", 1 << 6) -case object DebugSectionPositions extends DebugSection("positions", 1 << 7) -case object DebugSectionDataGen extends DebugSection("datagen", 1 << 8) -case object DebugSectionEvaluation extends DebugSection("eval", 1 << 9) -case object DebugSectionRepair extends DebugSection("repair", 1 << 10) -case object DebugSectionLeon extends DebugSection("leon", 1 << 11) -case object DebugSectionXLang extends DebugSection("xlang", 1 << 12) -case object DebugSectionTypes extends DebugSection("types", 1 << 13) -case object DebugSectionIsabelle extends DebugSection("isabelle", 1 << 14) -case object DebugSectionReport extends DebugSection("report", 1 << 15) -case object DebugSectionGenC extends DebugSection("genc", 1 << 16) - -object DebugSections { - val all = Set[DebugSection]( - DebugSectionSolver, - DebugSectionSynthesis, - DebugSectionTimers, - DebugSectionOptions, - DebugSectionVerification, - DebugSectionTermination, - DebugSectionTrees, - DebugSectionPositions, - DebugSectionDataGen, - DebugSectionEvaluation, - DebugSectionRepair, - DebugSectionLeon, - DebugSectionXLang, - DebugSectionTypes, - DebugSectionIsabelle, - DebugSectionReport, - DebugSectionGenC - ) -} diff --git a/src/main/scala/inox/utils/FileOutputPhase.scala b/src/main/scala/inox/utils/FileOutputPhase.scala deleted file mode 100644 index 140a863ac..000000000 --- a/src/main/scala/inox/utils/FileOutputPhase.scala +++ /dev/null @@ -1,37 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon.utils - -import leon._ -import purescala.Definitions.Program -import java.io.File - -object FileOutputPhase extends UnitPhase[Program] { - - val name = "File output" - val description = "Output parsed/generated program to the specified directory (default: leon.out)" - - val optOutputDirectory = LeonStringOptionDef("o", "Output directory", "leon.out", "dir") - - override val definedOptions: Set[LeonOptionDef[Any]] = Set(optOutputDirectory) - - def apply(ctx:LeonContext, p : Program) { - // Get the output file name from command line, or use default - val outputFolder = ctx.findOptionOrDefault(optOutputDirectory) - try { - new File(outputFolder).mkdir() - } catch { - case _ : java.io.IOException => ctx.reporter.fatalError("Could not create directory " + outputFolder) - } - - for (u <- p.units if u.isMainUnit) { - val outputFile = s"$outputFolder${File.separator}${u.id.toString}.scala" - try { u.writeScalaFile(outputFile, Some(p)) } - catch { - case _ : java.io.IOException => ctx.reporter.fatalError("Could not write on " + outputFile) - } - } - ctx.reporter.info("Output written on " + outputFolder) - } - -} diff --git a/src/main/scala/inox/utils/FilesWatcher.scala b/src/main/scala/inox/utils/FilesWatcher.scala deleted file mode 100644 index b4e579f5a..000000000 --- a/src/main/scala/inox/utils/FilesWatcher.scala +++ /dev/null @@ -1,69 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package utils - -import java.io.File -import java.nio.file._ -import scala.collection.JavaConversions._ -import java.security.MessageDigest -import java.math.BigInteger - -case class FilesWatcher(ctx: LeonContext, files: Seq[File]) { - val toWatch = files.map(_.getAbsoluteFile).toSet - val dirs = toWatch.map(_.getParentFile) - - def onChange(body: => Unit): Unit = { - val watcher = FileSystems.getDefault.newWatchService() - dirs foreach (_.toPath.register(watcher, StandardWatchEventKinds.ENTRY_MODIFY)) - var lastHashes = toWatch.map(md5file) - - body - ctx.reporter.info("Waiting for source changes...") - - while (true) { - val key = watcher.take() - - val events = key.pollEvents() - - if (events.exists{ _.context match { - case (p: Path) => - dirs exists { dir => toWatch(new File(dir, p.toFile.getName))} - case e => false - }}) { - val currentHashes = toWatch.map(md5file) - if (currentHashes != lastHashes) { - lastHashes = currentHashes - - ctx.reporter.info("Detected new version!") - body - ctx.reporter.info("Waiting for source changes...") - } - } - - key.reset() - } - } - - def md5file(f: File): String = { - val buffer = new Array[Byte](1024) - val md = MessageDigest.getInstance("MD5") - try { - val is = Files.newInputStream(f.toPath) - var read = 0 - do { - read = is.read(buffer) - if (read > 0) { - md.update(buffer, 0, read) - } - } while (read != -1) - is.close() - - val bytes = md.digest() - ("%0" + (bytes.length << 1) + "X").format(new BigInteger(1, bytes)); - } catch { - case _: RuntimeException => - "" - } - } -} diff --git a/src/main/scala/inox/utils/GraphOps.scala b/src/main/scala/inox/utils/GraphOps.scala index 576694d68..bf97bfd33 100644 --- a/src/main/scala/inox/utils/GraphOps.scala +++ b/src/main/scala/inox/utils/GraphOps.scala @@ -1,6 +1,6 @@ /* Copyright 2009-2016 EPFL, Lausanne */ -package leon.utils +package inox.utils object GraphOps { diff --git a/src/main/scala/inox/utils/GraphPrinters.scala b/src/main/scala/inox/utils/GraphPrinters.scala index c5fb26a09..e4c04a8fc 100644 --- a/src/main/scala/inox/utils/GraphPrinters.scala +++ b/src/main/scala/inox/utils/GraphPrinters.scala @@ -1,6 +1,6 @@ /* Copyright 2009-2016 EPFL, Lausanne */ -package leon +package inox package utils import Graphs._ diff --git a/src/main/scala/inox/utils/Graphs.scala b/src/main/scala/inox/utils/Graphs.scala index 9ac99b670..fb205e556 100644 --- a/src/main/scala/inox/utils/Graphs.scala +++ b/src/main/scala/inox/utils/Graphs.scala @@ -1,6 +1,6 @@ /* Copyright 2009-2016 EPFL, Lausanne */ -package leon +package inox package utils object Graphs { diff --git a/src/main/scala/inox/utils/InliningPhase.scala b/src/main/scala/inox/utils/InliningPhase.scala deleted file mode 100644 index 22b75e162..000000000 --- a/src/main/scala/inox/utils/InliningPhase.scala +++ /dev/null @@ -1,48 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon.utils - -import leon._ -import purescala.Common._ -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.TypeOps.instantiateType -import purescala.ExprOps._ -import purescala.DefOps._ -import purescala.Constructors.{caseClassSelector, application} - -object InliningPhase extends TransformationPhase { - - val name = "Inline @inline functions" - val description = "Inline functions marked as @inline and remove their definitions" - - def apply(ctx: LeonContext, p: Program): Program = { - - // Detect inlined functions that are recursive - val doNotInline = (for (fd <- p.definedFunctions.filter(_.flags(IsInlined)) if p.callGraph.isRecursive(fd)) yield { - ctx.reporter.warning("Refusing to inline recursive function '"+fd.id.asString(ctx)+"'!") - fd - }).toSet - - def doInline(fd: FunDef) = fd.flags(IsInlined) && !doNotInline(fd) - - for (fd <- p.definedFunctions) { - fd.fullBody = preMap ({ - case FunctionInvocation(tfd, args) if doInline(tfd.fd) => - Some(replaceFromIDs((tfd.params.map(_.id) zip args).toMap, tfd.fullBody)) - - case CaseClassSelector(cct, cc: CaseClass, id) => - Some(caseClassSelector(cct, cc, id)) - - case Application(caller: Lambda, args) => - Some(application(caller, args)) - - case _ => - None - }, applyRec = true)(fd.fullBody) - } - - filterFunDefs(p, fd => !doInline(fd)) - } - -} diff --git a/src/main/scala/inox/utils/InterruptManager.scala b/src/main/scala/inox/utils/InterruptManager.scala index 175018750..5377c0c6c 100644 --- a/src/main/scala/inox/utils/InterruptManager.scala +++ b/src/main/scala/inox/utils/InterruptManager.scala @@ -1,6 +1,6 @@ /* Copyright 2009-2016 EPFL, Lausanne */ -package leon +package inox package utils import scala.collection.JavaConversions._ diff --git a/src/main/scala/inox/utils/Interruptible.scala b/src/main/scala/inox/utils/Interruptible.scala index 8132c004c..ecb972315 100644 --- a/src/main/scala/inox/utils/Interruptible.scala +++ b/src/main/scala/inox/utils/Interruptible.scala @@ -1,6 +1,6 @@ /* Copyright 2009-2016 EPFL, Lausanne */ -package leon +package inox package utils trait Interruptible { diff --git a/src/main/scala/inox/utils/Library.scala b/src/main/scala/inox/utils/Library.scala deleted file mode 100644 index 3e9b99c88..000000000 --- a/src/main/scala/inox/utils/Library.scala +++ /dev/null @@ -1,49 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package utils - -import purescala.Definitions._ -import purescala.Types._ -import purescala.DefOps._ - -import scala.reflect._ - -case class Library(pgm: Program) { - lazy val List = lookup("leon.collection.List").collectFirst { case acd : AbstractClassDef => acd } - lazy val Cons = lookup("leon.collection.Cons").collectFirst { case ccd : CaseClassDef => ccd } - lazy val Nil = lookup("leon.collection.Nil").collectFirst { case ccd : CaseClassDef => ccd } - - lazy val Option = lookup("leon.lang.Option").collectFirst { case acd : AbstractClassDef => acd } - lazy val Some = lookup("leon.lang.Some").collectFirst { case ccd : CaseClassDef => ccd } - lazy val None = lookup("leon.lang.None").collectFirst { case ccd : CaseClassDef => ccd } - - lazy val StrOps = lookup("leon.lang.StrOps").collectFirst { case md: ModuleDef => md } - - lazy val Dummy = lookup("leon.lang.Dummy").collectFirst { case ccd : CaseClassDef => ccd } - - lazy val setToList = lookup("leon.collection.setToList").collectFirst { case fd : FunDef => fd } - - lazy val escape = lookup("leon.lang.StrOps.escape").collectFirst { case fd : FunDef => fd } - - lazy val mapMkString = lookup("leon.lang.Map.mkString").collectFirst { case fd : FunDef => fd } - - lazy val setMkString = lookup("leon.lang.Set.mkString").collectFirst { case fd : FunDef => fd } - - lazy val bagMkString = lookup("leon.lang.Bag.mkString").collectFirst { case fd : FunDef => fd } - - def lookup(name: String): Seq[Definition] = { - pgm.lookupAll(name) - } - - def lookupUnique[D <: Definition : ClassTag](name: String): D = { - val ct = classTag[D] - val all = pgm.lookupAll(name).filter(d => ct.runtimeClass.isInstance(d)) - assert(all.size == 1, "lookupUnique(\"name\") returned results " + all.map(_.id.uniqueName)) - all.head.asInstanceOf[D] - } - - def optionType(tp: TypeTree) = AbstractClassType(Option.get, Seq(tp)) - def someType(tp: TypeTree) = CaseClassType(Some.get, Seq(tp)) - def noneType(tp: TypeTree) = CaseClassType(None.get, Seq(tp)) -} diff --git a/src/main/scala/inox/utils/Positions.scala b/src/main/scala/inox/utils/Positions.scala index 74f999656..49fd7212b 100644 --- a/src/main/scala/inox/utils/Positions.scala +++ b/src/main/scala/inox/utils/Positions.scala @@ -1,6 +1,6 @@ /* Copyright 2009-2016 EPFL, Lausanne */ -package leon +package inox package utils import java.io.File diff --git a/src/main/scala/inox/utils/StreamUtils.scala b/src/main/scala/inox/utils/StreamUtils.scala index eddc35a0f..da89263c5 100644 --- a/src/main/scala/inox/utils/StreamUtils.scala +++ b/src/main/scala/inox/utils/StreamUtils.scala @@ -1,6 +1,6 @@ /* Copyright 2009-2016 EPFL, Lausanne */ -package leon.utils +package inox.utils object StreamUtils { @@ -24,7 +24,7 @@ object StreamUtils { } /** Applies the interleaving to a finite sequence of streams. */ - def interleave[T](streams : Seq[Stream[T]]) : Stream[T] = { + def interleave[T](streams: Seq[Stream[T]]) : Stream[T] = { if (streams.isEmpty) Stream() else { val nonEmpty = streams filter (_.nonEmpty) nonEmpty.toStream.map(_.head) #::: interleave(nonEmpty.map(_.tail)) @@ -54,7 +54,7 @@ object StreamUtils { combineRec(sa, sb)(0) } - def cartesianProduct[T](streams : Seq[Stream[T]]) : Stream[List[T]] = { + def cartesianProduct[T](streams: Seq[Stream[T]]) : Stream[List[T]] = { val dimensions = streams.size val vectorizedStreams = streams.map(new VectorizedStream(_)) diff --git a/src/main/scala/inox/utils/TemporaryInputPhase.scala b/src/main/scala/inox/utils/TemporaryInputPhase.scala deleted file mode 100644 index bf5101b4c..000000000 --- a/src/main/scala/inox/utils/TemporaryInputPhase.scala +++ /dev/null @@ -1,28 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package utils - -import java.io.{File, BufferedWriter, FileWriter} - -object TemporaryInputPhase extends SimpleLeonPhase[(List[String], List[String]), List[String]] { - - val name = "Temporary Input" - val description = "Create source files from string content" - - def apply(ctx: LeonContext, data: (List[String], List[String])): List[String] = { - val (contents, opts) = data - - val files = contents.map { content => - val file : File = File.createTempFile("leon", ".scala") - file.deleteOnExit() - val out = new BufferedWriter(new FileWriter(file)) - out.write(content) - out.close() - file.getAbsolutePath - } - - - files ::: opts - } -} diff --git a/src/main/scala/inox/utils/Timer.scala b/src/main/scala/inox/utils/Timer.scala index db5070e0d..c334abe9c 100644 --- a/src/main/scala/inox/utils/Timer.scala +++ b/src/main/scala/inox/utils/Timer.scala @@ -1,6 +1,6 @@ /* Copyright 2009-2016 EPFL, Lausanne */ -package leon +package inox package utils import scala.language.dynamics diff --git a/src/main/scala/inox/utils/TypingPhase.scala b/src/main/scala/inox/utils/TypingPhase.scala deleted file mode 100644 index ca4e3a72b..000000000 --- a/src/main/scala/inox/utils/TypingPhase.scala +++ /dev/null @@ -1,100 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package utils - -import purescala.Common._ -import purescala.ExprOps.preTraversal -import purescala.Types._ -import purescala.Expressions._ -import purescala.Definitions._ -import purescala.Constructors._ - -object TypingPhase extends SimpleLeonPhase[Program, Program] { - - val name = "Typing" - val description = "Ensure and enforce certain Leon typing rules" - - /** - * This phase does 2 different things: - * - * 1) Ensure that functions that take and/or return a specific ADT subtype - * have this encoded explicitly in pre/postconditions. Solvers such as Z3 - * unify types, which means that we need to ensure models are consistent - * with the original type constraints. - * - * 2) Report warnings in case parts of the tree are not correctly typed - * (Untyped). - * - * 3) Make sure that abstract classes have at least one descendant - */ - def apply(ctx: LeonContext, pgm: Program): Program = { - pgm.definedFunctions.foreach(fd => { - - // Part (1) - val argTypesPreconditions = fd.params.flatMap(arg => arg.getType match { - case cct: ClassType if cct.parent.isDefined => - Seq(IsInstanceOf(arg.id.toVariable, cct)) - case at: ArrayType => - Seq(GreaterEquals(ArrayLength(arg.id.toVariable), IntLiteral(0))) - case _ => - Seq() - }) - argTypesPreconditions match { - case Nil => () - case xs => fd.precondition = { - fd.precondition match { - case Some(p) => Some(andJoin(xs :+ p).copiedFrom(p)) - case None => Some(andJoin(xs)) - } - } - } - - fd.postcondition = fd.returnType match { - case ct: ClassType if ct.parent.isDefined => { - val resId = FreshIdentifier("res", ct) - fd.postcondition match { - case Some(p) => - Some(Lambda(Seq(ValDef(resId)), and( - application(p, Seq(Variable(resId))), - IsInstanceOf(Variable(resId), ct) - ).setPos(p)).setPos(p)) - - case None => - val pos = fd.body.map{ _.getPos } match { - case Some(df: DefinedPosition) => df.focusEnd - case _ => NoPosition - } - Some(Lambda(Seq(ValDef(resId)), IsInstanceOf(Variable(resId), ct)).setPos(pos)) - } - } - case _ => fd.postcondition - } - - // Part (2) - fd.body.foreach { - preTraversal{ - case t if !t.isTyped => - ctx.reporter.warning(t.getPos, "Tree "+t.asString(ctx)+" is not properly typed ("+t.getPos.fullString+")") - case _ => - } - } - - - }) - - // Part (3) - pgm.definedClasses.foreach { - case acd: AbstractClassDef => - if (acd.knownCCDescendants.isEmpty) { - ctx.reporter.error(acd.getPos, "Class "+acd.id.asString(ctx)+" has no concrete descendant!") - } - case _ => - } - - - pgm - } - -} - diff --git a/src/main/scala/inox/utils/UniqueCounter.scala b/src/main/scala/inox/utils/UniqueCounter.scala index e3c6ae0a1..9d2ff4554 100644 --- a/src/main/scala/inox/utils/UniqueCounter.scala +++ b/src/main/scala/inox/utils/UniqueCounter.scala @@ -1,6 +1,6 @@ /* Copyright 2009-2016 EPFL, Lausanne */ -package leon.utils +package inox.utils class UniqueCounter[K] { diff --git a/src/main/scala/inox/utils/UnitElimination.scala b/src/main/scala/inox/utils/UnitElimination.scala deleted file mode 100644 index d5aa0069e..000000000 --- a/src/main/scala/inox/utils/UnitElimination.scala +++ /dev/null @@ -1,155 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package utils - -import purescala.Common._ -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.Extractors._ -import purescala.Constructors._ -import purescala.Types._ - -// FIXME: Unused and untested -object UnitElimination extends TransformationPhase { - - val name = "Unit Elimination" - val description = "Remove all usage of the Unit type and value" - - private var fun2FreshFun: Map[FunDef, FunDef] = Map() - private var id2FreshId: Map[Identifier, Identifier] = Map() - - def apply(ctx: LeonContext, pgm: Program): Program = { - val newUnits = pgm.units map { u => u.copy(defs = u.defs.map { - case m: ModuleDef => - fun2FreshFun = Map() - val allFuns = m.definedFunctions - //first introduce new signatures without Unit parameters - allFuns.foreach(fd => { - if(fd.returnType != UnitType && fd.params.exists(vd => vd.getType == UnitType)) { - val freshFunDef = fd.duplicate(params = fd.params.filterNot(vd => vd.getType == UnitType)) - fun2FreshFun += (fd -> freshFunDef) - } else { - fun2FreshFun += (fd -> fd) //this will make the next step simpler - } - }) - - //then apply recursively to the bodies - val newFuns = allFuns.collect{ case fd if fd.returnType != UnitType => - val newFd = fun2FreshFun(fd) - newFd.fullBody = removeUnit(fd.fullBody) - newFd - } - - ModuleDef(m.id, m.definedClasses ++ newFuns, m.isPackageObject ) - case d => - d - })} - - - Program(newUnits) - } - - private def simplifyType(tpe: TypeTree): TypeTree = tpe match { - case TupleType(tpes) => tupleTypeWrap(tpes.map(simplifyType).filterNot{ _ == UnitType }) - case t => t - } - - //remove unit value as soon as possible, so expr should never be equal to a unit - private def removeUnit(expr: Expr): Expr = { - assert(expr.getType != UnitType) - expr match { - case fi@FunctionInvocation(tfd, args) => - val newArgs = args.filterNot(arg => arg.getType == UnitType) - FunctionInvocation(fun2FreshFun(tfd.fd).typed(tfd.tps), newArgs).setPos(fi) - - case IsTyped(Tuple(args), TupleType(tpes)) => - val newArgs = tpes.zip(args).collect { - case (tp, arg) if tp != UnitType => arg - } - tupleWrap(newArgs.map(removeUnit)) // @mk: FIXME this may actually return a Unit, is that cool? - - case ts@TupleSelect(t, index) => - val TupleType(tpes) = t.getType - val simpleTypes = tpes map simplifyType - val newArity = tpes.count(_ != UnitType) - val newIndex = simpleTypes.take(index).count(_ != UnitType) - tupleSelect(removeUnit(t), newIndex, newArity) - - case Let(id, e, b) => - if(id.getType == UnitType) - removeUnit(b) - else { - id.getType match { - case TupleType(tpes) if tpes.contains(UnitType) => { - val newTupleType = tupleTypeWrap(tpes.filterNot(_ == UnitType)) - val freshId = FreshIdentifier(id.name, newTupleType) - id2FreshId += (id -> freshId) - val newBody = removeUnit(b) - id2FreshId -= id - Let(freshId, removeUnit(e), newBody) - } - case _ => Let(id, removeUnit(e), removeUnit(b)) - } - } - - case LetDef(fds, b) => - val nonUnits = fds.filter(fd => fd.returnType != UnitType) - if(nonUnits.isEmpty) { - removeUnit(b) - } else { - val fdtoFreshFd = for(fd <- nonUnits) yield { - val m = if(fd.params.exists(vd => vd.getType == UnitType)) { - val freshFunDef = fd.duplicate(params = fd.params.filterNot(vd => vd.getType == UnitType)) - fd -> freshFunDef - } else { - fd -> fd - } - fun2FreshFun += m - m - } - for((fd, freshFunDef) <- fdtoFreshFd) { - if(fd.params.exists(vd => vd.getType == UnitType)) { - freshFunDef.fullBody = removeUnit(fd.fullBody) - } else { - fd.body = fd.body.map(b => removeUnit(b)) - } - } - val rest = removeUnit(b) - val newFds = for((fd, freshFunDef) <- fdtoFreshFd) yield { - fun2FreshFun -= fd - if(fd.params.exists(vd => vd.getType == UnitType)) { - freshFunDef - } else { - fd - } - } - - letDef(newFds, rest) - } - - case ite@IfExpr(cond, tExpr, eExpr) => - val thenRec = removeUnit(tExpr) - val elseRec = removeUnit(eExpr) - IfExpr(removeUnit(cond), thenRec, elseRec) - - case v @ Variable(id) => - if(id2FreshId.isDefinedAt(id)) - Variable(id2FreshId(id)) - else v - - case m @ MatchExpr(scrut, cses) => - val scrutRec = removeUnit(scrut) - val csesRec = cses.map{ cse => - MatchCase(cse.pattern, cse.optGuard map removeUnit, removeUnit(cse.rhs)) - } - matchExpr(scrutRec, csesRec).setPos(m) - - case Operator(args, recons) => - recons(args.map(removeUnit)) - - case _ => sys.error("not supported: " + expr) - } - } - -} diff --git a/src/main/scala/inox/utils/package.scala b/src/main/scala/inox/utils/package.scala index 09dc22742..c5a2e3943 100644 --- a/src/main/scala/inox/utils/package.scala +++ b/src/main/scala/inox/utils/package.scala @@ -1,8 +1,8 @@ /* Copyright 2009-2016 EPFL, Lausanne */ -package leon +package inox -/** Various utilities used throughout the Leon system */ +/** Various utilities used throughout the Inox system */ package object utils { /** compute the fixpoint of a function. -- GitLab