diff --git a/src/main/scala/leon/codegen/CompilationUnit.scala b/src/main/scala/leon/codegen/CompilationUnit.scala index 5b01232c47ac3a663ea3564f85386425539f881d..759af0eaf0e745e062a0c06d01f86ba240ccdf8e 100644 --- a/src/main/scala/leon/codegen/CompilationUnit.scala +++ b/src/main/scala/leon/codegen/CompilationUnit.scala @@ -170,7 +170,9 @@ class CompilationUnit(val ctx: LeonContext, case f @ purescala.Extractors.FiniteLambda(dflt, els) => val l = new leon.codegen.runtime.FiniteLambda(exprToJVM(dflt)) - for ((UnwrapTuple(ks),v) <- els) { + + for ((k,v) <- els) { + val ks = unwrapTuple(k, f.getType.asInstanceOf[FunctionType].from.size) // Force tuple even with 1/0 elems. val kJvm = tupleConstructor.newInstance(ks.map(exprToJVM _).toArray).asInstanceOf[leon.codegen.runtime.Tuple] val vJvm = exprToJVM(v) @@ -213,7 +215,8 @@ class CompilationUnit(val ctx: LeonContext, CaseClass(cct, (fields zip cct.fieldsTypes).map { case (e, tpe) => jvmToExpr(e, tpe) }) - case (tpl: runtime.Tuple, UnwrapTupleType(stpe)) => + case (tpl: runtime.Tuple, tpe) => + val stpe = unwrapTupleType(tpe, tpl.getArity()) val elems = stpe.zipWithIndex.map { case (tpe, i) => jvmToExpr(tpl.get(i), tpe) } diff --git a/src/main/scala/leon/datagen/VanuatooDataGen.scala b/src/main/scala/leon/datagen/VanuatooDataGen.scala index d1fd76888b47e3dd0db830bf72cfd114150f7b96..5858407944a4c228c9c6f65d5122c9c423ad2690 100644 --- a/src/main/scala/leon/datagen/VanuatooDataGen.scala +++ b/src/main/scala/leon/datagen/VanuatooDataGen.scala @@ -196,10 +196,12 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { (AnyPattern[Expr, TypeTree](), false) } - case (t: codegen.runtime.Tuple, tt@UnwrapTupleType(parts)) => + case (t: codegen.runtime.Tuple, tpe) => val r = t.__getRead() - val c = getConstructors(tt)(0) + val parts = unwrapTupleType(tpe, t.getArity()) + + val c = getConstructors(tpe)(0) val elems = for (i <- 0 until t.getArity) yield { if (((r >> i) & 1) == 1) { @@ -229,7 +231,7 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { val ttype = tupleTypeWrap(argorder.map(_.getType)) val tid = FreshIdentifier("tup", ttype) - val map = argorder.zipWithIndex.map{ case (id, i) => (id -> tupleSelect(Variable(tid), i+1)) }.toMap + val map = argorder.zipWithIndex.map{ case (id, i) => (id -> tupleSelect(Variable(tid), i+1, argorder.size)) }.toMap val newExpr = replaceFromIDs(map, expression) @@ -325,7 +327,7 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { def computeNext(): Option[Seq[Expr]] = { //return None while(total < maxEnumerated && found < maxValid && it.hasNext && !interrupted.get) { - val model = it.next//.asInstanceOf[Tuple] + val model = it.next if (model eq null) { total = maxEnumerated @@ -357,8 +359,7 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { it.skipIsomorphic() } - val UnwrapTuple(exprs) = model - return Some(exprs); + return Some(unwrapTuple(model, ins.size)); } //if (total % 1000 == 0) { diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala index 056acc8a6367bc3595a4709080496203964a7da7..ffa2524f3d9fc654291958a815896672b1462418 100644 --- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala +++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala @@ -1020,7 +1020,8 @@ trait CodeExtraction extends ASTExtractors { val oute = extractTree(out) val rc = cases.map(extractMatchCase(_)) - val UnwrapTuple(ines) = ine + // @mk: FIXME: this whole sanity checking is very dodgy at best. + val ines = unwrapTuple(ine, ine.isInstanceOf[Tuple]) // @mk We untuple all tuples ines foreach { case v @ Variable(_) if currentFunDef.params.map{ _.toVariable } contains v => case LeonThis(_) => @@ -1061,10 +1062,10 @@ trait CodeExtraction extends ASTExtractors { tupleExpr.getType match { case TupleType(tpes) if tpes.size >= index => - tupleSelect(tupleExpr, index) + tupleSelect(tupleExpr, index, true) case _ => - outOfSubsetError(current, "Invalid tupple access") + outOfSubsetError(current, "Invalid tuple access") } case ExValDef(vs, tpt, bdy) => diff --git a/src/main/scala/leon/purescala/Constructors.scala b/src/main/scala/leon/purescala/Constructors.scala index 02eae14fa5b5b91b16e2964f41f16ee1fdb0abe9..04ba31d95ea091f36f73c87e458f01328f57155b 100644 --- a/src/main/scala/leon/purescala/Constructors.scala +++ b/src/main/scala/leon/purescala/Constructors.scala @@ -11,31 +11,26 @@ object Constructors { import TypeTreeOps._ import Common._ import TypeTrees._ - import purescala.Extractors.UnwrapTupleType - def tupleSelect(t: Expr, index: Int) = t match { - case Tuple(es) => - // @mk FIXME: Notice tupleSelect(tupleWrap(Seq(Tuple(x,y))),1) -> x. This seems wrong. - es(index-1) - case _ if t.getType.isInstanceOf[TupleType] => + // If isTuple, the whole expression is returned. This is to avoid a situation + // like tupleSelect(tupleWrap(Seq(Tuple(x,y))),1) -> x, which is not expected. + // Instead, tupleSelect(tupleWrap(Seq(Tuple(x,y))),1) -> Tuple(x,y). + def tupleSelect(t: Expr, index: Int, isTuple: Boolean): Expr = t match { + case Tuple(es) if isTuple => es(index-1) + case _ if t.getType.isInstanceOf[TupleType] && isTuple => TupleSelect(t, index) - case _ if (index == 1) => - // For cases like tupleSelect(tupleWrap(Seq(x)), 1) -> x - t + case other if !isTuple => other case _ => - sys.error(s"Trying to construct TupleSelect with non-tuple $t and index $index!=1") + sys.error(s"Calling tupleSelect on non-tuple $t") } + def tupleSelect(t: Expr, index: Int, originalSize: Int): Expr = tupleSelect(t, index, originalSize > 1) + def letTuple(binders: Seq[Identifier], value: Expr, body: Expr) = binders match { case Nil => body case x :: Nil => - if (isSubtypeOf(value.getType, x.getType) || !value.getType.isInstanceOf[TupleType]) { - // This is for cases where we build it like: letTuple(List(x), tupleWrap(List(z))) - Let(x, value, body) - } else { - Let(x, tupleSelect(value, 1), body) - } + Let(x, value, body) case xs => require( value.getType.isInstanceOf[TupleType], @@ -218,8 +213,7 @@ object Constructors { } def finiteLambda(default: Expr, els: Seq[(Expr, Expr)], inputTypes: Seq[TypeTree]): Lambda = { - val UnwrapTupleType(argTypes) = els.headOption.map{_._1.getType}.getOrElse(tupleTypeWrap(inputTypes)) - val args = argTypes map { argType => ValDef(FreshIdentifier("x", argType, true)) } + val args = inputTypes map { tpe => ValDef(FreshIdentifier("x", tpe, true)) } if (els.isEmpty) { Lambda(args, default) } else { diff --git a/src/main/scala/leon/purescala/Extractors.scala b/src/main/scala/leon/purescala/Extractors.scala index fc53487c86d8a874d6687f2233687cca4bf206e9..0d9508a586e65310416610831857f20da424d180 100644 --- a/src/main/scala/leon/purescala/Extractors.scala +++ b/src/main/scala/leon/purescala/Extractors.scala @@ -27,7 +27,7 @@ object Extractors { case SetMax(s) => Some((s,SetMax)) case CaseClassSelector(cd, e, sel) => Some((e, CaseClassSelector(cd, _, sel))) case CaseClassInstanceOf(cd, e) => Some((e, CaseClassInstanceOf(cd, _))) - case TupleSelect(t, i) => Some((t, tupleSelect(_, i))) + case TupleSelect(t, i) => Some((t, tupleSelect(_, i, t.getType.asInstanceOf[TupleType].dimension))) case ArrayLength(a) => Some((a, ArrayLength)) case Lambda(args, body) => Some((body, Lambda(args, _))) case Forall(args, body) => Some((body, Forall(args, _))) @@ -287,12 +287,19 @@ object Extractors { object FiniteLambda { def unapply(lambda: Lambda): Option[(Expr, Seq[(Expr, Expr)])] = { + val inSize = lambda.getType.asInstanceOf[FunctionType].from.size lambda match { case Lambda(args, Let(theMapVar, FiniteMap(pairs), IfExpr( - MapIsDefinedAt(Variable(theMapVar1), UnwrapTuple(args2)), - MapGet(Variable(theMapVar2), UnwrapTuple(args3)), + MapIsDefinedAt(Variable(theMapVar1), targs2), + MapGet(Variable(theMapVar2), targs3), default - ))) if (args map { x: ValDef => x.toVariable }) == args2 && args2 == args3 && theMapVar == theMapVar1 && theMapVar == theMapVar2 => + ))) if { + val args2 = unwrapTuple(targs2, inSize) + val args3 = unwrapTuple(targs3, inSize) + (args map { x: ValDef => x.toVariable }) == args2 && + args2 == args3 && theMapVar == theMapVar1 && + theMapVar == theMapVar2 + } => Some(default, pairs) case Lambda(args, default) if (variablesOf(default) & args.toSet.map{x: ValDef => x.id}).isEmpty => Some(default, Seq()) @@ -379,26 +386,29 @@ object Extractors { } } - object UnwrapTuple { - def unapply(e: Expr): Option[Seq[Expr]] = Option(e) map { - case Tuple(subs) => subs - case other => Seq(other) - } + def unwrapTuple(e: Expr, isTuple: Boolean): Seq[Expr] = e.getType match { + case TupleType(subs) if isTuple => + for (ind <- 1 to subs.size) yield { tupleSelect(e, ind, isTuple) } + case _ if !isTuple => Seq(e) + case tp => sys.error(s"Calling unwrapTuple on non-tuple $e of type $tp") } + def unwrapTuple(e: Expr, expectedSize: Int): Seq[Expr] = unwrapTuple(e, expectedSize > 1) - object UnwrapTupleType { - def unapply(tp: TypeTree) = Option(tp) map { - case TupleType(subs) => subs - case other => Seq(other) - } + def unwrapTupleType(tp: TypeTree, isTuple: Boolean): Seq[TypeTree] = tp match { + case TupleType(subs) if isTuple => subs + case tp if !isTuple => Seq(tp) + case tp => sys.error(s"Calling unwrapTupleType on $tp") } + def unwrapTupleType(tp: TypeTree, expectedSize: Int): Seq[TypeTree] = + unwrapTupleType(tp, expectedSize > 1) - object UnwrapTuplePattern { - def unapply(p: Pattern): Option[Seq[Pattern]] = Option(p) map { - case TuplePattern(_,subs) => subs - case other => Seq(other) - } + def unwrapTuplePattern(p: Pattern, isTuple: Boolean): Seq[Pattern] = p match { + case TuplePattern(_, subs) if isTuple => subs + case tp if !isTuple => Seq(tp) + case tp => sys.error(s"Calling unwrapTuplePattern on $p") } + def unwrapTuplePattern(p: Pattern, expectedSize: Int): Seq[Pattern] = + unwrapTuplePattern(p, expectedSize > 1) object LetPattern { def apply(patt : Pattern, value: Expr, body: Expr) : Expr = { diff --git a/src/main/scala/leon/purescala/TreeOps.scala b/src/main/scala/leon/purescala/TreeOps.scala index a87805871840bed4bca0e6ca497246bb49c110e7..cb02eed026b0677173665ba7e4a7812a1725c70e 100644 --- a/src/main/scala/leon/purescala/TreeOps.scala +++ b/src/main/scala/leon/purescala/TreeOps.scala @@ -476,10 +476,10 @@ object TreeOps { def normalizeExpression(expr: Expr) : Expr = { def rec(e: Expr): Option[Expr] = e match { case TupleSelect(Let(id, v, b), ts) => - Some(Let(id, v, tupleSelect(b, ts))) + Some(Let(id, v, tupleSelect(b, ts, true))) case TupleSelect(LetTuple(ids, v, b), ts) => - Some(letTuple(ids, v, tupleSelect(b, ts))) + Some(letTuple(ids, v, tupleSelect(b, ts, true))) case IfExpr(c, thenn, elze) if (thenn == elze) && isDeterministic(e) => Some(thenn) @@ -580,7 +580,7 @@ object TreeOps { case l @ LetTuple(ids, tExpr: Terminal, body) if isDeterministic(body) => val substMap : Map[Expr,Expr] = ids.map(Variable(_) : Expr).zipWithIndex.toMap.map { - case (v,i) => (v -> tupleSelect(tExpr, i + 1).copiedFrom(v)) + case (v,i) => (v -> tupleSelect(tExpr, i + 1, true).copiedFrom(v)) } Some(replace(substMap, body)) @@ -611,7 +611,7 @@ object TreeOps { Some(body) } else if(total == 1) { val substMap : Map[Expr,Expr] = ids.map(Variable(_) : Expr).zipWithIndex.toMap.map { - case (v,i) => (v -> tupleSelect(tExpr, i + 1).copiedFrom(v)) + case (v,i) => (v -> tupleSelect(tExpr, i + 1, ids.size).copiedFrom(v)) } Some(replace(substMap, body)) @@ -719,7 +719,7 @@ object TreeOps { case TuplePattern(_, subps) => val TupleType(subts) = in.getType val subExprs = (subps zip subts).zipWithIndex map { - case ((p, t), index) => p.binder.map(_.toVariable).getOrElse(tupleSelect(in, index+1)) + case ((p, t), index) => p.binder.map(_.toVariable).getOrElse(tupleSelect(in, index+1, subps.size)) } // Special case to get rid of (a,b) match { case (c,d) => .. } @@ -803,7 +803,7 @@ object TreeOps { case TuplePattern(ob, subps) => { val TupleType(tpes) = in.getType assert(tpes.size == subps.size) - val subTests = subps.zipWithIndex.map{case (p, i) => rec(tupleSelect(in, i+1), p)} + val subTests = subps.zipWithIndex.map{case (p, i) => rec(tupleSelect(in, i+1, subps.size), p)} and(bind(ob, in) +: subTests: _*) } case LiteralPattern(ob,lit) => and(Equals(in,lit), bind(ob,in)) @@ -832,7 +832,7 @@ object TreeOps { val TupleType(tpes) = in.getType assert(tpes.size == subps.size) - val maps = subps.zipWithIndex.map{case (p, i) => mapForPattern(tupleSelect(in, i+1), p)} + val maps = subps.zipWithIndex.map{case (p, i) => mapForPattern(tupleSelect(in, i+1, subps.size), p)} val map = maps.foldLeft(Map.empty[Identifier,Expr])(_ ++ _) b match { case Some(id) => map + (id -> in) diff --git a/src/main/scala/leon/repair/Repairman.scala b/src/main/scala/leon/repair/Repairman.scala index 1a10635ff8ea64ca7f8eb150fde2cae91e255d9c..45dab972edb83cd11a52487bfb998d6a25f22d4f 100644 --- a/src/main/scala/leon/repair/Repairman.scala +++ b/src/main/scala/leon/repair/Repairman.scala @@ -10,7 +10,7 @@ import purescala.TreeOps._ import purescala.TypeTrees._ import purescala.DefOps._ import purescala.Constructors._ -import purescala.Extractors.UnwrapTuple +import purescala.Extractors.unwrapTuple import purescala.ScalaPrinter import evaluators._ import solvers._ @@ -412,7 +412,6 @@ class Repairman(ctx: LeonContext, initProgram: Program, fd: FunDef, verifTimeout import bonsai._ import bonsai.enumerators._ import utils.ExpressionGrammars.ValueGrammar - import purescala.Extractors.UnwrapTuple val maxEnumerated = 1000 val maxValid = 400 @@ -420,7 +419,7 @@ class Repairman(ctx: LeonContext, initProgram: Program, fd: FunDef, verifTimeout val evaluator = new CodeGenEvaluator(ctx, program, CodeGenParams(checkContracts = true)) val enum = new MemoizedEnumerator[TypeTree, Expr](ValueGrammar.getProductions _) - val inputs = enum.iterator(tupleTypeWrap(fd.params map { _.getType})).map{ case UnwrapTuple(is) => is } + val inputs = enum.iterator(tupleTypeWrap(fd.params map { _.getType})).map(unwrapTuple(_, fd.params.size)) val filtering: Seq[Expr] => Boolean = fd.precondition match { case None => @@ -493,7 +492,9 @@ class Repairman(ctx: LeonContext, initProgram: Program, fd: FunDef, verifTimeout val inputsMap = (p.as zip inputs).toMap (e.eval(s1, inputsMap), e.eval(s2, inputsMap)) match { - case (EvaluationResults.Successful(UnwrapTuple(r1)), EvaluationResults.Successful(UnwrapTuple(r2))) => + case (EvaluationResults.Successful(tr1), EvaluationResults.Successful(tr2)) => + val r1 = unwrapTuple(tr1, p.xs.size) + val r2 = unwrapTuple(tr2, p.xs.size) Some((InOutExample(inputs, r1), InOutExample(inputs, r2))) case _ => None diff --git a/src/main/scala/leon/solvers/templates/TemplateGenerator.scala b/src/main/scala/leon/solvers/templates/TemplateGenerator.scala index 4a13379f31581bf539656222a13a4a9a3e1c7014..741c750f3398e7c546914ffdc05b13ac9b541e3f 100644 --- a/src/main/scala/leon/solvers/templates/TemplateGenerator.scala +++ b/src/main/scala/leon/solvers/templates/TemplateGenerator.scala @@ -275,7 +275,7 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T]) { val m: Map[Expr, Expr] = if (ids.size == 1) { Map(Variable(ids.head) -> Variable(cid)) } else { - ids.zipWithIndex.map{ case (id, i) => Variable(id) -> tupleSelect(Variable(cid), i+1) }.toMap + ids.zipWithIndex.map{ case (id, i) => Variable(id) -> tupleSelect(Variable(cid), i+1, ids.size) }.toMap } storeGuarded(pathVar, replace(m, cond)) diff --git a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala index 1ef7d00d9a4edff5c6cbea09485ea52cf0be2c4e..b1626ec6c908958ed7b69fb09f01ba7e39f0dbc7 100644 --- a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala @@ -504,7 +504,7 @@ trait AbstractZ3Solver case LetTuple(ids, e, b) => { var ix = 1 z3Vars = z3Vars ++ ids.map((id) => { - val entry = (id -> rec(tupleSelect(e, ix))) + val entry = (id -> rec(tupleSelect(e, ix, ids.size))) ix += 1 entry }) diff --git a/src/main/scala/leon/synthesis/ExamplesFinder.scala b/src/main/scala/leon/synthesis/ExamplesFinder.scala index fade70b9463cdc951a0383d38e253287b9b8b0a8..679b6c5697a5b0cdabf6729c8ffce3fc9b73537c 100644 --- a/src/main/scala/leon/synthesis/ExamplesFinder.scala +++ b/src/main/scala/leon/synthesis/ExamplesFinder.scala @@ -151,9 +151,8 @@ class ExamplesFinder(ctx: LeonContext, program: Program) { // We will instantiate them according to a simple grammar to get them. val enum = new MemoizedEnumerator[TypeTree, Expr](ValueGrammar.getProductions _) val values = enum.iterator(tupleTypeWrap(freeVars.map{ _.getType })) - val instantiations = values map { - case UnwrapTuple(ins) => - (freeVars zip ins).toMap + val instantiations = values.map { + v => freeVars.zip(unwrapTuple(v, freeVars.size)).toMap } def filterGuard(e: Expr, mapping: Map[Identifier, Expr]): Boolean = cs.optGuard match { diff --git a/src/main/scala/leon/synthesis/Solution.scala b/src/main/scala/leon/synthesis/Solution.scala index d538a5a1a54e9171cd32b7b06029a225fab37936..5bf161e7b52ef822d68b7e25fbb9951b89de8770 100644 --- a/src/main/scala/leon/synthesis/Solution.scala +++ b/src/main/scala/leon/synthesis/Solution.scala @@ -44,7 +44,7 @@ class Solution(val pre: Expr, val defs: Set[FunDef], val term: Expr, val isTrust term.getType match { case TupleType(ts) => val t = FreshIdentifier("t", term.getType, true) - val newTerm = Let(t, term, tupleWrap(indices.map(i => tupleSelect(t.toVariable, i+1)))) + val newTerm = Let(t, term, tupleWrap(indices.map(i => tupleSelect(t.toVariable, i+1, indices.size)))) Solution(pre, defs, newTerm) case _ => diff --git a/src/main/scala/leon/synthesis/Synthesizer.scala b/src/main/scala/leon/synthesis/Synthesizer.scala index 85c534e3f80ceb5871e4aaf57d4758a1d621de9b..c44e7005952e863a4fa7db0052d620bad0260cbb 100644 --- a/src/main/scala/leon/synthesis/Synthesizer.scala +++ b/src/main/scala/leon/synthesis/Synthesizer.scala @@ -105,16 +105,9 @@ class Synthesizer(val context : LeonContext, val ret = tupleTypeWrap(problem.xs.map(_.getType)) val res = Variable(FreshIdentifier("res", ret)) - val mapPost: Map[Expr, Expr] = - if (problem.xs.size > 1) { - problem.xs.zipWithIndex.map{ case (id, i) => - Variable(id) -> tupleSelect(res, i+1) - }.toMap - } else { - problem.xs.map{ case id => - Variable(id) -> res - }.toMap - } + val mapPost: Map[Expr, Expr] = problem.xs.zipWithIndex.map{ case (id, i) => + Variable(id) -> tupleSelect(res, i+1, problem.xs.size) + }.toMap val fd = new FunDef(FreshIdentifier(ci.fd.id.name+"_final", alwaysShowUniqueID = true), Nil, ret, problem.as.map(ValDef(_)), DefType.MethodDef) fd.precondition = Some(and(problem.pc, sol.pre)) diff --git a/src/main/scala/leon/termination/ChainComparator.scala b/src/main/scala/leon/termination/ChainComparator.scala index e53716e38dadadf04aa3f37bbd3278ce368dae1f..58e04d4237fb4387868ca482190f41cde17a7fe9 100644 --- a/src/main/scala/leon/termination/ChainComparator.scala +++ b/src/main/scala/leon/termination/ChainComparator.scala @@ -38,7 +38,7 @@ trait ChainComparator { self : StructuralSize with TerminationChecker => }) case TupleType(tpes) => powerSetToFunSet((0 until tpes.length).flatMap { case index => - rec(tpes(index)).map(recons => (e: Expr) => recons(tupleSelect(e, index + 1))) + rec(tpes(index)).map(recons => (e: Expr) => recons(tupleSelect(e, index + 1, true))) }) case _ => Set((e: Expr) => e) } @@ -54,7 +54,7 @@ trait ChainComparator { self : StructuralSize with TerminationChecker => }.toSet case TupleType(tpes) => (0 until tpes.length).flatMap { case index => - rec(tpes(index)).map(recons => (e: Expr) => recons(tupleSelect(e, index + 1))) + rec(tpes(index)).map(recons => (e: Expr) => recons(tupleSelect(e, index + 1, true))) }.toSet case _ => Set((e: Expr) => e) } diff --git a/src/main/scala/leon/termination/StructuralSize.scala b/src/main/scala/leon/termination/StructuralSize.scala index eae1f53d0a559fd1e35700fc4d7b501b7976075f..69c22f0a2b133b40dc8cc66afa2bcb7dee94287e 100644 --- a/src/main/scala/leon/termination/StructuralSize.scala +++ b/src/main/scala/leon/termination/StructuralSize.scala @@ -64,7 +64,7 @@ trait StructuralSize { }) FunctionInvocation(TypedFunDef(fd, ct.tps), Seq(expr)) case TupleType(argTypes) => argTypes.zipWithIndex.map({ - case (_, index) => size(tupleSelect(expr, index + 1)) + case (_, index) => size(tupleSelect(expr, index + 1, true)) }).foldLeft[Expr](InfiniteIntegerLiteral(0))(Plus(_,_)) case _ => InfiniteIntegerLiteral(0) } diff --git a/src/main/scala/leon/utils/UnitElimination.scala b/src/main/scala/leon/utils/UnitElimination.scala index e47339018ed7c9737345bbc36dc460a940a18529..2cca31c7499383cf7aaba8feeefbc56bf97631e9 100644 --- a/src/main/scala/leon/utils/UnitElimination.scala +++ b/src/main/scala/leon/utils/UnitElimination.scala @@ -53,7 +53,7 @@ object UnitElimination extends TransformationPhase { } private def simplifyType(tpe: TypeTree): TypeTree = tpe match { - case TupleType(tpes) => tupleTypeWrap(tpes.map(simplifyType).filterNot{ case UnitType => true case _ => false }) + case TupleType(tpes) => tupleTypeWrap(tpes.map(simplifyType).filterNot{ _ == UnitType }) case t => t } @@ -72,12 +72,10 @@ object UnitElimination extends TransformationPhase { } case ts@TupleSelect(t, index) => { val TupleType(tpes) = t.getType - val selectionType = tpes(index-1) - val (_, newIndex) = tpes.zipWithIndex.foldLeft((0,-1)){ - case ((nbUnit, newIndex), (tpe, i)) => - if(i == index-1) (nbUnit, index - nbUnit) else (if(tpe == UnitType) nbUnit + 1 else nbUnit, newIndex) - } - tupleSelect(removeUnit(t), newIndex) + val simpleTypes = tpes map simplifyType + val newArity = tpes.count(_ != UnitType) + val newIndex = simpleTypes.take(index).filter(_ != UnitType).size + tupleSelect(removeUnit(t), newIndex, newArity) } case Let(id, e, b) => { if(id.getType == UnitType) diff --git a/src/main/scala/leon/xlang/ImperativeCodeElimination.scala b/src/main/scala/leon/xlang/ImperativeCodeElimination.scala index 12cb3552c4f340263ca0c76755d27769e7f5da7a..67faab2e3ef351cbf6f1dfa591480cef25fca4da 100644 --- a/src/main/scala/leon/xlang/ImperativeCodeElimination.scala +++ b/src/main/scala/leon/xlang/ImperativeCodeElimination.scala @@ -70,15 +70,15 @@ object ImperativeCodeElimination extends LeonPhase[Program, (Program, Set[FunDef val modifiedVars: Seq[Identifier] = (tFun.keys ++ eFun.keys).toSet.intersect(varInScope).toSeq val resId = FreshIdentifier("res", iteRType) - val freshIds = modifiedVars.map(id => FreshIdentifier(id.name, id.getType)) - val iteType = if(modifiedVars.isEmpty) resId.getType else tupleTypeWrap(resId.getType +: freshIds.map(_.getType)) + val freshIds = modifiedVars.map( { _.freshen }) + val iteType = tupleTypeWrap(resId.getType +: freshIds.map(_.getType)) - val thenVal = if(modifiedVars.isEmpty) tRes else tupleWrap(tRes +: modifiedVars.map(vId => tFun.get(vId) match { + val thenVal = tupleWrap(tRes +: modifiedVars.map(vId => tFun.get(vId) match { case Some(newId) => newId.toVariable case None => vId.toVariable })) - val elseVal = if(modifiedVars.isEmpty) eRes else tupleWrap(eRes +: modifiedVars.map(vId => eFun.get(vId) match { + val elseVal = tupleWrap(eRes +: modifiedVars.map(vId => eFun.get(vId) match { case Some(newId) => newId.toVariable case None => vId.toVariable })) @@ -87,16 +87,13 @@ object ImperativeCodeElimination extends LeonPhase[Program, (Program, Set[FunDef val scope = ((body: Expr) => { val tupleId = FreshIdentifier("t", iteType) - cScope( - Let(tupleId, iteExpr, - if(freshIds.isEmpty) - Let(resId, tupleId.toVariable, body) - else - Let(resId, tupleSelect(tupleId.toVariable, 1), - freshIds.zipWithIndex.foldLeft(body)((b, id) => - Let(id._1, - tupleSelect(tupleId.toVariable, id._2 + 2), - b)))).copiedFrom(expr)) + cScope( Let(tupleId, iteExpr, Let( + resId, + tupleSelect(tupleId.toVariable, 1, modifiedVars.nonEmpty), + freshIds.zipWithIndex.foldLeft(body)((b, id) => + Let(id._1, tupleSelect(tupleId.toVariable, id._2 + 2, true), b) + )) + ).copiedFrom(expr)) }) (resId.toVariable, scope, cFun ++ modifiedVars.zip(freshIds).toMap) @@ -110,10 +107,10 @@ object ImperativeCodeElimination extends LeonPhase[Program, (Program, Set[FunDef val modifiedVars: Seq[Identifier] = csesFun.toSet.flatMap((m: Map[Identifier, Identifier]) => m.keys).intersect(varInScope).toSeq val resId = FreshIdentifier("res", m.getType) val freshIds = modifiedVars.map(id => FreshIdentifier(id.name, id.getType)) - val matchType = if(modifiedVars.isEmpty) resId.getType else tupleTypeWrap(resId.getType +: freshIds.map(_.getType)) + val matchType = tupleTypeWrap(resId.getType +: freshIds.map(_.getType)) val csesVals = csesRes.zip(csesFun).map{ - case (cRes, cFun) => if(modifiedVars.isEmpty) cRes else tupleWrap(cRes +: modifiedVars.map(vId => cFun.get(vId) match { + case (cRes, cFun) => tupleWrap(cRes +: modifiedVars.map(vId => cFun.get(vId) match { case Some(newId) => newId.toVariable case None => vId.toVariable })) @@ -130,14 +127,13 @@ object ImperativeCodeElimination extends LeonPhase[Program, (Program, Set[FunDef val tupleId = FreshIdentifier("t", matchType) scrutScope( Let(tupleId, matchE, - if(freshIds.isEmpty) - Let(resId, tupleId.toVariable, body) - else - Let(resId, tupleSelect(tupleId.toVariable, 1), - freshIds.zipWithIndex.foldLeft(body)((b, id) => - Let(id._1, - tupleSelect(tupleId.toVariable, id._2 + 2), - b))))) + Let(resId, tupleSelect(tupleId.toVariable, 1, freshIds.nonEmpty), + freshIds.zipWithIndex.foldLeft(body)((b, id) => + Let(id._1, tupleSelect(tupleId.toVariable, id._2 + 2, true), b) + ) + ) + ) + ) }) (resId.toVariable, scope, scrutFun ++ modifiedVars.zip(freshIds).toMap) @@ -155,7 +151,7 @@ object ImperativeCodeElimination extends LeonPhase[Program, (Program, Set[FunDef val whileFunVars = modifiedVars.map(id => FreshIdentifier(id.name, id.getType)) val modifiedVars2WhileFunVars = modifiedVars.zip(whileFunVars).toMap val whileFunValDefs = whileFunVars.map(ValDef(_)) - val whileFunReturnType = if(whileFunVars.size == 1) whileFunVars.head.getType else tupleTypeWrap(whileFunVars.map(_.getType)) + val whileFunReturnType = tupleTypeWrap(whileFunVars.map(_.getType)) val whileFunDef = new FunDef(FreshIdentifier(parent.id.name), Nil, whileFunReturnType, whileFunValDefs,DefType.MethodDef).setPos(wh) wasLoop += whileFunDef @@ -170,10 +166,9 @@ object ImperativeCodeElimination extends LeonPhase[Program, (Program, Set[FunDef val resVar = Variable(FreshIdentifier("res", whileFunReturnType)) val whileFunVars2ResultVars: Map[Expr, Expr] = - if(whileFunVars.size == 1) - Map(whileFunVars.head.toVariable -> resVar) - else - whileFunVars.zipWithIndex.map{ case (v, i) => (v.toVariable, tupleSelect(resVar, i+1)) }.toMap + whileFunVars.zipWithIndex.map{ case (v, i) => + (v.toVariable, tupleSelect(resVar, i+1, whileFunVars.size)) + }.toMap val modifiedVars2ResultVars: Map[Expr, Expr] = modifiedVars.map(id => (id.toVariable, whileFunVars2ResultVars(modifiedVars2WhileFunVars(id).toVariable))).toMap @@ -190,20 +185,16 @@ object ImperativeCodeElimination extends LeonPhase[Program, (Program, Set[FunDef case None => BooleanLiteral(true) }))) - val finalVars = modifiedVars.map(id => FreshIdentifier(id.name, id.getType)) + val finalVars = modifiedVars.map(_.freshen) val finalScope = ((body: Expr) => { val tupleId = FreshIdentifier("t", whileFunReturnType) - LetDef( - whileFunDef, - Let(tupleId, - FunctionInvocation(whileFunDef.typed, modifiedVars.map(_.toVariable)).setPos(wh), - if(finalVars.size == 1) - Let(finalVars.head, tupleId.toVariable, body) - else - finalVars.zipWithIndex.foldLeft(body)((b, id) => - Let(id._1, - tupleSelect(tupleId.toVariable, id._2 + 1), - b)))) + LetDef( whileFunDef, Let( + tupleId, + FunctionInvocation(whileFunDef.typed, modifiedVars.map(_.toVariable)).setPos(wh), + finalVars.zipWithIndex.foldLeft(body){(b, id) => + Let(id._1, tupleSelect(tupleId.toVariable, id._2 + 1, finalVars.size), b) + } + )) }) (UnitLiteral(), finalScope, modifiedVars.zip(finalVars).toMap)