From bac2daf1e5f2beacc07f46e720f7f1cc5a12d2af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9gis=20Blanc?= <regwblanc@gmail.com> Date: Wed, 9 May 2012 19:43:29 +0200 Subject: [PATCH] correct bug when writing if(..) a(i) = .. else 0 --- src/main/scala/leon/ArrayTransformation.scala | 6 +++--- src/main/scala/leon/FairZ3Solver.scala | 2 +- .../leon/ImperativeCodeElimination.scala | 5 +++++ .../scala/leon/plugin/CodeExtraction.scala | 12 ++++++++--- .../scala/leon/purescala/Definitions.scala | 2 +- .../scala/leon/purescala/PrettyPrinter.scala | 8 ++++---- src/main/scala/leon/purescala/Trees.scala | 8 +++++++- src/main/scala/leon/purescala/TypeTrees.scala | 20 +++++++++---------- testcases/regression/error/Array3.scala | 11 ++++++++++ 9 files changed, 51 insertions(+), 23 deletions(-) create mode 100644 testcases/regression/error/Array3.scala diff --git a/src/main/scala/leon/ArrayTransformation.scala b/src/main/scala/leon/ArrayTransformation.scala index db6a7fe2f..6c4382ff7 100644 --- a/src/main/scala/leon/ArrayTransformation.scala +++ b/src/main/scala/leon/ArrayTransformation.scala @@ -42,10 +42,9 @@ object ArrayTransformation extends Pass { val rv = transform(v) val Variable(id) = ra val length = ArrayLength(ra) - val array = TupleSelect(ra, 1).setType(ArrayType(v.getType)) val res = IfExpr( And(LessEquals(IntLiteral(0), ri), LessThan(ri, length)), - Assignment(id, ArrayUpdated(ra, ri, rv).setType(a.getType).setPosInfo(up)), + Assignment(id, ArrayUpdated(ra, ri, rv).setType(ra.getType).setPosInfo(up)), Error("Index out of bound").setType(UnitType).setPosInfo(up) ).setType(UnitType) res @@ -53,7 +52,7 @@ object ArrayTransformation extends Pass { case Let(i, v, b) => { v.getType match { case ArrayType(_) => { - val freshIdentifier = FreshIdentifier("t").setType(v.getType) + val freshIdentifier = FreshIdentifier("t").setType(i.getType) id2FreshId += (i -> freshIdentifier) LetVar(freshIdentifier, transform(v), transform(b)) } @@ -74,6 +73,7 @@ object ArrayTransformation extends Pass { val newWh = While(transform(c), transform(e)) newWh.invariant = wh.invariant.map(i => transform(i)) newWh.setPosInfo(wh) + newWh } case ite@IfExpr(c, t, e) => { diff --git a/src/main/scala/leon/FairZ3Solver.scala b/src/main/scala/leon/FairZ3Solver.scala index 58cd49ff4..90f95153e 100644 --- a/src/main/scala/leon/FairZ3Solver.scala +++ b/src/main/scala/leon/FairZ3Solver.scala @@ -1237,7 +1237,7 @@ class FairZ3Solver(reporter: Reporter) extends Solver(reporter) with AbstractZ3S val r1 = rargs(1) val r2 = rargs(2) try { - IfExpr(r0, r1, r2).setType(leastUpperBound(r1.getType, r2.getType)) + IfExpr(r0, r1, r2).setType(leastUpperBound(r1.getType, r2.getType).get) } catch { case e => { println("I was asking for lub because of this.") diff --git a/src/main/scala/leon/ImperativeCodeElimination.scala b/src/main/scala/leon/ImperativeCodeElimination.scala index f6937463e..0bffa9459 100644 --- a/src/main/scala/leon/ImperativeCodeElimination.scala +++ b/src/main/scala/leon/ImperativeCodeElimination.scala @@ -49,6 +49,11 @@ object ImperativeCodeElimination extends Pass { val (cRes, cScope, cFun) = toFunction(cond) val (tRes, tScope, tFun) = toFunction(tExpr) val (eRes, eScope, eFun) = toFunction(eExpr) + if(tRes.getType != eRes.getType) { + println("PROBLEM: tres: " + tRes + " has type: " + tRes.getType + " then was: " + tExpr) + println("PROBLEM: eres: " + eRes + " has type: " + eRes.getType + " else was: " + eExpr) + assert(false) + } val modifiedVars: Seq[Identifier] = (tFun.keys ++ eFun.keys).toSet.intersect(varInScope).toSeq val resId = FreshIdentifier("res").setType(ite.getType) diff --git a/src/main/scala/leon/plugin/CodeExtraction.scala b/src/main/scala/leon/plugin/CodeExtraction.scala index 6162db6b1..0d05298b0 100644 --- a/src/main/scala/leon/plugin/CodeExtraction.scala +++ b/src/main/scala/leon/plugin/CodeExtraction.scala @@ -816,7 +816,7 @@ trait CodeExtraction extends Extractors { } val indexRec = rec(index) val newValueRec = rec(newValue) - ArrayUpdate(lhsRec, indexRec, newValueRec).setType(newValueRec.getType).setPosInfo(update.pos.line, update.pos.column) + ArrayUpdate(lhsRec, indexRec, newValueRec).setPosInfo(update.pos.line, update.pos.column) } case ExArrayLength(t) => { val rt = rec(t) @@ -836,7 +836,13 @@ trait CodeExtraction extends Extractors { } val r2 = rec(t2) val r3 = rec(t3) - IfExpr(r1, r2, r3).setType(leastUpperBound(r2.getType, r3.getType)) + val lub = leastUpperBound(r2.getType, r3.getType) + lub match { + case Some(lub) => IfExpr(r1, r2, r3).setType(lub) + case None => + unit.error(tree.pos, "Both branches of ifthenelse have incompatible types") + throw ImpureCodeEncounteredException(t1) + } } case ExIsInstanceOf(tt, cc) => { @@ -879,7 +885,7 @@ trait CodeExtraction extends Extractors { case pm @ ExPatternMatching(sel, cses) => { val rs = rec(sel) val rc = cses.map(rewriteCaseDef(_)) - val rt: purescala.TypeTrees.TypeTree = rc.map(_.rhs.getType).reduceLeft(leastUpperBound(_,_)) + val rt: purescala.TypeTrees.TypeTree = rc.map(_.rhs.getType).reduceLeft(leastUpperBound(_,_).get) MatchExpr(rs, rc).setType(rt).setPosInfo(pm.pos.line,pm.pos.column) } diff --git a/src/main/scala/leon/purescala/Definitions.scala b/src/main/scala/leon/purescala/Definitions.scala index f73530c40..bc497dcdd 100644 --- a/src/main/scala/leon/purescala/Definitions.scala +++ b/src/main/scala/leon/purescala/Definitions.scala @@ -30,7 +30,7 @@ object Definitions { case _ => false } - def toVariable : Variable = Variable(id)//.setType(tpe) + def toVariable : Variable = Variable(id).setType(tpe) } type VarDecls = Seq[VarDecl] diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala index c4b6b4fc9..ae42766f9 100644 --- a/src/main/scala/leon/purescala/PrettyPrinter.scala +++ b/src/main/scala/leon/purescala/PrettyPrinter.scala @@ -66,7 +66,7 @@ object PrettyPrinter { } private def pp(tree: Expr, sb: StringBuffer, lvl: Int): StringBuffer = tree match { - case Variable(id) => sb.append(id) + case Variable(id) => sb.append(id + "#" + id.getType) case DeBruijnIndex(idx) => sb.append("_" + idx) case Let(b,d,e) => { //pp(e, pp(d, sb.append("(let (" + b + " := "), lvl).append(") in "), lvl).append(")") @@ -139,10 +139,10 @@ object PrettyPrinter { sb.append("\n") } - case Tuple(exprs) => ppNary(sb, exprs, "(", ", ", ")", lvl) - case TupleSelect(t, i) => { + case t@Tuple(exprs) => ppNary(sb, exprs, "(", ", ", ")#" + t.getType, lvl) + case s@TupleSelect(t, i) => { pp(t, sb, lvl) - sb.append("._" + i) + sb.append("._" + i + "#" + s.getType) sb } diff --git a/src/main/scala/leon/purescala/Trees.scala b/src/main/scala/leon/purescala/Trees.scala index 33fb76780..8938048fb 100644 --- a/src/main/scala/leon/purescala/Trees.scala +++ b/src/main/scala/leon/purescala/Trees.scala @@ -76,7 +76,13 @@ object Trees { } case class IfExpr(cond: Expr, then: Expr, elze: Expr) extends Expr - case class Tuple(exprs: Seq[Expr]) extends Expr + case class Tuple(exprs: Seq[Expr]) extends Expr { + val subTpes = exprs.map(_.getType) + if(!subTpes.exists(_ == Untyped)) { + setType(TupleType(subTpes)) + } + + } case class TupleSelect(tuple: Expr, index: Int) extends Expr object MatchExpr { diff --git a/src/main/scala/leon/purescala/TypeTrees.scala b/src/main/scala/leon/purescala/TypeTrees.scala index 0410fd11a..01aea89f9 100644 --- a/src/main/scala/leon/purescala/TypeTrees.scala +++ b/src/main/scala/leon/purescala/TypeTrees.scala @@ -18,7 +18,7 @@ object TypeTrees { def setType(tt: TypeTree): self.type = _type match { case None => _type = Some(tt); this - case Some(o) if o != tt => scala.sys.error("Resetting type information.") + case Some(o) if o != tt => scala.sys.error("Resetting type information! Type [" + o + "] is modified to [" + tt) case _ => this } } @@ -47,7 +47,7 @@ object TypeTrees { case other => other } - def leastUpperBound(t1: TypeTree, t2: TypeTree): TypeTree = (t1,t2) match { + def leastUpperBound(t1: TypeTree, t2: TypeTree): Option[TypeTree] = (t1,t2) match { case (c1: ClassType, c2: ClassType) => { import scala.collection.immutable.Set var c: ClassTypeDef = c1.classDef @@ -72,19 +72,19 @@ object TypeTrees { } if(found.isEmpty) { - scala.sys.error("Asking for lub of unrelated class types : " + t1 + " and " + t2) + None } else { - classDefToClassType(found.get) + Some(classDefToClassType(found.get)) } } - case (o1, o2) if (o1 == o2) => o1 - case (o1,BottomType) => o1 - case (BottomType,o2) => o2 - case (o1,AnyType) => AnyType - case (AnyType,o2) => AnyType + case (o1, o2) if (o1 == o2) => Some(o1) + case (o1,BottomType) => Some(o1) + case (BottomType,o2) => Some(o2) + case (o1,AnyType) => Some(AnyType) + case (AnyType,o2) => Some(AnyType) - case _ => scala.sys.error("Asking for lub of unrelated types: " + t1 + " and " + t2) + case _ => None } // returns the number of distinct values that inhabit a type diff --git a/testcases/regression/error/Array3.scala b/testcases/regression/error/Array3.scala new file mode 100644 index 000000000..451dd9db0 --- /dev/null +++ b/testcases/regression/error/Array3.scala @@ -0,0 +1,11 @@ +object Array3 { + + def foo(): Int = { + val a = Array.fill(5)(5) + if(a.length > 2) + a(1) = 2 + else 0 + 0 + } + +} -- GitLab