diff --git a/src/main/scala/leon/LeonContext.scala b/src/main/scala/leon/LeonContext.scala index 9037c5f9e7da0f7d2e9fbb97e1cc6df11bde2195..29cc186e20d958825dbb76ab7a3808519530ff33 100644 --- a/src/main/scala/leon/LeonContext.scala +++ b/src/main/scala/leon/LeonContext.scala @@ -4,6 +4,7 @@ import purescala.Definitions.Program case class LeonContext( val settings: Settings = Settings(), + val files: List[String] = Nil, val reporter: Reporter = new DefaultReporter ) diff --git a/src/main/scala/leon/Main.scala b/src/main/scala/leon/Main.scala index e52ea1ccf6cebb8a9ccd3b85de1a3d4d89474ec0..b815f5cb9558d70a3d9942e74d7c592eec4dfc47 100644 --- a/src/main/scala/leon/Main.scala +++ b/src/main/scala/leon/Main.scala @@ -45,6 +45,8 @@ object Main { // Detect unknown options: val options = args.filter(_.startsWith("--")) + val files = args.filterNot(_.startsWith("-")) + val leonOptions = options.flatMap { opt => val leonOpt: LeonOption = opt.substring(2, opt.length).split("=", 2).toList match { case List(name, value) => @@ -87,7 +89,7 @@ object Main { case _ => } - LeonContext(settings = settings, reporter = reporter) + LeonContext(settings = settings, reporter = reporter, files = files) } implicit def phaseToPipeline[F, T](phase: LeonPhase[F, T]): Pipeline[F, T] = new PipeCons(phase, new PipeNil()) @@ -114,18 +116,18 @@ object Main { NoopPhase[Program]() } - val pipeAnalysis: Pipeline[Program, Unit] = + val pipeAnalysis: Pipeline[Program, Program] = if (settings.analyze) { - AnalysisPhase andThen - ExitPhase[Program]() + AnalysisPhase } else { - ExitPhase[Program]() + NoopPhase[Program]() } pipeBegin followedBy pipeTransforms followedBy pipeSynthesis followedBy - pipeAnalysis + pipeAnalysis andThen + ExitPhase[Program]() } def main(args : Array[String]) { diff --git a/src/main/scala/leon/plugin/CodeExtraction.scala b/src/main/scala/leon/plugin/CodeExtraction.scala index 385557d150caa7f989fe0350a19d406e97d6959c..8f48d8629a9019ebba5dfab6f0329f10da1acb1a 100644 --- a/src/main/scala/leon/plugin/CodeExtraction.scala +++ b/src/main/scala/leon/plugin/CodeExtraction.scala @@ -683,7 +683,7 @@ trait CodeExtraction extends Extractors { Epsilon(c1).setType(pstpe).setPosInfo(epsi.pos.line, epsi.pos.column) } - case chs @ ExChooseExpression(args, tpe, body) => { + case chs @ ExChooseExpression(args, tpe, body, select) => { val cTpe = scalaType2PureScala(unit, silent)(tpe) val vars = args map { case (tpe, sym) => @@ -696,7 +696,7 @@ trait CodeExtraction extends Extractors { val cBody = rec(body) - Choose(vars, cBody).setType(cTpe).setPosInfo(chs.pos.line, chs.pos.column) + Choose(vars, cBody).setType(cTpe).setPosInfo(select.pos.line, select.pos.column) } case ExWaypointExpression(tpe, i, tree) => { diff --git a/src/main/scala/leon/plugin/Extractors.scala b/src/main/scala/leon/plugin/Extractors.scala index d3879226b4ade267f7cbbfe98d5f313bccce63ed..8245195480b3ecf038e3e41fd209361c23ded5d7 100644 --- a/src/main/scala/leon/plugin/Extractors.scala +++ b/src/main/scala/leon/plugin/Extractors.scala @@ -176,12 +176,12 @@ trait Extractors { } object ExChooseExpression { - def unapply(tree: Apply) : Option[(List[(Type, Symbol)], Type, Tree)] = tree match { + def unapply(tree: Apply) : Option[(List[(Type, Symbol)], Type, Tree, Tree)] = tree match { case a @ Apply( - TypeApply(Select(Select(funcheckIdent, utilsName), chooseName), types), + TypeApply(Select(s @ Select(funcheckIdent, utilsName), chooseName), types), Function(vds, predicateBody) :: Nil) => { if (utilsName.toString == "Utils" && chooseName.toString == "choose") - Some(((types.map(_.tpe) zip vds.map(_.symbol)).toList, a.tpe, predicateBody)) + Some(((types.map(_.tpe) zip vds.map(_.symbol)).toList, a.tpe, predicateBody, s)) else None } diff --git a/src/main/scala/leon/synthesis/SynthesisPhase.scala b/src/main/scala/leon/synthesis/SynthesisPhase.scala index 62e004d22cea004b5e45eab1ee956bc2b49790a4..ac10811eb7cb361fa07a71b8fb460659cc06a312 100644 --- a/src/main/scala/leon/synthesis/SynthesisPhase.scala +++ b/src/main/scala/leon/synthesis/SynthesisPhase.scala @@ -17,6 +17,10 @@ object SynthesisPhase extends LeonPhase[Program, Program] { val synth = new Synthesizer(ctx.reporter, solvers) val solutions = synth.synthesizeAll(p) + for (file <- ctx.files) { + synth.updateFile(new java.io.File(file), solutions) + } + p } diff --git a/src/main/scala/leon/synthesis/Synthesizer.scala b/src/main/scala/leon/synthesis/Synthesizer.scala index 7e53e3812bada7a33bb82b095413410adf6444b0..74771d80affd87a37d8918dda733e81134010fbe 100644 --- a/src/main/scala/leon/synthesis/Synthesizer.scala +++ b/src/main/scala/leon/synthesis/Synthesizer.scala @@ -8,6 +8,7 @@ import purescala.Trees.{Expr, Not} import purescala.ScalaPrinter import Extensions.Solver +import java.io.File import collection.mutable.PriorityQueue @@ -87,7 +88,7 @@ class Synthesizer(val r: Reporter, val solvers: List[Solver]) { } import purescala.Trees._ - def synthesizeAll(program: Program): List[(Choose, Solution)] = { + def synthesizeAll(program: Program): Map[Choose, Solution] = { solvers.foreach(_.setProgram(program)) @@ -95,7 +96,7 @@ class Synthesizer(val r: Reporter, val solvers: List[Solver]) { def noop(u:Expr, u2: Expr) = u - var solutions = List[(Choose, Solution)]() + var solutions = Map[Choose, Solution]() def actOnChoose(f: FunDef)(e: Expr, a: Expr): Expr = e match { case ch @ Choose(vars, pred) => @@ -110,7 +111,7 @@ class Synthesizer(val r: Reporter, val solvers: List[Solver]) { val sol = synthesize(Problem(as, phi, xs), rules) - solutions = (ch, sol) :: solutions + solutions += ch -> sol info("Scala code:") info(ScalaPrinter(simplifyLets(sol.toExpr))) @@ -128,10 +129,110 @@ class Synthesizer(val r: Reporter, val solvers: List[Solver]) { solutions } - def substitueChooses(file: String, sols: List[(Choose, Solution)]) = { - import scala.io.Source - val src = Source.fromFile(file) + def substitueChooses(str: String, solutions: Map[Choose, Solution], ignoreMissing: Boolean = false): String = { + var lines = List[Int]() + // Compute line positions + var lastFound = -1 + do { + lastFound = str.indexOf('\n', lastFound+1) + + if (lastFound > -1) { + lines = lastFound :: lines + } + } while(lastFound> 0) + lines = lines.reverse; + + def lineOf(offset: Int): (Int, Int) = { + lines.zipWithIndex.find(_._1 > offset) match { + case Some((off, no)) => + (no+1, if (no > 0) lines(no-1) else 0) + case None => + (lines.size+1, lines.lastOption.getOrElse(0)) + } + } + + lastFound = -1 + + var newStr = str + var newStrOffset = 0 + + do { + lastFound = str.indexOf("choose", lastFound+1) + + if (lastFound > -1) { + val (lineno, lineoffset) = lineOf(lastFound) + // compute scala equivalent of the position: + val scalaOffset = str.substring(lineoffset, lastFound).replaceAll("\t", " "*8).length + + solutions.find(_._1.posIntInfo == (lineno, scalaOffset)) match { + case Some((choose, solution)) => + var lvl = 0; + var i = lastFound + 6; + var continue = true; + do { + val c = str.charAt(i) + if (c == '(' || c == '{') { + lvl += 1 + } else if (c == ')' || c == '}') { + lvl -= 1 + if (lvl == 0) { + continue = false + } + } + i += 1 + } while(continue) + + val newCode = solutionToString(solution) + newStr = (newStr.substring(0, lastFound+newStrOffset))+newCode+(newStr.substring(i+newStrOffset, newStr.length)) + + newStrOffset += -(i-lastFound)+newCode.length + + case _ => + if (!ignoreMissing) { + warning("Could not find solution corresponding to choose at "+lineno+":"+scalaOffset) + } + } + } + } while(lastFound> 0) + + newStr + } + + def solutionToString(solution: Solution): String = { + ScalaPrinter(simplifyLets(solution.toExpr)) + } + + def updateFile(origFile: File, solutions: Map[Choose, Solution], ignoreMissing: Boolean = false) { + import java.io.{File, BufferedWriter, FileWriter} + val FileExt = """^(.+)\.([^.]+)$""".r + + origFile.getAbsolutePath() match { + case FileExt(path, "scala") => + var i = 0 + def savePath = path+"."+i+".scala" + while (new File(savePath).isFile()) { + i += 1 + } + + val origCode = readFile(origFile) + val backup = new File(savePath) + val newFile = new File(origFile.getAbsolutePath()) + origFile.renameTo(backup) + + + val newCode = substitueChooses(origCode, solutions, ignoreMissing) + + val out = new BufferedWriter(new FileWriter(newFile)) + out.write(newCode) + out.close + case _ => + + } + } + + def readFile(file: File): String = { + scala.io.Source.fromFile(file).mkString } } diff --git a/testcases/synthesis/ChoosePos.scala b/testcases/synthesis/ChoosePos.scala new file mode 100644 index 0000000000000000000000000000000000000000..530d20c8120a69579ff3aed4c2cd8c3d42cc84f0 --- /dev/null +++ b/testcases/synthesis/ChoosePos.scala @@ -0,0 +1,18 @@ +import leon.Utils._ + +object ChoosePos { + + +def c1(x: Int): Int = + choose { + (y: Int) => y > x + } + +def c2(x: Int): Int = + choose ( + (y: Int) => y > x + ) + def c3(x: Int): Int = choose { (y: Int) => y > x } + def c4(x: Int): Int = choose { (y: Int) => y > x }; def c5(x: Int): Int = choose { (y: Int) => y > x } + +}