diff --git a/src/main/scala/leon/ArrayTransformation.scala b/src/main/scala/leon/ArrayTransformation.scala index c6c6dafc208259d07046a0afec7b42d54dab5a5d..db6a7fe2fead73cc496bcba0721709dcaca75f2e 100644 --- a/src/main/scala/leon/ArrayTransformation.scala +++ b/src/main/scala/leon/ArrayTransformation.scala @@ -7,117 +7,64 @@ import purescala.TypeTrees._ object ArrayTransformation extends Pass { - val description = "Add bound checking for array access and remove side effect array update operations" + val description = "Add bound checking for array access and remove array update with side effect" def apply(pgm: Program): Program = { - fd2fd = Map() - id2id = Map() - val allFuns = pgm.definedFunctions - - val newFuns: Seq[FunDef] = allFuns.map(fd => { - if(fd.hasImplementation) { - val args = fd.args - if(args.exists(vd => containsArrayType(vd.tpe)) || containsArrayType(fd.returnType)) { - val newArgs = args.map(vd => { - val freshId = FreshIdentifier(vd.id.name).setType(transform(vd.tpe)) - id2id += (vd.id -> freshId) - val newTpe = transform(vd.tpe) - VarDecl(freshId, newTpe) - }) - val freshFunName = FreshIdentifier(fd.id.name) - val freshFunDef = new FunDef(freshFunName, transform(fd.returnType), newArgs) - fd2fd += (fd -> freshFunDef) - freshFunDef.fromLoop = fd.fromLoop - freshFunDef.parent = fd.parent - freshFunDef.addAnnotation(fd.annotations.toSeq:_*) - freshFunDef - } else fd - } else fd + allFuns.foreach(fd => { + id2FreshId = Map() + fd.precondition = fd.precondition.map(transform) + fd.body = fd.body.map(transform) + fd.postcondition = fd.postcondition.map(transform) }) - - allFuns.zip(newFuns).foreach{ case (ofd, nfd) => ofd.body.map(body => { - nfd.precondition = ofd.precondition.map(transform) - nfd.postcondition = ofd.postcondition.map(transform) - val newBody = transform(body) - nfd.body = Some(newBody) - })} - - val Program(id, ObjectDef(objId, _, invariants)) = pgm - val allClasses: Seq[Definition] = pgm.definedClasses - Program(id, ObjectDef(objId, allClasses ++ newFuns, invariants)) + pgm } - private def transform(tpe: TypeTree): TypeTree = tpe match { - case ArrayType(base) => TupleType(Seq(ArrayType(transform(base)), Int32Type)) - case TupleType(tpes) => TupleType(tpes.map(transform)) - case t => t - } - private def containsArrayType(tpe: TypeTree): Boolean = tpe match { - case ArrayType(base) => true - case TupleType(tpes) => tpes.exists(containsArrayType) - case t => false - } - - private var id2id: Map[Identifier, Identifier] = Map() - private var fd2fd: Map[FunDef, FunDef] = Map() + private var id2FreshId = Map[Identifier, Identifier]() - private def transform(expr: Expr): Expr = expr match { - case fill@ArrayFill(length, default) => { - var rLength = transform(length) - val rDefault = transform(default) - val rFill = ArrayMake(rDefault).setType(fill.getType) - Tuple(Seq(rFill, rLength)).setType(TupleType(Seq(fill.getType, Int32Type))) - } + def transform(expr: Expr): Expr = expr match { case sel@ArraySelect(a, i) => { - val ar = transform(a) - val ir = transform(i) - val length = TupleSelect(ar, 2).setType(Int32Type) - IfExpr( - And(GreaterEquals(ir, IntLiteral(0)), LessThan(ir, length)), - ArraySelect(TupleSelect(ar, 1).setType(ArrayType(sel.getType)), ir).setType(sel.getType).setPosInfo(sel), + val ra = transform(a) + val ri = transform(i) + val length = ArrayLength(ra) + val res = IfExpr( + And(LessEquals(IntLiteral(0), ri), LessThan(ri, length)), + ArraySelect(ra, ri).setType(sel.getType).setPosInfo(sel), Error("Index out of bound").setType(sel.getType).setPosInfo(sel) ).setType(sel.getType) + res } case up@ArrayUpdate(a, i, v) => { - val ar = transform(a) - val ir = transform(i) - val vr = transform(v) - val Variable(id) = ar - val length = TupleSelect(ar, 2).setType(Int32Type) - val array = TupleSelect(ar, 1).setType(ArrayType(v.getType)) - IfExpr( - And(GreaterEquals(i, IntLiteral(0)), LessThan(i, length)), - Assignment( - id, - Tuple(Seq( - ArrayUpdated(array, ir, vr).setType(array.getType).setPosInfo(up), - length) - ).setType(TupleType(Seq(array.getType, Int32Type)))), + val ra = transform(a) + val ri = transform(i) + val rv = transform(v) + val Variable(id) = ra + val length = ArrayLength(ra) + val array = TupleSelect(ra, 1).setType(ArrayType(v.getType)) + val res = IfExpr( + And(LessEquals(IntLiteral(0), ri), LessThan(ri, length)), + Assignment(id, ArrayUpdated(ra, ri, rv).setType(a.getType).setPosInfo(up)), Error("Index out of bound").setType(UnitType).setPosInfo(up) ).setType(UnitType) - } - case ArrayLength(a) => { - val ar = transform(a) - TupleSelect(ar, 2).setType(Int32Type) + res } case Let(i, v, b) => { - val vr = transform(v) v.getType match { case ArrayType(_) => { - val freshIdentifier = FreshIdentifier("t").setType(vr.getType) - id2id += (i -> freshIdentifier) - val br = transform(b) - LetVar(freshIdentifier, vr, br) - } - case _ => { - val br = transform(b) - Let(i, vr, br) + val freshIdentifier = FreshIdentifier("t").setType(v.getType) + id2FreshId += (i -> freshIdentifier) + LetVar(freshIdentifier, transform(v), transform(b)) } + case _ => Let(i, transform(v), transform(b)) } } + case Variable(i) => { + val freshId = id2FreshId.get(i).getOrElse(i) + Variable(freshId) + } + case LetVar(id, e, b) => { val er = transform(e) val br = transform(b) @@ -146,48 +93,191 @@ object ArrayTransformation extends Pass { MatchExpr(scrutRec, csesRec).setType(tpe).setPosInfo(m) } case LetDef(fd, b) => { - val newFd = if(fd.hasImplementation) { - val body = fd.body.get - val args = fd.args - val newFd = - if(args.exists(vd => containsArrayType(vd.tpe)) || containsArrayType(fd.returnType)) { - val newArgs = args.map(vd => { - val freshId = FreshIdentifier(vd.id.name).setType(transform(vd.tpe)) - id2id += (vd.id -> freshId) - val newTpe = transform(vd.tpe) - VarDecl(freshId, newTpe) - }) - val freshFunName = FreshIdentifier(fd.id.name) - val freshFunDef = new FunDef(freshFunName, transform(fd.returnType), newArgs) - fd2fd += (fd -> freshFunDef) - freshFunDef.fromLoop = fd.fromLoop - freshFunDef.parent = fd.parent - freshFunDef.precondition = fd.precondition.map(transform) - freshFunDef.postcondition = fd.postcondition.map(transform) - freshFunDef.addAnnotation(fd.annotations.toSeq:_*) - freshFunDef - } else fd - val newBody = transform(body) - newFd.body = Some(newBody) - newFd - } else fd + fd.precondition = fd.precondition.map(transform) + fd.body = fd.body.map(transform) + fd.postcondition = fd.postcondition.map(transform) val rb = transform(b) - LetDef(newFd, rb) - } - case FunctionInvocation(fd, args) => { - val rargs = args.map(transform) - val rfd = fd2fd.get(fd).getOrElse(fd) - FunctionInvocation(rfd, rargs) + LetDef(fd, rb) } - case n @ NAryOperator(args, recons) => recons(args.map(transform)).setType(n.getType) case b @ BinaryOperator(a1, a2, recons) => recons(transform(a1), transform(a2)).setType(b.getType) case u @ UnaryOperator(a, recons) => recons(transform(a)).setType(u.getType) - case v @ Variable(id) => if(id2id.isDefinedAt(id)) Variable(id2id(id)) else v case (t: Terminal) => t case unhandled => scala.sys.error("Non-terminal case should be handled in ArrayTransformation: " + unhandled) - } + //val newFuns: Seq[FunDef] = allFuns.map(fd => { + // if(fd.hasImplementation) { + // val args = fd.args + // if(args.exists(vd => containsArrayType(vd.tpe)) || containsArrayType(fd.returnType)) { + // val newArgs = args.map(vd => { + // val freshId = FreshIdentifier(vd.id.name).setType(transform(vd.tpe)) + // id2id += (vd.id -> freshId) + // val newTpe = transform(vd.tpe) + // VarDecl(freshId, newTpe) + // }) + // val freshFunName = FreshIdentifier(fd.id.name) + // val freshFunDef = new FunDef(freshFunName, transform(fd.returnType), newArgs) + // fd2fd += (fd -> freshFunDef) + // freshFunDef.fromLoop = fd.fromLoop + // freshFunDef.parent = fd.parent + // freshFunDef.addAnnotation(fd.annotations.toSeq:_*) + // freshFunDef + // } else fd + // } else fd + //}) + + //allFuns.zip(newFuns).foreach{ case (ofd, nfd) => ofd.body.map(body => { + // nfd.precondition = ofd.precondition.map(transform) + // nfd.postcondition = ofd.postcondition.map(transform) + // val newBody = transform(body) + // nfd.body = Some(newBody) + //})} + + //val Program(id, ObjectDef(objId, _, invariants)) = pgm + //val allClasses: Seq[Definition] = pgm.definedClasses + //Program(id, ObjectDef(objId, allClasses ++ newFuns, invariants)) + + + //private def transform(tpe: TypeTree): TypeTree = tpe match { + // case ArrayType(base) => TupleType(Seq(ArrayType(transform(base)), Int32Type)) + // case TupleType(tpes) => TupleType(tpes.map(transform)) + // case t => t + //} + //private def containsArrayType(tpe: TypeTree): Boolean = tpe match { + // case ArrayType(base) => true + // case TupleType(tpes) => tpes.exists(containsArrayType) + // case t => false + //} + + //private var id2id: Map[Identifier, Identifier] = Map() + //private var fd2fd: Map[FunDef, FunDef] = Map() + + //private def transform(expr: Expr): Expr = expr match { + // case fill@ArrayFill(length, default) => { + // var rLength = transform(length) + // val rDefault = transform(default) + // val rFill = ArrayMake(rDefault).setType(fill.getType) + // Tuple(Seq(rFill, rLength)).setType(TupleType(Seq(fill.getType, Int32Type))) + // } + // case sel@ArraySelect(a, i) => { + // val ar = transform(a) + // val ir = transform(i) + // val length = TupleSelect(ar, 2).setType(Int32Type) + // IfExpr( + // And(GreaterEquals(ir, IntLiteral(0)), LessThan(ir, length)), + // ArraySelect(TupleSelect(ar, 1).setType(ArrayType(sel.getType)), ir).setType(sel.getType).setPosInfo(sel), + // Error("Index out of bound").setType(sel.getType).setPosInfo(sel) + // ).setType(sel.getType) + // } + // case up@ArrayUpdate(a, i, v) => { + // val ar = transform(a) + // val ir = transform(i) + // val vr = transform(v) + // val Variable(id) = ar + // val length = TupleSelect(ar, 2).setType(Int32Type) + // val array = TupleSelect(ar, 1).setType(ArrayType(v.getType)) + // IfExpr( + // And(GreaterEquals(i, IntLiteral(0)), LessThan(i, length)), + // Assignment( + // id, + // Tuple(Seq( + // ArrayUpdated(array, ir, vr).setType(array.getType).setPosInfo(up), + // length) + // ).setType(TupleType(Seq(array.getType, Int32Type)))), + // Error("Index out of bound").setType(UnitType).setPosInfo(up) + // ).setType(UnitType) + // } + // case ArrayLength(a) => { + // val ar = transform(a) + // TupleSelect(ar, 2).setType(Int32Type) + // } + // case Let(i, v, b) => { + // val vr = transform(v) + // v.getType match { + // case ArrayType(_) => { + // val freshIdentifier = FreshIdentifier("t").setType(vr.getType) + // id2id += (i -> freshIdentifier) + // val br = transform(b) + // LetVar(freshIdentifier, vr, br) + // } + // case _ => { + // val br = transform(b) + // Let(i, vr, br) + // } + // } + // } + // case LetVar(id, e, b) => { + // val er = transform(e) + // val br = transform(b) + // LetVar(id, er, br) + // } + // case wh@While(c, e) => { + // val newWh = While(transform(c), transform(e)) + // newWh.invariant = wh.invariant.map(i => transform(i)) + // newWh.setPosInfo(wh) + // } + + // case ite@IfExpr(c, t, e) => { + // val rc = transform(c) + // val rt = transform(t) + // val re = transform(e) + // IfExpr(rc, rt, re).setType(rt.getType) + // } + + // case m @ MatchExpr(scrut, cses) => { + // val scrutRec = transform(scrut) + // val csesRec = cses.map{ + // case SimpleCase(pat, rhs) => SimpleCase(pat, transform(rhs)) + // case GuardedCase(pat, guard, rhs) => GuardedCase(pat, transform(guard), transform(rhs)) + // } + // val tpe = csesRec.head.rhs.getType + // MatchExpr(scrutRec, csesRec).setType(tpe).setPosInfo(m) + // } + // case LetDef(fd, b) => { + // val newFd = if(fd.hasImplementation) { + // val body = fd.body.get + // val args = fd.args + // val newFd = + // if(args.exists(vd => containsArrayType(vd.tpe)) || containsArrayType(fd.returnType)) { + // val newArgs = args.map(vd => { + // val freshId = FreshIdentifier(vd.id.name).setType(transform(vd.tpe)) + // id2id += (vd.id -> freshId) + // val newTpe = transform(vd.tpe) + // VarDecl(freshId, newTpe) + // }) + // val freshFunName = FreshIdentifier(fd.id.name) + // val freshFunDef = new FunDef(freshFunName, transform(fd.returnType), newArgs) + // fd2fd += (fd -> freshFunDef) + // freshFunDef.fromLoop = fd.fromLoop + // freshFunDef.parent = fd.parent + // freshFunDef.precondition = fd.precondition.map(transform) + // freshFunDef.postcondition = fd.postcondition.map(transform) + // freshFunDef.addAnnotation(fd.annotations.toSeq:_*) + // freshFunDef + // } else fd + // val newBody = transform(body) + // newFd.body = Some(newBody) + // newFd + // } else fd + // val rb = transform(b) + // LetDef(newFd, rb) + // } + // case FunctionInvocation(fd, args) => { + // val rargs = args.map(transform) + // val rfd = fd2fd.get(fd).getOrElse(fd) + // FunctionInvocation(rfd, rargs) + // } + + // case n @ NAryOperator(args, recons) => recons(args.map(transform)).setType(n.getType) + // case b @ BinaryOperator(a1, a2, recons) => recons(transform(a1), transform(a2)).setType(b.getType) + // case u @ UnaryOperator(a, recons) => recons(transform(a)).setType(u.getType) + + // case v @ Variable(id) => if(id2id.isDefinedAt(id)) Variable(id2id(id)) else v + // case (t: Terminal) => t + // case unhandled => scala.sys.error("Non-terminal case should be handled in ArrayTransformation: " + unhandled) + + //} + } diff --git a/src/main/scala/leon/Evaluator.scala b/src/main/scala/leon/Evaluator.scala index 3a1882b12bb8933b84e731a0a382562aa00e52b2..c26d7ca3d404774908fc3e1d80ad01cd04921546 100644 --- a/src/main/scala/leon/Evaluator.scala +++ b/src/main/scala/leon/Evaluator.scala @@ -257,9 +257,16 @@ object Evaluator { case b @ BooleanLiteral(_) => b case u @ UnitLiteral => u - case f @ ArrayMake(default) => { + case f @ ArrayFill(length, default) => { val rDefault = rec(ctx, default) - ArrayMake(rDefault) + val rLength = rec(ctx, length) + ArrayFill(rLength, rDefault) + } + case ArrayLength(a) => { + var ra = rec(ctx, a) + while(!ra.isInstanceOf[ArrayFill]) + ra = ra.asInstanceOf[ArrayUpdated].array + ra.asInstanceOf[ArrayFill].length } case ArrayUpdated(a, i, v) => { val ra = rec(ctx, a) @@ -273,7 +280,7 @@ object Evaluator { var ra = rec(ctx, a) var found = false var result: Option[Expr] = None - while(!ra.isInstanceOf[ArrayMake] && !found) { + while(!ra.isInstanceOf[ArrayFill] && !found) { val ArrayUpdated(ra2, IntLiteral(i), v) = ra if(index == i) { result = Some(v) @@ -284,7 +291,7 @@ object Evaluator { } result match { case Some(r) => r - case None => ra.asInstanceOf[ArrayMake].defaultValue + case None => ra.asInstanceOf[ArrayFill].defaultValue } } diff --git a/src/main/scala/leon/FairZ3Solver.scala b/src/main/scala/leon/FairZ3Solver.scala index 4dbbaa2b58e8067b1d5abe013260cd4d253a211a..58cd49ff44051d396fda5fceb22569da4f55d17d 100644 --- a/src/main/scala/leon/FairZ3Solver.scala +++ b/src/main/scala/leon/FairZ3Solver.scala @@ -98,7 +98,6 @@ class FairZ3Solver(reporter: Reporter) extends Solver(reporter) with AbstractZ3S private var boolSort: Z3Sort = null private var setSorts: Map[TypeTree, Z3Sort] = Map.empty private var mapSorts: Map[TypeTree, Z3Sort] = Map.empty - private var arraySorts: Map[TypeTree, Z3Sort] = Map.empty private var unitSort: Z3Sort = null private var unitValue: Z3AST = null @@ -111,6 +110,11 @@ class FairZ3Solver(reporter: Reporter) extends Solver(reporter) with AbstractZ3S protected[leon] var tupleConstructors: Map[TypeTree, Z3FuncDecl] = Map.empty protected[leon] var tupleSelectors: Map[TypeTree, Seq[Z3FuncDecl]] = Map.empty + private var arraySorts: Map[TypeTree, Z3Sort] = Map.empty + protected[leon] var arrayTupleCons: Map[TypeTree, Z3FuncDecl] = Map.empty + protected[leon] var arrayTupleSelectorArray: Map[TypeTree, Z3FuncDecl] = Map.empty + protected[leon] var arrayTupleSelectorLength: Map[TypeTree, Z3FuncDecl] = Map.empty + private var reverseTupleConstructors: Map[Z3FuncDecl, TupleType] = Map.empty private var reverseTupleSelectors: Map[Z3FuncDecl, (TupleType, Int)] = Map.empty @@ -419,11 +423,16 @@ class FairZ3Solver(reporter: Reporter) extends Solver(reporter) with AbstractZ3S case at @ ArrayType(base) => arraySorts.get(at) match { case Some(s) => s case None => { - val fromSort = typeToSort(Int32Type) + val intSort = typeToSort(Int32Type) val toSort = typeToSort(base) - val as = z3.mkArraySort(fromSort, toSort) - arraySorts += (at -> as) - as + val as = z3.mkArraySort(intSort, toSort) + val tupleSortSymbol = z3.mkFreshStringSymbol("Array") + val (arrayTupleSort, arrayTupleCons_, Seq(arrayTupleSelectorArray_, arrayTupleSelectorLength_)) = z3.mkTupleSort(tupleSortSymbol, as, intSort) + arraySorts += (at -> arrayTupleSort) + arrayTupleCons += (at -> arrayTupleCons_) + arrayTupleSelectorArray += (at -> arrayTupleSelectorArray_) + arrayTupleSelectorLength += (at -> arrayTupleSelectorLength_) + arrayTupleSort } } case ft @ FunctionType(fts, tt) => funSorts.get(ft) match { @@ -443,7 +452,7 @@ class FairZ3Solver(reporter: Reporter) extends Solver(reporter) with AbstractZ3S case Some(s) => s case None => { val tpesSorts = tpes.map(typeToSort) - val sortSymbol = z3.mkFreshStringSymbol("TupleSort") + val sortSymbol = z3.mkFreshStringSymbol("Tuple") val (tupleSort, consTuple, projsTuple) = z3.mkTupleSort(sortSymbol, tpesSorts: _*) tupleSorts += (tt -> tupleSort) tupleConstructors += (tt -> consTuple) @@ -1086,15 +1095,37 @@ class FairZ3Solver(reporter: Reporter) extends Solver(reporter) with AbstractZ3S case MapType(ft, tt) => z3.mkDistinct(z3.mkSelect(rec(m), rec(k)), mapRangeNoneConstructors(tt)()) case errorType => scala.sys.error("Unexpected type for map: " + (ex, errorType)) } - case a@ArrayMake(default) => { - val ArrayType(base) = a.getType - z3.mkConstArray(typeToSort(base), rec(default)) + case fill@ArrayFill(length, default) => { + val at@ArrayType(base) = fill.getType + typeToSort(at) + val cons = arrayTupleCons(at) + val ar = z3.mkConstArray(typeToSort(base), rec(default)) + val res = cons(ar, rec(length)) + res } - case ArraySelect(ar, index) => { - z3.mkSelect(rec(ar), rec(index)) + case ArraySelect(a, index) => { + typeToSort(a.getType) + val ar = rec(a) + val getArray = arrayTupleSelectorArray(a.getType) + val res = z3.mkSelect(getArray(ar), rec(index)) + res } - case ArrayUpdated(ar, index, newVal) => { - z3.mkStore(rec(ar), rec(index), rec(newVal)) + case ArrayUpdated(a, index, newVal) => { + typeToSort(a.getType) + val ar = rec(a) + val getArray = arrayTupleSelectorArray(a.getType) + val getLength = arrayTupleSelectorLength(a.getType) + val cons = arrayTupleCons(a.getType) + val store = z3.mkStore(getArray(ar), rec(index), rec(newVal)) + val res = cons(store, getLength(ar)) + res + } + case ArrayLength(a) => { + typeToSort(a.getType) + val ar = rec(a) + val getLength = arrayTupleSelectorLength(a.getType) + val res = getLength(ar) + res } case AnonymousFunctionInvocation(id, args) => id.getType match { case ft @ FunctionType(fts, tt) => { @@ -1153,14 +1184,19 @@ class FairZ3Solver(reporter: Reporter) extends Solver(reporter) with AbstractZ3S (if (elems.isEmpty) EmptySet(dt) else FiniteSet(elems.toSeq)).setType(expType.get) } } - case Some(ArrayType(dt)) => - model.getArrayValue(t) match { + case Some(ArrayType(dt)) => { + val Z3AppAST(decl, args) = z3.getASTKind(t) + assert(args.size == 2) + val length = rec(args(1), Some(Int32Type)) + val array = model.getArrayValue(args(0)) match { case None => throw new CantTranslateException(t) case Some((map, elseValue)) => - map.foldLeft(ArrayMake(rec(elseValue, Some(dt))): Expr) { + map.foldLeft(ArrayFill(length, rec(elseValue, Some(dt))): Expr) { case (acc, (key, value)) => ArrayUpdated(acc, rec(key, Some(Int32Type)), rec(value, Some(dt))) } } + array + } case other => if(t == unitValue) UnitLiteral diff --git a/src/main/scala/leon/purescala/Trees.scala b/src/main/scala/leon/purescala/Trees.scala index 10ead38f574111bfd42850296bcb11002c1e9d47..33fb76780b3b6af4838db1355b269044c22d5c29 100644 --- a/src/main/scala/leon/purescala/Trees.scala +++ b/src/main/scala/leon/purescala/Trees.scala @@ -826,7 +826,7 @@ object Trees { case LetVar(_, _, _) => false case LetDef(_, _) => false case ArrayUpdate(_, _, _) => false - case ArrayFill(_, _) => false + case ArrayMake(_) => false case Epsilon(_) => false case _ => true } @@ -838,7 +838,7 @@ object Trees { case LetVar(_, _, _) => false case LetDef(_, _) => false case ArrayUpdate(_, _, _) => false - case ArrayFill(_, _) => false + case ArrayMake(_) => false case Epsilon(_) => false case _ => b }