diff --git a/src/cp/CallTransformation.scala b/src/cp/CallTransformation.scala index 3bc01a2e103d9b82ac19748800007c9be9859775..f541d4178a2bc7719d7ed887120278bf6a322bd0 100644 --- a/src/cp/CallTransformation.scala +++ b/src/cp/CallTransformation.scala @@ -43,7 +43,7 @@ trait CallTransformation println("Here is the extracted FunDef:") println(fd) - val codeGen = new CodeGenerator(unit, currentOwner) + val codeGen = new CodeGenerator(unit, currentOwner, tree.pos) fd.body match { case None => println("Could not extract choose predicate: " + funBody); super.transform(tree) @@ -86,7 +86,7 @@ trait CallTransformation } case cd @ ClassDef(mods, name, tparams, impl) if (cd.symbol.isModuleClass && tparams.isEmpty && !cd.symbol.isSynthetic) => { - val codeGen = new CodeGenerator(unit, currentOwner) + val codeGen = new CodeGenerator(unit, currentOwner, tree.pos) val ((e2sSym, e2sCode), (e2sCastSym,e2sCastCode)) = codeGen.exprToScalaMethods(cd.symbol, prog) exprToScalaSym = e2sSym diff --git a/src/cp/CodeGeneration.scala b/src/cp/CodeGeneration.scala index 5110648c1ba929c6b129d732bc4ccf99ff7d4c0e..639ecc10ddd69a1938dc24192cfff59b9dffc5b7 100644 --- a/src/cp/CodeGeneration.scala +++ b/src/cp/CodeGeneration.scala @@ -1,10 +1,11 @@ package cp import purescala.Trees._ +import purescala.TypeTrees.classDefToClassType import purescala.Definitions._ trait CodeGeneration { - self: CallTransformation => + self: CPComponent => import global._ import CODE._ @@ -12,7 +13,8 @@ trait CodeGeneration { private lazy val exceptionClass = definitions.getClass("java.lang.Exception") - private lazy val mapFunction = definitions.getMember(definitions.ListClass, "map") + private lazy val listMapFunction = definitions.getMember(definitions.ListClass, "map") + private lazy val listApplyFunction = definitions.getMember(definitions.ListClass, "apply") private lazy val mapClass = definitions.getClass("scala.collection.immutable.Map") @@ -32,9 +34,12 @@ trait CodeGeneration { private lazy val definitionsModule = definitions.getModule("purescala.Definitions") private lazy val programClass = definitions.getClass("purescala.Definitions.Program") + private lazy val caseClassDefClass = definitions.getClass("purescala.Definitions.CaseClassDef") + private lazy val idField = definitions.getMember(caseClassDefClass, "id") private lazy val commonModule = definitions.getModule("purescala.Common") private lazy val identifierClass = definitions.getClass("purescala.Common.Identifier") + private lazy val nameField = definitions.getMember(identifierClass, "name") private lazy val treesModule = definitions.getModule("purescala.Trees") private lazy val exprClass = definitions.getClass("purescala.Trees.Expr") @@ -48,7 +53,7 @@ trait CodeGeneration { private lazy val defaultReporter = definitions.getClass("purescala.DefaultReporter") - class CodeGenerator(unit : CompilationUnit, owner : Symbol) { + class CodeGenerator(unit : CompilationUnit, owner : Symbol, defaultPos : Position) { /* Assign the program read from file `filename` to a new variable and * return the code and the symbol for the variable */ @@ -125,15 +130,53 @@ trait CodeGeneration { owner.info.decls.enter(castMethodSym) // the following is for the recursive method - val intSym = methodSym.newValue(NoPosition, unit.fresh.newName(NoPosition, "value")).setInfo(definitions.IntClass.tpe) - val booleanSym = methodSym.newValue(NoPosition, unit.fresh.newName(NoPosition, "value")).setInfo(definitions.BooleanClass.tpe) + val intSym = methodSym.newValue(NoPosition, unit.fresh.newName(NoPosition, "value")).setInfo(definitions.IntClass.tpe) + val booleanSym = methodSym.newValue(NoPosition, unit.fresh.newName(NoPosition, "value")).setInfo(definitions.BooleanClass.tpe) - val definedCaseClasses : Seq[CaseClassDef] = prog.definedClasses.filter(_.isInstanceOf[CaseClassDef]).map(_.asInstanceOf[CaseClassDef]) + // ATTENTION the info might need module instead of class.. + val ccdBinderSym = methodSym.newValue(NoPosition, unit.fresh.newName(NoPosition, "ccd")).setInfo(caseClassDefClass.tpe) + val argsBinderSym = methodSym.newValue(NoPosition, unit.fresh.newName(NoPosition, "args")).setInfo(typeRef(NoPrefix, definitions.SeqClass, List(exprClass.tpe))) - val matchExpr = (methodSym ARG 0) MATCH ( + val definedCaseClasses : Seq[CaseClassDef] = prog.definedClasses.filter(_.isInstanceOf[CaseClassDef]).map(_.asInstanceOf[CaseClassDef]) + val dccSyms = definedCaseClasses map (reverseClassesToClasses(_)) + + val caseClassMatchCases : List[CaseDef] = ((definedCaseClasses zip dccSyms) map { + case (ccd, scalaSym) => + (CASE(caseClassModule APPLY ((ccdBinderSym BIND WILD()), (argsBinderSym BIND WILD()))) IF ((ccdBinderSym DOT idField DOT nameField).setPos(defaultPos) ANY_== LIT(ccd.id.name).setPos(defaultPos))) ==> + New(TypeTree(scalaSym.tpe), List({ + (ccd.fields.zipWithIndex map { + case (VarDecl(id, tpe), idx) => + val typeArg = tpe match { + case purescala.TypeTrees.BooleanType => definitions.BooleanClass + case purescala.TypeTrees.Int32Type => definitions.IntClass + case c : purescala.TypeTrees.ClassType => reverseClassesToClasses(c.classDef) + case _ => scala.Predef.error("Cannot generate method using type : " + tpe) + } + println("args binder sym and its type:") + println(argsBinderSym) + println(argsBinderSym.tpe) + println("here is the list apply fun") + println(listApplyFunction) + println(listApplyFunction.tpe) + Apply( + TypeApply( + Ident(castMethodSym), + List(TypeTree(typeArg.tpe)) + ), + List( + // ugly hack to make typer happy :( + ((argsBinderSym DOT listApplyFunction) APPLY LIT(idx)) AS (exprClass.tpe) + ) + ) + }).toList + })) + }).toList + + val matchExpr = (methodSym ARG 0) MATCH ( List( CASE((intLiteralModule) APPLY (intSym BIND WILD())) ==> ID(intSym) , - CASE((booleanLiteralModule) APPLY (booleanSym BIND WILD())) ==> ID(booleanSym) , - DEFAULT ==> THROW(exceptionClass, LIT("Cannot convert FunCheck expression to Scala term")) + CASE((booleanLiteralModule) APPLY (booleanSym BIND WILD())) ==> ID(booleanSym)) ::: + caseClassMatchCases ::: + List(DEFAULT ==> THROW(exceptionClass, LIT("Cannot convert FunCheck expression to Scala term"))) : _* ) // the following is for the casting method @@ -147,8 +190,8 @@ trait CodeGeneration { // TODO type of map function cannot be resolved. def invokeMap(mapFunSym : Symbol, listSym : Symbol) : (Tree, Symbol) = { val newListSym = owner.newValue(NoPosition, unit.fresh.newName(NoPosition, "nl")).setInfo(typeRef(NoPrefix, definitions.ListClass, List(definitions.AnyClass.tpe))) - val assignment = VAL(newListSym) === ((listSym DOT mapFunction) APPLY ID(listSym)) - // val assignment = VAL(newListSym) === (listSym DOT (TypeApply(ID(mapFunction), List(TypeTree(definitions.AnyClass.tpe)))) APPLY ID(listSym)) + val assignment = VAL(newListSym) === ((listSym DOT listMapFunction) APPLY ID(listSym)) + // val assignment = VAL(newListSym) === (listSym DOT (TypeApply(ID(listMapFunction), List(TypeTree(definitions.AnyClass.tpe)))) APPLY ID(listSym)) (assignment, newListSym) } diff --git a/src/funcheck/CodeExtraction.scala b/src/funcheck/CodeExtraction.scala index d24c2f398bd87578be30461fc2d6eb9d08374768..2322e1aa5e1b84671606495d0d8ae06533498d25 100644 --- a/src/funcheck/CodeExtraction.scala +++ b/src/funcheck/CodeExtraction.scala @@ -26,7 +26,7 @@ trait CodeExtraction extends Extractors { private val defsToDefs: scala.collection.mutable.Map[Symbol,FunDef] = scala.collection.mutable.Map.empty[Symbol,FunDef] - private val reverseClassesToClasses: scala.collection.mutable.Map[ClassTypeDef,Symbol] = + val reverseClassesToClasses: scala.collection.mutable.Map[ClassTypeDef,Symbol] = scala.collection.mutable.Map.empty[ClassTypeDef,Symbol] protected def stopIfErrors: Unit = { @@ -304,19 +304,6 @@ trait CodeExtraction extends Extractors { fd } - /* - def groundExprToScala(expr : Expr) : Tree = { - val converted = expr match { - case IntLiteral(v) => Literal(Constant(v)) - case BooleanLiteral(v) => Literal(Constant(v)) - case StringLiteral(v) => Literal(Constant(v)) - case CaseClass(cd,args) => New(Ident(reverseClassesToClasses(cd)), List(args.map(groundExprToScala(_)).toList)) - case _ => scala.Predef.error("Cannot convert to Scala : " + expr) - } - converted - } - */ - /** An exception thrown when non-purescala compatible code is encountered. */ sealed case class ImpureCodeEncounteredException(tree: Tree) extends Exception