-
Emmanouil (Manos) Koukoutos authoredEmmanouil (Manos) Koukoutos authored
Problem.scala 6.81 KiB
/* Copyright 2009-2014 EPFL, Lausanne */
package leon
package synthesis
import leon.purescala.Trees._
import leon.purescala.TreeOps._
import leon.purescala.TypeTrees.TypeTree
import leon.purescala.Common._
import leon.purescala.Constructors._
// Defines a synthesis triple of the form:
// ⟦ as ⟨ C | phi ⟩ xs ⟧
case class Problem(as: List[Identifier], pc: Expr, phi: Expr, xs: List[Identifier]) {
override def toString =
"⟦ "+as.mkString(";")+", " + (if (pc != BooleanLiteral(true)) pc+" ≺ " else "") + " ⟨ "+phi+" ⟩ " + xs.mkString(";") + " ⟧ "
def getTests(sctx: SynthesisContext): Seq[Example] = {
import purescala.Extractors._
import evaluators._
val predicates = and(pc, phi)
val ev = new DefaultEvaluator(sctx.context, sctx.program)
val safePc = removeWitnesses(sctx.program)(pc)
def isValidExample(ex: Example): Boolean = {
val (mapping, cond) = ex match {
case io: InOutExample =>
(Map((as zip io.ins) ++ (xs zip io.outs): _*), And(safePc, phi))
case i =>
((as zip i.ins).toMap, safePc)
}
ev.eval(cond, mapping) match {
case EvaluationResults.Successful(BooleanLiteral(true)) => true
case _ => false
}
}
// Returns a list of identifiers, and extractors
def andThen(pf1: PartialFunction[Expr, Expr], pf2: PartialFunction[Expr, Expr]): PartialFunction[Expr, Expr] = {
Function.unlift(pf1.lift(_) flatMap pf2.lift)
}
/**
* Extract ids in ins/outs args, and compute corresponding extractors for values map
*
* Examples:
* (a,b) =>
* a -> _.1
* b -> _.2
*
* Cons(a, Cons(b, c)) =>
* a -> _.head
* b -> _.tail.head
* c -> _.tail.tail
*/
def extractIds(e: Expr): Seq[(Identifier, PartialFunction[Expr, Expr])] = e match {
case Variable(id) =>
List((id, { case e => e }))
case Tuple(vs) =>
vs.map(extractIds).zipWithIndex.flatMap{ case (ids, i) =>
ids.map{ case (id, e) =>
(id, andThen({ case Tuple(vs) => vs(i) }, e))
}
}
case CaseClass(cct, args) =>
args.map(extractIds).zipWithIndex.flatMap { case (ids, i) =>
ids.map{ case (id, e) =>
(id, andThen({ case CaseClass(cct2, vs) if cct2 == cct => vs(i) } ,e))
}
}
case _ =>
sctx.reporter.warning("Unnexpected pattern in test-ids extraction: "+e)
Nil
}
def exprToIds(e: Expr): List[Identifier] = e match {
case Variable(i) => List(i)
case Tuple(is) => is.collect { case Variable(i) => i }.toList
case _ => Nil
}
def toIOExamples(in: Expr, out : Expr, cs : MatchCase) : Seq[(Expr,Expr)] = {
import utils.ExpressionGrammars.ValueGrammar
import leon.utils.StreamUtils.cartesianProduct
import bonsai._
import bonsai.enumerators._
val examplesPerCase = 5
def doSubstitute(subs : Seq[(Identifier, Expr)], e : Expr) =
subs.foldLeft(e) {
case (from, (id, to)) => replaceFromIDs(Map(id -> to), from)
}
if (cs.optGuard.isDefined) {
sctx.reporter.error("Cannot handle guards in example extraction. @" + cs.optGuard.get.getPos)
Seq()
} else if (cs.rhs == out) {
// The trivial example
Seq()
} else {
// The pattern as expression (input expression)(may contain free variables)
val (pattExpr, ieMap) = patternToExpression(cs.pattern, in.getType)
val freeVars = variablesOf(pattExpr).toSeq
if (freeVars.isEmpty) {
// The input contains no free vars. Trivially return input-output pair
Seq((pattExpr, doSubstitute(ieMap,cs.rhs)))
} else {
// If the input contains free variables, it does not provide concrete examples.
// We will instantiate them according to a simple grammar to get them.
val enum = new MemoizedEnumerator[TypeTree, Expr](ValueGrammar.getProductions _)
val types = freeVars.map{ _.getType }
val typesWithValues = types.map { tp => (tp, enum.iterator(tp).toStream) }.toMap
val values = freeVars map { v => typesWithValues(v.getType) }
val instantiations = cartesianProduct(values) map { freeVars.zip(_).toMap }
instantiations.map { inst =>
(replaceFromIDs(inst, pattExpr), replaceFromIDs(inst, doSubstitute(ieMap, cs.rhs)))
}.take(examplesPerCase)
}
}
}
val evaluator = new DefaultEvaluator(sctx.context, sctx.program)
val testClusters = collect[Map[Identifier, Expr]] {
//case FunctionInvocation(tfd, List(in, out, FiniteMap(inouts))) if tfd.id.name == "passes" =>
case p@Passes(ins, out, cases) =>
val ioPairs = cases flatMap { toIOExamples(ins,out,_) }
val infos = extractIds(p.scrutinee)
val exs = ioPairs.map{ case (i, o) =>
val test = Tuple(Seq(i, o))
val ids = variablesOf(test)
evaluator.eval(test, ids.map { (i: Identifier) => i -> i.toVariable }.toMap) match {
case EvaluationResults.Successful(res) => res
case _ =>
test
}
}
// Check whether we can extract all ids from example
val results = exs.collect { case e if infos.forall(_._2.isDefinedAt(e)) =>
infos.map{ case (id, f) => id -> f(e) }.toMap
}
results.toSet
case _ =>
Set()
}(predicates)
/**
* we now need to consolidate different clusters of compatible tests together
* t1: a->1, c->3
* t2: a->1, b->4
* => a->1, b->4, c->3
*/
def isCompatible(m1: Map[Identifier, Expr], m2: Map[Identifier, Expr]) = {
val ks = m1.keySet & m2.keySet
ks.nonEmpty && ks.map(m1) == ks.map(m2)
}
def mergeTest(m1: Map[Identifier, Expr], m2: Map[Identifier, Expr]) = {
if (!isCompatible(m1, m2)) {
m1
} else {
m1 ++ m2
}
}
var consolidated = Set[Map[Identifier, Expr]]()
for (t <- testClusters) {
consolidated += t
consolidated = consolidated.map { c =>
mergeTest(c, t)
}
}
// Finally, we keep complete tests covering all as++xs
val allIds = (as ++ xs).toSet
val insIds = as.toSet
val outsIds = xs.toSet
val examples = consolidated.toSeq.flatMap { t =>
val ids = t.keySet
if ((ids & allIds) == allIds) {
Some(InOutExample(as.map(t), xs.map(t)))
} else if ((ids & insIds) == insIds) {
Some(InExample(as.map(t)))
} else {
None
}
}
examples.filter(isValidExample)
}
}
object Problem {
def fromChoose(ch: Choose, pc: Expr = BooleanLiteral(true)): Problem = {
val xs = ch.vars
val phi = simplifyLets(ch.pred)
val as = (variablesOf(And(pc, phi))--xs).toList
Problem(as, pc, phi, xs)
}
}