diff --git a/src/integration/scala/leon/test/solvers/TimeoutSolverSuite.scala b/src/integration/scala/leon/test/solvers/TimeoutSolverSuite.scala index 7f9fc9f35d8cfab41a47bf235b54299c917c5dfc..0bdb5e99c0770562ae9ddc2f2c805d58a578fb99 100644 --- a/src/integration/scala/leon/test/solvers/TimeoutSolverSuite.scala +++ b/src/integration/scala/leon/test/solvers/TimeoutSolverSuite.scala @@ -37,6 +37,8 @@ class TimeoutSolverSuite extends LeonTestSuite { def free() {} + def reset() {} + def getModel = ??? } diff --git a/src/main/scala/leon/solvers/EnumerationSolver.scala b/src/main/scala/leon/solvers/EnumerationSolver.scala index 5b14c05da378511c4437a498d6b477578f4fc101..e01d8f7edea4caeb9059c70ef68cc48eb8fb085a 100644 --- a/src/main/scala/leon/solvers/EnumerationSolver.scala +++ b/src/main/scala/leon/solvers/EnumerationSolver.scala @@ -21,30 +21,27 @@ class EnumerationSolver(val context: LeonContext, val program: Program) extends private var interrupted = false - var freeVars = List[List[Identifier]](Nil) - var constraints = List[List[Expr]](Nil) + val freeVars = new IncrementalSet[Identifier]() + val constraints = new IncrementalSeq[Expr]() def assertCnstr(expression: Expr): Unit = { - constraints = (constraints.head :+ expression) :: constraints.tail - - val newFreeVars = (variablesOf(expression) -- freeVars.flatten).toList - - freeVars = (freeVars.head ::: newFreeVars) :: freeVars.tail + constraints += expression + freeVars ++= variablesOf(expression) } def push() = { - freeVars = Nil :: freeVars - constraints = Nil :: constraints + freeVars.push() + constraints.push() } - def pop(lvl: Int) = { - freeVars = freeVars.drop(lvl) - constraints = constraints.drop(lvl) + def pop() = { + freeVars.pop() + constraints.pop() } def reset() = { - freeVars = List(Nil) - constraints = List(Nil) + freeVars.clear() + constraints.clear() interrupted = false datagen = None } @@ -59,8 +56,8 @@ class EnumerationSolver(val context: LeonContext, val program: Program) extends None } else { modelMap = Map() - val allFreeVars = freeVars.reverse.flatten - val allConstraints = constraints.reverse.flatten + val allFreeVars = freeVars.toSet.toSeq.sortBy(_.name) + val allConstraints = constraints.toSeq val it = datagen.get.generateFor(allFreeVars, andJoin(allConstraints), 1, maxTried) @@ -86,7 +83,7 @@ class EnumerationSolver(val context: LeonContext, val program: Program) extends } def free() = { - constraints = Nil + constraints.clear() } def interrupt(): Unit = { diff --git a/src/main/scala/leon/solvers/GroundSolver.scala b/src/main/scala/leon/solvers/GroundSolver.scala index 443706fdf1d9ccaf7e86820bd4453f1543b0de66..ad4ab6dc7a290d79f0bddf520b2248a48eb681dc 100644 --- a/src/main/scala/leon/solvers/GroundSolver.scala +++ b/src/main/scala/leon/solvers/GroundSolver.scala @@ -11,9 +11,10 @@ import purescala.Expressions.{BooleanLiteral, Expr} import purescala.ExprOps.isGround import purescala.Constructors.andJoin import utils.Interruptible +import utils.IncrementalSeq // This solver only "solves" ground terms by evaluating them -class GroundSolver(val context: LeonContext, val program: Program) extends Solver with Interruptible { +class GroundSolver(val context: LeonContext, val program: Program) extends IncrementalSolver with Interruptible { context.interruptManager.registerForInterrupts(this) @@ -22,15 +23,17 @@ class GroundSolver(val context: LeonContext, val program: Program) extends Solve def name: String = "ground" - private var assertions: List[Expr] = Nil + private val assertions = new IncrementalSeq[Expr]() // Ground terms will always have the empty model def getModel: Map[Identifier, Expr] = Map() - def assertCnstr(expression: Expr): Unit = assertions ::= expression + def assertCnstr(expression: Expr): Unit = { + assertions += expression + } def check: Option[Boolean] = { - val expr = andJoin(assertions) + val expr = andJoin(assertions.toSeq) if (isGround(expr)) { evaluator.eval(expr) match { @@ -49,10 +52,20 @@ class GroundSolver(val context: LeonContext, val program: Program) extends Solve } } - def free(): Unit = assertions = Nil + def free(): Unit = { + assertions.clear() + } + + def push() = { + assertions.push() + } + + def pop() = { + assertions.pop() + } def reset() = { - assertions = Nil + assertions.reset() } def interrupt(): Unit = {} diff --git a/src/main/scala/leon/solvers/IncrementalSolver.scala b/src/main/scala/leon/solvers/IncrementalSolver.scala index ec0480c6c16b0654ab0aa999e4a52812c221cbc7..c935558c0a9bfe71f23f7eeef16fe95825022f9c 100644 --- a/src/main/scala/leon/solvers/IncrementalSolver.scala +++ b/src/main/scala/leon/solvers/IncrementalSolver.scala @@ -5,6 +5,6 @@ package solvers trait IncrementalSolver extends Solver { def push(): Unit - def pop(lvl: Int = 1): Unit + def pop(): Unit } diff --git a/src/main/scala/leon/solvers/ResettableSolver.scala b/src/main/scala/leon/solvers/ResettableSolver.scala deleted file mode 100644 index ca5eba931d3a99e63ac9d8a8d4057dec99041e07..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/solvers/ResettableSolver.scala +++ /dev/null @@ -1,10 +0,0 @@ -/* Copyright 2009-2015 EPFL, Lausanne */ - -package leon -package solvers - -import purescala.Expressions.Expr - -trait ResettableSolver extends Solver { - def reset() -} diff --git a/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala b/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala index b93e6826a652fdddaeaaca3740f50e3a00efe900..4b24f4a6e87fd1092788e7cb782f76c7c032f6d5 100644 --- a/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala +++ b/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala @@ -9,6 +9,7 @@ import purescala.Definitions._ import purescala.Constructors._ import purescala.Expressions._ import purescala.ExprOps._ +import utils._ import z3.FairZ3Component.{optFeelingLucky, optUseCodeGen, optAssumePre} import templates._ @@ -31,8 +32,8 @@ class UnrollingSolver(val context: LeonContext, program: Program, underlying: In protected var lastCheckResult : (Boolean, Option[Boolean], Option[Map[Identifier,Expr]]) = (false, None, None) - protected var varsInVC = List[Set[Identifier]](Set()) - protected var frameExpressions = List[List[Expr]](Nil) + private val freeVars = new IncrementalSet[Identifier]() + private val constraints = new IncrementalSeq[Expr]() protected var interrupted : Boolean = false @@ -69,14 +70,15 @@ class UnrollingSolver(val context: LeonContext, program: Program, underlying: In val solver = underlying def assertCnstr(expression: Expr) { - frameExpressions = (expression :: frameExpressions.head) :: frameExpressions.tail + constraints += expression val freeIds = variablesOf(expression) - varsInVC = (varsInVC.head ++ freeIds) :: varsInVC.tail - val freeVars = freeIds.map(_.toVariable: Expr) + freeVars ++= freeIds - val bindings = freeVars.zip(freeVars).toMap + val newVars = freeIds.map(_.toVariable: Expr) + + val bindings = newVars.zip(newVars).toMap val newClauses = unrollingBank.getClauses(expression, bindings) @@ -88,15 +90,15 @@ class UnrollingSolver(val context: LeonContext, program: Program, underlying: In def push() { unrollingBank.push() solver.push() - varsInVC = Set[Identifier]() :: varsInVC - frameExpressions = Nil :: frameExpressions + freeVars.push() + constraints.push() } - def pop(lvl: Int = 1) { - unrollingBank.pop(lvl) - solver.pop(lvl) - varsInVC = varsInVC.drop(lvl) - frameExpressions = frameExpressions.drop(lvl) + def pop() { + unrollingBank.pop() + solver.pop() + freeVars.pop() + constraints.pop() } def check: Option[Boolean] = { @@ -112,8 +114,8 @@ class UnrollingSolver(val context: LeonContext, program: Program, underlying: In def isValidModel(model: Map[Identifier, Expr], silenceErrors: Boolean = false): Boolean = { import EvaluationResults._ - val expr = andJoin(frameExpressions.flatten) - val allVars = varsInVC.flatten.toSet + val expr = andJoin(constraints.toSeq) + val allVars = freeVars.toSet val fullModel = allVars.map(v => v -> model.getOrElse(v, simplestValue(v.getType))).toMap @@ -241,7 +243,7 @@ class UnrollingSolver(val context: LeonContext, program: Program, underlying: In } def getModel: Map[Identifier,Expr] = { - val allVars = varsInVC.flatten.toSet + val allVars = freeVars.toSet lastCheckResult match { case (true, Some(true), Some(m)) => m.filterKeys(allVars) @@ -253,8 +255,8 @@ class UnrollingSolver(val context: LeonContext, program: Program, underlying: In override def reset() = { underlying.reset() lastCheckResult = (false, None, None) - varsInVC = List(Set()) - frameExpressions = List(Nil) + freeVars.reset() + constraints.reset() interrupted = false } diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala index 478c17d5d2af82ad178a65a81fb2bac2aaf6443e..8bfb6d92fabc264e91be9b89dceec4b5fc54baed 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala @@ -31,7 +31,7 @@ import _root_.smtlib.{Interpreter => SMTInterpreter} abstract class SMTLIBSolver(val context: LeonContext, val program: Program) - extends IncrementalSolver with ResettableSolver with Interruptible { + extends IncrementalSolver with Interruptible { /* Solver name */ @@ -783,9 +783,7 @@ abstract class SMTLIBSolver(val context: LeonContext, sendCommand(Push(1)) } - override def pop(lvl: Int = 1): Unit = { - assert(lvl == 1, "Current implementation only supports lvl = 1") - + override def pop(): Unit = { constructors.pop() selectors.pop() testers.pop() diff --git a/src/main/scala/leon/solvers/templates/UnrollingBank.scala b/src/main/scala/leon/solvers/templates/UnrollingBank.scala index 9d826865059987dc778c187ad1732d9fed0ea8bc..df1974f715acafe132af48ff8861b31d0dfae79e 100644 --- a/src/main/scala/leon/solvers/templates/UnrollingBank.scala +++ b/src/main/scala/leon/solvers/templates/UnrollingBank.scala @@ -7,67 +7,66 @@ package templates import purescala.Common._ import purescala.Expressions._ import purescala.Types._ +import utils._ -class UnrollingBank[T](reporter: Reporter, templateGenerator: TemplateGenerator[T]) { +class UnrollingBank[T](reporter: Reporter, templateGenerator: TemplateGenerator[T]) extends IncrementalState { implicit val debugSection = utils.DebugSectionSolver private val encoder = templateGenerator.encoder // Keep which function invocation is guarded by which guard, // also specify the generation of the blocker. - private var callInfoStack = List[Map[T, (Int, Int, T, Set[TemplateCallInfo[T]])]](Map()) - private def callInfo = callInfoStack.head - private def callInfo_= (v: Map[T, (Int, Int, T, Set[TemplateCallInfo[T]])]) = { - callInfoStack = v :: callInfoStack.tail - } + private val callInfos = new IncrementalMap[T, (Int, Int, T, Set[TemplateCallInfo[T]])]() + private def callInfo = callInfos.toMap // Function instantiations have their own defblocker - private var defBlockersStack = List[Map[TemplateCallInfo[T], T]](Map.empty) - private def defBlockers = defBlockersStack.head - private def defBlockers_= (v: Map[TemplateCallInfo[T], T]) : Unit = { - defBlockersStack = v :: defBlockersStack.tail - } + private val defBlockerss = new IncrementalMap[TemplateCallInfo[T], T]() + private def defBlockers = defBlockerss.toMap - private var appInfoStack = List[Map[(T, App[T]), (Int, Int, T, T, Set[TemplateAppInfo[T]])]](Map()) - private def appInfo = appInfoStack.head - private def appInfo_= (v: Map[(T, App[T]), (Int, Int, T, T, Set[TemplateAppInfo[T]])]) : Unit = { - appInfoStack = v :: appInfoStack.tail - } + private val appInfos = new IncrementalMap[(T, App[T]), (Int, Int, T, T, Set[TemplateAppInfo[T]])]() + private def appInfo = appInfos.toMap - private var appBlockersStack = List[Map[(T, App[T]), T]](Map.empty) - private def appBlockers = appBlockersStack.head - private def appBlockers_= (v: Map[(T, App[T]), T]) : Unit = { - appBlockersStack = v :: appBlockersStack.tail - } + private val appBlockerss = new IncrementalMap[(T, App[T]), T]() + private def appBlockers = appBlockerss.toMap - private var blockerToAppStack = List[Map[T, (T, App[T])]](Map.empty) - private def blockerToApp = blockerToAppStack.head - private def blockerToApp_= (v: Map[T, (T, App[T])]) : Unit = { - blockerToAppStack = v :: blockerToAppStack.tail + private val blockerToApps = new IncrementalMap[T, (T, App[T])]() + private def blockerToApp = blockerToApps.toMap + + private val functionVarss = new IncrementalMap[TypeTree, Set[T]]() + private def functionVars = functionVarss.toMap + + def push() { + callInfos.push() + defBlockerss.push() + appInfos.push() + appBlockerss.push() + blockerToApps.push() + functionVarss.push() } - private var functionVarsStack = List[Map[TypeTree, Set[T]]](Map.empty.withDefaultValue(Set.empty)) - private def functionVars = functionVarsStack.head - private def functionVars_= (v: Map[TypeTree, Set[T]]) : Unit = { - functionVarsStack = v :: functionVarsStack.tail + def pop() { + callInfos.pop() + defBlockerss.pop() + appInfos.pop() + appBlockerss.pop() + blockerToApps.pop() + functionVarss.pop() } - def push() { - appInfoStack = appInfo :: appInfoStack - callInfoStack = callInfo :: callInfoStack - defBlockersStack = defBlockers :: defBlockersStack - blockerToAppStack = blockerToApp :: blockerToAppStack - functionVarsStack = functionVars :: functionVarsStack - appBlockersStack = appBlockers :: appBlockersStack + def clear() { + callInfos.clear() + defBlockerss.clear() + appInfos.clear() + appBlockerss.clear() + functionVarss.clear() } - def pop(lvl: Int) { - appInfoStack = appInfoStack.drop(lvl) - callInfoStack = callInfoStack.drop(lvl) - defBlockersStack = defBlockersStack.drop(lvl) - blockerToAppStack = blockerToAppStack.drop(lvl) - functionVarsStack = functionVarsStack.drop(lvl) - appBlockersStack = appBlockersStack.drop(lvl) + def reset() { + callInfos.reset() + defBlockerss.reset() + appInfos.reset() + appBlockerss.reset() + functionVarss.reset() } def dumpBlockers() = { @@ -140,7 +139,7 @@ class UnrollingBank[T](reporter: Reporter, templateGenerator: TemplateGenerator[ apps.filter(!appBlockers.isDefinedAt(_)).toSeq.map { case app @ (blocker, App(caller, tpe, _)) => val firstB = encoder.encodeId(FreshIdentifier("b_lambda", BooleanType, true)) - val freeEq = functionVars(tpe).toSeq.map(t => encoder.mkEquals(t, caller)) + val freeEq = functionVars.getOrElse(tpe, Set()).toSeq.map(t => encoder.mkEquals(t, caller)) val clause = encoder.mkImplies(encoder.mkNot(encoder.mkOr((freeEq :+ firstB) : _*)), encoder.mkNot(blocker)) appBlockers += app -> firstB @@ -170,7 +169,7 @@ class UnrollingBank[T](reporter: Reporter, templateGenerator: TemplateGenerator[ val trArgs = template.tfd.params.map(vd => bindings(Variable(vd.id))) for (vd <- template.tfd.params if vd.getType.isInstanceOf[FunctionType]) { - functionVars += vd.getType -> (functionVars(vd.getType) + bindings(vd.toVariable)) + functionVars += vd.getType -> (functionVars.getOrElse(vd.getType, Set()) + bindings(vd.toVariable)) } // ...now this template defines clauses that are all guarded @@ -232,13 +231,13 @@ class UnrollingBank[T](reporter: Reporter, templateGenerator: TemplateGenerator[ var newClauses : Seq[T] = Seq.empty - val callInfos = ids.flatMap(id => callInfo.get(id).map(id -> _)) - callInfo = callInfo -- ids + val newCallInfos = ids.flatMap(id => callInfo.get(id).map(id -> _)) + callInfo --= ids val apps = ids.flatMap(id => blockerToApp.get(id)) val appInfos = apps.map(app => app -> appInfo(app)) - blockerToApp = blockerToApp -- ids - appInfo = appInfo -- apps + blockerToApp --= ids + appInfo --= apps for ((app, (_, _, _, _, infos)) <- appInfos if infos.nonEmpty) { val extension = extendAppBlock(app, infos) @@ -246,7 +245,7 @@ class UnrollingBank[T](reporter: Reporter, templateGenerator: TemplateGenerator[ newClauses :+= extension } - for ((id, (gen, _, _, infos)) <- callInfos; info @ TemplateCallInfo(tfd, args) <- infos) { + for ((id, (gen, _, _, infos)) <- newCallInfos; info @ TemplateCallInfo(tfd, args) <- infos) { var newCls = Seq[T]() val defBlocker = defBlockers.get(info) match { diff --git a/src/main/scala/leon/solvers/z3/FairZ3Solver.scala b/src/main/scala/leon/solvers/z3/FairZ3Solver.scala index c5a54685000852d099f6ee4a78a27e67f0e74747..c88a7677b1b05c2715bbbcd660ce370a0a09f6dc 100644 --- a/src/main/scala/leon/solvers/z3/FairZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/FairZ3Solver.scala @@ -4,7 +4,7 @@ package leon package solvers package z3 -import utils.IncrementalBijection +import utils._ import _root_.z3.scala._ import purescala.Common._ @@ -150,9 +150,9 @@ class FairZ3Solver(val context : LeonContext, val program: Program) val solver = z3.mkSolver() - private var varsInVC = List[Set[Identifier]](Set()) + private val freeVars = new IncrementalSet[Identifier]() + private var constraints = new IncrementalSeq[Expr]() - private var frameExpressions = List[List[Expr]](Nil) val unrollingBank = new UnrollingBank(reporter, templateGenerator) @@ -160,19 +160,16 @@ class FairZ3Solver(val context : LeonContext, val program: Program) errors.push() solver.push() unrollingBank.push() - varsInVC = Set[Identifier]() :: varsInVC - frameExpressions = Nil :: frameExpressions + freeVars.push() + constraints.push() } - def pop(lvl: Int = 1) { - for (i <- 1 until lvl) { - errors.pop() - } - - solver.pop(lvl) - unrollingBank.pop(lvl) - varsInVC = varsInVC.drop(lvl) - frameExpressions = frameExpressions.drop(lvl) + def pop() { + errors.pop() + solver.pop(1) + unrollingBank.pop() + freeVars.pop() + constraints.pop() } override def check: Option[Boolean] = { @@ -198,17 +195,17 @@ class FairZ3Solver(val context : LeonContext, val program: Program) def assertCnstr(expression: Expr) { try { - val freeVars = variablesOf(expression) - varsInVC = (varsInVC.head ++ freeVars) :: varsInVC.tail + val newFreeVars = variablesOf(expression) + freeVars ++= newFreeVars // We make sure all free variables are registered as variables - freeVars.foreach { v => + freeVars.toSet.foreach { v => variables.cachedB(Variable(v)) { templateGenerator.encoder.encodeId(v) } } - frameExpressions = (expression :: frameExpressions.head) :: frameExpressions.tail + constraints += expression val newClauses = unrollingBank.getClauses(expression, variables.aToB) @@ -232,7 +229,7 @@ class FairZ3Solver(val context : LeonContext, val program: Program) def fairCheck(assumptions: Set[Expr]): Option[Boolean] = { foundDefinitiveAnswer = false - def entireFormula = andJoin(assumptions.toSeq ++ frameExpressions.flatten) + def entireFormula = andJoin(assumptions.toSeq ++ constraints.toSeq) def foundAnswer(answer : Option[Boolean], model : Map[Identifier,Expr] = Map.empty, core: Set[Expr] = Set.empty) : Unit = { foundDefinitiveAnswer = true @@ -272,7 +269,7 @@ class FairZ3Solver(val context : LeonContext, val program: Program) reporter.debug(" - Finished search with blocked literals") - lazy val allVars = varsInVC.flatten.toSet + lazy val allVars = freeVars.toSet res match { case None => diff --git a/src/main/scala/leon/solvers/z3/UninterpretedZ3Solver.scala b/src/main/scala/leon/solvers/z3/UninterpretedZ3Solver.scala index 4b67702a953618e7da3b3ced528e1522896cc6a1..1106777377fa9de534bb0f1b60b1151197592f83 100644 --- a/src/main/scala/leon/solvers/z3/UninterpretedZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/UninterpretedZ3Solver.scala @@ -6,6 +6,7 @@ package solvers.z3 import z3.scala._ import leon.solvers._ +import utils.IncrementalSet import purescala.Common._ import purescala.Definitions._ @@ -42,13 +43,16 @@ class UninterpretedZ3Solver(val context : LeonContext, val program: Program) def push() { solver.push() + freeVariables.push() } - def pop(lvl: Int = 1) { - solver.pop(lvl) + def pop() { + solver.pop(1) + freeVariables.pop() } - private var freeVariables = Set[Identifier]() + private val freeVariables = new IncrementalSet[Identifier]() + def assertCnstr(expression: Expr) { freeVariables ++= variablesOf(expression) solver.assertCnstr(toZ3Formula(expression)) @@ -62,7 +66,7 @@ class UninterpretedZ3Solver(val context : LeonContext, val program: Program) } def getModel = { - modelToMap(solver.getModel(), freeVariables) + modelToMap(solver.getModel(), freeVariables.toSet) } def getUnsatCore = { diff --git a/src/main/scala/leon/utils/IncrementalBijection.scala b/src/main/scala/leon/utils/IncrementalBijection.scala index 99d0d70cdcc263d55e325edd4143b7788a2e2c9f..d90d68ec6a2fef80efff8ace6224c960007de5db 100644 --- a/src/main/scala/leon/utils/IncrementalBijection.scala +++ b/src/main/scala/leon/utils/IncrementalBijection.scala @@ -2,16 +2,10 @@ package leon.utils -class IncrementalBijection[A,B] extends Bijection[A,B] { +class IncrementalBijection[A,B] extends Bijection[A,B] with IncrementalState { private var a2bStack = List[Map[A,B]]() private var b2aStack = List[Map[B,A]]() - override def clear() : Unit = { - super.clear() - a2bStack = Nil - b2aStack = Nil - } - private def recursiveGet[T,U](stack: List[Map[T,U]], t: T): Option[U] = stack match { case t2u :: xs => t2u.get(t) orElse recursiveGet(xs, t) case Nil => None @@ -41,6 +35,12 @@ class IncrementalBijection[A,B] extends Bijection[A,B] { override def aSet = a2b.keySet ++ a2bStack.flatMap(_.keySet) override def bSet = b2a.keySet ++ b2aStack.flatMap(_.keySet) + def reset() : Unit = { + super.clear() + a2bStack = Nil + b2aStack = Nil + } + def push(): Unit = { a2bStack = a2b :: a2bStack b2aStack = b2a :: b2aStack @@ -54,5 +54,5 @@ class IncrementalBijection[A,B] extends Bijection[A,B] { a2bStack = a2bStack.tail b2aStack = b2aStack.tail } - + } diff --git a/src/main/scala/leon/utils/IncrementalMap.scala b/src/main/scala/leon/utils/IncrementalMap.scala new file mode 100644 index 0000000000000000000000000000000000000000..4515351de19cf83301f8888d7c52c35b386dc73b --- /dev/null +++ b/src/main/scala/leon/utils/IncrementalMap.scala @@ -0,0 +1,43 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon.utils + +import scala.collection.mutable.{Stack, Map => MMap} + +class IncrementalMap[A, B] extends IncrementalState { + private[this] val stack = new Stack[MMap[A, B]]() + + def clear(): Unit = { + stack.clear() + } + + def reset(): Unit = { + clear() + push() + } + + def push(): Unit = { + val last = if (stack.isEmpty) { + MMap[A,B]() + } else { + MMap[A,B]() ++ stack.head + } + stack.push(last) + } + + def pop(): Unit = { + stack.pop() + } + + def +=(a: A, b: B): Unit = { + stack.head += a -> b + } + + def ++=(as: Traversable[(A, B)]): Unit = { + stack.head ++= as + } + + def toMap = stack.head + + push() +} diff --git a/src/main/scala/leon/utils/IncrementalSeq.scala b/src/main/scala/leon/utils/IncrementalSeq.scala new file mode 100644 index 0000000000000000000000000000000000000000..9f66d66a895aeec9f11c64d7a59eef5268b39eb6 --- /dev/null +++ b/src/main/scala/leon/utils/IncrementalSeq.scala @@ -0,0 +1,35 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon.utils + +import scala.collection.mutable.Stack +import scala.collection.mutable.ArrayBuffer + +class IncrementalSeq[A] extends IncrementalState { + private[this] val stack = new Stack[ArrayBuffer[A]]() + + def clear() : Unit = { + stack.clear() + } + + def reset(): Unit = { + clear() + push() + } + + def push(): Unit = { + stack.push(new ArrayBuffer()) + } + + def pop(): Unit = { + stack.pop() + } + + def +=(e: A): Unit = { + stack.head += e + } + + def toSeq = stack.toSeq.flatten + + push() +} diff --git a/src/main/scala/leon/utils/IncrementalSet.scala b/src/main/scala/leon/utils/IncrementalSet.scala new file mode 100644 index 0000000000000000000000000000000000000000..95b473c756cb0a7a51835112c04d52bc4b404a46 --- /dev/null +++ b/src/main/scala/leon/utils/IncrementalSet.scala @@ -0,0 +1,38 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon.utils + +import scala.collection.mutable.{Stack, Set => MSet} + +class IncrementalSet[A] extends IncrementalState { + private[this] val stack = new Stack[MSet[A]]() + + def clear(): Unit = { + stack.clear() + } + + def reset(): Unit = { + clear() + push() + } + + def push(): Unit = { + stack.push(MSet()) + } + + def pop(): Unit = { + stack.pop() + } + + def +=(a: A): Unit = { + stack.head += a + } + + def ++=(as: Traversable[A]): Unit = { + stack.head ++= as + } + + def toSet = stack.toSet.flatten + + push() +} diff --git a/src/main/scala/leon/utils/IncrementalState.scala b/src/main/scala/leon/utils/IncrementalState.scala new file mode 100644 index 0000000000000000000000000000000000000000..b84606af2a7f8e0975bc37cc0f5d037e62c4e2a7 --- /dev/null +++ b/src/main/scala/leon/utils/IncrementalState.scala @@ -0,0 +1,9 @@ +package leon.utils + +trait IncrementalState { + def push(): Unit + def pop(): Unit + + def clear(): Unit + def reset(): Unit +}