diff --git a/src/main/scala/leon/purescala/Expressions.scala b/src/main/scala/leon/purescala/Expressions.scala index 40aa96e95b67b45452290e4a1a7fa4809fea2c01..0915a0243adc2739b0cd65cdf59a7e38d9b18cff 100644 --- a/src/main/scala/leon/purescala/Expressions.scala +++ b/src/main/scala/leon/purescala/Expressions.scala @@ -370,9 +370,9 @@ object Expressions { * [[cases]] should be nonempty. If you are not sure about this, you should use * [[purescala.Constructors#passes purescala's constructor passes]] * - * @param in - * @param out - * @param cases + * @param in The input expression + * @param out The output expression + * @param cases The cases to compare against */ case class Passes(in: Expr, out : Expr, cases : Seq[MatchCase]) extends Expr { require(cases.nonEmpty) diff --git a/src/main/scala/leon/synthesis/disambiguation/Question.scala b/src/main/scala/leon/synthesis/disambiguation/Question.scala new file mode 100644 index 0000000000000000000000000000000000000000..4074b0abe949ea8adfe0cfc946f18d7ce7c7c25e --- /dev/null +++ b/src/main/scala/leon/synthesis/disambiguation/Question.scala @@ -0,0 +1,9 @@ +package leon +package synthesis.disambiguation + +import purescala.Expressions.Expr + +/** + * @author Mikael + */ +case class Question[T <: Expr](inputs: List[Expr], current_output: T, other_outputs: List[T]) \ No newline at end of file diff --git a/src/main/scala/leon/synthesis/disambiguation/QuestionBuilder.scala b/src/main/scala/leon/synthesis/disambiguation/QuestionBuilder.scala new file mode 100644 index 0000000000000000000000000000000000000000..b3f13d4fc07b143dca2febe5724530eca698d7fd --- /dev/null +++ b/src/main/scala/leon/synthesis/disambiguation/QuestionBuilder.scala @@ -0,0 +1,144 @@ +package leon +package synthesis.disambiguation + +import synthesis.RuleClosed +import synthesis.Solution +import evaluators.DefaultEvaluator +import purescala.Expressions._ +import purescala.ExprOps +import purescala.Constructors._ +import purescala.Extractors._ +import purescala.Types.TypeTree +import purescala.Common.Identifier +import purescala.Definitions.Program +import purescala.DefOps +import grammars.ValueGrammar +import bonsai.enumerators.MemoizedEnumerator +import solvers.Model +import solvers.ModelBuilder +import scala.collection.mutable.ListBuffer + +object QuestionBuilder { + /** Sort methods for questions. You can build your own */ + trait QuestionSortingType { + def apply[T <: Expr](e: Question[T]): Int + } + object QuestionSortingType { + case object IncreasingInputSize extends QuestionSortingType { + def apply[T <: Expr](q: Question[T]) = q.inputs.map(i => ExprOps.count(e => 1)(i)).sum + } + case object DecreasingInputSize extends QuestionSortingType{ + def apply[T <: Expr](q: Question[T]) = -IncreasingInputSize(q) + } + } + // Add more if needed. + + /** Sort methods for question's answers. You can (and should) build your own. */ + abstract class AlternativeSortingType[-T <: Expr] extends Ordering[T] { self => + /** Prioritizes this comparison operator agains the second one. */ + def &&(other: AlternativeSortingType[T]): AlternativeSortingType[T] = new AlternativeSortingType[T] { + def compare(e: T, f: T): Int = { + val ce = self.compare(e, f) + if(ce == 0) other.compare(e, f) else ce + } + } + } + object AlternativeSortingType { + /** Presents shortest alternatives first */ + case class ShorterIsBetter()(implicit c: LeonContext) extends AlternativeSortingType[Expr] { + def compare(e: Expr, f: Expr) = e.asString.length - f.asString.length + } + /** Presents balanced alternatives first */ + case class BalancedParenthesisIsBetter[T <: Expr]()(implicit c: LeonContext) extends AlternativeSortingType[T] { + def convert(e: T): Int = { + val s = e.asString + var openP, openB, openC = 0 + for(c <- s) c match { + case '(' if openP >= 0 => openP += 1 + case ')' => openP -= 1 + case '{' if openB >= 0 => openB += 1 + case '}' => openB -= 1 + case '[' if openC >= 0 => openC += 1 + case ']' => openC -= 1 + case _ => + } + Math.abs(openP) + Math.abs(openB) + Math.abs(openC) + } + def compare(e: T, f: T): Int = convert(e) - convert(f) + } + } +} + +/** + * Builds a set of disambiguating questions for the problem + * + * {{{ + * def f(input: input.getType): T = + * [element of r.solution] + * }}} + * + * @param input The identifier of the unique function's input. Must be typed or the type should be defined by setArgumentType + * @param ruleApplication The set of solutions for the body of f + * @param filter A function filtering which outputs should be considered for comparison. + * @return An ordered + * + */ +class QuestionBuilder[T <: Expr](input: List[Identifier], ruleApplication: RuleClosed, filter: Expr => Option[T])(implicit c: LeonContext, p: Program) { + import QuestionBuilder._ + private var _argTypes = input.map(_.getType) + private var _questionSorMethod: QuestionSortingType = QuestionSortingType.IncreasingInputSize + private var _alternativeSortMethod: AlternativeSortingType[T] = AlternativeSortingType.BalancedParenthesisIsBetter() && AlternativeSortingType.ShorterIsBetter() + private var solutionsToTake = 15 + private var expressionsToTake = 15 + + /** Sets the way to sort questions. See [[QuestionSortingType]] */ + def setSortQuestionBy(questionSorMethod: QuestionSortingType) = _questionSorMethod = questionSorMethod + /** Sets the way to sort alternatives. See [[AlternativeSortingType]] */ + def setSortAlternativesBy(alternativeSortMethod: AlternativeSortingType[T]) = _alternativeSortMethod = alternativeSortMethod + /** Sets the argument type. Not needed if the input identifier is already assigned a type. */ + def setArgumentType(argTypes: List[TypeTree]) = _argTypes = argTypes + /** Sets the number of solutions to consider. Default is 15 */ + def setSolutionsToTake(n: Int) = solutionsToTake = n + /** Sets the number of expressions to consider. Default is 15 */ + def setExpressionsToTake(n: Int) = expressionsToTake = n + + private def run(s: Solution, elems: Seq[(Identifier, Expr)]): Option[Expr] = { + val newProgram = DefOps.addFunDefs(p, s.defs, p.definedFunctions.head) + val e = new DefaultEvaluator(c, newProgram) + val model = new ModelBuilder + model ++= elems + val modelResult = model.result() + e.eval(s.term, modelResult).result + } + + /** Returns a list of input/output questions to ask to the user. */ + def result(): List[Question[T]] = { + if(ruleApplication.solutions.isEmpty) return Nil + + val enum = new MemoizedEnumerator[TypeTree, Expr](ValueGrammar.getProductions) + val values = enum.iterator(tupleTypeWrap(_argTypes)) + val instantiations = values.map { + v => input.zip(unwrapTuple(v, input.size)) + } + + val enumerated_inputs = instantiations.take(expressionsToTake).toList + + val solution = ruleApplication.solutions.head + val alternatives = ruleApplication.solutions.drop(1).take(solutionsToTake).toList + val questions = ListBuffer[Question[T]]() + for{possible_input <- enumerated_inputs + current_output_nonfiltered <- run(solution, possible_input) + current_output <- filter(current_output_nonfiltered)} { + val alternative_outputs = ( + for{alternative <- alternatives + alternative_output <- run(alternative, possible_input) + alternative_output_filtered <- filter(alternative_output) + if alternative_output != current_output + } yield alternative_output_filtered).distinct + if(alternative_outputs.nonEmpty) { + questions += Question(possible_input.map(_._2), current_output, alternative_outputs.sortWith((e,f) => _alternativeSortMethod.compare(e, f) < 0)) + } + } + questions.toList.sortBy(_questionSorMethod(_)) + } +} \ No newline at end of file diff --git a/src/main/scala/leon/synthesis/rules/StringRender.scala b/src/main/scala/leon/synthesis/rules/StringRender.scala index 507c974c766343ef0b78cf9bac69974b2ca8aba9..c088c201870892823141394cf5703f9975b72847 100644 --- a/src/main/scala/leon/synthesis/rules/StringRender.scala +++ b/src/main/scala/leon/synthesis/rules/StringRender.scala @@ -4,32 +4,33 @@ package leon package synthesis package rules -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.TypeOps -import purescala.Extractors._ -import purescala.Constructors._ -import purescala.Types._ -import purescala.Definitions._ -import leon.utils.DebugSectionSynthesis -import leon.purescala.Common.{Identifier, FreshIdentifier} +import scala.annotation.tailrec +import scala.collection.mutable.ListBuffer + +import bonsai.enumerators.MemoizedEnumerator +import leon.evaluators.DefaultEvaluator +import leon.evaluators.StringTracingEvaluator +import leon.grammars.ValueGrammar +import leon.programsets.DirectProgramSet +import leon.programsets.JoinProgramSet +import leon.purescala.Common.FreshIdentifier +import leon.purescala.Common.Identifier +import leon.purescala.DefOps import leon.purescala.Definitions.FunDef -import leon.utils.IncrementalMap import leon.purescala.Definitions.FunDef import leon.purescala.Definitions.ValDef -import scala.collection.mutable.ListBuffer import leon.purescala.ExprOps -import leon.evaluators.Evaluator -import leon.evaluators.DefaultEvaluator -import leon.evaluators.StringTracingEvaluator import leon.solvers.Model import leon.solvers.ModelBuilder import leon.solvers.string.StringSolver -import scala.annotation.tailrec -import leon.purescala.DefOps -import leon.programsets.{UnionProgramSet, DirectProgramSet, JoinProgramSet} -import bonsai.enumerators.MemoizedEnumerator -import leon.grammars.ValueGrammar +import leon.utils.DebugSectionSynthesis +import purescala.Constructors._ +import purescala.Definitions._ +import purescala.ExprOps._ +import purescala.Expressions._ +import purescala.Extractors._ +import purescala.TypeOps +import purescala.Types._ /** A template generator for a given type tree. @@ -72,6 +73,8 @@ abstract class TypedTemplateGenerator(t: TypeTree) { case object StringRender extends Rule("StringRender") { type WithIds[T] = (T, List[Identifier]) + var EDIT_ME = "_edit_me_" + var _defaultTypeToString: Option[Map[TypeTree, FunDef]] = None def defaultMapTypeToString()(implicit hctx: SearchContext): Map[TypeTree, FunDef] = { @@ -183,46 +186,15 @@ case object StringRender extends Rule("StringRender") { solutionStreamToRuleApplication(p, leon.utils.StreamUtils.interleave(tagged_solutions)) } - case class Question(input: Expr, current_output: String, other_outputs: Set[String]) - /** Find ambiguities not containing _edit_me_ to ask to the user */ - def askQuestion(input: Identifier, argType: TypeTree, r: RuleClosed)(implicit c: LeonContext, p: Program): List[Question] = { - if(r.solutions.isEmpty) return Nil - val enum = new MemoizedEnumerator[TypeTree, Expr](ValueGrammar.getProductions) - val iter = enum.iterator(argType) - - def run(s: Solution, elem: Expr): Option[String] = { - val newProgram = DefOps.addFunDefs(p, s.defs, p.definedFunctions.head) - val e = new StringTracingEvaluator(c, newProgram) - val model = new ModelBuilder - model ++= List(input -> elem) - val modelResult = model.result() - e.eval(s.term, modelResult).result match { - case Some((StringLiteral(s), _)) => Some(s) - case _ => None - } - } - - val iterated = iter.take(15).toList - - val solution = r.solutions.head - val alternatives = r.solutions.drop(1).take(15) - val questions = ListBuffer[Question]() - for(elem <- iterated) { - val current_output = run(solution, elem).get - var possible_outputs = Set[String]() - for(alternative <- alternatives) { - run(alternative, elem) match { - case Some(s) if !s.contains("_edit_me_") && s != current_output => possible_outputs += s - case _ => - } - } - if(possible_outputs.nonEmpty) { - questions += Question(elem, current_output, possible_outputs) - } - } - questions.toList.sortBy(question => ExprOps.count(e => 1)(question.input)) - } //TODO: Need to ask these questions to the user, but in the background. Loop by adding new examples. Test with list. + def askQuestion(input: List[Identifier], r: RuleClosed)(implicit c: LeonContext, p: Program): List[disambiguation.Question[StringLiteral]] = { + //if !s.contains(EDIT_ME) + val qb = new disambiguation.QuestionBuilder(input, r, (expr: Expr) => expr match { + case s@StringLiteral(slv) if !slv.contains(EDIT_ME) => Some(s) + case _ => None + }) + qb.result() + } /** Converts the stream of solutions to a RuleApplication */ def solutionStreamToRuleApplication(p: Problem, solutions: Stream[(Seq[(FunDef, WithIds[Expr])], WithIds[Expr], Assignment)]): RuleApplication = { @@ -230,11 +202,11 @@ case object StringRender extends Rule("StringRender") { RuleClosed( for((funDefsBodies, (singleTemplate, ids), assignment) <- solutions) yield { val fds = for((fd, (body, ids)) <- funDefsBodies) yield { - val initMap = ids.map(_ -> StringLiteral("_edit_me_")).toMap + val initMap = ids.map(_ -> StringLiteral(EDIT_ME)).toMap fd.body = Some(ExprOps.simplifyString(ExprOps.replaceFromIDs(initMap ++ assignment.mapValues(StringLiteral), body))) fd } - val initMap = ids.map(_ -> StringLiteral("_edit_me_")).toMap + val initMap = ids.map(_ -> StringLiteral(EDIT_ME)).toMap val term = ExprOps.simplifyString(ExprOps.replaceFromIDs(initMap ++ assignment.mapValues(StringLiteral), singleTemplate)) val (finalTerm, finalDefs) = makeFunctionsUnique(term, fds.toSet) @@ -456,9 +428,9 @@ case object StringRender extends Rule("StringRender") { val res = findSolutions(examples, expr, funDefs.values.toSeq) res match { case r: RuleClosed => - val questions = askQuestion(p.as(0), p.as(0).getType, r)(hctx.context, hctx.program) + val questions = askQuestion(p.as, r)(hctx.context, hctx.program) println("Questions:") - println(questions.map(q => "For " + q.input + ", res = " + q.current_output + ", could also be " + q.other_outputs.toSet.map((s: String) => "\""+ s +"\"").mkString(",")).mkString("\n")) + println(questions.map(q => "For (" + q.inputs.mkString(", ") + "), res = " + q.current_output + ", could also be " + q.other_outputs.toSet.map((s: StringLiteral) => s.asString).mkString(",")).mkString("\n")) case _ => } res