From cc2f0edb73b07cd4c58c9f63d0a07fd1db304365 Mon Sep 17 00:00:00 2001 From: Etienne Kneuss <etienne.kneuss@epfl.ch> Date: Fri, 12 Dec 2014 13:41:23 +0100 Subject: [PATCH] labels should be bestRealType. Filter cases by incompatible return type --- .../scala/leon/purescala/Constructors.scala | 18 ++- src/main/scala/leon/repair/Repairman.scala | 115 +++++++++--------- .../synthesis/rules/EquivalentInputs.scala | 10 +- .../synthesis/utils/ExpressionGrammar.scala | 13 +- 4 files changed, 89 insertions(+), 67 deletions(-) diff --git a/src/main/scala/leon/purescala/Constructors.scala b/src/main/scala/leon/purescala/Constructors.scala index b68f8903a..7f56b2c32 100644 --- a/src/main/scala/leon/purescala/Constructors.scala +++ b/src/main/scala/leon/purescala/Constructors.scala @@ -7,6 +7,7 @@ import utils._ object Constructors { import Trees._ + import TypeTreeOps._ import Common._ import TypeTrees._ @@ -52,8 +53,8 @@ object Constructors { case more => TupleType(more) } - private def filterCases(scrutType : TypeTree, cases: Seq[MatchCase]): Seq[MatchCase] = { - scrutType match { + private def filterCases(scrutType : TypeTree, resType: Option[TypeTree], cases: Seq[MatchCase]): Seq[MatchCase] = { + val casesFiltered = scrutType match { case c: CaseClassType => cases.filter(_.pattern match { case CaseClassPattern(_, cct, _) if cct.classDef != c.classDef => false @@ -66,13 +67,20 @@ object Constructors { case t => scala.sys.error("Constructing match expression on non-supported type: "+t) } + + resType match { + case Some(tpe) => + casesFiltered.filter(c => isSubtypeOf(c.rhs.getType, tpe) || isSubtypeOf(tpe, c.rhs.getType)) + case None => + casesFiltered + } } def gives(scrutinee : Expr, cases : Seq[MatchCase]) : Gives = - Gives(scrutinee, filterCases(scrutinee.getType, cases)) + Gives(scrutinee, filterCases(scrutinee.getType, None, cases)) def passes(in : Expr, out : Expr, cases : Seq[MatchCase]): Expr = { - val resultingCases = filterCases(in.getType, cases) + val resultingCases = filterCases(in.getType, Some(out.getType), cases) if (resultingCases.nonEmpty) { Passes(in, out, resultingCases) } else { @@ -81,7 +89,7 @@ object Constructors { } def matchExpr(scrutinee : Expr, cases : Seq[MatchCase]) : Expr ={ - val filtered = filterCases(scrutinee.getType, cases) + val filtered = filterCases(scrutinee.getType, None, cases) if (filtered.nonEmpty) MatchExpr(scrutinee, filtered) else diff --git a/src/main/scala/leon/repair/Repairman.scala b/src/main/scala/leon/repair/Repairman.scala index 561359715..0d87c91ce 100644 --- a/src/main/scala/leon/repair/Repairman.scala +++ b/src/main/scala/leon/repair/Repairman.scala @@ -30,6 +30,64 @@ class Repairman(ctx: LeonContext, initProgram: Program, fd: FunDef, verifTimeout implicit val debugSection = DebugSectionRepair + def repair() = { + reporter.info(ASCIIHelpers.title("1. Discovering tests for "+fd.id)) + val (tests, isVerified) = discoverTests + + if (isVerified) { + reporter.info("Program verifies!") + } + + reporter.info(ASCIIHelpers.title("2. Locating/Focusing synthesis problem")) + val synth = getSynthesizer(tests) + val p = synth.problem + + var solutions = List[Solution]() + + reporter.info(ASCIIHelpers.title("3. Synthesizing")) + reporter.info(p) + + synth.synthesize() match { + case (search, sols) => + for (sol <- sols) { + + // Validate solution if not trusted + if (!sol.isTrusted) { + reporter.info("Found untrusted solution! Verifying...") + val (npr, fds) = synth.solutionToProgram(sol) + + getVerificationCounterExamples(fds.head, npr) match { + case Some(ces) => + reporter.error("I ended up finding this counter example:\n"+ces.mkString(" | ")) + + case None => + solutions ::= sol + reporter.info("Solution was not trusted but verification passed!") + } + } else { + reporter.info("Found trusted solution!") + solutions ::= sol + } + } + + if (synth.options.generateDerivationTrees) { + val dot = new DotGenerator(search.g) + dot.writeFile("derivation"+DotGenerator.nextId()+".dot") + } + + if (solutions.isEmpty) { + reporter.error(ASCIIHelpers.title("Failed to repair!")) + } else { + reporter.info(ASCIIHelpers.title("Repair successful:")) + for ((sol, i) <- solutions.zipWithIndex) { + reporter.info(ASCIIHelpers.subTitle("Solution "+(i+1)+":")) + val expr = sol.toSimplifiedExpr(ctx, program) + reporter.info(ScalaPrinter(expr)); + } + } + } + } + def getSynthesizer(tests: List[Example]): Synthesizer = { // Create a fresh function val nid = FreshIdentifier(fd.id.name+"_repair").copiedFrom(fd.id) @@ -224,7 +282,7 @@ class Repairman(ctx: LeonContext, initProgram: Program, fd: FunDef, verifTimeout } - def discoverTests: List[Example] = { + def discoverTests: (List[Example], Boolean) = { import bonsai._ import bonsai.enumerators._ @@ -288,62 +346,9 @@ class Repairman(ctx: LeonContext, initProgram: Program, fd: FunDef, verifTimeout // Try to verify, if it fails, we have at least one CE val ces = getVerificationCounterExamples(fd, program) getOrElse Nil - tests ++ ces + (tests ++ ces, ces.isEmpty) } - def repair() = { - reporter.info(ASCIIHelpers.title("1. Discovering tests for "+fd.id)) - val tests = discoverTests - - reporter.info(ASCIIHelpers.title("2. Locating/Focusing synthesis problem")) - val synth = getSynthesizer(tests) - val p = synth.problem - - var solutions = List[Solution]() - - reporter.info(ASCIIHelpers.title("3. Synthesizing")) - reporter.info(p) - - synth.synthesize() match { - case (search, sols) => - for (sol <- sols) { - - // Validate solution if not trusted - if (!sol.isTrusted) { - reporter.info("Found untrusted solution! Verifying...") - val (npr, fds) = synth.solutionToProgram(sol) - - getVerificationCounterExamples(fds.head, npr) match { - case Some(ces) => - reporter.error("I ended up finding this counter example:\n"+ces.mkString(" | ")) - - case None => - solutions ::= sol - reporter.info("Solution was not trusted but verification passed!") - } - } else { - reporter.info("Found trusted solution!") - solutions ::= sol - } - } - - if (synth.options.generateDerivationTrees) { - val dot = new DotGenerator(search.g) - dot.writeFile("derivation"+DotGenerator.nextId()+".dot") - } - - if (solutions.isEmpty) { - reporter.error(ASCIIHelpers.title("Failed to repair!")) - } else { - reporter.info(ASCIIHelpers.title("Repair successful:")) - for ((sol, i) <- solutions.zipWithIndex) { - reporter.info(ASCIIHelpers.subTitle("Solution "+(i+1)+":")) - val expr = sol.toSimplifiedExpr(ctx, program) - reporter.info(ScalaPrinter(expr)); - } - } - } - } // ununsed for now, but implementation could be useful later private def disambiguate(p: Problem, sol1: Solution, sol2: Solution): Option[(InOutExample, InOutExample)] = { diff --git a/src/main/scala/leon/synthesis/rules/EquivalentInputs.scala b/src/main/scala/leon/synthesis/rules/EquivalentInputs.scala index 99057db0c..3a0847389 100644 --- a/src/main/scala/leon/synthesis/rules/EquivalentInputs.scala +++ b/src/main/scala/leon/synthesis/rules/EquivalentInputs.scala @@ -8,6 +8,7 @@ import leon.utils._ import purescala.Trees._ import purescala.TreeOps._ import purescala.Extractors._ +import purescala.Constructors._ case object EquivalentInputs extends NormalizingRule("EquivalentInputs") { def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { @@ -51,11 +52,18 @@ case object EquivalentInputs extends NormalizingRule("EquivalentInputs") { val substs = discoverEquivalences(clauses) + val postsToInject = substs.collect { + case (FunctionInvocation(tfd, args), e) if tfd.hasPostcondition => + val Some((id, post)) = tfd.postcondition + + replaceFromIDs((tfd.params.map(_.id) zip args).toMap + (id -> e), post) + } + if (substs.nonEmpty) { val simplifier = Simplifiers.bestEffort(sctx.context, sctx.program) _ val sub = p.copy(ws = replaceSeq(substs, p.ws), - pc = simplifier(replaceSeq(substs, p.pc)), + pc = simplifier(andJoin(replaceSeq(substs, p.pc) +: postsToInject)), phi = simplifier(replaceSeq(substs, p.phi))) List(RuleInstantiation.immediateDecomp(p, this, List(sub), forward, this.name)) diff --git a/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala b/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala index 565a6c2bc..8ea105ce8 100644 --- a/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala +++ b/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala @@ -186,8 +186,9 @@ object ExpressionGrammars { def computeSimilar(e : Expr) : Seq[(L, Gen)] = { def getLabelPair(t: TypeTree) = { + val tpe = bestRealType(t) val c = getNext - (Label(t, "E"+c), Label(t, "G"+c)) + (Label(tpe, "E"+c), Label(tpe, "G"+c)) } def isCommutative(e: Expr) = e match { @@ -282,12 +283,12 @@ object ExpressionGrammars { val res = rec(e, el, gl) - //for ((t, g) <- res) { - // val subs = g.subTrees.map { t => FreshIdentifier(t.toString).setType(t.getType).toVariable} - // val gen = g.builder(subs) + for ((t, g) <- res) { + val subs = g.subTrees.map { t => FreshIdentifier(t.toString).setType(t.getType).toVariable} + val gen = g.builder(subs) - // println(f"$t%30s ::= "+gen) - //} + println(f"$t%30s ::= "+gen) + } res } } -- GitLab