diff --git a/src/main/scala/leon/plugin/CodeExtraction.scala b/src/main/scala/leon/plugin/CodeExtraction.scala index a05dd1f289ed4829ef0f29a49f31635f23bf5c1a..4d3dab3d0f84fa215a588712bfe82e57f63abb51 100644 --- a/src/main/scala/leon/plugin/CodeExtraction.scala +++ b/src/main/scala/leon/plugin/CodeExtraction.scala @@ -17,6 +17,7 @@ trait CodeExtraction extends Extractors { import StructuralExtractors._ import ExpressionExtractors._ + private lazy val setTraitSym = definitions.getClass("scala.collection.immutable.Set") private lazy val mapTraitSym = definitions.getClass("scala.collection.immutable.Map") private lazy val multisetTraitSym = try { @@ -497,505 +498,508 @@ trait CodeExtraction extends Extractors { case _ => (tr, None) } - var handleRest = true - val psExpr = nextExpr match { - case ExTuple(tpes, exprs) => { - val tupleType = TupleType(tpes.map(tpe => scalaType2PureScala(unit, silent)(tpe))) - val tupleExprs = exprs.map(e => rec(e)) - Tuple(tupleExprs).setType(tupleType) - } - case ExTupleExtract(tuple, index) => { - val tupleExpr = rec(tuple) - val TupleType(tpes) = tupleExpr.getType - if(tpes.size < index) - throw ImpureCodeEncounteredException(tree) - else - TupleSelect(tupleExpr, index).setType(tpes(index-1)) - } - case ExValDef(vs, tpt, bdy) => { - val binderTpe = scalaType2PureScala(unit, silent)(tpt.tpe) - val newID = FreshIdentifier(vs.name.toString).setType(binderTpe) - val valTree = rec(bdy) - handleRest = false - if(valTree.getType.isInstanceOf[ArrayType]) { - getOwner(valTree) match { - case None => - owners += (Variable(newID) -> Some(currentFunDef)) - case Some(_) => - unit.error(nextExpr.pos, "Cannot alias array") - throw ImpureCodeEncounteredException(nextExpr) - } - } - val restTree = rest match { - case Some(rst) => { - varSubsts(vs) = (() => Variable(newID)) - val res = rec(rst) - varSubsts.remove(vs) - res + val e2: Option[Expr] = nextExpr match { + case ExParameterlessMethodCall(t,n) => { + val selector = rec(t) + val selType = selector.getType + + if(!selType.isInstanceOf[CaseClassType]) + None + else { + + val selDef: CaseClassDef = selType.asInstanceOf[CaseClassType].classDef + + val fieldID = selDef.fields.find(_.id.name == n.toString) match { + case None => { + if(!silent) + unit.error(tr.pos, "Invalid method or field invocation (not a case class arg?)") + throw ImpureCodeEncounteredException(tr) + } + case Some(vd) => vd.id } - case None => UnitLiteral + + Some(CaseClassSelector(selDef, selector, fieldID).setType(fieldID.getType)) } - val res = Let(newID, valTree, restTree) - res } - case dd@ExFunctionDef(n, p, t, b) => { - val funDef = extractFunSig(n, p, t).setPosInfo(dd.pos.line, dd.pos.column) - defsToDefs += (dd.symbol -> funDef) - val oldMutableVarSubst = mutableVarSubsts.toMap //take an immutable snapshot of the map - val oldCurrentFunDef = currentFunDef - mutableVarSubsts.clear //reseting the visible mutable vars, we do not handle mutable variable closure in nested functions - val funDefWithBody = extractFunDef(funDef, b) - mutableVarSubsts ++= oldMutableVarSubst - currentFunDef = oldCurrentFunDef - val restTree = rest match { - case Some(rst) => rec(rst) - case None => UnitLiteral + case _ => None + } + var handleRest = true + val psExpr = e2 match { + case Some(e3) => e3 + case None => nextExpr match { + case ExTuple(tpes, exprs) => { + val tupleType = TupleType(tpes.map(tpe => scalaType2PureScala(unit, silent)(tpe))) + val tupleExprs = exprs.map(e => rec(e)) + Tuple(tupleExprs).setType(tupleType) } - defsToDefs.remove(dd.symbol) - handleRest = false - LetDef(funDefWithBody, restTree) - } - case ExVarDef(vs, tpt, bdy) => { - val binderTpe = scalaType2PureScala(unit, silent)(tpt.tpe) - //binderTpe match { - // case ArrayType(_) => - // unit.error(tree.pos, "Cannot declare array variables, only val are alllowed") - // throw ImpureCodeEncounteredException(tree) - // case _ => - //} - val newID = FreshIdentifier(vs.name.toString).setType(binderTpe) - val valTree = rec(bdy) - mutableVarSubsts += (vs -> (() => Variable(newID))) - handleRest = false - if(valTree.getType.isInstanceOf[ArrayType]) { - getOwner(valTree) match { - case None => - owners += (Variable(newID) -> Some(currentFunDef)) - case Some(_) => - unit.error(nextExpr.pos, "Cannot alias array") - throw ImpureCodeEncounteredException(nextExpr) + case ExTupleExtract(tuple, index) => { + val tupleExpr = rec(tuple) + val TupleType(tpes) = tupleExpr.getType + if(tpes.size < index) + throw ImpureCodeEncounteredException(tree) + else + TupleSelect(tupleExpr, index).setType(tpes(index-1)) + } + case ExValDef(vs, tpt, bdy) => { + val binderTpe = scalaType2PureScala(unit, silent)(tpt.tpe) + val newID = FreshIdentifier(vs.name.toString).setType(binderTpe) + val valTree = rec(bdy) + handleRest = false + if(valTree.getType.isInstanceOf[ArrayType]) { + getOwner(valTree) match { + case None => + owners += (Variable(newID) -> Some(currentFunDef)) + case Some(_) => + unit.error(nextExpr.pos, "Cannot alias array") + throw ImpureCodeEncounteredException(nextExpr) + } + } + val restTree = rest match { + case Some(rst) => { + varSubsts(vs) = (() => Variable(newID)) + val res = rec(rst) + varSubsts.remove(vs) + res + } + case None => UnitLiteral } + val res = Let(newID, valTree, restTree) + res } - val restTree = rest match { - case Some(rst) => { - varSubsts(vs) = (() => Variable(newID)) - val res = rec(rst) - varSubsts.remove(vs) - res + case dd@ExFunctionDef(n, p, t, b) => { + val funDef = extractFunSig(n, p, t).setPosInfo(dd.pos.line, dd.pos.column) + defsToDefs += (dd.symbol -> funDef) + val oldMutableVarSubst = mutableVarSubsts.toMap //take an immutable snapshot of the map + val oldCurrentFunDef = currentFunDef + mutableVarSubsts.clear //reseting the visible mutable vars, we do not handle mutable variable closure in nested functions + val funDefWithBody = extractFunDef(funDef, b) + mutableVarSubsts ++= oldMutableVarSubst + currentFunDef = oldCurrentFunDef + val restTree = rest match { + case Some(rst) => rec(rst) + case None => UnitLiteral } - case None => UnitLiteral + defsToDefs.remove(dd.symbol) + handleRest = false + LetDef(funDefWithBody, restTree) } - val res = LetVar(newID, valTree, restTree) - res - } - - case ExAssign(sym, rhs) => mutableVarSubsts.get(sym) match { - case Some(fun) => { - val Variable(id) = fun() - val rhsTree = rec(rhs) - if(rhsTree.getType.isInstanceOf[ArrayType]) { - getOwner(rhsTree) match { + case ExVarDef(vs, tpt, bdy) => { + val binderTpe = scalaType2PureScala(unit, silent)(tpt.tpe) + //binderTpe match { + // case ArrayType(_) => + // unit.error(tree.pos, "Cannot declare array variables, only val are alllowed") + // throw ImpureCodeEncounteredException(tree) + // case _ => + //} + val newID = FreshIdentifier(vs.name.toString).setType(binderTpe) + val valTree = rec(bdy) + mutableVarSubsts += (vs -> (() => Variable(newID))) + handleRest = false + if(valTree.getType.isInstanceOf[ArrayType]) { + getOwner(valTree) match { case None => + owners += (Variable(newID) -> Some(currentFunDef)) case Some(_) => unit.error(nextExpr.pos, "Cannot alias array") throw ImpureCodeEncounteredException(nextExpr) } } - Assignment(id, rhsTree) - } - case None => { - unit.error(tr.pos, "Undeclared variable.") - throw ImpureCodeEncounteredException(tr) + val restTree = rest match { + case Some(rst) => { + varSubsts(vs) = (() => Variable(newID)) + val res = rec(rst) + varSubsts.remove(vs) + res + } + case None => UnitLiteral + } + val res = LetVar(newID, valTree, restTree) + res } - } - case wh@ExWhile(cond, body) => { - val condTree = rec(cond) - val bodyTree = rec(body) - While(condTree, bodyTree).setPosInfo(wh.pos.line,wh.pos.column) - } - case wh@ExWhileWithInvariant(cond, body, inv) => { - val condTree = rec(cond) - val bodyTree = rec(body) - val invTree = rec(inv) - val w = While(condTree, bodyTree).setPosInfo(wh.pos.line,wh.pos.column) - w.invariant = Some(invTree) - w - } - case ExInt32Literal(v) => IntLiteral(v).setType(Int32Type) - case ExBooleanLiteral(v) => BooleanLiteral(v).setType(BooleanType) - case ExUnitLiteral() => UnitLiteral - - case ExTyped(e,tpt) => rec(e) - case ExIdentifier(sym,tpt) => varSubsts.get(sym) match { - case Some(fun) => fun() - case None => mutableVarSubsts.get(sym) match { - case Some(fun) => fun() + case ExAssign(sym, rhs) => mutableVarSubsts.get(sym) match { + case Some(fun) => { + val Variable(id) = fun() + val rhsTree = rec(rhs) + if(rhsTree.getType.isInstanceOf[ArrayType]) { + getOwner(rhsTree) match { + case None => + case Some(_) => + unit.error(nextExpr.pos, "Cannot alias array") + throw ImpureCodeEncounteredException(nextExpr) + } + } + Assignment(id, rhsTree) + } case None => { - unit.error(tr.pos, "Unidentified variable.") + unit.error(tr.pos, "Undeclared variable.") throw ImpureCodeEncounteredException(tr) } } - } - case epsi@ExEpsilonExpression(tpe, varSym, predBody) => { - val pstpe = scalaType2PureScala(unit, silent)(tpe) - val previousVarSubst: Option[Function0[Expr]] = varSubsts.get(varSym) //save the previous in case of nested epsilon - varSubsts(varSym) = (() => EpsilonVariable((epsi.pos.line, epsi.pos.column)).setType(pstpe)) - val c1 = rec(predBody) - previousVarSubst match { - case Some(f) => varSubsts(varSym) = f - case None => varSubsts.remove(varSym) + case wh@ExWhile(cond, body) => { + val condTree = rec(cond) + val bodyTree = rec(body) + While(condTree, bodyTree).setPosInfo(wh.pos.line,wh.pos.column) } - if(containsEpsilon(c1)) { - unit.error(epsi.pos, "Usage of nested epsilon is not allowed.") - throw ImpureCodeEncounteredException(epsi) + case wh@ExWhileWithInvariant(cond, body, inv) => { + val condTree = rec(cond) + val bodyTree = rec(body) + val invTree = rec(inv) + val w = While(condTree, bodyTree).setPosInfo(wh.pos.line,wh.pos.column) + w.invariant = Some(invTree) + w } - Epsilon(c1).setType(pstpe).setPosInfo(epsi.pos.line, epsi.pos.column) - } - case ExSomeConstruction(tpe, arg) => { - // println("Got Some !" + tpe + ":" + arg) - val underlying = scalaType2PureScala(unit, silent)(tpe) - OptionSome(rec(arg)).setType(OptionType(underlying)) - } - case ExCaseClassConstruction(tpt, args) => { - val cctype = scalaType2PureScala(unit, silent)(tpt.tpe) - if(!cctype.isInstanceOf[CaseClassType]) { - if(!silent) { - unit.error(tr.pos, "Construction of a non-case class.") + + case ExInt32Literal(v) => IntLiteral(v).setType(Int32Type) + case ExBooleanLiteral(v) => BooleanLiteral(v).setType(BooleanType) + case ExUnitLiteral() => UnitLiteral + + case ExTyped(e,tpt) => rec(e) + case ExIdentifier(sym,tpt) => varSubsts.get(sym) match { + case Some(fun) => fun() + case None => mutableVarSubsts.get(sym) match { + case Some(fun) => fun() + case None => { + unit.error(tr.pos, "Unidentified variable.") + throw ImpureCodeEncounteredException(tr) + } } - throw ImpureCodeEncounteredException(tree) } - val nargs = args.map(rec(_)) - val cct = cctype.asInstanceOf[CaseClassType] - CaseClass(cct.classDef, nargs).setType(cct) - } - case ExAnd(l, r) => And(rec(l), rec(r)).setType(BooleanType) - case ExOr(l, r) => Or(rec(l), rec(r)).setType(BooleanType) - case ExNot(e) => Not(rec(e)).setType(BooleanType) - case ExUMinus(e) => UMinus(rec(e)).setType(Int32Type) - case ExPlus(l, r) => Plus(rec(l), rec(r)).setType(Int32Type) - case ExMinus(l, r) => Minus(rec(l), rec(r)).setType(Int32Type) - case ExTimes(l, r) => Times(rec(l), rec(r)).setType(Int32Type) - case ExDiv(l, r) => Division(rec(l), rec(r)).setType(Int32Type) - case ExMod(l, r) => Modulo(rec(l), rec(r)).setType(Int32Type) - case ExEquals(l, r) => { - val rl = rec(l) - val rr = rec(r) - ((rl.getType,rr.getType) match { - case (SetType(_), SetType(_)) => SetEquals(rl, rr) - case (BooleanType, BooleanType) => Iff(rl, rr) - case (_, _) => Equals(rl, rr) - }).setType(BooleanType) - } - case ExNotEquals(l, r) => Not(Equals(rec(l), rec(r)).setType(BooleanType)).setType(BooleanType) - case ExGreaterThan(l, r) => GreaterThan(rec(l), rec(r)).setType(BooleanType) - case ExGreaterEqThan(l, r) => GreaterEquals(rec(l), rec(r)).setType(BooleanType) - case ExLessThan(l, r) => LessThan(rec(l), rec(r)).setType(BooleanType) - case ExLessEqThan(l, r) => LessEquals(rec(l), rec(r)).setType(BooleanType) - case ExFiniteSet(tt, args) => { - val underlying = scalaType2PureScala(unit, silent)(tt.tpe) - FiniteSet(args.map(rec(_))).setType(SetType(underlying)) - } - case ExFiniteMultiset(tt, args) => { - val underlying = scalaType2PureScala(unit, silent)(tt.tpe) - FiniteMultiset(args.map(rec(_))).setType(MultisetType(underlying)) - } - case ExEmptySet(tt) => { - val underlying = scalaType2PureScala(unit, silent)(tt.tpe) - EmptySet(underlying).setType(SetType(underlying)) - } - case ExEmptyMultiset(tt) => { - val underlying = scalaType2PureScala(unit, silent)(tt.tpe) - EmptyMultiset(underlying).setType(MultisetType(underlying)) - } - case ExEmptyMap(ft, tt) => { - val fromUnderlying = scalaType2PureScala(unit, silent)(ft.tpe) - val toUnderlying = scalaType2PureScala(unit, silent)(tt.tpe) - EmptyMap(fromUnderlying, toUnderlying).setType(MapType(fromUnderlying, toUnderlying)) - } - case ExSetMin(t) => { - val set = rec(t) - if(!set.getType.isInstanceOf[SetType]) { - if(!silent) unit.error(t.pos, "Min should be computed on a set.") - throw ImpureCodeEncounteredException(tree) + case epsi@ExEpsilonExpression(tpe, varSym, predBody) => { + val pstpe = scalaType2PureScala(unit, silent)(tpe) + val previousVarSubst: Option[Function0[Expr]] = varSubsts.get(varSym) //save the previous in case of nested epsilon + varSubsts(varSym) = (() => EpsilonVariable((epsi.pos.line, epsi.pos.column)).setType(pstpe)) + val c1 = rec(predBody) + previousVarSubst match { + case Some(f) => varSubsts(varSym) = f + case None => varSubsts.remove(varSym) + } + if(containsEpsilon(c1)) { + unit.error(epsi.pos, "Usage of nested epsilon is not allowed.") + throw ImpureCodeEncounteredException(epsi) + } + Epsilon(c1).setType(pstpe).setPosInfo(epsi.pos.line, epsi.pos.column) } - SetMin(set).setType(set.getType.asInstanceOf[SetType].base) - } - case ExSetMax(t) => { - val set = rec(t) - if(!set.getType.isInstanceOf[SetType]) { - if(!silent) unit.error(t.pos, "Max should be computed on a set.") - throw ImpureCodeEncounteredException(tree) + case ExSomeConstruction(tpe, arg) => { + // println("Got Some !" + tpe + ":" + arg) + val underlying = scalaType2PureScala(unit, silent)(tpe) + OptionSome(rec(arg)).setType(OptionType(underlying)) } - SetMax(set).setType(set.getType.asInstanceOf[SetType].base) - } - case ExUnion(t1,t2) => { - val rl = rec(t1) - val rr = rec(t2) - rl.getType match { - case s @ SetType(_) => SetUnion(rl, rr).setType(s) - case m @ MultisetType(_) => MultisetUnion(rl, rr).setType(m) - case _ => { - if(!silent) unit.error(tree.pos, "Union of non set/multiset expressions.") + case ExCaseClassConstruction(tpt, args) => { + val cctype = scalaType2PureScala(unit, silent)(tpt.tpe) + if(!cctype.isInstanceOf[CaseClassType]) { + if(!silent) { + unit.error(tr.pos, "Construction of a non-case class.") + } throw ImpureCodeEncounteredException(tree) } + val nargs = args.map(rec(_)) + val cct = cctype.asInstanceOf[CaseClassType] + CaseClass(cct.classDef, nargs).setType(cct) } - } - case ExIntersection(t1,t2) => { - val rl = rec(t1) - val rr = rec(t2) - rl.getType match { - case s @ SetType(_) => SetIntersection(rl, rr).setType(s) - case m @ MultisetType(_) => MultisetIntersection(rl, rr).setType(m) - case _ => { - if(!silent) unit.error(tree.pos, "Intersection of non set/multiset expressions.") + case ExAnd(l, r) => And(rec(l), rec(r)).setType(BooleanType) + case ExOr(l, r) => Or(rec(l), rec(r)).setType(BooleanType) + case ExNot(e) => Not(rec(e)).setType(BooleanType) + case ExUMinus(e) => UMinus(rec(e)).setType(Int32Type) + case ExPlus(l, r) => Plus(rec(l), rec(r)).setType(Int32Type) + case ExMinus(l, r) => Minus(rec(l), rec(r)).setType(Int32Type) + case ExTimes(l, r) => Times(rec(l), rec(r)).setType(Int32Type) + case ExDiv(l, r) => Division(rec(l), rec(r)).setType(Int32Type) + case ExMod(l, r) => Modulo(rec(l), rec(r)).setType(Int32Type) + case ExEquals(l, r) => { + val rl = rec(l) + val rr = rec(r) + ((rl.getType,rr.getType) match { + case (SetType(_), SetType(_)) => SetEquals(rl, rr) + case (BooleanType, BooleanType) => Iff(rl, rr) + case (_, _) => Equals(rl, rr) + }).setType(BooleanType) + } + case ExNotEquals(l, r) => Not(Equals(rec(l), rec(r)).setType(BooleanType)).setType(BooleanType) + case ExGreaterThan(l, r) => GreaterThan(rec(l), rec(r)).setType(BooleanType) + case ExGreaterEqThan(l, r) => GreaterEquals(rec(l), rec(r)).setType(BooleanType) + case ExLessThan(l, r) => LessThan(rec(l), rec(r)).setType(BooleanType) + case ExLessEqThan(l, r) => LessEquals(rec(l), rec(r)).setType(BooleanType) + case ExFiniteSet(tt, args) => { + val underlying = scalaType2PureScala(unit, silent)(tt.tpe) + FiniteSet(args.map(rec(_))).setType(SetType(underlying)) + } + case ExFiniteMultiset(tt, args) => { + val underlying = scalaType2PureScala(unit, silent)(tt.tpe) + FiniteMultiset(args.map(rec(_))).setType(MultisetType(underlying)) + } + case ExEmptySet(tt) => { + val underlying = scalaType2PureScala(unit, silent)(tt.tpe) + EmptySet(underlying).setType(SetType(underlying)) + } + case ExEmptyMultiset(tt) => { + val underlying = scalaType2PureScala(unit, silent)(tt.tpe) + EmptyMultiset(underlying).setType(MultisetType(underlying)) + } + case ExEmptyMap(ft, tt) => { + val fromUnderlying = scalaType2PureScala(unit, silent)(ft.tpe) + val toUnderlying = scalaType2PureScala(unit, silent)(tt.tpe) + EmptyMap(fromUnderlying, toUnderlying).setType(MapType(fromUnderlying, toUnderlying)) + } + case ExSetMin(t) => { + val set = rec(t) + if(!set.getType.isInstanceOf[SetType]) { + if(!silent) unit.error(t.pos, "Min should be computed on a set.") throw ImpureCodeEncounteredException(tree) } + SetMin(set).setType(set.getType.asInstanceOf[SetType].base) } - } - case ExSetContains(t1,t2) => { - val rl = rec(t1) - val rr = rec(t2) - rl.getType match { - case s @ SetType(_) => ElementOfSet(rr, rl) - case _ => { - if(!silent) unit.error(tree.pos, ".contains on non set expression.") + case ExSetMax(t) => { + val set = rec(t) + if(!set.getType.isInstanceOf[SetType]) { + if(!silent) unit.error(t.pos, "Max should be computed on a set.") throw ImpureCodeEncounteredException(tree) } + SetMax(set).setType(set.getType.asInstanceOf[SetType].base) } - } - case ExSetSubset(t1,t2) => { - val rl = rec(t1) - val rr = rec(t2) - rl.getType match { - case s @ SetType(_) => SubsetOf(rl, rr) - case _ => { - if(!silent) unit.error(tree.pos, "Subset on non set expression.") - throw ImpureCodeEncounteredException(tree) + case ExUnion(t1,t2) => { + val rl = rec(t1) + val rr = rec(t2) + rl.getType match { + case s @ SetType(_) => SetUnion(rl, rr).setType(s) + case m @ MultisetType(_) => MultisetUnion(rl, rr).setType(m) + case _ => { + if(!silent) unit.error(tree.pos, "Union of non set/multiset expressions.") + throw ImpureCodeEncounteredException(tree) + } } } - } - case ExSetMinus(t1,t2) => { - val rl = rec(t1) - val rr = rec(t2) - rl.getType match { - case s @ SetType(_) => SetDifference(rl, rr).setType(s) - case m @ MultisetType(_) => MultisetDifference(rl, rr).setType(m) - case _ => { - if(!silent) unit.error(tree.pos, "Difference of non set/multiset expressions.") - throw ImpureCodeEncounteredException(tree) + case ExIntersection(t1,t2) => { + val rl = rec(t1) + val rr = rec(t2) + rl.getType match { + case s @ SetType(_) => SetIntersection(rl, rr).setType(s) + case m @ MultisetType(_) => MultisetIntersection(rl, rr).setType(m) + case _ => { + if(!silent) unit.error(tree.pos, "Intersection of non set/multiset expressions.") + throw ImpureCodeEncounteredException(tree) + } } } - } - case ExSetCard(t) => { - val rt = rec(t) - rt.getType match { - case s @ SetType(_) => SetCardinality(rt) - case m @ MultisetType(_) => MultisetCardinality(rt) - case _ => { - if(!silent) unit.error(tree.pos, "Cardinality of non set/multiset expressions.") - throw ImpureCodeEncounteredException(tree) + case ExSetContains(t1,t2) => { + val rl = rec(t1) + val rr = rec(t2) + rl.getType match { + case s @ SetType(_) => ElementOfSet(rr, rl) + case _ => { + if(!silent) unit.error(tree.pos, ".contains on non set expression.") + throw ImpureCodeEncounteredException(tree) + } } } - } - case ExMultisetToSet(t) => { - val rt = rec(t) - rt.getType match { - case m @ MultisetType(u) => MultisetToSet(rt).setType(SetType(u)) - case _ => { - if(!silent) unit.error(tree.pos, "toSet can only be applied to multisets.") - throw ImpureCodeEncounteredException(tree) + case ExSetSubset(t1,t2) => { + val rl = rec(t1) + val rr = rec(t2) + rl.getType match { + case s @ SetType(_) => SubsetOf(rl, rr) + case _ => { + if(!silent) unit.error(tree.pos, "Subset on non set expression.") + throw ImpureCodeEncounteredException(tree) + } } } - } - case up@ExUpdated(m,f,t) => { - val rm = rec(m) - val rf = rec(f) - val rt = rec(t) - rm.getType match { - case MapType(ft, tt) => { - val newSingleton = SingletonMap(rf, rt).setType(rm.getType) - MapUnion(rm, FiniteMap(Seq(newSingleton)).setType(rm.getType)).setType(rm.getType) + case ExSetMinus(t1,t2) => { + val rl = rec(t1) + val rr = rec(t2) + rl.getType match { + case s @ SetType(_) => SetDifference(rl, rr).setType(s) + case m @ MultisetType(_) => MultisetDifference(rl, rr).setType(m) + case _ => { + if(!silent) unit.error(tree.pos, "Difference of non set/multiset expressions.") + throw ImpureCodeEncounteredException(tree) + } } - case ArrayType(bt) => { - ArrayUpdated(rm, rf, rt).setType(rm.getType).setPosInfo(up.pos.line, up.pos.column) + } + case ExSetCard(t) => { + val rt = rec(t) + rt.getType match { + case s @ SetType(_) => SetCardinality(rt) + case m @ MultisetType(_) => MultisetCardinality(rt) + case _ => { + if(!silent) unit.error(tree.pos, "Cardinality of non set/multiset expressions.") + throw ImpureCodeEncounteredException(tree) + } } - case _ => { - if (!silent) unit.error(tree.pos, "updated can only be applied to maps.") - throw ImpureCodeEncounteredException(tree) + } + case ExMultisetToSet(t) => { + val rt = rec(t) + rt.getType match { + case m @ MultisetType(u) => MultisetToSet(rt).setType(SetType(u)) + case _ => { + if(!silent) unit.error(tree.pos, "toSet can only be applied to multisets.") + throw ImpureCodeEncounteredException(tree) + } } } - } - case ExMapIsDefinedAt(m,k) => { - val rm = rec(m) - val rk = rec(k) - MapIsDefinedAt(rm, rk) - } + case up@ExUpdated(m,f,t) => { + val rm = rec(m) + val rf = rec(f) + val rt = rec(t) + rm.getType match { + case MapType(ft, tt) => { + val newSingleton = SingletonMap(rf, rt).setType(rm.getType) + MapUnion(rm, FiniteMap(Seq(newSingleton)).setType(rm.getType)).setType(rm.getType) + } + case ArrayType(bt) => { + ArrayUpdated(rm, rf, rt).setType(rm.getType).setPosInfo(up.pos.line, up.pos.column) + } + case _ => { + if (!silent) unit.error(tree.pos, "updated can only be applied to maps.") + throw ImpureCodeEncounteredException(tree) + } + } + } + case ExMapIsDefinedAt(m,k) => { + val rm = rec(m) + val rk = rec(k) + MapIsDefinedAt(rm, rk) + } - case ExPlusPlusPlus(t1,t2) => { - val rl = rec(t1) - val rr = rec(t2) - MultisetPlus(rl, rr).setType(rl.getType) - } - case app@ExApply(lhs,args) => { - val rlhs = rec(lhs) - val rargs = args map rec - rlhs.getType match { - case MapType(_,tt) => - assert(rargs.size == 1) - MapGet(rlhs, rargs.head).setType(tt).setPosInfo(app.pos.line, app.pos.column) - case FunctionType(fts, tt) => { - rlhs match { - case Variable(id) => - AnonymousFunctionInvocation(id, rargs).setType(tt) - case _ => { - if (!silent) unit.error(tree.pos, "apply on non-variable or non-map expression") - throw ImpureCodeEncounteredException(tree) + case ExPlusPlusPlus(t1,t2) => { + val rl = rec(t1) + val rr = rec(t2) + MultisetPlus(rl, rr).setType(rl.getType) + } + case app@ExApply(lhs,args) => { + val rlhs = rec(lhs) + val rargs = args map rec + rlhs.getType match { + case MapType(_,tt) => + assert(rargs.size == 1) + MapGet(rlhs, rargs.head).setType(tt).setPosInfo(app.pos.line, app.pos.column) + case FunctionType(fts, tt) => { + rlhs match { + case Variable(id) => + AnonymousFunctionInvocation(id, rargs).setType(tt) + case _ => { + if (!silent) unit.error(tree.pos, "apply on non-variable or non-map expression") + throw ImpureCodeEncounteredException(tree) + } } } + case ArrayType(bt) => { + assert(rargs.size == 1) + ArraySelect(rlhs, rargs.head).setType(bt).setPosInfo(app.pos.line, app.pos.column) + } + case _ => { + if (!silent) unit.error(tree.pos, "apply on unexpected type") + throw ImpureCodeEncounteredException(tree) + } } - case ArrayType(bt) => { - assert(rargs.size == 1) - ArraySelect(rlhs, rargs.head).setType(bt).setPosInfo(app.pos.line, app.pos.column) + } + // for now update only occurs on Array. later we might have to distinguished depending on the type of the lhs + case update@ExUpdate(lhs, index, newValue) => { + val lhsRec = rec(lhs) + lhsRec match { + case Variable(_) => + case _ => { + unit.error(tree.pos, "array update only works on variables") + throw ImpureCodeEncounteredException(tree) + } } - case _ => { - if (!silent) unit.error(tree.pos, "apply on unexpected type") - throw ImpureCodeEncounteredException(tree) + getOwner(lhsRec) match { + case Some(Some(fd)) if fd != currentFunDef => + unit.error(nextExpr.pos, "cannot update an array that is not defined locally") + throw ImpureCodeEncounteredException(nextExpr) + case Some(None) => + unit.error(nextExpr.pos, "cannot update an array that is not defined locally") + throw ImpureCodeEncounteredException(nextExpr) + case Some(_) => + case None => sys.error("This array: " + lhsRec + " should have had an owner") } + val indexRec = rec(index) + val newValueRec = rec(newValue) + ArrayUpdate(lhsRec, indexRec, newValueRec).setPosInfo(update.pos.line, update.pos.column) } - } - // for now update only occurs on Array. later we might have to distinguished depending on the type of the lhs - case update@ExUpdate(lhs, index, newValue) => { - val lhsRec = rec(lhs) - lhsRec match { - case Variable(_) => - case _ => { - unit.error(tree.pos, "array update only works on variables") - throw ImpureCodeEncounteredException(tree) - } + case ExArrayLength(t) => { + val rt = rec(t) + ArrayLength(rt) } - getOwner(lhsRec) match { - case Some(Some(fd)) if fd != currentFunDef => - unit.error(nextExpr.pos, "cannot update an array that is not defined locally") - throw ImpureCodeEncounteredException(nextExpr) - case Some(None) => - unit.error(nextExpr.pos, "cannot update an array that is not defined locally") - throw ImpureCodeEncounteredException(nextExpr) - case Some(_) => - case None => sys.error("This array: " + lhsRec + " should have had an owner") + case ExArrayClone(t) => { + val rt = rec(t) + ArrayClone(rt) } - val indexRec = rec(index) - val newValueRec = rec(newValue) - ArrayUpdate(lhsRec, indexRec, newValueRec).setPosInfo(update.pos.line, update.pos.column) - } - case ExArrayLength(t) => { - val rt = rec(t) - ArrayLength(rt) - } - case ExArrayClone(t) => { - val rt = rec(t) - ArrayClone(rt) - } - case ExArrayFill(baseType, length, defaultValue) => { - val underlying = scalaType2PureScala(unit, silent)(baseType.tpe) - val lengthRec = rec(length) - val defaultValueRec = rec(defaultValue) - ArrayFill(lengthRec, defaultValueRec).setType(ArrayType(underlying)) - } - case ExIfThenElse(t1,t2,t3) => { - val r1 = rec(t1) - if(containsLetDef(r1)) { - unit.error(t1.pos, "Condition of if-then-else expression should not contain nested function definition") - throw ImpureCodeEncounteredException(t1) + case ExArrayFill(baseType, length, defaultValue) => { + val underlying = scalaType2PureScala(unit, silent)(baseType.tpe) + val lengthRec = rec(length) + val defaultValueRec = rec(defaultValue) + ArrayFill(lengthRec, defaultValueRec).setType(ArrayType(underlying)) } - val r2 = rec(t2) - val r3 = rec(t3) - val lub = leastUpperBound(r2.getType, r3.getType) - lub match { - case Some(lub) => IfExpr(r1, r2, r3).setType(lub) - case None => - unit.error(nextExpr.pos, "Both branches of ifthenelse have incompatible types") + case ExIfThenElse(t1,t2,t3) => { + val r1 = rec(t1) + if(containsLetDef(r1)) { + unit.error(t1.pos, "Condition of if-then-else expression should not contain nested function definition") throw ImpureCodeEncounteredException(t1) + } + val r2 = rec(t2) + val r3 = rec(t3) + val lub = leastUpperBound(r2.getType, r3.getType) + lub match { + case Some(lub) => IfExpr(r1, r2, r3).setType(lub) + case None => + unit.error(nextExpr.pos, "Both branches of ifthenelse have incompatible types") + throw ImpureCodeEncounteredException(t1) + } } - } - case ExIsInstanceOf(tt, cc) => { - val ccRec = rec(cc) - val checkType = scalaType2PureScala(unit, silent)(tt.tpe) - checkType match { - case CaseClassType(ccd) => { - val rootType: ClassTypeDef = if(ccd.parent != None) ccd.parent.get else ccd - if(!ccRec.getType.isInstanceOf[ClassType]) { - unit.error(tr.pos, "isInstanceOf can only be used with a case class") - throw ImpureCodeEncounteredException(tr) - } else { - val testedExprType = ccRec.getType.asInstanceOf[ClassType].classDef - val testedExprRootType: ClassTypeDef = if(testedExprType.parent != None) testedExprType.parent.get else testedExprType - - if(rootType != testedExprRootType) { - unit.error(tr.pos, "isInstanceOf can only be used with compatible case classes") + case ExIsInstanceOf(tt, cc) => { + val ccRec = rec(cc) + val checkType = scalaType2PureScala(unit, silent)(tt.tpe) + checkType match { + case CaseClassType(ccd) => { + val rootType: ClassTypeDef = if(ccd.parent != None) ccd.parent.get else ccd + if(!ccRec.getType.isInstanceOf[ClassType]) { + unit.error(tr.pos, "isInstanceOf can only be used with a case class") throw ImpureCodeEncounteredException(tr) } else { - CaseClassInstanceOf(ccd, ccRec) + val testedExprType = ccRec.getType.asInstanceOf[ClassType].classDef + val testedExprRootType: ClassTypeDef = if(testedExprType.parent != None) testedExprType.parent.get else testedExprType + + if(rootType != testedExprRootType) { + unit.error(tr.pos, "isInstanceOf can only be used with compatible case classes") + throw ImpureCodeEncounteredException(tr) + } else { + CaseClassInstanceOf(ccd, ccRec) + } } } - } - case _ => { - unit.error(tr.pos, "isInstanceOf can only be used with a case class") - throw ImpureCodeEncounteredException(tr) + case _ => { + unit.error(tr.pos, "isInstanceOf can only be used with a case class") + throw ImpureCodeEncounteredException(tr) + } } } - } - case lc @ ExLocalCall(sy,nm,ar) => { - if(defsToDefs.keysIterator.find(_ == sy).isEmpty) { - if(!silent) - unit.error(tr.pos, "Invoking an invalid function.") - throw ImpureCodeEncounteredException(tr) - } - val fd = defsToDefs(sy) - FunctionInvocation(fd, ar.map(rec(_))).setType(fd.returnType).setPosInfo(lc.pos.line,lc.pos.column) - } - case pm @ ExPatternMatching(sel, cses) => { - val rs = rec(sel) - val rc = cses.map(rewriteCaseDef(_)) - val rt: purescala.TypeTrees.TypeTree = rc.map(_.rhs.getType).reduceLeft(leastUpperBound(_,_).get) - MatchExpr(rs, rc).setType(rt).setPosInfo(pm.pos.line,pm.pos.column) - } - - // this one should stay after all others, cause it also catches UMinus - // and Not, for instance. - case ExParameterlessMethodCall(t,n) => { - val selector = rec(t) - val selType = selector.getType - - if(!selType.isInstanceOf[CaseClassType]) { - if(!silent) - unit.error(tr.pos, "Invalid method or field invocation (not purescala?)") - throw ImpureCodeEncounteredException(tr) - } - - val selDef: CaseClassDef = selType.asInstanceOf[CaseClassType].classDef - - val fieldID = selDef.fields.find(_.id.name == n.toString) match { - case None => { + case lc @ ExLocalCall(sy,nm,ar) => { + if(defsToDefs.keysIterator.find(_ == sy).isEmpty) { if(!silent) - unit.error(tr.pos, "Invalid method or field invocation (not a case class arg?)") + unit.error(tr.pos, "Invoking an invalid function.") throw ImpureCodeEncounteredException(tr) } - case Some(vd) => vd.id + val fd = defsToDefs(sy) + FunctionInvocation(fd, ar.map(rec(_))).setType(fd.returnType).setPosInfo(lc.pos.line,lc.pos.column) + } + case pm @ ExPatternMatching(sel, cses) => { + val rs = rec(sel) + val rc = cses.map(rewriteCaseDef(_)) + val rt: purescala.TypeTrees.TypeTree = rc.map(_.rhs.getType).reduceLeft(leastUpperBound(_,_).get) + MatchExpr(rs, rc).setType(rt).setPosInfo(pm.pos.line,pm.pos.column) } - CaseClassSelector(selDef, selector, fieldID).setType(fieldID.getType) - } - - // default behaviour is to complain :) - case _ => { - if(!silent) { - println(tr) - reporter.info(tr.pos, "Could not extract as PureScala.", true) + + // default behaviour is to complain :) + case _ => { + if(!silent) { + println(tr) + reporter.info(tr.pos, "Could not extract as PureScala.", true) + } + throw ImpureCodeEncounteredException(tree) } - throw ImpureCodeEncounteredException(tree) } } diff --git a/testcases/regression/valid/Field1.scala b/testcases/regression/valid/Field1.scala new file mode 100644 index 0000000000000000000000000000000000000000..116557ab7b883d01a10168aeaf529d5300ee5f19 --- /dev/null +++ b/testcases/regression/valid/Field1.scala @@ -0,0 +1,11 @@ +object Field1 { + + abstract sealed class A + case class B(size: Int) extends A + + def foo(): Int = { + val b = B(3) + b.size + } ensuring(_ == 3) + +} diff --git a/testcases/regression/valid/Field2.scala b/testcases/regression/valid/Field2.scala new file mode 100644 index 0000000000000000000000000000000000000000..9a96610235a68754b84e58f73f5e435ce642ebc9 --- /dev/null +++ b/testcases/regression/valid/Field2.scala @@ -0,0 +1,11 @@ +object Field2 { + + abstract sealed class A + case class B(length: Int) extends A + + def foo(): Int = { + val b = B(3) + b.length + } ensuring(_ == 3) + +}