Skip to content
Snippets Groups Projects
Commit 32eb10be authored by Mikaël Mayer's avatar Mikaël Mayer
Browse files

Added disambiguation module for synthesis.

parent da7d4880
No related branches found
No related tags found
No related merge requests found
...@@ -370,9 +370,9 @@ object Expressions { ...@@ -370,9 +370,9 @@ object Expressions {
* [[cases]] should be nonempty. If you are not sure about this, you should use * [[cases]] should be nonempty. If you are not sure about this, you should use
* [[purescala.Constructors#passes purescala's constructor passes]] * [[purescala.Constructors#passes purescala's constructor passes]]
* *
* @param in * @param in The input expression
* @param out * @param out The output expression
* @param cases * @param cases The cases to compare against
*/ */
case class Passes(in: Expr, out : Expr, cases : Seq[MatchCase]) extends Expr { case class Passes(in: Expr, out : Expr, cases : Seq[MatchCase]) extends Expr {
require(cases.nonEmpty) require(cases.nonEmpty)
......
package leon
package synthesis.disambiguation
import purescala.Expressions.Expr
/**
* @author Mikael
*/
case class Question[T <: Expr](inputs: List[Expr], current_output: T, other_outputs: List[T])
\ No newline at end of file
package leon
package synthesis.disambiguation
import synthesis.RuleClosed
import synthesis.Solution
import evaluators.DefaultEvaluator
import purescala.Expressions._
import purescala.ExprOps
import purescala.Constructors._
import purescala.Extractors._
import purescala.Types.TypeTree
import purescala.Common.Identifier
import purescala.Definitions.Program
import purescala.DefOps
import grammars.ValueGrammar
import bonsai.enumerators.MemoizedEnumerator
import solvers.Model
import solvers.ModelBuilder
import scala.collection.mutable.ListBuffer
object QuestionBuilder {
/** Sort methods for questions. You can build your own */
trait QuestionSortingType {
def apply[T <: Expr](e: Question[T]): Int
}
object QuestionSortingType {
case object IncreasingInputSize extends QuestionSortingType {
def apply[T <: Expr](q: Question[T]) = q.inputs.map(i => ExprOps.count(e => 1)(i)).sum
}
case object DecreasingInputSize extends QuestionSortingType{
def apply[T <: Expr](q: Question[T]) = -IncreasingInputSize(q)
}
}
// Add more if needed.
/** Sort methods for question's answers. You can (and should) build your own. */
abstract class AlternativeSortingType[-T <: Expr] extends Ordering[T] { self =>
/** Prioritizes this comparison operator agains the second one. */
def &&(other: AlternativeSortingType[T]): AlternativeSortingType[T] = new AlternativeSortingType[T] {
def compare(e: T, f: T): Int = {
val ce = self.compare(e, f)
if(ce == 0) other.compare(e, f) else ce
}
}
}
object AlternativeSortingType {
/** Presents shortest alternatives first */
case class ShorterIsBetter()(implicit c: LeonContext) extends AlternativeSortingType[Expr] {
def compare(e: Expr, f: Expr) = e.asString.length - f.asString.length
}
/** Presents balanced alternatives first */
case class BalancedParenthesisIsBetter[T <: Expr]()(implicit c: LeonContext) extends AlternativeSortingType[T] {
def convert(e: T): Int = {
val s = e.asString
var openP, openB, openC = 0
for(c <- s) c match {
case '(' if openP >= 0 => openP += 1
case ')' => openP -= 1
case '{' if openB >= 0 => openB += 1
case '}' => openB -= 1
case '[' if openC >= 0 => openC += 1
case ']' => openC -= 1
case _ =>
}
Math.abs(openP) + Math.abs(openB) + Math.abs(openC)
}
def compare(e: T, f: T): Int = convert(e) - convert(f)
}
}
}
/**
* Builds a set of disambiguating questions for the problem
*
* {{{
* def f(input: input.getType): T =
* [element of r.solution]
* }}}
*
* @param input The identifier of the unique function's input. Must be typed or the type should be defined by setArgumentType
* @param ruleApplication The set of solutions for the body of f
* @param filter A function filtering which outputs should be considered for comparison.
* @return An ordered
*
*/
class QuestionBuilder[T <: Expr](input: List[Identifier], ruleApplication: RuleClosed, filter: Expr => Option[T])(implicit c: LeonContext, p: Program) {
import QuestionBuilder._
private var _argTypes = input.map(_.getType)
private var _questionSorMethod: QuestionSortingType = QuestionSortingType.IncreasingInputSize
private var _alternativeSortMethod: AlternativeSortingType[T] = AlternativeSortingType.BalancedParenthesisIsBetter() && AlternativeSortingType.ShorterIsBetter()
private var solutionsToTake = 15
private var expressionsToTake = 15
/** Sets the way to sort questions. See [[QuestionSortingType]] */
def setSortQuestionBy(questionSorMethod: QuestionSortingType) = _questionSorMethod = questionSorMethod
/** Sets the way to sort alternatives. See [[AlternativeSortingType]] */
def setSortAlternativesBy(alternativeSortMethod: AlternativeSortingType[T]) = _alternativeSortMethod = alternativeSortMethod
/** Sets the argument type. Not needed if the input identifier is already assigned a type. */
def setArgumentType(argTypes: List[TypeTree]) = _argTypes = argTypes
/** Sets the number of solutions to consider. Default is 15 */
def setSolutionsToTake(n: Int) = solutionsToTake = n
/** Sets the number of expressions to consider. Default is 15 */
def setExpressionsToTake(n: Int) = expressionsToTake = n
private def run(s: Solution, elems: Seq[(Identifier, Expr)]): Option[Expr] = {
val newProgram = DefOps.addFunDefs(p, s.defs, p.definedFunctions.head)
val e = new DefaultEvaluator(c, newProgram)
val model = new ModelBuilder
model ++= elems
val modelResult = model.result()
e.eval(s.term, modelResult).result
}
/** Returns a list of input/output questions to ask to the user. */
def result(): List[Question[T]] = {
if(ruleApplication.solutions.isEmpty) return Nil
val enum = new MemoizedEnumerator[TypeTree, Expr](ValueGrammar.getProductions)
val values = enum.iterator(tupleTypeWrap(_argTypes))
val instantiations = values.map {
v => input.zip(unwrapTuple(v, input.size))
}
val enumerated_inputs = instantiations.take(expressionsToTake).toList
val solution = ruleApplication.solutions.head
val alternatives = ruleApplication.solutions.drop(1).take(solutionsToTake).toList
val questions = ListBuffer[Question[T]]()
for{possible_input <- enumerated_inputs
current_output_nonfiltered <- run(solution, possible_input)
current_output <- filter(current_output_nonfiltered)} {
val alternative_outputs = (
for{alternative <- alternatives
alternative_output <- run(alternative, possible_input)
alternative_output_filtered <- filter(alternative_output)
if alternative_output != current_output
} yield alternative_output_filtered).distinct
if(alternative_outputs.nonEmpty) {
questions += Question(possible_input.map(_._2), current_output, alternative_outputs.sortWith((e,f) => _alternativeSortMethod.compare(e, f) < 0))
}
}
questions.toList.sortBy(_questionSorMethod(_))
}
}
\ No newline at end of file
...@@ -4,32 +4,33 @@ package leon ...@@ -4,32 +4,33 @@ package leon
package synthesis package synthesis
package rules package rules
import purescala.Expressions._ import scala.annotation.tailrec
import purescala.ExprOps._ import scala.collection.mutable.ListBuffer
import purescala.TypeOps
import purescala.Extractors._ import bonsai.enumerators.MemoizedEnumerator
import purescala.Constructors._ import leon.evaluators.DefaultEvaluator
import purescala.Types._ import leon.evaluators.StringTracingEvaluator
import purescala.Definitions._ import leon.grammars.ValueGrammar
import leon.utils.DebugSectionSynthesis import leon.programsets.DirectProgramSet
import leon.purescala.Common.{Identifier, FreshIdentifier} import leon.programsets.JoinProgramSet
import leon.purescala.Common.FreshIdentifier
import leon.purescala.Common.Identifier
import leon.purescala.DefOps
import leon.purescala.Definitions.FunDef import leon.purescala.Definitions.FunDef
import leon.utils.IncrementalMap
import leon.purescala.Definitions.FunDef import leon.purescala.Definitions.FunDef
import leon.purescala.Definitions.ValDef import leon.purescala.Definitions.ValDef
import scala.collection.mutable.ListBuffer
import leon.purescala.ExprOps import leon.purescala.ExprOps
import leon.evaluators.Evaluator
import leon.evaluators.DefaultEvaluator
import leon.evaluators.StringTracingEvaluator
import leon.solvers.Model import leon.solvers.Model
import leon.solvers.ModelBuilder import leon.solvers.ModelBuilder
import leon.solvers.string.StringSolver import leon.solvers.string.StringSolver
import scala.annotation.tailrec import leon.utils.DebugSectionSynthesis
import leon.purescala.DefOps import purescala.Constructors._
import leon.programsets.{UnionProgramSet, DirectProgramSet, JoinProgramSet} import purescala.Definitions._
import bonsai.enumerators.MemoizedEnumerator import purescala.ExprOps._
import leon.grammars.ValueGrammar import purescala.Expressions._
import purescala.Extractors._
import purescala.TypeOps
import purescala.Types._
/** A template generator for a given type tree. /** A template generator for a given type tree.
...@@ -72,6 +73,8 @@ abstract class TypedTemplateGenerator(t: TypeTree) { ...@@ -72,6 +73,8 @@ abstract class TypedTemplateGenerator(t: TypeTree) {
case object StringRender extends Rule("StringRender") { case object StringRender extends Rule("StringRender") {
type WithIds[T] = (T, List[Identifier]) type WithIds[T] = (T, List[Identifier])
var EDIT_ME = "_edit_me_"
var _defaultTypeToString: Option[Map[TypeTree, FunDef]] = None var _defaultTypeToString: Option[Map[TypeTree, FunDef]] = None
def defaultMapTypeToString()(implicit hctx: SearchContext): Map[TypeTree, FunDef] = { def defaultMapTypeToString()(implicit hctx: SearchContext): Map[TypeTree, FunDef] = {
...@@ -183,46 +186,15 @@ case object StringRender extends Rule("StringRender") { ...@@ -183,46 +186,15 @@ case object StringRender extends Rule("StringRender") {
solutionStreamToRuleApplication(p, leon.utils.StreamUtils.interleave(tagged_solutions)) solutionStreamToRuleApplication(p, leon.utils.StreamUtils.interleave(tagged_solutions))
} }
case class Question(input: Expr, current_output: String, other_outputs: Set[String])
/** Find ambiguities not containing _edit_me_ to ask to the user */ /** Find ambiguities not containing _edit_me_ to ask to the user */
def askQuestion(input: Identifier, argType: TypeTree, r: RuleClosed)(implicit c: LeonContext, p: Program): List[Question] = { def askQuestion(input: List[Identifier], r: RuleClosed)(implicit c: LeonContext, p: Program): List[disambiguation.Question[StringLiteral]] = {
if(r.solutions.isEmpty) return Nil //if !s.contains(EDIT_ME)
val enum = new MemoizedEnumerator[TypeTree, Expr](ValueGrammar.getProductions) val qb = new disambiguation.QuestionBuilder(input, r, (expr: Expr) => expr match {
val iter = enum.iterator(argType) case s@StringLiteral(slv) if !slv.contains(EDIT_ME) => Some(s)
case _ => None
def run(s: Solution, elem: Expr): Option[String] = { })
val newProgram = DefOps.addFunDefs(p, s.defs, p.definedFunctions.head) qb.result()
val e = new StringTracingEvaluator(c, newProgram) }
val model = new ModelBuilder
model ++= List(input -> elem)
val modelResult = model.result()
e.eval(s.term, modelResult).result match {
case Some((StringLiteral(s), _)) => Some(s)
case _ => None
}
}
val iterated = iter.take(15).toList
val solution = r.solutions.head
val alternatives = r.solutions.drop(1).take(15)
val questions = ListBuffer[Question]()
for(elem <- iterated) {
val current_output = run(solution, elem).get
var possible_outputs = Set[String]()
for(alternative <- alternatives) {
run(alternative, elem) match {
case Some(s) if !s.contains("_edit_me_") && s != current_output => possible_outputs += s
case _ =>
}
}
if(possible_outputs.nonEmpty) {
questions += Question(elem, current_output, possible_outputs)
}
}
questions.toList.sortBy(question => ExprOps.count(e => 1)(question.input))
} //TODO: Need to ask these questions to the user, but in the background. Loop by adding new examples. Test with list.
/** Converts the stream of solutions to a RuleApplication */ /** Converts the stream of solutions to a RuleApplication */
def solutionStreamToRuleApplication(p: Problem, solutions: Stream[(Seq[(FunDef, WithIds[Expr])], WithIds[Expr], Assignment)]): RuleApplication = { def solutionStreamToRuleApplication(p: Problem, solutions: Stream[(Seq[(FunDef, WithIds[Expr])], WithIds[Expr], Assignment)]): RuleApplication = {
...@@ -230,11 +202,11 @@ case object StringRender extends Rule("StringRender") { ...@@ -230,11 +202,11 @@ case object StringRender extends Rule("StringRender") {
RuleClosed( RuleClosed(
for((funDefsBodies, (singleTemplate, ids), assignment) <- solutions) yield { for((funDefsBodies, (singleTemplate, ids), assignment) <- solutions) yield {
val fds = for((fd, (body, ids)) <- funDefsBodies) yield { val fds = for((fd, (body, ids)) <- funDefsBodies) yield {
val initMap = ids.map(_ -> StringLiteral("_edit_me_")).toMap val initMap = ids.map(_ -> StringLiteral(EDIT_ME)).toMap
fd.body = Some(ExprOps.simplifyString(ExprOps.replaceFromIDs(initMap ++ assignment.mapValues(StringLiteral), body))) fd.body = Some(ExprOps.simplifyString(ExprOps.replaceFromIDs(initMap ++ assignment.mapValues(StringLiteral), body)))
fd fd
} }
val initMap = ids.map(_ -> StringLiteral("_edit_me_")).toMap val initMap = ids.map(_ -> StringLiteral(EDIT_ME)).toMap
val term = ExprOps.simplifyString(ExprOps.replaceFromIDs(initMap ++ assignment.mapValues(StringLiteral), singleTemplate)) val term = ExprOps.simplifyString(ExprOps.replaceFromIDs(initMap ++ assignment.mapValues(StringLiteral), singleTemplate))
val (finalTerm, finalDefs) = makeFunctionsUnique(term, fds.toSet) val (finalTerm, finalDefs) = makeFunctionsUnique(term, fds.toSet)
...@@ -456,9 +428,9 @@ case object StringRender extends Rule("StringRender") { ...@@ -456,9 +428,9 @@ case object StringRender extends Rule("StringRender") {
val res = findSolutions(examples, expr, funDefs.values.toSeq) val res = findSolutions(examples, expr, funDefs.values.toSeq)
res match { res match {
case r: RuleClosed => case r: RuleClosed =>
val questions = askQuestion(p.as(0), p.as(0).getType, r)(hctx.context, hctx.program) val questions = askQuestion(p.as, r)(hctx.context, hctx.program)
println("Questions:") println("Questions:")
println(questions.map(q => "For " + q.input + ", res = " + q.current_output + ", could also be " + q.other_outputs.toSet.map((s: String) => "\""+ s +"\"").mkString(",")).mkString("\n")) println(questions.map(q => "For (" + q.inputs.mkString(", ") + "), res = " + q.current_output + ", could also be " + q.other_outputs.toSet.map((s: StringLiteral) => s.asString).mkString(",")).mkString("\n"))
case _ => case _ =>
} }
res res
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment