diff --git a/src/main/scala/leon/purescala/TreeOps.scala b/src/main/scala/leon/purescala/TreeOps.scala index ff30df5c980be8801ee115daf154b4b8fa3a85c6..1b5a4bb9f351b841a3c2d2f4e9a90c56ab4e5d51 100644 --- a/src/main/scala/leon/purescala/TreeOps.scala +++ b/src/main/scala/leon/purescala/TreeOps.scala @@ -461,7 +461,9 @@ object TreeOps { */ def simplifyLets(expr: Expr) : Expr = { def simplerLet(t: Expr) : Option[Expr] = t match { + case letExpr @ Let(i, t: Terminal, b) => Some(replace(Map((Variable(i) -> t)), b)) + case letExpr @ Let(i,e,b) => { val occurences = treeCatamorphism[Int]((e:Expr) => e match { case Variable(x) if x == i => 1 @@ -475,8 +477,6 @@ object TreeOps { None } } - //case letTuple @ LetTuple(ids, expr, body) if ids.size == 1 => - // simplerLet(Let(ids.head, TupleSelect(expr, 1).setType(ids.head.getType), body)) case letTuple @ LetTuple(ids, Tuple(exprs), body) => @@ -511,6 +511,35 @@ object TreeOps { } else { Some(LetTuple(remIds, Tuple(remExprs), newBody)) } + + case l @ LetTuple(ids, tExpr, body) => + val TupleType(types) = tExpr.getType + val arity = ids.size + // A map containing vectors of the form (0, ..., 1, ..., 0) where the one corresponds to the index of the identifier in the + // LetTuple. The idea is that we can sum such vectors up to compute the occurences of all variables in one traversal of the + // expression. + val zeroVec = Seq.fill(arity)(0) + val idMap = ids.zipWithIndex.toMap.mapValues(i => zeroVec.updated(i, 1)) + + val occurences : Seq[Int] = treeCatamorphism[Seq[Int]]((e : Expr) => e match { + case Variable(x) => idMap.getOrElse(x, zeroVec) + case _ => zeroVec + }, (v1 : Seq[Int], v2 : Seq[Int]) => (v1 zip v2).map(p => p._1 + p._2), body) + + val total = occurences.sum + + if(total == 0) { + Some(body) + } else if(total == 1) { + val substMap : Map[Expr,Expr] = ids.map(Variable(_) : Expr).zipWithIndex.toMap.map { + case (v,i) => (v -> TupleSelect(tExpr, i + 1).setType(v.getType)) + } + + Some(replace(substMap, body)) + } else { + None + } + case _ => None } searchAndReplaceDFS(simplerLet)(expr) diff --git a/src/main/scala/leon/purescala/Trees.scala b/src/main/scala/leon/purescala/Trees.scala index c783bbb04287e6875cc28c4f396b034979d19571..4f88d3eb88550cbf13eb0fe8c1e6e891290664bb 100644 --- a/src/main/scala/leon/purescala/Trees.scala +++ b/src/main/scala/leon/purescala/Trees.scala @@ -37,6 +37,7 @@ object Trees { case class LetTuple(binders: Seq[Identifier], value: Expr, body: Expr) extends Expr { binders.foreach(_.markAsLetBinder) + assert(value.getType.isInstanceOf[TupleType]) val et = body.getType if(et != Untyped) setType(et) @@ -58,8 +59,16 @@ object Trees { } // This must be 1-indexed ! (So are methods of Scala Tuples) - case class TupleSelect(tuple: Expr, index: Int) extends Expr { + case class TupleSelect(tuple: Expr, index: Int) extends Expr with FixedType { assert(index >= 1) + + val fixedType : TypeTree = tuple.getType match { + case TupleType(ts) => + assert(index <= ts.size) + ts(index - 1) + + case _ => assert(false); Untyped + } } object MatchExpr { diff --git a/src/main/scala/leon/synthesis/heuristics/ADTInduction.scala b/src/main/scala/leon/synthesis/heuristics/ADTInduction.scala index 4cb8f22859770c07f7f80db7bda0e0187ca8514a..1c50c18c85bb1ec3a418a9be60bddc1f0eeb6a9f 100644 --- a/src/main/scala/leon/synthesis/heuristics/ADTInduction.scala +++ b/src/main/scala/leon/synthesis/heuristics/ADTInduction.scala @@ -90,7 +90,7 @@ case object ADTInduction extends Rule("ADT Induction") with Heuristic { val outerPre = Or(globalPre) newFun.precondition = Some(funPre) - newFun.postcondition = Some(LetTuple(p.xs.toSeq, ResultVariable(), p.phi)) + newFun.postcondition = Some(LetTuple(p.xs.toSeq, ResultVariable().setType(resType), p.phi)) newFun.body = Some(MatchExpr(Variable(inductOn), cases))