diff --git a/src/main/scala/leon/synthesis/CostModel.scala b/src/main/scala/leon/synthesis/CostModel.scala index f8fb6d6944ef9e1f2a5e6ef473e8c32f1ede8b80..5d71c6ec85e59ac9228ea8be40ff9628ae1b63f4 100644 --- a/src/main/scala/leon/synthesis/CostModel.scala +++ b/src/main/scala/leon/synthesis/CostModel.scala @@ -6,12 +6,16 @@ import purescala.TreeOps._ import synthesis.search.Cost -abstract class CostModel(name: String) { +abstract class CostModel(val name: String) { def solutionCost(s: Solution): Cost def problemCost(p: Problem): Cost def ruleAppCost(r: Rule, app: RuleApplication): Cost } +object CostModel { + def all: Set[CostModel] = Set(NaiveCostModel) +} + case object NaiveCostModel extends CostModel("Naive") { def solutionCost(s: Solution): Cost = new Cost { val value = { diff --git a/src/main/scala/leon/synthesis/SynthesisPhase.scala b/src/main/scala/leon/synthesis/SynthesisPhase.scala index ad21c6f7b074ad43860593ce81e36af31a690a41..e231e6ee323ae80bcd7dc8d28205e938fd84a5b4 100644 --- a/src/main/scala/leon/synthesis/SynthesisPhase.scala +++ b/src/main/scala/leon/synthesis/SynthesisPhase.scala @@ -19,6 +19,7 @@ object SynthesisPhase extends LeonPhase[Program, Program] { LeonFlagOptionDef( "derivtrees", "--derivtrees", "Generate derivation trees"), LeonFlagOptionDef( "firstonly", "--firstonly", "Stop as soon as one synthesis solution is found"), LeonValueOptionDef("timeout", "--timeout=T", "Timeout after T seconds when searching for synthesis solutions .."), + LeonValueOptionDef("costmodel", "--costmodel=cm", "Use a specific cost model for this search"), LeonValueOptionDef("functions", "--functions=f1:f2", "Limit synthesis of choose found within f1,f2,..") ) @@ -39,6 +40,13 @@ object SynthesisPhase extends LeonPhase[Program, Program] { inPlace = true case LeonValueOption("functions", ListValue(fs)) => options = options.copy(filterFuns = Some(fs.toSet)) + case LeonValueOption("costmodel", cm) => + CostModel.all.find(_.name.toLowerCase == cm.toLowerCase) match { + case Some(model) => + options = options.copy(costModel = model) + case None => + ctx.reporter.fatalError("Unknown cost model: "+cm) + } case LeonValueOption("timeout", t) => try { options = options.copy(timeoutMs = Some(t.toLong))