diff --git a/src/main/scala/leon/purescala/SelfPrettyPrinter.scala b/src/main/scala/leon/purescala/SelfPrettyPrinter.scala index 5fea29a6e1ca0b8a9bb5f6ef9bdf6e65a9fd9d32..32474bcb918d6cd6a2fe2d0ec4cba53fcc5c63a3 100644 --- a/src/main/scala/leon/purescala/SelfPrettyPrinter.scala +++ b/src/main/scala/leon/purescala/SelfPrettyPrinter.scala @@ -27,7 +27,7 @@ object SelfPrettyPrinter { def prettyPrintersForType(inputType: TypeTree)(implicit ctx: LeonContext, program: Program): Stream[Lambda] = { (new SelfPrettyPrinter).prettyPrintersForType(inputType) } - def print(v: Expr, orElse: =>String, excluded: FunDef => Boolean = Set())(implicit ctx: LeonContext, program: Program): String = { + def print(v: Expr, orElse: =>String, excluded: Set[FunDef] = Set())(implicit ctx: LeonContext, program: Program): String = { (new SelfPrettyPrinter).print(v, orElse, excluded) } } @@ -38,9 +38,13 @@ object SelfPrettyPrinter { * @return a user defined string for the given typed expression. */ class SelfPrettyPrinter { private var allowedFunctions = Set[FunDef]() - + private var excluded = Set[FunDef]() + /** Functions whose name does not need to end with `tostring` or which can be abstract, i.e. which may contain a choose construct.*/ def allowFunction(fd: FunDef) = { allowedFunctions += fd; this } + def excludeFunctions(fds: Set[FunDef]) = { excluded ++= fds; this } + def excludeFunction(fd: FunDef) = { excluded += fd; this } + /** Returns a list of possible lambdas that can transform the input type to a String*/ def prettyPrintersForType(inputType: TypeTree/*, existingPp: Map[TypeTree, List[Lambda]] = Map()*/)(implicit ctx: LeonContext, program: Program): Stream[Lambda] = { // Use the other argument if you need recursive typing (?) @@ -48,48 +52,57 @@ class SelfPrettyPrinter { fd => val isCandidate = fd.returnType == StringType && fd.params.length >= 1 && - allowedFunctions(fd) || ( + !excluded(fd) && + (allowedFunctions(fd) || ( //TypeOps.isSubtypeOf(v.getType, fd.params.head.getType) && fd.id.name.toLowerCase().endsWith("tostring") && program.callGraph.transitiveCallees(fd).forall { fde => !purescala.ExprOps.exists( _.isInstanceOf[Choose])(fde.fullBody) - }) + })) if(isCandidate) { - // InputType is concrete, the types of params may be abstract. - TypeOps.canBeSubtypeOf(inputType, fd.tparams.map(_.tp), fd.params.head.getType) match { - case Some(genericTypeMap) => - val defGenericTypeMap = genericTypeMap.map{ case (k, v) => (Definitions.TypeParameterDef(k), v) } - def gatherPrettyPrinters(funIds: List[Identifier], acc: ListBuffer[Stream[Lambda]] = ListBuffer()): Option[Stream[List[Lambda]]] = funIds match { - case Nil => Some(StreamUtils.cartesianProduct(acc.toList)) - case funId::tail => // For each function, find an expression which could be provided if it exists. - funId.getType match { - case FunctionType(Seq(in), StringType) => // Should have one argument. - val candidates = prettyPrintersForType(in) - gatherPrettyPrinters(tail, acc += candidates) - case _ => None - } - } - val funIds = fd.params.tail.map(x => TypeOps.instantiateType(x.id, defGenericTypeMap)).toList - gatherPrettyPrinters(funIds) match { - case Some(l) => for(lambdas <- l) yield { - val x = FreshIdentifier("x", fd.params.head.getType) // verify the type - Lambda(Seq(ValDef(x)), functionInvocation(fd, Variable(x)::lambdas)) - } - case _ => Nil - } - case None => Nil + prettyPrinterFromCandidate(fd, inputType) + } else Stream.Empty + } + } + + + def prettyPrinterFromCandidate(fd: FunDef, inputType: TypeTree)(implicit ctx: LeonContext, program: Program): Stream[Lambda] = { + TypeOps.canBeSubtypeOf(inputType, fd.tparams.map(_.tp), fd.params.head.getType) match { + case Some(genericTypeMap) => + val defGenericTypeMap = genericTypeMap.map{ case (k, v) => (Definitions.TypeParameterDef(k), v) } + def gatherPrettyPrinters(funIds: List[Identifier], acc: ListBuffer[Stream[Lambda]] = ListBuffer()): Option[Stream[List[Lambda]]] = funIds match { + case Nil => Some(StreamUtils.cartesianProduct(acc.toList)) + case funId::tail => // For each function, find an expression which could be provided if it exists. + funId.getType match { + case FunctionType(Seq(in), StringType) => // Should have one argument. + val candidates = prettyPrintersForType(in) + gatherPrettyPrinters(tail, acc += candidates) + case _ => None + } + } + val funIds = fd.params.tail.map(x => TypeOps.instantiateType(x.id, defGenericTypeMap)).toList + gatherPrettyPrinters(funIds) match { + case Some(l) => for(lambdas <- l) yield { + val x = FreshIdentifier("x", fd.params.head.getType) // verify the type + Lambda(Seq(ValDef(x)), functionInvocation(fd, Variable(x)::lambdas)) } - } else Nil + case _ => Stream.empty + } + case None => Stream.empty } } + /** Actually prints the expression with as alternative the given orElse */ - def print(v: Expr, orElse: =>String, excluded: FunDef => Boolean = Set())(implicit ctx: LeonContext, program: Program): String = { + def print(v: Expr, orElse: =>String, excluded: Set[FunDef] = Set())(implicit ctx: LeonContext, program: Program): String = { + this.excluded = excluded val s = prettyPrintersForType(v.getType) // TODO: Included the variable excluded if necessary. if(s.isEmpty) { + println("Could not find pretty printer for type " + v.getType) orElse } else { val l: Lambda = s.head + println("Executing pretty printer for type " + v.getType + " : " + l + " on " + v) val ste = new DefaultEvaluator(ctx, program) try { val toEvaluate = application(l, Seq(v)) @@ -99,10 +112,12 @@ class SelfPrettyPrinter { case Some(StringLiteral(res)) if res != "" => res case res => + println("not a string literal " + res) orElse } } catch { case e: evaluators.ContextualEvaluator#EvalError => + println("Error " + e.msg) orElse } } diff --git a/src/main/scala/leon/synthesis/Solution.scala b/src/main/scala/leon/synthesis/Solution.scala index ea52b9bab1c3e0d2b9be05df1cd2ec1c1f65c23f..affa41df46292ae73a6b9a87150283ec8c0ee997 100644 --- a/src/main/scala/leon/synthesis/Solution.scala +++ b/src/main/scala/leon/synthesis/Solution.scala @@ -31,6 +31,7 @@ class Solution(val pre: Expr, val defs: Set[FunDef], val term: Expr, val isTrust } def toExpr = { + if(defs.isEmpty) guardedTerm else LetDef(defs.toList, guardedTerm) } diff --git a/src/main/scala/leon/synthesis/disambiguation/ExamplesAdder.scala b/src/main/scala/leon/synthesis/disambiguation/ExamplesAdder.scala index cfe8f456aac42684be263b5594d151fd463e1623..375052dc7e5626cf635231d1c7687bba2c35f211 100644 --- a/src/main/scala/leon/synthesis/disambiguation/ExamplesAdder.scala +++ b/src/main/scala/leon/synthesis/disambiguation/ExamplesAdder.scala @@ -5,10 +5,11 @@ package disambiguation import leon.LeonContext import leon.purescala.Expressions._ +import purescala.Types.FunctionType import purescala.Common.FreshIdentifier import purescala.Constructors.{ and, tupleWrap } import purescala.Definitions.{ FunDef, Program, ValDef } -import purescala.ExprOps.expressionToPattern +import purescala.ExprOps import purescala.Expressions.{ BooleanLiteral, Equals, Expr, Lambda, MatchCase, Passes, Variable, WildcardPattern } import purescala.Extractors.TopLevelAnds import leon.purescala.Expressions._ @@ -16,7 +17,35 @@ import leon.purescala.Expressions._ /** * @author Mikael */ +object ExamplesAdder { + def replaceGenericValuesByVariable(e: Expr): (Expr, Map[Expr, Expr]) = { + var assignment = Map[Expr, Expr]() + var extension = 'a' + var id = "" + (ExprOps.postMap({ expr => expr match { + case g@GenericValue(tpe, index) => + val newIdentifier = FreshIdentifier(tpe.id.name.take(1).toLowerCase() + tpe.id.name.drop(1) + extension + id, tpe.id.getType) + if(extension != 'z' && extension != 'Z') + extension = (extension.toInt + 1).toChar + else if(extension == 'z') // No more than 52 generic variables in practice? + extension = 'A' + else { + if(id == "") id = "1" else id = (id.toInt + 1).toString + } + + val newVar = Variable(newIdentifier) + assignment += g -> newVar + Some(newVar) + case _ => None + } })(e), assignment) + } +} + class ExamplesAdder(ctx0: LeonContext, program: Program) { + import ExamplesAdder._ + var _removeFunctionParameters = false + + def setRemoveFunctionParameters(b: Boolean) = { _removeFunctionParameters = b; this } /** Accepts the nth alternative of a question (0 being the current one) */ def acceptQuestion[T <: Expr](fd: FunDef, q: Question[T], alternativeIndex: Int): Unit = { @@ -27,7 +56,8 @@ class ExamplesAdder(ctx0: LeonContext, program: Program) { /** Adds the given input/output examples to the function definitions */ def addToFunDef(fd: FunDef, examples: Seq[(Expr, Expr)]) = { - val inputVariables = tupleWrap(fd.params.map(p => Variable(p.id): Expr)) + val params = if(_removeFunctionParameters) fd.params.filter(x => !x.getType.isInstanceOf[FunctionType]) else fd.params + val inputVariables = tupleWrap(params.map(p => Variable(p.id): Expr)) val newCases = examples.map{ case (in, out) => exampleToCase(in, out) } fd.postcondition match { case Some(Lambda(Seq(ValDef(id, tpe)), post)) => @@ -68,12 +98,12 @@ class ExamplesAdder(ctx0: LeonContext, program: Program) { } private def exampleToCase(in: Expr, out: Expr): MatchCase = { - val (inPattern, inGuard) = expressionToPattern(in) - if(inGuard != BooleanLiteral(true)) { + val (inPattern, inGuard) = ExprOps.expressionToPattern(in) + if(inGuard == BooleanLiteral(true)) { + MatchCase(inPattern, None, out) + } else /*if (in == in_raw) { } *else*/ { val id = FreshIdentifier("out", in.getType, true) MatchCase(WildcardPattern(Some(id)), Some(Equals(Variable(id), in)), out) - } else { - MatchCase(inPattern, None, out) } } } \ No newline at end of file diff --git a/src/main/scala/leon/synthesis/disambiguation/QuestionBuilder.scala b/src/main/scala/leon/synthesis/disambiguation/QuestionBuilder.scala index c66b1e846d2e055ec494b5937aecc5d6a252a8a8..fe57f4c2c25e90c3ce88318a772fe23bb11f5faa 100644 --- a/src/main/scala/leon/synthesis/disambiguation/QuestionBuilder.scala +++ b/src/main/scala/leon/synthesis/disambiguation/QuestionBuilder.scala @@ -19,6 +19,7 @@ import solvers.ModelBuilder import scala.collection.mutable.ListBuffer import leon.grammars.ExpressionGrammar import evaluators.AbstractEvaluator +import scala.annotation.tailrec object QuestionBuilder { /** Sort methods for questions. You can build your own */ @@ -142,13 +143,32 @@ class QuestionBuilder[T <: Expr]( yield simp } + /** Make all generic values unique. + * Duplicate generic values are not suitable for disambiguating questions since they remove an order. */ + def makeGenericValuesUnique(a: Expr): Expr = { + var genVals = Set[GenericValue]() + @tailrec @inline def freshGenericValue(g: GenericValue): GenericValue = { + if(genVals contains g) + freshGenericValue(GenericValue(g.tp, g.id + 1)) + else { + genVals += g + g + } + } + ExprOps.postMap{ e => e match { + case g@GenericValue(tpe, i) => + Some(freshGenericValue(g)) + case _ => None + }}(a) + } + /** Returns a list of input/output questions to ask to the user. */ def result(): List[Question[T]] = { if(solutions.isEmpty) return Nil val enum = new MemoizedEnumerator[TypeTree, Expr](value_enumerator.getProductions) val values = enum.iterator(tupleTypeWrap(_argTypes)) - val instantiations = values.map { + val instantiations = values.map(makeGenericValuesUnique _).map { v => input.zip(unwrapTuple(v, input.size)) } diff --git a/src/main/scala/leon/synthesis/rules/StringRender.scala b/src/main/scala/leon/synthesis/rules/StringRender.scala index 319c3653801e373d7f9ca3eff10fa6fe13273021..baec973b388c3e27301e4b49ae328f97143d799d 100644 --- a/src/main/scala/leon/synthesis/rules/StringRender.scala +++ b/src/main/scala/leon/synthesis/rules/StringRender.scala @@ -75,38 +75,10 @@ case object StringRender extends Rule("StringRender") { var EDIT_ME = "_edit_me_" var enforceDefaultStringMethodsIfAvailable = true - - var _defaultTypeToString: Option[Map[TypeTree, FunDef]] = None - - def defaultMapTypeToString()(implicit hctx: SearchContext): Map[TypeTree, FunDef] = { - _defaultTypeToString.getOrElse{ - // Updates the cache with the functions converting standard types to string. - val res = (hctx.program.library.StrOps.toSeq.flatMap { StrOps => - StrOps.defs.collect{ case d: FunDef if d.params.length == 1 && d.returnType == StringType => d.params.head.getType -> d } - }).toMap - _defaultTypeToString = Some(res) - res - } - } - - /** Returns a toString function converter if it has been defined. */ - class WithFunDefConverter(implicit hctx: SearchContext) { - def unapply(tpe: TypeTree): Option[FunDef] = { - _defaultTypeToString.flatMap(_.get(tpe)) - } - } + var enforceSelfStringMethodsIfAvailable = false val booleanTemplate = (a: Expr) => StringTemplateGenerator(Hole => IfExpr(a, Hole, Hole)) - /** Returns a seq of expressions such as `x + y + "1" + y + "2" + z` associated to an expected result string `"1, 2"`. - * We use these equations so that we can find the values of the constants x, y, z and so on. - * This uses a custom evaluator which does not concatenate string but reminds the calculation. - */ - def createProblems(inlineFunc: Seq[FunDef], inlineExpr: Expr, examples: ExamplesBank): Seq[(Expr, String)] = ??? - - /** For each solution to the problem such as `x + "1" + y + j + "2" + z = 1, 2`, outputs all possible assignments if they exist. */ - def solveProblems(problems: Seq[(Expr, String)]): Seq[Map[Identifier, String]] = ??? - import StringSolver.{StringFormToken, StringForm, Problem => SProblem, Equation, Assignment} /** Augment the left-hand-side to have possible function calls, such as x + "const" + customToString(_) ... @@ -179,7 +151,6 @@ case object StringRender extends Rule("StringRender") { /** Returns a stream of assignments compatible with input/output examples for the given template */ def findAssignments(p: Program, inputs: Seq[Identifier], examples: ExamplesBank, template: Expr)(implicit hctx: SearchContext): Stream[Map[Identifier, String]] = { - //new Evaluator() val e = new AbstractEvaluator(hctx.context, p) @tailrec def gatherEquations(s: List[InOutExample], acc: ListBuffer[Equation] = ListBuffer()): Option[SProblem] = s match { @@ -193,6 +164,7 @@ case object StringRender extends Rule("StringRender") { evalResult.result match { case None => hctx.reporter.info("Eval = None : ["+template+"] in ["+inputs.zip(in)+"]") + hctx.reporter.info(evalResult) None case Some((sfExpr, abstractSfExpr)) => //ctx.reporter.debug("Eval = ["+sfExpr+"] (from "+abstractSfExpr+")") @@ -325,7 +297,7 @@ case object StringRender extends Rule("StringRender") { object StringSynthesisContext { def empty( abstractStringConverters: StringConverters, - originalInputs: Set[Identifier], + originalInputs: Set[Expr], provided_functions: Seq[Identifier])(implicit hctx: SearchContext) = new StringSynthesisContext(None, new StringSynthesisResult(Map(), Set()), abstractStringConverters, @@ -338,7 +310,7 @@ case object StringRender extends Rule("StringRender") { val currentCaseClassParent: Option[TypeTree], val result: StringSynthesisResult, val abstractStringConverters: StringConverters, - val originalInputs: Set[Identifier], + val originalInputs: Set[Expr], val provided_functions: Seq[Identifier] )(implicit hctx: SearchContext) { def add(d: DependentType, f: FunDef, s: Stream[WithIds[Expr]]): StringSynthesisContext = { @@ -371,7 +343,8 @@ case object StringRender extends Rule("StringRender") { val funName = funName3(0).toLower + funName3.substring(1) val funId = FreshIdentifier(ctx.freshFunName(funName), alwaysShowUniqueID = true) val argId= FreshIdentifier(tpe.typeToConvert.asString(hctx.context).toLowerCase()(0).toString, tpe.typeToConvert) - val fd = new FunDef(funId, Nil, ValDef(argId) :: ctx.provided_functions.map(ValDef(_, false)).toList, StringType) // Empty function. + val tparams = hctx.sctx.functionContext.tparams + val fd = new FunDef(funId, tparams, ValDef(argId) :: ctx.provided_functions.map(ValDef(_, false)).toList, StringType) // Empty function. fd } @@ -449,15 +422,30 @@ case object StringRender extends Rule("StringRender") { case Some(fd) => gatherInputs(ctx, q, result += Stream((functionInvocation(fd._1, input::ctx.provided_functions.toList.map(Variable)), Nil))) case None => // No function can render the current type. - - val exprs1s = (new SelfPrettyPrinter).allowFunction(hctx.sctx.functionContext).prettyPrintersForType(input.getType)(hctx.context, hctx.program).map(l => (application(l, Seq(input)), List[Identifier]())) // Use already pre-defined pretty printers. + // We should not rely on calling the original function on the first line of the body of the function itself. + val exprs1s = (new SelfPrettyPrinter) + .allowFunction(hctx.sctx.functionContext) + .excludeFunction(hctx.sctx.functionContext) + .prettyPrintersForType(input.getType)(hctx.context, hctx.program) + .map(l => (application(l, Seq(input)), List[Identifier]())) // Use already pre-defined pretty printers. val exprs1 = exprs1s.toList.sortBy{ case (Lambda(_, FunctionInvocation(fd, _)), _) if fd == hctx.sctx.functionContext => 0 case _ => 1} val exprs2 = ctx.abstractStringConverters.getOrElse(input.getType, Nil).map(f => (f(input), List[Identifier]())) - val converters: Stream[WithIds[Expr]] = (exprs1.toStream #::: exprs2.toStream) - def mergeResults(defaultconverters: =>Stream[WithIds[Expr]]): Stream[WithIds[Expr]] = { - if(converters.isEmpty) defaultconverters - else if(enforceDefaultStringMethodsIfAvailable) converters - else converters #::: defaultconverters + val defaultConverters: Stream[WithIds[Expr]] = exprs1.toStream #::: exprs2.toStream + val recursiveConverters: Stream[WithIds[Expr]] = + (new SelfPrettyPrinter) + .prettyPrinterFromCandidate(hctx.sctx.functionContext, input.getType)(hctx.context, hctx.program) + .map(l => (application(l, Seq(input)), List[Identifier]())) + + def mergeResults(templateConverters: =>Stream[WithIds[Expr]]): Stream[WithIds[Expr]] = { + if(defaultConverters.isEmpty) templateConverters + else if(enforceDefaultStringMethodsIfAvailable) { + if(enforceSelfStringMethodsIfAvailable) + recursiveConverters #::: defaultConverters + else { + defaultConverters #::: recursiveConverters + } + } + else recursiveConverters #::: defaultConverters #::: templateConverters } input.getType match { @@ -472,8 +460,8 @@ case object StringRender extends Rule("StringRender") { case WithStringconverter(converter) => // Base case gatherInputs(ctx, q, result += mergeResults(Stream((converter(input), Nil)))) case t: ClassType => - if(enforceDefaultStringMethodsIfAvailable && !converters.isEmpty) { - gatherInputs(ctx, q, result += converters) + if(enforceDefaultStringMethodsIfAvailable && !defaultConverters.isEmpty) { + gatherInputs(ctx, q, result += defaultConverters) } else { // Create the empty function body and updates the assignments parts. val fd = createEmptyFunDef(ctx, dependentType) @@ -510,7 +498,7 @@ case object StringRender extends Rule("StringRender") { } } case TypeParameter(t) => - if(converters.isEmpty) { + if(defaultConverters.isEmpty) { hctx.reporter.fatalError("Could not handle type parameter for string rendering " + t) } else { gatherInputs(ctx, q, result += mergeResults(Stream.empty)) @@ -558,9 +546,7 @@ case object StringRender extends Rule("StringRender") { p.xs match { case List(IsTyped(v, StringType)) => val description = "Creates a standard string conversion function" - - val defaultToStringFunctions = defaultMapTypeToString() - + val examplesFinder = new ExamplesFinder(hctx.context, hctx.program).setKeepAbstractExamples(true) val examples = examplesFinder.extractFromProblem(p) @@ -576,13 +562,14 @@ case object StringRender extends Rule("StringRender") { }) val ruleInstantiations = ListBuffer[RuleInstantiation]() + val originalInputs = inputVariables.map(Variable) ruleInstantiations += RuleInstantiation("String conversion") { val (expr, synthesisResult) = createFunDefsTemplates( StringSynthesisContext.empty( abstractStringConverters, - p.as.toSet, + originalInputs.toSet, functionVariables - ), inputVariables.map(Variable)) + ), originalInputs) val funDefs = synthesisResult.adtToString /*val toDebug: String = (("\nInferred functions:" /: funDefs)( (t, s) =>