From 895a196227835a46702170124a01a6b9771b9aae Mon Sep 17 00:00:00 2001 From: Etienne Kneuss <colder@php.net> Date: Mon, 16 Jun 2014 12:08:44 +0200 Subject: [PATCH] Extract tests from specification: passes(in,out)(Map(1 -> 2)) --- library/lang/package.scala | 8 ++ .../frontends/scalac/CodeExtraction.scala | 22 +++- .../scala/leon/purescala/Constructors.scala | 27 +++++ .../scala/leon/synthesis/InOutExample.scala | 13 ++ src/main/scala/leon/synthesis/Problem.scala | 113 ++++++++++++++++++ src/main/scala/leon/synthesis/Rules.scala | 24 ++-- .../scala/leon/synthesis/rules/Cegis.scala | 29 ++++- 7 files changed, 218 insertions(+), 18 deletions(-) create mode 100644 src/main/scala/leon/purescala/Constructors.scala create mode 100644 src/main/scala/leon/synthesis/InOutExample.scala diff --git a/library/lang/package.scala b/library/lang/package.scala index 41cd25b65..a3600b868 100644 --- a/library/lang/package.scala +++ b/library/lang/package.scala @@ -27,4 +27,12 @@ package object lang { @ignore def error[T](reason: String): T = sys.error(reason) + + def passes[A, B](in: A, out: B)(tests: Map[A,B]): Boolean = { + if (tests contains in) { + tests(in) == out + } else { + true + } + } } diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala index 3a98f859a..9826e01a0 100644 --- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala +++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala @@ -394,8 +394,11 @@ trait CodeExtraction extends ASTExtractors { // We collect the methods for (d <- tmpl.body) d match { + case EmptyTree => + // ignore + case t if isIgnored(t.symbol) => - //ignore + // ignore case t @ ExFunctionDef(fsym, _, _, _, _) if !fsym.isSynthetic && !fsym.isAccessor => if (parent.isDefined) { @@ -1045,10 +1048,6 @@ trait CodeExtraction extends ASTExtractors { case ExOr(l, r) => Or(extractTree(l), extractTree(r)) case ExNot(e) => Not(extractTree(e)) case ExUMinus(e) => UMinus(extractTree(e)) - case ExPlus(l, r) => Plus(extractTree(l), extractTree(r)) - case ExMinus(l, r) => Minus(extractTree(l), extractTree(r)) - case ExTimes(l, r) => Times(extractTree(l), extractTree(r)) - case ExDiv(l, r) => Division(extractTree(l), extractTree(r)) case ExMod(l, r) => Modulo(extractTree(l), extractTree(r)) case ExNotEquals(l, r) => Not(Equals(extractTree(l), extractTree(r))) case ExGreaterThan(l, r) => GreaterThan(extractTree(l), extractTree(r)) @@ -1214,6 +1213,19 @@ trait CodeExtraction extends ASTExtractors { CaseClassSelector(cct, rec, fieldID) + // Int methods + case (IsTyped(a1, Int32Type), "+", List(IsTyped(a2, Int32Type))) => + Plus(a1, a2) + + case (IsTyped(a1, Int32Type), "-", List(IsTyped(a2, Int32Type))) => + Minus(a1, a2) + + case (IsTyped(a1, Int32Type), "*", List(IsTyped(a2, Int32Type))) => + Times(a1, a2) + + case (IsTyped(a1, Int32Type), "/", List(IsTyped(a2, Int32Type))) => + Division(a1, a2) + // Set methods case (IsTyped(a1, SetType(b1)), "min", Nil) => SetMin(a1).setType(b1) diff --git a/src/main/scala/leon/purescala/Constructors.scala b/src/main/scala/leon/purescala/Constructors.scala new file mode 100644 index 000000000..d02804e45 --- /dev/null +++ b/src/main/scala/leon/purescala/Constructors.scala @@ -0,0 +1,27 @@ +/* Copyright 2009-2014 EPFL, Lausanne */ + +package leon +package purescala + +import utils._ + +object Constructors { + import Trees._ + import Common._ + + def tupleSelect(t: Expr, index: Int) = t match { + case Tuple(es) => + es(index-1) + case _ => + TupleSelect(t, index) + } + + def letTuple(binders: Seq[Identifier], value: Expr, body: Expr) = binders match { + case Nil => + body + case x :: Nil => + Let(x, tupleSelect(value, 1), body) + case xs => + LetTuple(xs, value, body) + } +} diff --git a/src/main/scala/leon/synthesis/InOutExample.scala b/src/main/scala/leon/synthesis/InOutExample.scala new file mode 100644 index 000000000..6c64f3977 --- /dev/null +++ b/src/main/scala/leon/synthesis/InOutExample.scala @@ -0,0 +1,13 @@ +/* Copyright 2009-2014 EPFL, Lausanne */ + +package leon +package synthesis + +import purescala.Trees.Expr + +case class InOutExample(ins: Seq[Expr], outs: Seq[Expr]) { + def inExample = InExample(ins) +} + + +case class InExample(ins: Seq[Expr]) diff --git a/src/main/scala/leon/synthesis/Problem.scala b/src/main/scala/leon/synthesis/Problem.scala index a227eb7c1..2ed3e5d17 100644 --- a/src/main/scala/leon/synthesis/Problem.scala +++ b/src/main/scala/leon/synthesis/Problem.scala @@ -11,6 +11,119 @@ import leon.purescala.Common._ // ⟦ as ⟨ C | phi ⟩ xs ⟧ case class Problem(as: List[Identifier], pc: Expr, phi: Expr, xs: List[Identifier]) { override def toString = "⟦ "+as.mkString(";")+", "+(if (pc != BooleanLiteral(true)) pc+" ≺ " else "")+" ⟨ "+phi+" ⟩ "+xs.mkString(";")+" ⟧ " + + def getTests(sctx: SynthesisContext): Seq[InOutExample] = { + import purescala.Extractors._ + import evaluators._ + + val TopLevelAnds(predicates) = And(pc, phi) + + val ev = new DefaultEvaluator(sctx.context, sctx.program) + + def isValidExample(ex: InOutExample): Boolean = { + val mapping = Map((as zip ex.ins) ++ (xs zip ex.outs): _*) + + ev.eval(And(pc, phi), mapping) match { + case EvaluationResults.Successful(BooleanLiteral(true)) => true + case _ => false + } + } + + // Returns a list of identifiers, and extractors + def andThen(pf1: PartialFunction[Expr, Expr], pf2: PartialFunction[Expr, Expr]): PartialFunction[Expr, Expr] = { + Function.unlift(pf1.lift(_) flatMap pf2.lift) + } + + /** + * Extract ids in ins/outs args, and compute corresponding extractors for values map + * + * Examples: + * (a,b) => + * a -> _.1 + * b -> _.2 + * + * Cons(a, Cons(b, c)) => + * a -> _.head + * b -> _.tail.head + * c -> _.tail.tail + */ + def extractIds(e: Expr): Seq[(Identifier, PartialFunction[Expr, Expr])] = e match { + case Variable(id) => + List((id, { case e => e })) + case Tuple(vs) => + vs.map(extractIds).zipWithIndex.flatMap{ case (ids, i) => + ids.map{ case (id, e) => + (id, andThen({ case Tuple(vs) => vs(i) }, e)) + } + } + case CaseClass(cct, args) => + args.map(extractIds).zipWithIndex.flatMap { case (ids, i) => + ids.map{ case (id, e) => + (id, andThen({ case CaseClass(cct2, vs) if cct2 == cct => vs(i) } ,e)) + } + } + + case _ => + sctx.reporter.warning("Unnexpected pattern in test-ids extraction: "+e) + Nil + } + + def exprToIds(e: Expr): List[Identifier] = e match { + case Variable(i) => List(i) + case Tuple(is) => is.collect { case Variable(i) => i }.toList + case _ => Nil + } + + val testClusters = predicates.collect { + case FunctionInvocation(tfd, List(in, out, FiniteMap(inouts))) if tfd.id.name == "passes" => + val infos = extractIds(Tuple(Seq(in, out))) + val exs = inouts.map{ case (i, o) => Tuple(Seq(i, o)) } + + // Check whether we can extract all ids from example + val results = exs.collect { case e if infos.forall(_._2.isDefinedAt(e)) => + infos.map{ case (id, f) => id -> f(e) }.toMap + } + + results + } + + /** + * we now need to consolidate different clusters of compatible tests together + * t1: a->1, c->3 + * t2: a->1, b->4 + * => a->1, b->4, c->3 + */ + + def isCompatible(m1: Map[Identifier, Expr], m2: Map[Identifier, Expr]) = { + val ks = m1.keySet & m2.keySet + ks.nonEmpty && ks.map(m1) == ks.map(m2) + } + + def mergeTest(m1: Map[Identifier, Expr], m2: Map[Identifier, Expr]) = { + if (!isCompatible(m1, m2)) { + m1 + } else { + m1 ++ m2 + } + } + + var consolidated = Set[Map[Identifier, Expr]]() + for (ts <- testClusters; t <- ts) { + consolidated += t + + consolidated = consolidated.map { c => + mergeTest(c, t) + } + } + + // Finally, we keep complete tests covering all as++xs + val requiredIds = (as ++ xs).toSet + val complete = consolidated.filter{ t => (t.keySet & requiredIds) == requiredIds } + + complete.toSeq.map { m => + InOutExample(as.map(m), xs.map(m)) + }.filter(isValidExample) + } } object Problem { diff --git a/src/main/scala/leon/synthesis/Rules.scala b/src/main/scala/leon/synthesis/Rules.scala index 5f75db650..2a1b90973 100644 --- a/src/main/scala/leon/synthesis/Rules.scala +++ b/src/main/scala/leon/synthesis/Rules.scala @@ -141,16 +141,24 @@ object RuleInstantiation { } } -abstract class Rule(val name: String) { +abstract class Rule(val name: String) extends RuleHelpers { def instantiateOn(sctx: SynthesisContext, problem: Problem): Traversable[RuleInstantiation] - val priority: RulePriority = RulePriorityDefault - - def subst(what: Tuple2[Identifier, Expr], in: Expr): Expr = replace(Map(Variable(what._1) -> what._2), in) - def substAll(what: Map[Identifier, Expr], in: Expr): Expr = replace(what.map(w => Variable(w._1) -> w._2), in) + val priority: RulePriority = RulePriorityDefault implicit val debugSection = leon.utils.DebugSectionSynthesis + override def toString = "R: "+name +} + +abstract class NormalizingRule(name: String) extends Rule(name) { + override val priority = RulePriorityNormalizing +} + +trait RuleHelpers { + def subst(what: Tuple2[Identifier, Expr], in: Expr): Expr = replaceFromIDs(Map(what), in) + def substAll(what: Map[Identifier, Expr], in: Expr): Expr = replaceFromIDs(what, in) + val forward: List[Solution] => Option[Solution] = { case List(s) => Some(Solution(s.pre, s.defs, s.term)) @@ -169,10 +177,4 @@ abstract class Rule(val name: String) { case _ => None } - - override def toString = "R: "+name -} - -abstract class NormalizingRule(name: String) extends Rule(name) { - override val priority = RulePriorityNormalizing } diff --git a/src/main/scala/leon/synthesis/rules/Cegis.scala b/src/main/scala/leon/synthesis/rules/Cegis.scala index 99cdcda04..37d19bb36 100644 --- a/src/main/scala/leon/synthesis/rules/Cegis.scala +++ b/src/main/scala/leon/synthesis/rules/Cegis.scala @@ -55,7 +55,19 @@ case object CEGIS extends Rule("CEGIS") { { () => List((BooleanLiteral(true), Set()), (BooleanLiteral(false), Set())) } case Int32Type => - { () => List((IntLiteral(0), Set()), (IntLiteral(1), Set())) } + { () => + val ground = List((IntLiteral(0), Set[Identifier]()), (IntLiteral(1), Set[Identifier]())) + val ops = List[Function2[Expr, Expr, Expr]]( + (a,b) => Plus(a,b), + (a,b) => Minus(a,b), + (a,b) => Times(a,b) + ) + + ops.map{f => + val ids = List(FreshIdentifier("a", true).setType(Int32Type), FreshIdentifier("b", true).setType(Int32Type)) + (f(ids(0).toVariable, ids(1).toVariable), ids.toSet) + } ++ ground + } case TupleType(tps) => { () => @@ -460,6 +472,8 @@ case object CEGIS extends Rule("CEGIS") { // We populate the list of examples with a predefined one sctx.reporter.debug("Acquiring list of examples") + baseExampleInputs ++= p.getTests(sctx).map(_.ins).toSet + if (p.pc == BooleanLiteral(true)) { baseExampleInputs = p.as.map(a => simplestValue(a.getType)) +: baseExampleInputs } else { @@ -485,6 +499,12 @@ case object CEGIS extends Rule("CEGIS") { } } + sctx.reporter.ifDebug { debug => + baseExampleInputs.foreach { in => + debug(" - "+in.mkString(", ")) + } + } + val inputIterator: Iterator[Seq[Expr]] = if (useVanuatoo) { new VanuatooDataGen(sctx.context, sctx.program).generateFor(p.as, p.pc, 20, 3000) } else { @@ -509,6 +529,7 @@ case object CEGIS extends Rule("CEGIS") { for (prog <- programs) { val expr = ndProgram.determinize(prog) val res = Equals(Tuple(p.xs.map(Variable(_))), expr) + val solver3 = sctx.newSolver.setTimeout(cexSolverTo) solver3.assertCnstr(And(p.pc :: res :: Not(p.phi) :: Nil)) @@ -609,7 +630,11 @@ case object CEGIS extends Rule("CEGIS") { needMoreUnrolling = true; } else if (nPassing <= testUpTo) { // Immediate Test - result = Some(checkForPrograms(prunedPrograms)) + checkForPrograms(prunedPrograms) match { + case rs: RuleSuccess => + result = Some(rs) + case _ => + } } else if (((nPassing < allPrograms*filterThreshold) || didFilterAlready) && useBssFiltering) { // We filter the Bss so that the formula we give to z3 is much smalled val bssToKeep = prunedPrograms.foldLeft(Set[Identifier]())(_ ++ _) -- GitLab