diff --git a/src/orderedsets/ASTUtil.scala b/src/orderedsets/ASTUtil.scala index 42749ed4bc46b30516c4d7532ca2cb09095257e6..9d7281c86d5c341242013d0693d06dae50d16a16 100644 --- a/src/orderedsets/ASTUtil.scala +++ b/src/orderedsets/ASTUtil.scala @@ -2,6 +2,7 @@ package orderedsets import AST._ import Primitives._ +import Symbol._ import scala.collection.mutable.ListBuffer case class IllegalTerm(a: Any) extends Exception(a + " should not be present in the formula to be converted.") @@ -167,4 +168,4 @@ object SetsToFormTest extends Application { And(ASTUtil.analyze(fs, ASTUtil.split(fs))).print } } -*/ \ No newline at end of file +*/ diff --git a/src/orderedsets/Main.scala b/src/orderedsets/Main.scala index 560245434f57ba41e996af5fe3ed0946e6007a71..9ee7a6e7f93eee32a6e73de9556328313554e111 100644 --- a/src/orderedsets/Main.scala +++ b/src/orderedsets/Main.scala @@ -1,16 +1,83 @@ package orderedsets import purescala.Reporter +import purescala.TypeTrees._ +import purescala.Common._ import purescala.Extensions._ import purescala.Trees._ +import Primitives._ class Main(reporter: Reporter) extends Solver(reporter) { val description = "Solver for constraints on ordered sets" + def convertToSetTerm( expr : Expr ): AST.Term = expr match { + // TODO: Use id.getType as Symbol's type, this _has_ to be a set variable + case Variable(id) if id.getType == SetType(Int32Type) => Symbol(id.name, Symbol.SetType) + case EmptySet(_) => AST.emptyset + case FiniteSet(elems) => { reporter.error(expr, "TODO!"); error("TODO!") } + case SetCardinality(set) => convertToSetTerm(set).card + case SetIntersection(set1, set2) => convertToSetTerm(set1) ** convertToSetTerm(set2) + case SetUnion(set1, set2) => convertToSetTerm(set1) ++ convertToSetTerm(set2) + // TODO: Confirm the order of operation + case SetDifference(set1, set2) => convertToSetTerm(set1) -- convertToSetTerm(set2) + case SetMin(set) => convertToSetTerm(set).inf + case SetMax(set) => convertToSetTerm(set).sup + case _ => reporter.error(expr, "Not a Set expression!"); error("Not a Set expression!") + } + + def convertToIntTerm( expr : Expr ): AST.Term = expr match { + case IntLiteral(v) => AST.Lit(IntLit(v)) + case Variable(id) if id.getType == Int32Type => Symbol(id.name, Symbol.IntType) + case Plus(lhs, rhs) => convertToIntTerm(lhs) + convertToIntTerm(rhs) + // TODO: Confirm order of operation + case Minus(lhs, rhs) => convertToIntTerm(lhs) - convertToIntTerm(rhs) + // TODO: Checking for linearity? + case Times(lhs, rhs) => convertToIntTerm(lhs) * convertToIntTerm(rhs) + // TODO: Throwing own exception here? + case Division(_, _) => reporter.error(expr, "Division is not supported."); error("Division is not supported.") + case UMinus(e) => AST.zero - convertToIntTerm(e) + case _ => reporter.error(expr, "Not an integer expression!"); error("Not an integer expression.") + } + + def convertToAST( expr : Expr ): AST.Formula = expr match { + case BooleanLiteral(true) => AST.True + case BooleanLiteral(false) => AST.False + // TODO: Use id.getType as Symbol's type, this _has_ to be a set variable + case Variable(id) if id.getType == BooleanType => Symbol(id.name, Symbol.BoolType) + + case Or(exprs) => AST.Or((exprs map convertToAST).toList) + case And(exprs) => AST.And((exprs map convertToAST).toList) + case Not(expr) => !convertToAST(expr) + case Implies(expr1, expr2) => !(convertToAST(expr1)) || convertToAST(expr2) + + // Set Formulas + case ElementOfSet(elem, set) => convertToIntTerm(elem) selem convertToSetTerm(set) + case SetEquals(set1, set2) => convertToSetTerm(set1) seq convertToSetTerm(set2) + // Is this a formula or a boolean function? Purification? + // TODO: Nicer way to write this? + // case IsEmptySet(set) => AST.Op(ITE, (convertToSetTerm(set).card === 0)::AST.True::AST.False::Nil) + case IsEmptySet(set) => convertToSetTerm(set).card === 0 + case SubsetOf(set1, set2) => convertToSetTerm(set1) subseteq convertToSetTerm(set2) + + // Integer Formulas + case LessThan(lhs, rhs) => convertToIntTerm(lhs) < convertToIntTerm(rhs) + case LessEquals(lhs, rhs) => convertToIntTerm(lhs) <= convertToIntTerm(rhs) + case GreaterThan(lhs, rhs) => convertToIntTerm(lhs) > convertToIntTerm(rhs) + case GreaterEquals(lhs, rhs) => convertToIntTerm(lhs) >= convertToIntTerm(rhs) + case Equals(lhs, rhs) => convertToIntTerm(lhs) === convertToIntTerm(rhs) + + case _ => reporter.error(expr, "Cannot be handled by Ordered BAPA."); error("Cannot be handled") + } + // checks for V-A-L-I-D-I-T-Y ! // Some(true) means formula is valid (negation is unsat) // None means you don't know. def solve(expr: Expr) : Option[Boolean] = { + try { + reporter.info("OrdBAPA conversion = " + convertToAST(expr).toString) + } catch { + case e => reporter.info(e.toString) + } reporter.info(expr, "I have no idea how to solve this :(") None } diff --git a/src/orderedsets/Symbol.scala b/src/orderedsets/Symbol.scala index c3c2e329c2e9409121ce04cb5d76b39550d2c8b9..c9ca0cd1ac8551121668f764d2953142a9866c82 100644 --- a/src/orderedsets/Symbol.scala +++ b/src/orderedsets/Symbol.scala @@ -3,11 +3,7 @@ package orderedsets import AST.{TermVar} import scala.collection.mutable.{HashMap => MutableMap} import z3.scala._ - -sealed abstract class Type -case object IntType extends Type -case object SetType extends Type -case object BoolType extends Type +import Symbol._ class Symbol private(val name: String, val tpe: Type) { override def toString: String = name @@ -37,6 +33,11 @@ class Symbol private(val name: String, val tpe: Type) { } object Symbol { + sealed abstract class Type + case object IntType extends Type + case object SetType extends Type + case object BoolType extends Type + private val counters = new MutableMap[String, Int]() private val interned = new MutableMap[String, Symbol]() @@ -56,6 +57,8 @@ object Symbol { sym } + def apply(name: String, tpe: Type) = Symbol.lookup(name, tpe) + def apply(name: String): Symbol = name.charAt(0) match { case c if c.isUpper => lookup(name, SetType) case c if c.isLower => lookup(name, IntType) diff --git a/src/orderedsets/Z3LibConverter.scala b/src/orderedsets/Z3LibConverter.scala index 5c6642aa4cb03ba973da0aa2e005e997bb4a3990..795b3d34341c3dfe077307434a461014c086ea95 100644 --- a/src/orderedsets/Z3LibConverter.scala +++ b/src/orderedsets/Z3LibConverter.scala @@ -4,6 +4,7 @@ import AST._ import Primitives._ import z3.scala._ import scala.collection.mutable.ArrayBuffer +import Symbol._ abstract sealed class Z3Output; case class Z3Failure(e: Exception) extends Z3Output