From 662ca1032e7d5c280fe9d618af308c6295cb94a8 Mon Sep 17 00:00:00 2001 From: Manos Koukoutos <emmanouil.koukoutos@epfl.ch> Date: Wed, 13 Jan 2016 14:20:58 +0100 Subject: [PATCH] Direct ExpressionGrammars through tags in Productions and NonTerminals Rename Label -> NonTerminal Rename Generator -> ProductionRule Adapt to latest bonsai Tag ProductionRule's Introduce TaggedGrammar, which tags non-terminals with the Tag of the ProductionRule they came from + position. Filter productions according to tags During CEGIS, filter out programs with trivial operations such as (x == x) (currently commented out) Constant grammar detects constants in current FunDef Commutative operations are now skewed to the right Cegis individually checks candidate programs if they are few compared to total programs. Add grammars.transformers package, move some files there Add sumToOrdered in SeqUtils Add some API documentation Empty grammar rules get printed Warning when no functions chosen for synthesis In CEGIS, individually test a few programs iff programs passing tests are either very few, or much fewer than total programs. In CEGIS, when we find a new counterexample, we filter the rest of remaining programs with it before verifying them. --- .../scala/leon/datagen/GrammarDataGen.scala | 2 +- .../scala/leon/grammars/BaseGrammar.scala | 55 +++-- src/main/scala/leon/grammars/Constants.scala | 33 +++ .../leon/grammars/DepthBoundedGrammar.scala | 19 -- src/main/scala/leon/grammars/Empty.scala | 3 +- .../scala/leon/grammars/EqualityGrammar.scala | 8 +- .../leon/grammars/ExpressionGrammar.scala | 51 +++-- .../scala/leon/grammars/FunctionCalls.scala | 10 +- src/main/scala/leon/grammars/Generator.scala | 16 -- src/main/scala/leon/grammars/Grammars.scala | 3 + .../scala/leon/grammars/NonTerminal.scala | 9 +- .../scala/leon/grammars/ProductionRule.scala | 18 ++ .../leon/grammars/SafeRecursiveCalls.scala | 15 +- src/main/scala/leon/grammars/SimilarTo.scala | 33 +-- .../leon/grammars/SizeBoundedGrammar.scala | 35 --- src/main/scala/leon/grammars/Tags.scala | 65 ++++++ .../scala/leon/grammars/ValueGrammar.scala | 40 ++-- .../transformers/DepthBoundedGrammar.scala | 21 ++ .../{ => transformers}/EmbeddedGrammar.scala | 9 +- .../grammars/{ => transformers}/OneOf.scala | 11 +- .../transformers/SizeBoundedGrammar.scala | 59 +++++ .../grammars/transformers/TaggedGrammar.scala | 112 ++++++++++ .../{Or.scala => transformers/Union.scala} | 5 +- src/main/scala/leon/purescala/DefOps.scala | 8 + src/main/scala/leon/purescala/TypeOps.scala | 2 +- src/main/scala/leon/repair/Repairman.scala | 2 +- .../scala/leon/synthesis/ExamplesFinder.scala | 2 +- src/main/scala/leon/synthesis/Rules.scala | 7 +- .../scala/leon/synthesis/SourceInfo.scala | 4 + .../scala/leon/synthesis/SynthesisPhase.scala | 6 +- .../scala/leon/synthesis/Synthesizer.scala | 6 +- .../disambiguation/QuestionBuilder.scala | 22 +- .../leon/synthesis/rules/BottomUpTegis.scala | 6 +- .../scala/leon/synthesis/rules/CEGIS.scala | 21 +- .../leon/synthesis/rules/CEGISLike.scala | 207 ++++++++++++------ .../scala/leon/synthesis/rules/CEGLESS.scala | 7 +- .../leon/synthesis/rules/DetupleInput.scala | 7 +- .../leon/synthesis/rules/TEGISLike.scala | 2 +- .../scala/leon/synthesis/utils/Helpers.scala | 13 +- src/main/scala/leon/utils/SeqUtils.scala | 14 ++ 40 files changed, 696 insertions(+), 272 deletions(-) create mode 100644 src/main/scala/leon/grammars/Constants.scala delete mode 100644 src/main/scala/leon/grammars/DepthBoundedGrammar.scala delete mode 100644 src/main/scala/leon/grammars/Generator.scala create mode 100644 src/main/scala/leon/grammars/ProductionRule.scala delete mode 100644 src/main/scala/leon/grammars/SizeBoundedGrammar.scala create mode 100644 src/main/scala/leon/grammars/Tags.scala create mode 100644 src/main/scala/leon/grammars/transformers/DepthBoundedGrammar.scala rename src/main/scala/leon/grammars/{ => transformers}/EmbeddedGrammar.scala (74%) rename src/main/scala/leon/grammars/{ => transformers}/OneOf.scala (56%) create mode 100644 src/main/scala/leon/grammars/transformers/SizeBoundedGrammar.scala create mode 100644 src/main/scala/leon/grammars/transformers/TaggedGrammar.scala rename src/main/scala/leon/grammars/{Or.scala => transformers/Union.scala} (73%) diff --git a/src/main/scala/leon/datagen/GrammarDataGen.scala b/src/main/scala/leon/datagen/GrammarDataGen.scala index cd86c707d..23c1ed5b8 100644 --- a/src/main/scala/leon/datagen/GrammarDataGen.scala +++ b/src/main/scala/leon/datagen/GrammarDataGen.scala @@ -20,7 +20,7 @@ class GrammarDataGen(evaluator: Evaluator, grammar: ExpressionGrammar[TypeTree] implicit val ctx = evaluator.context def generate(tpe: TypeTree): Iterator[Expr] = { - val enum = new MemoizedEnumerator[TypeTree, Expr, Generator[TypeTree, Expr]](grammar.getProductions) + val enum = new MemoizedEnumerator[TypeTree, Expr, ProductionRule[TypeTree, Expr]](grammar.getProductions) enum.iterator(tpe) } diff --git a/src/main/scala/leon/grammars/BaseGrammar.scala b/src/main/scala/leon/grammars/BaseGrammar.scala index f11f93749..6e0a2ee5e 100644 --- a/src/main/scala/leon/grammars/BaseGrammar.scala +++ b/src/main/scala/leon/grammars/BaseGrammar.scala @@ -7,56 +7,65 @@ import purescala.Types._ import purescala.Expressions._ import purescala.Constructors._ +/** The basic grammar for Leon expressions. + * Generates the most obvious expressions for a given type, + * without regard of context (variables in scope, current function etc.) + * Also does some trivial simplifications. + */ case object BaseGrammar extends ExpressionGrammar[TypeTree] { - def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Gen] = t match { + + def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Prod] = t match { case BooleanType => List( - terminal(BooleanLiteral(true)), - terminal(BooleanLiteral(false)), - nonTerminal(List(BooleanType), { case Seq(a) => not(a) }), - nonTerminal(List(BooleanType, BooleanType), { case Seq(a, b) => and(a, b) }), - nonTerminal(List(BooleanType, BooleanType), { case Seq(a, b) => or(a, b) }), - nonTerminal(List(Int32Type, Int32Type), { case Seq(a, b) => LessThan(a, b) }), + terminal(BooleanLiteral(false), Tags.BooleanC), + terminal(BooleanLiteral(true), Tags.BooleanC), + nonTerminal(List(BooleanType), { case Seq(a) => not(a) }, Tags.Not), + nonTerminal(List(BooleanType, BooleanType), { case Seq(a, b) => and(a, b) }, Tags.And), + nonTerminal(List(BooleanType, BooleanType), { case Seq(a, b) => or(a, b) }, Tags.Or ), + nonTerminal(List(Int32Type, Int32Type), { case Seq(a, b) => LessThan(a, b) }), nonTerminal(List(Int32Type, Int32Type), { case Seq(a, b) => LessEquals(a, b) }), - nonTerminal(List(IntegerType, IntegerType), { case Seq(a, b) => LessThan(a, b) }), + nonTerminal(List(IntegerType, IntegerType), { case Seq(a, b) => LessThan(a, b) }), nonTerminal(List(IntegerType, IntegerType), { case Seq(a, b) => LessEquals(a, b) }) ) case Int32Type => List( - terminal(IntLiteral(0)), - terminal(IntLiteral(1)), - nonTerminal(List(Int32Type, Int32Type), { case Seq(a,b) => plus(a, b) }), - nonTerminal(List(Int32Type, Int32Type), { case Seq(a,b) => minus(a, b) }), - nonTerminal(List(Int32Type, Int32Type), { case Seq(a,b) => times(a, b) }) + terminal(IntLiteral(0), Tags.Zero), + terminal(IntLiteral(1), Tags.One ), + nonTerminal(List(Int32Type, Int32Type), { case Seq(a,b) => plus(a, b) }, Tags.Plus ), + nonTerminal(List(Int32Type, Int32Type), { case Seq(a,b) => minus(a, b) }, Tags.Minus), + nonTerminal(List(Int32Type, Int32Type), { case Seq(a,b) => times(a, b) }, Tags.Times) ) case IntegerType => List( - terminal(InfiniteIntegerLiteral(0)), - terminal(InfiniteIntegerLiteral(1)), - nonTerminal(List(IntegerType, IntegerType), { case Seq(a,b) => plus(a, b) }), - nonTerminal(List(IntegerType, IntegerType), { case Seq(a,b) => minus(a, b) }), - nonTerminal(List(IntegerType, IntegerType), { case Seq(a,b) => times(a, b) }) + terminal(InfiniteIntegerLiteral(0), Tags.Zero), + terminal(InfiniteIntegerLiteral(1), Tags.One ), + nonTerminal(List(IntegerType, IntegerType), { case Seq(a,b) => plus(a, b) }, Tags.Plus ), + nonTerminal(List(IntegerType, IntegerType), { case Seq(a,b) => minus(a, b) }, Tags.Minus), + nonTerminal(List(IntegerType, IntegerType), { case Seq(a,b) => times(a, b) }, Tags.Times), + nonTerminal(List(IntegerType, IntegerType), { case Seq(a,b) => Modulo(a, b) }, Tags.Mod), + nonTerminal(List(IntegerType, IntegerType), { case Seq(a,b) => Division(a, b) }, Tags.Div) ) case TupleType(stps) => List( - nonTerminal(stps, { sub => Tuple(sub) }) + nonTerminal(stps, Tuple, Tags.Constructor(isTerminal = false)) ) case cct: CaseClassType => List( - nonTerminal(cct.fields.map(_.getType), { case rs => CaseClass(cct, rs)} ) + nonTerminal(cct.fields.map(_.getType), CaseClass(cct, _), Tags.tagOf(cct) ) ) case act: AbstractClassType => act.knownCCDescendants.map { cct => - nonTerminal(cct.fields.map(_.getType), { case rs => CaseClass(cct, rs)} ) + nonTerminal(cct.fields.map(_.getType), CaseClass(cct, _), Tags.tagOf(cct) ) } case st @ SetType(base) => List( - nonTerminal(List(base), { case elems => FiniteSet(elems.toSet, base) }), + terminal(FiniteSet(Set(), base), Tags.Constant), + nonTerminal(List(base), { case elems => FiniteSet(elems.toSet, base) }, Tags.Constructor(isTerminal = false)), nonTerminal(List(st, st), { case Seq(a, b) => SetUnion(a, b) }), nonTerminal(List(st, st), { case Seq(a, b) => SetIntersection(a, b) }), nonTerminal(List(st, st), { case Seq(a, b) => SetDifference(a, b) }) @@ -64,7 +73,7 @@ case object BaseGrammar extends ExpressionGrammar[TypeTree] { case UnitType => List( - terminal(UnitLiteral()) + terminal(UnitLiteral(), Tags.Constant) ) case _ => diff --git a/src/main/scala/leon/grammars/Constants.scala b/src/main/scala/leon/grammars/Constants.scala new file mode 100644 index 000000000..81c553460 --- /dev/null +++ b/src/main/scala/leon/grammars/Constants.scala @@ -0,0 +1,33 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package grammars + +import purescala.Expressions._ +import purescala.Types.TypeTree +import purescala.ExprOps.collect +import purescala.Extractors.IsTyped + +/** Generates constants found in an [[leon.purescala.Expressions.Expr]]. + * Some constants that are generated by other grammars (like 0, 1) will be excluded + */ +case class Constants(e: Expr) extends ExpressionGrammar[TypeTree] { + + private val excluded: Set[Expr] = Set( + InfiniteIntegerLiteral(1), + InfiniteIntegerLiteral(0), + IntLiteral(1), + IntLiteral(0), + BooleanLiteral(true), + BooleanLiteral(false) + ) + + def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Prod] = { + val literals = collect[Expr]{ + case IsTyped(l:Literal[_], `t`) => Set(l) + case _ => Set() + }(e) + + (literals -- excluded map (terminal(_, Tags.Constant))).toSeq + } +} \ No newline at end of file diff --git a/src/main/scala/leon/grammars/DepthBoundedGrammar.scala b/src/main/scala/leon/grammars/DepthBoundedGrammar.scala deleted file mode 100644 index fc999be64..000000000 --- a/src/main/scala/leon/grammars/DepthBoundedGrammar.scala +++ /dev/null @@ -1,19 +0,0 @@ -/* Copyright 2009-2015 EPFL, Lausanne */ - -package leon -package grammars - -case class DepthBoundedGrammar[T](g: ExpressionGrammar[NonTerminal[T]], bound: Int) extends ExpressionGrammar[NonTerminal[T]] { - def computeProductions(l: NonTerminal[T])(implicit ctx: LeonContext): Seq[Gen] = g.computeProductions(l).flatMap { - case gen => - if (l.depth == Some(bound) && gen.subTrees.nonEmpty) { - Nil - } else if (l.depth.exists(_ > bound)) { - Nil - } else { - List ( - nonTerminal(gen.subTrees.map(sl => sl.copy(depth = l.depth.map(_+1).orElse(Some(1)))), gen.builder) - ) - } - } -} diff --git a/src/main/scala/leon/grammars/Empty.scala b/src/main/scala/leon/grammars/Empty.scala index 70ebddc98..737f9cdf3 100644 --- a/src/main/scala/leon/grammars/Empty.scala +++ b/src/main/scala/leon/grammars/Empty.scala @@ -5,6 +5,7 @@ package grammars import purescala.Types.Typed +/** The empty expression grammar */ case class Empty[T <: Typed]() extends ExpressionGrammar[T] { - def computeProductions(t: T)(implicit ctx: LeonContext): Seq[Gen] = Nil + def computeProductions(t: T)(implicit ctx: LeonContext): Seq[Prod] = Nil } diff --git a/src/main/scala/leon/grammars/EqualityGrammar.scala b/src/main/scala/leon/grammars/EqualityGrammar.scala index e9463a771..fdcf079fa 100644 --- a/src/main/scala/leon/grammars/EqualityGrammar.scala +++ b/src/main/scala/leon/grammars/EqualityGrammar.scala @@ -8,11 +8,15 @@ import purescala.Constructors._ import bonsai._ +/** A grammar of equalities + * + * @param types The set of types for which equalities will be generated + */ case class EqualityGrammar(types: Set[TypeTree]) extends ExpressionGrammar[TypeTree] { - override def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Gen] = t match { + def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Prod] = t match { case BooleanType => types.toList map { tp => - nonTerminal(List(tp, tp), { case Seq(a, b) => equality(a, b) }) + nonTerminal(List(tp, tp), { case Seq(a, b) => equality(a, b) }, Tags.Equals) } case _ => Nil diff --git a/src/main/scala/leon/grammars/ExpressionGrammar.scala b/src/main/scala/leon/grammars/ExpressionGrammar.scala index ac394ab84..198410f3d 100644 --- a/src/main/scala/leon/grammars/ExpressionGrammar.scala +++ b/src/main/scala/leon/grammars/ExpressionGrammar.scala @@ -6,23 +6,35 @@ package grammars import purescala.Expressions._ import purescala.Types._ import purescala.Common._ +import transformers.Union import scala.collection.mutable.{HashMap => MutableMap} +/** Represents a context-free grammar of expressions + * + * @tparam T The type of nonterminal symbols for this grammar + */ abstract class ExpressionGrammar[T <: Typed] { - type Gen = Generator[T, Expr] + type Prod = ProductionRule[T, Expr] - private[this] val cache = new MutableMap[T, Seq[Gen]]() + private[this] val cache = new MutableMap[T, Seq[Prod]]() - def terminal(builder: => Expr) = { - Generator[T, Expr](Nil, { (subs: Seq[Expr]) => builder }) + /** Generates a [[ProductionRule]] without nonterminal symbols */ + def terminal(builder: => Expr, tag: Tags.Tag = Tags.Top) = { + ProductionRule[T, Expr](Nil, { (subs: Seq[Expr]) => builder }, tag) } - def nonTerminal(subs: Seq[T], builder: (Seq[Expr] => Expr)): Generator[T, Expr] = { - Generator[T, Expr](subs, builder) + /** Generates a [[ProductionRule]] with nonterminal symbols */ + def nonTerminal(subs: Seq[T], builder: (Seq[Expr] => Expr), tag: Tags.Tag = Tags.Top): ProductionRule[T, Expr] = { + ProductionRule[T, Expr](subs, builder, tag) } - def getProductions(t: T)(implicit ctx: LeonContext): Seq[Gen] = { + /** The list of production rules for this grammar for a given nonterminal. + * This is the cached version of [[getProductions]] which clients should use. + * + * @param t The nonterminal for which production rules will be generated + */ + def getProductions(t: T)(implicit ctx: LeonContext): Seq[Prod] = { cache.getOrElse(t, { val res = computeProductions(t) cache += t -> res @@ -30,9 +42,13 @@ abstract class ExpressionGrammar[T <: Typed] { }) } - def computeProductions(t: T)(implicit ctx: LeonContext): Seq[Gen] + /** The list of production rules for this grammar for a given nonterminal + * + * @param t The nonterminal for which production rules will be generated + */ + def computeProductions(t: T)(implicit ctx: LeonContext): Seq[Prod] - def filter(f: Gen => Boolean) = { + def filter(f: Prod => Boolean) = { new ExpressionGrammar[T] { def computeProductions(t: T)(implicit ctx: LeonContext) = ExpressionGrammar.this.computeProductions(t).filter(f) } @@ -44,14 +60,19 @@ abstract class ExpressionGrammar[T <: Typed] { final def printProductions(printer: String => Unit)(implicit ctx: LeonContext) { - for ((t, gs) <- cache; g <- gs) { - val subs = g.subTrees.map { t => - FreshIdentifier(Console.BOLD+t.asString+Console.RESET, t.getType).toVariable - } + for ((t, gs) <- cache) { + val lhs = f"${Console.BOLD}${t.asString}%50s${Console.RESET} ::=" + if (gs.isEmpty) { + printer(s"$lhs ε") + } else for (g <- gs) { + val subs = g.subTrees.map { t => + FreshIdentifier(Console.BOLD + t.asString + Console.RESET, t.getType).toVariable + } - val gen = g.builder(subs).asString + val gen = g.builder(subs).asString - printer(f"${Console.BOLD}${t.asString}%30s${Console.RESET} ::= $gen") + printer(s"$lhs $gen") + } } } } diff --git a/src/main/scala/leon/grammars/FunctionCalls.scala b/src/main/scala/leon/grammars/FunctionCalls.scala index 14f923939..1233fb193 100644 --- a/src/main/scala/leon/grammars/FunctionCalls.scala +++ b/src/main/scala/leon/grammars/FunctionCalls.scala @@ -10,8 +10,14 @@ import purescala.ExprOps._ import purescala.DefOps._ import purescala.Expressions._ +/** Generates non-recursive function calls + * + * @param currentFunction The currend function for which no calls will be generated + * @param types The candidate real type parameters for [[currentFunction]] + * @param exclude An additional set of functions for which no calls will be generated + */ case class FunctionCalls(prog: Program, currentFunction: FunDef, types: Seq[TypeTree], exclude: Set[FunDef]) extends ExpressionGrammar[TypeTree] { - def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Gen] = { + def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Prod] = { def getCandidates(fd: FunDef): Seq[TypedFunDef] = { // Prevents recursive calls @@ -73,7 +79,7 @@ case class FunctionCalls(prog: Program, currentFunction: FunDef, types: Seq[Type val funcs = visibleFunDefsFromMain(prog).toSeq.sortBy(_.id).flatMap(getCandidates).filterNot(filter) funcs.map{ tfd => - nonTerminal(tfd.params.map(_.getType), { sub => FunctionInvocation(tfd, sub) }) + nonTerminal(tfd.params.map(_.getType), FunctionInvocation(tfd, _), Tags.tagOf(tfd.fd, isSafe = false)) } } } diff --git a/src/main/scala/leon/grammars/Generator.scala b/src/main/scala/leon/grammars/Generator.scala deleted file mode 100644 index 18d132e2c..000000000 --- a/src/main/scala/leon/grammars/Generator.scala +++ /dev/null @@ -1,16 +0,0 @@ -/* Copyright 2009-2015 EPFL, Lausanne */ - -package leon -package grammars - -import bonsai.{Generator => Gen} - -object GrammarTag extends Enumeration { - val Top = Value -} -import GrammarTag._ - -class Generator[T, R](subTrees: Seq[T], builder: Seq[R] => R, tag: Value) extends Gen[T,R](subTrees, builder) -object Generator { - def apply[T, R](subTrees: Seq[T], builder: Seq[R] => R, tag: Value = Top) = new Generator(subTrees, builder, tag) -} \ No newline at end of file diff --git a/src/main/scala/leon/grammars/Grammars.scala b/src/main/scala/leon/grammars/Grammars.scala index 23b1dd5a1..06aba3d5f 100644 --- a/src/main/scala/leon/grammars/Grammars.scala +++ b/src/main/scala/leon/grammars/Grammars.scala @@ -7,6 +7,7 @@ import purescala.Expressions._ import purescala.Definitions._ import purescala.Types._ import purescala.TypeOps._ +import transformers.OneOf import synthesis.{SynthesisContext, Problem} @@ -16,6 +17,7 @@ object Grammars { BaseGrammar || EqualityGrammar(Set(IntegerType, Int32Type, BooleanType) ++ inputs.map { _.getType }) || OneOf(inputs) || + Constants(currentFunction.fullBody) || FunctionCalls(prog, currentFunction, inputs.map(_.getType), exclude) || SafeRecursiveCalls(prog, ws, pc) } @@ -28,3 +30,4 @@ object Grammars { g.filter(g => g.subTrees.forall(t => typeDepth(t.getType) <= b)) } } + diff --git a/src/main/scala/leon/grammars/NonTerminal.scala b/src/main/scala/leon/grammars/NonTerminal.scala index 7492ffac5..600189ffa 100644 --- a/src/main/scala/leon/grammars/NonTerminal.scala +++ b/src/main/scala/leon/grammars/NonTerminal.scala @@ -5,7 +5,14 @@ package grammars import purescala.Types._ -case class NonTerminal[T](t: TypeTree, l: T, depth: Option[Int] = None) extends Typed { +/** A basic non-terminal symbol of a grammar. + * + * @param t The type of which expressions will be generated + * @param l A label that characterizes this [[NonTerminal]] + * @param depth The optional depth within the syntax tree where this [[NonTerminal]] is. + * @tparam L The type of label for this NonTerminal. + */ +case class NonTerminal[L](t: TypeTree, l: L, depth: Option[Int] = None) extends Typed { val getType = t override def asString(implicit ctx: LeonContext) = t.asString+"#"+l+depth.map(d => "@"+d).getOrElse("") diff --git a/src/main/scala/leon/grammars/ProductionRule.scala b/src/main/scala/leon/grammars/ProductionRule.scala new file mode 100644 index 000000000..ded6a3c98 --- /dev/null +++ b/src/main/scala/leon/grammars/ProductionRule.scala @@ -0,0 +1,18 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package grammars + +import bonsai.Generator + +/** Represents a production rule of a non-terminal symbol of an [[ExpressionGrammar]]. + * + * @param subTrees The nonterminals that are used in the right-hand side of this [[ProductionRule]] + * (and will generate deeper syntax trees). + * @param builder A function that builds the syntax tree that this [[ProductionRule]] represents from nested trees. + * @param tag Gives information about the nature of this production rule. + * @tparam T The type of nonterminal symbols of the grammar + * @tparam R The type of syntax trees of the grammar + */ +case class ProductionRule[T, R](override val subTrees: Seq[T], override val builder: Seq[R] => R, tag: Tags.Tag) + extends Generator[T,R](subTrees, builder) diff --git a/src/main/scala/leon/grammars/SafeRecursiveCalls.scala b/src/main/scala/leon/grammars/SafeRecursiveCalls.scala index 1bbcb0523..df2090858 100644 --- a/src/main/scala/leon/grammars/SafeRecursiveCalls.scala +++ b/src/main/scala/leon/grammars/SafeRecursiveCalls.scala @@ -9,15 +9,24 @@ import purescala.ExprOps._ import purescala.Expressions._ import synthesis.utils.Helpers._ +/** Generates recursive calls that will not trivially result in non-termination. + * + * @param ws An expression that contains the known set [[synthesis.Witnesses.Terminating]] expressions + * @param pc The path condition for the generated [[Expr]] by this grammar + */ case class SafeRecursiveCalls(prog: Program, ws: Expr, pc: Expr) extends ExpressionGrammar[TypeTree] { - def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Gen] = { + def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Prod] = { val calls = terminatingCalls(prog, t, ws, pc) calls.map { - case (e, free) => + case (fi, free) => val freeSeq = free.toSeq - nonTerminal(freeSeq.map(_.getType), { sub => replaceFromIDs(freeSeq.zip(sub).toMap, e) }) + nonTerminal( + freeSeq.map(_.getType), + { sub => replaceFromIDs(freeSeq.zip(sub).toMap, fi) }, + Tags.tagOf(fi.tfd.fd, isSafe = true) + ) } } } diff --git a/src/main/scala/leon/grammars/SimilarTo.scala b/src/main/scala/leon/grammars/SimilarTo.scala index 77e912792..3a7708e9a 100644 --- a/src/main/scala/leon/grammars/SimilarTo.scala +++ b/src/main/scala/leon/grammars/SimilarTo.scala @@ -3,21 +3,24 @@ package leon package grammars +import transformers._ import purescala.Types._ import purescala.TypeOps._ import purescala.Extractors._ import purescala.Definitions._ -import purescala.ExprOps._ -import purescala.DefOps._ import purescala.Expressions._ import synthesis._ +/** A grammar that generates expressions by inserting small variations in [[e]] + * @param e The [[Expr]] to which small variations will be inserted + * @param terminals A set of [[Expr]]s that may be inserted into [[e]] as small variations + */ case class SimilarTo(e: Expr, terminals: Set[Expr] = Set(), sctx: SynthesisContext, p: Problem) extends ExpressionGrammar[NonTerminal[String]] { val excludeFCalls = sctx.settings.functionsToIgnore - val normalGrammar = DepthBoundedGrammar(EmbeddedGrammar( + val normalGrammar: ExpressionGrammar[NonTerminal[String]] = DepthBoundedGrammar(EmbeddedGrammar( BaseGrammar || EqualityGrammar(Set(IntegerType, Int32Type, BooleanType) ++ terminals.map { _.getType }) || OneOf(terminals.toSeq :+ e) || @@ -37,9 +40,9 @@ case class SimilarTo(e: Expr, terminals: Set[Expr] = Set(), sctx: SynthesisConte } } - private[this] var similarCache: Option[Map[L, Seq[Gen]]] = None + private[this] var similarCache: Option[Map[L, Seq[Prod]]] = None - def computeProductions(t: L)(implicit ctx: LeonContext): Seq[Gen] = { + def computeProductions(t: L)(implicit ctx: LeonContext): Seq[Prod] = { t match { case NonTerminal(_, "B", _) => normalGrammar.computeProductions(t) case _ => @@ -54,7 +57,7 @@ case class SimilarTo(e: Expr, terminals: Set[Expr] = Set(), sctx: SynthesisConte } } - def computeSimilar(e : Expr)(implicit ctx: LeonContext): Seq[(L, Gen)] = { + def computeSimilar(e : Expr)(implicit ctx: LeonContext): Seq[(L, Prod)] = { def getLabel(t: TypeTree) = { val tpe = bestRealType(t) @@ -67,9 +70,9 @@ case class SimilarTo(e: Expr, terminals: Set[Expr] = Set(), sctx: SynthesisConte case _ => false } - def rec(e: Expr, gl: L): Seq[(L, Gen)] = { + def rec(e: Expr, gl: L): Seq[(L, Prod)] = { - def gens(e: Expr, gl: L, subs: Seq[Expr], builder: (Seq[Expr] => Expr)): Seq[(L, Gen)] = { + def gens(e: Expr, gl: L, subs: Seq[Expr], builder: (Seq[Expr] => Expr)): Seq[(L, Prod)] = { val subGls = subs.map { s => getLabel(s.getType) } // All the subproductions for sub gl @@ -81,8 +84,8 @@ case class SimilarTo(e: Expr, terminals: Set[Expr] = Set(), sctx: SynthesisConte } val swaps = if (subs.size > 1 && !isCommutative(e)) { - (for (i <- 0 until subs.size; - j <- i+1 until subs.size) yield { + (for (i <- subs.indices; + j <- i+1 until subs.size) yield { if (subs(i).getType == subs(j).getType) { val swapSubs = subs.updated(i, subs(j)).updated(j, subs(i)) @@ -98,18 +101,18 @@ case class SimilarTo(e: Expr, terminals: Set[Expr] = Set(), sctx: SynthesisConte allSubs ++ injectG ++ swaps } - def cegis(gl: L): Seq[(L, Gen)] = { + def cegis(gl: L): Seq[(L, Prod)] = { normalGrammar.getProductions(gl).map(gl -> _) } - def int32Variations(gl: L, e : Expr): Seq[(L, Gen)] = { + def int32Variations(gl: L, e : Expr): Seq[(L, Prod)] = { Seq( gl -> terminal(BVMinus(e, IntLiteral(1))), gl -> terminal(BVPlus (e, IntLiteral(1))) ) } - def intVariations(gl: L, e : Expr): Seq[(L, Gen)] = { + def intVariations(gl: L, e : Expr): Seq[(L, Prod)] = { Seq( gl -> terminal(Minus(e, InfiniteIntegerLiteral(1))), gl -> terminal(Plus (e, InfiniteIntegerLiteral(1))) @@ -118,7 +121,7 @@ case class SimilarTo(e: Expr, terminals: Set[Expr] = Set(), sctx: SynthesisConte // Find neighbor case classes that are compatible with the arguments: // Turns And(e1, e2) into Or(e1, e2)... - def ccVariations(gl: L, cc: CaseClass): Seq[(L, Gen)] = { + def ccVariations(gl: L, cc: CaseClass): Seq[(L, Prod)] = { val CaseClass(cct, args) = cc val neighbors = cct.root.knownCCDescendants diff Seq(cct) @@ -129,7 +132,7 @@ case class SimilarTo(e: Expr, terminals: Set[Expr] = Set(), sctx: SynthesisConte } val funFilter = (fd: FunDef) => fd.isSynthetic || (excludeFCalls contains fd) - val subs: Seq[(L, Gen)] = e match { + val subs: Seq[(L, Prod)] = e match { case _: Terminal | _: Let | _: LetDef | _: MatchExpr => gens(e, gl, Nil, { _ => e }) ++ cegis(gl) diff --git a/src/main/scala/leon/grammars/SizeBoundedGrammar.scala b/src/main/scala/leon/grammars/SizeBoundedGrammar.scala deleted file mode 100644 index 1b25e30f6..000000000 --- a/src/main/scala/leon/grammars/SizeBoundedGrammar.scala +++ /dev/null @@ -1,35 +0,0 @@ -/* Copyright 2009-2015 EPFL, Lausanne */ - -package leon -package grammars - -import purescala.Types._ -import leon.utils.SeqUtils.sumTo - -case class SizedLabel[T <: Typed](underlying: T, size: Int) extends Typed { - val getType = underlying.getType - - override def asString(implicit ctx: LeonContext) = underlying.asString+"|"+size+"|" -} - -case class SizeBoundedGrammar[T <: Typed](g: ExpressionGrammar[T]) extends ExpressionGrammar[SizedLabel[T]] { - def computeProductions(sl: SizedLabel[T])(implicit ctx: LeonContext): Seq[Gen] = { - if (sl.size <= 0) { - Nil - } else if (sl.size == 1) { - g.getProductions(sl.underlying).filter(_.subTrees.isEmpty).map { gen => - terminal(gen.builder(Seq())) - } - } else { - g.getProductions(sl.underlying).filter(_.subTrees.nonEmpty).flatMap { gen => - val sizes = sumTo(sl.size-1, gen.subTrees.size) - - for (ss <- sizes) yield { - val subSizedLabels = (gen.subTrees zip ss) map (s => SizedLabel(s._1, s._2)) - - nonTerminal(subSizedLabels, gen.builder) - } - } - } - } -} diff --git a/src/main/scala/leon/grammars/Tags.scala b/src/main/scala/leon/grammars/Tags.scala new file mode 100644 index 000000000..4a6b6fca4 --- /dev/null +++ b/src/main/scala/leon/grammars/Tags.scala @@ -0,0 +1,65 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package grammars + +import purescala.Types.CaseClassType +import purescala.Definitions.FunDef + +object Tags { + /** A class for tags that tag a [[ProductionRule]] with the kind of expression in generates. */ + abstract class Tag + case object Top extends Tag // Tag for the top-level of the grammar (default) + case object Zero extends Tag // Tag for 0 + case object One extends Tag // Tag for 1 + case object BooleanC extends Tag // Tag for boolean constants + case object Constant extends Tag // Tag for other constants + case object And extends Tag // Tags for boolean operations + case object Or extends Tag + case object Not extends Tag + case object Plus extends Tag // Tags for arithmetic operations + case object Minus extends Tag + case object Times extends Tag + case object Mod extends Tag + case object Div extends Tag + case object Variable extends Tag // Tag for variables + case object Equals extends Tag // Tag for equality + /** Constructors like Tuple, CaseClass... + * + * @param isTerminal If true, this constructor represents a terminal symbol + * (in practice, case class with 0 fields) + */ + case class Constructor(isTerminal: Boolean) extends Tag + /** Tag for function calls + * + * @param isMethod Whether the function called is a method + * @param isSafe Whether this constructor represents a safe function call. + * We need this because this call implicitly contains a variable, + * so we want to allow constants in all arguments. + */ + case class FunCall(isMethod: Boolean, isSafe: Boolean) extends Tag + + /** The set of tags that represent constants */ + val isConst: Set[Tag] = Set(Zero, One, Constant, BooleanC, Constructor(true)) + + /** The set of tags that represent commutative operations */ + val isCommut: Set[Tag] = Set(Plus, Times, Equals) + + /** The set of tags which have trivial results for equal arguments */ + val symmetricTrivial = Set(Minus, And, Or, Equals, Div, Mod) + + /** Tags which allow constants in all their operands + * + * In reality, the current version never allows that: it is only allowed in safe function calls + * which by construction contain a hidden reference to a variable. + * TODO: Experiment with different conditions, e.g. are constants allowed in + * top-level/ general function calls/ constructors/...? + */ + def allConstArgsAllowed(t: Tag) = t match { + case FunCall(_, true) => true + case _ => false + } + + def tagOf(cct: CaseClassType) = Constructor(cct.fields.isEmpty) + def tagOf(fd: FunDef, isSafe: Boolean) = FunCall(fd.methodOwner.isDefined, isSafe) +} \ No newline at end of file diff --git a/src/main/scala/leon/grammars/ValueGrammar.scala b/src/main/scala/leon/grammars/ValueGrammar.scala index 98850c8df..8548c1a95 100644 --- a/src/main/scala/leon/grammars/ValueGrammar.scala +++ b/src/main/scala/leon/grammars/ValueGrammar.scala @@ -6,31 +6,32 @@ package grammars import purescala.Types._ import purescala.Expressions._ +/** A grammar of values (ground terms) */ case object ValueGrammar extends ExpressionGrammar[TypeTree] { - def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Gen] = t match { + def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Prod] = t match { case BooleanType => List( - terminal(BooleanLiteral(true)), - terminal(BooleanLiteral(false)) + terminal(BooleanLiteral(true), Tags.One), + terminal(BooleanLiteral(false), Tags.Zero) ) case Int32Type => List( - terminal(IntLiteral(0)), - terminal(IntLiteral(1)), - terminal(IntLiteral(5)) + terminal(IntLiteral(0), Tags.Zero), + terminal(IntLiteral(1), Tags.One), + terminal(IntLiteral(5), Tags.Constant) ) case IntegerType => List( - terminal(InfiniteIntegerLiteral(0)), - terminal(InfiniteIntegerLiteral(1)), - terminal(InfiniteIntegerLiteral(5)) + terminal(InfiniteIntegerLiteral(0), Tags.Zero), + terminal(InfiniteIntegerLiteral(1), Tags.One), + terminal(InfiniteIntegerLiteral(5), Tags.Constant) ) case StringType => List( - terminal(StringLiteral("")), - terminal(StringLiteral("a")), - terminal(StringLiteral("foo")), - terminal(StringLiteral("bar")) + terminal(StringLiteral(""), Tags.Constant), + terminal(StringLiteral("a"), Tags.Constant), + terminal(StringLiteral("foo"), Tags.Constant), + terminal(StringLiteral("bar"), Tags.Constant) ) case tp: TypeParameter => @@ -40,28 +41,29 @@ case object ValueGrammar extends ExpressionGrammar[TypeTree] { case TupleType(stps) => List( - nonTerminal(stps, { sub => Tuple(sub) }) + nonTerminal(stps, Tuple, Tags.Constructor(stps.isEmpty)) ) case cct: CaseClassType => List( - nonTerminal(cct.fields.map(_.getType), { case rs => CaseClass(cct, rs)}) + nonTerminal(cct.fields.map(_.getType), CaseClass(cct, _), Tags.tagOf(cct)) ) case act: AbstractClassType => act.knownCCDescendants.map { cct => - nonTerminal(cct.fields.map(_.getType), { case rs => CaseClass(cct, rs)}) + nonTerminal(cct.fields.map(_.getType), CaseClass(cct, _), Tags.tagOf(cct)) } case st @ SetType(base) => List( - nonTerminal(List(base), { case elems => FiniteSet(elems.toSet, base) }), - nonTerminal(List(base, base), { case elems => FiniteSet(elems.toSet, base) }) + terminal(FiniteSet(Set(), base), Tags.Constant), + nonTerminal(List(base), { elems => FiniteSet(elems.toSet, base) }, Tags.Constructor(isTerminal = false)), + nonTerminal(List(base, base), { elems => FiniteSet(elems.toSet, base) }, Tags.Constructor(isTerminal = false)) ) case UnitType => List( - terminal(UnitLiteral()) + terminal(UnitLiteral(), Tags.Constant) ) case _ => diff --git a/src/main/scala/leon/grammars/transformers/DepthBoundedGrammar.scala b/src/main/scala/leon/grammars/transformers/DepthBoundedGrammar.scala new file mode 100644 index 000000000..02e045497 --- /dev/null +++ b/src/main/scala/leon/grammars/transformers/DepthBoundedGrammar.scala @@ -0,0 +1,21 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package grammars +package transformers + +/** Limits a grammar to a specific expression depth */ +case class DepthBoundedGrammar[L](g: ExpressionGrammar[NonTerminal[L]], bound: Int) extends ExpressionGrammar[NonTerminal[L]] { + def computeProductions(l: NonTerminal[L])(implicit ctx: LeonContext): Seq[Prod] = g.computeProductions(l).flatMap { + case gen => + if (l.depth == Some(bound) && gen.isNonTerminal) { + Nil + } else if (l.depth.exists(_ > bound)) { + Nil + } else { + List ( + nonTerminal(gen.subTrees.map(sl => sl.copy(depth = l.depth.map(_+1).orElse(Some(1)))), gen.builder) + ) + } + } +} diff --git a/src/main/scala/leon/grammars/EmbeddedGrammar.scala b/src/main/scala/leon/grammars/transformers/EmbeddedGrammar.scala similarity index 74% rename from src/main/scala/leon/grammars/EmbeddedGrammar.scala rename to src/main/scala/leon/grammars/transformers/EmbeddedGrammar.scala index 8dcbc6ec1..d989a8804 100644 --- a/src/main/scala/leon/grammars/EmbeddedGrammar.scala +++ b/src/main/scala/leon/grammars/transformers/EmbeddedGrammar.scala @@ -2,10 +2,9 @@ package leon package grammars +package transformers -import purescala.Types._ -import purescala.Expressions._ -import purescala.Constructors._ +import leon.purescala.Types.Typed /** * Embed a grammar Li->Expr within a grammar Lo->Expr @@ -13,9 +12,9 @@ import purescala.Constructors._ * We rely on a bijection between Li and Lo labels */ case class EmbeddedGrammar[Ti <: Typed, To <: Typed](innerGrammar: ExpressionGrammar[Ti], iToo: Ti => To, oToi: To => Ti) extends ExpressionGrammar[To] { - def computeProductions(t: To)(implicit ctx: LeonContext): Seq[Gen] = { + def computeProductions(t: To)(implicit ctx: LeonContext): Seq[Prod] = { innerGrammar.computeProductions(oToi(t)).map { innerGen => - nonTerminal(innerGen.subTrees.map(iToo), innerGen.builder) + nonTerminal(innerGen.subTrees.map(iToo), innerGen.builder, innerGen.tag) } } } diff --git a/src/main/scala/leon/grammars/OneOf.scala b/src/main/scala/leon/grammars/transformers/OneOf.scala similarity index 56% rename from src/main/scala/leon/grammars/OneOf.scala rename to src/main/scala/leon/grammars/transformers/OneOf.scala index 0e10c0961..5c57c6a1a 100644 --- a/src/main/scala/leon/grammars/OneOf.scala +++ b/src/main/scala/leon/grammars/transformers/OneOf.scala @@ -2,14 +2,15 @@ package leon package grammars +package transformers -import purescala.Types._ -import purescala.Expressions._ -import purescala.TypeOps._ -import purescala.Constructors._ +import purescala.Expressions.Expr +import purescala.Types.TypeTree +import purescala.TypeOps.isSubtypeOf +/** Generates one production rule for each expression in a sequence that has compatible type */ case class OneOf(inputs: Seq[Expr]) extends ExpressionGrammar[TypeTree] { - def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Gen] = { + def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Prod] = { inputs.collect { case i if isSubtypeOf(i.getType, t) => terminal(i) diff --git a/src/main/scala/leon/grammars/transformers/SizeBoundedGrammar.scala b/src/main/scala/leon/grammars/transformers/SizeBoundedGrammar.scala new file mode 100644 index 000000000..5abff1aa7 --- /dev/null +++ b/src/main/scala/leon/grammars/transformers/SizeBoundedGrammar.scala @@ -0,0 +1,59 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package grammars +package transformers + +import purescala.Types.Typed +import utils.SeqUtils._ + +/** Adds information about size to a nonterminal symbol */ +case class SizedNonTerm[T <: Typed](underlying: T, size: Int) extends Typed { + val getType = underlying.getType + + override def asString(implicit ctx: LeonContext) = underlying.asString+"|"+size+"|" +} + +/** Limits a grammar by producing expressions of size bounded by the [[SizedNonTerm.size]] of a given [[SizedNonTerm]]. + * + * In case of commutative operations, the grammar will produce trees skewed to the right + * (i.e. the right subtree will always be larger). Notice we do not lose generality in case of + * commutative operations. + */ +case class SizeBoundedGrammar[T <: Typed](g: ExpressionGrammar[T], optimizeCommut: Boolean) extends ExpressionGrammar[SizedNonTerm[T]] { + def computeProductions(sl: SizedNonTerm[T])(implicit ctx: LeonContext): Seq[Prod] = { + if (sl.size <= 0) { + Nil + } else if (sl.size == 1) { + g.getProductions(sl.underlying).filter(_.isTerminal).map { gen => + terminal(gen.builder(Seq()), gen.tag) + } + } else { + g.getProductions(sl.underlying).filter(_.isNonTerminal).flatMap { gen => + + // Ad-hoc equality that does not take into account position of nonterminals. + // TODO: Ugly and hacky + def characteristic(t: T): AnyRef = t match { + case TaggedNonTerm(underlying, tag, _, isConst) => + (underlying, tag, isConst) + case other => + other + } + + // Optimization: When we have a commutative operation and all the labels are the same, + // we can skew the expression to always be right-heavy + val sizes = if(optimizeCommut && Tags.isCommut(gen.tag) && gen.subTrees.map(characteristic).toSet.size == 1) { + sumToOrdered(sl.size-1, gen.arity) + } else { + sumTo(sl.size-1, gen.arity) + } + + for (ss <- sizes) yield { + val subSizedLabels = (gen.subTrees zip ss) map (s => SizedNonTerm(s._1, s._2)) + + nonTerminal(subSizedLabels, gen.builder, gen.tag) + } + } + } + } +} diff --git a/src/main/scala/leon/grammars/transformers/TaggedGrammar.scala b/src/main/scala/leon/grammars/transformers/TaggedGrammar.scala new file mode 100644 index 000000000..a95306f7d --- /dev/null +++ b/src/main/scala/leon/grammars/transformers/TaggedGrammar.scala @@ -0,0 +1,112 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package grammars +package transformers + +import leon.purescala.Types.Typed +import Tags._ + +/** Adds to a nonterminal information about about the tag of its parent's [[leon.grammars.ProductionRule.tag]] + * and additional information. + * + * @param underlying The underlying nonterminal + * @param tag The tag of the parent of this nonterminal + * @param pos The index of this nonterminal in its father's production rule + * @param isConst Whether this nonterminal is obliged to generate/not generate constants. + * + */ +case class TaggedNonTerm[T <: Typed](underlying: T, tag: Tag, pos: Int, isConst: Option[Boolean]) extends Typed { + val getType = underlying.getType + + private val cString = isConst match { + case Some(true) => "↓" + case Some(false) => "↑" + case None => "○" + } + + /** [[isConst]] is printed as follows: ↓ for constants only, ↑ for nonconstants only, + * ○ for anything allowed. + */ + override def asString(implicit ctx: LeonContext): String = s"$underlying%$tag@$pos$cString" +} + +/** Constraints a grammar to reduce redundancy by utilizing information provided by the [[TaggedNonTerm]]. + * + * 1) In case of associative operations, right associativity is enforced. + * 2) Does not generate + * - neutral and absorbing elements (incl. boolean equality) + * - nested negations + * - trivial operations for symmetric arguments, e.g. a == a + * 3) Excludes method calls on nullary case objects, e.g. Nil().size + * 4) Enforces that no constant trees are generated (and recursively for each subtree) + * + * @param g The underlying untagged grammar + */ +case class TaggedGrammar[T <: Typed](g: ExpressionGrammar[T]) extends ExpressionGrammar[TaggedNonTerm[T]] { + + private def exclude(tag: Tag, pos: Int): Set[Tag] = (tag, pos) match { + case (Top, _) => Set() + case (And, 0) => Set(And, BooleanC) + case (And, 1) => Set(BooleanC) + case (Or, 0) => Set(Or, BooleanC) + case (Or, 1) => Set(BooleanC) + case (Plus, 0) => Set(Plus, Zero, One) + case (Plus, 1) => Set(Zero) + case (Minus, 1) => Set(Zero) + case (Not, _) => Set(Not, BooleanC) + case (Times, 0) => Set(Times, Zero, One) + case (Times, 1) => Set(Zero, One) + case (Equals,_) => Set(Not, BooleanC) + case (Div | Mod, 0 | 1) => Set(Zero, One) + case (FunCall(true, _), 0) => Set(Constructor(true)) // Don't allow Nil().size etc. + case _ => Set() + } + + def computeProductions(t: TaggedNonTerm[T])(implicit ctx: LeonContext): Seq[Prod] = { + + // Point (4) for this level + val constFilter: g.Prod => Boolean = t.isConst match { + case Some(b) => + innerGen => isConst(innerGen.tag) == b + case None => + _ => true + } + + g.computeProductions(t.underlying) + // Include only constants iff constants are forced, only non-constants iff they are forced + .filter(constFilter) + // Points (1), (2). (3) + .filterNot { innerGen => exclude(t.tag, t.pos)(innerGen.tag) } + .flatMap { innerGen => + + def nt(isConst: Int => Option[Boolean]) = nonTerminal( + innerGen.subTrees.zipWithIndex.map { + case (t, pos) => TaggedNonTerm(t, innerGen.tag, pos, isConst(pos)) + }, + innerGen.builder, + innerGen.tag + ) + + def powerSet[A](t: Set[A]): Set[Set[A]] = { + @scala.annotation.tailrec + def pwr(t: Set[A], ps: Set[Set[A]]): Set[Set[A]] = + if (t.isEmpty) ps + else pwr(t.tail, ps ++ (ps map (_ + t.head))) + + pwr(t, Set(Set.empty[A])) + } + + // Allow constants everywhere if this is allowed, otherwise demand at least 1 variable. + // Aka. tag subTrees correctly so point (4) is enforced in the lower level + // (also, make sure we treat terminals correctly). + if (innerGen.isTerminal || allConstArgsAllowed(innerGen.tag)) { + Seq(nt(_ => None)) + } else { + val indices = innerGen.subTrees.indices.toSet + (powerSet(indices) - indices) map (indices => nt(x => Some(indices(x)))) + } + } + } + +} diff --git a/src/main/scala/leon/grammars/Or.scala b/src/main/scala/leon/grammars/transformers/Union.scala similarity index 73% rename from src/main/scala/leon/grammars/Or.scala rename to src/main/scala/leon/grammars/transformers/Union.scala index e691a2459..471625ac3 100644 --- a/src/main/scala/leon/grammars/Or.scala +++ b/src/main/scala/leon/grammars/transformers/Union.scala @@ -2,8 +2,9 @@ package leon package grammars +package transformers -import purescala.Types._ +import purescala.Types.Typed case class Union[T <: Typed](gs: Seq[ExpressionGrammar[T]]) extends ExpressionGrammar[T] { val subGrammars: Seq[ExpressionGrammar[T]] = gs.flatMap { @@ -11,6 +12,6 @@ case class Union[T <: Typed](gs: Seq[ExpressionGrammar[T]]) extends ExpressionGr case g => Seq(g) } - def computeProductions(t: T)(implicit ctx: LeonContext): Seq[Gen] = + def computeProductions(t: T)(implicit ctx: LeonContext): Seq[Prod] = subGrammars.flatMap(_.getProductions(t)) } diff --git a/src/main/scala/leon/purescala/DefOps.scala b/src/main/scala/leon/purescala/DefOps.scala index 81ebbdbec..4085050d2 100644 --- a/src/main/scala/leon/purescala/DefOps.scala +++ b/src/main/scala/leon/purescala/DefOps.scala @@ -275,6 +275,14 @@ object DefOps { None } + /** + * + * @param p + * @param fdMapF + * @param fiMapF + * @return + */ + def replaceFunDefs(p: Program)(fdMapF: FunDef => Option[FunDef], fiMapF: (FunctionInvocation, FunDef) => Option[Expr] = defaultFiMap) = { diff --git a/src/main/scala/leon/purescala/TypeOps.scala b/src/main/scala/leon/purescala/TypeOps.scala index 81b8420c3..20cca91e1 100644 --- a/src/main/scala/leon/purescala/TypeOps.scala +++ b/src/main/scala/leon/purescala/TypeOps.scala @@ -13,7 +13,7 @@ import ExprOps.preMap object TypeOps { def typeDepth(t: TypeTree): Int = t match { - case NAryType(tps, builder) => 1+ (0 +: (tps map typeDepth)).max + case NAryType(tps, builder) => 1 + (0 +: (tps map typeDepth)).max } def typeParamsOf(t: TypeTree): Set[TypeParameter] = t match { diff --git a/src/main/scala/leon/repair/Repairman.scala b/src/main/scala/leon/repair/Repairman.scala index 37f187679..19095f6f4 100644 --- a/src/main/scala/leon/repair/Repairman.scala +++ b/src/main/scala/leon/repair/Repairman.scala @@ -242,7 +242,7 @@ class Repairman(ctx0: LeonContext, initProgram: Program, fd: FunDef, verifTimeou val maxValid = 400 val evaluator = new CodeGenEvaluator(ctx, program, CodeGenParams.default) - val enum = new MemoizedEnumerator[TypeTree, Expr, Generator[TypeTree, Expr]](ValueGrammar.getProductions) + val enum = new MemoizedEnumerator[TypeTree, Expr, ProductionRule[TypeTree, Expr]](ValueGrammar.getProductions) val inputs = enum.iterator(tupleTypeWrap(fd.params map { _.getType})).map(unwrapTuple(_, fd.params.size)) diff --git a/src/main/scala/leon/synthesis/ExamplesFinder.scala b/src/main/scala/leon/synthesis/ExamplesFinder.scala index 5077f9467..5728d3b14 100644 --- a/src/main/scala/leon/synthesis/ExamplesFinder.scala +++ b/src/main/scala/leon/synthesis/ExamplesFinder.scala @@ -182,7 +182,7 @@ class ExamplesFinder(ctx0: LeonContext, program: Program) { }) getOrElse { // If the input contains free variables, it does not provide concrete examples. // We will instantiate them according to a simple grammar to get them. - val enum = new MemoizedEnumerator[TypeTree, Expr, Generator[TypeTree, Expr]](ValueGrammar.getProductions) + val enum = new MemoizedEnumerator[TypeTree, Expr, ProductionRule[TypeTree, Expr]](ValueGrammar.getProductions) val values = enum.iterator(tupleTypeWrap(freeVars.map { _.getType })) val instantiations = values.map { v => freeVars.zip(unwrapTuple(v, freeVars.size)).toMap diff --git a/src/main/scala/leon/synthesis/Rules.scala b/src/main/scala/leon/synthesis/Rules.scala index cd27e272d..6583d4ca8 100644 --- a/src/main/scala/leon/synthesis/Rules.scala +++ b/src/main/scala/leon/synthesis/Rules.scala @@ -35,6 +35,9 @@ abstract class PreprocessingRule(name: String) extends Rule(name) { /** Contains the list of all available rules for synthesis */ object Rules { + + private val newCegis = true + /** Returns the list of all available rules for synthesis */ def all = List[Rule]( StringRender, @@ -54,8 +57,8 @@ object Rules { OptimisticGround, EqualitySplit, InequalitySplit, - CEGIS, - TEGIS, + if(newCegis) CEGIS2 else CEGIS, + //TEGIS, //BottomUpTEGIS, rules.Assert, DetupleOutput, diff --git a/src/main/scala/leon/synthesis/SourceInfo.scala b/src/main/scala/leon/synthesis/SourceInfo.scala index 4bb10d38c..8ab07929d 100644 --- a/src/main/scala/leon/synthesis/SourceInfo.scala +++ b/src/main/scala/leon/synthesis/SourceInfo.scala @@ -45,6 +45,10 @@ object SourceInfo { ci } + if (results.isEmpty) { + ctx.reporter.warning("No 'choose' found. Maybe the functions you chose do not exist?") + } + results.sortBy(_.source.getPos) } diff --git a/src/main/scala/leon/synthesis/SynthesisPhase.scala b/src/main/scala/leon/synthesis/SynthesisPhase.scala index b9ba6df1f..373e4fb56 100644 --- a/src/main/scala/leon/synthesis/SynthesisPhase.scala +++ b/src/main/scala/leon/synthesis/SynthesisPhase.scala @@ -3,13 +3,11 @@ package leon package synthesis -import purescala.ExprOps._ - +import purescala.ExprOps.replace import purescala.ScalaPrinter -import leon.utils._ import purescala.Definitions.{Program, FunDef} -import leon.utils.ASCIIHelpers +import leon.utils._ import graph._ object SynthesisPhase extends TransformationPhase { diff --git a/src/main/scala/leon/synthesis/Synthesizer.scala b/src/main/scala/leon/synthesis/Synthesizer.scala index bafed6ec2..efd1ad13e 100644 --- a/src/main/scala/leon/synthesis/Synthesizer.scala +++ b/src/main/scala/leon/synthesis/Synthesizer.scala @@ -70,21 +70,19 @@ class Synthesizer(val context : LeonContext, // Print out report for synthesis, if necessary reporter.ifDebug { printer => - import java.io.FileWriter import java.text.SimpleDateFormat import java.util.Date val categoryName = ci.fd.getPos.file.toString.split("/").dropRight(1).lastOption.getOrElse("?") val benchName = categoryName+"."+ci.fd.id.name - var time = lastTime/1000.0; + val time = lastTime/1000.0 val defs = visibleDefsFrom(ci.fd)(program).collect { case cd: ClassDef => 1 + cd.fields.size case fd: FunDef => 1 + fd.params.size + formulaSize(fd.fullBody) } - val psize = defs.sum; - + val psize = defs.sum val (size, calls, proof) = result.headOption match { case Some((sol, trusted)) => diff --git a/src/main/scala/leon/synthesis/disambiguation/QuestionBuilder.scala b/src/main/scala/leon/synthesis/disambiguation/QuestionBuilder.scala index bfa6a6212..244925b90 100644 --- a/src/main/scala/leon/synthesis/disambiguation/QuestionBuilder.scala +++ b/src/main/scala/leon/synthesis/disambiguation/QuestionBuilder.scala @@ -69,15 +69,15 @@ object QuestionBuilder { /** Specific enumeration of strings, which can be used with the QuestionBuilder#setValueEnumerator method */ object SpecialStringValueGrammar extends ExpressionGrammar[TypeTree] { - def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Gen] = t match { - case StringType => - List( - terminal(StringLiteral("")), - terminal(StringLiteral("a")), - terminal(StringLiteral("\"'\n\t")), - terminal(StringLiteral("Lara 2007")) - ) - case _ => ValueGrammar.computeProductions(t) + def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Prod] = t match { + case StringType => + List( + terminal(StringLiteral("")), + terminal(StringLiteral("a")), + terminal(StringLiteral("\"'\n\t")), + terminal(StringLiteral("Lara 2007")) + ) + case _ => ValueGrammar.computeProductions(t) } } } @@ -140,7 +140,7 @@ class QuestionBuilder[T <: Expr]( def result(): List[Question[T]] = { if(solutions.isEmpty) return Nil - val enum = new MemoizedEnumerator[TypeTree, Expr, Generator[TypeTree,Expr]](value_enumerator.getProductions) + val enum = new MemoizedEnumerator[TypeTree, Expr, ProductionRule[TypeTree,Expr]](value_enumerator.getProductions) val values = enum.iterator(tupleTypeWrap(_argTypes)) val instantiations = values.map { v => input.zip(unwrapTuple(v, input.size)) @@ -172,4 +172,4 @@ class QuestionBuilder[T <: Expr]( } questions.toList.sortBy(_questionSorMethod(_)) } -} \ No newline at end of file +} diff --git a/src/main/scala/leon/synthesis/rules/BottomUpTegis.scala b/src/main/scala/leon/synthesis/rules/BottomUpTegis.scala index 2f3869af1..f716997f4 100644 --- a/src/main/scala/leon/synthesis/rules/BottomUpTegis.scala +++ b/src/main/scala/leon/synthesis/rules/BottomUpTegis.scala @@ -51,13 +51,13 @@ abstract class BottomUpTEGISLike[T <: Typed](name: String) extends Rule(name) { val nTests = tests.size - var compiled = Map[Generator[T, Expr], Vector[Vector[Expr]] => Option[Vector[Expr]]]() + var compiled = Map[ProductionRule[T, Expr], Vector[Vector[Expr]] => Option[Vector[Expr]]]() /** * Compile Generators to functions from Expr to Expr. The compiled * generators will be passed to the enumerator */ - def compile(gen: Generator[T, Expr]): Vector[Vector[Expr]] => Option[Vector[Expr]] = { + def compile(gen: ProductionRule[T, Expr]): Vector[Vector[Expr]] => Option[Vector[Expr]] = { compiled.getOrElse(gen, { val executor = if (gen.subTrees.isEmpty) { @@ -108,7 +108,7 @@ abstract class BottomUpTEGISLike[T <: Typed](name: String) extends Rule(name) { val targetType = tupleTypeWrap(p.xs.map(_.getType)) val wrappedTests = tests.map { case (is, os) => (is, tupleWrap(os))} - val enum = new BottomUpEnumerator[T, Expr, Expr, Generator[T, Expr]]( + val enum = new BottomUpEnumerator[T, Expr, Expr, ProductionRule[T, Expr]]( grammar.getProductions, wrappedTests, { (vecs, gen) => diff --git a/src/main/scala/leon/synthesis/rules/CEGIS.scala b/src/main/scala/leon/synthesis/rules/CEGIS.scala index 1fcf01d52..139c6116e 100644 --- a/src/main/scala/leon/synthesis/rules/CEGIS.scala +++ b/src/main/scala/leon/synthesis/rules/CEGIS.scala @@ -4,16 +4,29 @@ package leon package synthesis package rules -import purescala.Types._ - import grammars._ -import utils._ +import grammars.transformers._ +import purescala.Types.TypeTree case object CEGIS extends CEGISLike[TypeTree]("CEGIS") { def getParams(sctx: SynthesisContext, p: Problem) = { CegisParams( grammar = Grammars.typeDepthBound(Grammars.default(sctx, p), 2), // This limits type depth - rootLabel = {(tpe: TypeTree) => tpe } + rootLabel = {(tpe: TypeTree) => tpe }, + maxUnfoldings = 12, + optimizations = false ) } } + +case object CEGIS2 extends CEGISLike[TaggedNonTerm[TypeTree]]("CEGIS2") { + def getParams(sctx: SynthesisContext, p: Problem) = { + val base = CEGIS.getParams(sctx,p).grammar + CegisParams( + grammar = TaggedGrammar(base), + rootLabel = TaggedNonTerm(_, Tags.Top, 0, None), + maxUnfoldings = 12, + optimizations = true + ) + } +} \ No newline at end of file diff --git a/src/main/scala/leon/synthesis/rules/CEGISLike.scala b/src/main/scala/leon/synthesis/rules/CEGISLike.scala index f643a7bb0..368608b90 100644 --- a/src/main/scala/leon/synthesis/rules/CEGISLike.scala +++ b/src/main/scala/leon/synthesis/rules/CEGISLike.scala @@ -4,10 +4,6 @@ package leon package synthesis package rules -import leon.utils.SeqUtils -import solvers._ -import grammars._ - import purescala.Expressions._ import purescala.Common._ import purescala.Definitions._ @@ -16,18 +12,24 @@ import purescala.ExprOps._ import purescala.DefOps._ import purescala.Constructors._ -import scala.collection.mutable.{HashMap=>MutableMap, ArrayBuffer} +import solvers._ +import grammars._ +import grammars.transformers._ +import leon.utils.SeqUtils import evaluators._ import datagen._ import codegen.CodeGenParams +import scala.collection.mutable.{HashMap=>MutableMap, ArrayBuffer} + abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { case class CegisParams( grammar: ExpressionGrammar[T], rootLabel: TypeTree => T, - maxUnfoldings: Int = 5 + maxUnfoldings: Int = 5, + optimizations: Boolean ) def getParams(sctx: SynthesisContext, p: Problem): CegisParams @@ -36,7 +38,7 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { val exSolverTo = 2000L val cexSolverTo = 2000L - // Track non-deterministic programs up to 10'000 programs, or give up + // Track non-deterministic programs up to 100'000 programs, or give up val nProgramsLimit = 100000 val sctx = hctx.sctx @@ -48,6 +50,7 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { // Limits the number of programs CEGIS will specifically validate individually val validateUpTo = 3 + val passingRatio = 10 val interruptManager = sctx.context.interruptManager @@ -61,13 +64,13 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { private var termSize = 0 - val grammar = SizeBoundedGrammar(params.grammar) + val grammar = SizeBoundedGrammar(params.grammar, params.optimizations) - def rootLabel = SizedLabel(params.rootLabel(tupleTypeWrap(p.xs.map(_.getType))), termSize) + def rootLabel = SizedNonTerm(params.rootLabel(tupleTypeWrap(p.xs.map(_.getType))), termSize) - var nAltsCache = Map[SizedLabel[T], Int]() + var nAltsCache = Map[SizedNonTerm[T], Int]() - def countAlternatives(l: SizedLabel[T]): Int = { + def countAlternatives(l: SizedNonTerm[T]): Int = { if (!(nAltsCache contains l)) { val count = grammar.getProductions(l).map { gen => gen.subTrees.map(countAlternatives).product @@ -102,7 +105,7 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { // C identifiers corresponding to p.xs - private var rootC: Identifier = _ + private var rootC: Identifier = _ private var bs: Set[Identifier] = Set() @@ -110,19 +113,19 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { class CGenerator { - private var buffers = Map[SizedLabel[T], Stream[Identifier]]() + private var buffers = Map[SizedNonTerm[T], Stream[Identifier]]() - private var slots = Map[SizedLabel[T], Int]().withDefaultValue(0) + private var slots = Map[SizedNonTerm[T], Int]().withDefaultValue(0) - private def streamOf(t: SizedLabel[T]): Stream[Identifier] = Stream.continually( + private def streamOf(t: SizedNonTerm[T]): Stream[Identifier] = Stream.continually( FreshIdentifier(t.asString, t.getType, true) ) def rewind(): Unit = { - slots = Map[SizedLabel[T], Int]().withDefaultValue(0) + slots = Map[SizedNonTerm[T], Int]().withDefaultValue(0) } - def getNext(t: SizedLabel[T]) = { + def getNext(t: SizedNonTerm[T]) = { if (!(buffers contains t)) { buffers += t -> streamOf(t) } @@ -146,7 +149,7 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { id } - def defineCTreeFor(l: SizedLabel[T], c: Identifier): Unit = { + def defineCTreeFor(l: SizedNonTerm[T], c: Identifier): Unit = { if (!(cTree contains c)) { val cGen = new CGenerator() @@ -182,6 +185,7 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { sctx.reporter.ifDebug { printer => printer("Grammar so far:") grammar.printProductions(printer) + printer("") } bsOrdered = bs.toSeq.sorted @@ -233,12 +237,47 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { cache(c) } - SeqUtils.cartesianProduct(seqs).map { ls => - ls.foldLeft(Set[Identifier]())(_ ++ _) + SeqUtils.cartesianProduct(seqs).map(_.flatten.toSet) + } + + def redundant(e: Expr): Boolean = { + val (op1, op2) = e match { + case Minus(o1, o2) => (o1, o2) + case Modulo(o1, o2) => (o1, o2) + case Division(o1, o2) => (o1, o2) + case BVMinus(o1, o2) => (o1, o2) + case BVRemainder(o1, o2) => (o1, o2) + case BVDivision(o1, o2) => (o1, o2) + + case And(Seq(Not(o1), Not(o2))) => (o1, o2) + case And(Seq(Not(o1), o2)) => (o1, o2) + case And(Seq(o1, Not(o2))) => (o1, o2) + case And(Seq(o1, o2)) => (o1, o2) + + case Or(Seq(Not(o1), Not(o2))) => (o1, o2) + case Or(Seq(Not(o1), o2)) => (o1, o2) + case Or(Seq(o1, Not(o2))) => (o1, o2) + case Or(Seq(o1, o2)) => (o1, o2) + + case SetUnion(o1, o2) => (o1, o2) + case SetIntersection(o1, o2) => (o1, o2) + case SetDifference(o1, o2) => (o1, o2) + + case Equals(Not(o1), Not(o2)) => (o1, o2) + case Equals(Not(o1), o2) => (o1, o2) + case Equals(o1, Not(o2)) => (o1, o2) + case Equals(o1, o2) => (o1, o2) + case _ => return false } + + op1 == op2 } - allProgramsFor(Seq(rootC)) + allProgramsFor(Seq(rootC))/* filterNot { bs => + val res = params.optimizations && exists(redundant)(getExpr(bs)) + if (!res) excludeProgram(bs, false) + res + }*/ } private def debugCTree(cTree: Map[Identifier, Seq[(Identifier, Seq[Expr] => Expr, Seq[Identifier])]], @@ -432,10 +471,9 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { } } - - // Returns the outer expression corresponding to a B-valuation def getExpr(bValues: Set[Identifier]): Expr = { + def getCValue(c: Identifier): Expr = { cTree(c).find(i => bValues(i._1)).map { case (b, builder, cs) => @@ -455,56 +493,64 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { def validatePrograms(bss: Set[Set[Identifier]]): Either[Stream[Solution], Seq[Seq[Expr]]] = { val origImpl = cTreeFd.fullBody - val cexs = for (bs <- bss.toSeq) yield { + var cexs = Seq[Seq[Expr]]() + + for (bs <- bss.toSeq) { val outerSol = getExpr(bs) val innerSol = outerExprToInnerExpr(outerSol) - + //println(s"Testing $outerSol") cTreeFd.fullBody = innerSol val cnstr = and(innerPc, letTuple(p.xs, innerSol, Not(innerPhi))) - //println("Solving for: "+cnstr.asString) - - val solverf = SolverFactory.getFromSettings(ctx, innerProgram).withTimeout(cexSolverTo) - val solver = solverf.getNewSolver() - try { - solver.assertCnstr(cnstr) - solver.check match { - case Some(true) => - excludeProgram(bs, true) - val model = solver.getModel - //println("Found counter example: ") - //for ((s, v) <- model) { - // println(" "+s.asString+" -> "+v.asString) - //} - - //val evaluator = new DefaultEvaluator(ctx, prog) - //println(evaluator.eval(cnstr, model)) - - Some(p.as.map(a => model.getOrElse(a, simplestValue(a.getType)))) + val eval = new DefaultEvaluator(ctx, innerProgram) - case Some(false) => - // UNSAT, valid program - return Left(Stream(Solution(BooleanLiteral(true), Set(), outerSol, true))) + if (cexs exists (cex => eval.eval(cnstr, p.as.zip(cex).toMap).result == Some(BooleanLiteral(true)))) { + //println(s"Program $outerSol fails!") + excludeProgram(bs, true) + cTreeFd.fullBody = origImpl + } else { + //println("Solving for: "+cnstr.asString) + + val solverf = SolverFactory.getFromSettings(ctx, innerProgram).withTimeout(cexSolverTo) + val solver = solverf.getNewSolver() + try { + solver.assertCnstr(cnstr) + solver.check match { + case Some(true) => + excludeProgram(bs, true) + val model = solver.getModel + //println("Found counter example: ") + //for ((s, v) <- model) { + // println(" "+s.asString+" -> "+v.asString) + //} + + //val evaluator = new DefaultEvaluator(ctx, prog) + //println(evaluator.eval(cnstr, model)) + //println(s"Program $outerSol fails with cex ${p.as.map(a => model.getOrElse(a, simplestValue(a.getType)))}") + cexs +:= p.as.map(a => model.getOrElse(a, simplestValue(a.getType))) + + case Some(false) => + // UNSAT, valid program + return Left(Stream(Solution(BooleanLiteral(true), Set(), outerSol, true))) - case None => - if (useOptTimeout) { - // Interpret timeout in CE search as "the candidate is valid" - sctx.reporter.info("CEGIS could not prove the validity of the resulting expression") - // Optimistic valid solution - return Left(Stream(Solution(BooleanLiteral(true), Set(), outerSol, false))) - } else { - None - } + case None => + if (useOptTimeout) { + // Interpret timeout in CE search as "the candidate is valid" + sctx.reporter.info("CEGIS could not prove the validity of the resulting expression") + // Optimistic valid solution + return Left(Stream(Solution(BooleanLiteral(true), Set(), outerSol, false))) + } + } + } finally { + solverf.reclaim(solver) + solverf.shutdown() + cTreeFd.fullBody = origImpl } - } finally { - solverf.reclaim(solver) - solverf.shutdown() - cTreeFd.fullBody = origImpl } } - Right(cexs.flatten) + Right(cexs) } var excludedPrograms = ArrayBuffer[Set[Identifier]]() @@ -663,6 +709,10 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { val sctx = hctx.sctx implicit val ctx = sctx.context + import leon.utils.Timer + val timer = new Timer + timer.start + val ndProgram = new NonDeterministicProgram(p) ndProgram.init() @@ -677,7 +727,7 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { ndProgram.grammar.printProductions(printer) } - // We populate the list of examples with a predefined one + // We populate the list of examples with a defined one sctx.reporter.debug("Acquiring initial list of examples") baseExampleInputs ++= p.eb.examples @@ -793,12 +843,22 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { } val nPassing = prunedPrograms.size - sctx.reporter.debug("#Programs passing tests: "+nPassing) + val nTotal = ndProgram.allProgramsCount() + + /*locally { + val progs = ndProgram.allPrograms() map ndProgram.getExpr + val ground = progs count isGround + println("Programs") + progs take 100 foreach println + println(s"$ground ground out of $nTotal") + }*/ + + sctx.reporter.debug(s"#Programs passing tests: $nPassing out of $nTotal") sctx.reporter.ifDebug{ printer => - for (p <- prunedPrograms.take(10)) { + for (p <- prunedPrograms.take(100)) { printer(" - "+ndProgram.getExpr(p).asString) } - if(nPassing > 10) { + if(nPassing > 100) { printer(" - ...") } } @@ -812,16 +872,23 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { } } - if (nPassing == 0 || interruptManager.isInterrupted) { // No test passed, we can skip solver and unfold again, if possible skipCESearch = true } else { var doFilter = true - if (validateUpTo > 0) { - // Validate the first N programs individually - ndProgram.validatePrograms(prunedPrograms.take(validateUpTo)) match { + // If the number of pruned programs is very small, or by far smaller than the number of total programs, + // we hypothesize it will be easier to just validate them individually. + // Otherwise, we validate a small number of programs just in case we are lucky FIXME is this last clause useful? + val programsToValidate = if (nTotal / nPassing > passingRatio || nPassing < 10) { + prunedPrograms + } else { + prunedPrograms.take(validateUpTo) + } + + if (programsToValidate.nonEmpty) { + ndProgram.validatePrograms(programsToValidate) match { case Left(sols) if sols.nonEmpty => doFilter = false result = Some(RuleClosed(sols)) @@ -919,6 +986,8 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { sctx.reporter.warning("CEGIS crashed: "+e.getMessage) e.printStackTrace() RuleFailed() + } finally { + ctx.reporter.info(s"CEGIS ran for ${timer.stop} ms") } } }) diff --git a/src/main/scala/leon/synthesis/rules/CEGLESS.scala b/src/main/scala/leon/synthesis/rules/CEGLESS.scala index c12edac07..357a3d888 100644 --- a/src/main/scala/leon/synthesis/rules/CEGLESS.scala +++ b/src/main/scala/leon/synthesis/rules/CEGLESS.scala @@ -4,10 +4,10 @@ package leon package synthesis package rules +import leon.grammars.transformers.Union import purescala.ExprOps._ import purescala.Types._ import purescala.Extractors._ -import utils._ import grammars._ import Witnesses._ @@ -24,7 +24,7 @@ case object CEGLESS extends CEGISLike[NonTerminal[String]]("CEGLESS") { val inputs = p.as.map(_.toVariable) sctx.reporter.ifDebug { printer => - printer("Guides available:") + printer("Guides available:") for (g <- guides) { printer(" - "+g.asString(ctx)) } @@ -35,7 +35,8 @@ case object CEGLESS extends CEGISLike[NonTerminal[String]]("CEGLESS") { CegisParams( grammar = guidedGrammar, rootLabel = { (tpe: TypeTree) => NonTerminal(tpe, "G0") }, - maxUnfoldings = (0 +: guides.map(depth(_) + 1)).max + maxUnfoldings = (0 +: guides.map(depth(_) + 1)).max, + optimizations = false ) } } diff --git a/src/main/scala/leon/synthesis/rules/DetupleInput.scala b/src/main/scala/leon/synthesis/rules/DetupleInput.scala index 2ae2b1d5d..d3b4c823d 100644 --- a/src/main/scala/leon/synthesis/rules/DetupleInput.scala +++ b/src/main/scala/leon/synthesis/rules/DetupleInput.scala @@ -83,7 +83,7 @@ case object DetupleInput extends NormalizingRule("Detuple In") { } } - var eb = p.qeb.mapIns { info => + val eb = p.qeb.mapIns { info => List(info.flatMap { case (id, v) => ebMapInfo.get(id) match { case Some(m) => @@ -103,7 +103,8 @@ case object DetupleInput extends NormalizingRule("Detuple In") { case CaseClass(ct, args) => val (cts, es) = args.zip(ct.fields).map { case (CaseClassSelector(ct, e, id), field) if field.id == id => (ct, e) - case _ => return e + case _ => + return e }.unzip if (cts.distinct.size == 1 && es.distinct.size == 1) { @@ -126,7 +127,7 @@ case object DetupleInput extends NormalizingRule("Detuple In") { val sub = Problem(newAs, subWs, subPc, subProblem, p.xs, eb) - val s = {substAll(reverseMap, _:Expr)} andThen { simplePostTransform(recompose) } + val s = (substAll(reverseMap, _:Expr)) andThen simplePostTransform(recompose) Some(decomp(List(sub), forwardMap(s), s"Detuple ${reverseMap.keySet.mkString(", ")}")) } else { diff --git a/src/main/scala/leon/synthesis/rules/TEGISLike.scala b/src/main/scala/leon/synthesis/rules/TEGISLike.scala index a6e060f89..eea2b0504 100644 --- a/src/main/scala/leon/synthesis/rules/TEGISLike.scala +++ b/src/main/scala/leon/synthesis/rules/TEGISLike.scala @@ -67,7 +67,7 @@ abstract class TEGISLike[T <: Typed](name: String) extends Rule(name) { val evalParams = CodeGenParams.default.copy(maxFunctionInvocations = 2000) val evaluator = new DualEvaluator(sctx.context, sctx.program, evalParams) - val enum = new MemoizedEnumerator[T, Expr, Generator[T, Expr]](grammar.getProductions) + val enum = new MemoizedEnumerator[T, Expr, ProductionRule[T, Expr]](grammar.getProductions) val targetType = tupleTypeWrap(p.xs.map(_.getType)) diff --git a/src/main/scala/leon/synthesis/utils/Helpers.scala b/src/main/scala/leon/synthesis/utils/Helpers.scala index acd285a45..4bfedc4ac 100644 --- a/src/main/scala/leon/synthesis/utils/Helpers.scala +++ b/src/main/scala/leon/synthesis/utils/Helpers.scala @@ -34,7 +34,18 @@ object Helpers { } } - def terminatingCalls(prog: Program, tpe: TypeTree, ws: Expr, pc: Expr): List[(Expr, Set[Identifier])] = { + /** Given an initial set of function calls provided by a list of [[Terminating]], + * returns function calls that will hopefully be safe to call recursively from within this initial function calls. + * + * For each returned call, one argument is substituted by a "smaller" one, while the rest are left as holes. + * + * @param prog The current program + * @param tpe The expected type for the returned function calls + * @param ws Helper predicates that contain [[Terminating]]s with the initial calls + * @param pc The path condition + * @return A list of pairs of (safe function call, holes), where holes stand for the rest of the arguments of the function. + */ + def terminatingCalls(prog: Program, tpe: TypeTree, ws: Expr, pc: Expr): List[(FunctionInvocation, Set[Identifier])] = { val TopLevelAnds(wss) = ws val TopLevelAnds(clauses) = pc diff --git a/src/main/scala/leon/utils/SeqUtils.scala b/src/main/scala/leon/utils/SeqUtils.scala index 002f2ebed..363808e45 100644 --- a/src/main/scala/leon/utils/SeqUtils.scala +++ b/src/main/scala/leon/utils/SeqUtils.scala @@ -42,6 +42,20 @@ object SeqUtils { } } } + + def sumToOrdered(sum: Int, arity: Int): Seq[Seq[Int]] = { + def rec(sum: Int, arity: Int): Seq[Seq[Int]] = { + require(arity > 0) + if (sum < 0) Nil + else if (arity == 1) Seq(Seq(sum)) + else for { + n <- 0 to sum / arity + rest <- rec(sum - arity * n, arity - 1) + } yield n +: rest.map(n + _) + } + + rec(sum, arity) filterNot (_.head == 0) + } } class CartesianView[+A](views: Seq[SeqView[A, Seq[A]]]) extends SeqView[Seq[A], Seq[Seq[A]]] { -- GitLab