From 37a4180ab8efb6397193cf235a272ff47d95fa12 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ali=20Sinan=20K=C3=B6ksal?= <alisinan@gmail.com> Date: Mon, 11 Apr 2011 13:00:23 +0000 Subject: [PATCH] Revert classes in 'funcheck' package to how they were before the CP implementation, copy necessary classes into 'cp' package --- src/cp/Annotations.scala | 5 + src/cp/CPComponent.scala | 1 - src/cp/CPPlugin.scala | 5 +- src/cp/CodeExtraction.scala | 636 +++++++++++++++++++++++++++ src/cp/Extractors.scala | 470 ++++++++++++++++++++ src/funcheck/AnalysisComponent.scala | 8 +- src/funcheck/CodeExtraction.scala | 22 +- src/funcheck/Extractors.scala | 14 +- src/funcheck/FunCheckPlugin.scala | 3 +- src/funcheck/PluginBase.scala | 11 - 10 files changed, 1129 insertions(+), 46 deletions(-) create mode 100644 src/cp/Annotations.scala create mode 100644 src/cp/CodeExtraction.scala create mode 100644 src/cp/Extractors.scala delete mode 100644 src/funcheck/PluginBase.scala diff --git a/src/cp/Annotations.scala b/src/cp/Annotations.scala new file mode 100644 index 000000000..84955b79b --- /dev/null +++ b/src/cp/Annotations.scala @@ -0,0 +1,5 @@ +package cp + +object Annotations { + class purescala extends StaticAnnotation +} diff --git a/src/cp/CPComponent.scala b/src/cp/CPComponent.scala index 673780571..22c2c91ce 100644 --- a/src/cp/CPComponent.scala +++ b/src/cp/CPComponent.scala @@ -2,7 +2,6 @@ package cp import scala.tools.nsc._ import scala.tools.nsc.plugins._ -import funcheck.CodeExtraction class CPComponent(val global: Global, val pluginInstance: CPPlugin) extends PluginComponent diff --git a/src/cp/CPPlugin.scala b/src/cp/CPPlugin.scala index d5f307a19..c1e57f811 100644 --- a/src/cp/CPPlugin.scala +++ b/src/cp/CPPlugin.scala @@ -3,10 +3,9 @@ package cp import scala.tools.nsc import scala.tools.nsc.{Global,Phase} import scala.tools.nsc.plugins.{Plugin,PluginComponent} -import funcheck.PluginBase /** This class is the entry point for the plugin. */ -class CPPlugin(val global: Global) extends PluginBase { +class CPPlugin(val global: Global) extends Plugin { import global._ val name = "constraint-programming" @@ -15,7 +14,7 @@ class CPPlugin(val global: Global) extends PluginBase { var stopAfterAnalysis: Boolean = true var stopAfterExtraction: Boolean = false - silentlyTolerateNonPureBodies = true + var silentlyTolerateNonPureBodies = true /** The help message displaying the options for that plugin. */ override val optionsHelp: Option[String] = Some( diff --git a/src/cp/CodeExtraction.scala b/src/cp/CodeExtraction.scala new file mode 100644 index 000000000..b870e74d1 --- /dev/null +++ b/src/cp/CodeExtraction.scala @@ -0,0 +1,636 @@ +package cp + +import scala.tools.nsc._ +import scala.tools.nsc.plugins._ + +import purescala.Definitions._ +import purescala.Trees._ +import purescala.TypeTrees._ +import purescala.Common._ + +trait CodeExtraction extends Extractors { + // self: AnalysisComponent => + + import global._ + import global.definitions._ + import StructuralExtractors._ + import ExpressionExtractors._ + + private lazy val setTraitSym = definitions.getClass("scala.collection.immutable.Set") + private lazy val multisetTraitSym = definitions.getClass("scala.collection.immutable.Multiset") + + private val varSubsts: scala.collection.mutable.Map[Symbol,Function0[Expr]] = + scala.collection.mutable.Map.empty[Symbol,Function0[Expr]] + private val classesToClasses: scala.collection.mutable.Map[Symbol,ClassTypeDef] = + scala.collection.mutable.Map.empty[Symbol,ClassTypeDef] + private val defsToDefs: scala.collection.mutable.Map[Symbol,FunDef] = + scala.collection.mutable.Map.empty[Symbol,FunDef] + + private val reverseVarSubsts_ : scala.collection.mutable.Map[Expr,Symbol] = + scala.collection.mutable.Map.empty[Expr,Symbol] + private val reverseClassesToClasses_ : scala.collection.mutable.Map[ClassTypeDef,Symbol] = + scala.collection.mutable.Map.empty[ClassTypeDef,Symbol] + + def reverseVarSubsts: scala.collection.immutable.Map[Expr,Symbol] = + scala.collection.immutable.Map() ++ reverseVarSubsts_ + + def reverseClassesToClasses: scala.collection.immutable.Map[ClassTypeDef,Symbol] = + scala.collection.immutable.Map() ++ reverseClassesToClasses_ + + protected def stopIfErrors: Unit = { + if(reporter.hasErrors) { + throw new Exception("There were errors.") + } + } + + def extractCode(unit: CompilationUnit, skipNonPureInstructions: Boolean): Program = { + import scala.collection.mutable.HashMap + + def s2ps(tree: Tree): Expr = toPureScala(unit)(tree) match { + case Some(ex) => ex + case None => stopIfErrors; scala.Predef.error("unreachable error.") + } + + def st2ps(tree: Type): purescala.TypeTrees.TypeTree = toPureScalaType(unit)(tree) match { + case Some(tt) => tt + case None => stopIfErrors; scala.Predef.error("unreachable error.") + } + + def extractTopLevelDef: ObjectDef = { + val top = unit.body match { + case p @ PackageDef(name, lst) if lst.size == 0 => { + unit.error(p.pos, "No top-level definition found.") + None + } + + case PackageDef(name, lst) if lst.size > 1 => { + unit.error(lst(1).pos, "Too many top-level definitions.") + None + } + + case PackageDef(name, lst) => { + assert(lst.size == 1) + lst(0) match { + case ExObjectDef(n, templ) => Some(extractObjectDef(n.toString, templ)) + case other @ _ => unit.error(other.pos, "Expected: top-level single object.") + None + } + } + } + + stopIfErrors + top.get + } + + def extractObjectDef(nameStr: String, tmpl: Template): ObjectDef = { + // we assume that the template actually corresponds to an object + // definition. Typically it should have been obtained from the proper + // extractor (ExObjectDef) + + var classDefs: List[ClassTypeDef] = Nil + var objectDefs: List[ObjectDef] = Nil + var funDefs: List[FunDef] = Nil + + val scalaClassSyms: scala.collection.mutable.Map[Symbol,Identifier] = + scala.collection.mutable.Map.empty[Symbol,Identifier] + val scalaClassArgs: scala.collection.mutable.Map[Symbol,Seq[(String,Tree)]] = + scala.collection.mutable.Map.empty[Symbol,Seq[(String,Tree)]] + val scalaClassNames: scala.collection.mutable.Set[String] = + scala.collection.mutable.Set.empty[String] + + // we need the new type definitions before we can do anything... + tmpl.body.foreach(t => + t match { + case ExAbstractClass(o2, sym) => { + if(scalaClassNames.contains(o2)) { + unit.error(t.pos, "A class with the same name already exists.") + } else { + scalaClassSyms += (sym -> FreshIdentifier(o2)) + scalaClassNames += o2 + } + } + case ExCaseClass(o2, sym, args) => { + if(scalaClassNames.contains(o2)) { + unit.error(t.pos, "A class with the same name already exists.") + } else { + scalaClassSyms += (sym -> FreshIdentifier(o2)) + scalaClassNames += o2 + scalaClassArgs += (sym -> args) + } + } + case _ => ; + } + ) + + stopIfErrors + + + scalaClassSyms.foreach(p => { + if(p._1.isAbstractClass) { + classesToClasses += (p._1 -> new AbstractClassDef(p._2)) + } else if(p._1.isCase) { + classesToClasses += (p._1 -> new CaseClassDef(p._2)) + } + }) + + classesToClasses.foreach(p => { + val superC: List[ClassTypeDef] = p._1.tpe.baseClasses.filter(bcs => scalaClassSyms.exists(pp => pp._1 == bcs) && bcs != p._1).map(s => classesToClasses(s)).toList + + val superAC: List[AbstractClassDef] = superC.map(c => { + if(!c.isInstanceOf[AbstractClassDef]) { + unit.error(p._1.pos, "Class is inheriting from non-abstract class.") + null + } else { + c.asInstanceOf[AbstractClassDef] + } + }).filter(_ != null) + + if(superAC.length > 1) { + unit.error(p._1.pos, "Multiple inheritance.") + } + + if(superAC.length == 1) { + p._2.setParent(superAC.head) + } + + if(p._2.isInstanceOf[CaseClassDef]) { + // this should never fail + val ccargs = scalaClassArgs(p._1) + p._2.asInstanceOf[CaseClassDef].fields = ccargs.map(cca => { + val cctpe = st2ps(cca._2.tpe) + VarDecl(FreshIdentifier(cca._1).setType(cctpe), cctpe) + }) + } + }) + + classDefs = classesToClasses.valuesIterator.toList + + // end of class (type) extraction + + // we now extract the function signatures. + tmpl.body.foreach( + _ match { + case ExMainFunctionDef() => ; + case dd @ ExFunctionDef(n,p,t,b) => { + val mods = dd.mods + val funDef = extractFunSig(n, p, t).setPosInfo(dd.pos.line, dd.pos.column) + if(mods.isPrivate) funDef.addAnnotation("private") + for(a <- dd.symbol.annotations) { + a.atp.safeToString match { + case "funcheck.Annotations.induct" => funDef.addAnnotation("induct") + case _ => ; + } + } + defsToDefs += (dd.symbol -> funDef) + } + case _ => ; + } + ) + + // then their bodies. + tmpl.body.foreach( + _ match { + case ExMainFunctionDef() => ; + case dd @ ExFunctionDef(n,p,t,b) => { + val fd = defsToDefs(dd.symbol) + defsToDefs(dd.symbol) = extractFunDef(fd, b) + } + case _ => ; + } + ) + + funDefs = defsToDefs.valuesIterator.toList + + // we check nothing else is polluting the object. + tmpl.body.foreach( + _ match { + case ExCaseClassSyntheticJunk() => ; + // case ExObjectDef(o2, t2) => { objectDefs = extractObjectDef(o2, t2) :: objectDefs } + case ExAbstractClass(_,_) => ; + case ExCaseClass(_,_,_) => ; + case ExConstructorDef() => ; + case ExMainFunctionDef() => ; + case ExFunctionDef(_,_,_,_) => ; + case tree => { unit.error(tree.pos, "Don't know what to do with this. Not purescala?"); println(tree) } + } + ) + + val name: Identifier = FreshIdentifier(nameStr) + val theDef = new ObjectDef(name, objectDefs.reverse ::: classDefs ::: funDefs, Nil) + + theDef + } + + def extractFunSig(nameStr: String, params: Seq[ValDef], tpt: Tree): FunDef = { + val newParams = params.map(p => { + val ptpe = st2ps(p.tpt.tpe) + val newID = FreshIdentifier(p.name.toString).setType(ptpe) + varSubsts(p.symbol) = (() => Variable(newID)) + VarDecl(newID, ptpe) + }) + new FunDef(FreshIdentifier(nameStr), st2ps(tpt.tpe), newParams) + } + + def extractFunDef(funDef: FunDef, body: Tree): FunDef = { + var realBody = body + var reqCont: Option[Expr] = None + var ensCont: Option[Expr] = None + + realBody match { + case ExEnsuredExpression(body2, resSym, contract) => { + varSubsts(resSym) = (() => ResultVariable().setType(funDef.returnType)) + val c1 = s2ps(contract) + // varSubsts.remove(resSym) + realBody = body2 + ensCont = Some(c1) + } + case ExHoldsExpression(body2) => { + realBody = body2 + ensCont = Some(ResultVariable().setType(BooleanType)) + } + case _ => ; + } + + realBody match { + case ExRequiredExpression(body3, contract) => { + realBody = body3 + reqCont = Some(s2ps(contract)) + } + case _ => ; + } + + val bodyAttempt = try { + Some(scala2PureScala(unit, pluginInstance.silentlyTolerateNonPureBodies, skipNonPureInstructions)(realBody)) + } catch { + case e: ImpureCodeEncounteredException => None + } + + funDef.body = bodyAttempt + funDef.precondition = reqCont + funDef.postcondition = ensCont + funDef + } + + // THE EXTRACTION CODE STARTS HERE + val topLevelObjDef: ObjectDef = extractTopLevelDef + + stopIfErrors + + // Reverse map for Scala class symbols + reverseClassesToClasses_ ++= classesToClasses.map{ case (a, b) => (b, a) } + reverseVarSubsts_ ++= varSubsts.map{ case (a, b) => (b(), a) } + + val programName: Identifier = unit.body match { + case PackageDef(name, _) => FreshIdentifier(name.toString) + case _ => FreshIdentifier("<program>") + } + + //Program(programName, ObjectDef("Object", Nil, Nil)) + Program(programName, topLevelObjDef) + } + + def extractPredicate(unit: CompilationUnit, params: Seq[ValDef], body: Tree) : FunDef = { + def st2ps(tree: Type): purescala.TypeTrees.TypeTree = { + try { + scalaType2PureScala(unit, true)(tree) + } catch { + case ImpureCodeEncounteredException(_) => stopIfErrors; scala.Predef.error("unreachable error.") + } + } + + val newParams = params.map(p => { + val ptpe = st2ps(p.tpt.tpe) + val newID = FreshIdentifier(p.name.toString).setType(ptpe) + varSubsts(p.symbol) = (() => Variable(newID)) + VarDecl(newID, ptpe) + }) + val fd = new FunDef(FreshIdentifier("predicate"), BooleanType, newParams) + + val bodyAttempt = try { Some(scala2PureScala(unit, true, false)(body)) } catch { case ImpureCodeEncounteredException(_) => None } + fd.body = bodyAttempt + fd + } + + /** An exception thrown when non-purescala compatible code is encountered. */ + sealed case class ImpureCodeEncounteredException(tree: Tree) extends Exception + + /** Attempts to convert a scalac AST to a pure scala AST. */ + def toPureScala(unit: CompilationUnit)(tree: Tree): Option[Expr] = { + try { + Some(scala2PureScala(unit, false, false)(tree)) + } catch { + case ImpureCodeEncounteredException(_) => None + } + } + + def toPureScalaType(unit: CompilationUnit)(typeTree: Type): Option[purescala.TypeTrees.TypeTree] = { + try { + Some(scalaType2PureScala(unit, false)(typeTree)) + } catch { + case ImpureCodeEncounteredException(_) => None + } + } + + /** Forces conversion from scalac AST to purescala AST, throws an Exception + * if impossible. If not in 'silent mode', non-pure AST nodes are reported as + * errors. */ + private def scala2PureScala(unit: CompilationUnit, silent: Boolean, skipNonPureInstructions: Boolean)(tree: Tree): Expr = { + def rewriteCaseDef(cd: CaseDef): MatchCase = { + def pat2pat(p: Tree): Pattern = p match { + case Ident(nme.WILDCARD) => WildcardPattern(None) + case b @ Bind(name, Ident(nme.WILDCARD)) => { + val newID = FreshIdentifier(name.toString).setType(scalaType2PureScala(unit,silent)(b.symbol.tpe)) + varSubsts(b.symbol) = (() => Variable(newID)) + WildcardPattern(Some(newID)) + } + case a @ Apply(fn, args) if fn.isType && a.tpe.typeSymbol.isCase && classesToClasses.keySet.contains(a.tpe.typeSymbol) => { + val cd = classesToClasses(a.tpe.typeSymbol).asInstanceOf[CaseClassDef] + assert(args.size == cd.fields.size) + CaseClassPattern(None, cd, args.map(pat2pat(_))) + } + case b @ Bind(name, a @ Apply(fn, args)) if fn.isType && a.tpe.typeSymbol.isCase && classesToClasses.keySet.contains(a.tpe.typeSymbol) => { + val newID = FreshIdentifier(name.toString).setType(scalaType2PureScala(unit,silent)(b.symbol.tpe)) + varSubsts(b.symbol) = (() => Variable(newID)) + val cd = classesToClasses(a.tpe.typeSymbol).asInstanceOf[CaseClassDef] + assert(args.size == cd.fields.size) + CaseClassPattern(Some(newID), cd, args.map(pat2pat(_))) + } + case _ => { + if(!silent) + unit.error(p.pos, "Unsupported pattern.") + throw ImpureCodeEncounteredException(p) + } + } + + if(cd.guard == EmptyTree) { + SimpleCase(pat2pat(cd.pat), rec(cd.body)) + } else { + GuardedCase(pat2pat(cd.pat), rec(cd.guard), rec(cd.body)) + } + } + + def rec(tr: Tree): Expr = tr match { + case ExValDef(vs, tpt, bdy, rst) => { + val binderTpe = scalaType2PureScala(unit, silent)(tpt.tpe) + val newID = FreshIdentifier(vs.name.toString).setType(binderTpe) + val oldSubsts = varSubsts + val valTree = rec(bdy) + varSubsts(vs) = (() => Variable(newID)) + val restTree = rec(rst) + varSubsts.remove(vs) + Let(newID, valTree, restTree) + } + case ExInt32Literal(v) => IntLiteral(v).setType(Int32Type) + case ExBooleanLiteral(v) => BooleanLiteral(v).setType(BooleanType) + case ExIdentifier(sym,tpt) => varSubsts.get(sym) match { + case Some(fun) => fun() + case None => { + unit.error(tr.pos, "Unidentified variable.") + throw ImpureCodeEncounteredException(tr) + } + } + case ExCaseClassConstruction(tpt, args) => { + val cctype = scalaType2PureScala(unit, silent)(tpt.tpe) + if(!cctype.isInstanceOf[CaseClassType]) { + if(!silent) { + unit.error(tr.pos, "Construction of a non-case class.") + } + throw ImpureCodeEncounteredException(tree) + } + val nargs = args.map(rec(_)) + val cct = cctype.asInstanceOf[CaseClassType] + CaseClass(cct.classDef, nargs).setType(cct) + } + case ExAnd(l, r) => And(rec(l), rec(r)).setType(BooleanType) + case ExOr(l, r) => Or(rec(l), rec(r)).setType(BooleanType) + case ExNot(e) => Not(rec(e)).setType(BooleanType) + case ExUMinus(e) => UMinus(rec(e)).setType(Int32Type) + case ExPlus(l, r) => Plus(rec(l), rec(r)).setType(Int32Type) + case ExMinus(l, r) => Minus(rec(l), rec(r)).setType(Int32Type) + case ExTimes(l, r) => Times(rec(l), rec(r)).setType(Int32Type) + case ExDiv(l, r) => Division(rec(l), rec(r)).setType(Int32Type) + case ExEquals(l, r) => { + val rl = rec(l) + val rr = rec(r) + ((rl.getType,rr.getType) match { + case (SetType(_), SetType(_)) => SetEquals(rl, rr) + case (BooleanType, BooleanType) => Iff(rl, rr) + case (_, _) => Equals(rl, rr) + }).setType(BooleanType) + } + case ExNotEquals(l, r) => Not(Equals(rec(l), rec(r)).setType(BooleanType)).setType(BooleanType) + case ExGreaterThan(l, r) => GreaterThan(rec(l), rec(r)).setType(BooleanType) + case ExGreaterEqThan(l, r) => GreaterEquals(rec(l), rec(r)).setType(BooleanType) + case ExLessThan(l, r) => LessThan(rec(l), rec(r)).setType(BooleanType) + case ExLessEqThan(l, r) => LessEquals(rec(l), rec(r)).setType(BooleanType) + case ExFiniteSet(tt, args) => { + val underlying = scalaType2PureScala(unit, silent)(tt.tpe) + FiniteSet(args.map(rec(_))).setType(SetType(underlying)) + } + case ExFiniteMultiset(tt, args) => { + val underlying = scalaType2PureScala(unit, silent)(tt.tpe) + FiniteMultiset(args.map(rec(_))).setType(MultisetType(underlying)) + } + case ExEmptySet(tt) => { + val underlying = scalaType2PureScala(unit, silent)(tt.tpe) + EmptySet(underlying).setType(SetType(underlying)) + } + case ExEmptyMultiset(tt) => { + val underlying = scalaType2PureScala(unit, silent)(tt.tpe) + EmptyMultiset(underlying).setType(MultisetType(underlying)) + } + case ExSetMin(t) => { + val set = rec(t) + if(!set.getType.isInstanceOf[SetType]) { + if(!silent) unit.error(t.pos, "Min should be computed on a set.") + throw ImpureCodeEncounteredException(tree) + } + SetMin(set).setType(set.getType.asInstanceOf[SetType].base) + } + case ExSetMax(t) => { + val set = rec(t) + if(!set.getType.isInstanceOf[SetType]) { + if(!silent) unit.error(t.pos, "Max should be computed on a set.") + throw ImpureCodeEncounteredException(tree) + } + SetMax(set).setType(set.getType.asInstanceOf[SetType].base) + } + case ExUnion(t1,t2) => { + val rl = rec(t1) + val rr = rec(t2) + rl.getType match { + case s @ SetType(_) => SetUnion(rl, rr).setType(s) + case m @ MultisetType(_) => MultisetUnion(rl, rr).setType(m) + case _ => { + if(!silent) unit.error(tree.pos, "Union of non set/multiset expressions.") + throw ImpureCodeEncounteredException(tree) + } + } + } + case ExIntersection(t1,t2) => { + val rl = rec(t1) + val rr = rec(t2) + rl.getType match { + case s @ SetType(_) => SetIntersection(rl, rr).setType(s) + case m @ MultisetType(_) => MultisetIntersection(rl, rr).setType(m) + case _ => { + if(!silent) unit.error(tree.pos, "Intersection of non set/multiset expressions.") + throw ImpureCodeEncounteredException(tree) + } + } + } + case ExSetContains(t1,t2) => { + val rl = rec(t1) + val rr = rec(t2) + rl.getType match { + case s @ SetType(_) => ElementOfSet(rr, rl) + case _ => { + if(!silent) unit.error(tree.pos, ".contains on non set expression.") + throw ImpureCodeEncounteredException(tree) + } + } + } + case ExSetSubset(t1,t2) => { + val rl = rec(t1) + val rr = rec(t2) + rl.getType match { + case s @ SetType(_) => SubsetOf(rl, rr) + case _ => { + if(!silent) unit.error(tree.pos, "Subset on non set expression.") + throw ImpureCodeEncounteredException(tree) + } + } + } + case ExSetMinus(t1,t2) => { + val rl = rec(t1) + val rr = rec(t2) + rl.getType match { + case s @ SetType(_) => SetDifference(rl, rr).setType(s) + case m @ MultisetType(_) => MultisetDifference(rl, rr).setType(m) + case _ => { + if(!silent) unit.error(tree.pos, "Difference of non set/multiset expressions.") + throw ImpureCodeEncounteredException(tree) + } + } + } + case ExSetCard(t) => { + val rt = rec(t) + rt.getType match { + case s @ SetType(_) => SetCardinality(rt) + case m @ MultisetType(_) => MultisetCardinality(rt) + case _ => { + if(!silent) unit.error(tree.pos, "Cardinality of non set/multiset expressions.") + throw ImpureCodeEncounteredException(tree) + } + } + } + case ExMultisetToSet(t) => { + val rt = rec(t) + rt.getType match { + case m @ MultisetType(u) => MultisetToSet(rt).setType(SetType(u)) + case _ => { + if(!silent) unit.error(tree.pos, "toSet can only be applied to multisets.") + throw ImpureCodeEncounteredException(tree) + } + } + } + + case ExPlusPlusPlus(t1,t2) => { + val rl = rec(t1) + val rr = rec(t2) + MultisetPlus(rl, rr).setType(rl.getType) + } + case ExIfThenElse(t1,t2,t3) => { + val r1 = rec(t1) + val r2 = rec(t2) + val r3 = rec(t3) + IfExpr(r1, r2, r3).setType(leastUpperBound(r2.getType, r3.getType)) + } + case lc @ ExLocalCall(sy,nm,ar) => { + if(defsToDefs.keysIterator.find(_ == sy).isEmpty) { + if(!silent) + unit.error(tr.pos, "Invoking an invalid function.") + throw ImpureCodeEncounteredException(tr) + } + val fd = defsToDefs(sy) + FunctionInvocation(fd, ar.map(rec(_))).setType(fd.returnType).setPosInfo(lc.pos.line,lc.pos.column) + } + case pm @ ExPatternMatching(sel, cses) => { + val rs = rec(sel) + val rc = cses.map(rewriteCaseDef(_)) + val rt: purescala.TypeTrees.TypeTree = rc.map(_.rhs.getType).reduceLeft(leastUpperBound(_,_)) + MatchExpr(rs, rc).setType(rt).setPosInfo(pm.pos.line,pm.pos.column) + } + + // this one should stay after all others, cause it also catches UMinus + // and Not, for instance. + case ExParameterlessMethodCall(t,n) => { + val selector = rec(t) + val selType = selector.getType + + if(!selType.isInstanceOf[CaseClassType]) { + if(!silent) + unit.error(tr.pos, "Invalid method or field invocation (not purescala?)") + throw ImpureCodeEncounteredException(tr) + } + + val selDef: CaseClassDef = selType.asInstanceOf[CaseClassType].classDef + + val fieldID = selDef.fields.find(_.id.name == n.toString) match { + case None => { + if(!silent) + unit.error(tr.pos, "Invalid method or field invocation (not a case class arg?)") + throw ImpureCodeEncounteredException(tr) + } + case Some(vd) => vd.id + } + + CaseClassSelector(selDef, selector, fieldID).setType(fieldID.getType) + } + + case ExSkipTree(rest) if skipNonPureInstructions => { + rec(rest) + } + + // default behaviour is to complain :) + case _ => { + if(!silent) { + println(tr) + reporter.info(tr.pos, "Could not extract as PureScala.", true) + } + throw ImpureCodeEncounteredException(tree) + } + } + rec(tree) + } + + private def scalaType2PureScala(unit: CompilationUnit, silent: Boolean)(tree: Type): purescala.TypeTrees.TypeTree = { + + def rec(tr: Type): purescala.TypeTrees.TypeTree = tr match { + case tpe if tpe == IntClass.tpe => Int32Type + case tpe if tpe == BooleanClass.tpe => BooleanType + case TypeRef(_, sym, btt :: Nil) if sym == setTraitSym => SetType(rec(btt)) + case TypeRef(_, sym, btt :: Nil) if sym == multisetTraitSym => MultisetType(rec(btt)) + case TypeRef(_, sym, Nil) if classesToClasses.keySet.contains(sym) => classDefToClassType(classesToClasses(sym)) + + case _ => { + if(!silent) { + unit.error(NoPosition, "Could not extract type as PureScala. [" + tr + "]") + } + throw ImpureCodeEncounteredException(null) + } + // case tt => { + // if(!silent) { + // unit.error(tree.pos, "This does not appear to be a type tree: [" + tt + "]") + // } + // throw ImpureCodeEncounteredException(tree) + // } + } + + rec(tree) + } + + def mkPosString(pos: scala.tools.nsc.util.Position) : String = { + pos.line + "," + pos.column + } +} diff --git a/src/cp/Extractors.scala b/src/cp/Extractors.scala new file mode 100644 index 000000000..76beda41d --- /dev/null +++ b/src/cp/Extractors.scala @@ -0,0 +1,470 @@ +package cp + +import scala.tools.nsc._ + +/** Contains extractors to pull-out interesting parts of the Scala ASTs. */ +trait Extractors { + val global: Global + val pluginInstance: CPPlugin + + import global._ + import global.definitions._ + + private lazy val setTraitSym = definitions.getClass("scala.collection.immutable.Set") + private lazy val multisetTraitSym = definitions.getClass("scala.collection.immutable.Multiset") + + object StructuralExtractors { + object ScalaPredef { + /** Extracts method calls from scala.Predef. */ + def unapply(tree: Select): Option[String] = tree match { + case Select(Select(This(scalaName),predefName),symName) + if("scala".equals(scalaName.toString) && "Predef".equals(predefName.toString)) => + Some(symName.toString) + case _ => None + } + } + + object ExEnsuredExpression { + /** Extracts the 'ensuring' contract from an expression. */ + def unapply(tree: Apply): Option[(Tree,Symbol,Tree)] = tree match { + case Apply( + Select( + Apply( + TypeApply( + ScalaPredef("any2Ensuring"), + TypeTree() :: Nil), + body :: Nil), + ensuringName), + (Function((vd @ ValDef(_, _, _, EmptyTree)) :: Nil, contractBody)) :: Nil) + if("ensuring".equals(ensuringName.toString)) => Some((body, vd.symbol, contractBody)) + case _ => None + } + } + + object ExHoldsExpression { + def unapply(tree: Select) : Option[Tree] = tree match { + case Select(Apply(Select(Select(funcheckIdent, utilsName), any2IsValidName), realExpr :: Nil), holdsName) if ( + utilsName.toString == "Utils" && + any2IsValidName.toString == "any2IsValid" && + holdsName.toString == "holds") => Some(realExpr) + case _ => None + } + } + + object ExRequiredExpression { + /** Extracts the 'require' contract from an expression (only if it's the + * first call in the block). */ + def unapply(tree: Block): Option[(Tree,Tree)] = tree match { + case Block(Apply(ScalaPredef("require"), contractBody :: Nil) :: rest, body) => + if(rest.isEmpty) + Some((body,contractBody)) + else + Some((Block(rest,body),contractBody)) + case _ => None + } + } + + object ExValDef { + /** Extracts val's in the head of blocks. */ + def unapply(tree: Block): Option[(Symbol,Tree,Tree,Tree)] = tree match { + case Block((vd @ ValDef(_, _, tpt, rhs)) :: rest, expr) => + if(rest.isEmpty) + Some((vd.symbol, tpt, rhs, expr)) + else + Some((vd.symbol, tpt, rhs, Block(rest, expr))) + case _ => None + } + } + + object ExSkipTree { + /** Skips the first tree in a block */ + def unapply(tree: Block): Option[Tree] = tree match { + case Block(t :: ts, expr) => + if (ts.isEmpty) + Some(expr) + else + Some(Block(ts, expr)) + case _ => None + } + } + + object ExObjectDef { + /** Matches an object with no type parameters, and regardless of its + * visibility. Does not match on the automatically generated companion + * objects of case classes (or any synthetic class). */ + def unapply(cd: ClassDef): Option[(String,Template)] = cd match { + case ClassDef(_, name, tparams, impl) if (cd.symbol.isModuleClass && tparams.isEmpty && !cd.symbol.isSynthetic) => { + Some((name.toString, impl)) + } + case _ => None + } + } + + object ExAbstractClass { + /** Matches an abstract class or a trait with no type parameters, no + * constrctor args (in the case of a class), no implementation details, + * no abstract members. */ + def unapply(cd: ClassDef): Option[(String,Symbol)] = cd match { + // abstract class + case ClassDef(_, name, tparams, impl) if (cd.symbol.isAbstractClass && tparams.isEmpty && impl.body.size == 1) => Some((name.toString, cd.symbol)) + + case _ => None + } + } + + object ExCaseClass { + def unapply(cd: ClassDef): Option[(String,Symbol,Seq[(String,Tree)])] = cd match { + case ClassDef(_, name, tparams, impl) if (cd.symbol.isCase && !cd.symbol.isAbstractClass && tparams.isEmpty && impl.body.size >= 8) => { + val constructor: DefDef = impl.children.find(child => child match { + case ExConstructorDef() => true + case _ => false + }).get.asInstanceOf[DefDef] + + val args = constructor.vparamss(0).map(vd => (vd.name.toString, vd.tpt)) + + Some((name.toString, cd.symbol, args)) + } + case _ => None + } + } + + object ExCaseClassSyntheticJunk { + def unapply(cd: ClassDef): Boolean = cd match { + case ClassDef(_, _, _, _) if (cd.symbol.isSynthetic && cd.symbol.isFinal) => true + case _ => false + } + } + + object ExConstructorDef { + def unapply(dd: DefDef): Boolean = dd match { + case DefDef(_, name, tparams, vparamss, tpt, rhs) if(name == nme.CONSTRUCTOR && tparams.isEmpty && vparamss.size == 1) => true + case _ => false + } + } + + object ExMainFunctionDef { + def unapply(dd: DefDef): Boolean = dd match { + case DefDef(_, name, tparams, vparamss, tpt, rhs) if(name.toString == "main" && tparams.isEmpty && vparamss.size == 1 && vparamss(0).size == 1) => { + true + } + case _ => false + } + } + + object ExFunctionDef { + /** Matches a function with a single list of arguments, no type + * parameters and regardless of its visibility. */ + def unapply(dd: DefDef): Option[(String,Seq[ValDef],Tree,Tree)] = dd match { + case DefDef(_, name, tparams, vparamss, tpt, rhs) if(tparams.isEmpty && vparamss.size == 1 && name != nme.CONSTRUCTOR) => Some((name.toString, vparamss(0), tpt, rhs)) + case _ => None + } + } + } + + object ExpressionExtractors { + object ExIfThenElse { + def unapply(tree: If): Option[(Tree,Tree,Tree)] = tree match { + case If(t1,t2,t3) => Some((t1,t2,t3)) + case _ => None + } + } + + object ExBooleanLiteral { + def unapply(tree: Literal): Option[Boolean] = tree match { + case Literal(Constant(true)) => Some(true) + case Literal(Constant(false)) => Some(false) + case _ => None + } + } + + object ExInt32Literal { + def unapply(tree: Literal): Option[Int] = tree match { + case Literal(c @ Constant(i)) if c.tpe == IntClass.tpe => Some(c.intValue) + case _ => None + } + } + + object ExCaseClassConstruction { + def unapply(tree: Apply): Option[(Tree,Seq[Tree])] = tree match { + case Apply(s @ Select(New(tpt), n), args) if (n == nme.CONSTRUCTOR) => { + Some((tpt, args)) + } + case _ => None + } + } + + object ExIdentifier { + def unapply(tree: Ident): Option[(Symbol,Tree)] = tree match { + case i: Ident => Some((i.symbol, i)) + case _ => None + } + } + + object ExIntIdentifier { + def unapply(tree: Ident): Option[String] = tree match { + case i: Ident if i.symbol.tpe == IntClass.tpe => Some(i.symbol.name.toString) + case _ => None + } + } + + object ExAnd { + def unapply(tree: Apply): Option[(Tree,Tree)] = tree match { + case Apply(s @ Select(lhs, _), List(rhs)) if (s.symbol == Boolean_and) => + Some((lhs,rhs)) + case _ => None + } + } + + object ExOr { + def unapply(tree: Apply): Option[(Tree,Tree)] = tree match { + case Apply(s @ Select(lhs, _), List(rhs)) if (s.symbol == Boolean_or) => + Some((lhs,rhs)) + case _ => None + } + } + + object ExNot { + def unapply(tree: Select): Option[Tree] = tree match { + case Select(t, n) if (n == nme.UNARY_!) => Some(t) + case _ => None + } + } + + object ExEquals { + def unapply(tree: Apply): Option[(Tree,Tree)] = tree match { + case Apply(Select(lhs, n), List(rhs)) if (n == nme.EQ) => Some((lhs,rhs)) + case _ => None + } + } + + object ExNotEquals { + def unapply(tree: Apply): Option[(Tree,Tree)] = tree match { + case Apply(Select(lhs, n), List(rhs)) if (n == nme.NE) => Some((lhs,rhs)) + case _ => None + } + } + + object ExLessThan { + def unapply(tree: Apply): Option[(Tree,Tree)] = tree match { + case Apply(Select(lhs, n), List(rhs)) if (n == nme.LT) => Some((lhs,rhs)) + case _ => None + } + } + + object ExLessEqThan { + def unapply(tree: Apply): Option[(Tree,Tree)] = tree match { + case Apply(Select(lhs, n), List(rhs)) if (n == nme.LE) => Some((lhs,rhs)) + case _ => None + } + } + + object ExGreaterThan { + def unapply(tree: Apply): Option[(Tree,Tree)] = tree match { + case Apply(Select(lhs, n), List(rhs)) if (n == nme.GT) => Some((lhs,rhs)) + case _ => None + } + } + + object ExGreaterEqThan { + def unapply(tree: Apply): Option[(Tree,Tree)] = tree match { + case Apply(Select(lhs, n), List(rhs)) if (n == nme.GE) => Some((lhs,rhs)) + case _ => None + } + } + + object ExUMinus { + def unapply(tree: Select): Option[Tree] = tree match { + case Select(t, n) if (n == nme.UNARY_-) => Some(t) + case _ => None + } + } + + object ExPlus { + def unapply(tree: Apply): Option[(Tree,Tree)] = tree match { + case Apply(Select(lhs, n), List(rhs)) if (n == nme.ADD) => Some((lhs,rhs)) + case _ => None + } + } + + object ExMinus { + def unapply(tree: Apply): Option[(Tree,Tree)] = tree match { + case Apply(Select(lhs, n), List(rhs)) if (n == nme.SUB) => Some((lhs,rhs)) + case _ => None + } + } + + object ExTimes { + def unapply(tree: Apply): Option[(Tree,Tree)] = tree match { + case Apply(Select(lhs, n), List(rhs)) if (n == nme.MUL) => Some((lhs,rhs)) + case _ => None + } + } + + object ExDiv { + def unapply(tree: Apply): Option[(Tree,Tree)] = tree match { + case Apply(Select(lhs, n), List(rhs)) if (n == nme.DIV) => Some((lhs,rhs)) + case _ => None + } + } + + object ExLocalCall { + def unapply(tree: Apply): Option[(Symbol,String,List[Tree])] = tree match { + case a @ Apply(Select(This(_), nme), args) => Some((a.symbol, nme.toString, args)) + case _ => None + } + } + + // used for case classes selectors. + object ExParameterlessMethodCall { + def unapply(tree: Select): Option[(Tree,Name)] = tree match { + case Select(lhs, n) => Some((lhs, n)) + case _ => None + } + } + + object ExPatternMatching { + def unapply(tree: Match): Option[(Tree,List[CaseDef])] = + if(tree != null) Some((tree.selector, tree.cases)) else None + } + + object ExSetMin { + def unapply(tree: Apply) : Option[Tree] = tree match { + case Apply( + TypeApply(Select(setTree, minName), typeTree :: Nil), + ordering :: Nil) if minName.toString == "min" && typeTree.tpe == IntClass.tpe => Some(setTree) + case _ => None + } + } + + object ExSetMax { + def unapply(tree: Apply) : Option[Tree] = tree match { + case Apply( + TypeApply(Select(setTree, maxName), typeTree :: Nil), + ordering :: Nil) if maxName.toString == "max" && typeTree.tpe == IntClass.tpe => Some(setTree) + case _ => None + } + } + + object ExEmptySet { + def unapply(tree: TypeApply): Option[Tree] = tree match { + case TypeApply( + Select( + Select( + Select( + Select(Ident(s), collectionName), + immutableName), + setName), + emptyName), theTypeTree :: Nil) if ( + collectionName.toString == "collection" && immutableName.toString == "immutable" && setName.toString == "Set" && emptyName.toString == "empty" + ) => Some(theTypeTree) + case _ => None + } + } + + object ExEmptyMultiset { + def unapply(tree: TypeApply): Option[Tree] = tree match { + case TypeApply( + Select( + Select( + Select( + Select(Ident(s), collectionName), + immutableName), + setName), + emptyName), theTypeTree :: Nil) if ( + collectionName.toString == "collection" && immutableName.toString == "immutable" && setName.toString == "Multiset" && emptyName.toString == "empty" + ) => Some(theTypeTree) + case _ => None + } + } + + object ExFiniteSet { + def unapply(tree: Apply): Option[(Tree,List[Tree])] = tree match { + case Apply( + TypeApply( + Select( + Select( + Select( + Select(Ident(s), collectionName), + immutableName), + setName), + emptyName), theTypeTree :: Nil), args) if ( + collectionName.toString == "collection" && immutableName.toString == "immutable" && setName.toString == "Set" && emptyName.toString == "apply" + )=> Some(theTypeTree, args) + case _ => None + } + } + + object ExFiniteMultiset { + def unapply(tree: Apply): Option[(Tree,List[Tree])] = tree match { + case Apply( + TypeApply( + Select( + Select( + Select( + Select(Ident(s), collectionName), + immutableName), + setName), + emptyName), theTypeTree :: Nil), args) if ( + collectionName.toString == "collection" && immutableName.toString == "immutable" && setName.toString == "Multiset" && emptyName.toString == "apply" + )=> Some(theTypeTree, args) + case _ => None + } + } + + object ExUnion { + def unapply(tree: Apply): Option[(Tree,Tree)] = tree match { + case Apply(Select(lhs, n), List(rhs)) if (n == nme.PLUSPLUS) => Some((lhs,rhs)) + case _ => None + } + } + + object ExPlusPlusPlus { + def unapply(tree: Apply): Option[(Tree,Tree)] = tree match { + case Apply(Select(lhs, n), List(rhs)) if (n.toString == "$plus$plus$plus") => Some((lhs,rhs)) + case _ => None + } + } + + object ExIntersection { + def unapply(tree: Apply): Option[(Tree,Tree)] = tree match { + case Apply(Select(lhs, n), List(rhs)) if (n == encode("**")) => Some((lhs,rhs)) + case _ => None + } + } + + object ExSetContains { + def unapply(tree: Apply) : Option[(Tree,Tree)] = tree match { + case Apply(Select(lhs, n), List(rhs)) if (n.toString == "contains") => Some((lhs,rhs)) + case _ => None + } + } + + object ExSetSubset { + def unapply(tree: Apply) : Option[(Tree,Tree)] = tree match { + case Apply(Select(lhs, n), List(rhs)) if (n.toString == "subsetOf") => Some((lhs,rhs)) + case _ => None + } + } + + object ExSetMinus { + def unapply(tree: Apply): Option[(Tree,Tree)] = tree match { + case Apply(Select(lhs, n), List(rhs)) if (n == encode("--")) => Some((lhs,rhs)) + case _ => None + } + } + + object ExSetCard { + def unapply(tree: Select): Option[Tree] = tree match { + case Select(t, n) if (n.toString == "size") => Some(t) + case _ => None + } + } + + object ExMultisetToSet { + def unapply(tree: Select): Option[Tree] = tree match { + case Select(t, n) if (n.toString == "toSet") => Some(t) + case _ => None + } + } + } +} diff --git a/src/funcheck/AnalysisComponent.scala b/src/funcheck/AnalysisComponent.scala index 49cb952a5..c0bbb4594 100644 --- a/src/funcheck/AnalysisComponent.scala +++ b/src/funcheck/AnalysisComponent.scala @@ -18,6 +18,12 @@ class AnalysisComponent(val global: Global, val pluginInstance: FunCheckPlugin) /** this is initialized when the Funcheck phase starts*/ var fresh: scala.tools.nsc.util.FreshNameCreator = null + protected def stopIfErrors: Unit = { + if(reporter.hasErrors) { + throw new Exception("There were errors.") + } + } + def newPhase(prev: Phase) = new AnalysisPhase(prev) class AnalysisPhase(prev: Phase) extends StdPhase(prev) { @@ -25,7 +31,7 @@ class AnalysisComponent(val global: Global, val pluginInstance: FunCheckPlugin) //global ref to freshName creator fresh = unit.fresh - val prog: purescala.Definitions.Program = extractCode(unit, false) + val prog: purescala.Definitions.Program = extractCode(unit) if(pluginInstance.stopAfterExtraction) { println("Extracted program for " + unit + ": ") println(prog) diff --git a/src/funcheck/CodeExtraction.scala b/src/funcheck/CodeExtraction.scala index f8fc276d6..66112aee5 100644 --- a/src/funcheck/CodeExtraction.scala +++ b/src/funcheck/CodeExtraction.scala @@ -9,7 +9,7 @@ import purescala.TypeTrees._ import purescala.Common._ trait CodeExtraction extends Extractors { - // self: AnalysisComponent => + self: AnalysisComponent => import global._ import global.definitions._ @@ -37,13 +37,7 @@ trait CodeExtraction extends Extractors { def reverseClassesToClasses: scala.collection.immutable.Map[ClassTypeDef,Symbol] = scala.collection.immutable.Map() ++ reverseClassesToClasses_ - protected def stopIfErrors: Unit = { - if(reporter.hasErrors) { - throw new Exception("There were errors.") - } - } - - def extractCode(unit: CompilationUnit, skipNonPureInstructions: Boolean): Program = { + def extractCode(unit: CompilationUnit): Program = { import scala.collection.mutable.HashMap def s2ps(tree: Tree): Expr = toPureScala(unit)(tree) match { @@ -260,7 +254,7 @@ trait CodeExtraction extends Extractors { } val bodyAttempt = try { - Some(scala2PureScala(unit, pluginInstance.silentlyTolerateNonPureBodies, skipNonPureInstructions)(realBody)) + Some(scala2PureScala(unit, pluginInstance.silentlyTolerateNonPureBodies)(realBody)) } catch { case e: ImpureCodeEncounteredException => None } @@ -306,7 +300,7 @@ trait CodeExtraction extends Extractors { }) val fd = new FunDef(FreshIdentifier("predicate"), BooleanType, newParams) - val bodyAttempt = try { Some(scala2PureScala(unit, true, false)(body)) } catch { case ImpureCodeEncounteredException(_) => None } + val bodyAttempt = try { Some(scala2PureScala(unit, true)(body)) } catch { case ImpureCodeEncounteredException(_) => None } fd.body = bodyAttempt fd } @@ -317,7 +311,7 @@ trait CodeExtraction extends Extractors { /** Attempts to convert a scalac AST to a pure scala AST. */ def toPureScala(unit: CompilationUnit)(tree: Tree): Option[Expr] = { try { - Some(scala2PureScala(unit, false, false)(tree)) + Some(scala2PureScala(unit, false)(tree)) } catch { case ImpureCodeEncounteredException(_) => None } @@ -334,7 +328,7 @@ trait CodeExtraction extends Extractors { /** Forces conversion from scalac AST to purescala AST, throws an Exception * if impossible. If not in 'silent mode', non-pure AST nodes are reported as * errors. */ - private def scala2PureScala(unit: CompilationUnit, silent: Boolean, skipNonPureInstructions: Boolean)(tree: Tree): Expr = { + private def scala2PureScala(unit: CompilationUnit, silent: Boolean)(tree: Tree): Expr = { def rewriteCaseDef(cd: CaseDef): MatchCase = { def pat2pat(p: Tree): Pattern = p match { case Ident(nme.WILDCARD) => WildcardPattern(None) @@ -588,10 +582,6 @@ trait CodeExtraction extends Extractors { CaseClassSelector(selDef, selector, fieldID).setType(fieldID.getType) } - case ExSkipTree(rest) if skipNonPureInstructions => { - rec(rest) - } - // default behaviour is to complain :) case _ => { if(!silent) { diff --git a/src/funcheck/Extractors.scala b/src/funcheck/Extractors.scala index 9d744cd49..3b540a735 100644 --- a/src/funcheck/Extractors.scala +++ b/src/funcheck/Extractors.scala @@ -5,7 +5,7 @@ import scala.tools.nsc._ /** Contains extractors to pull-out interesting parts of the Scala ASTs. */ trait Extractors { val global: Global - val pluginInstance: PluginBase + val pluginInstance: FunCheckPlugin import global._ import global.definitions._ @@ -76,18 +76,6 @@ trait Extractors { } } - object ExSkipTree { - /** Skips the first tree in a block */ - def unapply(tree: Block): Option[Tree] = tree match { - case Block(t :: ts, expr) => - if (ts.isEmpty) - Some(expr) - else - Some(Block(ts, expr)) - case _ => None - } - } - object ExObjectDef { /** Matches an object with no type parameters, and regardless of its * visibility. Does not match on the automatically generated companion diff --git a/src/funcheck/FunCheckPlugin.scala b/src/funcheck/FunCheckPlugin.scala index 4bbc7c3b8..f9ed09384 100644 --- a/src/funcheck/FunCheckPlugin.scala +++ b/src/funcheck/FunCheckPlugin.scala @@ -6,7 +6,7 @@ import scala.tools.nsc.plugins.{Plugin,PluginComponent} import purescala.Definitions.Program /** This class is the entry point for the plugin. */ -class FunCheckPlugin(val global: Global, val actionAfterExtraction : Option[Program=>Unit] = None) extends PluginBase { +class FunCheckPlugin(val global: Global, val actionAfterExtraction : Option[Program=>Unit] = None) extends Plugin { import global._ val name = "funcheck" @@ -14,6 +14,7 @@ class FunCheckPlugin(val global: Global, val actionAfterExtraction : Option[Prog var stopAfterAnalysis: Boolean = true var stopAfterExtraction: Boolean = false + var silentlyTolerateNonPureBodies: Boolean = false /** The help message displaying the options for that plugin. */ override val optionsHelp: Option[String] = Some( diff --git a/src/funcheck/PluginBase.scala b/src/funcheck/PluginBase.scala deleted file mode 100644 index 62a5e83e3..000000000 --- a/src/funcheck/PluginBase.scala +++ /dev/null @@ -1,11 +0,0 @@ -package funcheck - -import scala.tools.nsc -import scala.tools.nsc.{Global,Phase} -import scala.tools.nsc.plugins.{Plugin,PluginComponent} - -abstract class PluginBase extends Plugin { - import global._ - - var silentlyTolerateNonPureBodies: Boolean = false -} -- GitLab