diff --git a/src/funcheck/CodeExtraction.scala b/src/funcheck/CodeExtraction.scala index e996e7bf84e86217037a8929f929f2c42c206ed7..5817434edcdb27acb74a4239406905398a9b7209 100644 --- a/src/funcheck/CodeExtraction.scala +++ b/src/funcheck/CodeExtraction.scala @@ -538,7 +538,7 @@ trait CodeExtraction extends Extractors { case Some(vd) => vd.id } - CaseClassSelector(selector, fieldID).setType(fieldID.getType) + CaseClassSelector(selDef, selector, fieldID).setType(fieldID.getType) } // default behaviour is to complain :) diff --git a/src/orderedsets/RPrettyPrinter.scala b/src/orderedsets/RPrettyPrinter.scala index 059bb06d9940d9f9021799232e824a323ee4cd52..f2b7768ba175acb61390453c8363670d4d6db1e9 100644 --- a/src/orderedsets/RPrettyPrinter.scala +++ b/src/orderedsets/RPrettyPrinter.scala @@ -82,7 +82,7 @@ object RPrettyPrinter { case BooleanLiteral(v) => text(v toString) case StringLiteral(s) => "\"" :: s :: "\"" case CaseClass(ct, args) => ct.id :: ppVarary(args, ",") - case CaseClassSelector(cc, id) => pp(cc) :: "." :: id + case CaseClassSelector(_, cc, id) => pp(cc) :: "." :: id case FunctionInvocation(fd, args) => fd.id :: ppVarary(args, ",") case Plus(l, r) => ppBinary(l, r, "+") diff --git a/src/orderedsets/TreeOperations.scala b/src/orderedsets/TreeOperations.scala index 64d8a01f57bb5ae8af2b50fe7ae8c608514bae9d..d7f1614ed9b199468bfef1b83f4c15778da6609a 100644 --- a/src/orderedsets/TreeOperations.scala +++ b/src/orderedsets/TreeOperations.scala @@ -88,9 +88,9 @@ object TreeOperations { case c@CaseClass(cd, _args) => rewrite_*(_args, args => context(CaseClass(cd, args) setType c.getType)) - case c@CaseClassSelector(_cc, sel) => + case c@CaseClassSelector(ccd, _cc, sel) => rewrite(_cc, cc => - context(CaseClassSelector(cc, sel) setType c.getType)) + context(CaseClassSelector(ccd, cc, sel) setType c.getType)) case f@FiniteSet(_elems) => rewrite_*(_elems, elems => context(FiniteSet(elems) setType f.getType)) diff --git a/src/orderedsets/Unifier.scala b/src/orderedsets/Unifier.scala index f78ede598d92850f847906ff1d717d46b9830220..f0034dc6b47ae4924a0af2173971acb0e88c2ba9 100644 --- a/src/orderedsets/Unifier.scala +++ b/src/orderedsets/Unifier.scala @@ -107,7 +107,7 @@ object ADTUnifier extends Unifier[Variable, CaseClassDef] { def expr2term(expr: Expr): Term = expr match { case v@Variable(id) => Var(v) case CaseClass(ccdef, args) => Fun(ccdef, args map expr2term) - case CaseClassSelector(ex, sel) => + case CaseClassSelector(_, ex, sel) => val CaseClassType(ccdef) = ex.getType val args = ccdef.fields map freshVar("Sel") equalities += expr2term(ex) -> Fun(ccdef, args) @@ -162,7 +162,7 @@ object ADTUnifier extends Unifier[Variable, CaseClassDef] { println("Inequalities were checked to hold\n") println("--- Output of the unifier (Substitution table) ---") - val map1 = map.filterKeys{_.getType != NoType} + val map1 = map.filterKeys{_.getType != Untyped} for ((x, t) <- map1.toList sortWith byName) println(" " + x + " = " + pp(t)) if (map1.isEmpty) println(" (empty table)") @@ -441,4 +441,4 @@ trait Unifier[VarName >: Null, FunName >: Null] { (rec(function), frontier) } -} \ No newline at end of file +} diff --git a/src/orderedsets/UnifierMain.scala b/src/orderedsets/UnifierMain.scala index ede3dba52801df5de5803f2f22302c7f6502e944..eec35bcd2b9f2f47150fb6aa13f2fb6cc611220f 100644 --- a/src/orderedsets/UnifierMain.scala +++ b/src/orderedsets/UnifierMain.scala @@ -174,11 +174,11 @@ class UnifierMain(reporter: Reporter) extends Solver(reporter) { val bad = new ArrayBuffer[Expr]() // Formulas of unknown logic // TODO: Allow literals in unifier ? def isGood(expr: Expr) = expr match { - case Variable(_) | CaseClass(_, _) | CaseClassSelector(_, _) => true + case Variable(_) | CaseClass(_, _) | CaseClassSelector(_, _, _) => true case _ => false } def isBad(expr: Expr) = expr match { - case CaseClass(_, _) | CaseClassSelector(_, _) => false + case CaseClass(_, _) | CaseClassSelector(_, _, _) => false case _ => true } def purifyGood(expr: Expr) = if (isGood(expr)) None else { diff --git a/src/purescala/Analysis.scala b/src/purescala/Analysis.scala index fc5e602ca9c722c7598c7853a5ec42db79884c08..399c907331961a495875ef040b89e909e225b1bc 100644 --- a/src/purescala/Analysis.scala +++ b/src/purescala/Analysis.scala @@ -15,13 +15,10 @@ class Analysis(val program: Program) { Extensions.loadAll val analysisExtensions: Seq[Analyser] = loadedAnalysisExtensions - val solverExtensions: Seq[Solver] = loadedSolverExtensions - val trivialSolver = new Solver(reporter) { - val description = "Trivial" - override val shortDescription = "trivial" - def solve(e: Expr) = throw new Exception("trivial solver should not be called.") - } + val trivialSolver = new TrivialSolver(reporter) // This one you can't disable :D + val solverExtensions: Seq[Solver] = trivialSolver +: loadedSolverExtensions + solverExtensions.foreach(_.setProgram(program)) val defaultTactic = new DefaultTactic(reporter) defaultTactic.setProgram(program) @@ -35,10 +32,9 @@ class Analysis(val program: Program) { def analyse : Unit = { if(solverExtensions.size > 0) { reporter.info("Running verification condition generation...") - // checkVerificationConditions val list = generateVerificationConditions - list.foreach(e => println(e.infoLine)) + checkVerificationConditions(list : _*) } else { reporter.warning("No solver specified. Cannot test verification conditions.") } @@ -49,10 +45,13 @@ class Analysis(val program: Program) { }) } - def generateVerificationConditions : List[VerificationCondition] = { + private def generateVerificationConditions : List[VerificationCondition] = { var allVCs: Seq[VerificationCondition] = Seq.empty + val analysedFunctions: MutableSet[String] = MutableSet.empty for(funDef <- program.definedFunctions.toList.sortWith((fd1, fd2) => fd1.id.name < fd2.id.name) if (Settings.functionsToAnalyse.isEmpty || Settings.functionsToAnalyse.contains(funDef.id.name))) { + analysedFunctions += funDef.id.name + val tactic: Tactic = if(funDef.annotations.contains("induct")) { inductionTactic @@ -68,169 +67,80 @@ class Analysis(val program: Program) { } } + val notFound: Set[String] = Settings.functionsToAnalyse -- analysedFunctions + notFound.foreach(fn => reporter.error("Did not find function \"" + fn + "\" though it was marked for analysis.")) + allVCs.toList } - def checkVerificationConditions : Unit = { - // just for the summary: - var verificationConditionInfos: List[VerificationCondition] = Nil - - var analysedFunctions: MutableSet[String] = MutableSet.empty + def checkVerificationCondition(vc: VerificationCondition) : Unit = checkVerificationConditions(vc) + def checkVerificationConditions(vcs: VerificationCondition*) : Unit = { + for(vcInfo <- vcs) { + val funDef = vcInfo.funDef + val vc = vcInfo.condition - solverExtensions.foreach(_.setProgram(program)) + reporter.info("Verification condition (post) for ==== " + funDef.id + " ====") + if(true || Settings.unrollingLevel == 0) { + reporter.info(simplifyLets(vc)) + } else { + reporter.info("(not showing unrolled VCs)") + } - for(funDef <- program.definedFunctions.toList.sortWith((fd1,fd2) => fd1.id.name < fd2.id.name)) if (Settings.functionsToAnalyse.isEmpty || Settings.functionsToAnalyse.contains(funDef.id.name)) { - analysedFunctions += funDef.id.name - if(funDef.body.isDefined) { - val vcInfo = defaultTactic.generatePostconditions(funDef).head - val vc = vcInfo.condition - // val vc = postconditionVC(funDef) - // val vcInfo = new VerificationCondition(vc, funDef, VCKind.Postcondition, defaultTactic) - verificationConditionInfos = vcInfo :: verificationConditionInfos - - if(vc == BooleanLiteral(false)) { - vcInfo.value = Some(false) - vcInfo.solvedWith = Some(trivialSolver) - vcInfo.time = Some(0L) - } else if(vc == BooleanLiteral(true)) { - if(funDef.hasPostcondition) { - vcInfo.value = Some(true) - vcInfo.solvedWith = Some(trivialSolver) - vcInfo.time = Some(0L) - } + // try all solvers until one returns a meaningful answer + var superseeded : Set[String] = Set.empty[String] + solverExtensions.find(se => { + reporter.info("Trying with solver: " + se.shortDescription) + if(superseeded(se.shortDescription) || superseeded(se.description)) { + reporter.info("Solver was superseeded. Skipping.") + false } else { - reporter.info("Verification condition (post) for ==== " + funDef.id + " ====") - if(true || Settings.unrollingLevel == 0) { - reporter.info(simplifyLets(vc)) - } else { - reporter.info("(not showing unrolled VCs)") - } + superseeded = superseeded ++ Set(se.superseeds: _*) + + val t1 = System.nanoTime + val solverResult = se.solve(vc) + val t2 = System.nanoTime + val dt = ((t2 - t1) / 1000000) / 1000.0 + + solverResult match { + case None => false + case Some(true) => { + reporter.info("==== VALID ====") - // try all solvers until one returns a meaningful answer - var superseeded : Set[String] = Set.empty[String] - solverExtensions.find(se => { - reporter.info("Trying with solver: " + se.shortDescription) - if(superseeded(se.shortDescription) || superseeded(se.description)) { - reporter.info("Solver was superseeded. Skipping.") - false - } else { - superseeded = superseeded ++ Set(se.superseeds: _*) - - val t1 = System.nanoTime - val solverResult = se.solve(vc) - val t2 = System.nanoTime - val dt = ((t2 - t1) / 1000000) / 1000.0 - - solverResult match { - case None => false - case Some(true) => { - reporter.info("==== VALID ====") - - vcInfo.value = Some(true) - vcInfo.solvedWith = Some(se) - vcInfo.time = Some(dt) - - true - } - case Some(false) => { - reporter.error("==== INVALID ====") - - vcInfo.value = Some(false) - vcInfo.solvedWith = Some(se) - vcInfo.time = Some(dt) - - true - } - } + vcInfo.value = Some(true) + vcInfo.solvedWith = Some(se) + vcInfo.time = Some(dt) + + true } - }) match { - case None => { - reporter.warning("No solver could prove or disprove the verification condition.") + case Some(false) => { + reporter.error("==== INVALID ====") + + vcInfo.value = Some(false) + vcInfo.solvedWith = Some(se) + vcInfo.time = Some(dt) + + true } - case _ => - } + } } - } else { - if(funDef.postcondition.isDefined) { - reporter.warning(funDef, "Could not verify postcondition: function implementation is unknown.") + }) match { + case None => { + reporter.warning("No solver could prove or disprove the verification condition.") } - } + case _ => + } + } - if(verificationConditionInfos.size > 0) { - verificationConditionInfos = verificationConditionInfos.reverse + if(vcs.size > 0) { val summaryString = ( VerificationCondition.infoHeader + - verificationConditionInfos.map(_.infoLine).mkString("\n", "\n", "\n") + + vcs.map(_.infoLine).mkString("\n", "\n", "\n") + VerificationCondition.infoFooter ) reporter.info(summaryString) } else { - reporter.info("No verification conditions were generated.") - } - - val notFound: Set[String] = Settings.functionsToAnalyse -- analysedFunctions - notFound.foreach(fn => reporter.error("Did not find function \"" + fn + "\" though it was marked for analysis.")) - } - - def postconditionVC(functionDefinition: FunDef) : Expr = { - assert(functionDefinition.body.isDefined) - val prec = functionDefinition.precondition - val post = functionDefinition.postcondition - val body = functionDefinition.body.get - - if(post.isEmpty) { - BooleanLiteral(true) - } else { - val resFresh = FreshIdentifier("result", true).setType(body.getType) - val bodyAndPost = Let(resFresh, body, replace(Map(ResultVariable() -> Variable(resFresh)), post.get)) - val withPrec = if(prec.isEmpty) { - bodyAndPost - } else { - Implies(prec.get, bodyAndPost) - } - - import Analysis._ - - if(Settings.experimental) { - reporter.info("Raw:") - reporter.info(withPrec) - reporter.info("Raw, expanded:") - reporter.info(expandLets(withPrec)) - } - reporter.info(" - inlining...") - val expr0 = inlineNonRecursiveFunctions(program, withPrec) - if(Settings.experimental) { - reporter.info("Inlined:") - reporter.info(expr0) - reporter.info("Inlined, expanded:") - reporter.info(expandLets(expr0)) - } - reporter.info(" - unrolling...") - val expr1 = unrollRecursiveFunctions(program, expr0, Settings.unrollingLevel) - if(Settings.experimental) { - reporter.info("Unrolled:") - reporter.info(expr1) - reporter.info("Unrolled, expanded:") - reporter.info(expandLets(expr1)) - } - reporter.info(" - inlining contracts...") - val expr2 = inlineContracts(expr1) - if(Settings.experimental) { - reporter.info("Contract'ed:") - reporter.info(expr2) - reporter.info("Contract'ed, expanded:") - reporter.info(expandLets(expr2)) - } - reporter.info(" - converting pattern-matching...") - val expr3 = rewriteSimplePatternMatching(expr2) - if(Settings.experimental) { - reporter.info("Pattern'ed:") - reporter.info(expr3) - reporter.info("Pattern'ed, expanded:") - reporter.info(expandLets(expr3)) - } - expr3 + reporter.info("No verification conditions were analyzed.") } } } diff --git a/src/purescala/Common.scala b/src/purescala/Common.scala index 7e326ca3c865a17ac0646e3aa9468cb8ddb0fa16..fcdea7487f7f26d0baa5a2badcd9f26b8ed0bc00 100644 --- a/src/purescala/Common.scala +++ b/src/purescala/Common.scala @@ -3,7 +3,7 @@ package purescala object Common { import TypeTrees.Typed - // the type is left blank (NoType) for Identifiers that are not variables + // the type is left blank (Untyped) for Identifiers that are not variables class Identifier private[Common](val name: String, val id: Int, alwaysShowUniqueID: Boolean = false) extends Typed { override def equals(other: Any): Boolean = { if(other == null || !other.isInstanceOf[Identifier]) diff --git a/src/purescala/DefaultTactic.scala b/src/purescala/DefaultTactic.scala index e351adcfa001a40ef963784c87efff39567f0450..b725f10720379bce6ecb9177c02db73554bcf680 100644 --- a/src/purescala/DefaultTactic.scala +++ b/src/purescala/DefaultTactic.scala @@ -25,61 +25,66 @@ class DefaultTactic(reporter: Reporter) extends Tactic(reporter) { val post = functionDefinition.postcondition val body = functionDefinition.body.get - val theExpr = if(post.isEmpty) { - BooleanLiteral(true) + if(post.isEmpty) { + Seq.empty } else { - val resFresh = FreshIdentifier("result", true).setType(body.getType) - val bodyAndPost = Let(resFresh, body, replace(Map(ResultVariable() -> Variable(resFresh)), post.get)) - val withPrec = if(prec.isEmpty) { - bodyAndPost - } else { - Implies(prec.get, bodyAndPost) - } + val theExpr = { + val resFresh = FreshIdentifier("result", true).setType(body.getType) + val bodyAndPost = Let(resFresh, body, replace(Map(ResultVariable() -> Variable(resFresh)), post.get)) + val withPrec = if(prec.isEmpty) { + bodyAndPost + } else { + Implies(prec.get, bodyAndPost) + } - import Analysis._ - - if(Settings.experimental) { - reporter.info("Raw:") - reporter.info(withPrec) - reporter.info("Raw, expanded:") - reporter.info(expandLets(withPrec)) - } - reporter.info(" - inlining...") - val expr0 = inlineNonRecursiveFunctions(program, withPrec) - if(Settings.experimental) { - reporter.info("Inlined:") - reporter.info(expr0) - reporter.info("Inlined, expanded:") - reporter.info(expandLets(expr0)) - } - reporter.info(" - unrolling...") - val expr1 = unrollRecursiveFunctions(program, expr0, Settings.unrollingLevel) - if(Settings.experimental) { - reporter.info("Unrolled:") - reporter.info(expr1) - reporter.info("Unrolled, expanded:") - reporter.info(expandLets(expr1)) + import Analysis._ + + if(Settings.experimental) { + reporter.info("Raw:") + reporter.info(withPrec) + reporter.info("Raw, expanded:") + reporter.info(expandLets(withPrec)) + } + reporter.info(" - inlining...") + val expr0 = inlineNonRecursiveFunctions(program, withPrec) + if(Settings.experimental) { + reporter.info("Inlined:") + reporter.info(expr0) + reporter.info("Inlined, expanded:") + reporter.info(expandLets(expr0)) + } + reporter.info(" - unrolling...") + val expr1 = unrollRecursiveFunctions(program, expr0, Settings.unrollingLevel) + if(Settings.experimental) { + reporter.info("Unrolled:") + reporter.info(expr1) + reporter.info("Unrolled, expanded:") + reporter.info(expandLets(expr1)) + } + reporter.info(" - inlining contracts...") + val expr2 = inlineContracts(expr1) + if(Settings.experimental) { + reporter.info("Contract'ed:") + reporter.info(expr2) + reporter.info("Contract'ed, expanded:") + reporter.info(expandLets(expr2)) + } + reporter.info(" - converting pattern-matching...") + val expr3 = if(Settings.useNewPatternMatchingTranslator) { + matchToIfThenElse(expr2) + } else { + rewriteSimplePatternMatching(expr2) + } + if(Settings.experimental) { + reporter.info("Pattern'ed:") + reporter.info(expr3) + reporter.info("Pattern'ed, expanded:") + reporter.info(expandLets(expr3)) + } + expr3 } - reporter.info(" - inlining contracts...") - val expr2 = inlineContracts(expr1) - if(Settings.experimental) { - reporter.info("Contract'ed:") - reporter.info(expr2) - reporter.info("Contract'ed, expanded:") - reporter.info(expandLets(expr2)) - } - reporter.info(" - converting pattern-matching...") - val expr3 = rewriteSimplePatternMatching(expr2) - if(Settings.experimental) { - reporter.info("Pattern'ed:") - reporter.info(expr3) - reporter.info("Pattern'ed, expanded:") - reporter.info(expandLets(expr3)) - } - expr3 + Seq(new VerificationCondition(theExpr, functionDefinition, VCKind.Postcondition, this)) } - - Seq(new VerificationCondition(theExpr, functionDefinition, VCKind.Postcondition, this)) } def generatePreconditions(function: FunDef) : Seq[VerificationCondition] = { diff --git a/src/purescala/PrettyPrinter.scala b/src/purescala/PrettyPrinter.scala index c073757ce13a1a089a675e550204d1bafa1b215a..8c81b5e8f16d090ea6d3c219d73fa2ac49dfa56d 100644 --- a/src/purescala/PrettyPrinter.scala +++ b/src/purescala/PrettyPrinter.scala @@ -80,13 +80,19 @@ object PrettyPrinter { case IntLiteral(v) => sb.append(v) case BooleanLiteral(v) => sb.append(v) case StringLiteral(s) => sb.append("\"" + s + "\"") - case CaseClass(ct, args) => { + case CaseClass(cd, args) => { var nsb = sb - nsb.append(ct.id) + nsb.append(cd.id) nsb = ppNary(nsb, args, "(", ", ", ")", lvl) nsb } - case CaseClassSelector(cc, id) => pp(cc, sb, lvl).append("." + id) + case CaseClassInstanceOf(cd, e) => { + var nsb = sb + nsb = pp(e, nsb, lvl) + nsb.append(".isInstanceOf[" + cd.id + "]") + nsb + } + case CaseClassSelector(_, cc, id) => pp(cc, sb, lvl).append("." + id) case FunctionInvocation(fd, args) => { var nsb = sb nsb.append(fd.id) @@ -186,13 +192,21 @@ object PrettyPrinter { case ResultVariable() => sb.append("#res") case Not(expr) => ppUnary(sb, expr, "\u00AC(", ")", lvl) // \neg + case e @ Error(desc) => { + var nsb = sb + nsb.append("error(\"" + desc + "\")[") + nsb = pp(e.getType, nsb, lvl) + nsb.append("]") + nsb + } + case _ => sb.append("Expr?") } // TYPE TREES // all type trees are printed in-line private def pp(tpe: TypeTree, sb: StringBuffer, lvl: Int): StringBuffer = tpe match { - case NoType => sb.append("???") + case Untyped => sb.append("???") case Int32Type => sb.append("Int") case BooleanType => sb.append("Boolean") case SetType(bt) => pp(bt, sb.append("Set["), lvl).append("]") diff --git a/src/purescala/Trees.scala b/src/purescala/Trees.scala index f15671b320786af27fb173d7c32d036b56762011..9594b704aa7984dbbb73d2a2dcf55b97b14cdeda 100644 --- a/src/purescala/Trees.scala +++ b/src/purescala/Trees.scala @@ -12,13 +12,20 @@ object Trees { override def toString: String = PrettyPrinter(this) } - sealed trait Terminal + sealed trait Terminal { + self: Expr => + } + + /* This describes computational errors (unmatched case, taking min of an + * empty set, division by zero, etc.). It should always be typed according to + * the expected type. */ + case class Error(description: String) extends Expr with Terminal /* Like vals */ case class Let(binder: Identifier, value: Expr, body: Expr) extends Expr { binder.markAsLetBinder val et = body.getType - if(et != NoType) + if(et != Untyped) setType(et) } @@ -85,10 +92,15 @@ object Trees { /* Propositional logic */ object And { - def apply(exprs: Seq[Expr]) : Expr = exprs.size match { - case 0 => BooleanLiteral(true) - case 1 => exprs.head - case _ => new And(exprs) + def apply(exprs: Seq[Expr]) : Expr = { + val newExprs = exprs.filter(_ != BooleanLiteral(true)) + if(newExprs.contains(BooleanLiteral(false))) { + BooleanLiteral(false) + } else newExprs.size match { + case 0 => BooleanLiteral(true) + case 1 => newExprs.head + case _ => new And(newExprs) + } } def apply(l: Expr, r: Expr): Expr = (l,r) match { @@ -107,10 +119,15 @@ object Trees { } object Or { - def apply(exprs: Seq[Expr]) : Expr = exprs.size match { - case 0 => BooleanLiteral(false) - case 1 => exprs.head - case _ => new Or(exprs) + def apply(exprs: Seq[Expr]) : Expr = { + val newExprs = exprs.filter(_ != BooleanLiteral(false)) + if(newExprs.contains(BooleanLiteral(true))) { + BooleanLiteral(true) + } else newExprs.size match { + case 0 => BooleanLiteral(false) + case 1 => newExprs.head + case _ => new Or(newExprs) + } } def apply(l: Expr, r: Expr): Expr = (l,r) match { @@ -219,8 +236,11 @@ object Trees { case class CaseClass(classDef: CaseClassDef, args: Seq[Expr]) extends Expr with FixedType { val fixedType = CaseClassType(classDef) } - case class CaseClassSelector(caseClass: Expr, selector: Identifier) extends Expr with FixedType { - val fixedType = caseClass.getType.asInstanceOf[CaseClassType].classDef.fields.find(_.id == selector).get.getType + case class CaseClassInstanceOf(classDef: CaseClassDef, expr: Expr) extends Expr with FixedType { + val fixedType = BooleanType + } + case class CaseClassSelector(classDef: CaseClassDef, caseClass: Expr, selector: Identifier) extends Expr with FixedType { + val fixedType = classDef.fields.find(_.id == selector).get.getType } /* Arithmetic */ @@ -315,7 +335,8 @@ object Trees { case Cdr(t) => Some((t,Cdr)) case SetMin(s) => Some((s,SetMin)) case SetMax(s) => Some((s,SetMax)) - case CaseClassSelector(e, sel) => Some((e, CaseClassSelector(_, sel))) + case CaseClassSelector(cd, e, sel) => Some((e, CaseClassSelector(cd, _, sel))) + case CaseClassInstanceOf(cd, e) => Some((e, CaseClassInstanceOf(cd, _))) case _ => None } } @@ -388,10 +409,15 @@ object Trees { searchAndReplaceDFS(substs.get)(expr) } + // Can't just be overloading because of type erasure :'( + def replaceFromIDs(substs: Map[Identifier,Expr], expr: Expr) : Expr = { + replace(substs.map(p => (Variable(p._1) -> p._2)), expr) + } + def searchAndReplace(subst: Expr=>Option[Expr], recursive: Boolean=true)(expr: Expr) : Expr = { def rec(ex: Expr, skip: Expr = null) : Expr = (if (ex == skip) None else subst(ex)) match { case Some(newExpr) => { - if(newExpr.getType == NoType) { + if(newExpr.getType == Untyped) { Settings.reporter.error("REPLACING IN EXPRESSION WITH AN UNTYPED TREE ! " + ex + " --to--> " + newExpr) } if(ex == newExpr) @@ -473,7 +499,7 @@ object Trees { case None => ex case Some(newEx) => { somethingChanged = true - if(newEx.getType == NoType) { + if(newEx.getType == Untyped) { Settings.reporter.warning("REPLACING WITH AN UNTYPED EXPRESSION !") } newEx @@ -916,4 +942,68 @@ object Trees { }) } + /** Rewrites all pattern-matching expressions into if-then-else expressions, + * with additional error conditions. Does not introduce additional variables. + * */ + def matchToIfThenElse(expr: Expr) : Expr = { + def mapForPattern(in: Expr, pattern: Pattern) : Map[Identifier,Expr] = pattern match { + case WildcardPattern(None) => Map.empty + case WildcardPattern(Some(id)) => Map(id -> in) + case InstanceOfPattern(None, _) => Map.empty + case InstanceOfPattern(Some(id), _) => Map(id -> in) + case CaseClassPattern(b, ccd, subps) => { + assert(ccd.fields.size == subps.size) + val pairs = ccd.fields.map(_.id).toList zip subps.toList + val subMaps = pairs.map(p => mapForPattern(CaseClassSelector(ccd, in, p._1), p._2)) + val together = subMaps.foldLeft(Map.empty[Identifier,Expr])(_ ++ _) + b match { + case Some(id) => Map(id -> in) ++ together + case None => together + } + } + } + + def conditionForPattern(in: Expr, pattern: Pattern) : Expr = pattern match { + case WildcardPattern(_) => BooleanLiteral(true) + case InstanceOfPattern(_,_) => scala.Predef.error("InstanceOfPattern not yet supported.") + case CaseClassPattern(_, ccd, subps) => { + assert(ccd.fields.size == subps.size) + val pairs = ccd.fields.map(_.id).toList zip subps.toList + val subTests = pairs.map(p => conditionForPattern(CaseClassSelector(ccd, in, p._1), p._2)) + val together = And(subTests) + And(CaseClassInstanceOf(ccd, in), together) + } + } + + def rewritePM(e: Expr) : Option[Expr] = e match { + case m @ MatchExpr(scrut, cases) => { + println("Rewriting the following PM: " + e) + + val condsAndRhs = for(cse <- cases) yield { + // println("For this case: " + cse) + // println("Map: " + mapForPattern(scrut, cse.pattern)) + // println("Cond: " + conditionForPattern(scrut, cse.pattern)) + val map = mapForPattern(scrut, cse.pattern) + val patCond = conditionForPattern(scrut, cse.pattern) + val realCond = cse.theGuard match { + case Some(g) => And(patCond, replaceFromIDs(map, g)) + case None => patCond + } + val newRhs = replaceFromIDs(map, cse.rhs) + (realCond, newRhs) + } + + val bigIte = condsAndRhs.foldRight[Expr](Error("non-exhaustive match").setType(m.getType))((p1, ex) => { + IfExpr(p1._1, p1._2, ex) + }) + println(condsAndRhs) + println(bigIte) + + Some(e) + } + case _ => None + } + + searchAndReplaceDFS(rewritePM)(expr) + } } diff --git a/src/purescala/TrivialSolver.scala b/src/purescala/TrivialSolver.scala new file mode 100644 index 0000000000000000000000000000000000000000..688e680629b2e191f137b6b9e05ccd41dc36384a --- /dev/null +++ b/src/purescala/TrivialSolver.scala @@ -0,0 +1,21 @@ +package purescala + +import z3.scala._ +import Common._ +import Definitions._ +import Extensions._ +import Trees._ +import TypeTrees._ + +class TrivialSolver(reporter: Reporter) extends Solver(reporter) { + val description = "Solver for syntactically trivial formulas" + override val shortDescription = "trivial" + + def solve(expression: Expr) : Option[Boolean] = expression match { + case BooleanLiteral(v) => Some(v) + case Not(BooleanLiteral(v)) => Some(!v) + case Or(exs) if exs.contains(BooleanLiteral(true)) => Some(true) + case And(exs) if exs.contains(BooleanLiteral(false)) => Some(false) + case _ => None + } +} diff --git a/src/purescala/TypeTrees.scala b/src/purescala/TypeTrees.scala index 46ecb5d2be0887b499746b81d70a1b0602568bd9..69ee4a306b98dc5b4dc9663f00873dcd4c99b253 100644 --- a/src/purescala/TypeTrees.scala +++ b/src/purescala/TypeTrees.scala @@ -11,7 +11,7 @@ object TypeTrees { private var _type: Option[TypeTree] = None def getType: TypeTree = _type match { - case None => NoType + case None => Untyped case Some(t) => t } @@ -67,6 +67,11 @@ object TypeTrees { } case (o1, o2) if (o1 == o2) => o1 + case (o1,NoType) => o1 + case (NoType,o2) => o2 + case (o1,AnyType) => AnyType + case (AnyType,o2) => AnyType + case _ => scala.Predef.error("Asking for lub of unrelated types: " + t1 + " and " + t2) } @@ -76,8 +81,9 @@ object TypeTrees { case object InfiniteSize extends TypeSize def domainSize(typeTree: TypeTree) : TypeSize = typeTree match { - case NoType => FiniteSize(0) + case Untyped => FiniteSize(0) case AnyType => InfiniteSize + case NoType => FiniteSize(0) case BooleanType => FiniteSize(2) case Int32Type => InfiniteSize case ListType(_) => InfiniteSize @@ -105,9 +111,9 @@ object TypeTrees { case c: ClassType => InfiniteSize } - case object NoType extends TypeTree - + case object Untyped extends TypeTree case object AnyType extends TypeTree + case object NoType extends TypeTree // This is the type of errors (ie. subtype of anything) case object BooleanType extends TypeTree case object Int32Type extends TypeTree diff --git a/src/purescala/Z3Solver.scala b/src/purescala/Z3Solver.scala index f85158b626694ff5cdf078349bc11b4d8c04aabc..258fc36104c5967b1f62cb898e9f0d315539ed45 100644 --- a/src/purescala/Z3Solver.scala +++ b/src/purescala/Z3Solver.scala @@ -407,7 +407,7 @@ class Z3Solver(reporter: Reporter) extends Solver(reporter) with Z3ModelReconstr val constructor = adtConstructors(cd) constructor(args.map(rec(_)): _*) } - case c@CaseClassSelector(cc, sel) => { + case c@CaseClassSelector(_, cc, sel) => { val selector = adtFieldSelectors(sel) selector(rec(cc)) }