diff --git a/src/main/scala/leon/synthesis/grammars/ContextGrammar.scala b/src/main/scala/leon/synthesis/grammars/ContextGrammar.scala index c7351ca8b27390693ec479bfb922a1c5b2c04823..8adeeeae5904d71935769ebf9ad913c1e81e9d60 100644 --- a/src/main/scala/leon/synthesis/grammars/ContextGrammar.scala +++ b/src/main/scala/leon/synthesis/grammars/ContextGrammar.scala @@ -199,20 +199,22 @@ class ContextGrammar[SymbolTag, TerminalData] { def markovize_horizontal_filtered(pred: NonTerminal => Boolean, recursive: Boolean): Grammar = { var toDuplicate = Map[NonTerminal, Set[NonTerminal]]() var originals = Map[NonTerminal, NonTerminal]() - def getOriginal(nt: NonTerminal): NonTerminal = originals.get(nt).map(getOriginal).getOrElse(nt) + def getOriginal(nt: NonTerminal): NonTerminal = { + originals.get(nt).map(nt2 => if(nt2 != nt) getOriginal(nt2) else nt2).getOrElse(nt) + } val c = new MarkovizationContext(pred) { def process_sequence(ls: Seq[Symbol]): List[Symbol] = { val (_, res) = ((ListBuffer[Symbol](), ListBuffer[Symbol]()) /: ls) { case ((lbold, lbnew), nt: NonTerminal) if pred(nt) => val context_version = nt.copy(hcontext = lbold.toList) toDuplicate += nt -> (toDuplicate.getOrElse(nt, Set.empty[NonTerminal]) + context_version) - originals += context_version -> nt - for(descendant <- getDescendants(nt)) { + if(context_version != nt) originals += context_version -> nt + for(descendant <- getDescendants(nt) if descendant != nt) { val descendant_context_version = descendant.copy(hcontext = lbold.toList) toDuplicate += descendant -> (toDuplicate.getOrElse(descendant, Set.empty[NonTerminal]) + descendant_context_version) originals += descendant_context_version -> descendant } - for(ascendant <- getAncestors(nt)) { + for(ascendant <- getAncestors(nt) if ascendant != nt) { val acendant_context_version = ascendant.copy(hcontext = lbold.toList) toDuplicate += ascendant -> (toDuplicate.getOrElse(ascendant, Set.empty[NonTerminal]) + acendant_context_version) originals += acendant_context_version -> ascendant @@ -270,7 +272,7 @@ class ContextGrammar[SymbolTag, TerminalData] { /** Applies horizontal markovization to the grammar (add the left history to every node and duplicate rules as needed. * Is idempotent. */ def markovize_horizontal(): Grammar = { - markovize_horizontal_filtered(_ => true, true) + markovize_horizontal_filtered(_ => true, false) } /** Same as vertical markovization, but we add in the vertical context only the nodes coming from a "different abstract hierarchy". Top-level nodes count as a different hierarchy. diff --git a/src/test/scala/leon/integration/solvers/ContextGrammarSuite.scala b/src/test/scala/leon/integration/solvers/ContextGrammarSuite.scala index 86216a2dafda219434cc2ca7e700874bf2613de9..a2b2317acb78c448d2a8db879d84cf6a50e2cfe5 100644 --- a/src/test/scala/leon/integration/solvers/ContextGrammarSuite.scala +++ b/src/test/scala/leon/integration/solvers/ContextGrammarSuite.scala @@ -120,8 +120,8 @@ class ContextGrammarSuite extends FunSuite with Matchers with ScalaFutures { AB -> Expansion(List(List(x))) )) - grammar1.markovize_horizontal_filtered(_.tag == "B") should equalGrammar (grammar2) - grammar2.markovize_horizontal_filtered(_.tag == "B") should equalGrammar (grammar2) + grammar1.markovize_horizontal_filtered(_.tag == "B", false) should equalGrammar (grammar2) + grammar2.markovize_horizontal_filtered(_.tag == "B", false) should equalGrammar (grammar2) } test("Vertical Markovization simple") {