diff --git a/src/cp/CPComponent.scala b/src/cp/CPComponent.scala index 9fa4549a843f413bdab0c2027690c8f1e62bf349..ad87928df85077df03eadfcce7a059db6d7c9244 100644 --- a/src/cp/CPComponent.scala +++ b/src/cp/CPComponent.scala @@ -38,7 +38,7 @@ class CPComponent(val global: Global, val pluginInstance: CPPlugin) } // new ForeachTreeTraverser(plop).traverse(unit.body) - val prog: purescala.Definitions.Program = extractCode(unit) + val prog: purescala.Definitions.Program = extractCode(unit, true) val filename = writeProgram(prog) println("Program extracted and written into: " + filename) diff --git a/src/cp/CPPlugin.scala b/src/cp/CPPlugin.scala index 4329315f5c6cbaf6b27da02637edcb52934bf5f0..d5f307a1919c6dc0b3b4376fb12f42510c603894 100644 --- a/src/cp/CPPlugin.scala +++ b/src/cp/CPPlugin.scala @@ -15,6 +15,8 @@ class CPPlugin(val global: Global) extends PluginBase { var stopAfterAnalysis: Boolean = true var stopAfterExtraction: Boolean = false + silentlyTolerateNonPureBodies = true + /** The help message displaying the options for that plugin. */ override val optionsHelp: Option[String] = Some( " -P:funcheck:uniqid When pretty-printing funcheck trees, show identifiers IDs" + "\n" + diff --git a/src/cp/CallTransformation.scala b/src/cp/CallTransformation.scala index f541d4178a2bc7719d7ed887120278bf6a322bd0..2b7ec2a05d1fcb2304d2b9156033d66ef4be997d 100644 --- a/src/cp/CallTransformation.scala +++ b/src/cp/CallTransformation.scala @@ -28,7 +28,8 @@ trait CallTransformation var exprToScalaCode : Tree = null var exprToScalaCastSym : Symbol = null var exprToScalaCastCode : Tree = null - + var scalaToExprSym : Symbol = null + var scalaToExprCode : Tree = null override def transform(tree: Tree) : Tree = { tree match { @@ -56,8 +57,18 @@ trait CallTransformation val (exprAssignment, exprSym) = codeGen.assignExpr(exprFilename) val (outputVarListAssignment, outputVarListSym) = codeGen.assignOutputVarList(outputVarListFilename) + // compute input variables and assert equalities + val inputVars = variablesOf(b).filter{ v => !outputVarList.contains(v.name) } + println("here are the input vars: " + inputVars) + val inputVarListFilename = writeObject((inputVars map (iv => Variable(iv))).toList) + val equalities : List[Tree] = (for (iv <- inputVars) yield { + codeGen.inputEquality(inputVarListFilename, iv, scalaToExprSym) + }).toList + + val (andExprAssignment, andExprSym) = codeGen.assignAndExpr((ID(exprSym) :: equalities) : _*) + // invoke solver and get ordered list of assignments - val (solverInvocation, outcomeTupleSym) = codeGen.invokeSolver(progSym, exprSym) + val (solverInvocation, outcomeTupleSym) = codeGen.invokeSolver(progSym, andExprSym) val (modelAssignment, modelSym) = codeGen.assignModel(outcomeTupleSym) // TODO generate all correct e2s invocations @@ -77,7 +88,7 @@ trait CallTransformation New(tupleTypeTree,List(returnExpressions map (Ident(_)))) } - val code = BLOCK(List(programAssignment, exprAssignment, outputVarListAssignment) ::: solverInvocation ::: List(modelAssignment) ::: valueAssignments ::: List(returnExpr) : _*) + val code = BLOCK(List(programAssignment, exprAssignment, outputVarListAssignment, andExprAssignment) ::: solverInvocation ::: List(modelAssignment) ::: valueAssignments ::: List(returnExpr) : _*) typer.typed(atOwner(currentOwner) { code @@ -89,18 +100,24 @@ trait CallTransformation val codeGen = new CodeGenerator(unit, currentOwner, tree.pos) val ((e2sSym, e2sCode), (e2sCastSym,e2sCastCode)) = codeGen.exprToScalaMethods(cd.symbol, prog) + val (s2eCode, s2eSym) = codeGen.scalaToExprMethod(cd.symbol, prog, programFilename) exprToScalaSym = e2sSym exprToScalaCode = e2sCode exprToScalaCastSym = e2sCastSym exprToScalaCastCode = e2sCastCode + scalaToExprSym = s2eSym + scalaToExprCode = s2eCode atOwner(tree.symbol) { treeCopy.ClassDef(tree, transformModifiers(mods), name, transformTypeDefs(tparams), impl match { case Template(parents, self, body) => treeCopy.Template(impl, transformTrees(parents), - transformValDef(self), typer.typed(atOwner(currentOwner) {exprToScalaCode}) :: - typer.typed(atOwner(currentOwner) {exprToScalaCastCode}) :: transformStats(body, tree.symbol)) + transformValDef(self), + typer.typed(atOwner(currentOwner) {exprToScalaCode}) :: + typer.typed(atOwner(currentOwner) {exprToScalaCastCode}) :: + typer.typed(atOwner(currentOwner) {scalaToExprCode}) :: + transformStats(body, tree.symbol)) }) } } @@ -109,7 +126,6 @@ trait CallTransformation } } } - } object CallTransformation { @@ -138,4 +154,7 @@ object CallTransformation { outcomeTuple._1 } + def inputVar(inputVarList : List[Variable], varName : String) : Variable = + inputVarList.find(_.id.name == varName).getOrElse(scala.Predef.error("Could not find input variable '" + varName + "' in list " + inputVarList)) + } diff --git a/src/cp/CodeGeneration.scala b/src/cp/CodeGeneration.scala index 639ecc10ddd69a1938dc24192cfff59b9dffc5b7..ab331ceba0ee008b9ee2f830a0cb452275d52744 100644 --- a/src/cp/CodeGeneration.scala +++ b/src/cp/CodeGeneration.scala @@ -3,6 +3,7 @@ package cp import purescala.Trees._ import purescala.TypeTrees.classDefToClassType import purescala.Definitions._ +import purescala.Common.Identifier trait CodeGeneration { self: CPComponent => @@ -14,7 +15,8 @@ trait CodeGeneration { private lazy val exceptionClass = definitions.getClass("java.lang.Exception") private lazy val listMapFunction = definitions.getMember(definitions.ListClass, "map") - private lazy val listApplyFunction = definitions.getMember(definitions.ListClass, "apply") + private lazy val listClassApplyFunction = definitions.getMember(definitions.ListClass, "apply") + private lazy val listModuleApplyFunction = definitions.getMember(definitions.ListModule, "apply") private lazy val mapClass = definitions.getClass("scala.collection.immutable.Map") @@ -24,18 +26,21 @@ trait CodeGeneration { private lazy val outputAssignmentsFunction = definitions.getMember(callTransformationModule, "outputAssignments") private lazy val modelFunction = definitions.getMember(callTransformationModule, "model") private lazy val modelValueFunction = definitions.getMember(callTransformationModule, "modelValue") + private lazy val inputVarFunction = definitions.getMember(callTransformationModule, "inputVar") private lazy val serializationModule = definitions.getModule("cp.Serialization") private lazy val getProgramFunction = definitions.getMember(serializationModule, "getProgram") private lazy val getExprFunction = definitions.getMember(serializationModule, "getExpr") private lazy val getOutputVarListFunction = definitions.getMember(serializationModule, "getOutputVarList") + private lazy val getInputVarListFunction = definitions.getMember(serializationModule, "getInputVarList") private lazy val purescalaPackage = definitions.getModule("purescala") - private lazy val definitionsModule = definitions.getModule("purescala.Definitions") - private lazy val programClass = definitions.getClass("purescala.Definitions.Program") - private lazy val caseClassDefClass = definitions.getClass("purescala.Definitions.CaseClassDef") - private lazy val idField = definitions.getMember(caseClassDefClass, "id") + private lazy val definitionsModule = definitions.getModule("purescala.Definitions") + private lazy val programClass = definitions.getClass("purescala.Definitions.Program") + private lazy val caseClassDefFunction = definitions.getMember(programClass, "caseClassDef") + private lazy val caseClassDefClass = definitions.getClass("purescala.Definitions.CaseClassDef") + private lazy val idField = definitions.getMember(caseClassDefClass, "id") private lazy val commonModule = definitions.getModule("purescala.Common") private lazy val identifierClass = definitions.getClass("purescala.Common.Identifier") @@ -44,8 +49,13 @@ trait CodeGeneration { private lazy val treesModule = definitions.getModule("purescala.Trees") private lazy val exprClass = definitions.getClass("purescala.Trees.Expr") private lazy val intLiteralModule = definitions.getModule("purescala.Trees.IntLiteral") + private lazy val intLiteralClass = definitions.getClass("purescala.Trees.IntLiteral") private lazy val booleanLiteralModule = definitions.getModule("purescala.Trees.BooleanLiteral") + private lazy val booleanLiteralClass = definitions.getClass("purescala.Trees.BooleanLiteral") private lazy val caseClassModule = definitions.getModule("purescala.Trees.CaseClass") + private lazy val caseClassClass = definitions.getClass("purescala.Trees.CaseClass") + private lazy val andClass = definitions.getClass("purescala.Trees.And") + private lazy val equalsClass = definitions.getClass("purescala.Trees.Equals") private lazy val fairZ3SolverClass = definitions.getClass("purescala.FairZ3Solver") private lazy val restartAndDecideWithModel = definitions.getMember(fairZ3SolverClass, "restartAndDecideWithModel") @@ -133,7 +143,6 @@ trait CodeGeneration { val intSym = methodSym.newValue(NoPosition, unit.fresh.newName(NoPosition, "value")).setInfo(definitions.IntClass.tpe) val booleanSym = methodSym.newValue(NoPosition, unit.fresh.newName(NoPosition, "value")).setInfo(definitions.BooleanClass.tpe) - // ATTENTION the info might need module instead of class.. val ccdBinderSym = methodSym.newValue(NoPosition, unit.fresh.newName(NoPosition, "ccd")).setInfo(caseClassDefClass.tpe) val argsBinderSym = methodSym.newValue(NoPosition, unit.fresh.newName(NoPosition, "args")).setInfo(typeRef(NoPrefix, definitions.SeqClass, List(exprClass.tpe))) @@ -152,20 +161,14 @@ trait CodeGeneration { case c : purescala.TypeTrees.ClassType => reverseClassesToClasses(c.classDef) case _ => scala.Predef.error("Cannot generate method using type : " + tpe) } - println("args binder sym and its type:") - println(argsBinderSym) - println(argsBinderSym.tpe) - println("here is the list apply fun") - println(listApplyFunction) - println(listApplyFunction.tpe) Apply( TypeApply( Ident(castMethodSym), List(TypeTree(typeArg.tpe)) ), List( - // ugly hack to make typer happy :( - ((argsBinderSym DOT listApplyFunction) APPLY LIT(idx)) AS (exprClass.tpe) + // cast hack to make typer happy :( + ((argsBinderSym DOT listClassApplyFunction) APPLY LIT(idx)) AS (exprClass.tpe) ) ) }).toList @@ -176,7 +179,7 @@ trait CodeGeneration { CASE((intLiteralModule) APPLY (intSym BIND WILD())) ==> ID(intSym) , CASE((booleanLiteralModule) APPLY (booleanSym BIND WILD())) ==> ID(booleanSym)) ::: caseClassMatchCases ::: - List(DEFAULT ==> THROW(exceptionClass, LIT("Cannot convert FunCheck expression to Scala term"))) : _* + List(DEFAULT ==> THROW(exceptionClass, LIT("Cannot convert FunCheck expression to Scala term"))) : _* ) // the following is for the casting method @@ -185,24 +188,76 @@ trait CodeGeneration { ((methodSym, DEF(methodSym) === matchExpr), (castMethodSym, DEF(castMethodSym) === castBody)) } - /* Declare a new list variable and generate the code for assigning the - * result of applying the function on the input list */ - // TODO type of map function cannot be resolved. - def invokeMap(mapFunSym : Symbol, listSym : Symbol) : (Tree, Symbol) = { - val newListSym = owner.newValue(NoPosition, unit.fresh.newName(NoPosition, "nl")).setInfo(typeRef(NoPrefix, definitions.ListClass, List(definitions.AnyClass.tpe))) - val assignment = VAL(newListSym) === ((listSym DOT listMapFunction) APPLY ID(listSym)) - // val assignment = VAL(newListSym) === (listSym DOT (TypeApply(ID(listMapFunction), List(TypeTree(definitions.AnyClass.tpe)))) APPLY ID(listSym)) - (assignment, newListSym) + /* Generate the method for converting ground Scala terms into funcheck + * expressions */ + def scalaToExprMethod(owner : Symbol, prog : Program, programFilename : String) : (Tree, Symbol) = { + val methodSym = owner.newMethod(NoPosition, unit.fresh.newName(NoPosition, "scalaToExpr")) + methodSym setInfo (MethodType(methodSym newSyntheticValueParams (List(definitions.AnyClass.tpe)), exprClass.tpe)) + owner.info.decls.enter(methodSym) + + val intSym = methodSym.newValue(NoPosition, unit.fresh.newName(NoPosition, "value")).setInfo(definitions.IntClass.tpe) + val booleanSym = methodSym.newValue(NoPosition, unit.fresh.newName(NoPosition, "value")).setInfo(definitions.BooleanClass.tpe) + + val definedCaseClasses : Seq[CaseClassDef] = prog.definedClasses.filter(_.isInstanceOf[CaseClassDef]).map(_.asInstanceOf[CaseClassDef]) + val dccSyms = definedCaseClasses map (reverseClassesToClasses(_)) + + val caseClassMatchCases = ((definedCaseClasses zip dccSyms) map { + case (ccd, scalaSym) => + /* + val binderSyms = (ccd.fields.map { + case VarDecl(id, tpe) => + methodSym.newValue(NoPosition, unit.fresh.newName(NoPosition, id.name)).setInfo(definitions.AnyClass.tpe) + }).toList + */ + + val scalaBinderSym = methodSym.newValue(NoPosition, unit.fresh.newName(NoPosition, "cc")).setInfo(scalaSym.tpe) + + val memberSyms = (ccd.fields.map { + case VarDecl(id, tpe) => + scalaSym.info.member(id.name) + }).toList + + // CASE(scalaSym APPLY (binderSyms map (_ BIND WILD()))) ==> + CASE(scalaBinderSym BIND Typed(WILD(), TypeTree(scalaSym.tpe))) ==> + New( + TypeTree(caseClassClass.tpe), + List( + List( + (((cpPackage DOT serializationModule DOT getProgramFunction) APPLY LIT(programFilename)) DOT caseClassDefFunction) APPLY LIT(scalaSym.nameString), + listModuleApplyFunction APPLY (memberSyms map { + case ms => methodSym APPLY (scalaBinderSym DOT ms) + }) + ) + ) + ) + }).toList + + val matchExpr = (methodSym ARG 0) MATCH ( List( + CASE(intSym BIND Typed(WILD(), TypeTree(definitions.IntClass.tpe))) ==> NEW(ID(intLiteralClass), ID(intSym)) , + CASE(booleanSym BIND Typed(WILD(), TypeTree(definitions.BooleanClass.tpe))) ==> NEW(ID(booleanLiteralClass), ID(booleanSym))) ::: + caseClassMatchCases ::: + List(DEFAULT ==> THROW(exceptionClass, LIT("Cannot convert Scala term to FunCheck expression"))) : _* + ) + + (DEF(methodSym) === matchExpr, methodSym) } - def invokeOutputAssignments(outputVarListSym : Symbol, modelSym : Symbol) : (Tree, Symbol) = { - val assignmentListSym = owner.newValue(NoPosition, unit.fresh.newName(NoPosition, "as")).setInfo(typeRef(NoPrefix, definitions.ListClass, List(definitions.AnyClass.tpe))) - val assignment = VAL(assignmentListSym) === (outputAssignmentsFunction APPLY (ID(outputVarListSym), ID(modelSym))) - (assignment, assignmentListSym) + def inputEquality(inputVarListFilename : String, varId : Identifier, scalaToExprSym : Symbol) : Tree = { + NEW( + ID(equalsClass), + // retrieve input variable list and get corresponding variable + (cpPackage DOT callTransformationModule DOT inputVarFunction) APPLY + ((cpPackage DOT serializationModule DOT getInputVarListFunction) APPLY LIT(inputVarListFilename), LIT(varId.name)), + // invoke s2e on var symbol + scalaToExprSym APPLY ID(reverseVarSubsts(Variable(varId))) + ) } - def invokeMethod(methodSym : Symbol, argSyms : Symbol*) : Tree = { - methodSym APPLY (argSyms map (ID(_))).toList + def assignAndExpr(exprs : Tree*) : (Tree, Symbol) = { + val andSym = owner.newValue(NoPosition, unit.fresh.newName(NoPosition, "andExpr")).setInfo(exprClass.tpe) + val statement = VAL(andSym) === NEW(ID(andClass), listModuleApplyFunction APPLY (exprs.toList)) + (statement, andSym) } + } } diff --git a/src/cp/Serialization.scala b/src/cp/Serialization.scala index e5e0b782a0efce44824a1377c274824b76064210..d4fb5d8dd0beb99779d6b2c5f1b58268e03c57ed 100644 --- a/src/cp/Serialization.scala +++ b/src/cp/Serialization.scala @@ -13,6 +13,7 @@ trait Serialization { private var cachedProgram : Option[Program] = None private val exprMap = new scala.collection.mutable.HashMap[String,Expr]() private val outputVarListMap = new scala.collection.mutable.HashMap[String,List[String]]() + private val inputVarListMap = new scala.collection.mutable.HashMap[String,List[Variable]]() def programFileName(prog : Program) : String = { prog.mainObject.id.toString + fileSuffix @@ -57,31 +58,24 @@ trait Serialization { recovered } - private def readProgram(filename : String) : Program = { - readObject[Program](filename) - } - - private def readExpr(filename : String) : Expr = { - readObject[Expr](filename) - } - - private def readOutputVarList(filename : String) : List[String] = { - readObject[List[String]](filename) - } - def getProgram(filename : String) : Program = cachedProgram match { - case None => val r = readProgram(filename); cachedProgram = Some(r); r + case None => val r = readObject[Program](filename); cachedProgram = Some(r); r case Some(cp) => cp } def getExpr(filename : String) : Expr = exprMap.get(filename) match { case Some(e) => e - case None => val e = readExpr(filename); exprMap += (filename -> e); e + case None => val e = readObject[Expr](filename); exprMap += (filename -> e); e } def getOutputVarList(filename : String) : List[String] = outputVarListMap.get(filename) match { case Some(l) => l - case None => val l = readOutputVarList(filename); outputVarListMap += (filename -> l); l + case None => val l = readObject[List[String]](filename); outputVarListMap += (filename -> l); l + } + + def getInputVarList(filename : String) : List[Variable] = inputVarListMap.get(filename) match { + case Some(l) => l + case None => val l = readObject[List[Variable]](filename); inputVarListMap += (filename -> l); l } } diff --git a/src/funcheck/AnalysisComponent.scala b/src/funcheck/AnalysisComponent.scala index 78dafc2b0e2a3be40d99af3faaf3fdfa4d8549ac..49cb952a50a1011227382ad9eb9f994f4661f2c1 100644 --- a/src/funcheck/AnalysisComponent.scala +++ b/src/funcheck/AnalysisComponent.scala @@ -25,7 +25,7 @@ class AnalysisComponent(val global: Global, val pluginInstance: FunCheckPlugin) //global ref to freshName creator fresh = unit.fresh - val prog: purescala.Definitions.Program = extractCode(unit) + val prog: purescala.Definitions.Program = extractCode(unit, false) if(pluginInstance.stopAfterExtraction) { println("Extracted program for " + unit + ": ") println(prog) diff --git a/src/funcheck/CodeExtraction.scala b/src/funcheck/CodeExtraction.scala index 2322e1aa5e1b84671606495d0d8ae06533498d25..f139f1582ec169eb9b45ec0cdc6ff0ac34e3e867 100644 --- a/src/funcheck/CodeExtraction.scala +++ b/src/funcheck/CodeExtraction.scala @@ -19,13 +19,15 @@ trait CodeExtraction extends Extractors { private lazy val setTraitSym = definitions.getClass("scala.collection.immutable.Set") private lazy val multisetTraitSym = definitions.getClass("scala.collection.immutable.Multiset") - private val varSubsts: scala.collection.mutable.Map[Symbol,Function0[Expr]] = + val varSubsts: scala.collection.mutable.Map[Symbol,Function0[Expr]] = scala.collection.mutable.Map.empty[Symbol,Function0[Expr]] - private val classesToClasses: scala.collection.mutable.Map[Symbol,ClassTypeDef] = + val classesToClasses: scala.collection.mutable.Map[Symbol,ClassTypeDef] = scala.collection.mutable.Map.empty[Symbol,ClassTypeDef] private val defsToDefs: scala.collection.mutable.Map[Symbol,FunDef] = scala.collection.mutable.Map.empty[Symbol,FunDef] + val reverseVarSubsts: scala.collection.mutable.Map[Expr,Symbol] = + scala.collection.mutable.Map.empty[Expr,Symbol] val reverseClassesToClasses: scala.collection.mutable.Map[ClassTypeDef,Symbol] = scala.collection.mutable.Map.empty[ClassTypeDef,Symbol] @@ -35,7 +37,7 @@ trait CodeExtraction extends Extractors { } } - def extractCode(unit: CompilationUnit): Program = { + def extractCode(unit: CompilationUnit, skipNonPureInstructions: Boolean): Program = { import scala.collection.mutable.HashMap def s2ps(tree: Tree): Expr = toPureScala(unit)(tree) match { @@ -252,7 +254,7 @@ trait CodeExtraction extends Extractors { } val bodyAttempt = try { - Some(scala2PureScala(unit, pluginInstance.silentlyTolerateNonPureBodies)(realBody)) + Some(scala2PureScala(unit, pluginInstance.silentlyTolerateNonPureBodies, skipNonPureInstructions)(realBody)) } catch { case e: ImpureCodeEncounteredException => None } @@ -270,6 +272,7 @@ trait CodeExtraction extends Extractors { // Reverse map for Scala class symbols reverseClassesToClasses ++= classesToClasses.map{ case (a, b) => (b, a) } + reverseVarSubsts ++= varSubsts.map{ case (a, b) => (b(), a) } val programName: Identifier = unit.body match { case PackageDef(name, _) => FreshIdentifier(name.toString) @@ -281,14 +284,12 @@ trait CodeExtraction extends Extractors { } def extractPredicate(unit: CompilationUnit, params: Seq[ValDef], body: Tree) : FunDef = { - def s2ps(tree: Tree): Expr = toPureScala(unit)(tree) match { - case Some(ex) => ex - case None => stopIfErrors; scala.Predef.error("unreachable error.") - } - - def st2ps(tree: Type): purescala.TypeTrees.TypeTree = toPureScalaType(unit)(tree) match { - case Some(tt) => tt - case None => stopIfErrors; scala.Predef.error("unreachable error.") + def st2ps(tree: Type): purescala.TypeTrees.TypeTree = { + try { + scalaType2PureScala(unit, true)(tree) + } catch { + case ImpureCodeEncounteredException(_) => stopIfErrors; scala.Predef.error("unreachable error.") + } } val newParams = params.map(p => { @@ -299,7 +300,7 @@ trait CodeExtraction extends Extractors { }) val fd = new FunDef(FreshIdentifier("predicate"), BooleanType, newParams) - val bodyAttempt = toPureScala(unit)(body) + val bodyAttempt = try { Some(scala2PureScala(unit, true, false)(body)) } catch { case ImpureCodeEncounteredException(_) => None } fd.body = bodyAttempt fd } @@ -310,7 +311,7 @@ trait CodeExtraction extends Extractors { /** Attempts to convert a scalac AST to a pure scala AST. */ def toPureScala(unit: CompilationUnit)(tree: Tree): Option[Expr] = { try { - Some(scala2PureScala(unit, false)(tree)) + Some(scala2PureScala(unit, false, false)(tree)) } catch { case ImpureCodeEncounteredException(_) => None } @@ -327,7 +328,7 @@ trait CodeExtraction extends Extractors { /** Forces conversion from scalac AST to purescala AST, throws an Exception * if impossible. If not in 'silent mode', non-pure AST nodes are reported as * errors. */ - private def scala2PureScala(unit: CompilationUnit, silent: Boolean)(tree: Tree): Expr = { + private def scala2PureScala(unit: CompilationUnit, silent: Boolean, skipNonPureInstructions: Boolean)(tree: Tree): Expr = { def rewriteCaseDef(cd: CaseDef): MatchCase = { def pat2pat(p: Tree): Pattern = p match { case Ident(nme.WILDCARD) => WildcardPattern(None) @@ -581,6 +582,10 @@ trait CodeExtraction extends Extractors { CaseClassSelector(selDef, selector, fieldID).setType(fieldID.getType) } + case ExSkipTree(rest) if skipNonPureInstructions => { + rec(rest) + } + // default behaviour is to complain :) case _ => { if(!silent) { diff --git a/src/funcheck/Extractors.scala b/src/funcheck/Extractors.scala index 79ace4187aa2f9b2e1cf4da10a2cc18f4a0722af..9d744cd494ff91e50b56c6434197742b5c6451cc 100644 --- a/src/funcheck/Extractors.scala +++ b/src/funcheck/Extractors.scala @@ -76,6 +76,18 @@ trait Extractors { } } + object ExSkipTree { + /** Skips the first tree in a block */ + def unapply(tree: Block): Option[Tree] = tree match { + case Block(t :: ts, expr) => + if (ts.isEmpty) + Some(expr) + else + Some(Block(ts, expr)) + case _ => None + } + } + object ExObjectDef { /** Matches an object with no type parameters, and regardless of its * visibility. Does not match on the automatically generated companion diff --git a/src/purescala/Definitions.scala b/src/purescala/Definitions.scala index 2384edf71847be45e377a62bf212c17a96129dee..547ac316eb81dbd608b623bf8066327adbabe409 100644 --- a/src/purescala/Definitions.scala +++ b/src/purescala/Definitions.scala @@ -47,6 +47,7 @@ object Definitions { def transitiveCallees(f1: FunDef) = mainObject.transitiveCallees(f1) def isRecursive(f1: FunDef) = mainObject.isRecursive(f1) def isCatamorphism(f1: FunDef) = mainObject.isCatamorphism(f1) + def caseClassDef(name: String) = mainObject.caseClassDef(name) } /** Objects work as containers for class definitions, functions (def's) and @@ -56,6 +57,9 @@ object Definitions { lazy val definedClasses : Seq[ClassTypeDef] = defs.filter(_.isInstanceOf[ClassTypeDef]).map(_.asInstanceOf[ClassTypeDef]) + def caseClassDef(caseClassName : String) : CaseClassDef = + definedClasses.find(ctd => ctd.id.name == caseClassName).getOrElse(scala.Predef.error("Asking for non-existent case class def: " + caseClassName)).asInstanceOf[CaseClassDef] + lazy val classHierarchyRoots : Seq[ClassTypeDef] = defs.filter(_.isInstanceOf[ClassTypeDef]).map(_.asInstanceOf[ClassTypeDef]).filter(!_.hasParent) lazy val (callGraph, callers, callees) = {