From 38a639c94c6f861544addbd38a03ce869721f421 Mon Sep 17 00:00:00 2001
From: Etienne Kneuss <ekneuss@gmail.com>
Date: Wed, 24 Oct 2012 21:29:40 +0200
Subject: [PATCH] Type phases so that they act like functions, define
 pipelines, process options differently, not yet finished

---
 src/main/scala/leon/Analysis.scala            |   4 +-
 src/main/scala/leon/ArrayTransformation.scala |   2 +-
 src/main/scala/leon/EpsilonElimination.scala  |   2 +-
 src/main/scala/leon/FunctionClosure.scala     |   2 +-
 src/main/scala/leon/FunctionHoisting.scala    |   2 +-
 .../leon/ImperativeCodeElimination.scala      |   2 +-
 src/main/scala/leon/LeonContext.scala         |   5 +-
 src/main/scala/leon/LeonOption.scala          |  13 ++
 src/main/scala/leon/LeonPhase.scala           |  51 ++++---
 src/main/scala/leon/Main.scala                | 139 +++++++++++++-----
 src/main/scala/leon/Pipeline.scala            |  26 ++++
 src/main/scala/leon/Settings.scala            |   6 +
 src/main/scala/leon/Simplificator.scala       |   2 +-
 src/main/scala/leon/TypeChecking.scala        |   4 +-
 src/main/scala/leon/UnitElimination.scala     |   2 +-
 .../scala/leon/plugin/AnalysisComponent.scala |   2 +-
 .../scala/leon/plugin/ExtractorPhase.scala    |  28 ++--
 .../scala/leon/synthesis/SynthesisPhase.scala |   8 +-
 18 files changed, 204 insertions(+), 96 deletions(-)
 create mode 100644 src/main/scala/leon/LeonOption.scala
 create mode 100644 src/main/scala/leon/Pipeline.scala

diff --git a/src/main/scala/leon/Analysis.scala b/src/main/scala/leon/Analysis.scala
index da85f2438..ce98c70aa 100644
--- a/src/main/scala/leon/Analysis.scala
+++ b/src/main/scala/leon/Analysis.scala
@@ -328,11 +328,11 @@ object Analysis {
   }
 }
 
-object AnalysisPhase extends UnitPhase {
+object AnalysisPhase extends UnitPhase[Program] {
   val name = "Analysis"
   val description = "Leon Analyses"
 
-  def apply(program: Program) {
+  def apply(ctx: LeonContext, program: Program) {
     new Analysis(program).analyse
   }
 }
diff --git a/src/main/scala/leon/ArrayTransformation.scala b/src/main/scala/leon/ArrayTransformation.scala
index e4ce14c31..9d0403039 100644
--- a/src/main/scala/leon/ArrayTransformation.scala
+++ b/src/main/scala/leon/ArrayTransformation.scala
@@ -10,7 +10,7 @@ object ArrayTransformation extends TransformationPhase {
   val name = "Array Transformation"
   val description = "Add bound checking for array access and remove array update with side effect"
 
-  def apply(pgm: Program): Program = {
+  def apply(ctx: LeonContext, pgm: Program): Program = {
 
     val allFuns = pgm.definedFunctions
     allFuns.foreach(fd => {
diff --git a/src/main/scala/leon/EpsilonElimination.scala b/src/main/scala/leon/EpsilonElimination.scala
index e624c47f6..a785ddf9e 100644
--- a/src/main/scala/leon/EpsilonElimination.scala
+++ b/src/main/scala/leon/EpsilonElimination.scala
@@ -10,7 +10,7 @@ object EpsilonElimination extends TransformationPhase {
   val name = "Epsilon Elimination"
   val description = "Remove all epsilons from the program"
 
-  def apply(pgm: Program): Program = {
+  def apply(ctx: LeonContext, pgm: Program): Program = {
 
     val allFuns = pgm.definedFunctions
     allFuns.foreach(fd => fd.body.map(body => {
diff --git a/src/main/scala/leon/FunctionClosure.scala b/src/main/scala/leon/FunctionClosure.scala
index 7ce87d90b..ed5920185 100644
--- a/src/main/scala/leon/FunctionClosure.scala
+++ b/src/main/scala/leon/FunctionClosure.scala
@@ -15,7 +15,7 @@ object FunctionClosure extends TransformationPhase{
   private var newFunDefs: Map[FunDef, FunDef] = Map()
   private var topLevelFuns: Set[FunDef] = Set()
 
-  def apply(program: Program): Program = {
+  def apply(ctx: LeonContext, program: Program): Program = {
     newFunDefs = Map()
     val funDefs = program.definedFunctions
     funDefs.foreach(fd => {
diff --git a/src/main/scala/leon/FunctionHoisting.scala b/src/main/scala/leon/FunctionHoisting.scala
index 813afd9ed..d0fd3f83f 100644
--- a/src/main/scala/leon/FunctionHoisting.scala
+++ b/src/main/scala/leon/FunctionHoisting.scala
@@ -10,7 +10,7 @@ object FunctionHoisting extends TransformationPhase {
   val name = "Function Hoisting"
   val description = "Hoist function at the top level"
 
-  def apply(program: Program): Program = {
+  def apply(ctx: LeonContext, program: Program): Program = {
     val funDefs = program.definedFunctions
     var topLevelFuns: Set[FunDef] = Set()
     funDefs.foreach(fd => {
diff --git a/src/main/scala/leon/ImperativeCodeElimination.scala b/src/main/scala/leon/ImperativeCodeElimination.scala
index 6f27fc3b4..18243ea6b 100644
--- a/src/main/scala/leon/ImperativeCodeElimination.scala
+++ b/src/main/scala/leon/ImperativeCodeElimination.scala
@@ -13,7 +13,7 @@ object ImperativeCodeElimination extends TransformationPhase {
   private var varInScope = Set[Identifier]()
   private var parent: FunDef = null //the enclosing fundef
 
-  def apply(pgm: Program): Program = {
+  def apply(ctx: LeonContext, pgm: Program): Program = {
     val allFuns = pgm.definedFunctions
     allFuns.foreach(fd => fd.body.map(body => {
       parent = fd
diff --git a/src/main/scala/leon/LeonContext.scala b/src/main/scala/leon/LeonContext.scala
index a7f4dc982..9037c5f9e 100644
--- a/src/main/scala/leon/LeonContext.scala
+++ b/src/main/scala/leon/LeonContext.scala
@@ -3,8 +3,7 @@ package leon
 import purescala.Definitions.Program
 
 case class LeonContext(
-  val options: List[String] = List(),
-  val program: Option[Program] = None,
-  val reporter: Reporter = new DefaultReporter
+  val settings: Settings          = Settings(),
+  val reporter: Reporter          = new DefaultReporter
 )
 
diff --git a/src/main/scala/leon/LeonOption.scala b/src/main/scala/leon/LeonOption.scala
new file mode 100644
index 000000000..f2579e5af
--- /dev/null
+++ b/src/main/scala/leon/LeonOption.scala
@@ -0,0 +1,13 @@
+package leon
+
+sealed abstract class LeonOption {
+  val name: String
+}
+
+case class LeonFlagOption(name: String) extends LeonOption
+case class LeonValueOption(name: String, value: String) extends LeonOption {
+
+  def splitList : Seq[String] = value.split(':').map(_.trim).filter(!_.isEmpty)
+}
+
+case class LeonOptionDef(name: String, isFlag: Boolean, description: String)
diff --git a/src/main/scala/leon/LeonPhase.scala b/src/main/scala/leon/LeonPhase.scala
index 8e95d973e..d81b8eea7 100644
--- a/src/main/scala/leon/LeonPhase.scala
+++ b/src/main/scala/leon/LeonPhase.scala
@@ -2,37 +2,40 @@ package leon
 
 import purescala.Definitions.Program
 
-abstract class LeonPhase {
+abstract class LeonPhase[F, T] {
   val name: String
   val description: String
 
-  def run(ac: LeonContext): LeonContext
+  def definedOptions: Set[LeonOptionDef] = Set()
+
+  def run(ac: LeonContext)(v: F): T
 }
 
-abstract class TransformationPhase extends LeonPhase {
-  def apply(p: Program): Program
-
-  override def run(ctx: LeonContext) = {
-    ctx.program match {
-      case Some(p) =>
-        ctx.copy(program = Some(apply(p)))
-      case None =>
-        ctx.reporter.fatalError("Empty program at this point?!?")
-        ctx
-    }
+abstract class TransformationPhase extends LeonPhase[Program, Program] {
+  def apply(ctx: LeonContext, p: Program): Program
+
+  override def run(ctx: LeonContext)(p: Program) = {
+    apply(ctx, p)
   }
 }
 
-abstract class UnitPhase extends LeonPhase {
-  def apply(p: Program): Unit
-
-  override def run(ctx: LeonContext) = { 
-    ctx.program match {
-      case Some(p) =>
-        apply(p)
-      case None =>
-        ctx.reporter.fatalError("Empty program at this point?!?")
-    }
-    ctx
+abstract class UnitPhase[Program] extends LeonPhase[Program, Program] {
+  def apply(ctx: LeonContext, p: Program): Unit
+
+  override def run(ctx: LeonContext)(p: Program) = {
+    apply(ctx, p)
+    p
   }
 }
+
+case class NoopPhase[T]() extends LeonPhase[T, T] {
+  val name = "noop";
+  val description = "no-op"
+  override def run(ctx: LeonContext)(v: T) = v
+}
+
+case class ExitPhase[T]() extends LeonPhase[T, Unit] {
+  val name = "end";
+  val description = "end"
+  override def run(ctx: LeonContext)(v: T) = ()
+}
diff --git a/src/main/scala/leon/Main.scala b/src/main/scala/leon/Main.scala
index a345b4b80..013ae1cf5 100644
--- a/src/main/scala/leon/Main.scala
+++ b/src/main/scala/leon/Main.scala
@@ -1,56 +1,119 @@
 package leon
 
-import synthesis.SynthesisPhase
-import plugin.ExtractionPhase
-
 object Main {
 
-  def computeLeonPhases: List[LeonPhase] = {
+  def allPhases: List[LeonPhase[_, _]] = {
     List(
-      List(
-        ExtractionPhase
-      ),
-      if (Settings.transformProgram) {
-        List(
-          ArrayTransformation,
-          EpsilonElimination,
-          ImperativeCodeElimination,
-          /*UnitElimination,*/
-          FunctionClosure,
-          /*FunctionHoisting,*/
-          Simplificator
-        )
+      plugin.ExtractionPhase,
+      ArrayTransformation,
+      EpsilonElimination,
+      ImperativeCodeElimination,
+      /*UnitElimination,*/
+      FunctionClosure,
+      /*FunctionHoisting,*/
+      Simplificator,
+      synthesis.SynthesisPhase,
+      AnalysisPhase
+    )
+  }
+
+  def processOptions(reporter: Reporter, args: List[String]) = {
+    val phases = allPhases
+
+    val allOptions = allPhases.flatMap(_.definedOptions) ++ Set(
+      LeonOptionDef("synthesis",     true,  "--synthesis          Magic!"),
+      LeonOptionDef("xlang",         true,  "--xlang              Preprocessing and transformation from extended programs")
+    )
+
+    val allOptionsMap = allOptions.map(o => o.name -> o).toMap
+
+    // Detect unknown options:
+    val options = args.filter(_.startsWith("--"))
+
+    val leonOptions = options.flatMap { opt =>
+      val leonOpt: LeonOption = opt.substring(2, opt.length).split("=", 2).toList match {
+        case List(name, value) =>
+          LeonValueOption(name, value)
+        case List(name) => name
+          LeonFlagOption(name)
+        case _ =>
+          reporter.fatalError("Woot?")
+      }
+
+      if (allOptionsMap contains leonOpt.name) {
+        (allOptionsMap(leonOpt.name).isFlag, leonOpt) match {
+          case (true,  LeonFlagOption(name)) =>
+            Some(leonOpt)
+          case (false, LeonValueOption(name, value)) =>
+            Some(leonOpt)
+          case _ =>
+            reporter.fatalError("Invalid option usage")
+            None
+        }
       } else {
-        Nil
+        None
       }
-    ,
-      if (Settings.synthesis) {
-        List(
-          SynthesisPhase
-        )
+    }
+
+    var settings  = Settings()
+
+    // Process options we understand:
+    for(opt <- leonOptions) opt match {
+      case LeonFlagOption("synthesis") =>
+        settings = settings.copy(synthesis = true, xlang = false, analyze = false)
+      case LeonFlagOption("xlang") =>
+        settings = settings.copy(synthesis = false, xlang = true)
+      case _ =>
+    }
+
+    LeonContext(settings = settings, reporter = reporter)
+  }
+
+  implicit def phaseToPipeline[F, T](phase: LeonPhase[F, T]): Pipeline[F, T] = new PipeCons(phase, new PipeNil())
+
+  def computePipeLine(settings: Settings): Pipeline[List[String], Unit] = {
+    import purescala.Definitions.Program
+
+    val pipeBegin = phaseToPipeline(plugin.ExtractionPhase)
+
+    val pipeTransforms: Pipeline[Program, Program] =
+      if (settings.xlang) {
+        ArrayTransformation andThen
+        EpsilonElimination andThen
+        ImperativeCodeElimination
       } else {
-        Nil
+        NoopPhase[Program]()
       }
-    ,
-      if (!Settings.stopAfterTransformation) {
-        List(
-          AnalysisPhase
-        )
+
+    val pipeSynthesis: Pipeline[Program, Program] =
+      if (settings.synthesis) {
+        synthesis.SynthesisPhase
       } else {
-        Nil
+        NoopPhase[Program]()
       }
-    ).flatten
+
+    val pipeAnalysis: Pipeline[Program, Program] =
+      if (settings.analyze) {
+        AnalysisPhase
+      } else {
+        NoopPhase[Program]()
+      }
+
+    pipeBegin followedBy
+    pipeTransforms followedBy
+    pipeSynthesis followedBy
+    pipeAnalysis andThen
+    ExitPhase()
   }
 
   def main(args : Array[String]) : Unit = {
-    var ctx = LeonContext(options = args.toList)
+    val reporter = new DefaultReporter()
 
-    val phases = computeLeonPhases
+    // Process options
+    val ctx = processOptions(reporter, args.toList)
 
-    for ((phase, i) <- phases.zipWithIndex) {
-      ctx.reporter.info("%2d".format(i)+": "+phase.name)
-      ctx = phase.run(ctx)
-    }
+    val pipeline = computePipeLine(ctx.settings)
+
+    pipeline.run(ctx)(args.toList)
   }
 }
-
diff --git a/src/main/scala/leon/Pipeline.scala b/src/main/scala/leon/Pipeline.scala
new file mode 100644
index 000000000..477a08c31
--- /dev/null
+++ b/src/main/scala/leon/Pipeline.scala
@@ -0,0 +1,26 @@
+package leon
+
+abstract class Pipeline[F, T] {
+  def andThen[G](then: LeonPhase[T, G]): Pipeline[F, G];
+  def followedBy[G](pipe: Pipeline[T, G]): Pipeline[F, G];
+  def run(ctx: LeonContext)(v: F): T;
+}
+
+class PipeCons[F, V, T](phase: LeonPhase[F, V], then: Pipeline[V, T]) extends Pipeline[F, T] {
+  def andThen[G](last: LeonPhase[T, G]) = new PipeCons(phase, then andThen last)
+  def followedBy[G](pipe: Pipeline[T, G]) = new PipeCons(phase, then followedBy pipe);
+  def run(ctx: LeonContext)(v: F): T = then.run(ctx)(phase.run(ctx)(v))
+
+  override def toString = {
+    phase.toString + " -> " + then.toString
+  }
+}
+
+class PipeNil[T]() extends Pipeline[T,T] {
+  def andThen[G](last: LeonPhase[T, G]) = new PipeCons(last, new PipeNil)
+  def run(ctx: LeonContext)(v: T): T = v
+  def followedBy[G](pipe: Pipeline[T, G]) = pipe;
+  override def toString = {
+    "|"
+  }
+}
diff --git a/src/main/scala/leon/Settings.scala b/src/main/scala/leon/Settings.scala
index 4c15d2cbe..c57a1e451 100644
--- a/src/main/scala/leon/Settings.scala
+++ b/src/main/scala/leon/Settings.scala
@@ -34,3 +34,9 @@ object Settings {
   var stopAfterAnalysis: Boolean             = true
   var silentlyTolerateNonPureBodies: Boolean = false
 }
+
+case class Settings(
+  val synthesis: Boolean    = false,
+  val xlang: Boolean        = false,
+  val analyze: Boolean      = true
+)
diff --git a/src/main/scala/leon/Simplificator.scala b/src/main/scala/leon/Simplificator.scala
index 6fb6f9ef3..92b54072a 100644
--- a/src/main/scala/leon/Simplificator.scala
+++ b/src/main/scala/leon/Simplificator.scala
@@ -10,7 +10,7 @@ object Simplificator extends TransformationPhase {
   val name = "Simplificator"
   val description = "Some safe and minimal simplification"
 
-  def apply(pgm: Program): Program = {
+  def apply(ctx: LeonContext, pgm: Program): Program = {
 
     val allFuns = pgm.definedFunctions
     allFuns.foreach(fd => {
diff --git a/src/main/scala/leon/TypeChecking.scala b/src/main/scala/leon/TypeChecking.scala
index d0fd72316..08291badc 100644
--- a/src/main/scala/leon/TypeChecking.scala
+++ b/src/main/scala/leon/TypeChecking.scala
@@ -5,12 +5,12 @@ import purescala.Definitions._
 import purescala.Trees._
 import purescala.TypeTrees._
 
-object TypeChecking extends UnitPhase {
+object TypeChecking extends UnitPhase[Program] {
 
   val name = "Type Checking"
   val description = "Type check the AST"
 
-  def apply(pgm: Program): Unit = {
+  def apply(ctx: LeonContext, pgm: Program): Unit = {
     val allFuns = pgm.definedFunctions
 
     allFuns.foreach(fd  => {
diff --git a/src/main/scala/leon/UnitElimination.scala b/src/main/scala/leon/UnitElimination.scala
index 9604c3a4d..d30ac296f 100644
--- a/src/main/scala/leon/UnitElimination.scala
+++ b/src/main/scala/leon/UnitElimination.scala
@@ -13,7 +13,7 @@ object UnitElimination extends TransformationPhase {
   private var fun2FreshFun: Map[FunDef, FunDef] = Map()
   private var id2FreshId: Map[Identifier, Identifier] = Map()
 
-  def apply(pgm: Program): Program = {
+  def apply(ctx: LeonContext, pgm: Program): Program = {
     fun2FreshFun = Map()
     val allFuns = pgm.definedFunctions
 
diff --git a/src/main/scala/leon/plugin/AnalysisComponent.scala b/src/main/scala/leon/plugin/AnalysisComponent.scala
index 467c524c7..77c058a6f 100644
--- a/src/main/scala/leon/plugin/AnalysisComponent.scala
+++ b/src/main/scala/leon/plugin/AnalysisComponent.scala
@@ -39,7 +39,7 @@ class AnalysisComponent(val global: Global, val pluginInstance: LeonPlugin)
       fresh = unit.fresh
 
       
-      pluginInstance.global.ctx = pluginInstance.global.ctx.copy(program = Some(extractCode(unit)))
+      pluginInstance.global.program = Some(extractCode(unit))
     }
   }
 }
diff --git a/src/main/scala/leon/plugin/ExtractorPhase.scala b/src/main/scala/leon/plugin/ExtractorPhase.scala
index 25ef79874..ef94d03e5 100644
--- a/src/main/scala/leon/plugin/ExtractorPhase.scala
+++ b/src/main/scala/leon/plugin/ExtractorPhase.scala
@@ -1,43 +1,43 @@
 package leon
 package plugin
 
+import purescala.Definitions.Program
 import scala.tools.nsc.{Global,Settings=>NSCSettings,SubComponent,CompilerCommand}
 
-object ExtractionPhase extends LeonPhase {
+object ExtractionPhase extends LeonPhase[List[String], Program] {
 
   val name = "Extraction"
   val description = "Extraction of trees from the Scala Compiler"
 
-  def run(ctx: LeonContext): LeonContext = {
+  def run(ctx: LeonContext)(args: List[String]): Program = {
 
     val settings = new NSCSettings
-    val compilerOpts = ctx.options.filterNot(_.startsWith("--"))
+    val compilerOpts = args.filterNot(_.startsWith("--"))
 
     val command = new CompilerCommand(compilerOpts, settings) {
       override val cmdName = "leon"
     }
 
-    val newCtx = if(command.ok) {
-      val runner = new PluginRunner(settings, ctx)
+    if(command.ok) {
+      val runner = new PluginRunner(settings, ctx, None)
       val run = new runner.Run
       run.compile(command.files)
 
-      runner.ctx
+      runner.program match {
+        case Some(p) =>
+          p
+        case None =>
+          ctx.reporter.fatalError("Error while compiling.")
+      }
     } else {
-      ctx
-    }
-
-    if (newCtx.program.isDefined) {
-      newCtx
-    } else {
-      newCtx.reporter.fatalError("No input program.")
+      ctx.reporter.fatalError("No input program.")
     }
   }
 }
 
 /** This class is a compiler that will be used for running the plugin in
  * standalone mode. Original version courtesy of D. Zufferey. */
-class PluginRunner(settings : NSCSettings, var ctx: LeonContext) extends Global(settings, new SimpleReporter(settings, ctx.reporter)) {
+class PluginRunner(settings : NSCSettings, ctx: LeonContext, var program: Option[Program]) extends Global(settings, new SimpleReporter(settings, ctx.reporter)) {
   val leonPlugin = new LeonPlugin(this)
 
   protected def myAddToPhasesSet(sub : SubComponent, descr : String) : Unit = {
diff --git a/src/main/scala/leon/synthesis/SynthesisPhase.scala b/src/main/scala/leon/synthesis/SynthesisPhase.scala
index 2041aa013..f887def32 100644
--- a/src/main/scala/leon/synthesis/SynthesisPhase.scala
+++ b/src/main/scala/leon/synthesis/SynthesisPhase.scala
@@ -3,19 +3,17 @@ package synthesis
 
 import purescala.Definitions.Program
 
-object SynthesisPhase extends LeonPhase {
+object SynthesisPhase extends TransformationPhase {
   val name        = "Synthesis"
   val description = "Synthesis"
 
-  def run(ctx: LeonContext): LeonContext = {
+  def apply(ctx: LeonContext, p: Program): Program = {
     val quietReporter = new QuietReporter
     val solvers  = List(
       new TrivialSolver(quietReporter),
       new FairZ3Solver(quietReporter)
     )
 
-    val newProgram = new Synthesizer(ctx.reporter, solvers).synthesizeAll(ctx.program.get)
-
-    ctx.copy(program = Some(newProgram))
+    new Synthesizer(ctx.reporter, solvers).synthesizeAll(p)
   }
 }
-- 
GitLab