From c3faef4e7a7d2e7bfa565e3d00785e9daa742531 Mon Sep 17 00:00:00 2001 From: Etienne Kneuss <ekneuss@gmail.com> Date: Thu, 13 Dec 2012 18:27:00 +0100 Subject: [PATCH] Implement WeightedBranches cost model This cost model penalizes outer branches. Branches lose their weight as the nesting increases. LetDefs are assumed recursive and account for a static number of branches, 2. --- src/main/scala/leon/synthesis/CostModel.scala | 58 ++++++++++++++++++- 1 file changed, 57 insertions(+), 1 deletion(-) diff --git a/src/main/scala/leon/synthesis/CostModel.scala b/src/main/scala/leon/synthesis/CostModel.scala index e997d9976..11d035724 100644 --- a/src/main/scala/leon/synthesis/CostModel.scala +++ b/src/main/scala/leon/synthesis/CostModel.scala @@ -3,6 +3,7 @@ package synthesis import purescala.Trees._ import purescala.TreeOps._ +import leon.xlang.Trees.LetDef import synthesis.search.Cost @@ -19,7 +20,10 @@ abstract class CostModel(val name: String) { } object CostModel { - def all: Set[CostModel] = Set(NaiveCostModel) + def all: Set[CostModel] = Set( + NaiveCostModel, + WeightedBranchesCostModel + ) } case object NaiveCostModel extends CostModel("Naive") { @@ -37,3 +41,55 @@ case object NaiveCostModel extends CostModel("Naive") { } } + +case object WeightedBranchesCostModel extends CostModel("WeightedBranches") { + + def branchesCost(e: Expr): Int = { + case class BC(cost: Int, nesting: Int) + + def pre(e: Expr, c: BC) = { + (e, c.copy(nesting = c.nesting + 1)) + } + + def costOfBranches(alts: Int, nesting: Int) = { + if (nesting > 10) { + alts + } else { + (10-nesting)*alts + } + } + + def post(e: Expr, bc: BC) = e match { + case ie : IfExpr => + (e, bc.copy(cost = bc.cost + costOfBranches(2, bc.nesting))) + case ie : LetDef => + (e, bc.copy(cost = bc.cost + costOfBranches(2, bc.nesting))) + case ie : MatchExpr => + (e, bc.copy(cost = bc.cost + costOfBranches(ie.cases.size, bc.nesting))) + case _ => + (e, bc) + } + + def combiner(cs: Seq[BC]) = { + cs.foldLeft(BC(0,0))((bc1, bc2) => BC(bc1.cost + bc2.cost, 0)) + } + + val (_, bc) = genericTransform[BC](pre, post, combiner)(BC(0, 0))(e) + + bc.cost + } + + def solutionCost(s: Solution): Cost = new Cost { + val value = { + val chooses = collectChooses(s.toExpr) + val chooseCost = chooses.foldLeft(0)((i, c) => i + problemCost(Problem.fromChoose(c)).value) + + formulaSize(s.toExpr) + branchesCost(s.toExpr) + chooseCost + } + } + + def problemCost(p: Problem): Cost = new Cost { + val value = p.xs.size + } + +} -- GitLab