diff --git a/src/purescala/Analysis.scala b/src/purescala/Analysis.scala index a56acce566ee2629702b902b7ba5acd774a4ee32..a45e05982f864d6bda25425299fc1a25b8c5b29b 100644 --- a/src/purescala/Analysis.scala +++ b/src/purescala/Analysis.scala @@ -196,8 +196,32 @@ object Analysis { e.isInstanceOf[MatchExpr] } - def rewritePM(e: Expr) : Expr = { - val MatchExpr(scrutinee, cases) = e.asInstanceOf[MatchExpr] + def rewritePM(e: Expr) : Expr = e.asInstanceOf[MatchExpr] match { + case SimplePatternMatching(scrutinee, classType, casesInfo) => { + val newVar = Variable(FreshIdentifier("pm", true)).setType(e.getType) + val scrutAsLetID = FreshIdentifier("scrut", true).setType(scrutinee.getType) + val lle : List[(Variable,List[Expr])] = casesInfo.map(cseInfo => { + val (ccd, newPID, argIDs, rhs) = cseInfo + val newPVar = Variable(newPID) + val argVars = argIDs.map(Variable(_)) + val (rewrittenRHS, moreExtras) = rewriteSimplePatternMatching(rhs) + (newPVar, List(Equals(newPVar, CaseClass(ccd, argVars)), Implies(Equals(Variable(scrutAsLetID), newPVar), Equals(newVar, rewrittenRHS))) ::: moreExtras.toList) + }).toList + val (newPVars, newExtras) = lle.unzip + extras = Let(scrutAsLetID, scrutinee, And(Or(newPVars.map(Equals(Variable(scrutAsLetID), _))), And(newExtras.flatten))) :: extras + newVar + } + case _ => e + } + + val cleanerTree = searchAndApply(isPMExpr, rewritePM, expr) + (cleanerTree, extras.reverse) + } + + object SimplePatternMatching { + // (scrutinee, classtypedef, list((caseclassdef, variable, list(variable), rhs))) + def unapply(e: MatchExpr) : Option[(Expr,ClassType,Seq[(CaseClassDef,Identifier,Seq[Identifier],Expr)])] = { + val MatchExpr(scrutinee, cases) = e val sType = scrutinee.getType if(sType.isInstanceOf[AbstractClassType]) { @@ -205,61 +229,65 @@ object Analysis { if(cases.size == cCD.knownChildren.size && cases.forall(!_.hasGuard)) { var seen = Set.empty[ClassTypeDef] - val newVar = Variable(FreshIdentifier("pm", true)).setType(e.getType) - val scrutAsLetID = FreshIdentifier("scrut", true).setType(scrutinee.getType) - - val lle : List[(Variable,List[Expr])] = cases.map(cse => cse match { - case SimpleCase(CaseClassPattern(binder, ccd, subPats), rhs) if subPats.forall(_.isInstanceOf[WildcardPattern]) => { - seen = seen + ccd - val newPVar = if(binder.isDefined) { - Variable(binder.get) - } else { - Variable(FreshIdentifier("cse", true)).setType(CaseClassType(ccd)) + var lle : List[(CaseClassDef,Identifier,List[Identifier],Expr)] = Nil + for(cse <- cases) { + cse match { + case SimpleCase(CaseClassPattern(binder, ccd, subPats), rhs) if subPats.forall(_.isInstanceOf[WildcardPattern]) => { + seen = seen + ccd + + val patID : Identifier = if(binder.isDefined) { + binder.get + } else { + FreshIdentifier("cse", true).setType(CaseClassType(ccd)) + } + + val argIDs : List[Identifier] = (ccd.fields zip subPats.map(_.asInstanceOf[WildcardPattern])).map(p => if(p._2.binder.isDefined) { + p._2.binder.get + } else { + FreshIdentifier("pat", true).setType(p._1.tpe) + }).toList + + lle = (ccd, patID, argIDs, rhs) :: lle } - val argVars = (ccd.fields zip subPats.map(_.asInstanceOf[WildcardPattern])).map(p => if(p._2.binder.isDefined) { - Variable(p._2.binder.get) - } else { - Variable(FreshIdentifier("pat", true)).setType(p._1.tpe) - }) - val (rewrittenRHS, moreExtras) = rewriteSimplePatternMatching(rhs) - (newPVar, List(Equals(newPVar, CaseClass(ccd, argVars)), Implies(Equals(Variable(scrutAsLetID), newPVar), Equals(newVar, rewrittenRHS))) ::: moreExtras.toList) + case _ => ; } - case _ => (null,Nil) - }).toList + } + lle = lle.reverse if(seen.size == cases.size) { - val (newPVars, newExtras) = lle.unzip - extras = Let(scrutAsLetID, scrutinee, And(Or(newPVars.map(Equals(Variable(scrutAsLetID), _))), And(newExtras.flatten))) :: extras - newVar + Some((scrutinee, sType.asInstanceOf[AbstractClassType], lle)) } else { - e + None } } else { - e + None } } else { val cCD = sType.asInstanceOf[CaseClassType].classDef if(cases.size == 1 && !cases(0).hasGuard) { - e + val SimpleCase(pat,rhs) = cases(0).asInstanceOf[SimpleCase] + pat match { + case CaseClassPattern(binder, ccd, subPats) if (ccd == cCD && subPats.forall(_.isInstanceOf[WildcardPattern])) => { + val patID : Identifier = if(binder.isDefined) { + binder.get + } else { + FreshIdentifier("cse", true).setType(CaseClassType(ccd)) + } + + val argIDs : List[Identifier] = (ccd.fields zip subPats.map(_.asInstanceOf[WildcardPattern])).map(p => if(p._2.binder.isDefined) { + p._2.binder.get + } else { + FreshIdentifier("pat", true).setType(p._1.tpe) + }).toList + + Some((scrutinee, CaseClassType(cCD), List((cCD, patID, argIDs, rhs)))) + } + case _ => None + } } else { - e + None } } } - - val cleanerTree = searchAndApply(isPMExpr, rewritePM, expr) - // println("******************") - // println("rewrote: " + expr) - // println(" *** to ***") - // println(cleanerTree) - // println(" ** with side conds ** ") - // println(extras.reverse) - // println("******************") - (cleanerTree, extras.reverse) - // val theExtras = extras.reverse - // val onExtras: Seq[(Expr,Seq[Expr])] = theExtras.map(rewriteSimplePatternMatching(_)) - // // the "moreExtras" should be cleaned up due to the recursive call.. - // val (rewrittenExtras, moreExtras) = onExtras.unzip - // (cleanerTree, rewrittenExtras ++ moreExtras.flatten) } }