Skip to content
Snippets Groups Projects
Commit d64511f2 authored by Robin Steiger's avatar Robin Steiger
Browse files

Inferred (dis)equalities on element variables and other changes.

parent fdc3b193
Branches
Tags
No related merge requests found
package orderedsets
object DNF {
import purescala.Trees._
import TreeOperations.dnf
// Printer (both && and || are printed as ? on windows..)
def pp(expr: Expr): String = expr match {
case And(es) => (es map pp).mkString("( ", " & ", " )")
case Or(es) => (es map pp).mkString("( ", " | ", " )")
case Not(e) => "!(" + pp(e) + ")"
case _ => expr.toString
}
// Tests
import purescala.Common._
implicit def str2id(str: String): Identifier = FreshIdentifier(str)
val a = Variable("a")
val b = Variable("b")
val c = Variable("c")
val x = Variable("x")
val y = Variable("y")
val z = Variable("z")
val t1 = And(Or(a, b), Or(x, y))
val t2 = Implies(x, Iff(y, z))
def main(args: Array[String]) {
test(t1)
test(t2)
test(Not(t1))
test(Not(t2))
}
def test(before: Expr) {
val after = Or((dnf(before) map And.apply).toSeq)
println
println("Before dnf : " + pp(before))
//println("After dnf : " + pp(after))
println("After dnf : ")
for (and <- dnf(before)) println(" " + pp(And(and)))
}
}
\ No newline at end of file
......@@ -29,7 +29,7 @@ class Main(reporter: Reporter) extends Solver(reporter) {
//reporter.info("Sets: " + ExprToASTConverter.getSetTypes(expr))
try {
// Negate formula
(Some(!solve(!ExprToASTConverter(expr))), {
(Some(!solve(!ExprToASTConverter(expr, reporter))), {
val sets = ExprToASTConverter.getSetTypes(expr)
if (sets.size > 1)
......@@ -38,7 +38,7 @@ class Main(reporter: Reporter) extends Solver(reporter) {
})._1
} catch {
case ConversionException(badExpr, msg) =>
reporter.info(badExpr, msg)
reporter.info(badExpr, msg + " in " + badExpr.getClass)
None
case IncompleteException(msg) =>
reporter.info(msg)
......@@ -121,7 +121,7 @@ object ExprToASTConverter {
def makeEq(v: Variable, t: Expr) = v.getType match {
case Int32Type => Equals(v, t)
case tpe if isSetType(tpe) => SetEquals(v, t)
case _ => throw (new ConversionException(v, "is of type " + v.getType + " and cannot be handled by OrdBapa"))
case _ => throw (new ConversionException(v, "type " + v.getType))
}
private def toSetTerm(expr: Expr): AST.Term = expr match {
......@@ -133,8 +133,8 @@ object ExprToASTConverter {
case SetIntersection(set1, set2) => toSetTerm(set1) ** toSetTerm(set2)
case SetUnion(set1, set2) => toSetTerm(set1) ++ toSetTerm(set2)
case SetDifference(set1, set2) => toSetTerm(set1) -- toSetTerm(set2)
case Variable(_) => throw ConversionException(expr, "is a variable of type " + expr.getType + " and cannot be converted to bapa< set variable")
case _ => throw ConversionException(expr, "Cannot convert to bapa< set term")
case Variable(_) => throw ConversionException(expr, "type " + expr.getType)
case _ => throw ConversionException(expr, "bad set term")
}
private def toIntTerm(expr: Expr): AST.Term = expr match {
......@@ -148,7 +148,7 @@ object ExprToASTConverter {
case SetCardinality(e) => toSetTerm(e).card
case SetMin(set) if set.getType == SetType(Int32Type) => toSetTerm(set).inf
case SetMax(set) if set.getType == SetType(Int32Type) => toSetTerm(set).sup
case _ => throw ConversionException(expr, "Cannot convert to bapa< int term")
case _ => throw ConversionException(expr, "bad int term")
}
private def toFormula(expr: Expr): AST.Formula = expr match {
......@@ -159,6 +159,10 @@ object ExprToASTConverter {
case And(exprs) => AST.And((exprs map toFormula).toList)
case Not(expr) => !toFormula(expr)
case Implies(expr1, expr2) => !(toFormula(expr1)) || toFormula(expr2)
case Iff(expr1, expr2) =>
val f1 = toFormula(expr1)
val f2 = toFormula(expr2)
(f1 && f2) || (!f1 && !f2)
// Set Formulas
case ElementOfSet(elem, set) => toIntTerm(elem) selem toSetTerm(set)
......@@ -171,10 +175,13 @@ object ExprToASTConverter {
case LessEquals(lhs, rhs) => toIntTerm(lhs) <= toIntTerm(rhs)
case GreaterThan(lhs, rhs) => toIntTerm(lhs) > toIntTerm(rhs)
case GreaterEquals(lhs, rhs) => toIntTerm(lhs) >= toIntTerm(rhs)
case Equals(lhs, rhs) if lhs.getType == Int32Type && rhs.getType == Int32Type => toIntTerm(lhs) === toIntTerm(rhs)
case Equals(lhs, rhs) => (lhs.getType, rhs.getType) match {
case (Int32Type, Int32Type) => toIntTerm(lhs) === toIntTerm(rhs)
case types => throw ConversionException(expr, "types " + types)
}
// Assuming the formula to be True
case _ => throw ConversionException(expr, "Cannot convert to bapa< formula")
case _ => throw ConversionException(expr, "bad formula")
}
def getSetTypes(expr: Expr): Set[TypeTree] = expr match {
......@@ -204,7 +211,16 @@ object ExprToASTConverter {
case _ => Set.empty[TypeTree]
}
def apply(expr: Expr) = {
def apply(expr: Expr, reporter: Reporter) = {
def toRelaxedFormula(expr: Expr): AST.Formula =
try {
toFormula(expr)
} catch {
case ConversionException(badExpr, msg) =>
//reporter.warning("BAPA was relaxed : " + msg + " in " + badExpr.getClass + "\n" + rpp(badExpr))
formulaRelaxed = true
AST.True // Assuming the formula to be True
}
formulaRelaxed = false;
expr match {
case And(exprs) => AST.And((exprs map toRelaxedFormula).toList)
......@@ -212,13 +228,5 @@ object ExprToASTConverter {
}
}
private def toRelaxedFormula(expr: Expr): AST.Formula =
try {
toFormula(expr)
} catch {
case ConversionException(_, _) =>
formulaRelaxed = true
// Assuming the formula to be True
AST.True
}
}
......@@ -186,7 +186,7 @@ object NormalForms {
case Not(Predicate(comp, terms)) =>
rewriteNonPure_*(terms, ts => Not(Predicate(comp, ts)) :: Nil)
case And(_) | Or(_) if isAtom(f) =>
case And(_) | Or(_) if !isAtom(f) =>
error("A simplified conjunction cannot contain " + f)
case _ => List(f)
......
......@@ -8,6 +8,8 @@ import Common._
import TypeTrees._
import Definitions._
import RPrettyPrinter.rpp
object TreeOperations {
def dnf(expr: Expr): Stream[Seq[Expr]] = expr match {
case And(Nil) => Stream(Nil)
......@@ -35,7 +37,6 @@ object TreeOperations {
searchAndReplace({case FunctionInvocation(_, Seq(Variable(id))) => varSet += id; None; case _ => None})(l._4)
varSet.subsetOf(l._3.toSet)
}
val c = program.callees(f)
if (f.hasImplementation && f.args.size == 1 && c.size == 1 && c.head == f) f.body.get match {
case SimplePatternMatching(scrut, _, lstMatches)
......@@ -45,7 +46,7 @@ object TreeOperations {
None
}
}
// 'Lazy' rewriter
//
// Hoists if expressions to the top level and
......@@ -53,15 +54,18 @@ object TreeOperations {
//
// The implementation is totally brain-teasing
def rewrite(expr: Expr): Expr =
rewrite(expr, ex => ex)
Simplifier(rewrite(expr, ex => ex))
private def rewrite(expr: Expr, context: Expr => Expr): Expr = expr match {
// Convert to nnf
case Not(e@(And(_) | Or(_) | Iff(_, _) | Implies(_, _) | IfExpr(_, _, _))) =>
rewrite(negate(e), context)
case IfExpr(_c, _t, _e) =>
rewrite(_c, c =>
rewrite(_t, t =>
rewrite(_e, e =>
Or(And(c, context(t)), And(negate(c), context(e)))
)))
)))
case And(_exs) =>
rewrite_*(_exs, exs =>
context(And(exs)))
......@@ -69,7 +73,7 @@ object TreeOperations {
rewrite_*(_exs, exs =>
context(Or(exs)))
case Not(_ex) =>
rewrite(_ex, ex =>
rewrite(_ex, ex =>
context(Not(ex)))
case f@FunctionInvocation(fd, _args) =>
rewrite_*(_args, args =>
......@@ -78,8 +82,8 @@ object TreeOperations {
rewrite(_t, t =>
context(recons(t) setType u.getType))
case b@BinaryOperator(_t1, _t2, recons) =>
rewrite(_t1, t1 =>
rewrite(_t2, t2 =>
rewrite(_t1, t1 =>
rewrite(_t2, t2 =>
context(recons(t1, t2) setType b.getType)))
case c@CaseClass(cd, _args) =>
rewrite_*(_args, args =>
......@@ -90,29 +94,48 @@ object TreeOperations {
case f@FiniteSet(_elems) =>
rewrite_*(_elems, elems =>
context(FiniteSet(elems) setType f.getType))
case Terminal() =>
case _: Terminal =>
context(expr)
case _ => // Missed case
error("Unsupported case in rewrite : " + expr.getClass)
}
private def rewrite_*(exprs: Seq[Expr], context: Seq[Expr] => Expr): Expr =
exprs match {
case Nil => context(Nil)
case _t :: _ts =>
rewrite(_t, t => rewrite_*(_ts, ts => context(t +: ts)))
}
// This should rather be a marker interface, but I don't want
// to change Trees.scala without Philippe's permission.
object Terminal {
def unapply(expr: Expr): Boolean = expr match {
case Variable(_) | ResultVariable() | OptionNone(_) | EmptySet(_) | EmptyMultiset(_) | EmptyMap(_, _) | NilList(_) => true
case _: Literal[_] => true
case _ => false
object Simplifier {
private val True = BooleanLiteral(true)
private val False = BooleanLiteral(false)
def apply(expr: Expr) = simplify(expr)
def simplify(expr: Expr): Expr = expr match {
case Not(ex) => negate(ex)
case And(exs) => And(simplify(exs, True, False) flatMap flatAnd)
case Or(exs) => Or(simplify(exs, False, True) flatMap flatOr)
case _ => expr
}
private def simplify(exprs: Seq[Expr], neutral: Expr, absorbing: Expr): Seq[Expr] = {
val exs = (exprs map simplify) filterNot {_ == neutral}
if (exs contains absorbing) Seq(absorbing)
else exs
}
private def flatAnd(f: Expr) = f match {
case And(fs) => fs
case _ => Seq(f)
}
private def flatOr(f: Expr) = f match {
case Or(fs) => fs
case _ => Seq(f)
}
}
}
......@@ -86,7 +86,7 @@ object ADTUnifier extends Unifier[Variable, CaseClassDef] {
def freshVar(prefix: String)(typed: Typed) = Var(Variable(FreshIdentifier(prefix, true) setType typed.getType))
def unify(conjunction: Seq[Expr]): Map[Variable, Expr] = {
def unify(conjunction: Seq[Expr]): (Seq[Expr], Map[Variable, Expr]) = {
val equalities = new ArrayBuffer[(Term, Term)]()
val inequalities = new ArrayBuffer[(Var, Var)]()
......@@ -168,13 +168,38 @@ object ADTUnifier extends Unifier[Variable, CaseClassDef] {
if (map1.isEmpty) println(" (empty table)")
println
*/
table mapValues term2expr
// Extract element equalities and disequalities
val elementFormula = new ArrayBuffer[Expr]()
for ((e1, term) <- table; if isElementType(e1)) {
term2expr(term) match {
case e2@Variable(_) =>
//println(" " + e1 + ": " + e1.getType +
// " -> " + e2 + ": " + e2.getType)
elementFormula += Equals(e1, e2)
case expr =>
//println(" " + e1 + ": " + e1.getType +
// " -> " + expr + ": " + expr.getType)
//println("UNEXPECTED: " + term)
error("Unexpected " + expr)
}
}
for ((Var(e1), Var(e2)) <- inequalities; if isElementType(e1))
elementFormula += Not(Equals(e1, e2))
(elementFormula.toSeq, table mapValues term2expr)
}
def term2expr(term: Term): Expr = term match {
case Var(v) => v
case Fun(cd, args) => CaseClass(cd, args map term2expr)
}
def isElementType(typed: Typed) = typed.getType match {
case AbstractClassType(_) | CaseClassType(_) => false
case _ => true
}
}
......
......@@ -52,7 +52,7 @@ class UnifierMain(reporter: Reporter) extends Solver(reporter) {
val noAlphas = And(restFormula flatMap expandAlphas(varMap))
reporter.info("The resulting formula is\n" + rpp(noAlphas))
tryAllSolvers(noAlphas)
} catch {
} catch {
case ex@ConversionException(badExpr, msg) =>
reporter.info("Conjunction " + counter + " is UNKNOWN, could not be parsed")
throw ex
......@@ -88,36 +88,38 @@ class UnifierMain(reporter: Reporter) extends Solver(reporter) {
}
}
def tryAllSolvers(f : Expr): Unit = {
for(solver <- Extensions.loadedSolverExtensions; if solver != this) {
reporter.info("Trying solver: " + solver.shortDescription + " from inside the unifier.")
solver.isUnsat(f) match {
case Some(true) =>
reporter.info("Solver: " + solver.shortDescription + " proved the formula unsatisfiable")
return
case Some(false) =>
reporter.warning("Solver: " + solver.shortDescription + " proved the formula satisfiable")
throw (new SatException(null))
case None =>
reporter.info("Solver: " + solver.shortDescription + " was unable to conclusively determine the correctness of the formula")
}
}; throw IncompleteException("All the solvers were unable to prove the formula unsatisfiable, giving up.") }
def tryAllSolvers(f: Expr): Unit = {
for (solver <- Extensions.loadedSolverExtensions; if solver != this) {
reporter.info("Trying solver: " + solver.shortDescription + " from inside the unifier.")
solver.isUnsat(f) match {
case Some(true) =>
reporter.info("Solver: " + solver.shortDescription + " proved the formula unsatisfiable")
return
case Some(false) =>
reporter.warning("Solver: " + solver.shortDescription + " proved the formula satisfiable")
throw (new SatException(null))
case None =>
reporter.info("Solver: " + solver.shortDescription + " was unable to conclusively determine the correctness of the formula")
}
};
throw IncompleteException("All the solvers were unable to prove the formula unsatisfiable, giving up.")
}
def checkIsSupported(expr: Expr) {
def check(ex: Expr): Option[Expr] = ex match {
case Let(_, _, _) | MatchExpr(_, _) =>
throw ConversionException(ex, "Unifier does not support this expression")
case IfExpr(_, _, _) =>
println
println("--- BEFORE ---")
println(rpp(expr))
println
println("--- AFTER ---")
println(rpp(rewrite(expr)))
println
println
println("--- BEFORE ---")
println(rpp(expr))
println
println("--- AFTER ---")
println(rpp(rewrite(expr)))
println
throw ConversionException(ex, "Unifier does not support this expression")
case _ => None
}
......@@ -139,12 +141,12 @@ class UnifierMain(reporter: Reporter) extends Solver(reporter) {
*/
// The substitution table
val substTable = ADTUnifier.unify(treeEquations)
val (elementFormula, substTable) = ADTUnifier.unify(treeEquations)
// The substitution function (returns identity if unmapped)
def subst(v: Variable): Expr = substTable getOrElse (v, v)
(subst, rest)
(subst, elementFormula ++ rest)
}
......@@ -239,7 +241,9 @@ class UnifierMain(reporter: Reporter) extends Solver(reporter) {
//reporter.warning("Result:\n" + rpp(res))
Some(res)
}
case _ => error("Bad argument/substitution to catamorphism: " + substArg(arg))
case badArg =>
println(rpp(badArg))
error("Bad argument/substitution to catamorphism")
}
case None => // Not a catamorphism
warning("Function " + fd.id + " is not a catamorphism.")
......
package orderedsets
import TreeOperations._
import purescala._
import Trees._
import Common._
import TypeTrees._
import Definitions._
object getAlpha {
var program: Program = null
def setProgram(p: Program) = program = p
def isAlpha(varMap: Variable => Expr)(t: Expr): Option[Expr] = t match {
case FunctionInvocation(fd, Seq(v@Variable(_))) => asCatamorphism(program, fd) match {
case None => None
case Some(lstMatch) => varMap(v) match {
case CaseClass(cd, args) => {
val (_, _, ids, rhs) = lstMatch.find(_._1 == cd).get
val repMap = Map(ids.map(id => Variable(id): Expr).zip(args): _*)
Some(searchAndReplace(repMap.get)(rhs))
}
case u@Variable(_) => {
val c = Variable(FreshIdentifier("Coll", true)).setType(t.getType)
// TODO: Keep track of these variables for M1(t, c)
Some(c)
}
case _ => error("Bad substitution")
}
case _ => None
}
case _ => None
}
def apply(t: Expr, varMap: Variable => Expr): Expr = {
searchAndReplace(isAlpha(varMap))(t)
}
// def solve(e : Expr): Option[Boolean] = {
// searchAndReplace(isAlpha(x => x))(e)
// None
// }
}
import scala.collection.immutable.Set
// import scala.collection.immutable.Multiset
//import scala.collection.immutable.Multiset
object BinarySearchTree {
sealed abstract class Tree
......@@ -15,56 +15,56 @@ object BinarySearchTree {
sealed abstract class Triple
case class SortedTriple(min: Option, max: Option, sorted: Boolean) extends Triple
def isSorted(tree: Tree): SortedTriple = tree match {
def isSorted(tree: Tree): SortedTriple = tree match {
case Leaf() => SortedTriple(None(), None(), true)
case Node(l,v,r) => isSorted(l) match {
case SortedTriple(minl,maxl,sortl) => if (!sortl) SortedTriple(None(), None(), false)
else minl match {
case None() => maxl match {
case None() => isSorted(r) match {
case SortedTriple(minr,maxr,sortr) => if (!sortr) SortedTriple(None(), None(), false)
else minr match {
case None() => maxr match {
case None() => SortedTriple(Some(v),Some(v),true)
case Some(maxrv) => SortedTriple(None(),None(),false)
}
case Some(minrv) => maxr match {
case Some(maxrv) => if (minrv > v) SortedTriple(Some(v),Some(maxrv),true) else SortedTriple(None(),None(),false)
case None() => SortedTriple(None(),None(),false)
}
}
}
case Some(maxlv) => SortedTriple(None(),None(),false)
}
case Some(minlv) => maxl match {
case Some(maxlv) => isSorted(r) match {
case SortedTriple(minr,maxr,sortr) => if (!sortr) SortedTriple(None(), None(), false)
else minr match {
case None() => maxr match {
case None() => if (maxlv <= v) SortedTriple(Some(minlv),Some(v),true) else SortedTriple(None(),None(),false)
case Some(maxrv) => SortedTriple(None(),None(),false)
}
case Some(minrv) => maxr match {
case Some(maxrv) => if (maxlv <= v && minrv > v) SortedTriple(Some(minlv),Some(maxrv),true) else SortedTriple(None(),None(),false)
case None() => SortedTriple(None(),None(),false)
}
}
}
case None() => SortedTriple(None(),None(),false)
}
}
case Node(l, v, r) => isSorted(l) match {
case SortedTriple(minl, maxl, sortl) => if (!sortl) SortedTriple(None(), None(), false)
else minl match {
case None() => maxl match {
case None() => isSorted(r) match {
case SortedTriple(minr, maxr, sortr) => if (!sortr) SortedTriple(None(), None(), false)
else minr match {
case None() => maxr match {
case None() => SortedTriple(Some(v), Some(v), true)
case Some(maxrv) => SortedTriple(None(), None(), false)
}
case Some(minrv) => maxr match {
case Some(maxrv) => if (minrv > v) SortedTriple(Some(v), Some(maxrv), true) else SortedTriple(None(), None(), false)
case None() => SortedTriple(None(), None(), false)
}
}
}
case Some(maxlv) => SortedTriple(None(), None(), false)
}
case Some(minlv) => maxl match {
case Some(maxlv) => isSorted(r) match {
case SortedTriple(minr, maxr, sortr) => if (!sortr) SortedTriple(None(), None(), false)
else minr match {
case None() => maxr match {
case None() => if (maxlv <= v) SortedTriple(Some(minlv), Some(v), true) else SortedTriple(None(), None(), false)
case Some(maxrv) => SortedTriple(None(), None(), false)
}
case Some(minrv) => maxr match {
case Some(maxrv) => if (maxlv <= v && minrv > v) SortedTriple(Some(minlv), Some(maxrv), true) else SortedTriple(None(), None(), false)
case None() => SortedTriple(None(), None(), false)
}
}
}
case None() => SortedTriple(None(), None(), false)
}
}
}
}
def treeMin(tree: Node): Int = {
require(isSorted(tree).sorted)
tree match {
case Node(left, v, _) => left match {
case Leaf() => v
case n@Node(_, _, _) => treeMin(n)
case Node(left, v, _) => left match {
case Leaf() => v
case n@Node(_, _, _) => treeMin(n)
}
}
}} ensuring (_ == contents(tree).min)
} ensuring (_ == contents(tree).min)
def treeMax(tree: Node): Int = {
require(isSorted(tree).sorted)
......@@ -77,6 +77,19 @@ object BinarySearchTree {
} ensuring (_ == contents(tree).max)
def insert(tree: Tree, value: Int): Node = {
tree match {
case Leaf() => Node(Leaf(), value, Leaf())
case n@Node(l, v, r) => if (v < value) {
Node(l, v, insert(r, value))
} else if (v > value) {
Node(insert(l, value), v, r)
} else {
n
}
}
} ensuring (contents(_) == contents(tree) ++ Set(value))
def insertSorted(tree: Tree, value: Int): Node = {
require(isSorted(tree).sorted)
tree match {
case Leaf() => Node(Leaf(), value, Leaf())
......@@ -97,18 +110,18 @@ object BinarySearchTree {
}
} ensuring (contents(_) == contents(tree) ++ Set(0))
/*
def remove(tree: Tree, value: Int) : Node = (tree match {
case Leaf() => Node(Leaf(), value, Leaf())
case n @ Node(l, v, r) => if(v < value) {
Node(l, v, insert(r, value))
} else if(v > value) {
Node(insert(l, value), v, r)
} else {
n
}
}) ensuring (contents(_) == contents(tree) -- Set(value))
*/
/*
def remove(tree: Tree, value: Int) : Node = (tree match {
case Leaf() => Node(Leaf(), value, Leaf())
case n @ Node(l, v, r) => if(v < value) {
Node(l, v, insert(r, value))
} else if(v > value) {
Node(insert(l, value), v, r)
} else {
n
}
}) ensuring (contents(_) == contents(tree) -- Set(value))
*/
def dumbInsertWithOrder(tree: Tree): Node = {
tree match {
......@@ -139,20 +152,21 @@ object BinarySearchTree {
Node(mkInfiniteTree(x), x, mkInfiniteTree(x))
} ensuring (res =>
res.left != Leaf() && res.right != Leaf()
)
)
def contains(tree: Tree, value: Int): Boolean = {
require(isSorted(tree).sorted)
tree match {
case Leaf() => false
case n@Node(l, v, r) => if (v < value) {
contains(r, value)
} else if (v > value) {
contains(l, value)
} else {
true
case Leaf() => false
case n@Node(l, v, r) => if (v < value) {
contains(r, value)
} else if (v > value) {
contains(l, value)
} else {
true
}
}
} } ensuring( _ || !(contents(tree) == contents(tree) ++ Set(value)))
} ensuring (_ || !(contents(tree) == contents(tree) ++ Set(value)))
def contents(tree: Tree): Set[Int] = tree match {
case Leaf() => Set.empty[Int]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment