diff --git a/src/main/scala/leon/LeonOption.scala b/src/main/scala/leon/LeonOption.scala index a0a9c9d92ee78397cce95dd0a52042fa8bc63330..f079399f9995e409b6b72881ef8f81d92c273556 100644 --- a/src/main/scala/leon/LeonOption.scala +++ b/src/main/scala/leon/LeonOption.scala @@ -25,8 +25,10 @@ abstract class LeonOptionDef[+A] { try { parser(s) } catch { case _ : IllegalArgumentException => - reporter.error(s"Invalid option usage: $usageDesc") - Main.displayHelp(reporter, error = true) + reporter.fatalError( + s"Invalid option usage: --$name\n" + + "Try 'leon --help' for more information." + ) } } diff --git a/src/main/scala/leon/Main.scala b/src/main/scala/leon/Main.scala index f238d4a854bd881c74a3205ef5b727a220e34ae1..db49d0425e39ac888badb90c97aa7c8432d0ee96 100644 --- a/src/main/scala/leon/Main.scala +++ b/src/main/scala/leon/Main.scala @@ -116,8 +116,10 @@ object Main { } // Find respective LeonOptionDef, or report an unknown option val df = allOptions.find(_. name == name).getOrElse{ - initReporter.error(s"Unknown option: $name") - displayHelp(initReporter, error = true) + initReporter.fatalError( + s"Unknown option: $name\n" + + "Try 'leon --help' for more information." + ) } df.parse(value)(initReporter) } diff --git a/src/main/scala/leon/SharedOptions.scala b/src/main/scala/leon/SharedOptions.scala index 839dda206b4a702ad5011049a38c76d7dcd21478..b68e64c83a6fd42e445dd4b5fa3b3bd42a935de4 100644 --- a/src/main/scala/leon/SharedOptions.scala +++ b/src/main/scala/leon/SharedOptions.scala @@ -5,12 +5,11 @@ package leon import leon.utils.{DebugSections, DebugSection} import OptionParsers._ -/* - * This object contains options that are shared among different modules of Leon. - * - * Options that determine the pipeline of Leon are not stored here, - * but in MainComponent in Main.scala. - */ +/** This object contains options that are shared among different modules of Leon. + * + * Options that determine the pipeline of Leon are not stored here, + * but in [[Main.MainComponent]] instead. + */ object SharedOptions extends LeonComponent { val name = "sharedOptions" @@ -45,7 +44,7 @@ object SharedOptions extends LeonComponent { val name = "debug" val description = { val sects = DebugSections.all.toSeq.map(_.name).sorted - val (first, second) = sects.splitAt(sects.length/2) + val (first, second) = sects.splitAt(sects.length/2 + 1) "Enable detailed messages per component.\nAvailable:\n" + " " + first.mkString(", ") + ",\n" + " " + second.mkString(", ") @@ -61,8 +60,6 @@ object SharedOptions extends LeonComponent { Set(rs) case None => throw new IllegalArgumentException - //initReporter.error("Section "+s+" not found, available: "+DebugSections.all.map(_.name).mkString(", ")) - //Set() } } } diff --git a/src/main/scala/leon/datagen/GrammarDataGen.scala b/src/main/scala/leon/datagen/GrammarDataGen.scala index cd86c707ddb893918e512f6d7101cc4cc92b6405..04541e78a6e63d1fa8670f4d4aeb12dd9c4417a2 100644 --- a/src/main/scala/leon/datagen/GrammarDataGen.scala +++ b/src/main/scala/leon/datagen/GrammarDataGen.scala @@ -4,14 +4,17 @@ package leon package datagen import purescala.Expressions._ -import purescala.Types.TypeTree +import purescala.Types._ import purescala.Common._ import purescala.Constructors._ import purescala.Extractors._ +import purescala.ExprOps._ import evaluators._ import bonsai.enumerators._ import grammars._ +import utils.UniqueCounter +import utils.SeqUtils.cartesianProduct /** Utility functions to generate values of a given type. * In fact, it could be used to generate *terms* of a given type, @@ -19,9 +22,40 @@ import grammars._ class GrammarDataGen(evaluator: Evaluator, grammar: ExpressionGrammar[TypeTree] = ValueGrammar) extends DataGenerator { implicit val ctx = evaluator.context + // Assume e contains generic values with index 0. + // Return a series of expressions with all normalized combinations of generic values. + private def expandGenerics(e: Expr): Seq[Expr] = { + val c = new UniqueCounter[TypeParameter] + val withUniqueCounters: Expr = postMap { + case GenericValue(t, _) => + Some(GenericValue(t, c.next(t))) + case _ => None + }(e) + + val indices = c.current + + val (tps, substInt) = (for { + tp <- indices.keySet.toSeq + } yield tp -> (for { + from <- 0 to indices(tp) + to <- 0 to from + } yield (from, to))).unzip + + val combos = cartesianProduct(substInt) + + val substitutions = combos map { subst => + tps.zip(subst).map { case (tp, (from, to)) => + (GenericValue(tp, from): Expr) -> (GenericValue(tp, to): Expr) + }.toMap + } + + substitutions map (replace(_, withUniqueCounters)) + + } + def generate(tpe: TypeTree): Iterator[Expr] = { - val enum = new MemoizedEnumerator[TypeTree, Expr, Generator[TypeTree, Expr]](grammar.getProductions) - enum.iterator(tpe) + val enum = new MemoizedEnumerator[TypeTree, Expr, ProductionRule[TypeTree, Expr]](grammar.getProductions) + enum.iterator(tpe).flatMap(expandGenerics) } def generateFor(ins: Seq[Identifier], satisfying: Expr, maxValid: Int, maxEnumerated: Int): Iterator[Seq[Expr]] = { @@ -51,4 +85,8 @@ class GrammarDataGen(evaluator: Evaluator, grammar: ExpressionGrammar[TypeTree] } } + def generateMapping(ins: Seq[Identifier], satisfying: Expr, maxValid: Int, maxEnumerated: Int) = { + generateFor(ins, satisfying, maxValid, maxEnumerated) map (ins zip _) + } + } diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala index d7dd1a947ead97ccd1c30a6a89256e24e10654f4..48e2f48dca0b3e2170a21db9f00409ebdb46173a 100644 --- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala +++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala @@ -1138,7 +1138,7 @@ trait CodeExtraction extends ASTExtractors { case _ => (Nil, restTree) } - LetDef(funDefWithBody +: other_fds, block) + letDef(funDefWithBody +: other_fds, block) // FIXME case ExDefaultValueFunction diff --git a/src/main/scala/leon/grammars/BaseGrammar.scala b/src/main/scala/leon/grammars/BaseGrammar.scala index f11f937498051eb47c2c522a0faa2a1499545175..6e0a2ee5e6842255aac5755c9ee27005e5360eb5 100644 --- a/src/main/scala/leon/grammars/BaseGrammar.scala +++ b/src/main/scala/leon/grammars/BaseGrammar.scala @@ -7,56 +7,65 @@ import purescala.Types._ import purescala.Expressions._ import purescala.Constructors._ +/** The basic grammar for Leon expressions. + * Generates the most obvious expressions for a given type, + * without regard of context (variables in scope, current function etc.) + * Also does some trivial simplifications. + */ case object BaseGrammar extends ExpressionGrammar[TypeTree] { - def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Gen] = t match { + + def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Prod] = t match { case BooleanType => List( - terminal(BooleanLiteral(true)), - terminal(BooleanLiteral(false)), - nonTerminal(List(BooleanType), { case Seq(a) => not(a) }), - nonTerminal(List(BooleanType, BooleanType), { case Seq(a, b) => and(a, b) }), - nonTerminal(List(BooleanType, BooleanType), { case Seq(a, b) => or(a, b) }), - nonTerminal(List(Int32Type, Int32Type), { case Seq(a, b) => LessThan(a, b) }), + terminal(BooleanLiteral(false), Tags.BooleanC), + terminal(BooleanLiteral(true), Tags.BooleanC), + nonTerminal(List(BooleanType), { case Seq(a) => not(a) }, Tags.Not), + nonTerminal(List(BooleanType, BooleanType), { case Seq(a, b) => and(a, b) }, Tags.And), + nonTerminal(List(BooleanType, BooleanType), { case Seq(a, b) => or(a, b) }, Tags.Or ), + nonTerminal(List(Int32Type, Int32Type), { case Seq(a, b) => LessThan(a, b) }), nonTerminal(List(Int32Type, Int32Type), { case Seq(a, b) => LessEquals(a, b) }), - nonTerminal(List(IntegerType, IntegerType), { case Seq(a, b) => LessThan(a, b) }), + nonTerminal(List(IntegerType, IntegerType), { case Seq(a, b) => LessThan(a, b) }), nonTerminal(List(IntegerType, IntegerType), { case Seq(a, b) => LessEquals(a, b) }) ) case Int32Type => List( - terminal(IntLiteral(0)), - terminal(IntLiteral(1)), - nonTerminal(List(Int32Type, Int32Type), { case Seq(a,b) => plus(a, b) }), - nonTerminal(List(Int32Type, Int32Type), { case Seq(a,b) => minus(a, b) }), - nonTerminal(List(Int32Type, Int32Type), { case Seq(a,b) => times(a, b) }) + terminal(IntLiteral(0), Tags.Zero), + terminal(IntLiteral(1), Tags.One ), + nonTerminal(List(Int32Type, Int32Type), { case Seq(a,b) => plus(a, b) }, Tags.Plus ), + nonTerminal(List(Int32Type, Int32Type), { case Seq(a,b) => minus(a, b) }, Tags.Minus), + nonTerminal(List(Int32Type, Int32Type), { case Seq(a,b) => times(a, b) }, Tags.Times) ) case IntegerType => List( - terminal(InfiniteIntegerLiteral(0)), - terminal(InfiniteIntegerLiteral(1)), - nonTerminal(List(IntegerType, IntegerType), { case Seq(a,b) => plus(a, b) }), - nonTerminal(List(IntegerType, IntegerType), { case Seq(a,b) => minus(a, b) }), - nonTerminal(List(IntegerType, IntegerType), { case Seq(a,b) => times(a, b) }) + terminal(InfiniteIntegerLiteral(0), Tags.Zero), + terminal(InfiniteIntegerLiteral(1), Tags.One ), + nonTerminal(List(IntegerType, IntegerType), { case Seq(a,b) => plus(a, b) }, Tags.Plus ), + nonTerminal(List(IntegerType, IntegerType), { case Seq(a,b) => minus(a, b) }, Tags.Minus), + nonTerminal(List(IntegerType, IntegerType), { case Seq(a,b) => times(a, b) }, Tags.Times), + nonTerminal(List(IntegerType, IntegerType), { case Seq(a,b) => Modulo(a, b) }, Tags.Mod), + nonTerminal(List(IntegerType, IntegerType), { case Seq(a,b) => Division(a, b) }, Tags.Div) ) case TupleType(stps) => List( - nonTerminal(stps, { sub => Tuple(sub) }) + nonTerminal(stps, Tuple, Tags.Constructor(isTerminal = false)) ) case cct: CaseClassType => List( - nonTerminal(cct.fields.map(_.getType), { case rs => CaseClass(cct, rs)} ) + nonTerminal(cct.fields.map(_.getType), CaseClass(cct, _), Tags.tagOf(cct) ) ) case act: AbstractClassType => act.knownCCDescendants.map { cct => - nonTerminal(cct.fields.map(_.getType), { case rs => CaseClass(cct, rs)} ) + nonTerminal(cct.fields.map(_.getType), CaseClass(cct, _), Tags.tagOf(cct) ) } case st @ SetType(base) => List( - nonTerminal(List(base), { case elems => FiniteSet(elems.toSet, base) }), + terminal(FiniteSet(Set(), base), Tags.Constant), + nonTerminal(List(base), { case elems => FiniteSet(elems.toSet, base) }, Tags.Constructor(isTerminal = false)), nonTerminal(List(st, st), { case Seq(a, b) => SetUnion(a, b) }), nonTerminal(List(st, st), { case Seq(a, b) => SetIntersection(a, b) }), nonTerminal(List(st, st), { case Seq(a, b) => SetDifference(a, b) }) @@ -64,7 +73,7 @@ case object BaseGrammar extends ExpressionGrammar[TypeTree] { case UnitType => List( - terminal(UnitLiteral()) + terminal(UnitLiteral(), Tags.Constant) ) case _ => diff --git a/src/main/scala/leon/grammars/Constants.scala b/src/main/scala/leon/grammars/Constants.scala new file mode 100644 index 0000000000000000000000000000000000000000..81c55346052668e0d82b05ee240867eb1e5c468c --- /dev/null +++ b/src/main/scala/leon/grammars/Constants.scala @@ -0,0 +1,33 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package grammars + +import purescala.Expressions._ +import purescala.Types.TypeTree +import purescala.ExprOps.collect +import purescala.Extractors.IsTyped + +/** Generates constants found in an [[leon.purescala.Expressions.Expr]]. + * Some constants that are generated by other grammars (like 0, 1) will be excluded + */ +case class Constants(e: Expr) extends ExpressionGrammar[TypeTree] { + + private val excluded: Set[Expr] = Set( + InfiniteIntegerLiteral(1), + InfiniteIntegerLiteral(0), + IntLiteral(1), + IntLiteral(0), + BooleanLiteral(true), + BooleanLiteral(false) + ) + + def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Prod] = { + val literals = collect[Expr]{ + case IsTyped(l:Literal[_], `t`) => Set(l) + case _ => Set() + }(e) + + (literals -- excluded map (terminal(_, Tags.Constant))).toSeq + } +} \ No newline at end of file diff --git a/src/main/scala/leon/grammars/DepthBoundedGrammar.scala b/src/main/scala/leon/grammars/DepthBoundedGrammar.scala deleted file mode 100644 index fc999be644bf2c4a7a20a73403cf7b1001bb9b68..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/grammars/DepthBoundedGrammar.scala +++ /dev/null @@ -1,19 +0,0 @@ -/* Copyright 2009-2015 EPFL, Lausanne */ - -package leon -package grammars - -case class DepthBoundedGrammar[T](g: ExpressionGrammar[NonTerminal[T]], bound: Int) extends ExpressionGrammar[NonTerminal[T]] { - def computeProductions(l: NonTerminal[T])(implicit ctx: LeonContext): Seq[Gen] = g.computeProductions(l).flatMap { - case gen => - if (l.depth == Some(bound) && gen.subTrees.nonEmpty) { - Nil - } else if (l.depth.exists(_ > bound)) { - Nil - } else { - List ( - nonTerminal(gen.subTrees.map(sl => sl.copy(depth = l.depth.map(_+1).orElse(Some(1)))), gen.builder) - ) - } - } -} diff --git a/src/main/scala/leon/grammars/Empty.scala b/src/main/scala/leon/grammars/Empty.scala index 70ebddc98f21fc872aef8635fe36de7e9ba9bbce..737f9cdf389454f403a6581e13eec7fafa383f34 100644 --- a/src/main/scala/leon/grammars/Empty.scala +++ b/src/main/scala/leon/grammars/Empty.scala @@ -5,6 +5,7 @@ package grammars import purescala.Types.Typed +/** The empty expression grammar */ case class Empty[T <: Typed]() extends ExpressionGrammar[T] { - def computeProductions(t: T)(implicit ctx: LeonContext): Seq[Gen] = Nil + def computeProductions(t: T)(implicit ctx: LeonContext): Seq[Prod] = Nil } diff --git a/src/main/scala/leon/grammars/EqualityGrammar.scala b/src/main/scala/leon/grammars/EqualityGrammar.scala index e9463a771204d877a4d748c373b6d198e2c2591b..a2f9c41360ada03334ace63eca3ca46f9f6d5ff7 100644 --- a/src/main/scala/leon/grammars/EqualityGrammar.scala +++ b/src/main/scala/leon/grammars/EqualityGrammar.scala @@ -6,13 +6,15 @@ package grammars import purescala.Types._ import purescala.Constructors._ -import bonsai._ - +/** A grammar of equalities + * + * @param types The set of types for which equalities will be generated + */ case class EqualityGrammar(types: Set[TypeTree]) extends ExpressionGrammar[TypeTree] { - override def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Gen] = t match { + def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Prod] = t match { case BooleanType => types.toList map { tp => - nonTerminal(List(tp, tp), { case Seq(a, b) => equality(a, b) }) + nonTerminal(List(tp, tp), { case Seq(a, b) => equality(a, b) }, Tags.Equals) } case _ => Nil diff --git a/src/main/scala/leon/grammars/ExpressionGrammar.scala b/src/main/scala/leon/grammars/ExpressionGrammar.scala index ac394ab840bddf0d498080a04e447ce66de07caa..3179312b7f65444eb3e8c39357fd449e13339c8f 100644 --- a/src/main/scala/leon/grammars/ExpressionGrammar.scala +++ b/src/main/scala/leon/grammars/ExpressionGrammar.scala @@ -6,23 +6,37 @@ package grammars import purescala.Expressions._ import purescala.Types._ import purescala.Common._ +import transformers.Union +import utils.Timer import scala.collection.mutable.{HashMap => MutableMap} +/** Represents a context-free grammar of expressions + * + * @tparam T The type of nonterminal symbols for this grammar + */ abstract class ExpressionGrammar[T <: Typed] { - type Gen = Generator[T, Expr] - private[this] val cache = new MutableMap[T, Seq[Gen]]() + type Prod = ProductionRule[T, Expr] - def terminal(builder: => Expr) = { - Generator[T, Expr](Nil, { (subs: Seq[Expr]) => builder }) + private[this] val cache = new MutableMap[T, Seq[Prod]]() + + /** Generates a [[ProductionRule]] without nonterminal symbols */ + def terminal(builder: => Expr, tag: Tags.Tag = Tags.Top, cost: Int = 1) = { + ProductionRule[T, Expr](Nil, { (subs: Seq[Expr]) => builder }, tag, cost) } - def nonTerminal(subs: Seq[T], builder: (Seq[Expr] => Expr)): Generator[T, Expr] = { - Generator[T, Expr](subs, builder) + /** Generates a [[ProductionRule]] with nonterminal symbols */ + def nonTerminal(subs: Seq[T], builder: (Seq[Expr] => Expr), tag: Tags.Tag = Tags.Top, cost: Int = 1): ProductionRule[T, Expr] = { + ProductionRule[T, Expr](subs, builder, tag, cost) } - def getProductions(t: T)(implicit ctx: LeonContext): Seq[Gen] = { + /** The list of production rules for this grammar for a given nonterminal. + * This is the cached version of [[getProductions]] which clients should use. + * + * @param t The nonterminal for which production rules will be generated + */ + def getProductions(t: T)(implicit ctx: LeonContext): Seq[Prod] = { cache.getOrElse(t, { val res = computeProductions(t) cache += t -> res @@ -30,9 +44,13 @@ abstract class ExpressionGrammar[T <: Typed] { }) } - def computeProductions(t: T)(implicit ctx: LeonContext): Seq[Gen] + /** The list of production rules for this grammar for a given nonterminal + * + * @param t The nonterminal for which production rules will be generated + */ + def computeProductions(t: T)(implicit ctx: LeonContext): Seq[Prod] - def filter(f: Gen => Boolean) = { + def filter(f: Prod => Boolean) = { new ExpressionGrammar[T] { def computeProductions(t: T)(implicit ctx: LeonContext) = ExpressionGrammar.this.computeProductions(t).filter(f) } @@ -44,14 +62,19 @@ abstract class ExpressionGrammar[T <: Typed] { final def printProductions(printer: String => Unit)(implicit ctx: LeonContext) { - for ((t, gs) <- cache; g <- gs) { - val subs = g.subTrees.map { t => - FreshIdentifier(Console.BOLD+t.asString+Console.RESET, t.getType).toVariable - } + for ((t, gs) <- cache) { + val lhs = f"${Console.BOLD}${t.asString}%50s${Console.RESET} ::=" + if (gs.isEmpty) { + printer(s"$lhs ε") + } else for (g <- gs) { + val subs = g.subTrees.map { t => + FreshIdentifier(Console.BOLD + t.asString + Console.RESET, t.getType).toVariable + } - val gen = g.builder(subs).asString + val gen = g.builder(subs).asString - printer(f"${Console.BOLD}${t.asString}%30s${Console.RESET} ::= $gen") + printer(s"$lhs $gen") + } } } } diff --git a/src/main/scala/leon/grammars/FunctionCalls.scala b/src/main/scala/leon/grammars/FunctionCalls.scala index 14f92393934c18804bdb130e9c1617b915a347bd..1233fb1931a83b5ca674019be0c85144339dd19f 100644 --- a/src/main/scala/leon/grammars/FunctionCalls.scala +++ b/src/main/scala/leon/grammars/FunctionCalls.scala @@ -10,8 +10,14 @@ import purescala.ExprOps._ import purescala.DefOps._ import purescala.Expressions._ +/** Generates non-recursive function calls + * + * @param currentFunction The currend function for which no calls will be generated + * @param types The candidate real type parameters for [[currentFunction]] + * @param exclude An additional set of functions for which no calls will be generated + */ case class FunctionCalls(prog: Program, currentFunction: FunDef, types: Seq[TypeTree], exclude: Set[FunDef]) extends ExpressionGrammar[TypeTree] { - def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Gen] = { + def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Prod] = { def getCandidates(fd: FunDef): Seq[TypedFunDef] = { // Prevents recursive calls @@ -73,7 +79,7 @@ case class FunctionCalls(prog: Program, currentFunction: FunDef, types: Seq[Type val funcs = visibleFunDefsFromMain(prog).toSeq.sortBy(_.id).flatMap(getCandidates).filterNot(filter) funcs.map{ tfd => - nonTerminal(tfd.params.map(_.getType), { sub => FunctionInvocation(tfd, sub) }) + nonTerminal(tfd.params.map(_.getType), FunctionInvocation(tfd, _), Tags.tagOf(tfd.fd, isSafe = false)) } } } diff --git a/src/main/scala/leon/grammars/Generator.scala b/src/main/scala/leon/grammars/Generator.scala deleted file mode 100644 index 18d132e2c25ea222324dc05809220f12d0fb7100..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/grammars/Generator.scala +++ /dev/null @@ -1,16 +0,0 @@ -/* Copyright 2009-2015 EPFL, Lausanne */ - -package leon -package grammars - -import bonsai.{Generator => Gen} - -object GrammarTag extends Enumeration { - val Top = Value -} -import GrammarTag._ - -class Generator[T, R](subTrees: Seq[T], builder: Seq[R] => R, tag: Value) extends Gen[T,R](subTrees, builder) -object Generator { - def apply[T, R](subTrees: Seq[T], builder: Seq[R] => R, tag: Value = Top) = new Generator(subTrees, builder, tag) -} \ No newline at end of file diff --git a/src/main/scala/leon/grammars/Grammars.scala b/src/main/scala/leon/grammars/Grammars.scala index 23b1dd5a14cfeb82dd4555832e777597615b337e..06aba3d5f5343cc7e2807854f0b4665bfa1a602c 100644 --- a/src/main/scala/leon/grammars/Grammars.scala +++ b/src/main/scala/leon/grammars/Grammars.scala @@ -7,6 +7,7 @@ import purescala.Expressions._ import purescala.Definitions._ import purescala.Types._ import purescala.TypeOps._ +import transformers.OneOf import synthesis.{SynthesisContext, Problem} @@ -16,6 +17,7 @@ object Grammars { BaseGrammar || EqualityGrammar(Set(IntegerType, Int32Type, BooleanType) ++ inputs.map { _.getType }) || OneOf(inputs) || + Constants(currentFunction.fullBody) || FunctionCalls(prog, currentFunction, inputs.map(_.getType), exclude) || SafeRecursiveCalls(prog, ws, pc) } @@ -28,3 +30,4 @@ object Grammars { g.filter(g => g.subTrees.forall(t => typeDepth(t.getType) <= b)) } } + diff --git a/src/main/scala/leon/grammars/NonTerminal.scala b/src/main/scala/leon/grammars/NonTerminal.scala index 7492ffac5c17df326084f846857c6ac3bebe1775..600189ffa06378841f6bf3285f7f1bd7bb6116f5 100644 --- a/src/main/scala/leon/grammars/NonTerminal.scala +++ b/src/main/scala/leon/grammars/NonTerminal.scala @@ -5,7 +5,14 @@ package grammars import purescala.Types._ -case class NonTerminal[T](t: TypeTree, l: T, depth: Option[Int] = None) extends Typed { +/** A basic non-terminal symbol of a grammar. + * + * @param t The type of which expressions will be generated + * @param l A label that characterizes this [[NonTerminal]] + * @param depth The optional depth within the syntax tree where this [[NonTerminal]] is. + * @tparam L The type of label for this NonTerminal. + */ +case class NonTerminal[L](t: TypeTree, l: L, depth: Option[Int] = None) extends Typed { val getType = t override def asString(implicit ctx: LeonContext) = t.asString+"#"+l+depth.map(d => "@"+d).getOrElse("") diff --git a/src/main/scala/leon/grammars/ProductionRule.scala b/src/main/scala/leon/grammars/ProductionRule.scala new file mode 100644 index 0000000000000000000000000000000000000000..fc493a7d9d17557a26a8e54ff4615d39ba922190 --- /dev/null +++ b/src/main/scala/leon/grammars/ProductionRule.scala @@ -0,0 +1,18 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package grammars + +import bonsai.Generator + +/** Represents a production rule of a non-terminal symbol of an [[ExpressionGrammar]]. + * + * @param subTrees The nonterminals that are used in the right-hand side of this [[ProductionRule]] + * (and will generate deeper syntax trees). + * @param builder A function that builds the syntax tree that this [[ProductionRule]] represents from nested trees. + * @param tag Gives information about the nature of this production rule. + * @tparam T The type of nonterminal symbols of the grammar + * @tparam R The type of syntax trees of the grammar + */ +case class ProductionRule[T, R](override val subTrees: Seq[T], override val builder: Seq[R] => R, tag: Tags.Tag, cost: Int = 1) + extends Generator[T,R](subTrees, builder) diff --git a/src/main/scala/leon/grammars/SafeRecursiveCalls.scala b/src/main/scala/leon/grammars/SafeRecursiveCalls.scala index 1bbcb0523158ac95713f5a0d4a16f0f35e14edf4..f3234176a8c17378a7a5f027f38cbd42069ae7d6 100644 --- a/src/main/scala/leon/grammars/SafeRecursiveCalls.scala +++ b/src/main/scala/leon/grammars/SafeRecursiveCalls.scala @@ -9,15 +9,25 @@ import purescala.ExprOps._ import purescala.Expressions._ import synthesis.utils.Helpers._ +/** Generates recursive calls that will not trivially result in non-termination. + * + * @param ws An expression that contains the known set [[synthesis.Witnesses.Terminating]] expressions + * @param pc The path condition for the generated [[Expr]] by this grammar + */ case class SafeRecursiveCalls(prog: Program, ws: Expr, pc: Expr) extends ExpressionGrammar[TypeTree] { - def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Gen] = { + def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Prod] = { val calls = terminatingCalls(prog, t, ws, pc) calls.map { - case (e, free) => + case (fi, free) => val freeSeq = free.toSeq - nonTerminal(freeSeq.map(_.getType), { sub => replaceFromIDs(freeSeq.zip(sub).toMap, e) }) + nonTerminal( + freeSeq.map(_.getType), + { sub => replaceFromIDs(freeSeq.zip(sub).toMap, fi) }, + Tags.tagOf(fi.tfd.fd, isSafe = true), + 2 + ) } } } diff --git a/src/main/scala/leon/grammars/SimilarTo.scala b/src/main/scala/leon/grammars/SimilarTo.scala index 77e912792965d860fc934eb016370c8f2b57fd8f..3a7708e9a77960ffbfde98d478d2ca7c73c713d0 100644 --- a/src/main/scala/leon/grammars/SimilarTo.scala +++ b/src/main/scala/leon/grammars/SimilarTo.scala @@ -3,21 +3,24 @@ package leon package grammars +import transformers._ import purescala.Types._ import purescala.TypeOps._ import purescala.Extractors._ import purescala.Definitions._ -import purescala.ExprOps._ -import purescala.DefOps._ import purescala.Expressions._ import synthesis._ +/** A grammar that generates expressions by inserting small variations in [[e]] + * @param e The [[Expr]] to which small variations will be inserted + * @param terminals A set of [[Expr]]s that may be inserted into [[e]] as small variations + */ case class SimilarTo(e: Expr, terminals: Set[Expr] = Set(), sctx: SynthesisContext, p: Problem) extends ExpressionGrammar[NonTerminal[String]] { val excludeFCalls = sctx.settings.functionsToIgnore - val normalGrammar = DepthBoundedGrammar(EmbeddedGrammar( + val normalGrammar: ExpressionGrammar[NonTerminal[String]] = DepthBoundedGrammar(EmbeddedGrammar( BaseGrammar || EqualityGrammar(Set(IntegerType, Int32Type, BooleanType) ++ terminals.map { _.getType }) || OneOf(terminals.toSeq :+ e) || @@ -37,9 +40,9 @@ case class SimilarTo(e: Expr, terminals: Set[Expr] = Set(), sctx: SynthesisConte } } - private[this] var similarCache: Option[Map[L, Seq[Gen]]] = None + private[this] var similarCache: Option[Map[L, Seq[Prod]]] = None - def computeProductions(t: L)(implicit ctx: LeonContext): Seq[Gen] = { + def computeProductions(t: L)(implicit ctx: LeonContext): Seq[Prod] = { t match { case NonTerminal(_, "B", _) => normalGrammar.computeProductions(t) case _ => @@ -54,7 +57,7 @@ case class SimilarTo(e: Expr, terminals: Set[Expr] = Set(), sctx: SynthesisConte } } - def computeSimilar(e : Expr)(implicit ctx: LeonContext): Seq[(L, Gen)] = { + def computeSimilar(e : Expr)(implicit ctx: LeonContext): Seq[(L, Prod)] = { def getLabel(t: TypeTree) = { val tpe = bestRealType(t) @@ -67,9 +70,9 @@ case class SimilarTo(e: Expr, terminals: Set[Expr] = Set(), sctx: SynthesisConte case _ => false } - def rec(e: Expr, gl: L): Seq[(L, Gen)] = { + def rec(e: Expr, gl: L): Seq[(L, Prod)] = { - def gens(e: Expr, gl: L, subs: Seq[Expr], builder: (Seq[Expr] => Expr)): Seq[(L, Gen)] = { + def gens(e: Expr, gl: L, subs: Seq[Expr], builder: (Seq[Expr] => Expr)): Seq[(L, Prod)] = { val subGls = subs.map { s => getLabel(s.getType) } // All the subproductions for sub gl @@ -81,8 +84,8 @@ case class SimilarTo(e: Expr, terminals: Set[Expr] = Set(), sctx: SynthesisConte } val swaps = if (subs.size > 1 && !isCommutative(e)) { - (for (i <- 0 until subs.size; - j <- i+1 until subs.size) yield { + (for (i <- subs.indices; + j <- i+1 until subs.size) yield { if (subs(i).getType == subs(j).getType) { val swapSubs = subs.updated(i, subs(j)).updated(j, subs(i)) @@ -98,18 +101,18 @@ case class SimilarTo(e: Expr, terminals: Set[Expr] = Set(), sctx: SynthesisConte allSubs ++ injectG ++ swaps } - def cegis(gl: L): Seq[(L, Gen)] = { + def cegis(gl: L): Seq[(L, Prod)] = { normalGrammar.getProductions(gl).map(gl -> _) } - def int32Variations(gl: L, e : Expr): Seq[(L, Gen)] = { + def int32Variations(gl: L, e : Expr): Seq[(L, Prod)] = { Seq( gl -> terminal(BVMinus(e, IntLiteral(1))), gl -> terminal(BVPlus (e, IntLiteral(1))) ) } - def intVariations(gl: L, e : Expr): Seq[(L, Gen)] = { + def intVariations(gl: L, e : Expr): Seq[(L, Prod)] = { Seq( gl -> terminal(Minus(e, InfiniteIntegerLiteral(1))), gl -> terminal(Plus (e, InfiniteIntegerLiteral(1))) @@ -118,7 +121,7 @@ case class SimilarTo(e: Expr, terminals: Set[Expr] = Set(), sctx: SynthesisConte // Find neighbor case classes that are compatible with the arguments: // Turns And(e1, e2) into Or(e1, e2)... - def ccVariations(gl: L, cc: CaseClass): Seq[(L, Gen)] = { + def ccVariations(gl: L, cc: CaseClass): Seq[(L, Prod)] = { val CaseClass(cct, args) = cc val neighbors = cct.root.knownCCDescendants diff Seq(cct) @@ -129,7 +132,7 @@ case class SimilarTo(e: Expr, terminals: Set[Expr] = Set(), sctx: SynthesisConte } val funFilter = (fd: FunDef) => fd.isSynthetic || (excludeFCalls contains fd) - val subs: Seq[(L, Gen)] = e match { + val subs: Seq[(L, Prod)] = e match { case _: Terminal | _: Let | _: LetDef | _: MatchExpr => gens(e, gl, Nil, { _ => e }) ++ cegis(gl) diff --git a/src/main/scala/leon/grammars/SizeBoundedGrammar.scala b/src/main/scala/leon/grammars/SizeBoundedGrammar.scala deleted file mode 100644 index 1b25e30f61aa74598feb255366fe10a153bc9e30..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/grammars/SizeBoundedGrammar.scala +++ /dev/null @@ -1,35 +0,0 @@ -/* Copyright 2009-2015 EPFL, Lausanne */ - -package leon -package grammars - -import purescala.Types._ -import leon.utils.SeqUtils.sumTo - -case class SizedLabel[T <: Typed](underlying: T, size: Int) extends Typed { - val getType = underlying.getType - - override def asString(implicit ctx: LeonContext) = underlying.asString+"|"+size+"|" -} - -case class SizeBoundedGrammar[T <: Typed](g: ExpressionGrammar[T]) extends ExpressionGrammar[SizedLabel[T]] { - def computeProductions(sl: SizedLabel[T])(implicit ctx: LeonContext): Seq[Gen] = { - if (sl.size <= 0) { - Nil - } else if (sl.size == 1) { - g.getProductions(sl.underlying).filter(_.subTrees.isEmpty).map { gen => - terminal(gen.builder(Seq())) - } - } else { - g.getProductions(sl.underlying).filter(_.subTrees.nonEmpty).flatMap { gen => - val sizes = sumTo(sl.size-1, gen.subTrees.size) - - for (ss <- sizes) yield { - val subSizedLabels = (gen.subTrees zip ss) map (s => SizedLabel(s._1, s._2)) - - nonTerminal(subSizedLabels, gen.builder) - } - } - } - } -} diff --git a/src/main/scala/leon/grammars/Tags.scala b/src/main/scala/leon/grammars/Tags.scala new file mode 100644 index 0000000000000000000000000000000000000000..4a6b6fca491b8db6f74622edd9298ec5cd6053b0 --- /dev/null +++ b/src/main/scala/leon/grammars/Tags.scala @@ -0,0 +1,65 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package grammars + +import purescala.Types.CaseClassType +import purescala.Definitions.FunDef + +object Tags { + /** A class for tags that tag a [[ProductionRule]] with the kind of expression in generates. */ + abstract class Tag + case object Top extends Tag // Tag for the top-level of the grammar (default) + case object Zero extends Tag // Tag for 0 + case object One extends Tag // Tag for 1 + case object BooleanC extends Tag // Tag for boolean constants + case object Constant extends Tag // Tag for other constants + case object And extends Tag // Tags for boolean operations + case object Or extends Tag + case object Not extends Tag + case object Plus extends Tag // Tags for arithmetic operations + case object Minus extends Tag + case object Times extends Tag + case object Mod extends Tag + case object Div extends Tag + case object Variable extends Tag // Tag for variables + case object Equals extends Tag // Tag for equality + /** Constructors like Tuple, CaseClass... + * + * @param isTerminal If true, this constructor represents a terminal symbol + * (in practice, case class with 0 fields) + */ + case class Constructor(isTerminal: Boolean) extends Tag + /** Tag for function calls + * + * @param isMethod Whether the function called is a method + * @param isSafe Whether this constructor represents a safe function call. + * We need this because this call implicitly contains a variable, + * so we want to allow constants in all arguments. + */ + case class FunCall(isMethod: Boolean, isSafe: Boolean) extends Tag + + /** The set of tags that represent constants */ + val isConst: Set[Tag] = Set(Zero, One, Constant, BooleanC, Constructor(true)) + + /** The set of tags that represent commutative operations */ + val isCommut: Set[Tag] = Set(Plus, Times, Equals) + + /** The set of tags which have trivial results for equal arguments */ + val symmetricTrivial = Set(Minus, And, Or, Equals, Div, Mod) + + /** Tags which allow constants in all their operands + * + * In reality, the current version never allows that: it is only allowed in safe function calls + * which by construction contain a hidden reference to a variable. + * TODO: Experiment with different conditions, e.g. are constants allowed in + * top-level/ general function calls/ constructors/...? + */ + def allConstArgsAllowed(t: Tag) = t match { + case FunCall(_, true) => true + case _ => false + } + + def tagOf(cct: CaseClassType) = Constructor(cct.fields.isEmpty) + def tagOf(fd: FunDef, isSafe: Boolean) = FunCall(fd.methodOwner.isDefined, isSafe) +} \ No newline at end of file diff --git a/src/main/scala/leon/grammars/ValueGrammar.scala b/src/main/scala/leon/grammars/ValueGrammar.scala index 98850c8df4adcf3e776970c176cf37c251823917..d3c42201728f4b03d9518b3db503cff9189dcc8b 100644 --- a/src/main/scala/leon/grammars/ValueGrammar.scala +++ b/src/main/scala/leon/grammars/ValueGrammar.scala @@ -6,62 +6,64 @@ package grammars import purescala.Types._ import purescala.Expressions._ +/** A grammar of values (ground terms) */ case object ValueGrammar extends ExpressionGrammar[TypeTree] { - def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Gen] = t match { + def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Prod] = t match { case BooleanType => List( - terminal(BooleanLiteral(true)), - terminal(BooleanLiteral(false)) + terminal(BooleanLiteral(true), Tags.One), + terminal(BooleanLiteral(false), Tags.Zero) ) case Int32Type => List( - terminal(IntLiteral(0)), - terminal(IntLiteral(1)), - terminal(IntLiteral(5)) + terminal(IntLiteral(0), Tags.Zero), + terminal(IntLiteral(1), Tags.One), + terminal(IntLiteral(5), Tags.Constant) ) case IntegerType => List( - terminal(InfiniteIntegerLiteral(0)), - terminal(InfiniteIntegerLiteral(1)), - terminal(InfiniteIntegerLiteral(5)) + terminal(InfiniteIntegerLiteral(0), Tags.Zero), + terminal(InfiniteIntegerLiteral(1), Tags.One), + terminal(InfiniteIntegerLiteral(5), Tags.Constant) ) case StringType => List( - terminal(StringLiteral("")), - terminal(StringLiteral("a")), - terminal(StringLiteral("foo")), - terminal(StringLiteral("bar")) + terminal(StringLiteral(""), Tags.Constant), + terminal(StringLiteral("a"), Tags.Constant), + terminal(StringLiteral("foo"), Tags.Constant), + terminal(StringLiteral("bar"), Tags.Constant) ) case tp: TypeParameter => - for (ind <- (1 to 3).toList) yield { - terminal(GenericValue(tp, ind)) - } + List( + terminal(GenericValue(tp, 0)) + ) case TupleType(stps) => List( - nonTerminal(stps, { sub => Tuple(sub) }) + nonTerminal(stps, Tuple, Tags.Constructor(stps.isEmpty)) ) case cct: CaseClassType => List( - nonTerminal(cct.fields.map(_.getType), { case rs => CaseClass(cct, rs)}) + nonTerminal(cct.fields.map(_.getType), CaseClass(cct, _), Tags.tagOf(cct)) ) case act: AbstractClassType => act.knownCCDescendants.map { cct => - nonTerminal(cct.fields.map(_.getType), { case rs => CaseClass(cct, rs)}) + nonTerminal(cct.fields.map(_.getType), CaseClass(cct, _), Tags.tagOf(cct)) } case st @ SetType(base) => List( - nonTerminal(List(base), { case elems => FiniteSet(elems.toSet, base) }), - nonTerminal(List(base, base), { case elems => FiniteSet(elems.toSet, base) }) + terminal(FiniteSet(Set(), base), Tags.Constant), + nonTerminal(List(base), { elems => FiniteSet(elems.toSet, base) }, Tags.Constructor(isTerminal = false)), + nonTerminal(List(base, base), { elems => FiniteSet(elems.toSet, base) }, Tags.Constructor(isTerminal = false)) ) case UnitType => List( - terminal(UnitLiteral()) + terminal(UnitLiteral(), Tags.Constant) ) case _ => diff --git a/src/main/scala/leon/grammars/transformers/DepthBoundedGrammar.scala b/src/main/scala/leon/grammars/transformers/DepthBoundedGrammar.scala new file mode 100644 index 0000000000000000000000000000000000000000..02e045497a28db205d4a33a300ac0b742510920a --- /dev/null +++ b/src/main/scala/leon/grammars/transformers/DepthBoundedGrammar.scala @@ -0,0 +1,21 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package grammars +package transformers + +/** Limits a grammar to a specific expression depth */ +case class DepthBoundedGrammar[L](g: ExpressionGrammar[NonTerminal[L]], bound: Int) extends ExpressionGrammar[NonTerminal[L]] { + def computeProductions(l: NonTerminal[L])(implicit ctx: LeonContext): Seq[Prod] = g.computeProductions(l).flatMap { + case gen => + if (l.depth == Some(bound) && gen.isNonTerminal) { + Nil + } else if (l.depth.exists(_ > bound)) { + Nil + } else { + List ( + nonTerminal(gen.subTrees.map(sl => sl.copy(depth = l.depth.map(_+1).orElse(Some(1)))), gen.builder) + ) + } + } +} diff --git a/src/main/scala/leon/grammars/EmbeddedGrammar.scala b/src/main/scala/leon/grammars/transformers/EmbeddedGrammar.scala similarity index 74% rename from src/main/scala/leon/grammars/EmbeddedGrammar.scala rename to src/main/scala/leon/grammars/transformers/EmbeddedGrammar.scala index 8dcbc6ec10f9aa42895e5f876cdd4d72479de229..d989a8804b32f62697b7f31e498e61393a12c35b 100644 --- a/src/main/scala/leon/grammars/EmbeddedGrammar.scala +++ b/src/main/scala/leon/grammars/transformers/EmbeddedGrammar.scala @@ -2,10 +2,9 @@ package leon package grammars +package transformers -import purescala.Types._ -import purescala.Expressions._ -import purescala.Constructors._ +import leon.purescala.Types.Typed /** * Embed a grammar Li->Expr within a grammar Lo->Expr @@ -13,9 +12,9 @@ import purescala.Constructors._ * We rely on a bijection between Li and Lo labels */ case class EmbeddedGrammar[Ti <: Typed, To <: Typed](innerGrammar: ExpressionGrammar[Ti], iToo: Ti => To, oToi: To => Ti) extends ExpressionGrammar[To] { - def computeProductions(t: To)(implicit ctx: LeonContext): Seq[Gen] = { + def computeProductions(t: To)(implicit ctx: LeonContext): Seq[Prod] = { innerGrammar.computeProductions(oToi(t)).map { innerGen => - nonTerminal(innerGen.subTrees.map(iToo), innerGen.builder) + nonTerminal(innerGen.subTrees.map(iToo), innerGen.builder, innerGen.tag) } } } diff --git a/src/main/scala/leon/grammars/OneOf.scala b/src/main/scala/leon/grammars/transformers/OneOf.scala similarity index 56% rename from src/main/scala/leon/grammars/OneOf.scala rename to src/main/scala/leon/grammars/transformers/OneOf.scala index 0e10c096151c1fdf83d3c7e7f10c4a4a6518215b..5c57c6a1a48179e2d813398aa651022df6cae35a 100644 --- a/src/main/scala/leon/grammars/OneOf.scala +++ b/src/main/scala/leon/grammars/transformers/OneOf.scala @@ -2,14 +2,15 @@ package leon package grammars +package transformers -import purescala.Types._ -import purescala.Expressions._ -import purescala.TypeOps._ -import purescala.Constructors._ +import purescala.Expressions.Expr +import purescala.Types.TypeTree +import purescala.TypeOps.isSubtypeOf +/** Generates one production rule for each expression in a sequence that has compatible type */ case class OneOf(inputs: Seq[Expr]) extends ExpressionGrammar[TypeTree] { - def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Gen] = { + def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Prod] = { inputs.collect { case i if isSubtypeOf(i.getType, t) => terminal(i) diff --git a/src/main/scala/leon/grammars/transformers/SizeBoundedGrammar.scala b/src/main/scala/leon/grammars/transformers/SizeBoundedGrammar.scala new file mode 100644 index 0000000000000000000000000000000000000000..1b605359fdf18d02f08d105a9cccc58757b99262 --- /dev/null +++ b/src/main/scala/leon/grammars/transformers/SizeBoundedGrammar.scala @@ -0,0 +1,59 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package grammars +package transformers + +import purescala.Types.Typed +import utils.SeqUtils._ + +/** Adds information about size to a nonterminal symbol */ +case class SizedNonTerm[T <: Typed](underlying: T, size: Int) extends Typed { + val getType = underlying.getType + + override def asString(implicit ctx: LeonContext) = underlying.asString+"|"+size+"|" +} + +/** Limits a grammar by producing expressions of size bounded by the [[SizedNonTerm.size]] of a given [[SizedNonTerm]]. + * + * In case of commutative operations, the grammar will produce trees skewed to the right + * (i.e. the right subtree will always be larger). Notice we do not lose generality in case of + * commutative operations. + */ +case class SizeBoundedGrammar[T <: Typed](g: ExpressionGrammar[T], optimizeCommut: Boolean) extends ExpressionGrammar[SizedNonTerm[T]] { + def computeProductions(sl: SizedNonTerm[T])(implicit ctx: LeonContext): Seq[Prod] = { + if (sl.size <= 0) { + Nil + } else if (sl.size == 1) { + g.getProductions(sl.underlying).filter(_.isTerminal).map { gen => + terminal(gen.builder(Seq()), gen.tag) + } + } else { + g.getProductions(sl.underlying).filter(_.isNonTerminal).flatMap { gen => + + // Ad-hoc equality that does not take into account position etc.of TaggedNonTerminal's + // TODO: Ugly and hacky + def characteristic(t: T): Typed = t match { + case TaggedNonTerm(underlying, _, _, _) => + underlying + case other => + other + } + + // Optimization: When we have a commutative operation and all the labels are the same, + // we can skew the expression to always be right-heavy + val sizes = if(optimizeCommut && Tags.isCommut(gen.tag) && gen.subTrees.map(characteristic).toSet.size == 1) { + sumToOrdered(sl.size-gen.cost, gen.arity) + } else { + sumTo(sl.size-gen.cost, gen.arity) + } + + for (ss <- sizes) yield { + val subSizedLabels = (gen.subTrees zip ss) map (s => SizedNonTerm(s._1, s._2)) + + nonTerminal(subSizedLabels, gen.builder, gen.tag) + } + } + } + } +} diff --git a/src/main/scala/leon/grammars/transformers/TaggedGrammar.scala b/src/main/scala/leon/grammars/transformers/TaggedGrammar.scala new file mode 100644 index 0000000000000000000000000000000000000000..43ce13e850ed1b52460ef1a74d7b039adacbd519 --- /dev/null +++ b/src/main/scala/leon/grammars/transformers/TaggedGrammar.scala @@ -0,0 +1,111 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package grammars +package transformers + +import leon.purescala.Types.Typed +import Tags._ + +/** Adds to a nonterminal information about about the tag of its parent's [[leon.grammars.ProductionRule.tag]] + * and additional information. + * + * @param underlying The underlying nonterminal + * @param tag The tag of the parent of this nonterminal + * @param pos The index of this nonterminal in its father's production rule + * @param isConst Whether this nonterminal is obliged to generate/not generate constants. + * + */ +case class TaggedNonTerm[T <: Typed](underlying: T, tag: Tag, pos: Int, isConst: Option[Boolean]) extends Typed { + val getType = underlying.getType + + private val cString = isConst match { + case Some(true) => "↓" + case Some(false) => "↑" + case None => "○" + } + + /** [[isConst]] is printed as follows: ↓ for constants only, ↑ for nonconstants only, + * ○ for anything allowed. + */ + override def asString(implicit ctx: LeonContext): String = s"$underlying%$tag@$pos$cString" +} + +/** Constraints a grammar to reduce redundancy by utilizing information provided by the [[TaggedNonTerm]]. + * + * 1) In case of associative operations, right associativity is enforced. + * 2) Does not generate + * - neutral and absorbing elements (incl. boolean equality) + * - nested negations + * 3) Excludes method calls on nullary case objects, e.g. Nil().size + * 4) Enforces that no constant trees are generated (and recursively for each subtree) + * + * @param g The underlying untagged grammar + */ +case class TaggedGrammar[T <: Typed](g: ExpressionGrammar[T]) extends ExpressionGrammar[TaggedNonTerm[T]] { + + private def exclude(tag: Tag, pos: Int): Set[Tag] = (tag, pos) match { + case (Top, _) => Set() + case (And, 0) => Set(And, BooleanC) + case (And, 1) => Set(BooleanC) + case (Or, 0) => Set(Or, BooleanC) + case (Or, 1) => Set(BooleanC) + case (Plus, 0) => Set(Plus, Zero, One) + case (Plus, 1) => Set(Zero) + case (Minus, 1) => Set(Zero) + case (Not, _) => Set(Not, BooleanC) + case (Times, 0) => Set(Times, Zero, One) + case (Times, 1) => Set(Zero, One) + case (Equals,_) => Set(Not, BooleanC) + case (Div | Mod, 0 | 1) => Set(Zero, One) + case (FunCall(true, _), 0) => Set(Constructor(true)) // Don't allow Nil().size etc. + case _ => Set() + } + + def computeProductions(t: TaggedNonTerm[T])(implicit ctx: LeonContext): Seq[Prod] = { + + // Point (4) for this level + val constFilter: g.Prod => Boolean = t.isConst match { + case Some(b) => + innerGen => isConst(innerGen.tag) == b + case None => + _ => true + } + + g.computeProductions(t.underlying) + // Include only constants iff constants are forced, only non-constants iff they are forced + .filter(constFilter) + // Points (1), (2). (3) + .filterNot { innerGen => exclude(t.tag, t.pos)(innerGen.tag) } + .flatMap { innerGen => + + def nt(isConst: Int => Option[Boolean]) = nonTerminal( + innerGen.subTrees.zipWithIndex.map { + case (t, pos) => TaggedNonTerm(t, innerGen.tag, pos, isConst(pos)) + }, + innerGen.builder, + innerGen.tag + ) + + def powerSet[A](t: Set[A]): Set[Set[A]] = { + @scala.annotation.tailrec + def pwr(t: Set[A], ps: Set[Set[A]]): Set[Set[A]] = + if (t.isEmpty) ps + else pwr(t.tail, ps ++ (ps map (_ + t.head))) + + pwr(t, Set(Set.empty[A])) + } + + // Allow constants everywhere if this is allowed, otherwise demand at least 1 variable. + // Aka. tag subTrees correctly so point (4) is enforced in the lower level + // (also, make sure we treat terminals correctly). + if (innerGen.isTerminal || allConstArgsAllowed(innerGen.tag)) { + Seq(nt(_ => None)) + } else { + val indices = innerGen.subTrees.indices.toSet + (powerSet(indices) - indices) map (indices => nt(x => Some(indices(x)))) + } + } + } + +} diff --git a/src/main/scala/leon/grammars/Or.scala b/src/main/scala/leon/grammars/transformers/Union.scala similarity index 73% rename from src/main/scala/leon/grammars/Or.scala rename to src/main/scala/leon/grammars/transformers/Union.scala index e691a245984eaeb11277b9278505b49cf623fed3..471625ac3c22c22456f49f366ed26e5195b5f4ab 100644 --- a/src/main/scala/leon/grammars/Or.scala +++ b/src/main/scala/leon/grammars/transformers/Union.scala @@ -2,8 +2,9 @@ package leon package grammars +package transformers -import purescala.Types._ +import purescala.Types.Typed case class Union[T <: Typed](gs: Seq[ExpressionGrammar[T]]) extends ExpressionGrammar[T] { val subGrammars: Seq[ExpressionGrammar[T]] = gs.flatMap { @@ -11,6 +12,6 @@ case class Union[T <: Typed](gs: Seq[ExpressionGrammar[T]]) extends ExpressionGr case g => Seq(g) } - def computeProductions(t: T)(implicit ctx: LeonContext): Seq[Gen] = + def computeProductions(t: T)(implicit ctx: LeonContext): Seq[Prod] = subGrammars.flatMap(_.getProductions(t)) } diff --git a/src/main/scala/leon/purescala/ExprOps.scala b/src/main/scala/leon/purescala/ExprOps.scala index 99803ec2adea8a7c00dceacb579cdb4f10366b13..5740b0d5f567adc66d064edde5f8ab62ec45c6f5 100644 --- a/src/main/scala/leon/purescala/ExprOps.scala +++ b/src/main/scala/leon/purescala/ExprOps.scala @@ -75,7 +75,7 @@ object ExprOps extends { val Deconstructor = Operator } with SubTreeOps[Expr] { /** Returns the set of free variables in an expression */ def variablesOf(expr: Expr): Set[Identifier] = { - import leon.xlang.Expressions.LetVar + import leon.xlang.Expressions._ fold[Set[Identifier]] { case (e, subs) => val subvs = subs.flatten.toSet @@ -176,7 +176,7 @@ object ExprOps extends { val Deconstructor = Operator } with SubTreeOps[Expr] { case l @ Let(i,e,b) => val newID = FreshIdentifier(i.name, i.getType, alwaysShowUniqueID = true).copiedFrom(i) - Some(Let(newID, e, replace(Map(Variable(i) -> Variable(newID)), b))) + Some(Let(newID, e, replaceFromIDs(Map(i -> Variable(newID)), b))) case _ => None }(expr) @@ -331,7 +331,7 @@ object ExprOps extends { val Deconstructor = Operator } with SubTreeOps[Expr] { def simplerLet(t: Expr) : Option[Expr] = t match { case letExpr @ Let(i, t: Terminal, b) if isDeterministic(b) => - Some(replace(Map(Variable(i) -> t), b)) + Some(replaceFromIDs(Map(i -> t), b)) case letExpr @ Let(i,e,b) if isDeterministic(b) => { val occurrences = count { @@ -342,7 +342,7 @@ object ExprOps extends { val Deconstructor = Operator } with SubTreeOps[Expr] { if(occurrences == 0) { Some(b) } else if(occurrences == 1) { - Some(replace(Map(Variable(i) -> e), b)) + Some(replaceFromIDs(Map(i -> e), b)) } else { None } @@ -353,7 +353,7 @@ object ExprOps extends { val Deconstructor = Operator } with SubTreeOps[Expr] { val (remIds, remExprs) = (ids zip exprs).filter { case (id, value: Terminal) => - newBody = replace(Map(Variable(id) -> value), newBody) + newBody = replaceFromIDs(Map(id -> value), newBody) //we replace, so we drop old false case (id, value) => @@ -1863,7 +1863,7 @@ object ExprOps extends { val Deconstructor = Operator } with SubTreeOps[Expr] { fds ++= nfds - Some(LetDef(nfds.map(_._2), b)) + Some(letDef(nfds.map(_._2), b)) case FunctionInvocation(tfd, args) => if (fds contains tfd.fd) { diff --git a/src/main/scala/leon/purescala/Expressions.scala b/src/main/scala/leon/purescala/Expressions.scala index 69dbc4429aac6d31444cd024944d03096c7b6cc8..88edef68586474c24e7a13b70a22e9beed7441b8 100644 --- a/src/main/scala/leon/purescala/Expressions.scala +++ b/src/main/scala/leon/purescala/Expressions.scala @@ -76,10 +76,6 @@ object Expressions { val getType = tpe } - case class Old(id: Identifier) extends Expr with Terminal { - val getType = id.getType - } - /** Precondition of an [[Expressions.Expr]]. Corresponds to the Leon keyword *require* * * @param pred The precondition formula inside ``require(...)`` @@ -165,7 +161,7 @@ object Expressions { * @param body The body of the expression after the function */ case class LetDef(fds: Seq[FunDef], body: Expr) extends Expr { - assert(fds.nonEmpty) + require(fds.nonEmpty) val getType = body.getType } @@ -594,7 +590,10 @@ object Expressions { /** $encodingof `lhs.subString(start, end)` for strings */ case class SubString(expr: Expr, start: Expr, end: Expr) extends Expr { val getType = { - if (expr.getType == StringType && (start == IntegerType || start == Int32Type) && (end == IntegerType || end == Int32Type)) StringType + val ext = expr.getType + val st = start.getType + val et = end.getType + if (ext == StringType && (st == IntegerType || st == Int32Type) && (et == IntegerType || et == Int32Type)) StringType else Untyped } } @@ -786,7 +785,7 @@ object Expressions { * * [[exprs]] should always contain at least 2 elements. * If you are not sure about this requirement, you should use - * [[purescala.Constructors#tupleWrap purescala's constructor tupleWrap]] + * [[leon.purescala.Constructors#tupleWrap purescala's constructor tupleWrap]] * * @param exprs The expressions in the tuple */ @@ -799,7 +798,7 @@ object Expressions { * * Index is 1-based, first element of tuple is 1. * If you are not sure that [[tuple]] is indeed of a TupleType, - * you should use [[purescala.Constructors$.tupleSelect(t:leon\.purescala\.Expressions\.Expr,index:Int,isTuple:Boolean):leon\.purescala\.Expressions\.Expr* purescala's constructor tupleSelect]] + * you should use [[leon.purescala.Constructors.tupleSelect(t:leon\.purescala\.Expressions\.Expr,index:Int,isTuple:Boolean):leon\.purescala\.Expressions\.Expr* purescala's constructor tupleSelect]] */ case class TupleSelect(tuple: Expr, index: Int) extends Expr { require(index >= 1) diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala index 039bef507339294874ad59e7e38dfe335b8f6a2f..44d454a9c0a4f1127ed0533f872c66f1ec6bc067 100644 --- a/src/main/scala/leon/purescala/PrettyPrinter.scala +++ b/src/main/scala/leon/purescala/PrettyPrinter.scala @@ -81,15 +81,12 @@ class PrettyPrinter(opts: PrinterOptions, } p"$name" - case Old(id) => - p"old($id)" - case Variable(id) => p"$id" case Let(b,d,e) => - p"""|val $b = $d - |$e""" + p"""|val $b = $d + |$e""" case LetDef(a::q,body) => p"""|$a diff --git a/src/main/scala/leon/purescala/ScalaPrinter.scala b/src/main/scala/leon/purescala/ScalaPrinter.scala index 67cc994649e6b454ee8389fbfdd7674da4aebb6a..11f0c187e144873c386a01702326056e636e1225 100644 --- a/src/main/scala/leon/purescala/ScalaPrinter.scala +++ b/src/main/scala/leon/purescala/ScalaPrinter.scala @@ -9,14 +9,12 @@ import Common._ import Expressions._ import Types._ import Definitions._ -import org.apache.commons.lang3.StringEscapeUtils -/** This pretty-printer only print valid scala syntax */ +/** This pretty-printer only prints valid scala syntax */ class ScalaPrinter(opts: PrinterOptions, opgm: Option[Program], sb: StringBuffer = new StringBuffer) extends PrettyPrinter(opts, opgm, sb) { - private val dbquote = "\"" override def pp(tree: Tree)(implicit ctx: PrinterContext): Unit = { tree match { diff --git a/src/main/scala/leon/purescala/ScopeSimplifier.scala b/src/main/scala/leon/purescala/ScopeSimplifier.scala index f0ff379ffd6ec0419d631777c623098b21838a57..380554d8c3d8c566bb6da32d274ecf1eea199d11 100644 --- a/src/main/scala/leon/purescala/ScopeSimplifier.scala +++ b/src/main/scala/leon/purescala/ScopeSimplifier.scala @@ -7,6 +7,7 @@ import Common._ import Definitions._ import Expressions._ import Extractors._ +import Constructors.letDef class ScopeSimplifier extends Transformer { case class Scope(inScope: Set[Identifier] = Set(), oldToNew: Map[Identifier, Identifier] = Map(), funDefs: Map[FunDef, FunDef] = Map()) { @@ -59,7 +60,7 @@ class ScopeSimplifier extends Transformer { for((newFd, fd) <- fds_mapping) { newFd.fullBody = rec(fd.fullBody, newScope) } - LetDef(fds_mapping.map(_._1), rec(body, newScope)) + letDef(fds_mapping.map(_._1), rec(body, newScope)) case MatchExpr(scrut, cases) => val rs = rec(scrut, scope) diff --git a/src/main/scala/leon/purescala/TypeOps.scala b/src/main/scala/leon/purescala/TypeOps.scala index 20b8e5fdbee22a2b110fdc85839e9495cc19c7d4..14cf3b8250682a71285fffb3fd34e2e25608cdb1 100644 --- a/src/main/scala/leon/purescala/TypeOps.scala +++ b/src/main/scala/leon/purescala/TypeOps.scala @@ -13,7 +13,7 @@ import ExprOps.preMap object TypeOps extends { val Deconstructor = NAryType } with SubTreeOps[TypeTree] { def typeDepth(t: TypeTree): Int = t match { - case NAryType(tps, builder) => 1+ (0 +: (tps map typeDepth)).max + case NAryType(tps, builder) => 1 + (0 +: (tps map typeDepth)).max } def typeParamsOf(t: TypeTree): Set[TypeParameter] = { @@ -335,7 +335,7 @@ object TypeOps extends { val Deconstructor = NAryType } with SubTreeOps[TypeTree } val newBd = srec(subCalls(bd)).copiedFrom(bd) - LetDef(newFds, newBd).copiedFrom(l) + letDef(newFds, newBd).copiedFrom(l) case l @ Lambda(args, body) => val newArgs = args.map { arg => diff --git a/src/main/scala/leon/repair/Repairman.scala b/src/main/scala/leon/repair/Repairman.scala index 37f187679897bbec908ffeaf18f5385198f915c9..9dcd782a3323138e5bfe4c68822fba614e2065e7 100644 --- a/src/main/scala/leon/repair/Repairman.scala +++ b/src/main/scala/leon/repair/Repairman.scala @@ -3,6 +3,7 @@ package leon package repair +import leon.datagen.GrammarDataGen import purescala.Common._ import purescala.Definitions._ import purescala.Expressions._ @@ -25,7 +26,6 @@ import synthesis.Witnesses._ import synthesis.graph.{dotGenIds, DotGenerator} import rules._ -import grammars._ class Repairman(ctx0: LeonContext, initProgram: Program, fd: FunDef, verifTimeoutMs: Option[Long], repairTimeoutMs: Option[Long]) { implicit val ctx = ctx0 @@ -155,7 +155,7 @@ class Repairman(ctx0: LeonContext, initProgram: Program, fd: FunDef, verifTimeou }(DebugSectionReport) if (synth.settings.generateDerivationTrees) { - val dot = new DotGenerator(search.g) + val dot = new DotGenerator(search) dot.writeFile("derivation"+ dotGenIds.nextGlobal + ".dot") } @@ -236,29 +236,10 @@ class Repairman(ctx0: LeonContext, initProgram: Program, fd: FunDef, verifTimeou def discoverTests(): ExamplesBank = { - import bonsai.enumerators._ - val maxEnumerated = 1000 val maxValid = 400 val evaluator = new CodeGenEvaluator(ctx, program, CodeGenParams.default) - val enum = new MemoizedEnumerator[TypeTree, Expr, Generator[TypeTree, Expr]](ValueGrammar.getProductions) - - val inputs = enum.iterator(tupleTypeWrap(fd.params map { _.getType})).map(unwrapTuple(_, fd.params.size)) - - val filtering: Seq[Expr] => Boolean = fd.precondition match { - case None => - _ => true - case Some(pre) => - val argIds = fd.paramIds - evaluator.compile(pre, argIds) match { - case Some(evalFun) => - val sat = EvaluationResults.Successful(BooleanLiteral(true)); - { (es: Seq[Expr]) => evalFun(new solvers.Model((argIds zip es).toMap)) == sat } - case None => - { _ => false } - } - } val inputsToExample: Seq[Expr] => Example = { ins => evaluator.eval(functionInvocation(fd, ins)) match { @@ -269,10 +250,10 @@ class Repairman(ctx0: LeonContext, initProgram: Program, fd: FunDef, verifTimeou } } - val generatedTests = inputs - .take(maxEnumerated) - .filter(filtering) - .take(maxValid) + val dataGen = new GrammarDataGen(evaluator) + + val generatedTests = dataGen + .generateFor(fd.paramIds, fd.precOrTrue, maxValid, maxEnumerated) .map(inputsToExample) .toList diff --git a/src/main/scala/leon/solvers/SolverUnsupportedError.scala b/src/main/scala/leon/solvers/SolverUnsupportedError.scala index 5d519160d7aed9fce7a42584c8d53806e53e265a..2efc8ea39b0da8494b2cd1309b3dcf9c2ca9cec3 100644 --- a/src/main/scala/leon/solvers/SolverUnsupportedError.scala +++ b/src/main/scala/leon/solvers/SolverUnsupportedError.scala @@ -7,7 +7,7 @@ import purescala.Common.Tree object SolverUnsupportedError { def msg(t: Tree, s: Solver, reason: Option[String]) = { - s" is unsupported by solver ${s.name}" + reason.map(":\n " + _ ).getOrElse("") + s"(of ${t.getClass}) is unsupported by solver ${s.name}" + reason.map(":\n " + _ ).getOrElse("") } } diff --git a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala index f0b7d5f91a9f53f32dd6eb9590c988618f0536d4..7c3486ff533a630bbd1bce6bde31d86969163e7c 100644 --- a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala @@ -526,7 +526,7 @@ trait AbstractZ3Solver extends Solver { rec(RawArrayValue(from, elems.map{ case (k, v) => (k, CaseClass(library.someType(t), Seq(v))) - }.toMap, CaseClass(library.noneType(t), Seq()))) + }, CaseClass(library.noneType(t), Seq()))) case MapApply(m, k) => val mt @ MapType(_, t) = normalizeType(m.getType) diff --git a/src/main/scala/leon/synthesis/ExamplesFinder.scala b/src/main/scala/leon/synthesis/ExamplesFinder.scala index 5077f9467dab2edd3abf8fdf7da01d5eae59ff5e..1755966e0223c0a4fa75eb01eb4a681ed1ddbf1d 100644 --- a/src/main/scala/leon/synthesis/ExamplesFinder.scala +++ b/src/main/scala/leon/synthesis/ExamplesFinder.scala @@ -6,13 +6,10 @@ package synthesis import purescala.Expressions._ import purescala.Definitions._ import purescala.ExprOps._ -import purescala.Types.TypeTree import purescala.Common._ import purescala.Constructors._ -import purescala.Extractors._ import evaluators._ import grammars._ -import bonsai.enumerators._ import codegen._ import datagen._ import solvers._ @@ -110,9 +107,9 @@ class ExamplesFinder(ctx0: LeonContext, program: Program) { val datagen = new GrammarDataGen(evaluator, ValueGrammar) val solverDataGen = new SolverDataGen(ctx, program, (ctx, pgm) => SolverFactory(() => new FairZ3Solver(ctx, pgm))) - val generatedExamples = datagen.generateFor(ids, pc, maxValid, maxEnumerated).map(InExample(_)) + val generatedExamples = datagen.generateFor(ids, pc, maxValid, maxEnumerated).map(InExample) - val solverExamples = solverDataGen.generateFor(ids, pc, maxValid, maxEnumerated).map(InExample(_)) + val solverExamples = solverDataGen.generateFor(ids, pc, maxValid, maxEnumerated).map(InExample) ExamplesBank(generatedExamples.toSeq ++ solverExamples.toList, Nil) } @@ -180,34 +177,20 @@ class ExamplesFinder(ctx0: LeonContext, program: Program) { case (a, b, c) => None }) getOrElse { + // If the input contains free variables, it does not provide concrete examples. // We will instantiate them according to a simple grammar to get them. - val enum = new MemoizedEnumerator[TypeTree, Expr, Generator[TypeTree, Expr]](ValueGrammar.getProductions) - val values = enum.iterator(tupleTypeWrap(freeVars.map { _.getType })) - val instantiations = values.map { - v => freeVars.zip(unwrapTuple(v, freeVars.size)).toMap - } - def filterGuard(e: Expr, mapping: Map[Identifier, Expr]): Boolean = cs.optGuard match { - case Some(guard) => - // in -> e should be enough. We shouldn't find any subexpressions of in. - evaluator.eval(replace(Map(in -> e), guard), mapping) match { - case EvaluationResults.Successful(BooleanLiteral(true)) => true - case _ => false - } + val dataGen = new GrammarDataGen(evaluator) + + val theGuard = replace(Map(in -> pattExpr), cs.optGuard.getOrElse(BooleanLiteral(true))) - case None => - true + dataGen.generateFor(freeVars, theGuard, examplesPerCase, 1000).toSeq map { vals => + val inst = freeVars.zip(vals).toMap + val inR = replaceFromIDs(inst, pattExpr) + val outR = replaceFromIDs(inst, doSubstitute(ieMap, cs.rhs)) + (inR, outR) } - - if(cs.optGuard == Some(BooleanLiteral(false))) { - Nil - } else (for { - inst <- instantiations.toSeq - inR = replaceFromIDs(inst, pattExpr) - outR = replaceFromIDs(inst, doSubstitute(ieMap, cs.rhs)) - if filterGuard(inR, inst) - } yield (inR, outR)).take(examplesPerCase) } } } diff --git a/src/main/scala/leon/synthesis/Rules.scala b/src/main/scala/leon/synthesis/Rules.scala index cd27e272d53e9669f4c2d1f2c8e07356819b4827..3a86ca64a79238aec4e59b197e001fd4c24660b4 100644 --- a/src/main/scala/leon/synthesis/Rules.scala +++ b/src/main/scala/leon/synthesis/Rules.scala @@ -35,8 +35,10 @@ abstract class PreprocessingRule(name: String) extends Rule(name) { /** Contains the list of all available rules for synthesis */ object Rules { + + def all: List[Rule] = all(false) /** Returns the list of all available rules for synthesis */ - def all = List[Rule]( + def all(naiveGrammar: Boolean): List[Rule] = List[Rule]( StringRender, Unification.DecompTrivialClash, Unification.OccursCheck, // probably useless @@ -54,8 +56,8 @@ object Rules { OptimisticGround, EqualitySplit, InequalitySplit, - CEGIS, - TEGIS, + if(naiveGrammar) NaiveCEGIS else CEGIS, + //TEGIS, //BottomUpTEGIS, rules.Assert, DetupleOutput, diff --git a/src/main/scala/leon/synthesis/SourceInfo.scala b/src/main/scala/leon/synthesis/SourceInfo.scala index 4bb10d38c9ffc7a7667d165b84e4f65c1edc9e0c..8ab07929d78479656f18ce1fd652cfa7ef870e17 100644 --- a/src/main/scala/leon/synthesis/SourceInfo.scala +++ b/src/main/scala/leon/synthesis/SourceInfo.scala @@ -45,6 +45,10 @@ object SourceInfo { ci } + if (results.isEmpty) { + ctx.reporter.warning("No 'choose' found. Maybe the functions you chose do not exist?") + } + results.sortBy(_.source.getPos) } diff --git a/src/main/scala/leon/synthesis/SynthesisPhase.scala b/src/main/scala/leon/synthesis/SynthesisPhase.scala index b9ba6df1f688e01edf631e995ba2f8623bcfc5fe..ac4d30614d8269ce78a92a95e232855eb40d9fbd 100644 --- a/src/main/scala/leon/synthesis/SynthesisPhase.scala +++ b/src/main/scala/leon/synthesis/SynthesisPhase.scala @@ -3,13 +3,11 @@ package leon package synthesis -import purescala.ExprOps._ - +import purescala.ExprOps.replace import purescala.ScalaPrinter -import leon.utils._ import purescala.Definitions.{Program, FunDef} -import leon.utils.ASCIIHelpers +import leon.utils._ import graph._ object SynthesisPhase extends TransformationPhase { @@ -21,11 +19,13 @@ object SynthesisPhase extends TransformationPhase { val optDerivTrees = LeonFlagOptionDef( "derivtrees", "Generate derivation trees", false) // CEGIS options - val optCEGISOptTimeout = LeonFlagOptionDef( "cegis:opttimeout", "Consider a time-out of CE-search as untrusted solution", true) - val optCEGISVanuatoo = LeonFlagOptionDef( "cegis:vanuatoo", "Generate inputs using new korat-style generator", false) + val optCEGISOptTimeout = LeonFlagOptionDef("cegis:opttimeout", "Consider a time-out of CE-search as untrusted solution", true ) + val optCEGISVanuatoo = LeonFlagOptionDef("cegis:vanuatoo", "Generate inputs using new korat-style generator", false) + val optCEGISNaiveGrammar = LeonFlagOptionDef("cegis:naive", "Use the old naive grammar for CEGIS", false) + val optCEGISMaxSize = LeonLongOptionDef("cegis:maxsize", "Maximum size of expressions synthesized by CEGIS", 5L, "N") override val definedOptions : Set[LeonOptionDef[Any]] = - Set(optManual, optCostModel, optDerivTrees, optCEGISOptTimeout, optCEGISVanuatoo) + Set(optManual, optCostModel, optDerivTrees, optCEGISOptTimeout, optCEGISVanuatoo, optCEGISNaiveGrammar, optCEGISMaxSize) def processOptions(ctx: LeonContext): SynthesisSettings = { val ms = ctx.findOption(optManual) @@ -53,11 +53,13 @@ object SynthesisPhase extends TransformationPhase { timeoutMs = timeout map { _ * 1000 }, generateDerivationTrees = ctx.findOptionOrDefault(optDerivTrees), costModel = costModel, - rules = Rules.all ++ (if(ms.isDefined) Seq(rules.AsChoose, rules.SygusCVC4) else Seq()), + rules = Rules.all(ctx.findOptionOrDefault(optCEGISNaiveGrammar)) ++ + (if(ms.isDefined) Seq(rules.AsChoose, rules.SygusCVC4) else Seq()), manualSearch = ms, functions = ctx.findOption(SharedOptions.optFunctions) map { _.toSet }, - cegisUseOptTimeout = ctx.findOption(optCEGISOptTimeout), - cegisUseVanuatoo = ctx.findOption(optCEGISVanuatoo) + cegisUseOptTimeout = ctx.findOptionOrDefault(optCEGISOptTimeout), + cegisUseVanuatoo = ctx.findOptionOrDefault(optCEGISVanuatoo), + cegisMaxSize = ctx.findOptionOrDefault(optCEGISMaxSize).toInt ) } @@ -80,7 +82,7 @@ object SynthesisPhase extends TransformationPhase { try { if (options.generateDerivationTrees) { - val dot = new DotGenerator(search.g) + val dot = new DotGenerator(search) dot.writeFile("derivation"+dotGenIds.nextGlobal+".dot") } diff --git a/src/main/scala/leon/synthesis/SynthesisSettings.scala b/src/main/scala/leon/synthesis/SynthesisSettings.scala index 5202818e18765ebf4086ef41d1685967a14940d0..61dc24ece71081c0f02f5bdcb38d9d9eeb0fee14 100644 --- a/src/main/scala/leon/synthesis/SynthesisSettings.scala +++ b/src/main/scala/leon/synthesis/SynthesisSettings.scala @@ -16,7 +16,8 @@ case class SynthesisSettings( functionsToIgnore: Set[FunDef] = Set(), // Cegis related options - cegisUseOptTimeout: Option[Boolean] = None, - cegisUseVanuatoo: Option[Boolean] = None + cegisUseOptTimeout: Boolean = true, + cegisUseVanuatoo : Boolean = false, + cegisMaxSize: Int = 5 ) diff --git a/src/main/scala/leon/synthesis/Synthesizer.scala b/src/main/scala/leon/synthesis/Synthesizer.scala index bafed6ec2bab51539bfc0547563bbad2aeea873e..efd1ad13e0538f855487e94b7b1a35d7d893627f 100644 --- a/src/main/scala/leon/synthesis/Synthesizer.scala +++ b/src/main/scala/leon/synthesis/Synthesizer.scala @@ -70,21 +70,19 @@ class Synthesizer(val context : LeonContext, // Print out report for synthesis, if necessary reporter.ifDebug { printer => - import java.io.FileWriter import java.text.SimpleDateFormat import java.util.Date val categoryName = ci.fd.getPos.file.toString.split("/").dropRight(1).lastOption.getOrElse("?") val benchName = categoryName+"."+ci.fd.id.name - var time = lastTime/1000.0; + val time = lastTime/1000.0 val defs = visibleDefsFrom(ci.fd)(program).collect { case cd: ClassDef => 1 + cd.fields.size case fd: FunDef => 1 + fd.params.size + formulaSize(fd.fullBody) } - val psize = defs.sum; - + val psize = defs.sum val (size, calls, proof) = result.headOption match { case Some((sol, trusted)) => diff --git a/src/main/scala/leon/synthesis/disambiguation/ExamplesAdder.scala b/src/main/scala/leon/synthesis/disambiguation/ExamplesAdder.scala index 6e9dc237667e286cdbd9a82875a986dfbf6b8aba..6ce5f020ae88c649458b857afca76b207b68916f 100644 --- a/src/main/scala/leon/synthesis/disambiguation/ExamplesAdder.scala +++ b/src/main/scala/leon/synthesis/disambiguation/ExamplesAdder.scala @@ -3,15 +3,12 @@ package leon package synthesis package disambiguation -import leon.LeonContext -import leon.purescala.Expressions._ import purescala.Common.FreshIdentifier import purescala.Constructors.{ and, tupleWrap } import purescala.Definitions.{ FunDef, Program, ValDef } import purescala.ExprOps.expressionToPattern -import purescala.Expressions.{ BooleanLiteral, Equals, Expr, Lambda, MatchCase, Passes, Variable, WildcardPattern } import purescala.Extractors.TopLevelAnds -import leon.purescala.Expressions._ +import purescala.Expressions._ /** * @author Mikael diff --git a/src/main/scala/leon/synthesis/disambiguation/QuestionBuilder.scala b/src/main/scala/leon/synthesis/disambiguation/QuestionBuilder.scala index bfa6a62120af6171d001b6a026630734eb6c10fd..d77c06c97de2808e38ab4c0e218fb614a99d8298 100644 --- a/src/main/scala/leon/synthesis/disambiguation/QuestionBuilder.scala +++ b/src/main/scala/leon/synthesis/disambiguation/QuestionBuilder.scala @@ -1,21 +1,18 @@ package leon package synthesis.disambiguation +import datagen.GrammarDataGen import synthesis.Solution import evaluators.DefaultEvaluator import purescala.Expressions._ import purescala.ExprOps -import purescala.Constructors._ -import purescala.Extractors._ import purescala.Types.{StringType, TypeTree} import purescala.Common.Identifier import purescala.Definitions.Program import purescala.DefOps -import grammars.ValueGrammar -import bonsai.enumerators.MemoizedEnumerator +import grammars._ import solvers.ModelBuilder import scala.collection.mutable.ListBuffer -import grammars._ object QuestionBuilder { /** Sort methods for questions. You can build your own */ @@ -69,15 +66,15 @@ object QuestionBuilder { /** Specific enumeration of strings, which can be used with the QuestionBuilder#setValueEnumerator method */ object SpecialStringValueGrammar extends ExpressionGrammar[TypeTree] { - def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Gen] = t match { - case StringType => - List( - terminal(StringLiteral("")), - terminal(StringLiteral("a")), - terminal(StringLiteral("\"'\n\t")), - terminal(StringLiteral("Lara 2007")) - ) - case _ => ValueGrammar.computeProductions(t) + def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Prod] = t match { + case StringType => + List( + terminal(StringLiteral("")), + terminal(StringLiteral("a")), + terminal(StringLiteral("\"'\n\t")), + terminal(StringLiteral("Lara 2007")) + ) + case _ => ValueGrammar.computeProductions(t) } } } @@ -92,11 +89,9 @@ object QuestionBuilder { * * @tparam T A subtype of Expr that will be the type used in the Question[T] results. * @param input The identifier of the unique function's input. Must be typed or the type should be defined by setArgumentType - * @param ruleApplication The set of solutions for the body of f * @param filter A function filtering which outputs should be considered for comparison. - * It takes as input the sequence of outputs already considered for comparison, and the new output. - * It should return Some(result) if the result can be shown, and None else. - * @return An ordered + * It takes as input the sequence of outputs already considered for comparison, and the new output. + * It should return Some(result) if the result can be shown, and None else. * */ class QuestionBuilder[T <: Expr]( @@ -139,25 +134,22 @@ class QuestionBuilder[T <: Expr]( /** Returns a list of input/output questions to ask to the user. */ def result(): List[Question[T]] = { if(solutions.isEmpty) return Nil - - val enum = new MemoizedEnumerator[TypeTree, Expr, Generator[TypeTree,Expr]](value_enumerator.getProductions) - val values = enum.iterator(tupleTypeWrap(_argTypes)) - val instantiations = values.map { - v => input.zip(unwrapTuple(v, input.size)) - } - - val enumerated_inputs = instantiations.take(expressionsToTake).toList - + + val datagen = new GrammarDataGen(new DefaultEvaluator(c, p), value_enumerator) + val enumerated_inputs = datagen.generateMapping(input, BooleanLiteral(true), expressionsToTake, expressionsToTake).toList + val solution = solutions.head val alternatives = solutions.drop(1).take(solutionsToTake).toList val questions = ListBuffer[Question[T]]() - for{possible_input <- enumerated_inputs - current_output_nonfiltered <- run(solution, possible_input) - current_output <- filter(Seq(), current_output_nonfiltered)} { + for { + possibleInput <- enumerated_inputs + currentOutputNonFiltered <- run(solution, possibleInput) + currentOutput <- filter(Seq(), currentOutputNonFiltered) + } { - val alternative_outputs = ((ListBuffer[T](current_output) /: alternatives) { (prev, alternative) => - run(alternative, possible_input) match { - case Some(alternative_output) if alternative_output != current_output => + val alternative_outputs = (ListBuffer[T](currentOutput) /: alternatives) { (prev, alternative) => + run(alternative, possibleInput) match { + case Some(alternative_output) if alternative_output != currentOutput => filter(prev, alternative_output) match { case Some(alternative_output_filtered) => prev += alternative_output_filtered @@ -165,11 +157,11 @@ class QuestionBuilder[T <: Expr]( } case _ => prev } - }).drop(1).toList.distinct - if(alternative_outputs.nonEmpty || keepEmptyAlternativeQuestions(current_output)) { - questions += Question(possible_input.map(_._2), current_output, alternative_outputs.sortWith((e,f) => _alternativeSortMethod.compare(e, f) <= 0)) + }.drop(1).toList.distinct + if(alternative_outputs.nonEmpty || keepEmptyAlternativeQuestions(currentOutput)) { + questions += Question(possibleInput.map(_._2), currentOutput, alternative_outputs.sortWith((e,f) => _alternativeSortMethod.compare(e, f) <= 0)) } } questions.toList.sortBy(_questionSorMethod(_)) } -} \ No newline at end of file +} diff --git a/src/main/scala/leon/synthesis/graph/DotGenerator.scala b/src/main/scala/leon/synthesis/graph/DotGenerator.scala index 7da38716116f51d89e751a8aa12d709be776e17c..78ef7b371487a6711d3508b9712f7806e9c551e0 100644 --- a/src/main/scala/leon/synthesis/graph/DotGenerator.scala +++ b/src/main/scala/leon/synthesis/graph/DotGenerator.scala @@ -6,7 +6,11 @@ import leon.utils.UniqueCounter import java.io.{File, FileWriter, BufferedWriter} -class DotGenerator(g: Graph) { +class DotGenerator(search: Search) { + + implicit val ctx = search.ctx + + val g = search.g private val idCounter = new UniqueCounter[Unit] idCounter.nextGlobal // Start with 1 @@ -80,12 +84,14 @@ class DotGenerator(g: Graph) { } def nodeDesc(n: Node): String = n match { - case an: AndNode => an.ri.toString - case on: OrNode => on.p.toString + case an: AndNode => an.ri.asString + case on: OrNode => on.p.asString } def drawNode(res: StringBuffer, name: String, n: Node) { + val index = n.parent.map(_.descendants.indexOf(n) + " ").getOrElse("") + def escapeHTML(str: String) = str.replaceAll("&", "&").replaceAll("<", "<").replaceAll(">", ">") val color = if (n.isSolved) { @@ -109,10 +115,10 @@ class DotGenerator(g: Graph) { res append "<TR><TD BORDER=\"0\">"+escapeHTML(n.cost.asString)+"</TD></TR>" } - res append "<TR><TD BORDER=\"1\" BGCOLOR=\""+color+"\">"+escapeHTML(limit(nodeDesc(n)))+"</TD></TR>" + res append "<TR><TD BORDER=\"1\" BGCOLOR=\""+color+"\">"+escapeHTML(limit(index + nodeDesc(n)))+"</TD></TR>" if (n.isSolved) { - res append "<TR><TD BGCOLOR=\""+color+"\">"+escapeHTML(limit(n.generateSolutions().head.toString))+"</TD></TR>" + res append "<TR><TD BGCOLOR=\""+color+"\">"+escapeHTML(limit(n.generateSolutions().head.asString))+"</TD></TR>" } res append "</TABLE>>, shape = \"none\" ];\n" @@ -126,4 +132,4 @@ class DotGenerator(g: Graph) { } } -object dotGenIds extends UniqueCounter[Unit] \ No newline at end of file +object dotGenIds extends UniqueCounter[Unit] diff --git a/src/main/scala/leon/synthesis/graph/Search.scala b/src/main/scala/leon/synthesis/graph/Search.scala index 98554a5ae492972e0b7b3915979d9af829d81555..c630e315d9777110b5dcde7adc42cf6172161af3 100644 --- a/src/main/scala/leon/synthesis/graph/Search.scala +++ b/src/main/scala/leon/synthesis/graph/Search.scala @@ -10,7 +10,7 @@ import scala.collection.mutable.ArrayBuffer import leon.utils.Interruptible import java.util.concurrent.atomic.AtomicBoolean -abstract class Search(ctx: LeonContext, ci: SourceInfo, p: Problem, costModel: CostModel) extends Interruptible { +abstract class Search(val ctx: LeonContext, ci: SourceInfo, p: Problem, costModel: CostModel) extends Interruptible { val g = new Graph(costModel, p) def findNodeToExpandFrom(n: Node): Option[Node] diff --git a/src/main/scala/leon/synthesis/rules/ADTSplit.scala b/src/main/scala/leon/synthesis/rules/ADTSplit.scala index df2c44193412a55af004dfa7695901044a4b5b53..d3dc6347280a45642935e6ea3c314246a3cb6958 100644 --- a/src/main/scala/leon/synthesis/rules/ADTSplit.scala +++ b/src/main/scala/leon/synthesis/rules/ADTSplit.scala @@ -65,7 +65,7 @@ case object ADTSplit extends Rule("ADT Split.") { case Some((id, act, cases)) => val oas = p.as.filter(_ != id) - val subInfo = for(ccd <- cases) yield { + val subInfo0 = for(ccd <- cases) yield { val cct = CaseClassType(ccd, act.tps) val args = cct.fields.map { vd => FreshIdentifier(vd.id.name, vd.getType, true) }.toList @@ -89,6 +89,10 @@ case object ADTSplit extends Rule("ADT Split.") { (cct, subProblem, subPattern) } + val subInfo = subInfo0.sortBy{ case (cct, _, _) => + cct.fieldsTypes.count { t => t == act } + } + val onSuccess: List[Solution] => Option[Solution] = { case sols => diff --git a/src/main/scala/leon/synthesis/rules/BottomUpTegis.scala b/src/main/scala/leon/synthesis/rules/BottomUpTegis.scala index 2f3869af16b71f9635e36d27774f55a7cee7140c..4c12f58224427c1d74654638e28965d746f93d54 100644 --- a/src/main/scala/leon/synthesis/rules/BottomUpTegis.scala +++ b/src/main/scala/leon/synthesis/rules/BottomUpTegis.scala @@ -14,7 +14,6 @@ import codegen.CodeGenParams import grammars._ import bonsai.enumerators._ -import bonsai.{Generator => Gen} case object BottomUpTEGIS extends BottomUpTEGISLike[TypeTree]("BU TEGIS") { def getGrammar(sctx: SynthesisContext, p: Problem) = { @@ -51,13 +50,13 @@ abstract class BottomUpTEGISLike[T <: Typed](name: String) extends Rule(name) { val nTests = tests.size - var compiled = Map[Generator[T, Expr], Vector[Vector[Expr]] => Option[Vector[Expr]]]() + var compiled = Map[ProductionRule[T, Expr], Vector[Vector[Expr]] => Option[Vector[Expr]]]() /** * Compile Generators to functions from Expr to Expr. The compiled * generators will be passed to the enumerator */ - def compile(gen: Generator[T, Expr]): Vector[Vector[Expr]] => Option[Vector[Expr]] = { + def compile(gen: ProductionRule[T, Expr]): Vector[Vector[Expr]] => Option[Vector[Expr]] = { compiled.getOrElse(gen, { val executor = if (gen.subTrees.isEmpty) { @@ -108,7 +107,7 @@ abstract class BottomUpTEGISLike[T <: Typed](name: String) extends Rule(name) { val targetType = tupleTypeWrap(p.xs.map(_.getType)) val wrappedTests = tests.map { case (is, os) => (is, tupleWrap(os))} - val enum = new BottomUpEnumerator[T, Expr, Expr, Generator[T, Expr]]( + val enum = new BottomUpEnumerator[T, Expr, Expr, ProductionRule[T, Expr]]( grammar.getProductions, wrappedTests, { (vecs, gen) => diff --git a/src/main/scala/leon/synthesis/rules/CEGIS.scala b/src/main/scala/leon/synthesis/rules/CEGIS.scala index 1fcf01d52088ea9d4d25d184a673ef8335a8d260..b0de64ed05458d22cc113170dc850e2c1e2f6a3b 100644 --- a/src/main/scala/leon/synthesis/rules/CEGIS.scala +++ b/src/main/scala/leon/synthesis/rules/CEGIS.scala @@ -4,16 +4,31 @@ package leon package synthesis package rules -import purescala.Types._ - import grammars._ -import utils._ +import grammars.transformers._ +import purescala.Types.TypeTree -case object CEGIS extends CEGISLike[TypeTree]("CEGIS") { +/** Basic implementation of CEGIS that uses a naive grammar */ +case object NaiveCEGIS extends CEGISLike[TypeTree]("Naive CEGIS") { def getParams(sctx: SynthesisContext, p: Problem) = { CegisParams( grammar = Grammars.typeDepthBound(Grammars.default(sctx, p), 2), // This limits type depth - rootLabel = {(tpe: TypeTree) => tpe } + rootLabel = {(tpe: TypeTree) => tpe }, + optimizations = false + ) + } +} + +/** More advanced implementation of CEGIS that uses a less permissive grammar + * and some optimizations + */ +case object CEGIS extends CEGISLike[TaggedNonTerm[TypeTree]]("CEGIS") { + def getParams(sctx: SynthesisContext, p: Problem) = { + val base = NaiveCEGIS.getParams(sctx,p).grammar + CegisParams( + grammar = TaggedGrammar(base), + rootLabel = TaggedNonTerm(_, Tags.Top, 0, None), + optimizations = true ) } } diff --git a/src/main/scala/leon/synthesis/rules/CEGISLike.scala b/src/main/scala/leon/synthesis/rules/CEGISLike.scala index d577f7f9fe1f260f4af9ca4d3cb20ca868e37fcc..291e485d70b80b580095a484e02448166de9e18c 100644 --- a/src/main/scala/leon/synthesis/rules/CEGISLike.scala +++ b/src/main/scala/leon/synthesis/rules/CEGISLike.scala @@ -4,10 +4,6 @@ package leon package synthesis package rules -import leon.utils.SeqUtils -import solvers._ -import grammars._ - import purescala.Expressions._ import purescala.Common._ import purescala.Definitions._ @@ -16,44 +12,59 @@ import purescala.ExprOps._ import purescala.DefOps._ import purescala.Constructors._ -import scala.collection.mutable.{HashMap=>MutableMap, ArrayBuffer} +import solvers._ +import grammars._ +import grammars.transformers._ +import leon.utils._ import evaluators._ import datagen._ import codegen.CodeGenParams +import scala.collection.mutable.{HashMap=>MutableMap, ArrayBuffer} + abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { case class CegisParams( grammar: ExpressionGrammar[T], rootLabel: TypeTree => T, - maxUnfoldings: Int = 5 + optimizations: Boolean, + maxSize: Option[Int] = None ) def getParams(sctx: SynthesisContext, p: Problem): CegisParams def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { + val exSolverTo = 2000L val cexSolverTo = 2000L - // Track non-deterministic programs up to 10'000 programs, or give up + // Track non-deterministic programs up to 100'000 programs, or give up val nProgramsLimit = 100000 val sctx = hctx.sctx val ctx = sctx.context + val timers = ctx.timers.synthesis.cegis + // CEGIS Flags to activate or deactivate features - val useOptTimeout = sctx.settings.cegisUseOptTimeout.getOrElse(true) - val useVanuatoo = sctx.settings.cegisUseVanuatoo.getOrElse(false) + val useOptTimeout = sctx.settings.cegisUseOptTimeout + val useVanuatoo = sctx.settings.cegisUseVanuatoo // Limits the number of programs CEGIS will specifically validate individually val validateUpTo = 3 + val passingRatio = 10 val interruptManager = sctx.context.interruptManager val params = getParams(sctx, p) - if (params.maxUnfoldings == 0) { + // If this CEGISLike forces a maxSize, take it, otherwise find it in the settings + val maxSize = params.maxSize.getOrElse(sctx.settings.cegisMaxSize) + + ctx.reporter.debug(s"This is $name. Settings: optimizations = ${params.optimizations}, maxSize = $maxSize, vanuatoo=$useVanuatoo, optTimeout=$useOptTimeout") + + if (maxSize == 0) { return Nil } @@ -61,13 +72,13 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { private var termSize = 0 - val grammar = SizeBoundedGrammar(params.grammar) + val grammar = SizeBoundedGrammar(params.grammar, params.optimizations) - def rootLabel = SizedLabel(params.rootLabel(tupleTypeWrap(p.xs.map(_.getType))), termSize) + def rootLabel = SizedNonTerm(params.rootLabel(tupleTypeWrap(p.xs.map(_.getType))), termSize) - var nAltsCache = Map[SizedLabel[T], Int]() + var nAltsCache = Map[SizedNonTerm[T], Int]() - def countAlternatives(l: SizedLabel[T]): Int = { + def countAlternatives(l: SizedNonTerm[T]): Int = { if (!(nAltsCache contains l)) { val count = grammar.getProductions(l).map { gen => gen.subTrees.map(countAlternatives).product @@ -91,18 +102,18 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { * b3 => c6 == H(c4, c5) * * c1 -> Seq( - * (b1, F(c2, c3), Set(c2, c3)) - * (b2, G(c4, c5), Set(c4, c5)) + * (b1, F(_, _), Seq(c2, c3)) + * (b2, G(_, _), Seq(c4, c5)) * ) * c6 -> Seq( - * (b3, H(c7, c8), Set(c7, c8)) + * (b3, H(_, _), Seq(c7, c8)) * ) */ private var cTree: Map[Identifier, Seq[(Identifier, Seq[Expr] => Expr, Seq[Identifier])]] = Map() // C identifiers corresponding to p.xs - private var rootC: Identifier = _ + private var rootC: Identifier = _ private var bs: Set[Identifier] = Set() @@ -110,19 +121,19 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { class CGenerator { - private var buffers = Map[SizedLabel[T], Stream[Identifier]]() + private var buffers = Map[SizedNonTerm[T], Stream[Identifier]]() - private var slots = Map[SizedLabel[T], Int]().withDefaultValue(0) + private var slots = Map[SizedNonTerm[T], Int]().withDefaultValue(0) - private def streamOf(t: SizedLabel[T]): Stream[Identifier] = Stream.continually( + private def streamOf(t: SizedNonTerm[T]): Stream[Identifier] = Stream.continually( FreshIdentifier(t.asString, t.getType, true) ) def rewind(): Unit = { - slots = Map[SizedLabel[T], Int]().withDefaultValue(0) + slots = Map[SizedNonTerm[T], Int]().withDefaultValue(0) } - def getNext(t: SizedLabel[T]) = { + def getNext(t: SizedNonTerm[T]) = { if (!(buffers contains t)) { buffers += t -> streamOf(t) } @@ -140,13 +151,14 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { def updateCTree(): Unit = { + ctx.timers.synthesis.cegis.updateCTree.start() def freshB() = { val id = FreshIdentifier("B", BooleanType, true) bs += id id } - def defineCTreeFor(l: SizedLabel[T], c: Identifier): Unit = { + def defineCTreeFor(l: SizedNonTerm[T], c: Identifier): Unit = { if (!(cTree contains c)) { val cGen = new CGenerator() @@ -182,11 +194,14 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { sctx.reporter.ifDebug { printer => printer("Grammar so far:") grammar.printProductions(printer) + printer("") } bsOrdered = bs.toSeq.sorted + excludedPrograms = ArrayBuffer() setCExpr(computeCExpr()) + ctx.timers.synthesis.cegis.updateCTree.stop() } /** @@ -233,9 +248,7 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { cache(c) } - SeqUtils.cartesianProduct(seqs).map { ls => - ls.foldLeft(Set[Identifier]())(_ ++ _) - } + SeqUtils.cartesianProduct(seqs).map(_.flatten.toSet) } allProgramsFor(Seq(rootC)) @@ -287,7 +300,7 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { case (e, alt) => IfExpr(alt._1.toVariable, exprOf(alt), e) } } else { - Error(c.getType, "Impossibru") + Error(c.getType, s"Empty production rule: $c") } cToFd(c).fullBody = body @@ -325,11 +338,10 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { solFd.fullBody = Ensuring( FunctionInvocation(cTreeFd.typed, p.as.map(_.toVariable)), - Lambda(p.xs.map(ValDef(_)), p.phi) + Lambda(p.xs.map(ValDef), p.phi) ) - - phiFd.body = Some( + phiFd.body = Some( letTuple(p.xs, FunctionInvocation(cTreeFd.typed, p.as.map(_.toVariable)), p.phi) @@ -373,46 +385,56 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { private val innerPhi = outerExprToInnerExpr(p.phi) private var programCTree: Program = _ - private var tester: (Example, Set[Identifier]) => EvaluationResults.Result[Expr] = _ + + private var evaluator: DefaultEvaluator = _ private def setCExpr(cTreeInfo: (Expr, Seq[FunDef])): Unit = { val (cTree, newFds) = cTreeInfo cTreeFd.body = Some(cTree) programCTree = addFunDefs(innerProgram, newFds, cTreeFd) + evaluator = new DefaultEvaluator(sctx.context, programCTree) //println("-- "*30) //println(programCTree.asString) //println(".. "*30) + } - //val evaluator = new DualEvaluator(sctx.context, programCTree, CodeGenParams.default) - val evaluator = new DefaultEvaluator(sctx.context, programCTree) - - tester = - { (ex: Example, bValues: Set[Identifier]) => - // TODO: Test output value as well - val envMap = bs.map(b => b -> BooleanLiteral(bValues(b))).toMap - - ex match { - case InExample(ins) => - val fi = FunctionInvocation(phiFd.typed, ins) - evaluator.eval(fi, envMap) + def testForProgram(bValues: Set[Identifier])(ex: Example): Boolean = { - case InOutExample(ins, outs) => - val fi = FunctionInvocation(cTreeFd.typed, ins) - val eq = equality(fi, tupleWrap(outs)) - evaluator.eval(eq, envMap) - } - } - } + val origImpl = cTreeFd.fullBody + val outerSol = getExpr(bValues) + val innerSol = outerExprToInnerExpr(outerSol) + val cnstr = letTuple(p.xs, innerSol, innerPhi) + cTreeFd.fullBody = innerSol + + timers.testForProgram.start() + val res = ex match { + case InExample(ins) => + evaluator.eval(cnstr, p.as.zip(ins).toMap) + + case InOutExample(ins, outs) => + val eq = equality(innerSol, tupleWrap(outs)) + evaluator.eval(eq, p.as.zip(ins).toMap) + } + timers.testForProgram.stop() + cTreeFd.fullBody = origImpl - def testForProgram(bValues: Set[Identifier])(ex: Example): Boolean = { - tester(ex, bValues) match { + res match { case EvaluationResults.Successful(res) => res == BooleanLiteral(true) case EvaluationResults.RuntimeError(err) => + /*if (err.contains("Empty production rule")) { + println(programCTree.asString) + println(bValues) + println(ex) + println(this.getExpr(bValues)) + (new Throwable).printStackTrace() + println(err) + println() + }*/ sctx.reporter.debug("RE testing CE: "+err) false @@ -420,18 +442,18 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { sctx.reporter.debug("Error testing CE: "+err) false } - } - + } // Returns the outer expression corresponding to a B-valuation def getExpr(bValues: Set[Identifier]): Expr = { + def getCValue(c: Identifier): Expr = { cTree(c).find(i => bValues(i._1)).map { case (b, builder, cs) => builder(cs.map(getCValue)) }.getOrElse { - simplestValue(c.getType) + Error(c.getType, "Impossible assignment of bs") } } @@ -445,60 +467,70 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { def validatePrograms(bss: Set[Set[Identifier]]): Either[Stream[Solution], Seq[Seq[Expr]]] = { val origImpl = cTreeFd.fullBody - val cexs = for (bs <- bss.toSeq) yield { + var cexs = Seq[Seq[Expr]]() + + for (bs <- bss.toSeq) { val outerSol = getExpr(bs) val innerSol = outerExprToInnerExpr(outerSol) - + //println(s"Testing $outerSol") cTreeFd.fullBody = innerSol val cnstr = and(innerPc, letTuple(p.xs, innerSol, Not(innerPhi))) - //println("Solving for: "+cnstr.asString) + val eval = new DefaultEvaluator(ctx, innerProgram) - val solverf = SolverFactory.getFromSettings(ctx, innerProgram).withTimeout(cexSolverTo) - val solver = solverf.getNewSolver() - try { - solver.assertCnstr(cnstr) - solver.check match { - case Some(true) => - excludeProgram(bs, true) - val model = solver.getModel - //println("Found counter example: ") - //for ((s, v) <- model) { - // println(" "+s.asString+" -> "+v.asString) - //} - - //val evaluator = new DefaultEvaluator(ctx, prog) - //println(evaluator.eval(cnstr, model)) - - Some(p.as.map(a => model.getOrElse(a, simplestValue(a.getType)))) - - case Some(false) => - // UNSAT, valid program - return Left(Stream(Solution(BooleanLiteral(true), Set(), outerSol, true))) + if (cexs exists (cex => eval.eval(cnstr, p.as.zip(cex).toMap).result == Some(BooleanLiteral(true)))) { + //println(s"Program $outerSol fails!") + excludeProgram(bs, true) + cTreeFd.fullBody = origImpl + } else { + //println("Solving for: "+cnstr.asString) + + val solverf = SolverFactory.getFromSettings(ctx, innerProgram).withTimeout(cexSolverTo) + val solver = solverf.getNewSolver() + try { + solver.assertCnstr(cnstr) + solver.check match { + case Some(true) => + excludeProgram(bs, true) + val model = solver.getModel + //println("Found counter example: ") + //for ((s, v) <- model) { + // println(" "+s.asString+" -> "+v.asString) + //} + + //val evaluator = new DefaultEvaluator(ctx, prog) + //println(evaluator.eval(cnstr, model)) + //println(s"Program $outerSol fails with cex ${p.as.map(a => model.getOrElse(a, simplestValue(a.getType)))}") + cexs +:= p.as.map(a => model.getOrElse(a, simplestValue(a.getType))) + + case Some(false) => + // UNSAT, valid program + return Left(Stream(Solution(BooleanLiteral(true), Set(), outerSol, true))) - case None => - if (useOptTimeout) { - // Interpret timeout in CE search as "the candidate is valid" - sctx.reporter.info("CEGIS could not prove the validity of the resulting expression") - // Optimistic valid solution - return Left(Stream(Solution(BooleanLiteral(true), Set(), outerSol, false))) - } else { - None - } + case None => + if (useOptTimeout) { + // Interpret timeout in CE search as "the candidate is valid" + sctx.reporter.info("CEGIS could not prove the validity of the resulting expression") + // Optimistic valid solution + return Left(Stream(Solution(BooleanLiteral(true), Set(), outerSol, false))) + } + } + } finally { + solverf.reclaim(solver) + solverf.shutdown() + cTreeFd.fullBody = origImpl } - } finally { - solverf.reclaim(solver) - solverf.shutdown() - cTreeFd.fullBody = origImpl } } - Right(cexs.flatten) + Right(cexs) } var excludedPrograms = ArrayBuffer[Set[Identifier]]() + def allProgramsClosed = allProgramsCount() <= excludedPrograms.size + // Explicitly remove program computed by bValues from the search space // // If the bValues comes from models, we make sure the bValues we exclude @@ -542,9 +574,6 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { //println(" --- Constraints ---") //println(" - "+toFind.asString) try { - //TODO: WHAT THE F IS THIS? - //val bsOrNotBs = andJoin(bsOrdered.map(b => if (bs(b)) b.toVariable else Not(b.toVariable))) - //solver.assertCnstr(bsOrNotBs) solver.assertCnstr(toFind) for ((c, alts) <- cTree) { @@ -660,9 +689,6 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { ndProgram.init() var unfolding = 1 - val maxUnfoldings = params.maxUnfoldings - - sctx.reporter.debug(s"maxUnfoldings=$maxUnfoldings") var baseExampleInputs: ArrayBuffer[Example] = new ArrayBuffer[Example]() @@ -670,7 +696,7 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { ndProgram.grammar.printProductions(printer) } - // We populate the list of examples with a predefined one + // We populate the list of examples with a defined one sctx.reporter.debug("Acquiring initial list of examples") baseExampleInputs ++= p.eb.examples @@ -708,7 +734,6 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { } } - /** * We generate tests for discarding potential programs */ @@ -738,8 +763,6 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { try { do { - var skipCESearch = false - // Unfold formula ndProgram.unfold() @@ -748,6 +771,7 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { val nInitial = prunedPrograms.size sctx.reporter.debug("#Programs: "+nInitial) + //sctx.reporter.ifDebug{ printer => // val limit = 100 @@ -764,34 +788,33 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { // We further filter the set of working programs to remove those that fail on known examples if (hasInputExamples) { + timers.filter.start() for (bs <- prunedPrograms if !interruptManager.isInterrupted) { - var valid = true val examples = allInputExamples() - while(valid && examples.hasNext) { - val e = examples.next() - if (!ndProgram.testForProgram(bs)(e)) { - failedTestsStats(e) += 1 - sctx.reporter.debug(f" Program: ${ndProgram.getExpr(bs).asString}%-80s failed on: ${e.asString}") - wrongPrograms += bs - prunedPrograms -= bs - - valid = false - } + examples.find(e => !ndProgram.testForProgram(bs)(e)).foreach { e => + failedTestsStats(e) += 1 + sctx.reporter.debug(f" Program: ${ndProgram.getExpr(bs).asString}%-80s failed on: ${e.asString}") + wrongPrograms += bs + prunedPrograms -= bs } if (wrongPrograms.size+1 % 1000 == 0) { sctx.reporter.debug("..."+wrongPrograms.size) } } + timers.filter.stop() } val nPassing = prunedPrograms.size - sctx.reporter.debug("#Programs passing tests: "+nPassing) + val nTotal = ndProgram.allProgramsCount() + //println(s"Iotal: $nTotal, passing: $nPassing") + + sctx.reporter.debug(s"#Programs passing tests: $nPassing out of $nTotal") sctx.reporter.ifDebug{ printer => - for (p <- prunedPrograms.take(10)) { + for (p <- prunedPrograms.take(100)) { printer(" - "+ndProgram.getExpr(p).asString) } - if(nPassing > 10) { + if(nPassing > 100) { printer(" - ...") } } @@ -805,94 +828,86 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { } } + // We can skip CE search if - we have excluded all programs or - we do so with validatePrograms + var skipCESearch = nPassing == 0 || interruptManager.isInterrupted || { + // If the number of pruned programs is very small, or by far smaller than the number of total programs, + // we hypothesize it will be easier to just validate them individually. + // Otherwise, we validate a small number of programs just in case we are lucky FIXME is this last clause useful? + val (programsToValidate, otherPrograms) = if (nTotal / nPassing > passingRatio || nPassing < 10) { + (prunedPrograms, Nil) + } else { + prunedPrograms.splitAt(validateUpTo) + } - if (nPassing == 0 || interruptManager.isInterrupted) { - // No test passed, we can skip solver and unfold again, if possible - skipCESearch = true - } else { - var doFilter = true - - if (validateUpTo > 0) { - // Validate the first N programs individualy - ndProgram.validatePrograms(prunedPrograms.take(validateUpTo)) match { - case Left(sols) if sols.nonEmpty => - doFilter = false - result = Some(RuleClosed(sols)) - case Right(cexs) => - baseExampleInputs ++= cexs.map(InExample) - - if (nPassing <= validateUpTo) { - // All programs failed verification, we filter everything out and unfold - doFilter = false - skipCESearch = true + ndProgram.validatePrograms(programsToValidate) match { + case Left(sols) if sols.nonEmpty => + // Found solution! Exit CEGIS + result = Some(RuleClosed(sols)) + true + case Right(cexs) => + // Found some counterexamples + val newCexs = cexs.map(InExample) + baseExampleInputs ++= newCexs + // Retest whether the newly found C-E invalidates some programs + for (p <- otherPrograms if !interruptManager.isInterrupted) { + // Exclude any programs that fail at least one new cex + newCexs.find { cex => !ndProgram.testForProgram(p)(cex) }.foreach { cex => + failedTestsStats(cex) += 1 + ndProgram.excludeProgram(p, true) } - } + } + // If we excluded all programs, we can skip CE search + programsToValidate.size >= nPassing } + } - if (doFilter) { - sctx.reporter.debug("Excluding "+wrongPrograms.size+" programs") - wrongPrograms.foreach { - ndProgram.excludeProgram(_, true) - } + if (!skipCESearch) { + sctx.reporter.debug("Excluding "+wrongPrograms.size+" programs") + wrongPrograms.foreach { + ndProgram.excludeProgram(_, true) } } // CEGIS Loop at a given unfolding level - while (result.isEmpty && !skipCESearch && !interruptManager.isInterrupted) { + while (result.isEmpty && !skipCESearch && !interruptManager.isInterrupted && !ndProgram.allProgramsClosed) { + timers.loop.start() ndProgram.solveForTentativeProgram() match { case Some(Some(bs)) => - // Should we validate this program with Z3? - - val validateWithZ3 = if (hasInputExamples) { - - if (allInputExamples().forall(ndProgram.testForProgram(bs))) { - // All valid inputs also work with this, we need to - // make sure by validating this candidate with z3 - true - } else { - println("testing failed ?!") - // One valid input failed with this candidate, we can skip + // No inputs to test or all valid inputs also work with this. + // We need to make sure by validating this candidate with z3 + sctx.reporter.debug("Found tentative model, need to validate!") + ndProgram.solveForCounterExample(bs) match { + case Some(Some(inputsCE)) => + sctx.reporter.debug("Found counter-example:" + inputsCE) + val ce = InExample(inputsCE) + // Found counter example! Exclude this program + baseExampleInputs += ce ndProgram.excludeProgram(bs, false) - false - } - } else { - // No inputs or capability to test, we need to ask Z3 - true - } - sctx.reporter.debug("Found tentative model (Validate="+validateWithZ3+")!") - - if (validateWithZ3) { - ndProgram.solveForCounterExample(bs) match { - case Some(Some(inputsCE)) => - sctx.reporter.debug("Found counter-example:"+inputsCE) - val ce = InExample(inputsCE) - // Found counter example! - baseExampleInputs += ce - - // Retest whether the newly found C-E invalidates all programs - if (prunedPrograms.forall(p => !ndProgram.testForProgram(p)(ce))) { - skipCESearch = true - } else { - ndProgram.excludeProgram(bs, false) - } - - case Some(None) => - // Found no counter example! Program is a valid solution + + // Retest whether the newly found C-E invalidates some programs + prunedPrograms.foreach { p => + if (!ndProgram.testForProgram(p)(ce)) ndProgram.excludeProgram(p, true) + } + + case Some(None) => + // Found no counter example! Program is a valid solution + val expr = ndProgram.getExpr(bs) + result = Some(RuleClosed(Solution(BooleanLiteral(true), Set(), expr))) + + case None => + // We are not sure + sctx.reporter.debug("Unknown") + if (useOptTimeout) { + // Interpret timeout in CE search as "the candidate is valid" + sctx.reporter.info("CEGIS could not prove the validity of the resulting expression") val expr = ndProgram.getExpr(bs) - result = Some(RuleClosed(Solution(BooleanLiteral(true), Set(), expr))) - - case None => - // We are not sure - sctx.reporter.debug("Unknown") - if (useOptTimeout) { - // Interpret timeout in CE search as "the candidate is valid" - sctx.reporter.info("CEGIS could not prove the validity of the resulting expression") - val expr = ndProgram.getExpr(bs) - result = Some(RuleClosed(Solution(BooleanLiteral(true), Set(), expr, isTrusted = false))) - } else { - result = Some(RuleFailed()) - } - } + result = Some(RuleClosed(Solution(BooleanLiteral(true), Set(), expr, isTrusted = false))) + } else { + // Ok, we failed to validate, exclude this program + ndProgram.excludeProgram(bs, false) + // TODO: Make CEGIS fail early when it fails on 1 program? + // result = Some(RuleFailed()) + } } case Some(None) => @@ -901,11 +916,14 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { case None => result = Some(RuleFailed()) } + + timers.loop.stop() } unfolding += 1 - } while(unfolding <= maxUnfoldings && result.isEmpty && !interruptManager.isInterrupted) + } while(unfolding <= maxSize && result.isEmpty && !interruptManager.isInterrupted) + if (interruptManager.isInterrupted) interruptManager.recoverInterrupt() result.getOrElse(RuleFailed()) } catch { diff --git a/src/main/scala/leon/synthesis/rules/CEGLESS.scala b/src/main/scala/leon/synthesis/rules/CEGLESS.scala index c12edac075bc8525d395d5f792ef4579c0d109f1..36cc7f9e65dae8af9d8c17d4db936dc4400c0ece 100644 --- a/src/main/scala/leon/synthesis/rules/CEGLESS.scala +++ b/src/main/scala/leon/synthesis/rules/CEGLESS.scala @@ -4,10 +4,10 @@ package leon package synthesis package rules +import leon.grammars.transformers.Union import purescala.ExprOps._ import purescala.Types._ import purescala.Extractors._ -import utils._ import grammars._ import Witnesses._ @@ -24,7 +24,7 @@ case object CEGLESS extends CEGISLike[NonTerminal[String]]("CEGLESS") { val inputs = p.as.map(_.toVariable) sctx.reporter.ifDebug { printer => - printer("Guides available:") + printer("Guides available:") for (g <- guides) { printer(" - "+g.asString(ctx)) } @@ -35,7 +35,8 @@ case object CEGLESS extends CEGISLike[NonTerminal[String]]("CEGLESS") { CegisParams( grammar = guidedGrammar, rootLabel = { (tpe: TypeTree) => NonTerminal(tpe, "G0") }, - maxUnfoldings = (0 +: guides.map(depth(_) + 1)).max + optimizations = false, + maxSize = Some((0 +: guides.map(depth(_) + 1)).max) ) } } diff --git a/src/main/scala/leon/synthesis/rules/DetupleInput.scala b/src/main/scala/leon/synthesis/rules/DetupleInput.scala index 2ae2b1d5d0292a6ed725055e61b1b4af4100a63c..d3b4c823dd7110763316d121407bcf94820c5826 100644 --- a/src/main/scala/leon/synthesis/rules/DetupleInput.scala +++ b/src/main/scala/leon/synthesis/rules/DetupleInput.scala @@ -83,7 +83,7 @@ case object DetupleInput extends NormalizingRule("Detuple In") { } } - var eb = p.qeb.mapIns { info => + val eb = p.qeb.mapIns { info => List(info.flatMap { case (id, v) => ebMapInfo.get(id) match { case Some(m) => @@ -103,7 +103,8 @@ case object DetupleInput extends NormalizingRule("Detuple In") { case CaseClass(ct, args) => val (cts, es) = args.zip(ct.fields).map { case (CaseClassSelector(ct, e, id), field) if field.id == id => (ct, e) - case _ => return e + case _ => + return e }.unzip if (cts.distinct.size == 1 && es.distinct.size == 1) { @@ -126,7 +127,7 @@ case object DetupleInput extends NormalizingRule("Detuple In") { val sub = Problem(newAs, subWs, subPc, subProblem, p.xs, eb) - val s = {substAll(reverseMap, _:Expr)} andThen { simplePostTransform(recompose) } + val s = (substAll(reverseMap, _:Expr)) andThen simplePostTransform(recompose) Some(decomp(List(sub), forwardMap(s), s"Detuple ${reverseMap.keySet.mkString(", ")}")) } else { diff --git a/src/main/scala/leon/synthesis/rules/StringRender.scala b/src/main/scala/leon/synthesis/rules/StringRender.scala index f03c54b560e81428851c83b86b9430d4e706e20f..7b72da5396b179a82d4f8111335b8c82b11c76c8 100644 --- a/src/main/scala/leon/synthesis/rules/StringRender.scala +++ b/src/main/scala/leon/synthesis/rules/StringRender.scala @@ -7,35 +7,28 @@ package rules import scala.annotation.tailrec import scala.collection.mutable.ListBuffer -import bonsai.enumerators.MemoizedEnumerator -import leon.evaluators.DefaultEvaluator -import leon.evaluators.StringTracingEvaluator -import leon.synthesis.programsets.DirectProgramSet -import leon.synthesis.programsets.JoinProgramSet -import leon.purescala.Common.FreshIdentifier -import leon.purescala.Common.Identifier -import leon.purescala.DefOps -import leon.purescala.Definitions.FunDef -import leon.purescala.Definitions.FunDef -import leon.purescala.Definitions.ValDef -import leon.purescala.ExprOps -import leon.solvers.Model -import leon.solvers.ModelBuilder -import leon.solvers.string.StringSolver -import leon.utils.DebugSectionSynthesis +import purescala.Common._ +import purescala.Types._ import purescala.Constructors._ import purescala.Definitions._ -import purescala.ExprOps._ import purescala.Expressions._ import purescala.Extractors._ import purescala.TypeOps -import purescala.Types._ +import purescala.DefOps +import purescala.ExprOps + +import evaluators.StringTracingEvaluator +import synthesis.programsets.DirectProgramSet +import synthesis.programsets.JoinProgramSet + +import solvers.ModelBuilder +import solvers.string.StringSolver /** A template generator for a given type tree. * Extend this class using a concrete type tree, * Then use the apply method to get a hole which can be a placeholder for holes in the template. - * Each call to the ``.instantiate` method of the subsequent Template will provide different instances at each position of the hole. + * Each call to the `.instantiate` method of the subsequent Template will provide different instances at each position of the hole. */ abstract class TypedTemplateGenerator(t: TypeTree) { import StringRender.WithIds diff --git a/src/main/scala/leon/synthesis/rules/TEGIS.scala b/src/main/scala/leon/synthesis/rules/TEGIS.scala index d7ec34617ee7dc50745c3b6839511e2c00a6037e..3d496d0597e1947af0eb83504be5af449d7854f1 100644 --- a/src/main/scala/leon/synthesis/rules/TEGIS.scala +++ b/src/main/scala/leon/synthesis/rules/TEGIS.scala @@ -6,7 +6,6 @@ package rules import purescala.Types._ import grammars._ -import utils._ case object TEGIS extends TEGISLike[TypeTree]("TEGIS") { def getParams(sctx: SynthesisContext, p: Problem) = { diff --git a/src/main/scala/leon/synthesis/rules/TEGISLike.scala b/src/main/scala/leon/synthesis/rules/TEGISLike.scala index 91084ae4f6d69d055c36f0ce2c75bc4b41bfa763..93e97de6f1ad97c40def5b77c0d79fbb60282633 100644 --- a/src/main/scala/leon/synthesis/rules/TEGISLike.scala +++ b/src/main/scala/leon/synthesis/rules/TEGISLike.scala @@ -12,6 +12,7 @@ import datagen._ import evaluators._ import codegen.CodeGenParams import grammars._ +import leon.utils.GrowableIterable import scala.collection.mutable.{HashMap => MutableMap} @@ -40,7 +41,7 @@ abstract class TEGISLike[T <: Typed](name: String) extends Rule(name) { val nTests = if (p.pc == BooleanLiteral(true)) 50 else 20 - val useVanuatoo = sctx.settings.cegisUseVanuatoo.getOrElse(false) + val useVanuatoo = sctx.settings.cegisUseVanuatoo val inputGenerator: Iterator[Seq[Expr]] = if (useVanuatoo) { new VanuatooDataGen(sctx.context, sctx.program).generateFor(p.as, p.pc, nTests, 3000) @@ -53,8 +54,6 @@ abstract class TEGISLike[T <: Typed](name: String) extends Rule(name) { val failedTestsStats = new MutableMap[Seq[Expr], Int]().withDefaultValue(0) - def hasInputExamples = gi.nonEmpty - var n = 1 def allInputExamples() = { if (n == 10 || n == 50 || n % 500 == 0) { @@ -64,14 +63,12 @@ abstract class TEGISLike[T <: Typed](name: String) extends Rule(name) { gi.iterator } - var tests = p.eb.valids.map(_.ins).distinct - if (gi.nonEmpty) { - val evalParams = CodeGenParams.default.copy(maxFunctionInvocations = 2000) - val evaluator = new DualEvaluator(sctx.context, sctx.program, evalParams) + val evalParams = CodeGenParams.default.copy(maxFunctionInvocations = 2000) + val evaluator = new DualEvaluator(sctx.context, sctx.program, evalParams) - val enum = new MemoizedEnumerator[T, Expr, Generator[T, Expr]](grammar.getProductions) + val enum = new MemoizedEnumerator[T, Expr, ProductionRule[T, Expr]](grammar.getProductions) val targetType = tupleTypeWrap(p.xs.map(_.getType)) @@ -80,7 +77,6 @@ abstract class TEGISLike[T <: Typed](name: String) extends Rule(name) { val allExprs = enum.iterator(params.rootLabel(targetType)) var candidate: Option[Expr] = None - var n = 1 def findNext(): Option[Expr] = { candidate = None @@ -111,14 +107,9 @@ abstract class TEGISLike[T <: Typed](name: String) extends Rule(name) { candidate } - def toStream: Stream[Solution] = { - findNext() match { - case Some(e) => - Stream.cons(Solution(BooleanLiteral(true), Set(), e, isTrusted = false), toStream) - case None => - Stream.empty - } - } + val toStream = Stream.continually(findNext()).takeWhile(_.nonEmpty).map( e => + Solution(BooleanLiteral(true), Set(), e.get, isTrusted = false) + ) RuleClosed(toStream) } else { diff --git a/src/main/scala/leon/synthesis/utils/Helpers.scala b/src/main/scala/leon/synthesis/utils/Helpers.scala index acd285a4570f93ee9dd85ba3dd29a7e4b120c25a..4bfedc4acbe59440ac7f3382c8187ae201775f02 100644 --- a/src/main/scala/leon/synthesis/utils/Helpers.scala +++ b/src/main/scala/leon/synthesis/utils/Helpers.scala @@ -34,7 +34,18 @@ object Helpers { } } - def terminatingCalls(prog: Program, tpe: TypeTree, ws: Expr, pc: Expr): List[(Expr, Set[Identifier])] = { + /** Given an initial set of function calls provided by a list of [[Terminating]], + * returns function calls that will hopefully be safe to call recursively from within this initial function calls. + * + * For each returned call, one argument is substituted by a "smaller" one, while the rest are left as holes. + * + * @param prog The current program + * @param tpe The expected type for the returned function calls + * @param ws Helper predicates that contain [[Terminating]]s with the initial calls + * @param pc The path condition + * @return A list of pairs of (safe function call, holes), where holes stand for the rest of the arguments of the function. + */ + def terminatingCalls(prog: Program, tpe: TypeTree, ws: Expr, pc: Expr): List[(FunctionInvocation, Set[Identifier])] = { val TopLevelAnds(wss) = ws val TopLevelAnds(clauses) = pc diff --git a/src/main/scala/leon/utils/GrowableIterable.scala b/src/main/scala/leon/utils/GrowableIterable.scala index d05a9f06576a9e3748ba0ba5fdd33656cc9ac457..0b32fe6261b3bd41cf6bb8ad11fcc6161b47d44b 100644 --- a/src/main/scala/leon/utils/GrowableIterable.scala +++ b/src/main/scala/leon/utils/GrowableIterable.scala @@ -1,4 +1,4 @@ -package leon +package leon.utils import scala.collection.mutable.ArrayBuffer diff --git a/src/main/scala/leon/utils/SeqUtils.scala b/src/main/scala/leon/utils/SeqUtils.scala index 002f2ebedc8a6dfb265fbf101c2185b3bfa17ce1..f2290a68d11bc668c348af3954af16b95f0f7d88 100644 --- a/src/main/scala/leon/utils/SeqUtils.scala +++ b/src/main/scala/leon/utils/SeqUtils.scala @@ -34,7 +34,10 @@ object SeqUtils { } def sumTo(sum: Int, arity: Int): Seq[Seq[Int]] = { - if (arity == 1) { + require(arity >= 1) + if (sum < arity) { + Nil + } else if (arity == 1) { Seq(Seq(sum)) } else { (1 until sum).flatMap{ n => @@ -42,6 +45,20 @@ object SeqUtils { } } } + + def sumToOrdered(sum: Int, arity: Int): Seq[Seq[Int]] = { + def rec(sum: Int, arity: Int): Seq[Seq[Int]] = { + require(arity > 0) + if (sum < 0) Nil + else if (arity == 1) Seq(Seq(sum)) + else for { + n <- 0 to sum / arity + rest <- rec(sum - arity * n, arity - 1) + } yield n +: rest.map(n + _) + } + + rec(sum, arity) filterNot (_.head == 0) + } } class CartesianView[+A](views: Seq[SeqView[A, Seq[A]]]) extends SeqView[Seq[A], Seq[Seq[A]]] { diff --git a/src/main/scala/leon/utils/UniqueCounter.scala b/src/main/scala/leon/utils/UniqueCounter.scala index 06a6c0bb4b1badd63df38c3285c5fd8514d249fb..7c7862747271a67d899b9a590bc2d9c5fbb7de40 100644 --- a/src/main/scala/leon/utils/UniqueCounter.scala +++ b/src/main/scala/leon/utils/UniqueCounter.scala @@ -17,4 +17,5 @@ class UniqueCounter[K] { globalId } + def current = nameIds } diff --git a/src/main/scala/leon/utils/UnitElimination.scala b/src/main/scala/leon/utils/UnitElimination.scala index f4f603393728dd5a7b748c990486533b1cd18db6..45fa8bea46c71643c68a5a69f5f18e9318c4c449 100644 --- a/src/main/scala/leon/utils/UnitElimination.scala +++ b/src/main/scala/leon/utils/UnitElimination.scala @@ -125,7 +125,7 @@ object UnitElimination extends TransformationPhase { } } - LetDef(newFds, rest) + letDef(newFds, rest) } case ite@IfExpr(cond, tExpr, eExpr) => diff --git a/src/main/scala/leon/verification/InjectAsserts.scala b/src/main/scala/leon/verification/InjectAsserts.scala index 4e126827bd6cf352692c43e8433857b8894615d4..1bd9a695788877bd2a0034ec151294daddd5ab59 100644 --- a/src/main/scala/leon/verification/InjectAsserts.scala +++ b/src/main/scala/leon/verification/InjectAsserts.scala @@ -8,7 +8,6 @@ import Expressions._ import ExprOps._ import Definitions._ import Constructors._ -import xlang.Expressions._ object InjectAsserts extends SimpleLeonPhase[Program, Program] { diff --git a/src/main/scala/leon/xlang/Expressions.scala b/src/main/scala/leon/xlang/Expressions.scala index d627e0d284f4933cd6ed7ecefbe7dd13f4e658f8..98214ee640bd95227c0b759113ab74d2c9555d94 100644 --- a/src/main/scala/leon/xlang/Expressions.scala +++ b/src/main/scala/leon/xlang/Expressions.scala @@ -15,6 +15,14 @@ object Expressions { trait XLangExpr extends Expr + case class Old(id: Identifier) extends XLangExpr with Terminal with PrettyPrintable { + val getType = id.getType + + def printWith(implicit pctx: PrinterContext): Unit = { + p"old($id)" + } + } + case class Block(exprs: Seq[Expr], last: Expr) extends XLangExpr with Extractable with PrettyPrintable { def extract: Option[(Seq[Expr], (Seq[Expr])=>Expr)] = { Some((exprs :+ last, exprs => Block(exprs.init, exprs.last))) diff --git a/src/main/scala/leon/xlang/ImperativeCodeElimination.scala b/src/main/scala/leon/xlang/ImperativeCodeElimination.scala index 5feb45dda12662fcd904df25453c5f41e76beb6e..9a4e300e781cd1ab341160fa0097e731c757f642 100644 --- a/src/main/scala/leon/xlang/ImperativeCodeElimination.scala +++ b/src/main/scala/leon/xlang/ImperativeCodeElimination.scala @@ -218,7 +218,7 @@ object ImperativeCodeElimination extends UnitPhase[Program] { case LetDef(fds, b) => if(fds.size > 1) { - //TODO: no support for true mutually recursion + //TODO: no support for true mutual recursion toFunction(LetDef(Seq(fds.head), LetDef(fds.tail, b))) } else { diff --git a/src/test/scala/leon/regression/synthesis/SynthesisSuite.scala b/src/test/scala/leon/regression/synthesis/SynthesisSuite.scala index c70df950e0768ca71dd7bba013ac04cd7404edab..ca2b4a3c98107c73c9244486200dd0a44a348cb2 100644 --- a/src/test/scala/leon/regression/synthesis/SynthesisSuite.scala +++ b/src/test/scala/leon/regression/synthesis/SynthesisSuite.scala @@ -251,6 +251,7 @@ object SortedList { case "insertSorted" => Decomp("Assert isSorted(in1)", List( Decomp("ADT Split on 'in1'", List( + Close("CEGIS"), Decomp("Ineq. Split on 'head*' and 'v*'", List( Close("CEGIS"), Decomp("Equivalent Inputs *", List( @@ -259,8 +260,7 @@ object SortedList { )) )), Close("CEGIS") - )), - Close("CEGIS") + )) )) )) } diff --git a/testcases/synthesis/etienne-thesis/BatchedQueue/Enqueue.scala b/testcases/synthesis/etienne-thesis/BatchedQueue/Enqueue.scala index fe01946d158153d2dd9ae2a3be2234ee4cd18aa9..0f30a5ba1a95d39e78a1594f39804c8161e919a6 100644 --- a/testcases/synthesis/etienne-thesis/BatchedQueue/Enqueue.scala +++ b/testcases/synthesis/etienne-thesis/BatchedQueue/Enqueue.scala @@ -72,17 +72,12 @@ object BatchedQueue { def enqueue(v: T): Queue[T] = { require(invariant) - f match { - case Cons(h, t) => - Queue(f, Cons(v, r)) - case Nil() => - Queue(Cons(v, f), Nil()) - } - + ???[Queue[T]] } ensuring { (res: Queue[T]) => - res.invariant && - res.toList.last == v && - res.content == this.content ++ Set(v) + res.invariant && + res.toList.last == v && + res.size == size + 1 && + res.content == this.content ++ Set(v) } } } diff --git a/testcases/synthesis/etienne-thesis/run.sh b/testcases/synthesis/etienne-thesis/run.sh index ee64d86702076bf5ff909c3437f321498a2afe68..924b99cc57386f1dba92bfb97017b41a801cd8ea 100755 --- a/testcases/synthesis/etienne-thesis/run.sh +++ b/testcases/synthesis/etienne-thesis/run.sh @@ -1,7 +1,7 @@ #!/bin/bash function run { - cmd="./leon --debug=report --timeout=30 --synthesis $1" + cmd="./leon --debug=report --timeout=30 --synthesis --cegis:maxsize=5 $1" echo "Running " $cmd echo "------------------------------------------------------------------------------------------------------------------" $cmd; @@ -35,9 +35,9 @@ run testcases/synthesis/etienne-thesis/UnaryNumerals/Distinct.scala run testcases/synthesis/etienne-thesis/UnaryNumerals/Mult.scala # BatchedQueue -#run testcases/synthesis/etienne-thesis/BatchedQueue/Enqueue.scala +run testcases/synthesis/etienne-thesis/BatchedQueue/Enqueue.scala run testcases/synthesis/etienne-thesis/BatchedQueue/Dequeue.scala # AddressBook -#run testcases/synthesis/etienne-thesis/AddressBook/Make.scala +run testcases/synthesis/etienne-thesis/AddressBook/Make.scala run testcases/synthesis/etienne-thesis/AddressBook/Merge.scala