diff --git a/src/main/scala/leon/synthesis/rules/Cegis.scala b/src/main/scala/leon/synthesis/rules/Cegis.scala index d109c73fb987417b4694094a27f1bb92c99470f6..98ca2a4b21ac6b13ef0c867f66d22fd9dfee7e66 100644 --- a/src/main/scala/leon/synthesis/rules/Cegis.scala +++ b/src/main/scala/leon/synthesis/rules/Cegis.scala @@ -25,10 +25,18 @@ import evaluators._ import datagen._ import codegen.CodeGenParams -import utils.ExpressionGrammar +import utils._ +case object CEGIS extends CEGISLike("CEGIS") { + def getGrammar(sctx: SynthesisContext, p: Problem) = { + ExpressionGrammars.default(sctx, p) + } +} + + +abstract class CEGISLike(name: String) extends Rule(name) { + def getGrammar(sctx: SynthesisContext, p: Problem): ExpressionGrammar -case object CEGIS extends Rule("CEGIS") { def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { // CEGIS Flags to actiave or de-activate features @@ -51,7 +59,7 @@ case object CEGIS extends Rule("CEGIS") { class NonDeterministicProgram(val p: Problem, val initGuard: Identifier) { - val grammar = new ExpressionGrammar(sctx, p) + val grammar = getGrammar(sctx, p) // b -> (c, ex) means the clause b => c == ex var mappings: Map[Identifier, (Identifier, Expr)] = Map() @@ -309,7 +317,7 @@ case object CEGIS extends Rule("CEGIS") { for ((parentGuard, recIds) <- guardedTerms; recId <- recIds) { - var alts = grammar.getGenerators(recId.getType) + var alts = grammar.getProductions(recId.getType) if (finalUnrolling) { alts = alts.filter(_.subTrees.isEmpty) } @@ -356,7 +364,7 @@ case object CEGIS extends Rule("CEGIS") { sctx.reporter.ifDebug { printer => printer("Grammar so far:"); - grammar.printGrammar(printer) + grammar.printProductions(printer) } //program = And(program :: newClauses) diff --git a/src/main/scala/leon/synthesis/rules/Tegis.scala b/src/main/scala/leon/synthesis/rules/Tegis.scala index c6f57ed0f9bb7b5383e2d45bd0f8edb4e81347b8..ada8bce10dba8afb7a1c28c8191306c56eda7f70 100644 --- a/src/main/scala/leon/synthesis/rules/Tegis.scala +++ b/src/main/scala/leon/synthesis/rules/Tegis.scala @@ -24,7 +24,7 @@ import evaluators._ import datagen._ import codegen.CodeGenParams -import utils.ExpressionGrammar +import utils._ import bonsai._ import bonsai.enumerators._ @@ -32,7 +32,7 @@ import bonsai.enumerators._ case object TEGIS extends Rule("TEGIS") { def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { - val grammar = new ExpressionGrammar(sctx, p) + val grammar = ExpressionGrammars.default(sctx, p) var tests = p.getTests(sctx).map(_.ins).distinct if (tests.nonEmpty) { @@ -46,7 +46,7 @@ case object TEGIS extends Rule("TEGIS") { val interruptManager = sctx.context.interruptManager - val enum = new MemoizedEnumerator[TypeTree, Expr](grammar.getGenerators) + val enum = new MemoizedEnumerator[TypeTree, Expr](grammar.getProductions _) val (targetType, isWrapped) = if (p.xs.size == 1) { (p.xs.head.getType, false) diff --git a/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala b/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala index 84c6487ae73c36462506523482b130bc25cbafc5..a9c606ab7fdb950ac453b8196bd3bf892f1e4394 100644 --- a/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala +++ b/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala @@ -18,257 +18,193 @@ import purescala.ScalaPrinter import scala.collection.mutable.{HashMap => MutableMap} -class ExpressionGrammar(ctx: LeonContext, prog: Program, inputs: Seq[Expr], currentFunction: FunDef, pathCondition: Expr) { - def this(sctx: SynthesisContext, p: Problem) = { - this(sctx.context, sctx.program, p.as.map(_.toVariable), sctx.functionContext, p.pc) - } - +abstract class ExpressionGrammar { type Gen = Generator[TypeTree, Expr] private[this] val cache = new MutableMap[TypeTree, Seq[Gen]]() - def getGenerators(t: TypeTree): Seq[Gen] = { + def getProductions(t: TypeTree): Seq[Gen] = { cache.getOrElse(t, { - val res = computeGenerators(t) + val res = computeProductions(t) cache += t -> res res }) } - def computeGenerators(t: TypeTree): Seq[Gen] = { - computeBaseGenerators(t) ++ - computeInputGenerators(t) ++ - computeFcallGenerators(t) ++ - computeSafeRecCalls(t) - } + def computeProductions(t: TypeTree): Seq[Gen] - def computeBaseGenerators(t: TypeTree): Seq[Gen] = t match { - case BooleanType => - List( - Generator(Nil, { _ => BooleanLiteral(true) }), - Generator(Nil, { _ => BooleanLiteral(false) }) - ) - case Int32Type => - List( - Generator(Nil, { _ => IntLiteral(0) }), - Generator(Nil, { _ => IntLiteral(1) }), - Generator(List(Int32Type, Int32Type), { case Seq(a,b) => Plus(a, b) }), - Generator(List(Int32Type, Int32Type), { case Seq(a,b) => Minus(a, b) }), - Generator(List(Int32Type, Int32Type), { case Seq(a,b) => Times(a, b) }) - ) - case TupleType(stps) => - List(Generator(stps, { sub => Tuple(sub) })) - - case cct: CaseClassType => - List( - Generator(cct.fields.map(_.getType), { case rs => CaseClass(cct, rs)} ) - ) - - case act: AbstractClassType => - act.knownCCDescendents.map { cct => - Generator[TypeTree, Expr](cct.fields.map(_.getType), { case rs => CaseClass(cct, rs)} ) - } - - case st @ SetType(base) => - List( - Generator(List(base), { case elems => FiniteSet(elems.toSet).setType(st) }), - Generator(List(st, st), { case Seq(a, b) => SetUnion(a, b) }), - Generator(List(st, st), { case Seq(a, b) => SetIntersection(a, b) }), - Generator(List(st, st), { case Seq(a, b) => SetDifference(a, b) }) - ) - - case _ => - Nil + final def ||(that: ExpressionGrammar): ExpressionGrammar = { + ExpressionGrammar.Or(Seq(this, that)) } - def computeInputGenerators(t: TypeTree): Seq[Gen] = { - inputs.collect { - case i if isSubtypeOf(i.getType, t) => Generator[TypeTree, Expr](Nil, { _ => i }) + final def printProductions(printer: String => Unit) { + for ((t, gs) <- cache; g <- gs) { + val subs = g.subTrees.map { tpe => FreshIdentifier(tpe.toString).setType(tpe).toVariable } + val gen = g.builder(subs) + + printer(f"$t%30s ::= "+gen) } } +} - def computeFcallGenerators(t: TypeTree): Seq[Gen] = { - - def getCandidates(fd: FunDef): Seq[TypedFunDef] = { - // Prevents recursive calls - val cfd = currentFunction - - val isRecursiveCall = (prog.callGraph.transitiveCallers(cfd) + cfd) contains fd - - val isNotSynthesizable = fd.body match { - case Some(b) => - !containsChoose(b) - - case None => - false - } +object ExpressionGrammar { + case class Or(gs: Seq[ExpressionGrammar]) extends ExpressionGrammar { + val subGrammars: Seq[ExpressionGrammar] = gs.flatMap { + case o: Or => o.subGrammars + case g => Seq(g) + } + def computeProductions(t: TypeTree): Seq[Gen] = + subGrammars.flatMap(_.getProductions(t)) + } +} - if (!isRecursiveCall && isNotSynthesizable) { - val free = fd.tparams.map(_.tp) - canBeSubtypeOf(fd.returnType, free, t) match { - case Some(tpsMap) => - val tfd = fd.typed(free.map(tp => tpsMap.getOrElse(tp, tp))) - - if (tpsMap.size < free.size) { - /* Some type params remain free, we want to assign them: - * - * List[T] => Int, for instance, will be found when - * requesting Int, but we need to assign T to viable - * types. For that we use problem inputs as heuristic, - * and look for instantiations of T such that input <?: - * List[T]. - */ - inputs.map(_.getType).distinct.flatMap { (atpe: TypeTree) => - var finalFree = free.toSet -- tpsMap.keySet - var finalMap = tpsMap - - for (ptpe <- tfd.params.map(_.tpe).distinct) { - canBeSubtypeOf(atpe, finalFree.toSeq, ptpe) match { - case Some(ntpsMap) => - finalFree --= ntpsMap.keySet - finalMap ++= ntpsMap - case _ => - } - } - - if (finalFree.isEmpty) { - List(fd.typed(free.map(tp => finalMap.getOrElse(tp, tp)))) - } else { - Nil - } - } - } else { - /* All type parameters that used to be free are assigned - */ - List(tfd) - } - case None => - Nil +object ExpressionGrammars { + + case object BaseGrammar extends ExpressionGrammar { + def computeProductions(t: TypeTree): Seq[Gen] = t match { + case BooleanType => + List( + Generator(Nil, { _ => BooleanLiteral(true) }), + Generator(Nil, { _ => BooleanLiteral(false) }) + ) + case Int32Type => + List( + Generator(Nil, { _ => IntLiteral(0) }), + Generator(Nil, { _ => IntLiteral(1) }), + Generator(List(Int32Type, Int32Type), { case Seq(a,b) => Plus(a, b) }), + Generator(List(Int32Type, Int32Type), { case Seq(a,b) => Minus(a, b) }), + Generator(List(Int32Type, Int32Type), { case Seq(a,b) => Times(a, b) }) + ) + case TupleType(stps) => + List(Generator(stps, { sub => Tuple(sub) })) + + case cct: CaseClassType => + List( + Generator(cct.fields.map(_.getType), { case rs => CaseClass(cct, rs)} ) + ) + + case act: AbstractClassType => + act.knownCCDescendents.map { cct => + Generator[TypeTree, Expr](cct.fields.map(_.getType), { case rs => CaseClass(cct, rs)} ) } - } else { - Nil - } - } - val funcs = functionsAvailable(prog).toSeq.flatMap(getCandidates) + case st @ SetType(base) => + List( + Generator(List(base), { case elems => FiniteSet(elems.toSet).setType(st) }), + Generator(List(st, st), { case Seq(a, b) => SetUnion(a, b) }), + Generator(List(st, st), { case Seq(a, b) => SetIntersection(a, b) }), + Generator(List(st, st), { case Seq(a, b) => SetDifference(a, b) }) + ) - funcs.map{ tfd => - Generator[TypeTree, Expr](tfd.params.map(_.tpe), { sub => FunctionInvocation(tfd, sub) }) + case _ => + Nil } } - def computeSafeRecCalls(t: TypeTree): Seq[Gen] = { - val calls = terminatingCalls(prog, t, pathCondition) - - calls.map { - case (e, free) => - val freeSeq = free.toSeq - Generator[TypeTree, Expr](freeSeq.map(_.getType), { sub => - replaceFromIDs(freeSeq.zip(sub).toMap, e) - }) + case class OneOf(inputs: Seq[Expr]) extends ExpressionGrammar { + def computeProductions(t: TypeTree): Seq[Gen] = { + inputs.collect { + case i if isSubtypeOf(i.getType, t) => Generator[TypeTree, Expr](Nil, { _ => i }) + } } } - def computeSubexpressionGenerators(canPlacehold : Expr => Boolean)(e : Expr) : Seq[Gen] = { - - /** A simple Generator API **/ - - def gen(tps : Seq[TypeTree], f : Seq[Expr] => Expr) : Gen = - Generator[TypeTree, Expr](tps,f) - - // A generator that accepts a single type, and always regenerates its input - // (simple placeholder of 1 position) - def wildcardGen(tp : TypeTree) = gen(Seq(tp), { case Seq(x) => x }) - - // A generator that always regenerates its input - def const(e: Expr) : Gen = gen(Seq(), _ => e) - - // Creates a new generator by applying f on the result of g.builder - def map(f : Expr => Expr)(g : Gen) : Gen = { - gen(g.subTrees, es => f(g.builder(es)) ) - } - - // Concatenate a sequence of generators into a generator. - // The arity of the resulting generator is the total arity of the constituting generators. - // builder is the function combining the results of the partial generators - def concat(gens : Seq[Gen], builder : Seq[Expr] => Expr ) : Gen = { - val types = gens flatMap { _.subTrees } - gen( - types, - exprs => { - assert(exprs.length == types.length) // Total arity is arity of subgenerators - var remaining = exprs - val fromSubGens = for (gen <- gens) yield { - val (current, rem) = remaining splitAt gen.arity - remaining = rem - gen.builder(current) - } - builder(fromSubGens) - } - ) - - } - - - def rec(e : Expr) : Seq[Gen] = { - - // Add an additional wildcard generator, if current expression passes the filter - def optWild(gens : Seq[Gen]) : Seq[Gen] = - if (canPlacehold(e)) { - wildcardGen(e.getType) +: gens - } - else gens - - - e match { - - case t : Terminal => - // In case of Terminal, we either return the terminal itself, or the input expression - optWild(Seq(const(t))) - - case UnaryOperator(sub, builder) => - val fromSub = for (subGen <- rec(sub)) yield map(builder)(subGen) - optWild(fromSub) - - case BinaryOperator(e1,e2,builder) => - val fromSub = for { - subGen1 <- rec(e1) - subGen2 <- rec(e2) - } yield concat(Seq(subGen1, subGen2), { case Seq(e1,e2) => builder(e1,e2) }) - - optWild(fromSub) - - case NAryOperator(subExpressions, builder) => + case class FunctionCalls(prog: Program, currentFunction: FunDef, types: Seq[TypeTree]) extends ExpressionGrammar { + def computeProductions(t: TypeTree): Seq[Gen] = { + + def getCandidates(fd: FunDef): Seq[TypedFunDef] = { + // Prevents recursive calls + val cfd = currentFunction + + val isRecursiveCall = (prog.callGraph.transitiveCallers(cfd) + cfd) contains fd + + val isNotSynthesizable = fd.body match { + case Some(b) => + !containsChoose(b) + + case None => + false + } + + + if (!isRecursiveCall && isNotSynthesizable) { + val free = fd.tparams.map(_.tp) + canBeSubtypeOf(fd.returnType, free, t) match { + case Some(tpsMap) => + val tfd = fd.typed(free.map(tp => tpsMap.getOrElse(tp, tp))) + + if (tpsMap.size < free.size) { + /* Some type params remain free, we want to assign them: + * + * List[T] => Int, for instance, will be found when + * requesting Int, but we need to assign T to viable + * types. For that we use list of input types as heuristic, + * and look for instantiations of T such that input <?: + * List[T]. + */ + types.distinct.flatMap { (atpe: TypeTree) => + var finalFree = free.toSet -- tpsMap.keySet + var finalMap = tpsMap + + for (ptpe <- tfd.params.map(_.tpe).distinct) { + canBeSubtypeOf(atpe, finalFree.toSeq, ptpe) match { + case Some(ntpsMap) => + finalFree --= ntpsMap.keySet + finalMap ++= ntpsMap + case _ => + } + } + + if (finalFree.isEmpty) { + List(fd.typed(free.map(tp => finalMap.getOrElse(tp, tp)))) + } else { + Nil + } + } + } else { + /* All type parameters that used to be free are assigned + */ + List(tfd) + } + case None => + Nil + } + } else { + Nil + } + } + + val funcs = functionsAvailable(prog).toSeq.flatMap(getCandidates) + + funcs.map{ tfd => + Generator[TypeTree, Expr](tfd.params.map(_.tpe), { sub => FunctionInvocation(tfd, sub) }) + } + } + } - def combinations[A](seqs : Seq[Seq[A]]) : Seq[Seq[A]] = { - if (seqs.isEmpty) Seq(Seq()) - else for { - hd <- seqs.head - tl <- combinations(seqs.tail) - } yield hd +: tl - } + case class SafeRecCalls(prog: Program, pc: Expr) extends ExpressionGrammar { + def computeProductions(t: TypeTree): Seq[Gen] = { + val calls = terminatingCalls(prog, t, pc) - val combos = combinations(subExpressions map rec) - val fromSub = combos map { concat(_, builder) } - - optWild(fromSub) + calls.map { + case (e, free) => + val freeSeq = free.toSeq + Generator[TypeTree, Expr](freeSeq.map(_.getType), { sub => + replaceFromIDs(freeSeq.zip(sub).toMap, e) + }) } } - - rec(e) - } - def computeCompleteSubexpressionGenerators = inputs flatMap computeSubexpressionGenerators{ _ => true} - - - def printGrammar(printer: String => Unit) { - for ((t, gs) <- cache; g <- gs) { - val subs = g.subTrees.map { tpe => FreshIdentifier(tpe.toString).setType(tpe).toVariable } - val gen = g.builder(subs) + def default(prog: Program, inputs: Seq[Expr], currentFunction: FunDef, pc: Expr): ExpressionGrammar = { + BaseGrammar || + OneOf(inputs) || + FunctionCalls(prog, currentFunction, inputs.map(_.getType)) || + SafeRecCalls(prog, pc) + } - printer(f"$t%30s ::= "+gen) - } + def default(sctx: SynthesisContext, p: Problem): ExpressionGrammar = { + default(sctx.program, p.as.map(_.toVariable), sctx.functionContext, p.pc) } }