diff --git a/src/main/scala/leon/ArrayTransformation.scala b/src/main/scala/leon/ArrayTransformation.scala index 294bc4fefbf2ec0fb5f47d11bc8f7d3f5b5d4b04..35cad52f2d938fc1efd71c975b4a50d7b865c54e 100644 --- a/src/main/scala/leon/ArrayTransformation.scala +++ b/src/main/scala/leon/ArrayTransformation.scala @@ -11,42 +11,45 @@ object ArrayTransformation extends Pass { def apply(pgm: Program): Program = { + fd2fd = Map() + id2id = Map() + val allFuns = pgm.definedFunctions - val newFuns: Seq[Definition] = allFuns.map(fd => { + + val newFuns: Seq[FunDef] = allFuns.map(fd => { if(fd.hasImplementation) { - val body = fd.body.get - id2id = Map() val args = fd.args - val newFd = - if(args.exists(vd => containsArrayType(vd.tpe))) { - println("args has array") - val newArgs = args.map(vd => { - val freshId = FreshIdentifier(vd.id.name).setType(TupleType(Seq(vd.tpe, Int32Type))) - id2id += (vd.id -> freshId) - val newTpe = transform(vd.tpe) - VarDecl(freshId, newTpe) - }) - val freshFunName = FreshIdentifier(fd.id.name) - val freshFunDef = new FunDef(freshFunName, fd.returnType, newArgs) - freshFunDef.fromLoop = fd.fromLoop - freshFunDef.parent = fd.parent - freshFunDef.precondition = fd.precondition - freshFunDef.postcondition = fd.postcondition - freshFunDef.addAnnotation(fd.annotations.toSeq:_*) - freshFunDef - } else fd - val newBody = transform(body) - newFd.body = Some(newBody) - newFd + if(args.exists(vd => containsArrayType(vd.tpe))) { + val newArgs = args.map(vd => { + val freshId = FreshIdentifier(vd.id.name).setType(TupleType(Seq(vd.tpe, Int32Type))) + id2id += (vd.id -> freshId) + val newTpe = transform(vd.tpe) + VarDecl(freshId, newTpe) + }) + val freshFunName = FreshIdentifier(fd.id.name) + val freshFunDef = new FunDef(freshFunName, fd.returnType, newArgs) + freshFunDef.fromLoop = fd.fromLoop + freshFunDef.parent = fd.parent + freshFunDef.precondition = fd.precondition + freshFunDef.postcondition = fd.postcondition + freshFunDef.addAnnotation(fd.annotations.toSeq:_*) + fd2fd += (fd -> freshFunDef) + freshFunDef + } else fd } else fd }) + println(newFuns) + + allFuns.zip(newFuns).foreach{ case (ofd, nfd) => ofd.body.map(body => { + val newBody = transform(body) + nfd.body = Some(newBody) + })} val Program(id, ObjectDef(objId, _, invariants)) = pgm val allClasses: Seq[Definition] = pgm.definedClasses Program(id, ObjectDef(objId, allClasses ++ newFuns, invariants)) } - private var id2id: Map[Identifier, Identifier] = Map() private def transform(tpe: TypeTree): TypeTree = tpe match { case ArrayType(base) => TupleType(Seq(ArrayType(transform(base)), Int32Type)) @@ -59,6 +62,8 @@ object ArrayTransformation extends Pass { case t => false } + private var id2id: Map[Identifier, Identifier] = Map() + private var fd2fd: Map[FunDef, FunDef] = Map() private def transform(expr: Expr): Expr = expr match { case fill@ArrayFill(length, default) => { @@ -139,8 +144,40 @@ object ArrayTransformation extends Pass { val tpe = csesRec.head.rhs.getType MatchExpr(scrutRec, csesRec).setType(tpe).setPosInfo(m) } - - //case LetDef(fd, b) => + case LetDef(fd, b) => { + val newFd = if(fd.hasImplementation) { + val body = fd.body.get + val args = fd.args + val newFd = + if(args.exists(vd => containsArrayType(vd.tpe))) { + val newArgs = args.map(vd => { + val freshId = FreshIdentifier(vd.id.name).setType(TupleType(Seq(vd.tpe, Int32Type))) + id2id += (vd.id -> freshId) + val newTpe = transform(vd.tpe) + VarDecl(freshId, newTpe) + }) + val freshFunName = FreshIdentifier(fd.id.name) + val freshFunDef = new FunDef(freshFunName, fd.returnType, newArgs) + freshFunDef.fromLoop = fd.fromLoop + freshFunDef.parent = fd.parent + freshFunDef.precondition = fd.precondition + freshFunDef.postcondition = fd.postcondition + freshFunDef.addAnnotation(fd.annotations.toSeq:_*) + fd2fd += (fd -> freshFunDef) + freshFunDef + } else fd + val newBody = transform(body) + newFd.body = Some(newBody) + newFd + } else fd + val rb = transform(b) + LetDef(newFd, rb) + } + case FunctionInvocation(fd, args) => { + val rargs = args.map(transform) + val rfd = fd2fd.get(fd).getOrElse(fd) + FunctionInvocation(rfd, rargs) + } case n @ NAryOperator(args, recons) => recons(args.map(transform)).setType(n.getType) case b @ BinaryOperator(a1, a2, recons) => recons(transform(a1), transform(a2)).setType(b.getType) diff --git a/testcases/regression/valid/Array5.scala b/testcases/regression/valid/Array5.scala new file mode 100644 index 0000000000000000000000000000000000000000..1f36c1a760c263fe7b79b4498a65610b6265a45f --- /dev/null +++ b/testcases/regression/valid/Array5.scala @@ -0,0 +1,20 @@ +import leon.Utils._ + +object Array4 { + + def foo(a: Array[Int]): Int = { + var i = 0 + var sum = 0 + (while(i < a.length) { + sum = sum + a(i) + i = i + 1 + }) invariant(i >= 0) + sum + } + + def bar(): Int = { + val a = Array.fill(5)(5) + foo(a) + } + +}