/* Copyright 2009-2014 EPFL, Lausanne */

package leon
package synthesis
package rules

import leon.utils.SeqUtils
import solvers._
import solvers.z3._

import purescala.Expressions._
import purescala.Common._
import purescala.Definitions._
import purescala.Types._
import purescala.ExprOps._
import purescala.DefOps._
import purescala.Constructors._

import scala.collection.mutable.{HashMap=>MutableMap, ArrayBuffer}

import evaluators._
import datagen._
import codegen.CodeGenParams

import utils._

abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) {

  case class CegisParams(
    grammar: ExpressionGrammar[T],
    rootLabel: TypeTree => T,
    maxUnfoldings: Int = 3
  )

  def getParams(sctx: SynthesisContext, p: Problem): CegisParams

  def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = {
    val exSolverTo  = 2000L
    val cexSolverTo = 2000L

    // Track non-deterministic programs up to 50'000 programs, or give up
    val nProgramsLimit = 100000

    val sctx = hctx.sctx
    val ctx  = sctx.context

    // CEGIS Flags to activate or deactivate features
    val useOptTimeout         = sctx.settings.cegisUseOptTimeout.getOrElse(true)
    val useVanuatoo           = sctx.settings.cegisUseVanuatoo.getOrElse(false)
    val useShrink             = sctx.settings.cegisUseShrink.getOrElse(true)

    // Limits the number of programs CEGIS will specifically validate individually
    val validateUpTo          = 5

    // Shrink the program when the ratio of passing cases is less than the threshold
    val shrinkThreshold       = 1.0/2

    val interruptManager      = sctx.context.interruptManager

    val params = getParams(sctx, p)

    if (params.maxUnfoldings == 0) {
      return Nil
    }

    class NonDeterministicProgram(val p: Problem) {

      private val grammar = params.grammar

      /**
       * Different view of the tree of expressions:
       *
       * Case used to illustrate the different views, assuming encoding:
       *
       *   b1 => c1 == F(c2, c3)
       *   b2 => c1 == G(c4, c5)
       *   b3 => c6 == H(c4, c5)
       *
       * c1 -> Seq(
       *         (b1, F(c2, c3), Set(c2, c3))
       *         (b2, G(c4, c5), Set(c4, c5))
       *       )
       * c6 -> Seq(
       *         (b3, H(c7, c8), Set(c7, c8))
       *       )
       */
      private var cTree: Map[Identifier, Seq[(Identifier, Expr, Set[Identifier])]] = Map()

      /**
       * Computes dependencies of c's
       *
       * c1 -> Set(c2, c3, c4, c5)
       */
      private var cDeps: Map[Identifier, Set[Identifier]] = Map()

      /**
       * Keeps track of blocked Bs and which C are affected, assuming cs are undefined:
       *
       * b2 -> Set(c4)
       * b3 -> Set(c4)
       */
      private var closedBs: Map[Identifier, Set[Identifier]] = Map()

      /**
       * Maps c identifiers to grammar labels
       *
       * Labels allows us to use grammars that are not only type-based
       */
      private var labels: Map[Identifier, T] = Map() ++ p.xs.map(x => x -> params.rootLabel(x.getType))

      private var bs: Set[Identifier]        = Set()
      private var bsOrdered: Seq[Identifier] = Seq()

      /**
       * Checks if 'b' is closed (meaning it depends on uninterpreted terms)
       */
      def isBActive(b: Identifier) = !closedBs.contains(b)


      def allProgramsCount(): Int = {
        var nAltsCache = Map[Identifier, Int]()

        def nAltsFor(c: Identifier): Int = {
          if (!(nAltsCache contains c)) {
            val subs = for ((b, _, subcs) <- cTree(c) if isBActive(b)) yield {
              if (subcs.isEmpty) {
                1
              } else {
                subcs.toSeq.map(nAltsFor).product
              }
            }

            nAltsCache += c -> subs.sum
          }
          nAltsCache(c)
        }

        p.xs.map(nAltsFor).product
      }

      /**
       * Returns all possible assignments to Bs in order to enumerate all possible programs
       */
      def allPrograms(): Traversable[Set[Identifier]] = {

        if (allProgramsCount() > nProgramsLimit) {
           return Seq()
        }

        var cache = Map[Identifier, Seq[Set[Identifier]]]()

        def allProgramsFor(cs: Set[Identifier]): Seq[Set[Identifier]] = {
          val seqs = for (c <- cs.toSeq) yield {
            if (!(cache contains c)) {
              val subs = for ((b, _, subcs) <- cTree(c) if isBActive(b)) yield {
                if (subcs.isEmpty) {
                  Seq(Set(b))
                } else {
                  for (p <- allProgramsFor(subcs)) yield {
                    p + b
                  }
                }
              }
              cache += c -> subs.flatten
            }
            cache(c)
          }

          SeqUtils.cartesianProduct(seqs).map { ls =>
            ls.foldLeft(Set[Identifier]())(_ ++ _)
          }
        }

        allProgramsFor(p.xs.toSet)
      }

      private def debugCExpr(cTree: Map[Identifier, Seq[(Identifier, Expr, Set[Identifier])]],
                             markedBs: Set[Identifier] = Set()): Unit = {
        println(" -- -- -- -- -- ")
        for ((c, alts) <- cTree) {
          println
          println(f"$c%-4s :=")
          for ((b, ex, cs) <- alts ) {
            val active = if (isBActive(b)) " " else "тип"
            val markS   = if (markedBs(b)) Console.GREEN else ""
            val markE   = if (markedBs(b)) Console.RESET else ""

            println(f"      $markS$active  $b%-4s => $ex%-40s [$cs]$markE")
          }
        }
      }

      private def computeCExpr(): Expr = {

        val lets = (for ((c, alts) <- cTree) yield {
          val activeAlts = alts.filter(a => isBActive(a._1))

          val expr = activeAlts.foldLeft(simplestValue(c.getType): Expr) {
            case (e, (b, ex, _)) => IfExpr(b.toVariable, ex, e)
          }

          (c, expr)
        })

        // We order the lets base don dependencies
        def defFor(c: Identifier): Expr = {
          cDeps(c).filter(lets.contains).foldLeft(lets(c)) {
            case (e, c) => Let(c, defFor(c), e)
          }
        }

        val res = tupleWrap(p.xs.map(defFor))

        val substMap : Map[Expr,Expr] = bsOrdered.zipWithIndex.map {
          case (b, i) => Variable(b) -> ArraySelect(bArrayId.toVariable, IntLiteral(i))
        }.toMap

        val simplerRes = simplifyLets(res)

        replace(substMap, simplerRes)
      }


      /**
       * Information about the final Program representing CEGIS solutions at
       * the current unfolding level
       */
      private val outerSolution = {
        val part = new PartialSolution(hctx.search.g, true)
        e : Expr => part.solutionAround(hctx.currentNode)(e).getOrElse {
          sctx.reporter.fatalError("Unable to create outer solution")
        }
      }

      private val bArrayId = FreshIdentifier("bArray", ArrayType(BooleanType), true)

      private var cTreeFd = new FunDef(FreshIdentifier("cTree", alwaysShowUniqueID = true),
                               Seq(),
                               p.outType,
                               p.as.map(id => ValDef(id)),
                               DefType.MethodDef
                             )

      private var phiFd   = new FunDef(FreshIdentifier("phiFd", alwaysShowUniqueID = true),
                               Seq(),
                               BooleanType,
                               p.as.map(id => ValDef(id)),
                               DefType.MethodDef
                             )

      private var programCTree: Program = _

      // Map functions from original program to cTree program
      private var fdMapCTree: Map[FunDef, FunDef] = _

      private var tester: (Seq[Expr], Set[Identifier]) => EvaluationResults.Result = _

      private def initializeCTreeProgram(): Unit = {

        // CEGIS is solved by called cTree function (without bs yet)
        val fullSol = outerSolution(FunctionInvocation(cTreeFd.typed, p.as.map(_.toVariable)))


        val chFd = hctx.ci.fd
        val prog0 = hctx.program

        val affected = prog0.callGraph.transitiveCallers(chFd) ++ Set(chFd, cTreeFd, phiFd) ++ fullSol.defs

        cTreeFd.body = None
        phiFd.body   = Some(
          letTuple(p.xs,
                   FunctionInvocation(cTreeFd.typed, p.as.map(_.toVariable)),
                   p.phi)
        )

        val prog1 = addFunDefs(prog0, Seq(cTreeFd, phiFd) ++ fullSol.defs, chFd)

        val (prog2, fdMap2) = replaceFunDefs(prog1)({
          case fd if affected(fd) =>
            // Add the b array argument to all affected functions
            val nfd = new FunDef(fd.id.freshen,
                                 fd.tparams,
                                 fd.returnType,
                                 fd.params :+ ValDef(bArrayId),
                                 fd.defType)
            nfd.copyContentFrom(fd)
            nfd.copiedFrom(fd)

            if (fd == chFd) {
              nfd.fullBody = replace(Map(hctx.ci.ch -> fullSol.guardedTerm), nfd.fullBody)
            }

            Some(nfd)

          case _ =>
            None
        }, {
          case (FunctionInvocation(old, args), newfd) if old.fd != newfd =>
            Some(FunctionInvocation(newfd.typed(old.tps), args :+ bArrayId.toVariable))
          case _ =>
            None
        })

        programCTree = prog2
        cTreeFd      = fdMap2(cTreeFd)
        phiFd        = fdMap2(phiFd)
        fdMapCTree   = fdMap2
      }

      private def setCExpr(cTree: Expr): Unit = {

        cTreeFd.body = Some(preMap{
          case FunctionInvocation(TypedFunDef(fd, tps), args) if fdMapCTree contains fd =>
            Some(FunctionInvocation(fdMapCTree(fd).typed(tps), args :+ bArrayId.toVariable))
          case _ =>
            None
        }(cTree))

        //println("-- "*30)
        //println(programCTree)
        //println(".. "*30)

        val evaluator  = new DualEvaluator(sctx.context, programCTree, CodeGenParams.default)


        tester =
          { (ins: Seq[Expr], bValues: Set[Identifier]) =>
            val bsValue = finiteArray(bsOrdered.map(b => BooleanLiteral(bValues(b))), None, BooleanType)
            val args = ins :+ bsValue

            val fi = FunctionInvocation(phiFd.typed, args)

            evaluator.eval(fi, Map())
          }
      }


      private def updateCTree() {
        if (programCTree eq null) {
          initializeCTreeProgram()
        }

        setCExpr(computeCExpr())
      }

      def testForProgram(bValues: Set[Identifier])(ins: Seq[Expr]): Boolean = {
        tester(ins, bValues) match {
          case EvaluationResults.Successful(res) =>
            res == BooleanLiteral(true)

          case EvaluationResults.RuntimeError(err) =>
            false

          case EvaluationResults.EvaluatorError(err) =>
            sctx.reporter.error("Error testing CE: "+err)
            false
        }
      }



      def getExpr(bValues: Set[Identifier]): Expr = {
        def getCValue(c: Identifier): Expr = {
          cTree(c).find(i => bValues(i._1)).map {
            case (b, ex, cs) =>
              val map = for (c <- cs) yield {
                c -> getCValue(c)
              }

              substAll(map.toMap, ex)
          }.getOrElse {
            simplestValue(c.getType)
          }
        }

        tupleWrap(p.xs.map(c => getCValue(c)))
      }

      def validatePrograms(bss: Set[Set[Identifier]]): Either[Stream[Solution], Seq[Seq[Expr]]] = {
        try {
          val cexs = for (bs <- bss.toSeq) yield {
            val sol = getExpr(bs)

            val fullSol = outerSolution(sol)

            val prog = addFunDefs(hctx.program, fullSol.defs, hctx.ci.fd)

            hctx.ci.ch.impl = Some(fullSol.guardedTerm)

            val cnstr = and(p.pc, letTuple(p.xs, sol, Not(p.phi)))
            //println("Solving for: "+cnstr)

            val solver = (new FairZ3Solver(ctx, prog) with TimeoutSolver).setTimeout(cexSolverTo)
            try {
              solver.assertCnstr(cnstr)
              solver.check match {
                case Some(true) =>
                  excludeProgram(bs)
                  val model = solver.getModel
                  //println("Found counter example: ")
                  //for ((s, v) <- model) {
                  //  println(" "+s.asString+" -> "+v.asString)
                  //}

                  //val evaluator  = new DefaultEvaluator(ctx, prog)
                  //println(evaluator.eval(cnstr, model))

                  Some(p.as.map(a => model.getOrElse(a, simplestValue(a.getType))))

                case Some(false) =>
                  // UNSAT, valid program
                  return Left(Stream(Solution(BooleanLiteral(true), Set(), sol, true)))

                case None =>
                  if (useOptTimeout) {
                    // Interpret timeout in CE search as "the candidate is valid"
                    sctx.reporter.info("CEGIS could not prove the validity of the resulting expression")
                    // Optimistic valid solution
                    return Left(Stream(Solution(BooleanLiteral(true), Set(), sol, false)))
                  } else {
                    None
                  }
              }
            } finally {
              solver.free()
            }
          }

          Right(cexs.flatten)
        } finally {
          hctx.ci.ch.impl = None
        }
      }

      var excludedPrograms = ArrayBuffer[Set[Identifier]]()

      // Explicitly remove program computed by bValues from the search space
      def excludeProgram(bValues: Set[Identifier]): Unit = {
        val bvs = bValues.filter(isBActive)
        //println(f" (-)  ${bvs.mkString(", ")}%-40s  ("+getExpr(bvs)+")")
        excludedPrograms += bvs
      }

      /**
       * Shrinks the non-deterministic program to the provided set of
       * alternatives only
       */
      def shrinkTo(remainingBs: Set[Identifier], finalUnfolding: Boolean): Unit = {
        //println("Shrinking!")

        val initialBs = remainingBs ++ (if (finalUnfolding) Set() else closedBs.keySet)

        var cParent = Map[Identifier, Identifier]()
        var cOfB    = Map[Identifier, Identifier]()
        var underBs = Map[Identifier, Set[Identifier]]()

        for ((cparent, alts) <- cTree;
             (b, _, cs) <- alts) {

          cOfB += b -> cparent

          for (cchild <- cs) {
            underBs += cchild -> (underBs.getOrElse(cchild, Set()) + b)
            cParent += cchild -> cparent
          }
        }

        def bParents(b: Identifier): Set[Identifier] = {
          val parentBs = underBs.getOrElse(cOfB(b), Set())

          Set(b) ++ parentBs.flatMap(bParents)
        }

        // include parents
        val keptBs = initialBs.flatMap(bParents)

        //println("Initial Bs: "+initialBs)
        //println("Keeping Bs: "+keptBs)

        //debugCExpr(cTree, keptBs)

        var newCTree = Map[Identifier, Seq[(Identifier, Expr, Set[Identifier])]]()

        for ((c, alts) <- cTree) yield {
          newCTree += c -> alts.filter(a => keptBs(a._1))
        }

        def removeDeadAlts(c: Identifier, deadC: Identifier) {
          if (newCTree contains c) {
            val alts = newCTree(c)
            val newAlts = alts.filterNot(a => a._3 contains deadC)

            if (newAlts.isEmpty) {
              for (cp <- cParent.get(c)) {
                removeDeadAlts(cp, c)
              }
              newCTree -= c
            } else {
              newCTree += c -> newAlts
            }
          }
        }

        //println("BETWEEN")
        //debugCExpr(newCTree, keptBs)

        for ((c, alts) <- newCTree if alts.isEmpty) {
          for (cp <- cParent.get(c)) {
            removeDeadAlts(cp, c)
          }
          newCTree -= c
        }

        var newCDeps = Map[Identifier, Set[Identifier]]()

        for ((c, alts) <- cTree) yield {
          newCDeps += c -> alts.map(_._3).toSet.flatten
        }

        cTree        = newCTree
        cDeps        = newCDeps
        closedBs     = closedBs.filterKeys(keptBs)

        bs           = cTree.map(_._2.map(_._1)).flatten.toSet
        bsOrdered    = bs.toSeq.sortBy(_.id)

        excludedPrograms = excludedPrograms.filter(_.forall(bs))

        //debugCExpr(cTree)
        updateCTree()
      }

      class CGenerator {
        private var buffers = Map[T, Stream[Identifier]]()
        
        private var slots = Map[T, Int]().withDefaultValue(0)

        private def streamOf(t: T): Stream[Identifier] = {
          FreshIdentifier("c", t.getType, true) #:: streamOf(t)
        }

        def reset(): Unit = {
          slots = Map[T, Int]().withDefaultValue(0)
        }

        def getNext(t: T) = {
          if (!(buffers contains t)) {
            buffers += t -> streamOf(t)
          }

          val n = slots(t)
          slots += t -> (n+1)

          buffers(t)(n)
        }
      }

      def unfold(finalUnfolding: Boolean): Boolean = {
        var newBs = Set[Identifier]()
        var unfoldedSomething = false

        def freshB() = {
          val id = FreshIdentifier("B", BooleanType, true)
          newBs += id
          id
        }

        val unfoldBehind = if (cTree.isEmpty) {
          p.xs
        } else {
          closedBs.flatMap(_._2).toSet
        }

        closedBs = Map[Identifier, Set[Identifier]]()

        // Set of Cs that still have no active alternatives after unfolding
        var postClosedCs = Set[Identifier]()

        for (c <- unfoldBehind) {
          var alts = grammar.getProductions(labels(c))

          if (finalUnfolding) {
            alts = alts.filter(_.subTrees.isEmpty)
          }

          val cGen = new CGenerator()

          val cTreeInfos = if (alts.nonEmpty) {
            for (gen <- alts) yield {
              val b = freshB()

              // Optimize labels
              cGen.reset()

              val cToLabel = for (t <- gen.subTrees) yield {
                cGen.getNext(t) -> t
              }


              labels ++= cToLabel

              val cs = cToLabel.map(_._1)
              val ex = gen.builder(cs.map(_.toVariable))

              if (cs.nonEmpty) {
                closedBs += b -> cs.toSet
              }

              //println(" + "+b+" => "+c+" = "+ex)

              unfoldedSomething = true

              (b, ex, cs.toSet)
            }
          } else {
            // Happens in final unfolding when no alts have ground terms
            val b = freshB()
            closedBs += b -> Set()

            Seq((b, simplestValue(c.getType), Set[Identifier]()))
          }

          cTree += c -> cTreeInfos
          cDeps += c -> cTreeInfos.map(_._3).toSet.flatten
        }

        sctx.reporter.ifDebug { printer =>
          printer("Grammar so far:")
          grammar.printProductions(printer)
        }

        bs           = bs ++ newBs
        bsOrdered    = bs.toSeq.sortBy(_.id)

        /**
         * Close dead-ends
         *
         * Find 'c' that have no active alternatives, then close all 'b's that
         * depend on such "dead" 'c's
         */
        var deadCs = Set[Identifier]()

        for ((c, alts) <- cTree) {
          if (alts.forall{ case (b, _, _) => !isBActive(b) }) {
            deadCs += c
          }
        }

        for ((_, alts) <- cTree; (b, _, cs) <- alts) {
          if ((cs & deadCs).nonEmpty) {
            closedBs += (b -> closedBs.getOrElse(b, Set()))
          }
        }

        //debugCExpr(cTree)
        updateCTree()

        unfoldedSomething
      }

      def solveForTentativeProgram(): Option[Option[Set[Identifier]]] = {
        val solver  = (new FairZ3Solver(ctx, programCTree) with TimeoutSolver).setTimeout(exSolverTo)
        val cnstr = FunctionInvocation(phiFd.typed, phiFd.params.map(_.id.toVariable))
        //debugCExpr(cTree)

        //println(" --- PhiFD ---")
        //println(phiFd.fullBody.asString(ctx))

        val fixedBs = finiteArray(bsOrdered.map(_.toVariable), None, BooleanType)
        val cnstrFixed = replaceFromIDs(Map(bArrayId -> fixedBs), cnstr)

        val toFind = and(p.pc, cnstrFixed)
        //println(" --- Constraints ---")
        //println(" - "+toFind)
        solver.assertCnstr(toFind)

        // oneOfBs
        //println(" -- OneOf:")
        for ((c, alts) <- cTree) {
          val activeBs = alts.map(_._1).filter(isBActive)

          val either = for (a1 <- activeBs; a2 <- activeBs if a1.globalId < a2.globalId) yield {
            Or(Not(a1.toVariable), Not(a2.toVariable))
          }

          if (activeBs.nonEmpty) {
            //println(" - "+andJoin(either))
            solver.assertCnstr(andJoin(either))

            val oneOf = orJoin(activeBs.map(_.toVariable))
            //println(" - "+oneOf)
            solver.assertCnstr(oneOf)
          }
        }

        //println(" -- Excluded:")
        //println(" -- Active:")
        val isActive = andJoin(bsOrdered.filterNot(isBActive).map(id => Not(id.toVariable)))
        //println("  - "+isActive)
        solver.assertCnstr(isActive)


        for (b <- bs.filterNot(isBActive)) {

        }

        //println(" -- Excluded:")
        for (ex <- excludedPrograms) {
          val notThisProgram = Not(andJoin(ex.map(_.toVariable).toSeq))

          //println(f"  - $notThisProgram%-40s ("+getExpr(ex)+")")
          solver.assertCnstr(notThisProgram)
        }

        try {
          solver.check match {
            case Some(true) =>
              val model = solver.getModel

              val bModel = bs.filter(b => model.get(b).contains(BooleanLiteral(true)))

              //println("Tentative expr: "+getExpr(bModel))
              Some(Some(bModel))

            case Some(false) =>
              //println("UNSAT!")
              Some(None)

            case None =>
              /**
               * If the remaining tentative programs are all infeasible, it
               * might timeout instead of returning Some(false). We might still
               * benefit from unfolding further
               */
              ctx.reporter.debug("Timeout while getting tentative program!")
              Some(None)
          }
        } finally {
          solver.free()
        }
      }

      def solveForCounterExample(bs: Set[Identifier]): Option[Option[Seq[Expr]]] = {
        val solver = (new FairZ3Solver(ctx, programCTree) with TimeoutSolver).setTimeout(cexSolverTo)
        val cnstr = FunctionInvocation(phiFd.typed, phiFd.params.map(_.id.toVariable))

        val fixedBs = finiteArray(bsOrdered.map(b => BooleanLiteral(bs(b))), None, BooleanType)
        val cnstrFixed = replaceFromIDs(Map(bArrayId -> fixedBs), cnstr)

        solver.assertCnstr(p.pc)
        solver.assertCnstr(Not(cnstrFixed))

        try {
          solver.check match {
            case Some(true) =>
              val model = solver.getModel
              val cex = p.as.map(a => model.getOrElse(a, simplestValue(a.getType)))

              Some(Some(cex))

            case Some(false) =>
              Some(None)

            case None =>
              None
          }
        } finally {
          solver.free()
        }
      }

      def free(): Unit = {

      }
    }

    List(new RuleInstantiation(this.name) {
      def apply(hctx: SearchContext): RuleApplication = {
        var result: Option[RuleApplication]   = None
        val sctx = hctx.sctx

        val ndProgram = new NonDeterministicProgram(p)
        var unfolding = 1
        val maxUnfoldings = params.maxUnfoldings

        sctx.reporter.debug(s"maxUnfoldings=$maxUnfoldings")

        var baseExampleInputs: ArrayBuffer[Seq[Expr]] = new ArrayBuffer[Seq[Expr]]()

        // We populate the list of examples with a predefined one
        sctx.reporter.debug("Acquiring initial list of examples")

        val ef = new ExamplesFinder(sctx.context, sctx.program)
        baseExampleInputs ++= ef.extractTests(p).map(_.ins).toSet

        val pc = p.pc

        if (pc == BooleanLiteral(true)) {
          baseExampleInputs += p.as.map(a => simplestValue(a.getType))
        } else {
          val solver = sctx.newSolver.setTimeout(exSolverTo)

          solver.assertCnstr(pc)

          try {
            solver.check match {
              case Some(true) =>
                val model = solver.getModel
                baseExampleInputs += p.as.map(a => model.getOrElse(a, simplestValue(a.getType)))

              case Some(false) =>
                sctx.reporter.debug("Path-condition seems UNSAT")
                return RuleFailed()

              case None =>
                sctx.reporter.warning("Solver could not solve path-condition")
                return RuleFailed() // This is not necessary though, but probably wanted
            }
          } finally {
            solver.free()
          }
        }

        sctx.reporter.ifDebug { debug =>
          baseExampleInputs.foreach { in =>
            debug("  - "+in.mkString(", "))
          }
        }

        /**
         * We generate tests for discarding potential programs
         */
        val inputIterator: Iterator[Seq[Expr]] = if (useVanuatoo) {
          new VanuatooDataGen(sctx.context, sctx.program).generateFor(p.as, pc, 20, 3000)
        } else {
          val evaluator  = new DualEvaluator(sctx.context, sctx.program, CodeGenParams.default)
          new GrammarDataGen(evaluator, ExpressionGrammars.ValueGrammar).generateFor(p.as, pc, 20, 1000)
        }

        val cachedInputIterator = new Iterator[Seq[Expr]] {
          def next() = {
            val i = inputIterator.next()
            baseExampleInputs += i
            i
          }

          def hasNext = {
            inputIterator.hasNext
          }
        }

        val failedTestsStats = new MutableMap[Seq[Expr], Int]().withDefaultValue(0)

        def hasInputExamples = baseExampleInputs.size > 0 || cachedInputIterator.hasNext

        var n = 1
        def allInputExamples() = {
          if (n % 1000 == 0) {
            baseExampleInputs = baseExampleInputs.sortBy(e => -failedTestsStats(e))
          }
          n += 1
          baseExampleInputs.iterator ++ cachedInputIterator
        }

        try {
          do {
            var skipCESearch = false

            // Unfold formula
            val unfoldSuccess = ndProgram.unfold(unfolding == maxUnfoldings)

            if (!unfoldSuccess) {
              unfolding = maxUnfoldings
            }

            // Compute all programs that have not been excluded yet
            var prunedPrograms: Set[Set[Identifier]] = ndProgram.allPrograms().toSet

            val nInitial = prunedPrograms.size
            sctx.reporter.debug("#Programs: "+nInitial)
            //sctx.reporter.ifDebug{ printer =>
            //  val limit = 100

            //  for (p <- prunedPrograms.take(limit)) {
            //    val ps = p.toSeq.sortBy(_.id).mkString(", ")
            //    printer(f" - $ps%-40s - "+ndProgram.getExpr(p))
            //  }
            //  if(nInitial > limit) {
            //    printer(" - ...")
            //  }
            //}

            var wrongPrograms = Set[Set[Identifier]]()

            // We further filter the set of working programs to remove those that fail on known examples
            if (hasInputExamples) {
              for (bs <- prunedPrograms if !interruptManager.isInterrupted) {
                var valid = true
                val examples = allInputExamples()
                while(valid && examples.hasNext) {
                  val e = examples.next()
                  if (!ndProgram.testForProgram(bs)(e)) {
                    failedTestsStats(e) += 1
                    sctx.reporter.debug(f" Program: ${ndProgram.getExpr(bs)}%-80s failed on: ${e.mkString(", ")}")
                    wrongPrograms += bs
                    prunedPrograms -= bs

                    valid = false
                  }
                }

                if (wrongPrograms.size+1 % 1000 == 0) {
                  sctx.reporter.debug("..."+wrongPrograms.size)
                }
              }
            }

            val nPassing = prunedPrograms.size
            sctx.reporter.debug("#Programs passing tests: "+nPassing)
            sctx.reporter.ifDebug{ printer =>
              for (p <- prunedPrograms.take(10)) {
                printer(" - "+ndProgram.getExpr(p))
              }
              if(nPassing > 10) {
                printer(" - ...")
              }
            }
            sctx.reporter.debug("#Tests: "+baseExampleInputs.size)
            sctx.reporter.ifDebug{ printer =>
              for (i <- baseExampleInputs.take(10)) {
                printer(" - "+i.mkString(", "))
              }
              if(baseExampleInputs.size > 10) {
                printer(" - ...")
              }
            }


            if (nPassing == 0 || interruptManager.isInterrupted) {
              // No test passed, we can skip solver and unfold again, if possible
              skipCESearch = true
            } else {
              var doFilter = true

              if (validateUpTo > 0) {
                // Validate the first N programs individualy
                ndProgram.validatePrograms(prunedPrograms.take(validateUpTo)) match {
                  case Left(sols) if sols.nonEmpty =>
                    doFilter = false
                    result = Some(RuleClosed(sols))
                  case Right(cexs) =>
                    baseExampleInputs ++= cexs

                    if (nPassing <= validateUpTo) {
                      // All programs failed verification, we filter everything out and unfold
                      //ndProgram.shrinkTo(Set(), unfolding == maxUnfoldings)
                      doFilter     = false
                      skipCESearch = true
                    }
                }
              }

              if (doFilter) {
                if (nPassing < nInitial * shrinkThreshold && useShrink) {
                  // We shrink the program to only use the bs mentionned
                  val bssToKeep = prunedPrograms.foldLeft(Set[Identifier]())(_ ++ _)
                  ndProgram.shrinkTo(bssToKeep, unfolding == maxUnfoldings)
                } else {
                  wrongPrograms.foreach {
                    ndProgram.excludeProgram
                  }
                }
              }
            }

            // CEGIS Loop at a given unfolding level
            while (result.isEmpty && !skipCESearch && !interruptManager.isInterrupted) {

              ndProgram.solveForTentativeProgram() match {
                case Some(Some(bs)) =>
                  // Should we validate this program with Z3?

                  val validateWithZ3 = if (hasInputExamples) {

                    if (allInputExamples().forall(ndProgram.testForProgram(bs))) {
                      // All valid inputs also work with this, we need to
                      // make sure by validating this candidate with z3
                      true
                    } else {
                      // One valid input failed with this candidate, we can skip
                      ndProgram.excludeProgram(bs)
                      false
                    }
                  } else {
                    // No inputs or capability to test, we need to ask Z3
                    true
                  }

                  if (validateWithZ3) {
                    ndProgram.solveForCounterExample(bs) match {
                      case Some(Some(inputsCE)) =>
                        // Found counter example!
                        baseExampleInputs += inputsCE

                        // Retest whether the newly found C-E invalidates all programs
                        if (prunedPrograms.forall(p => !ndProgram.testForProgram(p)(inputsCE))) {
                          skipCESearch = true
                        } else {
                          ndProgram.excludeProgram(bs)
                        }

                      case Some(None) =>
                        // Found no counter example! Program is a valid solution
                        val expr = ndProgram.getExpr(bs)
                        result = Some(RuleClosed(Solution(BooleanLiteral(true), Set(), expr)))

                      case None =>
                        // We are not sure
                        if (useOptTimeout) {
                          // Interpret timeout in CE search as "the candidate is valid"
                          sctx.reporter.info("CEGIS could not prove the validity of the resulting expression")
                          val expr = ndProgram.getExpr(bs)
                          result = Some(RuleClosed(Solution(BooleanLiteral(true), Set(), expr, isTrusted = false)))
                        } else {
                          result = Some(RuleFailed())
                        }
                      }
                    }

                case Some(None) =>
                  skipCESearch = true

                case None =>
                  result = Some(RuleFailed())
              }
            }

            unfolding += 1
          } while(unfolding <= maxUnfoldings && result.isEmpty && !interruptManager.isInterrupted)

          result.getOrElse(RuleFailed())

        } catch {
          case e: Throwable =>
            sctx.reporter.warning("CEGIS crashed: "+e.getMessage)
            e.printStackTrace()
            RuleFailed()
        } finally {
          ndProgram.free()
        }
      }
    })
  }
}