diff --git a/src/main/scala/leon/ArrayTransformation.scala b/src/main/scala/leon/ArrayTransformation.scala index dd1fee89e2abbb84d31ae835db0b39d1066616cc..a61605287496ac48d57036f27940fbe3065a9f2f 100644 --- a/src/main/scala/leon/ArrayTransformation.scala +++ b/src/main/scala/leon/ArrayTransformation.scala @@ -12,15 +12,54 @@ object ArrayTransformation extends Pass { def apply(pgm: Program): Program = { val allFuns = pgm.definedFunctions - allFuns.foreach(fd => fd.body.map(body => { - val newBody = transform(body) - fd.body = Some(newBody) - })) - pgm + val newFuns: Seq[Definition] = allFuns.map(fd => { + if(fd.hasImplementation) { + val body = fd.body.get + id2id = Map() + val args = fd.args + val newFd = + if(args.exists(vd => containsArrayType(vd.tpe))) { + println("args has array") + val newArgs = args.map(vd => { + val freshId = FreshIdentifier(vd.id.name).setType(TupleType(Seq(vd.tpe, Int32Type))) + id2id += (vd.id -> freshId) + val newTpe = transform(vd.tpe) + VarDecl(freshId, newTpe) + }) + val freshFunName = FreshIdentifier(fd.id.name) + val freshFunDef = new FunDef(freshFunName, fd.returnType, newArgs) + freshFunDef.fromLoop = fd.fromLoop + freshFunDef.parent = fd.parent + freshFunDef.precondition = fd.precondition + freshFunDef.postcondition = fd.postcondition + freshFunDef.addAnnotation(fd.annotations.toSeq:_*) + freshFunDef + } else fd + val newBody = transform(body) + newFd.body = Some(newBody) + newFd + } else fd + }) + + val Program(id, ObjectDef(objId, _, invariants)) = pgm + val allClasses: Seq[Definition] = pgm.definedClasses + Program(id, ObjectDef(objId, allClasses ++ newFuns, invariants)) } private var id2id: Map[Identifier, Identifier] = Map() + private def transform(tpe: TypeTree): TypeTree = tpe match { + case ArrayType(base) => TupleType(Seq(ArrayType(transform(base)), Int32Type)) + case TupleType(tpes) => TupleType(tpes.map(transform)) + case t => t + } + private def containsArrayType(tpe: TypeTree): Boolean = tpe match { + case ArrayType(base) => true + case TupleType(tpes) => tpes.exists(containsArrayType) + case t => false + } + + private def transform(expr: Expr): Expr = expr match { case fill@ArrayFill(length, default) => { var rLength = transform(length) diff --git a/testcases/regression/valid/Array4.scala b/testcases/regression/valid/Array4.scala new file mode 100644 index 0000000000000000000000000000000000000000..479f243ba3962304e8be552924ed4a5275622db4 --- /dev/null +++ b/testcases/regression/valid/Array4.scala @@ -0,0 +1,7 @@ +object Array4 { + + def foo(a: Array[Int]): Int = { + a(2) + } ensuring(_ == 3) + +}