From d03bb32b9737ae58624a696cd928bffb1c245423 Mon Sep 17 00:00:00 2001 From: Manos Koukoutos <emmanouil.koukoutos@epfl.ch> Date: Thu, 25 Jun 2015 19:38:34 +0200 Subject: [PATCH] Simplifications --- .../scala/leon/datagen/NaiveDataGen.scala | 8 ++--- .../leon/evaluators/RecursiveEvaluator.scala | 14 +++------ .../frontends/scalac/CodeExtraction.scala | 15 +++++---- .../leon/purescala/TreeNormalizations.scala | 5 ++- .../scala/leon/termination/ChainBuilder.scala | 16 +++++----- .../SimpleTerminationChecker.scala | 5 ++- .../scala/leon/termination/Strengthener.scala | 9 ++++-- .../leon/verification/InductionTactic.scala | 31 ++++++++++--------- 8 files changed, 53 insertions(+), 50 deletions(-) diff --git a/src/main/scala/leon/datagen/NaiveDataGen.scala b/src/main/scala/leon/datagen/NaiveDataGen.scala index 6460ec491..9726d1876 100644 --- a/src/main/scala/leon/datagen/NaiveDataGen.scala +++ b/src/main/scala/leon/datagen/NaiveDataGen.scala @@ -63,14 +63,10 @@ class NaiveDataGen(ctx: LeonContext, p: Program, evaluator: Evaluator, _bounds : val sortedConss = conss sortBy { _.fields.count{ _.getType.isInstanceOf[ClassType]}} // The stream for leafs... - val leafsStream = leafs.toStream.flatMap { cct => - generate(cct) - } + val leafsStream = leafs.toStream.flatMap(generate) // ...to which we append the streams for constructors. - leafsStream.append(interleave(sortedConss.map { cct => - generate(cct) - })) + leafsStream.append(interleave(sortedConss.map(generate))) case cct : CaseClassType => if(cct.fields.isEmpty) { diff --git a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala index ff2e83091..1c3a6fd48 100644 --- a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala +++ b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala @@ -623,19 +623,15 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int caze match { case SimpleCase(p, rhs) => - matchesPattern(p, scrut).map( r => + matchesPattern(p, scrut).map(r => (caze, r) ) case GuardedCase(p, g, rhs) => - matchesPattern(p, scrut).flatMap( r => - e(g)(rctx.withNewVars(r), gctx) match { - case BooleanLiteral(true) => - Some((caze, r)) - case _ => - None - } - ) + for { + r <- matchesPattern(p, scrut) + BooleanLiteral(true) = e(g)(rctx.withNewVars(r), gctx) + } yield (caze, r) } } diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala index 3586bfc55..d95d4530e 100644 --- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala +++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala @@ -48,14 +48,13 @@ trait CodeExtraction extends ASTExtractors { def annotationsOf(s: Symbol): Set[String] = { val actualSymbol = s.accessedOrSelf - (for(a <- actualSymbol.annotations ++ actualSymbol.owner.annotations) yield { - val name = a.atp.safeToString.replaceAll("\\.package\\.", ".") - if (name startsWith "leon.annotation.") { - Some(name.split("\\.", 3)(2)) - } else { - None - } - }).flatten.toSet + (for { + a <- actualSymbol.annotations ++ actualSymbol.owner.annotations + name = a.atp.safeToString.replaceAll("\\.package\\.", ".") + if (name startsWith "leon.annotation.") + } yield { + name.split("\\.", 3)(2) + }).toSet } implicit def scalaPosToLeonPos(p: global.Position): LeonPosition = { diff --git a/src/main/scala/leon/purescala/TreeNormalizations.scala b/src/main/scala/leon/purescala/TreeNormalizations.scala index 8b264756d..d3dff445c 100644 --- a/src/main/scala/leon/purescala/TreeNormalizations.scala +++ b/src/main/scala/leon/purescala/TreeNormalizations.scala @@ -70,7 +70,10 @@ object TreeNormalizations { //multiply two sums together and distribute in a larger sum //do not keep the evaluation order def multiply(es1: Seq[Expr], es2: Seq[Expr]): Seq[Expr] = { - es1.flatMap(e1 => es2.map(e2 => Times(e1, e2))) + for { + e1 <- es1 + e2 <- es2 + } yield Times(e1,e2) } //expand the expr in a sum of "atoms", each atom being a product of literal and variable diff --git a/src/main/scala/leon/termination/ChainBuilder.scala b/src/main/scala/leon/termination/ChainBuilder.scala index 11874f798..be8b27f3c 100644 --- a/src/main/scala/leon/termination/ChainBuilder.scala +++ b/src/main/scala/leon/termination/ChainBuilder.scala @@ -100,14 +100,14 @@ final case class Chain(relations: List[Relation]) { val tmap = that.relations.zipWithIndex.map(p => p._1.funDef -> p._2).groupBy(_._1).mapValues(_.map(_._2)) val keys = map.keys.toSet & tmap.keys.toSet - keys.flatMap(fd => map(fd).flatMap { i1 => - val (start1, end1) = relations.splitAt(i1) - val called = if (start1.isEmpty) relations.head.funDef else start1.last.call.tfd.fd - tmap(called).map { i2 => - val (start2, end2) = that.relations.splitAt(i2) - Chain(start1 ++ end2 ++ start2 ++ end1) - } - }) + for { + fd <- keys + i1 <- map(fd) + (start1, end1) = relations.splitAt(i1) + called = if (start1.isEmpty) relations.head.funDef else start1.last.call.tfd.fd + i2 <- tmap(called) + (start2, end2) = that.relations.splitAt(i2) + } yield Chain(start1 ++ end2 ++ start2 ++ end1) } lazy val inlined: Seq[Expr] = inlining.map(_._2) diff --git a/src/main/scala/leon/termination/SimpleTerminationChecker.scala b/src/main/scala/leon/termination/SimpleTerminationChecker.scala index 2a6c5bcbc..e615fc373 100644 --- a/src/main/scala/leon/termination/SimpleTerminationChecker.scala +++ b/src/main/scala/leon/termination/SimpleTerminationChecker.scala @@ -31,7 +31,10 @@ class SimpleTerminationChecker(context: LeonContext, program: Program) extends T v -> (0 until cSize).find(i => sccArray(i)(v)).get).toMap val sccGraph = (0 until cSize).map({ i => - val dsts = sccArray(i).flatMap(v => callGraph.getOrElse(v, Set.empty)).map(funDefToSCCIndex(_)) + val dsts = for { + v <- sccArray(i) + c <- callGraph.getOrElse(v, Set.empty) + } yield funDefToSCCIndex(c) i -> dsts }).toMap diff --git a/src/main/scala/leon/termination/Strengthener.scala b/src/main/scala/leon/termination/Strengthener.scala index ed77e8bef..192e78599 100644 --- a/src/main/scala/leon/termination/Strengthener.scala +++ b/src/main/scala/leon/termination/Strengthener.scala @@ -81,7 +81,7 @@ trait Strengthener { self : RelationComparator => } def strengthenApplications(funDefs: Set[FunDef])(implicit solver: Processor with Solvable) { - val transitiveFunDefs = funDefs ++ funDefs.flatMap(fd => checker.program.callGraph.transitiveCallees(fd)) + val transitiveFunDefs = funDefs ++ funDefs.flatMap(checker.program.callGraph.transitiveCallees) val sortedFunDefs = transitiveFunDefs.toSeq.sortWith((fd1, fd2) => checker.program.callGraph.transitivelyCalls(fd2, fd1)) for (funDef <- sortedFunDefs if !strengthenedApp(funDef) && funDef.hasBody && checker.terminates(funDef).isGuaranteed) { @@ -131,8 +131,11 @@ trait Strengthener { self : RelationComparator => val invocations = fiCollector.traverse(funDef) val id2invocations : Seq[(Identifier, ((FunDef, Identifier), Expr, Seq[Expr]))] = - invocations.flatMap(p => p._3.map(c => c._1 -> ((c._2, p._1, p._2)))) - val invocationMap : Map[Identifier, Seq[((FunDef, Identifier), Expr, Seq[Expr])]] = + for { + p <- invocations + c <- p._3 + } yield c._1 -> (c._2, p._1, p._2) + val invocationMap: Map[Identifier, Seq[((FunDef, Identifier), Expr, Seq[Expr])]] = id2invocations.groupBy(_._1).mapValues(_.map(_._2)) def constraint(id: Identifier, passings: Seq[((FunDef, Identifier), Expr, Seq[Expr])]): SizeConstraint = { diff --git a/src/main/scala/leon/verification/InductionTactic.scala b/src/main/scala/leon/verification/InductionTactic.scala index ba4039b26..4352adc27 100644 --- a/src/main/scala/leon/verification/InductionTactic.scala +++ b/src/main/scala/leon/verification/InductionTactic.scala @@ -64,24 +64,27 @@ class InductionTactic(vctx: VerificationContext) extends DefaultTactic(vctx) { case fi @ FunctionInvocation(tfd, _) if tfd.hasPrecondition => (fi, tfd.precondition.get) }(body) - calls.flatMap { - case ((fi @ FunctionInvocation(tfd, args), pre), path) => - for (cct <- parentType.knownCCDescendents) yield { - val selectors = selectorsOfParentType(parentType, cct, arg.toVariable) + for { + ((fi@FunctionInvocation(tfd, args), pre), path) <- calls + cct <- parentType.knownCCDescendents + } yield { + val selectors = selectorsOfParentType(parentType, cct, arg.toVariable) - val subCases = selectors.map { sel => - replace(Map(arg.toVariable -> sel), - implies(precOrTrue(fd), replace((tfd.params.map(_.toVariable) zip args).toMap, pre)) - ) - } + val subCases = selectors.map { sel => + replace(Map(arg.toVariable -> sel), + implies(precOrTrue(fd), replace((tfd.params.map(_.toVariable) zip args).toMap, pre)) + ) + } - val vc = implies(and(CaseClassInstanceOf(cct, arg.toVariable), precOrTrue(fd), path), implies(andJoin(subCases), replace((tfd.params.map(_.toVariable) zip args).toMap, pre))) + val vc = implies( + andJoin(Seq(CaseClassInstanceOf(cct, arg.toVariable), precOrTrue(fd), path) ++ subCases), + replace((tfd.params.map(_.toVariable) zip args).toMap, pre) + ) - // Crop the call to display it properly - val fiS = sizeLimit(fi.toString, 25) + // Crop the call to display it properly + val fiS = sizeLimit(fi.toString, 25) - VC(vc, fd, VCKinds.Info(VCKinds.Precondition, s"call $fiS, ind. on ($arg : ${cct.classDef.id})"), this).setPos(fi) - } + VC(vc, fd, VCKinds.Info(VCKinds.Precondition, s"call $fiS, ind. on ($arg : ${cct.classDef.id})"), this).setPos(fi) } case (body, _) => -- GitLab