package leon package synthesis package utils import bonsai._ import Helpers._ import purescala.Trees.{Or => LeonOr, _} import purescala.Common._ import purescala.Definitions._ import purescala.TypeTrees._ import purescala.TreeOps._ import purescala.DefOps._ import purescala.TypeTreeOps._ import purescala.Extractors._ import purescala.ScalaPrinter import scala.language.implicitConversions import scala.collection.mutable.{HashMap => MutableMap} abstract class ExpressionGrammar[T <% Typed] { type Gen = Generator[T, Expr] private[this] val cache = new MutableMap[T, Seq[Gen]]() def getProductions(t: T): Seq[Gen] = { cache.getOrElse(t, { val res = computeProductions(t) cache += t -> res res }) } def computeProductions(t: T): Seq[Gen] def filter(f: Gen => Boolean) = { val that = this; new ExpressionGrammar[T] { def computeProductions(t: T) = that.computeProductions(t).filter(f) } } final def ||(that: ExpressionGrammar[T]): ExpressionGrammar[T] = { ExpressionGrammars.Or(Seq(this, that)) } final def printProductions(printer: String => Unit) { for ((t, gs) <- cache; g <- gs) { val subs = g.subTrees.map { t => FreshIdentifier(Console.BOLD+t.toString+Console.RESET).setType(t.getType).toVariable} val gen = g.builder(subs) printer(f"${Console.BOLD}$t%30s${Console.RESET} ::= $gen") } } } object ExpressionGrammars { case class Or[T <% Typed](gs: Seq[ExpressionGrammar[T]]) extends ExpressionGrammar[T] { val subGrammars: Seq[ExpressionGrammar[T]] = gs.flatMap { case o: Or[T] => o.subGrammars case g => Seq(g) } def computeProductions(t: T): Seq[Gen] = subGrammars.flatMap(_.getProductions(t)) } case class Empty[T <% Typed]() extends ExpressionGrammar[T] { def computeProductions(t: T): Seq[Gen] = Nil } case object BaseGrammar extends ExpressionGrammar[TypeTree] { def computeProductions(t: TypeTree): Seq[Gen] = t match { case BooleanType => List( Generator(Nil, { _ => BooleanLiteral(true) }), Generator(Nil, { _ => BooleanLiteral(false) }), Generator(List(BooleanType), { case Seq(a) => Not(a) }), Generator(List(BooleanType, BooleanType), { case Seq(a, b) => And(a, b) }), Generator(List(BooleanType, BooleanType), { case Seq(a, b) => LeonOr(a, b) }), Generator(List(Int32Type, Int32Type), { case Seq(a, b) => LessThan(a, b) }), Generator(List(Int32Type, Int32Type), { case Seq(a, b) => LessEquals(a, b) }), Generator(List(Int32Type, Int32Type ), { case Seq(a, b) => Equals(a, b) }), Generator(List(BooleanType, BooleanType), { case Seq(a, b) => Equals(a, b) }) ) case Int32Type => List( Generator(Nil, { _ => IntLiteral(0) }), Generator(Nil, { _ => IntLiteral(1) }), Generator(List(Int32Type, Int32Type), { case Seq(a,b) => Plus(a, b) }), Generator(List(Int32Type, Int32Type), { case Seq(a,b) => Minus(a, b) }), Generator(List(Int32Type, Int32Type), { case Seq(a,b) => Times(a, b) }) ) case TupleType(stps) => List(Generator(stps, { sub => Tuple(sub) })) case cct: CaseClassType => List( Generator(cct.fields.map(_.getType), { case rs => CaseClass(cct, rs)} ) ) case act: AbstractClassType => act.knownCCDescendents.map { cct => Generator[TypeTree, Expr](cct.fields.map(_.getType), { case rs => CaseClass(cct, rs)} ) } case st @ SetType(base) => List( Generator(List(base), { case elems => FiniteSet(elems.toSet).setType(st) }), Generator(List(st, st), { case Seq(a, b) => SetUnion(a, b) }), Generator(List(st, st), { case Seq(a, b) => SetIntersection(a, b) }), Generator(List(st, st), { case Seq(a, b) => SetDifference(a, b) }) ) case UnitType => List( Generator(Nil, { case _ => UnitLiteral() }) ) case _ => Nil } } case object ValueGrammar extends ExpressionGrammar[TypeTree] { def computeProductions(t: TypeTree): Seq[Gen] = t match { case BooleanType => List( Generator(Nil, { _ => BooleanLiteral(true) }), Generator(Nil, { _ => BooleanLiteral(false) }) ) case Int32Type => List( Generator(Nil, { _ => IntLiteral(0) }), Generator(Nil, { _ => IntLiteral(1) }), Generator(Nil, { _ => IntLiteral(-1) }) ) case tp@TypeParameter(_) => for (ind <- (1 to 3).toList) yield Generator[TypeTree, Expr](Nil, { _ => GenericValue(tp, ind) } ) case TupleType(stps) => List(Generator(stps, { sub => Tuple(sub) })) case cct: CaseClassType => List( Generator(cct.fields.map(_.getType), { case rs => CaseClass(cct, rs)} ) ) case act: AbstractClassType => act.knownCCDescendents.map { cct => Generator[TypeTree, Expr](cct.fields.map(_.getType), { case rs => CaseClass(cct, rs)} ) } case st @ SetType(base) => List( Generator(List(base), { case elems => FiniteSet(elems.toSet).setType(st) }), Generator(List(base, base), { case elems => FiniteSet(elems.toSet).setType(st) }) ) case UnitType => List( Generator(Nil, { case _ => UnitLiteral() }) ) case _ => Nil } } case class OneOf(inputs: Seq[Expr]) extends ExpressionGrammar[TypeTree] { def computeProductions(t: TypeTree): Seq[Gen] = { inputs.collect { case i if isSubtypeOf(i.getType, t) => Generator[TypeTree, Expr](Nil, { _ => i }) } } } case class Label[T](t: TypeTree, l: T, depth: Option[Int] = None) extends Typed { def getType = t override def toString = t.toString+"#"+l+depth.map(d => "@"+d).getOrElse("") } case class SimilarTo(e: Expr, terminals: Set[Expr] = Set(), sctx: SynthesisContext, p: Problem) extends ExpressionGrammar[Label[String]] { val excludeFCalls = sctx.settings.functionsToIgnore val normalGrammar = BoundedGrammar(EmbeddedGrammar( BaseGrammar || OneOf(terminals.toSeq) || FunctionCalls(sctx.program, sctx.functionContext, p.as.map(_.getType), excludeFCalls) || SafeRecCalls(sctx.program, p.ws, p.pc), { (t: TypeTree) => Label(t, "B", None)}, { (l: Label[String]) => l.getType } ), 1) type L = Label[String] private var counter = -1; def getNext(): Int = { counter += 1; counter } lazy val allSimilar = computeSimilar(e).groupBy(_._1).mapValues(_.map(_._2)) def computeProductions(t: L): Seq[Gen] = { t match { case Label(_, "B", _) => normalGrammar.computeProductions(t) case _ => allSimilar.getOrElse(t, Nil) } } def computeSimilar(e : Expr) : Seq[(L, Gen)] = { def getLabel(t: TypeTree) = { val tpe = bestRealType(t) val c = getNext Label(tpe, "G"+c) } def isCommutative(e: Expr) = e match { case _: Plus | _: Times => true case _ => false } def rec(e: Expr, gl: L): Seq[(L, Gen)] = { def gens(e: Expr, gl: L, subs: Seq[Expr], builder: (Seq[Expr] => Expr)): Seq[(L, Gen)] = { val subGls = subs.map { s => getLabel(s.getType) } // All the subproductions for sub gl val allSubs = (subs zip subGls).flatMap { case (e, gl) => rec(e, gl) } // Inject fix at one place val injectG = for ((sgl, i) <- subGls.zipWithIndex) yield { gl -> Generator[L, Expr](Seq(sgl), { case Seq(ge) => builder(subs.updated(i, ge)) } ) } val swaps = if (subs.size > 1 && !isCommutative(e)) { (for (i <- 0 until subs.size; j <- i+1 until subs.size) yield { if (subs(i).getType == subs(j).getType) { val swapSubs = subs.updated(i, subs(j)).updated(j, subs(i)) Some(gl -> Generator[L, Expr](Seq(), { _ => builder(swapSubs) })) } else { None } }).flatten } else { Nil } allSubs ++ injectG ++ swaps } def cegis(gl: L): Seq[(L, Gen)] = { normalGrammar.getProductions(gl).map(gl -> _) } def intVariations(gl: L, e : Expr): Seq[(L, Gen)] = { Seq( gl -> Generator(Nil, { _ => Minus(e, IntLiteral(1))} ), gl -> Generator(Nil, { _ => Plus (e, IntLiteral(1))} ) ) } // Find neighbor case classes that are compatible with the arguments: // Turns And(e1, e2) into Or(e1, e2)... def ccVariations(gl: L, cc: CaseClass): Seq[(L, Gen)] = { val CaseClass(cct, args) = cc val neighbors = cct.parent.map(_.knownCCDescendents).getOrElse(Seq()).filter(_ != cct) for (scct <- neighbors if scct.fieldsTypes == cct.fieldsTypes) yield { gl -> Generator[L, Expr](Nil, { _ => CaseClass(scct, args) }) } } val subs: Seq[(L, Gen)] = (e match { case _: Terminal | _: Let | _: LetTuple | _: LetDef | _: MatchExpr => gens(e, gl, Nil, { _ => e }) ++ cegis(gl) case cc @ CaseClass(cct, exprs) => gens(e, gl, exprs, { case ss => CaseClass(cct, ss) }) ++ ccVariations(gl, cc) case FunctionInvocation(TypedFunDef(fd, _), _) if excludeFCalls contains fd => // We allow only exact call, and/or cegis extensions /*Seq(el -> Generator[L, Expr](Nil, { _ => e })) ++*/ cegis(gl) case UnaryOperator(sub, builder) => gens(e, gl, List(sub), { case Seq(s) => builder(s) }) case BinaryOperator(sub1, sub2, builder) => gens(e, gl, List(sub1, sub2), { case Seq(s1, s2) => builder(s1, s2) }) case NAryOperator(subs, builder) => gens(e, gl, subs, { case ss => builder(ss) }) }) ++ (if (e.getType == Int32Type ) intVariations(gl, e) else Nil) val terminalsMatching = terminals.collect { case IsTyped(term, tpe) if tpe == gl.getType && term != e => gl -> Generator[L, Expr](Nil, { _ => term }) } subs ++ terminalsMatching } val gl = getLabel(e.getType) val res = rec(e, gl) //for ((t, g) <- res) { // val subs = g.subTrees.map { t => FreshIdentifier(t.toString).setType(t.getType).toVariable} // val gen = g.builder(subs) // println(f"$t%30s ::= "+gen) //} res } } case class FunctionCalls(prog: Program, currentFunction: FunDef, types: Seq[TypeTree], exclude: Set[FunDef]) extends ExpressionGrammar[TypeTree] { def computeProductions(t: TypeTree): Seq[Gen] = { def getCandidates(fd: FunDef): Seq[TypedFunDef] = { // Prevents recursive calls val cfd = currentFunction val isRecursiveCall = (prog.callGraph.transitiveCallers(cfd) + cfd) contains fd val isDet = fd.body.map(isDeterministic).getOrElse(false) if (!isRecursiveCall && isDet) { val free = fd.tparams.map(_.tp) canBeSubtypeOf(fd.returnType, free, t, rhsFixed = true) match { case Some(tpsMap) => val tfd = fd.typed(free.map(tp => tpsMap.getOrElse(tp, tp))) if (tpsMap.size < free.size) { /* Some type params remain free, we want to assign them: * * List[T] => Int, for instance, will be found when * requesting Int, but we need to assign T to viable * types. For that we use list of input types as heuristic, * and look for instantiations of T such that input <?: * List[T]. */ types.distinct.flatMap { (atpe: TypeTree) => var finalFree = free.toSet -- tpsMap.keySet var finalMap = tpsMap for (ptpe <- tfd.params.map(_.tpe).distinct) { canBeSubtypeOf(atpe, finalFree.toSeq, ptpe) match { case Some(ntpsMap) => finalFree --= ntpsMap.keySet finalMap ++= ntpsMap case _ => } } if (finalFree.isEmpty) { List(fd.typed(free.map(tp => finalMap.getOrElse(tp, tp)))) } else { Nil } } } else { /* All type parameters that used to be free are assigned */ List(tfd) } case None => Nil } } else { Nil } } val funcs = functionsAvailable(prog).toSeq.flatMap(getCandidates).filterNot( tfd => exclude contains tfd.fd) funcs.map{ tfd => Generator[TypeTree, Expr](tfd.params.map(_.tpe), { sub => FunctionInvocation(tfd, sub) }) } } } case class BoundedGrammar[T](g: ExpressionGrammar[Label[T]], bound: Int) extends ExpressionGrammar[Label[T]] { def computeProductions(l: Label[T]): Seq[Gen] = g.computeProductions(l).flatMap { case g: Generator[Label[T], Expr] => if (l.depth == Some(bound) && g.subTrees.nonEmpty) { None } else if (l.depth.map(_ > bound).getOrElse(false)) { None } else { Some(Generator(g.subTrees.map(sl => sl.copy(depth = l.depth.map(_+1).orElse(Some(1)))), g.builder)) } } } case class EmbeddedGrammar[Ti <% Typed, To <% Typed](g: ExpressionGrammar[Ti], iToo: Ti => To, oToi: To => Ti) extends ExpressionGrammar[To] { def computeProductions(t: To): Seq[Gen] = g.computeProductions(oToi(t)).map { case g : Generator[Ti, Expr] => Generator(g.subTrees.map(iToo), g.builder) } } case class SafeRecCalls(prog: Program, ws: Expr, pc: Expr) extends ExpressionGrammar[TypeTree] { def computeProductions(t: TypeTree): Seq[Gen] = { val calls = terminatingCalls(prog, t, ws, pc) calls.map { case (e, free) => val freeSeq = free.toSeq Generator[TypeTree, Expr](freeSeq.map(_.getType), { sub => replaceFromIDs(freeSeq.zip(sub).toMap, e) }) } } } def default(prog: Program, inputs: Seq[Expr], currentFunction: FunDef, exclude: Set[FunDef], ws: Expr, pc: Expr): ExpressionGrammar[TypeTree] = { BaseGrammar || OneOf(inputs) || FunctionCalls(prog, currentFunction, inputs.map(_.getType), exclude) || SafeRecCalls(prog, ws, pc) } def default(sctx: SynthesisContext, p: Problem): ExpressionGrammar[TypeTree] = { default(sctx.program, p.as.map(_.toVariable), sctx.functionContext, sctx.settings.functionsToIgnore, p.ws, p.pc) } def depthBound[T <% Typed](g: ExpressionGrammar[T], b: Int) = { g.filter(g => g.subTrees.forall(t => typeDepth(t.getType) <= b)) } }