Skip to content
Snippets Groups Projects
Commit 959719aa authored by Régis Blanc's avatar Régis Blanc
Browse files

add support for array in argument list

parent 4f050d7a
No related branches found
No related tags found
No related merge requests found
......@@ -12,15 +12,54 @@ object ArrayTransformation extends Pass {
def apply(pgm: Program): Program = {
val allFuns = pgm.definedFunctions
allFuns.foreach(fd => fd.body.map(body => {
val newBody = transform(body)
fd.body = Some(newBody)
}))
pgm
val newFuns: Seq[Definition] = allFuns.map(fd => {
if(fd.hasImplementation) {
val body = fd.body.get
id2id = Map()
val args = fd.args
val newFd =
if(args.exists(vd => containsArrayType(vd.tpe))) {
println("args has array")
val newArgs = args.map(vd => {
val freshId = FreshIdentifier(vd.id.name).setType(TupleType(Seq(vd.tpe, Int32Type)))
id2id += (vd.id -> freshId)
val newTpe = transform(vd.tpe)
VarDecl(freshId, newTpe)
})
val freshFunName = FreshIdentifier(fd.id.name)
val freshFunDef = new FunDef(freshFunName, fd.returnType, newArgs)
freshFunDef.fromLoop = fd.fromLoop
freshFunDef.parent = fd.parent
freshFunDef.precondition = fd.precondition
freshFunDef.postcondition = fd.postcondition
freshFunDef.addAnnotation(fd.annotations.toSeq:_*)
freshFunDef
} else fd
val newBody = transform(body)
newFd.body = Some(newBody)
newFd
} else fd
})
val Program(id, ObjectDef(objId, _, invariants)) = pgm
val allClasses: Seq[Definition] = pgm.definedClasses
Program(id, ObjectDef(objId, allClasses ++ newFuns, invariants))
}
private var id2id: Map[Identifier, Identifier] = Map()
private def transform(tpe: TypeTree): TypeTree = tpe match {
case ArrayType(base) => TupleType(Seq(ArrayType(transform(base)), Int32Type))
case TupleType(tpes) => TupleType(tpes.map(transform))
case t => t
}
private def containsArrayType(tpe: TypeTree): Boolean = tpe match {
case ArrayType(base) => true
case TupleType(tpes) => tpes.exists(containsArrayType)
case t => false
}
private def transform(expr: Expr): Expr = expr match {
case fill@ArrayFill(length, default) => {
var rLength = transform(length)
......
object Array4 {
def foo(a: Array[Int]): Int = {
a(2)
} ensuring(_ == 3)
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment