diff --git a/src/main/scala/leon/LeonContext.scala b/src/main/scala/leon/LeonContext.scala index 7aa76a9713f7e75ad747eb7e4ba8c88173e63861..e2145ed872e5cb3c397104b30b8538cad7f4a127 100644 --- a/src/main/scala/leon/LeonContext.scala +++ b/src/main/scala/leon/LeonContext.scala @@ -1,11 +1,12 @@ package leon import purescala.Definitions.Program +import java.io.File case class LeonContext( val settings: Settings = Settings(), val options: List[LeonOption] = Nil, - val files: List[String] = Nil, + val files: List[File] = Nil, val reporter: Reporter = new DefaultReporter ) diff --git a/src/main/scala/leon/Main.scala b/src/main/scala/leon/Main.scala index ace9c2d29bdff42abcd2c4b88b3e52611580740d..fc26196b608dad2c67e0125c600c4c1197695fa7 100644 --- a/src/main/scala/leon/Main.scala +++ b/src/main/scala/leon/Main.scala @@ -45,7 +45,7 @@ object Main { // Detect unknown options: val options = args.filter(_.startsWith("--")) - val files = args.filterNot(_.startsWith("-")) + val files = args.filterNot(_.startsWith("-")).map(new java.io.File(_)) val leonOptions = options.flatMap { opt => val leonOpt: LeonOption = opt.substring(2, opt.length).split("=", 2).toList match { diff --git a/src/main/scala/leon/synthesis/FileInterface.scala b/src/main/scala/leon/synthesis/FileInterface.scala new file mode 100644 index 0000000000000000000000000000000000000000..ad19ba6c78528032af34f62e5125c58a288197a4 --- /dev/null +++ b/src/main/scala/leon/synthesis/FileInterface.scala @@ -0,0 +1,111 @@ +package leon +package synthesis + +import purescala.Trees._ +import purescala.ScalaPrinter + +import java.io.File +class FileInterface(reporter: Reporter, origFile: File) { + + def updateFile(solutions: Map[Choose, Expr], ignoreMissing: Boolean = false) { + import java.io.{File, BufferedWriter, FileWriter} + val FileExt = """^(.+)\.([^.]+)$""".r + + origFile.getAbsolutePath() match { + case FileExt(path, "scala") => + var i = 0 + def savePath = path+"."+i+".scala" + while (new File(savePath).isFile()) { + i += 1 + } + + val origCode = readFile(origFile) + val backup = new File(savePath) + val newFile = new File(origFile.getAbsolutePath()) + origFile.renameTo(backup) + + + val newCode = substitueChooses(origCode, solutions, ignoreMissing) + + val out = new BufferedWriter(new FileWriter(newFile)) + out.write(newCode) + out.close + case _ => + + } + } + + def substitueChooses(str: String, solutions: Map[Choose, Expr], ignoreMissing: Boolean = false): String = { + var lines = List[Int]() + + // Compute line positions + var lastFound = -1 + do { + lastFound = str.indexOf('\n', lastFound+1) + + if (lastFound > -1) { + lines = lastFound :: lines + } + } while(lastFound> 0) + lines = lines.reverse; + + def lineOf(offset: Int): (Int, Int) = { + lines.zipWithIndex.find(_._1 > offset) match { + case Some((off, no)) => + (no+1, if (no > 0) lines(no-1) else 0) + case None => + (lines.size+1, lines.lastOption.getOrElse(0)) + } + } + + lastFound = -1 + + var newStr = str + var newStrOffset = 0 + + do { + lastFound = str.indexOf("choose", lastFound+1) + + if (lastFound > -1) { + val (lineno, lineoffset) = lineOf(lastFound) + // compute scala equivalent of the position: + val scalaOffset = str.substring(lineoffset, lastFound).replaceAll("\t", " "*8).length + + solutions.find(_._1.posIntInfo == (lineno, scalaOffset)) match { + case Some((choose, solution)) => + var lvl = 0; + var i = lastFound + 6; + var continue = true; + do { + val c = str.charAt(i) + if (c == '(' || c == '{') { + lvl += 1 + } else if (c == ')' || c == '}') { + lvl -= 1 + if (lvl == 0) { + continue = false + } + } + i += 1 + } while(continue) + + val newCode = ScalaPrinter(solution) + newStr = (newStr.substring(0, lastFound+newStrOffset))+newCode+(newStr.substring(i+newStrOffset, newStr.length)) + + newStrOffset += -(i-lastFound)+newCode.length + + case _ => + if (!ignoreMissing) { + reporter.warning("Could not find solution corresponding to choose at "+lineno+":"+scalaOffset) + } + } + } + } while(lastFound> 0) + + newStr + } + + def readFile(file: File): String = { + scala.io.Source.fromFile(file).mkString + } +} diff --git a/src/main/scala/leon/synthesis/SynthesisPhase.scala b/src/main/scala/leon/synthesis/SynthesisPhase.scala index 91134fcbd17db9bd8e30059774ba0d5dc230ae3d..3138c2f4bf5ba4e9de1f23f1c4790ecaa931c511 100644 --- a/src/main/scala/leon/synthesis/SynthesisPhase.scala +++ b/src/main/scala/leon/synthesis/SynthesisPhase.scala @@ -4,6 +4,9 @@ package synthesis import solvers.TrivialSolver import solvers.z3.FairZ3Solver +import purescala.TreeOps.simplifyLets +import purescala.Trees.Expr +import purescala.ScalaPrinter import purescala.Definitions.Program object SynthesisPhase extends LeonPhase[Program, Program] { @@ -11,7 +14,8 @@ object SynthesisPhase extends LeonPhase[Program, Program] { val description = "Synthesis" override def definedOptions = Set( - LeonFlagOptionDef("inplace", "--inplace", "Debug level") + LeonFlagOptionDef("inplace", "--inplace", "Debug level"), + LeonFlagOptionDef("derivtrees", "--derivtrees", "Generate derivation trees") ) def run(ctx: LeonContext)(p: Program): Program = { @@ -21,19 +25,37 @@ object SynthesisPhase extends LeonPhase[Program, Program] { new FairZ3Solver(quietReporter) ) - var inPlace = false + var inPlace = false + var genTrees = false for(opt <- ctx.options) opt match { case LeonFlagOption("inplace") => inPlace = true + case LeonFlagOption("derivtrees") => + genTrees = true case _ => } - val synth = new Synthesizer(ctx.reporter, solvers) + val synth = new Synthesizer(ctx.reporter, solvers, genTrees) val solutions = synth.synthesizeAll(p) + + // Simplify expressions + val simplifiers = List[Expr => Expr]( + simplifyLets _ + ) + + val chooseToExprs = solutions.mapValues(sol => simplifiers.foldLeft(sol.toExpr){ (x, sim) => sim(x) }) + if (inPlace) { for (file <- ctx.files) { - synth.updateFile(new java.io.File(file), solutions) + new FileInterface(ctx.reporter, file).updateFile(chooseToExprs) + } + } else { + for ((chs, ex) <- chooseToExprs) { + ctx.reporter.info("-"*32+" Synthesis of: "+"-"*32) + ctx.reporter.info(chs) + ctx.reporter.info("-"*35+" Result: "+"-"*35) + ctx.reporter.info(ScalaPrinter(ex)) } } diff --git a/src/main/scala/leon/synthesis/Synthesizer.scala b/src/main/scala/leon/synthesis/Synthesizer.scala index 74771d80affd87a37d8918dda733e81134010fbe..0dc4cb283f48c0674b8a3638a1e4f875f7e9affc 100644 --- a/src/main/scala/leon/synthesis/Synthesizer.scala +++ b/src/main/scala/leon/synthesis/Synthesizer.scala @@ -12,7 +12,7 @@ import java.io.File import collection.mutable.PriorityQueue -class Synthesizer(val r: Reporter, val solvers: List[Solver]) { +class Synthesizer(val r: Reporter, val solvers: List[Solver], generateDerivationTrees: Boolean) { import r.{error,warning,info,fatalError} private[this] var solution: Option[Solution] = None @@ -28,7 +28,9 @@ class Synthesizer(val r: Reporter, val solvers: List[Solver]) { workList += rootTask solution = None - derivationTree = new DerivationTree(rootTask) + if (generateDerivationTrees) { + derivationTree = new DerivationTree(rootTask) + } while (!workList.isEmpty && solution.isEmpty) { val task = workList.dequeue() @@ -56,15 +58,19 @@ class Synthesizer(val r: Reporter, val solvers: List[Solver]) { } - derivationTree.toDotFile("derivation"+derivationCounter+".dot") - derivationCounter += 1 + if (generateDerivationTrees) { + derivationTree.toDotFile("derivation"+derivationCounter+".dot") + derivationCounter += 1 + } solution.getOrElse(Solution.none) } def onTaskSucceeded(task: Task, solution: Solution) { info(" => Solved "+task.problem+" ⊢ "+solution) - derivationTree.recordSolutionFor(task, solution) + if (generateDerivationTrees) { + derivationTree.recordSolutionFor(task, solution) + } if (task.parent eq null) { info(" SUCCESS!") @@ -104,18 +110,10 @@ class Synthesizer(val r: Reporter, val solvers: List[Solver]) { val as = (variablesOf(pred)--xs).toList val phi = pred - info("") - info("") - info("In Function "+f.id+":") - info("-"*80) - val sol = synthesize(Problem(as, phi, xs), rules) solutions += ch -> sol - info("Scala code:") - info(ScalaPrinter(simplifyLets(sol.toExpr))) - a case _ => a @@ -130,109 +128,9 @@ class Synthesizer(val r: Reporter, val solvers: List[Solver]) { } - def substitueChooses(str: String, solutions: Map[Choose, Solution], ignoreMissing: Boolean = false): String = { - var lines = List[Int]() - - // Compute line positions - var lastFound = -1 - do { - lastFound = str.indexOf('\n', lastFound+1) - - if (lastFound > -1) { - lines = lastFound :: lines - } - } while(lastFound> 0) - lines = lines.reverse; - - def lineOf(offset: Int): (Int, Int) = { - lines.zipWithIndex.find(_._1 > offset) match { - case Some((off, no)) => - (no+1, if (no > 0) lines(no-1) else 0) - case None => - (lines.size+1, lines.lastOption.getOrElse(0)) - } - } - - lastFound = -1 - - var newStr = str - var newStrOffset = 0 - - do { - lastFound = str.indexOf("choose", lastFound+1) - - if (lastFound > -1) { - val (lineno, lineoffset) = lineOf(lastFound) - // compute scala equivalent of the position: - val scalaOffset = str.substring(lineoffset, lastFound).replaceAll("\t", " "*8).length - - solutions.find(_._1.posIntInfo == (lineno, scalaOffset)) match { - case Some((choose, solution)) => - var lvl = 0; - var i = lastFound + 6; - var continue = true; - do { - val c = str.charAt(i) - if (c == '(' || c == '{') { - lvl += 1 - } else if (c == ')' || c == '}') { - lvl -= 1 - if (lvl == 0) { - continue = false - } - } - i += 1 - } while(continue) - - val newCode = solutionToString(solution) - newStr = (newStr.substring(0, lastFound+newStrOffset))+newCode+(newStr.substring(i+newStrOffset, newStr.length)) - - newStrOffset += -(i-lastFound)+newCode.length - - case _ => - if (!ignoreMissing) { - warning("Could not find solution corresponding to choose at "+lineno+":"+scalaOffset) - } - } - } - } while(lastFound> 0) - - newStr - } def solutionToString(solution: Solution): String = { ScalaPrinter(simplifyLets(solution.toExpr)) } - def updateFile(origFile: File, solutions: Map[Choose, Solution], ignoreMissing: Boolean = false) { - import java.io.{File, BufferedWriter, FileWriter} - val FileExt = """^(.+)\.([^.]+)$""".r - - origFile.getAbsolutePath() match { - case FileExt(path, "scala") => - var i = 0 - def savePath = path+"."+i+".scala" - while (new File(savePath).isFile()) { - i += 1 - } - - val origCode = readFile(origFile) - val backup = new File(savePath) - val newFile = new File(origFile.getAbsolutePath()) - origFile.renameTo(backup) - - - val newCode = substitueChooses(origCode, solutions, ignoreMissing) - - val out = new BufferedWriter(new FileWriter(newFile)) - out.write(newCode) - out.close - case _ => - - } - } - - def readFile(file: File): String = { - scala.io.Source.fromFile(file).mkString - } } diff --git a/testcases/synthesis/Matching.scala b/testcases/synthesis/Matching.scala new file mode 100644 index 0000000000000000000000000000000000000000..21a7cf442e9165130e5a0c095bc4c86b608b9e69 --- /dev/null +++ b/testcases/synthesis/Matching.scala @@ -0,0 +1,13 @@ +import leon.Utils._ + +object Matching { + def t1(a: NatList) = choose( (x: Nat) => Cons(x, Nil()) == a) + + abstract class Nat + case class Z() extends Nat + case class Succ(n: Nat) extends Nat + + abstract class NatList + case class Nil() extends NatList + case class Cons(head: Nat, tail: NatList) extends NatList +}