From 8b8addf3c01a238f8aef131a89019bd73f83cf09 Mon Sep 17 00:00:00 2001
From: Etienne Kneuss <colder@php.net>
Date: Thu, 4 Apr 2013 15:19:55 +0200
Subject: [PATCH] Implement --manual search mode

---
 .../scala/leon/synthesis/ManualSearch.scala   | 184 ++++++++++++++++++
 .../leon/synthesis/SynthesisOptions.scala     |   1 +
 .../scala/leon/synthesis/SynthesisPhase.scala |   4 +
 .../scala/leon/synthesis/Synthesizer.scala    |   4 +-
 4 files changed, 192 insertions(+), 1 deletion(-)
 create mode 100644 src/main/scala/leon/synthesis/ManualSearch.scala

diff --git a/src/main/scala/leon/synthesis/ManualSearch.scala b/src/main/scala/leon/synthesis/ManualSearch.scala
new file mode 100644
index 000000000..faea80ddc
--- /dev/null
+++ b/src/main/scala/leon/synthesis/ManualSearch.scala
@@ -0,0 +1,184 @@
+package leon
+package synthesis
+
+import leon.purescala.ScalaPrinter
+
+class ManualSearch(synth: Synthesizer,
+                   problem: Problem,
+                   rules: Seq[Rule],
+                   costModel: CostModel) extends SimpleSearch(synth, problem, rules, costModel) {
+
+  def this(synth: Synthesizer, problem: Problem) = {
+    this(synth, problem, synth.rules, synth.options.costModel)
+  }
+
+  import synth.reporter._
+
+  var cd       = List[Int]()
+  var cmdQueue = List[String]()
+
+  def printGraph() {
+    def pathToString(path: List[Int]): String = {
+      val p = path.reverse.drop(cd.size)
+      if (p.isEmpty) {
+        ""
+      } else {
+        " "+p.mkString(" ")
+      }
+    }
+
+    def title(str: String) = "\033[1m"+str+"\033[0m"
+    def failed(str: String) = "\033[31m"+str+"\033[0m"
+    def solved(str: String) = "\033[32m"+str+"\033[0m"
+
+    def traversePathFrom(n: g.Tree, prefix: List[Int]) {
+      n match {
+        case l: g.AndLeaf =>
+          if (prefix.endsWith(cd.reverse)) {
+            println(pathToString(prefix)+" \u2508 "+l.task.app)
+          }
+        case l: g.OrLeaf =>
+          if (prefix.endsWith(cd.reverse)) {
+            println(pathToString(prefix)+" \u2508 "+l.task.p)
+          }
+        case an: g.AndNode =>
+          if (an.isSolved) {
+            if (prefix.endsWith(cd.reverse)) {
+              println(solved(pathToString(prefix)+" \u2508 "+an.task.app))
+            }
+          } else {
+            if (prefix.endsWith(cd.reverse)) {
+              println(title(pathToString(prefix)+" \u2510 "+an.task.app))
+            }
+            for ((st, i) <- an.subTasks.zipWithIndex) {
+              traversePathFrom(an.subProblems(st), i :: prefix)
+            }
+          }
+
+        case on: g.OrNode =>
+          if (on.isSolved) {
+            if (prefix.endsWith(cd.reverse)) {
+              println(solved(pathToString(prefix)+on.task.p))
+            }
+          } else {
+            if (prefix.endsWith(cd.reverse)) {
+              println(title(pathToString(prefix)+" \u2510 "+on.task.p))
+            }
+            for ((at, i) <- on.altTasks.zipWithIndex) {
+              if (on.triedAlternatives contains at) {
+                if (prefix.endsWith(cd.reverse)) {
+                  println(failed(pathToString(i :: prefix)+" \u2508 "+at.app))
+                }
+              } else {
+                traversePathFrom(on.alternatives(at), i :: prefix)
+              }
+            }
+          }
+      }
+    }
+
+    println("-"*80)
+    traversePathFrom(g.tree, List())
+    println("-"*80)
+  }
+
+
+  override def nextLeaf(): Option[g.Leaf] = {
+    g.tree match {
+      case l: g.Leaf =>
+        Some(l)
+      case _ =>
+
+        var res: Option[g.Leaf] = None
+        var continue = true
+
+        while(continue) {
+          printGraph()
+
+          try {
+
+            print("Next action? (q to quit) "+cd.mkString(" ")+" $ ")
+            val line = if (cmdQueue.isEmpty) {
+              readLine()
+            } else {
+              val n = cmdQueue.head
+              println(n)
+              cmdQueue = cmdQueue.tail
+              n
+            }
+            if (line == "q") {
+              continue = false
+              res = None
+            } else if (line startsWith "cd") {
+              val parts = line.split("\\s+").toList
+
+              parts match {
+                case List("cd") =>
+                  cd = List()
+                case List("cd", "..") =>
+                  if (cd.size > 0) {
+                    cd = cd.dropRight(1)
+                  }
+                case "cd" :: parts =>
+                  cd = cd ::: parts.map(_.toInt)
+                case _ =>
+              }
+
+            } else if (line startsWith "p") {
+              val parts = line.split("\\s+").toList.tail.map(_.toInt)
+              traversePath(cd ::: parts) match {
+                case Some(n) =>
+                  println("#"*80)
+                  println("AT:"+n.task)
+                  val sp = programAt(n)
+                  sp.acc.foreach(fd => println(ScalaPrinter(fd)))
+                  println("$"*20)
+                  println("ROOT: "+sp.fd.id)
+                case _ =>
+              }
+
+            } else {
+              val parts = line.split("\\s+").toList
+
+              val c = parts.head.toInt
+              cmdQueue = cmdQueue ::: parts.tail
+
+              traversePath(cd ::: c :: Nil) match {
+                case Some(l: g.Leaf) =>
+                  res = Some(l)
+                  cd = cd ::: c :: Nil
+                  continue = false
+                case Some(_) =>
+                  cd = cd ::: c :: Nil
+                case None =>
+                  error("Invalid path")
+              }
+            }
+          } catch {
+            case e =>
+              error("Woops: "+e.getMessage())
+              e.printStackTrace()
+          }
+        }
+        res
+    }
+  }
+
+  override def searchStep() {
+    super.searchStep()
+
+    var continue = cd.size > 0
+    while(continue) {
+      traversePath(cd) match {
+        case Some(t) if !t.isSolved =>
+          continue = false
+        case Some(t) =>
+          cd = cd.dropRight(1)
+        case None =>
+          cd = cd.dropRight(1)
+      }
+      continue = continue && (cd.size > 0)
+    }
+  }
+
+}
diff --git a/src/main/scala/leon/synthesis/SynthesisOptions.scala b/src/main/scala/leon/synthesis/SynthesisOptions.scala
index f221fb3aa..30f0de4f8 100644
--- a/src/main/scala/leon/synthesis/SynthesisOptions.scala
+++ b/src/main/scala/leon/synthesis/SynthesisOptions.scala
@@ -10,6 +10,7 @@ case class SynthesisOptions(
   timeoutMs: Option[Long]             = None,
   costModel: CostModel                = CostModel.default,
   rules: Seq[Rule]                    = Rules.all ++ Heuristics.all,
+  manualSearch: Boolean               = false,
 
   // Cegis related options
   cegisGenerateFunCalls: Boolean      = false,
diff --git a/src/main/scala/leon/synthesis/SynthesisPhase.scala b/src/main/scala/leon/synthesis/SynthesisPhase.scala
index dc15b0b2a..230c250d0 100644
--- a/src/main/scala/leon/synthesis/SynthesisPhase.scala
+++ b/src/main/scala/leon/synthesis/SynthesisPhase.scala
@@ -16,6 +16,7 @@ object SynthesisPhase extends LeonPhase[Program, Program] {
   override val definedOptions : Set[LeonOptionDef] = Set(
     LeonFlagOptionDef(    "inplace",         "--inplace",         "Debug level"),
     LeonOptValueOptionDef("parallel",        "--parallel[=N]",    "Parallel synthesis search using N workers"),
+    LeonFlagOptionDef(    "manual",          "--manual",          "Manual search"),
     LeonFlagOptionDef(    "derivtrees",      "--derivtrees",      "Generate derivation trees"),
     LeonFlagOptionDef(    "firstonly",       "--firstonly",       "Stop as soon as one synthesis solution is found"),
     LeonValueOptionDef(   "timeout",         "--timeout=T",       "Timeout after T seconds when searching for synthesis solutions .."),
@@ -28,6 +29,9 @@ object SynthesisPhase extends LeonPhase[Program, Program] {
     var options = SynthesisOptions()
 
     for(opt <- ctx.options) opt match {
+      case LeonFlagOption("manual") =>
+        options = options.copy(manualSearch = true)
+
       case LeonFlagOption("inplace") =>
         options = options.copy(inPlace = true)
 
diff --git a/src/main/scala/leon/synthesis/Synthesizer.scala b/src/main/scala/leon/synthesis/Synthesizer.scala
index 6b792aea2..ac2acc6f5 100644
--- a/src/main/scala/leon/synthesis/Synthesizer.scala
+++ b/src/main/scala/leon/synthesis/Synthesizer.scala
@@ -40,7 +40,9 @@ class Synthesizer(val context : LeonContext,
 
   def synthesize(): (Solution, Boolean) = {
 
-    val search = if (options.searchWorkers > 1) {
+    val search = if (options.manualSearch) {
+        new ManualSearch(this, problem)
+      } else if (options.searchWorkers > 1) {
         new ParallelSearch(this, problem, options.searchWorkers)
       } else {
         new SimpleSearch(this, problem)
-- 
GitLab