diff --git a/src/main/scala/leon/ImperativeCodeElimination.scala b/src/main/scala/leon/ImperativeCodeElimination.scala index b922b454089552b421bce14077234319040f7ff2..07d14895e8db9379758ce82cf0e97ff66d3517a2 100644 --- a/src/main/scala/leon/ImperativeCodeElimination.scala +++ b/src/main/scala/leon/ImperativeCodeElimination.scala @@ -64,41 +64,45 @@ object ImperativeCodeElimination extends Pass { b)))) }) - (resId.toVariable, scope, Map(modifiedVars.zip(freshIds):_*)) + (resId.toVariable, scope, modifiedVars.zip(freshIds).toMap) } case While(cond, body) => { val (_, bodyScope, bodyFun) = toFunction(body) val modifiedVars: Seq[Identifier] = bodyFun.keys.toSeq - val whileFunVars = modifiedVars.map(id => FreshIdentifier(id.name).setType(id.getType)) - val whileFunVarDecls = whileFunVars.map(id => VarDecl(id, id.getType)) - val whileFunReturnType = if(whileFunVars.size == 1) whileFunVars.head.getType else TupleType(whileFunVars.map(_.getType)) - val whileFunDef = new FunDef(FreshIdentifier("while"), whileFunReturnType, whileFunVarDecls) - - val modifiedVars2WhileFunVars: Map[Expr, Expr] = modifiedVars.zip(whileFunVars).map(p => (p._1.toVariable, p._2.toVariable)).toMap - val whileFunCond = replace(modifiedVars2WhileFunVars, cond) - val whileFunRecursiveCall = replace(modifiedVars2WhileFunVars, bodyScope(FunctionInvocation(whileFunDef, modifiedVars.map(id => bodyFun(id).toVariable)))) - val whileFunBaseCase = (if(whileFunVars.size == 1) whileFunVars.head.toVariable else Tuple(whileFunVars.map(_.toVariable))).setType(whileFunReturnType) - val whileFunBody = IfExpr(whileFunCond, whileFunRecursiveCall, whileFunBaseCase).setType(whileFunReturnType) - whileFunDef.body = Some(whileFunBody) - - val finalVars = modifiedVars.map(id => FreshIdentifier(id.name).setType(id.getType)) - val finalScope = ((body: Expr) => { - val tupleId = FreshIdentifier("t").setType(whileFunReturnType) - LetDef( - whileFunDef, - Let(tupleId, - FunctionInvocation(whileFunDef, modifiedVars.map(_.toVariable)), - if(finalVars.size == 1) - Let(finalVars.head, tupleId.toVariable, body) - else - finalVars.zipWithIndex.foldLeft(body)((b, id) => - Let(id._1, - TupleSelect(tupleId.toVariable, id._2 + 1).setType(id._1.getType), - b)))) - }) - - (UnitLiteral, finalScope, modifiedVars.zip(finalVars).toMap) + if(modifiedVars.isEmpty) + (UnitLiteral, (b: Expr) => b, Map()) + else { + val whileFunVars = modifiedVars.map(id => FreshIdentifier(id.name).setType(id.getType)) + val whileFunVarDecls = whileFunVars.map(id => VarDecl(id, id.getType)) + val whileFunReturnType = if(whileFunVars.size == 1) whileFunVars.head.getType else TupleType(whileFunVars.map(_.getType)) + val whileFunDef = new FunDef(FreshIdentifier("while"), whileFunReturnType, whileFunVarDecls) + + val modifiedVars2WhileFunVars: Map[Expr, Expr] = modifiedVars.zip(whileFunVars).map(p => (p._1.toVariable, p._2.toVariable)).toMap + val whileFunCond = replace(modifiedVars2WhileFunVars, cond) + val whileFunRecursiveCall = replace(modifiedVars2WhileFunVars, bodyScope(FunctionInvocation(whileFunDef, modifiedVars.map(id => bodyFun(id).toVariable)))) + val whileFunBaseCase = (if(whileFunVars.size == 1) whileFunVars.head.toVariable else Tuple(whileFunVars.map(_.toVariable))).setType(whileFunReturnType) + val whileFunBody = IfExpr(whileFunCond, whileFunRecursiveCall, whileFunBaseCase).setType(whileFunReturnType) + whileFunDef.body = Some(whileFunBody) + + val finalVars = modifiedVars.map(id => FreshIdentifier(id.name).setType(id.getType)) + val finalScope = ((body: Expr) => { + val tupleId = FreshIdentifier("t").setType(whileFunReturnType) + LetDef( + whileFunDef, + Let(tupleId, + FunctionInvocation(whileFunDef, modifiedVars.map(_.toVariable)), + if(finalVars.size == 1) + Let(finalVars.head, tupleId.toVariable, body) + else + finalVars.zipWithIndex.foldLeft(body)((b, id) => + Let(id._1, + TupleSelect(tupleId.toVariable, id._2 + 1).setType(id._1.getType), + b)))) + }) + + (UnitLiteral, finalScope, modifiedVars.zip(finalVars).toMap) + } } case Block(head::exprs, expr) => { val (_, headScope, headFun) = toFunction(head)