diff --git a/src/main/scala/leon/synthesis/CostModel.scala b/src/main/scala/leon/synthesis/CostModel.scala index e997d99761cc89862731de51c29194d4a452e365..11d035724b13e6ecd44fde2e7de40830dbcf3a2c 100644 --- a/src/main/scala/leon/synthesis/CostModel.scala +++ b/src/main/scala/leon/synthesis/CostModel.scala @@ -3,6 +3,7 @@ package synthesis import purescala.Trees._ import purescala.TreeOps._ +import leon.xlang.Trees.LetDef import synthesis.search.Cost @@ -19,7 +20,10 @@ abstract class CostModel(val name: String) { } object CostModel { - def all: Set[CostModel] = Set(NaiveCostModel) + def all: Set[CostModel] = Set( + NaiveCostModel, + WeightedBranchesCostModel + ) } case object NaiveCostModel extends CostModel("Naive") { @@ -37,3 +41,55 @@ case object NaiveCostModel extends CostModel("Naive") { } } + +case object WeightedBranchesCostModel extends CostModel("WeightedBranches") { + + def branchesCost(e: Expr): Int = { + case class BC(cost: Int, nesting: Int) + + def pre(e: Expr, c: BC) = { + (e, c.copy(nesting = c.nesting + 1)) + } + + def costOfBranches(alts: Int, nesting: Int) = { + if (nesting > 10) { + alts + } else { + (10-nesting)*alts + } + } + + def post(e: Expr, bc: BC) = e match { + case ie : IfExpr => + (e, bc.copy(cost = bc.cost + costOfBranches(2, bc.nesting))) + case ie : LetDef => + (e, bc.copy(cost = bc.cost + costOfBranches(2, bc.nesting))) + case ie : MatchExpr => + (e, bc.copy(cost = bc.cost + costOfBranches(ie.cases.size, bc.nesting))) + case _ => + (e, bc) + } + + def combiner(cs: Seq[BC]) = { + cs.foldLeft(BC(0,0))((bc1, bc2) => BC(bc1.cost + bc2.cost, 0)) + } + + val (_, bc) = genericTransform[BC](pre, post, combiner)(BC(0, 0))(e) + + bc.cost + } + + def solutionCost(s: Solution): Cost = new Cost { + val value = { + val chooses = collectChooses(s.toExpr) + val chooseCost = chooses.foldLeft(0)((i, c) => i + problemCost(Problem.fromChoose(c)).value) + + formulaSize(s.toExpr) + branchesCost(s.toExpr) + chooseCost + } + } + + def problemCost(p: Problem): Cost = new Cost { + val value = p.xs.size + } + +}