diff --git a/run-tests.sh b/run-tests.sh index 7a36168c1eb87dea5b72074eaf7277a2cbcbc5e6..15852d10f97f27a322dc6fee842d07fed1b848f1 100755 --- a/run-tests.sh +++ b/run-tests.sh @@ -7,7 +7,7 @@ failedtests="" for f in $base/valid/*.scala; do echo -n "Running $f, expecting VALID, got: " - res=`./leon --noLuckyTests --timeout=10 --oneline "$f"` + res=`./leon --xlang --noLuckyTests --timeout=10 --oneline "$f"` echo $res | tr [a-z] [A-Z] if [ $res = valid ]; then nbsuccess=$((nbsuccess + 1)) @@ -18,7 +18,7 @@ done for f in $base/invalid/*.scala; do echo -n "Running $f, expecting INVALID, got: " - res=`./leon --noLuckyTests --timeout=10 --oneline "$f"` + res=`./leon --xlang --noLuckyTests --timeout=10 --oneline "$f"` echo $res | tr [a-z] [A-Z] if [ $res = invalid ]; then nbsuccess=$((nbsuccess + 1)) @@ -29,7 +29,7 @@ done for f in $base/error/*.scala; do echo -n "Running $f, expecting ERROR, got: " - res=`./leon --noLuckyTests --timeout=10 --oneline "$f"` + res=`./leon --xlang --noLuckyTests --timeout=10 --oneline "$f"` echo $res | tr [a-z] [A-Z] if [ $res = error ]; then nbsuccess=$((nbsuccess + 1)) diff --git a/src/main/scala/leon/Analysis.scala b/src/main/scala/leon/Analysis.scala index ce98c70aaa2552f4506fd94610668c082aa56c71..62ea2a39b82b64a25bf15f17f579a0de2426808e 100644 --- a/src/main/scala/leon/Analysis.scala +++ b/src/main/scala/leon/Analysis.scala @@ -3,6 +3,7 @@ package leon import purescala.Common._ import purescala.Definitions._ import purescala.Trees._ +import purescala.TreeOps._ import purescala.TypeTrees._ import Extensions._ import scala.collection.mutable.{Set => MutableSet} diff --git a/src/main/scala/leon/ArrayTransformation.scala b/src/main/scala/leon/ArrayTransformation.scala index 9d0403039b767571a4ae23a6d9ffc0c29ef7e28b..fd46f449b1504e4082c97bfcf5267831381dc760 100644 --- a/src/main/scala/leon/ArrayTransformation.scala +++ b/src/main/scala/leon/ArrayTransformation.scala @@ -3,6 +3,7 @@ package leon import purescala.Common._ import purescala.Definitions._ import purescala.Trees._ +import purescala.Extractors._ import purescala.TypeTrees._ object ArrayTransformation extends TransformationPhase { diff --git a/src/main/scala/leon/DefaultTactic.scala b/src/main/scala/leon/DefaultTactic.scala index 70ad17f36150a47f846fe67a8bf8db350e03a9bf..4078e6e490f8a1bbe5adc5a985459d144a13c363 100644 --- a/src/main/scala/leon/DefaultTactic.scala +++ b/src/main/scala/leon/DefaultTactic.scala @@ -2,6 +2,8 @@ package leon import purescala.Common._ import purescala.Trees._ +import purescala.TreeOps._ +import purescala.Extractors._ import purescala.Definitions._ import Extensions.Tactic diff --git a/src/main/scala/leon/EpsilonElimination.scala b/src/main/scala/leon/EpsilonElimination.scala index a785ddf9e6710ec7c6ace7dda0c2f223c9a88eb1..dc69875d198f03366b99c30f4c5034dc89fa6b5d 100644 --- a/src/main/scala/leon/EpsilonElimination.scala +++ b/src/main/scala/leon/EpsilonElimination.scala @@ -3,6 +3,7 @@ package leon import purescala.Common._ import purescala.Definitions._ import purescala.Trees._ +import purescala.TreeOps._ import purescala.TypeTrees._ object EpsilonElimination extends TransformationPhase { diff --git a/src/main/scala/leon/Evaluator.scala b/src/main/scala/leon/Evaluator.scala index d5c5745c19bcb8b3c58c1496c6a63547d179dcee..9de1eafc5865efddb139d0f173563de698bded71 100644 --- a/src/main/scala/leon/Evaluator.scala +++ b/src/main/scala/leon/Evaluator.scala @@ -2,6 +2,7 @@ package leon import purescala.Common._ import purescala.Trees._ +import purescala.TreeOps._ import purescala.TypeTrees._ object Evaluator { diff --git a/src/main/scala/leon/Extensions.scala b/src/main/scala/leon/Extensions.scala index 53bd6f37e0f0d419acfe4f1d879cde7e85537244..1d6bb8198bce940d3ede8a9381a6c48a7ee502c4 100644 --- a/src/main/scala/leon/Extensions.scala +++ b/src/main/scala/leon/Extensions.scala @@ -2,6 +2,7 @@ package leon import purescala.Common._ import purescala.Trees._ +import purescala.TreeOps._ import purescala.Definitions._ object Extensions { diff --git a/src/main/scala/leon/FairZ3Solver.scala b/src/main/scala/leon/FairZ3Solver.scala index b54f5db1d9c7456dcda45756433167f8a0a7ebd4..23f93d497c15aab940a9af268775d6b660f90dcb 100644 --- a/src/main/scala/leon/FairZ3Solver.scala +++ b/src/main/scala/leon/FairZ3Solver.scala @@ -4,6 +4,8 @@ import z3.scala._ import purescala.Common._ import purescala.Definitions._ import purescala.Trees._ +import purescala.Extractors._ +import purescala.TreeOps._ import purescala.TypeTrees._ import Extensions._ diff --git a/src/main/scala/leon/FunctionClosure.scala b/src/main/scala/leon/FunctionClosure.scala index ed5920185c0940375fc0a270e9f13204830bebc3..4da2ac1e9cd639e8a736670a3746098d431e790c 100644 --- a/src/main/scala/leon/FunctionClosure.scala +++ b/src/main/scala/leon/FunctionClosure.scala @@ -3,6 +3,8 @@ package leon import purescala.Common._ import purescala.Definitions._ import purescala.Trees._ +import purescala.Extractors._ +import purescala.TreeOps._ import purescala.TypeTrees._ object FunctionClosure extends TransformationPhase{ diff --git a/src/main/scala/leon/FunctionHoisting.scala b/src/main/scala/leon/FunctionHoisting.scala index d0fd3f83f80499b4986a986ff1fcfeee60ee2378..64a44b01b684c3dc7f4da2d040afff8a51d9608f 100644 --- a/src/main/scala/leon/FunctionHoisting.scala +++ b/src/main/scala/leon/FunctionHoisting.scala @@ -3,6 +3,7 @@ package leon import purescala.Common._ import purescala.Definitions._ import purescala.Trees._ +import purescala.Extractors._ import purescala.TypeTrees._ object FunctionHoisting extends TransformationPhase { diff --git a/src/main/scala/leon/FunctionTemplate.scala b/src/main/scala/leon/FunctionTemplate.scala index be4189b4283b2d79eefd63c94726ea7f6d683000..279bafef9aa1aad7570da4a7faf8780d614f1166 100644 --- a/src/main/scala/leon/FunctionTemplate.scala +++ b/src/main/scala/leon/FunctionTemplate.scala @@ -2,6 +2,8 @@ package leon import purescala.Common._ import purescala.Trees._ +import purescala.Extractors._ +import purescala.TreeOps._ import purescala.TypeTrees._ import purescala.Definitions._ diff --git a/src/main/scala/leon/ImperativeCodeElimination.scala b/src/main/scala/leon/ImperativeCodeElimination.scala index 18243ea6ba735ec73b402003ac0fefbfa63c061e..74bb863094d485c5e6f2967583ecdc3aa3a93708 100644 --- a/src/main/scala/leon/ImperativeCodeElimination.scala +++ b/src/main/scala/leon/ImperativeCodeElimination.scala @@ -3,7 +3,9 @@ package leon import purescala.Common._ import purescala.Definitions._ import purescala.Trees._ +import purescala.Extractors._ import purescala.TypeTrees._ +import purescala.TreeOps._ object ImperativeCodeElimination extends TransformationPhase { diff --git a/src/main/scala/leon/InductionTactic.scala b/src/main/scala/leon/InductionTactic.scala index 1d30adddf62825cf8184c9eb42aa8677bb28134f..4adc158b323678d423ce9c7bb80705ac2512fc25 100644 --- a/src/main/scala/leon/InductionTactic.scala +++ b/src/main/scala/leon/InductionTactic.scala @@ -2,6 +2,7 @@ package leon import purescala.Common._ import purescala.Trees._ +import purescala.TreeOps._ import purescala.TypeTrees._ import purescala.Definitions._ diff --git a/src/main/scala/leon/Main.scala b/src/main/scala/leon/Main.scala index 013ae1cf5ce21156271ad66f6482a2716c125e88..b7681eaf955831465f2e8c2100ec8296ee1c6494 100644 --- a/src/main/scala/leon/Main.scala +++ b/src/main/scala/leon/Main.scala @@ -17,13 +17,28 @@ object Main { ) } + lazy val allOptions = allPhases.flatMap(_.definedOptions) ++ Set( + LeonOptionDef("synthesis", true, "--synthesis Partial synthesis or choose() constructs"), + LeonOptionDef("xlang", true, "--xlang Support for extra program constructs (imperative,...)"), + LeonOptionDef("parse", true, "--parse Checks only whether the program is valid PureScala"), + LeonOptionDef("debug", false, "--debug=[1-5] Debug level"), + LeonOptionDef("help", true, "--help This help") + ) + + def displayHelp(reporter: Reporter) { + reporter.info("usage: leon [--xlang] [--help] [--synthesis] [--help] [--debug=<N>] [..] <files>") + reporter.info("") + reporter.info("Leon options are:") + for (opt <- allOptions.toSeq.sortBy(_.name)) { + reporter.info(" "+opt.description) + } + sys.exit(1) + } + def processOptions(reporter: Reporter, args: List[String]) = { val phases = allPhases - val allOptions = allPhases.flatMap(_.definedOptions) ++ Set( - LeonOptionDef("synthesis", true, "--synthesis Magic!"), - LeonOptionDef("xlang", true, "--xlang Preprocessing and transformation from extended programs") - ) + val allOptions = this.allOptions val allOptionsMap = allOptions.map(o => o.name -> o).toMap @@ -47,10 +62,12 @@ object Main { case (false, LeonValueOption(name, value)) => Some(leonOpt) case _ => - reporter.fatalError("Invalid option usage") + reporter.error("Invalid option usage: "+opt) + displayHelp(reporter) None } } else { + reporter.error("leon: '"+opt+"' is not a valid option. See 'leon --help'") None } } @@ -63,6 +80,10 @@ object Main { settings = settings.copy(synthesis = true, xlang = false, analyze = false) case LeonFlagOption("xlang") => settings = settings.copy(synthesis = false, xlang = true) + case LeonFlagOption("parse") => + settings = settings.copy(synthesis = false, xlang = false, analyze = false) + case LeonFlagOption("help") => + displayHelp(reporter) case _ => } @@ -80,7 +101,8 @@ object Main { if (settings.xlang) { ArrayTransformation andThen EpsilonElimination andThen - ImperativeCodeElimination + ImperativeCodeElimination andThen + FunctionClosure } else { NoopPhase[Program]() } @@ -112,8 +134,10 @@ object Main { // Process options val ctx = processOptions(reporter, args.toList) + // Compute leon pipeline val pipeline = computePipeLine(ctx.settings) + // Run phases pipeline.run(ctx)(args.toList) } } diff --git a/src/main/scala/leon/RandomSolver.scala b/src/main/scala/leon/RandomSolver.scala index 4cee5b06181bbc803cc2061af9f62c33c5d04386..778a44e5942c3c4a1138367e63c88c826d1f4fc2 100644 --- a/src/main/scala/leon/RandomSolver.scala +++ b/src/main/scala/leon/RandomSolver.scala @@ -3,6 +3,7 @@ package leon import purescala.Common._ import purescala.Definitions._ import purescala.Trees._ +import purescala.TreeOps._ import purescala.TypeTrees._ import Extensions._ diff --git a/src/main/scala/leon/Simplificator.scala b/src/main/scala/leon/Simplificator.scala index 92b54072a1d89e176f64225225789161a25870af..bdc989e20dbda46eb03eb50a190c7a6390a9ae92 100644 --- a/src/main/scala/leon/Simplificator.scala +++ b/src/main/scala/leon/Simplificator.scala @@ -4,6 +4,7 @@ import purescala.Common._ import purescala.Definitions._ import purescala.Trees._ import purescala.TypeTrees._ +import purescala.TreeOps._ object Simplificator extends TransformationPhase { diff --git a/src/main/scala/leon/TestExtension.scala b/src/main/scala/leon/TestExtension.scala index c76f81caf5aec469c7b22a1de442e01a83236674..5fd517280e6c00eb4939e0615a5b6d41f0822379 100644 --- a/src/main/scala/leon/TestExtension.scala +++ b/src/main/scala/leon/TestExtension.scala @@ -2,6 +2,8 @@ package leon import purescala.Common._ import purescala.Trees._ +import purescala.TreeOps._ +import purescala.Extractors._ import purescala.TypeTrees._ import purescala.Definitions._ import Extensions._ diff --git a/src/main/scala/leon/UnitElimination.scala b/src/main/scala/leon/UnitElimination.scala index d30ac296f3206780e743c56bc54364a0c60e766e..6d3aa84562cbe56fee5b018e9b93578dd0b9b6ce 100644 --- a/src/main/scala/leon/UnitElimination.scala +++ b/src/main/scala/leon/UnitElimination.scala @@ -3,6 +3,7 @@ package leon import purescala.Common._ import purescala.Definitions._ import purescala.Trees._ +import purescala.Extractors._ import purescala.TypeTrees._ object UnitElimination extends TransformationPhase { diff --git a/src/main/scala/leon/Z3ModelReconstruction.scala b/src/main/scala/leon/Z3ModelReconstruction.scala index 913da2fa32b64056026f24e3d1ea1cbd8dd81672..87c75ddd5fb6667764727daa85f018deeeb4b6ec 100644 --- a/src/main/scala/leon/Z3ModelReconstruction.scala +++ b/src/main/scala/leon/Z3ModelReconstruction.scala @@ -4,6 +4,7 @@ import z3.scala._ import purescala.Common._ import purescala.Definitions._ import purescala.Trees._ +import purescala.TreeOps._ import purescala.TypeTrees._ import Extensions._ diff --git a/src/main/scala/leon/isabelle/Main.scala b/src/main/scala/leon/isabelle/Main.scala index 3fd96cff3553e8893dce9b9ba0e24609796c6685..ac95c8095ec4a9af8543e8d0cecc795a5f8acd19 100644 --- a/src/main/scala/leon/isabelle/Main.scala +++ b/src/main/scala/leon/isabelle/Main.scala @@ -8,6 +8,8 @@ import leon.purescala.Common.Identifier import leon.purescala.Definitions._ import leon.purescala.PrettyPrinter import leon.purescala.Trees._ +import leon.purescala.TreeOps._ +import leon.purescala.Extractors._ import leon.purescala.TypeTrees._ import java.lang.StringBuffer diff --git a/src/main/scala/leon/plugin/CodeExtraction.scala b/src/main/scala/leon/plugin/CodeExtraction.scala index 9c2d4fbf80afe77e878ec4fb1892365a69fad44e..385557d150caa7f989fe0350a19d406e97d6959c 100644 --- a/src/main/scala/leon/plugin/CodeExtraction.scala +++ b/src/main/scala/leon/plugin/CodeExtraction.scala @@ -8,6 +8,7 @@ import purescala.Definitions._ import purescala.Trees.{Block => PBlock, _} import purescala.TypeTrees._ import purescala.Common._ +import purescala.TreeOps._ trait CodeExtraction extends Extractors { self: AnalysisComponent => diff --git a/src/main/scala/leon/purescala/Definitions.scala b/src/main/scala/leon/purescala/Definitions.scala index d39d8abf502cf075688a7839fdb3e8eb4977ae7b..969ef362c266e2c82bed2203509031ee2c86c4c8 100644 --- a/src/main/scala/leon/purescala/Definitions.scala +++ b/src/main/scala/leon/purescala/Definitions.scala @@ -4,6 +4,8 @@ package purescala object Definitions { import Common._ import Trees._ + import TreeOps._ + import Extractors._ import TypeTrees._ sealed abstract class Definition extends Serializable { @@ -53,7 +55,7 @@ object Definitions { def isCatamorphism(f1: FunDef) = mainObject.isCatamorphism(f1) def caseClassDef(name: String) = mainObject.caseClassDef(name) def allIdentifiers : Set[Identifier] = mainObject.allIdentifiers + id - def isPure: Boolean = definedFunctions.forall(fd => fd.body.forall(Trees.isPure) && fd.precondition.forall(Trees.isPure) && fd.postcondition.forall(Trees.isPure)) + def isPure: Boolean = definedFunctions.forall(fd => fd.body.forall(TreeOps.isPure) && fd.precondition.forall(TreeOps.isPure) && fd.postcondition.forall(TreeOps.isPure)) def writeScalaFile(filename: String) { import java.io.FileWriter @@ -77,7 +79,7 @@ object Definitions { def allIdentifiers : Set[Identifier] = { (defs map (_.allIdentifiers)).foldLeft(Set[Identifier]())((a, b) => a ++ b) ++ - (invariants map (Trees.allIdentifiers(_))).foldLeft(Set[Identifier]())((a, b) => a ++ b) + id + (invariants map (TreeOps.allIdentifiers(_))).foldLeft(Set[Identifier]())((a, b) => a ++ b) + id } lazy val classHierarchyRoots : Seq[ClassTypeDef] = defs.filter(_.isInstanceOf[ClassTypeDef]).map(_.asInstanceOf[ClassTypeDef]).filter(!_.hasParent) @@ -276,7 +278,7 @@ object Definitions { /** Values */ case class ValDef(varDecl: VarDecl, value: Expr) extends Definition { val id: Identifier = varDecl.id - def allIdentifiers : Set[Identifier] = Trees.allIdentifiers(value) + id + def allIdentifiers : Set[Identifier] = TreeOps.allIdentifiers(value) + id } /** Functions (= 'methods' of objects) */ @@ -312,9 +314,9 @@ object Definitions { def allIdentifiers : Set[Identifier] = { args.map(_.id).toSet ++ - body.map(Trees.allIdentifiers(_)).getOrElse(Set[Identifier]()) ++ - precondition.map(Trees.allIdentifiers(_)).getOrElse(Set[Identifier]()) ++ - postcondition.map(Trees.allIdentifiers(_)).getOrElse(Set[Identifier]()) + id + body.map(TreeOps.allIdentifiers(_)).getOrElse(Set[Identifier]()) ++ + precondition.map(TreeOps.allIdentifiers(_)).getOrElse(Set[Identifier]()) ++ + postcondition.map(TreeOps.allIdentifiers(_)).getOrElse(Set[Identifier]()) + id } private var annots: Set[String] = Set.empty[String] diff --git a/src/main/scala/leon/purescala/Extractors.scala b/src/main/scala/leon/purescala/Extractors.scala new file mode 100644 index 0000000000000000000000000000000000000000..c169ba5f913b100c4118baaed4cb6cc4b574345c --- /dev/null +++ b/src/main/scala/leon/purescala/Extractors.scala @@ -0,0 +1,220 @@ +package leon +package purescala + +import Trees._ + +object Extractors { + import Common._ + import TypeTrees._ + import Definitions._ + import Extractors._ + import TreeOps._ + + object UnaryOperator { + def unapply(expr: Expr) : Option[(Expr,(Expr)=>Expr)] = expr match { + case Not(t) => Some((t,Not(_))) + case UMinus(t) => Some((t,UMinus)) + case SetCardinality(t) => Some((t,SetCardinality)) + case MultisetCardinality(t) => Some((t,MultisetCardinality)) + case MultisetToSet(t) => Some((t,MultisetToSet)) + case Car(t) => Some((t,Car)) + case Cdr(t) => Some((t,Cdr)) + case SetMin(s) => Some((s,SetMin)) + case SetMax(s) => Some((s,SetMax)) + case CaseClassSelector(cd, e, sel) => Some((e, CaseClassSelector(cd, _, sel))) + case CaseClassInstanceOf(cd, e) => Some((e, CaseClassInstanceOf(cd, _))) + case Assignment(id, e) => Some((e, Assignment(id, _))) + case TupleSelect(t, i) => Some((t, TupleSelect(_, i))) + case ArrayLength(a) => Some((a, ArrayLength)) + case ArrayClone(a) => Some((a, ArrayClone)) + case ArrayMake(t) => Some((t, ArrayMake)) + case Waypoint(i, t) => Some((t, (expr: Expr) => Waypoint(i, expr))) + case e@Epsilon(t) => Some((t, (expr: Expr) => Epsilon(expr).setType(e.getType).setPosInfo(e))) + case _ => None + } + } + + object BinaryOperator { + def unapply(expr: Expr) : Option[(Expr,Expr,(Expr,Expr)=>Expr)] = expr match { + case Equals(t1,t2) => Some((t1,t2,Equals.apply)) + case Iff(t1,t2) => Some((t1,t2,Iff(_,_))) + case Implies(t1,t2) => Some((t1,t2,Implies.apply)) + case Plus(t1,t2) => Some((t1,t2,Plus)) + case Minus(t1,t2) => Some((t1,t2,Minus)) + case Times(t1,t2) => Some((t1,t2,Times)) + case Division(t1,t2) => Some((t1,t2,Division)) + case Modulo(t1,t2) => Some((t1,t2,Modulo)) + case LessThan(t1,t2) => Some((t1,t2,LessThan)) + case GreaterThan(t1,t2) => Some((t1,t2,GreaterThan)) + case LessEquals(t1,t2) => Some((t1,t2,LessEquals)) + case GreaterEquals(t1,t2) => Some((t1,t2,GreaterEquals)) + case ElementOfSet(t1,t2) => Some((t1,t2,ElementOfSet)) + case SubsetOf(t1,t2) => Some((t1,t2,SubsetOf)) + case SetIntersection(t1,t2) => Some((t1,t2,SetIntersection)) + case SetUnion(t1,t2) => Some((t1,t2,SetUnion)) + case SetDifference(t1,t2) => Some((t1,t2,SetDifference)) + case Multiplicity(t1,t2) => Some((t1,t2,Multiplicity)) + case SubmultisetOf(t1,t2) => Some((t1,t2,SubmultisetOf)) + case MultisetIntersection(t1,t2) => Some((t1,t2,MultisetIntersection)) + case MultisetUnion(t1,t2) => Some((t1,t2,MultisetUnion)) + case MultisetPlus(t1,t2) => Some((t1,t2,MultisetPlus)) + case MultisetDifference(t1,t2) => Some((t1,t2,MultisetDifference)) + case SingletonMap(t1,t2) => Some((t1,t2,SingletonMap)) + case mg@MapGet(t1,t2) => Some((t1,t2, (t1, t2) => MapGet(t1, t2).setPosInfo(mg))) + case MapUnion(t1,t2) => Some((t1,t2,MapUnion)) + case MapDifference(t1,t2) => Some((t1,t2,MapDifference)) + case MapIsDefinedAt(t1,t2) => Some((t1,t2, MapIsDefinedAt)) + case ArrayFill(t1, t2) => Some((t1, t2, ArrayFill)) + case ArraySelect(t1, t2) => Some((t1, t2, ArraySelect)) + case Concat(t1,t2) => Some((t1,t2,Concat)) + case ListAt(t1,t2) => Some((t1,t2,ListAt)) + case wh@While(t1, t2) => Some((t1,t2, (t1, t2) => While(t1, t2).setInvariant(wh.invariant).setPosInfo(wh))) + case _ => None + } + } + + object NAryOperator { + def unapply(expr: Expr) : Option[(Seq[Expr],(Seq[Expr])=>Expr)] = expr match { + case fi @ FunctionInvocation(fd, args) => Some((args, (as => FunctionInvocation(fd, as).setPosInfo(fi)))) + case AnonymousFunctionInvocation(id, args) => Some((args, (as => AnonymousFunctionInvocation(id, as)))) + case CaseClass(cd, args) => Some((args, CaseClass(cd, _))) + case And(args) => Some((args, And.apply)) + case Or(args) => Some((args, Or.apply)) + case FiniteSet(args) => Some((args, FiniteSet)) + case FiniteMap(args) => Some((args, (as : Seq[Expr]) => FiniteMap(as.asInstanceOf[Seq[SingletonMap]]))) + case FiniteMultiset(args) => Some((args, FiniteMultiset)) + case ArrayUpdate(t1, t2, t3) => Some((Seq(t1,t2,t3), (as: Seq[Expr]) => ArrayUpdate(as(0), as(1), as(2)))) + case ArrayUpdated(t1, t2, t3) => Some((Seq(t1,t2,t3), (as: Seq[Expr]) => ArrayUpdated(as(0), as(1), as(2)))) + case FiniteArray(args) => Some((args, FiniteArray)) + case Distinct(args) => Some((args, Distinct)) + case Block(args, rest) => Some((args :+ rest, exprs => Block(exprs.init, exprs.last))) + case Tuple(args) => Some((args, Tuple)) + case _ => None + } + } + + object SimplePatternMatching { + def isSimple(me: MatchExpr) : Boolean = unapply(me).isDefined + + // (scrutinee, classtype, list((caseclassdef, variable, list(variable), rhs))) + def unapply(e: MatchExpr) : Option[(Expr,ClassType,Seq[(CaseClassDef,Identifier,Seq[Identifier],Expr)])] = { + val MatchExpr(scrutinee, cases) = e + val sType = scrutinee.getType + + if(sType.isInstanceOf[TupleType]) { + None + } else if(sType.isInstanceOf[AbstractClassType]) { + val cCD = sType.asInstanceOf[AbstractClassType].classDef + if(cases.size == cCD.knownChildren.size && cases.forall(!_.hasGuard)) { + var seen = Set.empty[ClassTypeDef] + + var lle : List[(CaseClassDef,Identifier,List[Identifier],Expr)] = Nil + for(cse <- cases) { + cse match { + case SimpleCase(CaseClassPattern(binder, ccd, subPats), rhs) if subPats.forall(_.isInstanceOf[WildcardPattern]) => { + seen = seen + ccd + + val patID : Identifier = if(binder.isDefined) { + binder.get + } else { + FreshIdentifier("cse", true).setType(CaseClassType(ccd)) + } + + val argIDs : List[Identifier] = (ccd.fields zip subPats.map(_.asInstanceOf[WildcardPattern])).map(p => if(p._2.binder.isDefined) { + p._2.binder.get + } else { + FreshIdentifier("pat", true).setType(p._1.tpe) + }).toList + + lle = (ccd, patID, argIDs, rhs) :: lle + } + case _ => ; + } + } + lle = lle.reverse + + if(seen.size == cases.size) { + Some((scrutinee, sType.asInstanceOf[AbstractClassType], lle)) + } else { + None + } + } else { + None + } + } else { + val cCD = sType.asInstanceOf[CaseClassType].classDef + if(cases.size == 1 && !cases(0).hasGuard) { + val SimpleCase(pat,rhs) = cases(0).asInstanceOf[SimpleCase] + pat match { + case CaseClassPattern(binder, ccd, subPats) if (ccd == cCD && subPats.forall(_.isInstanceOf[WildcardPattern])) => { + val patID : Identifier = if(binder.isDefined) { + binder.get + } else { + FreshIdentifier("cse", true).setType(CaseClassType(ccd)) + } + + val argIDs : List[Identifier] = (ccd.fields zip subPats.map(_.asInstanceOf[WildcardPattern])).map(p => if(p._2.binder.isDefined) { + p._2.binder.get + } else { + FreshIdentifier("pat", true).setType(p._1.tpe) + }).toList + + Some((scrutinee, CaseClassType(cCD), List((cCD, patID, argIDs, rhs)))) + } + case _ => None + } + } else { + None + } + } + } + } + + object NotSoSimplePatternMatching { + def coversType(tpe: ClassTypeDef, patterns: Seq[Pattern]) : Boolean = { + if(patterns.isEmpty) { + false + } else if(patterns.exists(_.isInstanceOf[WildcardPattern])) { + true + } else { + val allSubtypes: Seq[CaseClassDef] = tpe match { + case acd @ AbstractClassDef(_,_) => acd.knownDescendents.filter(_.isInstanceOf[CaseClassDef]).map(_.asInstanceOf[CaseClassDef]) + case ccd: CaseClassDef => List(ccd) + } + + var seen: Set[CaseClassDef] = Set.empty + var secondLevel: Map[(CaseClassDef,Int),List[Pattern]] = Map.empty + + for(pat <- patterns) if (pat.isInstanceOf[CaseClassPattern]) { + val pattern: CaseClassPattern = pat.asInstanceOf[CaseClassPattern] + val ccd: CaseClassDef = pattern.caseClassDef + seen = seen + ccd + + for((subPattern,i) <- (pattern.subPatterns.zipWithIndex)) { + val seenSoFar = secondLevel.getOrElse((ccd,i), Nil) + secondLevel = secondLevel + ((ccd,i) -> (subPattern :: seenSoFar)) + } + } + + allSubtypes.forall(ccd => { + seen(ccd) && ccd.fields.zipWithIndex.forall(p => p._1.tpe match { + case t: ClassType => coversType(t.classDef, secondLevel.getOrElse((ccd, p._2), Nil)) + case _ => true + }) + }) + } + } + + def unapply(pm : MatchExpr) : Option[MatchExpr] = if(!Settings.experimental) None else (pm match { + case MatchExpr(scrutinee, cases) if cases.forall(_.isInstanceOf[SimpleCase]) => { + val allPatterns = cases.map(_.pattern) + Settings.reporter.info("This might be a complete pattern-matching expression:") + Settings.reporter.info(pm) + Settings.reporter.info("Covered? " + coversType(pm.scrutineeClassType.classDef, allPatterns)) + None + } + case _ => None + }) + } + +} diff --git a/src/main/scala/leon/purescala/TreeOps.scala b/src/main/scala/leon/purescala/TreeOps.scala new file mode 100644 index 0000000000000000000000000000000000000000..a0494a8e440a799ca8d2d600a5a6fda6d9268170 --- /dev/null +++ b/src/main/scala/leon/purescala/TreeOps.scala @@ -0,0 +1,953 @@ +package leon +package purescala + + +object TreeOps { + import Common._ + import TypeTrees._ + import Definitions._ + import Trees._ + import Extractors._ + + def negate(expr: Expr) : Expr = expr match { + case Let(i,b,e) => Let(i,b,negate(e)) + case Not(e) => e + case Iff(e1,e2) => Iff(negate(e1),e2) + case Implies(e1,e2) => And(e1, negate(e2)) + case Or(exs) => And(exs map negate) + case And(exs) => Or(exs map negate) + case LessThan(e1,e2) => GreaterEquals(e1,e2) + case LessEquals(e1,e2) => GreaterThan(e1,e2) + case GreaterThan(e1,e2) => LessEquals(e1,e2) + case GreaterEquals(e1,e2) => LessThan(e1,e2) + case i @ IfExpr(c,e1,e2) => IfExpr(c, negate(e1), negate(e2)).setType(i.getType) + case BooleanLiteral(b) => BooleanLiteral(!b) + case _ => Not(expr) + } + + // Warning ! This may loop forever if the substitutions are not + // well-formed! + def replace(substs: Map[Expr,Expr], expr: Expr) : Expr = { + searchAndReplaceDFS(substs.get)(expr) + } + + // Can't just be overloading because of type erasure :'( + def replaceFromIDs(substs: Map[Identifier,Expr], expr: Expr) : Expr = { + replace(substs.map(p => (Variable(p._1) -> p._2)), expr) + } + + def searchAndReplace(subst: Expr=>Option[Expr], recursive: Boolean=true)(expr: Expr) : Expr = { + def rec(ex: Expr, skip: Expr = null) : Expr = (if (ex == skip) None else subst(ex)) match { + case Some(newExpr) => { + if(newExpr.getType == Untyped) { + Settings.reporter.error("REPLACING IN EXPRESSION WITH AN UNTYPED TREE ! " + ex + " --to--> " + newExpr) + } + if(ex == newExpr) + if(recursive) rec(ex, ex) else ex + else + if(recursive) rec(newExpr) else newExpr + } + case None => ex match { + case l @ Let(i,e,b) => { + val re = rec(e) + val rb = rec(b) + if(re != e || rb != b) + Let(i, re, rb).setType(l.getType) + else + l + } + case l @ LetVar(i,e,b) => { + val re = rec(e) + val rb = rec(b) + if(re != e || rb != b) + LetVar(i, re, rb).setType(l.getType) + else + l + } + case l @ LetDef(fd, b) => { + //TODO, not sure, see comment for the next LetDef + fd.body = fd.body.map(rec(_)) + fd.precondition = fd.precondition.map(rec(_)) + fd.postcondition = fd.postcondition.map(rec(_)) + LetDef(fd, rec(b)).setType(l.getType) + } + + case lt @ LetTuple(ids, expr, body) => { + val re = rec(expr) + val rb = rec(body) + if (re != expr || rb != body) { + LetTuple(ids, re, rb).setType(lt.getType) + } else { + lt + } + } + case n @ NAryOperator(args, recons) => { + var change = false + val rargs = args.map(a => { + val ra = rec(a) + if(ra != a) { + change = true + ra + } else { + a + } + }) + if(change) + recons(rargs).setType(n.getType) + else + n + } + case b @ BinaryOperator(t1,t2,recons) => { + val r1 = rec(t1) + val r2 = rec(t2) + if(r1 != t1 || r2 != t2) + recons(r1,r2).setType(b.getType) + else + b + } + case u @ UnaryOperator(t,recons) => { + val r = rec(t) + if(r != t) + recons(r).setType(u.getType) + else + u + } + case i @ IfExpr(t1,t2,t3) => { + val r1 = rec(t1) + val r2 = rec(t2) + val r3 = rec(t3) + if(r1 != t1 || r2 != t2 || r3 != t3) + IfExpr(rec(t1),rec(t2),rec(t3)).setType(i.getType) + else + i + } + case m @ MatchExpr(scrut,cses) => MatchExpr(rec(scrut), cses.map(inCase(_))).setType(m.getType).setPosInfo(m) + + case c @ Choose(args, body) => + val body2 = rec(body) + + if (body != body2) { + Choose(args, body2).setType(c.getType) + } else { + c + } + + case t if t.isInstanceOf[Terminal] => t + case unhandled => scala.sys.error("Non-terminal case should be handled in searchAndReplace: " + unhandled) + } + } + + def inCase(cse: MatchCase) : MatchCase = cse match { + case SimpleCase(pat, rhs) => SimpleCase(pat, rec(rhs)) + case GuardedCase(pat, guard, rhs) => GuardedCase(pat, rec(guard), rec(rhs)) + } + + rec(expr) + } + + def searchAndReplaceDFS(subst: Expr=>Option[Expr])(expr: Expr) : Expr = { + val (res,_) = searchAndReplaceDFSandTrackChanges(subst)(expr) + res + } + + def searchAndReplaceDFSandTrackChanges(subst: Expr=>Option[Expr])(expr: Expr) : (Expr,Boolean) = { + var somethingChanged: Boolean = false + def applySubst(ex: Expr) : Expr = subst(ex) match { + case None => ex + case Some(newEx) => { + somethingChanged = true + if(newEx.getType == Untyped) { + Settings.reporter.warning("REPLACING [" + ex + "] WITH AN UNTYPED EXPRESSION !") + Settings.reporter.warning("Here's the new expression: " + newEx) + } + newEx + } + } + + def rec(ex: Expr) : Expr = ex match { + case l @ Let(i,e,b) => { + val re = rec(e) + val rb = rec(b) + applySubst(if(re != e || rb != b) { + Let(i,re,rb).setType(l.getType) + } else { + l + }) + } + case l @ LetTuple(ids,e,b) => { + val re = rec(e) + val rb = rec(b) + applySubst(if(re != e || rb != b) { + LetTuple(ids,re,rb).setType(l.getType) + } else { + l + }) + } + case l @ LetVar(i,e,b) => { + val re = rec(e) + val rb = rec(b) + applySubst(if(re != e || rb != b) { + LetVar(i,re,rb).setType(l.getType) + } else { + l + }) + } + case l @ LetDef(fd,b) => { + //TODO: Not sure: I actually need the replace to occurs even in the pre/post condition, hope this is correct + fd.body = fd.body.map(rec(_)) + fd.precondition = fd.precondition.map(rec(_)) + fd.postcondition = fd.postcondition.map(rec(_)) + val rl = LetDef(fd, rec(b)).setType(l.getType) + applySubst(rl) + } + case n @ NAryOperator(args, recons) => { + var change = false + val rargs = args.map(a => { + val ra = rec(a) + if(ra != a) { + change = true + ra + } else { + a + } + }) + applySubst(if(change) { + recons(rargs).setType(n.getType) + } else { + n + }) + } + case b @ BinaryOperator(t1,t2,recons) => { + val r1 = rec(t1) + val r2 = rec(t2) + applySubst(if(r1 != t1 || r2 != t2) { + recons(r1,r2).setType(b.getType) + } else { + b + }) + } + case u @ UnaryOperator(t,recons) => { + val r = rec(t) + applySubst(if(r != t) { + recons(r).setType(u.getType) + } else { + u + }) + } + case i @ IfExpr(t1,t2,t3) => { + val r1 = rec(t1) + val r2 = rec(t2) + val r3 = rec(t3) + applySubst(if(r1 != t1 || r2 != t2 || r3 != t3) { + IfExpr(r1,r2,r3).setType(i.getType) + } else { + i + }) + } + case m @ MatchExpr(scrut,cses) => { + val rscrut = rec(scrut) + val (newCses,changes) = cses.map(inCase(_)).unzip + applySubst(if(rscrut != scrut || changes.exists(res=>res)) { + MatchExpr(rscrut, newCses).setType(m.getType).setPosInfo(m) + } else { + m + }) + } + + case c @ Choose(args, body) => + val body2 = rec(body) + + applySubst(if (body != body2) { + Choose(args, body2).setType(c.getType).setPosInfo(c) + } else { + c + }) + + case t if t.isInstanceOf[Terminal] => applySubst(t) + case unhandled => scala.sys.error("Non-terminal case should be handled in searchAndReplaceDFS: " + unhandled) + } + + def inCase(cse: MatchCase) : (MatchCase,Boolean) = cse match { + case s @ SimpleCase(pat, rhs) => { + val rrhs = rec(rhs) + if(rrhs != rhs) { + (SimpleCase(pat, rrhs), true) + } else { + (s, false) + } + } + case g @ GuardedCase(pat, guard, rhs) => { + val rguard = rec(guard) + val rrhs = rec(rhs) + if(rguard != guard || rrhs != rhs) { + (GuardedCase(pat, rguard, rrhs), true) + } else { + (g, false) + } + } + } + + val res = rec(expr) + (res, somethingChanged) + } + + // rewrites pattern-matching expressions to use fresh variables for the binders + def freshenLocals(expr: Expr) : Expr = { + def rewritePattern(p: Pattern, sm: Map[Identifier,Identifier]) : Pattern = p match { + case InstanceOfPattern(Some(b), ctd) => InstanceOfPattern(Some(sm(b)), ctd) + case WildcardPattern(Some(b)) => WildcardPattern(Some(sm(b))) + case CaseClassPattern(ob, ccd, sps) => CaseClassPattern(ob.map(sm(_)), ccd, sps.map(rewritePattern(_, sm))) + case other => other + } + + def freshenCase(cse: MatchCase) : MatchCase = { + val allBinders: Set[Identifier] = cse.pattern.binders + val subMap: Map[Identifier,Identifier] = Map(allBinders.map(i => (i, FreshIdentifier(i.name, true).setType(i.getType))).toSeq : _*) + val subVarMap: Map[Expr,Expr] = subMap.map(kv => (Variable(kv._1) -> Variable(kv._2))) + + cse match { + case SimpleCase(pattern, rhs) => SimpleCase(rewritePattern(pattern, subMap), replace(subVarMap, rhs)) + case GuardedCase(pattern, guard, rhs) => GuardedCase(rewritePattern(pattern, subMap), replace(subVarMap, guard), replace(subVarMap, rhs)) + } + } + + def applyToTree(e : Expr) : Option[Expr] = e match { + case m @ MatchExpr(s, cses) => Some(MatchExpr(s, cses.map(freshenCase(_))).setType(m.getType).setPosInfo(m)) + case l @ Let(i,e,b) => { + val newID = FreshIdentifier(i.name, true).setType(i.getType) + Some(Let(newID, e, replace(Map(Variable(i) -> Variable(newID)), b))) + } + case _ => None + } + + searchAndReplaceDFS(applyToTree)(expr) + } + + // convert describes how to compute a value for the leaves (that includes + // functions with no args.) + // combine descriess how to combine two values + def treeCatamorphism[A](convert: Expr=>A, combine: (A,A)=>A, expression: Expr) : A = { + treeCatamorphism(convert, combine, (e:Expr,a:A)=>a, expression) + } + // compute allows the catamorphism to change the combined value depending on the tree + def treeCatamorphism[A](convert: Expr=>A, combine: (A,A)=>A, compute: (Expr,A)=>A, expression: Expr) : A = { + def rec(expr: Expr) : A = expr match { + case l @ Let(_, e, b) => compute(l, combine(rec(e), rec(b))) + case l @ LetVar(_, e, b) => compute(l, combine(rec(e), rec(b))) + case l @ LetDef(fd, b) => {//TODO, still not sure about the semantic + val exprs: Seq[Expr] = fd.precondition.toSeq ++ fd.body.toSeq ++ fd.postcondition.toSeq ++ Seq(b) + compute(l, exprs.map(rec(_)).reduceLeft(combine)) + } + case n @ NAryOperator(args, _) => { + if(args.size == 0) + compute(n, convert(n)) + else + compute(n, args.map(rec(_)).reduceLeft(combine)) + } + case b @ BinaryOperator(a1,a2,_) => compute(b, combine(rec(a1),rec(a2))) + case u @ UnaryOperator(a,_) => compute(u, rec(a)) + case i @ IfExpr(a1,a2,a3) => compute(i, combine(combine(rec(a1), rec(a2)), rec(a3))) + case m @ MatchExpr(scrut, cses) => compute(m, (scrut +: cses.flatMap(_.expressions)).map(rec(_)).reduceLeft(combine)) + case a @ AnonymousFunction(es, ev) => compute(a, (es.flatMap(e => e._1 ++ Seq(e._2)) ++ Seq(ev)).map(rec(_)).reduceLeft(combine)) + case c @ Choose(args, body) => compute(c, rec(body)) + case t: Terminal => compute(t, convert(t)) + case unhandled => scala.sys.error("Non-terminal case should be handled in treeCatamorphism: " + unhandled) + } + + rec(expression) + } + + def flattenBlocks(expr: Expr): Expr = { + def applyToTree(expr: Expr): Option[Expr] = expr match { + case Block(exprs, last) => { + val nexprs = (exprs :+ last).flatMap{ + case Block(es2, el) => es2 :+ el + case UnitLiteral => Seq() + case e2 => Seq(e2) + } + val fexpr = nexprs match { + case Seq() => UnitLiteral + case Seq(e) => e + case es => Block(es.init, es.last).setType(es.last.getType) + } + Some(fexpr) + } + case _ => None + } + searchAndReplaceDFS(applyToTree)(expr) + } + + //checking whether the expr is pure, that is do not contains any non-pure construct: assign, while, blocks, array, ... + //this is expected to be true when entering the "backend" of Leon + def isPure(expr: Expr): Boolean = { + def convert(t: Expr) : Boolean = t match { + case Block(_, _) => false + case Assignment(_, _) => false + case While(_, _) => false + case LetVar(_, _, _) => false + case LetDef(_, _) => false + case ArrayUpdate(_, _, _) => false + case ArrayMake(_) => false + case ArrayClone(_) => false + case Epsilon(_) => false + case _ => true + } + def combine(b1: Boolean, b2: Boolean) = b1 && b2 + def compute(e: Expr, b: Boolean) = e match { + case Block(_, _) => false + case Assignment(_, _) => false + case While(_, _) => false + case LetVar(_, _, _) => false + case LetDef(_, _) => false + case ArrayUpdate(_, _, _) => false + case ArrayMake(_) => false + case ArrayClone(_) => false + case Epsilon(_) => false + case _ => b + } + treeCatamorphism(convert, combine, compute, expr) + } + + def containsEpsilon(expr: Expr): Boolean = { + def convert(t : Expr) : Boolean = t match { + case (l : Epsilon) => true + case _ => false + } + def combine(c1 : Boolean, c2 : Boolean) : Boolean = c1 || c2 + def compute(t : Expr, c : Boolean) = t match { + case (l : Epsilon) => true + case _ => c + } + treeCatamorphism(convert, combine, compute, expr) + } + + def containsLetDef(expr: Expr): Boolean = { + def convert(t : Expr) : Boolean = t match { + case (l : LetDef) => true + case _ => false + } + def combine(c1 : Boolean, c2 : Boolean) : Boolean = c1 || c2 + def compute(t : Expr, c : Boolean) = t match { + case (l : LetDef) => true + case _ => c + } + treeCatamorphism(convert, combine, compute, expr) + } + def containsIfExpr(expr: Expr): Boolean = { + def convert(t : Expr) : Boolean = t match { + case (i: IfExpr) => true + case _ => false + } + def combine(c1 : Boolean, c2 : Boolean) : Boolean = c1 || c2 + def compute(t : Expr, c : Boolean) = t match { + case (i: IfExpr) => true + case _ => c + } + treeCatamorphism(convert, combine, compute, expr) + } + + def variablesOf(expr: Expr) : Set[Identifier] = { + def convert(t: Expr) : Set[Identifier] = t match { + case Variable(i) => Set(i) + case _ => Set.empty + } + def combine(s1: Set[Identifier], s2: Set[Identifier]) = s1 ++ s2 + def compute(t: Expr, s: Set[Identifier]) = t match { + case Let(i,_,_) => s -- Set(i) + case MatchExpr(_, cses) => s -- (cses.map(_.pattern.binders).foldLeft(Set[Identifier]())((a, b) => a ++ b)) + case AnonymousFunctionInvocation(i,_) => s ++ Set[Identifier](i) + case _ => s + } + treeCatamorphism(convert, combine, compute, expr) + } + + def containsFunctionCalls(expr : Expr) : Boolean = { + def convert(t : Expr) : Boolean = t match { + case f : FunctionInvocation => true + case _ => false + } + def combine(c1 : Boolean, c2 : Boolean) : Boolean = c1 || c2 + def compute(t : Expr, c : Boolean) = t match { + case f : FunctionInvocation => true + case _ => c + } + treeCatamorphism(convert, combine, compute, expr) + } + + def topLevelFunctionCallsOf(expr: Expr, barring : Set[FunDef] = Set.empty) : Set[FunctionInvocation] = { + def convert(t: Expr) : Set[FunctionInvocation] = t match { + case f @ FunctionInvocation(fd, _) if(!barring(fd)) => Set(f) + case _ => Set.empty + } + def combine(s1: Set[FunctionInvocation], s2: Set[FunctionInvocation]) = s1 ++ s2 + def compute(t: Expr, s: Set[FunctionInvocation]) = t match { + case f @ FunctionInvocation(fd, _) if(!barring(fd)) => Set(f) // ++ s that's the difference with the one below + case _ => s + } + treeCatamorphism(convert, combine, compute, expr) + } + + def allNonRecursiveFunctionCallsOf(expr: Expr, program: Program) : Set[FunctionInvocation] = { + def convert(t: Expr) : Set[FunctionInvocation] = t match { + case f @ FunctionInvocation(fd, _) if program.isRecursive(fd) => Set(f) + case _ => Set.empty + } + + def combine(s1: Set[FunctionInvocation], s2: Set[FunctionInvocation]) = s1 ++ s2 + + def compute(t: Expr, s: Set[FunctionInvocation]) = t match { + case f @ FunctionInvocation(fd,_) if program.isRecursive(fd) => Set(f) ++ s + case _ => s + } + treeCatamorphism(convert, combine, compute, expr) + } + + def functionCallsOf(expr: Expr) : Set[FunctionInvocation] = { + def convert(t: Expr) : Set[FunctionInvocation] = t match { + case f @ FunctionInvocation(_, _) => Set(f) + case _ => Set.empty + } + def combine(s1: Set[FunctionInvocation], s2: Set[FunctionInvocation]) = s1 ++ s2 + def compute(t: Expr, s: Set[FunctionInvocation]) = t match { + case f @ FunctionInvocation(_, _) => Set(f) ++ s + case _ => s + } + treeCatamorphism(convert, combine, compute, expr) + } + + def contains(expr: Expr, matcher: Expr=>Boolean) : Boolean = { + treeCatamorphism[Boolean]( + matcher, + (b1: Boolean, b2: Boolean) => b1 || b2, + (t: Expr, b: Boolean) => b || matcher(t), + expr) + } + + def allIdentifiers(expr: Expr) : Set[Identifier] = expr match { + case l @ Let(binder, e, b) => allIdentifiers(e) ++ allIdentifiers(b) + binder + case l @ LetVar(binder, e, b) => allIdentifiers(e) ++ allIdentifiers(b) + binder + case l @ LetDef(fd, b) => allIdentifiers(fd.getBody) ++ allIdentifiers(b) + fd.id + case n @ NAryOperator(args, _) => + (args map (TreeOps.allIdentifiers(_))).foldLeft(Set[Identifier]())((a, b) => a ++ b) + case b @ BinaryOperator(a1,a2,_) => allIdentifiers(a1) ++ allIdentifiers(a2) + case u @ UnaryOperator(a,_) => allIdentifiers(a) + case i @ IfExpr(a1,a2,a3) => allIdentifiers(a1) ++ allIdentifiers(a2) ++ allIdentifiers(a3) + case m @ MatchExpr(scrut, cses) => + (cses map (_.allIdentifiers)).foldLeft(Set[Identifier]())((a, b) => a ++ b) ++ allIdentifiers(scrut) + case Variable(id) => Set(id) + case t: Terminal => Set.empty + } + + def allDeBruijnIndices(expr: Expr) : Set[DeBruijnIndex] = { + def convert(t: Expr) : Set[DeBruijnIndex] = t match { + case i @ DeBruijnIndex(idx) => Set(i) + case _ => Set.empty + } + def combine(s1: Set[DeBruijnIndex], s2: Set[DeBruijnIndex]) = s1 ++ s2 + treeCatamorphism(convert, combine, expr) + } + + /* Simplifies let expressions: + * - removes lets when expression never occurs + * - simplifies when expressions occurs exactly once + * - expands when expression is just a variable. + * Note that the code is simple but far from optimal (many traversals...) + */ + def simplifyLets(expr: Expr) : Expr = { + def simplerLet(t: Expr) : Option[Expr] = t match { + case letExpr @ Let(i, t: Terminal, b) => Some(replace(Map((Variable(i) -> t)), b)) + case letExpr @ Let(i,e,b) => { + val occurences = treeCatamorphism[Int]((e:Expr) => e match { + case Variable(x) if x == i => 1 + case _ => 0 + }, (x:Int,y:Int)=>x+y, b) + if(occurences == 0) { + Some(b) + } else if(occurences == 1) { + Some(replace(Map((Variable(i) -> e)), b)) + } else { + None + } + } + //case letTuple @ LetTuple(ids, expr, body) if ids.size == 1 => + // simplerLet(Let(ids.head, TupleSelect(expr, 1).setType(ids.head.getType), body)) + + case letTuple @ LetTuple(ids, Tuple(exprs), body) => + + var newBody = body + + val (remIds, remExprs) = (ids zip exprs).filter { + case (id, value: Terminal) => + newBody = replace(Map((Variable(id) -> value)), newBody) + //we replace, so we drop old + false + case (id, value) => + val occurences = treeCatamorphism[Int]((e:Expr) => e match { + case Variable(x) if x == id => 1 + case _ => 0 + }, (x:Int,y:Int)=>x+y, body) + + if(occurences == 0) { + false + } else if(occurences == 1) { + newBody = replace(Map((Variable(id) -> value)), newBody) + false + } else { + true + } + }.unzip + + + if (remIds.isEmpty) { + Some(newBody) + } else if (remIds.tail.isEmpty) { + Some(Let(remIds.head, remExprs.head, newBody)) + } else { + Some(LetTuple(remIds, Tuple(remExprs), newBody)) + } + case _ => None + } + searchAndReplaceDFS(simplerLet)(expr) + } + + // Pulls out all let constructs to the top level, and makes sure they're + // properly ordered. + private type DefPair = (Identifier,Expr) + private type DefPairs = List[DefPair] + private def allLetDefinitions(expr: Expr) : DefPairs = treeCatamorphism[DefPairs]( + (e: Expr) => Nil, + (s1: DefPairs, s2: DefPairs) => s1 ::: s2, + (e: Expr, dps: DefPairs) => e match { + case Let(i, e, _) => (i,e) :: dps + case _ => dps + }, + expr) + + private def killAllLets(expr: Expr) : Expr = searchAndReplaceDFS((e: Expr) => e match { + case Let(_,_,ex) => Some(ex) + case _ => None + })(expr) + + def liftLets(expr: Expr) : Expr = { + val initialDefinitionPairs = allLetDefinitions(expr) + val definitionPairs = initialDefinitionPairs.map(p => (p._1, killAllLets(p._2))) + val occursLists : Map[Identifier,Set[Identifier]] = Map(definitionPairs.map((dp: DefPair) => (dp._1 -> variablesOf(dp._2).toSet.filter(_.isLetBinder))) : _*) + var newList : DefPairs = Nil + var placed : Set[Identifier] = Set.empty + val toPlace = definitionPairs.size + var placedC = 0 + var traversals = 0 + + while(placedC < toPlace) { + if(traversals > toPlace + 1) { + scala.sys.error("Cycle in let definitions or multiple definition for the same identifier in liftLets : " + definitionPairs.mkString("\n")) + } + for((id,ex) <- definitionPairs) if (!placed(id)) { + if((occursLists(id) -- placed) == Set.empty) { + placed = placed + id + newList = (id,ex) :: newList + placedC = placedC + 1 + } + } + traversals = traversals + 1 + } + + val noLets = killAllLets(expr) + + val res = (newList.foldLeft(noLets)((e,iap) => Let(iap._1, iap._2, e))) + simplifyLets(res) + } + + def wellOrderedLets(tree : Expr) : Boolean = { + val pairs = allLetDefinitions(tree) + val definitions: Set[Identifier] = Set(pairs.map(_._1) : _*) + val vars: Set[Identifier] = variablesOf(tree) + val intersection = vars intersect definitions + if(!intersection.isEmpty) { + intersection.foreach(id => { + Settings.reporter.error("Variable with identifier '" + id + "' has escaped its let-definition !") + }) + false + } else { + vars.forall(id => if(id.isLetBinder) { + Settings.reporter.error("Variable with identifier '" + id + "' has lost its let-definition (it disappeared??)") + false + } else { + true + }) + } + } + + /* Fully expands all let expressions. */ + def expandLets(expr: Expr) : Expr = { + def rec(ex: Expr, s: Map[Identifier,Expr]) : Expr = ex match { + case v @ Variable(id) if s.isDefinedAt(id) => rec(s(id), s) + case l @ Let(i,e,b) => rec(b, s + (i -> rec(e, s))) + case i @ IfExpr(t1,t2,t3) => IfExpr(rec(t1, s),rec(t2, s),rec(t3, s)).setType(i.getType) + case m @ MatchExpr(scrut,cses) => MatchExpr(rec(scrut, s), cses.map(inCase(_, s))).setType(m.getType).setPosInfo(m) + case n @ NAryOperator(args, recons) => { + var change = false + val rargs = args.map(a => { + val ra = rec(a, s) + if(ra != a) { + change = true + ra + } else { + a + } + }) + if(change) + recons(rargs).setType(n.getType) + else + n + } + case b @ BinaryOperator(t1,t2,recons) => { + val r1 = rec(t1, s) + val r2 = rec(t2, s) + if(r1 != t1 || r2 != t2) + recons(r1,r2).setType(b.getType) + else + b + } + case u @ UnaryOperator(t,recons) => { + val r = rec(t, s) + if(r != t) + recons(r).setType(u.getType) + else + u + } + case t if t.isInstanceOf[Terminal] => t + case unhandled => scala.sys.error("Unhandled case in expandLets: " + unhandled) + } + + def inCase(cse: MatchCase, s: Map[Identifier,Expr]) : MatchCase = cse match { + case SimpleCase(pat, rhs) => SimpleCase(pat, rec(rhs, s)) + case GuardedCase(pat, guard, rhs) => GuardedCase(pat, rec(guard, s), rec(rhs, s)) + } + + rec(expr, Map.empty) + } + + private var matchConverterCache = new scala.collection.mutable.HashMap[Expr,Expr]() + /** Rewrites all pattern-matching expressions into if-then-else expressions, + * with additional error conditions. Does not introduce additional variables. + * We use a cache because we can. */ + def matchToIfThenElse(expr: Expr) : Expr = { + val toRet = if(matchConverterCache.isDefinedAt(expr)) { + matchConverterCache(expr) + } else { + val converted = convertMatchToIfThenElse(expr) + matchConverterCache(expr) = converted + converted + } + + toRet + } + + def conditionForPattern(in: Expr, pattern: Pattern) : Expr = pattern match { + case WildcardPattern(_) => BooleanLiteral(true) + case InstanceOfPattern(_,_) => scala.sys.error("InstanceOfPattern not yet supported.") + case CaseClassPattern(_, ccd, subps) => { + assert(ccd.fields.size == subps.size) + val pairs = ccd.fields.map(_.id).toList zip subps.toList + val subTests = pairs.map(p => conditionForPattern(CaseClassSelector(ccd, in, p._1), p._2)) + val together = And(subTests) + And(CaseClassInstanceOf(ccd, in), together) + } + case TuplePattern(_, subps) => { + val TupleType(tpes) = in.getType + assert(tpes.size == subps.size) + val subTests = subps.zipWithIndex.map{case (p, i) => conditionForPattern(TupleSelect(in, i+1).setType(tpes(i)), p)} + And(subTests) + } + } + + private def convertMatchToIfThenElse(expr: Expr) : Expr = { + def mapForPattern(in: Expr, pattern: Pattern) : Map[Identifier,Expr] = pattern match { + case WildcardPattern(None) => Map.empty + case WildcardPattern(Some(id)) => Map(id -> in) + case InstanceOfPattern(None, _) => Map.empty + case InstanceOfPattern(Some(id), _) => Map(id -> in) + case CaseClassPattern(b, ccd, subps) => { + assert(ccd.fields.size == subps.size) + val pairs = ccd.fields.map(_.id).toList zip subps.toList + val subMaps = pairs.map(p => mapForPattern(CaseClassSelector(ccd, in, p._1), p._2)) + val together = subMaps.foldLeft(Map.empty[Identifier,Expr])(_ ++ _) + b match { + case Some(id) => Map(id -> in) ++ together + case None => together + } + } + case TuplePattern(b, subps) => { + val TupleType(tpes) = in.getType + assert(tpes.size == subps.size) + + val maps = subps.zipWithIndex.map{case (p, i) => mapForPattern(TupleSelect(in, i+1).setType(tpes(i)), p)} + val map = maps.foldLeft(Map.empty[Identifier,Expr])(_ ++ _) + b match { + case Some(id) => map + (id -> in) + case None => map + } + } + } + + def rewritePM(e: Expr) : Option[Expr] = e match { + case m @ MatchExpr(scrut, cases) => { + // println("Rewriting the following PM: " + e) + + val condsAndRhs = for(cse <- cases) yield { + val map = mapForPattern(scrut, cse.pattern) + val patCond = conditionForPattern(scrut, cse.pattern) + val realCond = cse.theGuard match { + case Some(g) => And(patCond, replaceFromIDs(map, g)) + case None => patCond + } + val newRhs = replaceFromIDs(map, cse.rhs) + (realCond, newRhs) + } + + val optCondsAndRhs = if(SimplePatternMatching.isSimple(m)) { + // this is a hackish optimization: because we know all cases are covered, we replace the last condition by true (and that drops the check) + val lastExpr = condsAndRhs.last._2 + + condsAndRhs.dropRight(1) ++ Seq((BooleanLiteral(true),lastExpr)) + } else { + condsAndRhs + } + + val bigIte = optCondsAndRhs.foldRight[Expr](Error("non-exhaustive match").setType(bestRealType(m.getType)).setPosInfo(m))((p1, ex) => { + if(p1._1 == BooleanLiteral(true)) { + p1._2 + } else { + IfExpr(p1._1, p1._2, ex).setType(m.getType) + } + }) + + Some(bigIte) + } + case _ => None + } + + searchAndReplaceDFS(rewritePM)(expr) + } + + private var mapGetConverterCache = new scala.collection.mutable.HashMap[Expr,Expr]() + /** Rewrites all map accesses with additional error conditions. */ + def mapGetWithChecks(expr: Expr) : Expr = { + val toRet = if (mapGetConverterCache.isDefinedAt(expr)) { + mapGetConverterCache(expr) + } else { + val converted = convertMapGet(expr) + mapGetConverterCache(expr) = converted + converted + } + + toRet + } + + private def convertMapGet(expr: Expr) : Expr = { + def rewriteMapGet(e: Expr) : Option[Expr] = e match { + case mg @ MapGet(m,k) => + val ida = MapIsDefinedAt(m, k) + Some(IfExpr(ida, mg, Error("key not found for map access").setType(mg.getType).setPosInfo(mg)).setType(mg.getType)) + case _ => None + } + + searchAndReplaceDFS(rewriteMapGet)(expr) + } + + // prec: expression does not contain match expressions + def measureADTChildrenDepth(expression: Expr) : Int = { + import scala.math.max + + def rec(ex: Expr, lm: Map[Identifier,Int]) : Int = ex match { + case Let(i,e,b) => rec(b,lm + (i -> rec(e,lm))) + case Variable(id) => lm.getOrElse(id, 0) + case CaseClassSelector(_, e, _) => rec(e,lm) + 1 + case NAryOperator(args, _) => if(args.isEmpty) 0 else args.map(rec(_,lm)).max + case BinaryOperator(e1,e2,_) => max(rec(e1,lm), rec(e2,lm)) + case UnaryOperator(e,_) => rec(e,lm) + case IfExpr(c,t,e) => max(max(rec(c,lm),rec(t,lm)),rec(e,lm)) + case t: Terminal => 0 + case _ => scala.sys.error("Not handled in measureChildrenDepth : " + ex) + } + + rec(expression,Map.empty) + } + + private val random = new scala.util.Random() + + def randomValue(v: Variable) : Expr = randomValue(v.getType) + def simplestValue(v: Variable) : Expr = simplestValue(v.getType) + + def randomValue(tpe: TypeTree) : Expr = tpe match { + case Int32Type => IntLiteral(random.nextInt(42)) + case BooleanType => BooleanLiteral(random.nextBoolean()) + case AbstractClassType(acd) => + val children = acd.knownChildren + randomValue(classDefToClassType(children(random.nextInt(children.size)))) + case CaseClassType(cd) => + val fields = cd.fields + CaseClass(cd, fields.map(f => randomValue(f.getType))) + case _ => throw new Exception("I can't choose random value for type " + tpe) + } + + def simplestValue(tpe: TypeTree) : Expr = tpe match { + case Int32Type => IntLiteral(0) + case BooleanType => BooleanLiteral(false) + case AbstractClassType(acd) => { + val children = acd.knownChildren + val simplerChildren = children.filter{ + case ccd @ CaseClassDef(id, Some(parent), fields) => + !fields.exists(vd => vd.getType match { + case AbstractClassType(fieldAcd) => acd == fieldAcd + case CaseClassType(fieldCcd) => ccd == fieldCcd + case _ => false + }) + case _ => false + } + def orderByNumberOfFields(fst: ClassTypeDef, snd: ClassTypeDef) : Boolean = (fst, snd) match { + case (CaseClassDef(_, _, flds1), CaseClassDef(_, _, flds2)) => flds1.size <= flds2.size + case _ => true + } + val orderedChildren = simplerChildren.sortWith(orderByNumberOfFields) + simplestValue(classDefToClassType(orderedChildren.head)) + } + case CaseClassType(ccd) => + val fields = ccd.fields + CaseClass(ccd, fields.map(f => simplestValue(f.getType))) + case SetType(baseType) => EmptySet(baseType).setType(tpe) + case MapType(fromType, toType) => EmptyMap(fromType, toType).setType(tpe) + case FunctionType(fromTypes, toType) => AnonymousFunction(Seq.empty, simplestValue(toType)).setType(tpe) + case _ => throw new Exception("I can't choose simplest value for type " + tpe) + } + + //guarentee that all IfExpr will be at the top level and as soon as you encounter a non-IfExpr, then no more IfExpr can be find in the sub-expressions + //require no-match, no-ets and only pure code + def hoistIte(expr: Expr): Expr = { + def transform(expr: Expr): Option[Expr] = expr match { + case uop@UnaryOperator(IfExpr(c, t, e), op) => Some(IfExpr(c, op(t).setType(uop.getType), op(e).setType(uop.getType)).setType(uop.getType)) + case bop@BinaryOperator(IfExpr(c, t, e), t2, op) => Some(IfExpr(c, op(t, t2).setType(bop.getType), op(e, t2).setType(bop.getType)).setType(bop.getType)) + case bop@BinaryOperator(t1, IfExpr(c, t, e), op) => Some(IfExpr(c, op(t1, t).setType(bop.getType), op(t1, e).setType(bop.getType)).setType(bop.getType)) + case nop@NAryOperator(ts, op) => { + val iteIndex = ts.indexWhere{ case IfExpr(_, _, _) => true case _ => false } + if(iteIndex == -1) None else { + val (beforeIte, startIte) = ts.splitAt(iteIndex) + val afterIte = startIte.tail + val IfExpr(c, t, e) = startIte.head + Some(IfExpr(c, + op(beforeIte ++ Seq(t) ++ afterIte).setType(nop.getType), + op(beforeIte ++ Seq(e) ++ afterIte).setType(nop.getType) + ).setType(nop.getType)) + } + } + case _ => None + } + + def fix[A](f: (A) => A, a: A): A = { + val na = f(a) + if(a == na) a else fix(f, na) + } + fix(searchAndReplaceDFS(transform), expr) + } +} diff --git a/src/main/scala/leon/purescala/Trees.scala b/src/main/scala/leon/purescala/Trees.scala index 1f18b52c19c38e37ed45dfba22b96d8a2a3a431e..246f335ada8a1c0d084ce1ac7e355db619805320 100644 --- a/src/main/scala/leon/purescala/Trees.scala +++ b/src/main/scala/leon/purescala/Trees.scala @@ -9,11 +9,11 @@ object Trees { /* EXPRESSIONS */ - sealed abstract class Expr extends Typed with Serializable { + abstract class Expr extends Typed with Serializable { override def toString: String = PrettyPrinter(this) } - sealed trait Terminal { + trait Terminal { self: Expr => } @@ -122,9 +122,9 @@ object Trees { def allIdentifiers : Set[Identifier] = { pattern.allIdentifiers ++ - Trees.allIdentifiers(rhs) ++ - theGuard.map(Trees.allIdentifiers(_)).getOrElse(Set[Identifier]()) ++ - (expressions map (Trees.allIdentifiers(_))).foldLeft(Set[Identifier]())((a, b) => a ++ b) + TreeOps.allIdentifiers(rhs) ++ + theGuard.map(TreeOps.allIdentifiers(_)).getOrElse(Set[Identifier]()) ++ + (expressions map (TreeOps.allIdentifiers(_))).foldLeft(Set[Identifier]())((a, b) => a ++ b) } } @@ -439,1153 +439,4 @@ object Trees { val fixedType = BooleanType } - object UnaryOperator { - def unapply(expr: Expr) : Option[(Expr,(Expr)=>Expr)] = expr match { - case Not(t) => Some((t,Not(_))) - case UMinus(t) => Some((t,UMinus)) - case SetCardinality(t) => Some((t,SetCardinality)) - case MultisetCardinality(t) => Some((t,MultisetCardinality)) - case MultisetToSet(t) => Some((t,MultisetToSet)) - case Car(t) => Some((t,Car)) - case Cdr(t) => Some((t,Cdr)) - case SetMin(s) => Some((s,SetMin)) - case SetMax(s) => Some((s,SetMax)) - case CaseClassSelector(cd, e, sel) => Some((e, CaseClassSelector(cd, _, sel))) - case CaseClassInstanceOf(cd, e) => Some((e, CaseClassInstanceOf(cd, _))) - case Assignment(id, e) => Some((e, Assignment(id, _))) - case TupleSelect(t, i) => Some((t, TupleSelect(_, i))) - case ArrayLength(a) => Some((a, ArrayLength)) - case ArrayClone(a) => Some((a, ArrayClone)) - case ArrayMake(t) => Some((t, ArrayMake)) - case Waypoint(i, t) => Some((t, (expr: Expr) => Waypoint(i, expr))) - case e@Epsilon(t) => Some((t, (expr: Expr) => Epsilon(expr).setType(e.getType).setPosInfo(e))) - case _ => None - } - } - - object BinaryOperator { - def unapply(expr: Expr) : Option[(Expr,Expr,(Expr,Expr)=>Expr)] = expr match { - case Equals(t1,t2) => Some((t1,t2,Equals.apply)) - case Iff(t1,t2) => Some((t1,t2,Iff(_,_))) - case Implies(t1,t2) => Some((t1,t2,Implies.apply)) - case Plus(t1,t2) => Some((t1,t2,Plus)) - case Minus(t1,t2) => Some((t1,t2,Minus)) - case Times(t1,t2) => Some((t1,t2,Times)) - case Division(t1,t2) => Some((t1,t2,Division)) - case Modulo(t1,t2) => Some((t1,t2,Modulo)) - case LessThan(t1,t2) => Some((t1,t2,LessThan)) - case GreaterThan(t1,t2) => Some((t1,t2,GreaterThan)) - case LessEquals(t1,t2) => Some((t1,t2,LessEquals)) - case GreaterEquals(t1,t2) => Some((t1,t2,GreaterEquals)) - case ElementOfSet(t1,t2) => Some((t1,t2,ElementOfSet)) - case SubsetOf(t1,t2) => Some((t1,t2,SubsetOf)) - case SetIntersection(t1,t2) => Some((t1,t2,SetIntersection)) - case SetUnion(t1,t2) => Some((t1,t2,SetUnion)) - case SetDifference(t1,t2) => Some((t1,t2,SetDifference)) - case Multiplicity(t1,t2) => Some((t1,t2,Multiplicity)) - case SubmultisetOf(t1,t2) => Some((t1,t2,SubmultisetOf)) - case MultisetIntersection(t1,t2) => Some((t1,t2,MultisetIntersection)) - case MultisetUnion(t1,t2) => Some((t1,t2,MultisetUnion)) - case MultisetPlus(t1,t2) => Some((t1,t2,MultisetPlus)) - case MultisetDifference(t1,t2) => Some((t1,t2,MultisetDifference)) - case SingletonMap(t1,t2) => Some((t1,t2,SingletonMap)) - case mg@MapGet(t1,t2) => Some((t1,t2, (t1, t2) => MapGet(t1, t2).setPosInfo(mg))) - case MapUnion(t1,t2) => Some((t1,t2,MapUnion)) - case MapDifference(t1,t2) => Some((t1,t2,MapDifference)) - case MapIsDefinedAt(t1,t2) => Some((t1,t2, MapIsDefinedAt)) - case ArrayFill(t1, t2) => Some((t1, t2, ArrayFill)) - case ArraySelect(t1, t2) => Some((t1, t2, ArraySelect)) - case Concat(t1,t2) => Some((t1,t2,Concat)) - case ListAt(t1,t2) => Some((t1,t2,ListAt)) - case wh@While(t1, t2) => Some((t1,t2, (t1, t2) => While(t1, t2).setInvariant(wh.invariant).setPosInfo(wh))) - case _ => None - } - } - - object NAryOperator { - def unapply(expr: Expr) : Option[(Seq[Expr],(Seq[Expr])=>Expr)] = expr match { - case fi @ FunctionInvocation(fd, args) => Some((args, (as => FunctionInvocation(fd, as).setPosInfo(fi)))) - case AnonymousFunctionInvocation(id, args) => Some((args, (as => AnonymousFunctionInvocation(id, as)))) - case CaseClass(cd, args) => Some((args, CaseClass(cd, _))) - case And(args) => Some((args, And.apply)) - case Or(args) => Some((args, Or.apply)) - case FiniteSet(args) => Some((args, FiniteSet)) - case FiniteMap(args) => Some((args, (as : Seq[Expr]) => FiniteMap(as.asInstanceOf[Seq[SingletonMap]]))) - case FiniteMultiset(args) => Some((args, FiniteMultiset)) - case ArrayUpdate(t1, t2, t3) => Some((Seq(t1,t2,t3), (as: Seq[Expr]) => ArrayUpdate(as(0), as(1), as(2)))) - case ArrayUpdated(t1, t2, t3) => Some((Seq(t1,t2,t3), (as: Seq[Expr]) => ArrayUpdated(as(0), as(1), as(2)))) - case FiniteArray(args) => Some((args, FiniteArray)) - case Distinct(args) => Some((args, Distinct)) - case Block(args, rest) => Some((args :+ rest, exprs => Block(exprs.init, exprs.last))) - case Tuple(args) => Some((args, Tuple)) - case _ => None - } - } - - def negate(expr: Expr) : Expr = expr match { - case Let(i,b,e) => Let(i,b,negate(e)) - case Not(e) => e - case Iff(e1,e2) => Iff(negate(e1),e2) - case Implies(e1,e2) => And(e1, negate(e2)) - case Or(exs) => And(exs map negate) - case And(exs) => Or(exs map negate) - case LessThan(e1,e2) => GreaterEquals(e1,e2) - case LessEquals(e1,e2) => GreaterThan(e1,e2) - case GreaterThan(e1,e2) => LessEquals(e1,e2) - case GreaterEquals(e1,e2) => LessThan(e1,e2) - case i @ IfExpr(c,e1,e2) => IfExpr(c, negate(e1), negate(e2)).setType(i.getType) - case BooleanLiteral(b) => BooleanLiteral(!b) - case _ => Not(expr) - } - - // Warning ! This may loop forever if the substitutions are not - // well-formed! - def replace(substs: Map[Expr,Expr], expr: Expr) : Expr = { - searchAndReplaceDFS(substs.get)(expr) - } - - // Can't just be overloading because of type erasure :'( - def replaceFromIDs(substs: Map[Identifier,Expr], expr: Expr) : Expr = { - replace(substs.map(p => (Variable(p._1) -> p._2)), expr) - } - - def searchAndReplace(subst: Expr=>Option[Expr], recursive: Boolean=true)(expr: Expr) : Expr = { - def rec(ex: Expr, skip: Expr = null) : Expr = (if (ex == skip) None else subst(ex)) match { - case Some(newExpr) => { - if(newExpr.getType == Untyped) { - Settings.reporter.error("REPLACING IN EXPRESSION WITH AN UNTYPED TREE ! " + ex + " --to--> " + newExpr) - } - if(ex == newExpr) - if(recursive) rec(ex, ex) else ex - else - if(recursive) rec(newExpr) else newExpr - } - case None => ex match { - case l @ Let(i,e,b) => { - val re = rec(e) - val rb = rec(b) - if(re != e || rb != b) - Let(i, re, rb).setType(l.getType) - else - l - } - case l @ LetVar(i,e,b) => { - val re = rec(e) - val rb = rec(b) - if(re != e || rb != b) - LetVar(i, re, rb).setType(l.getType) - else - l - } - case l @ LetDef(fd, b) => { - //TODO, not sure, see comment for the next LetDef - fd.body = fd.body.map(rec(_)) - fd.precondition = fd.precondition.map(rec(_)) - fd.postcondition = fd.postcondition.map(rec(_)) - LetDef(fd, rec(b)).setType(l.getType) - } - - case lt @ LetTuple(ids, expr, body) => { - val re = rec(expr) - val rb = rec(body) - if (re != expr || rb != body) { - LetTuple(ids, re, rb).setType(lt.getType) - } else { - lt - } - } - case n @ NAryOperator(args, recons) => { - var change = false - val rargs = args.map(a => { - val ra = rec(a) - if(ra != a) { - change = true - ra - } else { - a - } - }) - if(change) - recons(rargs).setType(n.getType) - else - n - } - case b @ BinaryOperator(t1,t2,recons) => { - val r1 = rec(t1) - val r2 = rec(t2) - if(r1 != t1 || r2 != t2) - recons(r1,r2).setType(b.getType) - else - b - } - case u @ UnaryOperator(t,recons) => { - val r = rec(t) - if(r != t) - recons(r).setType(u.getType) - else - u - } - case i @ IfExpr(t1,t2,t3) => { - val r1 = rec(t1) - val r2 = rec(t2) - val r3 = rec(t3) - if(r1 != t1 || r2 != t2 || r3 != t3) - IfExpr(rec(t1),rec(t2),rec(t3)).setType(i.getType) - else - i - } - case m @ MatchExpr(scrut,cses) => MatchExpr(rec(scrut), cses.map(inCase(_))).setType(m.getType).setPosInfo(m) - - case c @ Choose(args, body) => - val body2 = rec(body) - - if (body != body2) { - Choose(args, body2).setType(c.getType) - } else { - c - } - - case t if t.isInstanceOf[Terminal] => t - case unhandled => scala.sys.error("Non-terminal case should be handled in searchAndReplace: " + unhandled) - } - } - - def inCase(cse: MatchCase) : MatchCase = cse match { - case SimpleCase(pat, rhs) => SimpleCase(pat, rec(rhs)) - case GuardedCase(pat, guard, rhs) => GuardedCase(pat, rec(guard), rec(rhs)) - } - - rec(expr) - } - - def searchAndReplaceDFS(subst: Expr=>Option[Expr])(expr: Expr) : Expr = { - val (res,_) = searchAndReplaceDFSandTrackChanges(subst)(expr) - res - } - - def searchAndReplaceDFSandTrackChanges(subst: Expr=>Option[Expr])(expr: Expr) : (Expr,Boolean) = { - var somethingChanged: Boolean = false - def applySubst(ex: Expr) : Expr = subst(ex) match { - case None => ex - case Some(newEx) => { - somethingChanged = true - if(newEx.getType == Untyped) { - Settings.reporter.warning("REPLACING [" + ex + "] WITH AN UNTYPED EXPRESSION !") - Settings.reporter.warning("Here's the new expression: " + newEx) - } - newEx - } - } - - def rec(ex: Expr) : Expr = ex match { - case l @ Let(i,e,b) => { - val re = rec(e) - val rb = rec(b) - applySubst(if(re != e || rb != b) { - Let(i,re,rb).setType(l.getType) - } else { - l - }) - } - case l @ LetTuple(ids,e,b) => { - val re = rec(e) - val rb = rec(b) - applySubst(if(re != e || rb != b) { - LetTuple(ids,re,rb).setType(l.getType) - } else { - l - }) - } - case l @ LetVar(i,e,b) => { - val re = rec(e) - val rb = rec(b) - applySubst(if(re != e || rb != b) { - LetVar(i,re,rb).setType(l.getType) - } else { - l - }) - } - case l @ LetDef(fd,b) => { - //TODO: Not sure: I actually need the replace to occurs even in the pre/post condition, hope this is correct - fd.body = fd.body.map(rec(_)) - fd.precondition = fd.precondition.map(rec(_)) - fd.postcondition = fd.postcondition.map(rec(_)) - val rl = LetDef(fd, rec(b)).setType(l.getType) - applySubst(rl) - } - case n @ NAryOperator(args, recons) => { - var change = false - val rargs = args.map(a => { - val ra = rec(a) - if(ra != a) { - change = true - ra - } else { - a - } - }) - applySubst(if(change) { - recons(rargs).setType(n.getType) - } else { - n - }) - } - case b @ BinaryOperator(t1,t2,recons) => { - val r1 = rec(t1) - val r2 = rec(t2) - applySubst(if(r1 != t1 || r2 != t2) { - recons(r1,r2).setType(b.getType) - } else { - b - }) - } - case u @ UnaryOperator(t,recons) => { - val r = rec(t) - applySubst(if(r != t) { - recons(r).setType(u.getType) - } else { - u - }) - } - case i @ IfExpr(t1,t2,t3) => { - val r1 = rec(t1) - val r2 = rec(t2) - val r3 = rec(t3) - applySubst(if(r1 != t1 || r2 != t2 || r3 != t3) { - IfExpr(r1,r2,r3).setType(i.getType) - } else { - i - }) - } - case m @ MatchExpr(scrut,cses) => { - val rscrut = rec(scrut) - val (newCses,changes) = cses.map(inCase(_)).unzip - applySubst(if(rscrut != scrut || changes.exists(res=>res)) { - MatchExpr(rscrut, newCses).setType(m.getType).setPosInfo(m) - } else { - m - }) - } - - case c @ Choose(args, body) => - val body2 = rec(body) - - applySubst(if (body != body2) { - Choose(args, body2).setType(c.getType).setPosInfo(c) - } else { - c - }) - - case t if t.isInstanceOf[Terminal] => applySubst(t) - case unhandled => scala.sys.error("Non-terminal case should be handled in searchAndReplaceDFS: " + unhandled) - } - - def inCase(cse: MatchCase) : (MatchCase,Boolean) = cse match { - case s @ SimpleCase(pat, rhs) => { - val rrhs = rec(rhs) - if(rrhs != rhs) { - (SimpleCase(pat, rrhs), true) - } else { - (s, false) - } - } - case g @ GuardedCase(pat, guard, rhs) => { - val rguard = rec(guard) - val rrhs = rec(rhs) - if(rguard != guard || rrhs != rhs) { - (GuardedCase(pat, rguard, rrhs), true) - } else { - (g, false) - } - } - } - - val res = rec(expr) - (res, somethingChanged) - } - - // rewrites pattern-matching expressions to use fresh variables for the binders - def freshenLocals(expr: Expr) : Expr = { - def rewritePattern(p: Pattern, sm: Map[Identifier,Identifier]) : Pattern = p match { - case InstanceOfPattern(Some(b), ctd) => InstanceOfPattern(Some(sm(b)), ctd) - case WildcardPattern(Some(b)) => WildcardPattern(Some(sm(b))) - case CaseClassPattern(ob, ccd, sps) => CaseClassPattern(ob.map(sm(_)), ccd, sps.map(rewritePattern(_, sm))) - case other => other - } - - def freshenCase(cse: MatchCase) : MatchCase = { - val allBinders: Set[Identifier] = cse.pattern.binders - val subMap: Map[Identifier,Identifier] = Map(allBinders.map(i => (i, FreshIdentifier(i.name, true).setType(i.getType))).toSeq : _*) - val subVarMap: Map[Expr,Expr] = subMap.map(kv => (Variable(kv._1) -> Variable(kv._2))) - - cse match { - case SimpleCase(pattern, rhs) => SimpleCase(rewritePattern(pattern, subMap), replace(subVarMap, rhs)) - case GuardedCase(pattern, guard, rhs) => GuardedCase(rewritePattern(pattern, subMap), replace(subVarMap, guard), replace(subVarMap, rhs)) - } - } - - def applyToTree(e : Expr) : Option[Expr] = e match { - case m @ MatchExpr(s, cses) => Some(MatchExpr(s, cses.map(freshenCase(_))).setType(m.getType).setPosInfo(m)) - case l @ Let(i,e,b) => { - val newID = FreshIdentifier(i.name, true).setType(i.getType) - Some(Let(newID, e, replace(Map(Variable(i) -> Variable(newID)), b))) - } - case _ => None - } - - searchAndReplaceDFS(applyToTree)(expr) - } - - // convert describes how to compute a value for the leaves (that includes - // functions with no args.) - // combine descriess how to combine two values - def treeCatamorphism[A](convert: Expr=>A, combine: (A,A)=>A, expression: Expr) : A = { - treeCatamorphism(convert, combine, (e:Expr,a:A)=>a, expression) - } - // compute allows the catamorphism to change the combined value depending on the tree - def treeCatamorphism[A](convert: Expr=>A, combine: (A,A)=>A, compute: (Expr,A)=>A, expression: Expr) : A = { - def rec(expr: Expr) : A = expr match { - case l @ Let(_, e, b) => compute(l, combine(rec(e), rec(b))) - case l @ LetVar(_, e, b) => compute(l, combine(rec(e), rec(b))) - case l @ LetDef(fd, b) => {//TODO, still not sure about the semantic - val exprs: Seq[Expr] = fd.precondition.toSeq ++ fd.body.toSeq ++ fd.postcondition.toSeq ++ Seq(b) - compute(l, exprs.map(rec(_)).reduceLeft(combine)) - } - case n @ NAryOperator(args, _) => { - if(args.size == 0) - compute(n, convert(n)) - else - compute(n, args.map(rec(_)).reduceLeft(combine)) - } - case b @ BinaryOperator(a1,a2,_) => compute(b, combine(rec(a1),rec(a2))) - case u @ UnaryOperator(a,_) => compute(u, rec(a)) - case i @ IfExpr(a1,a2,a3) => compute(i, combine(combine(rec(a1), rec(a2)), rec(a3))) - case m @ MatchExpr(scrut, cses) => compute(m, (scrut +: cses.flatMap(_.expressions)).map(rec(_)).reduceLeft(combine)) - case a @ AnonymousFunction(es, ev) => compute(a, (es.flatMap(e => e._1 ++ Seq(e._2)) ++ Seq(ev)).map(rec(_)).reduceLeft(combine)) - case c @ Choose(args, body) => compute(c, rec(body)) - case t: Terminal => compute(t, convert(t)) - case unhandled => scala.sys.error("Non-terminal case should be handled in treeCatamorphism: " + unhandled) - } - - rec(expression) - } - - def flattenBlocks(expr: Expr): Expr = { - def applyToTree(expr: Expr): Option[Expr] = expr match { - case Block(exprs, last) => { - val nexprs = (exprs :+ last).flatMap{ - case Block(es2, el) => es2 :+ el - case UnitLiteral => Seq() - case e2 => Seq(e2) - } - val fexpr = nexprs match { - case Seq() => UnitLiteral - case Seq(e) => e - case es => Block(es.init, es.last).setType(es.last.getType) - } - Some(fexpr) - } - case _ => None - } - searchAndReplaceDFS(applyToTree)(expr) - } - - //checking whether the expr is pure, that is do not contains any non-pure construct: assign, while, blocks, array, ... - //this is expected to be true when entering the "backend" of Leon - def isPure(expr: Expr): Boolean = { - def convert(t: Expr) : Boolean = t match { - case Block(_, _) => false - case Assignment(_, _) => false - case While(_, _) => false - case LetVar(_, _, _) => false - case LetDef(_, _) => false - case ArrayUpdate(_, _, _) => false - case ArrayMake(_) => false - case ArrayClone(_) => false - case Epsilon(_) => false - case _ => true - } - def combine(b1: Boolean, b2: Boolean) = b1 && b2 - def compute(e: Expr, b: Boolean) = e match { - case Block(_, _) => false - case Assignment(_, _) => false - case While(_, _) => false - case LetVar(_, _, _) => false - case LetDef(_, _) => false - case ArrayUpdate(_, _, _) => false - case ArrayMake(_) => false - case ArrayClone(_) => false - case Epsilon(_) => false - case _ => b - } - treeCatamorphism(convert, combine, compute, expr) - } - - def containsEpsilon(expr: Expr): Boolean = { - def convert(t : Expr) : Boolean = t match { - case (l : Epsilon) => true - case _ => false - } - def combine(c1 : Boolean, c2 : Boolean) : Boolean = c1 || c2 - def compute(t : Expr, c : Boolean) = t match { - case (l : Epsilon) => true - case _ => c - } - treeCatamorphism(convert, combine, compute, expr) - } - - def containsLetDef(expr: Expr): Boolean = { - def convert(t : Expr) : Boolean = t match { - case (l : LetDef) => true - case _ => false - } - def combine(c1 : Boolean, c2 : Boolean) : Boolean = c1 || c2 - def compute(t : Expr, c : Boolean) = t match { - case (l : LetDef) => true - case _ => c - } - treeCatamorphism(convert, combine, compute, expr) - } - def containsIfExpr(expr: Expr): Boolean = { - def convert(t : Expr) : Boolean = t match { - case (i: IfExpr) => true - case _ => false - } - def combine(c1 : Boolean, c2 : Boolean) : Boolean = c1 || c2 - def compute(t : Expr, c : Boolean) = t match { - case (i: IfExpr) => true - case _ => c - } - treeCatamorphism(convert, combine, compute, expr) - } - - def variablesOf(expr: Expr) : Set[Identifier] = { - def convert(t: Expr) : Set[Identifier] = t match { - case Variable(i) => Set(i) - case _ => Set.empty - } - def combine(s1: Set[Identifier], s2: Set[Identifier]) = s1 ++ s2 - def compute(t: Expr, s: Set[Identifier]) = t match { - case Let(i,_,_) => s -- Set(i) - case MatchExpr(_, cses) => s -- (cses.map(_.pattern.binders).foldLeft(Set[Identifier]())((a, b) => a ++ b)) - case AnonymousFunctionInvocation(i,_) => s ++ Set[Identifier](i) - case _ => s - } - treeCatamorphism(convert, combine, compute, expr) - } - - def containsFunctionCalls(expr : Expr) : Boolean = { - def convert(t : Expr) : Boolean = t match { - case f : FunctionInvocation => true - case _ => false - } - def combine(c1 : Boolean, c2 : Boolean) : Boolean = c1 || c2 - def compute(t : Expr, c : Boolean) = t match { - case f : FunctionInvocation => true - case _ => c - } - treeCatamorphism(convert, combine, compute, expr) - } - - def topLevelFunctionCallsOf(expr: Expr, barring : Set[FunDef] = Set.empty) : Set[FunctionInvocation] = { - def convert(t: Expr) : Set[FunctionInvocation] = t match { - case f @ FunctionInvocation(fd, _) if(!barring(fd)) => Set(f) - case _ => Set.empty - } - def combine(s1: Set[FunctionInvocation], s2: Set[FunctionInvocation]) = s1 ++ s2 - def compute(t: Expr, s: Set[FunctionInvocation]) = t match { - case f @ FunctionInvocation(fd, _) if(!barring(fd)) => Set(f) // ++ s that's the difference with the one below - case _ => s - } - treeCatamorphism(convert, combine, compute, expr) - } - - def allNonRecursiveFunctionCallsOf(expr: Expr, program: Program) : Set[FunctionInvocation] = { - def convert(t: Expr) : Set[FunctionInvocation] = t match { - case f @ FunctionInvocation(fd, _) if program.isRecursive(fd) => Set(f) - case _ => Set.empty - } - - def combine(s1: Set[FunctionInvocation], s2: Set[FunctionInvocation]) = s1 ++ s2 - - def compute(t: Expr, s: Set[FunctionInvocation]) = t match { - case f @ FunctionInvocation(fd,_) if program.isRecursive(fd) => Set(f) ++ s - case _ => s - } - treeCatamorphism(convert, combine, compute, expr) - } - - def functionCallsOf(expr: Expr) : Set[FunctionInvocation] = { - def convert(t: Expr) : Set[FunctionInvocation] = t match { - case f @ FunctionInvocation(_, _) => Set(f) - case _ => Set.empty - } - def combine(s1: Set[FunctionInvocation], s2: Set[FunctionInvocation]) = s1 ++ s2 - def compute(t: Expr, s: Set[FunctionInvocation]) = t match { - case f @ FunctionInvocation(_, _) => Set(f) ++ s - case _ => s - } - treeCatamorphism(convert, combine, compute, expr) - } - - def contains(expr: Expr, matcher: Expr=>Boolean) : Boolean = { - treeCatamorphism[Boolean]( - matcher, - (b1: Boolean, b2: Boolean) => b1 || b2, - (t: Expr, b: Boolean) => b || matcher(t), - expr) - } - - def allIdentifiers(expr: Expr) : Set[Identifier] = expr match { - case l @ Let(binder, e, b) => allIdentifiers(e) ++ allIdentifiers(b) + binder - case l @ LetVar(binder, e, b) => allIdentifiers(e) ++ allIdentifiers(b) + binder - case l @ LetDef(fd, b) => allIdentifiers(fd.getBody) ++ allIdentifiers(b) + fd.id - case n @ NAryOperator(args, _) => - (args map (Trees.allIdentifiers(_))).foldLeft(Set[Identifier]())((a, b) => a ++ b) - case b @ BinaryOperator(a1,a2,_) => allIdentifiers(a1) ++ allIdentifiers(a2) - case u @ UnaryOperator(a,_) => allIdentifiers(a) - case i @ IfExpr(a1,a2,a3) => allIdentifiers(a1) ++ allIdentifiers(a2) ++ allIdentifiers(a3) - case m @ MatchExpr(scrut, cses) => - (cses map (_.allIdentifiers)).foldLeft(Set[Identifier]())((a, b) => a ++ b) ++ allIdentifiers(scrut) - case Variable(id) => Set(id) - case t: Terminal => Set.empty - } - - def allDeBruijnIndices(expr: Expr) : Set[DeBruijnIndex] = { - def convert(t: Expr) : Set[DeBruijnIndex] = t match { - case i @ DeBruijnIndex(idx) => Set(i) - case _ => Set.empty - } - def combine(s1: Set[DeBruijnIndex], s2: Set[DeBruijnIndex]) = s1 ++ s2 - treeCatamorphism(convert, combine, expr) - } - - /* Simplifies let expressions: - * - removes lets when expression never occurs - * - simplifies when expressions occurs exactly once - * - expands when expression is just a variable. - * Note that the code is simple but far from optimal (many traversals...) - */ - def simplifyLets(expr: Expr) : Expr = { - def simplerLet(t: Expr) : Option[Expr] = t match { - case letExpr @ Let(i, t: Terminal, b) => Some(replace(Map((Variable(i) -> t)), b)) - case letExpr @ Let(i,e,b) => { - val occurences = treeCatamorphism[Int]((e:Expr) => e match { - case Variable(x) if x == i => 1 - case _ => 0 - }, (x:Int,y:Int)=>x+y, b) - if(occurences == 0) { - Some(b) - } else if(occurences == 1) { - Some(replace(Map((Variable(i) -> e)), b)) - } else { - None - } - } - //case letTuple @ LetTuple(ids, expr, body) if ids.size == 1 => - // simplerLet(Let(ids.head, TupleSelect(expr, 1).setType(ids.head.getType), body)) - - case letTuple @ LetTuple(ids, Tuple(exprs), body) => - - var newBody = body - - val (remIds, remExprs) = (ids zip exprs).filter { - case (id, value: Terminal) => - newBody = replace(Map((Variable(id) -> value)), newBody) - //we replace, so we drop old - false - case (id, value) => - val occurences = treeCatamorphism[Int]((e:Expr) => e match { - case Variable(x) if x == id => 1 - case _ => 0 - }, (x:Int,y:Int)=>x+y, body) - - if(occurences == 0) { - false - } else if(occurences == 1) { - newBody = replace(Map((Variable(id) -> value)), newBody) - false - } else { - true - } - }.unzip - - - if (remIds.isEmpty) { - Some(newBody) - } else if (remIds.tail.isEmpty) { - Some(Let(remIds.head, remExprs.head, newBody)) - } else { - Some(LetTuple(remIds, Tuple(remExprs), newBody)) - } - case _ => None - } - searchAndReplaceDFS(simplerLet)(expr) - } - - // Pulls out all let constructs to the top level, and makes sure they're - // properly ordered. - private type DefPair = (Identifier,Expr) - private type DefPairs = List[DefPair] - private def allLetDefinitions(expr: Expr) : DefPairs = treeCatamorphism[DefPairs]( - (e: Expr) => Nil, - (s1: DefPairs, s2: DefPairs) => s1 ::: s2, - (e: Expr, dps: DefPairs) => e match { - case Let(i, e, _) => (i,e) :: dps - case _ => dps - }, - expr) - - private def killAllLets(expr: Expr) : Expr = searchAndReplaceDFS((e: Expr) => e match { - case Let(_,_,ex) => Some(ex) - case _ => None - })(expr) - - def liftLets(expr: Expr) : Expr = { - val initialDefinitionPairs = allLetDefinitions(expr) - val definitionPairs = initialDefinitionPairs.map(p => (p._1, killAllLets(p._2))) - val occursLists : Map[Identifier,Set[Identifier]] = Map(definitionPairs.map((dp: DefPair) => (dp._1 -> variablesOf(dp._2).toSet.filter(_.isLetBinder))) : _*) - var newList : DefPairs = Nil - var placed : Set[Identifier] = Set.empty - val toPlace = definitionPairs.size - var placedC = 0 - var traversals = 0 - - while(placedC < toPlace) { - if(traversals > toPlace + 1) { - scala.sys.error("Cycle in let definitions or multiple definition for the same identifier in liftLets : " + definitionPairs.mkString("\n")) - } - for((id,ex) <- definitionPairs) if (!placed(id)) { - if((occursLists(id) -- placed) == Set.empty) { - placed = placed + id - newList = (id,ex) :: newList - placedC = placedC + 1 - } - } - traversals = traversals + 1 - } - - val noLets = killAllLets(expr) - - val res = (newList.foldLeft(noLets)((e,iap) => Let(iap._1, iap._2, e))) - simplifyLets(res) - } - - def wellOrderedLets(tree : Expr) : Boolean = { - val pairs = allLetDefinitions(tree) - val definitions: Set[Identifier] = Set(pairs.map(_._1) : _*) - val vars: Set[Identifier] = variablesOf(tree) - val intersection = vars intersect definitions - if(!intersection.isEmpty) { - intersection.foreach(id => { - Settings.reporter.error("Variable with identifier '" + id + "' has escaped its let-definition !") - }) - false - } else { - vars.forall(id => if(id.isLetBinder) { - Settings.reporter.error("Variable with identifier '" + id + "' has lost its let-definition (it disappeared??)") - false - } else { - true - }) - } - } - - /* Fully expands all let expressions. */ - def expandLets(expr: Expr) : Expr = { - def rec(ex: Expr, s: Map[Identifier,Expr]) : Expr = ex match { - case v @ Variable(id) if s.isDefinedAt(id) => rec(s(id), s) - case l @ Let(i,e,b) => rec(b, s + (i -> rec(e, s))) - case i @ IfExpr(t1,t2,t3) => IfExpr(rec(t1, s),rec(t2, s),rec(t3, s)).setType(i.getType) - case m @ MatchExpr(scrut,cses) => MatchExpr(rec(scrut, s), cses.map(inCase(_, s))).setType(m.getType).setPosInfo(m) - case n @ NAryOperator(args, recons) => { - var change = false - val rargs = args.map(a => { - val ra = rec(a, s) - if(ra != a) { - change = true - ra - } else { - a - } - }) - if(change) - recons(rargs).setType(n.getType) - else - n - } - case b @ BinaryOperator(t1,t2,recons) => { - val r1 = rec(t1, s) - val r2 = rec(t2, s) - if(r1 != t1 || r2 != t2) - recons(r1,r2).setType(b.getType) - else - b - } - case u @ UnaryOperator(t,recons) => { - val r = rec(t, s) - if(r != t) - recons(r).setType(u.getType) - else - u - } - case t if t.isInstanceOf[Terminal] => t - case unhandled => scala.sys.error("Unhandled case in expandLets: " + unhandled) - } - - def inCase(cse: MatchCase, s: Map[Identifier,Expr]) : MatchCase = cse match { - case SimpleCase(pat, rhs) => SimpleCase(pat, rec(rhs, s)) - case GuardedCase(pat, guard, rhs) => GuardedCase(pat, rec(guard, s), rec(rhs, s)) - } - - rec(expr, Map.empty) - } - - object SimplePatternMatching { - def isSimple(me: MatchExpr) : Boolean = unapply(me).isDefined - - // (scrutinee, classtype, list((caseclassdef, variable, list(variable), rhs))) - def unapply(e: MatchExpr) : Option[(Expr,ClassType,Seq[(CaseClassDef,Identifier,Seq[Identifier],Expr)])] = { - val MatchExpr(scrutinee, cases) = e - val sType = scrutinee.getType - - if(sType.isInstanceOf[TupleType]) { - None - } else if(sType.isInstanceOf[AbstractClassType]) { - val cCD = sType.asInstanceOf[AbstractClassType].classDef - if(cases.size == cCD.knownChildren.size && cases.forall(!_.hasGuard)) { - var seen = Set.empty[ClassTypeDef] - - var lle : List[(CaseClassDef,Identifier,List[Identifier],Expr)] = Nil - for(cse <- cases) { - cse match { - case SimpleCase(CaseClassPattern(binder, ccd, subPats), rhs) if subPats.forall(_.isInstanceOf[WildcardPattern]) => { - seen = seen + ccd - - val patID : Identifier = if(binder.isDefined) { - binder.get - } else { - FreshIdentifier("cse", true).setType(CaseClassType(ccd)) - } - - val argIDs : List[Identifier] = (ccd.fields zip subPats.map(_.asInstanceOf[WildcardPattern])).map(p => if(p._2.binder.isDefined) { - p._2.binder.get - } else { - FreshIdentifier("pat", true).setType(p._1.tpe) - }).toList - - lle = (ccd, patID, argIDs, rhs) :: lle - } - case _ => ; - } - } - lle = lle.reverse - - if(seen.size == cases.size) { - Some((scrutinee, sType.asInstanceOf[AbstractClassType], lle)) - } else { - None - } - } else { - None - } - } else { - val cCD = sType.asInstanceOf[CaseClassType].classDef - if(cases.size == 1 && !cases(0).hasGuard) { - val SimpleCase(pat,rhs) = cases(0).asInstanceOf[SimpleCase] - pat match { - case CaseClassPattern(binder, ccd, subPats) if (ccd == cCD && subPats.forall(_.isInstanceOf[WildcardPattern])) => { - val patID : Identifier = if(binder.isDefined) { - binder.get - } else { - FreshIdentifier("cse", true).setType(CaseClassType(ccd)) - } - - val argIDs : List[Identifier] = (ccd.fields zip subPats.map(_.asInstanceOf[WildcardPattern])).map(p => if(p._2.binder.isDefined) { - p._2.binder.get - } else { - FreshIdentifier("pat", true).setType(p._1.tpe) - }).toList - - Some((scrutinee, CaseClassType(cCD), List((cCD, patID, argIDs, rhs)))) - } - case _ => None - } - } else { - None - } - } - } - } - - object NotSoSimplePatternMatching { - def coversType(tpe: ClassTypeDef, patterns: Seq[Pattern]) : Boolean = { - if(patterns.isEmpty) { - false - } else if(patterns.exists(_.isInstanceOf[WildcardPattern])) { - true - } else { - val allSubtypes: Seq[CaseClassDef] = tpe match { - case acd @ AbstractClassDef(_,_) => acd.knownDescendents.filter(_.isInstanceOf[CaseClassDef]).map(_.asInstanceOf[CaseClassDef]) - case ccd: CaseClassDef => List(ccd) - } - - var seen: Set[CaseClassDef] = Set.empty - var secondLevel: Map[(CaseClassDef,Int),List[Pattern]] = Map.empty - - for(pat <- patterns) if (pat.isInstanceOf[CaseClassPattern]) { - val pattern: CaseClassPattern = pat.asInstanceOf[CaseClassPattern] - val ccd: CaseClassDef = pattern.caseClassDef - seen = seen + ccd - - for((subPattern,i) <- (pattern.subPatterns.zipWithIndex)) { - val seenSoFar = secondLevel.getOrElse((ccd,i), Nil) - secondLevel = secondLevel + ((ccd,i) -> (subPattern :: seenSoFar)) - } - } - - allSubtypes.forall(ccd => { - seen(ccd) && ccd.fields.zipWithIndex.forall(p => p._1.tpe match { - case t: ClassType => coversType(t.classDef, secondLevel.getOrElse((ccd, p._2), Nil)) - case _ => true - }) - }) - } - } - - def unapply(pm : MatchExpr) : Option[MatchExpr] = if(!Settings.experimental) None else (pm match { - case MatchExpr(scrutinee, cases) if cases.forall(_.isInstanceOf[SimpleCase]) => { - val allPatterns = cases.map(_.pattern) - Settings.reporter.info("This might be a complete pattern-matching expression:") - Settings.reporter.info(pm) - Settings.reporter.info("Covered? " + coversType(pm.scrutineeClassType.classDef, allPatterns)) - None - } - case _ => None - }) - } - - private var matchConverterCache = new scala.collection.mutable.HashMap[Expr,Expr]() - /** Rewrites all pattern-matching expressions into if-then-else expressions, - * with additional error conditions. Does not introduce additional variables. - * We use a cache because we can. */ - def matchToIfThenElse(expr: Expr) : Expr = { - val toRet = if(matchConverterCache.isDefinedAt(expr)) { - matchConverterCache(expr) - } else { - val converted = convertMatchToIfThenElse(expr) - matchConverterCache(expr) = converted - converted - } - - toRet - } - - def conditionForPattern(in: Expr, pattern: Pattern) : Expr = pattern match { - case WildcardPattern(_) => BooleanLiteral(true) - case InstanceOfPattern(_,_) => scala.sys.error("InstanceOfPattern not yet supported.") - case CaseClassPattern(_, ccd, subps) => { - assert(ccd.fields.size == subps.size) - val pairs = ccd.fields.map(_.id).toList zip subps.toList - val subTests = pairs.map(p => conditionForPattern(CaseClassSelector(ccd, in, p._1), p._2)) - val together = And(subTests) - And(CaseClassInstanceOf(ccd, in), together) - } - case TuplePattern(_, subps) => { - val TupleType(tpes) = in.getType - assert(tpes.size == subps.size) - val subTests = subps.zipWithIndex.map{case (p, i) => conditionForPattern(TupleSelect(in, i+1).setType(tpes(i)), p)} - And(subTests) - } - } - - private def convertMatchToIfThenElse(expr: Expr) : Expr = { - def mapForPattern(in: Expr, pattern: Pattern) : Map[Identifier,Expr] = pattern match { - case WildcardPattern(None) => Map.empty - case WildcardPattern(Some(id)) => Map(id -> in) - case InstanceOfPattern(None, _) => Map.empty - case InstanceOfPattern(Some(id), _) => Map(id -> in) - case CaseClassPattern(b, ccd, subps) => { - assert(ccd.fields.size == subps.size) - val pairs = ccd.fields.map(_.id).toList zip subps.toList - val subMaps = pairs.map(p => mapForPattern(CaseClassSelector(ccd, in, p._1), p._2)) - val together = subMaps.foldLeft(Map.empty[Identifier,Expr])(_ ++ _) - b match { - case Some(id) => Map(id -> in) ++ together - case None => together - } - } - case TuplePattern(b, subps) => { - val TupleType(tpes) = in.getType - assert(tpes.size == subps.size) - - val maps = subps.zipWithIndex.map{case (p, i) => mapForPattern(TupleSelect(in, i+1).setType(tpes(i)), p)} - val map = maps.foldLeft(Map.empty[Identifier,Expr])(_ ++ _) - b match { - case Some(id) => map + (id -> in) - case None => map - } - } - } - - def rewritePM(e: Expr) : Option[Expr] = e match { - case m @ MatchExpr(scrut, cases) => { - // println("Rewriting the following PM: " + e) - - val condsAndRhs = for(cse <- cases) yield { - val map = mapForPattern(scrut, cse.pattern) - val patCond = conditionForPattern(scrut, cse.pattern) - val realCond = cse.theGuard match { - case Some(g) => And(patCond, replaceFromIDs(map, g)) - case None => patCond - } - val newRhs = replaceFromIDs(map, cse.rhs) - (realCond, newRhs) - } - - val optCondsAndRhs = if(SimplePatternMatching.isSimple(m)) { - // this is a hackish optimization: because we know all cases are covered, we replace the last condition by true (and that drops the check) - val lastExpr = condsAndRhs.last._2 - - condsAndRhs.dropRight(1) ++ Seq((BooleanLiteral(true),lastExpr)) - } else { - condsAndRhs - } - - val bigIte = optCondsAndRhs.foldRight[Expr](Error("non-exhaustive match").setType(bestRealType(m.getType)).setPosInfo(m))((p1, ex) => { - if(p1._1 == BooleanLiteral(true)) { - p1._2 - } else { - IfExpr(p1._1, p1._2, ex).setType(m.getType) - } - }) - - Some(bigIte) - } - case _ => None - } - - searchAndReplaceDFS(rewritePM)(expr) - } - - private var mapGetConverterCache = new scala.collection.mutable.HashMap[Expr,Expr]() - /** Rewrites all map accesses with additional error conditions. */ - def mapGetWithChecks(expr: Expr) : Expr = { - val toRet = if (mapGetConverterCache.isDefinedAt(expr)) { - mapGetConverterCache(expr) - } else { - val converted = convertMapGet(expr) - mapGetConverterCache(expr) = converted - converted - } - - toRet - } - - private def convertMapGet(expr: Expr) : Expr = { - def rewriteMapGet(e: Expr) : Option[Expr] = e match { - case mg @ MapGet(m,k) => - val ida = MapIsDefinedAt(m, k) - Some(IfExpr(ida, mg, Error("key not found for map access").setType(mg.getType).setPosInfo(mg)).setType(mg.getType)) - case _ => None - } - - searchAndReplaceDFS(rewriteMapGet)(expr) - } - - // prec: expression does not contain match expressions - def measureADTChildrenDepth(expression: Expr) : Int = { - import scala.math.max - - def rec(ex: Expr, lm: Map[Identifier,Int]) : Int = ex match { - case Let(i,e,b) => rec(b,lm + (i -> rec(e,lm))) - case Variable(id) => lm.getOrElse(id, 0) - case CaseClassSelector(_, e, _) => rec(e,lm) + 1 - case NAryOperator(args, _) => if(args.isEmpty) 0 else args.map(rec(_,lm)).max - case BinaryOperator(e1,e2,_) => max(rec(e1,lm), rec(e2,lm)) - case UnaryOperator(e,_) => rec(e,lm) - case IfExpr(c,t,e) => max(max(rec(c,lm),rec(t,lm)),rec(e,lm)) - case t: Terminal => 0 - case _ => scala.sys.error("Not handled in measureChildrenDepth : " + ex) - } - - rec(expression,Map.empty) - } - - private val random = new scala.util.Random() - - def randomValue(v: Variable) : Expr = randomValue(v.getType) - def simplestValue(v: Variable) : Expr = simplestValue(v.getType) - - def randomValue(tpe: TypeTree) : Expr = tpe match { - case Int32Type => IntLiteral(random.nextInt(42)) - case BooleanType => BooleanLiteral(random.nextBoolean()) - case AbstractClassType(acd) => - val children = acd.knownChildren - randomValue(classDefToClassType(children(random.nextInt(children.size)))) - case CaseClassType(cd) => - val fields = cd.fields - CaseClass(cd, fields.map(f => randomValue(f.getType))) - case _ => throw new Exception("I can't choose random value for type " + tpe) - } - - def simplestValue(tpe: TypeTree) : Expr = tpe match { - case Int32Type => IntLiteral(0) - case BooleanType => BooleanLiteral(false) - case AbstractClassType(acd) => { - val children = acd.knownChildren - val simplerChildren = children.filter{ - case ccd @ CaseClassDef(id, Some(parent), fields) => - !fields.exists(vd => vd.getType match { - case AbstractClassType(fieldAcd) => acd == fieldAcd - case CaseClassType(fieldCcd) => ccd == fieldCcd - case _ => false - }) - case _ => false - } - def orderByNumberOfFields(fst: ClassTypeDef, snd: ClassTypeDef) : Boolean = (fst, snd) match { - case (CaseClassDef(_, _, flds1), CaseClassDef(_, _, flds2)) => flds1.size <= flds2.size - case _ => true - } - val orderedChildren = simplerChildren.sortWith(orderByNumberOfFields) - simplestValue(classDefToClassType(orderedChildren.head)) - } - case CaseClassType(ccd) => - val fields = ccd.fields - CaseClass(ccd, fields.map(f => simplestValue(f.getType))) - case SetType(baseType) => EmptySet(baseType).setType(tpe) - case MapType(fromType, toType) => EmptyMap(fromType, toType).setType(tpe) - case FunctionType(fromTypes, toType) => AnonymousFunction(Seq.empty, simplestValue(toType)).setType(tpe) - case _ => throw new Exception("I can't choose simplest value for type " + tpe) - } - - //guarentee that all IfExpr will be at the top level and as soon as you encounter a non-IfExpr, then no more IfExpr can be find in the sub-expressions - //require no-match, no-ets and only pure code - def hoistIte(expr: Expr): Expr = { - def transform(expr: Expr): Option[Expr] = expr match { - case uop@UnaryOperator(IfExpr(c, t, e), op) => Some(IfExpr(c, op(t).setType(uop.getType), op(e).setType(uop.getType)).setType(uop.getType)) - case bop@BinaryOperator(IfExpr(c, t, e), t2, op) => Some(IfExpr(c, op(t, t2).setType(bop.getType), op(e, t2).setType(bop.getType)).setType(bop.getType)) - case bop@BinaryOperator(t1, IfExpr(c, t, e), op) => Some(IfExpr(c, op(t1, t).setType(bop.getType), op(t1, e).setType(bop.getType)).setType(bop.getType)) - case nop@NAryOperator(ts, op) => { - val iteIndex = ts.indexWhere{ case IfExpr(_, _, _) => true case _ => false } - if(iteIndex == -1) None else { - val (beforeIte, startIte) = ts.splitAt(iteIndex) - val afterIte = startIte.tail - val IfExpr(c, t, e) = startIte.head - Some(IfExpr(c, - op(beforeIte ++ Seq(t) ++ afterIte).setType(nop.getType), - op(beforeIte ++ Seq(e) ++ afterIte).setType(nop.getType) - ).setType(nop.getType)) - } - } - case _ => None - } - - def fix[A](f: (A) => A, a: A): A = { - val na = f(a) - if(a == na) a else fix(f, na) - } - fix(searchAndReplaceDFS(transform), expr) - } - } diff --git a/src/main/scala/leon/synthesis/Rules.scala b/src/main/scala/leon/synthesis/Rules.scala index d5151cfca7b7a2b1bcddb2276906f4cd6a420f88..f0fb2e9c273f1a99f1d63d75e1fec9b52fecf8f2 100644 --- a/src/main/scala/leon/synthesis/Rules.scala +++ b/src/main/scala/leon/synthesis/Rules.scala @@ -3,6 +3,7 @@ package synthesis import purescala.Common._ import purescala.Trees._ +import purescala.TreeOps._ import purescala.TypeTrees._ object Rules { diff --git a/src/main/scala/leon/synthesis/Synthesizer.scala b/src/main/scala/leon/synthesis/Synthesizer.scala index b754201f22b9d559eda51b51b148d9c0643fb3d3..9e02f8025770e08043950dd7adf5a821cb8cf75b 100644 --- a/src/main/scala/leon/synthesis/Synthesizer.scala +++ b/src/main/scala/leon/synthesis/Synthesizer.scala @@ -3,6 +3,7 @@ package synthesis import purescala.Common._ import purescala.Definitions.{Program, FunDef} +import purescala.TreeOps._ import purescala.Trees.{Expr, Not} import purescala.ScalaPrinter diff --git a/src/main/scala/leon/testgen/CallGraph.scala b/src/main/scala/leon/testgen/CallGraph.scala index 16e31057390e2cea0fd1714d7631c819593ecc04..afebd8bdebd0d1bf5a370c16bb6b1ed2c5655e8e 100644 --- a/src/main/scala/leon/testgen/CallGraph.scala +++ b/src/main/scala/leon/testgen/CallGraph.scala @@ -2,6 +2,8 @@ package leon.testgen import leon.purescala.Definitions._ import leon.purescala.Trees._ +import leon.purescala.TreeOps._ +import leon.purescala.Extractors._ import leon.purescala.TypeTrees._ import leon.purescala.Common._ import leon.FairZ3Solver diff --git a/src/main/scala/leon/testgen/TestGeneration.scala b/src/main/scala/leon/testgen/TestGeneration.scala index 1b48d72579f0ff620ef7d35f41e9bcfe841f3ac4..8404294eed8975f25a21f2fcfa95882de2f48d0f 100644 --- a/src/main/scala/leon/testgen/TestGeneration.scala +++ b/src/main/scala/leon/testgen/TestGeneration.scala @@ -3,6 +3,7 @@ package leon.testgen import leon.purescala.Common._ import leon.purescala.Definitions._ import leon.purescala.Trees._ +import leon.purescala.TreeOps._ import leon.purescala.TypeTrees._ import leon.purescala.ScalaPrinter import leon.Extensions._