diff --git a/src/main/scala/lesynth/Evaluation.scala b/src/main/scala/lesynth/Evaluation.scala new file mode 100644 index 0000000000000000000000000000000000000000..93093979d158648310f749d733b33c9009862ba6 --- /dev/null +++ b/src/main/scala/lesynth/Evaluation.scala @@ -0,0 +1,69 @@ +package lesynth + +import scala.util.Random + +import leon.purescala.Trees.{ Variable => LeonVariable, _ } +import leon.purescala.Common.Identifier + +case class Evaluation(examples: Seq[Map[Identifier, Expr]], exampleFun: (Expr, Map[Identifier, Expr])=>Boolean, candidates: Seq[Expr], + exampleRunner: ExampleRunner) { + + val random: Random = new Random(System.currentTimeMillis) + + // keep track of evaluations + var nextExamples: Map[Int, Int] = Map() + + var evaluations = Map[Int, Array[Boolean]]() + +// def evalAvailable(expr: Int) = { +// val nextExample = nextExamples.getOrElse(expr, 0) +// if (nextExample >= examples.size) false +// else true +// } + + def evaluate(exprInd: Int) = { + numberOfEvaluationCalls += 1 + + val nextExample = nextExamples.getOrElse(exprInd, 0) + if (nextExample >= examples.size) throw new RuntimeException("Exhausted examples for " + exprInd) + + nextExamples += (exprInd -> (nextExample + 1)) + + val example = examples(nextExample) + val expressionToCheck = candidates(exprInd) + + val result = exampleFun(expressionToCheck, example) + val evalArray = evaluations.getOrElse(exprInd, Array.ofDim[Boolean](examples.size)) + evalArray(nextExample) = result + evaluations += (exprInd -> evalArray) + result + } + +// def evaluate(expr: Int, exampleInd: Int) = { +// val nextExample = nextExamples.getOrElse(expr, 0) +// assert(exampleInd <= nextExample) +// +// if (exampleInd >= nextExample) { +// nextExamples += (expr -> (nextExample + 1)) +// val example = examples(nextExample) +// val result = example(expr) +// val evalArray = evaluations.getOrElse(expr, Array.ofDim[Boolean](examples.size)) +// evalArray(nextExample) = result +// evaluations += (expr -> evalArray) +// result +// } else { +// assert(evaluations.contains(expr)) +// evaluations.get(expr).get(exampleInd) +// } +// } + + def evaluate(expression: Int, example: Int => Boolean) = { + example(expression) + } + + def getNumberOfExamples = examples.size + + var numberOfEvaluationCalls = 0 + def getEfficiencyRatio = numberOfEvaluationCalls.toFloat / (examples.size * evaluations.size) + +} \ No newline at end of file diff --git a/src/main/scala/lesynth/Ranker.scala b/src/main/scala/lesynth/Ranker.scala new file mode 100644 index 0000000000000000000000000000000000000000..13742f2c2a94c8e800519a0ca7a7cf56e928507e --- /dev/null +++ b/src/main/scala/lesynth/Ranker.scala @@ -0,0 +1,135 @@ +package lesynth + +import util.control.Breaks._ +import scala.collection._ + +import leon.purescala.Trees.{ Variable => LeonVariable, _ } + +class Ranker(candidates: Seq[Expr], evaluation: Evaluation, printStep: Boolean = false) { + + var rankings: Array[Int] = (0 until candidates.size).toArray + + // keep track of intervals + var tuples: Map[Int, (Int, Int)] = + (for (i <- 0 until candidates.size) + yield (i, (0, evaluation.getNumberOfExamples))) toMap + + def getKMax(k: Int) = { + + } + + def evaluate(ind: Int) { + val tuple = tuples(ind) + val expr = ind + + tuples += ( ind -> + { + if (evaluation.evaluate(expr)) { + (tuple._1 + 1, tuple._2) + } else { + (tuple._1, tuple._2 - 1) + } + } + ) + + } + + def swap(ind1: Int, ind2: Int) = { + val temp = rankings(ind1) + rankings(ind1) = rankings(ind2) + rankings(ind2) = temp + } + + def bubbleDown(ind: Int): Unit = { + if (compare(rankings(ind), rankings(ind + 1))) { + swap(ind, ind + 1) + if (ind < candidates.size-2) + bubbleDown(ind + 1) + } + } + + var numberLeft = candidates.size + + def getMax = { + + numberLeft = candidates.size + + while (numberLeft > 1) { + + evaluate(rankings(0)) + + if (printStep) { + println(printTuples) + println("*** left: " + numberLeft) + } + + bubbleDown(0) + + val topRank = rankings(0) + var secondRank = rankings(1) + + while (strictCompare(secondRank, topRank) && numberLeft > 1) { + numberLeft -= 1 + swap(1, numberLeft) + secondRank = rankings(1) + } + + } + + if (printStep) { + println(printTuples) + println("***: " + numberLeft) + } + + candidates(rankings(0)) + + } + +// def getVerifiedMax = { +// val results = (for (candidate <- 0 until candidates.size) +// yield (candidate, +// (0 /: (0 until evaluation.getNumberOfExamples)) { +// (res, exampleInd) => +// if (evaluation.evaluate(candidate, exampleInd)) { +// res + 1 +// } else { +// res +// } +// })) +// +// val maxPassed = results.sortWith((r1, r2) => r1._2 < r2._2)(candidates.size - 1)._2 +// +// (results.filter(_._2 == maxPassed).map(x => candidates(x._1)), results)//.map(x => candidates(x._1)) +// } + + def strictCompare(x: Int, y: Int) = { + val tuple1 = tuples(x) + val tuple2 = tuples(y) + + tuple1._2 <= tuple2._1 + } + + def compare(x: Int, y: Int) = { + val tuple1 = tuples(x) + val tuple2 = tuples(y) + + val median1 = (tuple1._1 + tuple1._2).toFloat/2 + val median2 = (tuple2._1 + tuple2._2).toFloat/2 + + /*median1 < median2 || median1 == median2 && */ + tuple1._2 < tuple2._2 || tuple1._2 == tuple2._2 && median1 < median2 + } + + def rankOf(expr: Int) = + rankings.indexOf(expr) + + def printTuples = + (for ((tuple, ind) <- + tuples.toList.sortWith((tp1, tp2) => rankOf(tp1._1) <= rankOf(tp2._1)).zipWithIndex) + yield (if (tuple._1 == rankings(0)) "->" else if (ind >= numberLeft) "/\\" else " ") + tuple._1 + + ": " + + ((0 until evaluation.getNumberOfExamples) map { + x => if (x < tuple._2._1) '+' else if (x >= tuple._2._2) '-' else '_' + }).mkString).mkString("\n") + +} \ No newline at end of file diff --git a/src/main/scala/lesynth/SynthesizerExamples.scala b/src/main/scala/lesynth/SynthesizerExamples.scala index 4c4fd1f731bae6f60ada0ada2baeec6a6ce77e0d..c88f9521c27f2c5e8732fcd1c46c51e7787f196b 100755 --- a/src/main/scala/lesynth/SynthesizerExamples.scala +++ b/src/main/scala/lesynth/SynthesizerExamples.scala @@ -294,7 +294,7 @@ class SynthesizerForRuleExamples( refiner = new Refiner(program, hole, holeFunDef) fine("Refiner initialized. Recursive call: " + refiner.recurentExpression) - exampleRunner = new ExampleRunner(program) + exampleRunner = new ExampleRunner(program, 4000) exampleRunner.counterExamples ++= //examples introduceExamples(holeFunDef.args.map(_.id), loader) @@ -336,6 +336,43 @@ class SynthesizerForRuleExamples( count } + + + def evaluateCandidate(snippet: Expr, mapping: Map[Identifier, Expr]) = { + val oldPreconditionSaved = holeFunDef.precondition + val oldBodySaved = holeFunDef.body + + // restore initial precondition + holeFunDef.precondition = Some(initialPrecondition) + + // get the whole body (if else...) + val accumulatedExpression = accumulatingExpression(snippet) + // set appropriate body to the function for the correct evaluation + holeFunDef.body = Some(accumulatedExpression) + + + import TreeOps._ + val expressionToCheck = + //Globals.bodyAndPostPlug(exp) + { + val resFresh = FreshIdentifier("result", true).setType(accumulatedExpression.getType) + Let( + resFresh, accumulatedExpression, + replace(Map(ResultVariable() -> LeonVariable(resFresh)), matchToIfThenElse(holeFunDef.getPostcondition))) + } + + fine("going to evaluate candidate for: " + holeFunDef) + fine("going to evaluate candidate for: " + expressionToCheck) + + val count = exampleRunner.evaluate(expressionToCheck, mapping) +// if (snippet.toString == "Cons(l1.head, concat(l1.tail, l2))") +// interactivePause + + holeFunDef.precondition = oldPreconditionSaved + holeFunDef.body = oldBodySaved + + count + } def synthesize: Report = { reporter.info("Synthesis called on file: " + fileName) @@ -404,82 +441,62 @@ class SynthesizerForRuleExamples( reporter.info("Going into a enumeration/testing phase.") fine("evaluating examples: " + exampleRunner.counterExamples.mkString("\n")) - - // found precondition? - found = false - // try to find it - breakable { - // go through all snippets - for ( - snippet <- snippetsIterator; val snippetTree = snippet.getSnippet; - // filter if seen - if ! (seenBranchExpressions contains snippetTree.toString) - ) { - finest("snippetTree is: " + snippetTree) - // note that we do not add snippets to the set of seen if enqueued - if (checkTimeout) break - - // skip avoidable calls - if (!refiner.isAvoidable(snippetTree, problem.as)) { - - // passed example pairs - val passCount = countPassedExamples(snippetTree) - - if (passCount == exampleRunner.counterExamples.size) { - info("All examples passed. Testing snippet " + snippetTree + " right away") - if (tryToSynthesizeBranch(snippetTree)) { - // will set found if correct body is found - noBranchFoundIteration = 0 - break - } - } else { - if (passCount > 0) { - finest("Snippet with pass count goes into queue: " + (snippetTree, passCount)) - pq.enqueue((snippetTree, 100 + (passCount * iteration) - snippet.getWeight.toInt)) - } - else { - fine("Snippet with pass count was dropped: " + (snippetTree, passCount) + - " while number of examples was: " + exampleRunner.counterExamples.size) - // add to seen if branch was not found for it - seenBranchExpressions += snippetTree.toString - } - } - - } else { - fine("Refiner filtered this snippet: " + snippetTree) - seenBranchExpressions += snippetTree.toString - } // if (!refiner.isAvoidable(snippetTree)) { - - // check if we this makes one test iteration - if (numberOfTested >= numberOfTestsInIteration * noBranchFoundIteration) { - reporter.info("Finalizing enumeration/testing phase.") - fine("Queue contents: " + pq.toList.take(10).mkString("\n")) - fine({ if (pq.isEmpty) "queue is empty" else "head of queue is: " + pq.head }) - - //interactivePause - // go and check the topmost numberOfCheckInIteration - for (i <- 1 to math.min(numberOfCheckInIteration, pq.size)) { - val nextSnippet = pq.dequeue._1 - fine("dequeued nextSnippet: " + nextSnippet) - //interactivePause - - if (tryToSynthesizeBranch(nextSnippet)) { - noBranchFoundIteration = 0 - break - } - - // dont drop snippets that were on top of queue (they may be good for else ... part) - //seenBranchExpressions += nextSnippet.toString - } - - - numberOfTested = 0 - } else - numberOfTested += 1 - - } // for (snippet <- snippets - } // breakable { for (snippet <- snippets + breakable { + while(true) { + val batchSize = numberOfTestsInIteration * (1 << noBranchFoundIteration) + + reporter.info("numberOfTested: " + numberOfTested) + // ranking of candidates + val candidates = { + val (it1, it2) = snippetsIterator.duplicate + snippetsIterator = it2.drop(batchSize) + it1.take(batchSize). + map(_.getSnippet).filterNot( + snip => { + if (snip.toString == "merge(sort(split(list).fst), sort(split(list).snd))") println("AAA") + + (seenBranchExpressions contains snip.toString) || refiner.isAvoidable(snip, problem.as) + } + ).toSeq + } + info("got candidates of size: " + candidates.size) + //interactivePause + + if (candidates.size > 0) { + val ranker = new Ranker(candidates, + Evaluation(exampleRunner.counterExamples, this.evaluateCandidate _, candidates, exampleRunner), + false) + + val maxCandidate = ranker.getMax + info("maxCandidate is: " + maxCandidate) + numberOfTested += batchSize + +// if (candidates.exists(_.toString == "merge(sort(split(list).fst), sort(split(list).snd))")) { +// println(ranker.printTuples) +// println("AAA2") +// println("Candidates: " + candidates.zipWithIndex.map({ +// case (cand, ind) => "[" + ind + "]" + cand.toString +// }).mkString(", ")) +// println("Examples: " + exampleRunner.counterExamples.zipWithIndex.map({ +// case (example, ind) => "[" + ind + "]" + example.toString +// }).mkString(", ")) +// interactivePause +// } + + //interactivePause + if (tryToSynthesizeBranch(maxCandidate)) { + noBranchFoundIteration = 0 + break + } + + noBranchFoundIteration += 1 + } + } + } + + // add to seen if branch was not found for it + //seenBranchExpressions += snippetTree.toString // if did not found for any of the branch expressions if (found) { diff --git a/testcases/lesynth/BinarySearchTree.scala b/testcases/lesynth/BinarySearchTree.scala new file mode 100644 index 0000000000000000000000000000000000000000..cada54218590ecdb1798d2f2227f0d42e38afc33 --- /dev/null +++ b/testcases/lesynth/BinarySearchTree.scala @@ -0,0 +1,94 @@ +import scala.collection.immutable.Set + +import leon.Annotations._ +import leon.Utils._ + +object BinarySearchTree { + sealed abstract class Tree + case class Node(left: Tree, value: Int, right: Tree) extends Tree + case class Leaf() extends Tree + + def contents(tree: Tree): Set[Int] = tree match { + case Leaf() => Set.empty[Int] + case Node(l, v, r) => contents(l) ++ Set(v) ++ contents(r) + } + + def isSorted(tree: Tree): Boolean = tree match { + case Leaf() => true + case Node(Leaf(), v, Leaf()) => true + case Node(l@Node(_, vIn, _), v, Leaf()) => v > vIn && isSorted(l) + case Node(Leaf(), v, r@Node(_, vIn, _)) => v < vIn && isSorted(r) + case Node(l@Node(_, vInLeft, _), v, r@Node(_, vInRight, _)) => + v > vInLeft && v < vInRight && isSorted(l) && isSorted(r) + } + + def member(tree: Tree, value: Int): Boolean = { + require(isSorted(tree)) + tree.isInstanceOf[Node] + } ensuring (res => res && contents(tree).contains(value) || + (!res && !contents(tree).contains(value))) + +// def member(tree: Tree, value: Int): Boolean = { +// require(isSorted(tree)) +// tree match { +// case Leaf() => false +// case n @ Node(l, v, r) => if (v < value) { +// member(r, value) +// } else if (v > value) { +// member(l, value) +// } else { +// true +// } +// } +// } ensuring (_ || !(contents(tree) == contents(tree) ++ Set(value))) +// +// def insert(tree: Tree, value: Int): Node = { +// require(isSorted(tree)) +// tree match { +// case Leaf() => Node(Leaf(), value, Leaf()) +// case n @ Node(l, v, r) => if (v < value) { +// Node(l, v, insert(r, value)) +// } else if (v > value) { +// Node(insert(l, value), v, r) +// } else { +// n +// } +// } +// } ensuring (res => contents(res) == contents(tree) ++ Set(value) && isSorted(res)) + + // def treeMin(tree: Node): Int = { + // require(isSorted(tree).sorted) + // tree match { + // case Node(left, v, _) => left match { + // case Leaf() => v + // case n@Node(_, _, _) => treeMin(n) + // } + // } + // } + // + // def treeMax(tree: Node): Int = { + // require(isSorted(tree).sorted) + // tree match { + // case Node(_, v, right) => right match { + // case Leaf() => v + // case n@Node(_, _, _) => treeMax(n) + // } + // } + // } + +// def remove(tree: Tree, value: Int): Node = { +// require(isSorted(tree)) +// tree match { +// case l @ Leaf() => l +// case n @ Node(l, v, r) => if (v < value) { +// Node(l, v, insert(r, value)) +// } else if (v > value) { +// Node(insert(l, value), v, r) +// } else { +// n +// } +// } +// } ensuring (contents(_) == contents(tree) -- Set(value)) + +} +