diff --git a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala index 75c691edf1f7e5a4ca056bf8c8448bff9b5ead77..897c1ba4794a628f4389e480a6cd02dc822b0ab4 100644 --- a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala +++ b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala @@ -645,9 +645,7 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int case l : Literal[_] => l case other => - context.reporter.error(other.getPos, "Error: don't know how to handle " + other.asString + " in Evaluator ("+other.getClass+").") - println("RecursiveEvaluator error:" + other.asString) - throw EvalError("Unhandled case in Evaluator : " + other.asString) + throw EvalError("Unhandled case in Evaluator : [" + other.getClass + "] " + other.asString) } def matchesCase(scrut: Expr, caze: MatchCase)(implicit rctx: RC, gctx: GC): Option[(MatchCase, Map[Identifier, Expr])] = { diff --git a/src/main/scala/leon/evaluators/StringTracingEvaluator.scala b/src/main/scala/leon/evaluators/StringTracingEvaluator.scala index ab8b14055d47533fb9d4d30f5bca66a3289db75a..be3ef7a062531e1a0c90abd57206efec8688761a 100644 --- a/src/main/scala/leon/evaluators/StringTracingEvaluator.scala +++ b/src/main/scala/leon/evaluators/StringTracingEvaluator.scala @@ -8,78 +8,18 @@ import purescala.Expressions._ import purescala.Types._ import purescala.Definitions.{TypedFunDef, Program} import purescala.DefOps +import purescala.TypeOps +import purescala.ExprOps import purescala.Expressions.Expr import leon.utils.DebugSectionSynthesis import org.apache.commons.lang3.StringEscapeUtils class StringTracingEvaluator(ctx: LeonContext, prog: Program) extends ContextualEvaluator(ctx, prog, 50000) with HasDefaultGlobalContext with HasDefaultRecContext { - - val underlying = new DefaultEvaluator(ctx, prog) { - override protected[evaluators] def e(expr: Expr)(implicit rctx: RC, gctx: GC): Expr = expr match { - - case FunctionInvocation(TypedFunDef(fd, Nil), Seq(input)) if fd == prog.library.escape.get => - e(input) match { - case StringLiteral(s) => - StringLiteral(StringEscapeUtils.escapeJava(s)) - case _ => throw EvalError(typeErrorMsg(input, StringType)) - } - - case FunctionInvocation(tfd, args) => - if (gctx.stepsLeft < 0) { - throw RuntimeError("Exceeded number of allocated methods calls ("+gctx.maxSteps+")") - } - gctx.stepsLeft -= 1 - - val evArgs = args map e - - // build a mapping for the function... - val frame = rctx.withNewVars(tfd.paramSubst(evArgs)) - - val callResult = if (tfd.fd.annotations("extern") && ctx.classDir.isDefined) { - scalaEv.call(tfd, evArgs) - } else { - if(!tfd.hasBody && !rctx.mappings.isDefinedAt(tfd.id)) { - throw EvalError("Evaluation of function with unknown implementation.") - } + lazy val scalaEv = new ScalacEvaluator(underlying, ctx, prog) - val body = tfd.body.getOrElse(rctx.mappings(tfd.id)) - e(body)(frame, gctx) - } - - callResult - - case Variable(id) => - rctx.mappings.get(id) match { - case Some(v) if v != expr => - e(v) - case Some(v) => - v - case None => - expr - } - case StringConcat(s1, s2) => - val es1 = e(s1) - val es2 = e(s2) - (es1, es2) match { - case (StringLiteral(_), StringLiteral(_)) => - (super.e(StringConcat(es1, es2))) - case _ => - StringConcat(es1, es2) - } - case Application(caller, args) => - val ecaller = e(caller) - ecaller match { - case l @ Lambda(params, body) => - super.e(Application(ecaller, args)) - case PartialLambda(mapping, dflt, _) => - super.e(Application(ecaller, args)) - case f => - Application(f, args.map(e)) - } - case expr => - super.e(expr) - } - } + /** Evaluates resuts which can be evaluated directly + * For example, concatenation of two string literals */ + val underlying = new DefaultEvaluator(ctx, prog) override type Value = (Expr, Expr) override val description: String = "Evaluates string programs but keeps the formula which generated the string" @@ -90,32 +30,12 @@ class StringTracingEvaluator(ctx: LeonContext, prog: Program) extends Contextual rctx.mappings.get(id) match { case Some(v) if v != expr => e(v) - case Some(v) => - (v, expr) - case None => - (expr, expr) - } - - case StringConcat(s1, s2) => - val (es1, t1) = e(s1) - val (es2, t2) = e(s2) - (es1, es2) match { - case (StringLiteral(_), StringLiteral(_)) => - (underlying.e(StringConcat(es1, es2)), StringConcat(t1, t2)) - case _ => - (StringConcat(es1, es2), StringConcat(t1, t2)) - } - case StringLength(s1) => - val (es1, t1) = e(s1) - es1 match { - case StringLiteral(_) => - (underlying.e(StringLength(es1)), StringLength(t1)) case _ => - (StringLength(es1), StringLength(t1)) + (expr, expr) } - case expr@StringLiteral(s) => - (expr, expr) + case e if ExprOps.isValue(e) => + (e, e) case IfExpr(cond, thenn, elze) => val first = underlying.e(cond) @@ -126,11 +46,51 @@ class StringTracingEvaluator(ctx: LeonContext, prog: Program) extends Contextual case BooleanLiteral(false) => e(elze) case _ => throw EvalError(typeErrorMsg(first, BooleanType)) } + + case MatchExpr(scrut, cases) => + val rscrut = if(ExprOps.isValue(scrut)) scrut else underlying.e(scrut) + cases.toStream.map(c => underlying.matchesCase(rscrut, c)).find(_.nonEmpty) match { + case Some(Some((c, mappings))) => + e(c.rhs)(rctx.withNewVars(mappings), gctx) + case _ => + throw RuntimeError("MatchError: "+rscrut.asString+" did not match any of the cases") + } + case FunctionInvocation(tfd, args) => + if (gctx.stepsLeft < 0) { + throw RuntimeError("Exceeded number of allocated methods calls ("+gctx.maxSteps+")") + } + gctx.stepsLeft -= 1 + + val evArgs = args map e + val evArgsValues = evArgs.map(_._1) + val evArgsOrigin = evArgs.map(_._2) + if(evArgsValues forall ExprOps.isValue) { + // build a mapping for the function... + val frame = rctx.withNewVars(tfd.paramSubst(evArgsValues)) + + val callResult = if (tfd.fd.annotations("extern") && ctx.classDir.isDefined) { + (scalaEv.call(tfd, evArgsValues), FunctionInvocation(tfd, evArgsOrigin)) + } else { + if(!tfd.hasBody && !rctx.mappings.isDefinedAt(tfd.id)) { + throw EvalError("Evaluation of function with unknown implementation.") + } + + val body = tfd.body.getOrElse(rctx.mappings(tfd.id)) + e(body)(frame, gctx) + } + + callResult + } else { + (FunctionInvocation(tfd, evArgsValues), FunctionInvocation(tfd, evArgsOrigin)) + } case Operator(es, builder) => val (ees, ts) = es.map(e).unzip - (underlying.e(builder(ees)), builder(ts)) - + if(ees forall ExprOps.isValue) { + (underlying.e(builder(ees)), builder(ts)) + } else { + (builder(ees), builder(ts)) + } } diff --git a/src/main/scala/leon/purescala/ExprOps.scala b/src/main/scala/leon/purescala/ExprOps.scala index 4f6c968f37f230ff3c69487c6adc27397e5ab1aa..51dd63d33adef97224d1269ef0270df0f6a069d7 100644 --- a/src/main/scala/leon/purescala/ExprOps.scala +++ b/src/main/scala/leon/purescala/ExprOps.scala @@ -2217,6 +2217,42 @@ object ExprOps { ) } - - + /** Returns true if expr is a value of type t */ + def isValueOfType(e: Expr, t: TypeTree): Boolean = { + (e, t) match { + case (StringLiteral(_), StringType) => true + case (IntLiteral(_), Int32Type) => true + case (InfiniteIntegerLiteral(_), IntegerType) => true + case (CharLiteral(_), CharType) => true + case (FractionalLiteral(_, _), RealType) => true + case (BooleanLiteral(_), BooleanType) => true + case (UnitLiteral(), UnitType) => true + case (GenericValue(t, _), tp) => t == tp + case (Tuple(elems), TupleType(bases)) => + elems zip bases forall (eb => isValueOfType(eb._1, eb._2)) + case (FiniteSet(elems, tbase), SetType(base)) => + tbase == base && + (elems forall isValue) + case (FiniteMap(elems, tk, tv), MapType(from, to)) => + tk == from && tv == to && + (elems forall (kv => isValueOfType(kv._1, from) && isValueOfType(kv._2, to) )) + case (NonemptyArray(elems, defaultValues), ArrayType(base)) => + elems.values forall (x => isValueOfType(x, base)) + case (EmptyArray(tpe), ArrayType(base)) => + tpe == base + case (CaseClass(ct, args), ct2@AbstractClassType(classDef, tps)) => + TypeOps.isSubtypeOf(ct, ct2) && + ((args zip ct.fieldsTypes) forall (argstyped => isValueOfType(argstyped._1, argstyped._2))) + case (CaseClass(ct, args), ct2@CaseClassType(classDef, tps)) => + ct == ct2 && + ((args zip ct.fieldsTypes) forall (argstyped => isValueOfType(argstyped._1, argstyped._2))) + case (Lambda(valdefs, body), FunctionType(ins, out)) => + (valdefs zip ins forall (vdin => vdin._1.getType == vdin._2)) && + body.getType == out + case _ => false + } + } + + /** Returns true if expr is a value. Stronger than isGround */ + val isValue = (e: Expr) => isValueOfType(e, e.getType) } diff --git a/src/main/scala/leon/purescala/ScopeSimplifier.scala b/src/main/scala/leon/purescala/ScopeSimplifier.scala index d2b4686b6358bd5eb306b54ebf27a0599adb6d3f..01e048699ec1654420ba059873dc4299724b842f 100644 --- a/src/main/scala/leon/purescala/ScopeSimplifier.scala +++ b/src/main/scala/leon/purescala/ScopeSimplifier.scala @@ -42,7 +42,7 @@ class ScopeSimplifier extends Transformer { case LetDef(fds, body: Expr) => var newScope: Scope = scope // First register all functions - val fds_newIds = for(fd <- fds) yield { // Problem if some functions use the same ID for a ValDef + val fds_newIds = for(fd <- fds) yield { val newId = genId(fd.id, scope) newScope = newScope.register(fd.id -> newId) (fd, newId) @@ -52,7 +52,7 @@ class ScopeSimplifier extends Transformer { val localScopeToRegister = ListBuffer[(Identifier, Identifier)]() // We record the mapping of these variables only for the function. val newArgs = for(ValDef(id, tpe) <- fd.params) yield { val newArg = genId(id, newScope.register(localScopeToRegister)) - localScopeToRegister += (id -> newArg) // This should happen only inside the function. + localScopeToRegister += (id -> newArg) // This renaming happens only inside the function. ValDef(newArg, tpe) } diff --git a/src/main/scala/leon/purescala/SelfPrettyPrinter.scala b/src/main/scala/leon/purescala/SelfPrettyPrinter.scala index 2a2b7307cea7dbfbf5814bf39dc2b043105dc8c4..e9859b7c84d8fed8b385563bd3d3c0bc2f04aea4 100644 --- a/src/main/scala/leon/purescala/SelfPrettyPrinter.scala +++ b/src/main/scala/leon/purescala/SelfPrettyPrinter.scala @@ -25,24 +25,37 @@ import purescala.Definitions._ import scala.collection.mutable.ListBuffer import leon.evaluators.DefaultEvaluator +object SelfPrettyPrinter { + def prettyPrintersForType(inputType: TypeTree)(implicit ctx: LeonContext, program: Program): Stream[Lambda] = { + (new SelfPrettyPrinter).prettyPrintersForType(inputType) + } + def print(v: Expr, orElse: =>String, excluded: FunDef => Boolean = Set())(implicit ctx: LeonContext, program: Program): String = { + (new SelfPrettyPrinter).print(v, orElse, excluded) + } +} /** This pretty-printer uses functions defined in Leon itself. * If not pretty printing function is defined, return the default value instead * @param The list of functions which should be excluded from pretty-printing (to avoid rendering counter-examples of toString methods using the method itself) * @return a user defined string for the given typed expression. */ -object SelfPrettyPrinter { +class SelfPrettyPrinter { + private var allowedFunctions = Set[FunDef]() + + def allowFunction(fd: FunDef) = { allowedFunctions += fd; this } + /** Returns a list of possible lambdas that can transform the input type to a String*/ def prettyPrintersForType(inputType: TypeTree/*, existingPp: Map[TypeTree, List[Lambda]] = Map()*/)(implicit ctx: LeonContext, program: Program): Stream[Lambda] = { // Use the other argument if you need recursive typing (?) - (program.definedFunctions flatMap { + program.definedFunctions.toStream flatMap { fd => val isCandidate = fd.returnType == StringType && fd.params.length >= 1 && + allowedFunctions(fd) || ( //TypeOps.isSubtypeOf(v.getType, fd.params.head.getType) && fd.id.name.toLowerCase().endsWith("tostring") && program.callGraph.transitiveCallees(fd).forall { fde => !purescala.ExprOps.exists( _.isInstanceOf[Choose])(fde.fullBody) - } + }) if(isCandidate) { // InputType is concrete, the types of params may be abstract. TypeOps.canBeSubtypeOf(inputType, fd.tparams.map(_.tp), fd.params.head.getType) match { @@ -69,12 +82,12 @@ object SelfPrettyPrinter { case None => Nil } } else Nil - }).toStream + } } /** Actually prints the expression with as alternative the given orElse */ def print(v: Expr, orElse: =>String, excluded: FunDef => Boolean = Set())(implicit ctx: LeonContext, program: Program): String = { - val s = prettyPrintersForType(v.getType) + val s = prettyPrintersForType(v.getType) // TODO: Included the variable excluded if necessary. if(s.isEmpty) { orElse } else { diff --git a/src/main/scala/leon/synthesis/rules/StringRender.scala b/src/main/scala/leon/synthesis/rules/StringRender.scala index b3239492211950f692e7c1a956e2d18e3619274c..b476a134d2757e8bf4ba39540a00d3daa08caf66 100644 --- a/src/main/scala/leon/synthesis/rules/StringRender.scala +++ b/src/main/scala/leon/synthesis/rules/StringRender.scala @@ -6,7 +6,6 @@ package rules import scala.annotation.tailrec import scala.collection.mutable.ListBuffer - import bonsai.enumerators.MemoizedEnumerator import leon.evaluators.DefaultEvaluator import leon.evaluators.StringTracingEvaluator @@ -30,6 +29,7 @@ import purescala.Expressions._ import purescala.Extractors._ import purescala.TypeOps import purescala.Types._ +import leon.purescala.SelfPrettyPrinter /** A template generator for a given type tree. @@ -328,12 +328,10 @@ case object StringRender extends Rule("StringRender") { /** Companion object to create a StringSynthesisContext */ object StringSynthesisContext { def empty( - definedStringConverters: StringConverters, abstractStringConverters: StringConverters, originalInputs: Set[Identifier], provided_functions: Seq[Identifier])(implicit hctx: SearchContext) = new StringSynthesisContext(None, new StringSynthesisResult(Map(), Set()), - definedStringConverters, abstractStringConverters, originalInputs, provided_functions) @@ -343,21 +341,18 @@ case object StringRender extends Rule("StringRender") { class StringSynthesisContext( val currentCaseClassParent: Option[TypeTree], val result: StringSynthesisResult, - val definedStringConverters: StringConverters, val abstractStringConverters: StringConverters, val originalInputs: Set[Identifier], val provided_functions: Seq[Identifier] )(implicit hctx: SearchContext) { def add(d: DependentType, f: FunDef, s: Stream[WithIds[Expr]]): StringSynthesisContext = { new StringSynthesisContext(currentCaseClassParent, result.add(d, f, s), - definedStringConverters, abstractStringConverters, originalInputs, provided_functions) } def copy(currentCaseClassParent: Option[TypeTree]=currentCaseClassParent, result: StringSynthesisResult = result): StringSynthesisContext = new StringSynthesisContext(currentCaseClassParent, result, - definedStringConverters, abstractStringConverters, originalInputs, provided_functions) @@ -458,13 +453,11 @@ case object StringRender extends Rule("StringRender") { case Some(fd) => gatherInputs(ctx, q, result += Stream((functionInvocation(fd._1, input::ctx.provided_functions.toList.map(Variable)), Nil))) case None => // No function can render the current type. - val exprs1 = ctx.definedStringConverters.getOrElse(input.getType, Nil).flatMap(f => - f(input) match { - case FunctionInvocation(fd, Variable(id)::_) if ctx.originalInputs(id) => None - case e => Some((e, Nil)) - }) - val exprs2 = ctx.abstractStringConverters.getOrElse(input.getType, Nil).map(f => (f(input), Nil)) - val converters = (exprs1 ++ exprs2).toStream + + val exprs1s = (new SelfPrettyPrinter).allowFunction(hctx.sctx.functionContext).prettyPrintersForType(input.getType)(hctx.context, hctx.program).map(l => (application(l, Seq(input)), List[Identifier]())) // Use already pre-defined pretty printers. + val exprs1 = exprs1s.toList.sortBy{ case (Lambda(_, FunctionInvocation(fd, _)), _) if fd == hctx.sctx.functionContext => 0 case _ => 1} + val exprs2 = ctx.abstractStringConverters.getOrElse(input.getType, Nil).map(f => (f(input), List[Identifier]())) + val converters: Stream[WithIds[Expr]] = (exprs1.toStream #::: exprs2.toStream) def mergeResults(defaultconverters: =>Stream[WithIds[Expr]]): Stream[WithIds[Expr]] = { if(converters.isEmpty) defaultconverters else if(enforceDefaultStringMethodsIfAvailable) converters @@ -575,13 +568,6 @@ case object StringRender extends Rule("StringRender") { val examplesFinder = new ExamplesFinder(hctx.context, hctx.program).setKeepAbstractExamples(true) val examples = examplesFinder.extractFromProblem(p) - val definedStringConverters: StringConverters = - hctx.program.definedFunctions.filter(fd => - fd.returnType == StringType && fd.params.length == 1 - && fd != hctx.program.library.escape.get - ) - .groupBy({ fd => fd.paramIds.head.getType }).mapValues(fds => - fds.map((fd : FunDef) => ((x: Expr) => functionInvocation(fd, Seq(x))))) val abstractStringConverters: StringConverters = (p.as.flatMap { case x => x.getType match { case FunctionType(Seq(aType), StringType) => List((aType, (arg: Expr) => application(Variable(x), Seq(arg)))) @@ -597,7 +583,6 @@ case object StringRender extends Rule("StringRender") { ruleInstantiations += RuleInstantiation("String conversion") { val (expr, synthesisResult) = createFunDefsTemplates( StringSynthesisContext.empty( - definedStringConverters, abstractStringConverters, p.as.toSet, functionVariables