diff --git a/src/purescala/Analysis.scala b/src/purescala/Analysis.scala index 221de11457d3945c78b825a073ee38e15225eee3..0519d109dfae6e0bc8224f9e94f2f9edd7e367c1 100644 --- a/src/purescala/Analysis.scala +++ b/src/purescala/Analysis.scala @@ -50,7 +50,7 @@ class Analysis(val program: Program) { } else { reporter.info("Verification condition (post) for ==== " + funDef.id + " ====") if(Settings.unrollingLevel == 0) { - reporter.info(vc) + reporter.info(simplifyLets(vc)) } else { reporter.info("(not showing unrolled VCs)") } @@ -246,6 +246,7 @@ object Analysis { var extras : List[Expr] = Nil def rewritePM(e: Expr) : Option[Expr] = e match { + case NotSoSimplePatternMatching(_) => None case SimplePatternMatching(scrutinee, classType, casesInfo) => Some({ val newVar = Variable(FreshIdentifier("pm", true)).setType(e.getType) val scrutAsLetID = FreshIdentifier("scrut", true).setType(scrutinee.getType) @@ -257,7 +258,7 @@ object Analysis { (newPVar, List(Equals(newPVar, CaseClass(ccd, argVars)), Implies(Equals(Variable(scrutAsLetID), newPVar), Equals(newVar, rewrittenRHS))) ::: moreExtras.toList) }).toList val (newPVars, newExtras) = lle.unzip - extras = Let(scrutAsLetID, scrutinee, And(Or(newPVars.map(Equals(Variable(scrutAsLetID), _))), And(newExtras.flatten))) :: extras + extras = Let(scrutAsLetID, scrutinee, And(/*Or(newPVars.map(Equals(Variable(scrutAsLetID), _))),*/BooleanLiteral(true), And(newExtras.flatten))) :: extras newVar }) case _ => None diff --git a/src/purescala/Trees.scala b/src/purescala/Trees.scala index c920f9ee6ad6d1172740358489ba411d841eba6b..94bf03bf73d6392ebb101b0f803f21bf734e0a3f 100644 --- a/src/purescala/Trees.scala +++ b/src/purescala/Trees.scala @@ -26,7 +26,9 @@ object Trees { val fixedType = funDef.returnType } case class IfExpr(cond: Expr, then: Expr, elze: Expr) extends Expr - case class MatchExpr(scrutinee: Expr, cases: Seq[MatchCase]) extends Expr + case class MatchExpr(scrutinee: Expr, cases: Seq[MatchCase]) extends Expr { + def scrutineeClassType: ClassType = scrutinee.getType.asInstanceOf[ClassType] + } sealed abstract class MatchCase { val pattern: Pattern @@ -335,84 +337,6 @@ object Trees { searchAndReplace(substs.get(_))(expr) } - // the replacement map should be understood as follows: - // - on each subexpression, checkFun checks whether it should be replaced - // - repFun is applied is checkFun succeeded - // - if the result of repFun is different from its argument and recursive - // is set to true, search/replace is reapplied on the result. - // def searchAndApply(checkFun: Expr=>Boolean, repFun: Expr=>Expr, expr: Expr, recursive: Boolean=true) : Expr = { - // def rec(ex: Expr, skip: Expr = null) : Expr = ex match { - // case _ if (ex != skip && checkFun(ex)) => { - // val newExpr = repFun(ex) - // if(newExpr.getType == NoType) { - // Settings.reporter.warning("REPLACING IN EXPRESSION WITH AN UNTYPED TREE ! " + ex + " --to--> " + newExpr) - // } - // if(ex == newExpr) - // if(recursive) rec(ex, ex) else ex - // else - // if(recursive) rec(newExpr) else newExpr - // } - // case l @ Let(i,e,b) => { - // val re = rec(e) - // val rb = rec(b) - // if(re != e || rb != b) - // Let(i, re, rb).setType(l.getType) - // else - // l - // } - // case n @ NAryOperator(args, recons) => { - // var change = false - // val rargs = args.map(a => { - // val ra = rec(a) - // if(ra != a) { - // change = true - // ra - // } else { - // a - // } - // }) - // if(change) - // recons(rargs).setType(n.getType) - // else - // n - // } - // case b @ BinaryOperator(t1,t2,recons) => { - // val r1 = rec(t1) - // val r2 = rec(t2) - // if(r1 != t1 || r2 != t2) - // recons(r1,r2).setType(b.getType) - // else - // b - // } - // case u @ UnaryOperator(t,recons) => { - // val r = rec(t) - // if(r != t) - // recons(r).setType(u.getType) - // else - // u - // } - // case i @ IfExpr(t1,t2,t3) => { - // val r1 = rec(t1) - // val r2 = rec(t2) - // val r3 = rec(t3) - // if(r1 != t1 || r2 != t2 || r3 != t3) - // IfExpr(rec(t1),rec(t2),rec(t3)).setType(i.getType) - // else - // i - // } - // case m @ MatchExpr(scrut,cses) => MatchExpr(rec(scrut), cses.map(inCase(_))).setType(m.getType) - // case t if t.isInstanceOf[Terminal] => t - // case unhandled => scala.Predef.error("Non-terminal case should be handled in searchAndApply: " + unhandled) - // } - - // def inCase(cse: MatchCase) : MatchCase = cse match { - // case SimpleCase(pat, rhs) => SimpleCase(pat, rec(rhs)) - // case GuardedCase(pat, guard, rhs) => GuardedCase(pat, rec(guard), rec(rhs)) - // } - - // rec(expr) - // } - def searchAndReplace(subst: Expr=>Option[Expr], recursive: Boolean=true)(expr: Expr) : Expr = { def rec(ex: Expr, skip: Expr = null) : Expr = (if (ex == skip) None else subst(ex)) match { case Some(newExpr) => { @@ -666,4 +590,51 @@ object Trees { } } } + + object NotSoSimplePatternMatching { + def coversType(tpe: ClassTypeDef, patterns: Seq[Pattern]) : Boolean = { + if(patterns.isEmpty) { + false + } else if(patterns.exists(_.isInstanceOf[WildcardPattern])) { + true + } else { + val allSubtypes: Seq[CaseClassDef] = tpe match { + case acd @ AbstractClassDef(_,_) => acd.knownDescendents.filter(_.isInstanceOf[CaseClassDef]).map(_.asInstanceOf[CaseClassDef]) + case ccd: CaseClassDef => List(ccd) + } + + var seen: Set[CaseClassDef] = Set.empty + var secondLevel: Map[(CaseClassDef,Int),List[Pattern]] = Map.empty + + for(pat <- patterns) if (pat.isInstanceOf[CaseClassPattern]) { + val pattern: CaseClassPattern = pat.asInstanceOf[CaseClassPattern] + val ccd: CaseClassDef = pattern.caseClassDef + seen = seen + ccd + + for((subPattern,i) <- (pattern.subPatterns.zipWithIndex)) { + val seenSoFar = secondLevel.getOrElse((ccd,i), Nil) + secondLevel = secondLevel + ((ccd,i) -> (subPattern :: seenSoFar)) + } + } + + allSubtypes.forall(ccd => { + seen(ccd) && ccd.fields.zipWithIndex.forall(p => p._1.tpe match { + case t: ClassType => coversType(t.classDef, secondLevel.getOrElse((ccd, p._2), Nil)) + case _ => true + }) + }) + } + } + + def unapply(pm : MatchExpr) : Option[MatchExpr] = pm match { + case MatchExpr(scrutinee, cases) if cases.forall(_.isInstanceOf[SimpleCase]) => { + val allPatterns = cases.map(_.pattern) + Settings.reporter.info("This might be a complete pattern-matching expression:") + Settings.reporter.info(pm) + Settings.reporter.info("Covered? " + coversType(pm.scrutineeClassType.classDef, allPatterns)) + None + } + case _ => None + } + } } diff --git a/src/purescala/Z3Solver.scala b/src/purescala/Z3Solver.scala index efefcea57df0cb0c31b4e9ed51878243a1c7801c..440d18efdc7002f4b89dc0e54ba815f51fc646ba 100644 --- a/src/purescala/Z3Solver.scala +++ b/src/purescala/Z3Solver.scala @@ -234,14 +234,17 @@ class Z3Solver(reporter: Reporter) extends Solver(reporter) { } private var abstractedFormula = false - def solve(vc: Expr) : Option[Boolean] = { + override def isSat(vc: Expr) = decide(vc, false) + def solve(vc: Expr) = decide(vc, true) + def decide(vc: Expr, forValidity: Boolean) : Option[Boolean] = { abstractedFormula = false if(neverInitialized) { reporter.error("Z3 Solver was not initialized with a PureScala Program.") None } - val result = toZ3Formula(z3, negate(vc)) match { + val toConvert = if(forValidity) negate(vc) else vc + val result = toZ3Formula(z3, toConvert) match { case None => None // means it could not be translated case Some(z3f) => { z3.push diff --git a/testcases/Test.scala b/testcases/Test.scala new file mode 100644 index 0000000000000000000000000000000000000000..e07f5eb71eef1e0f91e0c9f4f02728991140e2b7 --- /dev/null +++ b/testcases/Test.scala @@ -0,0 +1,26 @@ + +object Test { + sealed abstract class List + case class Cons(head: Int, tail: List) extends List + case class Nil() extends List + + def append(value: Int, list: List) : List = list match { + case Nil() => Cons(value, Nil()) + case Cons(x, xs) => Cons(x, append(value, xs)) + } + + def isSorted(list: List) : Boolean = list match { + case Nil() => true + case Cons(x, Nil()) => true + case Cons(x, c @ Cons(y, ys)) => x <= y && isSorted(c) + } + + def isSorted2(list: List) : Boolean = list match { + case Cons(x, c @ Cons(y, ys)) => x <= y && isSorted2(c) + case _ => true + } + + def sameSorted(list: List) : Boolean = { + isSorted(list) == isSorted2(list) + } ensuring(r=>r) +}