From ba41c6b20ab2f0015b7f70b45e1aef75a4d3816f Mon Sep 17 00:00:00 2001 From: Philippe Suter <philippe.suter@gmail.com> Date: Fri, 4 Jan 2013 19:28:21 +0100 Subject: [PATCH] Simplification of 1-uples and 0-uples. This commit also fixes a serious bug (that apparently affected no one in purescala/Extractors, namely the reconstruction function for Let and LetTuple was broken). --- .../scala/leon/purescala/Extractors.scala | 4 +- src/main/scala/leon/purescala/TreeOps.scala | 96 +++++++++++++++++++ src/main/scala/leon/purescala/Trees.scala | 4 +- src/main/scala/leon/purescala/TypeTrees.scala | 11 ++- .../scala/leon/synthesis/SynthesisPhase.scala | 3 +- .../synthesis/heuristics/IntInduction.scala | 2 +- 6 files changed, 113 insertions(+), 7 deletions(-) diff --git a/src/main/scala/leon/purescala/Extractors.scala b/src/main/scala/leon/purescala/Extractors.scala index 904e6e3d3..00ce83c66 100644 --- a/src/main/scala/leon/purescala/Extractors.scala +++ b/src/main/scala/leon/purescala/Extractors.scala @@ -69,8 +69,8 @@ object Extractors { case ArraySelect(t1, t2) => Some((t1, t2, ArraySelect)) case Concat(t1,t2) => Some((t1,t2,Concat)) case ListAt(t1,t2) => Some((t1,t2,ListAt)) - case Let(binders, e, body) => Some((e, body, (e: Expr, b: Expr) => Let(binders, e, body))) //TODO: shouldn't be "b" instead of "body" ? - case LetTuple(binders, e, body) => Some((e, body, (e: Expr, b: Expr) => LetTuple(binders, e, body))) //TODO: shouldn't be "b" instead of "body" ? + case Let(binders, e, body) => Some((e, body, (e: Expr, b: Expr) => Let(binders, e, b))) + case LetTuple(binders, e, body) => Some((e, body, (e: Expr, b: Expr) => LetTuple(binders, e, b))) case (ex: BinaryExtractable) => ex.extract case _ => None } diff --git a/src/main/scala/leon/purescala/TreeOps.scala b/src/main/scala/leon/purescala/TreeOps.scala index c3cc7809d..4a5ca356c 100644 --- a/src/main/scala/leon/purescala/TreeOps.scala +++ b/src/main/scala/leon/purescala/TreeOps.scala @@ -1254,6 +1254,102 @@ object TreeOps { rec(expr, Nil) } + // Eliminates tuples of arity 0 and 1. This function also affects types! + // Only rewrites local fundefs (i.e. LetDef's). + def rewriteTuples(expr: Expr) : Expr = { + def mapType(tt : TypeTree) : Option[TypeTree] = tt match { + case TupleType(ts) => ts.size match { + case 0 => Some(UnitType) + case 1 => Some(ts(0)) + case _ => + val tss = ts.map(mapType) + if(tss.exists(_.isDefined)) { + Some(TupleType((tss zip ts).map(p => p._1.getOrElse(p._2)))) + } else { + None + } + } + case ListType(t) => mapType(t).map(ListType(_)) + case SetType(t) => mapType(t).map(SetType(_)) + case MultisetType(t) => mapType(t).map(MultisetType(_)) + case ArrayType(t) => mapType(t).map(ArrayType(_)) + case MapType(f,t) => + val (f2,t2) = (mapType(f),mapType(t)) + if(f2.isDefined || t2.isDefined) { + Some(MapType(f2.getOrElse(f), t2.getOrElse(t))) + } else { + None + } + case a : AbstractClassType => None + case c : CaseClassType => + // This is really just one big assertion. We don't rewrite class defs. + val ccd = c.classDef + val fieldTypes = ccd.fields.map(_.tpe) + if(fieldTypes.exists(t => t match { + case TupleType(ts) if ts.size <= 1 => true + case _ => false + })) { + scala.sys.error("Cannot rewrite case class def that contains degenerate tuple types.") + } else { + None + } + case Untyped | AnyType | BottomType | BooleanType | Int32Type | UnitType => None + } + + var funDefMap = Map.empty[FunDef,FunDef] + + def fd2fd(funDef : FunDef) : FunDef = funDefMap.get(funDef) match { + case Some(fd) => fd + case None => + if(funDef.args.map(vd => mapType(vd.tpe)).exists(_.isDefined)) { + scala.sys.error("Cannot rewrite function def that takes degenerate tuple arguments,") + } + val newFD = mapType(funDef.returnType) match { + case None => funDef + case Some(rt) => + val fd = new FunDef(FreshIdentifier(funDef.id.name, true), rt, funDef.args) + // These will be taken care of in the recursive traversal. + fd.body = funDef.body + fd.precondition = funDef.precondition + fd.postcondition = funDef.postcondition + fd + } + funDefMap = funDefMap.updated(funDef, newFD) + newFD + } + + def pre(e : Expr) : Expr = e match { + case Tuple(Seq()) => UnitLiteral + + case Tuple(Seq(s)) => pre(s) + + case ts @ TupleSelect(t, 1) => t.getType match { + case TupleOneType(_) => pre(t) + case _ => ts + } + + case LetTuple(bs, v, bdy) if bs.size == 1 => + Let(bs(0), v, bdy) + + case l @ LetDef(fd, bdy) => + LetDef(fd2fd(fd), bdy) + + case r @ ResultVariable() => + mapType(r.getType).map { newType => + ResultVariable().setType(newType) + } getOrElse { + r + } + + case FunctionInvocation(fd, args) => + FunctionInvocation(fd2fd(fd), args) + + case _ => e + } + + simplePreTransform(pre)(expr) + } + def formulaSize(e: Expr): Int = e match { case t: Terminal => 1 diff --git a/src/main/scala/leon/purescala/Trees.scala b/src/main/scala/leon/purescala/Trees.scala index 4f88d3eb8..4e5e0fea6 100644 --- a/src/main/scala/leon/purescala/Trees.scala +++ b/src/main/scala/leon/purescala/Trees.scala @@ -47,7 +47,7 @@ object Trees { case class FunctionInvocation(funDef: FunDef, args: Seq[Expr]) extends Expr with FixedType with ScalacPositional { val fixedType = funDef.returnType - funDef.args.zip(args).foreach{ case (a, c) => typeCheck(c, a.tpe) } + funDef.args.zip(args).foreach { case (a, c) => typeCheck(c, a.tpe) } } case class IfExpr(cond: Expr, then: Expr, elze: Expr) extends Expr @@ -67,7 +67,7 @@ object Trees { assert(index <= ts.size) ts(index - 1) - case _ => assert(false); Untyped + case _ => scala.sys.error("Applying TupleSelect on a non-tuple tree [%s] of type [%s].".format(tuple, tuple.getType)) } } diff --git a/src/main/scala/leon/purescala/TypeTrees.scala b/src/main/scala/leon/purescala/TypeTrees.scala index 6dd49f749..7d2897ad5 100644 --- a/src/main/scala/leon/purescala/TypeTrees.scala +++ b/src/main/scala/leon/purescala/TypeTrees.scala @@ -176,7 +176,16 @@ object TypeTrees { // case (tt: TupleType) => Some(tt.bases) // case _ => None //} - def unapply(tt: TupleType): Option[Seq[TypeTree]] = Some(tt.bases) + def unapply(tt: TupleType): Option[Seq[TypeTree]] = if(tt == null) None else Some(tt.bases) + } + object TupleOneType { + def unapply(tt : TupleType) : Option[TypeTree] = if(tt == null) None else { + if(tt.bases.size == 1) { + Some(tt.bases.head) + } else { + None + } + } } case class ListType(base: TypeTree) extends TypeTree diff --git a/src/main/scala/leon/synthesis/SynthesisPhase.scala b/src/main/scala/leon/synthesis/SynthesisPhase.scala index 0d2701dbe..7adb83366 100644 --- a/src/main/scala/leon/synthesis/SynthesisPhase.scala +++ b/src/main/scala/leon/synthesis/SynthesisPhase.scala @@ -122,7 +122,8 @@ object SynthesisPhase extends LeonPhase[Program, Program] { decomposeIfs _, patternMatchReconstruction _, simplifyTautologies(uninterpretedZ3)(_), - simplifyLets _ + simplifyLets _, + rewriteTuples _ ) def simplify(e: Expr): Expr = simplifiers.foldLeft(e){ (x, sim) => sim(x) } diff --git a/src/main/scala/leon/synthesis/heuristics/IntInduction.scala b/src/main/scala/leon/synthesis/heuristics/IntInduction.scala index a882fe32b..8151341a3 100644 --- a/src/main/scala/leon/synthesis/heuristics/IntInduction.scala +++ b/src/main/scala/leon/synthesis/heuristics/IntInduction.scala @@ -38,7 +38,7 @@ case object IntInduction extends Rule("Int Induction") with Heuristic { val newFun = new FunDef(FreshIdentifier("rec", true), tpe, Seq(VarDecl(inductOn, inductOn.getType))) newFun.precondition = Some(preIn) - newFun.postcondition = Some(LetTuple(p.xs.toSeq, ResultVariable(), p.phi)) + newFun.postcondition = Some(LetTuple(p.xs.toSeq, ResultVariable().setType(tpe), p.phi)) newFun.body = Some( IfExpr(Equals(Variable(inductOn), IntLiteral(0)), -- GitLab