diff --git a/src/main/scala/leon/ImperativeCodeElimination.scala b/src/main/scala/leon/ImperativeCodeElimination.scala index d3cd29907c5cd8c30fb574d06ddd1682e5bfa59a..127e862cfaa6a18faa9ca6a73b93eacb2fb4cac2 100644 --- a/src/main/scala/leon/ImperativeCodeElimination.scala +++ b/src/main/scala/leon/ImperativeCodeElimination.scala @@ -257,8 +257,8 @@ object ImperativeCodeElimination extends Pass { val (argVal, argScope, argFun) = toFunction(a) (recons(argVal).setType(u.getType), argScope, argFun) } - case (t: Terminal) => (t, (body: Expr) => body, Map()) + case (t: Terminal) => (t, (body: Expr) => body, Map()) case _ => sys.error("not supported: " + expr) } diff --git a/src/main/scala/leon/Main.scala b/src/main/scala/leon/Main.scala index 9af1b41e731e32ff4b25942c4ec4f6ac63f91b0a..584caaa6fcc7a074982bedd80e91d41901f6af12 100644 --- a/src/main/scala/leon/Main.scala +++ b/src/main/scala/leon/Main.scala @@ -4,6 +4,9 @@ import scala.tools.nsc.{Global,Settings=>NSCSettings,SubComponent,CompilerComman import purescala.Definitions.Program +import synthesis.Synthesizer +import purescala.ScalaPrinter + object Main { import leon.{Reporter,DefaultReporter,Analysis} @@ -32,11 +35,10 @@ object Main { private def defaultAction(program: Program, reporter: Reporter) : Unit = { Logger.debug("Default action on program: " + program, 3, "main") - val passManager = new PassManager(Seq(ArrayTransformation, EpsilonElimination, ImperativeCodeElimination, /*UnitElimination,*/ FunctionClosure, /*FunctionHoisting,*/ Simplificator)) - val program2 = passManager.run(program) - assert(program2.isPure) - val analysis = new Analysis(program2, reporter) - analysis.analyse + //val passManager = new PassManager(Seq(ArrayTransformation, EpsilonElimination, ImperativeCodeElimination, /*UnitElimination,*/ FunctionClosure, /*FunctionHoisting,*/ Simplificator)) + //val program2 = passManager.run(program) + assert(program.isPure) + val program2 = new Synthesizer().synthesizeAll(program) } private def runWithSettings(args : Array[String], settings : NSCSettings, printerFunction : String=>Unit, actionOnProgram : Option[Program=>Unit] = None) : Unit = { diff --git a/src/main/scala/leon/purescala/Trees.scala b/src/main/scala/leon/purescala/Trees.scala index 4778c9b2dbc01d0291a408adb27ab520d8b93778..e32b3dd7a888b3cc04dbb9ab43e3ff7823d74777 100644 --- a/src/main/scala/leon/purescala/Trees.scala +++ b/src/main/scala/leon/purescala/Trees.scala @@ -830,6 +830,7 @@ object Trees { case i @ IfExpr(a1,a2,a3) => compute(i, combine(combine(rec(a1), rec(a2)), rec(a3))) case m @ MatchExpr(scrut, cses) => compute(m, (scrut +: cses.flatMap(_.expressions)).map(rec(_)).reduceLeft(combine)) case a @ AnonymousFunction(es, ev) => compute(a, (es.flatMap(e => e._1 ++ Seq(e._2)) ++ Seq(ev)).map(rec(_)).reduceLeft(combine)) + case c @ Choose(args, body) => compute(c, rec(body)) case t: Terminal => compute(t, convert(t)) case unhandled => scala.sys.error("Non-terminal case should be handled in treeCatamorphism: " + unhandled) } @@ -870,7 +871,6 @@ object Trees { case ArrayMake(_) => false case ArrayClone(_) => false case Epsilon(_) => false - case Choose(_, _) => false case _ => true } def combine(b1: Boolean, b2: Boolean) = b1 && b2 @@ -884,7 +884,6 @@ object Trees { case ArrayMake(_) => false case ArrayClone(_) => false case Epsilon(_) => false - case Choose(_, _) => false case _ => b } treeCatamorphism(convert, combine, compute, expr) diff --git a/src/main/scala/leon/synthesis/Rules.scala b/src/main/scala/leon/synthesis/Rules.scala index 7cedb5d0207e1fd01f591e09ccd621cb2b7a93d5..b9dc74af176689e0a534bb6c753051781ac06a87 100644 --- a/src/main/scala/leon/synthesis/Rules.scala +++ b/src/main/scala/leon/synthesis/Rules.scala @@ -1,6 +1,11 @@ package leon package synthesis +object Rules { + def all = List() +} + + abstract class Rule(val name: String) { def isApplicable(p: Problem, parent: Task): List[Task] } diff --git a/src/main/scala/leon/synthesis/Synthesizer.scala b/src/main/scala/leon/synthesis/Synthesizer.scala index ba4c3c4f1a387d6cd6c64ac3fb442bd0d48242cd..ab5f75432c13de551f0eb0ba525a8f10516314c1 100644 --- a/src/main/scala/leon/synthesis/Synthesizer.scala +++ b/src/main/scala/leon/synthesis/Synthesizer.scala @@ -1,9 +1,14 @@ package leon package synthesis +import purescala.Definitions.Program + import collection.mutable.PriorityQueue class Synthesizer(rules: List[Rule]) { + def this() = this(Rules.all) + + def applyRules(p: Problem, parent: Task): List[Task] = { rules.flatMap(_.isApplicable(p, parent)) } @@ -48,7 +53,28 @@ class Synthesizer(rules: List[Rule]) { } + println + println(" ++ RESULT ++ ") + println("==> "+p+" ⊢ "+solution) + solution } + + def test1 = { + import purescala.Common._ + import purescala.Trees._ + import purescala.TypeTrees._ + + val aID = FreshIdentifier("a").setType(Int32Type) + val a = Variable(aID) + val p = Problem(Nil, And(GreaterThan(a, IntLiteral(2)), Equals(a, IntLiteral(3))), List(aID)) + + synthesize(p) + } + + def synthesizeAll(p: Program): Program = { + test1 + p + } } diff --git a/src/main/scala/leon/synthesis/Task.scala b/src/main/scala/leon/synthesis/Task.scala index 959f564527048e6896c646cd76f693c721fd8346..69a3232808e305ca445aebee1fb220fa3a6c53c2 100644 --- a/src/main/scala/leon/synthesis/Task.scala +++ b/src/main/scala/leon/synthesis/Task.scala @@ -27,7 +27,12 @@ class Task( subSolutions += p -> s if (subSolutions.size == subProblems.size) { - notifyParent(construct(subProblems map subSolutions)) + + val solution = construct(subProblems map subSolutions) + + println(": "+problem+" ⊢ "+solution) + + notifyParent(solution) } }