diff --git a/src/main/scala/leon/frontends/scalac/ASTExtractors.scala b/src/main/scala/leon/frontends/scalac/ASTExtractors.scala index 211b7dfc072d7773d7166ddab14c831ada2739ff..da0a4074da8470c952fc6d80aaff2022f4bd1f19 100644 --- a/src/main/scala/leon/frontends/scalac/ASTExtractors.scala +++ b/src/main/scala/leon/frontends/scalac/ASTExtractors.scala @@ -84,9 +84,6 @@ trait ASTExtractors { getResolvedTypeSym(sym) == scalaMapSym } - def isMultisetTraitSym(sym: Symbol) : Boolean = { - sym == multisetTraitSym - } def isOptionClassSym(sym : Symbol) : Boolean = { sym == optionClassSym || sym == someClassSym @@ -95,12 +92,6 @@ trait ASTExtractors { def isFunction(sym : Symbol, i: Int) : Boolean = 1 <= i && i <= 22 && sym == functionTraitSym(i) - protected lazy val multisetTraitSym = try { - classFromName("scala.collection.immutable.Multiset") - } catch { - case e: Throwable => - null - } def isArrayClassSym(sym: Symbol): Boolean = sym == arraySym @@ -863,22 +854,6 @@ trait ASTExtractors { } } - object ExEmptyMultiset { - def unapply(tree: TypeApply): Option[Tree] = tree match { - case TypeApply( - Select( - Select( - Select( - Select(Ident(s), collectionName), - immutableName), - setName), - emptyName), theTypeTree :: Nil) if ( - collectionName.toString == "collection" && immutableName.toString == "immutable" && setName.toString == "Multiset" && emptyName.toString == "empty" - ) => Some(theTypeTree) - case _ => None - } - } - object ExLiteralMap { def unapply(tree: Apply): Option[(Tree, Tree, Seq[Tree])] = tree match { case Apply(TypeApply(ExSelected("scala", "Predef", "Map", "apply"), fromTypeTree :: toTypeTree :: Nil), args) => @@ -918,23 +893,6 @@ trait ASTExtractors { } } - object ExFiniteMultiset { - def unapply(tree: Apply): Option[(Tree,List[Tree])] = tree match { - case Apply( - TypeApply( - Select( - Select( - Select( - Select(Ident(s), collectionName), - immutableName), - setName), - emptyName), theTypeTree :: Nil), args) if ( - collectionName.toString == "collection" && immutableName.toString == "immutable" && setName.toString == "Multiset" && emptyName.toString == "apply" - )=> Some(theTypeTree, args) - case _ => None - } - } - object ExParameterLessCall { def unapply(tree: Tree): Option[(Tree, Symbol, Seq[Tree])] = tree match { case s @ Select(t, _) => diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala index b69cda26f8d994714d03aa4927c8af470a384df6..a1ac056ed3a91c52830d16dbed7210a467c7aa5e 100644 --- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala +++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala @@ -1245,22 +1245,6 @@ trait CodeExtraction extends ASTExtractors { Lambda(vds, exBody) - case f @ ExForallExpression(args, body) => - val vds = args map { case (tpt, sym) => - val aTpe = extractType(tpt) - val newID = FreshIdentifier(sym.name.toString, aTpe) - owners += (newID -> None) - LeonValDef(newID) - } - - val newVars = (args zip vds) map { case ((_, sym), vd) => - sym -> (() => vd.toVariable) - } - - val exBody = extractTree(body)(dctx.withNewVars(newVars)) - - Forall(vds, exBody) - case ExFiniteMap(tptFrom, tptTo, args) => val singletons: Seq[(LeonExpr, LeonExpr)] = args.collect { case ExTuple(tpes, trees) if trees.size == 2 => @@ -1328,14 +1312,6 @@ trait CodeExtraction extends ASTExtractors { outOfSubsetError(tr, "Invalid comparison: (_: "+rt+") == (_: "+lt+")") } - case ExFiniteMultiset(tt, args) => - val underlying = extractType(tt) - finiteMultiset(args.map(extractTree),underlying) - - case ExEmptyMultiset(tt) => - val underlying = extractType(tt) - EmptyMultiset(underlying) - case ExArrayFill(baseType, length, defaultValue) => val lengthRec = extractTree(length) val defaultValueRec = extractTree(defaultValue) @@ -1558,22 +1534,6 @@ trait CodeExtraction extends ASTExtractors { case (IsTyped(a1, SetType(b1)), "isEmpty", List()) => Equals(a1, FiniteSet(Set(), b1)) - // Multiset methods - case (IsTyped(a1, MultisetType(b1)), "++", List(IsTyped(a2, MultisetType(b2)))) if b1 == b2 => - MultisetUnion(a1, a2) - - case (IsTyped(a1, MultisetType(b1)), "&", List(IsTyped(a2, MultisetType(b2)))) if b1 == b2 => - MultisetIntersection(a1, a2) - - case (IsTyped(a1, MultisetType(b1)), "--", List(IsTyped(a2, MultisetType(b2)))) if b1 == b2 => - MultisetDifference(a1, a2) - - case (IsTyped(a1, MultisetType(b1)), "+++", List(IsTyped(a2, MultisetType(b2)))) if b1 == b2 => - MultisetPlus(a1, a2) - - case (IsTyped(_, MultisetType(b1)), "toSet", Nil) => - MultisetToSet(rrec) - // Array methods case (IsTyped(a1, ArrayType(vt)), "apply", List(a2)) => ArraySelect(a1, a2) @@ -1689,9 +1649,6 @@ trait CodeExtraction extends ASTExtractors { case TypeRef(_, sym, btt :: Nil) if isSetSym(sym) => SetType(extractType(btt)) - case TypeRef(_, sym, btt :: Nil) if isMultisetTraitSym(sym) => - MultisetType(extractType(btt)) - case TypeRef(_, sym, List(ftt,ttt)) if isMapSym(sym) => MapType(extractType(ftt), extractType(ttt)) diff --git a/src/main/scala/leon/purescala/Constructors.scala b/src/main/scala/leon/purescala/Constructors.scala index 148f8d051139c7b772485b2960262d2d5e81f228..6ffb833e3ca6dd64dc70f84ef6a0d4cca10e76e8 100644 --- a/src/main/scala/leon/purescala/Constructors.scala +++ b/src/main/scala/leon/purescala/Constructors.scala @@ -184,11 +184,6 @@ object Constructors { case _ => Implies(lhs, rhs) } - def finiteMultiset(els: Seq[Expr], tpe: TypeTree) = { - if (els.isEmpty) EmptyMultiset(tpe) - else NonemptyMultiset(els) - } - def finiteArray(els: Seq[Expr]): Expr = { require(els.nonEmpty) finiteArray(els, None, Untyped) // Untyped is not correct, but will not be used anyway diff --git a/src/main/scala/leon/purescala/ExprOps.scala b/src/main/scala/leon/purescala/ExprOps.scala index 66b1ed28a43b7849049d4a8bf726c8b49ab47c6d..50740179a9a08bdc12ccb88eeb79dba93af5ff66 100644 --- a/src/main/scala/leon/purescala/ExprOps.scala +++ b/src/main/scala/leon/purescala/ExprOps.scala @@ -7,7 +7,7 @@ import Common._ import Types._ import Definitions._ import Expressions._ -import Extractors._ +import leon.purescala.Extractors._ import Constructors._ import utils.Simplifiers import solvers._ @@ -34,19 +34,9 @@ object ExprOps { def foldRight[T](f: (Expr, Seq[T]) => T)(e: Expr): T = { val rec = foldRight(f) _ - e match { - case t: Terminal => - f(t, Stream.empty) + val Operator(es, _) = e + f(e, es.view.map(rec)) - case u @ UnaryOperator(e, builder) => - f(u, List(e).view.map(rec)) - - case b @ BinaryOperator(e1, e2, builder) => - f(b, List(e1, e2).view.map(rec)) - - case n @ NAryOperator(es, builder) => - f(n, es.view.map(rec)) - } } /** @@ -64,19 +54,8 @@ object ExprOps { def preTraversal(f: Expr => Unit)(e: Expr): Unit = { val rec = preTraversal(f) _ f(e) - - e match { - case t: Terminal => - - case u @ UnaryOperator(e, builder) => - List(e).foreach(rec) - - case b @ BinaryOperator(e1, e2, builder) => - List(e1, e2).foreach(rec) - - case n @ NAryOperator(es, builder) => - es.foreach(rec) - } + val Operator(es, _) = e + es.foreach(rec) } /** @@ -93,20 +72,8 @@ object ExprOps { */ def postTraversal(f: Expr => Unit)(e: Expr): Unit = { val rec = postTraversal(f) _ - - e match { - case t: Terminal => - - case u @ UnaryOperator(e, builder) => - List(e).foreach(rec) - - case b @ BinaryOperator(e1, e2, builder) => - List(e1, e2).foreach(rec) - - case n @ NAryOperator(es, builder) => - es.foreach(rec) - } - + val Operator(es, _) = e + es.foreach(rec) f(e) } @@ -148,37 +115,13 @@ object ExprOps { f(e) getOrElse e } - newV match { - case u @ UnaryOperator(e, builder) => - val newE = rec(e) - - if (newE ne e) { - builder(newE).copiedFrom(u) - } else { - u - } - - case b @ BinaryOperator(e1, e2, builder) => - val newE1 = rec(e1) - val newE2 = rec(e2) - - if ((newE1 ne e1) || (newE2 ne e2)) { - builder(newE1, newE2).copiedFrom(b) - } else { - b - } - - case n @ NAryOperator(es, builder) => - val newEs = es.map(rec) - - if ((newEs zip es).exists { case (bef, aft) => aft ne bef }) { - builder(newEs).copiedFrom(n) - } else { - n - } + val Operator(es, builder) = newV + val newEs = es.map(rec) - case t: Terminal => - t + if ((newEs zip es).exists { case (bef, aft) => aft ne bef }) { + builder(newEs).copiedFrom(newV) + } else { + newV } } @@ -213,37 +156,15 @@ object ExprOps { def postMap(f: Expr => Option[Expr], applyRec : Boolean = false)(e: Expr): Expr = { val rec = postMap(f, applyRec) _ - val newV = e match { - case u @ UnaryOperator(e, builder) => - val newE = rec(e) - - if (newE ne e) { - builder(newE).copiedFrom(u) - } else { - u - } - - case b @ BinaryOperator(e1, e2, builder) => - val newE1 = rec(e1) - val newE2 = rec(e2) - - if ((newE1 ne e1) || (newE2 ne e2)) { - builder(newE1, newE2).copiedFrom(b) - } else { - b - } + val Operator(es, builder) = e + val newV = { + val newEs = es.map(rec) - case n @ NAryOperator(es, builder) => - val newEs = es.map(rec) - - if ((newEs zip es).exists { case (bef, aft) => aft ne bef }) { - builder(newEs).copiedFrom(n) - } else { - n - } - - case t: Terminal => - t + if ((newEs zip es).exists { case (bef, aft) => aft ne bef }) { + builder(newEs).copiedFrom(e) + } else { + e + } } if (applyRec) { @@ -357,7 +278,6 @@ object ExprOps { case MatchExpr(_, cses) => subvs -- cses.map(_.pattern.binders).flatten case Passes(_, _ , cses) => subvs -- cses.map(_.pattern.binders).flatten case Lambda(args, body) => subvs -- args.map(_.id) - case Forall(args, body) => subvs -- args.map(_.id) case _ => subvs } }(expr) @@ -627,7 +547,7 @@ object ExprOps { case i @ IfExpr(t1,t2,t3) => IfExpr(rec(t1, s),rec(t2, s),rec(t3, s)) case m @ MatchExpr(scrut, cses) => matchExpr(rec(scrut, s), cses.map(inCase(_, s))).setPos(m) case p @ Passes(in, out, cses) => Passes(rec(in, s), rec(out,s), cses.map(inCase(_, s))).setPos(p) - case n @ NAryOperator(args, recons) => { + case n @ Operator(args, recons) => { var change = false val rargs = args.map(a => { val ra = rec(a, s) @@ -643,22 +563,6 @@ object ExprOps { else n } - case b @ BinaryOperator(t1,t2,recons) => { - val r1 = rec(t1, s) - val r2 = rec(t2, s) - if(r1 != t1 || r2 != t2) - recons(r1,r2) - else - b - } - case u @ UnaryOperator(t,recons) => { - val r = rec(t, s) - if(r != t) - recons(r) - else - u - } - case t if t.isInstanceOf[Terminal] => t case unhandled => scala.sys.error("Unhandled case in expandLets: " + unhandled) } @@ -1057,16 +961,7 @@ object ExprOps { def transform(expr: Expr): Option[Expr] = expr match { case IfExpr(c, t, e) => None - case uop@UnaryOperator(IfExpr(c, t, e), op) => - Some(IfExpr(c, op(t).copiedFrom(uop), op(e).copiedFrom(uop)).copiedFrom(uop)) - - case bop@BinaryOperator(IfExpr(c, t, e), t2, op) => - Some(IfExpr(c, op(t, t2).copiedFrom(bop), op(e, t2).copiedFrom(bop)).copiedFrom(bop)) - - case bop@BinaryOperator(t1, IfExpr(c, t, e), op) => - Some(IfExpr(c, op(t1, t).copiedFrom(bop), op(t1, e).copiedFrom(bop)).copiedFrom(bop)) - - case nop@NAryOperator(ts, op) => { + case nop@Operator(ts, op) => { val iteIndex = ts.indexWhere{ case IfExpr(_, _, _) => true case _ => false } if(iteIndex == -1) None else { val (beforeIte, startIte) = ts.splitAt(iteIndex) @@ -1091,32 +986,12 @@ object ExprOps { def rec(eIn: Expr, cIn: C): (Expr, C) = { val (expr, ctx) = pre(eIn, cIn) + val Operator(es, builder) = expr + val (newExpr, newC) = { + val (nes, cs) = es.map{ rec(_, ctx)}.unzip + val newE = builder(nes).copiedFrom(expr) - val (newExpr, newC) = expr match { - case t: Terminal => - (expr, ctx) - - case UnaryOperator(e, builder) => - val (e1, c) = rec(e, ctx) - val newE = builder(e1).copiedFrom(expr) - - (newE, combiner(newE, Seq(c))) - - case BinaryOperator(e1, e2, builder) => - val (ne1, c1) = rec(e1, ctx) - val (ne2, c2) = rec(e2, ctx) - val newE = builder(ne1, ne2).copiedFrom(expr) - - (newE, combiner(newE, Seq(c1, c2))) - - case NAryOperator(es, builder) => - val (nes, cs) = es.map{ rec(_, ctx)}.unzip - val newE = builder(nes).copiedFrom(expr) - - (newE, combiner(newE, cs)) - - case e => - sys.error("Expression "+e+" ["+e.getClass+"] is not extractable") + (newE, combiner(newE, cs)) } post(newExpr, newC) @@ -1270,22 +1145,13 @@ object ExprOps { } def formulaSize(e: Expr): Int = e match { - case t: Terminal => - 1 - case ml: MatchExpr => formulaSize(ml.scrutinee) + ml.cases.map { case MatchCase(p, og, rhs) => formulaSize(rhs) + og.map(formulaSize).getOrElse(0) + patternSize(p) }.sum - case UnaryOperator(e, builder) => - formulaSize(e)+1 - - case BinaryOperator(e1, e2, builder) => - formulaSize(e1)+formulaSize(e2)+1 - - case NAryOperator(es, _) => + case Operator(es, _) => es.map(formulaSize).sum+1 } @@ -1547,14 +1413,8 @@ object ExprOps { // TODO: Check type params fdHomo(tfd1.fd, tfd2.fd) && (args1 zip args2).forall{ case (a1, a2) => isHomo(a1, a2) } - - case Same(UnaryOperator(e1, _), UnaryOperator(e2, _)) => - isHomo(e1, e2) - - case Same(BinaryOperator(e11, e12, _), BinaryOperator(e21, e22, _)) => - isHomo(e11, e21) && isHomo(e12, e22) - case Same(NAryOperator(es1, _), NAryOperator(es2, _)) => + case Same(Operator(es1, _), Operator(es2, _)) => if (es1.size == es2.size) { (es1 zip es2).forall{ case (e1, e2) => isHomo(e1, e2) } } else { @@ -1848,19 +1708,8 @@ object ExprOps { f(e, initParent) - e match { - case u @ UnaryOperator(e, builder) => - rec(e) - - case b @ BinaryOperator(e1, e2, builder) => - rec(e1) - rec(e2) - - case n @ NAryOperator(es, builder) => - es.foreach(rec) - - case t: Terminal => - } + val Operator(es, _) = e + es foreach rec } def functionAppsOf(expr: Expr): Set[Application] = { @@ -1910,10 +1759,7 @@ object ExprOps { case l @ Lambda(args, body) => val newBody = rec(body, true) extract(Lambda(args, newBody), build) - case NAryOperator(es, recons) => recons(es.map(rec(_, build))) - case BinaryOperator(e1, e2, recons) => recons(rec(e1, build), rec(e2, build)) - case UnaryOperator(e, recons) => recons(rec(e, build)) - case t: Terminal => t + case Operator(es, recons) => recons(es.map(rec(_, build))) } rec(lift(expr), true) diff --git a/src/main/scala/leon/purescala/Expressions.scala b/src/main/scala/leon/purescala/Expressions.scala index 580388348445316d04f32381661edd818b2ea7ed..4f6ce84337d8e3cba59f28d40638467700e91ea4 100644 --- a/src/main/scala/leon/purescala/Expressions.scala +++ b/src/main/scala/leon/purescala/Expressions.scala @@ -110,11 +110,6 @@ object Expressions { } } - case class Forall(args: Seq[ValDef], body: Expr) extends Expr { - require(body.getType == BooleanType) - val getType = BooleanType - } - case class This(ct: ClassType) extends Expr with Terminal { val getType = ct } @@ -486,17 +481,17 @@ object Expressions { /* Special trees */ // Provide an oracle (synthesizable, all-seeing choose) - case class WithOracle(oracles: List[Identifier], body: Expr) extends Expr with UnaryExtractable { + case class WithOracle(oracles: List[Identifier], body: Expr) extends Expr with Extractable { require(oracles.nonEmpty) val getType = body.getType def extract = { - Some((body, (e: Expr) => WithOracle(oracles, e).setPos(this))) + Some((Seq(body), (es: Seq[Expr]) => WithOracle(oracles, es.head).setPos(this))) } } - case class Hole(tpe: TypeTree, alts: Seq[Expr]) extends Expr with NAryExtractable { + case class Hole(tpe: TypeTree, alts: Seq[Expr]) extends Expr with Extractable { val getType = tpe def extract = { @@ -504,64 +499,4 @@ object Expressions { } } - /** - * DEPRECATED TREES - * These trees are not guaranteed to be supported by Leon. - **/ - @deprecated("3.0", "Use NonemptyArray with default value") - case class ArrayFill(length: Expr, defaultValue: Expr) extends Expr { - val getType = ArrayType(defaultValue.getType).unveilUntyped - } - - @deprecated("3.0", "Leon does not guarantee to correctly handle this expression") - case class SetMin(set: Expr) extends Expr { - val getType = Int32Type - } - - @deprecated("3.0", "Leon does not guarantee to correctly handle this expression") - case class SetMax(set: Expr) extends Expr { - val getType = Int32Type - } - - @deprecated("3.0", "Leon does not guarantee to correctly handle this expression") - case class EmptyMultiset(baseType: TypeTree) extends Expr with Terminal { - val getType = MultisetType(baseType).unveilUntyped - } - @deprecated("3.0", "Leon does not guarantee to correctly handle this expression") - case class NonemptyMultiset(elements: Seq[Expr]) extends Expr { - val getType = MultisetType(optionToType(leastUpperBound(elements.toSeq.map(_.getType)))).unveilUntyped - } - @deprecated("3.0", "Leon does not guarantee to correctly handle this expression") - case class Multiplicity(element: Expr, multiset: Expr) extends Expr { - val getType = Int32Type - } - @deprecated("3.0", "Leon does not guarantee to correctly handle this expression") - case class MultisetCardinality(multiset: Expr) extends Expr { - val getType = Int32Type - } - @deprecated("3.0", "Leon does not guarantee to correctly handle this expression") - case class MultisetIntersection(multiset1: Expr, multiset2: Expr) extends Expr { - val getType = leastUpperBound(Seq(multiset1, multiset2).map(_.getType)).getOrElse(Untyped).unveilUntyped - } - @deprecated("3.0", "Leon does not guarantee to correctly handle this expression") - case class MultisetUnion(multiset1: Expr, multiset2: Expr) extends Expr { - val getType = leastUpperBound(Seq(multiset1, multiset2).map(_.getType)).getOrElse(Untyped).unveilUntyped - } - @deprecated("3.0", "Leon does not guarantee to correctly handle this expression") - case class MultisetPlus(multiset1: Expr, multiset2: Expr) extends Expr { // disjoint union - val getType = leastUpperBound(Seq(multiset1, multiset2).map(_.getType)).getOrElse(Untyped).unveilUntyped - } - @deprecated("3.0", "Leon does not guarantee to correctly handle this expression") - case class MultisetDifference(multiset1: Expr, multiset2: Expr) extends Expr { - val getType = leastUpperBound(Seq(multiset1, multiset2).map(_.getType)).getOrElse(Untyped).unveilUntyped - } - @deprecated("3.0", "Leon does not guarantee to correctly handle this expression") - case class MultisetToSet(multiset: Expr) extends Expr { - val getType = multiset.getType match { - case MultisetType(base) => SetType(base).unveilUntyped - case _ => Untyped - } - } - - } diff --git a/src/main/scala/leon/purescala/Extractors.scala b/src/main/scala/leon/purescala/Extractors.scala index d6e5ba0d9a7bc66551d302cfbf4cc98c0706b283..7591e251d2fe76eec56ed7039234e89ae0ad8cc3 100644 --- a/src/main/scala/leon/purescala/Extractors.scala +++ b/src/main/scala/leon/purescala/Extractors.scala @@ -12,108 +12,130 @@ import Definitions.Program object Extractors { - object UnaryOperator { - def unapply(expr: Expr) : Option[(Expr,(Expr)=>Expr)] = expr match { - case Not(t) => Some((t,not)) - case Choose(expr) => Some((expr,Choose)) - case UMinus(t) => Some((t,UMinus)) - case BVUMinus(t) => Some((t,BVUMinus)) - case BVNot(t) => Some((t,BVNot)) - case SetCardinality(t) => Some((t,SetCardinality)) - case MultisetCardinality(t) => Some((t,MultisetCardinality)) - case MultisetToSet(t) => Some((t,MultisetToSet)) - case SetMin(s) => Some((s,SetMin)) - case SetMax(s) => Some((s,SetMax)) - case CaseClassSelector(cd, e, sel) => Some((e, CaseClassSelector(cd, _, sel))) - case CaseClassInstanceOf(cd, e) => Some((e, CaseClassInstanceOf(cd, _))) - case TupleSelect(t, i) => Some((t, tupleSelect(_, i, t.getType.asInstanceOf[TupleType].dimension))) - case ArrayLength(a) => Some((a, ArrayLength)) - case Lambda(args, body) => Some((body, Lambda(args, _))) - case Forall(args, body) => Some((body, Forall(args, _))) - case (ue: UnaryExtractable) => ue.extract - case _ => None - } - } + object Operator { + def unapply(expr: Expr): Option[(Seq[Expr], (Seq[Expr]) => Expr)] = expr match { + /* Unary operators */ + case Not(t) => + Some((Seq(t), (es: Seq[Expr]) => not(es.head))) + case Choose(expr) => + Some((Seq(expr), (es: Seq[Expr]) => Choose(es.head))) + case UMinus(t) => + Some((Seq(t), (es: Seq[Expr]) => UMinus(es.head))) + case BVUMinus(t) => + Some((Seq(t), (es: Seq[Expr]) => BVUMinus(es.head))) + case BVNot(t) => + Some((Seq(t), (es: Seq[Expr]) => BVNot(es.head))) + case SetCardinality(t) => + Some((Seq(t), (es: Seq[Expr]) => SetCardinality(es.head))) + case CaseClassSelector(cd, e, sel) => + Some((Seq(e), (es: Seq[Expr]) => CaseClassSelector(cd, es.head, sel))) + case CaseClassInstanceOf(cd, e) => + Some((Seq(e), (es: Seq[Expr]) => CaseClassInstanceOf(cd, es.head))) + case TupleSelect(t, i) => + Some((Seq(t), (es: Seq[Expr]) => TupleSelect(es.head, i))) + case ArrayLength(a) => + Some((Seq(a), (es: Seq[Expr]) => ArrayLength(es.head))) + case Lambda(args, body) => + Some((Seq(body), (es: Seq[Expr]) => Lambda(args, es.head))) - trait UnaryExtractable { - def extract: Option[(Expr, (Expr)=>Expr)] - } - - object BinaryOperator { - def unapply(expr: Expr) : Option[(Expr,Expr,(Expr,Expr)=>Expr)] = expr match { - case LetDef(fd, body) => Some((fd.fullBody, body, - (fdBd, body) => { - fd.fullBody = fdBd - LetDef(fd, body) + /* Binary operators */ + case LetDef(fd, body) => Some(( + Seq(fd.fullBody, body), + (es: Seq[Expr]) => { + fd.fullBody = es(0) + LetDef(fd, es(1)) } )) - case Equals(t1,t2) => Some((t1,t2,Equals)) - case Implies(t1,t2) => Some((t1,t2, implies)) - case Plus(t1,t2) => Some((t1,t2,plus)) - case Minus(t1,t2) => Some((t1,t2,minus)) - case Times(t1,t2) => Some((t1,t2,times)) - case Division(t1,t2) => Some((t1,t2,Division)) - case Remainder(t1,t2) => Some((t1,t2,Remainder)) - case Modulo(t1,t2) => Some((t1,t2,Modulo)) - case LessThan(t1,t2) => Some((t1,t2,LessThan)) - case GreaterThan(t1,t2) => Some((t1,t2,GreaterThan)) - case LessEquals(t1,t2) => Some((t1,t2,LessEquals)) - case GreaterEquals(t1,t2) => Some((t1,t2,GreaterEquals)) - case BVPlus(t1,t2) => Some((t1,t2,plus)) - case BVMinus(t1,t2) => Some((t1,t2,minus)) - case BVTimes(t1,t2) => Some((t1,t2,times)) - case BVDivision(t1,t2) => Some((t1,t2,BVDivision)) - case BVRemainder(t1,t2) => Some((t1,t2,BVRemainder)) - case BVAnd(t1,t2) => Some((t1,t2,BVAnd)) - case BVOr(t1,t2) => Some((t1,t2,BVOr)) - case BVXOr(t1,t2) => Some((t1,t2,BVXOr)) - case BVShiftLeft(t1,t2) => Some((t1,t2,BVShiftLeft)) - case BVAShiftRight(t1,t2) => Some((t1,t2,BVAShiftRight)) - case BVLShiftRight(t1,t2) => Some((t1,t2,BVLShiftRight)) - case ElementOfSet(t1,t2) => Some((t1,t2,ElementOfSet)) - case SubsetOf(t1,t2) => Some((t1,t2,SubsetOf)) - case SetIntersection(t1,t2) => Some((t1,t2,SetIntersection)) - case SetUnion(t1,t2) => Some((t1,t2,SetUnion)) - case SetDifference(t1,t2) => Some((t1,t2,SetDifference)) - case Multiplicity(t1,t2) => Some((t1,t2,Multiplicity)) - case MultisetIntersection(t1,t2) => Some((t1,t2,MultisetIntersection)) - case MultisetUnion(t1,t2) => Some((t1,t2,MultisetUnion)) - case MultisetPlus(t1,t2) => Some((t1,t2,MultisetPlus)) - case MultisetDifference(t1,t2) => Some((t1,t2,MultisetDifference)) - case mg@MapGet(t1,t2) => Some((t1,t2, MapGet)) - case MapUnion(t1,t2) => Some((t1,t2,MapUnion)) - case MapDifference(t1,t2) => Some((t1,t2,MapDifference)) - case MapIsDefinedAt(t1,t2) => Some((t1,t2, MapIsDefinedAt)) - case ArraySelect(t1, t2) => Some((t1, t2, ArraySelect)) - case Let(binder, e, body) => Some((e, body, Let(binder, _, _))) - case Require(pre, body) => Some((pre, body, Require)) - case Ensuring(body, post) => Some((body, post, Ensuring)) - case Assert(const, oerr, body) => Some((const, body, Assert(_, oerr, _))) - case (ex: BinaryExtractable) => ex.extract - case _ => None - } - } - - trait BinaryExtractable { - def extract: Option[(Expr, Expr, (Expr, Expr)=>Expr)] - } + case Equals(t1, t2) => + Some(Seq(t1, t2), (es: Seq[Expr]) => equality(es(0), es(1))) + case Implies(t1, t2) => + Some(Seq(t1, t2), (es: Seq[Expr]) => implies(es(0), es(1))) + case Plus(t1, t2) => + Some(Seq(t1, t2), (es: Seq[Expr]) => plus(es(0), es(1))) + case Minus(t1, t2) => + Some(Seq(t1, t2), (es: Seq[Expr]) => minus(es(0), es(1))) + case Times(t1, t2) => + Some(Seq(t1, t2), (es: Seq[Expr]) => times(es(0), es(1))) + case Division(t1, t2) => + Some(Seq(t1, t2), (es: Seq[Expr]) => Division(es(0), es(1))) + case Remainder(t1, t2) => + Some(Seq(t1, t2), (es: Seq[Expr]) => Remainder(es(0), es(1))) + case Modulo(t1, t2) => + Some(Seq(t1, t2), (es: Seq[Expr]) => Modulo(es(0), es(1))) + case LessThan(t1, t2) => + Some(Seq(t1, t2), (es: Seq[Expr]) => LessThan(es(0), es(1))) + case GreaterThan(t1, t2) => + Some(Seq(t1, t2), (es: Seq[Expr]) => GreaterThan(es(0), es(1))) + case LessEquals(t1, t2) => + Some(Seq(t1, t2), (es: Seq[Expr]) => LessEquals(es(0), es(1))) + case GreaterEquals(t1, t2) => + Some(Seq(t1, t2), (es: Seq[Expr]) => GreaterEquals(es(0), es(1))) + case BVPlus(t1, t2) => + Some(Seq(t1, t2), (es: Seq[Expr]) => plus(es(0), es(1))) + case BVMinus(t1, t2) => + Some(Seq(t1, t2), (es: Seq[Expr]) => minus(es(0), es(1))) + case BVTimes(t1, t2) => + Some(Seq(t1, t2), (es: Seq[Expr]) => times(es(0), es(1))) + case BVDivision(t1, t2) => + Some(Seq(t1, t2), (es: Seq[Expr]) => BVDivision(es(0), es(1))) + case BVRemainder(t1, t2) => + Some(Seq(t1, t2), (es: Seq[Expr]) => BVRemainder(es(0), es(1))) + case BVAnd(t1, t2) => + Some(Seq(t1, t2), (es: Seq[Expr]) => BVAnd(es(0), es(1))) + case BVOr(t1, t2) => + Some(Seq(t1, t2), (es: Seq[Expr]) => BVOr(es(0), es(1))) + case BVXOr(t1, t2) => + Some(Seq(t1, t2), (es: Seq[Expr]) => BVXOr(es(0), es(1))) + case BVShiftLeft(t1, t2) => + Some(Seq(t1, t2), (es: Seq[Expr]) => BVShiftLeft(es(0), es(1))) + case BVAShiftRight(t1, t2) => + Some(Seq(t1, t2), (es: Seq[Expr]) => BVAShiftRight(es(0), es(1))) + case BVLShiftRight(t1, t2) => + Some(Seq(t1, t2), (es: Seq[Expr]) => BVLShiftRight(es(0), es(1))) + case ElementOfSet(t1, t2) => + Some(Seq(t1, t2), (es: Seq[Expr]) => ElementOfSet(es(0), es(1))) + case SubsetOf(t1, t2) => + Some(Seq(t1, t2), (es: Seq[Expr]) => SubsetOf(es(0), es(1))) + case SetIntersection(t1, t2) => + Some(Seq(t1, t2), (es: Seq[Expr]) => SetIntersection(es(0), es(1))) + case SetUnion(t1, t2) => + Some(Seq(t1, t2), (es: Seq[Expr]) => SetUnion(es(0), es(1))) + case SetDifference(t1, t2) => + Some(Seq(t1, t2), (es: Seq[Expr]) => SetDifference(es(0), es(1))) + case mg@MapGet(t1, t2) => + Some(Seq(t1, t2), (es: Seq[Expr]) => MapGet(es(0), es(1))) + case MapUnion(t1, t2) => + Some(Seq(t1, t2), (es: Seq[Expr]) => MapUnion(es(0), es(1))) + case MapDifference(t1, t2) => + Some(Seq(t1, t2), (es: Seq[Expr]) => MapDifference(es(0), es(1))) + case MapIsDefinedAt(t1, t2) => + Some(Seq(t1, t2), (es: Seq[Expr]) => MapIsDefinedAt(es(0), es(1))) + case ArraySelect(t1, t2) => + Some(Seq(t1, t2), (es: Seq[Expr]) => ArraySelect(es(0), es(1))) + case Let(binder, e, body) => + Some(Seq(e, body), (es: Seq[Expr]) => Let(binder, es(0), es(1))) + case Require(pre, body) => + Some(Seq(pre, body), (es: Seq[Expr]) => Require(es(0), es(1))) + case Ensuring(body, post) => + Some(Seq(body, post), (es: Seq[Expr]) => Ensuring(es(0), es(1))) + case Assert(const, oerr, body) => + Some(Seq(const, body), (es: Seq[Expr]) => Assert(es(0), oerr, es(1))) - object NAryOperator { - def unapply(expr: Expr) : Option[(Seq[Expr],(Seq[Expr])=>Expr)] = expr match { - case fi @ FunctionInvocation(fd, args) => Some((args, FunctionInvocation(fd, _))) - case mi @ MethodInvocation(rec, cd, tfd, args) => Some((rec +: args, as => MethodInvocation(as.head, cd, tfd, as.tail))) - case fa @ Application(caller, args) => Some(caller +: args, as => application(as.head, as.tail)) + /* Other operators */ + case fi@FunctionInvocation(fd, args) => Some((args, FunctionInvocation(fd, _))) + case mi@MethodInvocation(rec, cd, tfd, args) => Some((rec +: args, as => MethodInvocation(as.head, cd, tfd, as.tail))) + case fa@Application(caller, args) => Some(caller +: args, as => application(as.head, as.tail)) case CaseClass(cd, args) => Some((args, CaseClass(cd, _))) case And(args) => Some((args, and)) case Or(args) => Some((args, or)) case FiniteSet(els, base) => - Some(( els.toSeq, els => FiniteSet(els.toSet, base) )) + Some((els.toSeq, els => FiniteSet(els.toSet, base))) case FiniteMap(args, f, t) => { - val subArgs = args.flatMap{case (k, v) => Seq(k, v)} + val subArgs = args.flatMap { case (k, v) => Seq(k, v) } val builder = (as: Seq[Expr]) => { - def rec(kvs: Seq[Expr]) : Seq[(Expr, Expr)] = kvs match { + def rec(kvs: Seq[Expr]): Seq[(Expr, Expr)] = kvs match { case Seq(k, v, t@_*) => - (k,v) +: rec(t) + (k, v) +: rec(t) case Seq() => Seq() case _ => sys.error("odd number of key/value expressions") } @@ -121,59 +143,69 @@ object Extractors { } Some((subArgs, builder)) } - case NonemptyMultiset(args) => - Some((args, NonemptyMultiset)) - case ArrayUpdated(t1, t2, t3) => Some((Seq(t1,t2,t3), (as: Seq[Expr]) => - ArrayUpdated(as(0), as(1), as(2)))) - case NonemptyArray(elems, Some((default, length))) => { + case ArrayUpdated(t1, t2, t3) => Some(( + Seq(t1, t2, t3), + (as: Seq[Expr]) => ArrayUpdated(as(0), as(1), as(2)) + )) + case NonemptyArray(elems, Some((default, length))) => val all = elems.map(_._2).toSeq :+ default :+ length - Some(( all, as => { + Some((all, as => { val l = as.length - nonemptyArray(as.take(l-2), Some( (as(l-2), as(l-1)) ) ) + nonemptyArray(as.take(l - 2), Some((as(l - 2), as(l - 1)))) })) - } case NonemptyArray(elems, None) => val all = elems.map(_._2).toSeq - Some(( all, finiteArray)) + Some((all, finiteArray)) case Tuple(args) => Some((args, tupleWrap)) case IfExpr(cond, thenn, elze) => Some(( Seq(cond, thenn, elze), - { case Seq(c,t,e) => IfExpr(c,t,e) } + { case Seq(c, t, e) => IfExpr(c, t, e) } )) case MatchExpr(scrut, cases) => Some(( - scrut +: cases.flatMap { + scrut +: cases.flatMap { case SimpleCase(_, e) => Seq(e) - case GuardedCase(_, e1, e2) => Seq(e1, e2) - }, + case GuardedCase(_, e1, e2) => Seq(e1, e2) + }, (es: Seq[Expr]) => { var i = 1 val newcases = for (caze <- cases) yield caze match { - case SimpleCase(b, _) => i+=1; SimpleCase(b, es(i-1)) - case GuardedCase(b, _, _) => i+=2; GuardedCase(b, es(i-2), es(i-1)) + case SimpleCase(b, _) => i += 1; SimpleCase(b, es(i - 1)) + case GuardedCase(b, _, _) => i += 2; GuardedCase(b, es(i - 2), es(i - 1)) } - matchExpr(es(0), newcases) + matchExpr(es.head, newcases) } )) case Passes(in, out, cases) => Some(( - in +: out +: cases.flatMap { _.expressions }, - { case Seq(in, out, es@_*) => { - var i = 0 - val newcases = for (caze <- cases) yield caze match { - case SimpleCase(b, _) => i+=1; SimpleCase(b, es(i-1)) - case GuardedCase(b, _, _) => i+=2; GuardedCase(b, es(i-1), es(i-2)) - } + in +: out +: cases.flatMap { + _.expressions + }, { + case Seq(in, out, es@_*) => { + var i = 0 + val newcases = for (caze <- cases) yield caze match { + case SimpleCase(b, _) => i += 1; SimpleCase(b, es(i - 1)) + case GuardedCase(b, _, _) => i += 2; GuardedCase(b, es(i - 1), es(i - 2)) + } - passes(in, out, newcases) - }} + passes(in, out, newcases) + } + } )) - case (ex: NAryExtractable) => ex.extract - case _ => None + + /* Terminals */ + case t: Terminal => Some(Seq[Expr](), (_:Seq[Expr]) => t) + + /* Expr's not handled here should implement this trait */ + case e: Extractable => + e.extract + + case _ => + None } } - trait NAryExtractable { - def extract: Option[(Seq[Expr], (Seq[Expr])=>Expr)] + trait Extractable { + def extract: Option[(Seq[Expr], Seq[Expr] => Expr)] } object StringLiteral { @@ -248,14 +280,6 @@ object Extractors { } } - object FiniteMultiset { - def unapply(e: Expr): Option[Seq[Expr]] = e match { - case EmptyMultiset(_) => Some(Seq()) - case NonemptyMultiset(els) => Some(els) - case _ => None - } - } - object FiniteArray { def unapply(e: Expr): Option[(Map[Int, Expr], Option[Expr], Expr)] = e match { case EmptyArray(_) => diff --git a/src/main/scala/leon/purescala/FunctionClosure.scala b/src/main/scala/leon/purescala/FunctionClosure.scala index bedc124db6558eac471283c1750f16babacf9430..d27f90983931d5e1ce960800844e6669c75479b7 100644 --- a/src/main/scala/leon/purescala/FunctionClosure.scala +++ b/src/main/scala/leon/purescala/FunctionClosure.scala @@ -139,20 +139,10 @@ object FunctionClosure extends TransformationPhase { case None => v case Some(nid) => Variable(nid) } - case n @ NAryOperator(args, recons) => { + case n @ Operator(args, recons) => { val rargs = args.map(a => functionClosure(a, bindedVars, id2freshId, fd2FreshFd)) recons(rargs).copiedFrom(n) } - case b @ BinaryOperator(t1,t2,recons) => { - val r1 = functionClosure(t1, bindedVars, id2freshId, fd2FreshFd) - val r2 = functionClosure(t2, bindedVars, id2freshId, fd2FreshFd) - recons(r1,r2).copiedFrom(b) - } - case u @ UnaryOperator(t,recons) => { - val r = functionClosure(t, bindedVars, id2freshId, fd2FreshFd) - recons(r).copiedFrom(u) - } - case t : Terminal => t case unhandled => scala.sys.error("Non-terminal case should be handled in FunctionClosure: " + unhandled) } diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala index 887c50464eab12fcc4bebf46b83f1c5201bd92b9..a7ead369a606ce2fb51d82de4097cbfc723bfa73 100644 --- a/src/main/scala/leon/purescala/PrettyPrinter.scala +++ b/src/main/scala/leon/purescala/PrettyPrinter.scala @@ -10,7 +10,6 @@ import leon.purescala.Extractors._ import leon.purescala.PrinterHelpers._ import leon.purescala.ExprOps.{isListLiteral, simplestValue} import leon.purescala.Expressions._ -import leon.purescala.TypeOps.leastUpperBound import leon.purescala.Types._ import leon.synthesis.Witnesses._ @@ -277,25 +276,15 @@ class PrettyPrinter(opts: PrinterOptions, case BVLShiftRight(l,r) => optP { p"$l >>> $r" } case fs @ FiniteSet(rs, _) => p"{${rs.toSeq}}" case fm @ FiniteMap(rs, _, _) => p"{$rs}" - case FiniteMultiset(rs) => p"{|$rs|)" - case EmptyMultiset(_) => p"\u2205" case Not(ElementOfSet(e,s)) => p"$e \u2209 $s" case ElementOfSet(e,s) => p"$e \u2208 $s" case SubsetOf(l,r) => p"$l \u2286 $r" case Not(SubsetOf(l,r)) => p"$l \u2288 $r" - case SetMin(s) => p"$s.min" - case SetMax(s) => p"$s.max" case SetUnion(l,r) => p"$l \u222A $r" - case MultisetUnion(l,r) => p"$l \u222A $r" case MapUnion(l,r) => p"$l \u222A $r" case SetDifference(l,r) => p"$l \\ $r" - case MultisetDifference(l,r) => p"$l \\ $r" case SetIntersection(l,r) => p"$l \u2229 $r" - case MultisetIntersection(l,r) => p"$l \u2229 $r" case SetCardinality(s) => p"|$s|" - case MultisetCardinality(s) => p"|$s|" - case MultisetPlus(l,r) => p"$l \u228E $r" - case MultisetToSet(e) => p"$e.toSet" case MapGet(m,k) => p"$m($k)" case MapIsDefinedAt(m,k) => p"$m.isDefinedAt($k)" case ArrayLength(a) => p"$a.length" @@ -428,7 +417,6 @@ class PrettyPrinter(opts: PrinterOptions, case ArrayType(bt) => p"Array[$bt]" case SetType(bt) => p"Set[$bt]" case MapType(ft,tt) => p"Map[$ft, $tt]" - case MultisetType(bt) => p"Multiset[$bt]" case TupleType(tpes) => p"($tpes)" case FunctionType(fts, tt) => p"($fts) => $tt" case c: ClassType => diff --git a/src/main/scala/leon/purescala/ScopeSimplifier.scala b/src/main/scala/leon/purescala/ScopeSimplifier.scala index 1bc8e9379f88f63bfbfe10bd80d8abab679e5f83..5559308fe6f3e50b1ee22bc46a20d57d9339caf6 100644 --- a/src/main/scala/leon/purescala/ScopeSimplifier.scala +++ b/src/main/scala/leon/purescala/ScopeSimplifier.scala @@ -103,17 +103,9 @@ class ScopeSimplifier extends Transformer { FunctionInvocation(newFd.typed(tfd.tps), newArgs) - case UnaryOperator(e, builder) => - builder(rec(e, scope)) - - case BinaryOperator(e1, e2, builder) => - builder(rec(e1, scope), rec(e2, scope)) - - case NAryOperator(es, builder) => + case Operator(es, builder) => builder(es.map(rec(_, scope))) - case t : Terminal => t - case _ => sys.error("Expression "+e+" ["+e.getClass+"] is not extractable") } diff --git a/src/main/scala/leon/purescala/TransformerWithPC.scala b/src/main/scala/leon/purescala/TransformerWithPC.scala index 5037781d4c3db3505a58671bc80f9027ba048f81..d098bedf979fe8156f171a76b9b220bf45882603 100644 --- a/src/main/scala/leon/purescala/TransformerWithPC.scala +++ b/src/main/scala/leon/purescala/TransformerWithPC.scala @@ -67,13 +67,7 @@ abstract class TransformerWithPC extends Transformer { val rc = rec(lhs, path) Implies(rc, rec(rhs, register(rc, path))).copiedFrom(i) - case o @ UnaryOperator(e, builder) => - builder(rec(e, path)).copiedFrom(o) - - case o @ BinaryOperator(e1, e2, builder) => - builder(rec(e1, path), rec(e2, path)).copiedFrom(o) - - case o @ NAryOperator(es, builder) => + case o @ Operator(es, builder) => builder(es.map(rec(_, path))).copiedFrom(o) case t : Terminal => t diff --git a/src/main/scala/leon/purescala/TypeOps.scala b/src/main/scala/leon/purescala/TypeOps.scala index 1ab068c7bfc91aa82dcf26d772287ee537010d9e..8247a1ff5b5ebbd2c7b1863eed5f949b231578c6 100644 --- a/src/main/scala/leon/purescala/TypeOps.scala +++ b/src/main/scala/leon/purescala/TypeOps.scala @@ -321,13 +321,7 @@ object TypeOps { case v @ Variable(id) if idsMap contains id => Variable(idsMap(id)).copiedFrom(v) - case u @ UnaryOperator(e, builder) => - builder(srec(e)).copiedFrom(u) - - case b @ BinaryOperator(e1, e2, builder) => - builder(srec(e1), srec(e2)).copiedFrom(b) - - case n @ NAryOperator(es, builder) => + case n @ Operator(es, builder) => builder(es.map(srec)).copiedFrom(n) case _ => diff --git a/src/main/scala/leon/purescala/Types.scala b/src/main/scala/leon/purescala/Types.scala index db75ce8cdc1bda7d7ce08490a1944139bea5659f..8bf48519c329e5404438c05b8365795b77f06d41 100644 --- a/src/main/scala/leon/purescala/Types.scala +++ b/src/main/scala/leon/purescala/Types.scala @@ -65,7 +65,6 @@ object Types { } case class SetType(base: TypeTree) extends TypeTree - case class MultisetType(base: TypeTree) extends TypeTree case class MapType(from: TypeTree, to: TypeTree) extends TypeTree case class FunctionType(from: Seq[TypeTree], to: TypeTree) extends TypeTree case class ArrayType(base: TypeTree) extends TypeTree @@ -133,7 +132,6 @@ object Types { case TupleType(ts) => Some((ts, Constructors.tupleTypeWrap _)) case ArrayType(t) => Some((Seq(t), ts => ArrayType(ts.head))) case SetType(t) => Some((Seq(t), ts => SetType(ts.head))) - case MultisetType(t) => Some((Seq(t), ts => MultisetType(ts.head))) case MapType(from,to) => Some((Seq(from, to), t => MapType(t(0), t(1)))) case FunctionType(fts, tt) => Some((tt +: fts, ts => FunctionType(ts.tail.toList, ts.head))) case t => Some(Nil, _ => t) diff --git a/src/main/scala/leon/solvers/templates/TemplateGenerator.scala b/src/main/scala/leon/solvers/templates/TemplateGenerator.scala index 0ee3997c6eec28b9cba5c1ebfdba7ee635776cd5..f1ffe33bfda1fc38a370a570fc94200ea57d02ad 100644 --- a/src/main/scala/leon/solvers/templates/TemplateGenerator.scala +++ b/src/main/scala/leon/solvers/templates/TemplateGenerator.scala @@ -276,10 +276,8 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], Variable(lid) - case n @ NAryOperator(as, r) => r(as.map(a => rec(pathVar, a))) - case b @ BinaryOperator(a1, a2, r) => r(rec(pathVar, a1), rec(pathVar, a2)) - case u @ UnaryOperator(a, r) => r(rec(pathVar, a)) - case t : Terminal => t + case Operator(as, r) => r(as.map(a => rec(pathVar, a))) + } } diff --git a/src/main/scala/leon/solvers/templates/Templates.scala b/src/main/scala/leon/solvers/templates/Templates.scala index 62c0f1144cfbae792a8123696b9215e0722f9acc..38821e2a0a73c98fe45c05df0ff37b265d307aa9 100644 --- a/src/main/scala/leon/solvers/templates/Templates.scala +++ b/src/main/scala/leon/solvers/templates/Templates.scala @@ -446,15 +446,9 @@ class LambdaTemplate[T] private ( Seq.empty } - case (NAryOperator(es1, _), NAryOperator(es2, _)) => + case (Operator(es1, _), Operator(es2, _)) => (es1 zip es2).flatMap(p => rec(p._1, p._2)) - case (BinaryOperator(e11, e12, _), BinaryOperator(e21, e22, _)) => - rec(e11, e21) ++ rec(e12, e22) - - case (UnaryOperator(ue1, _), UnaryOperator(ue2, _)) => - rec(ue1, ue2) - case _ => Seq.empty } diff --git a/src/main/scala/leon/synthesis/Witnesses.scala b/src/main/scala/leon/synthesis/Witnesses.scala index e3da7a2bc0df2973066dbe47db19c449237df708..1213766ec5aad957e8af5f5f9c6abbccb33d1c1c 100644 --- a/src/main/scala/leon/synthesis/Witnesses.scala +++ b/src/main/scala/leon/synthesis/Witnesses.scala @@ -15,11 +15,11 @@ object Witnesses { val getType = BooleanType } - case class Guide(e : Expr) extends Witness with UnaryExtractable { - def extract: Option[(Expr, Expr => Expr)] = Some((e, Guide)) + case class Guide(e : Expr) extends Witness with Extractable { + def extract: Option[(Seq[Expr], Seq[Expr] => Expr)] = Some((Seq(e), (es: Seq[Expr]) => Guide(es.head))) } - case class Terminating(tfd: TypedFunDef, args: Seq[Expr]) extends Witness with NAryExtractable { + case class Terminating(tfd: TypedFunDef, args: Seq[Expr]) extends Witness with Extractable { def extract: Option[(Seq[Expr], Seq[Expr] => Expr)] = Some((args, Terminating(tfd, _))) } diff --git a/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala b/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala index c252880f59e6f1cba8c6248095f9448e21174574..2c539be22171749502a8b821f140264527502b67 100644 --- a/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala +++ b/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala @@ -333,12 +333,7 @@ object ExpressionGrammars { // We allow only exact call, and/or cegis extensions /*Seq(el -> Generator[L, Expr](Nil, { _ => e })) ++*/ cegis(gl) - case UnaryOperator(sub, builder) => - gens(e, gl, List(sub), { case Seq(s) => builder(s) }) - case BinaryOperator(sub1, sub2, builder) => - gens(e, gl, List(sub1, sub2), { case Seq(s1, s2) => builder(s1, s2) }) - - case NAryOperator(subs, builder) => + case Operator(subs, builder) => gens(e, gl, subs, { case ss => builder(ss) }) } diff --git a/src/main/scala/leon/utils/UnitElimination.scala b/src/main/scala/leon/utils/UnitElimination.scala index 0a487ef2842cc1b5da6e7aac77b585d4113e6335..1ff65119585e1fd2c6b5481d0e4040f4e2453dcc 100644 --- a/src/main/scala/leon/utils/UnitElimination.scala +++ b/src/main/scala/leon/utils/UnitElimination.scala @@ -119,17 +119,7 @@ object UnitElimination extends TransformationPhase { val elseRec = removeUnit(eExpr) IfExpr(removeUnit(cond), thenRec, elseRec) } - case n @ NAryOperator(args, recons) => { - recons(args.map(removeUnit)) - } - case b @ BinaryOperator(a1, a2, recons) => { - recons(removeUnit(a1), removeUnit(a2)) - } - case u @ UnaryOperator(a, recons) => { - recons(removeUnit(a)) - } case v @ Variable(id) => if(id2FreshId.isDefinedAt(id)) Variable(id2FreshId(id)) else v - case (t: Terminal) => t case m @ MatchExpr(scrut, cses) => { val scrutRec = removeUnit(scrut) val csesRec = cses.map{ cse => @@ -137,6 +127,11 @@ object UnitElimination extends TransformationPhase { } matchExpr(scrutRec, csesRec).setPos(m) } + case Operator(args, recons) => { + recons(args.map(removeUnit)) + } + // FIXME: It's dead (code) Jim! + case _ => sys.error("not supported: " + expr) } } diff --git a/src/main/scala/leon/xlang/ArrayTransformation.scala b/src/main/scala/leon/xlang/ArrayTransformation.scala index a05fc1bb8f453c22fd1c96f729c60cb0928a52f4..dfa87a993fde5eb4df3a418c4f018b4f97a7f19b 100644 --- a/src/main/scala/leon/xlang/ArrayTransformation.scala +++ b/src/main/scala/leon/xlang/ArrayTransformation.scala @@ -44,11 +44,8 @@ object ArrayTransformation extends UnitPhase[Program] { Variable(env.getOrElse(i, i)) } - case n @ NAryOperator(args, recons) => recons(args.map(transform)) - case b @ BinaryOperator(a1, a2, recons) => recons(transform(a1), transform(a2)) - case u @ UnaryOperator(a, recons) => recons(transform(a)) + case Operator(args, recons) => recons(args.map(transform)) - case (t: Terminal) => t case unhandled => scala.sys.error("Non-terminal case should be handled in ArrayTransformation: " + unhandled) }).setPos(expr) diff --git a/src/main/scala/leon/xlang/Expressions.scala b/src/main/scala/leon/xlang/Expressions.scala index 237636db74d6b92d74704114713e7270836f4e2d..66763915953458f47ab74d45c8e5ae3e3c6317ce 100644 --- a/src/main/scala/leon/xlang/Expressions.scala +++ b/src/main/scala/leon/xlang/Expressions.scala @@ -15,7 +15,7 @@ object Expressions { trait XLangExpr extends Expr - case class Block(exprs: Seq[Expr], last: Expr) extends XLangExpr with NAryExtractable with PrettyPrintable { + case class Block(exprs: Seq[Expr], last: Expr) extends XLangExpr with Extractable with PrettyPrintable { def extract: Option[(Seq[Expr], (Seq[Expr])=>Expr)] = { Some((exprs :+ last, exprs => Block(exprs.init, exprs.last))) } @@ -33,11 +33,11 @@ object Expressions { val getType = last.getType } - case class Assignment(varId: Identifier, expr: Expr) extends XLangExpr with UnaryExtractable with PrettyPrintable { + case class Assignment(varId: Identifier, expr: Expr) extends XLangExpr with Extractable with PrettyPrintable { val getType = UnitType - def extract: Option[(Expr, (Expr)=>Expr)] = { - Some((expr, Assignment(varId, _))) + def extract: Option[(Seq[Expr], Seq[Expr]=>Expr)] = { + Some((Seq(expr), (es: Seq[Expr]) => Assignment(varId, es.head))) } def printWith(implicit pctx: PrinterContext) { @@ -45,7 +45,7 @@ object Expressions { } } - case class While(cond: Expr, body: Expr) extends XLangExpr with NAryExtractable with PrettyPrintable { + case class While(cond: Expr, body: Expr) extends XLangExpr with Extractable with PrettyPrintable { val getType = UnitType var invariant: Option[Expr] = None @@ -54,10 +54,10 @@ object Expressions { def setInvariant(inv: Option[Expr]) = { invariant = inv; this } def extract: Option[(Seq[Expr], Seq[Expr]=>Expr)] = { - Some((Seq(cond, body) ++ invariant, { + Some((Seq(cond, body) ++ invariant, { (es:Seq[Expr]) => es match { case Seq(e1, e2) => While(e1, e2).setPos(this) case Seq(e1, e2, e3) => While(e1, e2).setInvariant(e3).setPos(this) - })) + }})) } def printWith(implicit pctx: PrinterContext) { @@ -74,9 +74,9 @@ object Expressions { } } - case class Epsilon(pred: Expr, tpe: TypeTree) extends XLangExpr with UnaryExtractable with PrettyPrintable { - def extract: Option[(Expr, (Expr)=>Expr)] = { - Some((pred, (expr: Expr) => Epsilon(expr, this.getType).setPos(this))) + case class Epsilon(pred: Expr, tpe: TypeTree) extends XLangExpr with Extractable with PrettyPrintable { + def extract: Option[(Seq[Expr], Seq[Expr]=>Expr)] = { + Some((Seq(pred), (es: Seq[Expr]) => Epsilon(es.head, this.getType).setPos(this))) } def printWith(implicit pctx: PrinterContext) { @@ -96,11 +96,11 @@ object Expressions { } //same as let, buf for mutable variable declaration - case class LetVar(binder: Identifier, value: Expr, body: Expr) extends XLangExpr with BinaryExtractable with PrettyPrintable { + case class LetVar(binder: Identifier, value: Expr, body: Expr) extends XLangExpr with Extractable with PrettyPrintable { val getType = body.getType - def extract: Option[(Expr, Expr, (Expr, Expr)=>Expr)] = { - Some((value, body, (e: Expr, b: Expr) => LetVar(binder, e, b))) + def extract: Option[(Seq[Expr], Seq[Expr]=>Expr)] = { + Some( Seq(value, body), (es:Seq[Expr]) => LetVar(binder, es(0), es(1)) ) } def printWith(implicit pctx: PrinterContext) { @@ -111,9 +111,9 @@ object Expressions { } } - case class Waypoint(i: Int, expr: Expr, tpe: TypeTree) extends XLangExpr with UnaryExtractable with PrettyPrintable{ - def extract: Option[(Expr, (Expr)=>Expr)] = { - Some((expr, (e: Expr) => Waypoint(i, e, tpe))) + case class Waypoint(i: Int, expr: Expr, tpe: TypeTree) extends XLangExpr with Extractable with PrettyPrintable{ + def extract: Option[(Seq[Expr], Seq[Expr]=>Expr)] = { + Some((Seq(expr), (es: Seq[Expr]) => Waypoint(i, es.head, tpe))) } def printWith(implicit pctx: PrinterContext) { @@ -123,7 +123,7 @@ object Expressions { val getType = tpe } - case class ArrayUpdate(array: Expr, index: Expr, newValue: Expr) extends XLangExpr with NAryExtractable with PrettyPrintable { + case class ArrayUpdate(array: Expr, index: Expr, newValue: Expr) extends XLangExpr with Extractable with PrettyPrintable { val getType = UnitType def extract: Option[(Seq[Expr], (Seq[Expr])=>Expr)] = { diff --git a/src/main/scala/leon/xlang/ImperativeCodeElimination.scala b/src/main/scala/leon/xlang/ImperativeCodeElimination.scala index e6ef2c65d717f4956844dba3f681e5a1cea92b1b..f9b2b85fbca82ba289986fe6221cc229c30cc84f 100644 --- a/src/main/scala/leon/xlang/ImperativeCodeElimination.scala +++ b/src/main/scala/leon/xlang/ImperativeCodeElimination.scala @@ -258,8 +258,7 @@ object ImperativeCodeElimination extends TransformationPhase { toFunction(ifExpr) } - case n @ NAryOperator(Seq(), recons) => (n, (body: Expr) => body, Map()) - case n @ NAryOperator(args, recons) => { + case n @ Operator(args, recons) => { val (recArgs, scope, fun) = args.foldRight((Seq[Expr](), (body: Expr) => body, Map[Identifier, Identifier]()))((arg, acc) => { val (accArgs, accScope, accFun) = acc val (argVal, argScope, argFun) = toFunction(arg) @@ -268,23 +267,6 @@ object ImperativeCodeElimination extends TransformationPhase { }) (recons(recArgs).copiedFrom(n), scope, fun) } - case b @ BinaryOperator(a1, a2, recons) => { - val (argVal1, argScope1, argFun1) = toFunction(a1) - val (argVal2, argScope2, argFun2) = toFunction(a2) - val scope = (body: Expr) => { - val rhs = argScope2(replaceNames(argFun2, body)) - val lhs = argScope1(replaceNames(argFun1, rhs)) - lhs - } - (recons(argVal1, argVal2).copiedFrom(b), scope, argFun1 ++ argFun2) - } - case u @ UnaryOperator(a, recons) => { - val (argVal, argScope, argFun) = toFunction(a) - (recons(argVal).copiedFrom(u), argScope, argFun) - } - - case (t: Terminal) => (t, (body: Expr) => body, Map()) - case _ => sys.error("not supported: " + expr) }