diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala index 23dcdd43cf0adb469e65a86925a17ac20fb9613c..7b0a88a4e4fe08d113543121c4040ffaceed69cd 100644 --- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala +++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala @@ -32,9 +32,12 @@ trait CodeExtraction extends ASTExtractors { if (p.isRange) { val start = p.focusStart val end = p.focusEnd - LeonRangePosition(start.line, start.column, end.line, end.column, p.source.file.file) + LeonRangePosition(start.line, start.column, start.point, + end.line, end.column, end.point, + p.source.file.file) } else { - LeonOffsetPosition(p.line, p.column, p.source.file.file) + LeonOffsetPosition(p.line, p.column, p.point, + p.source.file.file) } } diff --git a/src/main/scala/leon/purescala/ScalaPrinter.scala b/src/main/scala/leon/purescala/ScalaPrinter.scala index bc186b293d25f5e0615f916cc3638aee6d6626b7..eaa7bca9a3011abb9c684af763ea174c78bf4f62 100644 --- a/src/main/scala/leon/purescala/ScalaPrinter.scala +++ b/src/main/scala/leon/purescala/ScalaPrinter.scala @@ -341,7 +341,7 @@ class ScalaPrinter(sb: StringBuffer = new StringBuffer) extends PrettyPrinter(sb fd.args.foreach(arg => { sb.append(idToString(arg.id)) - sb.append(" : ") + sb.append(": ") pp(arg.tpe, p) if(c < sz - 1) { @@ -350,7 +350,7 @@ class ScalaPrinter(sb: StringBuffer = new StringBuffer) extends PrettyPrinter(sb c = c + 1 }) - sb.append(") : ") + sb.append("): ") pp(fd.returnType, p) sb.append(" = {\n") ind(lvl+1) diff --git a/src/main/scala/leon/purescala/TreeOps.scala b/src/main/scala/leon/purescala/TreeOps.scala index 9d7a569750af2896e5afc4a46485ebee3486ad5d..02b771757bba5f0d4866165ea67e4bfd59a539df 100644 --- a/src/main/scala/leon/purescala/TreeOps.scala +++ b/src/main/scala/leon/purescala/TreeOps.scala @@ -2142,4 +2142,62 @@ object TreeOps { areExaustive(Seq((m.scrutinee.getType, patterns))) } + def flattenFunctions(fdOuter: FunDef): FunDef = { + fdOuter.body match { + case Some(LetDef(fdInner, FunctionInvocation(fdInner2, args))) if fdInner == fdInner2 => + val argsDef = fdOuter.args.map(_.id) + val argsCall = args.collect { case Variable(id) => id } + + if (argsDef.toSet == argsCall.toSet) { + val defMap = argsDef.zipWithIndex.toMap + val rewriteMap = argsCall.map(defMap) + + val innerIdsToOuterIds = (fdInner.args.map(_.id) zip argsCall).toMap + + def pre(e: Expr) = e match { + case FunctionInvocation(fd, args) if fd == fdInner => + val newArgs = (args zip rewriteMap).sortBy(_._2) + FunctionInvocation(fdOuter, newArgs.map(_._1)) + case Variable(id) => + Variable(innerIdsToOuterIds.getOrElse(id, id)) + case _ => + e + } + + def mergePre(outer: Option[Expr], inner: Option[Expr]): Option[Expr] = (outer, inner) match { + case (None, Some(ie)) => + Some(simplePreTransform(pre)(ie)) + case (Some(oe), None) => + Some(oe) + case (None, None) => + None + case (Some(oe), Some(ie)) => + Some(And(oe, simplePreTransform(pre)(ie))) + } + + def mergePost(outer: Option[(Identifier, Expr)], inner: Option[(Identifier, Expr)]): Option[(Identifier, Expr)] = (outer, inner) match { + case (None, Some((iid, ie))) => + Some((iid, simplePreTransform(pre)(ie))) + case (Some(oe), None) => + Some(oe) + case (None, None) => + None + case (Some((oid, oe)), Some((iid, ie))) => + Some((oid, And(oe, replaceFromIDs(Map(iid -> Variable(oid)), simplePreTransform(pre)(ie))))) + } + + val newFd = fdOuter.duplicate + + newFd.body = fdInner.body.map(b => simplePreTransform(pre)(b)) + newFd.precondition = mergePre(fdOuter.precondition, fdInner.precondition) + newFd.postcondition = mergePost(fdOuter.postcondition, fdInner.postcondition) + + newFd + } else { + fdOuter + } + case _ => + fdOuter + } + } } diff --git a/src/main/scala/leon/synthesis/FileInterface.scala b/src/main/scala/leon/synthesis/FileInterface.scala index 145b5b132adc2cb318f0d4ccc103b3c28bf6163b..49ff882540e8c27798bf7e50c96cb50427d44cdd 100644 --- a/src/main/scala/leon/synthesis/FileInterface.scala +++ b/src/main/scala/leon/synthesis/FileInterface.scala @@ -8,7 +8,7 @@ import purescala.Common.Tree import purescala.Definitions.FunDef import purescala.ScalaPrinter -import leon.utils.Position +import leon.utils.RangePosition import java.io.File class FileInterface(reporter: Reporter) { @@ -33,7 +33,7 @@ class FileInterface(reporter: Reporter) { var newCode = origCode for ( (ci, e) <- solutions) { - newCode = substitute(newCode, CodePattern.forChoose(ci), e) + newCode = substitute(newCode, ci.ch, e) } val out = new BufferedWriter(new FileWriter(newFile)) @@ -44,102 +44,23 @@ class FileInterface(reporter: Reporter) { } } - case class CodePattern(startWith: String, pos: Position, blocks: Int) + def substitute(str: String, fromTree: Tree, toTree: Tree): String = { - object CodePattern { - def forChoose(ci: ChooseInfo) = CodePattern("choose", ci.ch.getPos, 1) - def forFunDef(fd: FunDef) = CodePattern("def", fd.getPos, 2) - } - - def substitute(str: String, pattern: CodePattern, subst: Tree): String = { - var lines = List[Int]() - - // Compute line positions - var lastFound = -1 - do { - lastFound = str.indexOf('\n', lastFound+1) - - if (lastFound > -1) { - lines = lastFound :: lines - } - } while(lastFound> 0) - lines = lines.reverse; - - def lineOf(offset: Int): (Int, Int) = { - lines.zipWithIndex.find(_._1 > offset) match { - case Some((off, no)) => - (no+1, if (no > 0) lines(no-1) else 0) - case None => - (lines.size+1, lines.lastOption.getOrElse(0)) - } - } + fromTree.getPos match { + case rp: RangePosition => + val from = rp.pointFrom + val to = rp.pointTo - def getLineIndentation(offset: Int): Int = { - var i = str.lastIndexOf('\n', offset)+1 + val before = str.substring(0, from) + val after = str.substring(to, str.length) - var res = 0; + val newCode = ScalaPrinter(toTree, fromTree.getPos.col/2) - while (i < str.length) { - val c = str.charAt(i) - i += 1 - - if (c == ' ') { - res += 1 - } else if (c == '\t') { - res += 4 - } else { - i = str.length - } - } + before + newCode + after - res + case _ => + sys.error("Substitution requires RangePos on the input tree: "+fromTree) } - - lastFound = -1 - - var newStr = str - var newStrOffset = 0 - - do { - lastFound = str.indexOf(pattern.startWith, lastFound+1) - - if (lastFound > -1) { - val (lineno, lineoffset) = lineOf(lastFound) - // compute scala equivalent of the position: - val scalaOffset = str.substring(lineoffset, lastFound).replaceAll("\t", " "*8).length - - val indent = getLineIndentation(lastFound) - - if (pattern.pos.line == lineno && pattern.pos.col == scalaOffset) { - var lvl = 0; - var i = lastFound + 6; - var continue = true; - do { - var blocksRemaining = pattern.blocks - val c = str.charAt(i) - if (c == '(' || c == '{') { - lvl += 1 - } else if (c == ')' || c == '}') { - lvl -= 1 - if (lvl == 0) { - blocksRemaining -= 1 - if (blocksRemaining == 0) { - continue = false - } - } - } - i += 1 - } while(continue) - - val newCode = ScalaPrinter(subst, indent/2) - newStr = (newStr.substring(0, lastFound+newStrOffset))+newCode+(newStr.substring(i+newStrOffset, newStr.length)) - - newStrOffset += -(i-lastFound)+newCode.length - } - } - } while(lastFound> 0) - - newStr } def readFile(file: File): String = { diff --git a/src/main/scala/leon/utils/Positions.scala b/src/main/scala/leon/utils/Positions.scala index b44dc9a36c7924972b114d9b9e4d463b4f59952e..1ca2f9faf279537f9bff8965d85f1c2a7d5b716a 100644 --- a/src/main/scala/leon/utils/Positions.scala +++ b/src/main/scala/leon/utils/Positions.scala @@ -16,9 +16,11 @@ abstract class Position { def isDefined = true } -case class OffsetPosition(line: Int, col: Int, file: File) extends Position +case class OffsetPosition(line: Int, col: Int, point: Int, file: File) extends Position -case class RangePosition(lineFrom: Int, colFrom: Int, lineTo: Int, colTo: Int, file: File) extends Position { +case class RangePosition(lineFrom: Int, colFrom: Int, pointFrom: Int, + lineTo: Int, colTo: Int, pointTo: Int, + file: File) extends Position { val line = lineFrom val col = colFrom