diff --git a/src/main/scala/leon/Main.scala b/src/main/scala/leon/Main.scala index fc26196b608dad2c67e0125c600c4c1197695fa7..f04a463cfe6dfaf969782c014a9cc344e72e012f 100644 --- a/src/main/scala/leon/Main.scala +++ b/src/main/scala/leon/Main.scala @@ -18,11 +18,36 @@ object Main { } lazy val allOptions = allPhases.flatMap(_.definedOptions) ++ Set( - LeonFlagOptionDef ("synthesis", "--synthesis", "Partial synthesis or choose() constructs"), - LeonFlagOptionDef ("xlang", "--xlang", "Support for extra program constructs (imperative,...)"), - LeonFlagOptionDef ("parse", "--parse", "Checks only whether the program is valid PureScala"), - LeonValueOptionDef("debug", "--debug=[1-5]", "Debug level"), - LeonFlagOptionDef ("help", "--help", "This help") + LeonFlagOptionDef ("synthesis", "--synthesis", "Partial synthesis or choose() constructs"), + LeonFlagOptionDef ("xlang", "--xlang", "Support for extra program constructs (imperative,...)"), + LeonFlagOptionDef ("parse", "--parse", "Checks only whether the program is valid PureScala"), + LeonValueOptionDef("debug", "--debug=[1-5]", "Debug level"), + LeonFlagOptionDef ("help", "--help", "This help") + + // Unimplemented Options: + // + // LeonFlagOptionDef("uniqid", "--uniqid", "When pretty-printing purescala trees, show identifiers IDs"), + // LeonValueOptionDef("extensions", "--extensions=ex1:...", "Specifies a list of qualified class names of extensions to be loaded"), + // LeonFlagOptionDef("nodefaults", "--nodefaults", "Runs only the analyses provided by the extensions"), + // LeonValueOptionDef("functions", "--functions=fun1:...", "Only generates verification conditions for the specified functions"), + // LeonFlagOptionDef("unrolling", "--unrolling=[0,1,2]", "Unrolling depth for recursive functions" ), + // LeonFlagOptionDef("axioms", "--axioms", "Generate simple forall axioms for recursive functions when possible" ), + // LeonFlagOptionDef("tolerant", "--tolerant", "Silently extracts non-pure function bodies as ''unknown''"), + // LeonFlagOptionDef("bapa", "--bapa", "Use BAPA Z3 extension (incompatible with many other things)"), + // LeonFlagOptionDef("impure", "--impure", "Generate testcases only for impure functions"), + // LeonValueOptionDef("testcases", "--testcases=[1,2]", "Number of testcases to generate per function"), + // LeonValueOptionDef("testbounds", "--testbounds=l:u", "Lower and upper bounds for integers in recursive datatypes"), + // LeonValueOptionDef("timeout", "--timeout=N", "Sets a timeout of N seconds"), + // LeonFlagOptionDef("XP", "--XP", "Enable weird transformations and other bug-producing features"), + // LeonFlagOptionDef("BV", "--BV", "Use bit-vectors for integers"), + // LeonFlagOptionDef("prune", "--prune", "Use additional SMT queries to rule out some unrollings"), + // LeonFlagOptionDef("cores", "--cores", "Use UNSAT cores in the unrolling/refinement step"), + // LeonFlagOptionDef("quickcheck", "--quickcheck", "Use QuickCheck-like random search"), + // LeonFlagOptionDef("parallel", "--parallel", "Run all solvers in parallel"), + // LeonFlagOptionDef("noLuckyTests", "--noLuckyTests", "Do not perform additional tests to potentially find models early"), + // LeonFlagOptionDef("noverifymodel", "--noverifymodel", "Do not verify the correctness of models returned by Z3"), + // LeonValueOptionDef("tags", "--tags=t1:...", "Filter out debug information that are not of one of the given tags"), + // LeonFlagOptionDef("oneline", "--oneline", "Reduce the output to a single line: valid if all properties were valid, invalid if at least one is invalid, unknown else") ) def displayHelp(reporter: Reporter) { diff --git a/src/main/scala/leon/purescala/Extractors.scala b/src/main/scala/leon/purescala/Extractors.scala index c169ba5f913b100c4118baaed4cb6cc4b574345c..3ad1d4f7a1b7ece5823bb84238d42a002c27c734 100644 --- a/src/main/scala/leon/purescala/Extractors.scala +++ b/src/main/scala/leon/purescala/Extractors.scala @@ -30,10 +30,15 @@ object Extractors { case ArrayMake(t) => Some((t, ArrayMake)) case Waypoint(i, t) => Some((t, (expr: Expr) => Waypoint(i, expr))) case e@Epsilon(t) => Some((t, (expr: Expr) => Epsilon(expr).setType(e.getType).setPosInfo(e))) + case ue: UnaryExtractable => ue.extract case _ => None } } + trait UnaryExtractable { + def extract: Option[(Expr, (Expr)=>Expr)]; + } + object BinaryOperator { def unapply(expr: Expr) : Option[(Expr,Expr,(Expr,Expr)=>Expr)] = expr match { case Equals(t1,t2) => Some((t1,t2,Equals.apply)) @@ -68,11 +73,17 @@ object Extractors { case ArraySelect(t1, t2) => Some((t1, t2, ArraySelect)) case Concat(t1,t2) => Some((t1,t2,Concat)) case ListAt(t1,t2) => Some((t1,t2,ListAt)) + case LetTuple(binders, e, body) => Some((e, body, (e: Expr, b: Expr) => LetTuple(binders, e, body))) case wh@While(t1, t2) => Some((t1,t2, (t1, t2) => While(t1, t2).setInvariant(wh.invariant).setPosInfo(wh))) + case ex: BinaryExtractable => ex.extract case _ => None } } + trait BinaryExtractable { + def extract: Option[(Expr, Expr, (Expr, Expr)=>Expr)]; + } + object NAryOperator { def unapply(expr: Expr) : Option[(Seq[Expr],(Seq[Expr])=>Expr)] = expr match { case fi @ FunctionInvocation(fd, args) => Some((args, (as => FunctionInvocation(fd, as).setPosInfo(fi)))) @@ -89,10 +100,16 @@ object Extractors { case Distinct(args) => Some((args, Distinct)) case Block(args, rest) => Some((args :+ rest, exprs => Block(exprs.init, exprs.last))) case Tuple(args) => Some((args, Tuple)) + case IfExpr(cond, then, elze) => Some((Seq(cond, then, elze), (as: Seq[Expr]) => IfExpr(as(0), as(1), as(2)))) + case ex: NAryExtractable => ex.extract case _ => None } } + trait NAryExtractable { + def extract: Option[(Seq[Expr], (Seq[Expr])=>Expr)]; + } + object SimplePatternMatching { def isSimple(me: MatchExpr) : Boolean = unapply(me).isDefined diff --git a/src/main/scala/leon/purescala/TreeOps.scala b/src/main/scala/leon/purescala/TreeOps.scala index a0494a8e440a799ca8d2d600a5a6fda6d9268170..0e63a4a27465728906a133625b24d15d0398635a 100644 --- a/src/main/scala/leon/purescala/TreeOps.scala +++ b/src/main/scala/leon/purescala/TreeOps.scala @@ -950,4 +950,63 @@ object TreeOps { } fix(searchAndReplaceDFS(transform), expr) } + + def genericTransform[C](pre: (Expr, C) => (Expr, C), + post: (Expr, C) => (Expr, C), + combiner: (Expr, C, Seq[C]) => C)(init: C)(expr: Expr) = { + + def rec(eIn: Expr, cIn: C): (Expr, C) = { + + val (expr, ctx) = pre(eIn, cIn) + + val (newExpr, newC) = expr match { + case t: Terminal => + (expr, ctx) + + case UnaryOperator(e, builder) => + val (e1, c) = rec(e, ctx) + val newE = builder(e1) + + (newE, combiner(newE, ctx, Seq(c))) + + case BinaryOperator(e1, e2, builder) => + val (ne1, c1) = rec(e1, ctx) + val (ne2, c2) = rec(e2, ctx) + val newE = builder(ne1, ne2) + + (newE, combiner(newE, ctx, Seq(c1, c2))) + + case NAryOperator(es, builder) => + val (nes, cs) = es.map(e => rec(e, ctx)).unzip + val newE = builder(nes) + + (newE, combiner(newE, ctx, cs)) + + case e => + sys.error("Expression "+e+" ["+e.getClass+"] is not extractable") + } + + post(newExpr, newC) + } + + rec(expr, init) + } + + def noPre[C] (e: Expr, c: C) = (e, c) + def noPost[C](e: Expr, c: C) = (e, c) + def noCombiner[C](e: Expr, initC: C, subCs: Seq[C]) = initC + + def patternMatchReconstruction(e: Expr): Expr = { + case class Context() + + def pre(e: Expr, c: Context): (Expr, Context) = e match { + case IfExpr(cond, then, elze) => + println("Found IF: "+e) + (e, c) + case _ => + (e, c) + } + + genericTransform[Context](pre, noPost, noCombiner)(Context())(e)._1 + } } diff --git a/src/main/scala/leon/purescala/Trees.scala b/src/main/scala/leon/purescala/Trees.scala index 246f335ada8a1c0d084ce1ac7e355db619805320..3cb5280ede5a1dfb10aedf4c5f051480b2866b70 100644 --- a/src/main/scala/leon/purescala/Trees.scala +++ b/src/main/scala/leon/purescala/Trees.scala @@ -6,6 +6,7 @@ object Trees { import Common._ import TypeTrees._ import Definitions._ + import Extractors._ /* EXPRESSIONS */ @@ -42,7 +43,9 @@ object Trees { case class Epsilon(pred: Expr) extends Expr with ScalacPositional - case class Choose(vars: List[Identifier], pred: Expr) extends Expr with ScalacPositional + case class Choose(vars: List[Identifier], pred: Expr) extends Expr with ScalacPositional with UnaryExtractable { + def extract = Some((pred, (e: Expr) => Choose(vars, e).setPosInfo(this))) + } /* Like vals */ case class Let(binder: Identifier, value: Expr, body: Expr) extends Expr { diff --git a/src/main/scala/leon/synthesis/SynthesisPhase.scala b/src/main/scala/leon/synthesis/SynthesisPhase.scala index 3138c2f4bf5ba4e9de1f23f1c4790ecaa931c511..e61c6f3d008b84e230fff84bfb81940dbc75b7c0 100644 --- a/src/main/scala/leon/synthesis/SynthesisPhase.scala +++ b/src/main/scala/leon/synthesis/SynthesisPhase.scala @@ -1,10 +1,10 @@ package leon package synthesis +import purescala.TreeOps._ import solvers.TrivialSolver import solvers.z3.FairZ3Solver -import purescala.TreeOps.simplifyLets import purescala.Trees.Expr import purescala.ScalaPrinter import purescala.Definitions.Program @@ -41,7 +41,8 @@ object SynthesisPhase extends LeonPhase[Program, Program] { // Simplify expressions val simplifiers = List[Expr => Expr]( - simplifyLets _ + simplifyLets _, + patternMatchReconstruction _ ) val chooseToExprs = solutions.mapValues(sol => simplifiers.foldLeft(sol.toExpr){ (x, sim) => sim(x) }) diff --git a/testcases/synthesis/Matching.scala b/testcases/synthesis/Matching.scala index 21a7cf442e9165130e5a0c095bc4c86b608b9e69..1be8706b677e62c1d366f845edc80b684be0a401 100644 --- a/testcases/synthesis/Matching.scala +++ b/testcases/synthesis/Matching.scala @@ -1,7 +1,7 @@ import leon.Utils._ object Matching { - def t1(a: NatList) = choose( (x: Nat) => Cons(x, Nil()) == a) + def t1(a: NatList) = choose( (x: Nat) => Cons(x, Nil()) == a) abstract class Nat case class Z() extends Nat