diff --git a/src/main/scala/leon/synthesis/Problem.scala b/src/main/scala/leon/synthesis/Problem.scala index e56eb59a894bd255701763aafab6e6bb2a286f17..a40435b0478b690e69157ea3af715c0f52a8f717 100644 --- a/src/main/scala/leon/synthesis/Problem.scala +++ b/src/main/scala/leon/synthesis/Problem.scala @@ -87,10 +87,11 @@ case class Problem(as: List[Identifier], pc: Expr, phi: Expr, xs: List[Identifie def toIOExamples(in: Expr, out : Expr, cs : MatchCase) : Seq[(Expr,Expr)] = { import utils.ExpressionGrammars + import leon.utils.StreamUtils.cartesianProduct import bonsai._ import bonsai.enumerators._ - val examplesPerVariable = 5 + val examplesPerCase = 5 def doSubstitute(subs : Seq[(Identifier, Expr)], e : Expr) = subs.foldLeft(e) { @@ -104,36 +105,26 @@ case class Problem(as: List[Identifier], pc: Expr, phi: Expr, xs: List[Identifie // The pattern as expression (input expression)(may contain free variables) val (pattExpr, ieMap) = patternToExpression(cs.pattern, in.getType) val freeVars = variablesOf(pattExpr).toSeq - if (freeVars.isEmpty) { // The input contains no free vars. Trivially return input-output pair Seq((pattExpr, doSubstitute(ieMap,cs.rhs))) } else { // If the input contains free variables, it does not provide concrete examples. // We will instantiate them according to a simple grammar to get them. - val grammar = ExpressionGrammars.BaseGrammar + val grammar = ExpressionGrammars.ValueGrammar val enum = new MemoizedEnumerator[TypeTree, Expr](grammar.getProductions _) val types = freeVars.map{ _.getType } - val typesWithValues = types.map { tp => (tp, enum.iterator(tp).take(examplesPerVariable).toSeq) }.toMap + val typesWithValues = types.map { tp => (tp, enum.iterator(tp).toStream) }.toMap val values = freeVars map { v => typesWithValues(v.getType) } - // Make all combinations of all possible instantiations - def combinations[A](s : Seq[Seq[A]]) : Seq[Seq[A]] = { - if (s.isEmpty) Seq(Seq()) - else for { - h <- s.head - t <- combinations(s.tail) - } yield (h +: t) - } - val instantiations = combinations(values) map { freeVars.zip(_).toMap } - instantiations map { inst => + val instantiations = cartesianProduct(values) map { freeVars.zip(_).toMap } + instantiations.map { inst => (replaceFromIDs(inst, pattExpr), replaceFromIDs(inst, doSubstitute(ieMap, cs.rhs))) - } + }.take(examplesPerCase) } } } - val evaluator = new DefaultEvaluator(sctx.context, sctx.program) val testClusters = collect[Map[Identifier, Expr]] { diff --git a/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala b/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala index ea0003e5a9cd8a52c5ff6b1b137aa5a4c2b16cfa..1cf51fc5c23c10eb0f550ab7b211dc248abd428f 100644 --- a/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala +++ b/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala @@ -79,6 +79,11 @@ object ExpressionGrammars { Generator(List(Int32Type, Int32Type), { case Seq(a,b) => Minus(a, b) }), Generator(List(Int32Type, Int32Type), { case Seq(a,b) => Times(a, b) }) ) + + case tp@TypeParameter(_) => + for (ind <- (1 to 3).toList) yield + Generator[TypeTree, Expr](Nil, { _ => GenericValue(tp, ind) } ) + case TupleType(stps) => List(Generator(stps, { sub => Tuple(sub) })) @@ -105,6 +110,48 @@ object ExpressionGrammars { } } + case object ValueGrammar extends ExpressionGrammar[TypeTree] { + def computeProductions(t: TypeTree): Seq[Gen] = t match { + case BooleanType => + List( + Generator(Nil, { _ => BooleanLiteral(true) }), + Generator(Nil, { _ => BooleanLiteral(false) }) + ) + case Int32Type => + List( + Generator(Nil, { _ => IntLiteral(0) }), + Generator(Nil, { _ => IntLiteral(1) }), + Generator(Nil, { _ => IntLiteral(-1) }) + ) + + case tp@TypeParameter(_) => + for (ind <- (1 to 3).toList) yield + Generator[TypeTree, Expr](Nil, { _ => GenericValue(tp, ind) } ) + + case TupleType(stps) => + List(Generator(stps, { sub => Tuple(sub) })) + + case cct: CaseClassType => + List( + Generator(cct.fields.map(_.getType), { case rs => CaseClass(cct, rs)} ) + ) + + case act: AbstractClassType => + act.knownCCDescendents.map { cct => + Generator[TypeTree, Expr](cct.fields.map(_.getType), { case rs => CaseClass(cct, rs)} ) + } + + case st @ SetType(base) => + List( + Generator(List(base), { case elems => FiniteSet(elems.toSet).setType(st) }), + Generator(List(base, base), { case elems => FiniteSet(elems.toSet).setType(st) }) + ) + + case _ => + Nil + } + } + case class OneOf(inputs: Seq[Expr]) extends ExpressionGrammar[TypeTree] { def computeProductions(t: TypeTree): Seq[Gen] = { inputs.collect {