/* Copyright 2009-2014 EPFL, Lausanne */ package leon package solvers.z3 import purescala.Common._ import purescala.Trees._ import purescala.Extractors._ import purescala.TreeOps._ import purescala.TypeTrees._ import purescala.Definitions._ import evaluators._ import z3.scala._ import scala.collection.mutable.{Set=>MutableSet,Map=>MutableMap} case class Z3FunctionInvocation(tfd: TypedFunDef, args: Seq[Z3AST]) { override def toString = tfd.signature + args.mkString("(", ",", ")") } class FunctionTemplate private( solver: FairZ3Solver, val tfd : TypedFunDef, activatingBool : Identifier, condVars : Set[Identifier], exprVars : Set[Identifier], guardedExprs : Map[Identifier,Seq[Expr]], isRealFunDef : Boolean) { private def isTerminatingForAllInputs : Boolean = ( isRealFunDef && !tfd.hasPrecondition && solver.getTerminator.terminates(tfd.fd).isGuaranteed ) private val z3 = solver.z3 private val asClauses : Seq[Expr] = { (for((b,es) <- guardedExprs; e <- es) yield { Implies(Variable(b), e) }).toSeq } val z3ActivatingBool = solver.idToFreshZ3Id(activatingBool) private val z3FunDefArgs = tfd.params.map( ad => solver.idToFreshZ3Id(ad.id)) private val zippedCondVars = condVars.map(id => (id, solver.idToFreshZ3Id(id))) private val zippedExprVars = exprVars.map(id => (id, solver.idToFreshZ3Id(id))) private val zippedFunDefArgs = tfd.params.map(_.id) zip z3FunDefArgs val idToZ3Ids: Map[Identifier, Z3AST] = { Map(activatingBool -> z3ActivatingBool) ++ zippedCondVars ++ zippedExprVars ++ zippedFunDefArgs } val asZ3Clauses: Seq[Z3AST] = asClauses.map { cl => solver.toZ3Formula(cl, idToZ3Ids).getOrElse(sys.error("Could not translate to z3. Did you forget --xlang? @"+cl.getPos)) } private val blockers : Map[Identifier,Set[FunctionInvocation]] = { val idCall = FunctionInvocation(tfd, tfd.params.map(_.toVariable)) Map((for((b, es) <- guardedExprs) yield { val calls = es.foldLeft(Set.empty[FunctionInvocation])((s,e) => s ++ functionCallsOf(e)) - idCall if(calls.isEmpty) { None } else { Some((b, calls)) } }).flatten.toSeq : _*) } val z3Blockers: Map[Z3AST,Set[Z3FunctionInvocation]] = blockers.map { case (b, funs) => (idToZ3Ids(b) -> funs.map(fi => Z3FunctionInvocation(fi.tfd, fi.args.map(solver.toZ3Formula(_, idToZ3Ids).get)))) } // We use a cache to create the same boolean variables. private val cache : MutableMap[Seq[Z3AST],Map[Z3AST,Z3AST]] = MutableMap.empty def instantiate(aVar : Z3AST, args : Seq[Z3AST]) : (Seq[Z3AST], Map[Z3AST,Set[Z3FunctionInvocation]]) = { assert(args.size == tfd.params.size) // The "isRealFunDef" part is to prevent evaluation of "fake" // function templates, as generated from FairZ3Solver. if(solver.evalGroundApps && isRealFunDef) { val ga = args.view.map(solver.asGround) if(ga.forall(_.isDefined)) { val leonArgs = ga.map(_.get).force val invocation = FunctionInvocation(tfd, leonArgs) solver.getEvaluator.eval(invocation) match { case EvaluationResults.Successful(result) => val z3Invocation = z3.mkApp(solver.functionDefToDecl(tfd), args: _*) val z3Value = solver.toZ3Formula(result).get val asZ3 = z3.mkEq(z3Invocation, z3Value) return (Seq(asZ3), Map.empty) case _ => throw new Exception("Evaluation of ground term should have succeeded.") } } } // ...end of ground evaluation part. val (wasHit,baseIDSubstMap) = cache.get(args) match { case Some(m) => (true,m) case None => val newMap : Map[Z3AST,Z3AST] = (zippedExprVars ++ zippedCondVars).map(p => p._2 -> solver.idToFreshZ3Id(p._1)).toMap ++ (z3FunDefArgs zip args) cache(args) = newMap (false,newMap) } val idSubstMap : Map[Z3AST,Z3AST] = baseIDSubstMap + (z3ActivatingBool -> aVar) val (from, to) = idSubstMap.unzip val (fromArray, toArray) = (from.toArray, to.toArray) val newClauses = asZ3Clauses.map(z3.substitute(_, fromArray, toArray)) val newBlockers = z3Blockers.map { case (b, funs) => val bp = if (b == z3ActivatingBool) { aVar } else { idSubstMap(b) } val newFuns = funs.map(fi => fi.copy(args = fi.args.map(z3.substitute(_, fromArray, toArray)))) bp -> newFuns } (newClauses, newBlockers) } override def toString : String = { "Template for def " + tfd.id + "(" + tfd.params.map(a => a.id + " : " + a.tpe).mkString(", ") + ") : " + tfd.returnType + " is :\n" + " * Activating boolean : " + activatingBool + "\n" + " * Control booleans : " + condVars.toSeq.map(_.toString).mkString(", ") + "\n" + " * Expression vars : " + exprVars.toSeq.map(_.toString).mkString(", ") + "\n" + " * \"Clauses\" : " + "\n " + asClauses.mkString("\n ") + "\n" + " * Block-map : " + blockers.toString } } object FunctionTemplate { def mkTemplate(solver: FairZ3Solver, tfd: TypedFunDef, isRealFunDef : Boolean = true) : FunctionTemplate = { val condVars : MutableSet[Identifier] = MutableSet.empty val exprVars : MutableSet[Identifier] = MutableSet.empty // Represents clauses of the form: // id => expr && ... && expr val guardedExprs : MutableMap[Identifier,Seq[Expr]] = MutableMap.empty def storeGuarded(guardVar : Identifier, expr : Expr) : Unit = { assert(expr.getType == BooleanType) if(guardedExprs.isDefinedAt(guardVar)) { val prev : Seq[Expr] = guardedExprs(guardVar) guardedExprs(guardVar) = expr +: prev } else { guardedExprs(guardVar) = Seq(expr) } } // Group elements that satisfy p toghether // List(a, a, a, b, c, a, a), with p = _ == a will produce: // List(List(a,a,a), List(b), List(c), List(a, a)) def groupWhile[T](p: T => Boolean, l: Seq[T]): Seq[Seq[T]] = { var res: Seq[Seq[T]] = Nil var c = l while(!c.isEmpty) { val (span, rest) = c.span(p) if (span.isEmpty) { res = res :+ Seq(rest.head) c = rest.tail } else { res = res :+ span c = rest } } res } def requireDecomposition(e: Expr) = { exists{ case (_: FunctionInvocation) | (_: Assert) | (_: Ensuring) | (_: Choose) => true case _ => false }(e) } def rec(pathVar : Identifier, expr : Expr) : Expr = { expr match { case a @ Assert(cond, _, body) => storeGuarded(pathVar, rec(pathVar, cond)) rec(pathVar, body) case e @ Ensuring(body, id, post) => rec(pathVar, Let(id, body, Assert(post, None, Variable(id)))) case l @ Let(i, e, b) => val newExpr : Identifier = FreshIdentifier("lt", true).setType(i.getType) exprVars += newExpr val re = rec(pathVar, e) storeGuarded(pathVar, Equals(Variable(newExpr), re)) val rb = rec(pathVar, replace(Map(Variable(i) -> Variable(newExpr)), b)) rb case l @ LetTuple(is, e, b) => val tuple : Identifier = FreshIdentifier("t", true).setType(TupleType(is.map(_.getType))) exprVars += tuple val re = rec(pathVar, e) storeGuarded(pathVar, Equals(Variable(tuple), re)) val mapping = for ((id, i) <- is.zipWithIndex) yield { val newId = FreshIdentifier("ti", true).setType(id.getType) exprVars += newId storeGuarded(pathVar, Equals(Variable(newId), TupleSelect(Variable(tuple), i+1))) (Variable(id) -> Variable(newId)) } val rb = rec(pathVar, replace(mapping.toMap, b)) rb case m : MatchExpr => sys.error("MatchExpr's should have been eliminated.") case i @ Implies(lhs, rhs) => Implies(rec(pathVar, lhs), rec(pathVar, rhs)) case a @ And(parts) => And(parts.map(rec(pathVar, _))) case o @ Or(parts) => Or(parts.map(rec(pathVar, _))) case i @ IfExpr(cond, thenn, elze) => { if(!requireDecomposition(i)) { i } else { val newBool1 : Identifier = FreshIdentifier("b", true).setType(BooleanType) val newBool2 : Identifier = FreshIdentifier("b", true).setType(BooleanType) val newExpr : Identifier = FreshIdentifier("e", true).setType(i.getType) condVars += newBool1 condVars += newBool2 exprVars += newExpr val crec = rec(pathVar, cond) val trec = rec(newBool1, thenn) val erec = rec(newBool2, elze) storeGuarded(pathVar, Or(Variable(newBool1), Variable(newBool2))) storeGuarded(pathVar, Or(Not(Variable(newBool1)), Not(Variable(newBool2)))) // TODO can we improve this? i.e. make it more symmetrical? // Probably it's symmetrical enough to Z3. storeGuarded(pathVar, Iff(Variable(newBool1), crec)) storeGuarded(newBool1, Equals(Variable(newExpr), trec)) storeGuarded(newBool2, Equals(Variable(newExpr), erec)) Variable(newExpr) } } case c @ Choose(ids, cond) => val cid = FreshIdentifier("choose", true).setType(c.getType) exprVars += cid val m: Map[Expr, Expr] = if (ids.size == 1) { Map(Variable(ids.head) -> Variable(cid)) } else { ids.zipWithIndex.map{ case (id, i) => Variable(id) -> TupleSelect(Variable(cid), i+1) }.toMap } storeGuarded(pathVar, replace(m, cond)) Variable(cid) case n @ NAryOperator(as, r) => r(as.map(a => rec(pathVar, a))).setType(n.getType) case b @ BinaryOperator(a1, a2, r) => r(rec(pathVar, a1), rec(pathVar, a2)).setType(b.getType) case u @ UnaryOperator(a, r) => r(rec(pathVar, a)).setType(u.getType) case t : Terminal => t } } // The precondition if it exists. val prec : Option[Expr] = tfd.precondition.map(p => matchToIfThenElse(p)) val newBody : Option[Expr] = tfd.body.map(b => matchToIfThenElse(b)) val invocation : Expr = FunctionInvocation(tfd, tfd.params.map(_.toVariable)) val invocationEqualsBody : Option[Expr] = newBody match { case Some(body) if isRealFunDef => val b : Expr = Equals(invocation, body) Some(if(prec.isDefined) { Implies(prec.get, b) } else { b }) case _ => None } val activatingBool : Identifier = FreshIdentifier("start", true).setType(BooleanType) if (isRealFunDef) { val finalPred : Option[Expr] = invocationEqualsBody.map(expr => rec(activatingBool, expr)) finalPred.foreach(p => storeGuarded(activatingBool, p)) } else { val newFormula = rec(activatingBool, newBody.get) storeGuarded(activatingBool, newFormula) } // Now the postcondition. tfd.postcondition match { case Some((id, post)) => val newPost : Expr = replace(Map(Variable(id) -> invocation), matchToIfThenElse(post)) val postHolds : Expr = if(tfd.hasPrecondition) { Implies(prec.get, newPost) } else { newPost } val finalPred2 : Expr = rec(activatingBool, postHolds) storeGuarded(activatingBool, finalPred2) case None => } new FunctionTemplate(solver, tfd, activatingBool, Set(condVars.toSeq : _*), Set(exprVars.toSeq : _*), Map(guardedExprs.toSeq : _*), isRealFunDef) } }