Skip to content
Snippets Groups Projects
Commit ce64ea9a authored by Ivan Kuraj's avatar Ivan Kuraj Committed by Etienne Kneuss
Browse files

Added support for candidate ranking (for minimizing #evaluations)

parent 1730c4a5
Branches
Tags
No related merge requests found
package lesynth
import scala.util.Random
import leon.purescala.Trees.{ Variable => LeonVariable, _ }
import leon.purescala.Common.Identifier
case class Evaluation(examples: Seq[Map[Identifier, Expr]], exampleFun: (Expr, Map[Identifier, Expr])=>Boolean, candidates: Seq[Expr],
exampleRunner: ExampleRunner) {
val random: Random = new Random(System.currentTimeMillis)
// keep track of evaluations
var nextExamples: Map[Int, Int] = Map()
var evaluations = Map[Int, Array[Boolean]]()
// def evalAvailable(expr: Int) = {
// val nextExample = nextExamples.getOrElse(expr, 0)
// if (nextExample >= examples.size) false
// else true
// }
def evaluate(exprInd: Int) = {
numberOfEvaluationCalls += 1
val nextExample = nextExamples.getOrElse(exprInd, 0)
if (nextExample >= examples.size) throw new RuntimeException("Exhausted examples for " + exprInd)
nextExamples += (exprInd -> (nextExample + 1))
val example = examples(nextExample)
val expressionToCheck = candidates(exprInd)
val result = exampleFun(expressionToCheck, example)
val evalArray = evaluations.getOrElse(exprInd, Array.ofDim[Boolean](examples.size))
evalArray(nextExample) = result
evaluations += (exprInd -> evalArray)
result
}
// def evaluate(expr: Int, exampleInd: Int) = {
// val nextExample = nextExamples.getOrElse(expr, 0)
// assert(exampleInd <= nextExample)
//
// if (exampleInd >= nextExample) {
// nextExamples += (expr -> (nextExample + 1))
// val example = examples(nextExample)
// val result = example(expr)
// val evalArray = evaluations.getOrElse(expr, Array.ofDim[Boolean](examples.size))
// evalArray(nextExample) = result
// evaluations += (expr -> evalArray)
// result
// } else {
// assert(evaluations.contains(expr))
// evaluations.get(expr).get(exampleInd)
// }
// }
def evaluate(expression: Int, example: Int => Boolean) = {
example(expression)
}
def getNumberOfExamples = examples.size
var numberOfEvaluationCalls = 0
def getEfficiencyRatio = numberOfEvaluationCalls.toFloat / (examples.size * evaluations.size)
}
\ No newline at end of file
package lesynth
import util.control.Breaks._
import scala.collection._
import leon.purescala.Trees.{ Variable => LeonVariable, _ }
class Ranker(candidates: Seq[Expr], evaluation: Evaluation, printStep: Boolean = false) {
var rankings: Array[Int] = (0 until candidates.size).toArray
// keep track of intervals
var tuples: Map[Int, (Int, Int)] =
(for (i <- 0 until candidates.size)
yield (i, (0, evaluation.getNumberOfExamples))) toMap
def getKMax(k: Int) = {
}
def evaluate(ind: Int) {
val tuple = tuples(ind)
val expr = ind
tuples += ( ind ->
{
if (evaluation.evaluate(expr)) {
(tuple._1 + 1, tuple._2)
} else {
(tuple._1, tuple._2 - 1)
}
}
)
}
def swap(ind1: Int, ind2: Int) = {
val temp = rankings(ind1)
rankings(ind1) = rankings(ind2)
rankings(ind2) = temp
}
def bubbleDown(ind: Int): Unit = {
if (compare(rankings(ind), rankings(ind + 1))) {
swap(ind, ind + 1)
if (ind < candidates.size-2)
bubbleDown(ind + 1)
}
}
var numberLeft = candidates.size
def getMax = {
numberLeft = candidates.size
while (numberLeft > 1) {
evaluate(rankings(0))
if (printStep) {
println(printTuples)
println("*** left: " + numberLeft)
}
bubbleDown(0)
val topRank = rankings(0)
var secondRank = rankings(1)
while (strictCompare(secondRank, topRank) && numberLeft > 1) {
numberLeft -= 1
swap(1, numberLeft)
secondRank = rankings(1)
}
}
if (printStep) {
println(printTuples)
println("***: " + numberLeft)
}
candidates(rankings(0))
}
// def getVerifiedMax = {
// val results = (for (candidate <- 0 until candidates.size)
// yield (candidate,
// (0 /: (0 until evaluation.getNumberOfExamples)) {
// (res, exampleInd) =>
// if (evaluation.evaluate(candidate, exampleInd)) {
// res + 1
// } else {
// res
// }
// }))
//
// val maxPassed = results.sortWith((r1, r2) => r1._2 < r2._2)(candidates.size - 1)._2
//
// (results.filter(_._2 == maxPassed).map(x => candidates(x._1)), results)//.map(x => candidates(x._1))
// }
def strictCompare(x: Int, y: Int) = {
val tuple1 = tuples(x)
val tuple2 = tuples(y)
tuple1._2 <= tuple2._1
}
def compare(x: Int, y: Int) = {
val tuple1 = tuples(x)
val tuple2 = tuples(y)
val median1 = (tuple1._1 + tuple1._2).toFloat/2
val median2 = (tuple2._1 + tuple2._2).toFloat/2
/*median1 < median2 || median1 == median2 && */
tuple1._2 < tuple2._2 || tuple1._2 == tuple2._2 && median1 < median2
}
def rankOf(expr: Int) =
rankings.indexOf(expr)
def printTuples =
(for ((tuple, ind) <-
tuples.toList.sortWith((tp1, tp2) => rankOf(tp1._1) <= rankOf(tp2._1)).zipWithIndex)
yield (if (tuple._1 == rankings(0)) "->" else if (ind >= numberLeft) "/\\" else " ") + tuple._1 +
": " +
((0 until evaluation.getNumberOfExamples) map {
x => if (x < tuple._2._1) '+' else if (x >= tuple._2._2) '-' else '_'
}).mkString).mkString("\n")
}
\ No newline at end of file
......@@ -294,7 +294,7 @@ class SynthesizerForRuleExamples(
refiner = new Refiner(program, hole, holeFunDef)
fine("Refiner initialized. Recursive call: " + refiner.recurentExpression)
exampleRunner = new ExampleRunner(program)
exampleRunner = new ExampleRunner(program, 4000)
exampleRunner.counterExamples ++= //examples
introduceExamples(holeFunDef.args.map(_.id), loader)
......@@ -336,6 +336,43 @@ class SynthesizerForRuleExamples(
count
}
def evaluateCandidate(snippet: Expr, mapping: Map[Identifier, Expr]) = {
val oldPreconditionSaved = holeFunDef.precondition
val oldBodySaved = holeFunDef.body
// restore initial precondition
holeFunDef.precondition = Some(initialPrecondition)
// get the whole body (if else...)
val accumulatedExpression = accumulatingExpression(snippet)
// set appropriate body to the function for the correct evaluation
holeFunDef.body = Some(accumulatedExpression)
import TreeOps._
val expressionToCheck =
//Globals.bodyAndPostPlug(exp)
{
val resFresh = FreshIdentifier("result", true).setType(accumulatedExpression.getType)
Let(
resFresh, accumulatedExpression,
replace(Map(ResultVariable() -> LeonVariable(resFresh)), matchToIfThenElse(holeFunDef.getPostcondition)))
}
fine("going to evaluate candidate for: " + holeFunDef)
fine("going to evaluate candidate for: " + expressionToCheck)
val count = exampleRunner.evaluate(expressionToCheck, mapping)
// if (snippet.toString == "Cons(l1.head, concat(l1.tail, l2))")
// interactivePause
holeFunDef.precondition = oldPreconditionSaved
holeFunDef.body = oldBodySaved
count
}
def synthesize: Report = {
reporter.info("Synthesis called on file: " + fileName)
......@@ -404,82 +441,62 @@ class SynthesizerForRuleExamples(
reporter.info("Going into a enumeration/testing phase.")
fine("evaluating examples: " + exampleRunner.counterExamples.mkString("\n"))
// found precondition?
found = false
// try to find it
breakable {
// go through all snippets
for (
snippet <- snippetsIterator; val snippetTree = snippet.getSnippet;
// filter if seen
if ! (seenBranchExpressions contains snippetTree.toString)
) {
finest("snippetTree is: " + snippetTree)
// note that we do not add snippets to the set of seen if enqueued
if (checkTimeout) break
// skip avoidable calls
if (!refiner.isAvoidable(snippetTree, problem.as)) {
// passed example pairs
val passCount = countPassedExamples(snippetTree)
if (passCount == exampleRunner.counterExamples.size) {
info("All examples passed. Testing snippet " + snippetTree + " right away")
if (tryToSynthesizeBranch(snippetTree)) {
// will set found if correct body is found
noBranchFoundIteration = 0
break
}
} else {
if (passCount > 0) {
finest("Snippet with pass count goes into queue: " + (snippetTree, passCount))
pq.enqueue((snippetTree, 100 + (passCount * iteration) - snippet.getWeight.toInt))
}
else {
fine("Snippet with pass count was dropped: " + (snippetTree, passCount) +
" while number of examples was: " + exampleRunner.counterExamples.size)
// add to seen if branch was not found for it
seenBranchExpressions += snippetTree.toString
}
}
} else {
fine("Refiner filtered this snippet: " + snippetTree)
seenBranchExpressions += snippetTree.toString
} // if (!refiner.isAvoidable(snippetTree)) {
// check if we this makes one test iteration
if (numberOfTested >= numberOfTestsInIteration * noBranchFoundIteration) {
reporter.info("Finalizing enumeration/testing phase.")
fine("Queue contents: " + pq.toList.take(10).mkString("\n"))
fine({ if (pq.isEmpty) "queue is empty" else "head of queue is: " + pq.head })
//interactivePause
// go and check the topmost numberOfCheckInIteration
for (i <- 1 to math.min(numberOfCheckInIteration, pq.size)) {
val nextSnippet = pq.dequeue._1
fine("dequeued nextSnippet: " + nextSnippet)
//interactivePause
if (tryToSynthesizeBranch(nextSnippet)) {
noBranchFoundIteration = 0
break
}
// dont drop snippets that were on top of queue (they may be good for else ... part)
//seenBranchExpressions += nextSnippet.toString
}
numberOfTested = 0
} else
numberOfTested += 1
} // for (snippet <- snippets
} // breakable { for (snippet <- snippets
breakable {
while(true) {
val batchSize = numberOfTestsInIteration * (1 << noBranchFoundIteration)
reporter.info("numberOfTested: " + numberOfTested)
// ranking of candidates
val candidates = {
val (it1, it2) = snippetsIterator.duplicate
snippetsIterator = it2.drop(batchSize)
it1.take(batchSize).
map(_.getSnippet).filterNot(
snip => {
if (snip.toString == "merge(sort(split(list).fst), sort(split(list).snd))") println("AAA")
(seenBranchExpressions contains snip.toString) || refiner.isAvoidable(snip, problem.as)
}
).toSeq
}
info("got candidates of size: " + candidates.size)
//interactivePause
if (candidates.size > 0) {
val ranker = new Ranker(candidates,
Evaluation(exampleRunner.counterExamples, this.evaluateCandidate _, candidates, exampleRunner),
false)
val maxCandidate = ranker.getMax
info("maxCandidate is: " + maxCandidate)
numberOfTested += batchSize
// if (candidates.exists(_.toString == "merge(sort(split(list).fst), sort(split(list).snd))")) {
// println(ranker.printTuples)
// println("AAA2")
// println("Candidates: " + candidates.zipWithIndex.map({
// case (cand, ind) => "[" + ind + "]" + cand.toString
// }).mkString(", "))
// println("Examples: " + exampleRunner.counterExamples.zipWithIndex.map({
// case (example, ind) => "[" + ind + "]" + example.toString
// }).mkString(", "))
// interactivePause
// }
//interactivePause
if (tryToSynthesizeBranch(maxCandidate)) {
noBranchFoundIteration = 0
break
}
noBranchFoundIteration += 1
}
}
}
// add to seen if branch was not found for it
//seenBranchExpressions += snippetTree.toString
// if did not found for any of the branch expressions
if (found) {
......
import scala.collection.immutable.Set
import leon.Annotations._
import leon.Utils._
object BinarySearchTree {
sealed abstract class Tree
case class Node(left: Tree, value: Int, right: Tree) extends Tree
case class Leaf() extends Tree
def contents(tree: Tree): Set[Int] = tree match {
case Leaf() => Set.empty[Int]
case Node(l, v, r) => contents(l) ++ Set(v) ++ contents(r)
}
def isSorted(tree: Tree): Boolean = tree match {
case Leaf() => true
case Node(Leaf(), v, Leaf()) => true
case Node(l@Node(_, vIn, _), v, Leaf()) => v > vIn && isSorted(l)
case Node(Leaf(), v, r@Node(_, vIn, _)) => v < vIn && isSorted(r)
case Node(l@Node(_, vInLeft, _), v, r@Node(_, vInRight, _)) =>
v > vInLeft && v < vInRight && isSorted(l) && isSorted(r)
}
def member(tree: Tree, value: Int): Boolean = {
require(isSorted(tree))
tree.isInstanceOf[Node]
} ensuring (res => res && contents(tree).contains(value) ||
(!res && !contents(tree).contains(value)))
// def member(tree: Tree, value: Int): Boolean = {
// require(isSorted(tree))
// tree match {
// case Leaf() => false
// case n @ Node(l, v, r) => if (v < value) {
// member(r, value)
// } else if (v > value) {
// member(l, value)
// } else {
// true
// }
// }
// } ensuring (_ || !(contents(tree) == contents(tree) ++ Set(value)))
//
// def insert(tree: Tree, value: Int): Node = {
// require(isSorted(tree))
// tree match {
// case Leaf() => Node(Leaf(), value, Leaf())
// case n @ Node(l, v, r) => if (v < value) {
// Node(l, v, insert(r, value))
// } else if (v > value) {
// Node(insert(l, value), v, r)
// } else {
// n
// }
// }
// } ensuring (res => contents(res) == contents(tree) ++ Set(value) && isSorted(res))
// def treeMin(tree: Node): Int = {
// require(isSorted(tree).sorted)
// tree match {
// case Node(left, v, _) => left match {
// case Leaf() => v
// case n@Node(_, _, _) => treeMin(n)
// }
// }
// }
//
// def treeMax(tree: Node): Int = {
// require(isSorted(tree).sorted)
// tree match {
// case Node(_, v, right) => right match {
// case Leaf() => v
// case n@Node(_, _, _) => treeMax(n)
// }
// }
// }
// def remove(tree: Tree, value: Int): Node = {
// require(isSorted(tree))
// tree match {
// case l @ Leaf() => l
// case n @ Node(l, v, r) => if (v < value) {
// Node(l, v, insert(r, value))
// } else if (v > value) {
// Node(insert(l, value), v, r)
// } else {
// n
// }
// }
// } ensuring (contents(_) == contents(tree) -- Set(value))
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment