From 3377476747fe890967db8377362e4cf0458fa0a5 Mon Sep 17 00:00:00 2001 From: Manos Koukoutos <emmanouil.koukoutos@epfl.ch> Date: Fri, 27 Feb 2015 18:27:07 +0100 Subject: [PATCH] Eliminate rewriteTuples. Use constructors/extractors. Use constructors/extractors everywhere (except solvers). Eliminate any Tuples with <2 elements from Leon. Deprecate rewriteTuples. Rewrite FiniteLambda. Not necessarily better, just clearer. --- .../scala/leon/codegen/CompilationUnit.scala | 20 +- .../codegen/runtime/ChooseEntryPoint.scala | 6 +- .../scala/leon/datagen/VanuatooDataGen.scala | 25 +- .../leon/evaluators/RecursiveEvaluator.scala | 7 +- .../scala/leon/purescala/Constructors.scala | 55 ++--- .../scala/leon/purescala/Extractors.scala | 31 ++- src/main/scala/leon/purescala/TreeOps.scala | 226 +++++++++--------- src/main/scala/leon/purescala/Trees.scala | 16 +- src/main/scala/leon/purescala/TypeTrees.scala | 19 +- src/main/scala/leon/repair/Repairman.scala | 12 +- .../solvers/smtlib/SMTLIBCVC4Target.scala | 6 +- .../leon/solvers/smtlib/SMTLIBTarget.scala | 4 +- .../solvers/templates/TemplateGenerator.scala | 2 +- .../leon/solvers/z3/AbstractZ3Solver.scala | 12 +- .../scala/leon/synthesis/ConvertHoles.scala | 2 +- .../leon/synthesis/ConvertWithOracles.scala | 2 +- src/main/scala/leon/synthesis/Rules.scala | 2 +- src/main/scala/leon/synthesis/Solution.scala | 2 +- .../scala/leon/termination/ChainBuilder.scala | 2 +- .../leon/termination/ChainComparator.scala | 8 +- .../leon/termination/ChainProcessor.scala | 5 +- .../leon/termination/LoopProcessor.scala | 4 +- .../leon/termination/RelationProcessor.scala | 2 +- .../scala/leon/termination/Strengthener.scala | 12 +- src/main/scala/leon/utils/Simplifiers.scala | 2 - .../scala/leon/utils/UnitElimination.scala | 10 +- .../xlang/ImperativeCodeElimination.scala | 30 +-- 27 files changed, 245 insertions(+), 279 deletions(-) diff --git a/src/main/scala/leon/codegen/CompilationUnit.scala b/src/main/scala/leon/codegen/CompilationUnit.scala index c522be854..0e15bc6db 100644 --- a/src/main/scala/leon/codegen/CompilationUnit.scala +++ b/src/main/scala/leon/codegen/CompilationUnit.scala @@ -170,13 +170,11 @@ class CompilationUnit(val ctx: LeonContext, case f @ purescala.Extractors.FiniteLambda(dflt, els) => val l = new leon.codegen.runtime.FiniteLambda(exprToJVM(dflt)) - for ((k,v) <- els) { - val jvmK = if (f.getType.from.size == 1) { - exprToJVM(Tuple(Seq(k))) - } else { - exprToJVM(k) - } - l.add(jvmK.asInstanceOf[leon.codegen.runtime.Tuple], exprToJVM(v)) + for ((UnwrapTuple(ks),v) <- els) { + // Force tuple even with 1/0 elems. + val kJvm = tupleConstructor.newInstance(ks.map(exprToJVM _).toArray).asInstanceOf[leon.codegen.runtime.Tuple] + val vJvm = exprToJVM(v) + l.add(kJvm,vJvm) } l @@ -215,11 +213,11 @@ class CompilationUnit(val ctx: LeonContext, CaseClass(cct, (fields zip cct.fieldsTypes).map { case (e, tpe) => jvmToExpr(e, tpe) }) - case (tpl: runtime.Tuple, TupleType(stpe)) => + case (tpl: runtime.Tuple, UnwrapTupleType(stpe)) => val elems = stpe.zipWithIndex.map { case (tpe, i) => jvmToExpr(tpl.get(i), tpe) } - Tuple(elems) + tupleWrap(elems) case (gv @ GenericValue(gtp, id), tp: TypeParameter) => if (gtp == tp) gv @@ -292,10 +290,10 @@ class CompilationUnit(val ctx: LeonContext, mkExpr(e, ch)(Locals(newMapping, Map.empty, Map.empty, true)) e.getType match { - case Int32Type | BooleanType => + case Int32Type | BooleanType | UnitType => ch << IRETURN - case IntegerType | UnitType | _: TupleType | _: SetType | _: MapType | _: AbstractClassType | _: CaseClassType | _: ArrayType | _: FunctionType | _: TypeParameter => + case IntegerType | _: TupleType | _: SetType | _: MapType | _: AbstractClassType | _: CaseClassType | _: ArrayType | _: FunctionType | _: TypeParameter => ch << ARETURN case other => diff --git a/src/main/scala/leon/codegen/runtime/ChooseEntryPoint.scala b/src/main/scala/leon/codegen/runtime/ChooseEntryPoint.scala index b871b2652..2835a8f41 100644 --- a/src/main/scala/leon/codegen/runtime/ChooseEntryPoint.scala +++ b/src/main/scala/leon/codegen/runtime/ChooseEntryPoint.scala @@ -65,11 +65,7 @@ object ChooseEntryPoint { val valModel = valuateWithModel(model) _ val res = p.xs.map(valModel) - val leonRes = if (res.size > 1) { - LeonTuple(res) - } else { - res(0) - } + val leonRes = tupleWrap(res) val total = System.currentTimeMillis-tStart; diff --git a/src/main/scala/leon/datagen/VanuatooDataGen.scala b/src/main/scala/leon/datagen/VanuatooDataGen.scala index 9b12b5beb..3b29ef4ac 100644 --- a/src/main/scala/leon/datagen/VanuatooDataGen.scala +++ b/src/main/scala/leon/datagen/VanuatooDataGen.scala @@ -88,7 +88,7 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { case tt @ TupleType(parts) => constructors.getOrElse(tt, { - val cs = List(Constructor[Expr, TypeTree](parts, tt, s => Tuple(s), tt.toString)) + val cs = List(Constructor[Expr, TypeTree](parts, tt, s => tupleWrap(s), tt.toString)) constructors += tt -> cs cs }) @@ -110,10 +110,10 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { val subs = (1 to size).flatMap(_ => from :+ to).toList Constructor[Expr, TypeTree](subs, ft, { s => val args = from.map(tpe => FreshIdentifier("x", tpe, true)) - val argsTuple = Tuple(args.map(_.toVariable)) + val argsTuple = tupleWrap(args.map(_.toVariable)) val grouped = s.grouped(from.size + 1).toSeq val body = grouped.init.foldRight(grouped.last.last) { case (t, elze) => - IfExpr(Equals(argsTuple, Tuple(t.init)), t.last, elze) + IfExpr(Equals(argsTuple, tupleWrap(t.init)), t.last, elze) } Lambda(args.map(id => ValDef(id, id.getType)), body) }, ft.toString + "@" + size) @@ -196,7 +196,7 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { (AnyPattern[Expr, TypeTree](), false) } - case (t: codegen.runtime.Tuple, tt @ TupleType(parts)) => + case (t: codegen.runtime.Tuple, tt@UnwrapTupleType(parts)) => val r = t.__getRead() val c = getConstructors(tt)(0) @@ -221,21 +221,21 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { type InstrumentedResult = (EvaluationResults.Result, Option[vanuatoo.Pattern[Expr, TypeTree]]) - def compile(expression : Expr, argorder : Seq[Identifier]) : Option[Tuple=>InstrumentedResult] = { + def compile(expression : Expr, argorder : Seq[Identifier]) : Option[Expr=>InstrumentedResult] = { import leon.codegen.runtime.LeonCodeGenRuntimeException import leon.codegen.runtime.LeonCodeGenEvaluationException try { - val ttype = TupleType(argorder.map(_.getType)) + val ttype = tupleTypeWrap(argorder.map(_.getType)) val tid = FreshIdentifier("tup", ttype) - val map = argorder.zipWithIndex.map{ case (id, i) => (id -> TupleSelect(Variable(tid), i+1)) }.toMap + val map = argorder.zipWithIndex.map{ case (id, i) => (id -> tupleSelect(Variable(tid), i+1)) }.toMap val newExpr = replaceFromIDs(map, expression) val ce = unit.compileExpression(newExpr, Seq(tid)) - Some((args : Tuple) => { + Some((args : Expr) => { try { val monitor = new LeonCodeGenRuntimeMonitor(unit.params.maxFunctionInvocations) val jvmArgs = ce.argsToJVM(Seq(args), monitor ) @@ -268,7 +268,7 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { }) } catch { case t: Throwable => - ctx.reporter.warning("Error while compiling expression: "+t.getMessage) + ctx.reporter.warning("Error while compiling expression: "+t.getMessage); t.printStackTrace None } } @@ -299,7 +299,7 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { val maxIsomorphicModels = maxValid+1; - val it = gen.enumerate(TupleType(ins.map(_.getType))) + val it = gen.enumerate(tupleTypeWrap(ins.map(_.getType))) return new Iterator[Seq[Expr]] { var total = 0 @@ -325,7 +325,7 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { def computeNext(): Option[Seq[Expr]] = { //return None while(total < maxEnumerated && found < maxValid && it.hasNext && !interrupted.get) { - val model = it.next.asInstanceOf[Tuple] + val model = it.next//.asInstanceOf[Tuple] if (model eq null) { total = maxEnumerated @@ -357,7 +357,8 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { it.skipIsomorphic() } - return Some(model.exprs); + val UnwrapTuple(exprs) = model + return Some(exprs); } //if (total % 1000 == 0) { diff --git a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala index b62d2ffc5..546548244 100644 --- a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala +++ b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala @@ -537,12 +537,7 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int val valModel = valuateWithModel(model) _ val res = p.xs.map(valModel) - val leonRes = if (res.size > 1) { - Tuple(res) - } else { - res(0) - } - + val leonRes = tupleWrap(res) val total = System.currentTimeMillis-tStart; ctx.reporter.debug("Synthesis took "+total+"ms") diff --git a/src/main/scala/leon/purescala/Constructors.scala b/src/main/scala/leon/purescala/Constructors.scala index f37136d4d..fd1058926 100644 --- a/src/main/scala/leon/purescala/Constructors.scala +++ b/src/main/scala/leon/purescala/Constructors.scala @@ -11,12 +11,16 @@ object Constructors { import TypeTreeOps._ import Common._ import TypeTrees._ + import purescala.Extractors.UnwrapTupleType def tupleSelect(t: Expr, index: Int) = t match { case Tuple(es) => es(index-1) - case _ => + case _ if t.getType.isInstanceOf[TupleType] => TupleSelect(t, index) + case _ => + if (index == 1) t + else sys.error(s"Trying to construct TupleSelect with non-tuple $t and index $index!=1") } def letTuple(binders: Seq[Identifier], value: Expr, body: Expr) = binders match { @@ -38,14 +42,6 @@ object Constructors { Extractors.LetPattern(TuplePattern(None,binders map { b => WildcardPattern(Some(b)) }), value, body) } - def tupleChoose(ch: Choose): Expr = { - if (ch.vars.size > 1) { - ch - } else { - Tuple(Seq(ch)) - } - } - def tupleWrap(es: Seq[Expr]): Expr = es match { case Seq() => UnitLiteral() case Seq(elem) => elem @@ -218,32 +214,21 @@ object Constructors { else NonemptyArray(els, defaultLength) } - def finiteLambda(dflt: Expr, els: Seq[(Expr, Expr)], tpe: FunctionType): Lambda = { - val args = tpe.from.zipWithIndex.map { case (tpe, idx) => - ValDef(FreshIdentifier(s"x${idx + 1}", tpe), tpe) - } - - assume(els.isEmpty || !tpe.from.isEmpty, "Can't provide finite mapping for lambda without parameters") - - lazy val (tupleArgs, tupleKey) = if (tpe.from.size > 1) { - val tpArgs = Tuple(args.map(_.toVariable)) - val key = (x: Expr) => x - (tpArgs, key) - } else { // note that value is lazy, so if tpe.from.size == 0, foldRight will never access (tupleArgs, tupleKey) - val tpArgs = args.head.toVariable - val key = (x: Expr) => { - if (isSubtypeOf(x.getType, tpe.from.head)) x - else if (isSubtypeOf(x.getType, TupleType(tpe.from))) x.asInstanceOf[Tuple].exprs.head - else throw new RuntimeException("Can't determine key tuple state : " + x + " of " + tpe) - } - (tpArgs, key) - } - - val body = els.toSeq.foldRight(dflt) { case ((k, v), elze) => - IfExpr(Equals(tupleArgs, tupleKey(k)), v, elze) + def finiteLambda(default: Expr, els: Seq[(Expr, Expr)], inputTypes: Seq[TypeTree]): Lambda = { + val UnwrapTupleType(argTypes) = els.headOption.map{_._1.getType}.getOrElse(tupleTypeWrap(inputTypes)) + val args = argTypes map { argType => ValDef(FreshIdentifier("x", argType, true), argType) } + if (els.isEmpty) { + Lambda(args, default) + } else { + val theMap = NonemptyMap(els) + val theMapVar = FreshIdentifier("pairs", theMap.getType, true) + val argsAsExpr = tupleWrap(args map { _.toVariable }) + val body = Let(theMapVar, theMap, IfExpr( + MapIsDefinedAt(Variable(theMapVar), argsAsExpr), + MapGet(Variable(theMapVar), argsAsExpr), + default + )) + Lambda(args, body) } - - Lambda(args, body) } - } diff --git a/src/main/scala/leon/purescala/Extractors.scala b/src/main/scala/leon/purescala/Extractors.scala index 291aa8487..fc53487c8 100644 --- a/src/main/scala/leon/purescala/Extractors.scala +++ b/src/main/scala/leon/purescala/Extractors.scala @@ -287,19 +287,17 @@ object Extractors { object FiniteLambda { def unapply(lambda: Lambda): Option[(Expr, Seq[(Expr, Expr)])] = { - val args = lambda.args.map(_.toVariable) - lazy val argsTuple = if (lambda.args.size > 1) Tuple(args) else args.head - - def rec(body: Expr): Option[(Expr, Seq[(Expr, Expr)])] = body match { - case _ : IntLiteral | _ : UMinus | _ : BooleanLiteral | _ : GenericValue | _ : Tuple | - _ : CaseClass | FiniteArray(_, _, _) | FiniteSet(_) | FiniteMap(_) | _ : Lambda => - Some(body -> Seq.empty) - case IfExpr(Equals(tpArgs, key), expr, elze) if tpArgs == argsTuple => - rec(elze).map { case (dflt, mapping) => dflt -> ((key -> expr) +: mapping) } + lambda match { + case Lambda(args, Let(theMapVar, FiniteMap(pairs), IfExpr( + MapIsDefinedAt(Variable(theMapVar1), UnwrapTuple(args2)), + MapGet(Variable(theMapVar2), UnwrapTuple(args3)), + default + ))) if (args map { x: ValDef => x.toVariable }) == args2 && args2 == args3 && theMapVar == theMapVar1 && theMapVar == theMapVar2 => + Some(default, pairs) + case Lambda(args, default) if (variablesOf(default) & args.toSet.map{x: ValDef => x.id}).isEmpty => + Some(default, Seq()) case _ => None } - - rec(lambda.body) } } @@ -382,14 +380,21 @@ object Extractors { } object UnwrapTuple { - def unapply(e : Expr) : Option[Seq[Expr]] = Option(e) map { + def unapply(e: Expr): Option[Seq[Expr]] = Option(e) map { case Tuple(subs) => subs case other => Seq(other) } } + object UnwrapTupleType { + def unapply(tp: TypeTree) = Option(tp) map { + case TupleType(subs) => subs + case other => Seq(other) + } + } + object UnwrapTuplePattern { - def unapply(p : Pattern) : Option[Seq[Pattern]] = Option(p) map { + def unapply(p: Pattern): Option[Seq[Pattern]] = Option(p) map { case TuplePattern(_,subs) => subs case other => Seq(other) } diff --git a/src/main/scala/leon/purescala/TreeOps.scala b/src/main/scala/leon/purescala/TreeOps.scala index 345708fef..95bb91aa5 100644 --- a/src/main/scala/leon/purescala/TreeOps.scala +++ b/src/main/scala/leon/purescala/TreeOps.scala @@ -727,7 +727,7 @@ object TreeOps { case Tuple(ts) => ts zip subExprs case _ => - Seq(in -> Tuple(subExprs)) + Seq(in -> tupleWrap(subExprs)) } subst0 ++ ((subExprs zip subps) flatMap { @@ -1255,114 +1255,6 @@ object TreeOps { def collect(e: Expr, path: Seq[Expr]) = matcher(e).map(_ -> and(path: _*)) } - /** - * Eliminates tuples of arity 0 and 1. - * Used to simplify synthesis solutions - * - * Only rewrites local fundefs. - */ - def rewriteTuples(expr: Expr) : Expr = { - def mapType(tt : TypeTree) : Option[TypeTree] = tt match { - case TupleType(ts) => ts.size match { - case 0 => Some(UnitType) - case 1 => Some(ts(0)) - case _ => - val tss = ts.map(mapType) - if(tss.exists(_.isDefined)) { - Some(TupleType((tss zip ts).map(p => p._1.getOrElse(p._2)))) - } else { - None - } - } - case SetType(t) => mapType(t).map(SetType(_)) - case MultisetType(t) => mapType(t).map(MultisetType(_)) - case ArrayType(t) => mapType(t).map(ArrayType(_)) - case MapType(f,t) => - val (f2,t2) = (mapType(f),mapType(t)) - if(f2.isDefined || t2.isDefined) { - Some(MapType(f2.getOrElse(f), t2.getOrElse(t))) - } else { - None - } - case ft : FunctionType => None // FIXME - - case a : AbstractClassType => None - case cct : CaseClassType => - // This is really just one big assertion. We don't rewrite class defs. - val fieldTypes = cct.fields.map(_.tpe) - if(fieldTypes.exists(t => t match { - case TupleType(ts) if ts.size <= 1 => true - case _ => false - })) { - scala.sys.error("Cannot rewrite case class def that contains degenerate tuple types.") - } else { - None - } - case Untyped | BooleanType | Int32Type | IntegerType | UnitType | TypeParameter(_) => None - } - - var idMap = Map[Identifier, Identifier]() - var funDefMap = Map.empty[FunDef,FunDef] - - def fd2fd(funDef : FunDef) : FunDef = funDefMap.get(funDef) match { - case Some(fd) => fd - case None => - if(funDef.params.map(vd => mapType(vd.tpe)).exists(_.isDefined)) { - scala.sys.error("Cannot rewrite function def that takes degenerate tuple arguments,") - } - val newFD = mapType(funDef.returnType) match { - case None => funDef - case Some(rt) => - val fd = new FunDef(FreshIdentifier(funDef.id.name, alwaysShowUniqueID = true), funDef.tparams, rt, funDef.params, funDef.defType) - // These will be taken care of in the recursive traversal. - fd.body = funDef.body - fd.precondition = funDef.precondition - funDef.postcondition match { - case Some((id, post)) => - val freshId = FreshIdentifier(id.name, rt, true) - idMap += id -> freshId - fd.postcondition = Some((freshId, post)) - case None => - fd.postcondition = None - } - fd - } - funDefMap = funDefMap.updated(funDef, newFD) - newFD - } - - import synthesis.Witnesses.Terminating - - def pre(e : Expr) : Expr = e match { - case Tuple(Seq()) => UnitLiteral() - case Variable(id) if idMap contains id => Variable(idMap(id)) - - case Error(tpe, err) => Error(mapType(tpe).getOrElse(e.getType), err).copiedFrom(e) - case Tuple(Seq(s)) => pre(s) - - case ts @ TupleSelect(t, 1) => t.getType match { - case TupleOneType(_) => pre(t) - case _ => ts - } - - case LetTuple(bs, v, bdy) if bs.size == 1 => - Let(bs(0), v, bdy) - - case l @ LetDef(fd, bdy) => - LetDef(fd2fd(fd), bdy) - - case FunctionInvocation(tfd, args) => - FunctionInvocation(fd2fd(tfd.fd).typed(tfd.tps), args) - - case Terminating(tfd, args) => - Terminating(fd2fd(tfd.fd).typed(tfd.tps), args) - - case _ => e - } - - simplePreTransform(pre)(expr) - } - def patternSize(p: Pattern): Int = p match { case wp: WildcardPattern => 1 @@ -2120,7 +2012,114 @@ object TreeOps { @deprecated("Use exists instead", "Leon 0.2.1") def contains(e: Expr, matcher: Expr => Boolean): Boolean = exists(matcher)(e) - + + /** + * Eliminates tuples of arity 0 and 1. + * Used to simplify synthesis solutions + * + * Only rewrites local fundefs. + */ + @deprecated("Use purescala.Constructors.tuple* and purescala.Extractors.Unwrap* " + + "to avoid creation of tuples of size 0 and 1", "Leon 3.0.0" + ) + def rewriteTuples(expr: Expr) : Expr = { + def mapType(tt : TypeTree) : Option[TypeTree] = tt match { + case TupleType(ts) => ts.size match { + case 0 => Some(UnitType) + case 1 => Some(ts(0)) + case _ => + val tss = ts.map(mapType) + if(tss.exists(_.isDefined)) { + Some(TupleType((tss zip ts).map(p => p._1.getOrElse(p._2)))) + } else { + None + } + } + case SetType(t) => mapType(t).map(SetType(_)) + case MultisetType(t) => mapType(t).map(MultisetType(_)) + case ArrayType(t) => mapType(t).map(ArrayType(_)) + case MapType(f,t) => + val (f2,t2) = (mapType(f),mapType(t)) + if(f2.isDefined || t2.isDefined) { + Some(MapType(f2.getOrElse(f), t2.getOrElse(t))) + } else { + None + } + case ft : FunctionType => None // FIXME + + case a : AbstractClassType => None + case cct : CaseClassType => + // This is really just one big assertion. We don't rewrite class defs. + val fieldTypes = cct.fields.map(_.tpe) + if(fieldTypes.exists(t => t match { + case TupleType(ts) if ts.size <= 1 => true + case _ => false + })) { + scala.sys.error("Cannot rewrite case class def that contains degenerate tuple types.") + } else { + None + } + case Untyped | BooleanType | Int32Type | IntegerType | UnitType | TypeParameter(_) => None + } + + var idMap = Map[Identifier, Identifier]() + var funDefMap = Map.empty[FunDef,FunDef] + + def fd2fd(funDef : FunDef) : FunDef = funDefMap.get(funDef) match { + case Some(fd) => fd + case None => + if(funDef.params.map(vd => mapType(vd.tpe)).exists(_.isDefined)) { + scala.sys.error("Cannot rewrite function def that takes degenerate tuple arguments,") + } + val newFD = mapType(funDef.returnType) match { + case None => funDef + case Some(rt) => + val fd = new FunDef(FreshIdentifier(funDef.id.name, alwaysShowUniqueID = true), funDef.tparams, rt, funDef.params, funDef.defType) + // These will be taken care of in the recursive traversal. + fd.body = funDef.body + fd.precondition = funDef.precondition + funDef.postcondition match { + case Some((id, post)) => + val freshId = FreshIdentifier(id.name, rt, true) + idMap += id -> freshId + fd.postcondition = Some((freshId, post)) + case None => + fd.postcondition = None + } + fd + } + funDefMap = funDefMap.updated(funDef, newFD) + newFD + } + + import synthesis.Witnesses.Terminating + + def pre(e : Expr) : Expr = e match { + case Tuple(Seq()) => println("Tuple0!"); UnitLiteral() + case Variable(id) if idMap contains id => Variable(idMap(id)) + + case Error(tpe, err) => Error(mapType(tpe).getOrElse(e.getType), err).copiedFrom(e) + case Tuple(Seq(s)) => println("Tuple1!"); pre(s) + + case LetTuple(bs, v, bdy) if bs.size == 1 => + Let(bs(0), v, bdy) + + case l @ LetDef(fd, bdy) => + LetDef(fd2fd(fd), bdy) + + case FunctionInvocation(tfd, args) => + FunctionInvocation(fd2fd(tfd.fd).typed(tfd.tps), args) + + case Terminating(tfd, args) => + Terminating(fd2fd(tfd.fd).typed(tfd.tps), args) + + case _ => e + } + + simplePreTransform(pre)(expr) + } + + /* * Transforms complicated Ifs into multiple nested if blocks * It will decompose every OR clauses, and it will group AND clauses checking @@ -2257,11 +2256,8 @@ object TreeOps { val (scrutinees, patterns) = scrutSet.toSeq.map(s => (s, computePatternFor(conditions(s), s))).unzip - val (scrutinee, pattern) = if (scrutinees.size > 1) { - (Tuple(scrutinees), TuplePattern(None, patterns)) - } else { - (scrutinees.head, patterns.head) - } + val scrutinee = tupleWrap(scrutinees) + val pattern = tuplePatternWrap(patterns) // We use searchAndReplace to replace the biggest match first // (topdown). diff --git a/src/main/scala/leon/purescala/Trees.scala b/src/main/scala/leon/purescala/Trees.scala index a3ce2caf7..ea8def16d 100644 --- a/src/main/scala/leon/purescala/Trees.scala +++ b/src/main/scala/leon/purescala/Trees.scala @@ -52,7 +52,7 @@ object Trees { case class Choose(vars: List[Identifier], pred: Expr, var impl: Option[Expr] = None) extends Expr with NAryExtractable { require(!vars.isEmpty) - val getType = if (vars.size > 1) TupleType(vars.map(_.getType)) else vars.head.getType + val getType = tupleTypeWrap(vars.map(_.getType)) def extract = { Some((Seq(pred)++impl, (es: Seq[Expr]) => Choose(vars, es.head, es.tail.headOption).setPos(this))) @@ -115,11 +115,21 @@ object Trees { val getType = leastUpperBound(thenn.getType, elze.getType).getOrElse(Untyped) } - case class Tuple(exprs: Seq[Expr]) extends Expr { + + /* + * If you are not sure about the requirement you should use + * the tupleWrap in purescala.Constructors + */ + case class Tuple (exprs: Seq[Expr]) extends Expr { + require(exprs.size >= 2) val getType = TupleType(exprs.map(_.getType)) } - // Index is 1-based, first element of tuple is 1. + /* + * Index is 1-based, first element of tuple is 1. + * If you are not sure that tuple has a TupleType, + * you should use tupleSelect in pureScala.Constructors + */ case class TupleSelect(tuple: Expr, index: Int) extends Expr { require(index >= 1) diff --git a/src/main/scala/leon/purescala/TypeTrees.scala b/src/main/scala/leon/purescala/TypeTrees.scala index bad181324..1f5d88df3 100644 --- a/src/main/scala/leon/purescala/TypeTrees.scala +++ b/src/main/scala/leon/purescala/TypeTrees.scala @@ -50,18 +50,13 @@ object TypeTrees { def freshen = TypeParameter(id.freshen) } - case class TupleType(val bases: Seq[TypeTree]) extends TypeTree { + /* + * If you are not sure about the requirement, + * you should use tupleTypeWrap in purescala.Constructors + */ + case class TupleType (val bases: Seq[TypeTree]) extends TypeTree { lazy val dimension: Int = bases.length - } - - object TupleOneType { - def unapply(tt : TupleType) : Option[TypeTree] = if(tt == null) None else { - if(tt.bases.size == 1) { - Some(tt.bases.head) - } else { - None - } - } + require(dimension >= 2) } case class SetType(base: TypeTree) extends TypeTree @@ -129,7 +124,7 @@ object TypeTrees { def unapply(t: TypeTree): Option[(Seq[TypeTree], Seq[TypeTree] => TypeTree)] = t match { case CaseClassType(ccd, ts) => Some((ts, ts => CaseClassType(ccd, ts))) case AbstractClassType(acd, ts) => Some((ts, ts => AbstractClassType(acd, ts))) - case TupleType(ts) => Some((ts, TupleType(_))) + case TupleType(ts) => Some((ts, Constructors.tupleTypeWrap(_))) case ArrayType(t) => Some((Seq(t), ts => ArrayType(ts.head))) case SetType(t) => Some((Seq(t), ts => SetType(ts.head))) case MultisetType(t) => Some((Seq(t), ts => MultisetType(ts.head))) diff --git a/src/main/scala/leon/repair/Repairman.scala b/src/main/scala/leon/repair/Repairman.scala index 237c129e4..1a10635ff 100644 --- a/src/main/scala/leon/repair/Repairman.scala +++ b/src/main/scala/leon/repair/Repairman.scala @@ -10,6 +10,7 @@ import purescala.TreeOps._ import purescala.TypeTrees._ import purescala.DefOps._ import purescala.Constructors._ +import purescala.Extractors.UnwrapTuple import purescala.ScalaPrinter import evaluators._ import solvers._ @@ -478,13 +479,6 @@ class Repairman(ctx: LeonContext, initProgram: Program, fd: FunDef, verifTimeout val e = new DefaultEvaluator(ctx, program) - def unwrap(e: Expr) = if (p.xs.size > 1) { - val Tuple(es) = e - es - } else { - Seq(e) - } - if (s1 == s2) { None } else { @@ -499,8 +493,8 @@ class Repairman(ctx: LeonContext, initProgram: Program, fd: FunDef, verifTimeout val inputsMap = (p.as zip inputs).toMap (e.eval(s1, inputsMap), e.eval(s2, inputsMap)) match { - case (EvaluationResults.Successful(r1), EvaluationResults.Successful(r2)) => - Some((InOutExample(inputs, unwrap(r1)), InOutExample(inputs, unwrap(r2)))) + case (EvaluationResults.Successful(UnwrapTuple(r1)), EvaluationResults.Successful(UnwrapTuple(r2))) => + Some((InOutExample(inputs, r1), InOutExample(inputs, r2))) case _ => None } diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala index f7e98281a..3c6fb95d8 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala @@ -56,7 +56,7 @@ trait SMTLIBCVC4Target extends SMTLIBTarget { RawArrayValue(k, Map(), fromSMT(elem, v)) case (FunctionApplication(SimpleSymbol(SSymbol("__array_store_all__")), Seq(_, elem)), ft @ FunctionType(from,to)) => - finiteLambda(fromSMT(elem, to), Seq.empty, ft) + finiteLambda(fromSMT(elem, to), Seq.empty, from) case (FunctionApplication(SimpleSymbol(SSymbol("store")), Seq(arr, key, elem)), RawArrayType(k,v)) => val RawArrayValue(_, elems, base) = fromSMT(arr, tpe) @@ -64,7 +64,7 @@ trait SMTLIBCVC4Target extends SMTLIBTarget { case (FunctionApplication(SimpleSymbol(SSymbol("store")), Seq(arr, key, elem)), ft @ FunctionType(from,to)) => val FiniteLambda(dflt, mapping) = fromSMT(arr, tpe) - finiteLambda(dflt, mapping :+ (fromSMT(key, TupleType(from)) -> fromSMT(elem, to)), ft) + finiteLambda(dflt, mapping :+ (fromSMT(key, TupleType(from)) -> fromSMT(elem, to)), from) case (FunctionApplication(SimpleSymbol(SSymbol("singleton")), elems), SetType(base)) => finiteSet(elems.map(fromSMT(_, base)).toSet, base) @@ -83,7 +83,7 @@ trait SMTLIBCVC4Target extends SMTLIBTarget { // some versions of CVC4 seem to generate array constants with "as const" notation instead of the __array_store_all__ // one I've witnessed up to now. Don't know why this is happening... case (FunctionApplication(QualifiedIdentifier(SMTIdentifier(SSymbol("const"), _), _), Seq(elem)), ft @ FunctionType(from, to)) => - finiteLambda(fromSMT(elem, to), Seq.empty, ft) + finiteLambda(fromSMT(elem, to), Seq.empty, from) case (FunctionApplication(QualifiedIdentifier(SMTIdentifier(SSymbol("const"), _), _), Seq(elem)), RawArrayType(k, v)) => RawArrayValue(k, Map(), fromSMT(elem, v)) diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala index 64503feee..d13a366e8 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala @@ -82,7 +82,7 @@ trait SMTLIBTarget { r case ft @ FunctionType(from, to) => - finiteLambda(r.default, r.elems.toSeq, ft) + finiteLambda(r.default, r.elems.toSeq, from) case _ => unsupported("Unable to extract from raw array for "+tpe) @@ -549,7 +549,7 @@ trait SMTLIBTarget { CaseClass(cct, rargs) case tt: TupleType => val rargs = args.zip(tt.bases).map(fromSMT) - Tuple(rargs) + tupleWrap(rargs) case ArrayType(baseType) => val IntLiteral(size) = fromSMT(args(0), Int32Type) diff --git a/src/main/scala/leon/solvers/templates/TemplateGenerator.scala b/src/main/scala/leon/solvers/templates/TemplateGenerator.scala index 48c029f60..6c31bb28d 100644 --- a/src/main/scala/leon/solvers/templates/TemplateGenerator.scala +++ b/src/main/scala/leon/solvers/templates/TemplateGenerator.scala @@ -275,7 +275,7 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T]) { val m: Map[Expr, Expr] = if (ids.size == 1) { Map(Variable(ids.head) -> Variable(cid)) } else { - ids.zipWithIndex.map{ case (id, i) => Variable(id) -> TupleSelect(Variable(cid), i+1) }.toMap + ids.zipWithIndex.map{ case (id, i) => Variable(id) -> tupleSelect(Variable(cid), i+1) }.toMap } storeGuarded(pathVar, replace(m, cond)) diff --git a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala index 6eca49c3a..a91cfb3fe 100644 --- a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala @@ -470,7 +470,7 @@ trait AbstractZ3Solver case ft @ FunctionType(from, to) => sorts.toZ3OrCompute(ft) { - val fromSort = typeToSort(TupleType(from)) + val fromSort = typeToSort(tupleTypeWrap(from)) val toSort = typeToSort(to) z3.mkArraySort(fromSort, toSort) @@ -625,7 +625,7 @@ trait AbstractZ3Solver z3.mkApp(functionDefToDecl(tfd), args.map(rec(_)): _*) case fa @ Application(caller, args) => - z3.mkSelect(rec(caller), rec(Tuple(args))) + z3.mkSelect(rec(caller), rec(tupleWrap(args))) case ElementOfSet(e, s) => z3.mkSetMember(rec(e), rec(s)) case SubsetOf(s1, s2) => z3.mkSetSubset(rec(s1), rec(s2)) @@ -778,7 +778,7 @@ trait AbstractZ3Solver case LeonType(tp: TupleType) => val rargs = args.map(rec) - Tuple(rargs) + tupleWrap(rargs) case LeonType(at @ ArrayType(dt)) => assert(args.size == 2) @@ -815,13 +815,13 @@ trait AbstractZ3Solver finiteMap(values, kt, vt) } - case LeonType(tpe @ FunctionType(fts, tt)) => + case LeonType(FunctionType(fts, tt)) => model.getArrayValue(t) match { case None => throw new CantTranslateException(t) case Some((map, elseZ3Value)) => val leonElseValue = rec(elseZ3Value) val leonMap = map.toSeq.map(p => rec(p._1) -> rec(p._2)) - finiteLambda(leonElseValue, leonMap, tpe) + finiteLambda(leonElseValue, leonMap, fts) } case LeonType(tpe @ SetType(dt)) => @@ -899,7 +899,7 @@ trait AbstractZ3Solver } else { z3.getSort(t) match { case LeonType(t : TupleType) => - Tuple(args.map(rec)) + tupleWrap(args.map(rec)) case _ => import Z3DeclKind._ diff --git a/src/main/scala/leon/synthesis/ConvertHoles.scala b/src/main/scala/leon/synthesis/ConvertHoles.scala index d7a15f584..d0118c318 100644 --- a/src/main/scala/leon/synthesis/ConvertHoles.scala +++ b/src/main/scala/leon/synthesis/ConvertHoles.scala @@ -83,7 +83,7 @@ object ConvertHoles extends LeonPhase[Program, Program] { BooleanLiteral(true) } - letTuple(holes, tupleChoose(Choose(cids, pred)), withoutHoles) + letTuple(holes, Choose(cids, pred), withoutHoles) } else withoutHoles diff --git a/src/main/scala/leon/synthesis/ConvertWithOracles.scala b/src/main/scala/leon/synthesis/ConvertWithOracles.scala index 0d7948e41..f8072b71a 100644 --- a/src/main/scala/leon/synthesis/ConvertWithOracles.scala +++ b/src/main/scala/leon/synthesis/ConvertWithOracles.scala @@ -54,7 +54,7 @@ object ConvertWithOracle extends LeonPhase[Program, Program] { BooleanLiteral(true) } - Some(letTuple(os, tupleChoose(Choose(chooseOs, pred)), b)) + Some(letTuple(os, Choose(chooseOs, pred), b)) case None => None } diff --git a/src/main/scala/leon/synthesis/Rules.scala b/src/main/scala/leon/synthesis/Rules.scala index 399e045a5..3ae642fc3 100644 --- a/src/main/scala/leon/synthesis/Rules.scala +++ b/src/main/scala/leon/synthesis/Rules.scala @@ -139,7 +139,7 @@ case object RulePriorityDefault extends RulePriority(2) trait RuleDSL { this: Rule => - def subst(what: Tuple2[Identifier, Expr], in: Expr): Expr = replaceFromIDs(Map(what), in) + def subst(what: (Identifier, Expr), in: Expr): Expr = replaceFromIDs(Map(what), in) def substAll(what: Map[Identifier, Expr], in: Expr): Expr = replaceFromIDs(what, in) val forward: List[Solution] => Option[Solution] = { ss => ss.headOption } diff --git a/src/main/scala/leon/synthesis/Solution.scala b/src/main/scala/leon/synthesis/Solution.scala index a3388ac6c..d538a5a1a 100644 --- a/src/main/scala/leon/synthesis/Solution.scala +++ b/src/main/scala/leon/synthesis/Solution.scala @@ -44,7 +44,7 @@ class Solution(val pre: Expr, val defs: Set[FunDef], val term: Expr, val isTrust term.getType match { case TupleType(ts) => val t = FreshIdentifier("t", term.getType, true) - val newTerm = Let(t, term, tupleWrap(indices.map(i => TupleSelect(t.toVariable, i+1)))) + val newTerm = Let(t, term, tupleWrap(indices.map(i => tupleSelect(t.toVariable, i+1)))) Solution(pre, defs, newTerm) case _ => diff --git a/src/main/scala/leon/termination/ChainBuilder.scala b/src/main/scala/leon/termination/ChainBuilder.scala index 8380bf4d6..1f84bd921 100644 --- a/src/main/scala/leon/termination/ChainBuilder.scala +++ b/src/main/scala/leon/termination/ChainBuilder.scala @@ -128,7 +128,7 @@ trait ChainBuilder extends RelationBuilder { self: TerminationChecker with Stren def decreasing(relations: List[Relation]): Boolean = { val constraints = relations.map(relation => relationConstraints.get(relation).getOrElse { val Relation(funDef, path, FunctionInvocation(fd, args), _) = relation - val (e1, e2) = (Tuple(funDef.params.map(_.toVariable)), Tuple(args)) + val (e1, e2) = (tupleWrap(funDef.params.map(_.toVariable)), tupleWrap(args)) val constraint = if (solver.definitiveALL(implies(andJoin(path), self.softDecreasing(e1, e2)))) { if (solver.definitiveALL(implies(andJoin(path), self.sizeDecreasing(e1, e2)))) { StrongDecreasing diff --git a/src/main/scala/leon/termination/ChainComparator.scala b/src/main/scala/leon/termination/ChainComparator.scala index f1d146e7d..3b1a6aaee 100644 --- a/src/main/scala/leon/termination/ChainComparator.scala +++ b/src/main/scala/leon/termination/ChainComparator.scala @@ -25,10 +25,10 @@ trait ChainComparator { self : StructuralSize with TerminationChecker => private def flatTypesPowerset(tpe: TypeTree): Set[Expr => Expr] = { def powerSetToFunSet(l: TraversableOnce[Expr => Expr]): Set[Expr => Expr] = { - l.toSet.subsets.filter(_.nonEmpty).map((reconss : Set[Expr => Expr]) => reconss.toSeq match { - case Seq(x) => x - case seq => (e: Expr) => Tuple(seq.map(r => r(e))) - }).toSet + l.toSet.subsets.filter(_.nonEmpty).map{ + (reconss : Set[Expr => Expr]) => (e : Expr) => + tupleWrap(reconss.toSeq map { f => f(e) }) + }.toSet } def rec(tpe: TypeTree): Set[Expr => Expr] = tpe match { diff --git a/src/main/scala/leon/termination/ChainProcessor.scala b/src/main/scala/leon/termination/ChainProcessor.scala index 54508cdfc..6570865e9 100644 --- a/src/main/scala/leon/termination/ChainProcessor.scala +++ b/src/main/scala/leon/termination/ChainProcessor.scala @@ -9,6 +9,7 @@ import purescala.TypeTrees._ import purescala.Common._ import purescala.Extractors._ import purescala.Definitions._ +import purescala.Constructors.tupleWrap import scala.collection.mutable.{Map => MutableMap} @@ -38,11 +39,11 @@ class ChainProcessor(val checker: TerminationChecker with ChainBuilder with Chai def exprs(fd: FunDef): (Expr, Seq[(Seq[Expr], Expr)], Set[Chain]) = { val fdChains = chainsMap(fd)._2 - val e1 = Tuple(fd.params.map(_.toVariable)) + val e1 = tupleWrap(fd.params.map(_.toVariable)) val e2s = fdChains.toSeq.map { chain => val freshParams = chain.finalParams.map(arg => FreshIdentifier(arg.id.name, arg.id.getType, true)) val finalBindings = (chain.finalParams.map(_.id) zip freshParams).toMap - (chain.loop(finalSubst = finalBindings), Tuple(freshParams.map(_.toVariable))) + (chain.loop(finalSubst = finalBindings), tupleWrap(freshParams.map(_.toVariable))) } (e1, e2s, fdChains) diff --git a/src/main/scala/leon/termination/LoopProcessor.scala b/src/main/scala/leon/termination/LoopProcessor.scala index 99b9d9efc..4ff739a06 100644 --- a/src/main/scala/leon/termination/LoopProcessor.scala +++ b/src/main/scala/leon/termination/LoopProcessor.scala @@ -33,8 +33,8 @@ class LoopProcessor(val checker: TerminationChecker with ChainBuilder with Stren val finalBindings = (chain.funDef.params.map(_.id) zip freshParams).toMap val path = chain.loop(finalSubst = finalBindings) - val srcTuple = Tuple(chain.funDef.params.map(_.toVariable)) - val resTuple = Tuple(freshParams.map(_.toVariable)) + val srcTuple = tupleWrap(chain.funDef.params.map(_.toVariable)) + val resTuple = tupleWrap(freshParams.map(_.toVariable)) definitiveSATwithModel(andJoin(path :+ Equals(srcTuple, resTuple))) match { case Some(map) => diff --git a/src/main/scala/leon/termination/RelationProcessor.scala b/src/main/scala/leon/termination/RelationProcessor.scala index 86d02fff6..a4f871ead 100644 --- a/src/main/scala/leon/termination/RelationProcessor.scala +++ b/src/main/scala/leon/termination/RelationProcessor.scala @@ -27,7 +27,7 @@ class RelationProcessor( val formulas = problem.funDefs.map({ funDef => funDef -> checker.getRelations(funDef).collect({ case Relation(_, path, FunctionInvocation(tfd, args), _) if problem.funDefs(tfd.fd) => - val (e1, e2) = (Tuple(funDef.params.map(_.toVariable)), Tuple(args)) + val (e1, e2) = (tupleWrap(funDef.params.map(_.toVariable)), tupleWrap(args)) def constraint(expr: Expr) = implies(andJoin(path.toSeq), expr) val greaterThan = checker.sizeDecreasing(e1, e2) val greaterEquals = checker.softDecreasing(e1, e2) diff --git a/src/main/scala/leon/termination/Strengthener.scala b/src/main/scala/leon/termination/Strengthener.scala index 5d0395c7b..b45d0f874 100644 --- a/src/main/scala/leon/termination/Strengthener.scala +++ b/src/main/scala/leon/termination/Strengthener.scala @@ -25,7 +25,7 @@ trait Strengthener { self : TerminationChecker with RelationComparator with Rela val (res, postcondition) = { val (res, post) = old.getOrElse(FreshIdentifier("res", funDef.returnType) -> BooleanLiteral(true)) val args = funDef.params.map(_.toVariable) - val sizePost = cmp(Tuple(funDef.params.map(_.toVariable)), res.toVariable) + val sizePost = cmp(tupleWrap(funDef.params.map(_.toVariable)), res.toVariable) (res, and(post, sizePost)) } @@ -69,8 +69,8 @@ trait Strengthener { self : TerminationChecker with RelationComparator with Rela def applicationConstraint(fd: FunDef, id: Identifier, arg: Expr, args: Seq[Expr]): Expr = arg match { case Lambda(fargs, body) => appConstraint.get(fd -> id) match { - case Some(StrongDecreasing) => self.sizeDecreasing(Tuple(args), Tuple(fargs.map(_.toVariable))) - case Some(WeakDecreasing) => self.softDecreasing(Tuple(args), Tuple(fargs.map(_.toVariable))) + case Some(StrongDecreasing) => self.sizeDecreasing(tupleWrap(args), tupleWrap(fargs.map(_.toVariable))) + case Some(WeakDecreasing) => self.softDecreasing(tupleWrap(args), tupleWrap(fargs.map(_.toVariable))) case _ => BooleanLiteral(true) } case _ => BooleanLiteral(true) @@ -84,14 +84,14 @@ trait Strengthener { self : TerminationChecker with RelationComparator with Rela val appCollector = new CollectorWithPaths[(Identifier,Expr,Expr)] { def collect(e: Expr, path: Seq[Expr]): Option[(Identifier, Expr, Expr)] = e match { - case Application(Variable(id), args) => Some((id, andJoin(path), Tuple(args))) + case Application(Variable(id), args) => Some((id, andJoin(path), tupleWrap(args))) case _ => None } } val applications = appCollector.traverse(funDef).distinct - val funDefArgTuple = Tuple(funDef.params.map(_.toVariable)) + val funDefArgTuple = tupleWrap(funDef.params.map(_.toVariable)) val allFormulas = for ((id, path, appArgs) <- applications) yield { val soft = Implies(path, self.softDecreasing(funDefArgTuple, appArgs)) @@ -118,7 +118,7 @@ trait Strengthener { self : TerminationChecker with RelationComparator with Rela val fiCollector = new CollectorWithPaths[(Expr, Expr, Seq[(Identifier,(FunDef, Identifier))])] { def collect(e: Expr, path: Seq[Expr]): Option[(Expr, Expr, Seq[(Identifier,(FunDef, Identifier))])] = e match { case FunctionInvocation(tfd, args) if (funDefHOArgs intersect args.collect({ case Variable(id) => id }).toSet).nonEmpty => - Some((andJoin(path), Tuple(args), (args zip tfd.fd.params).collect { + Some((andJoin(path), tupleWrap(args), (args zip tfd.fd.params).collect { case (Variable(id), vd) if funDefHOArgs(id) => id -> ((tfd.fd, vd.id)) })) case _ => None diff --git a/src/main/scala/leon/utils/Simplifiers.scala b/src/main/scala/leon/utils/Simplifiers.scala index 76fd42ce6..1659f4811 100644 --- a/src/main/scala/leon/utils/Simplifiers.scala +++ b/src/main/scala/leon/utils/Simplifiers.scala @@ -18,7 +18,6 @@ object Simplifiers { simplifyLets _, simplifyPaths(uninterpretedZ3)(_), simplifyArithmetic _, - rewriteTuples _, evalGround(ctx, p), normalizeExpression _ ) @@ -42,7 +41,6 @@ object Simplifiers { val simplifiers = List[Expr => Expr]( simplifyTautologies(uninterpretedZ3)(_), simplifyArithmetic _, - rewriteTuples _, evalGround(ctx, p), normalizeExpression _ ) diff --git a/src/main/scala/leon/utils/UnitElimination.scala b/src/main/scala/leon/utils/UnitElimination.scala index d32ac6fb3..a85505567 100644 --- a/src/main/scala/leon/utils/UnitElimination.scala +++ b/src/main/scala/leon/utils/UnitElimination.scala @@ -53,11 +53,7 @@ object UnitElimination extends TransformationPhase { } private def simplifyType(tpe: TypeTree): TypeTree = tpe match { - case TupleType(tpes) => tpes.map(simplifyType).filterNot{ case UnitType => true case _ => false } match { - case Seq() => UnitType - case Seq(tpe) => tpe - case tpes => TupleType(tpes) - } + case TupleType(tpes) => tupleTypeWrap(tpes.map(simplifyType).filterNot{ case UnitType => true case _ => false }) case t => t } @@ -72,7 +68,7 @@ object UnitElimination extends TransformationPhase { case t@Tuple(args) => { val TupleType(tpes) = t.getType val (newTpes, newArgs) = tpes.zip(args).filterNot{ case (UnitType, _) => true case _ => false }.unzip - Tuple(newArgs.map(removeUnit)) + tupleWrap(newArgs.map(removeUnit)) // @mk: FIXME this may actually return a Unit, is that cool? } case ts@TupleSelect(t, index) => { val TupleType(tpes) = t.getType @@ -89,7 +85,7 @@ object UnitElimination extends TransformationPhase { else { id.getType match { case TupleType(tpes) if tpes.exists(_ == UnitType) => { - val newTupleType = TupleType(tpes.filterNot(_ == UnitType)) + val newTupleType = tupleTypeWrap(tpes.filterNot(_ == UnitType)) val freshId = FreshIdentifier(id.name, newTupleType) id2FreshId += (id -> freshId) val newBody = removeUnit(b) diff --git a/src/main/scala/leon/xlang/ImperativeCodeElimination.scala b/src/main/scala/leon/xlang/ImperativeCodeElimination.scala index a7de9904e..ad34c5f48 100644 --- a/src/main/scala/leon/xlang/ImperativeCodeElimination.scala +++ b/src/main/scala/leon/xlang/ImperativeCodeElimination.scala @@ -71,14 +71,14 @@ object ImperativeCodeElimination extends LeonPhase[Program, (Program, Set[FunDef val modifiedVars: Seq[Identifier] = (tFun.keys ++ eFun.keys).toSet.intersect(varInScope).toSeq val resId = FreshIdentifier("res", iteRType) val freshIds = modifiedVars.map(id => FreshIdentifier(id.name, id.getType)) - val iteType = if(modifiedVars.isEmpty) resId.getType else TupleType(resId.getType +: freshIds.map(_.getType)) + val iteType = if(modifiedVars.isEmpty) resId.getType else tupleTypeWrap(resId.getType +: freshIds.map(_.getType)) - val thenVal = if(modifiedVars.isEmpty) tRes else Tuple(tRes +: modifiedVars.map(vId => tFun.get(vId) match { + val thenVal = if(modifiedVars.isEmpty) tRes else tupleWrap(tRes +: modifiedVars.map(vId => tFun.get(vId) match { case Some(newId) => newId.toVariable case None => vId.toVariable })) - val elseVal = if(modifiedVars.isEmpty) eRes else Tuple(eRes +: modifiedVars.map(vId => eFun.get(vId) match { + val elseVal = if(modifiedVars.isEmpty) eRes else tupleWrap(eRes +: modifiedVars.map(vId => eFun.get(vId) match { case Some(newId) => newId.toVariable case None => vId.toVariable })) @@ -92,10 +92,10 @@ object ImperativeCodeElimination extends LeonPhase[Program, (Program, Set[FunDef if(freshIds.isEmpty) Let(resId, tupleId.toVariable, body) else - Let(resId, TupleSelect(tupleId.toVariable, 1), + Let(resId, tupleSelect(tupleId.toVariable, 1), freshIds.zipWithIndex.foldLeft(body)((b, id) => Let(id._1, - TupleSelect(tupleId.toVariable, id._2 + 2), + tupleSelect(tupleId.toVariable, id._2 + 2), b)))).copiedFrom(expr)) }) @@ -110,10 +110,10 @@ object ImperativeCodeElimination extends LeonPhase[Program, (Program, Set[FunDef val modifiedVars: Seq[Identifier] = csesFun.toSet.flatMap((m: Map[Identifier, Identifier]) => m.keys).intersect(varInScope).toSeq val resId = FreshIdentifier("res", m.getType) val freshIds = modifiedVars.map(id => FreshIdentifier(id.name, id.getType)) - val matchType = if(modifiedVars.isEmpty) resId.getType else TupleType(resId.getType +: freshIds.map(_.getType)) + val matchType = if(modifiedVars.isEmpty) resId.getType else tupleTypeWrap(resId.getType +: freshIds.map(_.getType)) val csesVals = csesRes.zip(csesFun).map{ - case (cRes, cFun) => if(modifiedVars.isEmpty) cRes else Tuple(cRes +: modifiedVars.map(vId => cFun.get(vId) match { + case (cRes, cFun) => if(modifiedVars.isEmpty) cRes else tupleWrap(cRes +: modifiedVars.map(vId => cFun.get(vId) match { case Some(newId) => newId.toVariable case None => vId.toVariable })) @@ -133,10 +133,10 @@ object ImperativeCodeElimination extends LeonPhase[Program, (Program, Set[FunDef if(freshIds.isEmpty) Let(resId, tupleId.toVariable, body) else - Let(resId, TupleSelect(tupleId.toVariable, 1), + Let(resId, tupleSelect(tupleId.toVariable, 1), freshIds.zipWithIndex.foldLeft(body)((b, id) => Let(id._1, - TupleSelect(tupleId.toVariable, id._2 + 2), + tupleSelect(tupleId.toVariable, id._2 + 2), b))))) }) @@ -155,7 +155,7 @@ object ImperativeCodeElimination extends LeonPhase[Program, (Program, Set[FunDef val whileFunVars = modifiedVars.map(id => FreshIdentifier(id.name, id.getType)) val modifiedVars2WhileFunVars = modifiedVars.zip(whileFunVars).toMap val whileFunValDefs = whileFunVars.map(id => ValDef(id, id.getType)) - val whileFunReturnType = if(whileFunVars.size == 1) whileFunVars.head.getType else TupleType(whileFunVars.map(_.getType)) + val whileFunReturnType = if(whileFunVars.size == 1) whileFunVars.head.getType else tupleTypeWrap(whileFunVars.map(_.getType)) val whileFunDef = new FunDef(FreshIdentifier(parent.id.name), Nil, whileFunReturnType, whileFunValDefs,DefType.MethodDef).setPos(wh) wasLoop += whileFunDef @@ -163,11 +163,7 @@ object ImperativeCodeElimination extends LeonPhase[Program, (Program, Set[FunDef val whileFunRecursiveCall = replaceNames(condFun, bodyScope(FunctionInvocation(whileFunDef.typed, modifiedVars.map(id => condBodyFun(id).toVariable)).setPos(wh))) val whileFunBaseCase = - (if(whileFunVars.size == 1) - condFun.get(modifiedVars.head).getOrElse(whileFunVars.head).toVariable - else - Tuple(modifiedVars.map(id => condFun.get(id).getOrElse(modifiedVars2WhileFunVars(id)).toVariable)) - ) + tupleWrap(modifiedVars.map(id => condFun.get(id).getOrElse(modifiedVars2WhileFunVars(id)).toVariable)) val whileFunBody = replaceNames(modifiedVars2WhileFunVars, condScope(IfExpr(whileFunCond, whileFunRecursiveCall, whileFunBaseCase))) whileFunDef.body = Some(whileFunBody) @@ -177,7 +173,7 @@ object ImperativeCodeElimination extends LeonPhase[Program, (Program, Set[FunDef if(whileFunVars.size == 1) Map(whileFunVars.head.toVariable -> resVar) else - whileFunVars.zipWithIndex.map{ case (v, i) => (v.toVariable, TupleSelect(resVar, i+1)) }.toMap + whileFunVars.zipWithIndex.map{ case (v, i) => (v.toVariable, tupleSelect(resVar, i+1)) }.toMap val modifiedVars2ResultVars: Map[Expr, Expr] = modifiedVars.map(id => (id.toVariable, whileFunVars2ResultVars(modifiedVars2WhileFunVars(id).toVariable))).toMap @@ -206,7 +202,7 @@ object ImperativeCodeElimination extends LeonPhase[Program, (Program, Set[FunDef else finalVars.zipWithIndex.foldLeft(body)((b, id) => Let(id._1, - TupleSelect(tupleId.toVariable, id._2 + 1), + tupleSelect(tupleId.toVariable, id._2 + 1), b)))) }) -- GitLab