diff --git a/src/main/scala/leon/purescala/Extractors.scala b/src/main/scala/leon/purescala/Extractors.scala index 2047e175597ef4b8791ece32281441378023c39a..72cc54c21084f0032cc794aa5c7a66bdb4de98ed 100644 --- a/src/main/scala/leon/purescala/Extractors.scala +++ b/src/main/scala/leon/purescala/Extractors.scala @@ -98,7 +98,6 @@ object Extractors { } } )) - case FiniteMap(args) => { val subArgs = args.flatMap{case (k, v) => Seq(k, v)} val builder: (Seq[Expr]) => Expr = (as: Seq[Expr]) => { diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala index 2b914a9716811ca60542d3d520e91b8f88740a90..a5c9d9a89e334121ce9edf53523336cdf033296b 100644 --- a/src/main/scala/leon/purescala/PrettyPrinter.scala +++ b/src/main/scala/leon/purescala/PrettyPrinter.scala @@ -142,7 +142,7 @@ class PrettyPrinter(sb: StringBuffer = new StringBuffer) { case SetUnion(l,r) => ppBinary(l, r, " \u222A ", lvl) // \cup case MultisetUnion(l,r) => ppBinary(l, r, " \u222A ", lvl) // \cup case MapUnion(l,r) => ppBinary(l, r, " \u222A ", lvl) // \cup - case SetDifference(l,r) => ppBinary(l, r, " \\ ", lvl) + case SetDifference(l,r) => ppBinary(l, r, " \\ ", lvl) case MultisetDifference(l,r) => ppBinary(l, r, " \\ ", lvl) case SetIntersection(l,r) => ppBinary(l, r, " \u2229 ", lvl) // \cap case MultisetIntersection(l,r) => ppBinary(l, r, " \u2229 ", lvl) // \cap diff --git a/src/main/scala/leon/purescala/Trees.scala b/src/main/scala/leon/purescala/Trees.scala index 4cf4d8bbf794d4eac77bfc4bbd4cd34a65ab0c72..d7bd3e9689bc856486d9aa975a6653c183c88021 100644 --- a/src/main/scala/leon/purescala/Trees.scala +++ b/src/main/scala/leon/purescala/Trees.scala @@ -318,6 +318,7 @@ object Trees { object Not { def apply(expr : Expr) : Expr = expr match { case Not(e) => e + case BooleanLiteral(v) => BooleanLiteral(!v) case _ => new Not(expr) } @@ -478,6 +479,7 @@ object Trees { case class SubsetOf(set1: Expr, set2: Expr) extends Expr with FixedType { val fixedType = BooleanType } + case class SetIntersection(set1: Expr, set2: Expr) extends Expr { leastUpperBound(Seq(set1, set2).map(_.getType)).foreach(setType _) } @@ -487,8 +489,12 @@ object Trees { case class SetDifference(set1: Expr, set2: Expr) extends Expr { leastUpperBound(Seq(set1, set2).map(_.getType)).foreach(setType _) } - case class SetMin(set: Expr) extends Expr - case class SetMax(set: Expr) extends Expr + case class SetMin(set: Expr) extends Expr with FixedType { + val fixedType = Int32Type + } + case class SetMax(set: Expr) extends Expr with FixedType { + val fixedType = Int32Type + } /* Multiset expressions */ case class EmptyMultiset(baseType: TypeTree) extends Expr with Terminal diff --git a/src/main/scala/leon/solvers/SimpleSolverAPI.scala b/src/main/scala/leon/solvers/SimpleSolverAPI.scala index fd4c65252e46a0d5c7460b79bae1a3a2f263892c..6565a2148c7959ae2a19c402ad18c5aef93dcb14 100644 --- a/src/main/scala/leon/solvers/SimpleSolverAPI.scala +++ b/src/main/scala/leon/solvers/SimpleSolverAPI.scala @@ -6,7 +6,7 @@ package solvers import purescala.Common._ import purescala.Trees._ -case class SimpleSolverAPI(sf: SolverFactory[Solver]) { +case class SimpleSolverAPI[S <: Solver](sf: SolverFactory[S]) { def solveVALID(expression: Expr): Option[Boolean] = { val s = sf.getNewSolver() s.assertCnstr(Not(expression)) diff --git a/src/main/scala/leon/solvers/combinators/DNFSolverFactory.scala b/src/main/scala/leon/solvers/combinators/DNFSolverFactory.scala new file mode 100644 index 0000000000000000000000000000000000000000..6ef2752d900bdc78922a6ffcad83c48581c35ca3 --- /dev/null +++ b/src/main/scala/leon/solvers/combinators/DNFSolverFactory.scala @@ -0,0 +1,173 @@ +/* Copyright 2009-2013 EPFL, Lausanne */ + +package leon +package solvers +package combinators + +import purescala.Common._ +import purescala.Definitions._ +import purescala.Trees._ +import purescala.TypeTrees._ +import purescala.TreeOps._ + +import scala.collection.mutable.{Map=>MutableMap} + +class DNFSolverFactory[S <: Solver](val sf : SolverFactory[S]) extends SolverFactory[Solver] { + val description = "DNF around a base solver" + val name = sf.name + "!" + + val context = sf.context + val program = sf.program + + private val thisFactory = this + + override def free() { + sf.free() + } + + override def recoverInterrupt() { + sf.recoverInterrupt() + } + + def getNewSolver() : Solver = { + new Solver { + private var theConstraint : Option[Expr] = None + private var theModel : Option[Map[Identifier,Expr]] = None + + private def fail(because : String) : Nothing = { throw new Exception("Not supported in DNFSolvers : " + because) } + + def push() : Unit = fail("push()") + def pop(lvl : Int = 1) : Unit = fail("pop(lvl)") + + def assertCnstr(expression : Expr) { + if(!theConstraint.isEmpty) { fail("Multiple assertCnstr(...).") } + theConstraint = Some(expression) + } + + def interrupt() { fail("interrupt()") } + + def recoverInterrupt() { fail("recoverInterrupt()") } + + def check : Option[Boolean] = theConstraint.map { expr => + import context.reporter + + val simpleSolver = SimpleSolverAPI(sf) + + var result : Option[Boolean] = None + + def info(msg : String) { reporter.info("In " + thisFactory.name + ": " + msg) } + + // info("Before NNF:\n" + expr) + + val nnfed = nnf(expr, false) + + // info("After NNF:\n" + nnfed) + + val dnfed = dnf(nnfed) + + // info("After DNF:\n" + dnfed) + + val candidates : Seq[Expr] = dnfed match { + case Or(es) => es + case elze => Seq(elze) + } + + info("# conjuncts : " + candidates.size) + + var done : Boolean = false + + for(candidate <- candidates if !done) { + simpleSolver.solveSAT(candidate) match { + case (Some(false), _) => + result = Some(false) + + case (Some(true), m) => + result = Some(true) + theModel = Some(m) + done = true + + case (None, m) => + result = None + theModel = Some(m) + done = true + } + } + result + } getOrElse { + Some(true) + } + + def checkAssumptions(assumptions : Set[Expr]) : Option[Boolean] = { + fail("checkAssumptions(assumptions)") + } + + def getModel : Map[Identifier,Expr] = { + val vs : Set[Identifier] = theConstraint.map(variablesOf(_)).getOrElse(Set.empty) + theModel.getOrElse(Map.empty).filter(p => vs(p._1)) + } + + def getUnsatCore : Set[Expr] = { fail("getUnsatCore") } + } + } + + private def nnf(expr : Expr, flip : Boolean) : Expr = expr match { + case _ : Let | _ : IfExpr => throw new Exception("Can't NNF *everything*, sorry.") + case Not(Implies(l,r)) => nnf(And(l, Not(r)), flip) + case Implies(l, r) => nnf(Or(Not(l), r), flip) + case Not(Iff(l, r)) => nnf(Or(And(l, Not(r)), And(Not(l), r)), flip) + case Iff(l, r) => nnf(Or(And(l, r), And(Not(l), Not(r))), flip) + case And(es) if flip => Or(es.map(e => nnf(e, true))) + case And(es) => And(es.map(e => nnf(e, false))) + case Or(es) if flip => And(es.map(e => nnf(e, true))) + case Or(es) => Or(es.map(e => nnf(e, false))) + case Not(e) if flip => nnf(e, false) + case Not(e) => nnf(e, true) + case LessThan(l,r) if flip => GreaterEquals(l,r) + case GreaterThan(l,r) if flip => LessEquals(l,r) + case LessEquals(l,r) if flip => GreaterThan(l,r) + case GreaterEquals(l,r) if flip => LessThan(l,r) + case elze if flip => Not(elze) + case elze => elze + } + + // fun pushC (And(p,Or(q,r))) = Or(pushC(And(p,q)),pushC(And(p,r))) + // | pushC (And(Or(q,r),p)) = Or(pushC(And(p,q)),pushC(And(p,r))) + // | pushC (And(p,q)) = And(pushC(p),pushC(q)) + // | pushC (Literal(l)) = Literal(l) + // | pushC (Or(p,q)) = Or(pushC(p),pushC(q)) + + private def dnf(expr : Expr) : Expr = expr match { + case And(es) => + val (ors, lits) = es.partition(_.isInstanceOf[Or]) + if(!ors.isEmpty) { + val orHead = ors.head.asInstanceOf[Or] + val orTail = ors.tail + Or(orHead.exprs.map(oe => dnf(And(filterObvious(lits ++ (oe +: orTail)))))) + } else { + expr + } + + case Or(es) => + Or(es.map(dnf(_))) + + case _ => expr + } + + private def filterObvious(exprs : Seq[Expr]) : Seq[Expr] = { + var pos : List[Identifier] = Nil + var neg : List[Identifier] = Nil + + for(e <- exprs) e match { + case Variable(id) => pos = id :: pos + case Not(Variable(id)) => neg = id :: neg + case _ => ; + } + + val both : Set[Identifier] = pos.toSet intersect neg.toSet + if(!both.isEmpty) { + Seq(BooleanLiteral(false)) + } else { + exprs + } + } +} diff --git a/src/main/scala/leon/solvers/combinators/FunctionTemplate.scala b/src/main/scala/leon/solvers/combinators/FunctionTemplate.scala new file mode 100644 index 0000000000000000000000000000000000000000..143b1c7557ea76553b1ec84d739a9e4c104d144e --- /dev/null +++ b/src/main/scala/leon/solvers/combinators/FunctionTemplate.scala @@ -0,0 +1,311 @@ +/* Copyright 2009-2013 EPFL, Lausanne */ + +package leon +package solvers.combinators + +import purescala.Common._ +import purescala.Trees._ +import purescala.Extractors._ +import purescala.TreeOps._ +import purescala.TypeTrees._ +import purescala.Definitions._ + +import evaluators._ + +import scala.collection.mutable.{Set=>MutableSet,Map=>MutableMap} + +class FunctionTemplate private( + val funDef : FunDef, + val activatingBool : Identifier, + condVars : Set[Identifier], + exprVars : Set[Identifier], + guardedExprs : Map[Identifier,Seq[Expr]], + isRealFunDef : Boolean) { + + private val funDefArgsIDs : Seq[Identifier] = funDef.args.map(_.id) + + private val asClauses : Seq[Expr] = { + (for((b,es) <- guardedExprs; e <- es) yield { + Implies(Variable(b), e) + }).toSeq + } + + val blockers : Map[Identifier,Set[FunctionInvocation]] = { + val idCall = FunctionInvocation(funDef, funDef.args.map(_.toVariable)) + + Map((for((b, es) <- guardedExprs) yield { + val calls = es.foldLeft(Set.empty[FunctionInvocation])((s,e) => s ++ functionCallsOf(e)) - idCall + if(calls.isEmpty) { + None + } else { + Some((b, calls)) + } + }).flatten.toSeq : _*) + } + + private def idToFreshID(id : Identifier) : Identifier = { + FreshIdentifier(id.name, true).setType(id.getType) + } + + // We use a cache to create the same boolean variables. + private val cache : MutableMap[Seq[Expr],Map[Identifier,Expr]] = MutableMap.empty + + def instantiate(aVar : Identifier, args : Seq[Expr]) : (Seq[Expr], Map[Identifier,Set[FunctionInvocation]]) = { + assert(args.size == funDef.args.size) + + val (wasHit,baseIDSubstMap) = cache.get(args) match { + case Some(m) => (true,m) + case None => + val newMap : Map[Identifier,Expr] = + (exprVars ++ condVars).map(id => id -> Variable(idToFreshID(id))).toMap ++ + (funDefArgsIDs zip args) + cache(args) = newMap + (false, newMap) + } + + val idSubstMap : Map[Identifier,Expr] = baseIDSubstMap + (activatingBool -> Variable(aVar)) + val exprSubstMap : Map[Expr,Expr] = idSubstMap.map(p => (Variable(p._1), p._2)) + + val newClauses = asClauses.map(replace(exprSubstMap, _)) + + val newBlockers = blockers.map { case (id, funs) => + val bp = if (id == activatingBool) { + aVar + } else { + // That's not exactly safe... + idSubstMap(id).asInstanceOf[Variable].id + } + + val newFuns = funs.map(fi => fi.copy(args = fi.args.map(replace(exprSubstMap, _)))) + + bp -> newFuns + } + + (newClauses, newBlockers) + } + + override def toString : String = { + "Template for def " + funDef.id + "(" + funDef.args.map(a => a.id + " : " + a.tpe).mkString(", ") + ") : " + funDef.returnType + " is :\n" + + " * Activating boolean : " + activatingBool + "\n" + + " * Control booleans : " + condVars.toSeq.map(_.toString).mkString(", ") + "\n" + + " * Expression vars : " + exprVars.toSeq.map(_.toString).mkString(", ") + "\n" + + " * \"Clauses\" : " + "\n " + asClauses.mkString("\n ") + "\n" + + " * Block-map : " + blockers.toString + } +} + +object FunctionTemplate { + val splitAndOrImplies = false + + def mkTemplate(funDef: FunDef, isRealFunDef : Boolean = true) : FunctionTemplate = { + val condVars : MutableSet[Identifier] = MutableSet.empty + val exprVars : MutableSet[Identifier] = MutableSet.empty + + // Represents clauses of the form: + // id => expr && ... && expr + val guardedExprs : MutableMap[Identifier,Seq[Expr]] = MutableMap.empty + + def storeGuarded(guardVar : Identifier, expr : Expr) : Unit = { + assert(expr.getType == BooleanType) + if(guardedExprs.isDefinedAt(guardVar)) { + val prev : Seq[Expr] = guardedExprs(guardVar) + guardedExprs(guardVar) = expr +: prev + } else { + guardedExprs(guardVar) = Seq(expr) + } + } + + // Group elements that satisfy p toghether + // List(a, a, a, b, c, a, a), with p = _ == a will produce: + // List(List(a,a,a), List(b), List(c), List(a, a)) + def groupWhile[T](p: T => Boolean, l: Seq[T]): Seq[Seq[T]] = { + var res: Seq[Seq[T]] = Nil + + var c = l + + while(!c.isEmpty) { + val (span, rest) = c.span(p) + + if (span.isEmpty) { + res = res :+ Seq(rest.head) + c = rest.tail + } else { + res = res :+ span + c = rest + } + } + + res + } + + def rec(pathVar : Identifier, expr : Expr) : Expr = { + expr match { + case l @ Let(i, e, b) => + val newExpr : Identifier = FreshIdentifier("lt", true).setType(i.getType) + exprVars += newExpr + val re = rec(pathVar, e) + storeGuarded(pathVar, Equals(Variable(newExpr), re)) + val rb = rec(pathVar, replace(Map(Variable(i) -> Variable(newExpr)), b)) + rb + + case l @ LetTuple(is, e, b) => + val tuple : Identifier = FreshIdentifier("t", true).setType(TupleType(is.map(_.getType))) + exprVars += tuple + val re = rec(pathVar, e) + storeGuarded(pathVar, Equals(Variable(tuple), re)) + + val mapping = for ((id, i) <- is.zipWithIndex) yield { + val newId = FreshIdentifier("ti", true).setType(id.getType) + exprVars += newId + storeGuarded(pathVar, Equals(Variable(newId), TupleSelect(Variable(tuple), i+1))) + + (Variable(id) -> Variable(newId)) + } + + val rb = rec(pathVar, replace(mapping.toMap, b)) + rb + + case m : MatchExpr => sys.error("MatchExpr's should have been eliminated.") + + case i @ Implies(lhs, rhs) => + if (splitAndOrImplies) { + if (containsFunctionCalls(i)) { + rec(pathVar, IfExpr(lhs, rhs, BooleanLiteral(true))) + } else { + i + } + } else { + Implies(rec(pathVar, lhs), rec(pathVar, rhs)) + } + + case a @ And(parts) => + if (splitAndOrImplies) { + if (containsFunctionCalls(a)) { + val partitions = groupWhile((e: Expr) => !containsFunctionCalls(e), parts) + + val ifExpr = partitions.map(And(_)).reduceRight{ (a: Expr, b: Expr) => IfExpr(a, b, BooleanLiteral(false)) } + + rec(pathVar, ifExpr) + } else { + a + } + } else { + And(parts.map(rec(pathVar, _))) + } + + case o @ Or(parts) => + if (splitAndOrImplies) { + if (containsFunctionCalls(o)) { + val partitions = groupWhile((e: Expr) => !containsFunctionCalls(e), parts) + + val ifExpr = partitions.map(Or(_)).reduceRight{ (a: Expr, b: Expr) => IfExpr(a, BooleanLiteral(true), b) } + + rec(pathVar, ifExpr) + } else { + o + } + } else { + Or(parts.map(rec(pathVar, _))) + } + + case i @ IfExpr(cond, thenn, elze) => { + if(!containsFunctionCalls(cond) && !containsFunctionCalls(thenn) && !containsFunctionCalls(elze)) { + i + } else { + val newBool1 : Identifier = FreshIdentifier("b", true).setType(BooleanType) + val newBool2 : Identifier = FreshIdentifier("b", true).setType(BooleanType) + val newExpr : Identifier = FreshIdentifier("e", true).setType(i.getType) + + condVars += newBool1 + condVars += newBool2 + + exprVars += newExpr + + val crec = rec(pathVar, cond) + val trec = rec(newBool1, thenn) + val erec = rec(newBool2, elze) + + storeGuarded(pathVar, Or(Variable(newBool1), Variable(newBool2))) + storeGuarded(pathVar, Or(Not(Variable(newBool1)), Not(Variable(newBool2)))) + // TODO can we improve this? i.e. make it more symmetrical? + // Probably it's symmetrical enough to Z3. + storeGuarded(pathVar, Iff(Variable(newBool1), crec)) + storeGuarded(newBool1, Equals(Variable(newExpr), trec)) + storeGuarded(newBool2, Equals(Variable(newExpr), erec)) + Variable(newExpr) + } + } + + case c @ Choose(ids, cond) => + val cid = FreshIdentifier("choose", true).setType(c.getType) + exprVars += cid + + val m: Map[Expr, Expr] = if (ids.size == 1) { + Map(Variable(ids.head) -> Variable(cid)) + } else { + ids.zipWithIndex.map{ case (id, i) => Variable(id) -> TupleSelect(Variable(cid), i+1) }.toMap + } + + storeGuarded(pathVar, replace(m, cond)) + Variable(cid) + + case n @ NAryOperator(as, r) => r(as.map(a => rec(pathVar, a))).setType(n.getType) + case b @ BinaryOperator(a1, a2, r) => r(rec(pathVar, a1), rec(pathVar, a2)).setType(b.getType) + case u @ UnaryOperator(a, r) => r(rec(pathVar, a)).setType(u.getType) + case t : Terminal => t + } + } + + // The precondition if it exists. + val prec : Option[Expr] = funDef.precondition.map(p => matchToIfThenElse(p)) + + val newBody : Option[Expr] = funDef.body.map(b => matchToIfThenElse(b)) + + val invocation : Expr = FunctionInvocation(funDef, funDef.args.map(_.toVariable)) + + val invocationEqualsBody : Option[Expr] = newBody match { + case Some(body) if isRealFunDef => + val b : Expr = Equals(invocation, body) + + Some(if(prec.isDefined) { + Implies(prec.get, b) + } else { + b + }) + + case _ => + None + } + + val activatingBool : Identifier = FreshIdentifier("start", true).setType(BooleanType) + + if (isRealFunDef) { + val finalPred : Option[Expr] = invocationEqualsBody.map(expr => rec(activatingBool, expr)) + finalPred.foreach(p => storeGuarded(activatingBool, p)) + } else { + val newFormula = rec(activatingBool, newBody.get) + storeGuarded(activatingBool, newFormula) + } + + // Now the postcondition. + funDef.postcondition match { + case Some((id, post)) => + val newPost : Expr = replace(Map(Variable(id) -> invocation), matchToIfThenElse(post)) + + val postHolds : Expr = + if(funDef.hasPrecondition) { + Implies(prec.get, newPost) + } else { + newPost + } + + val finalPred2 : Expr = rec(activatingBool, postHolds) + storeGuarded(activatingBool, finalPred2) + case None => + + } + + new FunctionTemplate(funDef, activatingBool, Set(condVars.toSeq : _*), Set(exprVars.toSeq : _*), Map(guardedExprs.toSeq : _*), +isRealFunDef) + } +} diff --git a/src/main/scala/leon/solvers/combinators/RewritingSolverFactory.scala b/src/main/scala/leon/solvers/combinators/RewritingSolverFactory.scala new file mode 100644 index 0000000000000000000000000000000000000000..2a8fbb37a2301dcb9bef9518ecb7d4b5c87f961b --- /dev/null +++ b/src/main/scala/leon/solvers/combinators/RewritingSolverFactory.scala @@ -0,0 +1,81 @@ +/* Copyright 2009-2013 EPFL, Lausanne */ + +package leon +package solvers +package combinators + +import purescala.Common._ +import purescala.Definitions._ +import purescala.Trees._ +import purescala.TypeTrees._ + +/** This is for solvers that operate by rewriting formulas into equisatisfiable ones. + * They are essentially defined by two methods, one for preprocessing of the expressions, + * and one for reconstructing the models. */ +abstract class RewritingSolverFactory[S <: Solver,T](val sf : SolverFactory[S]) extends SolverFactory[Solver] { + val context = sf.context + val program = sf.program + + override def free() { + sf.free() + } + + override def recoverInterrupt() { + sf.recoverInterrupt() + } + + /** The type T is used to encode any meta information useful, for instance, to reconstruct + * models. */ + def rewriteCnstr(expression : Expr) : (Expr,T) + + def reconstructModel(model : Map[Identifier,Expr], meta : T) : Map[Identifier,Expr] + + def getNewSolver() : Solver = { + new Solver { + val underlying : Solver = sf.getNewSolver() + + private def fail(because : String) : Nothing = { + throw new Exception("Not supported in RewritingSolvers : " + because) + } + + def push() : Unit = fail("push()") + def pop(lvl : Int = 1) : Unit = fail("pop(lvl)") + + private var storedMeta : List[T] = Nil + + def assertCnstr(expression : Expr) { + context.reporter.info("Asked to solve this in BAPA<:\n" + expression) + val (rewritten, meta) = rewriteCnstr(expression) + storedMeta = meta :: storedMeta + underlying.assertCnstr(rewritten) + } + + def interrupt() { + underlying.interrupt() + } + + def recoverInterrupt() { + underlying.recoverInterrupt() + } + + def check : Option[Boolean] = { + underlying.check + } + + def checkAssumptions(assumptions : Set[Expr]) : Option[Boolean] = { + fail("checkAssumptions(assumptions)") + } + + def getModel : Map[Identifier,Expr] = { + storedMeta match { + case Nil => fail("reconstructing model without meta-information.") + case m :: _ => reconstructModel(underlying.getModel, m) + } + } + + def getUnsatCore : Set[Expr] = { + fail("getUnsatCore") + } + } + } +} diff --git a/src/main/scala/leon/solvers/combinators/TimeoutSolverFactory.scala b/src/main/scala/leon/solvers/combinators/TimeoutSolverFactory.scala index ce9e50712cb3b05b2c0c2f5936dd905f631c5ee4..dc670bfbe578c467786acd6f9675a708af10608b 100644 --- a/src/main/scala/leon/solvers/combinators/TimeoutSolverFactory.scala +++ b/src/main/scala/leon/solvers/combinators/TimeoutSolverFactory.scala @@ -9,8 +9,6 @@ import purescala.Definitions._ import purescala.Trees._ import purescala.TypeTrees._ -import scala.sys.error - class TimeoutSolverFactory[S <: Solver](val sf: SolverFactory[S], val timeoutMs: Long) extends SolverFactory[Solver] { val description = sf.description + ", with "+timeoutMs+"ms timeout" val name = sf.name + "+to" diff --git a/src/main/scala/leon/solvers/combinators/UnrollingSolverFactory.scala b/src/main/scala/leon/solvers/combinators/UnrollingSolverFactory.scala new file mode 100644 index 0000000000000000000000000000000000000000..b2f8862f6a15d0cea2237fe9276c99730da2f7f5 --- /dev/null +++ b/src/main/scala/leon/solvers/combinators/UnrollingSolverFactory.scala @@ -0,0 +1,190 @@ +/* Copyright 2009-2013 EPFL, Lausanne */ + +package leon +package solvers +package combinators + +import purescala.Common._ +import purescala.Definitions._ +import purescala.Trees._ +import purescala.TypeTrees._ +import purescala.TreeOps._ + +import scala.collection.mutable.{Map=>MutableMap} + +class UnrollingSolverFactory[S <: Solver](val sf : SolverFactory[S]) extends SolverFactory[Solver] { + val description = "Unrolling loop around a base solver." + val name = sf.name + "*" + + val context = sf.context + val program = sf.program + + // Yes, a hardcoded constant. Sue me. + val MAXUNROLLINGS : Int = 3 + + private val thisFactory = this + + override def free() { + sf.free() + } + + override def recoverInterrupt() { + sf.recoverInterrupt() + } + + def getNewSolver() : Solver = { + new Solver { + private var theConstraint : Option[Expr] = None + private var theModel : Option[Map[Identifier,Expr]] = None + + private def fail(because : String) : Nothing = { + throw new Exception("Not supported in UnrollingSolvers : " + because) + } + + def push() : Unit = fail("push()") + def pop(lvl : Int = 1) : Unit = fail("pop(lvl)") + + def assertCnstr(expression : Expr) { + if(!theConstraint.isEmpty) { + fail("Multiple assertCnstr(...).") + } + theConstraint = Some(expression) + } + + def interrupt() { fail("interrupt()") } + + def recoverInterrupt() { fail("recoverInterrupt()") } + + def check : Option[Boolean] = theConstraint.map { expr => + import context.reporter + + val simpleSolver = SimpleSolverAPI(sf) + + def info(msg : String) { reporter.info("In " + thisFactory.name + ": " + msg) } + + info("Check called on " + expr + "...") + + val template = getTemplate(expr) + + val aVar : Identifier = template.activatingBool + var allClauses : Seq[Expr] = Nil + var allBlockers : Map[Identifier,Set[FunctionInvocation]] = Map.empty + + def fullOpenExpr : Expr = { + // And(Variable(aVar), And(allClauses.reverse)) + // Let's help the poor underlying guy a little bit... + // Note that I keep aVar around, because it may negate one of the blockers, and we can't miss that! + And(Variable(aVar), replace(Map(Variable(aVar) -> BooleanLiteral(true)), And(allClauses.reverse))) + } + + def fullClosedExpr : Expr = { + val blockedVars : Seq[Expr] = allBlockers.toSeq.map(p => Variable(p._1)) + + And( + replace(blockedVars.map(v => (v -> BooleanLiteral(false))).toMap, fullOpenExpr), + And(blockedVars.map(Not(_))) + ) + } + + def unrollOneStep() { + val blockersBefore = allBlockers + + var newClauses : List[Seq[Expr]] = Nil + var newBlockers : Map[Identifier,Set[FunctionInvocation]] = Map.empty + + for(blocker <- allBlockers.keySet; FunctionInvocation(funDef, args) <- allBlockers(blocker)) { + val (nc, nb) = getTemplate(funDef).instantiate(blocker, args) + newClauses = nc :: newClauses + newBlockers = newBlockers ++ nb + } + + allClauses = newClauses.flatten ++ allClauses + allBlockers = newBlockers + } + + val (nc, nb) = template.instantiate(aVar, template.funDef.args.map(a => Variable(a.id))) + + allClauses = nc.reverse + allBlockers = nb + + var unrollingCount : Int = 0 + var done : Boolean = false + var result : Option[Boolean] = None + + // We're now past the initial step. + while(!done && unrollingCount < MAXUNROLLINGS) { + info("At lvl : " + unrollingCount) + val closed : Expr = fullClosedExpr + + info("Going for SAT with this:\n" + closed) + + simpleSolver.solveSAT(closed) match { + case (Some(false), _) => + val open = fullOpenExpr + info("Was UNSAT... Going for UNSAT with this:\n" + open) + simpleSolver.solveSAT(open) match { + case (Some(false), _) => + info("Was UNSAT... Done !") + done = true + result = Some(false) + + case _ => + info("Was SAT or UNKNOWN. Let's unroll !") + unrollingCount += 1 + unrollOneStep() + } + + case (Some(true), model) => + info("WAS SAT ! We're DONE !") + done = true + result = Some(true) + theModel = Some(model) + + case (None, model) => + info("WAS UNKNOWN ! We're DONE !") + done = true + result = Some(true) + theModel = Some(model) + } + } + result + + } getOrElse { + Some(true) + } + + def checkAssumptions(assumptions : Set[Expr]) : Option[Boolean] = { + fail("checkAssumptions(assumptions)") + } + + def getModel : Map[Identifier,Expr] = { + val vs : Set[Identifier] = theConstraint.map(variablesOf(_)).getOrElse(Set.empty) + theModel.getOrElse(Map.empty).filter(p => vs(p._1)) + } + + def getUnsatCore : Set[Expr] = { fail("getUnsatCore") } + } + } + + private val funDefTemplateCache : MutableMap[FunDef, FunctionTemplate] = MutableMap.empty + private val exprTemplateCache : MutableMap[Expr, FunctionTemplate] = MutableMap.empty + + private def getTemplate(funDef: FunDef): FunctionTemplate = { + funDefTemplateCache.getOrElse(funDef, { + val res = FunctionTemplate.mkTemplate(funDef, true) + funDefTemplateCache += funDef -> res + res + }) + } + + private def getTemplate(body: Expr): FunctionTemplate = { + exprTemplateCache.getOrElse(body, { + val fakeFunDef = new FunDef(FreshIdentifier("fake", true), body.getType, variablesOf(body).toSeq.map(id => VarDecl(id, id.getType))) + fakeFunDef.body = Some(body) + + val res = FunctionTemplate.mkTemplate(fakeFunDef, false) + exprTemplateCache += body -> res + res + }) + } +} diff --git a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala index 6e8dbd5f799d04e1c0b03bc8f7eb06e792169883..3bc56101590feed24655d46d55296605d7a8b3c2 100644 --- a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala @@ -455,7 +455,13 @@ trait AbstractZ3Solver extends SolverFactory[Solver] { case IntLiteral(v) => z3.mkInt(v, intSort) case BooleanLiteral(v) => if (v) z3.mkTrue() else z3.mkFalse() case UnitLiteral => unitValue - case Equals(l, r) => z3.mkEq(rec(l), rec(r)) + case Equals(l, r) => { + //if(l.getType != r.getType) + // println("Warning : wrong types in equality for " + l + " == " + r) + z3.mkEq(rec( l ), rec( r ) ) + } + + //case Equals(l, r) => z3.mkEq(rec(l), rec(r)) case Plus(l, r) => z3.mkAdd(rec(l), rec(r)) case Minus(l, r) => z3.mkSub(rec(l), rec(r)) case Times(l, r) => z3.mkMul(rec(l), rec(r)) diff --git a/src/main/scala/leon/verification/AnalysisPhase.scala b/src/main/scala/leon/verification/AnalysisPhase.scala index bdbcc6c304c22513c707d2f457a5eaf3559d3996..3a9d16e32aee4c0faca94c1592a2ff790fe862ec 100644 --- a/src/main/scala/leon/verification/AnalysisPhase.scala +++ b/src/main/scala/leon/verification/AnalysisPhase.scala @@ -11,6 +11,8 @@ import purescala.TypeTrees._ import solvers._ import solvers.z3._ +import solvers.bapaminmax._ +import solvers.combinators._ import scala.collection.mutable.{Set => MutableSet} @@ -68,7 +70,11 @@ object AnalysisPhase extends LeonPhase[Program,VerificationReport] { for((funDef, vcs) <- vcs.toSeq.sortWith((a,b) => a._1 < b._1); vcInfo <- vcs if !interruptManager.isInterrupted()) { val funDef = vcInfo.funDef val vc = vcInfo.condition + + val time0 : Long = System.currentTimeMillis + val time1 = System.currentTimeMillis + reporter.info("Now considering '" + vcInfo.kind + "' VC for " + funDef.id + "...") reporter.debug("Verification condition (" + vcInfo.kind + ") for ==== " + funDef.id + " ====") reporter.debug(simplifyLets(vc)) @@ -142,9 +148,11 @@ object AnalysisPhase extends LeonPhase[Program,VerificationReport] { val reporter = ctx.reporter - val fairZ3 = new FairZ3SolverFactory(ctx, program) + lazy val fairZ3 = new FairZ3SolverFactory(ctx, program) - val baseSolvers : Seq[SolverFactory[Solver]] = fairZ3 :: Nil + val baseSolvers : Seq[SolverFactory[Solver]] = { + fairZ3 :: Nil + } val solvers: Seq[SolverFactory[Solver]] = timeout match { case Some(t) =>