diff --git a/src/main/scala/leon/synthesis/rules/Cegis.scala b/src/main/scala/leon/synthesis/rules/Cegis.scala index 555006c197087ffa844a7bf3d82821c3dfd97538..b033058d8cb78fd3be7f2700d9997fd0df0c69cc 100644 --- a/src/main/scala/leon/synthesis/rules/Cegis.scala +++ b/src/main/scala/leon/synthesis/rules/Cegis.scala @@ -17,5 +17,5 @@ case object CEGIS extends CEGISLike[TypeTree]("CEGIS") { ExpressionGrammars.default(sctx, p) } - def getGrammarLabel(id: Identifier): TypeTree = id.getType + def getRootLabel(tpe: TypeTree): TypeTree = tpe } diff --git a/src/main/scala/leon/synthesis/rules/CegisLike.scala b/src/main/scala/leon/synthesis/rules/CegisLike.scala index 5d2e203f03e81193272da0000c8d6d5e73749313..5470b1769f0557f935098e9763eaa011e7289f85 100644 --- a/src/main/scala/leon/synthesis/rules/CegisLike.scala +++ b/src/main/scala/leon/synthesis/rules/CegisLike.scala @@ -33,7 +33,7 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { def getGrammar(sctx: SynthesisContext, p: Problem): ExpressionGrammar[T] - def getGrammarLabel(id: Identifier): T + def getRootLabel(tpe: TypeTree): T val maxUnfoldings = 3 @@ -67,7 +67,7 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { // b -> Set(c1, c2) means c1 and c2 are uninterpreted behind b, requires b to be closed private var guardedTerms: Map[Identifier, Set[Identifier]] = Map(initGuard -> p.xs.toSet) - private var labels: Map[Identifier, T] = Map() ++ p.xs.map(x => x -> getGrammarLabel(x)) + private var labels: Map[Identifier, T] = Map() ++ p.xs.map(x => x -> getRootLabel(x.getType)) def isBClosed(b: Identifier) = guardedTerms.contains(b) diff --git a/src/main/scala/leon/synthesis/rules/Cegless.scala b/src/main/scala/leon/synthesis/rules/Cegless.scala index d5fbb2619376d24fd995865ce2851448dc3b7758..bca7d6787712b4e359ffb0b319472c23ac889236 100644 --- a/src/main/scala/leon/synthesis/rules/Cegless.scala +++ b/src/main/scala/leon/synthesis/rules/Cegless.scala @@ -32,7 +32,7 @@ case object CEGLESS extends CEGISLike[Label[String]]("CEGLESS") { guidedGrammar } - def getGrammarLabel(id: Identifier): Label[String] = Label(id.getType, "G0") + def getRootLabel(tpe: TypeTree): Label[String] = Label(tpe, "G0") } diff --git a/src/main/scala/leon/synthesis/rules/Tegis.scala b/src/main/scala/leon/synthesis/rules/Tegis.scala index 3ce5a908902304cd0f39ed856b5cf34717b2408a..736b7035380dda4bbd718a257278b7dcbbb9dcae 100644 --- a/src/main/scala/leon/synthesis/rules/Tegis.scala +++ b/src/main/scala/leon/synthesis/rules/Tegis.scala @@ -4,125 +4,21 @@ package leon package synthesis package rules -import solvers._ -import solvers.z3._ - import purescala.Trees._ import purescala.Common._ import purescala.Definitions._ import purescala.TypeTrees._ -import purescala.TreeOps._ -import purescala.DefOps._ -import purescala.TypeTreeOps._ -import purescala.Extractors._ -import purescala.ScalaPrinter import purescala.Constructors._ -import scala.collection.mutable.{Map=>MutableMap, ArrayBuffer} - import evaluators._ import datagen._ -import codegen.CodeGenParams import utils._ -import bonsai._ -import bonsai.enumerators._ - -case object TEGIS extends Rule("TEGIS") { - - def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { - - List(new RuleInstantiation(p, this, SolutionBuilder.none, this.name, this.priority) { - def apply(sctx: SynthesisContext): RuleApplication = { - - val grammar = ExpressionGrammars.default(sctx, p) - - var tests = p.getTests(sctx).map(_.ins).distinct - if (tests.nonEmpty) { - - val evalParams = CodeGenParams(maxFunctionInvocations = 2000, checkContracts = true) - //val evaluator = new CodeGenEvaluator(sctx.context, sctx.program, evalParams) - //val evaluator = new DefaultEvaluator(sctx.context, sctx.program) - val evaluator = new DualEvaluator(sctx.context, sctx.program, evalParams) - - val interruptManager = sctx.context.interruptManager - - val enum = new MemoizedEnumerator[TypeTree, Expr](grammar.getProductions _) - - val (targetType, isWrapped) = if (p.xs.size == 1) { - (p.xs.head.getType, false) - } else { - (TupleType(p.xs.map(_.getType)), true) - } - - val timers = sctx.context.timers.synthesis.rules.tegis; - - val allExprs = enum.iterator(targetType) - - var failStat = Map[Seq[Expr], Int]().withDefaultValue(0) - - var candidate: Option[Expr] = None - var enumLimit = 10000; - var n = 1; - - def findNext(): Option[Expr] = { - candidate = None - timers.generating.start() - allExprs.take(enumLimit).takeWhile(e => candidate.isEmpty).foreach { e => - val exprToTest = if (!isWrapped) { - Let(p.xs.head, e, p.phi) - } else { - letTuple(p.xs, e, p.phi) - } - - sctx.reporter.debug("Got expression "+e) - timers.testing.start() - if (tests.forall{ case t => - val ts = System.currentTimeMillis - val res = evaluator.eval(exprToTest, p.as.zip(t).toMap) match { - case EvaluationResults.Successful(BooleanLiteral(true)) => - sctx.reporter.debug("Test "+t+" passed!") - true - case _ => - sctx.reporter.debug("Test "+t+" failed on "+e) - failStat += t -> (failStat(t) + 1) - false - } - res - }) { - if (isWrapped) { - candidate = Some(e) - } else { - candidate = Some(Tuple(Seq(e))) - } - } - timers.testing.stop() - - if (n % 50 == 0) { - tests = tests.sortBy(t => -failStat(t)) - } - n += 1 - } - timers.generating.stop() - - candidate - } - - def toStream(): Stream[Solution] = { - findNext() match { - case Some(e) => - Stream.cons(Solution(BooleanLiteral(true), Set(), e, isTrusted = false), toStream()) - case None => - Stream.empty - } - } - - RuleClosed(toStream()) - } else { - RuleFailed() - } - } - }) +case object TEGIS extends TEGISLike[TypeTree]("TEGIS") { + def getGrammar(sctx: SynthesisContext, p: Problem) = { + ExpressionGrammars.default(sctx, p) } + + def getRootLabel(tpe: TypeTree): TypeTree = tpe } diff --git a/src/main/scala/leon/synthesis/rules/TegisLike.scala b/src/main/scala/leon/synthesis/rules/TegisLike.scala new file mode 100644 index 0000000000000000000000000000000000000000..ed5003d47a27f4d9bb600496dd6ec4b6c14c19f3 --- /dev/null +++ b/src/main/scala/leon/synthesis/rules/TegisLike.scala @@ -0,0 +1,122 @@ +/* Copyright 2009-2014 EPFL, Lausanne */ + +package leon +package synthesis +package rules + +import purescala.Trees._ +import purescala.Common._ +import purescala.Definitions._ +import purescala.TypeTrees._ +import purescala.Extractors._ +import purescala.Constructors._ + +import evaluators._ +import datagen._ +import codegen.CodeGenParams + +import utils._ + +import bonsai._ +import bonsai.enumerators._ + +abstract class TEGISLike[T <% Typed](name: String) extends Rule(name) { + def getGrammar(sctx: SynthesisContext, p: Problem): ExpressionGrammar[T] + + def getRootLabel(tpe: TypeTree): T + + val enumLimit = 10000; + val testsReorderInterval = 50; // Every X test filterings, we reorder tests with most filtering first. + + def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { + + List(new RuleInstantiation(p, this, SolutionBuilder.none, this.name, this.priority) { + def apply(sctx: SynthesisContext): RuleApplication = { + + val grammar = getGrammar(sctx, p) + + var tests = p.getTests(sctx).map(_.ins).distinct + if (tests.nonEmpty) { + + val evalParams = CodeGenParams(maxFunctionInvocations = 2000, checkContracts = true) + val evaluator = new DualEvaluator(sctx.context, sctx.program, evalParams) + + val interruptManager = sctx.context.interruptManager + + val enum = new MemoizedEnumerator[T, Expr](grammar.getProductions _) + + val (targetType, isWrapped) = if (p.xs.size == 1) { + (p.xs.head.getType, false) + } else { + (TupleType(p.xs.map(_.getType)), true) + } + + val timers = sctx.context.timers.synthesis.rules.tegis; + + val allExprs = enum.iterator(getRootLabel(targetType)) + + var failStat = Map[Seq[Expr], Int]().withDefaultValue(0) + + var candidate: Option[Expr] = None + var n = 1; + + def findNext(): Option[Expr] = { + candidate = None + timers.generating.start() + allExprs.take(enumLimit).takeWhile(e => candidate.isEmpty).foreach { e => + val exprToTest = if (!isWrapped) { + Let(p.xs.head, e, p.phi) + } else { + letTuple(p.xs, e, p.phi) + } + + sctx.reporter.debug("Got expression "+e) + timers.testing.start() + if (tests.forall{ case t => + val ts = System.currentTimeMillis + val res = evaluator.eval(exprToTest, p.as.zip(t).toMap) match { + case EvaluationResults.Successful(BooleanLiteral(true)) => + sctx.reporter.debug("Test "+t+" passed!") + true + case _ => + sctx.reporter.debug("Test "+t+" failed on "+e) + failStat += t -> (failStat(t) + 1) + false + } + res + }) { + if (isWrapped) { + candidate = Some(e) + } else { + candidate = Some(Tuple(Seq(e))) + } + } + timers.testing.stop() + + if (n % testsReorderInterval == 0) { + tests = tests.sortBy(t => -failStat(t)) + } + n += 1 + } + timers.generating.stop() + + candidate + } + + def toStream(): Stream[Solution] = { + findNext() match { + case Some(e) => + Stream.cons(Solution(BooleanLiteral(true), Set(), e, isTrusted = false), toStream()) + case None => + Stream.empty + } + } + + RuleClosed(toStream()) + } else { + RuleFailed() + } + } + }) + } +}