diff --git a/src/funcheck/CodeExtraction.scala b/src/funcheck/CodeExtraction.scala index 7f592aeb72a53e07f2edabac4776e1c118c9c5b4..5cba1395e4eb27ae8c711e7a4dde9c77de17f673 100644 --- a/src/funcheck/CodeExtraction.scala +++ b/src/funcheck/CodeExtraction.scala @@ -16,6 +16,8 @@ trait CodeExtraction extends Extractors { import StructuralExtractors._ import ExpressionExtractors._ + private lazy val setTraitSym = definitions.getClass("scala.collection.immutable.Set") + private val varSubsts: scala.collection.mutable.Map[Identifier,Function0[Expr]] = scala.collection.mutable.Map.empty[Identifier,Function0[Expr]] def extractCode(unit: CompilationUnit): Program = { @@ -26,7 +28,7 @@ trait CodeExtraction extends Extractors { case None => stopIfErrors; scala.Predef.error("unreachable error.") } - def st2ps(tree: Tree): purescala.TypeTrees.TypeTree = toPureScalaType(unit)(tree) match { + def st2ps(tree: Type): purescala.TypeTrees.TypeTree = toPureScalaType(unit)(tree) match { case Some(tt) => tt case None => stopIfErrors; scala.Predef.error("unreachable error.") } @@ -90,7 +92,7 @@ trait CodeExtraction extends Extractors { var reqCont: Option[Expr] = None var ensCont: Option[Expr] = None - val ps = params.map(p => VarDecl(p.name.toString, st2ps(p.tpt))) + val ps = params.map(p => VarDecl(p.name.toString, st2ps(p.tpt.tpe))) realBody match { case ExEnsuredExpression(body2, resId, contract) => { @@ -111,7 +113,7 @@ trait CodeExtraction extends Extractors { case _ => ; } - FunDef(name, st2ps(tpt), ps, s2ps(realBody), reqCont, ensCont) + FunDef(name, st2ps(tpt.tpe), ps, s2ps(realBody), reqCont, ensCont) } // THE EXTRACTION CODE STARTS HERE @@ -140,7 +142,7 @@ trait CodeExtraction extends Extractors { } } - def toPureScalaType(unit: CompilationUnit)(typeTree: Tree): Option[purescala.TypeTrees.TypeTree] = { + def toPureScalaType(unit: CompilationUnit)(typeTree: Type): Option[purescala.TypeTrees.TypeTree] = { try { Some(scalaType2PureScala(unit, false)(typeTree)) } catch { @@ -187,20 +189,27 @@ trait CodeExtraction extends Extractors { rec(tree) } - private def scalaType2PureScala(unit: CompilationUnit, silent: Boolean)(tree: Tree): purescala.TypeTrees.TypeTree = { - tree match { - case tt: TypeTree if tt.tpe == IntClass.tpe => Int32Type - case tt: TypeTree if tt.tpe == BooleanClass.tpe => BooleanType - case tt: TypeTree => tt.tpe match { - case TypeRef(_,sym,_) if sym == setTraitSym => XXXXXX - } + private def scalaType2PureScala(unit: CompilationUnit, silent: Boolean)(tree: Type): purescala.TypeTrees.TypeTree = { - case tt => { + def rec(tr: Type): purescala.TypeTrees.TypeTree = tr match { + case tpe if tpe == IntClass.tpe => Int32Type + case tpe if tpe == BooleanClass.tpe => BooleanType + case TypeRef(_, sym, btt :: Nil) if sym == setTraitSym => SetType(rec(btt)) + + case _ => { if(!silent) { - unit.error(tree.pos, "Could not extract type as PureScala. [" + tt + "]") + unit.error(NoPosition, "Could not extract type as PureScala. [" + tr + "]") } - throw ImpureCodeEncounteredException(tree) + throw ImpureCodeEncounteredException(null) } + // case tt => { + // if(!silent) { + // unit.error(tree.pos, "This does not appear to be a type tree: [" + tt + "]") + // } + // throw ImpureCodeEncounteredException(tree) + // } } + + rec(tree) } } diff --git a/src/purescala/PrettyPrinter.scala b/src/purescala/PrettyPrinter.scala index 4d8f89bff786a65ef2fe3ea8c780fb6ea139b0cb..5442e0c8cc32c85d21e63227d55e69c2111943f7 100644 --- a/src/purescala/PrettyPrinter.scala +++ b/src/purescala/PrettyPrinter.scala @@ -102,6 +102,7 @@ object PrettyPrinter { private def pp(tpe: TypeTree, sb: StringBuffer): StringBuffer = tpe match { case Int32Type => sb.append("Int") case BooleanType => sb.append("Boolean") + case SetType(bt) => pp(bt, sb.append("Set[")).append("]") case _ => sb.append("Type?") }