diff --git a/src/main/scala/leon/datagen/GrammarDataGen.scala b/src/main/scala/leon/datagen/GrammarDataGen.scala index cd86c707ddb893918e512f6d7101cc4cc92b6405..23c1ed5b882e9197e14610ad15695fd6e5663d54 100644 --- a/src/main/scala/leon/datagen/GrammarDataGen.scala +++ b/src/main/scala/leon/datagen/GrammarDataGen.scala @@ -20,7 +20,7 @@ class GrammarDataGen(evaluator: Evaluator, grammar: ExpressionGrammar[TypeTree] implicit val ctx = evaluator.context def generate(tpe: TypeTree): Iterator[Expr] = { - val enum = new MemoizedEnumerator[TypeTree, Expr, Generator[TypeTree, Expr]](grammar.getProductions) + val enum = new MemoizedEnumerator[TypeTree, Expr, ProductionRule[TypeTree, Expr]](grammar.getProductions) enum.iterator(tpe) } diff --git a/src/main/scala/leon/grammars/BaseGrammar.scala b/src/main/scala/leon/grammars/BaseGrammar.scala index f11f937498051eb47c2c522a0faa2a1499545175..6e0a2ee5e6842255aac5755c9ee27005e5360eb5 100644 --- a/src/main/scala/leon/grammars/BaseGrammar.scala +++ b/src/main/scala/leon/grammars/BaseGrammar.scala @@ -7,56 +7,65 @@ import purescala.Types._ import purescala.Expressions._ import purescala.Constructors._ +/** The basic grammar for Leon expressions. + * Generates the most obvious expressions for a given type, + * without regard of context (variables in scope, current function etc.) + * Also does some trivial simplifications. + */ case object BaseGrammar extends ExpressionGrammar[TypeTree] { - def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Gen] = t match { + + def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Prod] = t match { case BooleanType => List( - terminal(BooleanLiteral(true)), - terminal(BooleanLiteral(false)), - nonTerminal(List(BooleanType), { case Seq(a) => not(a) }), - nonTerminal(List(BooleanType, BooleanType), { case Seq(a, b) => and(a, b) }), - nonTerminal(List(BooleanType, BooleanType), { case Seq(a, b) => or(a, b) }), - nonTerminal(List(Int32Type, Int32Type), { case Seq(a, b) => LessThan(a, b) }), + terminal(BooleanLiteral(false), Tags.BooleanC), + terminal(BooleanLiteral(true), Tags.BooleanC), + nonTerminal(List(BooleanType), { case Seq(a) => not(a) }, Tags.Not), + nonTerminal(List(BooleanType, BooleanType), { case Seq(a, b) => and(a, b) }, Tags.And), + nonTerminal(List(BooleanType, BooleanType), { case Seq(a, b) => or(a, b) }, Tags.Or ), + nonTerminal(List(Int32Type, Int32Type), { case Seq(a, b) => LessThan(a, b) }), nonTerminal(List(Int32Type, Int32Type), { case Seq(a, b) => LessEquals(a, b) }), - nonTerminal(List(IntegerType, IntegerType), { case Seq(a, b) => LessThan(a, b) }), + nonTerminal(List(IntegerType, IntegerType), { case Seq(a, b) => LessThan(a, b) }), nonTerminal(List(IntegerType, IntegerType), { case Seq(a, b) => LessEquals(a, b) }) ) case Int32Type => List( - terminal(IntLiteral(0)), - terminal(IntLiteral(1)), - nonTerminal(List(Int32Type, Int32Type), { case Seq(a,b) => plus(a, b) }), - nonTerminal(List(Int32Type, Int32Type), { case Seq(a,b) => minus(a, b) }), - nonTerminal(List(Int32Type, Int32Type), { case Seq(a,b) => times(a, b) }) + terminal(IntLiteral(0), Tags.Zero), + terminal(IntLiteral(1), Tags.One ), + nonTerminal(List(Int32Type, Int32Type), { case Seq(a,b) => plus(a, b) }, Tags.Plus ), + nonTerminal(List(Int32Type, Int32Type), { case Seq(a,b) => minus(a, b) }, Tags.Minus), + nonTerminal(List(Int32Type, Int32Type), { case Seq(a,b) => times(a, b) }, Tags.Times) ) case IntegerType => List( - terminal(InfiniteIntegerLiteral(0)), - terminal(InfiniteIntegerLiteral(1)), - nonTerminal(List(IntegerType, IntegerType), { case Seq(a,b) => plus(a, b) }), - nonTerminal(List(IntegerType, IntegerType), { case Seq(a,b) => minus(a, b) }), - nonTerminal(List(IntegerType, IntegerType), { case Seq(a,b) => times(a, b) }) + terminal(InfiniteIntegerLiteral(0), Tags.Zero), + terminal(InfiniteIntegerLiteral(1), Tags.One ), + nonTerminal(List(IntegerType, IntegerType), { case Seq(a,b) => plus(a, b) }, Tags.Plus ), + nonTerminal(List(IntegerType, IntegerType), { case Seq(a,b) => minus(a, b) }, Tags.Minus), + nonTerminal(List(IntegerType, IntegerType), { case Seq(a,b) => times(a, b) }, Tags.Times), + nonTerminal(List(IntegerType, IntegerType), { case Seq(a,b) => Modulo(a, b) }, Tags.Mod), + nonTerminal(List(IntegerType, IntegerType), { case Seq(a,b) => Division(a, b) }, Tags.Div) ) case TupleType(stps) => List( - nonTerminal(stps, { sub => Tuple(sub) }) + nonTerminal(stps, Tuple, Tags.Constructor(isTerminal = false)) ) case cct: CaseClassType => List( - nonTerminal(cct.fields.map(_.getType), { case rs => CaseClass(cct, rs)} ) + nonTerminal(cct.fields.map(_.getType), CaseClass(cct, _), Tags.tagOf(cct) ) ) case act: AbstractClassType => act.knownCCDescendants.map { cct => - nonTerminal(cct.fields.map(_.getType), { case rs => CaseClass(cct, rs)} ) + nonTerminal(cct.fields.map(_.getType), CaseClass(cct, _), Tags.tagOf(cct) ) } case st @ SetType(base) => List( - nonTerminal(List(base), { case elems => FiniteSet(elems.toSet, base) }), + terminal(FiniteSet(Set(), base), Tags.Constant), + nonTerminal(List(base), { case elems => FiniteSet(elems.toSet, base) }, Tags.Constructor(isTerminal = false)), nonTerminal(List(st, st), { case Seq(a, b) => SetUnion(a, b) }), nonTerminal(List(st, st), { case Seq(a, b) => SetIntersection(a, b) }), nonTerminal(List(st, st), { case Seq(a, b) => SetDifference(a, b) }) @@ -64,7 +73,7 @@ case object BaseGrammar extends ExpressionGrammar[TypeTree] { case UnitType => List( - terminal(UnitLiteral()) + terminal(UnitLiteral(), Tags.Constant) ) case _ => diff --git a/src/main/scala/leon/grammars/Constants.scala b/src/main/scala/leon/grammars/Constants.scala new file mode 100644 index 0000000000000000000000000000000000000000..81c55346052668e0d82b05ee240867eb1e5c468c --- /dev/null +++ b/src/main/scala/leon/grammars/Constants.scala @@ -0,0 +1,33 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package grammars + +import purescala.Expressions._ +import purescala.Types.TypeTree +import purescala.ExprOps.collect +import purescala.Extractors.IsTyped + +/** Generates constants found in an [[leon.purescala.Expressions.Expr]]. + * Some constants that are generated by other grammars (like 0, 1) will be excluded + */ +case class Constants(e: Expr) extends ExpressionGrammar[TypeTree] { + + private val excluded: Set[Expr] = Set( + InfiniteIntegerLiteral(1), + InfiniteIntegerLiteral(0), + IntLiteral(1), + IntLiteral(0), + BooleanLiteral(true), + BooleanLiteral(false) + ) + + def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Prod] = { + val literals = collect[Expr]{ + case IsTyped(l:Literal[_], `t`) => Set(l) + case _ => Set() + }(e) + + (literals -- excluded map (terminal(_, Tags.Constant))).toSeq + } +} \ No newline at end of file diff --git a/src/main/scala/leon/grammars/DepthBoundedGrammar.scala b/src/main/scala/leon/grammars/DepthBoundedGrammar.scala deleted file mode 100644 index fc999be644bf2c4a7a20a73403cf7b1001bb9b68..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/grammars/DepthBoundedGrammar.scala +++ /dev/null @@ -1,19 +0,0 @@ -/* Copyright 2009-2015 EPFL, Lausanne */ - -package leon -package grammars - -case class DepthBoundedGrammar[T](g: ExpressionGrammar[NonTerminal[T]], bound: Int) extends ExpressionGrammar[NonTerminal[T]] { - def computeProductions(l: NonTerminal[T])(implicit ctx: LeonContext): Seq[Gen] = g.computeProductions(l).flatMap { - case gen => - if (l.depth == Some(bound) && gen.subTrees.nonEmpty) { - Nil - } else if (l.depth.exists(_ > bound)) { - Nil - } else { - List ( - nonTerminal(gen.subTrees.map(sl => sl.copy(depth = l.depth.map(_+1).orElse(Some(1)))), gen.builder) - ) - } - } -} diff --git a/src/main/scala/leon/grammars/Empty.scala b/src/main/scala/leon/grammars/Empty.scala index 70ebddc98f21fc872aef8635fe36de7e9ba9bbce..737f9cdf389454f403a6581e13eec7fafa383f34 100644 --- a/src/main/scala/leon/grammars/Empty.scala +++ b/src/main/scala/leon/grammars/Empty.scala @@ -5,6 +5,7 @@ package grammars import purescala.Types.Typed +/** The empty expression grammar */ case class Empty[T <: Typed]() extends ExpressionGrammar[T] { - def computeProductions(t: T)(implicit ctx: LeonContext): Seq[Gen] = Nil + def computeProductions(t: T)(implicit ctx: LeonContext): Seq[Prod] = Nil } diff --git a/src/main/scala/leon/grammars/EqualityGrammar.scala b/src/main/scala/leon/grammars/EqualityGrammar.scala index e9463a771204d877a4d748c373b6d198e2c2591b..fdcf079fa736c7996b686f26bace5832c1da6adc 100644 --- a/src/main/scala/leon/grammars/EqualityGrammar.scala +++ b/src/main/scala/leon/grammars/EqualityGrammar.scala @@ -8,11 +8,15 @@ import purescala.Constructors._ import bonsai._ +/** A grammar of equalities + * + * @param types The set of types for which equalities will be generated + */ case class EqualityGrammar(types: Set[TypeTree]) extends ExpressionGrammar[TypeTree] { - override def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Gen] = t match { + def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Prod] = t match { case BooleanType => types.toList map { tp => - nonTerminal(List(tp, tp), { case Seq(a, b) => equality(a, b) }) + nonTerminal(List(tp, tp), { case Seq(a, b) => equality(a, b) }, Tags.Equals) } case _ => Nil diff --git a/src/main/scala/leon/grammars/ExpressionGrammar.scala b/src/main/scala/leon/grammars/ExpressionGrammar.scala index ac394ab840bddf0d498080a04e447ce66de07caa..198410f3df5f959b01b00c2fe9b80901a37226ed 100644 --- a/src/main/scala/leon/grammars/ExpressionGrammar.scala +++ b/src/main/scala/leon/grammars/ExpressionGrammar.scala @@ -6,23 +6,35 @@ package grammars import purescala.Expressions._ import purescala.Types._ import purescala.Common._ +import transformers.Union import scala.collection.mutable.{HashMap => MutableMap} +/** Represents a context-free grammar of expressions + * + * @tparam T The type of nonterminal symbols for this grammar + */ abstract class ExpressionGrammar[T <: Typed] { - type Gen = Generator[T, Expr] + type Prod = ProductionRule[T, Expr] - private[this] val cache = new MutableMap[T, Seq[Gen]]() + private[this] val cache = new MutableMap[T, Seq[Prod]]() - def terminal(builder: => Expr) = { - Generator[T, Expr](Nil, { (subs: Seq[Expr]) => builder }) + /** Generates a [[ProductionRule]] without nonterminal symbols */ + def terminal(builder: => Expr, tag: Tags.Tag = Tags.Top) = { + ProductionRule[T, Expr](Nil, { (subs: Seq[Expr]) => builder }, tag) } - def nonTerminal(subs: Seq[T], builder: (Seq[Expr] => Expr)): Generator[T, Expr] = { - Generator[T, Expr](subs, builder) + /** Generates a [[ProductionRule]] with nonterminal symbols */ + def nonTerminal(subs: Seq[T], builder: (Seq[Expr] => Expr), tag: Tags.Tag = Tags.Top): ProductionRule[T, Expr] = { + ProductionRule[T, Expr](subs, builder, tag) } - def getProductions(t: T)(implicit ctx: LeonContext): Seq[Gen] = { + /** The list of production rules for this grammar for a given nonterminal. + * This is the cached version of [[getProductions]] which clients should use. + * + * @param t The nonterminal for which production rules will be generated + */ + def getProductions(t: T)(implicit ctx: LeonContext): Seq[Prod] = { cache.getOrElse(t, { val res = computeProductions(t) cache += t -> res @@ -30,9 +42,13 @@ abstract class ExpressionGrammar[T <: Typed] { }) } - def computeProductions(t: T)(implicit ctx: LeonContext): Seq[Gen] + /** The list of production rules for this grammar for a given nonterminal + * + * @param t The nonterminal for which production rules will be generated + */ + def computeProductions(t: T)(implicit ctx: LeonContext): Seq[Prod] - def filter(f: Gen => Boolean) = { + def filter(f: Prod => Boolean) = { new ExpressionGrammar[T] { def computeProductions(t: T)(implicit ctx: LeonContext) = ExpressionGrammar.this.computeProductions(t).filter(f) } @@ -44,14 +60,19 @@ abstract class ExpressionGrammar[T <: Typed] { final def printProductions(printer: String => Unit)(implicit ctx: LeonContext) { - for ((t, gs) <- cache; g <- gs) { - val subs = g.subTrees.map { t => - FreshIdentifier(Console.BOLD+t.asString+Console.RESET, t.getType).toVariable - } + for ((t, gs) <- cache) { + val lhs = f"${Console.BOLD}${t.asString}%50s${Console.RESET} ::=" + if (gs.isEmpty) { + printer(s"$lhs ε") + } else for (g <- gs) { + val subs = g.subTrees.map { t => + FreshIdentifier(Console.BOLD + t.asString + Console.RESET, t.getType).toVariable + } - val gen = g.builder(subs).asString + val gen = g.builder(subs).asString - printer(f"${Console.BOLD}${t.asString}%30s${Console.RESET} ::= $gen") + printer(s"$lhs $gen") + } } } } diff --git a/src/main/scala/leon/grammars/FunctionCalls.scala b/src/main/scala/leon/grammars/FunctionCalls.scala index 14f92393934c18804bdb130e9c1617b915a347bd..1233fb1931a83b5ca674019be0c85144339dd19f 100644 --- a/src/main/scala/leon/grammars/FunctionCalls.scala +++ b/src/main/scala/leon/grammars/FunctionCalls.scala @@ -10,8 +10,14 @@ import purescala.ExprOps._ import purescala.DefOps._ import purescala.Expressions._ +/** Generates non-recursive function calls + * + * @param currentFunction The currend function for which no calls will be generated + * @param types The candidate real type parameters for [[currentFunction]] + * @param exclude An additional set of functions for which no calls will be generated + */ case class FunctionCalls(prog: Program, currentFunction: FunDef, types: Seq[TypeTree], exclude: Set[FunDef]) extends ExpressionGrammar[TypeTree] { - def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Gen] = { + def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Prod] = { def getCandidates(fd: FunDef): Seq[TypedFunDef] = { // Prevents recursive calls @@ -73,7 +79,7 @@ case class FunctionCalls(prog: Program, currentFunction: FunDef, types: Seq[Type val funcs = visibleFunDefsFromMain(prog).toSeq.sortBy(_.id).flatMap(getCandidates).filterNot(filter) funcs.map{ tfd => - nonTerminal(tfd.params.map(_.getType), { sub => FunctionInvocation(tfd, sub) }) + nonTerminal(tfd.params.map(_.getType), FunctionInvocation(tfd, _), Tags.tagOf(tfd.fd, isSafe = false)) } } } diff --git a/src/main/scala/leon/grammars/Generator.scala b/src/main/scala/leon/grammars/Generator.scala deleted file mode 100644 index 18d132e2c25ea222324dc05809220f12d0fb7100..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/grammars/Generator.scala +++ /dev/null @@ -1,16 +0,0 @@ -/* Copyright 2009-2015 EPFL, Lausanne */ - -package leon -package grammars - -import bonsai.{Generator => Gen} - -object GrammarTag extends Enumeration { - val Top = Value -} -import GrammarTag._ - -class Generator[T, R](subTrees: Seq[T], builder: Seq[R] => R, tag: Value) extends Gen[T,R](subTrees, builder) -object Generator { - def apply[T, R](subTrees: Seq[T], builder: Seq[R] => R, tag: Value = Top) = new Generator(subTrees, builder, tag) -} \ No newline at end of file diff --git a/src/main/scala/leon/grammars/Grammars.scala b/src/main/scala/leon/grammars/Grammars.scala index 23b1dd5a14cfeb82dd4555832e777597615b337e..06aba3d5f5343cc7e2807854f0b4665bfa1a602c 100644 --- a/src/main/scala/leon/grammars/Grammars.scala +++ b/src/main/scala/leon/grammars/Grammars.scala @@ -7,6 +7,7 @@ import purescala.Expressions._ import purescala.Definitions._ import purescala.Types._ import purescala.TypeOps._ +import transformers.OneOf import synthesis.{SynthesisContext, Problem} @@ -16,6 +17,7 @@ object Grammars { BaseGrammar || EqualityGrammar(Set(IntegerType, Int32Type, BooleanType) ++ inputs.map { _.getType }) || OneOf(inputs) || + Constants(currentFunction.fullBody) || FunctionCalls(prog, currentFunction, inputs.map(_.getType), exclude) || SafeRecursiveCalls(prog, ws, pc) } @@ -28,3 +30,4 @@ object Grammars { g.filter(g => g.subTrees.forall(t => typeDepth(t.getType) <= b)) } } + diff --git a/src/main/scala/leon/grammars/NonTerminal.scala b/src/main/scala/leon/grammars/NonTerminal.scala index 7492ffac5c17df326084f846857c6ac3bebe1775..600189ffa06378841f6bf3285f7f1bd7bb6116f5 100644 --- a/src/main/scala/leon/grammars/NonTerminal.scala +++ b/src/main/scala/leon/grammars/NonTerminal.scala @@ -5,7 +5,14 @@ package grammars import purescala.Types._ -case class NonTerminal[T](t: TypeTree, l: T, depth: Option[Int] = None) extends Typed { +/** A basic non-terminal symbol of a grammar. + * + * @param t The type of which expressions will be generated + * @param l A label that characterizes this [[NonTerminal]] + * @param depth The optional depth within the syntax tree where this [[NonTerminal]] is. + * @tparam L The type of label for this NonTerminal. + */ +case class NonTerminal[L](t: TypeTree, l: L, depth: Option[Int] = None) extends Typed { val getType = t override def asString(implicit ctx: LeonContext) = t.asString+"#"+l+depth.map(d => "@"+d).getOrElse("") diff --git a/src/main/scala/leon/grammars/ProductionRule.scala b/src/main/scala/leon/grammars/ProductionRule.scala new file mode 100644 index 0000000000000000000000000000000000000000..ded6a3c9832001af5d1bacc1f44096000482035d --- /dev/null +++ b/src/main/scala/leon/grammars/ProductionRule.scala @@ -0,0 +1,18 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package grammars + +import bonsai.Generator + +/** Represents a production rule of a non-terminal symbol of an [[ExpressionGrammar]]. + * + * @param subTrees The nonterminals that are used in the right-hand side of this [[ProductionRule]] + * (and will generate deeper syntax trees). + * @param builder A function that builds the syntax tree that this [[ProductionRule]] represents from nested trees. + * @param tag Gives information about the nature of this production rule. + * @tparam T The type of nonterminal symbols of the grammar + * @tparam R The type of syntax trees of the grammar + */ +case class ProductionRule[T, R](override val subTrees: Seq[T], override val builder: Seq[R] => R, tag: Tags.Tag) + extends Generator[T,R](subTrees, builder) diff --git a/src/main/scala/leon/grammars/SafeRecursiveCalls.scala b/src/main/scala/leon/grammars/SafeRecursiveCalls.scala index 1bbcb0523158ac95713f5a0d4a16f0f35e14edf4..df20908581a2576388d079854b3edd3255253e27 100644 --- a/src/main/scala/leon/grammars/SafeRecursiveCalls.scala +++ b/src/main/scala/leon/grammars/SafeRecursiveCalls.scala @@ -9,15 +9,24 @@ import purescala.ExprOps._ import purescala.Expressions._ import synthesis.utils.Helpers._ +/** Generates recursive calls that will not trivially result in non-termination. + * + * @param ws An expression that contains the known set [[synthesis.Witnesses.Terminating]] expressions + * @param pc The path condition for the generated [[Expr]] by this grammar + */ case class SafeRecursiveCalls(prog: Program, ws: Expr, pc: Expr) extends ExpressionGrammar[TypeTree] { - def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Gen] = { + def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Prod] = { val calls = terminatingCalls(prog, t, ws, pc) calls.map { - case (e, free) => + case (fi, free) => val freeSeq = free.toSeq - nonTerminal(freeSeq.map(_.getType), { sub => replaceFromIDs(freeSeq.zip(sub).toMap, e) }) + nonTerminal( + freeSeq.map(_.getType), + { sub => replaceFromIDs(freeSeq.zip(sub).toMap, fi) }, + Tags.tagOf(fi.tfd.fd, isSafe = true) + ) } } } diff --git a/src/main/scala/leon/grammars/SimilarTo.scala b/src/main/scala/leon/grammars/SimilarTo.scala index 77e912792965d860fc934eb016370c8f2b57fd8f..3a7708e9a77960ffbfde98d478d2ca7c73c713d0 100644 --- a/src/main/scala/leon/grammars/SimilarTo.scala +++ b/src/main/scala/leon/grammars/SimilarTo.scala @@ -3,21 +3,24 @@ package leon package grammars +import transformers._ import purescala.Types._ import purescala.TypeOps._ import purescala.Extractors._ import purescala.Definitions._ -import purescala.ExprOps._ -import purescala.DefOps._ import purescala.Expressions._ import synthesis._ +/** A grammar that generates expressions by inserting small variations in [[e]] + * @param e The [[Expr]] to which small variations will be inserted + * @param terminals A set of [[Expr]]s that may be inserted into [[e]] as small variations + */ case class SimilarTo(e: Expr, terminals: Set[Expr] = Set(), sctx: SynthesisContext, p: Problem) extends ExpressionGrammar[NonTerminal[String]] { val excludeFCalls = sctx.settings.functionsToIgnore - val normalGrammar = DepthBoundedGrammar(EmbeddedGrammar( + val normalGrammar: ExpressionGrammar[NonTerminal[String]] = DepthBoundedGrammar(EmbeddedGrammar( BaseGrammar || EqualityGrammar(Set(IntegerType, Int32Type, BooleanType) ++ terminals.map { _.getType }) || OneOf(terminals.toSeq :+ e) || @@ -37,9 +40,9 @@ case class SimilarTo(e: Expr, terminals: Set[Expr] = Set(), sctx: SynthesisConte } } - private[this] var similarCache: Option[Map[L, Seq[Gen]]] = None + private[this] var similarCache: Option[Map[L, Seq[Prod]]] = None - def computeProductions(t: L)(implicit ctx: LeonContext): Seq[Gen] = { + def computeProductions(t: L)(implicit ctx: LeonContext): Seq[Prod] = { t match { case NonTerminal(_, "B", _) => normalGrammar.computeProductions(t) case _ => @@ -54,7 +57,7 @@ case class SimilarTo(e: Expr, terminals: Set[Expr] = Set(), sctx: SynthesisConte } } - def computeSimilar(e : Expr)(implicit ctx: LeonContext): Seq[(L, Gen)] = { + def computeSimilar(e : Expr)(implicit ctx: LeonContext): Seq[(L, Prod)] = { def getLabel(t: TypeTree) = { val tpe = bestRealType(t) @@ -67,9 +70,9 @@ case class SimilarTo(e: Expr, terminals: Set[Expr] = Set(), sctx: SynthesisConte case _ => false } - def rec(e: Expr, gl: L): Seq[(L, Gen)] = { + def rec(e: Expr, gl: L): Seq[(L, Prod)] = { - def gens(e: Expr, gl: L, subs: Seq[Expr], builder: (Seq[Expr] => Expr)): Seq[(L, Gen)] = { + def gens(e: Expr, gl: L, subs: Seq[Expr], builder: (Seq[Expr] => Expr)): Seq[(L, Prod)] = { val subGls = subs.map { s => getLabel(s.getType) } // All the subproductions for sub gl @@ -81,8 +84,8 @@ case class SimilarTo(e: Expr, terminals: Set[Expr] = Set(), sctx: SynthesisConte } val swaps = if (subs.size > 1 && !isCommutative(e)) { - (for (i <- 0 until subs.size; - j <- i+1 until subs.size) yield { + (for (i <- subs.indices; + j <- i+1 until subs.size) yield { if (subs(i).getType == subs(j).getType) { val swapSubs = subs.updated(i, subs(j)).updated(j, subs(i)) @@ -98,18 +101,18 @@ case class SimilarTo(e: Expr, terminals: Set[Expr] = Set(), sctx: SynthesisConte allSubs ++ injectG ++ swaps } - def cegis(gl: L): Seq[(L, Gen)] = { + def cegis(gl: L): Seq[(L, Prod)] = { normalGrammar.getProductions(gl).map(gl -> _) } - def int32Variations(gl: L, e : Expr): Seq[(L, Gen)] = { + def int32Variations(gl: L, e : Expr): Seq[(L, Prod)] = { Seq( gl -> terminal(BVMinus(e, IntLiteral(1))), gl -> terminal(BVPlus (e, IntLiteral(1))) ) } - def intVariations(gl: L, e : Expr): Seq[(L, Gen)] = { + def intVariations(gl: L, e : Expr): Seq[(L, Prod)] = { Seq( gl -> terminal(Minus(e, InfiniteIntegerLiteral(1))), gl -> terminal(Plus (e, InfiniteIntegerLiteral(1))) @@ -118,7 +121,7 @@ case class SimilarTo(e: Expr, terminals: Set[Expr] = Set(), sctx: SynthesisConte // Find neighbor case classes that are compatible with the arguments: // Turns And(e1, e2) into Or(e1, e2)... - def ccVariations(gl: L, cc: CaseClass): Seq[(L, Gen)] = { + def ccVariations(gl: L, cc: CaseClass): Seq[(L, Prod)] = { val CaseClass(cct, args) = cc val neighbors = cct.root.knownCCDescendants diff Seq(cct) @@ -129,7 +132,7 @@ case class SimilarTo(e: Expr, terminals: Set[Expr] = Set(), sctx: SynthesisConte } val funFilter = (fd: FunDef) => fd.isSynthetic || (excludeFCalls contains fd) - val subs: Seq[(L, Gen)] = e match { + val subs: Seq[(L, Prod)] = e match { case _: Terminal | _: Let | _: LetDef | _: MatchExpr => gens(e, gl, Nil, { _ => e }) ++ cegis(gl) diff --git a/src/main/scala/leon/grammars/SizeBoundedGrammar.scala b/src/main/scala/leon/grammars/SizeBoundedGrammar.scala deleted file mode 100644 index 1b25e30f61aa74598feb255366fe10a153bc9e30..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/grammars/SizeBoundedGrammar.scala +++ /dev/null @@ -1,35 +0,0 @@ -/* Copyright 2009-2015 EPFL, Lausanne */ - -package leon -package grammars - -import purescala.Types._ -import leon.utils.SeqUtils.sumTo - -case class SizedLabel[T <: Typed](underlying: T, size: Int) extends Typed { - val getType = underlying.getType - - override def asString(implicit ctx: LeonContext) = underlying.asString+"|"+size+"|" -} - -case class SizeBoundedGrammar[T <: Typed](g: ExpressionGrammar[T]) extends ExpressionGrammar[SizedLabel[T]] { - def computeProductions(sl: SizedLabel[T])(implicit ctx: LeonContext): Seq[Gen] = { - if (sl.size <= 0) { - Nil - } else if (sl.size == 1) { - g.getProductions(sl.underlying).filter(_.subTrees.isEmpty).map { gen => - terminal(gen.builder(Seq())) - } - } else { - g.getProductions(sl.underlying).filter(_.subTrees.nonEmpty).flatMap { gen => - val sizes = sumTo(sl.size-1, gen.subTrees.size) - - for (ss <- sizes) yield { - val subSizedLabels = (gen.subTrees zip ss) map (s => SizedLabel(s._1, s._2)) - - nonTerminal(subSizedLabels, gen.builder) - } - } - } - } -} diff --git a/src/main/scala/leon/grammars/Tags.scala b/src/main/scala/leon/grammars/Tags.scala new file mode 100644 index 0000000000000000000000000000000000000000..4a6b6fca491b8db6f74622edd9298ec5cd6053b0 --- /dev/null +++ b/src/main/scala/leon/grammars/Tags.scala @@ -0,0 +1,65 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package grammars + +import purescala.Types.CaseClassType +import purescala.Definitions.FunDef + +object Tags { + /** A class for tags that tag a [[ProductionRule]] with the kind of expression in generates. */ + abstract class Tag + case object Top extends Tag // Tag for the top-level of the grammar (default) + case object Zero extends Tag // Tag for 0 + case object One extends Tag // Tag for 1 + case object BooleanC extends Tag // Tag for boolean constants + case object Constant extends Tag // Tag for other constants + case object And extends Tag // Tags for boolean operations + case object Or extends Tag + case object Not extends Tag + case object Plus extends Tag // Tags for arithmetic operations + case object Minus extends Tag + case object Times extends Tag + case object Mod extends Tag + case object Div extends Tag + case object Variable extends Tag // Tag for variables + case object Equals extends Tag // Tag for equality + /** Constructors like Tuple, CaseClass... + * + * @param isTerminal If true, this constructor represents a terminal symbol + * (in practice, case class with 0 fields) + */ + case class Constructor(isTerminal: Boolean) extends Tag + /** Tag for function calls + * + * @param isMethod Whether the function called is a method + * @param isSafe Whether this constructor represents a safe function call. + * We need this because this call implicitly contains a variable, + * so we want to allow constants in all arguments. + */ + case class FunCall(isMethod: Boolean, isSafe: Boolean) extends Tag + + /** The set of tags that represent constants */ + val isConst: Set[Tag] = Set(Zero, One, Constant, BooleanC, Constructor(true)) + + /** The set of tags that represent commutative operations */ + val isCommut: Set[Tag] = Set(Plus, Times, Equals) + + /** The set of tags which have trivial results for equal arguments */ + val symmetricTrivial = Set(Minus, And, Or, Equals, Div, Mod) + + /** Tags which allow constants in all their operands + * + * In reality, the current version never allows that: it is only allowed in safe function calls + * which by construction contain a hidden reference to a variable. + * TODO: Experiment with different conditions, e.g. are constants allowed in + * top-level/ general function calls/ constructors/...? + */ + def allConstArgsAllowed(t: Tag) = t match { + case FunCall(_, true) => true + case _ => false + } + + def tagOf(cct: CaseClassType) = Constructor(cct.fields.isEmpty) + def tagOf(fd: FunDef, isSafe: Boolean) = FunCall(fd.methodOwner.isDefined, isSafe) +} \ No newline at end of file diff --git a/src/main/scala/leon/grammars/ValueGrammar.scala b/src/main/scala/leon/grammars/ValueGrammar.scala index 98850c8df4adcf3e776970c176cf37c251823917..8548c1a9542b9f1b0acb96fd3d993fecd6e24c1d 100644 --- a/src/main/scala/leon/grammars/ValueGrammar.scala +++ b/src/main/scala/leon/grammars/ValueGrammar.scala @@ -6,31 +6,32 @@ package grammars import purescala.Types._ import purescala.Expressions._ +/** A grammar of values (ground terms) */ case object ValueGrammar extends ExpressionGrammar[TypeTree] { - def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Gen] = t match { + def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Prod] = t match { case BooleanType => List( - terminal(BooleanLiteral(true)), - terminal(BooleanLiteral(false)) + terminal(BooleanLiteral(true), Tags.One), + terminal(BooleanLiteral(false), Tags.Zero) ) case Int32Type => List( - terminal(IntLiteral(0)), - terminal(IntLiteral(1)), - terminal(IntLiteral(5)) + terminal(IntLiteral(0), Tags.Zero), + terminal(IntLiteral(1), Tags.One), + terminal(IntLiteral(5), Tags.Constant) ) case IntegerType => List( - terminal(InfiniteIntegerLiteral(0)), - terminal(InfiniteIntegerLiteral(1)), - terminal(InfiniteIntegerLiteral(5)) + terminal(InfiniteIntegerLiteral(0), Tags.Zero), + terminal(InfiniteIntegerLiteral(1), Tags.One), + terminal(InfiniteIntegerLiteral(5), Tags.Constant) ) case StringType => List( - terminal(StringLiteral("")), - terminal(StringLiteral("a")), - terminal(StringLiteral("foo")), - terminal(StringLiteral("bar")) + terminal(StringLiteral(""), Tags.Constant), + terminal(StringLiteral("a"), Tags.Constant), + terminal(StringLiteral("foo"), Tags.Constant), + terminal(StringLiteral("bar"), Tags.Constant) ) case tp: TypeParameter => @@ -40,28 +41,29 @@ case object ValueGrammar extends ExpressionGrammar[TypeTree] { case TupleType(stps) => List( - nonTerminal(stps, { sub => Tuple(sub) }) + nonTerminal(stps, Tuple, Tags.Constructor(stps.isEmpty)) ) case cct: CaseClassType => List( - nonTerminal(cct.fields.map(_.getType), { case rs => CaseClass(cct, rs)}) + nonTerminal(cct.fields.map(_.getType), CaseClass(cct, _), Tags.tagOf(cct)) ) case act: AbstractClassType => act.knownCCDescendants.map { cct => - nonTerminal(cct.fields.map(_.getType), { case rs => CaseClass(cct, rs)}) + nonTerminal(cct.fields.map(_.getType), CaseClass(cct, _), Tags.tagOf(cct)) } case st @ SetType(base) => List( - nonTerminal(List(base), { case elems => FiniteSet(elems.toSet, base) }), - nonTerminal(List(base, base), { case elems => FiniteSet(elems.toSet, base) }) + terminal(FiniteSet(Set(), base), Tags.Constant), + nonTerminal(List(base), { elems => FiniteSet(elems.toSet, base) }, Tags.Constructor(isTerminal = false)), + nonTerminal(List(base, base), { elems => FiniteSet(elems.toSet, base) }, Tags.Constructor(isTerminal = false)) ) case UnitType => List( - terminal(UnitLiteral()) + terminal(UnitLiteral(), Tags.Constant) ) case _ => diff --git a/src/main/scala/leon/grammars/transformers/DepthBoundedGrammar.scala b/src/main/scala/leon/grammars/transformers/DepthBoundedGrammar.scala new file mode 100644 index 0000000000000000000000000000000000000000..02e045497a28db205d4a33a300ac0b742510920a --- /dev/null +++ b/src/main/scala/leon/grammars/transformers/DepthBoundedGrammar.scala @@ -0,0 +1,21 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package grammars +package transformers + +/** Limits a grammar to a specific expression depth */ +case class DepthBoundedGrammar[L](g: ExpressionGrammar[NonTerminal[L]], bound: Int) extends ExpressionGrammar[NonTerminal[L]] { + def computeProductions(l: NonTerminal[L])(implicit ctx: LeonContext): Seq[Prod] = g.computeProductions(l).flatMap { + case gen => + if (l.depth == Some(bound) && gen.isNonTerminal) { + Nil + } else if (l.depth.exists(_ > bound)) { + Nil + } else { + List ( + nonTerminal(gen.subTrees.map(sl => sl.copy(depth = l.depth.map(_+1).orElse(Some(1)))), gen.builder) + ) + } + } +} diff --git a/src/main/scala/leon/grammars/EmbeddedGrammar.scala b/src/main/scala/leon/grammars/transformers/EmbeddedGrammar.scala similarity index 74% rename from src/main/scala/leon/grammars/EmbeddedGrammar.scala rename to src/main/scala/leon/grammars/transformers/EmbeddedGrammar.scala index 8dcbc6ec10f9aa42895e5f876cdd4d72479de229..d989a8804b32f62697b7f31e498e61393a12c35b 100644 --- a/src/main/scala/leon/grammars/EmbeddedGrammar.scala +++ b/src/main/scala/leon/grammars/transformers/EmbeddedGrammar.scala @@ -2,10 +2,9 @@ package leon package grammars +package transformers -import purescala.Types._ -import purescala.Expressions._ -import purescala.Constructors._ +import leon.purescala.Types.Typed /** * Embed a grammar Li->Expr within a grammar Lo->Expr @@ -13,9 +12,9 @@ import purescala.Constructors._ * We rely on a bijection between Li and Lo labels */ case class EmbeddedGrammar[Ti <: Typed, To <: Typed](innerGrammar: ExpressionGrammar[Ti], iToo: Ti => To, oToi: To => Ti) extends ExpressionGrammar[To] { - def computeProductions(t: To)(implicit ctx: LeonContext): Seq[Gen] = { + def computeProductions(t: To)(implicit ctx: LeonContext): Seq[Prod] = { innerGrammar.computeProductions(oToi(t)).map { innerGen => - nonTerminal(innerGen.subTrees.map(iToo), innerGen.builder) + nonTerminal(innerGen.subTrees.map(iToo), innerGen.builder, innerGen.tag) } } } diff --git a/src/main/scala/leon/grammars/OneOf.scala b/src/main/scala/leon/grammars/transformers/OneOf.scala similarity index 56% rename from src/main/scala/leon/grammars/OneOf.scala rename to src/main/scala/leon/grammars/transformers/OneOf.scala index 0e10c096151c1fdf83d3c7e7f10c4a4a6518215b..5c57c6a1a48179e2d813398aa651022df6cae35a 100644 --- a/src/main/scala/leon/grammars/OneOf.scala +++ b/src/main/scala/leon/grammars/transformers/OneOf.scala @@ -2,14 +2,15 @@ package leon package grammars +package transformers -import purescala.Types._ -import purescala.Expressions._ -import purescala.TypeOps._ -import purescala.Constructors._ +import purescala.Expressions.Expr +import purescala.Types.TypeTree +import purescala.TypeOps.isSubtypeOf +/** Generates one production rule for each expression in a sequence that has compatible type */ case class OneOf(inputs: Seq[Expr]) extends ExpressionGrammar[TypeTree] { - def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Gen] = { + def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Prod] = { inputs.collect { case i if isSubtypeOf(i.getType, t) => terminal(i) diff --git a/src/main/scala/leon/grammars/transformers/SizeBoundedGrammar.scala b/src/main/scala/leon/grammars/transformers/SizeBoundedGrammar.scala new file mode 100644 index 0000000000000000000000000000000000000000..5abff1aa79df47158a785bb160b54a7cb5d77ed5 --- /dev/null +++ b/src/main/scala/leon/grammars/transformers/SizeBoundedGrammar.scala @@ -0,0 +1,59 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package grammars +package transformers + +import purescala.Types.Typed +import utils.SeqUtils._ + +/** Adds information about size to a nonterminal symbol */ +case class SizedNonTerm[T <: Typed](underlying: T, size: Int) extends Typed { + val getType = underlying.getType + + override def asString(implicit ctx: LeonContext) = underlying.asString+"|"+size+"|" +} + +/** Limits a grammar by producing expressions of size bounded by the [[SizedNonTerm.size]] of a given [[SizedNonTerm]]. + * + * In case of commutative operations, the grammar will produce trees skewed to the right + * (i.e. the right subtree will always be larger). Notice we do not lose generality in case of + * commutative operations. + */ +case class SizeBoundedGrammar[T <: Typed](g: ExpressionGrammar[T], optimizeCommut: Boolean) extends ExpressionGrammar[SizedNonTerm[T]] { + def computeProductions(sl: SizedNonTerm[T])(implicit ctx: LeonContext): Seq[Prod] = { + if (sl.size <= 0) { + Nil + } else if (sl.size == 1) { + g.getProductions(sl.underlying).filter(_.isTerminal).map { gen => + terminal(gen.builder(Seq()), gen.tag) + } + } else { + g.getProductions(sl.underlying).filter(_.isNonTerminal).flatMap { gen => + + // Ad-hoc equality that does not take into account position of nonterminals. + // TODO: Ugly and hacky + def characteristic(t: T): AnyRef = t match { + case TaggedNonTerm(underlying, tag, _, isConst) => + (underlying, tag, isConst) + case other => + other + } + + // Optimization: When we have a commutative operation and all the labels are the same, + // we can skew the expression to always be right-heavy + val sizes = if(optimizeCommut && Tags.isCommut(gen.tag) && gen.subTrees.map(characteristic).toSet.size == 1) { + sumToOrdered(sl.size-1, gen.arity) + } else { + sumTo(sl.size-1, gen.arity) + } + + for (ss <- sizes) yield { + val subSizedLabels = (gen.subTrees zip ss) map (s => SizedNonTerm(s._1, s._2)) + + nonTerminal(subSizedLabels, gen.builder, gen.tag) + } + } + } + } +} diff --git a/src/main/scala/leon/grammars/transformers/TaggedGrammar.scala b/src/main/scala/leon/grammars/transformers/TaggedGrammar.scala new file mode 100644 index 0000000000000000000000000000000000000000..a95306f7d8a0188891248d67196598448ec69737 --- /dev/null +++ b/src/main/scala/leon/grammars/transformers/TaggedGrammar.scala @@ -0,0 +1,112 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package grammars +package transformers + +import leon.purescala.Types.Typed +import Tags._ + +/** Adds to a nonterminal information about about the tag of its parent's [[leon.grammars.ProductionRule.tag]] + * and additional information. + * + * @param underlying The underlying nonterminal + * @param tag The tag of the parent of this nonterminal + * @param pos The index of this nonterminal in its father's production rule + * @param isConst Whether this nonterminal is obliged to generate/not generate constants. + * + */ +case class TaggedNonTerm[T <: Typed](underlying: T, tag: Tag, pos: Int, isConst: Option[Boolean]) extends Typed { + val getType = underlying.getType + + private val cString = isConst match { + case Some(true) => "↓" + case Some(false) => "↑" + case None => "○" + } + + /** [[isConst]] is printed as follows: ↓ for constants only, ↑ for nonconstants only, + * ○ for anything allowed. + */ + override def asString(implicit ctx: LeonContext): String = s"$underlying%$tag@$pos$cString" +} + +/** Constraints a grammar to reduce redundancy by utilizing information provided by the [[TaggedNonTerm]]. + * + * 1) In case of associative operations, right associativity is enforced. + * 2) Does not generate + * - neutral and absorbing elements (incl. boolean equality) + * - nested negations + * - trivial operations for symmetric arguments, e.g. a == a + * 3) Excludes method calls on nullary case objects, e.g. Nil().size + * 4) Enforces that no constant trees are generated (and recursively for each subtree) + * + * @param g The underlying untagged grammar + */ +case class TaggedGrammar[T <: Typed](g: ExpressionGrammar[T]) extends ExpressionGrammar[TaggedNonTerm[T]] { + + private def exclude(tag: Tag, pos: Int): Set[Tag] = (tag, pos) match { + case (Top, _) => Set() + case (And, 0) => Set(And, BooleanC) + case (And, 1) => Set(BooleanC) + case (Or, 0) => Set(Or, BooleanC) + case (Or, 1) => Set(BooleanC) + case (Plus, 0) => Set(Plus, Zero, One) + case (Plus, 1) => Set(Zero) + case (Minus, 1) => Set(Zero) + case (Not, _) => Set(Not, BooleanC) + case (Times, 0) => Set(Times, Zero, One) + case (Times, 1) => Set(Zero, One) + case (Equals,_) => Set(Not, BooleanC) + case (Div | Mod, 0 | 1) => Set(Zero, One) + case (FunCall(true, _), 0) => Set(Constructor(true)) // Don't allow Nil().size etc. + case _ => Set() + } + + def computeProductions(t: TaggedNonTerm[T])(implicit ctx: LeonContext): Seq[Prod] = { + + // Point (4) for this level + val constFilter: g.Prod => Boolean = t.isConst match { + case Some(b) => + innerGen => isConst(innerGen.tag) == b + case None => + _ => true + } + + g.computeProductions(t.underlying) + // Include only constants iff constants are forced, only non-constants iff they are forced + .filter(constFilter) + // Points (1), (2). (3) + .filterNot { innerGen => exclude(t.tag, t.pos)(innerGen.tag) } + .flatMap { innerGen => + + def nt(isConst: Int => Option[Boolean]) = nonTerminal( + innerGen.subTrees.zipWithIndex.map { + case (t, pos) => TaggedNonTerm(t, innerGen.tag, pos, isConst(pos)) + }, + innerGen.builder, + innerGen.tag + ) + + def powerSet[A](t: Set[A]): Set[Set[A]] = { + @scala.annotation.tailrec + def pwr(t: Set[A], ps: Set[Set[A]]): Set[Set[A]] = + if (t.isEmpty) ps + else pwr(t.tail, ps ++ (ps map (_ + t.head))) + + pwr(t, Set(Set.empty[A])) + } + + // Allow constants everywhere if this is allowed, otherwise demand at least 1 variable. + // Aka. tag subTrees correctly so point (4) is enforced in the lower level + // (also, make sure we treat terminals correctly). + if (innerGen.isTerminal || allConstArgsAllowed(innerGen.tag)) { + Seq(nt(_ => None)) + } else { + val indices = innerGen.subTrees.indices.toSet + (powerSet(indices) - indices) map (indices => nt(x => Some(indices(x)))) + } + } + } + +} diff --git a/src/main/scala/leon/grammars/Or.scala b/src/main/scala/leon/grammars/transformers/Union.scala similarity index 73% rename from src/main/scala/leon/grammars/Or.scala rename to src/main/scala/leon/grammars/transformers/Union.scala index e691a245984eaeb11277b9278505b49cf623fed3..471625ac3c22c22456f49f366ed26e5195b5f4ab 100644 --- a/src/main/scala/leon/grammars/Or.scala +++ b/src/main/scala/leon/grammars/transformers/Union.scala @@ -2,8 +2,9 @@ package leon package grammars +package transformers -import purescala.Types._ +import purescala.Types.Typed case class Union[T <: Typed](gs: Seq[ExpressionGrammar[T]]) extends ExpressionGrammar[T] { val subGrammars: Seq[ExpressionGrammar[T]] = gs.flatMap { @@ -11,6 +12,6 @@ case class Union[T <: Typed](gs: Seq[ExpressionGrammar[T]]) extends ExpressionGr case g => Seq(g) } - def computeProductions(t: T)(implicit ctx: LeonContext): Seq[Gen] = + def computeProductions(t: T)(implicit ctx: LeonContext): Seq[Prod] = subGrammars.flatMap(_.getProductions(t)) } diff --git a/src/main/scala/leon/purescala/DefOps.scala b/src/main/scala/leon/purescala/DefOps.scala index 81ebbdbec38e1484baf106eac9652d62297ae493..4085050d251721bce587f2d4b82baff09e4acae0 100644 --- a/src/main/scala/leon/purescala/DefOps.scala +++ b/src/main/scala/leon/purescala/DefOps.scala @@ -275,6 +275,14 @@ object DefOps { None } + /** + * + * @param p + * @param fdMapF + * @param fiMapF + * @return + */ + def replaceFunDefs(p: Program)(fdMapF: FunDef => Option[FunDef], fiMapF: (FunctionInvocation, FunDef) => Option[Expr] = defaultFiMap) = { diff --git a/src/main/scala/leon/purescala/TypeOps.scala b/src/main/scala/leon/purescala/TypeOps.scala index 81b8420c3aae4161cf8f26a041539d1147290484..20cca91e1b8f4ebd2f2541d32d772cf55fbdfa79 100644 --- a/src/main/scala/leon/purescala/TypeOps.scala +++ b/src/main/scala/leon/purescala/TypeOps.scala @@ -13,7 +13,7 @@ import ExprOps.preMap object TypeOps { def typeDepth(t: TypeTree): Int = t match { - case NAryType(tps, builder) => 1+ (0 +: (tps map typeDepth)).max + case NAryType(tps, builder) => 1 + (0 +: (tps map typeDepth)).max } def typeParamsOf(t: TypeTree): Set[TypeParameter] = t match { diff --git a/src/main/scala/leon/repair/Repairman.scala b/src/main/scala/leon/repair/Repairman.scala index 37f187679897bbec908ffeaf18f5385198f915c9..19095f6f4406ef398fedfb64860c03a34c92f250 100644 --- a/src/main/scala/leon/repair/Repairman.scala +++ b/src/main/scala/leon/repair/Repairman.scala @@ -242,7 +242,7 @@ class Repairman(ctx0: LeonContext, initProgram: Program, fd: FunDef, verifTimeou val maxValid = 400 val evaluator = new CodeGenEvaluator(ctx, program, CodeGenParams.default) - val enum = new MemoizedEnumerator[TypeTree, Expr, Generator[TypeTree, Expr]](ValueGrammar.getProductions) + val enum = new MemoizedEnumerator[TypeTree, Expr, ProductionRule[TypeTree, Expr]](ValueGrammar.getProductions) val inputs = enum.iterator(tupleTypeWrap(fd.params map { _.getType})).map(unwrapTuple(_, fd.params.size)) diff --git a/src/main/scala/leon/synthesis/ExamplesFinder.scala b/src/main/scala/leon/synthesis/ExamplesFinder.scala index 5077f9467dab2edd3abf8fdf7da01d5eae59ff5e..5728d3b14df333300f80896ff3baf6f379342b65 100644 --- a/src/main/scala/leon/synthesis/ExamplesFinder.scala +++ b/src/main/scala/leon/synthesis/ExamplesFinder.scala @@ -182,7 +182,7 @@ class ExamplesFinder(ctx0: LeonContext, program: Program) { }) getOrElse { // If the input contains free variables, it does not provide concrete examples. // We will instantiate them according to a simple grammar to get them. - val enum = new MemoizedEnumerator[TypeTree, Expr, Generator[TypeTree, Expr]](ValueGrammar.getProductions) + val enum = new MemoizedEnumerator[TypeTree, Expr, ProductionRule[TypeTree, Expr]](ValueGrammar.getProductions) val values = enum.iterator(tupleTypeWrap(freeVars.map { _.getType })) val instantiations = values.map { v => freeVars.zip(unwrapTuple(v, freeVars.size)).toMap diff --git a/src/main/scala/leon/synthesis/Rules.scala b/src/main/scala/leon/synthesis/Rules.scala index cd27e272d53e9669f4c2d1f2c8e07356819b4827..6583d4ca805c4a3f8335dbd4aa374602c7f132a8 100644 --- a/src/main/scala/leon/synthesis/Rules.scala +++ b/src/main/scala/leon/synthesis/Rules.scala @@ -35,6 +35,9 @@ abstract class PreprocessingRule(name: String) extends Rule(name) { /** Contains the list of all available rules for synthesis */ object Rules { + + private val newCegis = true + /** Returns the list of all available rules for synthesis */ def all = List[Rule]( StringRender, @@ -54,8 +57,8 @@ object Rules { OptimisticGround, EqualitySplit, InequalitySplit, - CEGIS, - TEGIS, + if(newCegis) CEGIS2 else CEGIS, + //TEGIS, //BottomUpTEGIS, rules.Assert, DetupleOutput, diff --git a/src/main/scala/leon/synthesis/SourceInfo.scala b/src/main/scala/leon/synthesis/SourceInfo.scala index 4bb10d38c9ffc7a7667d165b84e4f65c1edc9e0c..8ab07929d78479656f18ce1fd652cfa7ef870e17 100644 --- a/src/main/scala/leon/synthesis/SourceInfo.scala +++ b/src/main/scala/leon/synthesis/SourceInfo.scala @@ -45,6 +45,10 @@ object SourceInfo { ci } + if (results.isEmpty) { + ctx.reporter.warning("No 'choose' found. Maybe the functions you chose do not exist?") + } + results.sortBy(_.source.getPos) } diff --git a/src/main/scala/leon/synthesis/SynthesisPhase.scala b/src/main/scala/leon/synthesis/SynthesisPhase.scala index b9ba6df1f688e01edf631e995ba2f8623bcfc5fe..373e4fb56e9f51639ebf45f16fa49069d07a28e5 100644 --- a/src/main/scala/leon/synthesis/SynthesisPhase.scala +++ b/src/main/scala/leon/synthesis/SynthesisPhase.scala @@ -3,13 +3,11 @@ package leon package synthesis -import purescala.ExprOps._ - +import purescala.ExprOps.replace import purescala.ScalaPrinter -import leon.utils._ import purescala.Definitions.{Program, FunDef} -import leon.utils.ASCIIHelpers +import leon.utils._ import graph._ object SynthesisPhase extends TransformationPhase { diff --git a/src/main/scala/leon/synthesis/Synthesizer.scala b/src/main/scala/leon/synthesis/Synthesizer.scala index bafed6ec2bab51539bfc0547563bbad2aeea873e..efd1ad13e0538f855487e94b7b1a35d7d893627f 100644 --- a/src/main/scala/leon/synthesis/Synthesizer.scala +++ b/src/main/scala/leon/synthesis/Synthesizer.scala @@ -70,21 +70,19 @@ class Synthesizer(val context : LeonContext, // Print out report for synthesis, if necessary reporter.ifDebug { printer => - import java.io.FileWriter import java.text.SimpleDateFormat import java.util.Date val categoryName = ci.fd.getPos.file.toString.split("/").dropRight(1).lastOption.getOrElse("?") val benchName = categoryName+"."+ci.fd.id.name - var time = lastTime/1000.0; + val time = lastTime/1000.0 val defs = visibleDefsFrom(ci.fd)(program).collect { case cd: ClassDef => 1 + cd.fields.size case fd: FunDef => 1 + fd.params.size + formulaSize(fd.fullBody) } - val psize = defs.sum; - + val psize = defs.sum val (size, calls, proof) = result.headOption match { case Some((sol, trusted)) => diff --git a/src/main/scala/leon/synthesis/disambiguation/QuestionBuilder.scala b/src/main/scala/leon/synthesis/disambiguation/QuestionBuilder.scala index bfa6a62120af6171d001b6a026630734eb6c10fd..244925b90a9bbab378eb47c991b145cf33ffde6c 100644 --- a/src/main/scala/leon/synthesis/disambiguation/QuestionBuilder.scala +++ b/src/main/scala/leon/synthesis/disambiguation/QuestionBuilder.scala @@ -69,15 +69,15 @@ object QuestionBuilder { /** Specific enumeration of strings, which can be used with the QuestionBuilder#setValueEnumerator method */ object SpecialStringValueGrammar extends ExpressionGrammar[TypeTree] { - def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Gen] = t match { - case StringType => - List( - terminal(StringLiteral("")), - terminal(StringLiteral("a")), - terminal(StringLiteral("\"'\n\t")), - terminal(StringLiteral("Lara 2007")) - ) - case _ => ValueGrammar.computeProductions(t) + def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Prod] = t match { + case StringType => + List( + terminal(StringLiteral("")), + terminal(StringLiteral("a")), + terminal(StringLiteral("\"'\n\t")), + terminal(StringLiteral("Lara 2007")) + ) + case _ => ValueGrammar.computeProductions(t) } } } @@ -140,7 +140,7 @@ class QuestionBuilder[T <: Expr]( def result(): List[Question[T]] = { if(solutions.isEmpty) return Nil - val enum = new MemoizedEnumerator[TypeTree, Expr, Generator[TypeTree,Expr]](value_enumerator.getProductions) + val enum = new MemoizedEnumerator[TypeTree, Expr, ProductionRule[TypeTree,Expr]](value_enumerator.getProductions) val values = enum.iterator(tupleTypeWrap(_argTypes)) val instantiations = values.map { v => input.zip(unwrapTuple(v, input.size)) @@ -172,4 +172,4 @@ class QuestionBuilder[T <: Expr]( } questions.toList.sortBy(_questionSorMethod(_)) } -} \ No newline at end of file +} diff --git a/src/main/scala/leon/synthesis/rules/BottomUpTegis.scala b/src/main/scala/leon/synthesis/rules/BottomUpTegis.scala index 2f3869af16b71f9635e36d27774f55a7cee7140c..f716997f4b14e3d634185d2235b6ddd8dcbd1d76 100644 --- a/src/main/scala/leon/synthesis/rules/BottomUpTegis.scala +++ b/src/main/scala/leon/synthesis/rules/BottomUpTegis.scala @@ -51,13 +51,13 @@ abstract class BottomUpTEGISLike[T <: Typed](name: String) extends Rule(name) { val nTests = tests.size - var compiled = Map[Generator[T, Expr], Vector[Vector[Expr]] => Option[Vector[Expr]]]() + var compiled = Map[ProductionRule[T, Expr], Vector[Vector[Expr]] => Option[Vector[Expr]]]() /** * Compile Generators to functions from Expr to Expr. The compiled * generators will be passed to the enumerator */ - def compile(gen: Generator[T, Expr]): Vector[Vector[Expr]] => Option[Vector[Expr]] = { + def compile(gen: ProductionRule[T, Expr]): Vector[Vector[Expr]] => Option[Vector[Expr]] = { compiled.getOrElse(gen, { val executor = if (gen.subTrees.isEmpty) { @@ -108,7 +108,7 @@ abstract class BottomUpTEGISLike[T <: Typed](name: String) extends Rule(name) { val targetType = tupleTypeWrap(p.xs.map(_.getType)) val wrappedTests = tests.map { case (is, os) => (is, tupleWrap(os))} - val enum = new BottomUpEnumerator[T, Expr, Expr, Generator[T, Expr]]( + val enum = new BottomUpEnumerator[T, Expr, Expr, ProductionRule[T, Expr]]( grammar.getProductions, wrappedTests, { (vecs, gen) => diff --git a/src/main/scala/leon/synthesis/rules/CEGIS.scala b/src/main/scala/leon/synthesis/rules/CEGIS.scala index 1fcf01d52088ea9d4d25d184a673ef8335a8d260..139c6116e4b19d289e5a95ecd866cf3549486a6a 100644 --- a/src/main/scala/leon/synthesis/rules/CEGIS.scala +++ b/src/main/scala/leon/synthesis/rules/CEGIS.scala @@ -4,16 +4,29 @@ package leon package synthesis package rules -import purescala.Types._ - import grammars._ -import utils._ +import grammars.transformers._ +import purescala.Types.TypeTree case object CEGIS extends CEGISLike[TypeTree]("CEGIS") { def getParams(sctx: SynthesisContext, p: Problem) = { CegisParams( grammar = Grammars.typeDepthBound(Grammars.default(sctx, p), 2), // This limits type depth - rootLabel = {(tpe: TypeTree) => tpe } + rootLabel = {(tpe: TypeTree) => tpe }, + maxUnfoldings = 12, + optimizations = false ) } } + +case object CEGIS2 extends CEGISLike[TaggedNonTerm[TypeTree]]("CEGIS2") { + def getParams(sctx: SynthesisContext, p: Problem) = { + val base = CEGIS.getParams(sctx,p).grammar + CegisParams( + grammar = TaggedGrammar(base), + rootLabel = TaggedNonTerm(_, Tags.Top, 0, None), + maxUnfoldings = 12, + optimizations = true + ) + } +} \ No newline at end of file diff --git a/src/main/scala/leon/synthesis/rules/CEGISLike.scala b/src/main/scala/leon/synthesis/rules/CEGISLike.scala index f643a7bb0e69a6f1384bb2a149a261fbe8eaba9d..368608b9031de196cadc06961eabc35c4b7aa2a1 100644 --- a/src/main/scala/leon/synthesis/rules/CEGISLike.scala +++ b/src/main/scala/leon/synthesis/rules/CEGISLike.scala @@ -4,10 +4,6 @@ package leon package synthesis package rules -import leon.utils.SeqUtils -import solvers._ -import grammars._ - import purescala.Expressions._ import purescala.Common._ import purescala.Definitions._ @@ -16,18 +12,24 @@ import purescala.ExprOps._ import purescala.DefOps._ import purescala.Constructors._ -import scala.collection.mutable.{HashMap=>MutableMap, ArrayBuffer} +import solvers._ +import grammars._ +import grammars.transformers._ +import leon.utils.SeqUtils import evaluators._ import datagen._ import codegen.CodeGenParams +import scala.collection.mutable.{HashMap=>MutableMap, ArrayBuffer} + abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { case class CegisParams( grammar: ExpressionGrammar[T], rootLabel: TypeTree => T, - maxUnfoldings: Int = 5 + maxUnfoldings: Int = 5, + optimizations: Boolean ) def getParams(sctx: SynthesisContext, p: Problem): CegisParams @@ -36,7 +38,7 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { val exSolverTo = 2000L val cexSolverTo = 2000L - // Track non-deterministic programs up to 10'000 programs, or give up + // Track non-deterministic programs up to 100'000 programs, or give up val nProgramsLimit = 100000 val sctx = hctx.sctx @@ -48,6 +50,7 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { // Limits the number of programs CEGIS will specifically validate individually val validateUpTo = 3 + val passingRatio = 10 val interruptManager = sctx.context.interruptManager @@ -61,13 +64,13 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { private var termSize = 0 - val grammar = SizeBoundedGrammar(params.grammar) + val grammar = SizeBoundedGrammar(params.grammar, params.optimizations) - def rootLabel = SizedLabel(params.rootLabel(tupleTypeWrap(p.xs.map(_.getType))), termSize) + def rootLabel = SizedNonTerm(params.rootLabel(tupleTypeWrap(p.xs.map(_.getType))), termSize) - var nAltsCache = Map[SizedLabel[T], Int]() + var nAltsCache = Map[SizedNonTerm[T], Int]() - def countAlternatives(l: SizedLabel[T]): Int = { + def countAlternatives(l: SizedNonTerm[T]): Int = { if (!(nAltsCache contains l)) { val count = grammar.getProductions(l).map { gen => gen.subTrees.map(countAlternatives).product @@ -102,7 +105,7 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { // C identifiers corresponding to p.xs - private var rootC: Identifier = _ + private var rootC: Identifier = _ private var bs: Set[Identifier] = Set() @@ -110,19 +113,19 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { class CGenerator { - private var buffers = Map[SizedLabel[T], Stream[Identifier]]() + private var buffers = Map[SizedNonTerm[T], Stream[Identifier]]() - private var slots = Map[SizedLabel[T], Int]().withDefaultValue(0) + private var slots = Map[SizedNonTerm[T], Int]().withDefaultValue(0) - private def streamOf(t: SizedLabel[T]): Stream[Identifier] = Stream.continually( + private def streamOf(t: SizedNonTerm[T]): Stream[Identifier] = Stream.continually( FreshIdentifier(t.asString, t.getType, true) ) def rewind(): Unit = { - slots = Map[SizedLabel[T], Int]().withDefaultValue(0) + slots = Map[SizedNonTerm[T], Int]().withDefaultValue(0) } - def getNext(t: SizedLabel[T]) = { + def getNext(t: SizedNonTerm[T]) = { if (!(buffers contains t)) { buffers += t -> streamOf(t) } @@ -146,7 +149,7 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { id } - def defineCTreeFor(l: SizedLabel[T], c: Identifier): Unit = { + def defineCTreeFor(l: SizedNonTerm[T], c: Identifier): Unit = { if (!(cTree contains c)) { val cGen = new CGenerator() @@ -182,6 +185,7 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { sctx.reporter.ifDebug { printer => printer("Grammar so far:") grammar.printProductions(printer) + printer("") } bsOrdered = bs.toSeq.sorted @@ -233,12 +237,47 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { cache(c) } - SeqUtils.cartesianProduct(seqs).map { ls => - ls.foldLeft(Set[Identifier]())(_ ++ _) + SeqUtils.cartesianProduct(seqs).map(_.flatten.toSet) + } + + def redundant(e: Expr): Boolean = { + val (op1, op2) = e match { + case Minus(o1, o2) => (o1, o2) + case Modulo(o1, o2) => (o1, o2) + case Division(o1, o2) => (o1, o2) + case BVMinus(o1, o2) => (o1, o2) + case BVRemainder(o1, o2) => (o1, o2) + case BVDivision(o1, o2) => (o1, o2) + + case And(Seq(Not(o1), Not(o2))) => (o1, o2) + case And(Seq(Not(o1), o2)) => (o1, o2) + case And(Seq(o1, Not(o2))) => (o1, o2) + case And(Seq(o1, o2)) => (o1, o2) + + case Or(Seq(Not(o1), Not(o2))) => (o1, o2) + case Or(Seq(Not(o1), o2)) => (o1, o2) + case Or(Seq(o1, Not(o2))) => (o1, o2) + case Or(Seq(o1, o2)) => (o1, o2) + + case SetUnion(o1, o2) => (o1, o2) + case SetIntersection(o1, o2) => (o1, o2) + case SetDifference(o1, o2) => (o1, o2) + + case Equals(Not(o1), Not(o2)) => (o1, o2) + case Equals(Not(o1), o2) => (o1, o2) + case Equals(o1, Not(o2)) => (o1, o2) + case Equals(o1, o2) => (o1, o2) + case _ => return false } + + op1 == op2 } - allProgramsFor(Seq(rootC)) + allProgramsFor(Seq(rootC))/* filterNot { bs => + val res = params.optimizations && exists(redundant)(getExpr(bs)) + if (!res) excludeProgram(bs, false) + res + }*/ } private def debugCTree(cTree: Map[Identifier, Seq[(Identifier, Seq[Expr] => Expr, Seq[Identifier])]], @@ -432,10 +471,9 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { } } - - // Returns the outer expression corresponding to a B-valuation def getExpr(bValues: Set[Identifier]): Expr = { + def getCValue(c: Identifier): Expr = { cTree(c).find(i => bValues(i._1)).map { case (b, builder, cs) => @@ -455,56 +493,64 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { def validatePrograms(bss: Set[Set[Identifier]]): Either[Stream[Solution], Seq[Seq[Expr]]] = { val origImpl = cTreeFd.fullBody - val cexs = for (bs <- bss.toSeq) yield { + var cexs = Seq[Seq[Expr]]() + + for (bs <- bss.toSeq) { val outerSol = getExpr(bs) val innerSol = outerExprToInnerExpr(outerSol) - + //println(s"Testing $outerSol") cTreeFd.fullBody = innerSol val cnstr = and(innerPc, letTuple(p.xs, innerSol, Not(innerPhi))) - //println("Solving for: "+cnstr.asString) - - val solverf = SolverFactory.getFromSettings(ctx, innerProgram).withTimeout(cexSolverTo) - val solver = solverf.getNewSolver() - try { - solver.assertCnstr(cnstr) - solver.check match { - case Some(true) => - excludeProgram(bs, true) - val model = solver.getModel - //println("Found counter example: ") - //for ((s, v) <- model) { - // println(" "+s.asString+" -> "+v.asString) - //} - - //val evaluator = new DefaultEvaluator(ctx, prog) - //println(evaluator.eval(cnstr, model)) - - Some(p.as.map(a => model.getOrElse(a, simplestValue(a.getType)))) + val eval = new DefaultEvaluator(ctx, innerProgram) - case Some(false) => - // UNSAT, valid program - return Left(Stream(Solution(BooleanLiteral(true), Set(), outerSol, true))) + if (cexs exists (cex => eval.eval(cnstr, p.as.zip(cex).toMap).result == Some(BooleanLiteral(true)))) { + //println(s"Program $outerSol fails!") + excludeProgram(bs, true) + cTreeFd.fullBody = origImpl + } else { + //println("Solving for: "+cnstr.asString) + + val solverf = SolverFactory.getFromSettings(ctx, innerProgram).withTimeout(cexSolverTo) + val solver = solverf.getNewSolver() + try { + solver.assertCnstr(cnstr) + solver.check match { + case Some(true) => + excludeProgram(bs, true) + val model = solver.getModel + //println("Found counter example: ") + //for ((s, v) <- model) { + // println(" "+s.asString+" -> "+v.asString) + //} + + //val evaluator = new DefaultEvaluator(ctx, prog) + //println(evaluator.eval(cnstr, model)) + //println(s"Program $outerSol fails with cex ${p.as.map(a => model.getOrElse(a, simplestValue(a.getType)))}") + cexs +:= p.as.map(a => model.getOrElse(a, simplestValue(a.getType))) + + case Some(false) => + // UNSAT, valid program + return Left(Stream(Solution(BooleanLiteral(true), Set(), outerSol, true))) - case None => - if (useOptTimeout) { - // Interpret timeout in CE search as "the candidate is valid" - sctx.reporter.info("CEGIS could not prove the validity of the resulting expression") - // Optimistic valid solution - return Left(Stream(Solution(BooleanLiteral(true), Set(), outerSol, false))) - } else { - None - } + case None => + if (useOptTimeout) { + // Interpret timeout in CE search as "the candidate is valid" + sctx.reporter.info("CEGIS could not prove the validity of the resulting expression") + // Optimistic valid solution + return Left(Stream(Solution(BooleanLiteral(true), Set(), outerSol, false))) + } + } + } finally { + solverf.reclaim(solver) + solverf.shutdown() + cTreeFd.fullBody = origImpl } - } finally { - solverf.reclaim(solver) - solverf.shutdown() - cTreeFd.fullBody = origImpl } } - Right(cexs.flatten) + Right(cexs) } var excludedPrograms = ArrayBuffer[Set[Identifier]]() @@ -663,6 +709,10 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { val sctx = hctx.sctx implicit val ctx = sctx.context + import leon.utils.Timer + val timer = new Timer + timer.start + val ndProgram = new NonDeterministicProgram(p) ndProgram.init() @@ -677,7 +727,7 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { ndProgram.grammar.printProductions(printer) } - // We populate the list of examples with a predefined one + // We populate the list of examples with a defined one sctx.reporter.debug("Acquiring initial list of examples") baseExampleInputs ++= p.eb.examples @@ -793,12 +843,22 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { } val nPassing = prunedPrograms.size - sctx.reporter.debug("#Programs passing tests: "+nPassing) + val nTotal = ndProgram.allProgramsCount() + + /*locally { + val progs = ndProgram.allPrograms() map ndProgram.getExpr + val ground = progs count isGround + println("Programs") + progs take 100 foreach println + println(s"$ground ground out of $nTotal") + }*/ + + sctx.reporter.debug(s"#Programs passing tests: $nPassing out of $nTotal") sctx.reporter.ifDebug{ printer => - for (p <- prunedPrograms.take(10)) { + for (p <- prunedPrograms.take(100)) { printer(" - "+ndProgram.getExpr(p).asString) } - if(nPassing > 10) { + if(nPassing > 100) { printer(" - ...") } } @@ -812,16 +872,23 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { } } - if (nPassing == 0 || interruptManager.isInterrupted) { // No test passed, we can skip solver and unfold again, if possible skipCESearch = true } else { var doFilter = true - if (validateUpTo > 0) { - // Validate the first N programs individually - ndProgram.validatePrograms(prunedPrograms.take(validateUpTo)) match { + // If the number of pruned programs is very small, or by far smaller than the number of total programs, + // we hypothesize it will be easier to just validate them individually. + // Otherwise, we validate a small number of programs just in case we are lucky FIXME is this last clause useful? + val programsToValidate = if (nTotal / nPassing > passingRatio || nPassing < 10) { + prunedPrograms + } else { + prunedPrograms.take(validateUpTo) + } + + if (programsToValidate.nonEmpty) { + ndProgram.validatePrograms(programsToValidate) match { case Left(sols) if sols.nonEmpty => doFilter = false result = Some(RuleClosed(sols)) @@ -919,6 +986,8 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { sctx.reporter.warning("CEGIS crashed: "+e.getMessage) e.printStackTrace() RuleFailed() + } finally { + ctx.reporter.info(s"CEGIS ran for ${timer.stop} ms") } } }) diff --git a/src/main/scala/leon/synthesis/rules/CEGLESS.scala b/src/main/scala/leon/synthesis/rules/CEGLESS.scala index c12edac075bc8525d395d5f792ef4579c0d109f1..357a3d8882cc5feb4e2687f231079748f21715bc 100644 --- a/src/main/scala/leon/synthesis/rules/CEGLESS.scala +++ b/src/main/scala/leon/synthesis/rules/CEGLESS.scala @@ -4,10 +4,10 @@ package leon package synthesis package rules +import leon.grammars.transformers.Union import purescala.ExprOps._ import purescala.Types._ import purescala.Extractors._ -import utils._ import grammars._ import Witnesses._ @@ -24,7 +24,7 @@ case object CEGLESS extends CEGISLike[NonTerminal[String]]("CEGLESS") { val inputs = p.as.map(_.toVariable) sctx.reporter.ifDebug { printer => - printer("Guides available:") + printer("Guides available:") for (g <- guides) { printer(" - "+g.asString(ctx)) } @@ -35,7 +35,8 @@ case object CEGLESS extends CEGISLike[NonTerminal[String]]("CEGLESS") { CegisParams( grammar = guidedGrammar, rootLabel = { (tpe: TypeTree) => NonTerminal(tpe, "G0") }, - maxUnfoldings = (0 +: guides.map(depth(_) + 1)).max + maxUnfoldings = (0 +: guides.map(depth(_) + 1)).max, + optimizations = false ) } } diff --git a/src/main/scala/leon/synthesis/rules/DetupleInput.scala b/src/main/scala/leon/synthesis/rules/DetupleInput.scala index 2ae2b1d5d0292a6ed725055e61b1b4af4100a63c..d3b4c823dd7110763316d121407bcf94820c5826 100644 --- a/src/main/scala/leon/synthesis/rules/DetupleInput.scala +++ b/src/main/scala/leon/synthesis/rules/DetupleInput.scala @@ -83,7 +83,7 @@ case object DetupleInput extends NormalizingRule("Detuple In") { } } - var eb = p.qeb.mapIns { info => + val eb = p.qeb.mapIns { info => List(info.flatMap { case (id, v) => ebMapInfo.get(id) match { case Some(m) => @@ -103,7 +103,8 @@ case object DetupleInput extends NormalizingRule("Detuple In") { case CaseClass(ct, args) => val (cts, es) = args.zip(ct.fields).map { case (CaseClassSelector(ct, e, id), field) if field.id == id => (ct, e) - case _ => return e + case _ => + return e }.unzip if (cts.distinct.size == 1 && es.distinct.size == 1) { @@ -126,7 +127,7 @@ case object DetupleInput extends NormalizingRule("Detuple In") { val sub = Problem(newAs, subWs, subPc, subProblem, p.xs, eb) - val s = {substAll(reverseMap, _:Expr)} andThen { simplePostTransform(recompose) } + val s = (substAll(reverseMap, _:Expr)) andThen simplePostTransform(recompose) Some(decomp(List(sub), forwardMap(s), s"Detuple ${reverseMap.keySet.mkString(", ")}")) } else { diff --git a/src/main/scala/leon/synthesis/rules/TEGISLike.scala b/src/main/scala/leon/synthesis/rules/TEGISLike.scala index a6e060f89809451d12be2830a69a4b0202824bdb..eea2b05040c677f0966049838cd0fd85da832e8e 100644 --- a/src/main/scala/leon/synthesis/rules/TEGISLike.scala +++ b/src/main/scala/leon/synthesis/rules/TEGISLike.scala @@ -67,7 +67,7 @@ abstract class TEGISLike[T <: Typed](name: String) extends Rule(name) { val evalParams = CodeGenParams.default.copy(maxFunctionInvocations = 2000) val evaluator = new DualEvaluator(sctx.context, sctx.program, evalParams) - val enum = new MemoizedEnumerator[T, Expr, Generator[T, Expr]](grammar.getProductions) + val enum = new MemoizedEnumerator[T, Expr, ProductionRule[T, Expr]](grammar.getProductions) val targetType = tupleTypeWrap(p.xs.map(_.getType)) diff --git a/src/main/scala/leon/synthesis/utils/Helpers.scala b/src/main/scala/leon/synthesis/utils/Helpers.scala index acd285a4570f93ee9dd85ba3dd29a7e4b120c25a..4bfedc4acbe59440ac7f3382c8187ae201775f02 100644 --- a/src/main/scala/leon/synthesis/utils/Helpers.scala +++ b/src/main/scala/leon/synthesis/utils/Helpers.scala @@ -34,7 +34,18 @@ object Helpers { } } - def terminatingCalls(prog: Program, tpe: TypeTree, ws: Expr, pc: Expr): List[(Expr, Set[Identifier])] = { + /** Given an initial set of function calls provided by a list of [[Terminating]], + * returns function calls that will hopefully be safe to call recursively from within this initial function calls. + * + * For each returned call, one argument is substituted by a "smaller" one, while the rest are left as holes. + * + * @param prog The current program + * @param tpe The expected type for the returned function calls + * @param ws Helper predicates that contain [[Terminating]]s with the initial calls + * @param pc The path condition + * @return A list of pairs of (safe function call, holes), where holes stand for the rest of the arguments of the function. + */ + def terminatingCalls(prog: Program, tpe: TypeTree, ws: Expr, pc: Expr): List[(FunctionInvocation, Set[Identifier])] = { val TopLevelAnds(wss) = ws val TopLevelAnds(clauses) = pc diff --git a/src/main/scala/leon/utils/SeqUtils.scala b/src/main/scala/leon/utils/SeqUtils.scala index 002f2ebedc8a6dfb265fbf101c2185b3bfa17ce1..363808e457f8df39e56ffd9124f27a553b9535bc 100644 --- a/src/main/scala/leon/utils/SeqUtils.scala +++ b/src/main/scala/leon/utils/SeqUtils.scala @@ -42,6 +42,20 @@ object SeqUtils { } } } + + def sumToOrdered(sum: Int, arity: Int): Seq[Seq[Int]] = { + def rec(sum: Int, arity: Int): Seq[Seq[Int]] = { + require(arity > 0) + if (sum < 0) Nil + else if (arity == 1) Seq(Seq(sum)) + else for { + n <- 0 to sum / arity + rest <- rec(sum - arity * n, arity - 1) + } yield n +: rest.map(n + _) + } + + rec(sum, arity) filterNot (_.head == 0) + } } class CartesianView[+A](views: Seq[SeqView[A, Seq[A]]]) extends SeqView[Seq[A], Seq[Seq[A]]] {