diff --git a/src/main/scala/leon/GlobalOptions.scala b/src/main/scala/inox/GlobalOptions.scala similarity index 100% rename from src/main/scala/leon/GlobalOptions.scala rename to src/main/scala/inox/GlobalOptions.scala diff --git a/src/main/scala/leon/LeonComponent.scala b/src/main/scala/inox/LeonComponent.scala similarity index 100% rename from src/main/scala/leon/LeonComponent.scala rename to src/main/scala/inox/LeonComponent.scala diff --git a/src/main/scala/leon/LeonContext.scala b/src/main/scala/inox/LeonContext.scala similarity index 100% rename from src/main/scala/leon/LeonContext.scala rename to src/main/scala/inox/LeonContext.scala diff --git a/src/main/scala/leon/LeonExceptions.scala b/src/main/scala/inox/LeonExceptions.scala similarity index 100% rename from src/main/scala/leon/LeonExceptions.scala rename to src/main/scala/inox/LeonExceptions.scala diff --git a/src/main/scala/leon/LeonOption.scala b/src/main/scala/inox/LeonOption.scala similarity index 100% rename from src/main/scala/leon/LeonOption.scala rename to src/main/scala/inox/LeonOption.scala diff --git a/src/main/scala/leon/Printable.scala b/src/main/scala/inox/Printable.scala similarity index 100% rename from src/main/scala/leon/Printable.scala rename to src/main/scala/inox/Printable.scala diff --git a/src/main/scala/leon/Reporter.scala b/src/main/scala/inox/Reporter.scala similarity index 100% rename from src/main/scala/leon/Reporter.scala rename to src/main/scala/inox/Reporter.scala diff --git a/src/main/scala/leon/evaluators/AbstractEvaluator.scala b/src/main/scala/inox/evaluators/AbstractEvaluator.scala similarity index 100% rename from src/main/scala/leon/evaluators/AbstractEvaluator.scala rename to src/main/scala/inox/evaluators/AbstractEvaluator.scala diff --git a/src/main/scala/leon/evaluators/AbstractOnlyEvaluator.scala b/src/main/scala/inox/evaluators/AbstractOnlyEvaluator.scala similarity index 100% rename from src/main/scala/leon/evaluators/AbstractOnlyEvaluator.scala rename to src/main/scala/inox/evaluators/AbstractOnlyEvaluator.scala diff --git a/src/main/scala/leon/evaluators/AngelicEvaluator.scala b/src/main/scala/inox/evaluators/AngelicEvaluator.scala similarity index 100% rename from src/main/scala/leon/evaluators/AngelicEvaluator.scala rename to src/main/scala/inox/evaluators/AngelicEvaluator.scala diff --git a/src/main/scala/leon/evaluators/CodeGenEvaluator.scala b/src/main/scala/inox/evaluators/CodeGenEvaluator.scala similarity index 100% rename from src/main/scala/leon/evaluators/CodeGenEvaluator.scala rename to src/main/scala/inox/evaluators/CodeGenEvaluator.scala diff --git a/src/main/scala/leon/evaluators/ContextualEvaluator.scala b/src/main/scala/inox/evaluators/ContextualEvaluator.scala similarity index 100% rename from src/main/scala/leon/evaluators/ContextualEvaluator.scala rename to src/main/scala/inox/evaluators/ContextualEvaluator.scala diff --git a/src/main/scala/leon/evaluators/DefaultEvaluator.scala b/src/main/scala/inox/evaluators/DefaultEvaluator.scala similarity index 100% rename from src/main/scala/leon/evaluators/DefaultEvaluator.scala rename to src/main/scala/inox/evaluators/DefaultEvaluator.scala diff --git a/src/main/scala/leon/evaluators/DualEvaluator.scala b/src/main/scala/inox/evaluators/DualEvaluator.scala similarity index 100% rename from src/main/scala/leon/evaluators/DualEvaluator.scala rename to src/main/scala/inox/evaluators/DualEvaluator.scala diff --git a/src/main/scala/leon/evaluators/EvaluationPhase.scala b/src/main/scala/inox/evaluators/EvaluationPhase.scala similarity index 100% rename from src/main/scala/leon/evaluators/EvaluationPhase.scala rename to src/main/scala/inox/evaluators/EvaluationPhase.scala diff --git a/src/main/scala/leon/evaluators/EvaluationResults.scala b/src/main/scala/inox/evaluators/EvaluationResults.scala similarity index 100% rename from src/main/scala/leon/evaluators/EvaluationResults.scala rename to src/main/scala/inox/evaluators/EvaluationResults.scala diff --git a/src/main/scala/leon/evaluators/Evaluator.scala b/src/main/scala/inox/evaluators/Evaluator.scala similarity index 100% rename from src/main/scala/leon/evaluators/Evaluator.scala rename to src/main/scala/inox/evaluators/Evaluator.scala diff --git a/src/main/scala/leon/evaluators/EvaluatorContexts.scala b/src/main/scala/inox/evaluators/EvaluatorContexts.scala similarity index 100% rename from src/main/scala/leon/evaluators/EvaluatorContexts.scala rename to src/main/scala/inox/evaluators/EvaluatorContexts.scala diff --git a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala b/src/main/scala/inox/evaluators/RecursiveEvaluator.scala similarity index 100% rename from src/main/scala/leon/evaluators/RecursiveEvaluator.scala rename to src/main/scala/inox/evaluators/RecursiveEvaluator.scala diff --git a/src/main/scala/leon/evaluators/ScalacEvaluator.scala b/src/main/scala/inox/evaluators/ScalacEvaluator.scala similarity index 100% rename from src/main/scala/leon/evaluators/ScalacEvaluator.scala rename to src/main/scala/inox/evaluators/ScalacEvaluator.scala diff --git a/src/main/scala/leon/evaluators/StreamEvaluator.scala b/src/main/scala/inox/evaluators/StreamEvaluator.scala similarity index 100% rename from src/main/scala/leon/evaluators/StreamEvaluator.scala rename to src/main/scala/inox/evaluators/StreamEvaluator.scala diff --git a/src/main/scala/leon/evaluators/TracingEvaluator.scala b/src/main/scala/inox/evaluators/TracingEvaluator.scala similarity index 100% rename from src/main/scala/leon/evaluators/TracingEvaluator.scala rename to src/main/scala/inox/evaluators/TracingEvaluator.scala diff --git a/src/main/scala/leon/evaluators/TrackingEvaluator.scala b/src/main/scala/inox/evaluators/TrackingEvaluator.scala similarity index 100% rename from src/main/scala/leon/evaluators/TrackingEvaluator.scala rename to src/main/scala/inox/evaluators/TrackingEvaluator.scala diff --git a/src/main/scala/leon/grammars/Aspect.scala b/src/main/scala/inox/grammars/Aspect.scala similarity index 100% rename from src/main/scala/leon/grammars/Aspect.scala rename to src/main/scala/inox/grammars/Aspect.scala diff --git a/src/main/scala/leon/grammars/BaseGrammar.scala b/src/main/scala/inox/grammars/BaseGrammar.scala similarity index 100% rename from src/main/scala/leon/grammars/BaseGrammar.scala rename to src/main/scala/inox/grammars/BaseGrammar.scala diff --git a/src/main/scala/leon/grammars/Closures.scala b/src/main/scala/inox/grammars/Closures.scala similarity index 100% rename from src/main/scala/leon/grammars/Closures.scala rename to src/main/scala/inox/grammars/Closures.scala diff --git a/src/main/scala/leon/grammars/Constants.scala b/src/main/scala/inox/grammars/Constants.scala similarity index 100% rename from src/main/scala/leon/grammars/Constants.scala rename to src/main/scala/inox/grammars/Constants.scala diff --git a/src/main/scala/leon/grammars/Empty.scala b/src/main/scala/inox/grammars/Empty.scala similarity index 100% rename from src/main/scala/leon/grammars/Empty.scala rename to src/main/scala/inox/grammars/Empty.scala diff --git a/src/main/scala/leon/grammars/EqualityGrammar.scala b/src/main/scala/inox/grammars/EqualityGrammar.scala similarity index 100% rename from src/main/scala/leon/grammars/EqualityGrammar.scala rename to src/main/scala/inox/grammars/EqualityGrammar.scala diff --git a/src/main/scala/leon/grammars/ExpressionGrammar.scala b/src/main/scala/inox/grammars/ExpressionGrammar.scala similarity index 100% rename from src/main/scala/leon/grammars/ExpressionGrammar.scala rename to src/main/scala/inox/grammars/ExpressionGrammar.scala diff --git a/src/main/scala/leon/grammars/FunctionCalls.scala b/src/main/scala/inox/grammars/FunctionCalls.scala similarity index 100% rename from src/main/scala/leon/grammars/FunctionCalls.scala rename to src/main/scala/inox/grammars/FunctionCalls.scala diff --git a/src/main/scala/leon/grammars/Label.scala b/src/main/scala/inox/grammars/Label.scala similarity index 100% rename from src/main/scala/leon/grammars/Label.scala rename to src/main/scala/inox/grammars/Label.scala diff --git a/src/main/scala/leon/grammars/OneOf.scala b/src/main/scala/inox/grammars/OneOf.scala similarity index 100% rename from src/main/scala/leon/grammars/OneOf.scala rename to src/main/scala/inox/grammars/OneOf.scala diff --git a/src/main/scala/leon/grammars/ProductionRule.scala b/src/main/scala/inox/grammars/ProductionRule.scala similarity index 100% rename from src/main/scala/leon/grammars/ProductionRule.scala rename to src/main/scala/inox/grammars/ProductionRule.scala diff --git a/src/main/scala/leon/grammars/SafeRecursiveCalls.scala b/src/main/scala/inox/grammars/SafeRecursiveCalls.scala similarity index 100% rename from src/main/scala/leon/grammars/SafeRecursiveCalls.scala rename to src/main/scala/inox/grammars/SafeRecursiveCalls.scala diff --git a/src/main/scala/leon/grammars/SimpleExpressionGrammar.scala b/src/main/scala/inox/grammars/SimpleExpressionGrammar.scala similarity index 100% rename from src/main/scala/leon/grammars/SimpleExpressionGrammar.scala rename to src/main/scala/inox/grammars/SimpleExpressionGrammar.scala diff --git a/src/main/scala/leon/grammars/Tags.scala b/src/main/scala/inox/grammars/Tags.scala similarity index 100% rename from src/main/scala/leon/grammars/Tags.scala rename to src/main/scala/inox/grammars/Tags.scala diff --git a/src/main/scala/leon/grammars/Union.scala b/src/main/scala/inox/grammars/Union.scala similarity index 100% rename from src/main/scala/leon/grammars/Union.scala rename to src/main/scala/inox/grammars/Union.scala diff --git a/src/main/scala/leon/grammars/ValueGrammar.scala b/src/main/scala/inox/grammars/ValueGrammar.scala similarity index 100% rename from src/main/scala/leon/grammars/ValueGrammar.scala rename to src/main/scala/inox/grammars/ValueGrammar.scala diff --git a/src/main/scala/leon/grammars/aspects/DepthBound.scala b/src/main/scala/inox/grammars/aspects/DepthBound.scala similarity index 100% rename from src/main/scala/leon/grammars/aspects/DepthBound.scala rename to src/main/scala/inox/grammars/aspects/DepthBound.scala diff --git a/src/main/scala/leon/grammars/aspects/ExtraTerminals.scala b/src/main/scala/inox/grammars/aspects/ExtraTerminals.scala similarity index 100% rename from src/main/scala/leon/grammars/aspects/ExtraTerminals.scala rename to src/main/scala/inox/grammars/aspects/ExtraTerminals.scala diff --git a/src/main/scala/leon/grammars/aspects/PersistentAspect.scala b/src/main/scala/inox/grammars/aspects/PersistentAspect.scala similarity index 100% rename from src/main/scala/leon/grammars/aspects/PersistentAspect.scala rename to src/main/scala/inox/grammars/aspects/PersistentAspect.scala diff --git a/src/main/scala/leon/grammars/aspects/SimilarTo.scala b/src/main/scala/inox/grammars/aspects/SimilarTo.scala similarity index 100% rename from src/main/scala/leon/grammars/aspects/SimilarTo.scala rename to src/main/scala/inox/grammars/aspects/SimilarTo.scala diff --git a/src/main/scala/leon/grammars/aspects/Sized.scala b/src/main/scala/inox/grammars/aspects/Sized.scala similarity index 100% rename from src/main/scala/leon/grammars/aspects/Sized.scala rename to src/main/scala/inox/grammars/aspects/Sized.scala diff --git a/src/main/scala/leon/grammars/aspects/Tagged.scala b/src/main/scala/inox/grammars/aspects/Tagged.scala similarity index 100% rename from src/main/scala/leon/grammars/aspects/Tagged.scala rename to src/main/scala/inox/grammars/aspects/Tagged.scala diff --git a/src/main/scala/leon/grammars/aspects/TypeDepthBound.scala b/src/main/scala/inox/grammars/aspects/TypeDepthBound.scala similarity index 100% rename from src/main/scala/leon/grammars/aspects/TypeDepthBound.scala rename to src/main/scala/inox/grammars/aspects/TypeDepthBound.scala diff --git a/src/main/scala/leon/package.scala b/src/main/scala/inox/package.scala similarity index 100% rename from src/main/scala/leon/package.scala rename to src/main/scala/inox/package.scala diff --git a/src/main/scala/leon/solvers/ADTManager.scala b/src/main/scala/inox/solvers/ADTManager.scala similarity index 100% rename from src/main/scala/leon/solvers/ADTManager.scala rename to src/main/scala/inox/solvers/ADTManager.scala diff --git a/src/main/scala/leon/solvers/CantResetException.scala b/src/main/scala/inox/solvers/CantResetException.scala similarity index 100% rename from src/main/scala/leon/solvers/CantResetException.scala rename to src/main/scala/inox/solvers/CantResetException.scala diff --git a/src/main/scala/leon/solvers/EnumerationSolver.scala b/src/main/scala/inox/solvers/EnumerationSolver.scala similarity index 100% rename from src/main/scala/leon/solvers/EnumerationSolver.scala rename to src/main/scala/inox/solvers/EnumerationSolver.scala diff --git a/src/main/scala/leon/solvers/EvaluatingSolver.scala b/src/main/scala/inox/solvers/EvaluatingSolver.scala similarity index 100% rename from src/main/scala/leon/solvers/EvaluatingSolver.scala rename to src/main/scala/inox/solvers/EvaluatingSolver.scala diff --git a/src/main/scala/leon/solvers/GroundSolver.scala b/src/main/scala/inox/solvers/GroundSolver.scala similarity index 100% rename from src/main/scala/leon/solvers/GroundSolver.scala rename to src/main/scala/inox/solvers/GroundSolver.scala diff --git a/src/main/scala/leon/solvers/Model.scala b/src/main/scala/inox/solvers/Model.scala similarity index 100% rename from src/main/scala/leon/solvers/Model.scala rename to src/main/scala/inox/solvers/Model.scala diff --git a/src/main/scala/leon/solvers/NaiveAssumptionSolver.scala b/src/main/scala/inox/solvers/NaiveAssumptionSolver.scala similarity index 100% rename from src/main/scala/leon/solvers/NaiveAssumptionSolver.scala rename to src/main/scala/inox/solvers/NaiveAssumptionSolver.scala diff --git a/src/main/scala/leon/solvers/RawArray.scala b/src/main/scala/inox/solvers/RawArray.scala similarity index 100% rename from src/main/scala/leon/solvers/RawArray.scala rename to src/main/scala/inox/solvers/RawArray.scala diff --git a/src/main/scala/leon/solvers/SimpleSolverAPI.scala b/src/main/scala/inox/solvers/SimpleSolverAPI.scala similarity index 100% rename from src/main/scala/leon/solvers/SimpleSolverAPI.scala rename to src/main/scala/inox/solvers/SimpleSolverAPI.scala diff --git a/src/main/scala/leon/solvers/Solver.scala b/src/main/scala/inox/solvers/Solver.scala similarity index 100% rename from src/main/scala/leon/solvers/Solver.scala rename to src/main/scala/inox/solvers/Solver.scala diff --git a/src/main/scala/leon/solvers/SolverFactory.scala b/src/main/scala/inox/solvers/SolverFactory.scala similarity index 100% rename from src/main/scala/leon/solvers/SolverFactory.scala rename to src/main/scala/inox/solvers/SolverFactory.scala diff --git a/src/main/scala/leon/solvers/SolverUnsupportedError.scala b/src/main/scala/inox/solvers/SolverUnsupportedError.scala similarity index 100% rename from src/main/scala/leon/solvers/SolverUnsupportedError.scala rename to src/main/scala/inox/solvers/SolverUnsupportedError.scala diff --git a/src/main/scala/leon/solvers/TimeoutSolver.scala b/src/main/scala/inox/solvers/TimeoutSolver.scala similarity index 100% rename from src/main/scala/leon/solvers/TimeoutSolver.scala rename to src/main/scala/inox/solvers/TimeoutSolver.scala diff --git a/src/main/scala/leon/solvers/TimeoutSolverFactory.scala b/src/main/scala/inox/solvers/TimeoutSolverFactory.scala similarity index 100% rename from src/main/scala/leon/solvers/TimeoutSolverFactory.scala rename to src/main/scala/inox/solvers/TimeoutSolverFactory.scala diff --git a/src/main/scala/leon/solvers/combinators/PortfolioSolver.scala b/src/main/scala/inox/solvers/combinators/PortfolioSolver.scala similarity index 100% rename from src/main/scala/leon/solvers/combinators/PortfolioSolver.scala rename to src/main/scala/inox/solvers/combinators/PortfolioSolver.scala diff --git a/src/main/scala/leon/solvers/combinators/PortfolioSolverFactory.scala b/src/main/scala/inox/solvers/combinators/PortfolioSolverFactory.scala similarity index 100% rename from src/main/scala/leon/solvers/combinators/PortfolioSolverFactory.scala rename to src/main/scala/inox/solvers/combinators/PortfolioSolverFactory.scala diff --git a/src/main/scala/leon/solvers/combinators/RewritingSolver.scala b/src/main/scala/inox/solvers/combinators/RewritingSolver.scala similarity index 100% rename from src/main/scala/leon/solvers/combinators/RewritingSolver.scala rename to src/main/scala/inox/solvers/combinators/RewritingSolver.scala diff --git a/src/main/scala/leon/solvers/combinators/SolverPoolFactory.scala b/src/main/scala/inox/solvers/combinators/SolverPoolFactory.scala similarity index 100% rename from src/main/scala/leon/solvers/combinators/SolverPoolFactory.scala rename to src/main/scala/inox/solvers/combinators/SolverPoolFactory.scala diff --git a/src/main/scala/leon/solvers/cvc4/CVC4Solver.scala b/src/main/scala/inox/solvers/cvc4/CVC4Solver.scala similarity index 100% rename from src/main/scala/leon/solvers/cvc4/CVC4Solver.scala rename to src/main/scala/inox/solvers/cvc4/CVC4Solver.scala diff --git a/src/main/scala/leon/solvers/cvc4/CVC4UnrollingSolver.scala b/src/main/scala/inox/solvers/cvc4/CVC4UnrollingSolver.scala similarity index 100% rename from src/main/scala/leon/solvers/cvc4/CVC4UnrollingSolver.scala rename to src/main/scala/inox/solvers/cvc4/CVC4UnrollingSolver.scala diff --git a/src/main/scala/leon/solvers/isabelle/AdaptationPhase.scala b/src/main/scala/inox/solvers/isabelle/AdaptationPhase.scala similarity index 100% rename from src/main/scala/leon/solvers/isabelle/AdaptationPhase.scala rename to src/main/scala/inox/solvers/isabelle/AdaptationPhase.scala diff --git a/src/main/scala/leon/solvers/isabelle/Component.scala b/src/main/scala/inox/solvers/isabelle/Component.scala similarity index 100% rename from src/main/scala/leon/solvers/isabelle/Component.scala rename to src/main/scala/inox/solvers/isabelle/Component.scala diff --git a/src/main/scala/leon/solvers/isabelle/Functions.scala b/src/main/scala/inox/solvers/isabelle/Functions.scala similarity index 100% rename from src/main/scala/leon/solvers/isabelle/Functions.scala rename to src/main/scala/inox/solvers/isabelle/Functions.scala diff --git a/src/main/scala/leon/solvers/isabelle/IsabelleEnvironment.scala b/src/main/scala/inox/solvers/isabelle/IsabelleEnvironment.scala similarity index 100% rename from src/main/scala/leon/solvers/isabelle/IsabelleEnvironment.scala rename to src/main/scala/inox/solvers/isabelle/IsabelleEnvironment.scala diff --git a/src/main/scala/leon/solvers/isabelle/IsabellePhase.scala b/src/main/scala/inox/solvers/isabelle/IsabellePhase.scala similarity index 100% rename from src/main/scala/leon/solvers/isabelle/IsabellePhase.scala rename to src/main/scala/inox/solvers/isabelle/IsabellePhase.scala diff --git a/src/main/scala/leon/solvers/isabelle/IsabelleSolver.scala b/src/main/scala/inox/solvers/isabelle/IsabelleSolver.scala similarity index 100% rename from src/main/scala/leon/solvers/isabelle/IsabelleSolver.scala rename to src/main/scala/inox/solvers/isabelle/IsabelleSolver.scala diff --git a/src/main/scala/leon/solvers/isabelle/IsabelleSolverFactory.scala b/src/main/scala/inox/solvers/isabelle/IsabelleSolverFactory.scala similarity index 100% rename from src/main/scala/leon/solvers/isabelle/IsabelleSolverFactory.scala rename to src/main/scala/inox/solvers/isabelle/IsabelleSolverFactory.scala diff --git a/src/main/scala/leon/solvers/isabelle/LeonLoggerFactory.scala b/src/main/scala/inox/solvers/isabelle/LeonLoggerFactory.scala similarity index 100% rename from src/main/scala/leon/solvers/isabelle/LeonLoggerFactory.scala rename to src/main/scala/inox/solvers/isabelle/LeonLoggerFactory.scala diff --git a/src/main/scala/leon/solvers/isabelle/Translator.scala b/src/main/scala/inox/solvers/isabelle/Translator.scala similarity index 100% rename from src/main/scala/leon/solvers/isabelle/Translator.scala rename to src/main/scala/inox/solvers/isabelle/Translator.scala diff --git a/src/main/scala/leon/solvers/isabelle/Types.scala b/src/main/scala/inox/solvers/isabelle/Types.scala similarity index 100% rename from src/main/scala/leon/solvers/isabelle/Types.scala rename to src/main/scala/inox/solvers/isabelle/Types.scala diff --git a/src/main/scala/leon/solvers/isabelle/package.scala b/src/main/scala/inox/solvers/isabelle/package.scala similarity index 100% rename from src/main/scala/leon/solvers/isabelle/package.scala rename to src/main/scala/inox/solvers/isabelle/package.scala diff --git a/src/main/scala/leon/solvers/package.scala b/src/main/scala/inox/solvers/package.scala similarity index 100% rename from src/main/scala/leon/solvers/package.scala rename to src/main/scala/inox/solvers/package.scala diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4CounterExampleSolver.scala b/src/main/scala/inox/solvers/smtlib/SMTLIBCVC4CounterExampleSolver.scala similarity index 100% rename from src/main/scala/leon/solvers/smtlib/SMTLIBCVC4CounterExampleSolver.scala rename to src/main/scala/inox/solvers/smtlib/SMTLIBCVC4CounterExampleSolver.scala diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4ProofSolver.scala b/src/main/scala/inox/solvers/smtlib/SMTLIBCVC4ProofSolver.scala similarity index 100% rename from src/main/scala/leon/solvers/smtlib/SMTLIBCVC4ProofSolver.scala rename to src/main/scala/inox/solvers/smtlib/SMTLIBCVC4ProofSolver.scala diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4QuantifiedSolver.scala b/src/main/scala/inox/solvers/smtlib/SMTLIBCVC4QuantifiedSolver.scala similarity index 100% rename from src/main/scala/leon/solvers/smtlib/SMTLIBCVC4QuantifiedSolver.scala rename to src/main/scala/inox/solvers/smtlib/SMTLIBCVC4QuantifiedSolver.scala diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4QuantifiedTarget.scala b/src/main/scala/inox/solvers/smtlib/SMTLIBCVC4QuantifiedTarget.scala similarity index 100% rename from src/main/scala/leon/solvers/smtlib/SMTLIBCVC4QuantifiedTarget.scala rename to src/main/scala/inox/solvers/smtlib/SMTLIBCVC4QuantifiedTarget.scala diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Solver.scala b/src/main/scala/inox/solvers/smtlib/SMTLIBCVC4Solver.scala similarity index 100% rename from src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Solver.scala rename to src/main/scala/inox/solvers/smtlib/SMTLIBCVC4Solver.scala diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala b/src/main/scala/inox/solvers/smtlib/SMTLIBCVC4Target.scala similarity index 100% rename from src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala rename to src/main/scala/inox/solvers/smtlib/SMTLIBCVC4Target.scala diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBQuantifiedSolver.scala b/src/main/scala/inox/solvers/smtlib/SMTLIBQuantifiedSolver.scala similarity index 100% rename from src/main/scala/leon/solvers/smtlib/SMTLIBQuantifiedSolver.scala rename to src/main/scala/inox/solvers/smtlib/SMTLIBQuantifiedSolver.scala diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBQuantifiedTarget.scala b/src/main/scala/inox/solvers/smtlib/SMTLIBQuantifiedTarget.scala similarity index 100% rename from src/main/scala/leon/solvers/smtlib/SMTLIBQuantifiedTarget.scala rename to src/main/scala/inox/solvers/smtlib/SMTLIBQuantifiedTarget.scala diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala b/src/main/scala/inox/solvers/smtlib/SMTLIBSolver.scala similarity index 100% rename from src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala rename to src/main/scala/inox/solvers/smtlib/SMTLIBSolver.scala diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala b/src/main/scala/inox/solvers/smtlib/SMTLIBTarget.scala similarity index 100% rename from src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala rename to src/main/scala/inox/solvers/smtlib/SMTLIBTarget.scala diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBUnsupportedError.scala b/src/main/scala/inox/solvers/smtlib/SMTLIBUnsupportedError.scala similarity index 100% rename from src/main/scala/leon/solvers/smtlib/SMTLIBUnsupportedError.scala rename to src/main/scala/inox/solvers/smtlib/SMTLIBUnsupportedError.scala diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3QuantifiedSolver.scala b/src/main/scala/inox/solvers/smtlib/SMTLIBZ3QuantifiedSolver.scala similarity index 100% rename from src/main/scala/leon/solvers/smtlib/SMTLIBZ3QuantifiedSolver.scala rename to src/main/scala/inox/solvers/smtlib/SMTLIBZ3QuantifiedSolver.scala diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3QuantifiedTarget.scala b/src/main/scala/inox/solvers/smtlib/SMTLIBZ3QuantifiedTarget.scala similarity index 100% rename from src/main/scala/leon/solvers/smtlib/SMTLIBZ3QuantifiedTarget.scala rename to src/main/scala/inox/solvers/smtlib/SMTLIBZ3QuantifiedTarget.scala diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Solver.scala b/src/main/scala/inox/solvers/smtlib/SMTLIBZ3Solver.scala similarity index 100% rename from src/main/scala/leon/solvers/smtlib/SMTLIBZ3Solver.scala rename to src/main/scala/inox/solvers/smtlib/SMTLIBZ3Solver.scala diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala b/src/main/scala/inox/solvers/smtlib/SMTLIBZ3Target.scala similarity index 100% rename from src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala rename to src/main/scala/inox/solvers/smtlib/SMTLIBZ3Target.scala diff --git a/src/main/scala/leon/solvers/string/StringSolver.scala b/src/main/scala/inox/solvers/string/StringSolver.scala similarity index 100% rename from src/main/scala/leon/solvers/string/StringSolver.scala rename to src/main/scala/inox/solvers/string/StringSolver.scala diff --git a/src/main/scala/leon/solvers/sygus/CVC4SygusSolver.scala b/src/main/scala/inox/solvers/sygus/CVC4SygusSolver.scala similarity index 100% rename from src/main/scala/leon/solvers/sygus/CVC4SygusSolver.scala rename to src/main/scala/inox/solvers/sygus/CVC4SygusSolver.scala diff --git a/src/main/scala/leon/solvers/sygus/SygusSolver.scala b/src/main/scala/inox/solvers/sygus/SygusSolver.scala similarity index 100% rename from src/main/scala/leon/solvers/sygus/SygusSolver.scala rename to src/main/scala/inox/solvers/sygus/SygusSolver.scala diff --git a/src/main/scala/leon/solvers/theories/ArrayEncoder.scala b/src/main/scala/inox/solvers/theories/ArrayEncoder.scala similarity index 100% rename from src/main/scala/leon/solvers/theories/ArrayEncoder.scala rename to src/main/scala/inox/solvers/theories/ArrayEncoder.scala diff --git a/src/main/scala/leon/solvers/theories/BagEncoder.scala b/src/main/scala/inox/solvers/theories/BagEncoder.scala similarity index 100% rename from src/main/scala/leon/solvers/theories/BagEncoder.scala rename to src/main/scala/inox/solvers/theories/BagEncoder.scala diff --git a/src/main/scala/leon/solvers/theories/StringEncoder.scala b/src/main/scala/inox/solvers/theories/StringEncoder.scala similarity index 100% rename from src/main/scala/leon/solvers/theories/StringEncoder.scala rename to src/main/scala/inox/solvers/theories/StringEncoder.scala diff --git a/src/main/scala/leon/solvers/theories/TheoryEncoder.scala b/src/main/scala/inox/solvers/theories/TheoryEncoder.scala similarity index 100% rename from src/main/scala/leon/solvers/theories/TheoryEncoder.scala rename to src/main/scala/inox/solvers/theories/TheoryEncoder.scala diff --git a/src/main/scala/leon/solvers/unrolling/DatatypeManager.scala b/src/main/scala/inox/solvers/unrolling/DatatypeManager.scala similarity index 100% rename from src/main/scala/leon/solvers/unrolling/DatatypeManager.scala rename to src/main/scala/inox/solvers/unrolling/DatatypeManager.scala diff --git a/src/main/scala/leon/solvers/unrolling/LambdaManager.scala b/src/main/scala/inox/solvers/unrolling/LambdaManager.scala similarity index 100% rename from src/main/scala/leon/solvers/unrolling/LambdaManager.scala rename to src/main/scala/inox/solvers/unrolling/LambdaManager.scala diff --git a/src/main/scala/leon/solvers/unrolling/QuantificationManager.scala b/src/main/scala/inox/solvers/unrolling/QuantificationManager.scala similarity index 100% rename from src/main/scala/leon/solvers/unrolling/QuantificationManager.scala rename to src/main/scala/inox/solvers/unrolling/QuantificationManager.scala diff --git a/src/main/scala/leon/solvers/unrolling/TemplateEncoder.scala b/src/main/scala/inox/solvers/unrolling/TemplateEncoder.scala similarity index 100% rename from src/main/scala/leon/solvers/unrolling/TemplateEncoder.scala rename to src/main/scala/inox/solvers/unrolling/TemplateEncoder.scala diff --git a/src/main/scala/leon/solvers/unrolling/TemplateGenerator.scala b/src/main/scala/inox/solvers/unrolling/TemplateGenerator.scala similarity index 100% rename from src/main/scala/leon/solvers/unrolling/TemplateGenerator.scala rename to src/main/scala/inox/solvers/unrolling/TemplateGenerator.scala diff --git a/src/main/scala/leon/solvers/unrolling/TemplateInfo.scala b/src/main/scala/inox/solvers/unrolling/TemplateInfo.scala similarity index 100% rename from src/main/scala/leon/solvers/unrolling/TemplateInfo.scala rename to src/main/scala/inox/solvers/unrolling/TemplateInfo.scala diff --git a/src/main/scala/leon/solvers/unrolling/TemplateManager.scala b/src/main/scala/inox/solvers/unrolling/TemplateManager.scala similarity index 100% rename from src/main/scala/leon/solvers/unrolling/TemplateManager.scala rename to src/main/scala/inox/solvers/unrolling/TemplateManager.scala diff --git a/src/main/scala/leon/solvers/unrolling/UnrollingBank.scala b/src/main/scala/inox/solvers/unrolling/UnrollingBank.scala similarity index 100% rename from src/main/scala/leon/solvers/unrolling/UnrollingBank.scala rename to src/main/scala/inox/solvers/unrolling/UnrollingBank.scala diff --git a/src/main/scala/leon/solvers/unrolling/UnrollingSolver.scala b/src/main/scala/inox/solvers/unrolling/UnrollingSolver.scala similarity index 100% rename from src/main/scala/leon/solvers/unrolling/UnrollingSolver.scala rename to src/main/scala/inox/solvers/unrolling/UnrollingSolver.scala diff --git a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala b/src/main/scala/inox/solvers/z3/AbstractZ3Solver.scala similarity index 100% rename from src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala rename to src/main/scala/inox/solvers/z3/AbstractZ3Solver.scala diff --git a/src/main/scala/leon/solvers/z3/FairZ3Solver.scala b/src/main/scala/inox/solvers/z3/FairZ3Solver.scala similarity index 100% rename from src/main/scala/leon/solvers/z3/FairZ3Solver.scala rename to src/main/scala/inox/solvers/z3/FairZ3Solver.scala diff --git a/src/main/scala/leon/solvers/z3/UninterpretedZ3Solver.scala b/src/main/scala/inox/solvers/z3/UninterpretedZ3Solver.scala similarity index 100% rename from src/main/scala/leon/solvers/z3/UninterpretedZ3Solver.scala rename to src/main/scala/inox/solvers/z3/UninterpretedZ3Solver.scala diff --git a/src/main/scala/leon/solvers/z3/Z3ModelReconstruction.scala b/src/main/scala/inox/solvers/z3/Z3ModelReconstruction.scala similarity index 100% rename from src/main/scala/leon/solvers/z3/Z3ModelReconstruction.scala rename to src/main/scala/inox/solvers/z3/Z3ModelReconstruction.scala diff --git a/src/main/scala/leon/solvers/z3/Z3Solver.scala b/src/main/scala/inox/solvers/z3/Z3Solver.scala similarity index 100% rename from src/main/scala/leon/solvers/z3/Z3Solver.scala rename to src/main/scala/inox/solvers/z3/Z3Solver.scala diff --git a/src/main/scala/leon/solvers/z3/Z3UnrollingSolver.scala b/src/main/scala/inox/solvers/z3/Z3UnrollingSolver.scala similarity index 100% rename from src/main/scala/leon/solvers/z3/Z3UnrollingSolver.scala rename to src/main/scala/inox/solvers/z3/Z3UnrollingSolver.scala diff --git a/src/main/scala/leon/purescala/CallGraph.scala b/src/main/scala/inox/trees/CallGraph.scala similarity index 100% rename from src/main/scala/leon/purescala/CallGraph.scala rename to src/main/scala/inox/trees/CallGraph.scala diff --git a/src/main/scala/leon/purescala/Constructors.scala b/src/main/scala/inox/trees/Constructors.scala similarity index 76% rename from src/main/scala/leon/purescala/Constructors.scala rename to src/main/scala/inox/trees/Constructors.scala index e82595f30c57768db3aa908264708c3b8640932a..22844a42297c5f2bf0fb771f9680472ff9f318fb 100644 --- a/src/main/scala/leon/purescala/Constructors.scala +++ b/src/main/scala/inox/trees/Constructors.scala @@ -17,7 +17,7 @@ import Types._ * potentially use a different expression node if one is more suited. * @define encodingof Encoding of * */ -object Constructors { +trait Constructors { self: Expressions => /** If `isTuple`: * `tupleSelect(tupleWrap(Seq(Tuple(x,y))),1) -> x` @@ -41,17 +41,6 @@ object Constructors { */ def tupleSelect(t: Expr, index: Int, originalSize: Int): Expr = tupleSelect(t, index, originalSize > 1) - /** $encodingof ``def foo(..) {} ...; e``. - * @see [[purescala.Expressions.LetDef]] - */ - def letDef(defs: Seq[FunDef], e: Expr) = { - if (defs.isEmpty) { - e - } else { - LetDef(defs, e) - } - } - /** $encodingof ``val id = e; bd``, and returns `bd` if the identifier is not bound in `bd`. * @see [[purescala.Expressions.Let]] */ @@ -105,7 +94,7 @@ object Constructors { * If the sequence is empty, the [[purescala.Types.UnitType UnitType]] is returned. * @see [[purescala.Types.TupleType]] */ - def tupleTypeWrap(tps : Seq[TypeTree]) = tps match { + def tupleTypeWrap(tps: Seq[TypeTree]) = tps match { case Seq() => UnitType case Seq(elem) => elem case more => TupleType(more) @@ -115,8 +104,7 @@ object Constructors { * @return A [[purescala.Expressions.FunctionInvocation FunctionInvocation]] if it type checks, else throws an error. * @see [[purescala.Expressions.FunctionInvocation]] */ - def functionInvocation(fd : FunDef, args : Seq[Expr]) = { - + def functionInvocation(fd: FunDef, args: Seq[Expr]) = { require(fd.params.length == args.length, "Invoking function with incorrect number of arguments") val formalType = tupleTypeWrap(fd.params map { _.getType }) @@ -145,7 +133,7 @@ object Constructors { * @see [[purescala.Expressions.CaseClassPattern MatchExpr]] * @see [[purescala.Expressions.CaseClassPattern CaseClassPattern]] */ - private def filterCases(scrutType : TypeTree, resType: Option[TypeTree], cases: Seq[MatchCase]): Seq[MatchCase] = { + private def filterCases(scrutType: TypeTree, resType: Option[TypeTree], cases: Seq[MatchCase]): Seq[MatchCase] = { val casesFiltered = scrutType match { case c: CaseClassType => cases.filter(_.pattern match { @@ -165,17 +153,6 @@ object Constructors { } } - /** $encodingof the I/O example specification, simplified to '''true''' if the cases are trivially true. - * @see [[purescala.Expressions.Passes Passes]] - */ - def passes(in : Expr, out : Expr, cases : Seq[MatchCase]): Expr = { - val resultingCases = filterCases(in.getType, Some(out.getType), cases) - if (resultingCases.nonEmpty) { - Passes(in, out, resultingCases) - } else { - BooleanLiteral(true) - } - } /** $encodingof `... match { ... }` but simplified if possible. Simplifies to [[Error]] if no case can match the scrutined expression. * @see [[purescala.Expressions.MatchExpr MatchExpr]] */ @@ -261,33 +238,6 @@ object Constructors { case _ => Implies(lhs, rhs) } - /** $encodingof Simplified `Array(...)` (array length defined at compile-time) - * @see [[purescala.Expressions.NonemptyArray NonemptyArray]] - */ - def finiteArray(els: Seq[Expr], tpe: TypeTree = Untyped): Expr = { - require(els.nonEmpty || tpe != Untyped) - finiteArray(els, None, Untyped) // Untyped is not correct, but will not be used anyway - } - /** $encodingof Simplified `Array[...](...)` (array length and default element defined at run-time) with type information - * @see [[purescala.Constructors#finiteArray(els:Map* finiteArray]] - */ - def finiteArray(els: Seq[Expr], defaultLength: Option[(Expr, Expr)], tpe: TypeTree): Expr = { - finiteArray(els.zipWithIndex.map{ _.swap }.toMap, defaultLength, tpe) - } - /** $encodingof Simplified `Array[...](...)` (array length and default element defined at run-time) with type information - * @see [[purescala.Expressions.EmptyArray EmptyArray]] - */ - def finiteArray(els: Map[Int, Expr], defaultLength: Option[(Expr, Expr)], tpe: TypeTree): Expr = { - if (els.isEmpty && (defaultLength.isEmpty || defaultLength.get._2 == IntLiteral(0))) EmptyArray(tpe) - else NonemptyArray(els, defaultLength) - } - /** $encodingof simplified `Array(...)` (array length and default element defined at run-time). - * @see [[purescala.Expressions.NonemptyArray NonemptyArray]] - */ - def nonemptyArray(els: Seq[Expr], defaultLength: Option[(Expr, Expr)]): Expr = { - NonemptyArray(els.zipWithIndex.map{ _.swap }.toMap, defaultLength) - } - /** $encodingof simplified `... == ...` (equality). * @see [[purescala.Expressions.Equals Equals]] */ @@ -350,14 +300,12 @@ object Constructors { * @see [[purescala.Expressions.RealPlus RealPlus]] */ def plus(lhs: Expr, rhs: Expr): Expr = (lhs, rhs) match { - case (InfiniteIntegerLiteral(bi), _) if bi == 0 => rhs - case (_, InfiniteIntegerLiteral(bi)) if bi == 0 => lhs + case (IntegerLiteral(bi), _) if bi == 0 => rhs + case (_, IntegerLiteral(bi)) if bi == 0 => lhs case (IntLiteral(0), _) => rhs case (_, IntLiteral(0)) => lhs case (FractionalLiteral(n, d), _) if n == 0 => rhs case (_, FractionalLiteral(n, d)) if n == 0 => lhs - case (IsTyped(_, Int32Type), IsTyped(_, Int32Type)) => BVPlus(lhs, rhs) - case (IsTyped(_, RealType), IsTyped(_, RealType)) => RealPlus(lhs, rhs) case _ => Plus(lhs, rhs) } @@ -367,21 +315,17 @@ object Constructors { * @see [[purescala.Expressions.RealMinus RealMinus]] */ def minus(lhs: Expr, rhs: Expr): Expr = (lhs, rhs) match { - case (_, InfiniteIntegerLiteral(bi)) if bi == 0 => lhs + case (_, IntegerLiteral(bi)) if bi == 0 => lhs case (_, IntLiteral(0)) => lhs - case (InfiniteIntegerLiteral(bi), _) if bi == 0 => UMinus(rhs) - case (IntLiteral(0), _) => BVUMinus(rhs) - case (IsTyped(_, Int32Type), IsTyped(_, Int32Type)) => BVMinus(lhs, rhs) - case (IsTyped(_, RealType), IsTyped(_, RealType)) => RealMinus(lhs, rhs) + case (IntegerLiteral(bi), _) if bi == 0 => UMinus(rhs) case _ => Minus(lhs, rhs) } def uminus(e: Expr): Expr = e match { - case InfiniteIntegerLiteral(bi) if bi == 0 => e + case IntegerLiteral(bi) if bi == 0 => e case IntLiteral(0) => e - case InfiniteIntegerLiteral(bi) if bi < 0 => InfiniteIntegerLiteral(-bi) - case IsTyped(_, Int32Type) => BVUMinus(e) - case IsTyped(_, RealType) => RealUMinus(e) + case IntegerLiteral(bi) if bi < 0 => IntegerLiteral(-bi) + case UMinus(i) if i.getType == IntegerType => i case _ => UMinus(e) } @@ -391,16 +335,14 @@ object Constructors { * @see [[purescala.Expressions.RealTimes RealTimes]] */ def times(lhs: Expr, rhs: Expr): Expr = (lhs, rhs) match { - case (InfiniteIntegerLiteral(bi), _) if bi == 1 => rhs - case (_, InfiniteIntegerLiteral(bi)) if bi == 1 => lhs - case (InfiniteIntegerLiteral(bi), _) if bi == 0 => InfiniteIntegerLiteral(0) - case (_, InfiniteIntegerLiteral(bi)) if bi == 0 => InfiniteIntegerLiteral(0) + case (IntegerLiteral(bi), _) if bi == 1 => rhs + case (_, IntegerLiteral(bi)) if bi == 1 => lhs + case (IntegerLiteral(bi), _) if bi == 0 => IntegerLiteral(0) + case (_, IntegerLiteral(bi)) if bi == 0 => IntegerLiteral(0) case (IntLiteral(1), _) => rhs case (_, IntLiteral(1)) => lhs case (IntLiteral(0), _) => IntLiteral(0) case (_, IntLiteral(0)) => IntLiteral(0) - case (IsTyped(_, Int32Type), IsTyped(_, Int32Type)) => BVTimes(lhs, rhs) - case (IsTyped(_, RealType), IsTyped(_, RealType)) => RealTimes(lhs, rhs) case _ => Times(lhs, rhs) } diff --git a/src/main/scala/leon/purescala/DefinitionTransformer.scala b/src/main/scala/inox/trees/DefinitionTransformer.scala similarity index 100% rename from src/main/scala/leon/purescala/DefinitionTransformer.scala rename to src/main/scala/inox/trees/DefinitionTransformer.scala diff --git a/src/main/scala/inox/trees/Definitions.scala b/src/main/scala/inox/trees/Definitions.scala new file mode 100644 index 0000000000000000000000000000000000000000..ded2b36156d6d14d10e3cc97dd32a0bbbf996ad0 --- /dev/null +++ b/src/main/scala/inox/trees/Definitions.scala @@ -0,0 +1,386 @@ +/* Copyright 2009-2016 EPFL, Lausanne */ + +package leon +package purescala + +import utils.Library +import Common._ +import Expressions._ +import ExprOps._ +import Types._ +import TypeOps._ + +import scala.collection.mutable.{Map => MutableMap} + +object Definitions { + + sealed trait Definition extends Tree { + val id: Identifier + + override def equals(that: Any): Boolean = that match { + case d: Definition => id == d.id + case _=> false + } + + override def hashCode = id.hashCode + } + + abstract class LookupException(id: Identifier, what: String) extends Exception("Lookup failed for " + what + " with symbol " + id) + case class FunctionLookupException(id: Identifier) extends LookupException(id, "function") + case class ClassLookupException(id: Identifier) extends LookupException(id, "class") + + /** + * A ValDef declares a formal parameter (with symbol [[id]]) to be of a certain type. + */ + case class ValDef(id: Identifier, tpe: Type) extends Definition with Typed { + def getType(implicit p: Program): Type = tpe + + /** Transform this [[ValDef]] into a [[Expressions.Variable Variable]] */ + def toVariable: Variable = Variable(id, tpe) + } + + /** A wrapper for a program. For now a program is simply a single object. */ + case class Program(classes: Map[Identifier, ClassDef], functions: Map[Identifier, FunDef]) extends Tree { + lazy val callGraph = new CallGraph(this) + + private val typedClassCache: MutableMap[(Identifier, Seq[Type]), TypedClassDef] = MutableMap.empty + def lookupClass(id: Identifier): Option[ClassDef] = classes.get(id) + def lookupClass(id: Identifier, tps: Seq[Type]): Option[TypedClassDef] = + typedClassCache.getOrElseUpdated(id -> tps, lookupClass(id).typed(tps)) + + def getClass(id: Identifier): ClassDef = lookupClass(id).getOrElse(throw ClassLookupException(id)) + def getClass(id: Identifier, tps: Seq[Type]): TypedClassDef = lookupClass(id, tps).getOrElse(throw ClassLookupException(id)) + + private val typedFunctionCache: MutableMap[(Identifier, Seq[Type]), TypedFunDef] = MutableMap.empty + def lookupFunction(id: Identifier): Option[FunDef] = functions.get(id) + def lookupFunction(id: Identifier, tps: Seq[Type]): Option[TypedFunDef] = + typedFunctionCache.getOrElseUpdated(id -> tps, lookupFunction(id).typed(tps)) + + def getFunction(id: Identifier): FunDef = lookupFunction(id).getOrElse(throw FunctionLookupException(id)) + def getFunction(id: Identifier, tps: Seq[Type]): TypedFunDef = lookupFunction(id, tps).getOrElse(throw FunctionLookupException(id)) + } + + object Program { + lazy val empty: Program = Program(Nil) + } + + case class TypeParameterDef(tp: TypeParameter) extends Definition { + def freshen = TypeParameterDef(tp.freshen) + val id = tp.id + } + + /** A trait that represents flags that annotate a ClassDef with different attributes */ + sealed trait ClassFlag + + object ClassFlag { + def fromName(name: String, args: Seq[Option[Any]]): ClassFlag = Annotation(name, args) + } + + /** A trait that represents flags that annotate a FunDef with different attributes */ + sealed trait FunctionFlag + + object FunctionFlag { + def fromName(name: String, args: Seq[Option[Any]]): FunctionFlag = name match { + case "inline" => IsInlined + case _ => Annotation(name, args) + } + } + + // Compiler annotations given in the source code as @annot + case class Annotation(annot: String, args: Seq[Option[Any]]) extends FunctionFlag with ClassFlag + // Is inlined + case object IsInlined extends FunctionFlag + // Is an ADT invariant method + case object IsADTInvariant(id: Identifier) extends FunctionFlag + + /** Represents a class definition (either an abstract- or a case-class) */ + sealed trait ClassDef extends Definition { + val id: Identifier + val tparams: Seq[TypeParameterDef] + val fields: Seq[ValDef] + val flags: Set[ClassFlag] + + val parent: Option[Identifier] + val children: Seq[Identifier] + + def hasParent = parent.isDefined + + def invariana(implicit p: Program)t: Option[FunDef] = { + // TODO + parent.flatMap(_.classDef.invariant).orElse(_invariant) + + } + + def annotations: Set[String] = extAnnotations.keySet + def extAnnotations: Map[String, Seq[Option[Any]]] = flags.collect { case Annotation(s, args) => s -> args }.toMap + + def ancestors(implicit p: Program): Seq[ClassDef] = parent + .map(p.getClass).toSeq + .flatMap(parentCls => parentCls +: parentCls.ancestors) + + def root(implicit p: Program) = ancestors.lastOption.getOrElse(this) + + def descendants(implicit p: Program): Seq[ClassDef] = children + .map(p.getClass) + .flatMap(cd => cd +: cd.descendants) + + def ccDescendants(implicit p: Program): Seq[CaseClassDef] = + descendants collect { case ccd: CaseClassDef => ccd } + + def isInductive(implicit p: Program): Boolean = { + def induct(tpe: Type, seen: Set[ClassDef]): Boolean = tpe match { + case ct: ClassType => + val tcd = ct.lookupClass.getOrElse(throw ClassLookupException(ct.id)) + val root = tcd.cd.root + seen(root) || tcd.fields.forall(vd => induct(vd.getType, seen + root)) + case TupleType(tpes) => + tpes.forall(tpe => induct(tpe, seen)) + case _ => true + } + + if (this == root && !this.isAbstract) false + else if (this != root) root.isInductive + else ccDescendants.forall { ccd => + ccd.fields.forall(vd => induct(vd.getType, Set(root))) + } + } + + val isAbstract: Boolean + + def typeArgs = tparams map (_.tp) + + def typed(tps: Seq[Type]): TypedClassDef + def typed: TypedClassDef + } + + /** Abstract classes. */ + class AbstractClassDef(val id: Identifier, + val tparams: Seq[TypeParameterDef], + val parent: Option[Identifier], + val children: Seq[Identifier], + val flags: Set[Flag]) extends ClassDef { + + val fields = Nil + val isAbstract = true + + def typed: TypedAbstractClassDef = typed(tparams.map(_.tp)) + def typed(tps: Seq[Type]): TypedAbstractClassDef = { + require(tps.length == tparams.length) + TypedAbstractClassDef(this, tps) + } + } + + /** Case classes/ case objects. */ + class CaseClassDef(val id: Identifier, + val tparams: Seq[TypeParameterDef], + val parent: Option[Identifier], + val fields: Seq[ValDef], + val flags: Set[Flag]) extends ClassDef { + + val children = Nil + val isAbstract = false + + def selectorID2Index(id: Identifier) : Int = { + val index = fields.indexWhere(_.id == id) + + if (index < 0) { + scala.sys.error( + "Could not find '"+id+"' ("+id.uniqueName+") within "+ + fields.map(_.id.uniqueName).mkString(", ") + ) + } else index + } + + def typed: TypedCaseClassDef = typed(tparams.map(_.tp)) + def typed(tps: Seq[Type]): TypedCaseClassDef = { + require(tps.length == tparams.length) + TypedCaseClassDef(this, tps) + } + } + + sealed abstract class TypedClassDef extends Tree { + val cd: ClassDef + val tps: Seq[Type] + implicit val program: Program + + val id: Identifier = cd.id + lazy val fields: Seq[ValDef] = { + val tmap = (cd.typeArgs zip tps).toMap + if (tmap.isEmpty) cd.fields + else classDef.fields.map(vd => vd.copy(tpe = instantiateType(vd.getType, tmap))) + } + + lazy val parent: Option[TypedAbstractClassDef] = cd.parent.map(id => p.getClass(id) match { + case acd: AbstractClassDef => TypedAbstractClassDef(acd, tps) + case _ => scala.sys.error("Expected parent to be an AbstractClassDef") + }) + + lazy val invariant: Option[TypedFunDef] = cd.invariant.map { fd => + TypedFunDef(fd, tps) + } + + lazy val root = parent.map(_.root).getOrElse(this) + + def descendants: Seq[TypedClassDef] = cd.descendants.map(_.typed(tps)) + def ccDescendants: Seq[TypedCaseClassDef] = cd.ccDescendants.map(_.typed(tps)) + } + + case class TypedAbstractClassDef(cd: AbstractClassDef, tps: Seq[Type])(implicit program: Program) extends TypedClassDef + case class TypedCaseClassDef(cd: AbstractClassDef, tps: Seq[Type])(implicit program: Program) extends TypedClassDef + + + /** Function/method definition. + * + * This class represents methods or fields of objects or classes. By "fields" we mean + * fields defined in the body of a class/object, not the constructor arguments of a case class + * (those are accessible through [[leon.purescala.Definitions.ClassDef.fields]]). + * + * When it comes to verification, all are treated the same (as functions). + * They are only differentiated when it comes to code generation/ pretty printing. + * By default, the FunDef represents a function/method as opposed to a field, + * unless otherwise specified by its flags. + * + * Bear in mind that [[id]] will not be consistently typed. + */ + class FunDef( + val id: Identifier, + val tparams: Seq[TypeParameterDef], + val params: Seq[ValDef], + val returnType: Type, + val fullBody: Expr, + val flags: Set[Flag] + ) extends Definition { + + /* Body manipulation */ + + lazy val body: Option[Expr] = withoutSpec(fullBody) + lazy val precondition = preconditionOf(fullBody) + lazy val precOrTrue = precondition getOrElse BooleanLiteral(true) + + lazy val postcondition = postconditionOf(fullBody) + lazy val postOrTrue = postcondition getOrElse { + val arg = ValDef(FreshIdentifier("res", returnType, alwaysShowUniqueID = true)) + Lambda(Seq(arg), BooleanLiteral(true)) + } + + def hasBody = body.isDefined + def hasPrecondition = precondition.isDefined + def hasPostcondition = postcondition.isDefined + + def annotations: Set[String] = extAnnotations.keySet + def extAnnotations: Map[String, Seq[Option[Any]]] = flags.collect { + case Annotation(s, args) => s -> args + }.toMap + + def canBeLazyField = flags.contains(IsField(true)) && params.isEmpty && tparams.isEmpty + def canBeStrictField = flags.contains(IsField(false)) && params.isEmpty && tparams.isEmpty + def canBeField = canBeLazyField || canBeStrictField + def isRealFunction = !canBeField + def isInvariant = flags contains IsADTInvariant + + /* Wrapping in TypedFunDef */ + + def typed(tps: Seq[Type]): TypedFunDef = { + assert(tps.size == tparams.size) + TypedFunDef(this, tps) + } + + def typed: TypedFunDef = typed(tparams.map(_.tp)) + + /* Auxiliary methods */ + + def isRecursive(implicit p: Program) = p.callGraph.transitiveCallees(this) contains this + + def paramIds = params map { _.id } + + def typeArgs = tparams map (_.tp) + + def applied(args: Seq[Expr]): FunctionInvocation = Constructors.functionInvocation(this, args) + def applied = FunctionInvocation(this.typed, this.paramIds map Variable) + } + + + // Wrapper for typing function according to valuations for type parameters + case class TypedFunDef(fd: FunDef, tps: Seq[Type])(implicit program: Program) extends Tree { + val id = fd.id + + def signature = { + if (tps.nonEmpty) { + id.toString+tps.mkString("[", ", ", "]") + } else { + id.toString + } + } + + private lazy val typesMap: Map[TypeParameter, Type] = { + (fd.typeArgs zip tps).toMap.filter(tt => tt._1 != tt._2) + } + + def translated(t: Type): Type = instantiateType(t, typesMap) + + def translated(e: Expr): Expr = instantiateType(e, typesMap, paramsMap) + + /** A mapping from this [[TypedFunDef]]'s formal parameters to real arguments + * + * @param realArgs The arguments to which the formal argumentas are mapped + * */ + def paramSubst(realArgs: Seq[Expr]) = { + require(realArgs.size == params.size) + (paramIds zip realArgs).toMap + } + + /** Substitute this [[TypedFunDef]]'s formal parameters with real arguments in some expression + * + * @param realArgs The arguments to which the formal argumentas are mapped + * @param e The expression in which the substitution will take place + */ + def withParamSubst(realArgs: Seq[Expr], e: Expr) = { + replaceFromIDs(paramSubst(realArgs), e) + } + + def applied(realArgs: Seq[Expr]): FunctionInvocation = { + FunctionInvocation(fd, tps, realArgs) + } + + def applied: FunctionInvocation = + applied(params map { _.toVariable }) + + /** + * Params will return ValDefs instantiated with the correct types + * For such a ValDef(id,tp) it may hold that (id.getType != tp) + */ + lazy val (params: Seq[ValDef], paramsMap: Map[Identifier, Identifier]) = { + if (typesMap.isEmpty) { + (fd.params, Map()) + } else { + val newParams = fd.params.map { vd => + val newTpe = translated(vd.getType) + val newId = FreshIdentifier(vd.id.name, newTpe, true).copiedFrom(vd.id) + vd.copy(id = newId).setPos(vd) + } + + val paramsMap: Map[Identifier, Identifier] = (fd.params zip newParams).map { case (vd1, vd2) => vd1.id -> vd2.id }.toMap + + (newParams, paramsMap) + } + } + + lazy val functionType = FunctionType(params.map(_.getType).toList, returnType) + + lazy val returnType: Type = translated(fd.returnType) + + lazy val paramIds = params map { _.id } + + lazy val fullBody = translated(fd.fullBody) + lazy val body = fd.body map translated + lazy val precondition = fd.precondition map translated + lazy val precOrTrue = translated(fd.precOrTrue) + lazy val postcondition = fd.postcondition map translated + lazy val postOrTrue = translated(fd.postOrTrue) + + def hasImplementation = body.isDefined + def hasBody = hasImplementation + def hasPrecondition = precondition.isDefined + def hasPostcondition = postcondition.isDefined + } +} diff --git a/src/main/scala/leon/purescala/DependencyFinder.scala b/src/main/scala/inox/trees/DependencyFinder.scala similarity index 90% rename from src/main/scala/leon/purescala/DependencyFinder.scala rename to src/main/scala/inox/trees/DependencyFinder.scala index 704cce67aa90313d899821cc3a0729465059bb93..95e894f166eb140e5fb62d0631b43dc8db1cb5e6 100644 --- a/src/main/scala/leon/purescala/DependencyFinder.scala +++ b/src/main/scala/inox/trees/DependencyFinder.scala @@ -11,14 +11,9 @@ import Types._ import scala.collection.mutable.{Map => MutableMap, Set => MutableSet} -/** Accumulate dependency information for a program - * - * As long as the Program does not change, you should reuse the same - * DependencyFinder, as it could cache its computation internally. If the program - * does change, all bets are off. - */ -class DependencyFinder { - private val deps: MutableMap[Definition, Set[Definition]] = MutableMap.empty +/** Accumulate dependency information for a given program */ +class DependencyFinder(p: Program) { + private val deps: MutableMap[Identifier, Set[Identifier]] = MutableMap.empty /** Return all dependencies for a given Definition * diff --git a/src/main/scala/leon/purescala/ExprOps.scala b/src/main/scala/inox/trees/ExprOps.scala similarity index 97% rename from src/main/scala/leon/purescala/ExprOps.scala rename to src/main/scala/inox/trees/ExprOps.scala index 130b5840e5c2f9ab27b3375f0a3211a81b1c8c4c..e3406cbbfc268b4ad9e237a005f2fb2c4c8b1790 100644 --- a/src/main/scala/leon/purescala/ExprOps.scala +++ b/src/main/scala/inox/trees/ExprOps.scala @@ -32,7 +32,7 @@ import scala.language.implicitConversions * operations on Leon expressions. * */ -object ExprOps extends GenTreeOps[Expr] { +trait ExprOps extends GenTreeOps[Expr] with Extractors { self: Trees => val Deconstructor = Operator @@ -1103,54 +1103,91 @@ object ExprOps extends GenTreeOps[Expr] { postMap(transform, applyRec = true)(expr) } - def simplifyPaths(sf: SolverFactory[Solver], initPC: Path = Path.empty): Expr => Expr = { - new SimplifierWithPaths(sf, initPC).transform - } - - trait Traverser[T] { - def traverse(e: Expr): T - } + def collectWithPC[T](f: PartialFunction[Expr, T])(expr: Expr): Seq[(T, Path)] = { - object CollectorWithPaths { - def apply[T](p: PartialFunction[Expr,T]): CollectorWithPaths[(T, Path)] = new CollectorWithPaths[(T, Path)] { - def collect(e: Expr, path: Path): Option[(T, Path)] = if (!p.isDefinedAt(e)) None else { - Some(p(e) -> path) + def rec(expr: Expr, path: Path): Seq[(T, Path)] = { + val seq = if (f.isDefinedAt(expr)) { + Seq(f(expr) -> path) + } else { + Seq.empty[(T, Path)] } - } - } - - trait CollectorWithPaths[T] extends TransformerWithPC with Traverser[Seq[T]] { - protected val initPath: Path = Path.empty - - private var results: Seq[T] = Nil - def collect(e: Expr, path: Path): Option[T] + val rseq = expr match { + case Let(i, v, b) => + rec(v, path) ++ + rec(b, path withBinding (i -> v)) + + case Ensuring(Require(pre, body), Lambda(Seq(arg), post)) => + rec(pre, path) ++ + rec(body, path withCond pre) ++ + rec(post, path withCond pre withBinding (arg.toVariable -> body)) + + case Ensuring(body, Lambda(Seq(arg), post)) => + rec(body, path) ++ + rec(post, path withBinding (arg.toVariable -> body)) + + case Require(pre, body) => + rec(pre, path) ++ + rec(body, path withCond pre) + + case Assert(pred, err, body) => + rec(pred, path) ++ + rec(body, path withCond pred) + + case MatchExpr(scrut, cases) => + val rs = rec(scrut, path) + var soFar = path + + rs ++ cases.flatMap { c => + val patternPathPos = conditionForPattern(scrut, c.pattern, includeBinders = true) + val patternPathNeg = conditionForPattern(scrut, c.pattern, includeBinders = false) + val map = mapForPattern(scrut, c.pattern) + val guardOrTrue = c.optGuard.getOrElse(BooleanLiteral(true)) + val guardMapped = replaceFromIDs(map, guardOrTrue) + + val rc = rec((patternPathPos withCond guardOrTrue).fullClause, soFar) + val subPath = soFar merge (patternPathPos withCond guardOrTrue) + val rrhs = rec(c.rhs, subPath) + + soFar = soFar merge (patternPathNeg withCond guardMapped).negate + rc ++ rrhs + } - def walk(e: Expr, path: Path): Option[Expr] = None + case IfExpr(cond, thenn, elze) => + rec(cond, path) ++ + rec(thenn, path withCond cond) ++ + rec(elze, path withCond Not(cond)) + + case And(es) => + var soFar = path + es.flatMap { e => + val re = rec(e, soFar) + soFar = soFar withCond e + re + } - override def rec(e: Expr, path: Path) = { - collect(e, path).foreach { results :+= _ } - walk(e, path) match { - case Some(r) => r - case _ => super.rec(e, path) - } - } + case Or(es) => + var soFar = path + es.flatMap { e => + val re = rec(e, soFar) + soFar = soFar withCond Not(e) + re + } - def traverse(funDef: FunDef): Seq[T] = traverse(funDef.fullBody) + case Implies(lhs, rhs) => + rec(lhs, path) ++ + rec(rhs, path withCond lhs) - def traverse(e: Expr): Seq[T] = traverse(e, initPath) + case Operator(es, _) => + es.flatMap(rec(_, path)) - def traverse(e: Expr, init: Expr): Seq[T] = traverse(e, Path(init)) + case _ => sys.error("Expression " + e + "["+e.getClass+"] is not extractable") + } - def traverse(e: Expr, init: Path): Seq[T] = { - results = Nil - rec(e, init) - results + seq ++ rseq } - } - def collectWithPC[T](f: PartialFunction[Expr, T])(expr: Expr): Seq[(T, Path)] = { - CollectorWithPaths(f).traverse(expr) + rec(expr, Path.empty) } override def formulaSize(e: Expr): Int = e match { diff --git a/src/main/scala/leon/purescala/Expressions.scala b/src/main/scala/inox/trees/Expressions.scala similarity index 53% rename from src/main/scala/leon/purescala/Expressions.scala rename to src/main/scala/inox/trees/Expressions.scala index 65fbfb0dc99411ace43fa47ff4d2a6956a574bf2..1215fd7c0913a21abe91c5fff7e44f2dd0f239c2 100644 --- a/src/main/scala/leon/purescala/Expressions.scala +++ b/src/main/scala/inox/trees/Expressions.scala @@ -30,10 +30,10 @@ import ExprOps.replaceFromIDs * @define noteBitvector (32-bit vector) * @define noteReal (Real) */ -object Expressions { +trait Expressions { self: Trees => - private def checkParamTypes(real: Seq[Typed], formal: Seq[Typed], result: TypeTree): TypeTree = { - if (real zip formal forall { case (real, formal) => isSubtypeOf(real.getType, formal.getType)} ) { + private def checkParamTypes(real: Seq[Type], formal: Seq[Type], result: Type): Type = { + if (real zip formal forall { case (real, formal) => isSubtypeOf(real, formal)} ) { result.unveilUntyped } else { //println(s"Failed to type as $result") @@ -52,16 +52,6 @@ object Expressions { } - /** Stands for an undefined Expr, similar to `???` or `null` - * - * During code generation, it gets compiled to `null`, or the 0 of the - * respective type for value types. - */ - case class NoTree(tpe: TypeTree) extends Expr with Terminal { - val getType = tpe - } - - /* Specifications */ /** Computational errors (unmatched case, taking min of an empty set, @@ -72,45 +62,8 @@ object Expressions { * @param tpe The type of this expression * @param description The description of the error */ - case class Error(tpe: TypeTree, description: String) extends Expr with Terminal { - val getType = tpe - } - - /** Precondition of an [[Expressions.Expr]]. Corresponds to the Leon keyword *require* - * - * @param pred The precondition formula inside ``require(...)`` - * @param body The body following the ``require(...)`` - */ - case class Require(pred: Expr, body: Expr) extends Expr { - val getType = { - if (pred.getType == BooleanType) - body.getType - else Untyped - } - } - - /** Postcondition of an [[Expressions.Expr]]. Corresponds to the Leon keyword *ensuring* - * - * @param body The body of the expression. It can contain at most one [[Expressions.Require]] sub-expression. - * @param pred The predicate to satisfy. It should be a function whose argument's type can handle the type of the body - */ - case class Ensuring(body: Expr, pred: Expr) extends Expr { - require(pred.isInstanceOf[Lambda]) - - val getType = pred.getType match { - case FunctionType(Seq(bodyType), BooleanType) if isSubtypeOf(body.getType, bodyType) => - body.getType - case _ => - Untyped - } - /** Converts this ensuring clause to the body followed by an assert statement */ - def toAssert: Expr = { - val res = FreshIdentifier("res", getType, true) - Let(res, body, Assert( - application(pred, Seq(Variable(res))), - Some("Postcondition failed @" + this.getPos), Variable(res) - )) - } + case class Error(tpe: Type, description: String) extends Expr with Terminal { + def getType(implicit p: Program): Type = tpe } /** Local assertions with customizable error message @@ -119,10 +72,9 @@ object Expressions { * @param error An optional error string to display if the assert fails. Second argument of `assert(..., ...)` * @param body The expression following `assert(..., ...)` */ - case class Assert(pred: Expr, error: Option[String], body: Expr) extends Expr { - val getType = { - if (pred.getType == BooleanType) - body.getType + case class Assert(pred: Expr, error: Option[String], body: Expr) extends Expr with CachingTyped { + protected def computeType(implicit p: Program): Type = { + if (pred.getType == BooleanType) body.getType else Untyped } } @@ -131,8 +83,8 @@ object Expressions { /** Variable * @param id The identifier of this variable */ - case class Variable(id: Identifier) extends Expr with Terminal { - val getType = id.getType + case class Variable(id: Identifier) extends Expr with Terminal with CachingTyped { + protected def computeType(implicit p: Program): Type = id.getType } @@ -143,8 +95,8 @@ object Expressions { * @param body The expression following the ``val ... = ... ;`` construct * @see [[purescala.Constructors#let purescala's constructor let]] */ - case class Let(binder: Identifier, value: Expr, body: Expr) extends Expr { - val getType = { + case class Let(binder: Identifier, value: Expr, body: Expr) extends Expr with CachingTyped { + protected def computeType(implicit p: Program): Type = { // We can't demand anything sticter here, because some binders are // typed context-wise if (typesCompatible(value.getType, binder.getType)) @@ -155,59 +107,11 @@ object Expressions { } } - /** $encodingof multiple `def ... = ...; ...` (local function definition and possibly mutually recursive) - * - * @param fds The function definitions. - * @param body The body of the expression after the function - */ - case class LetDef(fds: Seq[FunDef], body: Expr) extends Expr { - require(fds.nonEmpty) - val getType = body.getType - } - - - /* OO Trees */ - - /** $encodingof `(...).method(args)` (method invocation) - * - * Both [[Expressions.MethodInvocation]] and [[Expressions.This]] get removed by phase [[MethodLifting]]. - * Methods become functions, [[Expressions.This]] becomes first argument, - * and [[Expressions.MethodInvocation]] becomes [[Expressions.FunctionInvocation]]. - * - * @param rec The expression evaluating to an object - * @param cd The class definition typing `rec` - * @param tfd The typed function definition of the method - * @param args The arguments provided to the method - */ - case class MethodInvocation(rec: Expr, cd: ClassDef, tfd: TypedFunDef, args: Seq[Expr]) extends Expr { - val getType = { - // We need ot instantiate the type based on the type of the function as well as receiver - val fdret = tfd.returnType - val extraMap: Map[TypeParameter, TypeTree] = rec.getType match { - case ct: ClassType => - (cd.typeArgs zip ct.tps).toMap - case _ => - Map() - } - instantiateType(fdret, extraMap) - } - } - - /** $encodingof the '''this''' keyword - * Both [[Expressions.MethodInvocation]] and [[Expressions.This]] get removed by phase [[MethodLifting]]. - * Methods become functions, [[Expressions.This]] becomes first argument, - * and [[Expressions.MethodInvocation]] becomes [[Expressions.FunctionInvocation]]. - */ - case class This(ct: ClassType) extends Expr with Terminal { - val getType = ct - } - - /* Higher-order Functions */ /** $encodingof `callee(args...)`, where [[callee]] is an expression of a function type (not a method) */ - case class Application(callee: Expr, args: Seq[Expr]) extends Expr { - val getType = callee.getType match { + case class Application(callee: Expr, args: Seq[Expr]) extends Expr with CachingTyped { + protected def computeType(implicit p: Program): Type = callee.getType match { case FunctionType(from, to) => checkParamTypes(args, from, to) case _ => @@ -216,38 +120,43 @@ object Expressions { } /** $encodingof `(args) => body` */ - case class Lambda(args: Seq[ValDef], body: Expr) extends Expr { - val getType = FunctionType(args.map(_.getType), body.getType).unveilUntyped + case class Lambda(args: Seq[ValDef], body: Expr) extends Expr with CachingTyped { + protected def computeType(implicit p: Program): Type = + FunctionType(args.map(_.getType), body.getType).unveilUntyped + def paramSubst(realArgs: Seq[Expr]) = { require(realArgs.size == args.size) (args map { _.id } zip realArgs).toMap } + def withParamSubst(realArgs: Seq[Expr], e: Expr) = { replaceFromIDs(paramSubst(realArgs), e) } } - case class FiniteLambda(mapping: Seq[(Seq[Expr], Expr)], default: Expr, tpe: FunctionType) extends Expr { - val getType = tpe - } - /* Universal Quantification */ - case class Forall(args: Seq[ValDef], body: Expr) extends Expr { - val getType = BooleanType + case class Forall(args: Seq[ValDef], body: Expr) extends Expr with CachingTyped { + protected def computeType(implicit p: Program): Type = body.getType } /* Control flow */ /** $encodingof `function(...)` (function invocation) */ - case class FunctionInvocation(tfd: TypedFunDef, args: Seq[Expr]) extends Expr { - require(tfd.params.size == args.size) - val getType = checkParamTypes(args, tfd.params, tfd.returnType) + case class FunctionInvocation(id: Identifier, tps: Seq[Type], args: Seq[Expr]) extends Expr with CachingTyped { + protected def computeType(implicit p: Program): Type = p.lookupFunction(id) match { + case Some(fd) => + val tfd = fd.typed(tps) + require(args.size == tfd.params.size) + checkParamTypes(args.map(_.getType), tfd.params.map(_.getType), tfd.returnType) + case _ => Untyped + } } /** $encodingof `if(...) ... else ...` */ - case class IfExpr(cond: Expr, thenn: Expr, elze: Expr) extends Expr { - val getType = leastUpperBound(thenn.getType, elze.getType).getOrElse(Untyped).unveilUntyped + case class IfExpr(cond: Expr, thenn: Expr, elze: Expr) extends Expr with CachingTyped { + protected def computeType(implicit p: Program): Type = + leastUpperBound(thenn.getType, elze.getType).getOrElse(Untyped).unveilUntyped } /** $encodingof `... match { ... }` @@ -258,9 +167,10 @@ object Expressions { * @param scrutinee Expression to the left of the '''match''' keyword * @param cases A sequence of cases to match `scrutinee` against */ - case class MatchExpr(scrutinee: Expr, cases: Seq[MatchCase]) extends Expr { + case class MatchExpr(scrutinee: Expr, cases: Seq[MatchCase]) extends Expr with CachingTyped { require(cases.nonEmpty) - val getType = leastUpperBound(cases.map(_.rhs.getType)).getOrElse(Untyped).unveilUntyped + protected def computeType(implicit p: Program): Type = + leastUpperBound(cases.map(_.rhs.getType)).getOrElse(Untyped).unveilUntyped } /** $encodingof `case pattern [if optGuard] => rhs` @@ -298,10 +208,12 @@ object Expressions { case class InstanceOfPattern(binder: Option[Identifier], ct: ClassType) extends Pattern { val subPatterns = Seq() } + /** Pattern encoding `case _ => `, or `case binder => ` if identifier [[binder]] is present */ case class WildcardPattern(binder: Option[Identifier]) extends Pattern { // c @ _ val subPatterns = Seq() } + /** Pattern encoding `case binder @ ct(subPatterns...) =>` * * If [[binder]] is empty, consider a wildcard `_` in its place. @@ -367,9 +279,8 @@ object Expressions { ) def isSome(scrut: Expr) = IsInstanceOf(FunctionInvocation(unapplyFun, Seq(scrut)), someType) - } - + // Extracts without taking care of the binder. (contrary to Extractos.Pattern) object PatternExtractor extends TreeExtractor[Pattern] { def unapply(e: Pattern): Option[(Seq[Pattern], (Seq[Pattern]) => Pattern)] = e match { @@ -389,65 +300,52 @@ object Expressions { val Deconstructor = PatternExtractor } - /** Symbolic I/O examples as a match/case. - * $encodingof `out == (in match { cases; case _ => out })` - * - * [[cases]] should be nonempty. If you are not sure about this, you should use - * [[purescala.Constructors#passes purescala's constructor passes]] - * - * @param in The input expression - * @param out The output expression - * @param cases The cases to compare against - */ - case class Passes(in: Expr, out: Expr, cases: Seq[MatchCase]) extends Expr { - //require(cases.nonEmpty) - - val getType = leastUpperBound(cases.map(_.rhs.getType)) match { - case None => Untyped - case Some(_) => BooleanType - } - - /** Transforms the set of I/O examples to a constraint equality. */ - def asConstraint = { - val defaultCase = SimpleCase(WildcardPattern(None), out) - Equals(out, MatchExpr(in, cases :+ defaultCase)) - } - } - - /** Literals */ + sealed abstract class Literal[+T] extends Expr with Terminal { val value: T } + /** $encodingof a character literal */ case class CharLiteral(value: Char) extends Literal[Char] { - val getType = CharType + def getType(implicit p: Program): Type = CharType } + /** $encodingof a 32-bit integer literal */ case class IntLiteral(value: Int) extends Literal[Int] { - val getType = Int32Type + def getType(implicit p: Program): Type = Int32Type + } + + /** $encodingof a n-bit bitvector literal */ + case class BVLiteral(value: BigInt, size: Int) extends Literal[BigInt] { + def getType(implicit p: Program): Type = BVType(size) } + /** $encodingof an infinite precision integer literal */ - case class InfiniteIntegerLiteral(value: BigInt) extends Literal[BigInt] { - val getType = IntegerType + case class IntegerLiteral(value: BigInt) extends Literal[BigInt] { + def getType(implicit p: Program): Type = IntegerType } + /** $encodingof a fraction literal */ case class FractionalLiteral(numerator: BigInt, denominator: BigInt) extends Literal[(BigInt, BigInt)] { val value = (numerator, denominator) - val getType = RealType + def getType(implicit p: Program): Type = RealType } + /** $encodingof a boolean literal '''true''' or '''false''' */ case class BooleanLiteral(value: Boolean) extends Literal[Boolean] { - val getType = BooleanType + def getType(implicit p: Program): Type = BooleanType } + /** $encodingof the unit literal `()` */ - case class UnitLiteral() extends Literal[Unit] { - val getType = UnitType + case object UnitLiteral extends Literal[Unit] { val value = () + def getType(implicit p: Program): Type = UnitType } + /** $encodingof a string literal */ case class StringLiteral(value: String) extends Literal[String] { - val getType = StringType + def getType(implicit p: Program): Type = StringType } @@ -455,7 +353,7 @@ object Expressions { * This is useful e.g. to present counterexamples of generic types. */ case class GenericValue(tp: TypeParameter, id: Int) extends Expr with Terminal { - val getType = tp + def getType(implicit p: Program): Type = tp } @@ -464,13 +362,17 @@ object Expressions { * @param ct The case class name and inherited attributes * @param args The arguments of the case class */ - case class CaseClass(ct: CaseClassType, args: Seq[Expr]) extends Expr { - val getType = checkParamTypes(args, ct.fieldsTypes, ct) + case class CaseClass(ct: CaseClassType, args: Seq[Expr]) extends Expr with CachingTyped { + protected def computeType(implicit p: Program): Type = ct.lookupClass match { + case Some(tcd) => checkParamTypes(args.map(_.getType), tcd.fieldsTypes, ct) + case _ => Untyped + } } /** $encodingof `.isInstanceOf[...]` */ - case class IsInstanceOf(expr: Expr, classType: ClassType) extends Expr { - val getType = BooleanType + case class IsInstanceOf(expr: Expr, classType: ClassType) extends Expr with CachingTyped { + protected def computeType(implicit p: Program): Type = + if (isSubtypeOf(expr.getType, classType)) BooleanType else Untyped } /** $encodingof `expr.asInstanceOf[tpe]` @@ -478,8 +380,9 @@ object Expressions { * Introduced by matchToIfThenElse to transform match-cases to type-correct * if bodies. */ - case class AsInstanceOf(expr: Expr, tpe: ClassType) extends Expr { - val getType = tpe + case class AsInstanceOf(expr: Expr, tpe: ClassType) extends Expr with CachingTyped { + protected def computeType(implicit p: Program): Type = + if (typesCompatible(tpe, expr.getType)) tpe else Untyped } /** $encodingof `value.selector` where value is of a case class type @@ -487,24 +390,22 @@ object Expressions { * If you are not sure about the requirement you should use * [[purescala.Constructors#caseClassSelector purescala's constructor caseClassSelector]] */ - case class CaseClassSelector(classType: CaseClassType, caseClass: Expr, selector: Identifier) extends Expr { - val selectorIndex = classType.classDef.selectorID2Index(selector) - val getType = { - // We don't demand equality because we may construct a mistyped field retrieval - // (retrieving from a supertype before) passing it to the solver. - // E.g. l.head where l:List[A] or even l: Nil[A]. This is ok for the solvers. - if (typesCompatible(classType, caseClass.getType)) { - classType.fieldsTypes(selectorIndex) - } else { - Untyped - } + case class CaseClassSelector(classType: CaseClassType, caseClass: Expr, selector: Identifier) extends Expr with CachingTyped { + protected def computeType(implicit p: Program): Type = classType.lookupClass match { + case Some(tcd: TypedCaseClassDef) => + val index = tcd.selectorID2Index(selector) + if (classType == caseClass.getType) { + tcd.fieldsTypes(index) + } else { + Untyped + } + case _ => Untyped } } - /** $encodingof `... == ...` */ - case class Equals(lhs: Expr, rhs: Expr) extends Expr { - val getType = { + case class Equals(lhs: Expr, rhs: Expr) extends Expr with CachingTyped { + protected def computeType(implicit p: Program): Type = { if (typesCompatible(lhs.getType, rhs.getType)) BooleanType else { //println(s"Incompatible argument types: arguments: ($lhs, $rhs) types: ${lhs.getType}, ${rhs.getType}") @@ -515,17 +416,18 @@ object Expressions { /* Propositional logic */ + /** $encodingof `... && ...` * * [[exprs]] must contain at least two elements; if you are not sure about this, * you should use [[purescala.Constructors#and purescala's constructor and]] * or [[purescala.Constructors#andJoin purescala's constructor andJoin]] */ - case class And(exprs: Seq[Expr]) extends Expr { + case class And(exprs: Seq[Expr]) extends Expr with CachingTyped { require(exprs.size >= 2) - val getType = { + protected def computeType(implicit p: Program): Type = { if (exprs forall (_.getType == BooleanType)) BooleanType - else Untyped + else checkBVCompatible(exprs.map(_.getType) : _*) } } @@ -541,9 +443,9 @@ object Expressions { */ case class Or(exprs: Seq[Expr]) extends Expr { require(exprs.size >= 2) - val getType = { + protected def computeType(implicit p: Program): Type = { if (exprs forall (_.getType == BooleanType)) BooleanType - else Untyped + else checkBVCompatible(exprs.map(_.getType) : _*) } } @@ -559,7 +461,7 @@ object Expressions { * @see [[leon.purescala.Constructors.implies]] */ case class Implies(lhs: Expr, rhs: Expr) extends Expr { - val getType = { + protected def computeType(implicit p: Program): Type = { if(lhs.getType == BooleanType && rhs.getType == BooleanType) BooleanType else Untyped } @@ -570,38 +472,47 @@ object Expressions { * @see [[leon.purescala.Constructors.not]] */ case class Not(expr: Expr) extends Expr { - val getType = { + protected def computeType(implicit p: Program): Type = { if (expr.getType == BooleanType) BooleanType - else Untyped + else bitVectorType(expr.getType) } } - - abstract class ConverterToString(fromType: TypeTree, toType: TypeTree) extends Expr { - def expr: Expr - val getType = if(expr.getType == fromType) toType else Untyped - } - + + /* String Theory */ + + abstract class ConverterToString(fromType: Type, toType: Type) extends Expr with CachingTyped { + val expr: Expr + protected def computeType(implicit p: Program): Type = + if (expr.getType == fromType) toType else Untyped + } + /** $encodingof `expr.toString` for Int32 to String */ case class Int32ToString(expr: Expr) extends ConverterToString(Int32Type, StringType) + /** $encodingof `expr.toString` for boolean to String */ case class BooleanToString(expr: Expr) extends ConverterToString(BooleanType, StringType) + /** $encodingof `expr.toString` for BigInt to String */ case class IntegerToString(expr: Expr) extends ConverterToString(IntegerType, StringType) + /** $encodingof `expr.toString` for char to String */ case class CharToString(expr: Expr) extends ConverterToString(CharType, StringType) + /** $encodingof `expr.toString` for real to String */ case class RealToString(expr: Expr) extends ConverterToString(RealType, StringType) + /** $encodingof `lhs + rhs` for strings */ - case class StringConcat(lhs: Expr, rhs: Expr) extends Expr { - val getType = { + case class StringConcat(lhs: Expr, rhs: Expr) extends Expr with CachingTyped { + protected def computeType(implicit p: Program): Type = { if (lhs.getType == StringType && rhs.getType == StringType) StringType else Untyped } } + /** $encodingof `lhs.subString(start, end)` for strings */ - case class SubString(expr: Expr, start: Expr, end: Expr) extends Expr { - val getType = { + case class SubString(expr: Expr, start: Expr, end: Expr) extends Expr with CachingTyped { + protected def computeType(implicit p: Program): Type = { val ext = expr.getType val st = start.getType val et = end.getType @@ -609,9 +520,10 @@ object Expressions { else Untyped } } + /** $encodingof `lhs.subString(start, end)` for strings */ - case class BigSubString(expr: Expr, start: Expr, end: Expr) extends Expr { - val getType = { + case class BigSubString(expr: Expr, start: Expr, end: Expr) extends Expr with CachingTyped { + protected def computeType(implicit p: Program): Type = { val ext = expr.getType val st = start.getType val et = end.getType @@ -619,51 +531,68 @@ object Expressions { else Untyped } } + /** $encodingof `lhs.length` for strings */ - case class StringLength(expr: Expr) extends Expr { - val getType = { + case class StringLength(expr: Expr) extends Expr with CachingTyped { + protected def computeType(implicit p: Program): Type = { if (expr.getType == StringType) Int32Type else Untyped } } + /** $encodingof `lhs.length` for strings */ - case class StringBigLength(expr: Expr) extends Expr { - val getType = { + case class StringBigLength(expr: Expr) extends Expr with CachingTyped { + protected def computeType(implicit p: Program): Type = { if (expr.getType == StringType) IntegerType else Untyped } } - /* Integer arithmetic */ + + /* General arithmetic */ + + def numericType(tpe: TypeTree, tpes: TypeTree*): TypeTree = { + lazy val intType = integerType(tpe, tpes) + lazy val bvType = bitVectorType(tpe, tpes) + lazy val realType = realType(tpe, tpes) + if (intType.isTyped) intType else if (bvType.isTyped) bvType else realType + } + + def integerType(tpe: TypeTree, tpes: TypeTree*): TypeTree = tpe match { + case IntegerType if typesCompatible(tpe, tpes : _*) => tpe + case _ => Untyped + } + + def bitVectorType(tpe: TypeTree, tpes: TypeTree*): TypeTree = tpe match { + case _: BVType if typesCompatible(tpe, tpes: _*) => tpe + case _ => Untyped + } + + def realType(tpe: TypeTree, tpes: TypeTree*): TypeTree = tpe match { + case RealType if typesCompatible(tpe, tpes : _*) => tpe + case _ => Untyped + } /** $encodingof `... + ...` for BigInts */ - case class Plus(lhs: Expr, rhs: Expr) extends Expr { - val getType = { - if (lhs.getType == IntegerType && rhs.getType == IntegerType) IntegerType - else Untyped - } + case class Plus(lhs: Expr, rhs: Expr) extends Expr with CachingTyped { + protected def computeType(implicit p: Program): Type = numericType(lhs.getType, rhs.getType) } + /** $encodingof `... - ...` */ - case class Minus(lhs: Expr, rhs: Expr) extends Expr { - val getType = { - if (lhs.getType == IntegerType && rhs.getType == IntegerType) IntegerType - else Untyped - } + case class Minus(lhs: Expr, rhs: Expr) extends Expr with CachingTyped { + protected def computeType(implicit p: Program): Type = numericType(lhs.getType, rhs.getType) } + /** $encodingof `- ... for BigInts`*/ - case class UMinus(expr: Expr) extends Expr { - val getType = { - if (expr.getType == IntegerType) IntegerType - else Untyped - } + case class UMinus(expr: Expr) extends Expr with CachingTyped { + protected def computeType(implicit p: Program): Type = numericType(expr.getType) } + /** $encodingof `... * ...` */ - case class Times(lhs: Expr, rhs: Expr) extends Expr { - val getType = { - if (lhs.getType == IntegerType && rhs.getType == IntegerType) IntegerType - else Untyped - } + case class Times(lhs: Expr, rhs: Expr) extends Expr with CachingTyped { + protected def computeType(implicit p: Program): Type = numericType(lhs.getType, rhs.getType) } + /** $encodingof `... / ...` * * Division and Remainder follows Java/Scala semantics. Division corresponds @@ -675,136 +604,77 @@ object Expressions { * * Division(x, y) * y + Remainder(x, y) == x */ - case class Division(lhs: Expr, rhs: Expr) extends Expr { - val getType = { - if (lhs.getType == IntegerType && rhs.getType == IntegerType) IntegerType - else Untyped - } + case class Division(lhs: Expr, rhs: Expr) extends Expr with CachingTyped { + protected def computeType(implicit p: Program): Type = numericType(lhs.getType, rhs.getType) } + /** $encodingof `... % ...` (can return negative numbers) * * @see [[Expressions.Division]] */ - case class Remainder(lhs: Expr, rhs: Expr) extends Expr { - val getType = { - if (lhs.getType == IntegerType && rhs.getType == IntegerType) IntegerType - else Untyped + case class Remainder(lhs: Expr, rhs: Expr) extends Expr with CachingTyped { + protected def computeType(implicit p: Program): Type = integerType(lhs.getType, rhs.getType) match { + case Untyped => bitVectorType(lhs.getType, rhs.getType) + case tpe => tpe } } + /** $encodingof `... mod ...` (cannot return negative numbers) * * @see [[Expressions.Division]] */ - case class Modulo(lhs: Expr, rhs: Expr) extends Expr { - val getType = { - if (lhs.getType == IntegerType && rhs.getType == IntegerType) IntegerType - else Untyped + case class Modulo(lhs: Expr, rhs: Expr) extends Expr with CachingTyped { + protected def computeType(implicit p: Program): Type = integerType(lhs.getType, rhs.getType) match { + case Untyped => bitVectorType(lhs.getType, rhs.getType) + case tpe => tpe } } + /** $encodingof `... < ...`*/ - case class LessThan(lhs: Expr, rhs: Expr) extends Expr { - val getType = BooleanType + case class LessThan(lhs: Expr, rhs: Expr) extends Expr with CachingTyped { + protected def computeType(implicit p: Program): Type = + if (numericType(lhs.getType, rhs.getType) != Untyped) BooleanType else Untyped } + /** $encodingof `... > ...`*/ case class GreaterThan(lhs: Expr, rhs: Expr) extends Expr { - val getType = BooleanType + protected def computeType(implicit p: Program): Type = + if (numericType(lhs.getType, rhs.getType) != Untyped) BooleanType else Untyped } + /** $encodingof `... <= ...`*/ case class LessEquals(lhs: Expr, rhs: Expr) extends Expr { - val getType = BooleanType + protected def computeType(implicit p: Program): Type = + if (numericType(lhs.getType, rhs.getType) != Untyped) BooleanType else Untyped } + /** $encodingof `... >= ...`*/ case class GreaterEquals(lhs: Expr, rhs: Expr) extends Expr { - val getType = BooleanType + protected def computeType(implicit p: Program): Type = + if (numericType(lhs.getType, rhs.getType) != Untyped) BooleanType else Untyped } - /* Bit-vector arithmetic */ - /** $encodingof `... + ...` $noteBitvector*/ - case class BVPlus(lhs: Expr, rhs: Expr) extends Expr { - require(lhs.getType == Int32Type && rhs.getType == Int32Type) - val getType = Int32Type - } - /** $encodingof `... - ...` $noteBitvector*/ - case class BVMinus(lhs: Expr, rhs: Expr) extends Expr { - require(lhs.getType == Int32Type && rhs.getType == Int32Type) - val getType = Int32Type - } - /** $encodingof `- ...` $noteBitvector*/ - case class BVUMinus(expr: Expr) extends Expr { - require(expr.getType == Int32Type) - val getType = Int32Type - } - /** $encodingof `... * ...` $noteBitvector*/ - case class BVTimes(lhs: Expr, rhs: Expr) extends Expr { - require(lhs.getType == Int32Type && rhs.getType == Int32Type) - val getType = Int32Type - } - /** $encodingof `... / ...` $noteBitvector*/ - case class BVDivision(lhs: Expr, rhs: Expr) extends Expr { - require(lhs.getType == Int32Type && rhs.getType == Int32Type) - val getType = Int32Type - } - /** $encodingof `... % ...` $noteBitvector*/ - case class BVRemainder(lhs: Expr, rhs: Expr) extends Expr { - require(lhs.getType == Int32Type && rhs.getType == Int32Type) - val getType = Int32Type - } - /** $encodingof `! ...` $noteBitvector */ - case class BVNot(expr: Expr) extends Expr { - val getType = Int32Type - } - /** $encodingof `... & ...` $noteBitvector */ - case class BVAnd(lhs: Expr, rhs: Expr) extends Expr { - val getType = Int32Type - } - /** $encodingof `... | ...` $noteBitvector */ - case class BVOr(lhs: Expr, rhs: Expr) extends Expr { - val getType = Int32Type - } + /* Bit-vector operations */ + /** $encodingof `... ^ ...` $noteBitvector */ - case class BVXOr(lhs: Expr, rhs: Expr) extends Expr { - val getType = Int32Type + case class BVXOr(lhs: Expr, rhs: Expr) extends Expr with CachingTyped { + protected def computeType(implicit p: Program): Type = bitVectorType(lhs.getType, rhs.getType) } + /** $encodingof `... << ...` $noteBitvector */ - case class BVShiftLeft(lhs: Expr, rhs: Expr) extends Expr { - val getType = Int32Type + case class BVShiftLeft(lhs: Expr, rhs: Expr) extends Expr with CachingTyped { + protected def computeType(implicit p: Program): Type = bitVectorType(lhs.getType, rhs.getType) } + /** $encodingof `... >> ...` $noteBitvector (arithmetic shift, sign-preserving) */ - case class BVAShiftRight(lhs: Expr, rhs: Expr) extends Expr { - val getType = Int32Type - } - /** $encodingof `... >>> ...` $noteBitvector (logical shift) */ - case class BVLShiftRight(lhs: Expr, rhs: Expr) extends Expr { - val getType = Int32Type + case class BVAShiftRight(lhs: Expr, rhs: Expr) extends Expr with CachingTyped { + protected def computeType(implicit p: Program): Type = bitVectorType(lhs.getType, rhs.getType) } - - /* Real arithmetic */ - /** $encodingof `... + ...` $noteReal */ - case class RealPlus(lhs: Expr, rhs: Expr) extends Expr { - require(lhs.getType == RealType && rhs.getType == RealType) - val getType = RealType - } - /** $encodingof `... - ...` $noteReal */ - case class RealMinus(lhs: Expr, rhs: Expr) extends Expr { - require(lhs.getType == RealType && rhs.getType == RealType) - val getType = RealType - } - /** $encodingof `- ...` $noteReal */ - case class RealUMinus(expr: Expr) extends Expr { - require(expr.getType == RealType) - val getType = RealType - } - /** $encodingof `... * ...` $noteReal */ - case class RealTimes(lhs: Expr, rhs: Expr) extends Expr { - require(lhs.getType == RealType && rhs.getType == RealType) - val getType = RealType - } - /** $encodingof `... / ...` $noteReal */ - case class RealDivision(lhs: Expr, rhs: Expr) extends Expr { - require(lhs.getType == RealType && rhs.getType == RealType) - val getType = RealType + /** $encodingof `... >>> ...` $noteBitvector (logical shift) */ + case class BVLShiftRight(lhs: Expr, rhs: Expr) extends Expr with CachingTyped { + protected def computeType(implicit p: Program): Type = bitVectorType(lhs.getType, rhs.getType) } @@ -818,9 +688,9 @@ object Expressions { * * @param exprs The expressions in the tuple */ - case class Tuple(exprs: Seq[Expr]) extends Expr { + case class Tuple(exprs: Seq[Expr]) extends Expr with CachingTyped { require(exprs.size >= 2) - val getType = TupleType(exprs.map(_.getType)).unveilUntyped + protected def computeType(implicit p: Program): Type = TupleType(exprs.map(_.getType)).unveilUntyped } /** $encodingof `(tuple)._i` @@ -829,16 +699,14 @@ object Expressions { * If you are not sure that [[tuple]] is indeed of a TupleType, * you should use [[leon.purescala.Constructors.tupleSelect(t:leon\.purescala\.Expressions\.Expr,index:Int,isTuple:Boolean):leon\.purescala\.Expressions\.Expr* purescala's constructor tupleSelect]] */ - case class TupleSelect(tuple: Expr, index: Int) extends Expr { + case class TupleSelect(tuple: Expr, index: Int) extends Expr with CachingTyped { require(index >= 1) - val getType = tuple.getType match { - case tp@TupleType(ts) => + protected def computeType(implicit p: Program): Type = tuple.getType match { + case tp @ TupleType(ts) => require(index <= ts.size, s"Got index $index for '$tuple' of type '$tp") ts(index - 1) - - case _ => - Untyped + case _ => Untyped } } @@ -846,12 +714,14 @@ object Expressions { /* Set operations */ /** $encodingof `Set[base](elements)` */ - case class FiniteSet(elements: Set[Expr], base: TypeTree) extends Expr { - val getType = SetType(base).unveilUntyped + case class FiniteSet(elements: Seq[Expr], base: Type) extends Expr { + private lazy val tpe = SetType(base).unveilUntyped + def getType(implicit p: Program): Type = tpe } + /** $encodingof `set + elem` */ - case class SetAdd(set: Expr, elem: Expr) extends Expr { - val getType = { + case class SetAdd(set: Expr, elem: Expr) extends Expr with CachingTyped { + protected def computeType(implicit p: Program): Type = { val base = set.getType match { case SetType(base) => base case _ => Untyped @@ -859,45 +729,61 @@ object Expressions { checkParamTypes(Seq(elem.getType), Seq(base), SetType(base).unveilUntyped) } } + /** $encodingof `set.contains(element)` or `set(element)` */ - case class ElementOfSet(element: Expr, set: Expr) extends Expr { - val getType = checkParamTypes(Seq(element.getType), Seq(set.getType match { + case class ElementOfSet(element: Expr, set: Expr) extends Expr with CachingTyped { + protected def computeType(implicit p: Program) = checkParamTypes(Seq(element.getType), Seq(set.getType match { case SetType(base) => base case _ => Untyped }), BooleanType) } + /** $encodingof `set.length` */ - case class SetCardinality(set: Expr) extends Expr { - val getType = IntegerType + case class SetCardinality(set: Expr) extends Expr with CachingTyped { + protected def computeType(implicit p: Program): Type = set.getType match { + case SetType(_) => IntegerType + case _ => Untyped + } } + /** $encodingof `set.subsetOf(set2)` */ - case class SubsetOf(set1: Expr, set2: Expr) extends Expr { - val getType = (set1.getType, set2.getType) match { + case class SubsetOf(set1: Expr, set2: Expr) extends Expr with CachingTyped { + protected def computeType(implicit p: Program): Type = (set1.getType, set2.getType) match { case (SetType(b1), SetType(b2)) if b1 == b2 => BooleanType case _ => Untyped } } + /** $encodingof `set & set2` */ - case class SetIntersection(set1: Expr, set2: Expr) extends Expr { - val getType = leastUpperBound(Seq(set1, set2).map(_.getType)).getOrElse(Untyped).unveilUntyped + case class SetIntersection(set1: Expr, set2: Expr) extends Expr with CachingTyped { + protected def computeType(implicit p: Program): Type = + leastUpperBound(Seq(set1, set2).map(_.getType)).getOrElse(Untyped).unveilUntyped } + /** $encodingof `set ++ set2` */ - case class SetUnion(set1: Expr, set2: Expr) extends Expr { - val getType = leastUpperBound(Seq(set1, set2).map(_.getType)).getOrElse(Untyped).unveilUntyped + case class SetUnion(set1: Expr, set2: Expr) extends Expr with CachingTyped { + protected def computeType(implicit p: Program): Type = + leastUpperBound(Seq(set1, set2).map(_.getType)).getOrElse(Untyped).unveilUntyped } + /** $encodingof `set -- set2` */ - case class SetDifference(set1: Expr, set2: Expr) extends Expr { - val getType = leastUpperBound(Seq(set1, set2).map(_.getType)).getOrElse(Untyped).unveilUntyped + case class SetDifference(set1: Expr, set2: Expr) extends Expr with CachingTyped { + protected def computeType(implicit p: Program): Type = + leastUpperBound(Seq(set1, set2).map(_.getType)).getOrElse(Untyped).unveilUntyped } + /* Bag operations */ + /** $encodingof `Bag[base](elements)` */ - case class FiniteBag(elements: Map[Expr, Expr], base: TypeTree) extends Expr { - val getType = BagType(base).unveilUntyped + case class FiniteBag(elements: Seq[(Expr, Expr)], base: TypeTree) extends Expr { + lazy val tpe = BagType(base).unveilUntyped + def getType(implicit p: Program): Type = tpe } + /** $encodingof `bag + elem` */ - case class BagAdd(bag: Expr, elem: Expr) extends Expr { - val getType = { + case class BagAdd(bag: Expr, elem: Expr) extends Expr with CachingTyped { + protected def computeType(implicit p: Program): Type = { val base = bag.getType match { case BagType(base) => base case _ => Untyped @@ -905,136 +791,47 @@ object Expressions { checkParamTypes(Seq(base), Seq(elem.getType), BagType(base).unveilUntyped) } } + /** $encodingof `bag.get(element)` or `bag(element)` */ - case class MultiplicityInBag(element: Expr, bag: Expr) extends Expr { - val getType = checkParamTypes(Seq(element.getType), Seq(bag.getType match { + case class MultiplicityInBag(element: Expr, bag: Expr) extends Expr with CachingTyped { + protected def computeType(implicit p: Program): Type = checkParamTypes(Seq(element.getType), Seq(bag.getType match { case BagType(base) => base case _ => Untyped }), IntegerType) } + /** $encodingof `bag1 & bag2` */ - case class BagIntersection(bag1: Expr, bag2: Expr) extends Expr { - val getType = leastUpperBound(Seq(bag1, bag2).map(_.getType)).getOrElse(Untyped).unveilUntyped + case class BagIntersection(bag1: Expr, bag2: Expr) extends Expr with CachingTyped { + protected def computeType(implicit p: Program): Type = + leastUpperBound(Seq(bag1, bag2).map(_.getType)).getOrElse(Untyped).unveilUntyped } + /** $encodingof `bag1 ++ bag2` */ - case class BagUnion(bag1: Expr, bag2: Expr) extends Expr { - val getType = leastUpperBound(Seq(bag1, bag2).map(_.getType)).getOrElse(Untyped).unveilUntyped + case class BagUnion(bag1: Expr, bag2: Expr) extends Expr with CachingTyped { + protected def computeType(implicit p: Program): Type = + leastUpperBound(Seq(bag1, bag2).map(_.getType)).getOrElse(Untyped).unveilUntyped } + /** $encodingof `bag1 -- bag2` */ - case class BagDifference(bag1: Expr, bag2: Expr) extends Expr { - val getType = leastUpperBound(Seq(bag1, bag2).map(_.getType)).getOrElse(Untyped).unveilUntyped + case class BagDifference(bag1: Expr, bag2: Expr) extends Expr with CachingTyped { + protected def computeType(implicit p: Program): Type = + leastUpperBound(Seq(bag1, bag2).map(_.getType)).getOrElse(Untyped).unveilUntyped } - // TODO: Add checks for these expressions too + /* Total map operations */ - /* Map operations */ /** $encodingof `Map[keyType, valueType](key1 -> value1, key2 -> value2 ...)` */ - case class FiniteMap(pairs: Map[Expr, Expr], keyType: TypeTree, valueType: TypeTree) extends Expr { - val getType = MapType(keyType, valueType).unveilUntyped + case class FiniteMap(pairs: Seq[(Expr, Expr)], default: Expr, keyType: TypeTree) extends Expr { + lazy val tpe = MapType(keyType, default.getType).unveilUntyped + def getType(implicit p: Program): Type = tpe } + /** $encodingof `map.apply(key)` (or `map(key)`)*/ - case class MapApply(map: Expr, key: Expr) extends Expr { - val getType = map.getType match { - case MapType(from, to) if isSubtypeOf(key.getType, from) => - to + case class MapApply(map: Expr, key: Expr) extends Expr with CachingTyped { + protected def computeType(implicit p: Program): Type = map.getType match { + case MapType(from, to) => checkParamTypes(Seq(key.getType), Seq(from), to) case _ => Untyped } } - /** $encodingof `map ++ map2` */ - case class MapUnion(map1: Expr, map2: Expr) extends Expr { - val getType = leastUpperBound(Seq(map1, map2).map(_.getType)).getOrElse(Untyped).unveilUntyped - } - /** $encodingof `map -- map2` */ - case class MapDifference(map: Expr, keys: Expr) extends Expr { - val getType = map.getType - } - /** $encodingof `map.isDefinedAt(key)` */ - case class MapIsDefinedAt(map: Expr, key: Expr) extends Expr { - val getType = BooleanType - } - - - /* Array operations */ - /** $encodingof `array(key)` */ - case class ArraySelect(array: Expr, index: Expr) extends Expr { - val getType = array.getType match { - case ArrayType(base) => - base - case _ => - Untyped - } - } - - /** $encodingof `array.updated(key, index)` */ - case class ArrayUpdated(array: Expr, index: Expr, newValue: Expr) extends Expr { - val getType = array.getType match { - case ArrayType(base) => - leastUpperBound(base, newValue.getType).map(ArrayType).getOrElse(Untyped).unveilUntyped - case _ => - Untyped - } - } - - /** $encodingof `array.length` */ - case class ArrayLength(array: Expr) extends Expr { - val getType = Int32Type - } - - /** $encodingof Array(elems...) with predetermined elements - * @param elems The map from the position to the elements. - * @param defaultLength An optional pair where the first element is the default value - * and the second is the size of the array. Set this for big arrays - * with a default value (as genereted with `Array.fill` in Scala). - */ - case class NonemptyArray(elems: Map[Int, Expr], defaultLength: Option[(Expr, Expr)]) extends Expr { - require(elems.nonEmpty || (defaultLength.nonEmpty && defaultLength.get._2 != IntLiteral(0))) - private val elements = elems.values.toList ++ defaultLength.map(_._1) - val getType = ArrayType(optionToType(leastUpperBound(elements map { _.getType }))).unveilUntyped - } - - /** $encodingof `Array[tpe]()` */ - case class EmptyArray(tpe: TypeTree) extends Expr with Terminal { - val getType = ArrayType(tpe).unveilUntyped - } - - /* Special trees for synthesis */ - /** $encodingof `choose(pred)`, the non-deterministic choice in Leon. - * - * The semantics of this expression is some value - * @note [[pred]] should be a of a [[Types.FunctionType]]. - */ - case class Choose(pred: Expr) extends Expr { - val getType = pred.getType match { - case FunctionType(from, BooleanType) if from.nonEmpty => // @mk why nonEmpty? - tupleTypeWrap(from) - case _ => - Untyped - } - } - - /** Provide an oracle (synthesizable, all-seeing choose) */ - case class WithOracle(oracles: List[Identifier], body: Expr) extends Expr with Extractable { - require(oracles.nonEmpty) - - val getType = body.getType - - def extract = { - Some((Seq(body), (es: Seq[Expr]) => WithOracle(oracles, es.head).setPos(this))) - } - } - - /** $encodingof a synthesizable hole in a program. Represented by `???[tpe]` - * in Leon source code. - * - * A [[Hole]] gets transformed into a [[Choose]] construct during [[leon.synthesis.ConversionPhase the ConvertHoles phase]]. - */ - case class Hole(tpe: TypeTree, alts: Seq[Expr]) extends Expr with Extractable { - val getType = tpe - - def extract = { - Some((alts, (es: Seq[Expr]) => Hole(tpe, es).setPos(this))) - } - } - } diff --git a/src/main/scala/leon/purescala/Extractors.scala b/src/main/scala/inox/trees/Extractors.scala similarity index 70% rename from src/main/scala/leon/purescala/Extractors.scala rename to src/main/scala/inox/trees/Extractors.scala index 36b0887dc14fa86623357b1c2157a75b30b3f49e..ac3a22663e0723e5dc87db6859c96f62dacf9ea2 100644 --- a/src/main/scala/leon/purescala/Extractors.scala +++ b/src/main/scala/inox/trees/Extractors.scala @@ -8,7 +8,7 @@ import Common._ import Types._ import Constructors._ -object Extractors { +trait Extractors { self: Expressions => /** Operator Extractor to extract any Expression in a consistent way. * @@ -28,16 +28,8 @@ object Extractors { /* Unary operators */ case Not(t) => Some((Seq(t), (es: Seq[Expr]) => Not(es.head))) - case Choose(expr) => - Some((Seq(expr), (es: Seq[Expr]) => Choose(es.head))) case UMinus(t) => Some((Seq(t), (es: Seq[Expr]) => UMinus(es.head))) - case BVUMinus(t) => - Some((Seq(t), (es: Seq[Expr]) => BVUMinus(es.head))) - case RealUMinus(t) => - Some((Seq(t), (es: Seq[Expr]) => RealUMinus(es.head))) - case BVNot(t) => - Some((Seq(t), (es: Seq[Expr]) => BVNot(es.head))) case StringLength(t) => Some((Seq(t), (es: Seq[Expr]) => StringLength(es.head))) case StringBigLength(t) => @@ -62,37 +54,12 @@ object Extractors { Some((Seq(e), (es: Seq[Expr]) => AsInstanceOf(es.head, ct))) case TupleSelect(t, i) => Some((Seq(t), (es: Seq[Expr]) => TupleSelect(es.head, i))) - case ArrayLength(a) => - Some((Seq(a), (es: Seq[Expr]) => ArrayLength(es.head))) case Lambda(args, body) => Some((Seq(body), (es: Seq[Expr]) => Lambda(args, es.head))) - case FiniteLambda(mapping, dflt, tpe) => - val sze = tpe.from.size + 1 - val subArgs = mapping.flatMap { case (args, v) => args :+ v } - val builder = (as: Seq[Expr]) => { - def rec(kvs: Seq[Expr]): Seq[(Seq[Expr], Expr)] = kvs match { - case seq if seq.size >= sze => - val (args :+ res, rest) = seq.splitAt(sze) - (args -> res) +: rec(rest) - case Seq() => Seq.empty - case _ => sys.error("unexpected number of key/value expressions") - } - FiniteLambda(rec(as.init), as.last, tpe) - } - Some((subArgs :+ dflt, builder)) case Forall(args, body) => Some((Seq(body), (es: Seq[Expr]) => Forall(args, es.head))) /* Binary operators */ - case LetDef(fds, rest) => Some(( - fds.map(_.fullBody) ++ Seq(rest), - (es: Seq[Expr]) => { - for((fd, i) <- fds.zipWithIndex) { - fd.fullBody = es(i) - } - LetDef(fds, es(fds.length)) - } - )) case Equals(t1, t2) => Some(Seq(t1, t2), (es: Seq[Expr]) => Equals(es(0), es(1))) case Implies(t1, t2) => @@ -117,20 +84,6 @@ object Extractors { Some(Seq(t1, t2), (es: Seq[Expr]) => LessEquals(es(0), es(1))) case GreaterEquals(t1, t2) => Some(Seq(t1, t2), (es: Seq[Expr]) => GreaterEquals(es(0), es(1))) - case BVPlus(t1, t2) => - Some(Seq(t1, t2), (es: Seq[Expr]) => BVPlus(es(0), es(1))) - case BVMinus(t1, t2) => - Some(Seq(t1, t2), (es: Seq[Expr]) => BVMinus(es(0), es(1))) - case BVTimes(t1, t2) => - Some(Seq(t1, t2), (es: Seq[Expr]) => BVTimes(es(0), es(1))) - case BVDivision(t1, t2) => - Some(Seq(t1, t2), (es: Seq[Expr]) => BVDivision(es(0), es(1))) - case BVRemainder(t1, t2) => - Some(Seq(t1, t2), (es: Seq[Expr]) => BVRemainder(es(0), es(1))) - case BVAnd(t1, t2) => - Some(Seq(t1, t2), (es: Seq[Expr]) => BVAnd(es(0), es(1))) - case BVOr(t1, t2) => - Some(Seq(t1, t2), (es: Seq[Expr]) => BVOr(es(0), es(1))) case BVXOr(t1, t2) => Some(Seq(t1, t2), (es: Seq[Expr]) => BVXOr(es(0), es(1))) case BVShiftLeft(t1, t2) => @@ -139,14 +92,6 @@ object Extractors { Some(Seq(t1, t2), (es: Seq[Expr]) => BVAShiftRight(es(0), es(1))) case BVLShiftRight(t1, t2) => Some(Seq(t1, t2), (es: Seq[Expr]) => BVLShiftRight(es(0), es(1))) - case RealPlus(t1, t2) => - Some(Seq(t1, t2), (es: Seq[Expr]) => RealPlus(es(0), es(1))) - case RealMinus(t1, t2) => - Some(Seq(t1, t2), (es: Seq[Expr]) => RealMinus(es(0), es(1))) - case RealTimes(t1, t2) => - Some(Seq(t1, t2), (es: Seq[Expr]) => RealTimes(es(0), es(1))) - case RealDivision(t1, t2) => - Some(Seq(t1, t2), (es: Seq[Expr]) => RealDivision(es(0), es(1))) case StringConcat(t1, t2) => Some(Seq(t1, t2), (es: Seq[Expr]) => StringConcat(es(0), es(1))) case SetAdd(t1, t2) => @@ -173,14 +118,6 @@ object Extractors { Some(Seq(e1, e2), (es: Seq[Expr]) => BagDifference(es(0), es(1))) case mg @ MapApply(t1, t2) => Some(Seq(t1, t2), (es: Seq[Expr]) => MapApply(es(0), es(1))) - case MapUnion(t1, t2) => - Some(Seq(t1, t2), (es: Seq[Expr]) => MapUnion(es(0), es(1))) - case MapDifference(t1, t2) => - Some(Seq(t1, t2), (es: Seq[Expr]) => MapDifference(es(0), es(1))) - case MapIsDefinedAt(t1, t2) => - Some(Seq(t1, t2), (es: Seq[Expr]) => MapIsDefinedAt(es(0), es(1))) - case ArraySelect(t1, t2) => - Some(Seq(t1, t2), (es: Seq[Expr]) => ArraySelect(es(0), es(1))) case Let(binder, e, body) => Some(Seq(e, body), (es: Seq[Expr]) => Let(binder, es(0), es(1))) case Require(pre, body) => @@ -191,8 +128,7 @@ object Extractors { Some(Seq(const, body), (es: Seq[Expr]) => Assert(es(0), oerr, es(1))) /* Other operators */ - case fi @ FunctionInvocation(fd, args) => Some((args, FunctionInvocation(fd, _))) - case mi @ MethodInvocation(rec, cd, tfd, args) => Some((rec +: args, as => MethodInvocation(as.head, cd, tfd, as.tail))) + case fi @ FunctionInvocation(fd, tps, args) => Some((args, FunctionInvocation(fd, tps, _))) case fa @ Application(caller, args) => Some(caller +: args, as => Application(as.head, as.tail)) case CaseClass(cd, args) => Some((args, CaseClass(cd, _))) case And(args) => Some((args, es => And(es))) @@ -226,32 +162,12 @@ object Extractors { } Some((subArgs, builder)) } - case ArrayUpdated(t1, t2, t3) => Some(( - Seq(t1, t2, t3), - (as: Seq[Expr]) => ArrayUpdated(as(0), as(1), as(2)) - )) - case NonemptyArray(elems, Some((default, length))) => - val elemsSeq: Seq[(Int, Expr)] = elems.toSeq - val all = elemsSeq.map(_._2) :+ default :+ length - Some((all, as => { - val l = as.length - NonemptyArray(elemsSeq.map(_._1).zip(as.take(l - 2)).toMap, - Some((as(l - 2), as(l - 1)))) - })) - case na @ NonemptyArray(elems, None) => - val ArrayType(tpe) = na.getType - val (indexes, elsOrdered) = elems.toSeq.unzip - - Some(( - elsOrdered, - es => NonemptyArray(indexes.zip(es).toMap, None) - )) case Tuple(args) => Some((args, es => Tuple(es))) case IfExpr(cond, thenn, elze) => Some(( Seq(cond, thenn, elze), { case Seq(c, t, e) => IfExpr(c, t, e) } )) - case m@MatchExpr(scrut, cases) => Some(( + case m @ MatchExpr(scrut, cases) => Some(( scrut +: cases.flatMap { _.expressions }, (es: Seq[Expr]) => { var i = 1 @@ -263,20 +179,6 @@ object Extractors { MatchExpr(es.head, newcases) } )) - case Passes(in, out, cases) => Some(( - in +: out +: cases.flatMap { _.expressions }, - { - case Seq(in, out, es@_*) => { - var i = 0 - val newcases = for (caze <- cases) yield caze match { - case SimpleCase(b, _) => i += 1; SimpleCase(b, es(i - 1)) - case GuardedCase(b, _, _) => i += 2; GuardedCase(b, es(i - 2), es(i - 1)) - } - - Passes(in, out, newcases) - } - } - )) /* Terminals */ case t: Terminal => Some(Seq[Expr](), (_:Seq[Expr]) => t) @@ -289,14 +191,14 @@ object Extractors { None } } - + // Extractors for types are available at Types.NAryType trait Extractable { def extract: Option[(Seq[Expr], Seq[Expr] => Expr)] } - object TopLevelOrs { // expr1 AND (expr2 AND (expr3 AND ..)) => List(expr1, expr2, expr3) + object TopLevelOrs { // expr1 OR (expr2 OR (expr3 OR ..)) => List(expr1, expr2, expr3) def unapply(e: Expr): Option[Seq[Expr]] = e match { case Let(i, e, TopLevelOrs(bs)) => Some(bs map (let(i,e,_))) case Or(exprs) => @@ -305,6 +207,7 @@ object Extractors { Some(Seq(e)) } } + object TopLevelAnds { // expr1 AND (expr2 AND (expr3 AND ..)) => List(expr1, expr2, expr3) def unapply(e: Expr): Option[Seq[Expr]] = e match { case Let(i, e, TopLevelAnds(bs)) => Some(bs map (let(i,e,_))) @@ -316,11 +219,11 @@ object Extractors { } object IsTyped { - def unapply[T <: Typed](e: T): Option[(T, TypeTree)] = Some((e, e.getType)) + def unapply[T <: Typed](e: T): Option[(T, Type)] = Some((e, e.getType)) } - + object WithStringconverter { - def unapply(t: TypeTree): Option[Expr => Expr] = t match { + def unapply(t: Type): Option[Expr => Expr] = t match { case BooleanType => Some(BooleanToString) case Int32Type => Some(Int32ToString) case IntegerType => Some(IntegerToString) @@ -330,19 +233,6 @@ object Extractors { } } - object FiniteArray { - def unapply(e: Expr): Option[(Map[Int, Expr], Option[Expr], Expr)] = e match { - case EmptyArray(_) => - Some((Map(), None, IntLiteral(0))) - case NonemptyArray(els, Some((default, length))) => - Some((els, Some(default), length)) - case NonemptyArray(els, None) => - Some((els, None, IntLiteral(els.size))) - case _ => - None - } - } - object SimpleCase { def apply(p : Pattern, rhs : Expr) = MatchCase(p, None, rhs) def unapply(c : MatchCase) = c match { @@ -350,7 +240,7 @@ object Extractors { case _ => None } } - + object GuardedCase { def apply(p : Pattern, g: Expr, rhs : Expr) = MatchCase(p, Some(g), rhs) def unapply(c : MatchCase) = c match { @@ -358,7 +248,7 @@ object Extractors { case _ => None } } - + object Pattern { def unapply(p : Pattern) : Option[( Option[Identifier], @@ -380,14 +270,16 @@ object Extractors { case _ if !isTuple => Seq(e) case tp => sys.error(s"Calling unwrapTuple on non-tuple $e of type $tp") } + def unwrapTuple(e: Expr, expectedSize: Int): Seq[Expr] = unwrapTuple(e, expectedSize > 1) - def unwrapTupleType(tp: TypeTree, isTuple: Boolean): Seq[TypeTree] = tp match { + def unwrapTupleType(tp: Type, isTuple: Boolean): Seq[Type] = tp match { case TupleType(subs) if isTuple => subs case tp if !isTuple => Seq(tp) case tp => sys.error(s"Calling unwrapTupleType on $tp") } - def unwrapTupleType(tp: TypeTree, expectedSize: Int): Seq[TypeTree] = + + def unwrapTupleType(tp: Type, expectedSize: Int): Seq[Type] = unwrapTupleType(tp, expectedSize > 1) def unwrapTuplePattern(p: Pattern, isTuple: Boolean): Seq[Pattern] = p match { @@ -395,6 +287,7 @@ object Extractors { case tp if !isTuple => Seq(tp) case tp => sys.error(s"Calling unwrapTuplePattern on $p") } + def unwrapTuplePattern(p: Pattern, expectedSize: Int): Seq[Pattern] = unwrapTuplePattern(p, expectedSize > 1) diff --git a/src/main/scala/leon/purescala/GenTreeOps.scala b/src/main/scala/inox/trees/GenTreeOps.scala similarity index 100% rename from src/main/scala/leon/purescala/GenTreeOps.scala rename to src/main/scala/inox/trees/GenTreeOps.scala diff --git a/src/main/scala/leon/purescala/Path.scala b/src/main/scala/inox/trees/Path.scala similarity index 100% rename from src/main/scala/leon/purescala/Path.scala rename to src/main/scala/inox/trees/Path.scala diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/inox/trees/PrettyPrinter.scala similarity index 99% rename from src/main/scala/leon/purescala/PrettyPrinter.scala rename to src/main/scala/inox/trees/PrettyPrinter.scala index 871f269c98b86f5866cd0a7517465ae0f413ad10..1d93d9725b9b703fc0ee247c1631daa993ee9916 100644 --- a/src/main/scala/leon/purescala/PrettyPrinter.scala +++ b/src/main/scala/inox/trees/PrettyPrinter.scala @@ -26,7 +26,7 @@ case class PrinterContext( } /** This pretty-printer uses Unicode for some operators, to make sure we - * distinguish PureScala from "real" Scala (and also because it's cute). */ + * distinguish PureScala from "real" Scala (and also because it's cute). */ class PrettyPrinter(opts: PrinterOptions, opgm: Option[Program], val sb: StringBuffer = new StringBuffer) { diff --git a/src/main/scala/leon/purescala/PrinterHelpers.scala b/src/main/scala/inox/trees/PrinterHelpers.scala similarity index 100% rename from src/main/scala/leon/purescala/PrinterHelpers.scala rename to src/main/scala/inox/trees/PrinterHelpers.scala diff --git a/src/main/scala/leon/purescala/PrinterOptions.scala b/src/main/scala/inox/trees/PrinterOptions.scala similarity index 100% rename from src/main/scala/leon/purescala/PrinterOptions.scala rename to src/main/scala/inox/trees/PrinterOptions.scala diff --git a/src/main/scala/leon/purescala/Quantification.scala b/src/main/scala/inox/trees/Quantification.scala similarity index 100% rename from src/main/scala/leon/purescala/Quantification.scala rename to src/main/scala/inox/trees/Quantification.scala diff --git a/src/main/scala/leon/purescala/Transformer.scala b/src/main/scala/inox/trees/Transformer.scala similarity index 100% rename from src/main/scala/leon/purescala/Transformer.scala rename to src/main/scala/inox/trees/Transformer.scala diff --git a/src/main/scala/leon/purescala/TransformerWithPC.scala b/src/main/scala/inox/trees/TransformerWithPC.scala similarity index 82% rename from src/main/scala/leon/purescala/TransformerWithPC.scala rename to src/main/scala/inox/trees/TransformerWithPC.scala index 95b2c89d4bb976f20b9ab628131fee53a4d7dbf1..2408757a6734690ab6b784b702c116b8a79abfee 100644 --- a/src/main/scala/leon/purescala/TransformerWithPC.scala +++ b/src/main/scala/inox/trees/TransformerWithPC.scala @@ -8,18 +8,16 @@ import Constructors._ import Extractors._ import ExprOps._ -/** Traverses/ transforms expressions with path condition awareness. +/** Traverses expressions with path condition awareness. * * As lets cannot be encoded as Equals due to types for which equality * is not well-founded, path conditions reconstruct lets around the * final condition one wishes to verify through [[Path.getClause]]. */ -abstract class TransformerWithPC extends Transformer { +abstract class TraverserWithPaths[T](trees: Trees) { + import trees._ - /** The initial path condition */ - protected val initPath: Path - - protected def rec(e: Expr, path: Path): Expr = e match { + protected def rec(e: Expr, path: Path): Unit = e match { case Let(i, v, b) => val se = rec(v, path) val sb = rec(b, path withBinding (i -> se)) @@ -47,18 +45,13 @@ abstract class TransformerWithPC extends Transformer { val sb = rec(body, path withCond pre) Require(sp, sb).copiedFrom(e) - //@mk: TODO Discuss if we should include asserted predicates in the pc - //case Assert(pred, err, body) => - // val sp = rec(pred, path) - // val sb = rec(body, register(sp, path)) - // Assert(sp, err, sb).copiedFrom(e) - - case p: Passes => - applyAsMatches(p, rec(_,path)) + case Assert(pred, err, body) => + val sp = rec(pred, path) + val sb = rec(body, register(sp, path)) + Assert(sp, err, sb).copiedFrom(e) case MatchExpr(scrut, cases) => val rs = rec(scrut, path) - var soFar = path MatchExpr(rs, cases.map { c => @@ -104,9 +97,5 @@ abstract class TransformerWithPC extends Transformer { case _ => sys.error("Expression "+e+" ["+e.getClass+"] is not extractable") } - - def transform(e: Expr): Expr = { - rec(e, initPath) - } } diff --git a/src/main/scala/leon/purescala/TreeTransformer.scala b/src/main/scala/inox/trees/TreeTransformer.scala similarity index 100% rename from src/main/scala/leon/purescala/TreeTransformer.scala rename to src/main/scala/inox/trees/TreeTransformer.scala diff --git a/src/main/scala/leon/purescala/Common.scala b/src/main/scala/inox/trees/Trees.scala similarity index 90% rename from src/main/scala/leon/purescala/Common.scala rename to src/main/scala/inox/trees/Trees.scala index 11aa558586d99ae0a667f5b7d1bb08f3706450b9..39d3178916d24f0412a112ddfa50b81b768d8d23 100644 --- a/src/main/scala/leon/purescala/Common.scala +++ b/src/main/scala/inox/trees/Trees.scala @@ -8,7 +8,7 @@ import Expressions.Variable import Types._ import Definitions.Program -object Common { +trait Trees extends Expressions with ExprOps with Types with TypeOps { abstract class Tree extends Positioned with Serializable with Printable { def copiedFrom(o: Tree): this.type = { @@ -18,18 +18,14 @@ object Common { // @EK: toString is considered harmful for non-internal things. Use asString(ctx) instead. - def asString(implicit ctx: LeonContext): String = { - ScalaPrinter(this, ctx) - } - - def asString(pgm: Program)(implicit ctx: LeonContext): String = { + def asString(implicit pgm: Program, ctx: Context): String = { ScalaPrinter(this, ctx, pgm) } override def toString = asString(LeonContext.printNames) } - /** Represents a unique symbol in Leon. + /** Represents a unique symbol in Inox. * * The name is stored in the decoded (source code) form rather than encoded (JVM) form. * The type may be left blank (Untyped) for Identifiers that are not variables. @@ -108,12 +104,12 @@ object Common { } } - def aliased(id1 : Identifier, id2 : Identifier) = { + def aliased(id1: Identifier, id2: Identifier) = { id1.toString == id2.toString } /** Returns true if the two group of identifiers ovelap. */ - def aliased(ids1 : Set[Identifier], ids2 : Set[Identifier]) = { + def aliased(ids1: Set[Identifier], ids2: Set[Identifier]) = { val s1 = ids1.groupBy{ _.toString }.keySet val s2 = ids2.groupBy{ _.toString }.keySet (s1 & s2).nonEmpty diff --git a/src/main/scala/leon/purescala/TypeOps.scala b/src/main/scala/inox/trees/TypeOps.scala similarity index 98% rename from src/main/scala/leon/purescala/TypeOps.scala rename to src/main/scala/inox/trees/TypeOps.scala index 3fd35f7190d8cdafb617bac9375eae9d1d08ffc4..2a9e96b2dd340dd6c16b49c1e0809aade8f5e8a7 100644 --- a/src/main/scala/leon/purescala/TypeOps.scala +++ b/src/main/scala/inox/trees/TypeOps.scala @@ -168,8 +168,8 @@ object TypeOps extends GenTreeOps[TypeTree] { leastUpperBound(t1, t2) == Some(t2) } - def typesCompatible(t1: TypeTree, t2: TypeTree) = { - leastUpperBound(t1, t2).isDefined + def typesCompatible(t1: TypeTree, t2s: TypeTree*) = { + leastUpperBound(t1 +: t2s).isDefined } def typeCheck(obj: Expr, exps: TypeTree*) { diff --git a/src/main/scala/inox/trees/Types.scala b/src/main/scala/inox/trees/Types.scala new file mode 100644 index 0000000000000000000000000000000000000000..608f01ccf6891306d632ef30e64df6d587001792 --- /dev/null +++ b/src/main/scala/inox/trees/Types.scala @@ -0,0 +1,112 @@ +/* Copyright 2009-2016 EPFL, Lausanne */ + +package leon +package purescala + +import Common._ +import Expressions._ +import Definitions._ +import TypeOps._ + +trait Types { self: Trees => + + trait Typed extends Printable { + def getType(implicit p: Program): Type + def isTyped(implicit p: Program): Boolean = getType != Untyped + } + + private[trees] trait CachingTyped extends Typed { + private var lastProgram: Program = null + private var lastType: Type = null + + final def getType(implicit p: Program): Type = + if (p eq lastProgram) lastType else { + val tpe = computeType + lastProgram = p + lastType = tpe + tpe + } + + protected def computeType(implicit p: Program): Type + } + + abstract class Type extends Tree with Typed { + def getType(implicit p: Program): Type = this + + // Checks whether the subtypes of this type contain Untyped, + // and if so sets this to Untyped. + // Assumes the subtypes are correctly formed, so it does not descend + // deep into the TypeTree. + def unveilUntyped: Type = this match { + case NAryType(tps, _) => + if (tps contains Untyped) Untyped else this + } + } + + case object Untyped extends Type + case object BooleanType extends Type + case object UnitType extends Type + case object CharType extends Type + case object IntegerType extends Type + case object RealType extends Type + case object StringType extends Type + + case class BVType(size: Int) extends Type + case object Int32Type extends BVType(32) + + class TypeParameter private (name: String) extends Type { + val id = FreshIdentifier(name, this) + def freshen = new TypeParameter(name) + + override def equals(that: Any) = that match { + case TypeParameter(id) => this.id == id + case _ => false + } + + override def hashCode = id.hashCode + } + + object TypeParameter { + def unapply(tp: TypeParameter): Option[Identifier] = Some(tp.id) + def fresh(name: String) = new TypeParameter(name) + } + + /* + * If you are not sure about the requirement, + * you should use tupleTypeWrap in purescala.Constructors + */ + case class TupleType(bases: Seq[Type]) extends Type { + val dimension: Int = bases.length + require(dimension >= 2) + } + + case class SetType(base: Type) extends Type + case class BagType(base: Type) extends Type + case class MapType(from: Type, to: Type) extends Type + case class FunctionType(from: Seq[Type], to: Type) extends Type + + case class ClassType(id: Identifier, tps: Seq[Type]) extends Type { + def lookupClass(implicit p: Program): Option[ClassDef] = p.lookupClass(id, tps) + } + + object NAryType extends TreeExtractor[Type] { + def unapply(t: Type): Option[(Seq[Type], Seq[Type] => Type)] = t match { + case ClassType(ccd, ts) => Some((ts, ts => ClassType(ccd, ts))) + case TupleType(ts) => Some((ts, TupleType)) + case SetType(t) => Some((Seq(t), ts => SetType(ts.head))) + case BagType(t) => Some((Seq(t), ts => BagType(ts.head))) + case MapType(from,to) => Some((Seq(from, to), t => MapType(t(0), t(1)))) + case FunctionType(fts, tt) => Some((tt +: fts, ts => FunctionType(ts.tail.toList, ts.head))) + /* nullary types */ + case t => Some(Nil, _ => t) + } + } + + object FirstOrderFunctionType { + def unapply(tpe: Type): Option[(Seq[Type], Type)] = tpe match { + case FunctionType(from, to) => + unapply(to).map(p => (from ++ p._1) -> p._2) orElse Some(from -> to) + case _ => None + } + } +} diff --git a/src/main/scala/leon/purescala/package.scala b/src/main/scala/inox/trees/package.scala similarity index 100% rename from src/main/scala/leon/purescala/package.scala rename to src/main/scala/inox/trees/package.scala diff --git a/src/main/scala/leon/utils/ASCIIHelpers.scala b/src/main/scala/inox/utils/ASCIIHelpers.scala similarity index 100% rename from src/main/scala/leon/utils/ASCIIHelpers.scala rename to src/main/scala/inox/utils/ASCIIHelpers.scala diff --git a/src/main/scala/leon/utils/Benchmarks.scala b/src/main/scala/inox/utils/Benchmarks.scala similarity index 100% rename from src/main/scala/leon/utils/Benchmarks.scala rename to src/main/scala/inox/utils/Benchmarks.scala diff --git a/src/main/scala/leon/utils/Bijection.scala b/src/main/scala/inox/utils/Bijection.scala similarity index 100% rename from src/main/scala/leon/utils/Bijection.scala rename to src/main/scala/inox/utils/Bijection.scala diff --git a/src/main/scala/leon/utils/DebugSections.scala b/src/main/scala/inox/utils/DebugSections.scala similarity index 100% rename from src/main/scala/leon/utils/DebugSections.scala rename to src/main/scala/inox/utils/DebugSections.scala diff --git a/src/main/scala/leon/utils/FileOutputPhase.scala b/src/main/scala/inox/utils/FileOutputPhase.scala similarity index 100% rename from src/main/scala/leon/utils/FileOutputPhase.scala rename to src/main/scala/inox/utils/FileOutputPhase.scala diff --git a/src/main/scala/leon/utils/FilesWatcher.scala b/src/main/scala/inox/utils/FilesWatcher.scala similarity index 100% rename from src/main/scala/leon/utils/FilesWatcher.scala rename to src/main/scala/inox/utils/FilesWatcher.scala diff --git a/src/main/scala/leon/utils/FreeableIterator.scala b/src/main/scala/inox/utils/FreeableIterator.scala similarity index 100% rename from src/main/scala/leon/utils/FreeableIterator.scala rename to src/main/scala/inox/utils/FreeableIterator.scala diff --git a/src/main/scala/leon/utils/GraphOps.scala b/src/main/scala/inox/utils/GraphOps.scala similarity index 100% rename from src/main/scala/leon/utils/GraphOps.scala rename to src/main/scala/inox/utils/GraphOps.scala diff --git a/src/main/scala/leon/utils/GraphPrinters.scala b/src/main/scala/inox/utils/GraphPrinters.scala similarity index 100% rename from src/main/scala/leon/utils/GraphPrinters.scala rename to src/main/scala/inox/utils/GraphPrinters.scala diff --git a/src/main/scala/leon/utils/Graphs.scala b/src/main/scala/inox/utils/Graphs.scala similarity index 100% rename from src/main/scala/leon/utils/Graphs.scala rename to src/main/scala/inox/utils/Graphs.scala diff --git a/src/main/scala/leon/utils/GrowableIterable.scala b/src/main/scala/inox/utils/GrowableIterable.scala similarity index 100% rename from src/main/scala/leon/utils/GrowableIterable.scala rename to src/main/scala/inox/utils/GrowableIterable.scala diff --git a/src/main/scala/leon/utils/IncrementalBijection.scala b/src/main/scala/inox/utils/IncrementalBijection.scala similarity index 100% rename from src/main/scala/leon/utils/IncrementalBijection.scala rename to src/main/scala/inox/utils/IncrementalBijection.scala diff --git a/src/main/scala/leon/utils/IncrementalMap.scala b/src/main/scala/inox/utils/IncrementalMap.scala similarity index 100% rename from src/main/scala/leon/utils/IncrementalMap.scala rename to src/main/scala/inox/utils/IncrementalMap.scala diff --git a/src/main/scala/leon/utils/IncrementalSeq.scala b/src/main/scala/inox/utils/IncrementalSeq.scala similarity index 100% rename from src/main/scala/leon/utils/IncrementalSeq.scala rename to src/main/scala/inox/utils/IncrementalSeq.scala diff --git a/src/main/scala/leon/utils/IncrementalSet.scala b/src/main/scala/inox/utils/IncrementalSet.scala similarity index 100% rename from src/main/scala/leon/utils/IncrementalSet.scala rename to src/main/scala/inox/utils/IncrementalSet.scala diff --git a/src/main/scala/leon/utils/IncrementalState.scala b/src/main/scala/inox/utils/IncrementalState.scala similarity index 100% rename from src/main/scala/leon/utils/IncrementalState.scala rename to src/main/scala/inox/utils/IncrementalState.scala diff --git a/src/main/scala/leon/utils/InliningPhase.scala b/src/main/scala/inox/utils/InliningPhase.scala similarity index 100% rename from src/main/scala/leon/utils/InliningPhase.scala rename to src/main/scala/inox/utils/InliningPhase.scala diff --git a/src/main/scala/leon/utils/InterruptManager.scala b/src/main/scala/inox/utils/InterruptManager.scala similarity index 100% rename from src/main/scala/leon/utils/InterruptManager.scala rename to src/main/scala/inox/utils/InterruptManager.scala diff --git a/src/main/scala/leon/utils/Interruptible.scala b/src/main/scala/inox/utils/Interruptible.scala similarity index 100% rename from src/main/scala/leon/utils/Interruptible.scala rename to src/main/scala/inox/utils/Interruptible.scala diff --git a/src/main/scala/leon/utils/Library.scala b/src/main/scala/inox/utils/Library.scala similarity index 100% rename from src/main/scala/leon/utils/Library.scala rename to src/main/scala/inox/utils/Library.scala diff --git a/src/main/scala/leon/utils/ModelEnumerator.scala b/src/main/scala/inox/utils/ModelEnumerator.scala similarity index 100% rename from src/main/scala/leon/utils/ModelEnumerator.scala rename to src/main/scala/inox/utils/ModelEnumerator.scala diff --git a/src/main/scala/leon/utils/OracleTraverser.scala b/src/main/scala/inox/utils/OracleTraverser.scala similarity index 100% rename from src/main/scala/leon/utils/OracleTraverser.scala rename to src/main/scala/inox/utils/OracleTraverser.scala diff --git a/src/main/scala/leon/utils/Positions.scala b/src/main/scala/inox/utils/Positions.scala similarity index 100% rename from src/main/scala/leon/utils/Positions.scala rename to src/main/scala/inox/utils/Positions.scala diff --git a/src/main/scala/leon/utils/PreprocessingPhase.scala b/src/main/scala/inox/utils/PreprocessingPhase.scala similarity index 100% rename from src/main/scala/leon/utils/PreprocessingPhase.scala rename to src/main/scala/inox/utils/PreprocessingPhase.scala diff --git a/src/main/scala/leon/utils/PrintReportPhase.scala b/src/main/scala/inox/utils/PrintReportPhase.scala similarity index 100% rename from src/main/scala/leon/utils/PrintReportPhase.scala rename to src/main/scala/inox/utils/PrintReportPhase.scala diff --git a/src/main/scala/leon/utils/PrintTreePhase.scala b/src/main/scala/inox/utils/PrintTreePhase.scala similarity index 100% rename from src/main/scala/leon/utils/PrintTreePhase.scala rename to src/main/scala/inox/utils/PrintTreePhase.scala diff --git a/src/main/scala/leon/utils/Report.scala b/src/main/scala/inox/utils/Report.scala similarity index 100% rename from src/main/scala/leon/utils/Report.scala rename to src/main/scala/inox/utils/Report.scala diff --git a/src/main/scala/leon/utils/SCC.scala b/src/main/scala/inox/utils/SCC.scala similarity index 100% rename from src/main/scala/leon/utils/SCC.scala rename to src/main/scala/inox/utils/SCC.scala diff --git a/src/main/scala/leon/utils/SearchSpace.scala b/src/main/scala/inox/utils/SearchSpace.scala similarity index 100% rename from src/main/scala/leon/utils/SearchSpace.scala rename to src/main/scala/inox/utils/SearchSpace.scala diff --git a/src/main/scala/leon/utils/SeqUtils.scala b/src/main/scala/inox/utils/SeqUtils.scala similarity index 100% rename from src/main/scala/leon/utils/SeqUtils.scala rename to src/main/scala/inox/utils/SeqUtils.scala diff --git a/src/main/scala/leon/utils/StreamUtils.scala b/src/main/scala/inox/utils/StreamUtils.scala similarity index 100% rename from src/main/scala/leon/utils/StreamUtils.scala rename to src/main/scala/inox/utils/StreamUtils.scala diff --git a/src/main/scala/leon/utils/TemporaryInputPhase.scala b/src/main/scala/inox/utils/TemporaryInputPhase.scala similarity index 100% rename from src/main/scala/leon/utils/TemporaryInputPhase.scala rename to src/main/scala/inox/utils/TemporaryInputPhase.scala diff --git a/src/main/scala/leon/utils/TimeoutFor.scala b/src/main/scala/inox/utils/TimeoutFor.scala similarity index 100% rename from src/main/scala/leon/utils/TimeoutFor.scala rename to src/main/scala/inox/utils/TimeoutFor.scala diff --git a/src/main/scala/leon/utils/Timer.scala b/src/main/scala/inox/utils/Timer.scala similarity index 100% rename from src/main/scala/leon/utils/Timer.scala rename to src/main/scala/inox/utils/Timer.scala diff --git a/src/main/scala/leon/utils/TypingPhase.scala b/src/main/scala/inox/utils/TypingPhase.scala similarity index 100% rename from src/main/scala/leon/utils/TypingPhase.scala rename to src/main/scala/inox/utils/TypingPhase.scala diff --git a/src/main/scala/leon/utils/UniqueCounter.scala b/src/main/scala/inox/utils/UniqueCounter.scala similarity index 100% rename from src/main/scala/leon/utils/UniqueCounter.scala rename to src/main/scala/inox/utils/UniqueCounter.scala diff --git a/src/main/scala/leon/utils/UnitElimination.scala b/src/main/scala/inox/utils/UnitElimination.scala similarity index 100% rename from src/main/scala/leon/utils/UnitElimination.scala rename to src/main/scala/inox/utils/UnitElimination.scala diff --git a/src/main/scala/leon/utils/package.scala b/src/main/scala/inox/utils/package.scala similarity index 100% rename from src/main/scala/leon/utils/package.scala rename to src/main/scala/inox/utils/package.scala diff --git a/src/main/scala/leon/LeonPhase.scala b/src/main/scala/leon/LeonPhase.scala deleted file mode 100644 index 8fb5307955305fb6612cd7b3933fcda553cd0fe7..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/LeonPhase.scala +++ /dev/null @@ -1,41 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon - -import purescala.Definitions.Program - -trait LeonPhase[-F, +T] extends Pipeline[F, T] with LeonComponent { - // def run(ac: LeonContext)(v: F): T -} - -trait SimpleLeonPhase[-F, +T] extends LeonPhase[F, T] { - def apply(ctx: LeonContext, v: F): T - - def run(ctx: LeonContext, v: F): (LeonContext, T) = (ctx, apply(ctx, v)) -} - -abstract class TransformationPhase extends LeonPhase[Program, Program] { - def apply(ctx: LeonContext, p: Program): Program - - override def run(ctx: LeonContext, p: Program) = { - ctx.reporter.debug("Running transformation phase: " + name)(utils.DebugSectionLeon) - (ctx, apply(ctx, p)) - } - -} - -abstract class UnitPhase[T] extends LeonPhase[T, T] { - def apply(ctx: LeonContext, p: T): Unit - - override def run(ctx: LeonContext, p: T) = { - ctx.reporter.debug("Running unit phase: " + name)(utils.DebugSectionLeon) - apply(ctx, p) - (ctx, p) - } -} - -case class NoopPhase[T]() extends LeonPhase[T, T] { - val name = "noop" - val description = "no-op" - override def run(ctx: LeonContext, v: T) = (ctx, v) -} diff --git a/src/main/scala/leon/Main.scala b/src/main/scala/leon/Main.scala deleted file mode 100644 index 583296566a74dd5ed1e8dcc1038aa32c3272690b..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/Main.scala +++ /dev/null @@ -1,292 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon - -import leon.utils._ - -object Main { - - lazy val allPhases: List[LeonPhase[_, _]] = { - List( - frontends.scalac.ExtractionPhase, - frontends.scalac.ClassgenPhase, - utils.TypingPhase, - utils.FileOutputPhase, - purescala.RestoreMethods, - xlang.AntiAliasingPhase, - xlang.EpsilonElimination, - xlang.ImperativeCodeElimination, - xlang.FixReportLabels, - xlang.XLangDesugaringPhase, - purescala.FunctionClosure, - synthesis.SynthesisPhase, - termination.TerminationPhase, - verification.VerificationPhase, - repair.RepairPhase, - evaluators.EvaluationPhase, - solvers.isabelle.AdaptationPhase, - solvers.isabelle.IsabellePhase, - transformations.InstrumentationPhase, - invariant.engine.InferInvariantsPhase, - laziness.LazinessEliminationPhase, - genc.GenerateCPhase, - genc.CFileOutputPhase - ) - } - - // Add whatever you need here. - lazy val allComponents : Set[LeonComponent] = allPhases.toSet ++ Set( - solvers.unrolling.UnrollingProcedure, MainComponent, GlobalOptions, solvers.smtlib.SMTLIBCVC4Component, solvers.isabelle.Component - ) - - /* - * This object holds the options that determine the selected pipeline of Leon. - * Please put any further such options here to have them print nicely in --help message. - */ - object MainComponent extends LeonComponent { - val name = "main" - val description = "Selection of Leon functionality. Default: verify" - - val optEval = LeonStringOptionDef("eval", "Evaluate ground functions through code generation or evaluation (default: evaluation)", "default", "[codegen|default]") - val optTermination = LeonFlagOptionDef("termination", "Check program termination. Can be used along --verify", false) - val optRepair = LeonFlagOptionDef("repair", "Repair selected functions", false) - val optSynthesis = LeonFlagOptionDef("synthesis", "Partial synthesis of choose() constructs", false) - val optIsabelle = LeonFlagOptionDef("isabelle", "Run Isabelle verification", false) - val optNoop = LeonFlagOptionDef("noop", "No operation performed, just output program", false) - val optVerify = LeonFlagOptionDef("verify", "Verify function contracts", false) - val optHelp = LeonFlagOptionDef("help", "Show help message", false) - val optInstrument = LeonFlagOptionDef("instrument", "Instrument the code for inferring time/depth/stack bounds", false) - val optInferInv = LeonFlagOptionDef("inferInv", "Infer invariants from (instrumented) the code", false) - val optLazyEval = LeonFlagOptionDef("lazy", "Handles programs that may use the 'lazy' construct", false) - val optGenc = LeonFlagOptionDef("genc", "Generate C code", false) - - override val definedOptions: Set[LeonOptionDef[Any]] = - Set(optTermination, optRepair, optSynthesis, optIsabelle, optNoop, optHelp, optEval, optVerify, optInstrument, optInferInv, optLazyEval, optGenc) - } - - lazy val allOptions: Set[LeonOptionDef[Any]] = allComponents.flatMap(_.definedOptions) - - def displayHelp(reporter: Reporter, error: Boolean) = { - - reporter.title(MainComponent.description) - for (opt <- MainComponent.definedOptions.toSeq.sortBy(_.name)) { - reporter.info(opt.helpString) - } - reporter.info("") - - reporter.title("Additional global options") - for (opt <- GlobalOptions.definedOptions.toSeq.sortBy(_.name)) { - reporter.info(opt.helpString) - } - reporter.info("") - - reporter.title("Additional options, by component:") - - for (c <- (allComponents - MainComponent - GlobalOptions).toSeq.sortBy(_.name) if c.definedOptions.nonEmpty) { - reporter.info("") - reporter.info(s"${c.name} (${c.description})") - for (opt <- c.definedOptions.toSeq.sortBy(_.name)) { - // there is a non-breaking space at the beginning of the string :) - reporter.info(opt.helpString) - } - } - exit(error) - } - - def displayVersion(reporter: Reporter) = { - reporter.title("Leon verification and synthesis tool (http://leon.epfl.ch/)") - reporter.info("") - } - - private def exit(error: Boolean) = sys.exit(if (error) 1 else 0) - - def processOptions(args: Seq[String]): LeonContext = { - - val initReporter = new DefaultReporter(Set()) - - val options = args.filter(_.startsWith("--")) - - val files = args.filterNot(_.startsWith("-")).map(new java.io.File(_)) - - val leonOptions: Seq[LeonOption[Any]] = options.map { opt => - val (name, value) = OptionsHelpers.nameValue(opt).getOrElse( - initReporter.fatalError( - s"Malformed option $opt. Options should have the form --name or --name=value" - ) - ) - // Find respective LeonOptionDef, or report an unknown option - val df = allOptions.find(_.name == name).getOrElse{ - initReporter.fatalError( - s"Unknown option: $name\n" + - "Try 'leon --help' for more information." - ) - } - df.parse(value)(initReporter) - } - - val reporter = new DefaultReporter( - leonOptions.collectFirst { - case LeonOption(GlobalOptions.optDebug, sections) => - sections.asInstanceOf[Set[DebugSection]] - }.getOrElse(Set[DebugSection]()) - ) - - reporter.whenDebug(DebugSectionOptions) { debug => - debug("Options considered by Leon:") - for (lo <- leonOptions) debug(lo.toString) - } - - LeonContext( - reporter = reporter, - files = files, - options = leonOptions, - interruptManager = new InterruptManager(reporter) - ) - } - - def computePipeline(ctx: LeonContext): Pipeline[List[String], Any] = { - - import purescala.Definitions.Program - import purescala.RestoreMethods - import utils.FileOutputPhase - import frontends.scalac.{ ExtractionPhase, ClassgenPhase } - import synthesis.SynthesisPhase - import termination.TerminationPhase - import xlang.FixReportLabels - import verification.VerificationPhase - import repair.RepairPhase - import evaluators.EvaluationPhase - import solvers.isabelle.IsabellePhase - import genc.GenerateCPhase - import genc.CFileOutputPhase - import MainComponent._ - import invariant.engine.InferInvariantsPhase - import transformations.InstrumentationPhase - import laziness._ - - val helpF = ctx.findOptionOrDefault(optHelp) - val noopF = ctx.findOptionOrDefault(optNoop) - val synthesisF = ctx.findOptionOrDefault(optSynthesis) - val repairF = ctx.findOptionOrDefault(optRepair) - val isabelleF = ctx.findOptionOrDefault(optIsabelle) - val terminationF = ctx.findOptionOrDefault(optTermination) - val verifyF = ctx.findOptionOrDefault(optVerify) - val gencF = ctx.findOptionOrDefault(optGenc) - val evalF = ctx.findOption(optEval).isDefined - val inferInvF = ctx.findOptionOrDefault(optInferInv) - val instrumentF = ctx.findOptionOrDefault(optInstrument) - val lazyevalF = ctx.findOptionOrDefault(optLazyEval) - val analysisF = verifyF && terminationF - // Check consistency in options - - if (helpF) { - displayVersion(ctx.reporter) - displayHelp(ctx.reporter, error = false) - } else { - val pipeBegin: Pipeline[List[String], Program] = - ClassgenPhase andThen - ExtractionPhase andThen - new PreprocessingPhase(genc = gencF) - - val verification = - VerificationPhase andThen - FixReportLabels andThen - PrintReportPhase - val termination = TerminationPhase andThen PrintReportPhase - - val pipeProcess: Pipeline[Program, Any] = { - if (noopF) RestoreMethods andThen FileOutputPhase - else if (synthesisF) SynthesisPhase - else if (repairF) RepairPhase - else if (analysisF) Pipeline.both(verification, termination) - else if (terminationF) termination - else if (isabelleF) IsabellePhase andThen PrintReportPhase - else if (evalF) EvaluationPhase - else if (inferInvF) InferInvariantsPhase - else if (instrumentF) InstrumentationPhase andThen FileOutputPhase - else if (gencF) GenerateCPhase andThen CFileOutputPhase - else if (lazyevalF) LazinessEliminationPhase - else verification - } - - pipeBegin andThen - pipeProcess - } - } - - private var hasFatal = false - - def main(args: Array[String]) { - val argsl = args.toList - - // Process options - val ctx = try { - processOptions(argsl) - } catch { - case LeonFatalError(None) => - exit(error = true) - - case LeonFatalError(Some(msg)) => - // For the special case of fatal errors not sent though Reporter, we - // send them through reporter one time - try { - new DefaultReporter(Set()).fatalError(msg) - } catch { - case _: LeonFatalError => - } - - exit(error = true) - } - - ctx.interruptManager.registerSignalHandler() - - val doWatch = ctx.findOptionOrDefault(GlobalOptions.optWatch) - - if (doWatch) { - val watcher = new FilesWatcher(ctx, ctx.files ++ Build.libFiles.map { new java.io.File(_) }) - watcher.onChange { - execute(args, ctx) - } - } else { - execute(args, ctx) - } - - exit(hasFatal) - } - - def execute(args: Seq[String], ctx0: LeonContext): Unit = { - val ctx = ctx0.copy(reporter = new DefaultReporter(ctx0.reporter.debugSections)) - - try { - // Compute leon pipeline - val pipeline = computePipeline(ctx) - - val timer = ctx.timers.total.start() - - // Run pipeline - val (ctx2, _) = pipeline.run(ctx, args.toList) - - timer.stop() - - ctx2.reporter.whenDebug(DebugSectionTimers) { debug => - ctx2.timers.outputTable(debug) - } - hasFatal = false - } catch { - case LeonFatalError(None) => - hasFatal = true - - case LeonFatalError(Some(msg)) => - // For the special case of fatal errors not sent though Reporter, we - // send them through reporter one time - try { - ctx.reporter.fatalError(msg) - } catch { - case _: LeonFatalError => - } - - hasFatal = true - } - } - -} diff --git a/src/main/scala/leon/Pipeline.scala b/src/main/scala/leon/Pipeline.scala deleted file mode 100644 index 48f4192e3058548af64ddab922e84c4149299b61..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/Pipeline.scala +++ /dev/null @@ -1,36 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon - -abstract class Pipeline[-F, +T] { - self => - - def andThen[G](thenn: Pipeline[T, G]): Pipeline[F, G] = new Pipeline[F,G] { - def run(ctx: LeonContext, v: F): (LeonContext, G) = { - val (ctx2, s) = self.run(ctx, v) - if(ctx.findOptionOrDefault(GlobalOptions.optStrictPhases)) ctx.reporter.terminateIfError() - thenn.run(ctx2, s) - } - } - - def when[F2 <: F, T2 >: T](cond: Boolean)(implicit tps: F2 =:= T2): Pipeline[F2, T2] = { - if (cond) this else new Pipeline[F2, T2] { - def run(ctx: LeonContext, v: F2): (LeonContext, T2) = (ctx, v) - } - } - - def run(ctx: LeonContext, v: F): (LeonContext, T) -} - -object Pipeline { - - def both[T, R1, R2](f1: Pipeline[T, R1], f2: Pipeline[T, R2]): Pipeline[T, (R1, R2)] = new Pipeline[T, (R1, R2)] { - def run(ctx: LeonContext, t: T): (LeonContext, (R1, R2)) = { - val (ctx1, r1) = f1.run(ctx, t) - // don't check for SharedOptions.optStrictPhases because f2 does not depend on the result of f1 - val (ctx2, r2) = f2.run(ctx1, t) - (ctx2, (r1, r2)) - } - } - -} diff --git a/src/main/scala/leon/codegen/CodeGenParams.scala b/src/main/scala/leon/codegen/CodeGenParams.scala deleted file mode 100644 index ff99627659ddd7ef26b736f245d702d70e0318c2..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/codegen/CodeGenParams.scala +++ /dev/null @@ -1,18 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package codegen - -case class CodeGenParams ( - maxFunctionInvocations: Int, // Monitor calls and abort execution if more than X calls - checkContracts: Boolean, // Generate calls that checks pre/postconditions - doInstrument: Boolean // Instrument reads to case classes (mainly for vanuatoo) -) { - val recordInvocations = maxFunctionInvocations > -1 - - val requireMonitor = recordInvocations -} - -object CodeGenParams { - def default = CodeGenParams(maxFunctionInvocations = -1, checkContracts = true, doInstrument = false) -} diff --git a/src/main/scala/leon/codegen/CodeGeneration.scala b/src/main/scala/leon/codegen/CodeGeneration.scala deleted file mode 100644 index 588d97c2024f12442bfc890136383ee4a8dca2c8..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/codegen/CodeGeneration.scala +++ /dev/null @@ -1,1956 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package codegen - -import purescala.Common._ -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Types._ -import purescala.TypeOps._ -import purescala.Constructors._ -import purescala.Extractors._ -import purescala.Quantification._ -import cafebabe._ -import cafebabe.AbstractByteCodes._ -import cafebabe.ByteCodes._ -import cafebabe.ClassFileTypes._ -import cafebabe.Defaults.constructorName -import cafebabe.Flags._ - -trait CodeGeneration { - self: CompilationUnit => - - /** A class providing information about the status of parameters in the function that is being currently compiled. - * vars is a mapping from local variables/ parameters to the offset of the respective JVM local register - * isStatic signifies if the current method is static (a function, in Leon terms) - */ - class Locals private[codegen] ( - vars : Map[Identifier, Int], - args : Map[Identifier, Int], - fields : Map[Identifier, (String,String,String)], - val tps : Seq[TypeParameter] - ) { - /** Fetches the offset of a local variable/ parameter from its identifier */ - def varToLocal(v: Identifier): Option[Int] = vars.get(v) - - def varToArg(v: Identifier): Option[Int] = args.get(v) - - def varToField(v: Identifier): Option[(String,String,String)] = fields.get(v) - - /** Adds some extra variables to the mapping */ - def withVars(newVars: Map[Identifier, Int]) = new Locals(vars ++ newVars, args, fields, tps) - - /** Adds an extra variable to the mapping */ - def withVar(nv: (Identifier, Int)) = new Locals(vars + nv, args, fields, tps) - - def withArgs(newArgs: Map[Identifier, Int]) = new Locals(vars, args ++ newArgs, fields, tps) - - def withFields(newFields: Map[Identifier,(String,String,String)]) = new Locals(vars, args, fields ++ newFields, tps) - - def withTypes(newTps: Seq[TypeParameter]) = new Locals(vars, args, fields, tps ++ newTps) - - override def toString = "Locals("+vars + ", " + args + ", " + fields + ", " + tps + ")" - } - - object NoLocals extends Locals(Map.empty, Map.empty, Map.empty, Seq.empty) - - lazy val monitorID = FreshIdentifier("__$monitor") - lazy val tpsID = FreshIdentifier("__$tps") - - private[codegen] val ObjectClass = "java/lang/Object" - private[codegen] val BoxedIntClass = "java/lang/Integer" - private[codegen] val BoxedBoolClass = "java/lang/Boolean" - private[codegen] val BoxedCharClass = "java/lang/Character" - private[codegen] val BoxedArrayClass = "leon/codegen/runtime/ArrayBox" - - private[codegen] val JavaListClass = "java/util/List" - private[codegen] val JavaIteratorClass = "java/util/Iterator" - private[codegen] val JavaStringClass = "java/lang/String" - - private[codegen] val TupleClass = "leon/codegen/runtime/Tuple" - private[codegen] val SetClass = "leon/codegen/runtime/Set" - private[codegen] val BagClass = "leon/codegen/runtime/Bag" - private[codegen] val MapClass = "leon/codegen/runtime/Map" - private[codegen] val BigIntClass = "leon/codegen/runtime/BigInt" - private[codegen] val RealClass = "leon/codegen/runtime/Real" - private[codegen] val RationalClass = "leon/codegen/runtime/Rational" - private[codegen] val CaseClassClass = "leon/codegen/runtime/CaseClass" - private[codegen] val LambdaClass = "leon/codegen/runtime/Lambda" - private[codegen] val FiniteLambdaClass = "leon/codegen/runtime/FiniteLambda" - private[codegen] val ErrorClass = "leon/codegen/runtime/LeonCodeGenRuntimeException" - private[codegen] val ImpossibleEvaluationClass = "leon/codegen/runtime/LeonCodeGenEvaluationException" - private[codegen] val BadQuantificationClass = "leon/codegen/runtime/LeonCodeGenQuantificationException" - private[codegen] val HashingClass = "leon/codegen/runtime/LeonCodeGenRuntimeHashing" - private[codegen] val ChooseEntryPointClass = "leon/codegen/runtime/ChooseEntryPoint" - private[codegen] val GenericValuesClass = "leon/codegen/runtime/GenericValues" - private[codegen] val MonitorClass = "leon/codegen/runtime/Monitor" - private[codegen] val NoMonitorClass = "leon/codegen/runtime/NoMonitor" - private[codegen] val StrOpsClass = "leon/codegen/runtime/StrOps" - - def idToSafeJVMName(id: Identifier) = { - scala.reflect.NameTransformer.encode(id.uniqueName).replaceAll("\\.", "\\$") - } - - def defToJVMName(d: Definition): String = "Leon$CodeGen$Def$" + idToSafeJVMName(d.id) - - /** Retrieve the name of the underlying lazy field from a lazy field accessor method */ - private[codegen] def underlyingField(lazyAccessor : String) = lazyAccessor + "$underlying" - - protected object ValueType { - def unapply(tp: TypeTree): Boolean = tp match { - case Int32Type | BooleanType | CharType | UnitType => true - case _ => false - } - } - - /** Return the respective JVM type from a Leon type */ - def typeToJVM(tpe : TypeTree) : String = tpe match { - case Int32Type => "I" - - case BooleanType => "Z" - - case CharType => "C" - - case UnitType => "Z" - - case c : ClassType => - leonClassToJVMInfo(c.classDef).map { case (n, _) => "L" + n + ";" }.getOrElse( - throw CompilationException("Unsupported class " + c.id) - ) - - case _ : TupleType => - "L" + TupleClass + ";" - - case _ : SetType => - "L" + SetClass + ";" - - case _ : BagType => - "L" + BagClass + ";" - - case _ : MapType => - "L" + MapClass + ";" - - case IntegerType => - "L" + BigIntClass + ";" - - case RealType => - "L" + RationalClass + ";" - - case _ : FunctionType => - "L" + LambdaClass + ";" - - case ArrayType(base) => - "[" + typeToJVM(base) - - case TypeParameter(_) => - "L" + ObjectClass + ";" - - case StringType => - "L" + JavaStringClass + ";" - - case _ => throw CompilationException("Unsupported type : " + tpe) - } - - /** Return the respective boxed JVM type from a Leon type */ - def typeToJVMBoxed(tpe : TypeTree) : String = tpe match { - case Int32Type => s"L$BoxedIntClass;" - case BooleanType | UnitType => s"L$BoxedBoolClass;" - case CharType => s"L$BoxedCharClass;" - case other => typeToJVM(other) - } - - /** - * Compiles a function/method definition. - * @param funDef The function definition to be compiled - * @param owner The module/class that contains `funDef` - */ - def compileFunDef(funDef: FunDef, owner: Definition) { - val isStatic = owner.isInstanceOf[ModuleDef] - - val cf = classes(owner) - val (_,mn,_) = leonFunDefToJVMInfo(funDef).get - - val tpeParam = if (funDef.tparams.isEmpty) Seq() else Seq("[I") - val realParams = ("L" + MonitorClass + ";") +: (tpeParam ++ funDef.params.map(a => typeToJVM(a.getType))) - - val m = cf.addMethod( - typeToJVM(funDef.returnType), - mn, - realParams : _* - ) - m.setFlags(( - // FIXME Not sure about this "FINAL" now that we can have methods in inheritable classes - if (isStatic) - METHOD_ACC_PUBLIC | - METHOD_ACC_FINAL | - METHOD_ACC_STATIC - else - METHOD_ACC_PUBLIC | - METHOD_ACC_FINAL - ).asInstanceOf[U2]) - - val ch = m.codeHandler - - // An offset we introduce to the parameters: - // 1 if this is a method, so we need "this" in position 0 of the stack - val receiverOffset = if (isStatic) 0 else 1 - val paramIds = Seq(monitorID) ++ - (if (funDef.tparams.nonEmpty) Seq(tpsID) else Seq.empty) ++ - funDef.paramIds - val newMapping = paramIds.zipWithIndex.toMap.mapValues(_ + receiverOffset) - - val body = if (params.checkContracts) { - funDef.fullBody - } else { - funDef.body.getOrElse( - if (funDef.annotations contains "extern") { - Error(funDef.id.getType, "Body of " + funDef.id.name + " not implemented at compile-time and still executed.") - } else { - throw CompilationException("Can't compile a FunDef without body: "+funDef.id.name) - }) - } - - val locals = NoLocals.withVars(newMapping).withTypes(funDef.tparams.map(_.tp)) - - if (params.recordInvocations) { - load(monitorID, ch)(locals) - ch << InvokeVirtual(MonitorClass, "onInvocation", "()V") - } - - mkExpr(body, ch)(locals) - - funDef.returnType match { - case ValueType() => - ch << IRETURN - - case _ => - ch << ARETURN - } - - ch.freeze - } - - private[codegen] val lambdaToClass = scala.collection.mutable.Map.empty[Lambda, String] - private[codegen] val classToLambda = scala.collection.mutable.Map.empty[String, Lambda] - - protected def compileLambda(l: Lambda): (String, Seq[(Identifier, String)], Seq[TypeParameter], String) = { - val tparams: Seq[TypeParameter] = typeParamsOf(l).toSeq.sortBy(_.id.uniqueName) - - val closedVars = purescala.ExprOps.variablesOf(l).toSeq.sortBy(_.uniqueName) - val closuresWithoutMonitor = closedVars.map(id => id -> typeToJVM(id.getType)) - val closures = (monitorID -> s"L$MonitorClass;") +: - ((if (tparams.nonEmpty) Seq(tpsID -> "[I") else Seq.empty) ++ closuresWithoutMonitor) - - val afName = lambdaToClass.getOrElse(l, { - val afId = FreshIdentifier("Leon$CodeGen$Lambda$") - val afName = afId.uniqueName - lambdaToClass += l -> afName - classToLambda += afName -> l - - val cf = new ClassFile(afName, Some(LambdaClass)) - - cf.setFlags(( - CLASS_ACC_SUPER | - CLASS_ACC_PUBLIC | - CLASS_ACC_FINAL - ).asInstanceOf[U2]) - - if (closures.isEmpty) { - cf.addDefaultConstructor - } else { - for ((id, jvmt) <- closures) { - val fh = cf.addField(jvmt, id.uniqueName) - fh.setFlags(( - FIELD_ACC_PUBLIC | - FIELD_ACC_FINAL - ).asInstanceOf[U2]) - } - - val cch = cf.addConstructor(closures.map(_._2).toList).codeHandler - - cch << ALoad(0) - cch << InvokeSpecial(LambdaClass, constructorName, "()V") - - var c = 1 - for ((id, jvmt) <- closures) { - cch << ALoad(0) - cch << (jvmt match { - case "I" | "Z" => ILoad(c) - case _ => ALoad(c) - }) - cch << PutField(afName, id.uniqueName, jvmt) - c += 1 - } - - cch << RETURN - cch.freeze - } - - val argMapping = l.args.map(_.id).zipWithIndex.toMap - val closureMapping = closures.map { case (id, jvmt) => id -> (afName, id.uniqueName, jvmt) }.toMap - val newLocals = NoLocals.withArgs(argMapping).withFields(closureMapping).withTypes(tparams) - - locally { - val apm = cf.addMethod(s"L$ObjectClass;", "apply", s"[L$ObjectClass;") - - apm.setFlags(( - METHOD_ACC_PUBLIC | - METHOD_ACC_FINAL - ).asInstanceOf[U2]) - - val apch = apm.codeHandler - - mkBoxedExpr(l.body, apch)(newLocals) - - apch << ARETURN - - apch.freeze - } - - locally { - val emh = cf.addMethod("Z", "equals", s"L$ObjectClass;") - emh.setFlags(( - METHOD_ACC_PUBLIC | - METHOD_ACC_FINAL - ).asInstanceOf[U2]) - - val ech = emh.codeHandler - - val notRefEq = ech.getFreshLabel("notrefeq") - val notEq = ech.getFreshLabel("noteq") - val castSlot = ech.getFreshVar - - // If references are equal, trees are equal. - ech << ALoad(0) << ALoad(1) << If_ACmpNe(notRefEq) << Ldc(1) << IRETURN << Label(notRefEq) - - // We check the type (this also checks against null).... - ech << ALoad(1) << InstanceOf(afName) << IfEq(notEq) - - // ...finally, we compare fields one by one, shortcircuiting on disequalities. - if(closures.nonEmpty) { - ech << ALoad(1) << CheckCast(afName) << AStore(castSlot) - - for((id,jvmt) <- closures) { - ech << ALoad(0) << GetField(afName, id.uniqueName, jvmt) - ech << ALoad(castSlot) << GetField(afName, id.uniqueName, jvmt) - - jvmt match { - case "I" | "Z" => - ech << If_ICmpNe(notEq) - - case ot => - ech << InvokeVirtual(ObjectClass, "equals", s"(L$ObjectClass;)Z") << IfEq(notEq) - } - } - } - - ech << Ldc(1) << IRETURN << Label(notEq) << Ldc(0) << IRETURN - ech.freeze - } - - locally { - val hashFieldName = "$leon$hashCode" - cf.addField("I", hashFieldName).setFlags(FIELD_ACC_PRIVATE) - val hmh = cf.addMethod("I", "hashCode", "") - hmh.setFlags(( - METHOD_ACC_PUBLIC | - METHOD_ACC_FINAL - ).asInstanceOf[U2]) - - val hch = hmh.codeHandler - - val wasNotCached = hch.getFreshLabel("wasNotCached") - - hch << ALoad(0) << GetField(afName, hashFieldName, "I") << DUP - hch << IfEq(wasNotCached) - hch << IRETURN - hch << Label(wasNotCached) << POP - - hch << Ldc(closuresWithoutMonitor.size) << NewArray(s"$ObjectClass") - for (((id, jvmt),i) <- closuresWithoutMonitor.zipWithIndex) { - hch << DUP << Ldc(i) - hch << ALoad(0) << GetField(afName, id.uniqueName, jvmt) - mkBox(id.getType, hch) - hch << AASTORE - } - - hch << Ldc(afName.hashCode) - hch << InvokeStatic(HashingClass, "seqHash", s"([L$ObjectClass;I)I") << DUP - hch << ALoad(0) << SWAP << PutField(afName, hashFieldName, "I") - hch << IRETURN - - hch.freeze - } - - loader.register(cf) - - afName - }) - - (afName, closures.map { case p @ (id, jvmt) => - if (id == monitorID || id == tpsID) p else (id -> jvmt) - }, tparams, "(" + closures.map(_._2).mkString("") + ")V") - } - - // also makes tuples with 0/1 args - private def mkTuple(es: Seq[Expr], ch: CodeHandler)(implicit locals: Locals) : Unit = { - ch << New(TupleClass) << DUP - ch << Ldc(es.size) - ch << NewArray(s"$ObjectClass") - for((e,i) <- es.zipWithIndex) { - ch << DUP - ch << Ldc(i) - mkBoxedExpr(e, ch) - ch << AASTORE - } - ch << InvokeSpecial(TupleClass, constructorName, s"([L$ObjectClass;)V") - } - - private def loadTypes(tps: Seq[TypeTree], ch: CodeHandler)(implicit locals: Locals): Unit = { - if (tps.nonEmpty) { - ch << Ldc(tps.size) - ch << NewArray.primitive("T_INT") - for ((tpe,idx) <- tps.zipWithIndex) { - ch << DUP << Ldc(idx) << Ldc(registerType(tpe)) << IASTORE - } - - if (locals.tps.nonEmpty) { - load(monitorID, ch) - ch << SWAP - - ch << Ldc(locals.tps.size) - ch << NewArray.primitive("T_INT") - for ((tpe,idx) <- locals.tps.zipWithIndex) { - ch << DUP << Ldc(idx) << Ldc(registerType(tpe)) << IASTORE - } - - ch << SWAP - load(tpsID, ch) - ch << InvokeVirtual(MonitorClass, "typeParams", s"([I[I[I)[I") - } - } - } - - private[codegen] def mkExpr(e: Expr, ch: CodeHandler, canDelegateToMkBranch: Boolean = true)(implicit locals: Locals) { - e match { - case Variable(id) => - load(id, ch) - - case Assert(cond, oerr, body) => - mkExpr(IfExpr(Not(cond), Error(body.getType, oerr.getOrElse("Assertion failed @"+e.getPos)), body), ch) - - case en @ Ensuring(_, _) => - mkExpr(en.toAssert, ch) - - case Require(pre, body) => - mkExpr(IfExpr(pre, body, Error(body.getType, "Precondition failed")), ch) - - case Let(id, d, Variable(id2)) if id == id2 => // Optimization for local variables. - mkExpr(d, ch) - - case Let(id, d, Let(id3, Variable(id2), Variable(id4))) if id == id2 && id3 == id4 => // Optimization for local variables. - mkExpr(d, ch) - - case Let(i,d,b) => - mkExpr(d, ch) - val slot = ch.getFreshVar - val instr = i.getType match { - case ValueType() => - if(slot > 127) { - println("Error while converting one more slot which is too much " + e) - } - IStore(slot) - case _ => AStore(slot) - } - ch << instr - mkExpr(b, ch)(locals.withVar(i -> slot)) - - case IntLiteral(v) => - ch << Ldc(v) - - case CharLiteral(v) => - ch << Ldc(v) - - case BooleanLiteral(v) => - ch << Ldc(if(v) 1 else 0) - - case UnitLiteral() => - ch << Ldc(1) - - case StringLiteral(v) => - ch << Ldc(v) - - case InfiniteIntegerLiteral(v) => - ch << New(BigIntClass) << DUP - ch << Ldc(v.toString) - ch << InvokeSpecial(BigIntClass, constructorName, "(Ljava/lang/String;)V") - - case FractionalLiteral(n, d) => - ch << New(RationalClass) << DUP - ch << Ldc(n.toString) - ch << Ldc(d.toString) - ch << InvokeSpecial(RationalClass, constructorName, "(Ljava/lang/String;Ljava/lang/String;)V") - - // Case classes - case CaseClass(cct, as) => - val (ccName, ccApplySig) = leonClassToJVMInfo(cct.classDef).getOrElse { - throw CompilationException("Unknown class : " + cct.id) - } - ch << New(ccName) << DUP - load(monitorID, ch) - loadTypes(cct.tps, ch) - - for ((a, vd) <- as zip cct.classDef.fields) { - vd.getType match { - case TypeParameter(_) => - mkBoxedExpr(a, ch) - case _ => - mkExpr(a, ch) - } - } - ch << InvokeSpecial(ccName, constructorName, ccApplySig) - - case IsInstanceOf(e, cct) => - val (ccName, _) = leonClassToJVMInfo(cct.classDef).getOrElse { - throw CompilationException("Unknown class : " + cct.id) - } - mkExpr(e, ch) - ch << InstanceOf(ccName) - - case AsInstanceOf(e, cct) => - val (ccName, _) = leonClassToJVMInfo(cct.classDef).getOrElse { - throw CompilationException("Unknown class : " + cct.id) - } - mkExpr(e, ch) - ch << CheckCast(ccName) - - case CaseClassSelector(cct, e, sid) => - mkExpr(e, ch) - val (ccName, _) = leonClassToJVMInfo(cct.classDef).getOrElse { - throw CompilationException("Unknown class : " + cct.id) - } - ch << CheckCast(ccName) - instrumentedGetField(ch, cct, sid) - - // Tuples (note that instanceOf checks are in mkBranch) - case Tuple(es) => mkTuple(es, ch) - - case TupleSelect(t, i) => - val TupleType(bs) = t.getType - mkExpr(t,ch) - ch << Ldc(i - 1) - ch << InvokeVirtual(TupleClass, "get", s"(I)L$ObjectClass;") - mkUnbox(bs(i - 1), ch) - - // Sets - case FiniteSet(es, _) => - ch << DefaultNew(SetClass) - for(e <- es) { - ch << DUP - mkBoxedExpr(e, ch) - ch << InvokeVirtual(SetClass, "add", s"(L$ObjectClass;)V") - } - - case SetAdd(s, e) => - mkExpr(s, ch) - mkBoxedExpr(e, ch) - ch << InvokeVirtual(SetClass, "plus", s"(L$ObjectClass;)L$SetClass;") - - case ElementOfSet(e, s) => - mkExpr(s, ch) - mkBoxedExpr(e, ch) - ch << InvokeVirtual(SetClass, "contains", s"(L$ObjectClass;)Z") - - case SetCardinality(s) => - mkExpr(s, ch) - ch << InvokeVirtual(SetClass, "size", s"()$BigIntClass;") - - case SubsetOf(s1, s2) => - mkExpr(s1, ch) - mkExpr(s2, ch) - ch << InvokeVirtual(SetClass, "subsetOf", s"(L$SetClass;)Z") - - case SetIntersection(s1, s2) => - mkExpr(s1, ch) - mkExpr(s2, ch) - ch << InvokeVirtual(SetClass, "intersect", s"(L$SetClass;)L$SetClass;") - - case SetUnion(s1, s2) => - mkExpr(s1, ch) - mkExpr(s2, ch) - ch << InvokeVirtual(SetClass, "union", s"(L$SetClass;)L$SetClass;") - - case SetDifference(s1, s2) => - mkExpr(s1, ch) - mkExpr(s2, ch) - ch << InvokeVirtual(SetClass, "minus", s"(L$SetClass;)L$SetClass;") - - // Bags - case FiniteBag(els, _) => - ch << DefaultNew(BagClass) - for((k,v) <- els) { - ch << DUP - mkBoxedExpr(k, ch) - mkExpr(v, ch) - ch << InvokeVirtual(BagClass, "add", s"(L$ObjectClass;L$BigIntClass;)V") - } - - case BagAdd(b, e) => - mkExpr(b, ch) - mkBoxedExpr(e, ch) - ch << InvokeVirtual(BagClass, "plus", s"(L$ObjectClass;)L$BagClass;") - - case MultiplicityInBag(e, b) => - mkExpr(b, ch) - mkBoxedExpr(e, ch) - ch << InvokeVirtual(BagClass, "get", s"(L$ObjectClass;)L$BigIntClass;") - - case BagIntersection(b1, b2) => - mkExpr(b1, ch) - mkExpr(b2, ch) - ch << InvokeVirtual(BagClass, "intersect", s"(L$BagClass;)L$BagClass;") - - case BagUnion(b1, b2) => - mkExpr(b1, ch) - mkExpr(b2, ch) - ch << InvokeVirtual(BagClass, "union", s"(L$BagClass;)L$BagClass;") - - case BagDifference(b1, b2) => - mkExpr(b1, ch) - mkExpr(b2, ch) - ch << InvokeVirtual(BagClass, "difference", s"(L$BagClass;)L$BagClass;") - - // Maps - case FiniteMap(ss, _, _) => - ch << DefaultNew(MapClass) - for((f,t) <- ss) { - ch << DUP - mkBoxedExpr(f, ch) - mkBoxedExpr(t, ch) - ch << InvokeVirtual(MapClass, "add", s"(L$ObjectClass;L$ObjectClass;)V") - } - - case MapApply(m, k) => - val MapType(_, tt) = m.getType - mkExpr(m, ch) - mkBoxedExpr(k, ch) - ch << InvokeVirtual(MapClass, "get", s"(L$ObjectClass;)L$ObjectClass;") - mkUnbox(tt, ch) - - case MapIsDefinedAt(m, k) => - mkExpr(m, ch) - mkBoxedExpr(k, ch) - ch << InvokeVirtual(MapClass, "isDefinedAt", s"(L$ObjectClass;)Z") - - case MapUnion(m1, m2) => - mkExpr(m1, ch) - mkExpr(m2, ch) - ch << InvokeVirtual(MapClass, "union", s"(L$MapClass;)L$MapClass;") - - // Branching - case IfExpr(c, t, e) => - val tl = ch.getFreshLabel("then") - val el = ch.getFreshLabel("else") - val al = ch.getFreshLabel("after") - mkBranch(c, tl, el, ch) - ch << Label(tl) - mkExpr(t, ch) - ch << Goto(al) << Label(el) - mkExpr(e, ch) - ch << Label(al) - - // Strict static fields - case FunctionInvocation(tfd, as) if tfd.fd.canBeStrictField => - val (className, fieldName, _) = leonFunDefToJVMInfo(tfd.fd).getOrElse { - throw CompilationException("Unknown method : " + tfd.id) - } - - // Get static field - ch << GetStatic(className, fieldName, typeToJVM(tfd.fd.returnType)) - - // unbox field - (tfd.fd.returnType, tfd.returnType) match { - case (TypeParameter(_), tpe) => - mkUnbox(tpe, ch) - case _ => - } - - case FunctionInvocation(TypedFunDef(fd, Nil), Seq(a)) if fd == program.library.escape.get => - mkExpr(a, ch) - ch << InvokeStatic(StrOpsClass, "escape", s"(L$JavaStringClass;)L$JavaStringClass;") - - case FunctionInvocation(TypedFunDef(fd, Seq(tp)), Seq(set)) if fd == program.library.setToList.get => - - val nil = CaseClass(CaseClassType(program.library.Nil.get, Seq(tp)), Seq()) - val cons = program.library.Cons.get - val (consName, ccApplySig) = leonClassToJVMInfo(cons).getOrElse { - throw CompilationException("Unknown class : " + cons) - } - - mkExpr(nil, ch) - mkExpr(set, ch) - //if (params.requireMonitor) { - // ch << ALoad(locals.monitorIndex) - //} - - // No dynamic dispatching/overriding in Leon, - // so no need to take care of own vs. "super" methods - ch << InvokeVirtual(SetClass, "getElements", s"()L$JavaIteratorClass;") - - val loop = ch.getFreshLabel("loop") - val out = ch.getFreshLabel("out") - ch << Label(loop) - // list, it - ch << DUP - // list, it, it - ch << InvokeInterface(JavaIteratorClass, "hasNext", "()Z") - // list, it, hasNext - ch << IfEq(out) - // list, it - ch << DUP2 - // list, it, list, it - ch << InvokeInterface(JavaIteratorClass, "next", s"()L$ObjectClass;") << SWAP - // list, it, elem, list - ch << New(consName) << DUP << DUP2_X2 - // list, it, cons, cons, elem, list, cons, cons - ch << POP << POP - // list, it, cons, cons, elem, list - - load(monitorID, ch) - ch << DUP_X2 << POP - loadTypes(Seq(tp), ch) - ch << DUP_X2 << POP - - ch << InvokeSpecial(consName, constructorName, ccApplySig) - // list, it, newList - ch << DUP_X2 << POP << SWAP << POP - // newList, it - ch << Goto(loop) - - ch << Label(out) - // list, it - ch << POP - // list - - case FunctionInvocation(tfd, as) if abstractFunDefs(tfd.fd.id) => - val id = registerAbstractFD(tfd.fd) - - load(monitorID, ch) - - ch << Ldc(id) - if (tfd.fd.tparams.nonEmpty) { - loadTypes(tfd.tps, ch) - } else { - ch << Ldc(0) << NewArray.primitive("T_INT") - } - - ch << Ldc(as.size) - ch << NewArray(ObjectClass) - - for ((e, i) <- as.zipWithIndex) { - ch << DUP - ch << Ldc(i) - mkExpr(e, ch) - mkBox(e.getType, ch) - ch << AASTORE - } - - ch << InvokeVirtual(MonitorClass, "onAbstractInvocation", s"(I[I[L$ObjectClass;)L$ObjectClass;") - - mkUnbox(tfd.returnType, ch) - - // Static lazy fields/ functions - case fi @ FunctionInvocation(tfd, as) => - val (cn, mn, ms) = leonFunDefToJVMInfo(tfd.fd).getOrElse { - throw CompilationException("Unknown method : " + tfd.id) - } - - load(monitorID, ch) - loadTypes(tfd.tps, ch) - - for((a, vd) <- as zip tfd.fd.params) { - vd.getType match { - case TypeParameter(_) => - mkBoxedExpr(a, ch) - case _ => - mkExpr(a, ch) - } - } - - ch << InvokeStatic(cn, mn, ms) - - (tfd.fd.returnType, tfd.returnType) match { - case (TypeParameter(_), tpe) => - mkUnbox(tpe, ch) - case _ => - } - - // Strict fields are handled as fields - case MethodInvocation(rec, _, tfd, _) if tfd.fd.canBeStrictField => - val (className, fieldName, _) = leonFunDefToJVMInfo(tfd.fd).getOrElse { - throw CompilationException("Unknown method : " + tfd.id) - } - - // Load receiver - mkExpr(rec,ch) - - // Get field - ch << GetField(className, fieldName, typeToJVM(tfd.fd.returnType)) - - // unbox field - (tfd.fd.returnType, tfd.returnType) match { - case (TypeParameter(_), tpe) => - mkUnbox(tpe, ch) - case _ => - } - - // This is for lazy fields and real methods. - // To access a lazy field, we call its accessor function. - case MethodInvocation(rec, cd, tfd, as) => - val (className, methodName, sig) = leonFunDefToJVMInfo(tfd.fd).getOrElse { - throw CompilationException("Unknown method : " + tfd.id) - } - - // Receiver of the method call - mkExpr(rec, ch) - - load(monitorID, ch) - loadTypes(tfd.tps, ch) - - for((a, vd) <- as zip tfd.fd.params) { - vd.getType match { - case TypeParameter(_) => - mkBoxedExpr(a, ch) - case _ => - mkExpr(a, ch) - } - } - - // No interfaces in Leon, so no need to use InvokeInterface - ch << InvokeVirtual(className, methodName, sig) - - (tfd.fd.returnType, tfd.returnType) match { - case (TypeParameter(_), tpe) => - mkUnbox(tpe, ch) - case _ => - } - - case app @ Application(caller, args) => - mkExpr(caller, ch) - ch << Ldc(args.size) << NewArray(s"$ObjectClass") - for ((arg,i) <- args.zipWithIndex) { - ch << DUP << Ldc(i) - mkBoxedExpr(arg, ch) - ch << AASTORE - } - - ch << InvokeVirtual(LambdaClass, "apply", s"([L$ObjectClass;)L$ObjectClass;") - mkUnbox(app.getType, ch) - - case p @ FiniteLambda(mapping, dflt, _) => - ch << New(FiniteLambdaClass) << DUP - mkBoxedExpr(dflt, ch) - ch << InvokeSpecial(FiniteLambdaClass, constructorName, s"(L$ObjectClass;)V") - - for ((es,v) <- mapping) { - ch << DUP - mkTuple(es, ch) - mkBoxedExpr(v, ch) - ch << InvokeVirtual(FiniteLambdaClass, "add", s"(L$TupleClass;L$ObjectClass;)V") - } - - case l @ Lambda(args, body) => - val (afName, closures, tparams, consSig) = compileLambda(l) - - ch << New(afName) << DUP - for ((id,jvmt) <- closures) { - if (id == tpsID) { - loadTypes(tparams, ch) - } else { - mkExpr(Variable(id), ch) - } - } - ch << InvokeSpecial(afName, constructorName, consSig) - - // String processing => - case StringConcat(l, r) => - mkExpr(l, ch) - mkExpr(r, ch) - ch << InvokeStatic(StrOpsClass, "concat", s"(L$JavaStringClass;L$JavaStringClass;)L$JavaStringClass;") - - case StringLength(a) => - mkExpr(a, ch) - ch << InvokeVirtual(JavaStringClass, "length", s"()I") - - case StringBigLength(a) => - mkExpr(a, ch) - ch << InvokeStatic(StrOpsClass, "bigLength", s"(L$JavaStringClass;)L$BigIntClass;") - - case Int32ToString(a) => - mkExpr(a, ch) - ch << InvokeStatic(StrOpsClass, "intToString", s"(I)L$JavaStringClass;") - case BooleanToString(a) => - mkExpr(a, ch) - ch << InvokeStatic(StrOpsClass, "booleanToString", s"(Z)L$JavaStringClass;") - case IntegerToString(a) => - mkExpr(a, ch) - ch << InvokeStatic(StrOpsClass, "bigIntToString", s"(L$BigIntClass;)L$JavaStringClass;") - case CharToString(a) => - mkExpr(a, ch) - ch << InvokeStatic(StrOpsClass, "charToString", s"(C)L$JavaStringClass;") - case RealToString(a) => - mkExpr(a, ch) - ch << InvokeStatic(StrOpsClass, "realToString", s"(L$RealClass;)L$JavaStringClass;") - - case SubString(a, start, end) => - mkExpr(a, ch) - mkExpr(start, ch) - mkExpr(end, ch) - ch << InvokeVirtual(JavaStringClass, "substring", s"(II)L$JavaStringClass;") - - case BigSubString(a, start, end) => - mkExpr(a, ch) - mkExpr(start, ch) - mkExpr(end, ch) - ch << InvokeStatic(StrOpsClass, "bigSubstring", s"(L$JavaStringClass;L$BigIntClass;L$BigIntClass;)L$JavaStringClass;") - - // Arithmetic - case Plus(l, r) => - mkExpr(l, ch) - mkExpr(r, ch) - ch << InvokeVirtual(BigIntClass, "add", s"(L$BigIntClass;)L$BigIntClass;") - - case Minus(l, r) => - mkExpr(l, ch) - mkExpr(r, ch) - ch << InvokeVirtual(BigIntClass, "sub", s"(L$BigIntClass;)L$BigIntClass;") - - case Times(l, r) => - mkExpr(l, ch) - mkExpr(r, ch) - ch << InvokeVirtual(BigIntClass, "mult", s"(L$BigIntClass;)L$BigIntClass;") - - case Division(l, r) => - mkExpr(l, ch) - mkExpr(r, ch) - ch << InvokeVirtual(BigIntClass, "div", s"(L$BigIntClass;)L$BigIntClass;") - - case Remainder(l, r) => - mkExpr(l, ch) - mkExpr(r, ch) - ch << InvokeVirtual(BigIntClass, "rem", s"(L$BigIntClass;)L$BigIntClass;") - - case Modulo(l, r) => - mkExpr(l, ch) - mkExpr(r, ch) - ch << InvokeVirtual(BigIntClass, "mod", s"(L$BigIntClass;)L$BigIntClass;") - - case UMinus(e) => - mkExpr(e, ch) - ch << InvokeVirtual(BigIntClass, "neg", s"()L$BigIntClass;") - - case RealPlus(l, r) => - mkExpr(l, ch) - mkExpr(r, ch) - ch << InvokeVirtual(RationalClass, "add", s"(L$RationalClass;)L$RationalClass;") - - case RealMinus(l, r) => - mkExpr(l, ch) - mkExpr(r, ch) - ch << InvokeVirtual(RationalClass, "sub", s"(L$RationalClass;)L$RationalClass;") - - case RealTimes(l, r) => - mkExpr(l, ch) - mkExpr(r, ch) - ch << InvokeVirtual(RationalClass, "mult", s"(L$RationalClass;)L$RationalClass;") - - case RealDivision(l, r) => - mkExpr(l, ch) - mkExpr(r, ch) - ch << InvokeVirtual(RationalClass, "div", s"(L$RationalClass;)L$RationalClass;") - - case RealUMinus(e) => - mkExpr(e, ch) - ch << InvokeVirtual(RationalClass, "neg", s"()L$RationalClass;") - - - //BV arithmetic - case BVPlus(l, r) => - mkExpr(l, ch) - mkExpr(r, ch) - ch << IADD - - case BVMinus(l, r) => - mkExpr(l, ch) - mkExpr(r, ch) - ch << ISUB - - case BVTimes(l, r) => - mkExpr(l, ch) - mkExpr(r, ch) - ch << IMUL - - case BVDivision(l, r) => - mkExpr(l, ch) - mkExpr(r, ch) - ch << IDIV - - case BVRemainder(l, r) => - mkExpr(l, ch) - mkExpr(r, ch) - ch << IREM - - case BVUMinus(e) => - mkExpr(e, ch) - ch << INEG - - case BVNot(e) => - mkExpr(e, ch) - mkExpr(IntLiteral(-1), ch) - ch << IXOR - - case BVAnd(l, r) => - mkExpr(l, ch) - mkExpr(r, ch) - ch << IAND - - case BVOr(l, r) => - mkExpr(l, ch) - mkExpr(r, ch) - ch << IOR - - case BVXOr(l, r) => - mkExpr(l, ch) - mkExpr(r, ch) - ch << IXOR - - case BVShiftLeft(l, r) => - mkExpr(l, ch) - mkExpr(r, ch) - ch << ISHL - - case BVLShiftRight(l, r) => - mkExpr(l, ch) - mkExpr(r, ch) - ch << IUSHR - - case BVAShiftRight(l, r) => - mkExpr(l, ch) - mkExpr(r, ch) - ch << ISHR - - case ArrayLength(a) => - mkExpr(a, ch) - ch << ARRAYLENGTH - - case as @ ArraySelect(a,i) => - mkExpr(a, ch) - mkExpr(i, ch) - ch << (as.getType match { - case Untyped => throw CompilationException("Cannot compile untyped array access.") - case CharType => CALOAD - case Int32Type => IALOAD - case BooleanType => BALOAD - case _ => AALOAD - }) - - case au @ ArrayUpdated(a, i, v) => - mkExpr(a, ch) - ch << DUP - ch << ARRAYLENGTH - val storeInstr = a.getType match { - case ArrayType(CharType) => ch << NewArray.primitive("T_CHAR"); CASTORE - case ArrayType(Int32Type) => ch << NewArray.primitive("T_INT"); IASTORE - case ArrayType(BooleanType) => ch << NewArray.primitive("T_BOOLEAN"); BASTORE - case ArrayType(other) => ch << NewArray(typeToJVM(other)); AASTORE - case other => throw CompilationException(s"Cannot compile finite array expression whose type is $other.") - } - //srcArrary and targetArray is on the stack - ch << DUP_X1 //insert targetArray under srcArray - ch << Ldc(0) << SWAP //srcArray, 0, targetArray - ch << DUP << ARRAYLENGTH //targetArray, length on stack - ch << Ldc(0) << SWAP //final arguments: src, 0, target, 0, length - ch << InvokeStatic("java/lang/System", "arraycopy", s"(L$ObjectClass;IL$ObjectClass;II)V") - - //targetArray remains on the stack - ch << DUP - mkExpr(i, ch) - mkExpr(v, ch) - ch << storeInstr - //returns targetArray - - case a @ FiniteArray(elems, default, length) => - mkExpr(length, ch) - - val storeInstr = a.getType match { - case ArrayType(CharType) => ch << NewArray.primitive("T_CHAR"); CASTORE - case ArrayType(Int32Type) => ch << NewArray.primitive("T_INT"); IASTORE - case ArrayType(BooleanType) => ch << NewArray.primitive("T_BOOLEAN"); BASTORE - case ArrayType(other) => ch << NewArray(typeToJVM(other)); AASTORE - case other => throw CompilationException(s"Cannot compile finite array expression whose type is $other.") - } - - // Fill up with default - default foreach { df => - val loop = ch.getFreshLabel("array_loop") - val loopOut = ch.getFreshLabel("array_loop_out") - ch << Ldc(0) - // (array, index) - ch << Label(loop) - ch << DUP2 << SWAP - // (array, index, index, array) - ch << ARRAYLENGTH - // (array, index, index, length) - ch << If_ICmpGe(loopOut) << DUP2 - // (array, index, array, index) - mkExpr(df, ch) - ch << storeInstr - ch << Ldc(1) << IADD << Goto(loop) - ch << Label(loopOut) << POP - } - - // Replace present elements with correct value - for ((i,v) <- elems ) { - ch << DUP << Ldc(i) - mkExpr(v, ch) - ch << storeInstr - } - - // Misc and boolean tests - case Error(tpe, desc) => - ch << New(ErrorClass) << DUP - ch << Ldc(desc) - ch << InvokeSpecial(ErrorClass, constructorName, "(Ljava/lang/String;)V") - ch << ATHROW - - case forall @ Forall(fargs, body) => - val id = registerForall(forall, locals.tps) - val args = variablesOf(forall).toSeq.sortBy(_.uniqueName) - - load(monitorID, ch) - ch << Ldc(id) - if (locals.tps.nonEmpty) { - load(tpsID, ch) - } else { - ch << Ldc(0) << NewArray.primitive("T_INT") - } - - ch << Ldc(args.size) - ch << NewArray(ObjectClass) - - for ((id, i) <- args.zipWithIndex) { - ch << DUP - ch << Ldc(i) - mkExpr(Variable(id), ch) - mkBox(id.getType, ch) - ch << AASTORE - } - - ch << InvokeVirtual(MonitorClass, "onForallInvocation", s"(I[I[L$ObjectClass;)Z") - - case choose: Choose => - val prob = synthesis.Problem.fromSpec(choose.pred) - - val id = registerProblem(prob, locals.tps) - - load(monitorID, ch) - ch << Ldc(id) - if (locals.tps.nonEmpty) { - load(tpsID, ch) - } else { - ch << Ldc(0) << NewArray.primitive("T_INT") - } - - ch << Ldc(prob.as.size) - ch << NewArray(ObjectClass) - - for ((id, i) <- prob.as.zipWithIndex) { - ch << DUP - ch << Ldc(i) - mkExpr(Variable(id), ch) - mkBox(id.getType, ch) - ch << AASTORE - } - - ch << InvokeVirtual(MonitorClass, "onChooseInvocation", s"(I[I[L$ObjectClass;)L$ObjectClass;") - - mkUnbox(choose.getType, ch) - - case gv @ GenericValue(tp, int) => - val id = runtime.GenericValues.register(gv) - ch << Ldc(id) - ch << InvokeStatic(GenericValuesClass, "get", s"(I)L$ObjectClass;") - - case nt @ NoTree( tp@ValueType() ) => - mkExpr(simplestValue(tp), ch) - - case NoTree(_) => - ch << ACONST_NULL - - case This(ct) => - ch << ALoad(0) - - case p : Passes => - mkExpr(matchToIfThenElse(p.asConstraint), ch) - - case m : MatchExpr => - mkExpr(matchToIfThenElse(m), ch) - - case b if b.getType == BooleanType && canDelegateToMkBranch => - val fl = ch.getFreshLabel("boolfalse") - val al = ch.getFreshLabel("boolafter") - ch << Ldc(1) - mkBranch(b, al, fl, ch, canDelegateToMkExpr = false) - ch << Label(fl) << POP << Ldc(0) << Label(al) - - case synthesis.utils.MutableExpr(e) => - mkExpr(e, ch) - - case _ => throw CompilationException("Unsupported expr " + e + " : " + e.getClass) - } - } - - // Leaves on the stack a value equal to `e`, always of a type compatible with java.lang.Object. - private[codegen] def mkBoxedExpr(e: Expr, ch: CodeHandler)(implicit locals: Locals) { - e.getType match { - case Int32Type => - ch << New(BoxedIntClass) << DUP - mkExpr(e, ch) - ch << InvokeSpecial(BoxedIntClass, constructorName, "(I)V") - - case BooleanType | UnitType => - ch << New(BoxedBoolClass) << DUP - mkExpr(e, ch) - ch << InvokeSpecial(BoxedBoolClass, constructorName, "(Z)V") - - case CharType => - ch << New(BoxedCharClass) << DUP - mkExpr(e, ch) - ch << InvokeSpecial(BoxedCharClass, constructorName, "(C)V") - - case at @ ArrayType(et) => - ch << New(BoxedArrayClass) << DUP - mkExpr(e, ch) - ch << InvokeSpecial(BoxedArrayClass, constructorName, s"(${typeToJVM(at)})V") - - case _ => - mkExpr(e, ch) - } - } - - // Assumes the top of the stack contains of value of the right type, and makes it - // compatible with java.lang.Object. - private[codegen] def mkBox(tpe: TypeTree, ch: CodeHandler): Unit = { - tpe match { - case Int32Type => - ch << New(BoxedIntClass) << DUP_X1 << SWAP << InvokeSpecial(BoxedIntClass, constructorName, "(I)V") - - case BooleanType | UnitType => - ch << New(BoxedBoolClass) << DUP_X1 << SWAP << InvokeSpecial(BoxedBoolClass, constructorName, "(Z)V") - - case CharType => - ch << New(BoxedCharClass) << DUP_X1 << SWAP << InvokeSpecial(BoxedCharClass, constructorName, "(C)V") - - case at @ ArrayType(et) => - ch << New(BoxedArrayClass) << DUP_X1 << SWAP << InvokeSpecial(BoxedArrayClass, constructorName, s"(${typeToJVM(at)})V") - case _ => - } - } - - // Assumes that the top of the stack contains a value that should be of type `tpe`, and unboxes it to the right (JVM) type. - private[codegen] def mkUnbox(tpe: TypeTree, ch: CodeHandler): Unit = { - tpe match { - case Int32Type => - ch << CheckCast(BoxedIntClass) << InvokeVirtual(BoxedIntClass, "intValue", "()I") - - case BooleanType | UnitType => - ch << CheckCast(BoxedBoolClass) << InvokeVirtual(BoxedBoolClass, "booleanValue", "()Z") - - case CharType => - ch << CheckCast(BoxedCharClass) << InvokeVirtual(BoxedCharClass, "charValue", "()C") - - case ct : ClassType => - val (cn, _) = leonClassToJVMInfo(ct.classDef).getOrElse { - throw new CompilationException("Unsupported class type : " + ct) - } - ch << CheckCast(cn) - - case IntegerType => - ch << CheckCast(BigIntClass) - - case StringType => - ch << CheckCast(JavaStringClass) - - case RealType => - ch << CheckCast(RationalClass) - - case tt : TupleType => - ch << CheckCast(TupleClass) - - case st : SetType => - ch << CheckCast(SetClass) - - case mt : MapType => - ch << CheckCast(MapClass) - - case ft : FunctionType => - ch << CheckCast(LambdaClass) - - case tp : TypeParameter => - - case tp : ArrayType => - ch << CheckCast(BoxedArrayClass) << InvokeVirtual(BoxedArrayClass, "arrayValue", s"()${typeToJVM(tp)}") - ch << CheckCast(typeToJVM(tp)) - - case _ => - throw new CompilationException("Unsupported type in unboxing : " + tpe) - } - } - - private[codegen] def mkBranch(cond: Expr, thenn: String, elze: String, ch: CodeHandler, canDelegateToMkExpr: Boolean = true)(implicit locals: Locals) { - cond match { - case BooleanLiteral(true) => - ch << Goto(thenn) - - case BooleanLiteral(false) => - ch << Goto(elze) - - case And(es) => - val fl = ch.getFreshLabel("andnext") - mkBranch(es.head, fl, elze, ch) - ch << Label(fl) - mkBranch(andJoin(es.tail), thenn, elze, ch) - - case Or(es) => - val fl = ch.getFreshLabel("ornext") - mkBranch(es.head, thenn, fl, ch) - ch << Label(fl) - mkBranch(orJoin(es.tail), thenn, elze, ch) - - case Implies(l, r) => - mkBranch(or(not(l), r), thenn, elze, ch) - - case Not(c) => - mkBranch(c, elze, thenn, ch) - - case Variable(b) => - load(b, ch) - ch << IfEq(elze) << Goto(thenn) - - case Equals(l,r) => - mkExpr(l, ch) - mkExpr(r, ch) - l.getType match { - case ValueType() => - ch << If_ICmpEq(thenn) << Goto(elze) - - case _ => - ch << InvokeVirtual(s"$ObjectClass", "equals", s"(L$ObjectClass;)Z") - ch << IfEq(elze) << Goto(thenn) - } - - case LessThan(l,r) => - mkExpr(l, ch) - mkExpr(r, ch) - l.getType match { - case Int32Type | CharType => - ch << If_ICmpLt(thenn) << Goto(elze) - case IntegerType => - ch << InvokeVirtual(BigIntClass, "lessThan", s"(L$BigIntClass;)Z") - ch << IfEq(elze) << Goto(thenn) - case RealType => - ch << InvokeVirtual(RationalClass, "lessThan", s"(L$RationalClass;)Z") - ch << IfEq(elze) << Goto(thenn) - } - - case GreaterThan(l,r) => - mkExpr(l, ch) - mkExpr(r, ch) - l.getType match { - case Int32Type | CharType => - ch << If_ICmpGt(thenn) << Goto(elze) - case IntegerType => - ch << InvokeVirtual(BigIntClass, "greaterThan", s"(L$BigIntClass;)Z") - ch << IfEq(elze) << Goto(thenn) - case RealType => - ch << InvokeVirtual(RationalClass, "greaterThan", s"(L$RationalClass;)Z") - ch << IfEq(elze) << Goto(thenn) - } - - case LessEquals(l,r) => - mkExpr(l, ch) - mkExpr(r, ch) - l.getType match { - case Int32Type | CharType => - ch << If_ICmpLe(thenn) << Goto(elze) - case IntegerType => - ch << InvokeVirtual(BigIntClass, "lessEquals", s"(L$BigIntClass;)Z") - ch << IfEq(elze) << Goto(thenn) - case RealType => - ch << InvokeVirtual(RationalClass, "lessEquals", s"(L$RationalClass;)Z") - ch << IfEq(elze) << Goto(thenn) - } - - case GreaterEquals(l,r) => - mkExpr(l, ch) - mkExpr(r, ch) - l.getType match { - case Int32Type | CharType => - ch << If_ICmpGe(thenn) << Goto(elze) - case IntegerType => - ch << InvokeVirtual(BigIntClass, "greaterEquals", s"(L$BigIntClass;)Z") - ch << IfEq(elze) << Goto(thenn) - case RealType => - ch << InvokeVirtual(RationalClass, "greaterEquals", s"(L$RationalClass;)Z") - ch << IfEq(elze) << Goto(thenn) - } - - case IfExpr(c, t, e) => - val innerThen = ch.getFreshLabel("then") - val innerElse = ch.getFreshLabel("else") - mkBranch(c, innerThen, innerElse, ch) - ch << Label(innerThen) - mkBranch(t, thenn, elze, ch) - ch << Label(innerElse) - mkBranch(e, thenn, elze, ch) - - case cci@IsInstanceOf(cct, e) => - mkExpr(cci, ch) - ch << IfEq(elze) << Goto(thenn) - - case other if canDelegateToMkExpr => - mkExpr(other, ch, canDelegateToMkBranch = false) - ch << IfEq(elze) << Goto(thenn) - - case other => throw CompilationException("Unsupported branching expr. : " + other) - } - } - - private def load(id: Identifier, ch: CodeHandler)(implicit locals: Locals): Unit = { - locals.varToArg(id) match { - case Some(slot) => - ch << ALoad(1) << Ldc(slot) << AALOAD - mkUnbox(id.getType, ch) - case None => locals.varToField(id) match { - case Some((afName, nme, tpe)) => - ch << ALoad(0) << GetField(afName, nme, tpe) - case None => locals.varToLocal(id) match { - case Some(slot) => - val instr = id.getType match { - case ValueType() => ILoad(slot) - case _ => ALoad(slot) - } - ch << instr - case None => throw CompilationException("Unknown variable : " + id) - } - } - } - } - - /** Compiles a lazy field. - * - * To define a lazy field, we have to add an accessor method and an underlying field. - * The accessor method has the name of the original (Scala) lazy field and can be public. - * The underlying field has a different name, is private, and is of a boxed type - * to support null value (to signify uninitialized). - * - * @param lzy The lazy field to be compiled - * @param owner The module/class containing `lzy` - */ - def compileLazyField(lzy: FunDef, owner: Definition) { - ctx.reporter.internalAssertion(lzy.canBeLazyField, s"Trying to compile non-lazy ${lzy.id.name} as a lazy field") - - val (_, accessorName, _ ) = leonFunDefToJVMInfo(lzy).get - val cf = classes(owner) - val cName = defToJVMName(owner) - - val isStatic = owner.isInstanceOf[ModuleDef] - - // Name of the underlying field - val underlyingName = underlyingField(accessorName) - // Underlying field is of boxed type - val underlyingType = typeToJVMBoxed(lzy.returnType) - - // Underlying field. It is of a boxed type - val fh = cf.addField(underlyingType,underlyingName) - fh.setFlags( if (isStatic) {( - FIELD_ACC_STATIC | - FIELD_ACC_PRIVATE - ).asInstanceOf[U2] } else { - FIELD_ACC_PRIVATE - }) // FIXME private etc? - - // accessor method - locally { - val parameters = Seq(monitorID -> s"L$MonitorClass;") - - val paramMapping = parameters.map(_._1).zipWithIndex.toMap.mapValues(_ + (if (isStatic) 0 else 1)) - val newLocs = NoLocals.withVars(paramMapping) - - val accM = cf.addMethod(typeToJVM(lzy.returnType), accessorName, parameters.map(_._2) : _*) - accM.setFlags( if (isStatic) {( - METHOD_ACC_STATIC | // FIXME other flags? Not always public? - METHOD_ACC_PUBLIC - ).asInstanceOf[U2] } else { - METHOD_ACC_PUBLIC - }) - val ch = accM.codeHandler - val body = lzy.body.getOrElse(throw CompilationException("Lazy field without body?")) - val initLabel = ch.getFreshLabel("isInitialized") - - if (isStatic) { - ch << GetStatic(cName, underlyingName, underlyingType) - } else { - ch << ALoad(0) << GetField(cName, underlyingName, underlyingType) // if (lzy == null) - } - // oldValue - ch << DUP << IfNonNull(initLabel) - // null - ch << POP - // - mkBoxedExpr(body,ch)(newLocs) // lzy = <expr> - ch << DUP - // newValue, newValue - if (isStatic) { - ch << PutStatic(cName, underlyingName, underlyingType) - //newValue - } - else { - ch << ALoad(0) << SWAP - // newValue, object, newValue - ch << PutField (cName, underlyingName, underlyingType) - //newValue - } - ch << Label(initLabel) // return lzy - //newValue - lzy.returnType match { - case ValueType() => - // Since the underlying field only has boxed types, we have to unbox them to return them - mkUnbox(lzy.returnType, ch) - ch << IRETURN - case _ => - ch << ARETURN - } - ch.freeze - } - } - - /** Compile the (strict) field `field` which is owned by class `owner` */ - def compileStrictField(field : FunDef, owner : Definition) = { - - ctx.reporter.internalAssertion(field.canBeStrictField, - s"Trying to compile ${field.id.name} as a strict field") - val (_, fieldName, _) = leonFunDefToJVMInfo(field).get - - val cf = classes(owner) - val fh = cf.addField(typeToJVM(field.returnType),fieldName) - fh.setFlags( owner match { - case _ : ModuleDef => ( - FIELD_ACC_STATIC | - FIELD_ACC_PUBLIC | // FIXME - FIELD_ACC_FINAL - ).asInstanceOf[U2] - case _ => ( - FIELD_ACC_PUBLIC | // FIXME - FIELD_ACC_FINAL - ).asInstanceOf[U2] - }) - } - - /** Initializes a lazy field to null - * @param ch the codehandler to add the initializing code to - * @param className the name of the class in which the field is initialized - * @param lzy the lazy field to be initialized - * @param isStatic true if this is a static field - */ - def initLazyField(ch: CodeHandler, className: String, lzy: FunDef, isStatic: Boolean)(implicit locals: Locals) = { - val (_, name, _) = leonFunDefToJVMInfo(lzy).get - val underlyingName = underlyingField(name) - val jvmType = typeToJVMBoxed(lzy.returnType) - if (isStatic){ - ch << ACONST_NULL << PutStatic(className, underlyingName, jvmType) - } else { - ch << ALoad(0) << ACONST_NULL << PutField(className, underlyingName, jvmType) - } - } - - /** Initializes a (strict) field - * @param ch the codehandler to add the initializing code to - * @param className the name of the class in which the field is initialized - * @param field the field to be initialized - * @param isStatic true if this is a static field - */ - def initStrictField(ch: CodeHandler, className: String, field: FunDef, isStatic: Boolean)(implicit locals: Locals) { - val (_, name , _) = leonFunDefToJVMInfo(field).get - val body = field.body.getOrElse(throw CompilationException("No body for field?")) - val jvmType = typeToJVM(field.returnType) - - mkExpr(body, ch) - - if (isStatic){ - ch << PutStatic(className, name, jvmType) - } else { - ch << ALoad(0) << SWAP << PutField (className, name, jvmType) - } - } - - def compileAbstractClassDef(acd: AbstractClassDef) { - - val cName = defToJVMName(acd) - - val cf = classes(acd) - - cf.setFlags(( - CLASS_ACC_SUPER | - CLASS_ACC_PUBLIC | - CLASS_ACC_ABSTRACT - ).asInstanceOf[U2]) - - cf.addInterface(CaseClassClass) - - // add special monitor for method invocations - if (params.doInstrument) { - val fh = cf.addField("I", instrumentedField) - fh.setFlags(FIELD_ACC_PUBLIC) - } - - val (fields, methods) = acd.methods partition { _.canBeField } - val (strictFields, lazyFields) = fields partition { _.canBeStrictField } - - // Compile methods - for (method <- methods) { - compileFunDef(method,acd) - } - - // Compile lazy fields - for (lzy <- lazyFields) { - compileLazyField(lzy, acd) - } - - // Compile strict fields - for (field <- strictFields) { - compileStrictField(field, acd) - } - - // definition of the constructor - locally { - val constrParams = Seq(monitorID -> s"L$MonitorClass;") - - val newLocs = NoLocals.withVars { - constrParams.map(_._1).zipWithIndex.toMap.mapValues(_ + 1) - } - - val cch = cf.addConstructor(constrParams.map(_._2) : _*).codeHandler - - for (lzy <- lazyFields) { initLazyField(cch, cName, lzy, isStatic = false)(newLocs) } - for (field <- strictFields) { initStrictField(cch, cName, field, isStatic = false)(newLocs) } - - // Call parent constructor - cch << ALoad(0) - acd.parent match { - case Some(parent) => - val pName = defToJVMName(parent.classDef) - // Load monitor object - cch << ALoad(1) - val constrSig = "(L" + MonitorClass + ";)V" - cch << InvokeSpecial(pName, constructorName, constrSig) - - case None => - // Call constructor of java.lang.Object - cch << InvokeSpecial(ObjectClass, constructorName, "()V") - } - - // Initialize special monitor field - if (params.doInstrument) { - cch << ALoad(0) - cch << Ldc(0) - cch << PutField(cName, instrumentedField, "I") - } - - cch << RETURN - cch.freeze - } - - } - - /** - * Instrument read operations - */ - val instrumentedField = "__read" - - def instrumentedGetField(ch: CodeHandler, cct: ClassType, id: Identifier)(implicit locals: Locals): Unit = { - val ccd = cct.classDef - ccd.fields.zipWithIndex.find(_._1.id == id) match { - case Some((f, i)) => - val expType = cct.fields(i).getType - - val cName = defToJVMName(ccd) - if (params.doInstrument) { - ch << DUP << DUP - ch << GetField(cName, instrumentedField, "I") - ch << Ldc(1) - ch << Ldc(i) - ch << ISHL - ch << IOR - ch << PutField(cName, instrumentedField, "I") - } - ch << GetField(cName, f.id.name, typeToJVM(f.getType)) - - f.getType match { - case TypeParameter(_) => - mkUnbox(expType, ch) - case _ => - } - case None => - throw CompilationException("Unknown field: "+ccd.id.name+"."+id) - } - } - - def compileCaseClassDef(ccd: CaseClassDef) { - val cName = defToJVMName(ccd) - val pName = ccd.parent.map(parent => defToJVMName(parent.classDef)) - // An instantiation of ccd with its own type parameters - val cct = CaseClassType(ccd, ccd.tparams.map(_.tp)) - - val cf = classes(ccd) - - cf.setFlags(( - CLASS_ACC_SUPER | - CLASS_ACC_PUBLIC | - CLASS_ACC_FINAL - ).asInstanceOf[U2]) - - if (ccd.parent.isEmpty) { - cf.addInterface(CaseClassClass) - } - - // Case class parameters - val fieldsTypes = ccd.fields.map { vd => (vd.id, typeToJVM(vd.getType)) } - val tpeParam = if (ccd.tparams.isEmpty) Seq() else Seq(tpsID -> "[I") - val constructorArgs = (monitorID -> s"L$MonitorClass;") +: (tpeParam ++ fieldsTypes) - - val newLocs = NoLocals.withFields(constructorArgs.map { - case (id, jvmt) => (id, (cName, id.name, jvmt)) - }.toMap) - - locally { - val (fields, methods) = ccd.methods partition { _.canBeField } - val (strictFields, lazyFields) = fields partition { _.canBeStrictField } - - // Compile methods - for (method <- methods) { - compileFunDef(method, ccd) - } - - // Compile lazy fields - for (lzy <- lazyFields) { - compileLazyField(lzy, ccd) - } - - // Compile strict fields - for (field <- strictFields) { - compileStrictField(field, ccd) - } - - // definition of the constructor - for ((id, jvmt) <- constructorArgs) { - val fh = cf.addField(jvmt, id.name) - fh.setFlags(( - FIELD_ACC_PUBLIC | - FIELD_ACC_FINAL - ).asInstanceOf[U2]) - } - - if (params.doInstrument) { - val fh = cf.addField("I", instrumentedField) - fh.setFlags(FIELD_ACC_PUBLIC) - } - - val cch = cf.addConstructor(constructorArgs.map(_._2) : _*).codeHandler - - if (params.doInstrument) { - cch << ALoad(0) - cch << Ldc(0) - cch << PutField(cName, instrumentedField, "I") - } - - var c = 1 - for ((id, jvmt) <- constructorArgs) { - cch << ALoad(0) - cch << (jvmt match { - case "I" | "Z" => ILoad(c) - case _ => ALoad(c) - }) - cch << PutField(cName, id.name, jvmt) - c += 1 - } - - // Call parent constructor AFTER initializing case class parameters - if (ccd.parent.isDefined) { - cch << ALoad(0) - cch << ALoad(1) - cch << InvokeSpecial(pName.get, constructorName, s"(L$MonitorClass;)V") - } else { - // Call constructor of java.lang.Object - cch << ALoad(0) - cch << InvokeSpecial(ObjectClass, constructorName, "()V") - } - - // Now initialize fields - for (lzy <- lazyFields) { initLazyField(cch, cName, lzy, isStatic = false)(newLocs) } - for (field <- strictFields) { initStrictField(cch, cName , field, isStatic = false)(newLocs) } - - // Finally check invariant (if it exists) - if (params.checkContracts && ccd.hasInvariant) { - val skip = cch.getFreshLabel("skip_invariant") - load(monitorID, cch)(newLocs) - cch << ALoad(0) - cch << Ldc(registerType(cct)) - - cch << InvokeVirtual(MonitorClass, "invariantCheck", s"(L$ObjectClass;I)Z") - cch << IfEq(skip) - - load(monitorID, cch)(newLocs) - cch << ALoad(0) - cch << Ldc(registerType(cct)) - - val thisId = FreshIdentifier("this", cct, true) - val invLocals = newLocs.withVar(thisId -> 0) - mkExpr(FunctionInvocation(cct.invariant.get, Seq(Variable(thisId))), cch)(invLocals) - cch << InvokeVirtual(MonitorClass, "invariantResult", s"(L$ObjectClass;IZ)V") - cch << Label(skip) - } - - cch << RETURN - cch.freeze - } - - locally { - val pnm = cf.addMethod("I", "__getRead") - pnm.setFlags(( - METHOD_ACC_PUBLIC | - METHOD_ACC_FINAL - ).asInstanceOf[U2]) - - val pnch = pnm.codeHandler - - pnch << ALoad(0) << GetField(cName, instrumentedField, "I") << IRETURN - - pnch.freeze - } - - locally { - val pnm = cf.addMethod("Ljava/lang/String;", "productName") - pnm.setFlags(( - METHOD_ACC_PUBLIC | - METHOD_ACC_FINAL - ).asInstanceOf[U2]) - - val pnch = pnm.codeHandler - - pnch << Ldc(cName) << ARETURN - - pnch.freeze - } - - locally { - val pem = cf.addMethod(s"[L$ObjectClass;", "productElements") - pem.setFlags(( - METHOD_ACC_PUBLIC | - METHOD_ACC_FINAL - ).asInstanceOf[U2]) - - val pech = pem.codeHandler - - pech << Ldc(ccd.fields.size) - pech << NewArray(ObjectClass) - - for ((f, i) <- ccd.fields.zipWithIndex) { - pech << DUP - pech << Ldc(i) - pech << ALoad(0) - instrumentedGetField(pech, cct, f.id)(newLocs) - mkBox(f.getType, pech) - pech << AASTORE - } - - pech << ARETURN - pech.freeze - } - - // definition of equals - locally { - val emh = cf.addMethod("Z", "equals", s"L$ObjectClass;") - emh.setFlags(( - METHOD_ACC_PUBLIC | - METHOD_ACC_FINAL - ).asInstanceOf[U2]) - - val ech = emh.codeHandler - - val notRefEq = ech.getFreshLabel("notrefeq") - val notEq = ech.getFreshLabel("noteq") - val castSlot = ech.getFreshVar - - // If references are equal, trees are equal. - ech << ALoad(0) << ALoad(1) << If_ACmpNe(notRefEq) << Ldc(1) << IRETURN << Label(notRefEq) - - // We check the type (this also checks against null).... - ech << ALoad(1) << InstanceOf(cName) << IfEq(notEq) - - // ...finally, we compare fields one by one, shortcircuiting on disequalities. - if(ccd.fields.nonEmpty) { - ech << ALoad(1) << CheckCast(cName) << AStore(castSlot) - - for(vd <- ccd.fields) { - ech << ALoad(0) - instrumentedGetField(ech, cct, vd.id)(newLocs) - ech << ALoad(castSlot) - instrumentedGetField(ech, cct, vd.id)(newLocs) - - typeToJVM(vd.getType) match { - case "I" | "Z" => - ech << If_ICmpNe(notEq) - - case ot => - ech << InvokeVirtual(ObjectClass, "equals", s"(L$ObjectClass;)Z") << IfEq(notEq) - } - } - } - - ech << Ldc(1) << IRETURN << Label(notEq) << Ldc(0) << IRETURN - ech.freeze - } - - // definition of hashcode - locally { - val hashFieldName = "$leon$hashCode" - cf.addField("I", hashFieldName).setFlags(FIELD_ACC_PRIVATE) - val hmh = cf.addMethod("I", "hashCode", "") - hmh.setFlags(( - METHOD_ACC_PUBLIC | - METHOD_ACC_FINAL - ).asInstanceOf[U2]) - - val hch = hmh.codeHandler - - val wasNotCached = hch.getFreshLabel("wasNotCached") - - hch << ALoad(0) << GetField(cName, hashFieldName, "I") << DUP - hch << IfEq(wasNotCached) - hch << IRETURN - hch << Label(wasNotCached) << POP - hch << ALoad(0) << InvokeVirtual(cName, "productElements", s"()[L$ObjectClass;") - hch << ALoad(0) << InvokeVirtual(cName, "productName", "()Ljava/lang/String;") - hch << InvokeVirtual("java/lang/String", "hashCode", "()I") - hch << InvokeStatic(HashingClass, "seqHash", s"([L$ObjectClass;I)I") << DUP - hch << ALoad(0) << SWAP << PutField(cName, hashFieldName, "I") - hch << IRETURN - - hch.freeze - } - - } -} diff --git a/src/main/scala/leon/codegen/CompilationException.scala b/src/main/scala/leon/codegen/CompilationException.scala deleted file mode 100644 index a782e5c1f9a9d786a4e594e0b386dd333900945c..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/codegen/CompilationException.scala +++ /dev/null @@ -1,8 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package codegen - -case class CompilationException(msg : String) extends Exception { - override def getMessage = msg -} diff --git a/src/main/scala/leon/codegen/CompilationUnit.scala b/src/main/scala/leon/codegen/CompilationUnit.scala deleted file mode 100644 index 1eac956d238a9256d42277c3b15aae0e8b884672..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/codegen/CompilationUnit.scala +++ /dev/null @@ -1,606 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package codegen - -import purescala.Common._ -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Types._ -import purescala.TypeOps.typeParamsOf -import purescala.Extractors._ -import purescala.Constructors._ -import utils.UniqueCounter -import runtime.{Monitor, StdMonitor} - -import cafebabe._ -import cafebabe.AbstractByteCodes._ -import cafebabe.ByteCodes._ -import cafebabe.ClassFileTypes._ -import cafebabe.Flags._ - -import scala.collection.JavaConverters._ - -import java.lang.reflect.Constructor - -import synthesis.Problem -import evaluators._ - -class CompilationUnit(val ctx: LeonContext, - val program: Program, - val bank: EvaluationBank = new EvaluationBank, - val params: CodeGenParams = CodeGenParams.default) extends CodeGeneration { - - - protected[codegen] val requireQuantification = program.definedFunctions.exists { fd => - exists { case _: Forall => true case _ => false } (fd.fullBody) - } - - val loader = new CafebabeClassLoader(classOf[CompilationUnit].getClassLoader) - - var classes = Map[Definition, ClassFile]() - - var defToModuleOrClass = Map[Definition, Definition]() - - val abstractFunDefs = program.definedFunctions.filter(_.body.isEmpty).map(_.id).toSet - - val runtimeCounter = new UniqueCounter[Unit] - - var runtimeTypeToIdMap = Map[TypeTree, Int]() - var runtimeIdToTypeMap = Map[Int, TypeTree]() - def registerType(tpe: TypeTree): Int = runtimeTypeToIdMap.get(tpe) match { - case Some(id) => id - case None => - val id = runtimeCounter.nextGlobal - runtimeTypeToIdMap += tpe -> id - runtimeIdToTypeMap += id -> tpe - id - } - - var runtimeProblemMap = Map[Int, (Seq[TypeParameter], Problem)]() - - def registerProblem(p: Problem, tps: Seq[TypeParameter]): Int = { - val id = runtimeCounter.nextGlobal - runtimeProblemMap += id -> (tps, p) - id - } - - var runtimeForallMap = Map[Int, (Seq[TypeParameter], Forall)]() - - def registerForall(f: Forall, tps: Seq[TypeParameter]): Int = { - val id = runtimeCounter.nextGlobal - runtimeForallMap += id -> (tps, f) - id - } - - var runtimeAbstractMap = Map[Int, FunDef]() - - def registerAbstractFD(fd: FunDef): Int = { - val id = runtimeCounter.nextGlobal - runtimeAbstractMap += id -> fd - id - } - - def defineClass(df: Definition): Unit = { - val cName = defToJVMName(df) - - val cf = df match { - case cd: ClassDef => - val pName = cd.parent.map(parent => defToJVMName(parent.classDef)) - new ClassFile(cName, pName) - - case ob: ModuleDef => - new ClassFile(cName, None) - - case _ => - sys.error("Unhandled definition type") - } - - classes += df -> cf - } - - def jvmClassToLeonClass(name: String): Option[Definition] = { - classes.find(_._2.className == name).map(_._1) - } - - def leonClassToJVMInfo(cd: ClassDef): Option[(String, String)] = { - classes.get(cd) match { - case Some(cf) => - val tpeParam = if (cd.tparams.isEmpty) "" else "[I" - val sig = "(L"+MonitorClass+";" + tpeParam + cd.fields.map(f => typeToJVM(f.getType)).mkString("") + ")V" - Some((cf.className, sig)) - case _ => None - } - } - - // Returns className, methodName, methodSignature - private[this] var funDefInfo = Map[FunDef, (String, String, String)]() - - /** - * Returns (cn, mn, sig) where - * cn is the module name - * mn is the safe method name - * sig is the method signature - */ - def leonFunDefToJVMInfo(fd: FunDef): Option[(String, String, String)] = { - funDefInfo.get(fd).orElse { - val sig = "(L"+MonitorClass+";" + - (if (fd.tparams.nonEmpty) "[I" else "") + - fd.params.map(a => typeToJVM(a.getType)).mkString("") + ")" + typeToJVM(fd.returnType) - - defToModuleOrClass.get(fd).flatMap(m => classes.get(m)) match { - case Some(cf) => - val res = (cf.className, idToSafeJVMName(fd.id), sig) - funDefInfo += fd -> res - Some(res) - case None => - None - } - } - } - - // Get the Java constructor corresponding to the Case class - private[this] var ccdConstructors = Map[CaseClassDef, Constructor[_]]() - - private[this] def caseClassConstructor(ccd: CaseClassDef): Option[Constructor[_]] = { - ccdConstructors.get(ccd).orElse { - classes.get(ccd) match { - case Some(cf) => - val klass = loader.loadClass(cf.className) - // This is a hack: we pick the constructor with the most arguments. - val conss = klass.getConstructors.sortBy(_.getParameterTypes.length) - assert(conss.nonEmpty) - val cons = conss.last - - ccdConstructors += ccd -> cons - Some(cons) - case None => - None - } - } - } - - private[this] lazy val tupleConstructor: Constructor[_] = { - val tc = loader.loadClass("leon.codegen.runtime.Tuple") - val conss = tc.getConstructors.sortBy(_.getParameterTypes.length) - assert(conss.nonEmpty) - conss.last - } - - def getMonitor(model: solvers.Model, maxInvocations: Int): Monitor = { - val bodies = model.toSeq.filter { case (id, v) => abstractFunDefs(id) }.toMap - val domains = model match { - case hm: solvers.PartialModel => Some(hm.domains) - case _ => None - } - - new StdMonitor(this, maxInvocations, bodies, domains) - } - - /** Translates Leon values (not generic expressions) to JVM compatible objects. - * - * Currently, this method is only used to prepare arguments to reflective calls. - * This means it is safe to return AnyRef (as opposed to primitive types), because - * reflection needs this anyway. - */ - def valueToJVM(e: Expr)(implicit monitor: Monitor): AnyRef = e match { - case IntLiteral(v) => - new java.lang.Integer(v) - - case BooleanLiteral(v) => - new java.lang.Boolean(v) - - case UnitLiteral() => - new java.lang.Boolean(true) - - case CharLiteral(c) => - new Character(c) - - case InfiniteIntegerLiteral(v) => - new runtime.BigInt(v.toString) - - case FractionalLiteral(n, d) => - new runtime.Rational(n.toString, d.toString) - - case StringLiteral(v) => - new java.lang.String(v) - - case GenericValue(tp, id) => - e - - case Tuple(elems) => - tupleConstructor.newInstance(elems.map(valueToJVM).toArray).asInstanceOf[AnyRef] - - case CaseClass(cct, args) => - caseClassConstructor(cct.classDef) match { - case Some(cons) => - try { - val tpeParam = if (cct.tps.isEmpty) Seq() else Seq(cct.tps.map(registerType).toArray) - val jvmArgs = monitor +: (tpeParam ++ args.map(valueToJVM)) - cons.newInstance(jvmArgs.toArray : _*).asInstanceOf[AnyRef] - } catch { - case e : java.lang.reflect.InvocationTargetException => throw e.getCause - } - case None => - ctx.reporter.fatalError("Case class constructor not found?!?") - } - - // For now, we only treat boolean arrays separately. - // We have a use for these, mind you. - //case f @ FiniteArray(exprs) if f.getType == ArrayType(BooleanType) => - // exprs.map(e => exprToJVM(e).asInstanceOf[java.lang.Boolean].booleanValue).toArray - - case s @ FiniteSet(els, _) => - val s = new leon.codegen.runtime.Set() - for (e <- els) { - s.add(valueToJVM(e)) - } - s - - case b @ FiniteBag(els, _) => - val b = new leon.codegen.runtime.Bag() - for ((k,v) <- els) { - b.add(valueToJVM(k), valueToJVM(v).asInstanceOf[leon.codegen.runtime.BigInt]) - } - b - - case m @ FiniteMap(els, _, _) => - val m = new leon.codegen.runtime.Map() - for ((k,v) <- els) { - m.add(valueToJVM(k), valueToJVM(v)) - } - m - - case f @ FiniteLambda(mapping, dflt, _) => - val l = new leon.codegen.runtime.FiniteLambda(valueToJVM(dflt)) - - for ((ks,v) <- mapping) { - // Force tuple even with 1/0 elems. - val kJvm = tupleConstructor.newInstance(ks.map(valueToJVM).toArray).asInstanceOf[leon.codegen.runtime.Tuple] - val vJvm = valueToJVM(v) - l.add(kJvm,vJvm) - } - l - - case l @ Lambda(args, body) => - val (afName, closures, tparams, consSig) = compileLambda(l) - val args = closures.map { case (id, _) => - if (id == monitorID) monitor - else if (id == tpsID) typeParamsOf(l).toSeq.sortBy(_.id.uniqueName).map(registerType).toArray - else throw CompilationException(s"Unexpected closure $id in Lambda compilation") - } - - val lc = loader.loadClass(afName) - val conss = lc.getConstructors.sortBy(_.getParameterTypes.length) - assert(conss.nonEmpty) - val lambdaConstructor = conss.last - lambdaConstructor.newInstance(args.toArray : _*).asInstanceOf[AnyRef] - - case f @ IsTyped(FiniteArray(elems, default, IntLiteral(length)), ArrayType(underlying)) => - if (length < 0) { - throw LeonFatalError( - s"Whoops! Array ${f.asString(ctx)} has length $length. " + - default.map { df => s"default: ${df.asString(ctx)}" }.getOrElse("") - ) - } - - import scala.reflect.ClassTag - - def allocArray[A: ClassTag](f: Expr => A): Array[A] = { - val arr = new Array[A](length) - for { - df <- default.toSeq - v = f(df) - i <- 0 until length - } { - arr(i) = v - } - for ((ind, v) <- elems) { - arr(ind) = f(v) - } - arr - - } - - underlying match { - case Int32Type => - allocArray { case IntLiteral(v) => v } - case BooleanType => - allocArray { case BooleanLiteral(b) => b } - case UnitType => - allocArray { case UnitLiteral() => true } - case CharType => - allocArray { case CharLiteral(c) => c } - case _ => - allocArray(valueToJVM) - } - - case _ => - throw CompilationException(s"Unexpected expression $e in valueToJVM") - } - - /** Translates JVM objects back to Leon values of the appropriate type */ - def jvmToValue(e: AnyRef, tpe: TypeTree): Expr = (e, tpe) match { - case (i: Integer, Int32Type) => - IntLiteral(i.toInt) - - case (c: runtime.BigInt, IntegerType) => - InfiniteIntegerLiteral(BigInt(c.underlying)) - - case (c: runtime.Rational, RealType) => - val num = BigInt(c.numerator()) - val denom = BigInt(c.denominator()) - FractionalLiteral(num, denom) - - case (b: java.lang.Boolean, BooleanType) => - BooleanLiteral(b.booleanValue) - - case (c: java.lang.Character, CharType) => - CharLiteral(c.toChar) - - case (c: java.lang.String, StringType) => - StringLiteral(c) - - case (cc: runtime.CaseClass, ct: ClassType) => - val fields = cc.productElements() - - // identify case class type of ct - val cct = ct match { - case cc: CaseClassType => - cc - - case _ => - jvmClassToLeonClass(cc.getClass.getName) match { - case Some(cc: CaseClassDef) => - CaseClassType(cc, ct.tps) - case _ => - throw CompilationException("Unable to identify class "+cc.getClass.getName+" to descendant of "+ct) - } - } - - CaseClass(cct, (fields zip cct.fieldsTypes).map { case (e, tpe) => jvmToValue(e, tpe) }) - - case (tpl: runtime.Tuple, tpe) => - val stpe = unwrapTupleType(tpe, tpl.getArity) - val elems = stpe.zipWithIndex.map { case (tpe, i) => - jvmToValue(tpl.get(i), tpe) - } - tupleWrap(elems) - - case (gv @ GenericValue(gtp, id), tp: TypeParameter) => - if (gtp == tp) gv - else GenericValue(tp, id).copiedFrom(gv) - - case (set: runtime.Set, SetType(b)) => - FiniteSet(set.getElements.asScala.map(jvmToValue(_, b)).toSet, b) - - case (bag: runtime.Bag, BagType(b)) => - FiniteBag(bag.getElements.asScala.map { entry => - val k = jvmToValue(entry.getKey, b) - val v = jvmToValue(entry.getValue, IntegerType) - (k, v) - }.toMap, b) - - case (map: runtime.Map, MapType(from, to)) => - val pairs = map.getElements.asScala.map { entry => - val k = jvmToValue(entry.getKey, from) - val v = jvmToValue(entry.getValue, to) - (k, v) - }.toMap - FiniteMap(pairs, from, to) - - case (lambda: runtime.FiniteLambda, ft @ FunctionType(from, to)) => - val mapping = lambda.mapping.asScala.map { entry => - val k = jvmToValue(entry._1, tupleTypeWrap(from)) - val v = jvmToValue(entry._2, to) - unwrapTuple(k, from.size) -> v - } - val dflt = jvmToValue(lambda.dflt, to) - FiniteLambda(mapping.toSeq, dflt, ft) - - case (lambda: runtime.Lambda, _: FunctionType) => - val cls = lambda.getClass - - val l = classToLambda(cls.getName) - val closures = purescala.ExprOps.variablesOf(l).toSeq.sortBy(_.uniqueName) - val closureVals = closures.map { id => - val fieldVal = lambda.getClass.getField(id.uniqueName).get(lambda) - jvmToValue(fieldVal, id.getType) - } - - purescala.ExprOps.replaceFromIDs((closures zip closureVals).toMap, l) - - case (_, UnitType) => - UnitLiteral() - - case (ar: Array[_], ArrayType(base)) => - if (ar.length == 0) { - EmptyArray(base) - } else { - val elems = for ((e: AnyRef, i) <- ar.zipWithIndex) yield { - i -> jvmToValue(e, base) - } - - NonemptyArray(elems.toMap, None) - } - - case _ => - throw CompilationException("Unsupported return value : " + e.getClass +" while expecting "+tpe) - } - - - def compileExpression(e: Expr, args: Seq[Identifier])(implicit ctx: LeonContext): CompiledExpression = { - if(e.getType == Untyped) { - throw new Unsupported(e, s"Cannot compile untyped expression.") - } - - val id = exprCounter.nextGlobal - - val cName = "Leon$CodeGen$Expr$"+id - - val cf = new ClassFile(cName, None) - cf.setFlags(( - CLASS_ACC_PUBLIC | - CLASS_ACC_FINAL - ).asInstanceOf[U2]) - - cf.addDefaultConstructor - - val argsTypes = args.map(a => typeToJVM(a.getType)) - - val realArgs = ("L" + MonitorClass + ";") +: argsTypes - - val m = cf.addMethod( - typeToJVM(e.getType), - "eval", - realArgs : _* - ) - - m.setFlags(( - METHOD_ACC_PUBLIC | - METHOD_ACC_FINAL | - METHOD_ACC_STATIC - ).asInstanceOf[U2]) - - val ch = m.codeHandler - - val newMapping = Map(monitorID -> 0) ++ args.zipWithIndex.toMap.mapValues(_ + 1) - - mkExpr(e, ch)(NoLocals.withVars(newMapping)) - - e.getType match { - case ValueType() => - ch << IRETURN - case _ => - ch << ARETURN - } - - ch.freeze - - loader.register(cf) - - new CompiledExpression(this, cf, e, args) - } - - def compileModule(module: ModuleDef) { - val cf = classes(module) - cf.setFlags(( - CLASS_ACC_SUPER | - CLASS_ACC_PUBLIC | - CLASS_ACC_FINAL - ).asInstanceOf[U2]) - - val (fields, functions) = module.definedFunctions partition { _.canBeField } - val (strictFields, lazyFields) = fields partition { _.canBeStrictField } - - // Compile methods - for (function <- functions) { - compileFunDef(function,module) - } - - // Compile lazy fields - for (lzy <- lazyFields) { - compileLazyField(lzy, module) - } - - // Compile strict fields - for (field <- strictFields) { - compileStrictField(field, module) - } - - // Constructor - cf.addDefaultConstructor - - val cName = defToJVMName(module) - - // Add class initializer method - locally{ - val mh = cf.addMethod("V", "<clinit>") - mh.setFlags(( - METHOD_ACC_STATIC | - METHOD_ACC_PUBLIC - ).asInstanceOf[U2]) - - val ch = mh.codeHandler - /* - * FIXME : - * Dirty hack to make this compatible with monitoring of method invocations. - * Because we don't have access to the monitor object here, we initialize a new one - * that will get lost when this method returns, so we can't hope to count - * method invocations here :( - */ - val locals = NoLocals.withVar(monitorID -> ch.getFreshVar) - ch << New(NoMonitorClass) << DUP - ch << InvokeSpecial(NoMonitorClass, cafebabe.Defaults.constructorName, "()V") - ch << AStore(locals.varToLocal(monitorID).get) // position 0 - - for (lzy <- lazyFields) { initLazyField(ch, cName, lzy, isStatic = true)(locals) } - for (field <- strictFields) { initStrictField(ch, cName , field, isStatic = true)(locals) } - ch << RETURN - ch.freeze - } - - } - - /** Traverses the program to find all definitions, and stores those in global variables */ - def init() { - // First define all classes/ methods/ functions - for (u <- program.units) { - - for { - ch <- u.classHierarchies - cls <- ch - } { - defineClass(cls) - for (meth <- cls.methods) { - defToModuleOrClass += meth -> cls - } - } - - for (m <- u.modules) { - defineClass(m) - for (funDef <- m.definedFunctions) { - defToModuleOrClass += funDef -> m - } - } - } - } - - /** Compiles the program. - * - * Uses information provided by [[init]]. - */ - def compile() { - // Compile everything - for (u <- program.units) { - - for { - ch <- u.classHierarchies - c <- ch - } c match { - case acd: AbstractClassDef => - compileAbstractClassDef(acd) - case ccd: CaseClassDef => - compileCaseClassDef(ccd) - } - - for (m <- u.modules) compileModule(m) - } - - classes.values.foreach(loader.register) - } - - def writeClassFiles(prefix: String) { - for ((d, cl) <- classes) { - cl.writeToFile(prefix+cl.className + ".class") - } - } - - init() - compile() -} - -private [codegen] object exprCounter extends UniqueCounter[Unit] -private [codegen] object forallCounter extends UniqueCounter[Unit] - diff --git a/src/main/scala/leon/codegen/CompiledExpression.scala b/src/main/scala/leon/codegen/CompiledExpression.scala deleted file mode 100644 index d6cb9a11eb611aabbf95519e1e9309c8562812f5..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/codegen/CompiledExpression.scala +++ /dev/null @@ -1,55 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package codegen - -import purescala.Common._ -import purescala.Expressions._ - -import cafebabe._ - -import runtime.Monitor - -import java.lang.reflect.InvocationTargetException - -class CompiledExpression(unit: CompilationUnit, cf: ClassFile, expression: Expr, argsDecl: Seq[Identifier]) { - - private lazy val cl = unit.loader.loadClass(cf.className) - private lazy val meth = cl.getMethods()(0) - - private val exprType = expression.getType - - private val params = unit.params - - def argsToJVM(args: Seq[Expr], monitor: Monitor): Seq[AnyRef] = { - args.map(unit.valueToJVM(_)(monitor)) - } - - def evalToJVM(args: Seq[AnyRef], monitor: Monitor): AnyRef = { - assert(args.size == argsDecl.size) - - val allArgs = monitor +: args - - meth.invoke(null, allArgs.toArray : _*) - } - - // This may throw an exception. We unwrap it if needed. - // We also need to reattach a type in some cases (sets, maps). - def evalFromJVM(args: Seq[AnyRef], monitor: Monitor) : Expr = { - try { - unit.jvmToValue(evalToJVM(args, monitor), exprType) - } catch { - case ite : InvocationTargetException => throw ite.getCause - } - } - - def eval(model: solvers.Model) : Expr = { - try { - val monitor = unit.getMonitor(model, params.maxFunctionInvocations) - - evalFromJVM(argsToJVM(argsDecl.map(model), monitor), monitor) - } catch { - case ite : InvocationTargetException => throw ite.getCause - } - } -} diff --git a/src/main/scala/leon/codegen/runtime/GenericValues.scala b/src/main/scala/leon/codegen/runtime/GenericValues.scala deleted file mode 100644 index 58ecf8547d07c3c964017d8014b5dd86ca9d6358..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/codegen/runtime/GenericValues.scala +++ /dev/null @@ -1,29 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package codegen.runtime - -import purescala.Expressions.GenericValue -import scala.collection.immutable.{Map => ScalaMap} - -object GenericValues { - private[this] var counter = 0 - - private[this] var gvToI = ScalaMap[GenericValue, Int]() - private[this] var iTogv = ScalaMap[Int, GenericValue]() - - def register(gv: GenericValue): Int = { - if (gvToI contains gv) { - gvToI(gv) - } else { - counter += 1 - gvToI += gv -> counter - iTogv += counter -> gv - counter - } - } - - def get(i: Int): java.lang.Object = { - iTogv(i) - } -} diff --git a/src/main/scala/leon/codegen/runtime/Monitor.scala b/src/main/scala/leon/codegen/runtime/Monitor.scala deleted file mode 100644 index efe230e8df471d05702d410bec41b9631f4dcddc..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/codegen/runtime/Monitor.scala +++ /dev/null @@ -1,297 +0,0 @@ -/* Copyright 2009-2015 EPFL, Lausanne */ - -package leon -package codegen.runtime - -import utils._ -import purescala.Expressions.{CaseClass => LeonCaseClass, _} -import purescala.Constructors._ -import purescala.Definitions._ -import purescala.Common._ -import purescala.Types._ -import purescala.TypeOps._ -import purescala.ExprOps.{valuateWithModel, replaceFromIDs, variablesOf} -import purescala.Quantification.{extractQuorums, Domains} - -import codegen.CompilationUnit - -import scala.collection.immutable.{Map => ScalaMap} -import scala.collection.mutable.{HashMap => MutableMap, Set => MutableSet} -import scala.concurrent.duration._ - -import solvers.{SolverContext, SolverFactory} -import solvers.unrolling.UnrollingProcedure - -import evaluators._ -import synthesis._ - -abstract class Monitor { - def onInvocation(): Unit - - def typeParams(params: Array[Int], tps: Array[Int], newTps: Array[Int]): Array[Int] - - def invariantCheck(obj: AnyRef, tpeIdx: Int): Boolean - - def invariantResult(obj: AnyRef, tpeIdx: Int, result: Boolean): Unit - - def onAbstractInvocation(id: Int, tps: Array[Int], args: Array[AnyRef]): AnyRef - - def onChooseInvocation(id: Int, tps: Array[Int], args: Array[AnyRef]): AnyRef - - def onForallInvocation(id: Int, tps: Array[Int], args: Array[AnyRef]): Boolean -} - -class NoMonitor extends Monitor { - def onInvocation(): Unit = {} - - def typeParams(params: Array[Int], tps: Array[Int], newTps: Array[Int]): Array[Int] = { - throw new LeonCodeGenEvaluationException("No monitor available.") - } - - def invariantCheck(obj: AnyRef, tpeIdx: Int): Boolean = { - throw new LeonCodeGenEvaluationException("No monitor available.") - } - - def invariantResult(obj: AnyRef, tpeIdx: Int, result: Boolean): Unit = { - throw new LeonCodeGenEvaluationException("No monitor available.") - } - - def onAbstractInvocation(id: Int, tps: Array[Int], args: Array[AnyRef]): AnyRef = { - throw new LeonCodeGenEvaluationException("No monitor available.") - } - - def onChooseInvocation(id: Int, tps: Array[Int], args: Array[AnyRef]): AnyRef = { - throw new LeonCodeGenEvaluationException("No monitor available.") - } - - def onForallInvocation(id: Int, tps: Array[Int], args: Array[AnyRef]): Boolean = { - throw new LeonCodeGenEvaluationException("No monitor available.") - } -} - -class StdMonitor(unit: CompilationUnit, invocationsMax: Int, bodies: ScalaMap[Identifier, Expr], domains: Option[Domains] = None) extends Monitor { - - private[this] var invocations = 0 - - def onInvocation(): Unit = { - if (invocationsMax >= 0) { - if (invocations < invocationsMax) { - invocations += 1; - } else { - throw new LeonCodeGenEvaluationException("Maximum number of invocations reached ("+invocationsMax+")."); - } - } - } - - def invariantCheck(obj: AnyRef, tpeIdx: Int): Boolean = { - val tpe = unit.runtimeIdToTypeMap(tpeIdx) - val cc = unit.jvmToValue(obj, tpe).asInstanceOf[LeonCaseClass] - val result = unit.bank.invariantCheck(cc) - if (result.isFailure) throw new LeonCodeGenRuntimeException("ADT invariant failed @" + cc.ct.classDef.invariant.get.getPos) - else result.isRequired - } - - def invariantResult(obj: AnyRef, tpeIdx: Int, result: Boolean): Unit = { - val tpe = unit.runtimeIdToTypeMap(tpeIdx) - val cc = unit.jvmToValue(obj, tpe).asInstanceOf[LeonCaseClass] - unit.bank.invariantResult(cc, result) - if (!result) throw new LeonCodeGenRuntimeException("ADT invariant failed @" + cc.ct.classDef.invariant.get.getPos) - } - - def typeParams(params: Array[Int], tps: Array[Int], newTps: Array[Int]): Array[Int] = { - val tparams = params.toSeq.map(unit.runtimeIdToTypeMap(_).asInstanceOf[TypeParameter]) - val static = tps.toSeq.map(unit.runtimeIdToTypeMap(_)) - val newTypes = newTps.toSeq.map(unit.runtimeIdToTypeMap(_)) - val tpMap = (tparams zip newTypes).toMap - static.map(tpe => unit.registerType(instantiateType(tpe, tpMap))).toArray - } - - def onAbstractInvocation(id: Int, tps: Array[Int], args: Array[AnyRef]): AnyRef = { - val fd = unit.runtimeAbstractMap(id) - - // TODO: extract types too! - - bodies.get(fd.id) match { - case Some(expr) => - throw new LeonCodeGenRuntimeException("Found body!") - - case None => - throw new LeonCodeGenRuntimeException("Did not find body!") - } - } - - private[this] val chooseCache = new MutableMap[(Int, Seq[AnyRef]), AnyRef]() - - def onChooseInvocation(id: Int, tps: Array[Int], inputs: Array[AnyRef]): AnyRef = { - implicit val debugSection = DebugSectionSynthesis - - val (tparams, p) = unit.runtimeProblemMap(id) - - val program = unit.program - val ctx = unit.ctx - - ctx.reporter.debug("Executing choose (codegen)!") - val is = inputs.toSeq - - if (chooseCache contains ((id, is))) { - chooseCache((id, is)) - } else { - val tStart = System.currentTimeMillis - - val sctx = SolverContext(ctx, unit.bank) - val solverf = SolverFactory.getFromSettings(sctx, program).withTimeout(10.second) - val solver = solverf.getNewSolver() - - val newTypes = tps.toSeq.map(unit.runtimeIdToTypeMap(_)) - val tpMap = (tparams zip newTypes).toMap - - val newXs = p.xs.map { id => - val newTpe = instantiateType(id.getType, tpMap) - if (id.getType == newTpe) id else FreshIdentifier(id.name, newTpe, true) - } - - val newAs = p.as.map { id => - val newTpe = instantiateType(id.getType, tpMap) - if (id.getType == newTpe) id else FreshIdentifier(id.name, newTpe, true) - } - - val inputsMap = (newAs zip inputs).map { - case (id, v) => id -> unit.jvmToValue(v, id.getType) - } - - val instTpe: Expr => Expr = { - val idMap = (p.as zip newAs).toMap ++ (p.xs zip newXs) - instantiateType(_: Expr, tpMap, idMap) - } - - val expr = p.pc map instTpe withBindings inputsMap and instTpe(p.phi) - solver.assertCnstr(expr) - - try { - solver.check match { - case Some(true) => - val model = solver.getModel - - val valModel = valuateWithModel(model) _ - - val res = newXs.map(valModel) - val leonRes = tupleWrap(res) - - val total = System.currentTimeMillis-tStart - - ctx.reporter.debug("Synthesis took "+total+"ms") - ctx.reporter.debug("Finished synthesis with "+leonRes.asString(ctx)) - - val obj = unit.valueToJVM(leonRes)(this) - chooseCache += (id, is) -> obj - obj - case Some(false) => - throw new LeonCodeGenRuntimeException("Constraint is UNSAT") - case _ => - throw new LeonCodeGenRuntimeException("Timeout exceeded") - } - } finally { - solver.free() - solverf.shutdown() - } - } - } - - private[this] val forallCache = new MutableMap[(Int, Seq[AnyRef]), Boolean]() - - def onForallInvocation(id: Int, tps: Array[Int], args: Array[AnyRef]): Boolean = { - implicit val debugSection = DebugSectionVerification - - val (tparams, f) = unit.runtimeForallMap(id) - - val program = unit.program - - val newOptions = Seq( - LeonOption(UnrollingProcedure.optFeelingLucky)(false), - LeonOption(UnrollingProcedure.optSilentErrors)(true), - LeonOption(UnrollingProcedure.optCheckModels)(true) - ) - - val ctx = unit.ctx.copy(options = unit.ctx.options.filterNot { opt => - newOptions.exists(no => opt.optionDef == no.optionDef) - } ++ newOptions) - - ctx.reporter.debug("Executing forall (codegen)!") - val argsSeq = args.toSeq - - if (forallCache contains ((id, argsSeq))) { - forallCache((id, argsSeq)) - } else { - val tStart = System.currentTimeMillis - - val sctx = SolverContext(ctx, unit.bank) - val solverf = SolverFactory.getFromSettings(sctx, program).withTimeout(.5.second) - val solver = solverf.getNewSolver() - - val newTypes = tps.toSeq.map(unit.runtimeIdToTypeMap(_)) - val tpMap = (tparams zip newTypes).toMap - - val vars = variablesOf(f).toSeq.sortBy(_.uniqueName) - val newVars = vars.map(id => FreshIdentifier(id.name, instantiateType(id.getType, tpMap), true)) - - val Forall(fargs, body) = instantiateType(f, tpMap, (vars zip newVars).toMap) - val mapping = (newVars zip argsSeq).map(p => p._1 -> unit.jvmToValue(p._2, p._1.getType)).toMap - val cnstr = Not(replaceFromIDs(mapping, body)) - solver.assertCnstr(cnstr) - - if (domains.isDefined) { - val dom = domains.get - val quantifiers = fargs.map(_.id).toSet - val quorums = extractQuorums(body, quantifiers) - - val domainCnstr = orJoin(quorums.map { quorum => - val quantifierDomains = quorum.flatMap { case (path, caller, args) => - val domain = caller match { - case Variable(id) => dom.get(mapping(id)) - case _ => ctx.reporter.fatalError("Unexpected quantifier matcher: " + caller) - } - - args.zipWithIndex.flatMap { - case (Variable(id),idx) if quantifiers(id) => - Some(id -> domain.map(cargs => path -> cargs(idx))) - case _ => None - } - } - - val domainMap = quantifierDomains.groupBy(_._1).mapValues(_.map(_._2).flatten) - andJoin(domainMap.toSeq.map { case (id, dom) => - orJoin(dom.toSeq.map { case (path, value) => - // @nv: Note that we know id.getType is first-order since quantifiers can only - // range over basic types. This means equality is guaranteed well-defined - // between `id` and `value` - path and Equals(Variable(id), value) - }) - }) - }) - - solver.assertCnstr(domainCnstr) - } - - try { - solver.check match { - case Some(negRes) => - val res = !negRes - val total = System.currentTimeMillis-tStart - - ctx.reporter.debug("Verification took "+total+"ms") - ctx.reporter.debug("Finished forall evaluation with: "+res) - - forallCache += (id, argsSeq) -> res - res - - case _ => - throw new LeonCodeGenRuntimeException("Timeout exceeded") - } - } finally { - solver.free() - solverf.shutdown() - } - } - } -} - diff --git a/src/main/scala/leon/codegen/runtime/RuntimeResources.scala b/src/main/scala/leon/codegen/runtime/RuntimeResources.scala deleted file mode 100644 index f18f23b3f2c92cb93e2842695874d3b4ee661d55..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/codegen/runtime/RuntimeResources.scala +++ /dev/null @@ -1,42 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon.codegen.runtime - -import leon.utils.UniqueCounter - -import java.util.WeakHashMap -import java.lang.ref.WeakReference - -/** - * This class allows an evaluator to statically register a resource, identified - * by an integer. This identifier can be stored in bytecode, allowing .class to - * access the resource at runtime. - * - * The user/evaluator should keep hold of the returned Token, otherwise the - * resource may be garbage-collected. - * - * This is not statically-typed, but - * - * get[A]( register[A]( ... ) ) - * - * should always be safe. - */ -object RuntimeResources { - case class Token(id: Int) - - private val intCounter = new UniqueCounter[Unit] - - private[this] val store = new WeakHashMap[Token, WeakReference[AnyRef]]() - - def register[T <: AnyRef](data: T): Token = synchronized { - val t = Token(intCounter.nextGlobal) - - store.put(t, new WeakReference(data)) - - t - } - - def get[T <: AnyRef](id: Int): T = { - store.get(Token(id)).get.asInstanceOf[T] - } -} diff --git a/src/main/scala/leon/datagen/DataGenerator.scala b/src/main/scala/leon/datagen/DataGenerator.scala deleted file mode 100644 index ea6aac313f357fd6c04577d8a00261e4dfa0fe5c..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/datagen/DataGenerator.scala +++ /dev/null @@ -1,26 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package datagen - -import purescala.Expressions._ -import purescala.Common._ -import utils._ - -import java.util.concurrent.atomic.AtomicBoolean - -trait DataGenerator extends Interruptible { - implicit val debugSection = DebugSectionDataGen - - def generateFor(ins: Seq[Identifier], satisfying: Expr, maxValid: Int, maxEnumerated: Int): Iterator[Seq[Expr]] - - protected val interrupted: AtomicBoolean = new AtomicBoolean(false) - - def interrupt(): Unit = { - interrupted.set(true) - } - - def recoverInterrupt(): Unit = { - interrupted.set(false) - } -} diff --git a/src/main/scala/leon/datagen/GrammarDataGen.scala b/src/main/scala/leon/datagen/GrammarDataGen.scala deleted file mode 100644 index ffe7b8b4741d7323a99bdfc2741722fb356b5394..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/datagen/GrammarDataGen.scala +++ /dev/null @@ -1,94 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package datagen - -import purescala.Expressions._ -import purescala.Types._ -import purescala.Common._ -import purescala.Constructors._ -import purescala.Extractors._ -import purescala.ExprOps._ -import evaluators._ -import bonsai.enumerators._ - -import grammars._ -import utils.UniqueCounter -import utils.SeqUtils.cartesianProduct - -/** Utility functions to generate values of a given type. - * In fact, it could be used to generate *terms* of a given type, - * e.g. by passing trees representing variables for the "bounds". */ -class GrammarDataGen(evaluator: Evaluator, grammar: ExpressionGrammar = ValueGrammar) extends DataGenerator { - implicit val ctx = evaluator.context - - // Assume e contains generic values with index 0. - // Return a series of expressions with all normalized combinations of generic values. - private def expandGenerics(e: Expr): Seq[Expr] = { - val c = new UniqueCounter[TypeParameter] - val withUniqueCounters: Expr = postMap { - case GenericValue(t, _) => - Some(GenericValue(t, c.next(t))) - case _ => None - }(e) - - val indices = c.current - - val (tps, substInt) = (for { - tp <- indices.keySet.toSeq - } yield tp -> (for { - from <- 0 to indices(tp) - to <- 0 to from - } yield (from, to))).unzip - - val combos = cartesianProduct(substInt) - - val substitutions = combos map { subst => - tps.zip(subst).map { case (tp, (from, to)) => - (GenericValue(tp, from): Expr) -> (GenericValue(tp, to): Expr) - }.toMap - } - - substitutions map (replace(_, withUniqueCounters)) - - } - - def generate(tpe: TypeTree): Iterator[Expr] = { - val enum = new MemoizedEnumerator[Label, Expr, ProductionRule[Label, Expr]](grammar.getProductions) - enum.iterator(Label(tpe)).flatMap(expandGenerics).takeWhile(_ => !interrupted.get) - } - - def generateFor(ins: Seq[Identifier], satisfying: Expr, maxValid: Int, maxEnumerated: Int): Iterator[Seq[Expr]] = { - - def filterCond(vs: Seq[Expr]): Boolean = satisfying match { - case BooleanLiteral(true) => - true - case e => - // in -> e should be enough. We shouldn't find any subexpressions of in. - evaluator.eval(e, (ins zip vs).toMap) match { - case EvaluationResults.Successful(BooleanLiteral(true)) => true - case _ => false - } - } - - if (ins.isEmpty) { - Iterator(Seq[Expr]()).filter(filterCond) - } else { - val values = generate(tupleTypeWrap(ins.map{ _.getType })) - - val detupled = values.map { - v => unwrapTuple(v, ins.size) - } - - detupled.take(maxEnumerated) - .filter(filterCond) - .take(maxValid) - .takeWhile(_ => !interrupted.get) - } - } - - def generateMapping(ins: Seq[Identifier], satisfying: Expr, maxValid: Int, maxEnumerated: Int) = { - generateFor(ins, satisfying, maxValid, maxEnumerated) map (ins zip _) - } - -} diff --git a/src/main/scala/leon/datagen/NaiveDataGen.scala b/src/main/scala/leon/datagen/NaiveDataGen.scala deleted file mode 100644 index e40359e686e76030fb29171fe18c6eab82482898..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/datagen/NaiveDataGen.scala +++ /dev/null @@ -1,104 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package datagen - -import purescala.Common._ -import purescala.Expressions._ -import purescala.Types._ -import purescala.Definitions._ -import purescala.Quantification._ -import utils.StreamUtils._ - -import evaluators._ - -import scala.collection.mutable.{Map=>MutableMap} - -/** Utility functions to generate values of a given type. - * In fact, it could be used to generate *terms* of a given type, - * e.g. by passing trees representing variables for the "bounds". */ -@deprecated("Stream-based datagen is deprecated, use GrammarDataGen with ValueGrammar instead", "3.0") -class NaiveDataGen(ctx: LeonContext, p: Program, evaluator: Evaluator, _bounds : Option[Map[TypeTree,Seq[Expr]]] = None) extends DataGenerator { - - val bounds = _bounds.getOrElse(Map()) - - private def tpStream(tp: TypeParameter, i: Int = 1): Stream[Expr] = Stream.cons(GenericValue(tp, i), tpStream(tp, i+1)) - - private val streamCache : MutableMap[TypeTree,Stream[Expr]] = MutableMap.empty - def generate(tpe : TypeTree) : Stream[Expr] = { - try { - streamCache.getOrElse(tpe, { - val s = generate0(tpe) - streamCache(tpe) = s - s - }) - } catch { - case so: StackOverflowError => - Stream.empty - } - } - - private def generate0(tpe: TypeTree): Stream[Expr] = bounds.get(tpe).map(_.toStream).getOrElse { - tpe match { - case BooleanType => - BooleanLiteral(true) #:: BooleanLiteral(false) #:: Stream.empty - - case Int32Type => - IntLiteral(0) #:: IntLiteral(1) #:: IntLiteral(2) #:: IntLiteral(-1) #:: Stream.empty - - case tp: TypeParameter => - tpStream(tp) - - case TupleType(bses) => - cartesianProduct(bses.map(generate)).map(Tuple) - - case act : AbstractClassType => - // We prioritize base cases among the children. - // Otherwise we run the risk of infinite recursion when - // generating lists. - val ccChildren = act.knownCCDescendants - - val (leafs,conss) = ccChildren.partition(_.fields.isEmpty) - - // FIXME: Will not work for mutually recursive types - val sortedConss = conss sortBy { _.fields.count{ _.getType.isInstanceOf[ClassType]}} - - // The stream for leafs... - val leafsStream = leafs.toStream.flatMap(generate) - - // ...to which we append the streams for constructors. - leafsStream.append(interleave(sortedConss.map(generate))) - - case cct : CaseClassType => - if(cct.fields.isEmpty) { - Stream.cons(CaseClass(cct, Nil), Stream.empty) - } else { - val fieldTypes = cct.fieldsTypes - cartesianProduct(fieldTypes.map(generate)).map(CaseClass(cct, _)) - } - - case _ => Stream.empty - } - } - - def generateFor(ins: Seq[Identifier], satisfying: Expr, maxValid : Int, maxEnumerated : Int) : Iterator[Seq[Expr]] = { - val filtering = if (satisfying == BooleanLiteral(true)) { - { (e: Seq[Expr]) => true } - } else { - evaluator.compile(satisfying, ins).map { evalFun => - val sat = EvaluationResults.Successful(BooleanLiteral(true)) - - { (e: Seq[Expr]) => evalFun(new solvers.Model((ins zip e).toMap)) == sat } - } getOrElse { - { (e: Seq[Expr]) => false } - } - } - - cartesianProduct(ins.map(id => generate(id.getType))) - .take(maxEnumerated) - .takeWhile(s => !interrupted.get) - .filter{filtering} - .take(maxValid) - .iterator - } -} diff --git a/src/main/scala/leon/datagen/SolverDataGen.scala b/src/main/scala/leon/datagen/SolverDataGen.scala deleted file mode 100644 index 37dcfba061b7ef9fd71bb2d31e57754950eed5d5..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/datagen/SolverDataGen.scala +++ /dev/null @@ -1,83 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package datagen - -import purescala.Expressions._ -import purescala.Types._ -import purescala.Definitions._ -import purescala.Common._ -import purescala.Constructors._ -import solvers._ -import utils._ - -class SolverDataGen(ctx: LeonContext, pgm: Program, sf: SolverFactory[Solver]) extends DataGenerator { - implicit val ctx0 = ctx - - def generate(tpe: TypeTree): FreeableIterator[Expr] = { - generateFor(Seq(FreshIdentifier("tmp", tpe)), BooleanLiteral(true), 20, 20).map(_.head).takeWhile(_ => !interrupted.get) - } - - def generateFor(ins: Seq[Identifier], satisfying: Expr, maxValid: Int, maxEnumerated: Int): FreeableIterator[Seq[Expr]] = { - if (ins.isEmpty) { - FreeableIterator.empty - } else { - - var fds = Map[ClassDef, FunDef]() - - def sizeFor(of: Expr): Expr = of.getType match { - case AbstractClassType(acd, tps) => - val fd = fds.getOrElse(acd, { - val actDef = AbstractClassType(acd, acd.tparams.map(_.tp)) - - val e = FreshIdentifier("e", actDef) - - val fd: FunDef = new FunDef(FreshIdentifier("sizeOf", Untyped), acd.tparams, Seq(ValDef(e)), IntegerType) - - fds += acd -> fd - - - fd.body = Some(MatchExpr(e.toVariable, - actDef.knownCCDescendants.map { cct => - val fields = cct.fieldsTypes.map ( t => FreshIdentifier("field", t)) - - val rhs = fields.foldLeft(InfiniteIntegerLiteral(1): Expr) { (i, f) => - plus(i, sizeFor(f.toVariable)) - } - - MatchCase(CaseClassPattern(None, cct, fields.map(f => WildcardPattern(Some(f)))), None, rhs) - } - )) - - fd - }) - - FunctionInvocation(fd.typed(tps), Seq(of)) - - case tt @ TupleType(tps) => - val exprs = for ((t,i) <- tps.zipWithIndex) yield { - sizeFor(tupleSelect(of, i+1, tps.size)) - } - - exprs.foldLeft(InfiniteIntegerLiteral(1): Expr)(plus) - - case _ => - InfiniteIntegerLiteral(1) - } - - val sizeOf = sizeFor(tupleWrap(ins.map(_.toVariable))) - - // We need to synthesize a size function for ins' types. - val pgm1 = Program(pgm.units :+ UnitDef(FreshIdentifier("new"), List( - ModuleDef(FreshIdentifier("new"), fds.values.toSeq, false) - ))) - - val modelEnum = new ModelEnumerator(ctx, pgm1, sf) - - val enum = modelEnum.enumVarying(ins, satisfying, sizeOf, 5) - - enum.take(maxValid).map(model => ins.map(model)).takeWhile(_ => !interrupted.get) - } - } - -} diff --git a/src/main/scala/leon/datagen/VanuatooDataGen.scala b/src/main/scala/leon/datagen/VanuatooDataGen.scala deleted file mode 100644 index ccba55af919489697d9d9e6a3aa59768d29dbab2..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/datagen/VanuatooDataGen.scala +++ /dev/null @@ -1,409 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package datagen - -import purescala.Common._ -import purescala.Definitions._ -import purescala.ExprOps._ -import purescala.Expressions._ -import purescala.Types._ -import purescala.Extractors._ -import purescala.Constructors._ - -import codegen.CompilationUnit -import codegen.CodeGenParams -import codegen.runtime.StdMonitor -import vanuatoo.{Pattern => VPattern, _} - -import evaluators._ - -class VanuatooDataGen(ctx: LeonContext, p: Program, bank: EvaluationBank = new EvaluationBank) extends DataGenerator { - val unit = new CompilationUnit(ctx, p, bank, CodeGenParams.default.copy(doInstrument = true)) - - val ints = (for (i <- Set(0, 1, 2, 3)) yield { - i -> Constructor[Expr, TypeTree](List(), Int32Type, s => IntLiteral(i), ""+i) - }).toMap - - val bigInts = (for (i <- Set(0, 1, 2, 3)) yield { - i -> Constructor[Expr, TypeTree](List(), IntegerType, s => InfiniteIntegerLiteral(i), ""+i) - }).toMap - - val booleans = (for (b <- Set(true, false)) yield { - b -> Constructor[Expr, TypeTree](List(), BooleanType, s => BooleanLiteral(b), ""+b) - }).toMap - - val chars = (for (c <- Set('a', 'b', 'c', 'd')) yield { - c -> Constructor[Expr, TypeTree](List(), CharType, s => CharLiteral(c), ""+c) - }).toMap - - val rationals = (for (n <- Set(0, 1, 2, 3); d <- Set(1,2,3,4)) yield { - (n, d) -> Constructor[Expr, TypeTree](List(), RealType, s => FractionalLiteral(n, d), "" + n + "/" + d) - }).toMap - - val strings = (for (b <- Set("", "a", "foo", "bar")) yield { - b -> Constructor[Expr, TypeTree](List(), StringType, s => StringLiteral(b), b) - }).toMap - - - def intConstructor(i: Int) = ints(i) - - def bigIntConstructor(i: Int) = bigInts(i) - - def boolConstructor(b: Boolean) = booleans(b) - - def charConstructor(c: Char) = chars(c) - - def rationalConstructor(n: Int, d: Int) = rationals(n -> d) - - def stringConstructor(s: String) = strings(s) - - lazy val stubValues = ints.values ++ bigInts.values ++ booleans.values ++ chars.values ++ rationals.values ++ strings.values - - def cPattern(c: Constructor[Expr, TypeTree], args: Seq[VPattern[Expr, TypeTree]]) = { - ConstructorPattern[Expr, TypeTree](c, args) - } - - private var constructors = Map[TypeTree, List[Constructor[Expr, TypeTree]]]() - - private def getConstructorFor(t: CaseClassType, act: AbstractClassType): Constructor[Expr, TypeTree] = { - // We "up-cast" the returnType of the specific caseclass generator to match its superclass - getConstructors(t).head.copy(retType = act) - } - - private def getConstructors(t: TypeTree): List[Constructor[Expr, TypeTree]] = t match { - case UnitType => - constructors.getOrElse(t, { - val cs = List(Constructor[Expr, TypeTree](List(), t, s => UnitLiteral(), "()")) - constructors += t -> cs - cs - }) - - case at @ ArrayType(sub) => - constructors.getOrElse(at, { - val cs = for (size <- List(0, 1, 2, 5)) yield { - Constructor[Expr, TypeTree]( - (1 to size).map(i => sub).toList, - at, - s => finiteArray(s, None, sub), - at.asString(ctx)+"@"+size - ) - } - constructors += at -> cs - cs - }) - - case st @ SetType(sub) => - constructors.getOrElse(st, { - val cs = for (size <- List(0, 1, 2, 5)) yield { - Constructor[Expr, TypeTree]( - (1 to size).map(i => sub).toList, - st, - s => FiniteSet(s.toSet, sub), - st.asString(ctx)+"@"+size - ) - } - constructors += st -> cs - cs - }) - - case bt @ BagType(sub) => - constructors.getOrElse(bt, { - val cs = for (size <- List(0, 1, 2, 5)) yield { - val subs = (1 to size).flatMap(i => List(sub, IntegerType)).toList - Constructor[Expr, TypeTree](subs, bt, s => FiniteBag(s.grouped(2).map { - case Seq(k, i @ InfiniteIntegerLiteral(v)) => - k -> (if (v > 0) i else InfiniteIntegerLiteral(-v + 1)) - }.toMap, sub), bt.asString(ctx)+"@"+size) - } - constructors += bt -> cs - cs - }) - - case tt @ TupleType(parts) => - constructors.getOrElse(tt, { - val cs = List(Constructor[Expr, TypeTree](parts, tt, s => tupleWrap(s), tt.asString(ctx))) - constructors += tt -> cs - cs - }) - - case mt @ MapType(from, to) => - constructors.getOrElse(mt, { - val cs = for (size <- List(0, 1, 2, 5)) yield { - val subs = (1 to size).flatMap(i => List(from, to)).toList - Constructor[Expr, TypeTree](subs, mt, s => FiniteMap(s.grouped(2).map(t => (t(0), t(1))).toMap, from, to), mt.asString(ctx)+"@"+size) - } - constructors += mt -> cs - cs - }) - - case ft @ FunctionType(from, to) => - constructors.getOrElse(ft, { - val cs = for (size <- List(1, 2, 3, 5)) yield { - val subs = (1 to size).flatMap(_ => from :+ to).toList - Constructor[Expr, TypeTree](subs, ft, { s => - val grouped = s.grouped(from.size + 1).toSeq - val mapping = grouped.init.map { case args :+ res => (args -> res) } - FiniteLambda(mapping, grouped.last.last, ft) - }, ft.asString(ctx) + "@" + size) - } - constructors += ft -> cs - cs - }) - - case tp: TypeParameter => - constructors.getOrElse(tp, { - val cs = for (i <- List(1, 2)) yield { - Constructor[Expr, TypeTree](List(), tp, s => GenericValue(tp, i), tp.id+"#"+i) - } - constructors += tp -> cs - cs - }) - - case act: AbstractClassType => - constructors.getOrElse(act, { - val cs = act.knownCCDescendants.map { - cct => getConstructorFor(cct, act) - }.toList - - constructors += act -> cs - - cs - }) - - case cct: CaseClassType => - constructors.getOrElse(cct, { - val c = List(Constructor[Expr, TypeTree](cct.fieldsTypes, cct, s => CaseClass(cct, s), cct.id.name)) - constructors += cct -> c - c - }) - - case _ => - ctx.reporter.error("Unknown type to generate constructor for: "+t) - Nil - } - - // Returns the pattern and whether it is fully precise - private def valueToPattern(v: AnyRef, expType: TypeTree): (VPattern[Expr, TypeTree], Boolean) = (v, expType) match { - case (i: Integer, Int32Type) => - (cPattern(intConstructor(i), List()), true) - - case (i: Integer, IntegerType) => - (cPattern(bigIntConstructor(i), List()), true) - - case (b: java.lang.Boolean, BooleanType) => - (cPattern(boolConstructor(b), List()), true) - - case (c: java.lang.Character, CharType) => - (cPattern(charConstructor(c), List()), true) - - case (b: java.lang.String, StringType) => - (cPattern(stringConstructor(b), List()), true) - - case (cc: codegen.runtime.CaseClass, ct: ClassType) => - val r = cc.__getRead() - - unit.jvmClassToLeonClass(cc.getClass.getName) match { - case Some(ccd: CaseClassDef) => - val cct = CaseClassType(ccd, ct.tps) - val c = ct match { - case act : AbstractClassType => - getConstructorFor(cct, act) - case cct : CaseClassType => - getConstructors(cct).head - } - - val fields = cc.productElements() - - val elems = for (i <- 0 until fields.length) yield { - if (((r >> i) & 1) == 1) { - // has been read - valueToPattern(fields(i), cct.fieldsTypes(i)) - } else { - (AnyPattern[Expr, TypeTree](), false) - } - } - - (ConstructorPattern(c, elems.map(_._1)), elems.forall(_._2)) - - case _ => - ctx.reporter.error("Could not retrieve type for :"+cc.getClass.getName) - (AnyPattern[Expr, TypeTree](), false) - } - - case (t: codegen.runtime.Tuple, tpe) => - val r = t.__getRead() - - val parts = unwrapTupleType(tpe, t.getArity) - - val c = getConstructors(tpe).head - - val elems = for (i <- 0 until t.getArity) yield { - if (((r >> i) & 1) == 1) { - // has been read - valueToPattern(t.get(i), parts(i)) - } else { - (AnyPattern[Expr, TypeTree](), false) - } - } - - (ConstructorPattern(c, elems.map(_._1)), elems.forall(_._2)) - - case (gv: GenericValue, t: TypeParameter) => - (cPattern(getConstructors(t)(gv.id-1), List()), true) - - case (v, t) => - ctx.reporter.debug("Unsupported value, can't paternify : "+v+" ("+v.getClass+") : "+t) - (AnyPattern[Expr, TypeTree](), false) - } - - type InstrumentedResult = (EvaluationResults.Result[Expr], Option[vanuatoo.Pattern[Expr, TypeTree]]) - - def compile(expression: Expr, argorder: Seq[Identifier]) : Option[Expr=>InstrumentedResult] = { - import leon.codegen.runtime.LeonCodeGenRuntimeException - import leon.codegen.runtime.LeonCodeGenEvaluationException - - try { - val ttype = tupleTypeWrap(argorder.map(_.getType)) - val tid = FreshIdentifier("tup", ttype) - - val map = argorder.zipWithIndex.map{ case (id, i) => id -> tupleSelect(Variable(tid), i + 1, argorder.size) }.toMap - - val newExpr = replaceFromIDs(map, expression) - - val ce = unit.compileExpression(newExpr, Seq(tid))(ctx) - - Some((args : Expr) => { - try { - val monitor = new StdMonitor(unit, unit.params.maxFunctionInvocations, Map()) - - val jvmArgs = ce.argsToJVM(Seq(args), monitor) - - val result = ce.evalFromJVM(jvmArgs, monitor) - - // jvmArgs is getting updated by evaluating - val pattern = valueToPattern(jvmArgs.head, ttype) - - (EvaluationResults.Successful(result), if (!pattern._2) Some(pattern._1) else None) - } catch { - case e : StackOverflowError => - (EvaluationResults.RuntimeError(e.getMessage), None) - - case e : ClassCastException => - (EvaluationResults.RuntimeError(e.getMessage), None) - - case e : ArithmeticException => - (EvaluationResults.RuntimeError(e.getMessage), None) - - case e : ArrayIndexOutOfBoundsException => - (EvaluationResults.RuntimeError(e.getMessage), None) - - case e : LeonCodeGenRuntimeException => - (EvaluationResults.RuntimeError(e.getMessage), None) - - case e : LeonCodeGenEvaluationException => - (EvaluationResults.EvaluatorError(e.getMessage), None) - } - }) - } catch { - case t: Throwable => - ctx.reporter.warning("Error while compiling expression: "+t.getMessage); t.printStackTrace() - None - } - } - - def generateFor(ins: Seq[Identifier], satisfying: Expr, maxValid: Int, maxEnumerated: Int): Iterator[Seq[Expr]] = { - // Split conjunctions - val TopLevelAnds(ands) = satisfying - - val runners = ands.flatMap(a => compile(a, ins) match { - case Some(runner) => Some(runner) - case None => - ctx.reporter.error("Could not compile predicate " + a) - None - }) - - val gen = new StubGenerator[Expr, TypeTree](stubValues.toSeq, - Some(getConstructors _), - treatEmptyStubsAsChildless = true) - - /** - * Gather at most <n> isomoprhic models before skipping them - * - Too little means skipping many excluding patterns - * - Too large means repetitive (and not useful models) before reaching maxEnumerated - */ - - val maxIsomorphicModels = maxValid+1 - - val it = gen.enumerate(tupleTypeWrap(ins.map(_.getType))) - - new Iterator[Seq[Expr]] { - var total = 0 - var found = 0 - - var theNext: Option[Seq[Expr]] = None - - def hasNext = { - if (total == 0) { - theNext = computeNext() - } - - theNext.isDefined - } - - def next() = { - val res = theNext.get - theNext = computeNext() - res - } - - - def computeNext(): Option[Seq[Expr]] = { - //return None - while (total < maxEnumerated && found < maxValid && it.hasNext && !interrupted.get) { - val model = it.next() - it.hasNext // FIXME: required for some reason by StubGenerator or will return false during loop check - - if (model eq null) { - total = maxEnumerated - } else { - total += 1 - - var failed = false - - for (r <- runners) r(model) match { - case (EvaluationResults.Successful(BooleanLiteral(true)), _) => - - case (_, Some(pattern)) => - failed = true - it.exclude(pattern) - - case (_, None) => - failed = true; - } - - if (!failed) { - //println("Got model:") - //for ((i, v) <- (ins zip model.exprs)) { - // println(" - "+i+" -> "+v) - //} - - found += 1 - - if (found % maxIsomorphicModels == 0) { - it.skipIsomorphic() - } - - return Some(unwrapTuple(model, ins.size)) - } - - //if (total % 1000 == 0) { - // println("... "+total+" ...") - //} - } - } - None - } - } - } -} diff --git a/src/main/scala/leon/frontends/scalac/ASTExtractors.scala b/src/main/scala/leon/frontends/scalac/ASTExtractors.scala deleted file mode 100644 index b99e73fce7a05378d208d526501a577c3eb2d2d7..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/frontends/scalac/ASTExtractors.scala +++ /dev/null @@ -1,1192 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package frontends.scalac - -import scala.tools.nsc._ - -/** Contains extractors to pull-out interesting parts of the Scala ASTs. */ -trait ASTExtractors { - val global: Global - - import global._ - import global.definitions._ - - def classFromName(str: String) = { - rootMirror.getClassByName(newTypeName(str)) - } - - def objectFromName(str: String) = { - rootMirror.getClassByName(newTermName(str)) - } - - def annotationsOf(s: Symbol): Map[String, Seq[Option[Any]]] = { - val actualSymbol = s.accessedOrSelf - - (for { - a <- actualSymbol.annotations ++ actualSymbol.owner.annotations - name = a.atp.safeToString.replaceAll("\\.package\\.", ".") - if (name startsWith "leon.annotation.") - } yield { - val args = a.args.map { - case Literal(x) => Some(x.value) - case _ => None - } - (name.split("\\.", 3)(2), args) - }).toMap - } - - protected lazy val tuple2Sym = classFromName("scala.Tuple2") - protected lazy val tuple3Sym = classFromName("scala.Tuple3") - protected lazy val tuple4Sym = classFromName("scala.Tuple4") - protected lazy val tuple5Sym = classFromName("scala.Tuple5") - protected lazy val scalaMapSym = classFromName("scala.collection.immutable.Map") - protected lazy val scalaSetSym = classFromName("scala.collection.immutable.Set") - protected lazy val setSym = classFromName("leon.lang.Set") - protected lazy val mapSym = classFromName("leon.lang.Map") - protected lazy val bagSym = classFromName("leon.lang.Bag") - protected lazy val realSym = classFromName("leon.lang.Real") - protected lazy val optionClassSym = classFromName("scala.Option") - protected lazy val arraySym = classFromName("scala.Array") - protected lazy val someClassSym = classFromName("scala.Some") - protected lazy val byNameSym = classFromName("scala.<byname>") - protected lazy val bigIntSym = classFromName("scala.math.BigInt") - protected lazy val stringSym = classFromName("java.lang.String") - protected def functionTraitSym(i:Int) = { - require(0 <= i && i <= 22) - classFromName("scala.Function" + i) - } - - def isTuple2(sym : Symbol) : Boolean = sym == tuple2Sym - def isTuple3(sym : Symbol) : Boolean = sym == tuple3Sym - def isTuple4(sym : Symbol) : Boolean = sym == tuple4Sym - def isTuple5(sym : Symbol) : Boolean = sym == tuple5Sym - - def isBigIntSym(sym : Symbol) : Boolean = getResolvedTypeSym(sym) == bigIntSym - - def isStringSym(sym : Symbol) : Boolean = getResolvedTypeSym(sym) match { case `stringSym` => true case _ => false } - - def isByNameSym(sym : Symbol) : Boolean = getResolvedTypeSym(sym) == byNameSym - - // Resolve type aliases - def getResolvedTypeSym(sym: Symbol): Symbol = { - if (sym.isAliasType) { - getResolvedTypeSym(sym.tpe.resultType.typeSymbol) - } else { - sym - } - } - - def isSetSym(sym: Symbol) : Boolean = { - getResolvedTypeSym(sym) == setSym - } - - def isBagSym(sym: Symbol) : Boolean = { - getResolvedTypeSym(sym) == bagSym - } - - def isRealSym(sym: Symbol) : Boolean = { - getResolvedTypeSym(sym) == realSym - } - - def isScalaSetSym(sym: Symbol) : Boolean = { - getResolvedTypeSym(sym) == scalaSetSym - } - - def isMapSym(sym: Symbol) : Boolean = { - getResolvedTypeSym(sym) == mapSym - } - - def isScalaMapSym(sym: Symbol) : Boolean = { - getResolvedTypeSym(sym) == scalaMapSym - } - - def isOptionClassSym(sym : Symbol) : Boolean = { - sym == optionClassSym || sym == someClassSym - } - - def isFunction(sym : Symbol, i: Int) : Boolean = - 0 <= i && i <= 22 && sym == functionTraitSym(i) - - def isArrayClassSym(sym: Symbol): Boolean = sym == arraySym - - def hasIntType(t : Tree) = { - val tpe = t.tpe.widen - tpe =:= IntClass.tpe - } - - def hasBigIntType(t : Tree) = isBigIntSym(t.tpe.typeSymbol) - - def hasStringType(t : Tree) = isStringSym(t.tpe.typeSymbol) - - def hasRealType(t : Tree) = isRealSym(t.tpe.typeSymbol) - - /** A set of helpers for extracting trees.*/ - object ExtractorHelpers { - /** Extracts the identifier as `"Ident(name)"` (who needs this?!) */ - object ExIdNamed { - def unapply(id: Ident): Option[String] = Some(id.toString) - } - - /** Extracts the tree and its type (who needs this?!) */ - object ExHasType { - def unapply(tr: Tree): Option[(Tree, Symbol)] = Some((tr, tr.tpe.typeSymbol)) - } - - /** Extracts the string representation of a name of something having the `Name` trait */ - object ExNamed { - def unapply(name: Name): Option[String] = Some(name.toString) - } - - /** Returns the full dot-separated names of the symbol as a list of strings */ - object ExSymbol { - def unapplySeq(t: Tree): Option[Seq[String]] = { - Some(t.symbol.fullName.toString.split('.').toSeq) - } - } - - /** Matches nested `Select(Select(...Select(a, b) ...y) , z)` and returns the list `a,b, ... y,z` */ - object ExSelected { - def unapplySeq(select: Select): Option[Seq[String]] = select match { - case Select(This(scalaName), name) => - Some(Seq(scalaName.toString, name.toString)) - - case Select(from: Select, name) => - unapplySeq(from).map(prefix => prefix :+ name.toString) - - case Select(from: Ident, name) => - Some(Seq(from.toString, name.toString)) - - case _ => - None - } - } - } - - object StructuralExtractors { - import ExtractorHelpers._ - - /** Extracts the 'ensuring' contract from an expression. */ - object ExEnsuredExpression { - def unapply(tree: Apply): Option[(Tree,Tree)] = tree match { - case Apply(Select(Apply(TypeApply(ExSelected("scala", "Predef", "Ensuring"), _ :: Nil), body :: Nil), ExNamed("ensuring")), contract :: Nil) - => Some((body, contract)) - case Apply(Select(Apply(TypeApply(ExSelected("leon", "lang", "StaticChecks", "any2Ensuring"), _ :: Nil), body :: Nil), ExNamed("ensuring")), contract :: Nil) - => Some((body, contract)) - case _ => None - } - } - - /** Matches the `holds` expression at the end of any boolean expression, and returns the boolean expression.*/ - object ExHoldsExpression { - def unapply(tree: Select) : Option[Tree] = tree match { - case Select( - Apply(ExSelected("leon", "lang", "package", "BooleanDecorations"), realExpr :: Nil), - ExNamed("holds") - ) => Some(realExpr) - case _ => None - } - } - - /** Matches the `holds` expression at the end of any boolean expression with a proof as argument, and returns both of themn.*/ - object ExHoldsWithProofExpression { - def unapply(tree: Apply) : Option[(Tree, Tree)] = tree match { - case Apply(Select(Apply(ExSelected("leon", "lang", "package", "BooleanDecorations"), body :: Nil), ExNamed("holds")), proof :: Nil) => - Some((body, proof)) - case _ => None - } - } - - /** Matches the `because` method at the end of any boolean expression, and return the assertion and the cause. If no "because" method, still returns the expression */ - object ExMaybeBecauseExpressionWrapper { - def unapply(tree: Tree) : Some[Tree] = tree match { - case Apply(ExSelected("leon", "lang", "package", "because"), body :: Nil) => - unapply(body) - case body => Some(body) - } - } - - /** Matches the `because` method at the end of any boolean expression, and return the assertion and the cause.*/ - object ExBecauseExpression { - def unapply(tree: Apply) : Option[(Tree, Tree)] = tree match { - case Apply(Select(Apply(ExSelected("leon", "proof", "package", "boolean2ProofOps"), body :: Nil), ExNamed("because")), proof :: Nil) => - Some((body, proof)) - case _ => None - } - } - - /** Matches the `bigLength` expression at the end of any string expression, and returns the expression.*/ - object ExBigLengthExpression { - def unapply(tree: Apply) : Option[Tree] = tree match { - case Apply(Select( - Apply(ExSelected("leon", "lang", "package", "StringDecorations"), stringExpr :: Nil), - ExNamed("bigLength")), Nil) - => Some(stringExpr) - case _ => None - } - } - - /** Matches the `bigSubstring` method at the end of any string expression, and returns the expression and the start index expression.*/ - object ExBigSubstringExpression { - def unapply(tree: Apply) : Option[(Tree, Tree)] = tree match { - case Apply(Select( - Apply(ExSelected("leon", "lang", "package", "StringDecorations"), stringExpr :: Nil), - ExNamed("bigSubstring")), startExpr :: Nil) - => Some(stringExpr, startExpr) - case _ => None - } - } - - /** Matches the `bigSubstring` expression at the end of any string expression, and returns the expression, the start and end index expressions.*/ - object ExBigSubstring2Expression { - def unapply(tree: Apply) : Option[(Tree, Tree, Tree)] = tree match { - case Apply(Select( - Apply(ExSelected("leon", "lang", "package", "StringDecorations"), stringExpr :: Nil), - ExNamed("bigSubstring")), startExpr :: endExpr :: Nil) - => Some(stringExpr, startExpr, endExpr) - case _ => None - } - } - - /** Matches an implication `lhs ==> rhs` and returns (lhs, rhs)*/ - object ExImplies { - def unapply(tree: Apply) : Option[(Tree, Tree)] = tree match { - case - Apply( - Select( - Apply( - ExSelected("leon", "lang", "package", "BooleanDecorations"), - lhs :: Nil - ), - ExNamed("$eq$eq$greater") - ), - rhs :: Nil - ) => Some((lhs, rhs)) - case _ => None - } - } - - /** Extracts the 'require' contract from an expression (only if it's the - * first call in the block). */ - object ExRequiredExpression { - def unapply(tree: Apply): Option[Tree] = tree match { - case Apply(ExSelected("scala", "Predef", "require"), contractBody :: Nil) => - Some(contractBody) - case _ => None - } - } - - /** Matches the `A computes B` expression at the end of any expression A, and returns (A, B).*/ - object ExComputesExpression { - def unapply(tree: Apply) : Option[(Tree, Tree)] = tree match { - case Apply(Select( - Apply(TypeApply(ExSelected("leon", "lang", "package", "SpecsDecorations"), List(_)), realExpr :: Nil), - ExNamed("computes")), expected::Nil) - => Some((realExpr, expected)) - case _ => None - } - } - - /** Matches the `O ask I` expression at the end of any expression O, and returns (I, O).*/ - object ExAskExpression { - def unapply(tree: Apply) : Option[(Tree, Tree)] = tree match { - case Apply(TypeApply(Select( - Apply(TypeApply(ExSelected("leon", "lang", "package", "SpecsDecorations"), List(_)), output :: Nil), - ExNamed("ask")), List(_)), input::Nil) - => Some((input, output)) - case _ => None - } - } - - object ExByExampleExpression { - def unapply(tree: Apply) : Option[(Tree, Tree)] = tree match { - case Apply(TypeApply(ExSelected("leon", "lang", "package", "byExample"), List(_, _)), input :: res_output :: Nil) - => Some((input, res_output)) - case _ => None - } - } - - /** Extracts the `(input, output) passes { case In => Out ...}` and returns (input, output, list of case classes) */ - object ExPasses { - def unapply(tree : Apply) : Option[(Tree, Tree, List[CaseDef])] = tree match { - case Apply( - Select( - Apply( - TypeApply( - ExSelected("leon", "lang", "package", "Passes"), - List(_, _) - ), - List(ExpressionExtractors.ExTuple(_, Seq(in,out))) - ), - ExNamed("passes") - ), - List(Function( - List(ValDef(_, _, _, EmptyTree)), - ExpressionExtractors.ExPatternMatching(_,tests) - )) - ) - => Some((in, out, tests)) - case _ => None - } - } - - /** Returns a string literal from a constant string literal. */ - object ExStringLiteral { - def unapply(tree: Tree): Option[String] = tree match { - case Literal(c @ Constant(i)) if c.tpe == StringClass.tpe => - Some(c.stringValue) - case _ => - None - } - } - - /** Returns the arguments of an unapply pattern */ - object ExUnapplyPattern { - def unapply(tree: Tree): Option[(Symbol, Seq[Tree])] = tree match { - case UnApply(Apply(s, _), args) => - Some((s.symbol, args)) - case _ => None - } - } - - /** Returns the argument of a bigint literal, either from scala or leon */ - object ExBigIntLiteral { - def unapply(tree: Tree): Option[Tree] = tree match { - case Apply(ExSelected("scala", "package", "BigInt", "apply"), n :: Nil) => - Some(n) - case Apply(ExSelected("leon", "lang", "package", "BigInt", "apply"), n :: Nil) => - Some(n) - case _ => - None - } - } - - /** Returns the two components (n, d) of a real n/d literal */ - object ExRealLiteral { - def unapply(tree: Tree): Option[(Tree, Tree)] = tree match { - case Apply(ExSelected("leon", "lang", "Real", "apply"), n :: d :: Nil) => - Some((n, d)) - case _ => - None - } - } - - /** Matches Real(x) when n is an integer and returns x */ - object ExRealIntLiteral { - def unapply(tree: Tree): Option[Tree] = tree match { - case Apply(ExSelected("leon", "lang", "Real", "apply"), n :: Nil) => - Some(n) - case _ => - None - } - } - - /** Matches the construct int2bigInt(a) and returns a */ - object ExIntToBigInt { - def unapply(tree: Tree): Option[Tree] = tree match { - case Apply(ExSelected("math", "BigInt", "int2bigInt"), tree :: Nil) => Some(tree) - case _ => None - } - } - - /** Matches the construct List[tpe](a, b, ...) and returns tpe and arguments */ - object ExListLiteral { - def unapply(tree: Apply): Option[(Tree, List[Tree])] = tree match { - case Apply( - TypeApply(ExSelected("leon", "collection", "List", "apply"), tpe :: Nil), - args) => - Some((tpe, args)) - case _ => - None - } - } - - /** Extracts the 'assert' contract from an expression (only if it's the - * first call in the block). */ - object ExAssertExpression { - def unapply(tree: Apply): Option[(Tree, Option[String])] = tree match { - case Apply(ExSelected("scala", "Predef", "assert"), contractBody :: Nil) => - Some((contractBody, None)) - case Apply(ExSelected("scala", "Predef", "assert"), contractBody :: (error: Literal) :: Nil) => - Some((contractBody, Some(error.value.stringValue))) - case _ => - None - } - } - - /** Matches an object with no type parameters, and regardless of its - * visibility. Does not match on case objects or the automatically generated companion - * objects of case classes (or any synthetic class). */ - object ExObjectDef { - def unapply(cd: ClassDef): Option[(String,Template)] = cd match { - case ClassDef(_, name, tparams, impl) if - (cd.symbol.isModuleClass || cd.symbol.hasPackageFlag) && - tparams.isEmpty && - !cd.symbol.isSynthetic && - !cd.symbol.isCaseClass - => { - Some((name.toString, impl)) - } - case _ => None - } - } - - /** Matches an abstract class or a trait with no type parameters, no - * constructor args (in the case of a class), no implementation details, - * no abstract members. */ - object ExAbstractClass { - def unapply(cd: ClassDef): Option[(String, Symbol, Template)] = cd match { - // abstract class - case ClassDef(_, name, tparams, impl) if cd.symbol.isAbstractClass => Some((name.toString, cd.symbol, impl)) - - case _ => None - } - } - - /** Returns true if the class definition is a case class */ - private def isCaseClass(cd: ClassDef): Boolean = { - cd.symbol.isCase && !cd.symbol.isAbstractClass && cd.impl.body.size >= 8 - } - - /** Returns true if the class definition is an implicit class */ - private def isImplicitClass(cd: ClassDef): Boolean = { - cd.symbol.isImplicit - } - - object ExCaseClass { - def unapply(cd: ClassDef): Option[(String,Symbol,Seq[(Symbol,ValDef)], Template)] = cd match { - case ClassDef(_, name, tparams, impl) if isCaseClass(cd) || isImplicitClass(cd) => { - val constructor: DefDef = impl.children.find { - case ExConstructorDef() => true - case _ => false - }.get.asInstanceOf[DefDef] - - val valDefs = constructor.vparamss.flatten - //println("valDefs: " + valDefs) - - //impl.children foreach println - - val symbols = impl.children.collect { - case df@DefDef(_, name, _, _, _, _) if - df.symbol.isAccessor && df.symbol.isParamAccessor - && !name.endsWith("_$eq") => df.symbol - } - //println("symbols: " + symbols) - //println("symbols accessed: " + symbols.map(_.accessed)) - - //if (symbols.size != valDefs.size) { - // println(" >>>>> " + cd.name) - // symbols foreach println - // valDefs foreach println - //} - - val args = symbols zip valDefs - - Some((name.toString, cd.symbol, args, impl)) - } - case _ => None - } - } - - object ExCaseObject { - def unapply(s: Select): Option[Symbol] = { - if (s.tpe.typeSymbol.isModuleClass) { - Some(s.tpe.typeSymbol) - } else { - None - } - } - } - - object ExCompanionObjectSynthetic { - def unapply(cd : ClassDef) : Option[(String, Symbol, Template)] = { - val sym = cd.symbol - cd match { - case ClassDef(_, name, tparams, impl) if sym.isModule && sym.isSynthetic => //FIXME flags? - Some((name.toString, sym, impl)) - case _ => None - } - - } - } - - object ExCaseClassSyntheticJunk { - def unapply(cd: Tree): Boolean = cd match { - case ClassDef(_, _, _, _) if cd.symbol.isSynthetic => true - case DefDef(_, _, _, _, _, _) if cd.symbol.isSynthetic && (cd.symbol.isCase || cd.symbol.isPrivate) => true - case _ => false - } - } - - object ExConstructorDef { - def unapply(dd: DefDef): Boolean = dd match { - case DefDef(_, name, tparams, vparamss, tpt, rhs) if name == nme.CONSTRUCTOR && tparams.isEmpty => true - case _ => false - } - } - - object ExMainFunctionDef { - def unapply(dd: DefDef): Boolean = dd match { - case DefDef(_, name, tparams, vparamss, tpt, rhs) if name.toString == "main" && tparams.isEmpty && vparamss.size == 1 && vparamss.head.size == 1 => { - true - } - case _ => false - } - } - - object ExFunctionDef { - /** Matches a function with a single list of arguments, - * and regardless of its visibility. - */ - def unapply(dd: DefDef): Option[(Symbol, Seq[Symbol], Seq[ValDef], Type, Tree)] = dd match { - case DefDef(_, name, tparams, vparamss, tpt, rhs) if name != nme.CONSTRUCTOR && !dd.symbol.isAccessor => - if (dd.symbol.isSynthetic && dd.symbol.isImplicit && dd.symbol.isMethod) { - // Check that the class it was generated from is not ignored - if (annotationsOf(tpt.symbol).isDefinedAt("ignore")) { - None - } else { - Some((dd.symbol, tparams.map(_.symbol), vparamss.flatten, tpt.tpe, rhs)) - } - } else if (!dd.symbol.isSynthetic) { - Some((dd.symbol, tparams.map(_.symbol), vparamss.flatten, tpt.tpe, rhs)) - } else { - None - } - case _ => None - } - } - - object ExLazyAccessorFunction { - def unapply(dd: DefDef): Option[(Symbol, Type, Tree)] = dd match { - case DefDef(_, name, tparams, vparamss, tpt, rhs) if( - vparamss.size <= 1 && name != nme.CONSTRUCTOR && - !dd.symbol.isSynthetic && dd.symbol.isAccessor && dd.symbol.isLazy - ) => - Some((dd.symbol, tpt.tpe, rhs)) - case _ => None - } - } - - object ExMutatorAccessorFunction { - def unapply(dd: DefDef): Option[(Symbol, Seq[Symbol], Seq[ValDef], Type, Tree)] = dd match { - case DefDef(_, name, tparams, vparamss, tpt, rhs) if( - vparamss.size <= 1 && name != nme.CONSTRUCTOR && - !dd.symbol.isSynthetic && dd.symbol.isAccessor && name.endsWith("_$eq") - ) => - Some((dd.symbol, tparams.map(_.symbol), vparamss.flatten, tpt.tpe, rhs)) - case _ => None - } - } - object ExMutableFieldDef { - - /** Matches a definition of a strict var field inside a class constructor */ - def unapply(vd: SymTree) : Option[(Symbol, Type, Tree)] = { - val sym = vd.symbol - vd match { - // Implemented fields - case ValDef(mods, name, tpt, rhs) if ( - !sym.isCaseAccessor && !sym.isParamAccessor && - !sym.isLazy && !sym.isSynthetic && !sym.isAccessor && sym.isVar - ) => - println("matched a var accessor field: sym is: " + sym) - println("getterIn is: " + sym.getterIn(sym.owner)) - // Since scalac uses the accessor symbol all over the place, we pass that instead: - Some( (sym.getterIn(sym.owner),tpt.tpe,rhs) ) - case _ => None - } - } - } - - object ExFieldDef { - /** Matches a definition of a strict field inside a class constructor */ - def unapply(vd: SymTree) : Option[(Symbol, Type, Tree)] = { - val sym = vd.symbol - vd match { - // Implemented fields - case ValDef(mods, name, tpt, rhs) if ( - !sym.isCaseAccessor && !sym.isParamAccessor && - !sym.isLazy && !sym.isSynthetic && !sym.isAccessor && !sym.isVar - ) => - // Since scalac uses the accessor symbol all over the place, we pass that instead: - Some( (sym.getterIn(sym.owner),tpt.tpe,rhs) ) - // Unimplemented fields - case df@DefDef(_, name, _, _, tpt, _) if ( - sym.isStable && sym.isAccessor && sym.name != nme.CONSTRUCTOR && - sym.accessed == NoSymbol // This is to exclude fields with implementation - ) => - Some( (sym, tpt.tpe, EmptyTree)) - case _ => None - } - } - } - - object ExLazyFieldDef { - /** Matches lazy field definitions. - * WARNING: Do NOT use this as extractor for lazy fields, - * as it does not contain the body of the lazy definition. - * It is here just to signify a Definition acceptable by Leon - */ - def unapply(vd : ValDef) : Boolean = { - val sym = vd.symbol - vd match { - case ValDef(mods, name, tpt, rhs) if ( - sym.isLazy && !sym.isCaseAccessor && !sym.isParamAccessor && - !sym.isSynthetic && !sym.isAccessor - ) => - // Since scalac uses the accessor symbol all over the place, we pass that instead: - true - case _ => false - } - } - } - - object ExFieldAccessorFunction{ - /** Matches the accessor function of a field - * WARNING: This is not meant to be used for any useful purpose, - * other than to satisfy Definition acceptable by Leon - */ - def unapply(dd: DefDef): Boolean = dd match { - case DefDef(_, name, tparams, vparamss, tpt, rhs) if( - vparamss.size <= 1 && name != nme.CONSTRUCTOR && - dd.symbol.isAccessor && !dd.symbol.isLazy - ) => - true - case _ => false - } - } - - object ExDefaultValueFunction{ - /** Matches a function that defines the default value of a parameter */ - def unapply(dd: DefDef): Option[(Symbol, Seq[Symbol], Seq[ValDef], Type, String, Int, Tree)] = { - val sym = dd.symbol - dd match { - case DefDef(_, name, tparams, vparamss, tpt, rhs) if( - vparamss.size <= 1 && name != nme.CONSTRUCTOR && sym.isSynthetic - ) => - - // Split the name into pieces, to find owner of the parameter + param.index - // Form has to be <owner name>$default$<param index> - val symPieces = sym.name.toString.reverse.split("\\$", 3).reverseMap(_.reverse) - - try { - if (symPieces(1) != "default" || symPieces(0) == "copy") throw new IllegalArgumentException("") - val ownerString = symPieces(0) - val index = symPieces(2).toInt - 1 - Some((sym, tparams.map(_.symbol), vparamss.headOption.getOrElse(Nil), tpt.tpe, ownerString, index, rhs)) - } catch { - case _ : NumberFormatException | _ : IllegalArgumentException | _ : ArrayIndexOutOfBoundsException => - None - } - - case _ => None - } - } - } - - } - - object ExpressionExtractors { - import ExtractorHelpers._ - - object ExEpsilonExpression { - def unapply(tree: Apply) : Option[(Tree, Symbol, Tree)] = tree match { - case Apply( - TypeApply(ExSymbol("leon", "lang", "xlang", "epsilon"), typeTree :: Nil), - Function((vd @ ValDef(_, _, _, EmptyTree)) :: Nil, predicateBody) :: Nil) => - Some((typeTree, vd.symbol, predicateBody)) - case _ => None - } - } - - object ExErrorExpression { - def unapply(tree: Apply) : Option[(String, Tree)] = tree match { - case a @ Apply(TypeApply(ExSymbol("leon", "lang", "error"), List(tpe)), List(lit : Literal)) => - Some((lit.value.stringValue, tpe)) - case _ => - None - } - } - - object ExOldExpression { - def unapply(tree: Apply) : Option[Tree] = tree match { - case a @ Apply(TypeApply(ExSymbol("leon", "lang", "old"), List(tpe)), List(arg)) => - Some(arg) - case _ => - None - } - } - - object ExHoleExpression { - def unapply(tree: Tree) : Option[(Tree, List[Tree])] = tree match { - case a @ Apply(TypeApply(s @ ExSymbol("leon", "lang", "synthesis", "$qmark"), List(tpt)), args1) => - Some((tpt, args1)) - case TypeApply(s @ ExSymbol("leon", "lang", "synthesis", "$qmark$qmark$qmark"), List(tpt)) => - Some((tpt, Nil)) - case _ => None - } - } - - object ExChooseExpression { - def unapply(tree: Apply) : Option[Tree] = tree match { - case a @ Apply( - TypeApply(s @ ExSymbol("leon", "lang", "synthesis", "choose"), types), - predicate :: Nil) => - Some(predicate) - case _ => None - } - } - - object ExWithOracleExpression { - def unapply(tree: Apply) : Option[(List[(Tree, Symbol)], Tree)] = tree match { - case a @ Apply( - TypeApply(s @ ExSymbol("leon", "lang", "synthesis", "withOracle"), types), - Function(vds, body) :: Nil) => - Some((types zip vds.map(_.symbol), body)) - case _ => None - } - } - - object ExLambdaExpression { - def unapply(tree: Function) : Option[(Seq[ValDef], Tree)] = tree match { - case Function(vds, body) => Some((vds, body)) - case _ => None - } - } - - object ExForallExpression { - def unapply(tree: Apply) : Option[(List[(Tree, Symbol)], Tree)] = tree match { - case a @ Apply( - TypeApply(s @ ExSymbol("leon", "lang", "forall"), types), - Function(vds, predicateBody) :: Nil) => - Some((types zip vds.map(_.symbol), predicateBody)) - case _ => None - } - } - - object ExArrayUpdated { - def unapply(tree: Apply): Option[(Tree,Tree,Tree)] = tree match { - case Apply( - Apply(TypeApply(Select(Apply(ExSelected("scala", "Predef", s), Seq(lhs)), n), _), Seq(index, value)), - List(Apply(_, _))) if (s.toString contains "Array") && (n.toString == "updated") => Some((lhs, index, value)) - case _ => None - } - } - - object ExValDef { - /** Extracts val's in the head of blocks. */ - def unapply(tree: ValDef): Option[(Symbol,Tree,Tree)] = tree match { - case vd @ ValDef(mods, _, tpt, rhs) if !mods.isMutable => Some((vd.symbol, tpt, rhs)) - case _ => None - } - } - object ExVarDef { - /** Extracts var's in the head of blocks. */ - def unapply(tree: ValDef): Option[(Symbol,Tree,Tree)] = tree match { - case vd @ ValDef(mods, _, tpt, rhs) if mods.isMutable => Some((vd.symbol, tpt, rhs)) - case _ => None - } - } - - object ExAssign { - def unapply(tree: Assign): Option[(Symbol,Tree)] = tree match { - case Assign(id@Ident(_), rhs) => Some((id.symbol, rhs)) - //case Assign(sym@Select(This(_), v), rhs) => Some((sym.symbol, rhs)) - case _ => None - } - } - - object ExWhile { - def unapply(tree: LabelDef): Option[(Tree,Tree)] = tree match { - case (label@LabelDef( - _, _, If(cond, Block(body, jump@Apply(_, _)), unit@ExUnitLiteral()))) - if label.symbol == jump.symbol && unit.symbol == null => Some((cond, Block(body, unit))) - case _ => None - } - } - - object ExWhileWithInvariant { - def unapply(tree: Apply): Option[(Tree, Tree, Tree)] = tree match { - case Apply( - Select( - Apply(while2invariant, List(ExWhile(cond, body))), - invariantSym), - List(invariant)) if invariantSym.toString == "invariant" => Some((cond, body, invariant)) - case _ => None - } - } - - object ExArrayLiteral { - def unapply(tree: Apply): Option[(Type, Seq[Tree])] = tree match { - case Apply(ExSelected("scala", "Array", "apply"), args) => - tree.tpe match { - case TypeRef(_, _, List(t1)) => - Some((t1, args)) - case _ => - None - } - case Apply(Apply(TypeApply(ExSelected("scala", "Array", "apply"), List(tpt)), args), ctags) => - Some((tpt.tpe, args)) - - case _ => - None - } - } - - object ExTuple { - def unapply(tree: Apply): Option[(Seq[Type], Seq[Tree])] = tree match { - case Apply( - Select(New(tupleType), _), - List(e1, e2) - ) if tupleType.symbol == tuple2Sym => tupleType.tpe match { - case TypeRef(_, sym, List(t1, t2)) => Some((Seq(t1, t2), Seq(e1, e2))) - case _ => None - } - - case Apply( - Select(New(tupleType), _), - List(e1, e2, e3) - ) if tupleType.symbol == tuple3Sym => tupleType.tpe match { - case TypeRef(_, sym, List(t1, t2, t3)) => Some((Seq(t1, t2, t3), Seq(e1, e2, e3))) - case _ => None - } - case Apply( - Select(New(tupleType), _), - List(e1, e2, e3, e4) - ) if tupleType.symbol == tuple4Sym => tupleType.tpe match { - case TypeRef(_, sym, List(t1, t2, t3, t4)) => Some((Seq(t1, t2, t3, t4), Seq(e1, e2, e3, e4))) - case _ => None - } - case Apply( - Select(New(tupleType), _), - List(e1, e2, e3, e4, e5) - ) if tupleType.symbol == tuple5Sym => tupleType.tpe match { - case TypeRef(_, sym, List(t1, t2, t3, t4, t5)) => Some((Seq(t1, t2, t3, t4, t5), Seq(e1, e2, e3, e4, e5))) - case _ => None - } - // Match e1 -> e2 - case Apply(TypeApply(Select(Apply(TypeApply(ExSelected("scala", "Predef", "ArrowAssoc"), List(tpeFrom)), List(from)), ExNamed("$minus$greater")), List(tpeTo)), List(to)) => - - Some((Seq(tpeFrom.tpe, tpeTo.tpe), Seq(from, to))) - case _ => None - } - } - - object ExLocally { - def unapply(tree: Apply) : Option[Tree] = tree match { - case Apply(TypeApply(ExSelected("scala", "Predef", "locally"), _), List(body)) => - Some(body) - - case _ => - None - } - } - - object ExTupleExtract { - def unapply(tree: Select) : Option[(Tree,Int)] = tree match { - case Select(lhs, n) => { - val methodName = n.toString - if(methodName.head == '_') { - val indexString = methodName.tail - try { - val index = indexString.toInt - if(index > 0) { - Some((lhs, index)) - } else None - } catch { - case t: Throwable => - None - } - } else None - } - case _ => None - } - } - - object ExIfThenElse { - def unapply(tree: If): Option[(Tree,Tree,Tree)] = tree match { - case If(t1,t2,t3) => Some((t1,t2,t3)) - case _ => None - } - } - - object ExBooleanLiteral { - def unapply(tree: Literal): Option[Boolean] = tree match { - case Literal(Constant(true)) => Some(true) - case Literal(Constant(false)) => Some(false) - case _ => None - } - } - - object ExCharLiteral { - def unapply(tree: Literal): Option[Char] = tree match { - case Literal(c @ Constant(i)) if c.tpe == CharClass.tpe => Some(c.charValue) - case _ => None - } - } - - object ExInt32Literal { - def unapply(tree: Literal): Option[Int] = tree match { - case Literal(c @ Constant(i)) if c.tpe == IntClass.tpe => Some(c.intValue) - case _ => None - } - } - - object ExUnitLiteral { - def unapply(tree: Literal): Boolean = tree match { - case Literal(c @ Constant(_)) if c.tpe == UnitClass.tpe => true - case _ => false - } - } - - object ExSomeConstruction { - def unapply(tree: Apply) : Option[(Type,Tree)] = tree match { - case Apply(s @ Select(New(tpt), n), arg) if arg.size == 1 && n == nme.CONSTRUCTOR && tpt.symbol.name.toString == "Some" => tpt.tpe match { - case TypeRef(_, sym, tpe :: Nil) => Some((tpe, arg.head)) - case _ => None - } - case _ => None - } - } - - object ExCaseClassConstruction { - def unapply(tree: Apply): Option[(Tree,Seq[Tree])] = tree match { - case Apply(s @ Select(New(tpt), n), args) if n == nme.CONSTRUCTOR => { - Some((tpt, args)) - } - case _ => None - } - } - - object ExIdentifier { - def unapply(tree: Ident): Option[(Symbol,Tree)] = tree match { - case i: Ident => Some((i.symbol, i)) - case _ => None - } - } - - object ExTyped { - def unapply(tree : Typed): Option[(Tree,Tree)] = tree match { - case Typed(e,t) => Some((e,t)) - case _ => None - } - } - - object ExIntIdentifier { - def unapply(tree: Ident): Option[String] = tree match { - case i: Ident if i.symbol.tpe == IntClass.tpe => Some(i.symbol.name.toString) - case _ => None - } - } - - object ExAnd { - def unapply(tree: Apply): Option[(Tree,Tree)] = tree match { - case Apply(s @ Select(lhs, _), List(rhs)) if s.symbol == Boolean_and => - Some((lhs,rhs)) - case _ => None - } - } - - object ExOr { - def unapply(tree: Apply): Option[(Tree,Tree)] = tree match { - case Apply(s @ Select(lhs, _), List(rhs)) if s.symbol == Boolean_or => - Some((lhs,rhs)) - case _ => None - } - } - - object ExNot { - def unapply(tree: Select): Option[Tree] = tree match { - case Select(t, n) if n == nme.UNARY_! => Some(t) - case _ => None - } - } - - object ExEquals { - def unapply(tree: Apply): Option[(Tree,Tree)] = tree match { - case Apply(Select(lhs, n), List(rhs)) if n == nme.EQ => Some((lhs,rhs)) - case _ => None - } - } - - object ExNotEquals { - def unapply(tree: Apply): Option[(Tree,Tree)] = tree match { - case Apply(Select(lhs, n), List(rhs)) if n == nme.NE => Some((lhs,rhs)) - case _ => None - } - } - - object ExUMinus { - def unapply(tree: Select): Option[Tree] = tree match { - case Select(t, n) if n == nme.UNARY_- && hasBigIntType(t) => Some(t) - case _ => None - } - } - - object ExRealUMinus { - def unapply(tree: Select): Option[Tree] = tree match { - case Select(t, n) if n == nme.UNARY_- && hasRealType(t) => Some(t) - case _ => None - } - } - - object ExBVUMinus { - def unapply(tree: Select): Option[Tree] = tree match { - case Select(t, n) if n == nme.UNARY_- && hasIntType(t) => Some(t) - case _ => None - } - } - - object ExBVNot { - def unapply(tree: Select): Option[Tree] = tree match { - case Select(t, n) if n == nme.UNARY_~ && hasIntType(t) => Some(t) - case _ => None - } - } - - object ExPatternMatching { - def unapply(tree: Match): Option[(Tree,List[CaseDef])] = - if(tree != null) Some((tree.selector, tree.cases)) else None - } - - object ExBigIntPattern { - def unapply(tree: UnApply): Option[Tree] = tree match { - case ua @ UnApply(Apply(ExSelected("leon", "lang", "package", "BigInt", "unapply"), _), List(l)) => - Some(l) - case _ => - None - } - } - - object ExAsInstanceOf { - def unapply(tree: TypeApply) : Option[(Tree, Tree)] = tree match { - case TypeApply(Select(t, isInstanceOfName), typeTree :: Nil) if isInstanceOfName.toString == "asInstanceOf" => Some((t, typeTree)) - case _ => None - } - } - - object ExIsInstanceOf { - def unapply(tree: TypeApply) : Option[(Tree, Tree)] = tree match { - case TypeApply(Select(t, isInstanceOfName), typeTree :: Nil) if isInstanceOfName.toString == "isInstanceOf" => Some((typeTree, t)) - case _ => None - } - } - - object ExLiteralMap { - def unapply(tree: Apply): Option[(Tree, Tree, Seq[Tree])] = tree match { - case Apply(TypeApply(ExSelected("scala", "Predef", "Map", "apply"), fromTypeTree :: toTypeTree :: Nil), args) => - Some((fromTypeTree, toTypeTree, args)) - case _ => - None - } - } - object ExEmptyMap { - def unapply(tree: TypeApply): Option[(Tree, Tree)] = tree match { - case TypeApply(ExSelected("scala", "collection", "immutable", "Map", "empty"), fromTypeTree :: toTypeTree :: Nil) => - Some((fromTypeTree, toTypeTree)) - case TypeApply(ExSelected("scala", "Predef", "Map", "empty"), fromTypeTree :: toTypeTree :: Nil) => - Some((fromTypeTree, toTypeTree)) - case _ => - None - } - } - - object ExFiniteSet { - def unapply(tree: Apply): Option[(Tree,List[Tree])] = tree match { - case Apply(TypeApply(ExSelected("Set", "apply"), Seq(tpt)), args) => - Some(tpt, args) - case Apply(TypeApply(ExSelected("leon", "lang", "Set", "apply"), Seq(tpt)), args) => - Some(tpt, args) - case _ => None - } - } - - object ExFiniteBag { - def unapply(tree: Apply): Option[(Tree, List[Tree])] = tree match { - case Apply(TypeApply(ExSelected("Bag", "apply"), Seq(tpt)), args) => - Some(tpt, args) - case Apply(TypeApply(ExSelected("leon", "lang", "Bag", "apply"), Seq(tpt)), args) => - Some(tpt, args) - case _ => None - } - } - - object ExFiniteMap { - def unapply(tree: Apply): Option[(Tree, Tree, List[Tree])] = tree match { - case Apply(TypeApply(ExSelected("Map", "apply"), Seq(tptFrom, tptTo)), args) => - Some((tptFrom, tptTo, args)) - case Apply(TypeApply(ExSelected("leon", "lang", "Map", "apply"), Seq(tptFrom, tptTo)), args) => - Some((tptFrom, tptTo, args)) - case _ => None - } - } - - object ExParameterLessCall { - def unapply(tree: Tree): Option[(Tree, Symbol, Seq[Tree])] = tree match { - case s @ Select(t, _) => - Some((t, s.symbol, Nil)) - - case TypeApply(s @ Select(t, _), tps) => - Some((t, s.symbol, tps)) - - case TypeApply(i: Ident, tps) => - Some((i, i.symbol, tps)) - - case _ => - None - } - } - - object ExCall { - def unapply(tree: Tree): Option[(Tree, Symbol, Seq[Tree], Seq[Tree])] = tree match { - // foo / foo[T] - case ExParameterLessCall(t, s, tps) => - Some((t, s, tps, Nil)) - - // foo(args) - case Apply(i: Ident, args) => - Some((i, i.symbol, Nil, args)) - - // foo(args1)(args2) - case Apply(Apply(i: Ident, args1), args2) => - Some((i, i.symbol, Nil, args1 ++ args2)) - - // foo[T](args) - case Apply(ExParameterLessCall(t, s, tps), args) => - Some((t, s, tps, args)) - - // foo[T](args1)(args2) - case Apply(Apply(ExParameterLessCall(t, s, tps), args1), args2) => - Some((t, s, tps, args1 ++ args2)) - - case _ => None - } - } - - object ExUpdate { - def unapply(tree: Apply): Option[(Tree, Tree, Tree)] = tree match { - case Apply( - s @ Select(lhs, update), - index :: newValue :: Nil) if s.symbol.fullName.endsWith("Array.update") => - Some((lhs, index, newValue)) - case _ => None - } - } - - object ExArrayFill { - def unapply(tree: Apply): Option[(Tree, Tree, Tree)] = tree match { - case Apply( - Apply( - Apply( - TypeApply(ExSelected("scala", "Array", "fill"), baseType :: Nil), - length :: Nil - ), - defaultValue :: Nil - ), - manifest - ) => - Some((baseType, length, defaultValue)) - case _ => None - } - } - } -} diff --git a/src/main/scala/leon/frontends/scalac/ClassgenPhase.scala b/src/main/scala/leon/frontends/scalac/ClassgenPhase.scala deleted file mode 100644 index c7f5823ae6bf5e997758266693782df3a3e72518..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/frontends/scalac/ClassgenPhase.scala +++ /dev/null @@ -1,81 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package frontends.scalac - -import utils._ - -import scala.tools.nsc.{Settings,CompilerCommand} -import java.io.File -import java.nio.file.Files - -object ClassgenPhase extends LeonPhase[List[String], List[String]] { - - val optExtern = LeonFlagOptionDef("extern", "Run @extern function on the JVM", false) - - override val definedOptions: Set[LeonOptionDef[Any]] = Set(optExtern) - - val name = "Scalac .class generation" - val description = "Generation of .class for evaluation of @extern functions" - - implicit val debug = DebugSectionTrees - - def run(ctx: LeonContext, args: List[String]): (LeonContext, List[String]) = { - if (ctx.findOptionOrDefault(optExtern)) { - val timer = ctx.timers.frontend.extern.start() - - val settings = new Settings - - val scalaLib = Option(scala.Predef.getClass.getProtectionDomain.getCodeSource).map{ - _.getLocation.getPath - }.orElse( for { - // We are in Eclipse. Look in Eclipse plugins to find scala lib - eclipseHome <- Option(System.getenv("ECLIPSE_HOME")) - pluginsHome = eclipseHome + "/plugins" - plugins <- scala.util.Try(new File(pluginsHome).listFiles().map{ _.getAbsolutePath }).toOption - path <- plugins.find{ _ contains "scala-library"} - } yield path).getOrElse( ctx.reporter.fatalError( - "No Scala library found. If you are working in Eclipse, " + - "make sure to set the ECLIPSE_HOME environment variable to your Eclipse installation home directory" - )) - - val tempOut = Files.createTempDirectory("classes").toFile - - settings.classpath.value = scalaLib - settings.usejavacp.value = false - settings.deprecation.value = true - settings.outdir.value = tempOut.getPath - - val compilerOpts = Build.libFiles ::: args.filterNot(_.startsWith("--")) - - val command = new CompilerCommand(compilerOpts, settings) { - override val cmdName = "leon" - } - - if(command.ok) { - // Debugging code for classpath crap - // new scala.tools.util.PathResolver(settings).Calculated.basis.foreach { cp => - // cp.foreach( p => - // println(" => "+p.toString) - // ) - // } - - - val compiler = new FullScalaCompiler(settings, ctx) - val run = new compiler.Run - run.compile(command.files) - - timer.stop() - - val ctx2 = ctx.copy(classDir = Some(tempOut)) - - (ctx2, args) - } else { - ctx.reporter.fatalError("No input program.") - } - } else { - (ctx, args) - } - } -} - diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala deleted file mode 100644 index 324b23f9c0e7eb5bc103edd3677b72fb602ea15a..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala +++ /dev/null @@ -1,2145 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package frontends.scalac - -import scala.reflect.internal.util._ - -import scala.language.implicitConversions - -import purescala._ -import Definitions.{ - ClassDef => LeonClassDef, - ModuleDef => LeonModuleDef, - ValDef => LeonValDef, - Import => LeonImport, - _ -} - -import Expressions.{Expr => LeonExpr, This => LeonThis, _} -import Types.{TypeTree => LeonType, _} -import Common._ -import Extractors._ -import Constructors._ -import ExprOps.exists -import TypeOps.{exists => _, _} -import xlang.Expressions.{Block => _, _} -import xlang.ExprOps._ -import xlang.Constructors.{block => leonBlock} - -import leon.utils.{Position => LeonPosition, OffsetPosition => LeonOffsetPosition, RangePosition => LeonRangePosition, Bijection, DefinedPosition} - -trait CodeExtraction extends ASTExtractors { - self: LeonExtraction => - - import global._ - import global.definitions._ - import StructuralExtractors._ - import ExpressionExtractors._ - import scala.collection.immutable.Set - - val reporter = self.ctx.reporter - - implicit def scalaPosToLeonPos(p: global.Position): LeonPosition = { - if (p == NoPosition) { - leon.utils.NoPosition - } else if (p.isRange) { - val start = p.focusStart - val end = p.focusEnd - LeonRangePosition(start.line, start.column, start.point, - end.line, end.column, end.point, - p.source.file.file) - } else { - LeonOffsetPosition(p.line, p.column, p.point, - p.source.file.file) - } - } - - def leonPosToScalaPos(spos: global.Position, p: LeonPosition): global.Position = { - (spos, p) match { - case (NoPosition, _) => - NoPosition - - case (p, dp: DefinedPosition) => - new OffsetPosition(p.source, dp.focusBegin.point) - - case _ => - NoPosition - - } - } - - /** An exception thrown when non-purescala compatible code is encountered. */ - sealed class ImpureCodeEncounteredException(pos: Position, msg: String, ot: Option[Tree]) extends Exception(msg) { - def emit() { - val debugInfo = if (ctx.findOptionOrDefault(GlobalOptions.optDebug) contains utils.DebugSectionTrees) { - ot.map { t => - val strWr = new java.io.StringWriter() - new global.TreePrinter(new java.io.PrintWriter(strWr)).printTree(t) - " (Tree: "+strWr.toString+" ; Class: "+t.getClass+")" - }.getOrElse("") - } else { - "" - } - - if (ctx.findOptionOrDefault(ExtractionPhase.optStrictCompilation)) { - reporter.error(pos, msg + debugInfo) - } else { - reporter.warning(pos, msg + debugInfo) - } - } - } - - def outOfSubsetError(pos: Position, msg: String) = { - throw new ImpureCodeEncounteredException(pos, msg, None) - } - - def outOfSubsetError(t: Tree, msg: String) = { - throw new ImpureCodeEncounteredException(t.pos, msg, Some(t)) - } - - // Simple case classes to capture the representation of units/modules after discovering them. - case class ScalaUnit( - name : String, - pack : PackageRef, - imports : List[Import], - defs : List[Tree], - isPrintable : Boolean - ) - - class Extraction(units: List[CompilationUnit]) { - - case class DefContext(tparams: Map[Symbol, TypeParameter] = Map(), - vars: Map[Symbol, () => LeonExpr] = Map(), - mutableVars: Map[Symbol, () => LeonExpr] = Map(), - isExtern: Boolean = false){ - - def union(that: DefContext) = { - copy(this.tparams ++ that.tparams, - this.vars ++ that.vars, - this.mutableVars ++ that.mutableVars, - this.isExtern || that.isExtern) - } - - def isVariable(s: Symbol) = (vars contains s) || (mutableVars contains s) - - def withNewVars(nvars: Traversable[(Symbol, () => LeonExpr)]) = { - copy(vars = vars ++ nvars) - } - - def withNewVar(nvar: (Symbol, () => LeonExpr)) = { - copy(vars = vars + nvar) - } - - def withNewMutableVar(nvar: (Symbol, () => LeonExpr)) = { - copy(mutableVars = mutableVars + nvar) - } - - def withNewMutableVars(nvars: Traversable[(Symbol, () => LeonExpr)]) = { - copy(mutableVars = mutableVars ++ nvars) - } - } - - private var currentFunDef: FunDef = null - - // This one never fails, on error, it returns Untyped - def leonType(tpt: Type)(implicit dctx: DefContext, pos: Position): LeonType = { - try { - extractType(tpt) - } catch { - case e: ImpureCodeEncounteredException => - e.emit() - Untyped - } - } - - private def isIgnored(s: Symbol) = { - (annotationsOf(s) contains "ignore") - } - - private def isLibrary(u: CompilationUnit) = Build.libFiles contains u.source.file.absolute.path - - def extractProgram: Option[Program] = { - try { - val scalaUnits = units.map { u => u.body match { - // package object - case PackageDef(refTree, List(PackageDef(inner, body))) => - val name = extractPackageRef(inner).mkString("$") - val pack = extractPackageRef(refTree) ++ extractPackageRef(inner) - val imps = imports.getOrElse(refTree, Nil) - - ScalaUnit(name, pack, imps, body, !isLibrary(u)) - - // normal package - case pd@PackageDef(refTree, lst) => - val name = u.source.file.name.replaceFirst("[.][^.]+$", "") - val pack = extractPackageRef(refTree) - val imps = imports.getOrElse(refTree, Nil) - - ScalaUnit(name, pack, imps, lst, !isLibrary(u)) - - case _ => - outOfSubsetError(u.body, "Unexpected Unit body") - }} - - // Phase 1, we discover and define objects/classes/types - for (unit <- scalaUnits) collectClassSymbols(unit.defs) - - // Phase 2, build scafolding program with empty bodies - val leonUnits = scalaUnits.map { createLeonUnit } - - // Phase 3, define bodies of all functions/methods - for (unit <- scalaUnits) fillLeonUnit(unit) - - val pgm0 = Program(leonUnits) - - // Phase 4, resolve imports - val leonUnits1 = for ((sunit, lunit) <- scalaUnits zip leonUnits) yield { - val imports = sunit.imports.flatMap(i => extractImport(i, lunit)(pgm0)) - lunit.copy(imports = imports) - } - - val pgm1 = Program(leonUnits1) - - Some(pgm1) - } catch { - case icee: ImpureCodeEncounteredException => - icee.emit() - None - } - } - - private def collectClassSymbols(defs: List[Tree]) { - // We collect all defined classes - for (t <- defs if !t.isEmpty) t match { - case t if isIgnored(t.symbol) => - // ignore - - case ExAbstractClass(o2, sym, tmpl) => - seenClasses += sym -> ((Nil, tmpl)) - - case ExCaseClass(o2, sym, args, tmpl) => - seenClasses += sym -> ((args, tmpl)) - - case ExObjectDef(n, templ) => - for (t <- templ.body if !t.isEmpty) t match { - case t if isIgnored(t.symbol) => - // ignore - - case ExAbstractClass(_, sym, tmpl) => - seenClasses += sym -> ((Nil, tmpl)) - - case ExCaseClass(_, sym, args, tmpl) => - seenClasses += sym -> ((args, tmpl)) - - case _ => - } - - case _ => - } - } - - private def createLeonUnit(u: ScalaUnit): UnitDef = { - val ScalaUnit(name, pack, _, defs, isPrintable) = u - - val leonDefs = defs flatMap { - case t if isIgnored(t.symbol) => - // ignore - None - - case t @ ExAbstractClass(o2, sym, _) => - Some(getClassDef(sym, t.pos)) - - case t @ ExCaseClass(o2, sym, args, _) => - Some(getClassDef(sym, t.pos)) - - case t @ ExObjectDef(n, templ) => - // Module - val id = FreshIdentifier(n) - val leonDefs = templ.body.flatMap { - case t if isIgnored(t.symbol) => - // ignore - None - - case ExAbstractClass(_, sym, _) => - Some(getClassDef(sym, t.pos)) - - case ExCaseClass(_, sym, _, _) => - Some(getClassDef(sym, t.pos)) - - // Functions - case ExFunctionDef(sym, _, _, _, _) => - Some(defineFunDef(sym)(DefContext())) - - // Default value functions - case ExDefaultValueFunction(sym, _, _, _ ,_ , _, _) => - val fd = defineFunDef(sym)(DefContext()) - fd.addFlag(IsSynthetic) - - Some(fd) - - // Lazy vals - case ExLazyAccessorFunction(sym, _, _) => - Some(defineFieldFunDef(sym, true)(DefContext())) - - // Normal vals - case ExFieldDef(sym, _, _) => - Some(defineFieldFunDef(sym, false)(DefContext())) - - // var - case ExMutableFieldDef(sym, _, _) => - Some(defineFieldFunDef(sym, false)(DefContext())) - - // All these are expected, but useless - case ExCaseClassSyntheticJunk() - | ExConstructorDef() - | ExLazyFieldDef() - | ExFieldAccessorFunction() => - None - case d if (d.symbol.isImplicit && d.symbol.isSynthetic) => - None - - //vars are never accessed directly so we extract accessors and mutators and - //ignore bare variables - case d if d.symbol.isVar => - None - - // Everything else is unexpected - case tree => - println(tree) - outOfSubsetError(tree, "Don't know what to do with this. Not purescala?"); - } - - Some(LeonModuleDef(id, leonDefs, id.name == "package")) - - // Expected, but useless - case ExCaseClassSyntheticJunk() | ExConstructorDef() => None - case d if (d.symbol.isImplicit && d.symbol.isSynthetic) => None - - // Unexpected - case tree => - println(tree) - outOfSubsetError(tree, "Don't know what to do with this. Not purescala?"); - } - - // we only resolve imports once we have the full program - UnitDef(FreshIdentifier(name), pack, Nil, leonDefs, isPrintable) - } - - private def fillLeonUnit(u: ScalaUnit): Unit = { - def extractClassMembers(sym: Symbol, tpl: Template): Unit = { - for (t <- tpl.body if !t.isEmpty) { - extractFunOrMethodBody(Some(sym), t) - } - - classToInvariants.get(sym).foreach { bodies => - val cd = classesToClasses(sym) - - for (c <- (cd.ancestors.toSet ++ cd.root.knownDescendants + cd) if !c.methods.exists(_.isInvariant)) { - val fd = new FunDef(invId, Seq.empty, Seq.empty, BooleanType) - fd.addFlag(IsADTInvariant) - fd.addFlags(c.flags.collect { case annot : purescala.Definitions.Annotation => annot }) - fd.fullBody = BooleanLiteral(true) - c.registerMethod(fd) - } - - val fd = cd.methods.find(_.isInvariant).get - val ctparams = sym.tpe match { - case TypeRef(_, _, tps) => - extractTypeParams(tps).map(_._1) - case _ => - Nil - } - - val tparamsMap = (ctparams zip cd.tparams.map(_.tp)).toMap - val dctx = DefContext(tparamsMap) - - val body = andJoin(bodies.toSeq.filter(_ != EmptyTree).map { - body => flattenBlocks(extractTreeOrNoTree(body)(dctx)) - }) - - fd.fullBody = body - } - } - - for (t <- u.defs) t match { - case t if isIgnored(t.symbol) => - // ignore - - case ExAbstractClass(_, sym, tpl) => - extractClassMembers(sym, tpl) - - case ExCaseClass(_, sym, _, tpl) => - extractClassMembers(sym, tpl) - - case ExObjectDef(n, templ) => - for (t <- templ.body if !t.isEmpty) t match { - case t if isIgnored(t.symbol) => - // ignore - None - - case ExAbstractClass(_, sym, tpl) => - extractClassMembers(sym, tpl) - - case ExCaseClass(_, sym, _, tpl) => - extractClassMembers(sym, tpl) - - case t => - extractFunOrMethodBody(None, t) - } - case _ => - } - } - - private def getSelectChain(e: Tree): List[String] = { - def rec(e: Tree): List[Name] = e match { - case Select(q, name) => name :: rec(q) - case Ident(name) => List(name) - case EmptyTree => List() - case _ => - ctx.reporter.internalError("getSelectChain: unexpected Tree:\n" + e.toString) - } - rec(e).reverseMap(_.toString) - } - - private def extractPackageRef(refPath: RefTree): PackageRef = { - (getSelectChain(refPath.qualifier) :+ refPath.name.toString).filter(_ != "<empty>") - } - - private def extractImport(i: Import, current: UnitDef)(implicit pgm: Program): Seq[LeonImport] = { - val Import(expr, sels) = i - import DefOps._ - - val prefix = getSelectChain(expr) - - val allSels = sels map { prefix :+ _.name.toString } - - // Make a different import for each selector at the end of the chain - allSels flatMap { selectors => - assert(selectors.nonEmpty) - val (thePath, isWild) = selectors.last match { - case "_" => (selectors.dropRight(1), true) - case _ => (selectors, false) - } - - Some(LeonImport(thePath, isWild)) - } - } - - private var seenClasses = Map[Symbol, (Seq[(Symbol, ValDef)], Template)]() - private var classesToClasses = Map[Symbol, LeonClassDef]() - - def oracleType(pos: Position, tpe: LeonType) = { - classesToClasses.find { - case (sym, cl) => sym.fullName.toString == "leon.lang.synthesis.Oracle" - } match { - case Some((_, cd)) => - cd.typed(List(tpe)) - case None => - outOfSubsetError(pos, "Could not find class Oracle") - } - } - - def libraryClass(pos: Position, className: String): LeonClassDef = { - classesToClasses.find{ case (s, c) => s.fullName == className }.map(_._2).getOrElse { - outOfSubsetError(pos, "Could not find class "+className) - } - } - - def libraryCaseClass(pos: Position, className: String): CaseClassDef = { - libraryClass(pos, className) match { - case ccd: CaseClassDef => ccd - case _ => - outOfSubsetError(pos, "Class "+className+" is not a case class") - } - } - - private var paramsToDefaultValues = Map[Symbol,FunDef]() - - def getClassDef(sym: Symbol, pos: Position): LeonClassDef = { - classesToClasses.get(sym) match { - case Some(cd) => cd - case None => - if (seenClasses contains sym) { - val (args, tmpl) = seenClasses(sym) - - extractClassDef(sym, args, tmpl) - } else { - outOfSubsetError(pos, "Class "+sym.fullName+" not defined?") - } - } - } - - /** For every argument that is lazy, transforms it to a lambda. */ - def defaultArgConvert(tfd: TypedFunDef, args: Seq[LeonExpr]): Seq[LeonExpr] = { - val argsByName = (tfd.fd.params zip args).map(p => if (isLazy(p._1)) Lambda(Seq(), p._2) else p._2) - argsByName - } - - /** Returns the function associated to the symbol. - * In the case of varargs, if the function is not found - * and there are others with the same name in the same scope, - * finds an equivalent function and converts the argument.*/ - def getFunDef(sym: Symbol, pos: Position, allowFreeArgs: Boolean = true): (FunDef, (TypedFunDef, Seq[LeonExpr]) => Seq[LeonExpr]) = { - defsToDefs.get(sym) match { - case Some(fd) => (fd, defaultArgConvert) - case None => - // Look for other functions accepting lists if they exist. - val similarFunction = - if(!allowFreeArgs) None - else (defsToDefs.find{ case (s, fd) => fd.id.name == sym.nameString && - sym.owner == s.owner && - fd.params.length == 1 && (fd.paramIds(0).getType match { - case AbstractClassType(ccd, tps) => ccd.id.name == "List" - case _ => false - }) - }) - similarFunction match { - case Some((sym, fd)) => - val convertArgs = (tfd: TypedFunDef, elems: Seq[LeonExpr]) => { - val allowedType = fd.paramIds.head.getType match { - case AbstractClassType(_, Seq(tpe)) => tfd.translated(tpe) - } - val cons = CaseClassType(libraryCaseClass(sym.pos, "leon.collection.Cons"), Seq(allowedType)) - val nil = CaseClassType(libraryCaseClass(sym.pos, "leon.collection.Nil"), Seq(allowedType)) - List(elems.foldRight(CaseClass(nil, Seq())) { - case (e, ls) => CaseClass(cons, Seq(e, ls)) - }) - } - (fd, convertArgs) - case None => - outOfSubsetError(pos, "Function "+sym.name+" not properly defined?") - } - } - } - - private var isMethod = Set[Symbol]() - private var ignoredMethods = Set[Symbol]() - private var isMutator = Set[Symbol]() - private var methodToClass = Map[FunDef, LeonClassDef]() - private var classToInvariants = Map[Symbol, Set[Tree]]() - - /** - * For the function in $defs with name $owner, find its parameter with index $index, - * and registers $fd as the default value function for this parameter. - */ - private def registerDefaultMethod( - defs : List[Tree], - matcher : PartialFunction[Tree,Symbol], - index : Int, - fd : FunDef - ) { - // Search tmpl to find the function that includes this parameter - val paramOwner = defs.collectFirst(matcher).get - - // assumes single argument list - if(paramOwner.paramss.length != 1) { - outOfSubsetError(paramOwner.pos, "Multiple argument lists for a function are not allowed") - } - val theParam = paramOwner.paramss.head(index) - paramsToDefaultValues += (theParam -> fd) - } - - def extractClassDef(sym: Symbol, args: Seq[(Symbol, ValDef)], tmpl: Template): LeonClassDef = { - - //println(s"Extracting $sym") - - val id = FreshIdentifier(sym.name.toString).setPos(sym.pos) - - val tparamsMap = sym.tpe match { - case TypeRef(_, _, tps) => - extractTypeParams(tps) - case _ => - Nil - } - - val parent = sym.tpe.parents.headOption match { - case Some(TypeRef(_, parentSym, tps)) if seenClasses contains parentSym => - getClassDef(parentSym, sym.pos) match { - case acd: AbstractClassDef => - val defCtx = DefContext(tparamsMap.toMap) - val newTps = tps.map(extractType(_)(defCtx, sym.pos)) - val zip = (newTps zip tparamsMap.map(_._2)) - if (newTps.size != tparamsMap.size) { - outOfSubsetError(sym.pos, "Child classes should have the same number of type parameters as their parent") - None - } else if (zip.exists { - case (TypeParameter(_), _) => false - case _ => true - }) { - outOfSubsetError(sym.pos, "Child class type params should have a simple mapping to parent params") - None - } else if (zip.exists { - case (TypeParameter(id), ctp) => id.name != ctp.id.name - case _ => false - }) { - outOfSubsetError(sym.pos, "Child type params should be identical to parent class's (e.g. C[T1,T2] extends P[T1,T2])") - None - } else { - Some(acd.typed -> acd.tparams) - } - - case cd => - outOfSubsetError(sym.pos, s"Class $id cannot extend ${cd.id}") - None - } - - case p => - None - } - - val tparams = parent match { - case Some((p, tparams)) => tparams - case None => tparamsMap.map(t => TypeParameterDef(t._2)) - } - - val defCtx = DefContext((tparamsMap.map(_._1) zip tparams.map(_.tp)).toMap) - - // Extract class - val cd = if (sym.isAbstractClass) { - new AbstractClassDef(id, tparams, parent.map(_._1)) - } else { - new CaseClassDef(id, tparams, parent.map(_._1), sym.isModuleClass) - } - cd.setPos(sym.pos) - //println(s"Registering $sym") - classesToClasses += sym -> cd - cd.addFlags(annotationsOf(sym).map { case (name, args) => ClassFlag.fromName(name, args) }.toSet) - - // Register parent - parent.map(_._1).foreach(_.classDef.registerChild(cd)) - - // Extract case class fields - cd match { - case ccd: CaseClassDef => - - val fields = args.map { case (fsym, t) => - val tpe = leonType(t.tpt.tpe)(defCtx, fsym.pos) - val id = cachedWithOverrides(fsym, Some(ccd), tpe) - if (tpe != id.getType) println(tpe, id.getType) - LeonValDef(id.setPos(t.pos)).setPos(t.pos).setIsVar(fsym.accessed.isVar) - } - - //println(s"Fields of $sym") - ccd.setFields(fields) - - // checks whether this type definition could lead to an infinite type - def computeChains(tpe: LeonType): Map[TypeParameterDef, Set[LeonClassDef]] = { - var seen: Set[LeonClassDef] = Set.empty - var chains: Map[TypeParameterDef, Set[LeonClassDef]] = Map.empty - - def rec(tpe: LeonType): Set[LeonClassDef] = tpe match { - case ct: ClassType => - val root = ct.classDef.root - if (!seen(ct.classDef.root)) { - seen += ct.classDef.root - for (cct <- ct.root.knownCCDescendants; - (tp, tpe) <- cct.classDef.tparams zip cct.tps) { - val relevant = rec(tpe) - chains += tp -> (chains.getOrElse(tp, Set.empty) ++ relevant) - for (cd <- relevant; vd <- cd.fields) { - rec(vd.getType) - } - } - } - Set(root) - - case Types.NAryType(tpes, _) => - tpes.flatMap(rec).toSet - } - - rec(tpe) - chains - } - - val chains = computeChains(ccd.typed) - - def check(tp: TypeParameterDef, seen: Set[LeonClassDef]): Unit = chains.get(tp) match { - case Some(classDefs) => - if ((seen intersect classDefs).nonEmpty) { - outOfSubsetError(sym.pos, "Infinite types are not allowed") - } else { - for (cd <- classDefs; tp <- cd.tparams) check(tp, seen + cd) - } - case None => - } - - for (tp <- ccd.tparams) check(tp, Set.empty) - - case _ => - } - - //println(s"Body of $sym") - - // We collect the methods and fields - for (d <- tmpl.body) d match { - case EmptyTree => - // ignore - - case t if isIgnored(t.symbol) => - // ignore - d match { - // Special case so that we can find methods with varargs. - case ExFunctionDef(fsym, _, _, _, _) => - ignoredMethods += fsym - case _ => - } - - // Normal methods - case t @ ExFunctionDef(fsym, _, _, _, _) => - isMethod += fsym - val fd = defineFunDef(fsym, Some(cd))(defCtx) - - methodToClass += fd -> cd - - cd.registerMethod(fd) - - case ExRequiredExpression(body) => - classToInvariants += sym -> (classToInvariants.getOrElse(sym, Set.empty) + body) - - // Default values for parameters - case t@ ExDefaultValueFunction(fsym, _, _, _, owner, index, _) => - isMethod += fsym - val fd = defineFunDef(fsym)(defCtx) - fd.addFlag(IsSynthetic) - methodToClass += fd -> cd - - cd.registerMethod(fd) - val matcher: PartialFunction[Tree, Symbol] = { - case ExFunctionDef(ownerSym, _ ,_ ,_, _) if ownerSym.name.toString == owner => ownerSym - } - registerDefaultMethod(tmpl.body, matcher, index, fd ) - - // Lazy fields - case t @ ExLazyAccessorFunction(fsym, _, _) => - isMethod += fsym - val fd = defineFieldFunDef(fsym, true, Some(cd))(defCtx) - methodToClass += fd -> cd - - cd.registerMethod(fd) - - // normal fields - case t @ ExFieldDef(fsym, _, _) => - //println(fsym + "matched as ExFieldDef") - // we will be using the accessor method of this field everywhere - isMethod += fsym - val fd = defineFieldFunDef(fsym, false, Some(cd))(defCtx) - methodToClass += fd -> cd - - cd.registerMethod(fd) - - case t @ ExMutableFieldDef(fsym, _, _) => - //println(fsym + "matched as ExMutableFieldDef") - // we will be using the accessor method of this field everywhere - //isMethod += fsym - //val fd = defineFieldFunDef(fsym, false, Some(cd))(defCtx) - //methodToClass += fd -> cd - - //cd.registerMethod(fd) - - case t@ ExMutatorAccessorFunction(fsym, _, _, _, _) => - //println("FOUND mutator: " + t) - //println("accessed: " + fsym.accessed) - isMutator += fsym - //val fd = defineFunDef(fsym, Some(cd))(defCtx) - - //methodToClass += fd -> cd - - //cd.registerMethod(fd) - - case other => - - } - - //println(s"End body $sym") - - cd - } - - // Returns the parent's method Identifier if sym overrides a symbol, otherwise a fresh Identifier - - private val funOrFieldSymsToIds = new Bijection[Symbol, Identifier] - - private def cachedWithOverrides(sym: Symbol, within: Option[LeonClassDef], tpe: LeonType = Untyped) = { - - val topOfHierarchy = sym.overrideChain.last - - funOrFieldSymsToIds.cachedB(topOfHierarchy){ - FreshIdentifier(sym.name.toString.trim, tpe) //trim because sometimes Scala names end with a trailing space, looks nicer without the space - } - } - - private val invId = FreshIdentifier("inv", BooleanType) - - private var isLazy = Set[LeonValDef]() - - private var defsToDefs = Map[Symbol, FunDef]() - - private def defineFunDef(sym: Symbol, within: Option[LeonClassDef] = None)(implicit dctx: DefContext): FunDef = { - // Type params of the function itself - val tparams = extractTypeParams(sym.typeParams.map(_.tpe)) - - val nctx = dctx.copy(tparams = dctx.tparams ++ tparams.toMap) - - val newParams = sym.info.paramss.flatten.map{ sym => - val ptpe = leonType(sym.tpe)(nctx, sym.pos) - val tpe = if (sym.isByNameParam) FunctionType(Seq(), ptpe) else ptpe - val newID = FreshIdentifier(sym.name.toString, tpe).setPos(sym.pos) - val vd = LeonValDef(newID).setPos(sym.pos) - - if (sym.isByNameParam) { - isLazy += vd - } - - vd - } - - val tparamsDef = tparams.map(t => TypeParameterDef(t._2)) - - val returnType = leonType(sym.info.finalResultType)(nctx, sym.pos) - - // @mk: We type the identifiers of methods during code extraction because - // a possible implementing/overriding field will use this same Identifier - val idType = { - val argTypes = newParams map { _.getType } - if (argTypes.nonEmpty) FunctionType(argTypes, returnType) - else returnType - } - - val id = cachedWithOverrides(sym, within, idType) - - val fd = new FunDef(id.setPos(sym.pos), tparamsDef, newParams, returnType) - - fd.setPos(sym.pos) - - fd.addFlags(annotationsOf(sym).map { case (name, args) => FunctionFlag.fromName(name, args) }.toSet) - - if (sym.isImplicit) { - fd.addFlag(IsInlined) - } - - defsToDefs += sym -> fd - - fd - } - - private def defineFieldFunDef(sym : Symbol, isLazy : Boolean, within: Option[LeonClassDef] = None)(implicit dctx : DefContext) : FunDef = { - - val nctx = dctx.copy(tparams = dctx.tparams) - - val returnType = leonType(sym.info.finalResultType)(nctx, sym.pos) - - // @mk: We type the identifiers of methods during code extraction because - // a possible implementing/overriding field will use this same Identifier - val id = cachedWithOverrides(sym, within, returnType) - val fd = new FunDef(id.setPos(sym.pos), Seq(), Seq(), returnType) - - fd.setPos(sym.pos) - fd.addFlag(IsField(isLazy)) - - fd.addFlags(annotationsOf(sym).map { case (name, args) => FunctionFlag.fromName(name, args) }.toSet) - - defsToDefs += sym -> fd - - fd - } - - private def extractFunOrMethodBody(ocsym: Option[Symbol], t: Tree) { - - val ctparamsMap = ocsym match { - case Some(csym) => - val cd = classesToClasses(csym) - - val ctparams = csym.tpe match { - case TypeRef(_, _, tps) => - extractTypeParams(tps).map(_._1) - case _ => - Nil - } - - ctparams zip cd.tparams.map(_.tp) - - case None => - Map[Symbol, TypeParameter]() - } - - t match { - case t if isIgnored(t.symbol) => - //ignore - - case ExFunctionDef(sym, tparams, params, _, body) => - val fd = defsToDefs(sym) - - val tparamsMap = (tparams zip fd.tparams.map(_.tp)).toMap ++ ctparamsMap - - if(body != EmptyTree) { - extractFunBody(fd, params, body)(DefContext(tparamsMap)) - } - - // Default value functions - case ExDefaultValueFunction(sym, tparams, params, _, _, _, body) => - val fd = defsToDefs(sym) - - val tparamsMap = (tparams zip fd.tparams.map(_.tp)).toMap ++ ctparamsMap - - if(body != EmptyTree) { - extractFunBody(fd, params, body)(DefContext(tparamsMap)) - } - - // Lazy fields - case t @ ExLazyAccessorFunction(sym, _, body) => - val fd = defsToDefs(sym) - val tparamsMap = ctparamsMap - - if(body != EmptyTree) { - extractFunBody(fd, Seq(), body)(DefContext(tparamsMap.toMap)) - } - - // normal fields - case t @ ExFieldDef(sym, _, body) => // if !sym.isSynthetic && !sym.isAccessor => - val fd = defsToDefs(sym) - val tparamsMap = ctparamsMap - - if(body != EmptyTree) { - extractFunBody(fd, Seq(), body)(DefContext(tparamsMap.toMap)) - } - - case t @ ExMutableFieldDef(sym, _, body) => // if !sym.isSynthetic && !sym.isAccessor => - //val fd = defsToDefs(sym) - //val tparamsMap = ctparamsMap - - //if(body != EmptyTree) { - // extractFunBody(fd, Seq(), body)(DefContext(tparamsMap.toMap)) - //} - - case ExMutatorAccessorFunction(sym, tparams, params, _, body) => - //val fd = defsToDefs(sym) - - //val tparamsMap = (tparams zip fd.tparams.map(_.tp)).toMap ++ ctparamsMap - - //val classSym = ocsym.get - //val cd = classesToClasses(classSym).asInstanceOf[CaseClassDef] - //val classVarDefs = seenClasses(classSym)._2 - //val mutableFields = classVarDefs.zip(cd.varFields).map(p => (p._1._1, () => p._2.toVariable)) - - //val dctx = DefContext(tparamsMap) - //val pctx = dctx.withNewMutableVars(mutableFields) - - //if(body != EmptyTree) { - // extractFunBody(fd, params, body)(pctx) - //} - - case _ => - } - } - - private def extractTypeParams(tps: Seq[Type]): Seq[(Symbol, TypeParameter)] = { - tps.flatMap { - case TypeRef(_, sym, Nil) => - Some(sym -> TypeParameter.fresh(sym.name.toString)) - case t => - outOfSubsetError(t.typeSymbol.pos, "Unhandled type for parameter: "+t) - None - } - } - - //var objects2Objects = Map[Identifier, LeonModuleDef]() - - private def extractFunBody(funDef: FunDef, params: Seq[ValDef], body0 : Tree)(dctx: DefContext): FunDef = { - currentFunDef = funDef - - // Find defining function for params with default value - for ((s,vd) <- params zip funDef.params) { - vd.defaultValue = paramsToDefaultValues.get(s.symbol) - } - - val newVars = for ((s, vd) <- params zip funDef.params) yield s.symbol -> { - if (s.symbol.isByNameParam) () => Application(Variable(vd.id), Seq()) - else () => Variable(vd.id) - } - - val fctx = dctx.withNewVars(newVars).copy(isExtern = funDef.annotations("extern")) - - // If this is a lazy field definition, drop the assignment/ accessing - val body = - if (funDef.flags.contains(IsField(true))) { body0 match { - case Block(List(Assign(_, realBody)),_ ) => realBody - case _ => outOfSubsetError(body0, "Wrong form of lazy accessor") - }} else body0 - - val finalBody = try { - flattenBlocks(extractTreeOrNoTree(body)(fctx)) - } catch { - case e: ImpureCodeEncounteredException => - e.emit() - //val pos = if (body0.pos == NoPosition) NoPosition else leonPosToScalaPos(body0.pos.source, funDef.getPos) - if (ctx.findOptionOrDefault(ExtractionPhase.optStrictCompilation)) { - reporter.error(funDef.getPos, "Function "+funDef.id.name+" could not be extracted. The function likely uses features not supported by Leon.") - } else { - reporter.warning(funDef.getPos, "Function "+funDef.id.name+" is not fully available to Leon.") - } - - funDef.addFlag(IsAbstract) - NoTree(funDef.returnType) - } - - //if (fctx.isExtern && !exists(_.isInstanceOf[NoTree])(finalBody)) { - // reporter.warning(finalBody.getPos, "External function could be extracted as Leon tree: "+finalBody) - //} - - funDef.fullBody = finalBody - if(fctx.isExtern) { //extern never keeps the body, but we keep pre and post - funDef.body = None - } - - // Post-extraction sanity checks - - funDef.precondition.foreach { case e => - if(containsLetDef(e)) { - reporter.warning(e.getPos, "Function precondition should not contain nested function definition, ignoring.") - funDef.precondition = None - } - } - - funDef.postcondition.foreach { e => - if(containsLetDef(e)) { - reporter.warning(e.getPos, "Function postcondition should not contain nested function definition, ignoring.") - funDef.postcondition = None - } - } - - funDef - } - - private def extractPattern(p: Tree, binder: Option[Identifier] = None)(implicit dctx: DefContext): (Pattern, DefContext) = p match { - case b @ Bind(name, t @ Typed(pat, tpt)) => - val newID = FreshIdentifier(name.toString, extractType(tpt)).setPos(b.pos) - val pctx = dctx.withNewVar(b.symbol -> (() => Variable(newID))) - extractPattern(t, Some(newID))(pctx) - - case b @ Bind(name, pat) => - val newID = FreshIdentifier(name.toString, extractType(b)).setPos(b.pos) - val pctx = dctx.withNewVar(b.symbol -> (() => Variable(newID))) - extractPattern(pat, Some(newID))(pctx) - - case t @ Typed(Ident(nme.WILDCARD), tpt) => - extractType(tpt) match { - case ct: ClassType => - (InstanceOfPattern(binder, ct).setPos(p.pos), dctx) - - case lt => - outOfSubsetError(tpt, "Invalid type "+tpt.tpe+" for .isInstanceOf") - } - - case Ident(nme.WILDCARD) => - (WildcardPattern(binder).setPos(p.pos), dctx) - - case s @ Select(_, b) if s.tpe.typeSymbol.isCase => - // case Obj => - extractType(s) match { - case ct: CaseClassType => - assert(ct.classDef.fields.isEmpty) - (CaseClassPattern(binder, ct, Seq()).setPos(p.pos), dctx) - case _ => - outOfSubsetError(s, "Invalid type "+s.tpe+" for .isInstanceOf") - } - - case a @ Apply(fn, args) => - - extractType(a) match { - case ct: CaseClassType => - assert(args.size == ct.classDef.fields.size) - val (subPatterns, subDctx) = args.map(extractPattern(_)).unzip - - val nctx = subDctx.foldLeft(dctx)(_ union _) - - (CaseClassPattern(binder, ct, subPatterns).setPos(p.pos), nctx) - case TupleType(argsTpes) => - val (subPatterns, subDctx) = args.map(extractPattern(_)).unzip - - val nctx = subDctx.foldLeft(dctx)(_ union _) - - (TuplePattern(binder, subPatterns).setPos(p.pos), nctx) - case _ => - outOfSubsetError(a, "Invalid type "+a.tpe+" for .isInstanceOf") - } - - case ExBigIntPattern(n: Literal) => - val lit = InfiniteIntegerLiteral(BigInt(n.value.stringValue)) - (LiteralPattern(binder, lit), dctx) - - case ExInt32Literal(i) => (LiteralPattern(binder, IntLiteral(i)), dctx) - case ExBooleanLiteral(b) => (LiteralPattern(binder, BooleanLiteral(b)), dctx) - case ExUnitLiteral() => (LiteralPattern(binder, UnitLiteral()), dctx) - case ExStringLiteral(s) => (LiteralPattern(binder, StringLiteral(s)), dctx) - - case up@ExUnapplyPattern(s, args) => - implicit val p: Position = NoPosition - val (fd, _) = getFunDef(s, up.pos, allowFreeArgs=false) - val (sub, ctx) = args.map (extractPattern(_)).unzip - val unapplyMethod = defsToDefs(s) - val formalTypes = tupleTypeWrap( - unapplyMethod.params.map { _.getType } ++ - unapplyMethod.returnType.asInstanceOf[ClassType].tps - ) - val realTypes = tupleTypeWrap(Seq( - extractType(up.tpe), - tupleTypeWrap(args map { tr => extractType(tr.tpe)}) - )) - val newTps = canBeSupertypeOf(formalTypes, realTypes) match { - case Some(tmap) => - fd.tparams map { tpd => tmap.getOrElse(tpd.tp, tpd.tp) } - case None => - //println(realTypes, formalTypes) - reporter.fatalError("Could not instantiate type of unapply method") - } - - (UnapplyPattern(binder, fd.typed(newTps), sub).setPos(up.pos), ctx.foldLeft(dctx)(_ union _)) - - case _ => - outOfSubsetError(p, "Unsupported pattern: "+p.getClass) - } - - private def extractMatchCase(cd: CaseDef)(implicit dctx: DefContext): MatchCase = { - val (recPattern, ndctx) = extractPattern(cd.pat) - val recBody = extractTree(cd.body)(ndctx) - - if(cd.guard == EmptyTree) { - SimpleCase(recPattern, recBody).setPos(cd.pos) - } else { - val recGuard = extractTree(cd.guard)(ndctx) - - if(isXLang(recGuard)) { - outOfSubsetError(cd.guard.pos, "Guard expression must be pure") - } - - GuardedCase(recPattern, recGuard, recBody).setPos(cd.pos) - } - } - - private def extractTreeOrNoTree(tr: Tree)(implicit dctx: DefContext): LeonExpr = { - try { - extractTree(tr) - } catch { - case e: ImpureCodeEncounteredException => - if (dctx.isExtern) { - NoTree(extractType(tr)).setPos(tr.pos) - } else { - throw e - } - } - } - - private def extractTree(tr: Tree)(implicit dctx: DefContext): LeonExpr = { - val (current, tmpRest) = tr match { - case Block(Block(e :: es1, l1) :: es2, l2) => - (e, Some(Block(es1 ++ Seq(l1) ++ es2, l2))) - case Block(e :: Nil, last) => - (e, Some(last)) - case Block(e :: es, last) => - (e, Some(Block(es, last))) - case Block(Nil, last) => - (last, None) - case e => - (e, None) - } - - var rest = tmpRest - - val res = current match { - case ExEnsuredExpression(body, contract) => - val post = extractTree(contract) - - val b = extractTreeOrNoTree(body) - - val closure = post match { - case IsTyped(_, BooleanType) => - val resId = FreshIdentifier("res", b.getType).setPos(post) - Lambda(Seq(LeonValDef(resId)), post).setPos(post) - case l: Lambda => l - case other => - val resId = FreshIdentifier("res", b.getType).setPos(post) - Lambda(Seq(LeonValDef(resId)), application(other, Seq(Variable(resId)))).setPos(post) - } - - Ensuring(b, closure) - - case t @ ExHoldsWithProofExpression(body, ExMaybeBecauseExpressionWrapper(proof)) => - val resId = FreshIdentifier("holds", BooleanType).setPos(current.pos) - val p = extractTreeOrNoTree(proof) - val post = Lambda(Seq(LeonValDef(resId)), And(Seq(p, Variable(resId)))).setPos(current.pos) - val b = extractTreeOrNoTree(body) - Ensuring(b, post) - - case t @ ExHoldsExpression(body) => - val resId = FreshIdentifier("holds", BooleanType).setPos(current.pos) - val post = Lambda(Seq(LeonValDef(resId)), Variable(resId)).setPos(current.pos) - - val b = extractTreeOrNoTree(body) - - Ensuring(b, post) - - // If the because statement encompasses a holds statement - case t @ ExBecauseExpression(ExHoldsExpression(body), proof) => - val resId = FreshIdentifier("holds", BooleanType).setPos(current.pos) - val p = extractTreeOrNoTree(proof) - val post = Lambda(Seq(LeonValDef(resId)), And(Seq(p, Variable(resId)))).setPos(current.pos) - val b = extractTreeOrNoTree(body) - Ensuring(b, post) - - case t @ ExComputesExpression(body, expected) => - val b = extractTreeOrNoTree(body).setPos(body.pos) - val expected_expr = extractTreeOrNoTree(expected).setPos(expected.pos) - - val resId = FreshIdentifier("res", b.getType).setPos(current.pos) - val post = Lambda(Seq(LeonValDef(resId)), Equals(Variable(resId), expected_expr)).setPos(current.pos) - - Ensuring(b, post) - - case t @ ExByExampleExpression(input, output) => - val input_expr = extractTreeOrNoTree(input).setPos(input.pos) - val output_expr = extractTreeOrNoTree(output).setPos(output.pos) - Passes(input_expr, output_expr, MatchCase(WildcardPattern(None), Some(BooleanLiteral(false)), NoTree(output_expr.getType))::Nil) - - case t @ ExAskExpression(input, output) => - val input_expr = extractTreeOrNoTree(input).setPos(input.pos) - val output_expr = extractTreeOrNoTree(output).setPos(output.pos) - - val resId = FreshIdentifier("res", output_expr.getType).setPos(current.pos) - val post = Lambda(Seq(LeonValDef(resId)), - Passes(input_expr, Variable(resId), MatchCase(WildcardPattern(None), Some(BooleanLiteral(false)), NoTree(output_expr.getType))::Nil)).setPos(current.pos) - - Ensuring(output_expr, post) - - case t @ ExBigLengthExpression(input) => - val input_expr = extractTreeOrNoTree(input).setPos(input.pos) - StringBigLength(input_expr) - case t @ ExBigSubstringExpression(input, start) => - val input_expr = extractTreeOrNoTree(input).setPos(input.pos) - val start_expr = extractTreeOrNoTree(start).setPos(start.pos) - val s = FreshIdentifier("s", StringType) - let(s, input_expr, - BigSubString(Variable(s), start_expr, StringBigLength(Variable(s))) - ) - - case t @ ExBigSubstring2Expression(input, start, end) => - val input_expr = extractTreeOrNoTree(input).setPos(input.pos) - val start_expr = extractTreeOrNoTree(start).setPos(start.pos) - val end_expr = extractTreeOrNoTree(end).setPos(end.pos) - BigSubString(input_expr, start_expr, end_expr) - - case ExAssertExpression(contract, oerr) => - val const = extractTree(contract) - val b = rest.map(extractTreeOrNoTree).getOrElse(UnitLiteral()) - - rest = None - - Assert(const, oerr, b) - - case ExRequiredExpression(contract) => - val pre = extractTree(contract) - - val b = rest.map(extractTreeOrNoTree).getOrElse(UnitLiteral()) - - rest = None - - Require(pre, b) - - case ExPasses(in, out, cases) => - val ine = extractTree(in) - val oute = extractTree(out) - val rc = cases.map(extractMatchCase) - - // @mk: FIXME: this whole sanity checking is very dodgy at best. - val ines = unwrapTuple(ine, ine.isInstanceOf[Tuple]) // @mk We untuple all tuples - ines foreach { - case v @ Variable(_) if currentFunDef.params.map{ _.toVariable } contains v => - case LeonThis(_) => - case other => ctx.reporter.fatalError(other.getPos, "Only i/o variables are allowed in i/o examples") - } - oute match { - case Variable(_) => // FIXME: this is not strict enough, we need the bound variable of enclosing Ensuring - case other => ctx.reporter.fatalError(other.getPos, "Only i/o variables are allowed in i/o examples") - } - passes(ine, oute, rc) - - case ExArrayLiteral(tpe, args) => - finiteArray(args.map(extractTree), None, extractType(tpe)(dctx, current.pos)) - - case ExCaseObject(sym) => - getClassDef(sym, current.pos) match { - case ccd: CaseClassDef => - CaseClass(CaseClassType(ccd, Seq()), Seq()) - case _ => - outOfSubsetError(current, "Unknown case object "+sym.name) - } - - case ExTuple(tpes, exprs) => - val tupleExprs = exprs.map(e => extractTree(e)) - Tuple(tupleExprs) - - case ex@ExOldExpression(t) if dctx.isVariable(t.symbol) => - val sym = t.symbol - dctx.vars.get(sym).orElse(dctx.mutableVars.get(sym)) match { - case Some(builder) => - val Variable(id) = builder() - Old(id).setPos(ex.pos) - case None => - outOfSubsetError(current, "old can only be used with variables") - } - case ex@ExOldExpression(t: This) => - extractType(t) match { - case ct: ClassType => - OldThis(ct) - case _ => - outOfSubsetError(t, "Invalid usage of `this`") - } - - //TODO: could have a case to extract Old of CaseClassSelector and map them to Selectors of OldThis. - case ex@ExOldExpression(t) => - outOfSubsetError(t, "Invalid usage of `old` with expression: " + t + ". Only works with variables and `this` keyword") - - - case ExErrorExpression(str, tpt) => - Error(extractType(tpt), str) - - case ExTupleExtract(tuple, index) => - val tupleExpr = extractTree(tuple) - - tupleExpr.getType match { - case TupleType(tpes) if tpes.size >= index => - tupleSelect(tupleExpr, index, true) - - case _ => - outOfSubsetError(current, "Invalid tuple access") - } - - case ExValDef(vs, tpt, bdy) => - val binderTpe = extractType(tpt) - val newID = FreshIdentifier(vs.name.toString, binderTpe) - val valTree = extractTree(bdy) - - val restTree = rest match { - case Some(rst) => - val nctx = dctx.withNewVar(vs -> (() => Variable(newID))) - extractTree(rst)(nctx) - case None => - UnitLiteral() - } - - rest = None - Let(newID, valTree, restTree) - - case d @ ExFunctionDef(sym, tparams, params, ret, b) => - val fd = defineFunDef(sym) - - val tparamsMap = (tparams zip fd.tparams.map(_.tp)).toMap - - fd.addFlags(annotationsOf(d.symbol).map { case (name, args) => FunctionFlag.fromName(name, args) }.toSet) - - val newDctx = dctx.copy(tparams = dctx.tparams ++ tparamsMap) - - val restTree = rest match { - case Some(rst) => extractTree(rst) - case None => UnitLiteral() - } - rest = None - - val oldCurrentFunDef = currentFunDef - - val funDefWithBody = extractFunBody(fd, params, b)(newDctx) - - currentFunDef = oldCurrentFunDef - - val (other_fds, block) = restTree match { - case LetDef(fds, block) => - (fds, block) - case _ => - (Nil, restTree) - } - letDef(funDefWithBody +: other_fds, block) - - // FIXME case ExDefaultValueFunction - - /** - * XLang Extractors - */ - - case ExVarDef(vs, tpt, bdy) => { - val binderTpe = extractType(tpt) - val newID = FreshIdentifier(vs.name.toString, binderTpe) - val valTree = extractTree(bdy) - - val restTree = rest match { - case Some(rst) => { - val nv = vs -> (() => Variable(newID)) - val nctx = dctx.withNewVar(nv).withNewMutableVar(nv) - extractTree(rst)(nctx) - } - case None => UnitLiteral() - } - - rest = None - - LetVar(newID, valTree, restTree) - } - - case a@ExAssign(sym, rhs) => { - dctx.mutableVars.get(sym) match { - case Some(fun) => - val Variable(id) = fun() - val rhsTree = extractTree(rhs) - Assignment(id, rhsTree) - - case None => - outOfSubsetError(a, "Undeclared variable.") - }} - - case wh @ ExWhile(cond, body) => - val condTree = extractTree(cond) - val bodyTree = extractTree(body) - While(condTree, bodyTree) - - case wh @ ExWhileWithInvariant(cond, body, inv) => - val condTree = extractTree(cond) - val bodyTree = extractTree(body) - val invTree = extractTree(inv) - - val w = While(condTree, bodyTree) - w.invariant = Some(invTree) - w - - case epsi @ ExEpsilonExpression(tpt, varSym, predBody) => - val pstpe = extractType(tpt) - val nctx = dctx.withNewVar(varSym -> (() => EpsilonVariable(epsi.pos, pstpe))) - val c1 = extractTree(predBody)(nctx) - if(containsEpsilon(c1)) { - outOfSubsetError(epsi, "Usage of nested epsilon is not allowed") - } - Epsilon(c1, pstpe) - - case update @ ExUpdate(lhs, index, newValue) => - val lhsRec = extractTree(lhs) - val indexRec = extractTree(index) - val newValueRec = extractTree(newValue) - ArrayUpdate(lhsRec, indexRec, newValueRec) - - case ExBigIntLiteral(n: Literal) => - InfiniteIntegerLiteral(BigInt(n.value.stringValue)) - - case ExBigIntLiteral(n) => outOfSubsetError(tr, "Non-literal BigInt constructor") - - case ExIntToBigInt(tree) => - val rec = extractTree(tree) - rec match { - case IntLiteral(n) => - InfiniteIntegerLiteral(BigInt(n)) - case _ => - outOfSubsetError(tr, "Conversion from Int to BigInt") - } - - case ExRealLiteral(n, d) => - val rn = extractTree(n) - val rd = extractTree(d) - (rn, rd) match { - case (InfiniteIntegerLiteral(n), InfiniteIntegerLiteral(d)) => - FractionalLiteral(n, d) - case _ => - outOfSubsetError(tr, "Real not build from literals") - } - case ExRealIntLiteral(n) => - val rn = extractTree(n) - rn match { - case InfiniteIntegerLiteral(n) => - FractionalLiteral(n, 1) - case _ => - outOfSubsetError(tr, "Real not build from literals") - } - - case ExInt32Literal(v) => - IntLiteral(v) - - case ExBooleanLiteral(v) => - BooleanLiteral(v) - - case ExUnitLiteral() => - UnitLiteral() - - case ExLocally(body) => - extractTree(body) - - case ExTyped(e, _) => - // TODO: refine type here? - extractTree(e) - - case ex @ ExIdentifier(sym, tpt) if dctx.isVariable(sym) || defsToDefs.contains(sym) => - dctx.vars.get(sym).orElse(dctx.mutableVars.get(sym)) match { - case Some(builder) => - builder().setPos(ex.pos) - case None => - // Maybe it is a function - defsToDefs.get(sym) match { - case Some(fd) => - FunctionInvocation(fd.typed, Seq()).setPos(sym.pos) - case None => - outOfSubsetError(tr, "Unidentified variable " + sym + " " + sym.id + ".") - } - } - - case hole @ ExHoleExpression(tpt, exprs) => - Hole(extractType(tpt), exprs.map(extractTree)) - - case ops @ ExWithOracleExpression(oracles, body) => - val newOracles = oracles map { case (tpt, sym) => - val aTpe = extractType(tpt) - val oTpe = oracleType(ops.pos, aTpe) - val newID = FreshIdentifier(sym.name.toString, oTpe) - newID - } - - val newVars = (oracles zip newOracles).map { - case ((_, sym), id) => - sym -> (() => Variable(id)) - } - - val cBody = extractTree(body)(dctx.withNewVars(newVars)) - - WithOracle(newOracles, cBody) - - case chs @ ExChooseExpression(body) => - val cBody = extractTree(body) - Choose(cBody) - - case l @ ExLambdaExpression(args, body) => - val vds = args map { vd => - val aTpe = extractType(vd.tpt) - val newID = FreshIdentifier(vd.symbol.name.toString, aTpe) - LeonValDef(newID) - } - - val newVars = (args zip vds).map { case (vd, lvd) => - vd.symbol -> (() => lvd.toVariable) - } - - val exBody = extractTree(body)(dctx.withNewVars(newVars)) - - Lambda(vds, exBody) - - case ExForallExpression(args, body) => - val vds = args map { case (tpt, sym) => - val aTpe = extractType(tpt) - val newID = FreshIdentifier(sym.name.toString, aTpe) - LeonValDef(newID) - } - - val newVars = (args zip vds).map { case ((_, sym), lvd) => - sym -> (() => lvd.toVariable) - } - - val exBody = extractTree(body)(dctx.withNewVars(newVars)) - - Forall(vds, exBody) - - case ExFiniteMap(tptFrom, tptTo, args) => - FiniteMap(args.map { - case ExTuple(tpes, Seq(key, value)) => - (extractTree(key), extractTree(value)) - case tree => - val ex = extractTree(tree) - (TupleSelect(ex, 1), TupleSelect(ex, 2)) - }.toMap, extractType(tptFrom), extractType(tptTo)) - - case ExFiniteSet(tpt, args) => - FiniteSet(args.map(extractTree).toSet, extractType(tpt)) - - case ExFiniteBag(tpt, args) => - FiniteBag(args.map { - case ExTuple(tpes, Seq(key, value)) => - (extractTree(key), extractTree(value)) - case tree => - val ex = extractTree(tree) - (TupleSelect(ex, 1), TupleSelect(ex, 2)) - }.toMap, extractType(tpt)) - - case ExCaseClassConstruction(tpt, args) => - extractType(tpt) match { - case cct: CaseClassType => - CaseClass(cct, args.map(extractTree)) - - case _ => - outOfSubsetError(tr, "Construction of a non-case class.") - - } - - case ExNot(e) => Not(extractTree(e)) - case ExUMinus(e) => UMinus(extractTree(e)) - case ExRealUMinus(e) => RealUMinus(extractTree(e)) - case ExBVUMinus(e) => BVUMinus(extractTree(e)) - case ExBVNot(e) => BVNot(extractTree(e)) - - case ExNotEquals(l, r) => - val rl = extractTree(l) - val rr = extractTree(r) - - (rl, rr) match { - case (IsTyped(_, ArrayType(_)), IsTyped(_, ArrayType(_))) => - outOfSubsetError(tr, "Leon does not support array comparison") - - case (IsTyped(_, rt), IsTyped(_, lt)) if typesCompatible(lt, rt) => - Not(Equals(rl, rr)) - - case (IntLiteral(v), IsTyped(_, IntegerType)) => - Not(Equals(InfiniteIntegerLiteral(v), rr)) - - case (IsTyped(_, IntegerType), IntLiteral(v)) => - Not(Equals(rl, InfiniteIntegerLiteral(v))) - - case (IsTyped(_, rt), IsTyped(_, lt)) => - outOfSubsetError(tr, "Invalid comparison: (_: "+rt.asString+") != (_: "+lt.asString+")") - } - - case ExEquals(l, r) => - val rl = extractTree(l) - val rr = extractTree(r) - - (rl, rr) match { - case (IsTyped(_, ArrayType(_)), IsTyped(_, ArrayType(_))) => - outOfSubsetError(tr, "Leon does not support array comparison") - - case (IsTyped(_, rt), IsTyped(_, lt)) if typesCompatible(lt, rt) => - Equals(rl, rr) - - case (IntLiteral(v), IsTyped(_, IntegerType)) => - Equals(InfiniteIntegerLiteral(v), rr) - - case (IsTyped(_, IntegerType), IntLiteral(v)) => - Equals(rl, InfiniteIntegerLiteral(v)) - - case (IsTyped(_, rt), IsTyped(_, lt)) => - outOfSubsetError(tr, "Invalid comparison: (_: "+rt+") == (_: "+lt+")") - } - - case ExArrayFill(baseType, length, defaultValue) => - val lengthRec = extractTree(length) - val defaultValueRec = extractTree(defaultValue) - NonemptyArray(Map(), Some(defaultValueRec, lengthRec)) - - case ExIfThenElse(t1,t2,t3) => - val r1 = extractTree(t1) - if(containsLetDef(r1)) { - outOfSubsetError(t1, "Condition of if-then-else expression should not contain nested function definition") - } - val r2 = extractTree(t2) - val r3 = extractTree(t3) - val lub = leastUpperBound(r2.getType, r3.getType) - lub match { - case Some(lub) => - IfExpr(r1, r2, r3) - - case None => - outOfSubsetError(tr, "Both branches of ifthenelse have incompatible types ("+r2.getType.asString(ctx)+" and "+r3.getType.asString(ctx)+")") - } - - case ExAsInstanceOf(expr, tt) => - val eRec = extractTree(expr) - val checkType = extractType(tt) - checkType match { - case ct: ClassType => - AsInstanceOf(eRec, ct) - case _ => - outOfSubsetError(tr, "asInstanceOf can only cast to class types") - } - - case ExIsInstanceOf(tt, cc) => - val ccRec = extractTree(cc) - val checkType = extractType(tt) - checkType match { - case ct: ClassType => - if(!ccRec.getType.isInstanceOf[ClassType]) { - outOfSubsetError(tr, "isInstanceOf can only be used with a class") - } else { - val rootType: LeonClassDef = ct.root.classDef - val testedExprType = ccRec.getType.asInstanceOf[ClassType] - val testedExprRootType: LeonClassDef = testedExprType.root.classDef - - if(rootType != testedExprRootType) { - outOfSubsetError(tr, "isInstanceOf can only be used with compatible classes") - } else { - IsInstanceOf(ccRec, ct) - } - } - case _ => - outOfSubsetError(tr, "isInstanceOf can only be used with a class") - } - - case pm @ ExPatternMatching(sel, cses) => - val rs = extractTree(sel) - val rc = cses.map(extractMatchCase) - matchExpr(rs, rc) - - case t: This => - extractType(t) match { - case ct: ClassType => - LeonThis(ct) - case _ => - outOfSubsetError(t, "Invalid usage of `this`") - } - - case aup @ ExArrayUpdated(ar, k, v) => - val rar = extractTree(ar) - val rk = extractTree(k) - val rv = extractTree(v) - - ArrayUpdated(rar, rk, rv) - - case l @ ExListLiteral(tpe, elems) => - val rtpe = extractType(tpe) - val cons = CaseClassType(libraryCaseClass(l.pos, "leon.collection.Cons"), Seq(rtpe)) - val nil = CaseClassType(libraryCaseClass(l.pos, "leon.collection.Nil"), Seq(rtpe)) - - elems.foldRight(CaseClass(nil, Seq())) { - case (e, ls) => CaseClass(cons, Seq(extractTree(e), ls)) - } - - case chr @ ExCharLiteral(c) => - CharLiteral(c) - - case str @ ExStringLiteral(s) => - StringLiteral(s) - - case ExImplies(lhs, rhs) => - Implies(extractTree(lhs), extractTree(rhs)).setPos(current.pos) - - case c @ ExCall(rec, sym, tps, args) => - // The object on which it is called is null if the symbol sym is a valid function in the scope and not a method. - val rrec = rec match { - case t if (defsToDefs contains sym) && !isMethod(sym) && !isMutator(sym) => - null - case _ => - extractTree(rec) - } - - val rargs = args.map(extractTree) - - //println(s"symbol $sym with id ${sym.id}") - //println(s"isMethod($sym) == ${isMethod(sym)}") - - (rrec, sym.name.decoded, rargs) match { - case (null, _, args) => - val (fd, convertArgs) = getFunDef(sym, c.pos, allowFreeArgs=true) - val newTps = tps.map(t => extractType(t)) - val tfd = fd.typed(newTps) - - FunctionInvocation(tfd, convertArgs(tfd, args)) - - case (IsTyped(rec, ct: ClassType), methodName, args) if isMethod(sym) || ignoredMethods(sym) => - val (fd, convertArgs) = getFunDef(sym, c.pos) - val cd = methodToClass(fd) - - val newTps = tps.map(t => extractType(t)) - val tfd = fd.typed(newTps) - - MethodInvocation(rec, cd, tfd, convertArgs(tfd, args)) - - case (IsTyped(rec, ft: FunctionType), _, args) => - application(rec, args) - - case (IsTyped(rec, cct: CaseClassType), name, Nil) if cct.classDef.fields.exists(_.id.name == name) => - val fieldID = cct.classDef.fields.find(_.id.name == name).get.id - - caseClassSelector(cct, rec, fieldID) - - //mutable variables - case (IsTyped(rec, cct: CaseClassType), name, List(e1)) if isMutator(sym) => - val id = cct.classDef.fields.find(_.id.name == name.dropRight(2)).get.id - FieldAssignment(rec, id, e1) - - - //String methods - case (IsTyped(a1, StringType), "toString", List()) => - a1 - case (IsTyped(a1, WithStringconverter(converter)), "toString", List()) => - converter(a1) - case (IsTyped(a1, StringType), "+", List(IsTyped(a2, StringType))) => - StringConcat(a1, a2) - case (IsTyped(a1, StringType), "+", List(IsTyped(a2, WithStringconverter(converter)))) => - StringConcat(a1, converter(a2)) - case (IsTyped(a1, WithStringconverter(converter)), "+", List(IsTyped(a2, StringType))) => - StringConcat(converter(a1), a2) - case (IsTyped(a1, StringType), "length", List()) => - StringLength(a1) - case (IsTyped(a1, StringType), "substring", List(IsTyped(start, Int32Type))) => - val s = FreshIdentifier("s", StringType) - let(s, a1, - SubString(Variable(s), start, StringLength(Variable(s))) - ) - case (IsTyped(a1, StringType), "substring", List(IsTyped(start, Int32Type), IsTyped(end, Int32Type))) => - SubString(a1, start, end) - - //BigInt methods - case (IsTyped(a1, IntegerType), "+", List(IsTyped(a2, IntegerType))) => - Plus(a1, a2) - case (IsTyped(a1, IntegerType), "-", List(IsTyped(a2, IntegerType))) => - Minus(a1, a2) - case (IsTyped(a1, IntegerType), "*", List(IsTyped(a2, IntegerType))) => - Times(a1, a2) - case (IsTyped(a1, IntegerType), "%", List(IsTyped(a2, IntegerType))) => - Remainder(a1, a2) - case (IsTyped(a1, IntegerType), "mod", List(IsTyped(a2, IntegerType))) => - Modulo(a1, a2) - case (IsTyped(a1, IntegerType), "/", List(IsTyped(a2, IntegerType))) => - Division(a1, a2) - case (IsTyped(a1, IntegerType), ">", List(IsTyped(a2, IntegerType))) => - GreaterThan(a1, a2) - case (IsTyped(a1, IntegerType), ">=", List(IsTyped(a2, IntegerType))) => - GreaterEquals(a1, a2) - case (IsTyped(a1, IntegerType), "<", List(IsTyped(a2, IntegerType))) => - LessThan(a1, a2) - case (IsTyped(a1, IntegerType), "<=", List(IsTyped(a2, IntegerType))) => - LessEquals(a1, a2) - - - //Real methods - case (IsTyped(a1, RealType), "+", List(IsTyped(a2, RealType))) => - RealPlus(a1, a2) - case (IsTyped(a1, RealType), "-", List(IsTyped(a2, RealType))) => - RealMinus(a1, a2) - case (IsTyped(a1, RealType), "*", List(IsTyped(a2, RealType))) => - RealTimes(a1, a2) - case (IsTyped(a1, RealType), "/", List(IsTyped(a2, RealType))) => - RealDivision(a1, a2) - case (IsTyped(a1, RealType), ">", List(IsTyped(a2, RealType))) => - GreaterThan(a1, a2) - case (IsTyped(a1, RealType), ">=", List(IsTyped(a2, RealType))) => - GreaterEquals(a1, a2) - case (IsTyped(a1, RealType), "<", List(IsTyped(a2, RealType))) => - LessThan(a1, a2) - case (IsTyped(a1, RealType), "<=", List(IsTyped(a2, RealType))) => - LessEquals(a1, a2) - - - // Int methods - case (IsTyped(a1, Int32Type), "+", List(IsTyped(a2, Int32Type))) => - BVPlus(a1, a2) - case (IsTyped(a1, Int32Type), "-", List(IsTyped(a2, Int32Type))) => - BVMinus(a1, a2) - case (IsTyped(a1, Int32Type), "*", List(IsTyped(a2, Int32Type))) => - BVTimes(a1, a2) - case (IsTyped(a1, Int32Type), "%", List(IsTyped(a2, Int32Type))) => - BVRemainder(a1, a2) - case (IsTyped(a1, Int32Type), "/", List(IsTyped(a2, Int32Type))) => - BVDivision(a1, a2) - - case (IsTyped(a1, Int32Type), "|", List(IsTyped(a2, Int32Type))) => - BVOr(a1, a2) - case (IsTyped(a1, Int32Type), "&", List(IsTyped(a2, Int32Type))) => - BVAnd(a1, a2) - case (IsTyped(a1, Int32Type), "^", List(IsTyped(a2, Int32Type))) => - BVXOr(a1, a2) - case (IsTyped(a1, Int32Type), "<<", List(IsTyped(a2, Int32Type))) => - BVShiftLeft(a1, a2) - case (IsTyped(a1, Int32Type), ">>", List(IsTyped(a2, Int32Type))) => - BVAShiftRight(a1, a2) - case (IsTyped(a1, Int32Type), ">>>", List(IsTyped(a2, Int32Type))) => - BVLShiftRight(a1, a2) - - case (IsTyped(a1, Int32Type), ">", List(IsTyped(a2, Int32Type))) => - GreaterThan(a1, a2) - case (IsTyped(a1, Int32Type), ">=", List(IsTyped(a2, Int32Type))) => - GreaterEquals(a1, a2) - case (IsTyped(a1, Int32Type), "<", List(IsTyped(a2, Int32Type))) => - LessThan(a1, a2) - case (IsTyped(a1, Int32Type), "<=", List(IsTyped(a2, Int32Type))) => - LessEquals(a1, a2) - - - // Boolean methods - case (IsTyped(a1, BooleanType), "&&", List(IsTyped(a2, BooleanType))) => - and(a1, a2) - - case (IsTyped(a1, BooleanType), "||", List(IsTyped(a2, BooleanType))) => - or(a1, a2) - - - // Set methods - case (IsTyped(a1, SetType(b1)), "size", Nil) => - SetCardinality(a1) - - //case (IsTyped(a1, SetType(b1)), "min", Nil) => - // SetMin(a1) - - //case (IsTyped(a1, SetType(b1)), "max", Nil) => - // SetMax(a1) - - case (IsTyped(a1, SetType(b1)), "+", List(a2)) => - SetAdd(a1, a2) - - case (IsTyped(a1, SetType(b1)), "++", List(IsTyped(a2, SetType(b2)))) if b1 == b2 => - SetUnion(a1, a2) - - case (IsTyped(a1, SetType(b1)), "&", List(IsTyped(a2, SetType(b2)))) if b1 == b2 => - SetIntersection(a1, a2) - - case (IsTyped(a1, SetType(b1)), "subsetOf", List(IsTyped(a2, SetType(b2)))) if b1 == b2 => - SubsetOf(a1, a2) - - case (IsTyped(a1, SetType(b1)), "--", List(IsTyped(a2, SetType(b2)))) if b1 == b2 => - SetDifference(a1, a2) - - case (IsTyped(a1, SetType(b1)), "contains", List(a2)) => - ElementOfSet(a2, a1) - - case (IsTyped(a1, SetType(b1)), "isEmpty", List()) => - Equals(a1, FiniteSet(Set(), b1)) - - - // Bag methods - case (IsTyped(a1, BagType(b1)), "+", List(a2)) => - BagAdd(a1, a2) - - case (IsTyped(a1, BagType(b1)), "++", List(IsTyped(a2, BagType(b2)))) if b1 == b2 => - BagUnion(a1, a2) - - case (IsTyped(a1, BagType(b1)), "&", List(IsTyped(a2, BagType(b2)))) if b1 == b2 => - BagIntersection(a1, a2) - - case (IsTyped(a1, BagType(b1)), "--", List(IsTyped(a2, BagType(b2)))) if b1 == b2 => - BagDifference(a1, a2) - - case (IsTyped(a1, BagType(b1)), "get", List(a2)) => - MultiplicityInBag(a2, a1) - - case (IsTyped(a1, BagType(b1)), "isEmpty", List()) => - Equals(a1, FiniteBag(Map(), b1)) - - - // Array methods - case (IsTyped(a1, ArrayType(vt)), "apply", List(a2)) => - ArraySelect(a1, a2) - - case (IsTyped(a1, at: ArrayType), "length", Nil) => - ArrayLength(a1) - - case (IsTyped(a1, at: ArrayType), "updated", List(k, v)) => - ArrayUpdated(a1, k, v) - - case (IsTyped(a1, at: ArrayType), "clone", Nil) => - a1 - - // Map methods - case (IsTyped(a1, MapType(_, vt)), "apply", List(a2)) => - MapApply(a1, a2) - - case (IsTyped(a1, MapType(_, vt)), "get", List(a2)) => - val someClass = CaseClassType(libraryCaseClass(sym.pos, "leon.lang.Some"), Seq(vt)) - val noneClass = CaseClassType(libraryCaseClass(sym.pos, "leon.lang.None"), Seq(vt)) - - IfExpr(MapIsDefinedAt(a1, a2).setPos(current.pos), - CaseClass(someClass, Seq(MapApply(a1, a2).setPos(current.pos))).setPos(current.pos), - CaseClass(noneClass, Seq()).setPos(current.pos)) - - case (IsTyped(a1, MapType(_, vt)), "getOrElse", List(a2, a3)) => - IfExpr(MapIsDefinedAt(a1, a2).setPos(current.pos), - MapApply(a1, a2).setPos(current.pos), - a3) - - case (IsTyped(a1, mt: MapType), "isDefinedAt", List(a2)) => - MapIsDefinedAt(a1, a2) - - case (IsTyped(a1, mt: MapType), "contains", List(a2)) => - MapIsDefinedAt(a1, a2) - - case (IsTyped(a1, mt: MapType), "updated", List(k, v)) => - MapUnion(a1, FiniteMap(Map(k -> v), mt.from, mt.to)) - - case (IsTyped(a1, mt: MapType), "+", List(k, v)) => - MapUnion(a1, FiniteMap(Map(k -> v), mt.from, mt.to)) - - case (IsTyped(a1, mt: MapType), "+", List(IsTyped(kv, TupleType(List(_, _))))) => - kv match { - case Tuple(List(k, v)) => - MapUnion(a1, FiniteMap(Map(k -> v), mt.from, mt.to)) - case kv => - MapUnion(a1, FiniteMap(Map(TupleSelect(kv, 1) -> TupleSelect(kv, 2)), mt.from, mt.to)) - } - - case (IsTyped(a1, mt1: MapType), "++", List(IsTyped(a2, mt2: MapType))) if mt1 == mt2 => - MapUnion(a1, a2) - - // Char operations - case (IsTyped(a1, CharType), ">", List(IsTyped(a2, CharType))) => - GreaterThan(a1, a2) - - case (IsTyped(a1, CharType), ">=", List(IsTyped(a2, CharType))) => - GreaterEquals(a1, a2) - - case (IsTyped(a1, CharType), "<", List(IsTyped(a2, CharType))) => - LessThan(a1, a2) - - case (IsTyped(a1, CharType), "<=", List(IsTyped(a2, CharType))) => - LessEquals(a1, a2) - - case (a1, name, a2) => - val typea1 = a1.getType - val typea2 = a2.map(_.getType).mkString(",") - val sa2 = a2.mkString(",") - outOfSubsetError(tr, "Unknown call to " + name + s" on $a1 ($typea1) with arguments $sa2 of type $typea2") - } - - // default behaviour is to complain :) - case _ => - outOfSubsetError(tr, "Could not extract as PureScala (Scala tree of type "+tr.getClass+")") - } - - res.setPos(current.pos) - - rest match { - case Some(r) => - leonBlock(Seq(res, extractTree(r))) - case None => - res - } - } - - private def extractType(t: Tree)(implicit dctx: DefContext): LeonType = { - extractType(t.tpe)(dctx, t.pos) - } - - private def extractType(tpt: Type)(implicit dctx: DefContext, pos: Position): LeonType = tpt match { - case tpe if tpe == CharClass.tpe => - CharType - - case tpe if tpe == IntClass.tpe => - Int32Type - - case tpe if tpe == BooleanClass.tpe => - BooleanType - - case tpe if tpe == UnitClass.tpe => - UnitType - - case tpe if tpe == NothingClass.tpe => - Untyped - - case ct: ConstantType => - extractType(ct.value.tpe) - - case TypeRef(_, sym, _) if isBigIntSym(sym) => - IntegerType - - case TypeRef(_, sym, _) if isRealSym(sym) => - RealType - - case TypeRef(_, sym, _) if isStringSym(sym) => - StringType - - case TypeRef(_, sym, btt :: Nil) if isScalaSetSym(sym) => - outOfSubsetError(pos, "Scala's Set API is no longer extracted. Make sure you import leon.lang.Set that defines supported Set operations.") - - case TypeRef(_, sym, List(a,b)) if isScalaMapSym(sym) => - outOfSubsetError(pos, "Scala's Map API is no longer extracted. Make sure you import leon.lang.Map that defines supported Map operations.") - - case TypeRef(_, sym, btt :: Nil) if isSetSym(sym) => - SetType(extractType(btt)) - - case TypeRef(_, sym, btt :: Nil) if isBagSym(sym) => - BagType(extractType(btt)) - - case TypeRef(_, sym, List(ftt,ttt)) if isMapSym(sym) => - MapType(extractType(ftt), extractType(ttt)) - - case TypeRef(_, sym, List(t1,t2)) if isTuple2(sym) => - TupleType(Seq(extractType(t1),extractType(t2))) - - case TypeRef(_, sym, List(t1,t2,t3)) if isTuple3(sym) => - TupleType(Seq(extractType(t1),extractType(t2),extractType(t3))) - - case TypeRef(_, sym, List(t1,t2,t3,t4)) if isTuple4(sym) => - TupleType(Seq(extractType(t1),extractType(t2),extractType(t3),extractType(t4))) - - case TypeRef(_, sym, List(t1,t2,t3,t4,t5)) if isTuple5(sym) => - TupleType(Seq(extractType(t1),extractType(t2),extractType(t3),extractType(t4),extractType(t5))) - - case TypeRef(_, sym, btt :: Nil) if isArrayClassSym(sym) => - ArrayType(extractType(btt)) - - // TODO: What about Function0? - case TypeRef(_, sym, subs) if subs.size >= 1 && isFunction(sym, subs.size - 1) => - val from = subs.init - val to = subs.last - FunctionType(from map extractType, extractType(to)) - - case TypeRef(_, sym, tps) if isByNameSym(sym) => - extractType(tps.head) - - case tr @ TypeRef(_, sym, tps) => - val leontps = tps.map(extractType) - - if (sym.isAbstractType) { - if(dctx.tparams contains sym) { - dctx.tparams(sym) - } else { - outOfSubsetError(pos, "Unknown type parameter "+sym) - } - } else { - getClassType(sym, leontps) - } - - case tt: ThisType => - val cd = getClassDef(tt.sym, pos) - cd.typed // Typed using own's type parameters - - case SingleType(_, sym) => - getClassType(sym.moduleClass, Nil) - - case RefinedType(parents, defs) if defs.isEmpty => - /** - * For cases like if(a) e1 else e2 where - * e1 <: C1, - * e2 <: C2, - * with C1,C2 <: C - * - * Scala might infer a type for C such as: Product with Serializable with C - * we generalize to the first known type, e.g. C. - */ - parents.flatMap { ptpe => - try { - Some(extractType(ptpe)) - } catch { - case e: ImpureCodeEncounteredException => - None - }}.headOption match { - case Some(tpe) => - tpe - - case None => - outOfSubsetError(tpt.typeSymbol.pos, "Could not extract refined type as PureScala: "+tpt+" ("+tpt.getClass+")") - } - - case AnnotatedType(_, tpe) => extractType(tpe) - - case _ => - if (tpt ne null) { - outOfSubsetError(tpt.typeSymbol.pos, "Could not extract type as PureScala: "+tpt+" ("+tpt.getClass+")") - } else { - outOfSubsetError(NoPosition, "Tree with null-pointer as type found") - } - } - - private def getClassType(sym: Symbol, tps: List[LeonType])(implicit dctx: DefContext) = { - if (seenClasses contains sym) { - getClassDef(sym, NoPosition).typed(tps) - } else { - outOfSubsetError(NoPosition, "Unknown class "+sym.fullName) - } - } - - } - - def containsLetDef(expr: LeonExpr): Boolean = { - exists { - case (l: LetDef) => true - case _ => false - }(expr) - } -} diff --git a/src/main/scala/leon/frontends/scalac/ExtractionPhase.scala b/src/main/scala/leon/frontends/scalac/ExtractionPhase.scala deleted file mode 100644 index 3a338b9f00be52c27afbf5893843a233112dd37e..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/frontends/scalac/ExtractionPhase.scala +++ /dev/null @@ -1,84 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package frontends.scalac - -import purescala.Definitions.Program -import utils._ - -import scala.tools.nsc.{Settings,CompilerCommand} -import java.io.File - -object ExtractionPhase extends SimpleLeonPhase[List[String], Program] { - - val name = "Scalac Extraction" - val description = "Extraction of trees from the Scala Compiler" - - val optStrictCompilation = LeonFlagOptionDef("strictCompilation", "Exit Leon after an error in compilation", true) - - override val definedOptions: Set[LeonOptionDef[Any]] = Set(optStrictCompilation) - - implicit val debug = DebugSectionTrees - - def apply(ctx: LeonContext, args: List[String]): Program = { - val timer = ctx.timers.frontend.start() - - val settings = new Settings - - def getFiles(path: String): Option[Array[String]] = - scala.util.Try(new File(path).listFiles().map{ _.getAbsolutePath }).toOption - - val scalaLib = Option(scala.Predef.getClass.getProtectionDomain.getCodeSource).map{ - _.getLocation.getPath - }.orElse( for { - // We are in Eclipse. Look in Eclipse plugins to find scala lib - eclipseHome <- Option(System.getenv("ECLIPSE_HOME")) - pluginsHome = eclipseHome + "/plugins" - plugins <- getFiles(pluginsHome) - path <- plugins.find{ _ contains "scala-library"} - } yield path).getOrElse( ctx.reporter.fatalError( - "No Scala library found. If you are working in Eclipse, " + - "make sure to set the ECLIPSE_HOME environment variable to your Eclipse installation home directory" - )) - - settings.classpath.value = scalaLib - settings.usejavacp.value = false - settings.deprecation.value = true - settings.Yrangepos.value = true - settings.skip.value = List("patmat") - - val compilerOpts = Build.libFiles ::: args.filterNot(_.startsWith("--")) - - val command = new CompilerCommand(compilerOpts, settings) { - override val cmdName = "leon" - } - - if(command.ok) { - // Debugging code for classpath crap - // new scala.tools.util.PathResolver(settings).Calculated.basis.foreach { cp => - // cp.foreach( p => - // println(" => "+p.toString) - // ) - // } - - - val compiler = new ScalaCompiler(settings, ctx) - val run = new compiler.Run - run.compile(command.files) - - timer.stop() - - compiler.leonExtraction.setImports(compiler.saveImports.imports ) - - compiler.leonExtraction.compiledProgram match { - case Some(pgm) => - pgm - - case None => - ctx.reporter.fatalError("Failed to extract Leon program.") - } - } else { - ctx.reporter.fatalError("No input program.") - } - } -} diff --git a/src/main/scala/leon/frontends/scalac/FullScalaCompiler.scala b/src/main/scala/leon/frontends/scalac/FullScalaCompiler.scala deleted file mode 100644 index 72e98e01561d954abe1048d80753228b0459e935..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/frontends/scalac/FullScalaCompiler.scala +++ /dev/null @@ -1,16 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package frontends.scalac - -import scala.tools.nsc.{Global,Settings=>NSCSettings} -import scala.reflect.internal.Positions - -class FullScalaCompiler(settings: NSCSettings, ctx: LeonContext) extends Global(settings, new SimpleReporter(settings, ctx.reporter)) with Positions { - - class Run extends super.Run { - override def progress(current: Int, total: Int) { - ctx.reporter.onCompilerProgress(current, total) - } - } -} diff --git a/src/main/scala/leon/frontends/scalac/LeonExtraction.scala b/src/main/scala/leon/frontends/scalac/LeonExtraction.scala deleted file mode 100644 index 10ca2cb97214a5046dbabaedf060352402582ed8..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/frontends/scalac/LeonExtraction.scala +++ /dev/null @@ -1,34 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package frontends.scalac - -import scala.tools.nsc._ - -trait LeonExtraction extends SubComponent with CodeExtraction { - import global._ - - val phaseName = "leon" - - var units: List[CompilationUnit] = Nil - - implicit val ctx: LeonContext - - var imports : Map[RefTree,List[Import]] = Map() - - def setImports( imports : Map[RefTree,List[Import]] ) { - this.imports = imports - } - - def compiledProgram = { - new Extraction(units).extractProgram - } - - def newPhase(prev: scala.tools.nsc.Phase): StdPhase = new Phase(prev) - - class Phase(prev: scala.tools.nsc.Phase) extends StdPhase(prev) { - def apply(unit: CompilationUnit): Unit = { - units ::= unit - } - } -} diff --git a/src/main/scala/leon/frontends/scalac/SaveImports.scala b/src/main/scala/leon/frontends/scalac/SaveImports.scala deleted file mode 100644 index bed944f38b15568509de54fcb26aebf30eec127d..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/frontends/scalac/SaveImports.scala +++ /dev/null @@ -1,67 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package frontends.scalac - -import scala.tools.nsc._ -import leon.utils.{ - Position => LeonPosition, - RangePosition => LeonRangePosition, - OffsetPosition => LeonOffsetPosition, - DebugSectionTrees -} - -trait SaveImports extends SubComponent { - import global._ - - val phaseName = "imports" - - val ctx: LeonContext - - var imports : Map[RefTree,List[Import]] = Map() - - implicit val debugSection = DebugSectionTrees - - // FIXME : Copy pasting code is bad. - def scalaPosToLeonPos(p: global.Position): LeonPosition = { - if (p == NoPosition) { - leon.utils.NoPosition - } else if (p.isRange) { - val start = p.focusStart - val end = p.focusEnd - LeonRangePosition(start.line, start.column, start.point, - end.line, end.column, end.point, - p.source.file.file) - } else { - LeonOffsetPosition(p.line, p.column, p.point, - p.source.file.file) - } - } - - - def newPhase(prev: scala.tools.nsc.Phase): StdPhase = new Phase(prev) - - class Phase(prev: scala.tools.nsc.Phase) extends StdPhase(prev) { - def apply(unit: CompilationUnit): Unit = { - unit.body match { - case pkg @ PackageDef(pid,lst) => - - imports += pid -> (lst collect { - case i : Import => i - }) - - for (tree <- lst if !tree.isInstanceOf[Import] ) { - tree.foreach { - case imp : Import => - ctx.reporter.debug( - scalaPosToLeonPos(imp.pos), - "Note: Imports will not be preserved in the AST unless they are at top-level" - ) - case _ => - } - } - - } - } - } -} diff --git a/src/main/scala/leon/frontends/scalac/ScalaCompiler.scala b/src/main/scala/leon/frontends/scalac/ScalaCompiler.scala deleted file mode 100644 index 0ae89a4633c0e5a7c87c35d965b2653418e443cf..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/frontends/scalac/ScalaCompiler.scala +++ /dev/null @@ -1,47 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package frontends.scalac - -import scala.tools.nsc.{Global,Settings=>NSCSettings} -import scala.reflect.internal.Positions - -class ScalaCompiler(settings : NSCSettings, ctx: LeonContext) extends Global(settings, new SimpleReporter(settings, ctx.reporter)) with Positions { - - object leonExtraction extends { - val global: ScalaCompiler.this.type = ScalaCompiler.this - val runsAfter = List[String]("refchecks") - val runsRightAfter = None - val ctx = ScalaCompiler.this.ctx - } with LeonExtraction - - object saveImports extends { - val global: ScalaCompiler.this.type = ScalaCompiler.this - val runsAfter = List[String]("pickler") - val runsRightAfter = None - val ctx = ScalaCompiler.this.ctx - } with SaveImports - - override protected def computeInternalPhases() : Unit = { - val phs = List( - syntaxAnalyzer -> "parse source into ASTs, perform simple desugaring", - analyzer.namerFactory -> "resolve names, attach symbols to named trees", - analyzer.packageObjects -> "load package objects", - analyzer.typerFactory -> "the meat and potatoes: type the trees", - patmat -> "translate match expressions", - superAccessors -> "add super accessors in traits and nested classes", - extensionMethods -> "add extension methods for inline classes", - pickler -> "serialize symbol tables", - saveImports -> "save imports to pass to leonExtraction", - refChecks -> "reference/override checking, translate nested objects", - leonExtraction -> "extracts leon trees out of scala trees" - ) - phs foreach { phasesSet += _._1 } - } - - class Run extends super.Run { - override def progress(current: Int, total: Int) { - ctx.reporter.onCompilerProgress(current, total) - } - } -} diff --git a/src/main/scala/leon/frontends/scalac/SimpleReporter.scala b/src/main/scala/leon/frontends/scalac/SimpleReporter.scala deleted file mode 100644 index 79a3495c64c26e7ca7111811d0a84318ce3f5ba1..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/frontends/scalac/SimpleReporter.scala +++ /dev/null @@ -1,70 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package frontends.scalac - -import scala.tools.nsc.Settings -import scala.tools.nsc.reporters.AbstractReporter - -import scala.reflect.internal.util.{Position, NoPosition, FakePos, StringOps} -import utils.{Position => LeonPosition, NoPosition => LeonNoPosition, OffsetPosition => LeonOffsetPosition} - -/** This implements a reporter that calls the callback with every line that a -regular ConsoleReporter would display. */ -class SimpleReporter(val settings: Settings, reporter: leon.Reporter) extends AbstractReporter { - final val ERROR_LIMIT = 5 - - private def label(severity: Severity): String = severity match { - case ERROR => "error" - case WARNING => "warning" - case INFO => null - } - - private def clabel(severity: Severity): String = { - val label0 = label(severity) - if (label0 eq null) "" else label0 + ": " - } - - private def getCountString(severity: Severity): String = - StringOps.countElementsAsString(severity.count, label(severity)) - - /** Prints the message. */ - def printMessage(msg: String, pos: LeonPosition, severity: Severity) { - severity match { - case ERROR => - reporter.error(pos, msg) - case WARNING => - reporter.warning(pos, msg) - case INFO => - reporter.info(pos, msg) - } - } - - /** Prints the message with the given position indication. */ - def printMessage(posIn: Position, msg: String, severity: Severity) { - val pos = if (posIn eq null) NoPosition - else if (posIn.isDefined) posIn.finalPosition - else posIn - pos match { - case FakePos(fmsg) => - printMessage(fmsg+" "+msg, LeonNoPosition, severity) - case NoPosition => - printMessage(msg, LeonNoPosition, severity) - case _ => - val lpos = LeonOffsetPosition(pos.line, pos.column, pos.point, pos.source.file.file) - printMessage(msg, lpos, severity) - } - } - - def print(pos: Position, msg: String, severity: Severity) { - printMessage(pos, clabel(severity) + msg, severity) - } - - def display(pos: Position, msg: String, severity: Severity) { - severity.count += 1 - if (severity != ERROR || severity.count <= ERROR_LIMIT) - print(pos, msg, severity) - } - - def displayPrompt(): Unit = {} -} diff --git a/src/main/scala/leon/genc/CAST.scala b/src/main/scala/leon/genc/CAST.scala deleted file mode 100644 index a05f3beb98d21e0c77a3222004645af5e35555fe..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/genc/CAST.scala +++ /dev/null @@ -1,296 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package genc - -import utils.UniqueCounter - -/* - * Here are defined classes used to represent AST of C programs. - */ - -object CAST { // C Abstract Syntax Tree - - sealed abstract class Tree - case object NoTree extends Tree - - - /* ------------------------------------------------------------ Types ----- */ - abstract class Type(val rep: String) extends Tree { - override def toString = rep - - def mutable: Type = this match { - case Const(typ) => typ.mutable - case _ => this - } - } - - /* Type Modifiers */ - case class Const(typ: Type) extends Type(s"$typ const") - case class Pointer(typ: Type) extends Type(s"$typ*") - - /* Primitive Types */ - case object Int32 extends Type("int32_t") // Requires <stdint.h> - case object Bool extends Type("bool") // Requires <stdbool.h> - case object Void extends Type("void") - - /* Compound Types */ - case class Struct(id: Id, fields: Seq[Var]) extends Type(id.name) - - - /* --------------------------------------------------------- Literals ----- */ - case class IntLiteral(v: Int) extends Stmt - case class BoolLiteral(b: Boolean) extends Stmt - - - /* ----------------------------------------------------- Definitions ----- */ - abstract class Def extends Tree - - case class Prog(structs: Seq[Struct], functions: Seq[Fun]) extends Def - - case class Fun(id: Id, retType: Type, params: Seq[Var], body: Stmt) extends Def - - case class Id(name: String) extends Def { - // `|` is used as the margin delimiter and can cause trouble in some situations - def fixMargin = - if (name.size > 0 && name(0) == '|') "| " + name - else name - } - - case class Var(id: Id, typ: Type) extends Def - - /* ----------------------------------------------------------- Stmts ----- */ - abstract class Stmt extends Tree - case object NoStmt extends Stmt - - case class Compound(stmts: Seq[Stmt]) extends Stmt - - case class Assert(pred: Stmt, error: Option[String]) extends Stmt { // Requires <assert.h> - require(pred.isValue) - } - - case class DeclVar(x: Var) extends Stmt - - case class DeclInitVar(x: Var, value: Stmt) extends Stmt { - require(value.isValue) - } - - case class Assign(lhs: Stmt, rhs: Stmt) extends Stmt { - require(lhs.isValue && rhs.isValue) - } - - // Note: we don't need to differentiate between specific - // operators so we only keep track of the "kind" of operator - // with an Id. - case class UnOp(op: Id, rhs: Stmt) extends Stmt { - require(rhs.isValue) - } - - case class MultiOp(op: Id, stmts: Seq[Stmt]) extends Stmt { - require(stmts.length > 1 && stmts.forall { _.isValue }) - } - - case class SubscriptOp(ptr: Stmt, idx: Stmt) extends Stmt { - require(ptr.isValue && idx.isValue) - } - - case object Break extends Stmt - - case class Return(stmt: Stmt) extends Stmt { - require(stmt.isValue) - } - - case class IfElse(cond: Stmt, thn: Stmt, elze: Stmt) extends Stmt { - require(cond.isValue) - } - - case class While(cond: Stmt, body: Stmt) extends Stmt { - require(cond.isValue) - } - - case class AccessVar(id: Id) extends Stmt - case class AccessRef(id: Id) extends Stmt - case class AccessAddr(id: Id) extends Stmt - case class AccessField(struct: Stmt, field: Id) extends Stmt { - require(struct.isValue) - } - - case class Call(id: Id, args: Seq[Stmt]) extends Stmt { - require(args forall { _.isValue }) - } - - case class StructInit(args: Seq[(Id, Stmt)], struct: Struct) extends Stmt { - require(args forall { _._2.isValue }) - } - - case class ArrayInit(length: Stmt, valueType: Type, defaultValue: Stmt) extends Stmt { - require(length.isValue && defaultValue.isValue) - } - - case class ArrayInitWithValues(valueType: Type, values: Seq[Stmt]) extends Stmt { - require(values forall { _.isValue }) - - lazy val length = values.length - } - - - /* -------------------------------------------------------- Factories ----- */ - object Op { - def apply(op: String, rhs: Stmt) = UnOp(Id(op), rhs) - def apply(op: String, rhs: Stmt, lhs: Stmt) = MultiOp(Id(op), rhs :: lhs :: Nil) - def apply(op: String, stmts: Seq[Stmt]) = MultiOp(Id(op), stmts) - } - - object Val { - def apply(id: Id, typ: Type) = typ match { - case Const(_) => Var(id, typ) // avoid const of const - case _ => Var(id, Const(typ)) - } - } - - /* "Templatetized" Types */ - object Tuple { - def apply(bases: Seq[Type]) = { - val name = Id("__leon_tuple_" + bases.mkString("_") + "_t") - - val fields = bases.zipWithIndex map { - case (typ, idx) => Var(getNthId(idx + 1), typ) - } - - Struct(name, fields) - } - - // Indexes start from 1, not 0! - def getNthId(n: Int) = Id("_" + n) - } - - object Array { - def apply(base: Type) = { - val name = Id("__leon_array_" + base + "_t") - val data = Var(dataId, Pointer(base)) - val length = Var(lengthId, Int32) - val fields = data :: length :: Nil - - Struct(name, fields) - } - - def lengthId = Id("length") - def dataId = Id("data") - } - - - /* ---------------------------------------------------- Introspection ----- */ - implicit class IntrospectionOps(val stmt: Stmt) { - def isLiteral = stmt match { - case _: IntLiteral => true - case _: BoolLiteral => true - case _ => false - } - - // True if statement can be used as a value - def isValue: Boolean = isLiteral || { - stmt match { - //case _: Assign => true it's probably the case but for now let's ignore it - case c: Compound => c.stmts.size == 1 && c.stmts.head.isValue - case _: UnOp => true - case _: MultiOp => true - case _: SubscriptOp => true - case _: AccessVar => true - case _: AccessRef => true - case _: AccessAddr => true - case _: AccessField => true - case _: Call => true - case _: StructInit => true - case _: ArrayInit => true - case _: ArrayInitWithValues => true - case _ => false - } - } - - def isPure: Boolean = isLiteral || { - stmt match { - case NoStmt => true - case Compound(stmts) => stmts forall { _.isPure } - case Assert(pred, _) => pred.isPure - case UnOp(_, rhs) => rhs.isPure - case MultiOp(_, stmts) => Compound(stmts).isPure - case SubscriptOp(ptr, idx) => (ptr ~ idx).isPure - case IfElse(c, t, e) => (c ~ t ~ e).isPure - case While(c, b) => (c ~ b).isPure - case AccessVar(_) => true - case AccessRef(_) => true - case AccessAddr(_) => true - case AccessField(s, _) => s.isPure - // case Call(id, args) => true if args are pure and function `id` is pure too - case _ => false - } - } - - def isPureValue = isValue && isPure - } - - - /* ------------------------------------------------------------- DSL ----- */ - // Operator ~~ appends and flattens nested compounds - implicit class StmtOps(val stmt: Stmt) { - // In addition to combining statements together in a compound - // we remove the empty ones and if the resulting compound - // has only one statement we return this one without being - // wrapped into a Compound - def ~(other: Stmt) = { - val stmts = (stmt, other) match { - case (Compound(stmts), Compound(others)) => stmts ++ others - case (stmt , Compound(others)) => stmt +: others - case (Compound(stmts), other ) => stmts :+ other - case (stmt , other ) => stmt :: other :: Nil - } - - def isNoStmt(s: Stmt) = s match { - case NoStmt => true - case _ => false - } - - val compound = Compound(stmts filterNot isNoStmt) - compound match { - case Compound(stmts) if stmts.length == 0 => NoStmt - case Compound(stmts) if stmts.length == 1 => stmts.head - case compound => compound - } - } - - def ~~(others: Seq[Stmt]) = stmt ~ Compound(others) - } - - implicit class StmtsOps(val stmts: Seq[Stmt]) { - def ~~(other: Stmt) = other match { - case Compound(others) => Compound(stmts) ~~ others - case other => Compound(stmts) ~ other - } - - def ~~~(others: Seq[Stmt]) = Compound(stmts) ~~ others - } - - val True = BoolLiteral(true) - val False = BoolLiteral(false) - - - /* ------------------------------------------------ Fresh Generators ----- */ - object FreshId { - private var counter = -1 - private val leonPrefix = "__leon_" - - def apply(prefix: String = ""): Id = { - counter += 1 - Id(leonPrefix + prefix + counter) - } - } - - object FreshVar { - def apply(typ: Type, prefix: String = "") = Var(FreshId(prefix), typ) - } - - object FreshVal { - def apply(typ: Type, prefix: String = "") = Val(FreshId(prefix), typ) - } -} - diff --git a/src/main/scala/leon/genc/CConverter.scala b/src/main/scala/leon/genc/CConverter.scala deleted file mode 100644 index ea3130d4c337bbfb4812269dcbe73fa63c9420eb..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/genc/CConverter.scala +++ /dev/null @@ -1,708 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package genc - -import purescala.Common._ -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.Types._ -import xlang.Expressions._ - -import scala.reflect.ClassTag - -// don't import CAST._ to decrease possible confusion between the two ASTs - -class CConverter(val ctx: LeonContext, val prog: Program) { - def convert: CAST.Prog = convertToProg(prog) - - // Global data: keep track of the custom types and function of the input program - // Using sequences and not sets to keep track of order/dependencies - private var typeDecls = Seq[CAST.Struct]() - private var functions = Seq[CAST.Fun]() - - // Extra information about inner functions' context - // See classes VarInfo and FunCtx and functions convertToFun and - // FunctionInvocation conversion - private var funExtraArgss = Map[CAST.Id, Seq[CAST.Id]]() - private val emptyFunCtx = FunCtx(Seq()) - - private def registerType(typ: CAST.Struct) { - // Types might be processed more than once as the corresponding CAST type - // is not cached and need to be reconstructed several time if necessary - if (!typeDecls.contains(typ)) { - typeDecls = typeDecls :+ typ - debug(s"New type registered: $typ") - } - } - - private def registerFun(fun: CAST.Fun) { - // Unlike types, functions should not get defined multiple times as this - // would imply invalidating funExtraArgss - if (functions contains fun) - internalError("Function ${fun.id} defined more than once") - else - functions = functions :+ fun - } - - // Register extra function argument for the function named `id` - private def registerFunExtraArgs(id: CAST.Id, params: Seq[CAST.Id]) { - funExtraArgss = funExtraArgss + ((id, params)) - } - - // Get the extra argument identifiers for the function named `id` - private def getFunExtraArgs(id: CAST.Id) = funExtraArgss.getOrElse(id, Seq()) - - // Apply the conversion function and make sure the resulting AST matches our expectation - private def convertTo[T](tree: Tree)(implicit funCtx: FunCtx, ct: ClassTag[T]): T = convert(tree) match { - case t: T => t - case x => internalError(s"Expected an instance of $ct when converting $tree but got $x") - } - - // Generic conversion function - // Currently simple aliases in case we need later to have special treatment instead - private def convertToType (tree: Tree)(implicit funCtx: FunCtx) = convertTo[CAST.Type](tree) - private def convertToStruct(tree: Tree)(implicit funCtx: FunCtx) = convertTo[CAST.Struct](tree) - private def convertToId (tree: Tree)(implicit funCtx: FunCtx) = convertTo[CAST.Id](tree) - private def convertToStmt (tree: Tree)(implicit funCtx: FunCtx) = convertTo[CAST.Stmt](tree) - private def convertToVar (tree: Tree)(implicit funCtx: FunCtx) = convertTo[CAST.Var](tree) - - private def convertToProg(prog: Program): CAST.Prog = { - // Only process the main unit - val (mainUnits, _) = prog.units partition { _.isMainUnit } - - if (mainUnits.size == 0) fatalError("No main unit in the program") - if (mainUnits.size >= 2) fatalError("Multiple main units in the program") - - val mainUnit = mainUnits.head - - debug(s"Converting the main unit:\n$mainUnit") - collectSymbols(mainUnit) - - CAST.Prog(typeDecls, functions) - } - - // Look for function and structure definitions - private def collectSymbols(unit: UnitDef) { - implicit val defaultFunCtx = emptyFunCtx - - unit.defs.foreach { - case ModuleDef(_, funDefs, _) => - funDefs.foreach { - case fd: FunDef => convertToFun(fd) // the function gets registered here - case cc: CaseClassDef => convertToStruct(cc) // the type declaration gets registered here - - case x => internalError(s"Unknown function definition $x: ${x.getClass}") - } - - case x => internalError(s"Unexpected definition $x instead of a ModuleDef") - } - } - - // A variable can be locally declared (e.g. function parameter or local variable) - // or it can be "inherited" from a more global context (e.g. inner functions have - // access to their outer function parameters). - private case class VarInfo(x: CAST.Var, local: Boolean) { - // Transform a local variable into a global variable - def lift = VarInfo(x, false) - - // Generate CAST variable declaration for function signature - def toParam = CAST.Var(x.id, CAST.Pointer(x.typ)) - - // Generate CAST access statement - def toArg = if (local) CAST.AccessAddr(x.id) else CAST.AccessVar(x.id) - } - - private case class FunCtx(vars: Seq[VarInfo]) { - // Transform local variables into "outer" variables - def lift = FunCtx(vars map { _.lift }) - - // Create a new context with more local variables - def extend(x: CAST.Var): FunCtx = extend(Seq(x)) - def extend(xs: Seq[CAST.Var]): FunCtx = { - val newVars = xs map { VarInfo(_, true) } - FunCtx(vars ++ newVars) - } - - // Check if a given variable's identifier exists in the context and is an "outer" variable - def hasOuterVar(id: CAST.Id) = vars exists { vi => !vi.local && vi.x.id == id } - - // List all variables' ids - def extractIds = vars map { _.x.id } - - // Generate arguments for the given identifiers according to the current context - def toArgs(ids: Seq[CAST.Id]) = { - val filtered = vars filter { ids contains _.x.id } - filtered map { _.toArg } - } - - // Generate parameters (var + type) - def toParams = vars map { _.toParam } - } - - // Extract inner functions too - private def convertToFun(fd: FunDef)(implicit funCtx: FunCtx) = { - // Forbid return of array as they are allocated on the stack - if (containsArrayType(fd.returnType)) { - fatalError("Returning arrays is currently not allowed") - } - - val id = convertToId(fd.id) - val retType = convertToType(fd.returnType) - val stdParams = fd.params map convertToVar - - // Prepend existing variables from the outer function context to - // this function's parameters - val extraParams = funCtx.toParams - val params = extraParams ++ stdParams - - // Function LeonContext: - // 1) Save the variables of the current context for later function invocation - // 2) Lift & augment funCtx with the current function's arguments - // 3) Propagate it to the current function's body - - registerFunExtraArgs(id, funCtx.extractIds) - - val funCtx2 = funCtx.lift.extend(stdParams) - - val b = convertToStmt(fd.fullBody)(funCtx2) - val body = retType match { - case CAST.Void => b - case _ => injectReturn(b) - } - - val fun = CAST.Fun(id, retType, params, body) - registerFun(fun) - - fun - } - - private def convert(tree: Tree)(implicit funCtx: FunCtx): CAST.Tree = tree match { - /* ---------------------------------------------------------- Types ----- */ - case Int32Type => CAST.Int32 - case BooleanType => CAST.Bool - case UnitType => CAST.Void - - case ArrayType(base) => - val typ = CAST.Array(convertToType(base)) - registerType(typ) - typ - - case TupleType(bases) => - val typ = CAST.Tuple(bases map convertToType) - registerType(typ) - typ - - case cd: CaseClassDef => - if (cd.isAbstract) fatalError("Abstract types are not supported") - if (cd.hasParent) fatalError("Inheritance is not supported") - if (cd.isCaseObject) fatalError("Case Objects are not supported") - if (cd.tparams.length > 0) fatalError("Type Parameters are not supported") - if (cd.methods.length > 0) fatalError("Methods are not yet supported") - - val id = convertToId(cd.id) - val fields = cd.fields map convertToVar - val typ = CAST.Struct(id, fields) - - registerType(typ) - typ - - case CaseClassType(cd, _) => convertToStruct(cd) // reuse `case CaseClassDef` - - /* ------------------------------------------------------- Literals ----- */ - case IntLiteral(v) => CAST.IntLiteral(v) - case BooleanLiteral(b) => CAST.BoolLiteral(b) - case UnitLiteral() => CAST.NoStmt - - /* ------------------------------------ Definitions and Statements ----- */ - case id: Identifier => - if (id.name == "main") CAST.Id("main") // and not `main0` - else CAST.Id(id.uniqueName) - - // Function parameter - case vd: ValDef => buildVal(vd.id, vd.getType) - - // Accessing variable - case v: Variable => buildAccessVar(v.id) - - case Block(exprs, last) => - // Interleave the "bodies" of flatten expressions and their values - // and generate a Compound statement - (exprs :+ last) map convertToStmt reduce { _ ~ _ } - - case Let(b, v, r) => buildLet(b, v, r, false) - case LetVar(b, v, r) => buildLet(b, v, r, true) - - case LetDef(fds, rest) => - fds foreach convertToFun // The functions get registered there - convertToStmt(rest) - - case Assignment(varId, expr) => - val f = convertAndFlatten(expr) - val x = buildAccessVar(varId) - - val assign = CAST.Assign(x, f.value) - - f.body ~~ assign - - case tuple @ Tuple(exprs) => - val struct = convertToStruct(tuple.getType) - val types = struct.fields map { _.typ } - val fs = convertAndNormaliseExecution(exprs, types) - val args = fs.values.zipWithIndex map { - case (arg, idx) => (CAST.Tuple.getNthId(idx + 1), arg) - } - - fs.bodies ~~ CAST.StructInit(args, struct) - - case TupleSelect(tuple1, idx) => // here idx is already 1-based - val struct = convertToStruct(tuple1.getType) - val tuple2 = convertToStmt(tuple1) - - val fs = normaliseExecution((tuple2, struct) :: Nil) - - val tuple = fs.values.head - - fs.bodies ~~ CAST.AccessField(tuple, CAST.Tuple.getNthId(idx)) - - case ArrayLength(array1) => - val array2 = convertToStmt(array1) - val arrayType = convertToType(array1.getType) - - val fs = normaliseExecution((array2, arrayType) :: Nil) - - val array = fs.values.head - - fs.bodies ~~ CAST.AccessField(array, CAST.Array.lengthId) - - case ArraySelect(array1, index1) => - val array2 = convertToStmt(array1) - val arrayType = convertToType(array1.getType) - val index2 = convertToStmt(index1) - - val fs = normaliseExecution((array2, arrayType) :: (index2, CAST.Int32) :: Nil) - - val array = fs.values(0) - val index = fs.values(1) - val ptr = CAST.AccessField(array, CAST.Array.dataId) - val select = CAST.SubscriptOp(ptr, index) - - fs.bodies ~~ select - - case NonemptyArray(elems, Some((value1, length1))) if elems.isEmpty => - val length2 = convertToStmt(length1) - val valueType = convertToType(value1.getType) - val value2 = convertToStmt(value1) - - val fs = normaliseExecution((length2, CAST.Int32) :: (value2, valueType) :: Nil) - val length = fs.values(0) - val value = fs.values(1) - - fs.bodies ~~ CAST.ArrayInit(length, valueType, value) - - case NonemptyArray(elems, Some(_)) => - fatalError("NonemptyArray with non empty elements is not supported") - - case NonemptyArray(elems, None) => // Here elems is non-empty - // Sort values according the the key (aka index) - val indexes = elems.keySet.toSeq.sorted - val values = indexes map { elems(_) } - - // Assert all types are the same - val types = values map { e => convertToType(e.getType) } - val typ = types(0) - val allSame = types forall { _ == typ } - if (!allSame) fatalError("Heterogenous arrays are not supported") - - val fs = convertAndNormaliseExecution(values, types) - - fs.bodies ~~ CAST.ArrayInitWithValues(typ, fs.values) - - case ArrayUpdate(array1, index1, newValue1) => - val array2 = convertToStmt(array1) - val index2 = convertToStmt(index1) - val newValue2 = convertToStmt(newValue1) - val values = array2 :: index2 :: newValue2 :: Nil - - val arePure = values forall { _.isPure } - val areValues = array2.isValue && index2.isValue // no newValue here - - newValue2 match { - case CAST.IfElse(cond, thn, elze) if arePure && areValues => - val array = array2 - val index = index2 - val ptr = CAST.AccessField(array, CAST.Array.dataId) - val select = CAST.SubscriptOp(ptr, index) - - val ifelse = buildIfElse(cond, injectAssign(select, thn), - injectAssign(select, elze)) - - ifelse - - case _ => - val arrayType = convertToType(array1.getType) - val indexType = CAST.Int32 - val valueType = convertToType(newValue1.getType) - val types = arrayType :: indexType :: valueType :: Nil - - val fs = normaliseExecution(values, types) - - val array = fs.values(0) - val index = fs.values(1) - val newValue = fs.values(2) - - val ptr = CAST.AccessField(array, CAST.Array.dataId) - val select = CAST.SubscriptOp(ptr, index) - val assign = CAST.Assign(select, newValue) - - fs.bodies ~~ assign - } - - case CaseClass(typ, args1) => - val struct = convertToStruct(typ) - val types = struct.fields map { _.typ } - val argsFs = convertAndNormaliseExecution(args1, types) - val fieldsIds = typ.classDef.fieldsIds map convertToId - val args = fieldsIds zip argsFs.values - - argsFs.bodies ~~ CAST.StructInit(args, struct) - - case CaseClassSelector(_, x1, fieldId) => - val struct = convertToStruct(x1.getType) - val x2 = convertToStmt(x1) - - val fs = normaliseExecution((x2, struct) :: Nil) - val x = fs.values.head - - fs.bodies ~~ CAST.AccessField(x, convertToId(fieldId)) - - case LessThan(lhs, rhs) => buildBinOp(lhs, "<", rhs) - case GreaterThan(lhs, rhs) => buildBinOp(lhs, ">", rhs) - case LessEquals(lhs, rhs) => buildBinOp(lhs, "<=", rhs) - case GreaterEquals(lhs, rhs) => buildBinOp(lhs, ">=", rhs) - case Equals(lhs, rhs) => buildBinOp(lhs, "==", rhs) - - case Not(rhs) => buildUnOp ( "!", rhs) - - case And(exprs) => buildMultiOp("&&", exprs) - case Or(exprs) => buildMultiOp("||", exprs) - - case BVPlus(lhs, rhs) => buildBinOp(lhs, "+", rhs) - case BVMinus(lhs, rhs) => buildBinOp(lhs, "-", rhs) - case BVUMinus(rhs) => buildUnOp ( "-", rhs) - case BVTimes(lhs, rhs) => buildBinOp(lhs, "*", rhs) - case BVDivision(lhs, rhs) => buildBinOp(lhs, "/", rhs) - case BVRemainder(lhs, rhs) => buildBinOp(lhs, "%", rhs) - case BVNot(rhs) => buildUnOp ( "~", rhs) - case BVAnd(lhs, rhs) => buildBinOp(lhs, "&", rhs) - case BVOr(lhs, rhs) => buildBinOp(lhs, "|", rhs) - case BVXOr(lhs, rhs) => buildBinOp(lhs, "^", rhs) - case BVShiftLeft(lhs, rhs) => buildBinOp(lhs, "<<", rhs) - case BVAShiftRight(lhs, rhs) => buildBinOp(lhs, ">>", rhs) - case BVLShiftRight(lhs, rhs) => fatalError("operator >>> not supported") - - // Ignore assertions for now - case Ensuring(body, _) => convert(body) - case Require(_, body) => convert(body) - case Assert(_, _, body) => convert(body) - - case IfExpr(cond1, thn1, elze1) => - val condF = convertAndFlatten(cond1) - val thn = convertToStmt(thn1) - val elze = convertToStmt(elze1) - - condF.body ~~ buildIfElse(condF.value, thn, elze) - - case While(cond1, body1) => - val cond = convertToStmt(cond1) - val body = convertToStmt(body1) - - if (cond.isPureValue) { - CAST.While(cond, body) - } else { - // Transform while (cond) { body } into - // while (true) { if (cond) { body } else { break } } - val condF = flatten(cond) - val ifelse = condF.body ~~ buildIfElse(condF.value, CAST.NoStmt, CAST.Break) - CAST.While(CAST.True, ifelse ~ body) - } - - case FunctionInvocation(tfd @ TypedFunDef(fd, _), stdArgs) => - // In addition to regular function parameters, add the callee's extra parameters - val id = convertToId(fd.id) - val types = tfd.params map { p => convertToType(p.getType) } - val fs = convertAndNormaliseExecution(stdArgs, types) - val extraArgs = funCtx.toArgs(getFunExtraArgs(id)) - val args = extraArgs ++ fs.values - - fs.bodies ~~ CAST.Call(id, args) - - case unsupported => - fatalError(s"$unsupported (of type ${unsupported.getClass}) is currently not supported by GenC") - } - - private def buildVar(id: Identifier, typ: TypeTree)(implicit funCtx: FunCtx) = - CAST.Var(convertToId(id), convertToType(typ)) - - private def buildVal(id: Identifier, typ: TypeTree)(implicit funCtx: FunCtx) = - CAST.Val(convertToId(id), convertToType(typ)) - - private def buildAccessVar(id1: Identifier)(implicit funCtx: FunCtx) = { - // Depending on the context, we have to deference the variable - val id = convertToId(id1) - if (funCtx.hasOuterVar(id)) CAST.AccessRef(id) - else CAST.AccessVar(id) - } - - private def buildLet(id: Identifier, value: Expr, rest1: Expr, forceVar: Boolean) - (implicit funCtx: FunCtx): CAST.Stmt = { - val (x, stmt) = buildDeclInitVar(id, value, forceVar) - - // Augment ctx for the following instructions - val funCtx2 = funCtx.extend(x) - val rest = convertToStmt(rest1)(funCtx2) - - stmt ~ rest - } - - - // Create a new variable for the given value, potentially immutable, and initialize it - private def buildDeclInitVar(id: Identifier, v: Expr, forceVar: Boolean) - (implicit funCtx: FunCtx): (CAST.Var, CAST.Stmt) = { - val valueF = convertAndFlatten(v) - val typ = v.getType - - valueF.value match { - case CAST.IfElse(cond, thn, elze) => - val x = buildVar(id, typ) - val decl = CAST.DeclVar(x) - val ifelse = buildIfElse(cond, injectAssign(x, thn), injectAssign(x, elze)) - val init = decl ~ ifelse - - (x, valueF.body ~~ init) - - case value => - val x = if (forceVar) buildVar(id, typ) else buildVal(id, typ) - val init = CAST.DeclInitVar(x, value) - - (x, valueF.body ~~ init) - } - } - - private def buildBinOp(lhs: Expr, op: String, rhs: Expr)(implicit funCtx: FunCtx) = { - buildMultiOp(op, lhs :: rhs :: Nil) - } - - private def buildUnOp(op: String, rhs1: Expr)(implicit funCtx: FunCtx) = { - val rhsF = convertAndFlatten(rhs1) - rhsF.body ~~ CAST.Op(op, rhsF.value) - } - - private def buildMultiOp(op: String, exprs: Seq[Expr])(implicit funCtx: FunCtx): CAST.Stmt = { - require(exprs.length >= 2) - - val stmts = exprs map convertToStmt - val types = exprs map { e => convertToType(e.getType) } - - buildMultiOp(op, stmts, types) - } - - private def buildMultiOp(op: String, stmts: Seq[CAST.Stmt], types: Seq[CAST.Type]): CAST.Stmt = { - // Default operator constuction when either pure statements are involved - // or no shortcut can happen - def defaultBuild = { - val fs = normaliseExecution(stmts, types) - fs.bodies ~~ CAST.Op(op, fs.values) - } - - if (stmts forall { _.isPureValue }) defaultBuild - else op match { - case "&&" => - // Apply short-circuit if needed - if (stmts.length == 2) { - // Base case: - // { { a; v } && { b; w } } - // is mapped onto - // { a; if (v) { b; w } else { false } } - val av = flatten(stmts(0)) - val bw = stmts(1) - - if (bw.isPureValue) defaultBuild - else av.body ~~ buildIfElse(av.value, bw, CAST.False) - } else { - // Recursive case: - // { { a; v } && ... } - // is mapped onto - // { a; if (v) { ... } else { false } } - val av = flatten(stmts(0)) - val rest = buildMultiOp(op, stmts.tail, types.tail) - - if (rest.isPureValue) defaultBuild - else av.body ~~ buildIfElse(av.value, rest, CAST.False) - } - - case "||" => - // Apply short-circuit if needed - if (stmts.length == 2) { - // Base case: - // { { a; v } || { b; w } } - // is mapped onto - // { a; if (v) { true } else { b; w } } - val av = flatten(stmts(0)) - val bw = stmts(1) - - if (bw.isPureValue) defaultBuild - else av.body ~~ buildIfElse(av.value, CAST.True, bw) - } else { - // Recusrive case: - // { { a; v } || ... } - // is mapped onto - // { a; if (v) { true } else { ... } } - val av = flatten(stmts(0)) - val rest = buildMultiOp(op, stmts.tail, types.tail) - - if (rest.isPureValue) defaultBuild - else av.body ~~ buildIfElse(av.value, CAST.True, rest) - } - - case _ => - defaultBuild - } - } - - // Flatten `if (if (cond1) thn1 else elze1) thn2 else elze2` - // into `if (cond1) { if (thn1) thn2 else elz2 } else { if (elz1) thn2 else elze2 }` - // or, if possible, into `if ((cond1 && thn1) || elz1) thn2 else elz2` - // - // Flatten `if (true) thn else elze` into `thn` - // Flatten `if (false) thn else elze` into `elze` - private def buildIfElse(cond: CAST.Stmt, thn2: CAST.Stmt, elze2: CAST.Stmt): CAST.Stmt = { - val condF = flatten(cond) - - val ifelse = condF.value match { - case CAST.IfElse(cond1, thn1, elze1) => - if (cond1.isPure && thn1.isPure && elze1.isPure) { - val bools = CAST.Bool :: CAST.Bool :: Nil - val ands = cond1 :: thn1 :: Nil - val ors = buildMultiOp("&&", ands, bools) :: elze1 :: Nil - val condX = buildMultiOp("||", ors, bools) - CAST.IfElse(condX, thn2, elze2) - } else { - buildIfElse(cond1, buildIfElse(thn1, thn2, elze2), buildIfElse(elze1, thn2, elze2)) - } - - case CAST.True => thn2 - case CAST.False => elze2 - case cond2 => CAST.IfElse(cond2, thn2, elze2) - } - - condF.body ~~ ifelse - } - - private def injectReturn(stmt: CAST.Stmt): CAST.Stmt = { - val f = flatten(stmt) - - f.value match { - case CAST.IfElse(cond, thn, elze) => - f.body ~~ CAST.IfElse(cond, injectReturn(thn), injectReturn(elze)) - - case _ => - f.body ~~ CAST.Return(f.value) - } - } - - private def injectAssign(x: CAST.Var, stmt: CAST.Stmt): CAST.Stmt = { - injectAssign(CAST.AccessVar(x.id), stmt) - } - - private def injectAssign(x: CAST.Stmt, stmt: CAST.Stmt): CAST.Stmt = { - val f = flatten(stmt) - - f.value match { - case CAST.IfElse(cond, thn, elze) => - f.body ~~ CAST.IfElse(cond, injectAssign(x, thn), injectAssign(x, elze)) - - case _ => - f.body ~~ CAST.Assign(x, f.value) - } - } - - // Flattened represents a non-empty statement { a; b; ...; y; z } - // split into body { a; b; ...; y } and value z - private case class Flattened(value: CAST.Stmt, body: Seq[CAST.Stmt]) - - // FlattenedSeq does the same as Flattened for a sequence of non-empty statements - private case class FlattenedSeq(values: Seq[CAST.Stmt], bodies: Seq[CAST.Stmt]) - - private def flatten(stmt: CAST.Stmt) = stmt match { - case CAST.Compound(stmts) if stmts.isEmpty => internalError(s"Empty compound cannot be flattened") - case CAST.Compound(stmts) => Flattened(stmts.last, stmts.init) - case stmt => Flattened(stmt, Seq()) - } - - private def convertAndFlatten(expr: Expr)(implicit funCtx: FunCtx) = flatten(convertToStmt(expr)) - - // Normalise execution order of, for example, function parameters; - // `types` represents the expected type of the corresponding values - // in case an intermediary variable needs to be created - private def convertAndNormaliseExecution(exprs: Seq[Expr], types: Seq[CAST.Type]) - (implicit funCtx: FunCtx) = { - require(exprs.length == types.length) - normaliseExecution(exprs map convertToStmt, types) - } - - private def normaliseExecution(typedStmts: Seq[(CAST.Stmt, CAST.Type)]): FlattenedSeq = - normaliseExecution(typedStmts map { _._1 }, typedStmts map { _._2 }) - - private def normaliseExecution(stmts: Seq[CAST.Stmt], types: Seq[CAST.Type]): FlattenedSeq = { - require(stmts.length == types.length) - - // Create temporary variables if needed - val stmtsFs = stmts map flatten - val fs = (stmtsFs zip types) map { - case (f, _) if f.value.isPureValue => f - - case (f, typ) => - // Similarly to buildDeclInitVar: - val (tmp, body) = f.value match { - case CAST.IfElse(cond, thn, elze) => - val tmp = CAST.FreshVar(typ.mutable, "normexec") - val decl = CAST.DeclVar(tmp) - val ifelse = buildIfElse(cond, injectAssign(tmp, thn), injectAssign(tmp, elze)) - val body = f.body ~~ decl ~ ifelse - - (tmp, body) - - case value => - val tmp = CAST.FreshVal(typ, "normexec") - val body = f.body ~~ CAST.DeclInitVar(tmp, f.value) - - (tmp, body) - } - - val value = CAST.AccessVar(tmp.id) - flatten(body ~ value) - } - - val empty = Seq[CAST.Stmt]() - val bodies = fs.foldLeft(empty){ _ ++ _.body } - val values = fs map { _.value } - - FlattenedSeq(values, bodies) - } - - private def containsArrayType(typ: TypeTree): Boolean = typ match { - case Int32Type => false - case BooleanType => false - case UnitType => false - case ArrayType(_) => true - case TupleType(bases) => bases exists containsArrayType - case CaseClassType(cd, _) => cd.fields map { _.getType } exists containsArrayType - } - - private def internalError(msg: String) = ctx.reporter.internalError(msg) - private def fatalError(msg: String) = ctx.reporter.fatalError(msg) - private def debug(msg: String) = ctx.reporter.debug(msg)(utils.DebugSectionGenC) - -} - diff --git a/src/main/scala/leon/genc/CFileOutputPhase.scala b/src/main/scala/leon/genc/CFileOutputPhase.scala deleted file mode 100644 index 2e687e0b7056bc2c777e0516c304b02014e4ed1e..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/genc/CFileOutputPhase.scala +++ /dev/null @@ -1,54 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package genc - -import java.io.File -import java.io.FileWriter -import java.io.BufferedWriter - -object CFileOutputPhase extends UnitPhase[CAST.Prog] { - - val name = "C File Output" - val description = "Output converted C program to the specified file (default leon.c)" - - val optOutputFile = new LeonOptionDef[String] { - val name = "o" - val description = "Output file" - val default = "leon.c" - val usageRhs = "file" - val parser = OptionParsers.stringParser - } - - override val definedOptions: Set[LeonOptionDef[Any]] = Set(optOutputFile) - - def apply(ctx: LeonContext, program: CAST.Prog) { - // Get the output file name from command line options, or use default - val outputFile = new File(ctx.findOptionOrDefault(optOutputFile)) - val parent = outputFile.getParentFile() - try { - if (parent != null) { - parent.mkdirs() - } - } catch { - case _ : java.io.IOException => ctx.reporter.fatalError("Could not create directory " + parent) - } - - // Output C code to the file - try { - val fstream = new FileWriter(outputFile) - val out = new BufferedWriter(fstream) - - val p = new CPrinter - p.print(program) - - out.write(p.toString) - out.close() - - ctx.reporter.info(s"Output written to $outputFile") - } catch { - case _ : java.io.IOException => ctx.reporter.fatalError("Could not write on " + outputFile) - } - } - -} diff --git a/src/main/scala/leon/genc/CPrinter.scala b/src/main/scala/leon/genc/CPrinter.scala deleted file mode 100644 index 9c38fcf81574a2223c0f084f7ddb2f8e45f8c537..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/genc/CPrinter.scala +++ /dev/null @@ -1,198 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package genc - -import CAST._ -import CPrinterHelpers._ - -class CPrinter(val sb: StringBuffer = new StringBuffer) { - override def toString = sb.toString - - def print(tree: Tree) = pp(tree)(PrinterContext(0, this)) - - def pp(tree: Tree)(implicit ctx: PrinterContext): Unit = tree match { - /* ---------------------------------------------------------- Types ----- */ - case typ: Type => c"${typ.toString}" - - - /* ------------------------------------------------------- Literals ----- */ - case IntLiteral(v) => c"$v" - case BoolLiteral(b) => c"$b" - - - /* --------------------------------------------------- Definitions ----- */ - case Prog(structs, functions) => - c"""|/* ------------------------------------ includes ----- */ - | - |${nary(includeStmts, sep = "\n")} - | - |/* ---------------------- data type declarations ----- */ - | - |${nary(structs map StructDecl, sep = "\n")} - | - |/* ----------------------- data type definitions ----- */ - | - |${nary(structs map StructDef, sep = "\n")} - | - |/* ----------------------- function declarations ----- */ - | - |${nary(functions map FunDecl, sep = "\n")} - | - |/* ------------------------ function definitions ----- */ - | - |${nary(functions, sep = "\n")} - |""" - - case f @ Fun(_, _, _, body) => - c"""|${FunSign(f)} - |{ - | $body - |} - |""" - - case Id(name) => c"$name" - - - /* --------------------------------------------------------- Stmts ----- */ - case NoStmt => c"/* empty */" - - case Compound(stmts) => - val lastIdx = stmts.length - 1 - - for ((stmt, idx) <- stmts.zipWithIndex) { - if (stmt.isValue) c"$stmt;" - else c"$stmt" - - if (idx != lastIdx) - c"$NewLine" - } - - case Assert(pred, Some(error)) => c"assert($pred); /* $error */" - case Assert(pred, None) => c"assert($pred);" - - case Var(id, _) => c"$id" - case DeclVar(Var(id, typ)) => c"$typ $id;" - - // If the length is a literal we don't need VLA - case DeclInitVar(Var(id, typ), ai @ ArrayInit(IntLiteral(length), _, _)) => - val buffer = FreshId("buffer") - val values = for (i <- 0 until length) yield ai.defaultValue - c"""|${ai.valueType} $buffer[${ai.length}] = { $values }; - |$typ $id = { .length = ${ai.length}, .data = $buffer }; - |""" - - // TODO depending on the type of array (e.g. `char`) or the value (e.g. `0`), we could use `memset`. - case DeclInitVar(Var(id, typ), ai: ArrayInit) => // Note that `typ` is a struct here - val buffer = FreshId("vla_buffer") - val i = FreshId("i") - c"""|${ai.valueType} $buffer[${ai.length}]; - |for (${Int32} $i = 0; $i < ${ai.length}; ++$i) { - | $buffer[$i] = ${ai.defaultValue}; - |} - |$typ $id = { .length = ${ai.length}, .data = $buffer }; - |""" - - case DeclInitVar(Var(id, typ), ai: ArrayInitWithValues) => // Note that `typ` is a struct here - val buffer = FreshId("buffer") - c"""|${ai.valueType} $buffer[${ai.length}] = { ${ai.values} }; - |$typ $id = { .length = ${ai.length}, .data = $buffer }; - |""" - - case DeclInitVar(Var(id, typ), value) => - c"$typ $id = $value;" - - case Assign(lhs, rhs) => - c"$lhs = $rhs;" - - case UnOp(op, rhs) => c"($op$rhs)" - case MultiOp(op, stmts) => c"""${nary(stmts, sep = s" ${op.fixMargin} ", - opening = "(", closing = ")")}""" - case SubscriptOp(ptr, idx) => c"$ptr[$idx]" - - case Break => c"break;" - case Return(stmt) => c"return $stmt;" - - case IfElse(cond, thn, elze) => - c"""|if ($cond) - |{ - | $thn - |} - |else - |{ - | $elze - |} - |""" - - case While(cond, body) => - c"""|while ($cond) - |{ - | $body - |} - |""" - - case AccessVar(id) => c"$id" - case AccessRef(id) => c"(*$id)" - case AccessAddr(id) => c"(&$id)" - case AccessField(struct, field) => c"$struct.$field" - case Call(id, args) => c"$id($args)" - - case StructInit(args, struct) => - c"(${struct.id}) { " - for ((id, stmt) <- args.init) { - c".$id = $stmt, " - } - if (!args.isEmpty) { - val (id, stmt) = args.last - c".$id = $stmt " - } - c"}" - - /* --------------------------------------------------------- Error ----- */ - case tree => sys.error(s"CPrinter: <<$tree>> was not handled properly") - } - - - def pp(wt: WrapperTree)(implicit ctx: PrinterContext): Unit = wt match { - case FunDecl(f) => - c"${FunSign(f)};$NewLine" - - case FunSign(Fun(id, retType, Nil, _)) => - c"""|$retType - |$id($Void)""" - - case FunSign(Fun(id, retType, params, _)) => - c"""|$retType - |$id(${nary(params map DeclParam)})""" - - case DeclParam(Var(id, typ)) => - c"$typ $id" - - case StructDecl(s) => - c"struct $s;" - - case StructDef(Struct(name, fields)) => - c"""|typedef struct $name { - | ${nary(fields map DeclParam, sep = ";\n", closing = ";")} - |} $name; - |""" - - case NewLine => - c"""| - |""" - } - - /** Hardcoded list of required include files from C standard library **/ - lazy val includes = "assert.h" :: "stdbool.h" :: "stdint.h" :: Nil - lazy val includeStmts = includes map { i => s"#include <$i>" } - - /** Wrappers to distinguish how the data should be printed **/ - sealed abstract class WrapperTree - case class FunDecl(f: Fun) extends WrapperTree - case class FunSign(f: Fun) extends WrapperTree - case class DeclParam(x: Var) extends WrapperTree - case class StructDecl(s: Struct) extends WrapperTree - case class StructDef(s: Struct) extends WrapperTree - case object NewLine extends WrapperTree -} - diff --git a/src/main/scala/leon/genc/CPrinterHelper.scala b/src/main/scala/leon/genc/CPrinterHelper.scala deleted file mode 100644 index a173aa24b5df3ef352f2ef4bb49a79d88ccefc7f..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/genc/CPrinterHelper.scala +++ /dev/null @@ -1,86 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package genc - -import CAST.Tree - -/* Printer helpers adapted to C code generation */ - -case class PrinterContext( - indent: Int, - printer: CPrinter -) - -object CPrinterHelpers { - implicit class Printable(val f: PrinterContext => Any) extends AnyVal { - def print(ctx: PrinterContext) = f(ctx) - } - - implicit class PrinterHelper(val sc: StringContext) extends AnyVal { - def c(args: Any*)(implicit ctx: PrinterContext): Unit = { - val printer = ctx.printer - import printer.WrapperTree - val sb = printer.sb - - val strings = sc.parts.iterator - val expressions = args.iterator - - var extraInd = 0 - var firstElem = true - - while(strings.hasNext) { - val s = strings.next.stripMargin - - // Compute indentation - val start = s.lastIndexOf('\n') - if(start >= 0 || firstElem) { - var i = start + 1 - while(i < s.length && s(i) == ' ') { - i += 1 - } - extraInd = (i - start - 1) / 2 - } - - firstElem = false - - // Make sure new lines are also indented - sb.append(s.replaceAll("\n", "\n" + (" " * ctx.indent))) - - val nctx = ctx.copy(indent = ctx.indent + extraInd) - - if (expressions.hasNext) { - val e = expressions.next - - e match { - case ts: Seq[Any] => - nary(ts).print(nctx) - - case t: Tree => - printer.pp(t)(nctx) - - case wt: WrapperTree => - printer.pp(wt)(nctx) - - case p: Printable => - p.print(nctx) - - case e => - sb.append(e.toString) - } - } - } - } - } - - def nary(ls: Seq[Any], sep: String = ", ", opening: String = "", closing: String = ""): Printable = { - val (o, c) = if(ls.isEmpty) ("", "") else (opening, closing) - val strs = o +: List.fill(ls.size-1)(sep) :+ c - - implicit pctx: PrinterContext => - new StringContext(strs: _*).c(ls: _*) - } - -} - - diff --git a/src/main/scala/leon/genc/GenerateCPhase.scala b/src/main/scala/leon/genc/GenerateCPhase.scala deleted file mode 100644 index 2d66adbc969d0fe8ad8488c25e6cf9a12362fb69..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/genc/GenerateCPhase.scala +++ /dev/null @@ -1,20 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package genc - -import purescala.Definitions.Program - -object GenerateCPhase extends SimpleLeonPhase[Program, CAST.Prog] { - - val name = "Generate C" - val description = "Generate equivalent C code from Leon's AST" - - def apply(ctx: LeonContext, program: Program) = { - ctx.reporter.debug("Running code conversion phase: " + name)(utils.DebugSectionLeon) - val cprogram = new CConverter(ctx, program).convert - cprogram - } - -} - diff --git a/src/main/scala/leon/invariant/datastructure/DisjointSets.scala b/src/main/scala/leon/invariant/datastructure/DisjointSets.scala deleted file mode 100644 index d9bb73e39c70072a1719549470db10f0fe41875f..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/invariant/datastructure/DisjointSets.scala +++ /dev/null @@ -1,61 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package invariant.datastructure - -import scala.collection.mutable.{Map => MutableMap} - -class DisjointSets[T] { - // A map from elements to their parent and rank - private var disjTree = MutableMap[T, (T, Int)]() - - private def findInternal(x: T): (T, Int) = { - val (p, rank) = disjTree(x) - if (p == x) - (x, rank) - else { - val root = findInternal(p) - // compress path - disjTree(x) = root - root - } - } - - private def findOrCreateInternal(x: T) = - if (!disjTree.contains(x)) { - disjTree += (x -> (x, 1)) - (x, 1) - } else findInternal(x) - - def findOrCreate(x: T) = findOrCreateInternal(x)._1 - - def find(x: T) = findInternal(x)._1 - - def union(x: T, y: T) { - val (rep1, rank1) = findOrCreateInternal(x) - val (rep2, rank2) = findOrCreateInternal(y) - if (rank1 < rank2) { - disjTree(rep1) = (rep2, rank2) - } else if (rank2 < rank1) { - disjTree(rep2) = (rep1, rank1) - } else - disjTree(rep1) = (rep2, rank2 + 1) - } - - def toMap = { - val repToSet = disjTree.keys.foldLeft(MutableMap[T, Set[T]]()) { - case (acc, k) => - val root = find(k) - if (acc.contains(root)) - acc(root) = acc(root) + k - else - acc += (root -> Set(k)) - acc - } - disjTree.keys.map {k => (k -> repToSet(find(k)))}.toMap - } - - override def toString = { - disjTree.toString - } -} \ No newline at end of file diff --git a/src/main/scala/leon/invariant/datastructure/Graph.scala b/src/main/scala/leon/invariant/datastructure/Graph.scala deleted file mode 100644 index 8af783f52d8cf033fc471a7c4ed522195547d514..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/invariant/datastructure/Graph.scala +++ /dev/null @@ -1,179 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package invariant.datastructure - -class DirectedGraph[T] { - - var adjlist = scala.collection.mutable.Map[T, Set[T]]() - var edgeCount: Int = 0 - - def addNode(n: T) { - if (!adjlist.contains(n)) { - adjlist.update(n, Set()) - } - } - - def addEdge(src: T, dest: T): Unit = { - val newset = if (adjlist.contains(src)) adjlist(src) + dest - else Set(dest) - - //this has some side-effects - adjlist.update(src, newset) - - edgeCount += 1 - } - - def BFSReach(src: T, dest: T, excludeSrc: Boolean = false): Boolean = { - var queue = List[T]() - var visited = Set[T]() - visited += src - - //TODO: is there a better (and efficient) way to implement BFS without using side-effects - def BFSReachRecur(cur: T): Boolean = { - var found: Boolean = false - if (adjlist.contains(cur)) { - adjlist(cur).foreach((fi) => { - if (fi == dest) found = true - else if (!visited.contains(fi)) { - visited += fi - queue ::= fi - } - }) - } - if (found) true - else if (queue.isEmpty) false - else { - val (head :: tail) = queue - queue = tail - BFSReachRecur(head) - } - } - - if (!excludeSrc && src == dest) true - else BFSReachRecur(src) - } - - def BFSReachables(srcs: Seq[T]): Set[T] = { - var queue = List[T]() - var visited = Set[T]() - visited ++= srcs.toSet - - def BFSReachRecur(cur: T): Unit = { - if (adjlist.contains(cur)) { - adjlist(cur).foreach((neigh) => { - if (!visited.contains(neigh)) { - visited += neigh - queue ::= neigh - } - }) - } - if (queue.nonEmpty) { - val (head :: tail) = queue - queue = tail - BFSReachRecur(head) - } - } - - srcs.foreach{src => BFSReachRecur(src) } - visited - } - - def containsEdge(src: T, dest: T): Boolean = { - if (adjlist.contains(src)) { - adjlist(src).contains(dest) - } else false - } - - def getEdgeCount: Int = edgeCount - def getNodes: Set[T] = adjlist.keySet.toSet - def getSuccessors(src: T): Set[T] = adjlist(src) - - /** - * TODO: Change this to the verified component - * The computed nodes are also in reverse topological order. - */ - def sccs: List[List[T]] = { - - type Component = List[T] - - case class State(count: Int, - visited: Set[T], - dfNumber: Map[T, Int], - lowlinks: Map[T, Int], - stack: List[T], - components: List[Component]) - - def search(vertex: T, state: State): State = { - val newState = state.copy(visited = state.visited + vertex, - dfNumber = state.dfNumber + (vertex -> state.count), - count = state.count + 1, - lowlinks = state.lowlinks + (vertex -> state.count), - stack = vertex :: state.stack) - - def processNeighbor(st: State, w: T): State = { - if (!st.visited(w)) { - val st1 = search(w, st) - val min = Math.min(st1.lowlinks(w), st1.lowlinks(vertex)) - st1.copy(lowlinks = st1.lowlinks + (vertex -> min)) - } else { - if ((st.dfNumber(w) < st.dfNumber(vertex)) && st.stack.contains(w)) { - val min = Math.min(st.dfNumber(w), st.lowlinks(vertex)) - st.copy(lowlinks = st.lowlinks + (vertex -> min)) - } else st - } - } - val strslt = getSuccessors(vertex).foldLeft(newState)(processNeighbor) - if (strslt.lowlinks(vertex) == strslt.dfNumber(vertex)) { - val index = strslt.stack.indexOf(vertex) - val (comp, rest) = strslt.stack.splitAt(index + 1) - strslt.copy(stack = rest, - components = strslt.components :+ comp) - } else strslt - } - val initial = State( - count = 1, - visited = Set(), - dfNumber = Map(), - lowlinks = Map(), - stack = Nil, - components = Nil) - - var state = initial - val totalNodes = getNodes - while (state.visited.size < totalNodes.size) { - totalNodes.find(n => !state.visited.contains(n)).foreach { n => - state = search(n, state) - } - } - state.components - } - - /** - * Reverses the direction of the edges in the graph - */ - def reverse : DirectedGraph[T] = { - val revg = new DirectedGraph[T]() - adjlist.foreach{ - case (src, dests) => - dests.foreach { revg.addEdge(_, src) } - } - revg - } -} - -class UndirectedGraph[T] extends DirectedGraph[T] { - - override def addEdge(src: T, dest: T): Unit = { - val newset1 = - if (adjlist.contains(src)) adjlist(src) + dest - else Set(dest) - val newset2 = - if (adjlist.contains(dest)) adjlist(dest) + src - else Set(src) - //this has some side-effects - adjlist.update(src, newset1) - adjlist.update(dest, newset2) - edgeCount += 1 - } -} diff --git a/src/main/scala/leon/invariant/datastructure/Maps.scala b/src/main/scala/leon/invariant/datastructure/Maps.scala deleted file mode 100644 index 7bd237a74c95a5da247442c13fa463a1a6889042..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/invariant/datastructure/Maps.scala +++ /dev/null @@ -1,112 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package invariant.datastructure - -import scala.annotation.tailrec - -class MultiMap[A, B] extends scala.collection.mutable.HashMap[A, scala.collection.mutable.Set[B]] with scala.collection.mutable.MultiMap[A, B] { - /** - * Creates a new map and does not change the existing map - */ - def append(that: MultiMap[A, B]): MultiMap[A, B] = { - val newmap = new MultiMap[A, B]() - this.foreach { case (k, vset) => newmap += (k -> vset) } - that.foreach { - case (k, vset) => vset.foreach(v => newmap.addBinding(k, v)) - } - newmap - } -} - -/** - * A multimap that allows duplicate entries - */ -class OrderedMultiMap[A, B] extends scala.collection.mutable.HashMap[A, scala.collection.mutable.ListBuffer[B]] { - - def addBinding(key: A, value: B): this.type = { - get(key) match { - case None => - val list = new scala.collection.mutable.ListBuffer[B]() - list += value - this(key) = list - case Some(list) => - list += value - } - this - } - - /** - * Creates a new map and does not change the existing map - */ - def append(that: OrderedMultiMap[A, B]): OrderedMultiMap[A, B] = { - val newmap = new OrderedMultiMap[A, B]() - this.foreach { case (k, vlist) => newmap += (k -> vlist) } - that.foreach { - case (k, vlist) => vlist.foreach(v => newmap.addBinding(k, v)) - } - newmap - } - - /** - * Make the value of every key distinct - */ - def distinct: OrderedMultiMap[A, B] = { - val newmap = new OrderedMultiMap[A, B]() - this.foreach { case (k, vlist) => newmap += (k -> vlist.distinct) } - newmap - } -} - -/** - * Implements a mapping from Seq[A] to B where Seq[A] - * is stored as a Trie - */ -final class TrieMap[A, B] { - var childrenMap = Map[A, TrieMap[A, B]]() - var dataMap = Map[A, B]() - - @tailrec def addBinding(key: Seq[A], value: B) { - key match { - case Seq() => - throw new IllegalStateException("Key is empty!!") - case Seq(x) => - //add the value to the dataMap - if (dataMap.contains(x)) - throw new IllegalStateException("A mapping for key already exists: " + x + " --> " + dataMap(x)) - else - dataMap += (x -> value) - case head +: tail => //here, tail has at least one element - //check if we have an entry for seq(0) if yes go to the children, if not create one - val child = childrenMap.getOrElse(head, { - val ch = new TrieMap[A, B]() - childrenMap += (head -> ch) - ch - }) - child.addBinding(tail, value) - } - } - - @tailrec def lookup(key: Seq[A]): Option[B] = { - key match { - case Seq() => - throw new IllegalStateException("Key is empty!!") - case Seq(x) => - dataMap.get(x) - case head +: tail => //here, tail has at least one element - childrenMap.get(head) match { - case Some(child) => - child.lookup(tail) - case _ => None - } - } - } -} - -class CounterMap[T] extends scala.collection.mutable.HashMap[T, Int] { - def inc(v: T) = { - if (this.contains(v)) - this(v) += 1 - else this += (v -> 1) - } -} \ No newline at end of file diff --git a/src/main/scala/leon/invariant/engine/CompositionalTimeBoundSolver.scala b/src/main/scala/leon/invariant/engine/CompositionalTimeBoundSolver.scala deleted file mode 100644 index 576d28dbd7fbcf4d2919152a83f22433dff519ab..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/invariant/engine/CompositionalTimeBoundSolver.scala +++ /dev/null @@ -1,208 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package invariant.engine - -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Extractors._ -import purescala.Types._ -import transformations._ -import invariant.structure.FunctionUtils._ -import leon.invariant.structure.Formula -import leon.invariant.structure.Call -import leon.invariant.util._ -import leon.invariant.factories.TemplateSolverFactory -import leon.invariant.util.Minimizer -import leon.solvers.Model -import Util._ -import PredicateUtil._ -import ProgramUtil._ -import invariant.factories.TemplateInstantiator._ - -class CompositionalTimeBoundSolver(ctx: InferenceContext, prog: Program, rootFd: FunDef) - extends FunctionTemplateSolver { - - val printIntermediatePrograms = false - val debugDecreaseConstraints = false - val debugComposition = false - val reporter = ctx.reporter - - def inferTemplate(instProg: Program) = { - (new UnfoldingTemplateSolver(ctx, instProg, findRoot(instProg)))() - } - - def findRoot(prog: Program) = { - functionByName(rootFd.id.name, prog).get - } - - def apply() = { - // Check if all the three templates have different template variable sets - val (Some(tprTmpl), Some(recTmpl), Some(timeTmpl), othersTmpls) = extractSeparateTemplates(rootFd) - val tmplIds = (Seq(tprTmpl, recTmpl, timeTmpl) ++ othersTmpls) flatMap getTemplateIds - if (tmplIds.toSet.size < tmplIds.size) - throw new IllegalStateException("Templates for tpr, rec, time as well as all other templates " + - " taken together should not have the any common template variables for compositional analysis") - - val origProg = prog - // add only rec templates for all functions - val funToRecTmpl = userLevelFunctions(origProg).collect { - case fd if fd.hasTemplate && fd == rootFd => - fd -> recTmpl - case fd if fd.hasTemplate => - fd -> fd.getTemplate - }.toMap - val recProg = assignTemplateAndCojoinPost(funToRecTmpl, origProg) - - // add only tpr template for all functions - val funToNonRecTmpl = userLevelFunctions(origProg).collect { - case fd if fd.hasTemplate && fd == rootFd => - fd -> tprTmpl - case fd if fd.hasTemplate => - fd -> fd.getTemplate - }.toMap - val tprProg = assignTemplateAndCojoinPost(funToNonRecTmpl, origProg) - - if (printIntermediatePrograms) { - reporter.info("RecProg:\n" + recProg) - reporter.info("TRPProg: \n" + tprProg) - } - val recInfRes = inferTemplate(recProg) - val tprInfRes = inferTPRTemplate(tprProg) - - (recInfRes, tprInfRes) match { - case (Some(InferResult(true, Some(recModel), _)), - Some(InferResult(true, Some(tprModel), _))) => - // create a new program by omitting the templates of the root function - val funToTmpl = userLevelFunctions(origProg).collect { - case fd if fd.hasTemplate && fd != rootFd => - (fd -> fd.getTemplate) - }.toMap - val compProg = assignTemplateAndCojoinPost(funToTmpl, origProg) - val compFunDef = findRoot(compProg) - //val nctx = ctx.copy(program = compProg) - - // construct the instantiated tpr bound and check if it monotonically decreases - val Operator(Seq(_, tprFun), _) = tprTmpl - val tprFunInst = (new RealToInt()).mapExpr( - replace(tprModel.map { case (k, v) => (k.toVariable -> v) }.toMap, tprFun)) - // TODO: this would fail on non-integers, handle these by approximating to the next bigger integer - - // Upper bound on time time <= recFun * tprFun + tprFun - val (_, multFun) = MultFuncs.getMultFuncs(if (ctx.usereals) RealType else IntegerType) - val Operator(Seq(_, recFun), _) = recTmpl - val recFunInst = (new RealToInt()).mapExpr( - replace(recModel.map { case (k, v) => (k.toVariable -> v) }.toMap, recFun)) - - val timeUpperBound = ExpressionTransformer.normalizeMultiplication( - Plus(FunctionInvocation(TypedFunDef(multFun, Seq()), - Seq(recFunInst, tprFunInst)), tprFunInst), ctx.multOp) - - // map the old functions in the vc using the new functions - val substMap = origProg.definedFunctions.collect { - case fd => (fd -> functionByName(fd.id.name, compProg).get) - }.toMap - // res = body - val body = mapFunctionsInExpr(substMap)(Equals(getResId(rootFd).get.toVariable, rootFd.body.get)) - val pre = rootFd.precondition.getOrElse(tru) - val Operator(Seq(timeInstExpr, _), _) = timeTmpl - val trans = mapFunctionsInExpr(substMap) _ - val assump = trans(createAnd(Seq(LessEquals(timeInstExpr, timeUpperBound), pre))) - val conseq = trans(timeTmpl) - - if (printIntermediatePrograms) reporter.info("Comp prog: " + compProg) - if (debugComposition) reporter.info("Compositional VC: " + createAnd(Seq(assump, body, Not(conseq)))) - - val recTempSolver = new UnfoldingTemplateSolver(ctx, compProg, compFunDef) { - val minFunc = { - val mizer = new Minimizer(ctx, compProg) - Some(mizer.minimizeBounds(mizer.computeCompositionLevel(timeTmpl)) _) - } - override lazy val templateSolver = - TemplateSolverFactory.createTemplateSolver(ctx, compProg, constTracker, rootFd, minFunc) - override def instantiateModel(model: Model, funcs: Seq[FunDef]) = { - funcs.collect { - case `compFunDef` => compFunDef -> timeTmpl - case fd if fd.hasTemplate => - fd -> instantiateNormTemplates(model, fd.normalizedTemplate.get) - }.toMap - } - } - recTempSolver.solveParametricVC(assump, body, conseq) match { - case Some(InferResult(true, Some(timeModel),timeInferredFuncs)) => - val inferredFuns = (recInfRes.get.inferredFuncs ++ tprInfRes.get.inferredFuncs ++ timeInferredFuncs).distinct - Some(InferResult(true, Some(recModel ++ tprModel.toMap ++ timeModel.toMap), - inferredFuns.map(ifd => functionByName(ifd.id.name, origProg).get).distinct)) - case res @ _ => - res - } - case _ => - reporter.info("Could not infer bounds on rec and(or) tpr. Cannot precced with composition.") - None - } - } - - def extractSeparateTemplates(funDef: FunDef): (Option[Expr], Option[Expr], Option[Expr], Seq[Expr]) = { - if (!funDef.hasTemplate) (None, None, None, Seq[Expr]()) - else { - val template = ExpressionTransformer.pullAndOrs(And(funDef.getTemplate, - funDef.getPostWoTemplate)) // note that some bounds can occur in post and not in tmpl - def extractTmplConjuncts(tmpl: Expr): Seq[Expr] = { - tmpl match { - case And(seqExprs) => - seqExprs - case _ => - throw new IllegalStateException("Compositional reasoning requires templates to be conjunctions!" + tmpl) - } - } - val tmplConjuncts = extractTmplConjuncts(template) - val tupleSelectToInst = InstUtil.getInstMap(funDef) - var tprTmpl: Option[Expr] = None - var timeTmpl: Option[Expr] = None - var recTmpl: Option[Expr] = None - var othersTmpls: Seq[Expr] = Seq[Expr]() - tmplConjuncts.foreach { - case conj@Operator(Seq(lhs, _), _) if (tupleSelectToInst.contains(lhs)) => - tupleSelectToInst(lhs) match { - case n if n == TPR.name => - tprTmpl = Some(conj) - case n if n == Time.name => - timeTmpl = Some(conj) - case n if n == Rec.name => - recTmpl = Some(conj) - case _ => - othersTmpls = othersTmpls :+ conj - } - case conj => - othersTmpls = othersTmpls :+ conj - } - (tprTmpl, recTmpl, timeTmpl, othersTmpls) - } - } - - def inferTPRTemplate(tprProg: Program) = { - val tempSolver = new UnfoldingTemplateSolver(ctx, tprProg, findRoot(tprProg)) { - override def constructVC(rootFd: FunDef): (Expr, Expr, Expr) = { - val body = Equals(getResId(rootFd).get.toVariable, rootFd.body.get) - val preExpr = rootFd.precondition.getOrElse(tru) - val tprTmpl = rootFd.getTemplate - val postWithTemplate = And(rootFd.getPostWoTemplate, tprTmpl) - // generate constraints characterizing decrease of the tpr function with recursive calls - val Operator(Seq(_, tprFun), op) = tprTmpl - val bodyFormula = new Formula(rootFd, ExpressionTransformer.normalizeExpr(body, ctx.multOp), ctx) - val constraints = bodyFormula.callsInFormula.collect { - case call @ Call(_, FunctionInvocation(TypedFunDef(`rootFd`, _), _)) => //direct recursive call ? - val cdata = bodyFormula.callData(call) - Implies(cdata.guard, LessEquals(replace(formalToActual(call), tprFun), tprFun)) - } - if (debugDecreaseConstraints) - reporter.info("Decrease constraints: " + createAnd(constraints.toSeq)) - - val fullPost = createAnd(postWithTemplate +: constraints.toSeq) - (bodyFormula.toExpr, preExpr, fullPost) - } - } - tempSolver() - } -} \ No newline at end of file diff --git a/src/main/scala/leon/invariant/engine/ConstraintTracker.scala b/src/main/scala/leon/invariant/engine/ConstraintTracker.scala deleted file mode 100644 index 488ce7e3a6d952005c4ae0dcf28a3f1913eb0973..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/invariant/engine/ConstraintTracker.scala +++ /dev/null @@ -1,61 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package invariant.engine - -import purescala.Definitions._ -import purescala.Expressions._ -import invariant.structure._ -import invariant.util.ExpressionTransformer._ -import purescala.ExprOps._ -import invariant.util.PredicateUtil._ - -object ConstraintTracker { - val debugVC = false -} -class ConstraintTracker(ctx : InferenceContext, program: Program, rootFun : FunDef/*, temFactory: TemplateFactory*/) { - - import ConstraintTracker._ - //a mapping from functions to its VCs represented as a CNF formula - protected var funcVCs = Map[FunDef,Formula]() - - val vcRefiner = new RefinementEngine(ctx, program, this) - val specInstantiator = new SpecInstantiator(ctx, program, this) - - def getFuncs : Seq[FunDef] = funcVCs.keys.toSeq - def hasVC(fdef: FunDef) = funcVCs.contains(fdef) - def getVC(fd: FunDef) : Formula = funcVCs(fd) - - /** - * @param body the body part of the VC that may possibly have instrumentation - * @param assump is the additional assumptions e.g. pre and conseq - * is the goal e.g. post - * The VC constructed is assump ^ body ^ Not(conseq) - */ - def addVC(fd: FunDef, assump: Expr, body: Expr, conseq: Expr) = { - if(debugVC) { - println(s"Init VC \n assumption: $assump \n body: $body \n conseq: $conseq") - } - val flatBody = normalizeExpr(body, ctx.multOp) - val flatAssump = normalizeExpr(assump, ctx.multOp) - val conseqNeg = normalizeExpr(Not(conseq), ctx.multOp) - val callCollect = collect { - case c @ Equals(_, _: FunctionInvocation) => Set[Expr](c) - case _ => Set[Expr]() - } _ - val specCalls = callCollect(flatAssump) ++ callCollect(conseqNeg) - val vc = createAnd(Seq(flatAssump, flatBody, conseqNeg)) - funcVCs += (fd -> new Formula(fd, vc, ctx, specCalls)) - } - - def initialize = { - //assume specifications - specInstantiator.instantiate - } - - def refineVCs(toUnrollCalls: Option[Set[Call]]) : Set[Call] = { - val unrolledCalls = vcRefiner.refineAbstraction(toUnrollCalls) - specInstantiator.instantiate - unrolledCalls - } -} diff --git a/src/main/scala/leon/invariant/engine/InferInvariantsPhase.scala b/src/main/scala/leon/invariant/engine/InferInvariantsPhase.scala deleted file mode 100644 index e5866f5ee92d9a3810e29729310ea3f20beec7b9..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/invariant/engine/InferInvariantsPhase.scala +++ /dev/null @@ -1,42 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package invariant.engine - -import purescala.Definitions._ - -/** - * @author ravi - * This phase performs automatic invariant inference. - */ -object InferInvariantsPhase extends SimpleLeonPhase[Program, InferenceReport] { - val name = "InferInv" - val description = "Invariant Inference" - - val optFunctionUnroll = LeonFlagOptionDef("fullunroll", "Unroll all calls in every unroll step", false) - val optWithMult = LeonFlagOptionDef("withmult", "Multiplication is not converted to a recursive function in VCs", false) - val optUseReals = LeonFlagOptionDef("usereals", "Interpret the input program as a real program", false) - val optMinBounds = LeonFlagOptionDef("minbounds", "tighten time bounds", false) - val optInferTemp = LeonFlagOptionDef("inferTemp", "Infer templates by enumeration", false) - val optCegis = LeonFlagOptionDef("cegis", "use cegis instead of farkas", false) - val optStatsSuffix = LeonStringOptionDef("stats-suffix", "the suffix of the statistics file", "", "s") - val optVCTimeout = LeonLongOptionDef("vcTimeout", "Timeout after T seconds when trying to prove a verification condition", 20, "s") - val optNLTimeout = LeonLongOptionDef("nlTimeout", "Timeout after T seconds when trying to solve nonlinear constraints", 20, "s") - val optDisableInfer = LeonFlagOptionDef("disableInfer", "Disable automatic inference of auxiliary invariants", false) - val optAssumePre = LeonFlagOptionDef("assumepreInf", "Assume preconditions of callees during unrolling", false) - - override val definedOptions: Set[LeonOptionDef[Any]] = - Set(optFunctionUnroll, optWithMult, optUseReals, - optMinBounds, optInferTemp, optCegis, optStatsSuffix, optVCTimeout, - optNLTimeout, optDisableInfer, optAssumePre) - - def apply(ctx: LeonContext, program: Program): InferenceReport = { - val inferctx = new InferenceContext(program, ctx) - val report = (new InferenceEngine(inferctx)).runWithTimeout() - //println("Final Program: \n" +PrettyPrinter.apply(InferenceReportUtil.pushResultsToInput(inferctx, report.conditions))) - if(!ctx.findOption(GlobalOptions.optSilent).getOrElse(false)) { - println("Inference Result: \n"+report.summaryString) - } - report - } -} diff --git a/src/main/scala/leon/invariant/engine/InferenceContext.scala b/src/main/scala/leon/invariant/engine/InferenceContext.scala deleted file mode 100644 index 750c156e00fa795b6d60a4959be94a33f471b7c3..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/invariant/engine/InferenceContext.scala +++ /dev/null @@ -1,119 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package invariant.engine - -import scala.collection.mutable.{ Map => MutableMap } -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.Types._ -import purescala.ExprOps._ -import transformations._ -import invariant.structure.FunctionUtils._ -import invariant.util._ -import verification._ -import verification.VCKinds -import InferInvariantsPhase._ -import ProgramUtil._ - -/** - * @author ravi - */ -class InferenceContext(val initProgram: Program, val leonContext: LeonContext) { - - var abort = false // a flag for aborting - - // get options from ctx or initialize them to default values - // the following options are enabled by default - val targettedUnroll = !(leonContext.findOption(optFunctionUnroll).getOrElse(false)) - val autoInference = !(leonContext.findOption(optDisableInfer).getOrElse(false)) - val assumepre = leonContext.findOption(optAssumePre).getOrElse(false) - - // the following options are disabled by default - val tightBounds = leonContext.findOption(optMinBounds).getOrElse(false) - val inferTemp = leonContext.findOption(optInferTemp).getOrElse(false) - val withmult = leonContext.findOption(optWithMult).getOrElse(false) - val usereals = leonContext.findOption(optUseReals).getOrElse(false) - val useCegis: Boolean = leonContext.findOption(optCegis).getOrElse(false) - val dumpStats = leonContext.findOption(GlobalOptions.optBenchmark).getOrElse(false) - - // the following options have default values - val vcTimeout = leonContext.findOption(optVCTimeout).getOrElse(15L) // in secs - val nlTimeout = leonContext.findOption(optNLTimeout).getOrElse(15L) - val totalTimeout = leonContext.findOption(GlobalOptions.optTimeout) // in secs - val functionsToInfer = leonContext.findOption(GlobalOptions.optFunctions) - val reporter = leonContext.reporter - val maxCegisBound = 1000 - val statsSuffix = leonContext.findOption(optStatsSuffix).getOrElse("-stats" + FileCountGUID.getID) - - val instrumentedProg = InstrumentationPhase(leonContext, initProgram) - // converts qmarks to templates - val qMarksRemovedProg = { - val funToTmpl = userLevelFunctions(instrumentedProg).collect { - case fd if fd.hasTemplate => - fd -> fd.getTemplate - }.toMap - assignTemplateAndCojoinPost(funToTmpl, instrumentedProg, Map()) - } - - val nlelim = new NonlinearityEliminator(withmult, if (usereals) RealType else IntegerType) - - val inferProgram = { - // convert nonlinearity to recursive functions - nlelim(if (usereals) (new IntToRealProgram())(qMarksRemovedProg) else qMarksRemovedProg) - } - - // other utilities - lazy val enumerationRelation = { - // collect strongest relation for enumeration if defined - var foundStrongest = false - var rel: (Expr, Expr) => Expr = LessEquals.apply _ - //go over all post-conditions and pick the strongest relation - instrumentedProg.definedFunctions.foreach((fd) => { - if (!foundStrongest && fd.hasPostcondition) { - val cond = fd.postcondition.get - postTraversal { - case Equals(_, _) => { - rel = Equals.apply _ - foundStrongest = true - } - case _ => ; - }(cond) - } - }) - rel - } - - def multOp(e1: Expr, e2: Expr) = { - FunctionInvocation(TypedFunDef(nlelim.multFun, nlelim.multFun.tparams.map(_.tp)), Seq(e1, e2)) - } - - val validPosts = MutableMap[String, VCResult]() - - /** - * There should be only one function with funName in the - * program - */ - def isFunctionPostVerified(funName: String) = { - if (validPosts.contains(funName)) { - validPosts(funName).isValid - } - else if (abort) false - else { - val verifyPipe = VerificationPhase - val ctxWithTO = createLeonContext(leonContext, s"--timeout=$vcTimeout", s"--functions=$funName") - (true /: verifyPipe.run(ctxWithTO, qMarksRemovedProg)._2.results) { - case (acc, (VC(_, _, vckind), Some(vcRes))) if vcRes.isInvalid => - throw new IllegalStateException(s"$vckind invalid for function $funName") // TODO: remove the exception - case (acc, (VC(_, _, VCKinds.Postcondition), None)) => - throw new IllegalStateException(s"Postcondition verification returned unknown for function $funName") // TODO: remove the exception - case (acc, (VC(_, _, VCKinds.Postcondition), _)) if validPosts.contains(funName) => - throw new IllegalStateException(s"Multiple postcondition VCs for function $funName") // TODO: remove the exception - case (acc, (VC(_, _, VCKinds.Postcondition), Some(vcRes))) => - validPosts(funName) = vcRes - vcRes.isValid - case (acc, _) => acc - } - } - } -} diff --git a/src/main/scala/leon/invariant/engine/InferenceEngine.scala b/src/main/scala/leon/invariant/engine/InferenceEngine.scala deleted file mode 100644 index 4d73ffae38fb236a8c2690b1f62d0025b64379a8..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/invariant/engine/InferenceEngine.scala +++ /dev/null @@ -1,251 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package invariant.engine - -import purescala._ -import purescala.Definitions._ -import purescala.ExprOps._ -import java.io._ -import verification.VC -import scala.util.control.Breaks._ -import invariant.factories._ -import invariant.util._ -import invariant.structure.FunctionUtils._ -import transformations._ -import leon.utils._ -import Util._ -import ProgramUtil._ -import Stats._ - -/** - * @author ravi - * This phase performs automatic invariant inference. - * TODO: should time be implicitly made positive - */ -class InferenceEngine(val ctx: InferenceContext) extends Interruptible { - - val debugBottomupIterations = false - val debugAnalysisOrder = false - - val ti = new TimeoutFor(this) - val reporter = ctx.reporter - - def interrupt() = { - ctx.abort = true - ctx.leonContext.interruptManager.interrupt() - } - - def recoverInterrupt() = { - ctx.abort = false - } - - def runWithTimeout(progressCallback: Option[InferenceCondition => Unit] = None) = { - ctx.totalTimeout match { - case Some(t) => // timeout in secs - ti.interruptAfter(t * 1000) { - run(progressCallback) - } - case None => - run(progressCallback) - } - } - - private def run(progressCallback: Option[InferenceCondition => Unit] = None): InferenceReport = { - val program = ctx.inferProgram - reporter.info("Running Inference Engine...") - if (ctx.dumpStats) { //register a shutdownhook - sys.ShutdownHookThread({ dumpStats(ctx.statsSuffix) }) - } - val relfuns = ctx.functionsToInfer.getOrElse(program.definedFunctions.map(InstUtil.userFunctionName)) - var results: Map[FunDef, InferenceCondition] = null - time { - if (!ctx.useCegis) { - results = analyseProgram(program, relfuns, defaultVCSolver, progressCallback) - //println("Inferrence did not succeeded for functions: "+functionsToAnalyze.filterNot(succeededFuncs.contains _).map(_.id)) - } else { - var remFuncs = relfuns - var b = 200 - val maxCegisBound = 200 - breakable { - while (b <= maxCegisBound) { - Stats.updateCumStats(1, "CegisBoundsTried") - val succeededFuncs = analyseProgram(program, remFuncs, defaultVCSolver, progressCallback) - val successes = succeededFuncs.keySet.map(InstUtil.userFunctionName) - remFuncs = remFuncs.filterNot(successes.contains _) - if (remFuncs.isEmpty) break - b += 5 //increase bounds in steps of 5 - } - //println("Inferrence did not succeeded for functions: " + remFuncs.map(_.id)) - } - } - } { totTime => updateCumTime(totTime, "TotalTime") } - if (ctx.dumpStats) { - reporter.info("- Dumping statistics") - dumpStats(ctx.statsSuffix) - } - new InferenceReport(results.map { case (fd, ic) => (fd -> List[VC](ic)) }, program)(ctx) - } - - def dumpStats(statsSuffix: String) = { - //pick the module id. - val modid = ctx.inferProgram.units.find(_.isMainUnit).get.id - val filename = modid + statsSuffix + ".txt" - val pw = new PrintWriter(filename) - Stats.dumpStats(pw) - SpecificStats.dumpOutputs(pw) - if (ctx.tightBounds) { - SpecificStats.dumpMinimizationStats(pw) - } - pw.close() - ctx.reporter.info("Stats dumped to file: " + filename) - } - - def defaultVCSolver = - (funDef: FunDef, prog: Program) => { - if (funDef.annotations.contains("compose")) //compositional inference ? - new CompositionalTimeBoundSolver(ctx, prog, funDef) - else - new UnfoldingTemplateSolver(ctx, prog, funDef) - } - - /** - * sort the given functions based on ascending topological order of the callgraph. - * For SCCs, preserve the order in which the functions are called in the program - */ - def sortByTopologicalOrder(program: Program, relfuns: Seq[String]) = { - val callgraph = CallGraphUtil.constructCallGraph(program, onlyBody = true) - val relset = relfuns.toSet - val relfds = program.definedFunctions.filter(fd => relset(InstUtil.userFunctionName(fd))) - val funsToAnalyze = relfds.flatMap(callgraph.transitiveCallees _).toSet - // note: the order preserves the order in which functions appear in the program within an SCC - val funsInOrder = callgraph.reverseTopologicalOrder(program.definedFunctions).filter(funsToAnalyze) - if (debugAnalysisOrder) - reporter.info("Analysis Order: " + funsInOrder.map(_.id.uniqueName)) - funsInOrder - } - - /** - * Returns map from analyzed functions to their inference conditions. - * @param - a list of user-level function names that need to analyzed. The names should not - * include the instrumentation suffixes - * TODO: use function names in inference conditions, so that - * we an get rid of dependence on origFd in many places. - */ - def analyseProgram(startProg: Program, relfuns: Seq[String], - vcSolver: (FunDef, Program) => FunctionTemplateSolver, - progressCallback: Option[InferenceCondition => Unit]): Map[FunDef, InferenceCondition] = { - val functionsToAnalyze = sortByTopologicalOrder(startProg, relfuns) - val funToTmpl = - if (ctx.autoInference) { - //A template generator that generates templates for the functions (here we are generating templates by enumeration) - // not usef for now - /*val tempFactory = new TemplateFactory(Some(new TemplateEnumerator(ctx, startProg)), - startProg, ctx.reporter)*/ - userLevelFunctions(startProg).map(fd => fd -> getOrCreateTemplateForFun(fd)).toMap - } else - userLevelFunctions(startProg).collect { case fd if fd.hasTemplate => fd -> fd.getTemplate }.toMap - val progWithTemplates = assignTemplateAndCojoinPost(funToTmpl, startProg) - var analyzedSet = Map[FunDef, InferenceCondition]() - - functionsToAnalyze.filterNot(fd => { - (fd.annotations contains "verified") || - (fd.annotations contains "library") || - (fd.annotations contains "theoryop") || - (fd.annotations contains "extern") - }).foldLeft(progWithTemplates) { (prog, origFun) => - - if (debugBottomupIterations) { - println("Current Program: \n", - ScalaPrinter.apply(prog, purescala.PrinterOptions(printUniqueIds = true))) - scala.io.StdIn.readLine() - } - - if (ctx.abort) { - reporter.info("- Aborting analysis of " + origFun.id.name) - val ic = new InferenceCondition(Seq(), origFun) - ic.time = Some(0) - prog - } else if (origFun.getPostWoTemplate == tru && !origFun.hasTemplate) { - reporter.info("- Nothing to solve for " + origFun.id.name) - prog - } else { - val funDef = functionByName(origFun.id.name, prog).get - reporter.info("- considering function " + funDef.id.name + "...") - //skip the function if it has been analyzed - if (!analyzedSet.contains(origFun)) { - if (funDef.hasBody && funDef.hasPostcondition) { - // for stats - Stats.updateCounter(1, "procs") - val solver = vcSolver(funDef, prog) - val (infRes, funcTime) = getTime { solver() } - infRes match { - case Some(InferResult(true, model, inferredFuns)) => - val origFds = inferredFuns.map { fd => - (fd -> functionByName(fd.id.name, startProg).get) - }.toMap - // find functions in the source that had a user-defined template and was solved - // and it was not previously solved - val funsWithTemplates = inferredFuns.filter { fd => - val origFd = origFds(fd) - !analyzedSet.contains(origFd) && origFd.hasTemplate - } - // now the templates of these functions will be replaced by inferred invariants - val invs = TemplateInstantiator.getAllInvariants(model.get, funsWithTemplates) - // collect templates of remaining functions - val funToTmpl = userLevelFunctions(prog).collect { - case fd if !invs.contains(fd) && fd.hasTemplate => - fd -> fd.getTemplate - }.toMap - val nextProg = assignTemplateAndCojoinPost(funToTmpl, prog, invs) - // create a inference condition for reporting - var first = true - inferredFuns.foreach { fd => - val origFd = origFds(fd) - val invOpt = if (funsWithTemplates.contains(fd)) { - Some(TemplateInstantiator.getAllInvariants(model.get, Seq(origFd), prettyInv = true)(origFd)) - } else if (fd.hasTemplate) { - val currentInv = TemplateInstantiator.getAllInvariants(model.get, Seq(fd), prettyInv = true)(fd) - // map result variable in currentInv - val repInv = replace(Map(getResId(fd).get.toVariable -> getResId(origFd).get.toVariable), currentInv) - Some(translateExprToProgram(repInv, prog, startProg)) - } else None - invOpt match { - case Some(inv) => - // record the inferred invariants - val inferCond = if (analyzedSet.contains(origFd)) { - val ic = analyzedSet(origFd) - ic.addInv(Seq(inv)) - ic - } else { - val ic = new InferenceCondition(Seq(inv), origFd) - ic.time = if (first) Some(funcTime / 1000.0) else Some(0.0) - // update analyzed set - analyzedSet += (origFd -> ic) - first = false - ic - } - progressCallback.foreach(cb => cb(inferCond)) - case _ => - } - } - nextProg - - case _ => - reporter.info("- Exhausted all templates, cannot infer invariants") - val ic = new InferenceCondition(Seq(), origFun) - ic.time = Some(funcTime / 1000.0) - analyzedSet += (origFun -> ic) - prog - } - } else { - //nothing needs to be done here - reporter.info("Function does not have a body or postcondition") - prog - } - } else prog - } - } - analyzedSet - } -} diff --git a/src/main/scala/leon/invariant/engine/InferenceReport.scala b/src/main/scala/leon/invariant/engine/InferenceReport.scala deleted file mode 100644 index b782aa5e0b82915b248af0114bc8714d8b2bd7eb..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/invariant/engine/InferenceReport.scala +++ /dev/null @@ -1,199 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package invariant.engine - -import purescala.Definitions.FunDef -import verification._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Definitions._ -import purescala.Common._ -import invariant.util._ -import invariant.structure._ -import leon.transformations.InstUtil -import leon.purescala.PrettyPrinter -import Util._ -import PredicateUtil._ -import ProgramUtil._ -import FunctionUtils._ -import purescala._ - -class InferenceCondition(invs: Seq[Expr], funDef: FunDef) - extends VC(BooleanLiteral(true), funDef, null) { - - var time: Option[Double] = None - var invariants = invs - - def addInv(invs: Seq[Expr]) { - invariants ++= invs - } - - lazy val prettyInv = invariants.map(inv => - simplifyArithmetic(InstUtil.replaceInstruVars(multToTimes(inv), fd))) match { - case Seq() => None - case invs => - invs.map(ExpressionTransformer.simplify _).filter(_ != tru) match { - case Seq() => Some(tru) - case Seq(ninv) => Some(ninv) - case ninvs => Some(And(ninvs)) - } - } - - def status: String = prettyInv match { - case None => "unknown" - case Some(inv) => - PrettyPrinter(inv) - } -} - -class InferenceReport(fvcs: Map[FunDef, List[VC]], program: Program)(implicit ctx: InferenceContext) - extends VerificationReport(program, Map()) { - - import scala.math.Ordering.Implicits._ - val conditions: Seq[InferenceCondition] = - fvcs.flatMap(_._2.map(_.asInstanceOf[InferenceCondition])).toSeq.sortBy(vc => vc.fd.id.name) - - private def infoSep(size: Int): String = "╟" + ("┄" * size) + "╢\n" - private def infoFooter(size: Int): String = "╚" + ("═" * size) + "╝" - private def infoHeader(size: Int): String = ". ┌─────────┐\n" + - "╔═╡ Summary ╞" + ("═" * (size - 12)) + "╗\n" + - "║ └─────────┘" + (" " * (size - 12)) + "║" - - private def infoLine(str: String, size: Int): String = { - "║ " + str + (" " * (size - str.length - 2)) + " ║" - } - - private def fit(str: String, maxLength: Int): String = { - if (str.length <= maxLength) { - str - } else { - str.substring(0, maxLength - 1) + "…" - } - } - - private def funName(fd: FunDef) = InstUtil.userFunctionName(fd) - - override def summaryString: String = if (conditions.nonEmpty) { - val maxTempSize = (conditions.map(_.status.length).max + 3) - val outputStrs = conditions.map(vc => { - val timeStr = vc.time.map(t => "%-3.3f".format(t)).getOrElse("") - "%-15s %s %-4s".format(fit(funName(vc.fd), 15), vc.status + (" " * (maxTempSize - vc.status.length)), timeStr) - }) - val summaryStr = { - val totalTime = conditions.foldLeft(0.0)((a, ic) => a + ic.time.getOrElse(0.0)) - val inferredConds = conditions.count((ic) => ic.prettyInv.isDefined) - "total: %-4d inferred: %-4d unknown: %-4d time: %-3.3f".format( - conditions.size, inferredConds, conditions.size - inferredConds, totalTime) - } - val entrySize = (outputStrs :+ summaryStr).map(_.length).max + 2 - - infoHeader(entrySize) + - outputStrs.map(str => infoLine(str, entrySize)).mkString("\n", "\n", "\n") + - infoSep(entrySize) + - infoLine(summaryStr, entrySize) + "\n" + - infoFooter(entrySize) - - } else { - "No user provided templates were solved." - } - - def finalProgram: Program = { - val funToTmpl = conditions.collect { - case cd if cd.prettyInv.isDefined => - cd.fd -> cd.prettyInv.get - }.toMap - assignTemplateAndCojoinPost(funToTmpl, program) - } -} - -object InferenceReportUtil { - - def pushResultsToInput(ctx: InferenceContext, ics: Seq[InferenceCondition]) = { - - val initFuns = functionsWOFields(ctx.initProgram.definedFunctions).filter { fd => - !fd.isTheoryOperation && !fd.annotations.contains("library") - } - val solvedICs = ics.filter { _.prettyInv.isDefined } - - // mapping from init to output - val initToOutput = - initFuns.map { fd => - val freshId = FreshIdentifier(fd.id.name, fd.returnType) - val newfd = new FunDef(freshId, fd.tparams, fd.params, fd.returnType) - fd -> newfd - }.toMap - - def fullNameWoInst(fd: FunDef) = { - val splits = DefOps.fullName(fd)(ctx.inferProgram).split("-") - if (splits.nonEmpty) splits(0) - else "" - } - - val nameToInitFun = initFuns.map { fd => - DefOps.fullName(fd)(ctx.initProgram) -> fd - }.toMap - - // mapping from init to ic - val initICMap = (Map[FunDef, InferenceCondition]() /: solvedICs) { - case (acc, ic) => - nameToInitFun.get(fullNameWoInst(ic.fd)) match { - case Some(initfd) => - acc + (initfd -> ic) - case _ => acc - } - } - - def mapExpr(ine: Expr): Expr = { - val replaced = simplePostTransform { - case e@FunctionInvocation(TypedFunDef(fd, targs), args) => - if (initToOutput.contains(fd)) { - FunctionInvocation(TypedFunDef(initToOutput(fd), targs), args) - } else { - nameToInitFun.get(fullNameWoInst(fd)) match { - case Some(ifun) if initToOutput.contains(ifun) => - FunctionInvocation(TypedFunDef(initToOutput(ifun), targs), args) - case _ => e - } - } - case e => e - }(ine) - replaced - } - // copy bodies and specs - for ((from, to) <- initToOutput) { - to.body = from.body.map(mapExpr) - to.precondition = from.precondition.map(mapExpr) - val icOpt = initICMap.get(from) - if (icOpt.isDefined) { - val ic = icOpt.get - val paramMap = (ic.fd.params zip from.params).map { - case (p1, p2) => - (p1.id.toVariable -> p2.id.toVariable) - }.toMap - val icres = getResId(ic.fd).get - val npost = - if (from.hasPostcondition) { - val resid = getResId(from).get - val inv = replace(Map(icres.toVariable -> resid.toVariable) ++ paramMap, ic.prettyInv.get) - val postBody = from.postWoTemplate.map(post => createAnd(Seq(post, inv))).getOrElse(inv) - Lambda(Seq(ValDef(resid)), postBody) - } else { - val resid = FreshIdentifier(icres.name, icres.getType) - val inv = replace(Map(icres.toVariable -> resid.toVariable) ++ paramMap, ic.prettyInv.get) - Lambda(Seq(ValDef(resid)), inv) - } - to.postcondition = Some(mapExpr(npost)) - } else - to.postcondition = from.postcondition.map(mapExpr) - //copy annotations - from.flags.foreach(to.addFlag(_)) - } - - copyProgram(ctx.initProgram, (defs: Seq[Definition]) => defs.map { - case fd: FunDef if initToOutput.contains(fd) => - initToOutput(fd) - case d => d - }) - } -} \ No newline at end of file diff --git a/src/main/scala/leon/invariant/engine/RefinementEngine.scala b/src/main/scala/leon/invariant/engine/RefinementEngine.scala deleted file mode 100644 index 1be74fb8ffc882a6b96fd68886dbedd7364cb50c..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/invariant/engine/RefinementEngine.scala +++ /dev/null @@ -1,203 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package invariant.engine - -import purescala.Common._ -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.TypeOps.instantiateType -import purescala.Extractors._ -import purescala.Types._ -import java.io._ - -import invariant.templateSolvers._ -import invariant.factories._ -import invariant.util._ -import invariant.util.Util._ -import invariant.structure._ -import FunctionUtils._ -import Util._ -import PredicateUtil._ -import ProgramUtil._ - -//TODO: the parts of the code that collect the new head functions is ugly and has many side-effects. Fix this. -//TODO: there is a better way to compute heads, which is to consider all guards not previous seen -class RefinementEngine(ctx: InferenceContext, prog: Program, ctrTracker: ConstraintTracker) { - - val tru = BooleanLiteral(true) - val reporter = ctx.reporter - val cg = CallGraphUtil.constructCallGraph(prog) - - //this count indicates the number of times we unroll a recursive call - private val MAX_UNROLLS = 2 - - //debugging flags - private val dumpInlinedSummary = false - - //print flags - val verbose = false - - //the guards of disjuncts that were already processed - private var exploredGuards = Set[Variable]() - - //a set of calls that have not been unrolled (these are potential unroll candidates) - //However, these calls except those given by the unspecdCalls have been assumed specifications - private var headCalls = Map[FunDef, Set[Call]]() - def getHeads(fd: FunDef) = if (headCalls.contains(fd)) headCalls(fd) else Set() - def resetHeads(fd: FunDef, heads: Set[Call]) = { - if (headCalls.contains(fd)) { - headCalls -= fd - headCalls += (fd -> heads) - } else { - headCalls += (fd -> heads) - } - } - - /** - * This procedure refines the existing abstraction. - * Currently, the refinement happens by unrolling the head functions. - */ - def refineAbstraction(toRefineCalls: Option[Set[Call]]): Set[Call] = { - - ctrTracker.getFuncs.flatMap((fd) => { - val formula = ctrTracker.getVC(fd) - val disjuncts = formula.disjunctsInFormula - val newguards = formula.disjunctsInFormula.keySet.diff(exploredGuards) - exploredGuards ++= newguards - - val newheads = formula.getCallsOfGuards(newguards.toSeq) //.flatMap(g => disjuncts(g).collect { case c: Call => c }) - val allheads = getHeads(fd) ++ newheads - - //unroll each call in the head pointers and in toRefineCalls - val callsToProcess = if (toRefineCalls.isDefined) { - //pick only those calls that have been least unrolled - val relevCalls = allheads.intersect(toRefineCalls.get) - var minCalls = Set[Call]() - var minUnrollings = MAX_UNROLLS - relevCalls.foreach((call) => { - val calldata = formula.callData(call) - val recInvokes = calldata.parents.count(_ == call.fi.tfd.fd) - if (recInvokes < minUnrollings) { - minUnrollings = recInvokes - minCalls = Set(call) - } else if (recInvokes == minUnrollings) { - minCalls += call - } - }) - minCalls - } else allheads - - if (verbose) - reporter.info("Unrolling: " + callsToProcess.size + "/" + allheads.size) - - val unrolls = callsToProcess.foldLeft(Set[Call]())((acc, call) => { - val calldata = formula.callData(call) - val recInvokes = calldata.parents.count(_ == call.fi.tfd.fd) - //if the call is not a recursive call, unroll it unconditionally - if (recInvokes == 0) { - unrollCall(call, formula) - acc + call - } else { - //if the call is recursive, unroll iff the number of times the recursive function occurs in the context is < MAX-UNROLL - if (recInvokes < MAX_UNROLLS) { - unrollCall(call, formula) - acc + call - } else { - //otherwise, do not unroll the call - acc - } - } - //TODO: are there better ways of unrolling ?? Yes. Akask Lal "dag Inlining". Implement that! - }) - - //update the head functions - resetHeads(fd, allheads.diff(callsToProcess)) - unrolls - }).toSet - } - import leon.transformations.InstUtil._ - - def shouldCreateVC(recFun: FunDef, inSpec: Boolean): Boolean = { - if (ctrTracker.hasVC(recFun)) false - else { - //need not create vcs for theory operations and library methods - !recFun.isTheoryOperation && !recFun.annotations.contains("library") && - (recFun.template match { - case Some(temp) if inSpec && isResourceBoundOf(recFun)(temp) => false // TODO: here we can also drop resource templates if it is used with other templates - case Some(_) => true - case _ => false - }) - } - } - - /** - * Returns a set of unrolled calls and a set of new head functions - * here we unroll the methods in the current abstraction by one step. - * This procedure has side-effects on 'headCalls' and 'callDataMap' - */ - def unrollCall(call: Call, formula: Formula) { - val fi = call.fi - val calldata = formula.callData(call) - val callee = fi.tfd.fd - if (fi.tfd.fd.hasBody) { - //freshen the body and the post - val isRecursive = cg.isRecursive(callee) - if (isRecursive) { - val recFun = callee - val recFunTyped = fi.tfd - //check if we need to create a VC formula for the call's target - if (shouldCreateVC(recFun, calldata.inSpec)) { - reporter.info("Creating VC for " + recFun.id) - // instantiate the body with new types - val tparamMap = (recFun.typeArgs zip recFunTyped.tps).toMap - val paramMap = recFun.params.map { pdef => - pdef.id -> FreshIdentifier(pdef.id.name, instantiateType(pdef.id.getType, tparamMap)) - }.toMap - val freshBody = instantiateType(freshenLocals(recFun.body.get), tparamMap, paramMap) - val resname = if (recFun.hasPostcondition) getResId(recFun).get.name else "res" - //create a new result variable here for the same reason as freshening the locals, - //which is to avoid variable capturing during unrolling - val resvar = Variable(FreshIdentifier(resname, recFunTyped.returnType, true)) - val bodyExpr = Equals(resvar, freshBody) - val pre = recFun.precondition.map(p => instantiateType(p, tparamMap, paramMap)).getOrElse(tru) - //note: here we are only adding the template as the postcondition (other post need not be proved again) - val idmap = formalToActual(Call(resvar, FunctionInvocation(recFunTyped, paramMap.values.toSeq.map(_.toVariable)))) - val postTemp = replace(idmap, recFun.getTemplate) - //val vcExpr = ExpressionTransformer.normalizeExpr(And(bodyExpr, Not(postTemp)), ctx.multOp) - ctrTracker.addVC(recFun, pre, bodyExpr, postTemp) - } - //Here, unroll the call into the caller tree - if (verbose) reporter.info("Unrolling " + Equals(call.retexpr, call.fi)) - inilineCall(call, calldata, formula) - } else { - //here we are unrolling a function without template - if (verbose) reporter.info("Unfolding " + Equals(call.retexpr, call.fi)) - inilineCall(call, calldata, formula) - } - } else Set() - } - - def inilineCall(call: Call, calldata: CallData, formula: Formula) { - val tfd = call.fi.tfd - val callee = tfd.fd - if (callee.isBodyVisible) { - //here inline the body and conjoin it with the guard - //Important: make sure we use a fresh body expression here, and freshenlocals - val tparamMap = (callee.typeArgs zip tfd.tps).toMap - val freshBody = instantiateType(replace(formalToActual(call), - Equals(getFunctionReturnVariable(callee), freshenLocals(callee.body.get))), - tparamMap, Map()) - val inlinedSummary = ExpressionTransformer.normalizeExpr(freshBody, ctx.multOp) - if (this.dumpInlinedSummary) - println(s"Call: ${call} \n FunDef: $callee \n Inlined Summary of ${callee.id}: $inlinedSummary") - //conjoin the summary with the disjunct corresponding to the 'guard' - //note: the parents of the summary are the parents of the call plus the callee function - formula.conjoinWithDisjunct(calldata.guard, inlinedSummary, (callee +: calldata.parents), calldata.inSpec) - } else { - if (verbose) - reporter.info(s"Not inlining ${call.fi}: body invisible!") - } - } -} diff --git a/src/main/scala/leon/invariant/engine/SpecInstantiator.scala b/src/main/scala/leon/invariant/engine/SpecInstantiator.scala deleted file mode 100644 index 84dbcb2bac1b19311ad40ba227eeb951b4393d47..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/invariant/engine/SpecInstantiator.scala +++ /dev/null @@ -1,284 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package invariant.engine -import z3.scala._ -import purescala.Common._ -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.ExprOps._ -import leon.purescala.TypeOps.instantiateType -import purescala.Extractors._ -import purescala.Types._ -import java.io._ -import leon.invariant.templateSolvers.ExtendedUFSolver -import invariant._ -import scala.util.control.Breaks._ -import solvers._ -import scala.concurrent._ -import scala.concurrent.duration._ -import leon.evaluators.DefaultEvaluator - -import invariant.templateSolvers._ -import invariant.factories._ -import invariant.util._ -import invariant.structure._ -import FunctionUtils._ -import Util._ -import PredicateUtil._ -import ProgramUtil._ -import SolverUtil._ - -class SpecInstantiator(ctx: InferenceContext, program: Program, ctrTracker: ConstraintTracker) { - - val verbose = false - - protected val disableAxioms = false - protected val debugAxiomInstantiation = false - - val tru = BooleanLiteral(true) - val axiomFactory = new AxiomFactory(ctx) //handles instantiation of axiomatic specification - - //the guards of the set of calls that were already processed - protected var exploredGuards = Set[Variable]() - - def instantiate() = { - val funcs = ctrTracker.getFuncs - - funcs.foreach((fd) => { - val formula = ctrTracker.getVC(fd) - val disjuncts = formula.disjunctsInFormula - val newguards = disjuncts.keySet.diff(exploredGuards) - exploredGuards ++= newguards - - val newcalls = formula.getCallsOfGuards(newguards.toSeq).toSet //flatMap(g => disjuncts(g).collect { case c: Call => c }) - instantiateSpecs(formula, newcalls, funcs.toSet) - - if (!disableAxioms) { - //remove all multiplication if "withmult" is specified - val relavantCalls = if (ctx.withmult) { - newcalls.filter(call => !isMultFunctions(call.fi.tfd.fd)) - } else newcalls - instantiateAxioms(formula, relavantCalls) - } - }) - } - - /** - * This function refines the formula by assuming the specifications/templates for calls in the formula - * Here, we assume (pre => post ^ template) for each call (templates only for calls with VC) - * Important: adding templates for 'newcalls' of the previous iterations is empirically more effective - */ - //a set of calls for which templates or specifications have not been assumed - private var untemplatedCalls = Map[FunDef, Set[Call]]() - def getUntempCalls(fd: FunDef) = if (untemplatedCalls.contains(fd)) untemplatedCalls(fd) else Set() - def resetUntempCalls(fd: FunDef, calls: Set[Call]) = { - if (untemplatedCalls.contains(fd)) { - untemplatedCalls -= fd - untemplatedCalls += (fd -> calls) - } else { - untemplatedCalls += (fd -> calls) - } - } - - def instantiateSpecs(formula: Formula, calls: Set[Call], funcsWithVC: Set[FunDef]) = { - - //assume specifications - calls.foreach((call) => { - //first get the spec for the call if it exists - val spec = specForCall(call) - if (spec.isDefined && spec.get != tru) { - val cdata = formula.callData(call) - formula.conjoinWithDisjunct(cdata.guard, spec.get, cdata.parents, inSpec = true) - } - }) - - //try to assume templates for all the current un-templated calls - var newUntemplatedCalls = Set[Call]() - getUntempCalls(formula.fd).foreach { call => - if (funcsWithVC.contains(call.fi.tfd.fd)) { // add templates of only functions for which there exists a VC - templateForCall(call) match { - case Some(temp) => - val cdata = formula.callData(call) - formula.conjoinWithDisjunct(cdata.guard, temp, cdata.parents, inSpec = true) - case _ => - ; // here there is no template for the call - } - } else { - newUntemplatedCalls += call - } - } - resetUntempCalls(formula.fd, newUntemplatedCalls ++ calls) - } - - def specForCall(call: Call): Option[Expr] = { - val tfd = call.fi.tfd - val callee = tfd.fd - if (callee.hasPostcondition) { - //get the postcondition without templates - val rawpost = freshenLocals(callee.getPostWoTemplate) - val rawspec = - if (callee.hasPrecondition) { - val pre = freshenLocals(callee.precondition.get) - if (ctx.assumepre) - And(pre, rawpost) - else - Implies(pre, rawpost) - } else { - rawpost - } - // instantiate the post - val tparamMap = (callee.typeArgs zip tfd.tps).toMap - val instSpec = instantiateType(replace(formalToActual(call), rawspec), tparamMap, Map()) - val inlinedSpec = ExpressionTransformer.normalizeExpr(instSpec, ctx.multOp) - Some(inlinedSpec) - } else { - None - } - } - - def templateForCall(call: Call): Option[Expr] = { - val tfd = call.fi.tfd - val callee = tfd.fd - if (callee.hasTemplate) { - val argmap = formalToActual(call) - val tparamMap = (callee.typeArgs zip tfd.tps).toMap - val tempExpr = instantiateType(replace(argmap, freshenLocals(callee.getTemplate)), tparamMap, Map()) - val template = if (callee.hasPrecondition) { - val pre = instantiateType(replace(argmap, freshenLocals(callee.precondition.get)), tparamMap, Map()) - if (ctx.assumepre) - And(pre, tempExpr) - else - Implies(pre, tempExpr) - } else { - tempExpr - } - //TODO: should we freshen locals of template here ?? - Some(ExpressionTransformer.normalizeExpr(template, ctx.multOp)) - } else None - } - - //axiomatic specification - protected var axiomRoots = Map[Seq[Call], Variable]() //a mapping from axioms keys (a sequence of calls) to the guards - def instantiateAxioms(formula: Formula, calls: Set[Call]) = { - - val debugSolver = if (this.debugAxiomInstantiation) { - val sol = new ExtendedUFSolver(ctx.leonContext, program) - sol.assertCnstr(formula.toExpr) - Some(sol) - } else None - - val inst1 = instantiateUnaryAxioms(formula, calls) - val inst2 = instantiateBinaryAxioms(formula, calls) - val axiomInsts = inst1 ++ inst2 - - Stats.updateCounterStats(atomNum(createAnd(axiomInsts)), "AxiomBlowup", "VC-refinement") - if(verbose) ctx.reporter.info("Number of axiom instances: " + axiomInsts.size) - - if (this.debugAxiomInstantiation) { - println("Instantianting axioms over: " + calls) - println("Instantiated Axioms: ") - axiomInsts.foreach((ainst) => { - println(ainst) - debugSolver.get.assertCnstr(ainst) - val res = debugSolver.get.check - res match { - case Some(false) => - println("adding axiom made formula unsat!!") - case _ => ; - } - }) - debugSolver.get.free - } - } - - //this code is similar to assuming specifications - def instantiateUnaryAxioms(formula: Formula, calls: Set[Call]) = { - val axioms = calls.collect { - case call @ _ if axiomFactory.hasUnaryAxiom(call) => { - val (ant, conseq) = axiomFactory.unaryAxiom(call) - val axiomInst = Implies(ant, conseq) - val nnfAxiom = ExpressionTransformer.normalizeExpr(axiomInst, ctx.multOp) - val cdata = formula.callData(call) - formula.conjoinWithDisjunct(cdata.guard, nnfAxiom, cdata.parents, inSpec = true) - axiomInst - } - } - axioms.toSeq - } - - /** - * Here, we assume that axioms do not introduce calls. - * If this does not hold, 'guards' have to be used while instantiating axioms so as - * to compute correct verification conditions. - * TODO: Use least common ancestor etc. to avoid axiomatizing calls along different disjuncts - * TODO: can we avoid axioms like (a <= b ^ x<=y => p <= q), (x <= y ^ a<=b => p <= q), ... - * TODO: can we have axiomatic specifications relating two different functions ? - */ - protected var binaryAxiomCalls = Map[FunDef, Set[Call]]() //calls with axioms so far seen - def getBinaxCalls(fd: FunDef) = if (binaryAxiomCalls.contains(fd)) binaryAxiomCalls(fd) else Set[Call]() - def appendBinaxCalls(fd: FunDef, calls: Set[Call]) = { - if (binaryAxiomCalls.contains(fd)) { - val oldcalls = binaryAxiomCalls(fd) - binaryAxiomCalls -= fd - binaryAxiomCalls += (fd -> (oldcalls ++ calls)) - } else { - binaryAxiomCalls += (fd -> calls) - } - } - - def instantiateBinaryAxioms(formula: Formula, calls: Set[Call]) = { - - val newCallsWithAxioms = calls.filter(axiomFactory.hasBinaryAxiom _) - - def isInstantiable(call1: Call, call2: Call): Boolean = { - //important: check if the two calls refer to the same function - (call1.fi.tfd.id == call2.fi.tfd.id) && (call1 != call2) - } - - val product = cross[Call, Call](newCallsWithAxioms, getBinaxCalls(formula.fd), Some(isInstantiable)).flatMap( - p => Seq((p._1, p._2), (p._2, p._1))) ++ - cross[Call, Call](newCallsWithAxioms, newCallsWithAxioms, Some(isInstantiable)).map(p => (p._1, p._2)) - - //ctx.reporter.info("# of pairs with axioms: "+product.size) - //Stats.updateCumStats(product.size, "Call-pairs-with-axioms") - - val addedAxioms = product.flatMap(pair => { - //union the parents of the two calls - val cdata1 = formula.callData(pair._1) - val cdata2 = formula.callData(pair._2) - val parents = cdata1.parents ++ cdata2.parents - val axiomInsts = axiomFactory.binaryAxiom(pair._1, pair._2) - - axiomInsts.foldLeft(Seq[Expr]())((acc, inst) => { - val (ant, conseq) = inst - val axiom = Implies(ant, conseq) - val nnfAxiom = ExpressionTransformer.normalizeExpr(axiom, ctx.multOp) - val axroot = formula.conjoinWithRoot(nnfAxiom, parents, true) - //important: here we need to update the axiom roots - axiomRoots += (Seq(pair._1, pair._2) -> axroot) - acc :+ axiom - }) - }) - appendBinaxCalls(formula.fd, newCallsWithAxioms) - addedAxioms - } - - /** - * Note: taking a formula as input may not be necessary. We can store it as a part of the state - * TODO: can we use transitivity here to optimize ? - */ - def axiomsForCalls(formula: Formula, calls: Set[Call], model: LazyModel, tmplMap: Map[Identifier,Expr], eval: DefaultEvaluator): Seq[Constraint] = { - //note: unary axioms need not be instantiated - //consider only binary axioms - (for (x <- calls; y <- calls) yield (x, y)).foldLeft(Seq[Constraint]())((acc, pair) => { - val (c1, c2) = pair - if (c1 != c2) { - val axRoot = axiomRoots.get(Seq(c1, c2)) - if (axRoot.isDefined) - acc ++ formula.pickSatDisjunct(axRoot.get, model, tmplMap, eval) - else acc - } else acc - }) - } -} \ No newline at end of file diff --git a/src/main/scala/leon/invariant/engine/TemplateEnumerator.scala b/src/main/scala/leon/invariant/engine/TemplateEnumerator.scala deleted file mode 100644 index e830d91b2e7f2bff1268b6e00f03f121539b7b59..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/invariant/engine/TemplateEnumerator.scala +++ /dev/null @@ -1,187 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package invariant.engine - -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.Types._ - -import invariant.factories._ -import invariant.util._ -import ProgramUtil._ - -import scala.collection.mutable.{Set => MutableSet} -/** - * An enumeration based template generator. - * Enumerates all numerical terms in some order (this enumeration is incomplete for termination). - * TODO: Feature: - * (a) allow template functions and functions with template variables ? - * (b) should we unroll algebraic data types ? - * - * The following function may potentially have complexity O(n^i) where 'n' is the number of functions - * and 'i' is the increment step - * TODO: optimize the running and also reduce the size of the input templates - * - * For now this is incomplete - */ -class TemplateEnumerator(ctx: InferenceContext, prog: Program) extends TemplateGenerator { - val reporter = ctx.reporter - - //create a call graph for the program - //Caution: this call-graph could be modified later while call the 'getNextTemplate' method - private val callGraph = { - val cg = CallGraphUtil.constructCallGraph(prog) - cg - } - - private var tempEnumMap = Map[FunDef, FunctionTemplateEnumerator]() - - def getNextTemplate(fd: FunDef): Expr = { - if (tempEnumMap.contains(fd)) tempEnumMap(fd).getNextTemplate() - else { - val enumerator = new FunctionTemplateEnumerator(fd, prog, ctx.enumerationRelation, callGraph, reporter) - tempEnumMap += (fd -> enumerator) - enumerator.getNextTemplate() - } - } -} - -/** - * This class manages templates for the given function - * 'op' is a side-effects parameter - * Caution: The methods of this class has side-effects on the 'callGraph' parameter - */ -class FunctionTemplateEnumerator(rootFun: FunDef, prog: Program, op: (Expr, Expr) => Expr, - callGraph: CallGraph, reporter: Reporter) { - private val MAX_INCREMENTS = 2 - private val zero = InfiniteIntegerLiteral(0) - //using default op as <= or == (manually adjusted) - //private val op = LessEquals - //LessThan - //LessEquals - //Equals.apply _ - private var currTemp: Expr = null - private var incrStep: Int = 0 - private var typeTermMap = Map[TypeTree, MutableSet[Expr]]() - private var ttCurrent = Map[TypeTree, MutableSet[Expr]]() - - //get all functions that are not the current function. - //the value of the current function is given by res and its body - //itself characterizes how it is defined recursively w.r.t itself. - //Need to also avoid mutual recursion as it may lead to proving of invalid facts - private val fds = prog.definedFunctions.filter(_ != rootFun) - - def getNextTemplate(): Expr = { - //println("Getting next template for function: "+fd.id) - - if (incrStep == MAX_INCREMENTS) { - //exhausted the templates, so return - op(currTemp, zero) - } else { - - incrStep += 1 - - var newTerms = Map[TypeTree, MutableSet[Expr]]() - if (currTemp == null) { - //initialize - //add all the arguments and results of fd to 'typeTermMap' - rootFun.params.foreach((vardecl) => { - val tpe = vardecl.getType - val v = vardecl.id.toVariable - if (newTerms.contains(tpe)) { - newTerms(tpe).add(v) - } else { - newTerms += (tpe -> MutableSet(v)) - } - }) - - val resVar = getFunctionReturnVariable(rootFun) - if (newTerms.contains(rootFun.returnType)) { - newTerms(rootFun.returnType).add(resVar) - } else { - newTerms += (rootFun.returnType -> MutableSet(resVar)) - } - - //also 'assignCurrTemp' to a template variable - currTemp = TemplateIdFactory.freshTemplateVar() - } else { - - //apply the user-defined functions to the compatible terms in typeTermMap - //Important: Make sure that the recursive calls are not introduced in the templates - //TODO: this is a hack to prevent infinite recursion in specification. However, it is not clear if this will prevent inferrence of - //any legitimate specifications (however this can be modified). - fds.foreach((fun) => { - //Check if adding a call from 'rootFun' to 'fun' creates a mutual recursion by checking if - //'fun' transitively calls 'rootFun' - if (fun != rootFun && !callGraph.transitivelyCalls(fun, rootFun)) { - - //check if every argument has at least one satisfying assignment? - if (!fun.params.exists((vardecl) => !ttCurrent.contains(vardecl.getType))) { - - //here compute all the combinations - val newcalls = generateFunctionCalls(fun) - if (newTerms.contains(fun.returnType)) { - newTerms(fun.returnType) ++= newcalls - } else { - var muset = MutableSet[Expr]() - muset ++= newcalls - newTerms += (fun.returnType -> muset) - } - } - } - - }) - - } - //add all the newly generated expression to the typeTermMap - ttCurrent = newTerms - typeTermMap ++= newTerms - - //statistics - reporter.info("- Number of new terms enumerated: " + newTerms.size) - - //return all the integer valued terms of newTerms - //++ newTerms.getOrElse(Int32Type, Seq[Expr]()) (for now not handling int 32 terms) - val numericTerms = (newTerms.getOrElse(RealType, Seq[Expr]()) ++ newTerms.getOrElse(IntegerType, Seq[Expr]())).toSeq - if (numericTerms.nonEmpty) { - //create a linear combination of intTerms - val newTemp = numericTerms.foldLeft(null: Expr)((acc, t: Expr) => { - val summand = Times(t, TemplateIdFactory.freshTemplateVar(): Expr) - if (acc == null) summand - else - Plus(summand, acc) - }) - //add newTemp to currTemp - currTemp = Plus(newTemp, currTemp) - - //get all the calls in the 'newTemp' and add edges from 'rootFun' to the callees to the call-graph - val callees = CallGraphUtil.getCallees(newTemp) - callees.foreach(callGraph.addEdgeIfNotPresent(rootFun, _)) - } - op(currTemp, zero) - } - } - - /** - * Generate a set of function calls of fun using the terms in ttCurrent - */ - def generateFunctionCalls(fun: FunDef): Set[Expr] = { - /** - * To be called with argIndex of zero and an empty argList - */ - def genFunctionCallsRecur(argIndex: Int, argList: Seq[Expr]): Set[Expr] = { - if (argIndex == fun.params.size) { - //create a call using argList - //TODO: how should we handle generics - Set(FunctionInvocation(TypedFunDef(fun, fun.tparams.map(_.tp)), argList)) - } else { - val arg = fun.params(argIndex) - val tpe = arg.getType - ttCurrent(tpe).foldLeft(Set[Expr]())((acc, term) => acc ++ genFunctionCallsRecur(argIndex + 1, argList :+ term)) - } - } - - genFunctionCallsRecur(0, Seq()) - } -} \ No newline at end of file diff --git a/src/main/scala/leon/invariant/engine/UnfoldingTemplateSolver.scala b/src/main/scala/leon/invariant/engine/UnfoldingTemplateSolver.scala deleted file mode 100644 index 336b7960303946553f02b37ad0c0250dd1d06100..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/invariant/engine/UnfoldingTemplateSolver.scala +++ /dev/null @@ -1,206 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package invariant.engine - -import purescala.Common._ -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Types._ -import purescala.DefOps._ -import purescala.ScalaPrinter -import purescala.Constructors._ - -import solvers._ -import verification._ -import invariant.factories._ -import invariant.util._ -import invariant.structure._ -import transformations._ -import FunctionUtils._ -import Util._ -import PredicateUtil._ -import ProgramUtil._ - -/** - * @author ravi - * This phase performs automatic invariant inference. - * TODO: Do we need to also assert that time is >= 0 - */ -case class InferResult(res: Boolean, model: Option[Model], inferredFuncs: List[FunDef]) { -} - -trait FunctionTemplateSolver { - def apply(): Option[InferResult] -} - -class UnfoldingTemplateSolver(ctx: InferenceContext, program: Program, rootFd: FunDef) extends FunctionTemplateSolver { - - val reporter = ctx.reporter - val debugVCs = false - - lazy val constTracker = new ConstraintTracker(ctx, program, rootFd) - lazy val templateSolver = TemplateSolverFactory.createTemplateSolver(ctx, program, constTracker, rootFd) - - def constructVC(funDef: FunDef): (Expr, Expr, Expr) = { - val Lambda(Seq(ValDef(resid)), _) = funDef.postcondition.get - val body = Equals(resid.toVariable, funDef.body.get) - val funName = fullName(funDef, useUniqueIds = false)(program) - val assumptions = - if (funDef.usePost && ctx.isFunctionPostVerified(funName)) - createAnd(Seq(funDef.getPostWoTemplate, funDef.precOrTrue)) - else funDef.precOrTrue - val fullPost = - if (funDef.hasTemplate) { - // if the postcondition is verified do not include it in the sequent - if (ctx.isFunctionPostVerified(funName)) - funDef.getTemplate - else - And(funDef.getPostWoTemplate, funDef.getTemplate) - } else if (!ctx.isFunctionPostVerified(funName)) - funDef.getPostWoTemplate - else tru - (body, assumptions, fullPost) - } - - def solveParametricVC(assump: Expr, body: Expr, conseq: Expr) = { - // initialize the constraint tracker - constTracker.addVC(rootFd, assump, body, conseq) - - var refinementStep: Int = 0 - var toRefineCalls: Option[Set[Call]] = None - var infRes: Option[InferResult] = None - do { - infRes = - if(ctx.abort) - Some(InferResult(false, None, List())) - else{ - Stats.updateCounter(1, "VC-refinement") - /* - * uncomment if we want to bound refinements - * if (refinementStep >= 5) - * throw new IllegalStateException("Done 4 refinements") - */ - val refined = - if (refinementStep >= 1) { - reporter.info("- More unrollings for invariant inference") - - val toUnrollCalls = if (ctx.targettedUnroll) toRefineCalls else None - val unrolledCalls = constTracker.refineVCs(toUnrollCalls) - if (unrolledCalls.isEmpty) { - reporter.info("- Cannot do more unrollings, reached unroll bound") - false - } else true - } else { - constTracker.initialize - true - } - refinementStep += 1 - if (!refined) - Some(InferResult(false, None, List())) - else { - //solve for the templates in this unroll step - templateSolver.solveTemplates() match { - case (Some(model), callsInPath) => - toRefineCalls = callsInPath - //Validate the model here - instantiateAndValidateModel(model, constTracker.getFuncs) - Some(InferResult(true, Some(model), constTracker.getFuncs.toList)) - case (None, callsInPath) => - toRefineCalls = callsInPath - //here, we do not know if the template is solvable or not, we need to do more unrollings. - None - } - } - } - } while (infRes.isEmpty) - infRes - } - - def apply() = { - if(ctx.abort) { - Some(InferResult(false, None, List())) - } else { - val (body, pre, post) = constructVC(rootFd) - if (post == tru) - Some(InferResult(true, Some(Model.empty), List())) - else - solveParametricVC(pre, body, post) - } - } - - def instantiateModel(model: Model, funcs: Seq[FunDef]) = { - funcs.collect { - case fd if fd.hasTemplate => - fd -> TemplateInstantiator.instantiateNormTemplates(model, fd.normalizedTemplate.get) - }.toMap - } - - def instantiateAndValidateModel(model: Model, funcs: Seq[FunDef]) = { - val sols = instantiateModel(model, funcs) - var output = "Invariants for Function: " + rootFd.id + "\n" - sols foreach { - case (fd, inv) => - val simpInv = simplifyArithmetic(InstUtil.replaceInstruVars(multToTimes(inv), fd)) - reporter.info("- Found inductive invariant: " + fd.id + " --> " + ScalaPrinter(simpInv)) - output += fd.id + " --> " + simpInv + "\n" - } - SpecificStats.addOutput(output) - - reporter.info("- Verifying Invariants... ") - val verifierRes = verifyInvariant(sols) - val finalRes = verifierRes._1 match { - case Some(false) => - reporter.info("- Invariant verified") - sols - case Some(true) => - reporter.error("- Invalid invariant, model: " + verifierRes._2.toMap) - throw new IllegalStateException("") - case _ => - //the solver timed out here - reporter.error("- Unable to prove or disprove invariant, the invariant is probably true") - sols - } - finalRes - } - - /** - * This function creates a new program with each function postcondition strengthened by - * the inferred postcondition - */ - def verifyInvariant(newposts: Map[FunDef, Expr]): (Option[Boolean], Model) = { - val augProg = assignTemplateAndCojoinPost(Map(), program, newposts, uniqueIdDisplay = false) - //convert the program back to an integer program if necessary - val newprog = - if (ctx.usereals) new RealToIntProgram()(augProg) - else augProg - val newroot = functionByFullName(fullName(rootFd)(program), newprog).get - verifyVC(newprog, newroot) - } - - /** - * Uses default postcondition VC, but can be overriden in the case of non-standard VCs - */ - def verifyVC(newprog: Program, newroot: FunDef) = { - val post = newroot.postcondition.get - val body = newroot.body.get - val vc = implies(newroot.precOrTrue, application(post, Seq(body))) - solveUsingLeon(ctx.leonContext, newprog, VC(vc, newroot, VCKinds.Postcondition)) - } - - import leon.solvers._ - import leon.solvers.unrolling.UnrollingSolver - def solveUsingLeon(leonctx: LeonContext, p: Program, vc: VC) = { - val solFactory = SolverFactory.uninterpreted(leonctx, program) - val smtUnrollZ3 = new UnrollingSolver(ctx.leonContext.toSctx, program, solFactory.getNewSolver()) with TimeoutSolver - smtUnrollZ3.setTimeout(ctx.vcTimeout * 1000) - smtUnrollZ3.assertVC(vc) - smtUnrollZ3.check match { - case Some(true) => - (Some(true), smtUnrollZ3.getModel) - case r => - (r, Model.empty) - } - } -} diff --git a/src/main/scala/leon/invariant/factories/AxiomFactory.scala b/src/main/scala/leon/invariant/factories/AxiomFactory.scala deleted file mode 100644 index 126e3d1db01c122a2a9b115b19dd218dd67f5ec2..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/invariant/factories/AxiomFactory.scala +++ /dev/null @@ -1,90 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package invariant.factories - -import purescala.Expressions._ -import purescala.Types._ - -import invariant.engine._ -import invariant.util._ -import invariant.structure._ -import FunctionUtils._ -import PredicateUtil._ - -class AxiomFactory(ctx : InferenceContext) { - - val tru = BooleanLiteral(true) - - //Add more axioms here, if necessary - def hasUnaryAxiom(call: Call) : Boolean = { - //important: here we need to avoid applying commutativity on the calls produced by axioms instantiation - call.fi.tfd.fd.isCommutative - } - - def hasBinaryAxiom(call: Call) : Boolean = { - val callee = call.fi.tfd.fd - (callee.isMonotonic || callee.isDistributive) - } - - def unaryAxiom(call: Call): (Expr, Expr) = { - val callee = call.fi.tfd.fd - val tfd = call.fi.tfd - - if (callee.isCommutative) { - //note: commutativity is defined only for binary operations - val Seq(a1, a2) = call.fi.args - val newret = TVarFactory.createTempDefault("cm").toVariable - val newfi = FunctionInvocation(tfd, Seq(a2, a1)) - val newcall = Call(newret, newfi) - (tru, And(newcall.toExpr, Equals(newret, call.retexpr))) - } else - throw new IllegalStateException("Call does not have unary axiom: " + call) - } - - def binaryAxiom(call1: Call, call2: Call): Seq[(Expr,Expr)] = { - - if (call1.fi.tfd.id != call2.fi.tfd.id) - throw new IllegalStateException("Instantiating binary axiom on calls to different functions: " + call1 + "," + call2) - - if (!hasBinaryAxiom(call1)) - throw new IllegalStateException("Call does not have binary axiom: " + call1) - - val callee = call1.fi.tfd.fd - //monotonicity - var axioms = if (callee.isMonotonic) { - Seq(monotonizeCalls(call1, call2)) - } else Seq() - - //distributivity - axioms ++= (if (callee.isDistributive) { - //println("Applying distributivity on: "+(call1,call2)) - Seq(undistributeCalls(call1, call2)) - } else Seq()) - - axioms - } - - def monotonizeCalls(call1: Call, call2: Call): (Expr,Expr) = { - val ants = (call1.fi.args zip call2.fi.args).foldLeft(Seq[Expr]())((acc, pair) => { - val lesse = LessEquals(pair._1, pair._2) - lesse +: acc - }) - val conseq = LessEquals(call1.retexpr, call2.retexpr) - (createAnd(ants), conseq) - } - - //this is applicable only to binary operations - def undistributeCalls(call1: Call, call2: Call): (Expr,Expr) = { - val tfd = call1.fi.tfd - - val Seq(a1,b1) = call1.fi.args - val Seq(a2,b2) = call2.fi.args - val r1 = call1.retexpr - val r2 = call2.retexpr - - val dret2 = TVarFactory.createTempDefault("dt", IntegerType).toVariable - val dcall2 = Call(dret2, FunctionInvocation(tfd,Seq(Plus(a1,a2),b2))) - (LessEquals(b1,b2), And(LessEquals(Plus(r1,r2),dret2), dcall2.toExpr)) - } -} diff --git a/src/main/scala/leon/invariant/factories/TemplateFactory.scala b/src/main/scala/leon/invariant/factories/TemplateFactory.scala deleted file mode 100644 index 06798e870d2fed8e3e0b5ed77ca6a0e773092b39..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/invariant/factories/TemplateFactory.scala +++ /dev/null @@ -1,151 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package invariant.factories - -import purescala.Common._ -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Types._ -import scala.collection.mutable.{Map => MutableMap} - -import invariant.util._ -import invariant.structure._ -import FunctionUtils._ -import PredicateUtil._ -import ProgramUtil._ -import TypeUtil._ - -object TemplateIdFactory { - //a set of template ids - private var ids = Set[Identifier]() - - def getTemplateIds : Set[Identifier] = ids - - def freshIdentifier(name : String = "", idType: TypeTree = RealType) : Identifier = { - val idname = if(name.isEmpty) "a?" - else name + "?" - val freshid = FreshIdentifier(idname, idType, true) - ids += freshid - freshid - } - - def copyIdentifier(id: Identifier) : Identifier = { - val freshid = FreshIdentifier(id.name, id.getType, false) - ids += freshid - freshid - } - - /** - * Template variables have real type - */ - def IsTemplateIdentifier(id : Identifier) : Boolean = { - ids.contains(id) - } - - def IsTemplateVar(v : Variable) : Boolean = { - IsTemplateIdentifier(v.id) - } - - def freshTemplateVar(name : String= "") : Variable = { - Variable(freshIdentifier(name)) - } -} - -trait TemplateGenerator { - def getNextTemplate(fd : FunDef): Expr -} - -/** - * Templates are expressions with template variables. - * The program variables that can be free in the templates are only the arguments and - * the result variable. - * Note: the program logic depends on the mutability here. - */ -class TemplateFactory(tempGen : Option[TemplateGenerator], prog: Program, reporter : Reporter) { - - //a mapping from function definition to the template - private var templateMap = { - //initialize the template map with predefined user maps - val muMap = MutableMap[FunDef, Expr]() - functionsWOFields(prog.definedFunctions).foreach { fd => - val tmpl = fd.template - if (tmpl.isDefined) { - muMap.update(fd, tmpl.get) - } - } - muMap - } - - def setTemplate(fd:FunDef, tempExpr :Expr) = { - templateMap += (fd -> tempExpr) - } - - /** - * This is the default template generator. - * - */ - def getDefaultTemplate(fd : FunDef): Expr = { - - //just consider all the arguments, return values that are integers - val baseTerms = fd.params.filter((vardecl) => isNumericType(vardecl.getType)).map(_.toVariable) ++ - (if(isNumericType(fd.returnType)) Seq(getFunctionReturnVariable(fd)) - else Seq()) - - val lhs = baseTerms.foldLeft(TemplateIdFactory.freshTemplateVar() : Expr)((acc, t)=> { - Plus(Times(TemplateIdFactory.freshTemplateVar(),t),acc) - }) - val tempExpr = LessEquals(lhs,InfiniteIntegerLiteral(0)) - tempExpr - } - - /** - * Constructs a template using a mapping from the formals to actuals. - * Uses default template if a template does not exist for the function and no template generator is provided. - * Otherwise, use the provided template generator - */ - var refinementSet = Set[FunDef]() - def constructTemplate(argmap: Map[Expr,Expr], fd: FunDef): Expr = { - - //initialize the template for the function - if (!templateMap.contains(fd)) { - if(tempGen.isEmpty) templateMap += (fd -> getDefaultTemplate(fd)) - else { - templateMap += (fd -> tempGen.get.getNextTemplate(fd)) - refinementSet += fd - //for information - reporter.info("- Template generated for function "+fd.id+" : "+templateMap(fd)) - } - } - replace(argmap,templateMap(fd)) - } - - /** - * Refines the templates of the functions that were assigned templates using the template generator. - */ - def refineTemplates(): Boolean = { - - if(tempGen.isDefined) { - var modifiedTemplate = false - refinementSet.foreach((fd) => { - val oldTemp = templateMap(fd) - val newTemp = tempGen.get.getNextTemplate(fd) - - if (oldTemp != newTemp) { - modifiedTemplate = true - templateMap.update(fd, newTemp) - reporter.info("- New template for function " + fd.id + " : " + newTemp) - } - }) - modifiedTemplate - } else false - } - - def getTemplate(fd : FunDef) : Option[Expr] = { - templateMap.get(fd) - } - - def getFunctionsWithTemplate : Seq[FunDef] = templateMap.keys.toSeq - -} \ No newline at end of file diff --git a/src/main/scala/leon/invariant/factories/TemplateInstantiator.scala b/src/main/scala/leon/invariant/factories/TemplateInstantiator.scala deleted file mode 100644 index ea810b82843b9eeb25464ac70dc6c75f78000d3c..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/invariant/factories/TemplateInstantiator.scala +++ /dev/null @@ -1,128 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package invariant.factories - -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Extractors._ -import invariant.util._ -import invariant.structure._ -import invariant.engine.InferenceContext -import leon.solvers.Model -import leon.invariant.util.RealValuedExprEvaluator -import PredicateUtil._ -import FunctionUtils._ -import ExpressionTransformer._ - -object TemplateInstantiator { - - /** - * Computes the invariant for all the procedures given a model for the template variables. - * (Undone) If the mapping does not have a value for an id, then the id is bound to the simplest value - */ - def getAllInvariants(model: Model, funs: Seq[FunDef], prettyInv: Boolean = false): Map[FunDef, Expr] = { - val invs = funs.collect { - case fd if fd.hasTemplate => - (fd, instantiateNormTemplates(model, fd.normalizedTemplate.get, prettyInv)) - }.toMap - invs - } - - /** - * This function expects a template in a normalized form. - */ - def instantiateNormTemplates(model: Model, template: Expr, prettyInv: Boolean = false): Expr = { - val tempvars = getTemplateVars(template) - val instTemplate = instantiate(template, tempvars.map { v => (v, model(v.id)) }.toMap, prettyInv) - unflatten(instTemplate) - } - - /** - * Instantiates templated subexpressions of the given expression (expr) using the given mapping for the template variables. - * The instantiation also takes care of converting the rational coefficients to integer coefficients. - */ - def instantiate(expr: Expr, tempVarMap: Map[Expr, Expr], prettyInv: Boolean = false): Expr = { - //do a simple post transform and replace the template vars by their values - val inv = simplePostTransform { - case tempExpr@(e@Operator(Seq(lhs, rhs), op)) if ((e.isInstanceOf[Equals] || e.isInstanceOf[LessThan] - || e.isInstanceOf[LessEquals] || e.isInstanceOf[GreaterThan] - || e.isInstanceOf[GreaterEquals]) - && - getTemplateVars(tempExpr).nonEmpty) => { - val linearTemp = LinearConstraintUtil.exprToTemplate(tempExpr) - instantiateTemplate(linearTemp, tempVarMap, prettyInv) - } - case tempExpr => tempExpr - }(expr) - inv - } - - def validateLiteral(e: Expr) = e match { - case FractionalLiteral(num, denom) => { - if (denom == 0) - throw new IllegalStateException("Denominator is zero !! " + e) - if (denom < 0) - throw new IllegalStateException("Denominator is negative: " + denom) - true - } - case IntLiteral(_) => true - case InfiniteIntegerLiteral(_) => true - case _ => throw new IllegalStateException("Not a real literal: " + e) - } - - def instantiateTemplate(linearTemp: LinearTemplate, tempVarMap: Map[Expr, Expr], prettyInv: Boolean = false): Expr = { - val bigone = BigInt(1) - val coeffMap = linearTemp.coeffTemplate.map((entry) => { - val (term, coeffTemp) = entry - val coeffE = replace(tempVarMap, coeffTemp) - val coeff = RealValuedExprEvaluator.evaluate(coeffE) - - validateLiteral(coeff) - - (term -> coeff) - }) - val const = if (linearTemp.constTemplate.isDefined) { - val constE = replace(tempVarMap, linearTemp.constTemplate.get) - val constV = RealValuedExprEvaluator.evaluate(constE) - - validateLiteral(constV) - Some(constV) - } else None - - val realValues: Seq[Expr] = coeffMap.values.toSeq ++ { if (const.isDefined) Seq(const.get) else Seq() } - //the coefficients could be fractions ,so collect all the denominators - val getDenom = (t: Expr) => t match { - case FractionalLiteral(num, denum) => denum - case _ => bigone - } - - val denoms = realValues.foldLeft(Set[BigInt]())((acc, entry) => { acc + getDenom(entry) }) - - //compute the LCM of the denominators - val gcd = denoms.foldLeft(bigone)((acc, d) => acc.gcd(d)) - val lcm = denoms.foldLeft(BigInt(1))((acc, d) => { - val product = (acc * d) - if (product % gcd == 0) - product / gcd - else product - }) - - //scale the numerator by lcm - val scaleNum = (t: Expr) => t match { - case FractionalLiteral(num, denum) => - InfiniteIntegerLiteral(num * (lcm / denum)) - case InfiniteIntegerLiteral(n) => - InfiniteIntegerLiteral(n * lcm) - case _ => throw new IllegalStateException("Coefficient not assigned to any value") - } - val intCoeffMap = coeffMap.map((entry) => (entry._1, scaleNum(entry._2))) - val intConst = if (const.isDefined) Some(scaleNum(const.get)) else None - - val linearCtr = new LinearConstraint(linearTemp.op, intCoeffMap, intConst) - if (prettyInv) - linearCtr.toPrettyExpr - else linearCtr.toExpr - } -} \ No newline at end of file diff --git a/src/main/scala/leon/invariant/factories/TemplateSolverFactory.scala b/src/main/scala/leon/invariant/factories/TemplateSolverFactory.scala deleted file mode 100644 index 888c579bf047c017dfa2c68baa5e7dc5441955d7..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/invariant/factories/TemplateSolverFactory.scala +++ /dev/null @@ -1,42 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package invariant.factories - -import purescala.Definitions._ -import purescala.Expressions._ -import invariant._ -import invariant.engine._ -import invariant.util._ -import invariant.structure._ -import FunctionUtils._ -import templateSolvers._ -import leon.solvers.Model - -object TemplateSolverFactory { - - def createTemplateSolver(ctx: InferenceContext, prog: Program, ctrack: ConstraintTracker, rootFun: FunDef, - // options to solvers - minopt: Option[(Expr, Model) => Model] = None, - bound: Option[Int] = None): TemplateSolver = { - if (ctx.useCegis) { - // TODO: find a better way to specify CEGIS total time bound - new CegisSolver(ctx, prog, rootFun, ctrack, 10000, bound) - } else { - val minimizer = if (ctx.tightBounds && rootFun.hasTemplate) { - if (minopt.isDefined) - minopt - else { - //TODO: need to assert that the templates are resource templates - Some((new Minimizer(ctx, prog)).tightenTimeBounds(rootFun.getTemplate) _) - } - } else - None - if (ctx.withmult) { - new NLTemplateSolverWithMult(ctx, prog, rootFun, ctrack, minimizer) - } else { - new NLTemplateSolver(ctx, prog, rootFun, ctrack, minimizer) - } - } - } -} \ No newline at end of file diff --git a/src/main/scala/leon/invariant/structure/Constraint.scala b/src/main/scala/leon/invariant/structure/Constraint.scala deleted file mode 100644 index e55481b359de198b1f41fdc76b8bd1ea8ad1d13e..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/invariant/structure/Constraint.scala +++ /dev/null @@ -1,388 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package invariant.structure - -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Types._ -import invariant.util._ -import Util._ -import PredicateUtil._ -import TypeUtil._ -import purescala.Extractors._ -import ExpressionTransformer._ -import solvers.Model -import purescala.Common._ -import leon.evaluators._ - -trait Constraint { - def toExpr: Expr -} - -trait ExtendedConstraint extends Constraint { - def pickSatDisjunct(model: LazyModel, tmplModel: Map[Identifier,Expr], eval: DefaultEvaluator): Constraint -} - -object LinearTemplate { - val debug = false - val debugPickSat = false -} - -/** - * Class representing linear templates which is a constraint of the form - * a1*v1 + a2*v2 + .. + an*vn + a0 <= 0 or = 0 or < 0 where ai's are unknown coefficients - * which could be any arbitrary expression with template variables as free variables - * and vi's are variables. - * Note: we need atleast one coefficient or one constant to be defined. - * Otherwise a NPE will be thrown (in the computation of 'template') - */ -class LinearTemplate(oper: Seq[Expr] => Expr, - coeffTemp: Map[Expr, Expr], - constTemp: Option[Expr]) extends Constraint { - - import LinearTemplate._ - - val zero = InfiniteIntegerLiteral(0) - val op = oper - - val coeffTemplate = { - if(debug) assert(coeffTemp.values.forall(e => isTemplateExpr(e))) - coeffTemp - } - - val constTemplate = { - if(debug) assert(constTemp.map(isTemplateExpr).getOrElse(true)) - constTemp - } - - val lhsExpr = { - //construct the expression corresponding to the template here - var lhs = coeffTemp.foldLeft(null: Expr) { - case (acc, (term, coeff)) => - val minterm = Times(coeff, term) - if (acc == null) minterm else Plus(acc, minterm) - } - if (constTemp.isDefined) { - if (lhs == null) constTemp.get - else Plus(lhs, constTemp.get) - } else lhs - } - - val template = oper(Seq(lhsExpr, zero)) - - def templateVars: Set[Variable] = getTemplateVars(template) - - /** - * Picks a sat disjunct of the negation of the template w.r.t to the - * given model. - */ - lazy val negTmpls = { - val args = template match { - case _: Equals => Seq(GreaterThan(lhsExpr, zero), LessThan(lhsExpr,zero)) - case _: LessEquals => Seq(GreaterThan(lhsExpr, zero)) - case _: LessThan => Seq(GreaterEquals(lhsExpr, zero)) - case _: GreaterEquals => Seq(LessThan(lhsExpr, zero)) - case _: GreaterThan => Seq(LessEquals(lhsExpr, zero)) - } - args map LinearConstraintUtil.exprToTemplate - } - - def pickSatDisjunctOfNegation(model: LazyModel, tmplModel: Map[Identifier, Expr], eval: DefaultEvaluator) = { - val err = new IllegalStateException(s"Cannot pick a sat disjunct of negation: ${toString} is sat!") - template match { - case _: Equals => // here, negation is a disjunction - UnflatHelper.evaluate(replaceFromIDs(tmplModel, lhsExpr), model, eval) match { - case InfiniteIntegerLiteral(lval) => - val Seq(grt, less) = negTmpls - if (lval > 0) grt - else if (lval < 0) less - else throw err - } - case _ => // here, the negation must be sat - if (debugPickSat) { - if (UnflatHelper.evaluate(replaceFromIDs(tmplModel, negTmpls.head.toExpr), model, eval) != tru) - throw err - } - negTmpls.head - } - } - - def coeffEntryToString(coeffEntry: (Expr, Expr)): String = { - val (e, i) = coeffEntry - i match { - case InfiniteIntegerLiteral(x) if (x == 1) => e.toString - case InfiniteIntegerLiteral(x) if (x == -1) => "-" + e.toString - case InfiniteIntegerLiteral(v) => v + e.toString - case IntLiteral(1) => e.toString - case IntLiteral(-1) => "-" + e.toString - case IntLiteral(v) => v + e.toString - case _ => i + " * " + e.toString - } - } - - override def toExpr: Expr = template - - /** - * Converts the template to a more human readable form - * by group positive (and negative) terms together - */ - def toPrettyExpr = { - val (lhsCoeff, rhsCoeff) = coeffTemplate.partition { - case (term, InfiniteIntegerLiteral(v)) => - v >= 0 - case _ => true - } - var lhsExprs: Seq[Expr] = lhsCoeff.map(e => Times(e._2, e._1)).toSeq - var rhsExprs: Seq[Expr] = rhsCoeff.map { - case (term, InfiniteIntegerLiteral(v)) => - Times(InfiniteIntegerLiteral(-v), term) // make the coeff +ve - }.toSeq - constTemplate match { - case Some(InfiniteIntegerLiteral(v)) if v < 0 => - rhsExprs :+= InfiniteIntegerLiteral(-v) - case Some(c) => - lhsExprs :+= c - case _ => - } - val lhsExprOpt = ((None: Option[Expr]) /: lhsExprs) { - case (acc, minterm) => - if (acc.isDefined) - Some(Plus(acc.get, minterm)) - else Some(minterm) - } - val rhsExprOpt = ((None: Option[Expr]) /: rhsExprs) { - case (acc, minterm) => - if (acc.isDefined) - Some(Plus(acc.get, minterm)) - else Some(minterm) - } - val lhs = lhsExprOpt.getOrElse(InfiniteIntegerLiteral(0)) - val rhs = rhsExprOpt.getOrElse(InfiniteIntegerLiteral(0)) - oper(Seq(lhs, rhs)) - } - - override def toString(): String = { - val coeffStr = if (coeffTemplate.isEmpty) "" - else { - val (head :: tail) = coeffTemplate.toList - tail.foldLeft(coeffEntryToString(head))((str, pair) => { - - val termStr = coeffEntryToString(pair) - (str + " + " + termStr) - }) - } - val constStr = if (constTemplate.isDefined) constTemplate.get.toString else "" - val str = if (!coeffStr.isEmpty() && !constStr.isEmpty()) coeffStr + " + " + constStr - else coeffStr + constStr - str + (template match { - case t: Equals => " = " - case t: LessThan => " < " - case t: GreaterThan => " > " - case t: LessEquals => " <= " - case t: GreaterEquals => " >= " - }) + "0" - } - - override def hashCode(): Int = template.hashCode() - - override def equals(obj: Any): Boolean = obj match { - case lit: LinearTemplate => lit.template.equals(this.template) - case _ => false - } -} - -/** - * class representing a linear constraint. This is a linear template wherein the coefficients are constants - */ -class LinearConstraint(opr: Seq[Expr] => Expr, cMap: Map[Expr, Expr], constant: Option[Expr]) - extends LinearTemplate(opr, cMap, constant) { - val coeffMap = cMap - val const = constant -} - -/** - * Class representing Equality or disequality of a boolean variable and an linear template. - * Used for efficiently choosing a disjunct - */ -case class ExtendedLinearTemplate(v: Variable, tmpl: LinearTemplate, diseq: Boolean) extends ExtendedConstraint { - val expr = { - val eqExpr = Equals(v, tmpl.toExpr) - if(diseq) Not(eqExpr) else eqExpr - } - override def toExpr = expr - override def toString: String = expr.toString - - /** - * Chooses a sat disjunct of the constraint - */ - override def pickSatDisjunct(model: LazyModel, tmplModel: Map[Identifier,Expr], eval: DefaultEvaluator) = { - if((model(v.id) == tru && !diseq) || (model(v.id) == fls && diseq)) tmpl - else { - //println(s"Picking sat disjunct of: ${toExpr} model($v) = ${model(v.id)}") - tmpl.pickSatDisjunctOfNegation(model, tmplModel, eval) - } - } -} - -object BoolConstraint { - def isBoolConstraint(e: Expr): Boolean = e match { - case _: Variable | _: BooleanLiteral if e.getType == BooleanType => true - case Equals(l, r) => isBoolConstraint(l) && isBoolConstraint(r) //enabling makes the system slower!! surprising - case Not(arg) => isBoolConstraint(arg) - case And(args) => args forall isBoolConstraint - case Or(args) => args forall isBoolConstraint - case _ => false - } -} - -case class BoolConstraint(e: Expr) extends Constraint { - import BoolConstraint._ - val expr = { - assert(isBoolConstraint(e)) - e - } - override def toString(): String = expr.toString - def toExpr: Expr = expr -} - -object ADTConstraint { - def apply(e: Expr): ADTConstraint = e match { - case Equals(_: Variable, _: CaseClassSelector | _: TupleSelect) => - new ADTConstraint(e, sel = true) - case Equals(_: Variable, _: CaseClass | _: Tuple) => - new ADTConstraint(e, cons = true) - case Equals(_: Variable, _: IsInstanceOf) => - new ADTConstraint(e, inst = true) - case Equals(lhs @ Variable(_), AsInstanceOf(rhs @ Variable(_), _)) => - new ADTConstraint(Equals(lhs, rhs), comp= true) - case Equals(lhs: Variable, _: Variable) if adtType(lhs) => - new ADTConstraint(e, comp = true) - case Not(Equals(lhs: Variable, _: Variable)) if adtType(lhs) => - new ADTConstraint(e, comp = true) - case _ => - throw new IllegalStateException(s"Expression not an ADT constraint: $e") - } -} - -class ADTConstraint(val expr: Expr, - val cons: Boolean = false, - val inst: Boolean = false, - val comp: Boolean = false, - val sel: Boolean = false) extends Constraint { - - override def toString(): String = expr.toString - override def toExpr = expr -} - -case class ExtendedADTConstraint(v: Variable, adtCtr: ADTConstraint, diseq: Boolean) extends ExtendedConstraint { - val expr = { - assert(adtCtr.comp) - val eqExpr = Equals(v, adtCtr.toExpr) - if(diseq) Not(eqExpr) else eqExpr - } - override def toExpr = expr - override def toString: String = expr.toString - - /** - * Chooses a sat disjunct of the constraint - */ - override def pickSatDisjunct(model: LazyModel, tmplModel: Map[Identifier,Expr], eval: DefaultEvaluator) = { - if((model(v.id) == tru && !diseq) || (model(v.id) == fls && diseq)) adtCtr - else ADTConstraint(Not(adtCtr.toExpr)) - } -} - -case class Call(retexpr: Expr, fi: FunctionInvocation) extends Constraint { - val expr = Equals(retexpr, fi) - override def toExpr = expr -} - -/** - * If-then-else constraint - */ -case class ITE(cond: BoolConstraint, ths: Seq[Constraint], elzs: Seq[Constraint]) extends Constraint { - val expr = IfExpr(cond.toExpr, createAnd(ths.map(_.toExpr)), createAnd(elzs.map(_.toExpr))) - override def toExpr = expr -} - -object SetConstraint { - def setConstraintOfBase(e: Expr) = e match { - case Equals(lhs@Variable(_), _) if lhs.getType.isInstanceOf[SetType] => - true - case Equals(Variable(_), SetUnion(_, _) | FiniteSet(_, _) | ElementOfSet(_, _) | SubsetOf(_, _)) => - true - case _ => false - } - - def isSetConstraint(e: Expr) = { - val base = e match { - case Not(b) => b - case _ => e - } - setConstraintOfBase(base) - } -} - -case class SetConstraint(expr: Expr) extends Constraint { - var union = false - var newset = false - var equal = false - var elemof = false - var subset = false - // TODO: add more operations here - expr match { - case Equals(Variable(_), rhs) => - rhs match { - case SetUnion(_, _) => union = true - case FiniteSet(_, _) => newset = true - case ElementOfSet(_, _) => elemof = true - case SubsetOf(_, _) => subset = true - case Variable(_) => equal = true - } - } - override def toString(): String = { - expr.toString - } - override def toExpr = expr -} - -object ConstraintUtil { - def toLinearTemplate(ie: Expr) = { - simplifyArithmetic(ie) match { - case b: BooleanLiteral => BoolConstraint(b) - case _ => { - val template = LinearConstraintUtil.exprToTemplate(ie) - LinearConstraintUtil.evaluate(template) match { - case Some(v) => BoolConstraint(BooleanLiteral(v)) - case _ => template - } - } - } - } - - def toExtendedTemplate(v: Variable, ie: Expr, diseq: Boolean) = { - toLinearTemplate(ie) match { - case bc: BoolConstraint => BoolConstraint(Equals(v, bc.toExpr)) - case t: LinearTemplate => ExtendedLinearTemplate(v, t, diseq) - } - } - - def createConstriant(ie: Expr): Constraint = { - ie match { - case _ if BoolConstraint.isBoolConstraint(ie) => BoolConstraint(ie) - case Equals(v @ Variable(_), fi @ FunctionInvocation(_, _)) => Call(v, fi) - case Equals(_: Variable, _: CaseClassSelector | _: CaseClass | _: TupleSelect | _: Tuple | _: IsInstanceOf) => - ADTConstraint(ie) - case _ if SetConstraint.isSetConstraint(ie) => SetConstraint(ie) - case Equals(v: Variable, rhs) if (isArithmeticRelation(rhs) != Some(false)) => toExtendedTemplate(v, rhs, false) - case Not(Equals(v: Variable, rhs)) if (isArithmeticRelation(rhs) != Some(false)) => toExtendedTemplate(v, rhs, true) - case _ if (isArithmeticRelation(ie) != Some(false)) => toLinearTemplate(ie) - case Equals(v: Variable, rhs@Equals(l, _)) if adtType(l) => ExtendedADTConstraint(v, ADTConstraint(rhs), false) - - // every other equality will be considered an ADT constraint (including TypeParameter equalities) - case Equals(lhs, rhs) if !isNumericType(lhs.getType) => ADTConstraint(ie) - case Not(Equals(lhs, rhs)) if !isNumericType(lhs.getType) => ADTConstraint(ie) - } - } -} diff --git a/src/main/scala/leon/invariant/structure/Formula.scala b/src/main/scala/leon/invariant/structure/Formula.scala deleted file mode 100644 index 2480b75c53249c52d40463cb8d42001d6306c78f..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/invariant/structure/Formula.scala +++ /dev/null @@ -1,479 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package invariant.structure - -import purescala.Common._ -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Extractors._ -import purescala.Types._ -import java.io._ -import invariant.engine._ -import invariant.util._ -import solvers.Model -import Util._ -import PredicateUtil._ -import TVarFactory._ -import ExpressionTransformer._ -import invariant.factories._ -import evaluators._ -import EvaluationResults._ - -/** - * Data associated with a call. - * @param inSpec true if the call (transitively) made within specifications - */ -class CallData(val guard : Variable, val parents: List[FunDef], val inSpec: Boolean) - -object Formula { - val debugUnflatten = false - val dumpUnflatFormula = false - // a context for creating blockers - val blockContext = newContext -} - -/** - * Representation of an expression as a set of implications. - * 'initexpr' is required to be in negation normal form and And/Ors have been pulled up - * TODO: optimize the representation so that we use fewer guards. - * @param initSpecCalls when specified it optimizes the handling of calls made in the specification. - */ -class Formula(val fd: FunDef, initexpr: Expr, ctx: InferenceContext, initSpecCalls: Set[Expr] = Set()) { - - import Formula._ - - val fls = BooleanLiteral(false) - val tru = BooleanLiteral(true) - val useImplies = false // note: we have to use equality for 'cond' blockers (no matter what!) - - val combiningOp = if(useImplies) Implies.apply _ else Equals.apply _ - protected var disjuncts = Map[Variable, Seq[Constraint]]() //a mapping from guards to conjunction of atoms - protected var conjuncts = Map[Variable, Expr]() //a mapping from guards to disjunction of atoms - private var paramBlockers = Set[Variable]() - private var callDataMap = Map[Call, CallData]() //a mapping from a 'call' to the 'guard' guarding the call plus the list of transitive callers of 'call' - - val firstRoot: Variable = addConstraints(initexpr, List(fd), c => initSpecCalls(c.toExpr))._1 - protected var roots : Seq[Variable] = Seq(firstRoot) //a list of roots, the formula is a conjunction of formula of each root - - def disjunctsInFormula = disjuncts - - def callData(call: Call) : CallData = callDataMap(call) - - //return the root variable and the sequence of disjunct guards added - //(which includes the root variable incase it respresents a disjunct) - def addConstraints(ine: Expr, callParents: List[FunDef], inSpec: Call => Boolean): (Variable, Seq[Variable]) = { - def atoms(e: Expr) = e match { - case And(atms) => atms - case _ => Seq(e) - } - var newDisjGuards = Seq[Variable]() - var condBlockers = Map[Variable, (Variable, Variable)]() // a mapping from condition constraint to then and else blockers - - def getCtrsFromExprs(guard: Variable, exprs: Seq[Expr]): Seq[Constraint] = { - var break = false - exprs.foldLeft(Seq[Constraint]()) { - case (acc, _) if break => acc - case (acc, ife @ IfExpr(cond: Variable, th, elze)) => - val (thBlock, elseBlock) = condBlockers(cond) - acc :+ ITE(BoolConstraint(cond), BoolConstraint(thBlock) +: getCtrsFromExprs(thBlock, atoms(th)), - BoolConstraint(elseBlock) +: getCtrsFromExprs(elseBlock, atoms(elze))) - case (acc, e) => - ConstraintUtil.createConstriant(e) match { - case BoolConstraint(BooleanLiteral(true)) => acc - case fls @ BoolConstraint(BooleanLiteral(false)) => - break = true - Seq(fls) - case call @ Call(_, _) => - if (callParents.isEmpty) throw new IllegalArgumentException("Parent not specified for call: " + call) - else callDataMap += (call -> new CallData(guard, callParents, inSpec(call))) - acc :+ call - case ctr => acc :+ ctr - } - } - } - /** - * Creates disjunct of the form b == exprs and updates the necessary mutable states - */ - def addToDisjunct(exprs: Seq[Expr], isTemplate: Boolean) = { - val g = createTemp("b", BooleanType, blockContext).toVariable - newDisjGuards :+= g - val ctrs = getCtrsFromExprs(g, exprs) - disjuncts += (g -> ctrs) - if(isTemplate) - paramBlockers += g - g - } - def rec(e: Expr)(implicit insideOperation: Boolean): Expr = e match { - case Or(args) if !insideOperation => - val newargs = (args map rec).map { - case v: Variable if disjuncts.contains(v) => v - case v: Variable if conjuncts.contains(v) => throw new IllegalStateException("or gaurd inside conjunct: " + e + " or-guard: " + v) - case arg => - val g = addToDisjunct(atoms(arg), !getTemplateIds(arg).isEmpty) - //println(s"creating a new OR blocker $g for "+atoms) - g - } - val gor = createTemp("b", BooleanType, blockContext).toVariable - val newor = createOr(newargs) - //println("Creating or const: "+(gor -> newor)) - conjuncts += (gor -> newor) - gor - - case And(args) => - //if the expression has template variables then we separate it using guards - val (nonparams, params) = (args map rec).partition(getTemplateIds(_).isEmpty) - val newargs = - if (!params.isEmpty) - addToDisjunct(params, true) +: nonparams - else nonparams - createAnd(newargs) - - case e : IfExpr => - val (con, th, elze) = (rec(e.cond)(true), rec(e.thenn)(false), rec(e.elze)(false)) - if(!isAtom(con) || !getTemplateIds(con).isEmpty) - throw new IllegalStateException(s"Condition of ifexpr is not an atom: $e") - // create condition and anti-condition blockers - val ncond = addToDisjunct(Seq(con), false) - val thBlock = addToDisjunct(Seq(), false) - val elseBlock = addToDisjunct(Seq(), false) - condBlockers += (ncond -> (thBlock, elseBlock)) - // normalize thn and elze - val trans = (e: Expr) => { - if(getTemplateIds(e).isEmpty) e - else addToDisjunct(atoms(e), true) - } - IfExpr(ncond, trans(th), trans(elze)) - - case Operator(args, op) => - op(args.map(rec(_)(true))) - } - val f1 = simplifyByConstructors(rec(ExpressionTransformer.simplify(simplifyArithmetic( - //TODO: this is a hack as of now. Fix this. - //Note: it is necessary to convert real literals to integers since the linear constraint cannot handle real literals - if(ctx.usereals) ExpressionTransformer.FractionalLiteralToInt(ine) - else ine - )))(false)) - val rootvar = f1 match { - case v: Variable if(conjuncts.contains(v)) => v - case v: Variable if(disjuncts.contains(v)) => throw new IllegalStateException("f1 is a disjunct guard: "+v) - case _ => addToDisjunct(atoms(f1), !getTemplateIds(f1).isEmpty) - } - (rootvar, newDisjGuards) - } - - def pickSatDisjunct(startGaurd : Variable, model: LazyModel, tmplModel: Map[Identifier, Expr], eval: DefaultEvaluator): Seq[Constraint] = { - - def traverseOrs(ine: Expr): Seq[Constraint] = { - val Or(guards) = ine - val guard = guards.collectFirst { case g @ Variable(id) if (model(id) == tru) => g } //pick one guard that is true - if (guard.isEmpty) - throw new IllegalStateException("No satisfiable guard found: " + ine) - BoolConstraint(guard.get) +: traverseAnds(disjuncts(guard.get)) - } - def traverseAnds(inctrs: Seq[Constraint]): Seq[Constraint] = - inctrs.foldLeft(Seq[Constraint]()) { - case (acc, ITE(BoolConstraint(c: Variable), ths, elzes)) => - val conds = disjuncts(c) // here, cond it guaranteed to be an atom - assert(conds.size <= 1) - val ctrs = - if (model(c.id) == tru) - conds ++ traverseAnds(ths) - else { - val condCtr = conds match { - case Seq(bc: BoolConstraint) => BoolConstraint(Not(bc.toExpr)) - case Seq(lc: LinearTemplate) => lc.pickSatDisjunctOfNegation(model, tmplModel, eval) - case Seq(adteq: ADTConstraint) if adteq.comp => - adteq.toExpr match { - case Not(eq) => ADTConstraint(eq) - case eq => ADTConstraint(Not(eq)) - } - } - condCtr +: traverseAnds(elzes) - } - acc ++ ctrs - case (acc, elt: ExtendedConstraint) => - acc :+ elt.pickSatDisjunct(model, tmplModel, eval) - case (acc, ctr @ BoolConstraint(v: Variable)) if conjuncts.contains(v) => //assert(model(v.id) == tru) - acc ++ (ctr +: traverseOrs(conjuncts(v))) - case (acc, ctr @ BoolConstraint(v: Variable)) if disjuncts.contains(v) => //assert(model(v.id) == tru) - acc ++ (ctr +: traverseAnds(disjuncts(v))) - case (acc, ctr) => acc :+ ctr - } - val path = - if (model(startGaurd.id) == fls) Seq() //if startGuard is unsat return empty - else { - if (conjuncts.contains(startGaurd)) - traverseOrs(conjuncts(startGaurd)) - else - BoolConstraint(startGaurd) +: traverseAnds(disjuncts(startGaurd)) - } - /*println("Path: " + simplifyArithmetic(createAnd(path.map(_.toExpr)))) - scala.io.StdIn.readLine()*/ - path - } - - /** - * 'neweexpr' is required to be in negation normal form and And/Ors have been pulled up - */ - def conjoinWithDisjunct(guard: Variable, newexpr: Expr, callParents: List[FunDef], inSpec:Boolean) = { - val (exprRoot, newGuards) = addConstraints(newexpr, callParents, _ => inSpec) - //add 'newguard' in conjunction with 'disjuncts(guard)' - val ctrs = disjuncts(guard) - disjuncts -= guard - disjuncts += (guard -> (BoolConstraint(exprRoot) +: ctrs)) - exprRoot - } - - def conjoinWithRoot(newexpr: Expr, callParents: List[FunDef], inSpec: Boolean) = { - val (exprRoot, newGuards) = addConstraints(newexpr, callParents, _ => inSpec) - roots :+= exprRoot - exprRoot - } - - def getCallsOfGuards(guards: Seq[Variable]): Seq[Call] = { - def calls(ctrs: Seq[Constraint]): Seq[Call] = { - ctrs.flatMap { - case c: Call => Seq(c) - case ITE(_, th, el) => - calls(th) ++ calls(el) - case _ => Seq() - } - } - guards.flatMap{g => calls(disjuncts(g)) } - } - - def callsInFormula: Seq[Call] = getCallsOfGuards(disjuncts.keys.toSeq) - - def templateIdsInFormula = paramBlockers.flatMap { g => - getTemplateIds(createAnd(disjuncts(g).map(_.toExpr))) - }.toSet - - /** - * The first return value is param part and the second one is the - * non-parametric part - */ - def splitParamPart : (Expr, Expr) = { - val paramPart = paramBlockers.toSeq.map{ g => - combiningOp(g,createAnd(disjuncts(g).map(_.toExpr))) - } - val rest = disjuncts.collect { - case (g, ctrs) if !paramBlockers(g) => - combiningOp(g, createAnd(ctrs.map(_.toExpr))) - }.toSeq - val conjs = conjuncts.map((entry) => combiningOp(entry._1, entry._2)).toSeq ++ roots - (createAnd(paramPart), createAnd(rest ++ conjs)) - } - - def toExpr : Expr={ - val disjs = disjuncts.map((entry) => { - val (g,ctrs) = entry - combiningOp(g, createAnd(ctrs.map(_.toExpr))) - }).toSeq - val conjs = conjuncts.map((entry) => combiningOp(entry._1, entry._2)).toSeq - createAnd(disjs ++ conjs ++ roots) - } - - /** - * Creates an unflat expr of the non-param part, - * and returns a constructor for the flat model from unflat models - */ - def toUnflatExpr = { - val paramPart = paramBlockers.toSeq.map{ g => - combiningOp(g,createAnd(disjuncts(g).map(_.toExpr))) - } - // simplify blockers if we can, and close the map - val blockMap = substClosure(disjuncts.collect { - case (g, Seq(ctr)) if !paramBlockers(g) => (g.id -> ctr.toExpr) - case (g, Seq()) => (g.id -> tru) - }.toMap) - val conjs = conjuncts.map { - case (g, rhs) => replaceFromIDs(blockMap, combiningOp(g, rhs)) - }.toSeq ++ roots.map(replaceFromIDs(blockMap, _)) - val flatRest = disjuncts.toSeq collect { - case (g, ctrs) if !paramBlockers(g) && !blockMap.contains(g.id) => - //val ng = blockMap.getOrElse(g.id, g) - (g, replaceFromIDs(blockMap, createAnd(ctrs.map(_.toExpr)))) - } - // compute variables used in more than one disjunct - var sharedVars = (paramPart ++ conjs).flatMap(variablesOf).toSet - var uniqueVars = Set[Identifier]() - var freevars = Set[Identifier]() - flatRest.foreach{ - case (g, rhs) => - val fvs = variablesOf(rhs).toSet - val candUniques = fvs -- sharedVars - val newShared = uniqueVars.intersect(candUniques) - freevars ++= fvs - sharedVars ++= newShared - uniqueVars = (uniqueVars ++ candUniques) -- newShared - } - // unflatten rest - var flatIdMap = blockMap - val unflatRest = (flatRest collect { - case (g, rhs) => - // note: we call simple unflatten in the presence of if-then-else because it will not have flat-ids transcending then and else branches - val (unflatRhs, idmap) = simpleUnflattenWithMap(rhs, sharedVars, includeFuns = false) - // sanity checks - if (debugUnflatten) { - val rhsvars = variablesOf(rhs) - if(!rhsvars.filter(TemplateIdFactory.IsTemplateIdentifier).isEmpty) - throw new IllegalStateException(s"Non-param part has template identifiers ${toString}") - val seenKeys = flatIdMap.keySet.intersect(rhsvars) - if (!seenKeys.isEmpty) - throw new IllegalStateException(s"flat ids used across clauses $seenKeys in ${toString}") - } - flatIdMap ++= idmap - combiningOp(g, unflatRhs) - }).toSeq - - val modelCons = (m: Model, eval: DefaultEvaluator) => new FlatModel(freevars, flatIdMap, m, eval) - - if (dumpUnflatFormula) { - val unf = ((paramPart ++ unflatRest.map(_.toString) ++ conjs.map(_.toString)).mkString("\n")) - val filename = "unflatVC-" + FileCountGUID.getID - val wr = new PrintWriter(new File(filename + ".txt")) - println("Printed VC of " + fd.id + " to file: " + filename) - wr.println(unf) - wr.close() - } - if (ctx.dumpStats) { - Stats.updateCounterStats(atomNum(And(paramPart ++ unflatRest ++ conjs)), "unflatSize", "VC-refinement") - } - (createAnd(paramPart), createAnd(unflatRest ++ conjs), modelCons) - } - - //unpack the disjunct and conjuncts by removing all guards - def eliminateBlockers : Expr = { - //replace all conjunct guards in disjuncts by their mapping - val disjs : Map[Expr,Expr] = disjuncts.map((entry) => { - val (g,ctrs) = entry - val newctrs = ctrs.map { - case BoolConstraint(g@Variable(_)) if conjuncts.contains(g) => conjuncts(g) - case ctr@_ => ctr.toExpr - } - (g, createAnd(newctrs)) - }) - val rootexprs = roots.map { - case g@Variable(_) if conjuncts.contains(g) => conjuncts(g) - case e@_ => e - } - //replace every guard in the 'disjs' by its disjunct. DO this as long as every guard is replaced in every disjunct - var unpackedDisjs = disjs - var replacedGuard = true - //var removeGuards = Seq[Variable]() - while(replacedGuard) { - replacedGuard = false - val newDisjs = unpackedDisjs.map(entry => { - val (g,d) = entry - val guards = variablesOf(d).collect{ case id@_ if disjuncts.contains(id.toVariable) => id.toVariable } - if (guards.isEmpty) entry - else { - replacedGuard = true - //removeGuards ++= guards - (g, replace(unpackedDisjs, d)) - } - }) - unpackedDisjs = newDisjs - } - //replace all the 'guards' in root using 'unpackedDisjs' - replace(unpackedDisjs, createAnd(rootexprs)) - } - - override def toString : String = { - val disjStrs = disjuncts.map((entry) => { - val (g,ctrs) = entry - simplifyArithmetic(combiningOp(g, createAnd(ctrs.map(_.toExpr)))).toString - }).toSeq - val conjStrs = conjuncts.map((entry) => combiningOp(entry._1, entry._2).toString).toSeq - val rootStrs = roots.map(_.toString) - (disjStrs ++ conjStrs ++ rootStrs).foldLeft("")((acc,str) => acc + "\n" + str) - } - - /** - * Functions for stats - */ - def atomsCount = disjuncts.map(_._2.size).sum + conjuncts.map(i => atomNum(i._2)).sum - def funsCount = disjuncts.map(_._2.count { - case _: Call | _: ADTConstraint => true - case _ => false - }).sum - - /** - * Functions solely used for debugging - */ - import solvers.SimpleSolverAPI - def checkUnflattening(tempMap: Map[Expr, Expr], sol: SimpleSolverAPI, eval: DefaultEvaluator) = { - // solve unflat formula - val (temp, rest, modelCons) = toUnflatExpr - val packedFor = TemplateInstantiator.instantiate(And(Seq(rest, temp)), tempMap) - val (unflatSat, unflatModel) = sol.solveSAT(packedFor) - // solve flat formula (using the same values for the uncompressed vars) - val flatVCInst = simplifyArithmetic(TemplateInstantiator.instantiate(toExpr, tempMap)) - val modelExpr = SolverUtil.modelToExpr(unflatModel) - val (flatSat, flatModel) = sol.solveSAT(And(flatVCInst, modelExpr)) - //println("Formula: "+unpackedFor) - //println("packed formula: "+packedFor) - val satdisj = - if (unflatSat == Some(true)) - Some(pickSatDisjunct(firstRoot, new SimpleLazyModel(unflatModel), - tempMap.map{ case (Variable(id), v) => id -> v }.toMap, eval)) - else None - if (unflatSat != flatSat) { - if (satdisj.isDefined) { - val preds = satdisj.get.filter { ctr => - if (getTemplateIds(ctr.toExpr).isEmpty) { - val exp = And(Seq(ctr.toExpr, modelExpr)) - sol.solveSAT(exp)._1 == Some(false) - } else false - } - println(s"Conflicting preds: ${preds.map(_.toExpr)}") - } - throw new IllegalStateException(s"VC produces different result with flattening: unflatSat: $unflatSat flatRes: $flatSat") - } else { - if (satdisj.isDefined) { - // print all differences between the models (only along the satisfiable path, values of other variables may not be computable) - val satExpr = createAnd(satdisj.get.map(_.toExpr)) - val lazyModel = modelCons(unflatModel, eval) - val allvars = variablesOf(satExpr) - val elimIds = allvars -- variablesOf(packedFor) - val diffs = allvars.filterNot(TemplateIdFactory.IsTemplateIdentifier).flatMap { - case id if !flatModel.isDefinedAt(id) => - println("Did not find a solver model for: " + id + " elimIds: " + elimIds(id)) - Seq() - case id if lazyModel(id) != flatModel(id) => - println(s"diff $id : flat: ${lazyModel(id)} solver: ${flatModel(id)}" + " elimIds: " + elimIds(id)) - Seq(id) - case _ => Seq() - } - if (!diffs.isEmpty) - throw new IllegalStateException("Model do not agree on diffs: " + diffs) - } - } - } - - /** - * A method for picking a sat disjunct of unflat formula. Mostly used for debugging. - */ - def pickSatFromUnflatFormula(unflate: Expr, model: Model, evaluator: DefaultEvaluator): Seq[Expr] = { - def rec(e: Expr): Seq[Expr] = e match { - case IfExpr(cond, thn, elze) => - (evaluator.eval(cond, model): @unchecked) match { - case Successful(BooleanLiteral(true)) => cond +: rec(thn) - case Successful(BooleanLiteral(false)) => Not(cond) +: rec(elze) - } - case And(args) => args flatMap rec - case Or(args) => rec(args.find(evaluator.eval(_, model) == Successful(BooleanLiteral(true))).get) - case Equals(b: Variable, rhs) if b.getType == BooleanType => - (evaluator.eval(b, model): @unchecked) match { - case Successful(BooleanLiteral(true)) => - rec(b) ++ rec(rhs) - case Successful(BooleanLiteral(false)) => - Seq(Not(b)) - } - case e => Seq(e) - } - rec(unflate) - } -} diff --git a/src/main/scala/leon/invariant/structure/FunctionUtils.scala b/src/main/scala/leon/invariant/structure/FunctionUtils.scala deleted file mode 100644 index eb36720b93b678724d8c8bf13a1d4d8b8d4dd720..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/invariant/structure/FunctionUtils.scala +++ /dev/null @@ -1,175 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package invariant.structure - -import purescala.Common._ -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Types._ -import invariant.factories._ -import invariant.util._ -import Util._ -import PredicateUtil._ -import scala.language.implicitConversions -import ExpressionTransformer._ - -/** - * Some utiliy methods for functions. - * This also does caching to improve performance. - */ -object FunctionUtils { - - class FunctionInfo(fd: FunDef) { - //flags - lazy val isTheoryOperation = fd.annotations.contains("theoryop") - lazy val isMonotonic = fd.annotations.contains("monotonic") - lazy val isCommutative = fd.annotations.contains("commutative") - lazy val isDistributive = fd.annotations.contains("distributive") - lazy val compose = fd.annotations.contains("compose") - lazy val isLibrary = fd.annotations.contains("library") - lazy val isExtern = fd.annotations.contains("extern") - lazy val isBodyVisible = !fd.annotations.contains("invisibleBody") - lazy val hasFieldFlag = fd.flags.contains(IsField(false)) - lazy val hasLazyFieldFlag = fd.flags.contains(IsField(true)) - lazy val isUserFunction = !hasFieldFlag && !hasLazyFieldFlag - lazy val usePost = fd.annotations.contains("usePost") - - //the template function - lazy val tmplFunctionName = "tmpl" - /** - * checks if the function name is 'tmpl' and there is only one argument - * if not, type checker would anyway throw an error if leon.invariant._ is included - */ - def isTemplateInvocation(finv: Expr) = { - finv match { - case FunctionInvocation(funInv, args) => - funInv.id.name == "tmpl" && funInv.returnType == BooleanType && - args.size == 1 && args(0).isInstanceOf[Lambda] - case _ => - false - } - } - - def isQMark(e: Expr) = e match { - case FunctionInvocation(TypedFunDef(fd, Seq()), args) => - fd.id.name == "?" && fd.returnType == IntegerType && args.size <= 1 - case _ => false - } - - def extractTemplateFromLambda(tempLambda: Lambda): Expr = { - val Lambda(vdefs, body) = tempLambda - val vars = vdefs.map(_.id.toVariable) - val tempVars = vars.map { // reuse template variables if possible - case v if TemplateIdFactory.IsTemplateIdentifier(v.id) => v - case v => - TemplateIdFactory.freshIdentifier(v.id.name).toVariable - } - val repmap = (vars zip tempVars).toMap[Expr, Expr] - replace(repmap, body) - } - - def tmplFunction(paramTypes: Seq[TypeTree]) = { - val lambdaType = FunctionType(paramTypes, BooleanType) - val paramid = FreshIdentifier("lamb", lambdaType) - new FunDef(FreshIdentifier("tmpl", BooleanType), Seq(), Seq(ValDef(paramid)), BooleanType) - } - - /** - * Repackages '?' mark expression into tmpl functions - */ - def qmarksToTmplFunction(ine: Expr) = { - var tempIds = Seq[Identifier]() - var indexToId = Map[BigInt, Identifier]() - val lambBody = simplePostTransform { - case q @ FunctionInvocation(_, Seq()) if isQMark(q) => // question mark with zero args - val freshid = TemplateIdFactory.freshIdentifier("q") - tempIds :+= freshid - freshid.toVariable - - case q @ FunctionInvocation(_, Seq(InfiniteIntegerLiteral(index))) if isQMark(q) => //question mark with one arg - indexToId.getOrElse(index, { - val freshid = TemplateIdFactory.freshIdentifier("q" + index) - tempIds :+= freshid - indexToId += (index -> freshid) - freshid - }).toVariable - - case other => other - }(ine) - FunctionInvocation(TypedFunDef(tmplFunction(tempIds.map(_.getType)), Seq()), - Seq(Lambda(tempIds.map(id => ValDef(id)), lambBody))) - } - - /** - * Does not support mixing of tmpl exprs and '?'. - * Need to check that tmpl functions are not nested. - */ - lazy val (postWoTemplate, templateExpr) = { - if (fd.postcondition.isDefined) { - val Lambda(_, postBody) = fd.postcondition.get - // collect all terms with question marks and convert them to a template - val postWoQmarks = simplifyByConstructors(postBody) match { - case And(args) if args.exists(exists(isQMark)) => - val (tempExprs, otherPreds) = args.partition(exists(isQMark)) - //println(s"Otherpreds: $otherPreds ${qmarksToTmplFunction(createAnd(tempExprs))}") - createAnd(otherPreds :+ qmarksToTmplFunction(createAnd(tempExprs))) - case pb if exists(isQMark)(pb) => - pb match { - case l: Let => - val (letsCons, letsBody) = letStarUnapplyWithSimplify(l) // we try to see if the post is let* .. in e_1 ^ e_2 ^ ... - letsBody match { - case And(args) => - val (tempExprs, rest) = args.partition(exists(isQMark)) - val toTmplFun = qmarksToTmplFunction(letsCons(createAnd(tempExprs))) - createAnd(Seq(letsCons(createAnd(rest)), toTmplFun)) - case _ => - qmarksToTmplFunction(pb) - } - case _ => qmarksToTmplFunction(pb) - } - case other => other - } - //the 'body' could be a template or 'And(pred, template)' - postWoQmarks match { - case finv @ FunctionInvocation(_, args) if isTemplateInvocation(finv) => - (None, Some(finv)) - case And(args) if args.exists(isTemplateInvocation) => - val (tempFuns, otherPreds) = args.partition(isTemplateInvocation) - if (tempFuns.size > 1) { - throw new IllegalStateException("Multiple template functions used in the postcondition: " + postBody) - } else { - val rest = if (otherPreds.size <= 1) otherPreds(0) else And(otherPreds) - (Some(rest), Some(tempFuns(0).asInstanceOf[FunctionInvocation])) - } - case pb => - (Some(pb), None) - } - } else { - (None, None) - } - } - - lazy val template = templateExpr map (finv => extractTemplateFromLambda(finv.args(0).asInstanceOf[Lambda])) - lazy val normalizedTemplate = template.map(normalizeExpr(_, (e1: Expr, e2: Expr) => - throw new IllegalStateException("Not implemented yet!"))) - - def hasTemplate: Boolean = templateExpr.isDefined - def getPostWoTemplate = postWoTemplate match { - case None => tru - case Some(expr) => expr - } - def getTemplate = template.get - } - - // a cache for function infos - private var functionInfos = Map[FunDef, FunctionInfo]() - implicit def funDefToFunctionInfo(fd: FunDef): FunctionInfo = { - functionInfos.getOrElse(fd, { - val info = new FunctionInfo(fd) - functionInfos += (fd -> info) - info - }) - } -} diff --git a/src/main/scala/leon/invariant/structure/LinearConstraintUtil.scala b/src/main/scala/leon/invariant/structure/LinearConstraintUtil.scala deleted file mode 100644 index 5c48b2aeb7b6fb044aeca6111bef3d96f9c42c10..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/invariant/structure/LinearConstraintUtil.scala +++ /dev/null @@ -1,420 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package invariant.structure - -import purescala._ -import purescala.Common._ -import purescala.Expressions._ -import purescala.ExprOps._ -import leon.purescala.Types._ -import purescala.Extractors._ -import scala.collection.mutable.{ Map => MutableMap, Set => MutableSet, MutableList } -import invariant.util._ -import BigInt._ -import PredicateUtil._ -import Stats._ - -class NotImplementedException(message: String) extends RuntimeException(message) - -//a collections of utility methods that manipulate the templates -object LinearConstraintUtil { - val zero = InfiniteIntegerLiteral(0) - val one = InfiniteIntegerLiteral(1) - val mone = InfiniteIntegerLiteral(-1) - val tru = BooleanLiteral(true) - val fls = BooleanLiteral(false) - - val debugElimination = false - - //some utility methods - def getFIs(ctr: LinearConstraint): Set[FunctionInvocation] = { - val fis = ctr.coeffMap.keys.collect { - case fi: FunctionInvocation => fi - } - fis.toSet - } - - def evaluate(lt: LinearTemplate): Option[Boolean] = lt match { - case lc: LinearConstraint if lc.coeffMap.isEmpty => - ExpressionTransformer.simplify(lt.toExpr) match { - case BooleanLiteral(v) => Some(v) - case _ => None - } - case _ => None - } - - /** - * the expression 'Expr' is required to be a linear atomic predicate (or a template), - * if not, an exception would be thrown. - * For now some of the constructs are not handled. - * The function returns a linear template or a linear constraint depending - * on whether the expression has template variables or not - */ - def exprToTemplate(expr: Expr): LinearTemplate = { - - //these are the result values - var coeffMap = MutableMap[Expr, Expr]() - var constant: Option[Expr] = None - var isTemplate: Boolean = false - - def addCoefficient(term: Expr, coeff: Expr) = { - if (coeffMap.contains(term)) { - val value = coeffMap(term) - val newcoeff = simplifyArithmetic(Plus(value, coeff)) - //if newcoeff becomes zero remove it from the coeffMap - if (newcoeff == zero) { - coeffMap.remove(term) - } else { - coeffMap.update(term, newcoeff) - } - } else coeffMap += (term -> simplifyArithmetic(coeff)) - if (variablesOf(coeff).nonEmpty) { - isTemplate = true - } - } - - def addConstant(coeff: Expr) = { - if (constant.isDefined) { - val value = constant.get - constant = Some(simplifyArithmetic(Plus(value, coeff))) - } else - constant = Some(simplifyArithmetic(coeff)) - if (variablesOf(coeff).nonEmpty) { - isTemplate = true - } - } - - //recurse into plus and get all minterms - def getMinTerms(lexpr: Expr): Seq[Expr] = lexpr match { - case Plus(e1, e2) => getMinTerms(e1) ++ getMinTerms(e2) - case _ => Seq(lexpr) - } - - //the top most operator should be a relation - val Operator(Seq(lhs, InfiniteIntegerLiteral(x)), op) = makeLinear(expr) - /*if (lhs.isInstanceOf[InfiniteIntegerLiteral]) - throw new IllegalStateException("relation on two integers, not in canonical form: " + linearExpr)*/ - //handle each minterm - getMinTerms(lhs).foreach(minterm => minterm match { - case _ if (isTemplateExpr(minterm)) => addConstant(minterm) - case Times(e1, e2) => - e2 match { - case Variable(_) | ResultVariable(_) | FunctionInvocation(_, _) => - case _ => throw new IllegalStateException("Multiplicand not a constraint variable: " + e2) - } - e1 match { - case _ if (isTemplateExpr(e1)) => addCoefficient(e2, e1) - case _ => throw new IllegalStateException("Coefficient not a constant or template expression: " + e1) - } - case Variable(_) => addCoefficient(minterm, one) //here the coefficient is 1 - case ResultVariable(_) => addCoefficient(minterm, one) - case _ => throw new IllegalStateException("Unhandled min term: " + minterm) - }) - - if (coeffMap.isEmpty && constant.isEmpty) { - //here the generated template the constant term is zero. - new LinearConstraint(op, Map.empty, Some(zero)) - } else if (isTemplate) { - new LinearTemplate(op, coeffMap.toMap, constant) - } else { - new LinearConstraint(op, coeffMap.toMap, constant) - } - } - - /** - * This method may have to do all sorts of transformation to make the expressions linear constraints. - * This assumes that the input expression is an atomic predicate (i.e, without and, or and nots) - * This is subjected to constant modification. - */ - def makeLinear(atom: Expr): Expr = { - - //pushes the minus inside the arithmetic terms - //we assume that inExpr is in linear form - def pushMinus(inExpr: Expr): Expr = { - inExpr match { - case IntLiteral(v) => IntLiteral(-v) - case InfiniteIntegerLiteral(v) => InfiniteIntegerLiteral(-v) - case t: Terminal => Times(mone, t) - case fi @ FunctionInvocation(fdef, args) => Times(mone, fi) - case UMinus(e1) => e1 - case RealUMinus(e1) => e1 - case Minus(e1, e2) => Plus(pushMinus(e1), e2) - case RealMinus(e1, e2) => Plus(pushMinus(e1), e2) - case Plus(e1, e2) => Plus(pushMinus(e1), pushMinus(e2)) - case RealPlus(e1, e2) => Plus(pushMinus(e1), pushMinus(e2)) - case Times(e1, e2) => - //here push the minus in to the coefficient which is the first argument - Times(pushMinus(e1), e2) - case RealTimes(e1, e2) => Times(pushMinus(e1), e2) - case _ => throw new NotImplementedException("pushMinus -- Operators not yet handled: " + inExpr) - } - } - - //we assume that ine is in linear form - def pushTimes(mul: Expr, ine: Expr): Expr = { - val isReal = ine.getType == RealType && mul.getType == RealType - val timesCons = - if (isReal) RealTimes - else Times - ine match { - case t: Terminal => timesCons(mul, t) - case fi @ FunctionInvocation(fdef, ars) => timesCons(mul, fi) - case Plus(e1, e2) => Plus(pushTimes(mul, e1), pushTimes(mul, e2)) - case RealPlus(e1, e2) => - val r1 = pushTimes(mul, e1) - val r2 = pushTimes(mul, e2) - if (isReal) RealPlus(r1, r2) - else Plus(r1, r2) - case Times(e1, e2) => - //here push the times into the coefficient which should be the first expression - Times(pushTimes(mul, e1), e2) - case RealTimes(e1, e2) => - val r = pushTimes(mul, e1) - if (isReal) RealTimes(r, e2) - else Times(r, e2) - case _ => throw new NotImplementedException("pushTimes -- Operators not yet handled: " + ine) - } - } - - //collect all the constants in addition and simplify them - //we assume that ine is in linear form and also that all constants are integers - def simplifyConsts(ine: Expr): (Option[Expr], BigInt) = { - ine match { - case IntLiteral(v) => (None, v) - case InfiniteIntegerLiteral(v) => (None, v) - case Plus(e1, e2) => { - val (r1, c1) = simplifyConsts(e1) - val (r2, c2) = simplifyConsts(e2) - val newe = (r1, r2) match { - case (None, None) => None - case (Some(t), None) => Some(t) - case (None, Some(t)) => Some(t) - case (Some(t1), Some(t2)) => Some(Plus(t1, t2)) - } - (newe, c1 + c2) - } - case _ => (Some(ine), 0) - } - } - - def mkLinearRecur(inExpr: Expr): Expr = { - //println("inExpr: "+inExpr + " tpe: "+inExpr.getType) - val res = inExpr match { - case e @ Operator(Seq(e1, e2), op) if ((e.isInstanceOf[Equals] || e.isInstanceOf[LessThan] - || e.isInstanceOf[LessEquals] || e.isInstanceOf[GreaterThan] - || e.isInstanceOf[GreaterEquals])) => { - - //check if the expression has real valued sub-expressions - val isReal = hasRealsOrTemplates(e1) || hasRealsOrTemplates(e2) - val (newe, newop) = e match { - case t: Equals => (Minus(e1, e2), Equals) - case t: LessEquals => (Minus(e1, e2), LessEquals) - case t: GreaterEquals => (Minus(e2, e1), LessEquals) - case t: LessThan => - if (isReal) - (Minus(e1, e2), LessThan) - else - (Plus(Minus(e1, e2), one), LessEquals) - case t: GreaterThan => - if (isReal) - (Minus(e2, e1), LessThan) - else - (Plus(Minus(e2, e1), one), LessEquals) - } - val r = mkLinearRecur(newe) - //simplify the resulting constants - val (r2, const) = simplifyConsts(r) - val finale = if (r2.isDefined) { - if (const != 0) Plus(r2.get, InfiniteIntegerLiteral(const)) - else r2.get - } else InfiniteIntegerLiteral(const) - newop(finale, zero) - } - case Minus(e1, e2) => Plus(mkLinearRecur(e1), pushMinus(mkLinearRecur(e2))) - case RealMinus(e1, e2) => RealPlus(mkLinearRecur(e1), pushMinus(mkLinearRecur(e2))) - case UMinus(e1) => pushMinus(mkLinearRecur(e1)) - case RealUMinus(e1) => pushMinus(mkLinearRecur(e1)) - case Times(_, _) | RealTimes(_, _) => { - val Operator(Seq(e1, e2), op) = inExpr - val (r1, r2) = (mkLinearRecur(e1), mkLinearRecur(e2)) - if (isTemplateExpr(r1)) - pushTimes(r1, r2) - else if (isTemplateExpr(r2)) - pushTimes(r2, r1) - else - throw new IllegalStateException("Expression not linear: " + Times(r1, r2)) - } - case Plus(e1, e2) => Plus(mkLinearRecur(e1), mkLinearRecur(e2)) - case rp @ RealPlus(e1, e2) => - RealPlus(mkLinearRecur(e1), mkLinearRecur(e2)) - case t: Terminal => t - case fi: FunctionInvocation => fi - case _ => throw new IllegalStateException("Expression not linear: " + inExpr) - } - res - } - val rese = mkLinearRecur(atom) - rese - } - - /** - * Replaces an expression by another expression in the terms of the given linear constraint. - */ - def replaceInCtr(replaceMap: Map[Identifier, Expr], lc: LinearConstraint): Option[LinearConstraint] = { - //println("Replacing in "+lc+" repMap: "+replaceMap) - val newexpr = ExpressionTransformer.simplify(replaceFromIDs(replaceMap, lc.toExpr)) - if (newexpr == tru) None - else if (newexpr == fls) throw new IllegalStateException("!!Constraint reduced to false during elimination: " + lc) - else { - val res = exprToTemplate(newexpr) - //check if res is true or false - evaluate(res) match { - case Some(false) => throw new IllegalStateException("!!Constraint reduced to false during elimination: " + lc) - case Some(true) => None //constraint reduced to true - case _ => - Some(res.asInstanceOf[LinearConstraint]) - } - } - } - - def ctrVars(lc: LinearConstraint) = lc.coeffMap.keySet.map { case Variable(id) => id } - - /** - * Eliminates all variables except the `retainVars` from a conjunction of linear constraints (a disjunct) (that is satisfiable) - * We assume that the disjunct is in nnf form. - * The strategy is to look for (a) equality involving the elimVars or (b) check if all bounds are lower or (c) if all bounds are upper. - * TODO: handle cases wherein the coefficient of the variable that is substituted is not 1 or -1 - * - * @param debugger is a function used for debugging - */ - def apply1PRuleOnDisjunct(linearCtrs: Seq[LinearConstraint], retainVars: Set[Identifier], - debugger: Option[(Seq[LinearConstraint] => Unit)]): Seq[LinearConstraint] = { - val idsWithUpperBounds = MutableSet[Identifier]() // identifiers with only upper bounds - val idsWithLowerBounds = MutableSet[Identifier]() // identifiers with only lower bounds - val idsWithEquality = MutableSet[Identifier]() // identifiers for which an equality constraint exist - var eqctrs = MutableList[LinearConstraint]() - var restctrs = MutableList[LinearConstraint]() - linearCtrs.foreach { - case lc => - val vars = ctrVars(lc) - val elimVars = vars -- retainVars - lc.template match { - case eq: Equals => - idsWithEquality ++= vars - if (!elimVars.isEmpty) - eqctrs += lc - else restctrs += lc - // choose all vars whose coefficient is either 1 or -1 - case _: LessEquals | _: LessThan => - elimVars.foreach { elimVar => - val InfiniteIntegerLiteral(elimCoeff) = lc.coeffMap(elimVar.toVariable) - if (elimCoeff > 0) - idsWithUpperBounds += elimVar //here, we have found an upper bound - else - idsWithLowerBounds += elimVar //here, we have found a lower bound - } - restctrs += lc - case _ => throw new IllegalStateException("LinearConstraint not in expeceted form : " + lc.toExpr) - } - } - // sort 'eqctrs' by the size of the constraints so that we use smaller expressions in 'subst' map. - var currEqs = eqctrs.sortBy(eqc => eqc.coeffMap.keySet.size + (if (eqc.const.isDefined) 1 else 0)) - // compute the subst map recursively - var nextEqs = MutableList[LinearConstraint]() - var foundSubst = true - var subst = Map[Identifier, Expr]() - while (foundSubst) { - foundSubst = false - currEqs.foreach { eq => - // replace the constraint by the current subst (which may require multiple applications) - replaceInCtr(subst, eq) match { - case None => // constraint reduced to true, drop the constraint - case Some(newc) => - // choose one new variable that can be substituted - val elimVarOpt = ctrVars(newc).find { evar => - !retainVars.contains(evar) && !subst.contains(evar) && - (newc.coeffMap(evar.toVariable) match { - case InfiniteIntegerLiteral(elimCoeff) if (elimCoeff == 1 || elimCoeff == -1) => true - case _ => false - }) - } - elimVarOpt match { - case None => - nextEqs += newc // here, the constraint cannot be substituted, so we need to preserve it - case Some(elimVar) => - //if the coeffcient of elimVar is +ve the the sign of the coeff of every other term should be changed - val InfiniteIntegerLiteral(elimCoeff) = newc.coeffMap(elimVar.toVariable) - val changeSign = elimCoeff > 0 - val startval = if (newc.const.isDefined) { - val InfiniteIntegerLiteral(cval) = newc.const.get - val newconst = if (changeSign) -cval else cval - InfiniteIntegerLiteral(newconst) - } else zero - val substExpr = newc.coeffMap.foldLeft(startval: Expr) { - case (acc, (term, InfiniteIntegerLiteral(coeff))) if (term != elimVar.toVariable) => - val newcoeff = if (changeSign) -coeff else coeff - val newsummand = if (newcoeff == 1) term else Times(term, InfiniteIntegerLiteral(newcoeff)) - if (acc == zero) newsummand - else Plus(acc, newsummand) - case (acc, _) => acc - } - if (debugElimination) { - println("Analyzing ctr: " + newc + " found mapping: " + elimVar + " --> " + substExpr) - } - subst = Util.substClosure(subst + (elimVar -> simplifyArithmetic(substExpr))) - foundSubst = true - } - } - } - currEqs = nextEqs - } - val oneSidedVars = ((idsWithUpperBounds -- idsWithLowerBounds) ++ (idsWithLowerBounds -- idsWithUpperBounds)) -- idsWithEquality - val resctrs = (restctrs.flatMap { - case ctr if ctrVars(ctr).intersect(oneSidedVars).isEmpty => - replaceInCtr(subst, ctr) match { - case None => Seq() - case Some(newctr) => Seq(newctr) - } - case _ => Seq() // drop constraints with `oneSidedVars` - } ++ currEqs).distinct // note: this is very important!! - Stats.updateCumStats(currEqs.size, "UneliminatedEqualities") - resctrs - } - - def sizeExpr(ine: Expr): Int = { - val simpe = simplifyArithmetic(ine) - var size = 0 - simplePostTransform((e: Expr) => { - size += 1 - e - })(simpe) - size - } - - def sizeCtr(ctr: LinearConstraint): Int = { - val coeffSize = ctr.coeffMap.foldLeft(0)((acc, pair) => { - val (term, coeff) = pair - if (coeff == one) acc + 1 - else acc + sizeExpr(coeff) + 2 - }) - if (ctr.const.isDefined) coeffSize + 1 - else coeffSize - } - - /** - * Checks if the expression is linear i.e, - * is only conjuntion and disjunction of linear atomic predicates - */ - def isLinearFormula(e: Expr): Boolean = { - e match { - case And(args) => args forall isLinearFormula - case Or(args) => args forall isLinearFormula - case Not(arg) => isLinearFormula(arg) - case Implies(e1, e2) => isLinearFormula(e1) && isLinearFormula(e2) - case t: Terminal => true - case atom => - exprToTemplate(atom).isInstanceOf[LinearConstraint] - } - } -} diff --git a/src/main/scala/leon/invariant/templateSolvers/CegisSolver.scala b/src/main/scala/leon/invariant/templateSolvers/CegisSolver.scala deleted file mode 100644 index 72bffc37c94b5791af3b16582e3d6a4e740c90c5..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/invariant/templateSolvers/CegisSolver.scala +++ /dev/null @@ -1,370 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package invariant.templateSolvers - -import purescala.Common._ -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.ExprOps._ -import solvers._ -import invariant.engine._ -import invariant.factories._ -import invariant.util._ -import invariant.structure._ -import invariant.structure.FunctionUtils._ -import leon.invariant.util.RealValuedExprEvaluator._ -import PredicateUtil._ -import SolverUtil._ -import Stats._ -import Util._ - -class CegisSolver(ctx: InferenceContext, program: Program, - rootFun: FunDef, ctrTracker: ConstraintTracker, - timeout: Int, bound: Option[Int] = None) extends TemplateSolver(ctx, rootFun, ctrTracker) { - - override def solve(tempIds: Set[Identifier], funcs: Seq[FunDef]): (Option[Model], Option[Set[Call]]) = { - val initCtr = if (bound.isDefined) { - //use a predefined bound on the template variables - createAnd(tempIds.map((id) => { - val idvar = id.toVariable - And(Implies(LessThan(idvar, realzero), GreaterEquals(idvar, InfiniteIntegerLiteral(-bound.get))), - Implies(GreaterEquals(idvar, realzero), LessEquals(idvar, InfiniteIntegerLiteral(bound.get)))) - }).toSeq) - - } else tru - val formula = createOr(funcs.map(getVCForFun _).toSeq) - //using reals with bounds does not converge and also results in overflow - val (res, _, model) = (new CegisCore(ctx, program, timeout, this)).solve(tempIds, formula, initCtr, solveAsInt = true) - res match { - case Some(true) => (Some(model), None) - case Some(false) => (None, None) //no solution exists - case _ => //timed out - throw new IllegalStateException("Timeout!!") - } - } -} - -class CegisCore(ctx: InferenceContext, - program: Program, timeout: Int, - cegisSolver: TemplateSolver) { - - val fls = BooleanLiteral(false) - val tru = BooleanLiteral(true) - val zero = InfiniteIntegerLiteral(0) - val timeoutMillis = timeout.toLong * 1000 - val dumpCandidateInvs = true - val minimizeSum = false - val context = ctx.leonContext - val reporter = context.reporter - - /** - * Finds a model for the template variables in the 'formula' so that 'formula' is falsified - * subject to the constraints on the template variables given by the 'envCtrs' - * - * The parameter solveAsInt when set to true will convert the template constraints - * to integer constraints and solve. This should be enabled when bounds are used to constrain the variables - */ - def solve(tempIds: Set[Identifier], formula: Expr, initCtr: Expr, solveAsInt: Boolean, - initModel: Option[Model] = None): (Option[Boolean], Expr, Model) = { - - //start a timer - val startTime = System.currentTimeMillis() - - //for some sanity checks - var oldModels = Set[Expr]() - def addModel(m: Model) = { - val mexpr = modelToExpr(m) - if (oldModels.contains(mexpr)) - throw new IllegalStateException("repeating model !!:" + m) - else oldModels += mexpr - } - - //add the initial model - val simplestModel = if (initModel.isDefined) initModel.get else { - new Model(tempIds.map((id) => (id -> simplestValue(id.getType))).toMap) - } - addModel(simplestModel) - - val tempVarSum = if (minimizeSum) { - //compute the sum of the tempIds - val rootTempIds = getTemplateVars(cegisSolver.rootFun.getTemplate) - if (rootTempIds.nonEmpty) { - rootTempIds.tail.foldLeft(rootTempIds.head.asInstanceOf[Expr])((acc, tvar) => Plus(acc, tvar)) - } else zero - } else zero - - //convert initCtr to a real-constraint - val initRealCtr = ExpressionTransformer.IntLiteralToReal(initCtr) - if (hasInts(initRealCtr)) - throw new IllegalStateException("Initial constraints have integer terms: " + initRealCtr) - - def cegisRec(model: Model, prevctr: Expr): (Option[Boolean], Expr, Model) = { - - val elapsedTime = (System.currentTimeMillis() - startTime) - if (elapsedTime >= timeoutMillis - 100) { - //if we have timed out return the present set of constrains and the current model we have - (None, prevctr, model) - } else { - - //println("elapsedTime: "+elapsedTime / 1000+" timeout: "+timeout) - Stats.updateCounter(1, "CegisIters") - - if (dumpCandidateInvs) { - reporter.info("Candidate invariants") - val candInvs = TemplateInstantiator.getAllInvariants(model, cegisSolver.ctrTracker.getFuncs) - candInvs.foreach((entry) => println(entry._1.id + "-->" + entry._2)) - } - val tempVarMap: Map[Expr, Expr] = model.map((elem) => (elem._1.toVariable, elem._2)).toMap - val instFormula = simplifyArithmetic(TemplateInstantiator.instantiate(formula, tempVarMap)) - - //sanity checks - val spuriousTempIds = variablesOf(instFormula).intersect(TemplateIdFactory.getTemplateIds) - if (spuriousTempIds.nonEmpty) - throw new IllegalStateException("Found a template variable in instFormula: " + spuriousTempIds) - - //println("solving instantiated vcs...") - val solver1 = new ExtendedUFSolver(context, program) - solver1.assertCnstr(instFormula) - val (res, solTime) = getTime{ solver1.check } - println("1: " + (if (res.isDefined) "solved" else "timedout") + "... in " + solTime / 1000.0 + "s") - res match { - case Some(true) => { - //simplify the tempctrs, evaluate every atom that does not involve a template variable - //this should get rid of all functions - val satctrs = - simplePreTransform { - //is 'e' free of template variables ? - case e if !variablesOf(e).exists(TemplateIdFactory.IsTemplateIdentifier _) => { - //evaluate the term - val value = solver1.evalExpr(e) - if (value.isDefined) value.get - else throw new IllegalStateException("Cannot evaluate expression: " + e) - } - case e => e - }(Not(formula)) - solver1.free() - //sanity checks - val spuriousProgIds = variablesOf(satctrs).filterNot(TemplateIdFactory.IsTemplateIdentifier _) - if (spuriousProgIds.nonEmpty) - throw new IllegalStateException("Found a progam variable in tempctrs: " + spuriousProgIds) - val tempctrs = if (!solveAsInt) ExpressionTransformer.IntLiteralToReal(satctrs) else satctrs - val newctr = And(tempctrs, prevctr) - if (ctx.dumpStats) { - Stats.updateCounterStats(atomNum(newctr), "CegisTemplateCtrs", "CegisIters") - } - val t3 = System.currentTimeMillis() - val elapsedTime = (t3 - startTime) - val solver2 = SimpleSolverAPI(new TimeoutSolverFactory(SolverFactory("extededUF", () => new ExtendedUFSolver(context, program) with TimeoutSolver), - timeoutMillis - elapsedTime)) - val (res1, newModel) = if (solveAsInt) { - val intctr = And(newctr, initRealCtr) - if (minimizeSum) - minimizeIntegers(intctr, tempVarSum) - else - solver2.solveSAT(intctr) - } else { - if (minimizeSum) { - minimizeReals(And(newctr, initRealCtr), tempVarSum) - } else { - solver2.solveSAT(And(newctr, initRealCtr)) - } - } - println("2: " + (if (res1.isDefined) "solved" else "timed out") + "... in " + (System.currentTimeMillis() - t3) / 1000.0 + "s") - if (res1.isDefined) { - if (!res1.get) { - //there exists no solution for templates - (Some(false), newctr, Model.empty) - } else { - //this is for sanity check - addModel(newModel) - //generate more constraints - cegisRec(newModel, newctr) - } - } else { - //we have timed out - (None, prevctr, model) - } - } - case Some(false) => { - solver1.free() - //found a model for disabling the formula - (Some(true), prevctr, model) - } case _ => { - solver1.free() - throw new IllegalStateException("Cannot solve instFormula: " + instFormula) - } - } - } - } - //note: initRealCtr is used inside 'cegisRec' - cegisRec(simplestModel, tru) - } - - /** - * Performs minimization - */ - val MaxIter = 16 //note we may not be able to represent anything beyond 2^16 - val MaxInt = Int.MaxValue - val sqrtMaxInt = 45000 - val half = FractionalLiteral(1, 2) - val two = FractionalLiteral(2, 1) - val rzero = FractionalLiteral(0, 1) - val mone = FractionalLiteral(-1, 1) - val debugMinimization = false - - def minimizeReals(inputCtr: Expr, objective: Expr): (Option[Boolean], Model) = { - val sol = SimpleSolverAPI(new TimeoutSolverFactory(SolverFactory("extendedUF", () => new ExtendedUFSolver(context, program) with TimeoutSolver), timeoutMillis)) - val (res, model1) = sol.solveSAT(inputCtr) - res match { - case Some(true) => { - //do a binary search on sequentially on each of these tempvars - println("minimizing " + objective + " ...") - val idMap: Map[Expr, Expr] = variablesOf(objective).map(id => (id.toVariable -> model1(id))).toMap - var upperBound: FractionalLiteral = evaluate(replace(idMap, objective)) - var lowerBound: Option[FractionalLiteral] = None - var currentModel = model1 - var continue = true - var iter = 0 - do { - iter += 1 - //here we perform some sanity checks to prevent overflow - if (!boundSanityChecks(upperBound, lowerBound)) { - continue = false - } else { - if (lowerBound.isDefined && evaluateRealPredicate(GreaterEquals(lowerBound.get, upperBound))) { - continue = false - } else { - - val currval = if (lowerBound.isDefined) { - val midval = evaluate(Times(half, Plus(upperBound, lowerBound.get))) - floor(midval) - - } else { - val rlit @ FractionalLiteral(n, d) = upperBound - if (isGEZ(rlit)) { - if (n == 0) { - //make the upper bound negative - mone - } else { - floor(evaluate(Times(half, upperBound))) - } - } else floor(evaluate(Times(two, upperBound))) - - } - val boundCtr = LessEquals(objective, currval) - val solver2 = SimpleSolverAPI(new TimeoutSolverFactory(SolverFactory("extendedUF", () => new ExtendedUFSolver(context, program) with TimeoutSolver), timeoutMillis)) - val (res, newModel) = sol.solveSAT(And(inputCtr, boundCtr)) - res match { - case Some(true) => { - //here we have a new upper bound - currentModel = newModel - val idMap: Map[Expr, Expr] = variablesOf(objective).map(id => (id.toVariable -> newModel(id))).toMap - val value = evaluate(replace(idMap, objective)) - upperBound = value - if (this.debugMinimization) - reporter.info("Found new upper bound: " + upperBound) - } - case _ => { - //here we have a new lower bound : currval - lowerBound = Some(currval) - if (this.debugMinimization) - reporter.info("Found new lower bound: " + currval) - } - } - } - } - } while (continue && iter < MaxIter) - //here, we found a best-effort minimum - reporter.info("Minimization complete...") - (Some(true), currentModel) - } - case _ => (res, model1) - } - } - - def boundSanityChecks(ub: FractionalLiteral, lb: Option[FractionalLiteral]): Boolean = { - val FractionalLiteral(n, d) = ub - if (n <= (MaxInt / 2)) { - if (lb.isDefined) { - val FractionalLiteral(n2, _) = lb.get - (n2 <= sqrtMaxInt && d <= sqrtMaxInt) - } else { - (d <= (MaxInt / 2)) - } - } else false - } - - def minimizeIntegers(inputCtr: Expr, objective: Expr): (Option[Boolean], Model) = { - val sol = SimpleSolverAPI(new TimeoutSolverFactory(SolverFactory("extendedUF", () => new ExtendedUFSolver(context, program) with TimeoutSolver), timeoutMillis)) - val (res, model1) = sol.solveSAT(inputCtr) - res match { - case Some(true) => { - //do a binary search on sequentially on each of these tempvars - reporter.info("minimizing " + objective + " ...") - val idMap: Map[Expr, Expr] = variablesOf(objective).map(id => (id.toVariable -> model1(id))).toMap - var upperBound = simplifyArithmetic(replace(idMap, objective)).asInstanceOf[InfiniteIntegerLiteral].value - var lowerBound: Option[BigInt] = None - var currentModel = model1 - var continue = true - var iter = 0 - do { - iter += 1 - if (lowerBound.isDefined && lowerBound.get >= upperBound - 1) { - continue = false - } else { - - val currval = if (lowerBound.isDefined) { - val sum = (upperBound + lowerBound.get) - floorDiv(sum, 2) - } else { - if (upperBound >= 0) { - if (upperBound == 0) { - //make the upper bound negative - BigInt(-1) - } else { - floorDiv(upperBound, 2) - } - } else 2 * upperBound - } - val boundCtr = LessEquals(objective, InfiniteIntegerLiteral(currval)) - val solver2 = SimpleSolverAPI(new TimeoutSolverFactory(SolverFactory("extededUF", () => new ExtendedUFSolver(context, program) with TimeoutSolver), timeoutMillis)) - val (res, newModel) = sol.solveSAT(And(inputCtr, boundCtr)) - res match { - case Some(true) => { - //here we have a new upper bound - currentModel = newModel - val idMap: Map[Expr, Expr] = variablesOf(objective).map(id => (id.toVariable -> newModel(id))).toMap - val value = simplifyArithmetic(replace(idMap, objective)).asInstanceOf[InfiniteIntegerLiteral].value - upperBound = value - if (this.debugMinimization) - reporter.info("Found new upper bound: " + upperBound) - } - case _ => { - //here we have a new lower bound : currval - lowerBound = Some(currval) - if (this.debugMinimization) - reporter.info("Found new lower bound: " + currval) - } - } - } - } while (continue && iter < MaxIter) - //here, we found a best-effort minimum - reporter.info("Minimization complete...") - (Some(true), currentModel) - } - case _ => (res, model1) - } - } - - def floorDiv(did: BigInt, div: BigInt): BigInt = { - if (div <= 0) throw new IllegalStateException("Invalid divisor") - if (did < 0) { - if (did % div != 0) did / div - 1 - else did / div - } else { - did / div - } - } - -} diff --git a/src/main/scala/leon/invariant/templateSolvers/DisjunctChooser.scala b/src/main/scala/leon/invariant/templateSolvers/DisjunctChooser.scala deleted file mode 100644 index 27d83ef562a671045d8c7091caa2437af8af0d3c..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/invariant/templateSolvers/DisjunctChooser.scala +++ /dev/null @@ -1,213 +0,0 @@ -package leon -package invariant.templateSolvers - -import z3.scala._ -import purescala.Common._ -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Extractors._ -import purescala.Types._ -import evaluators._ -import java.io._ -import solvers._ -import solvers.combinators._ -import solvers.smtlib._ -import solvers.z3._ -import scala.util.control.Breaks._ -import purescala.ScalaPrinter -import scala.collection.mutable.{ Map => MutableMap } -import scala.reflect.runtime.universe -import invariant.engine._ -import invariant.factories._ -import invariant.util._ -import invariant.util.ExpressionTransformer._ -import invariant.structure._ -import invariant.structure.FunctionUtils._ -import Stats._ - -import Util._ -import PredicateUtil._ -import SolverUtil._ - -class DisjunctChooser(ctx: InferenceContext, program: Program, ctrTracker: ConstraintTracker, defaultEval: DefaultEvaluator) { - val debugElimination = false - val debugChooseDisjunct = false - val debugTheoryReduction = false - val debugAxioms = false - val debugReducedFormula = false - val verifyInvariant = false - val printPathToFile = false - val dumpPathAsSMTLIB = false - - val leonctx = ctx.leonContext - val linearEval = new LinearRelationEvaluator(ctx) // an evaluator for quickly checking the result of linear predicates - - //additional book-keeping for statistics - val trackNumericalDisjuncts = false - var numericalDisjuncts = List[Expr]() - - /** - * A helper function used only in debugging. - */ - protected def doesSatisfyExpr(expr: Expr, model: LazyModel): Boolean = { - val compModel = variablesOf(expr).map { k => k -> model(k) }.toMap - defaultEval.eval(expr, new Model(compModel)).result match { - case Some(BooleanLiteral(true)) => true - case _ => false - } - } - - /** - * This solver does not use any theories other than UF/ADT. It assumes that other theories are axiomatized in the VC. - * This method can be overloaded by the subclasses. - */ - protected def axiomsForTheory(formula: Formula, calls: Set[Call], model: LazyModel): Seq[Constraint] = Seq() - - /** - * Chooses a purely numerical disjunct from a given formula that is - * satisfied by the model - * @precondition the formula is satisfied by the model - * @tempIdMap a model for the template variables - */ - def chooseNumericalDisjunct(formula: Formula, initModel: LazyModel, tempIdMap: Map[Identifier, Expr]): (Seq[LinearConstraint], Seq[LinearTemplate], Set[Call]) = { - val satCtrs = formula.pickSatDisjunct(formula.firstRoot, initModel, tempIdMap, defaultEval) //this picks the satisfiable disjunct of the VC modulo axioms - //for debugging - if (debugChooseDisjunct || printPathToFile || dumpPathAsSMTLIB || verifyInvariant) { - val pathctrs = satCtrs.map(_.toExpr) - val plainFormula = createAnd(pathctrs) - val pathcond = simplifyArithmetic(plainFormula) - if (printPathToFile) { - //val simpcond = ExpressionTransformer.unFlatten(pathcond, variablesOf(pathcond).filterNot(TVarFactory.isTemporary _)) - ExpressionTransformer.PrintWithIndentation("full-path", pathcond) - } - if (dumpPathAsSMTLIB) { - val filename = "pathcond" + FileCountGUID.getID + ".smt2" - toZ3SMTLIB(pathcond, filename, "QF_NIA", leonctx, program) - println("Path dumped to: " + filename) - } - if (debugChooseDisjunct) { - satCtrs.filter(_.isInstanceOf[LinearConstraint]).map(_.toExpr).foreach((ctr) => { - if (!doesSatisfyExpr(ctr, initModel)) - throw new IllegalStateException("Path ctr not satisfied by model: " + ctr) - }) - } - if (verifyInvariant) { - println("checking invariant for path...") - val sat = checkInvariant(pathcond, leonctx, program) - } - } - var calls = Set[Call]() - var adtExprs = Seq[Expr]() - satCtrs.foreach { - case t: Call => calls += t - case t: ADTConstraint if (t.cons || t.sel) => adtExprs :+= t.expr - // TODO: ignoring all set constraints here, fix this - case _ => ; - } - val callExprs = calls.map(_.toExpr) - - val axiomCtrs = time { - ctrTracker.specInstantiator.axiomsForCalls(formula, calls, initModel, tempIdMap, defaultEval) - } { updateCumTime(_, "Total-AxiomChoose-Time") } - - //here, handle theory operations by reducing them to axioms. - //Note: uninterpreted calls/ADTs are handled below as they are more general. Here, we handle - //other theory axioms like: multiplication, sets, arrays, maps etc. - val theoryCtrs = time { - axiomsForTheory(formula, calls, initModel) - } { updateCumTime(_, "Total-TheoryAxiomatization-Time") } - - //Finally, eliminate UF/ADT - // convert all adt constraints to 'cons' ctrs, and expand the model - val selTrans = new SelectorToCons() - val cons = selTrans.selToCons(adtExprs) - val expModel = selTrans.getModel(initModel) - // get constraints for UFADTs - val callCtrs = time { - (new UFADTEliminator(leonctx, program)).constraintsForCalls((callExprs ++ cons), - linearEval.predEval(expModel)).map(ConstraintUtil.createConstriant _) - } { updateCumTime(_, "Total-ElimUF-Time") } - - //exclude guards, separate calls and cons from the rest - var lnctrs = Set[LinearConstraint]() - var temps = Set[LinearTemplate]() - (satCtrs ++ callCtrs ++ axiomCtrs ++ theoryCtrs).foreach { - case t: LinearConstraint => lnctrs += t - case t: LinearTemplate => temps += t - case _ => ; - } - if (debugChooseDisjunct) { - lnctrs.map(_.toExpr).foreach((ctr) => { - if (!doesSatisfyExpr(ctr, expModel)) - throw new IllegalStateException("Ctr not satisfied by model: " + ctr) - }) - } - if (debugTheoryReduction) { - val simpPathCond = createAnd((lnctrs ++ temps).map(_.template).toSeq) - if (verifyInvariant) { - println("checking invariant for simp-path...") - checkInvariant(simpPathCond, leonctx, program) - } - } - if (trackNumericalDisjuncts) { - numericalDisjuncts :+= createAnd((lnctrs ++ temps).map(_.template).toSeq) - } - val tempCtrs = temps.toSeq - val elimCtrs = eliminateVars(lnctrs.toSeq, tempCtrs) - //for debugging - if (debugReducedFormula) { - println("Final Path Constraints: " + elimCtrs ++ tempCtrs) - if (verifyInvariant) { - println("checking invariant for final disjunct... ") - checkInvariant(createAnd((elimCtrs ++ tempCtrs).map(_.template)), leonctx, program) - } - } - (elimCtrs, tempCtrs, calls) - } - - /** - * TODO:Remove transitive facts. E.g. a <= b, b <= c, a <=c can be simplified by dropping a <= c - * TODO: simplify the formulas and remove implied conjuncts if possible (note the formula is satisfiable, so there can be no inconsistencies) - * e.g, remove: a <= b if we have a = b or if a < b - * Also, enrich the rules for quantifier elimination: try z3 quantifier elimination on variables that have an equality. - * TODO: Use the dependence chains in the formulas to identify what to assertionize - * and what can never be implied by solving for the templates - */ - import LinearConstraintUtil._ - def eliminateVars(lnctrs: Seq[LinearConstraint], temps: Seq[LinearTemplate]): Seq[LinearConstraint] = { - if (temps.isEmpty) lnctrs //here ants ^ conseq is sat (otherwise we wouldn't reach here) and there is no way to falsify this path - else { - if (debugElimination && verifyInvariant) { - println("checking invariant for disjunct before elimination...") - checkInvariant(createAnd((lnctrs ++ temps).map(_.template)), leonctx, program) - } - // for debugging - val debugger = - if (debugElimination && verifyInvariant) { - Some((ctrs: Seq[LinearConstraint]) => { - val debugRes = checkInvariant(createAnd((ctrs ++ temps).map(_.template)), leonctx, program) - }) - } else None - val elimLnctrs = time { - apply1PRuleOnDisjunct(lnctrs, temps.flatMap(lt => variablesOf(lt.template)).toSet, debugger) - } { updateCumTime(_, "ElimTime") } - - if (debugElimination) { - println("Path constriants (after elimination): " + elimLnctrs) - if (verifyInvariant) { - println("checking invariant for disjunct after elimination...") - checkInvariant(createAnd((elimLnctrs ++ temps).map(_.template)), leonctx, program) - } - } - //for stats - if (ctx.dumpStats) { - Stats.updateCounterStats(lnctrs.size, "CtrsBeforeElim", "disjuncts") - Stats.updateCounterStats(lnctrs.size - elimLnctrs.size, "EliminatedAtoms", "disjuncts") - Stats.updateCounterStats(temps.size, "Param-Atoms", "disjuncts") - Stats.updateCounterStats(elimLnctrs.size, "NonParam-Atoms", "disjuncts") - } - elimLnctrs - } - } -} diff --git a/src/main/scala/leon/invariant/templateSolvers/ExistentialQuantificationSolver.scala b/src/main/scala/leon/invariant/templateSolvers/ExistentialQuantificationSolver.scala deleted file mode 100644 index a2ad660104332fdce5c436fae93dbbfc692f9724..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/invariant/templateSolvers/ExistentialQuantificationSolver.scala +++ /dev/null @@ -1,92 +0,0 @@ -package leon -package invariant.templateSolvers - -import z3.scala._ -import purescala.Common._ -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Extractors._ -import purescala.Types._ -import evaluators._ -import java.io._ -import solvers._ -import solvers.combinators._ -import solvers.smtlib._ -import solvers.z3._ -import scala.util.control.Breaks._ -import purescala.ScalaPrinter -import scala.collection.mutable.{ Map => MutableMap } -import scala.reflect.runtime.universe -import invariant.engine._ -import invariant.factories._ -import invariant.util._ -import invariant.util.ExpressionTransformer._ -import invariant.structure._ -import invariant.structure.FunctionUtils._ -import Stats._ - -import Util._ -import PredicateUtil._ -import SolverUtil._ - -/** - * This class uses Farkas' lemma to try and falsify numerical disjuncts with templates provided one by one - */ -class ExistentialQuantificationSolver(ctx: InferenceContext, program: Program, - ctrTracker: ConstraintTracker, defaultEval: DefaultEvaluator) { - import NLTemplateSolver._ - val reporter = ctx.reporter - - var currentCtr: Expr = tru - private val farkasSolver = new FarkasLemmaSolver(ctx, program) - val disjunctChooser = new DisjunctChooser(ctx, program, ctrTracker, defaultEval) - - def getSolvedCtrs = currentCtr - - def generateCtrsForUNSAT(fd: FunDef, univModel: LazyModel, tempModel: Model) = { - // chooose a sat numerical disjunct from the model - val (lnctrs, temps, calls) = - time { - disjunctChooser.chooseNumericalDisjunct(ctrTracker.getVC(fd), univModel, tempModel.toMap) - } { chTime => - updateCounterTime(chTime, "Disj-choosing-time", "disjuncts") - updateCumTime(chTime, "Total-Choose-Time") - } - val disjunct = (lnctrs ++ temps) - if (temps.isEmpty) { - //here ants ^ conseq is sat (otherwise we wouldn't reach here) and there is no way to falsify this path - (fls, disjunct, calls) - } else - (farkasSolver.constraintsForUnsat(lnctrs, temps), disjunct, calls) - } - - /** - * Solves the nonlinear Farkas' constraints - */ - def solveConstraints(newctrs: Seq[Expr], oldModel: Model): (Option[Boolean], Model) = { - val newPart = createAnd(newctrs) - val newSize = atomNum(newPart) - val currSize = atomNum(currentCtr) - Stats.updateCounterStats((newSize + currSize), "NLsize", "disjuncts") - if (verbose) reporter.info("# of atomic predicates: " + newSize + " + " + currSize) - - val combCtr = And(currentCtr, newPart) - val (res, newModel) = farkasSolver.solveFarkasConstraints(combCtr) - res match { - case _ if ctx.abort => - (None, Model.empty) // stop immediately - case None => - //here we have timed out while solving the non-linear constraints - if (verbose) reporter.info("NLsolver timed-out on the disjunct...") - (None, Model.empty) - case Some(false) => - currentCtr = fls - (Some(false), Model.empty) - case Some(true) => - currentCtr = combCtr - //new model may not have mappings for all the template variables, hence, use the mappings from earlier models - (Some(true), completeWithRefModel(newModel, oldModel)) - } - } -} diff --git a/src/main/scala/leon/invariant/templateSolvers/ExtendedUFSolver.scala b/src/main/scala/leon/invariant/templateSolvers/ExtendedUFSolver.scala deleted file mode 100644 index 8902564ded5e43b97bf736a838966e200cd4cfa3..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/invariant/templateSolvers/ExtendedUFSolver.scala +++ /dev/null @@ -1,73 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package invariant.templateSolvers - -import z3.scala._ -import purescala.Definitions._ -import purescala.Expressions._ -import leon.solvers.z3.UninterpretedZ3Solver - -/** - * A uninterpreted solver extended with additional functionalities. - * TODO: need to handle bit vectors - */ -class ExtendedUFSolver(context: LeonContext, program: Program) - extends UninterpretedZ3Solver(context.toSctx, program) { - - override val name = "Z3-eu" - override val description = "Extended UF-ADT Z3 Solver" - - /** - * This uses z3 methods to evaluate the model - */ - def evalExpr(expr: Expr): Option[Expr] = { - val ast = toZ3Formula(expr) - val model = solver.getModel - val res = model.eval(ast, true) - if (res.isDefined) - Some(fromZ3Formula(model, res.get, null)) - else None - } - - def getAssertions: Expr = { - val assers = solver.getAssertions.map((ast) => fromZ3Formula(null, ast, null)) - And(assers) - } - - /** - * Uses z3 to convert a formula to SMTLIB. - */ - def ctrsToString(logic: String, unsatcore: Boolean = false): String = { - z3.setAstPrintMode(Z3Context.AstPrintMode.Z3_PRINT_SMTLIB2_COMPLIANT) - var seenHeaders = Set[String]() - var headers = Seq[String]() - var asserts = Seq[String]() - solver.getAssertions().toSeq.foreach((asser) => { - val str = z3.benchmarkToSMTLIBString("benchmark", logic, "unknown", "", Seq(), asser) - //remove from the string the headers and also redeclaration of template variables - //split based on newline to get a list of strings - val strs = str.split("\n") - val newstrs = strs.filter((line) => !seenHeaders.contains(line)) - var newHeaders = Seq[String]() - newstrs.foreach((line) => { - if (line == "; benchmark") newHeaders :+= line - else if (line.startsWith("(set")) newHeaders :+= line - else if (line.startsWith("(declare")) newHeaders :+= line - else if (line.startsWith("(check-sat)")) {} //do nothing - else asserts :+= line - }) - headers ++= newHeaders - seenHeaders ++= newHeaders - }) - val initstr = if (unsatcore) { - "(set-option :produce-unsat-cores true)" - } else "" - val smtstr = headers.foldLeft(initstr)((acc, hdr) => acc + "\n" + hdr) + "\n" + - asserts.foldLeft("")((acc, asrt) => acc + "\n" + asrt) + "\n" + - "(check-sat)" + "\n" + - (if (!unsatcore) "(get-model)" - else "(get-unsat-core)") - smtstr - } -} diff --git a/src/main/scala/leon/invariant/templateSolvers/FarkasLemmaSolver.scala b/src/main/scala/leon/invariant/templateSolvers/FarkasLemmaSolver.scala deleted file mode 100644 index 11a560be07ec4c678ad479971dedd8a079149f88..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/invariant/templateSolvers/FarkasLemmaSolver.scala +++ /dev/null @@ -1,328 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package invariant.templateSolvers - -import purescala.Common._ -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Types._ -import solvers.SimpleSolverAPI -import invariant.engine._ -import invariant.util._ -import Util._ -import Stats._ -import SolverUtil._ -import PredicateUtil._ -import invariant.structure._ -import invariant.datastructure._ -import leon.solvers.TimeoutSolver -import leon.solvers.SolverFactory -import leon.solvers.TimeoutSolverFactory -import leon.solvers.Model -import leon.solvers.smtlib.SMTLIBZ3Solver -import leon.invariant.util.RealValuedExprEvaluator._ - -class FarkasLemmaSolver(ctx: InferenceContext, program: Program) { - - //debug flags - val verbose = true - val verifyModel = false - val dumpNLCtrsAsSMTLIB = false - val dumpNLCtrs = false - val debugNLCtrs = false - - // functionality flags - val solveAsBitvectors = false - val bvsize = 5 - val useIncrementalSolvingForNLctrs = false //note: NLsat doesn't support incremental solving. It starts from sratch even in incremental solving. - - val leonctx = ctx.leonContext - val reporter = ctx.reporter - val timeout = ctx.nlTimeout // Note: we are using vcTimeout here as well - - /** - * This procedure produces a set of constraints that need to be satisfiable for the - * conjunction ants and conseqs to be false - * antsSimple - antecedents without template variables - * antsTemp - antecedents with template variables - * Similarly for conseqsSimple and conseqsTemp - * - * Let A,A' and C,C' denote the simple and templated portions of the antecedent and the consequent respectively. - * We need to check \exists a, \forall x, A[x] ^ A'[x,a] ^ C[x] ^ C'[x,a] = false - * - */ - def constraintsForUnsat(linearCtrs: Seq[LinearConstraint], temps: Seq[LinearTemplate]): Expr = { - this.applyFarkasLemma(linearCtrs ++ temps, Seq(), true) - } - - /** - * This procedure produces a set of constraints that need to be satisfiable for the implication to hold - * antsSimple - antecedents without template variables - * antsTemp - antecedents with template variables - * Similarly for conseqsSimple and conseqsTemp - * - * Let A,A' and C,C' denote the simple and templated portions of the antecedent and the consequent respectively. - * We need to check \exists a, \forall x, A[x] ^ A'[x,a] => C[x] ^ C'[x,a] - * - */ - def constraintsForImplication(antsSimple: Seq[LinearConstraint], antsTemp: Seq[LinearTemplate], - conseqsSimple: Seq[LinearConstraint], conseqsTemp: Seq[LinearTemplate], - uisolver: SimpleSolverAPI): Expr = { - - val allAnts = antsSimple ++ antsTemp - val allConseqs = conseqsSimple ++ conseqsTemp - //for debugging - println("#" * 20) - println(allAnts + " => " + allConseqs) - println("#" * 20) - - //Optimization 1: Check if ants are unsat (already handled) - val pathVC = createAnd(antsSimple.map(_.toExpr) ++ conseqsSimple.map(_.toExpr)) - val notPathVC = And(createAnd(antsSimple.map(_.toExpr)), Not(createAnd(conseqsSimple.map(_.toExpr)))) - val (satVC, _) = uisolver.solveSAT(pathVC) - val (satNVC, _) = uisolver.solveSAT(notPathVC) - - //Optimization 2: use the unsatisfiability of VC and not VC to simplify the constraint generation - //(a) if A => C is false and A' is true then the entire formula is unsat - //(b) if A => C is false and A' is not true then we need to ensure A^A' is unsat (i.e, disable Ant) - //(c) if A => C is true (i.e, valid) then it suffices to ensure A^A' => C' is valid - //(d) if A => C is neither true nor false then we cannot do any simplification - //TODO: Food for thought: - //(a) can we do any simplification for case (d) with the model - //(b) could the linearity in the disabled case be exploited - val (ants, conseqs, disableFlag) = (satVC, satNVC) match { - case (Some(false), _) if (antsTemp.isEmpty) => (Seq(), Seq(), false) - case (Some(false), _) => (allAnts, Seq(), true) //here only disable the antecedents - case (_, Some(false)) => (allAnts, conseqsTemp, false) //here we need to only check the inductiveness of the templates - case _ => (allAnts, allConseqs, false) - } - if (ants.isEmpty) { - BooleanLiteral(false) - } else { - this.applyFarkasLemma(ants, conseqs, disableFlag) - } - } - - /** - * This procedure uses Farka's lemma to generate a set of non-linear constraints for the input implication. - * Note that these non-linear constraints are in real arithmetic. - * TODO: Correctness issue: need to handle strict inequalities in consequent - * Do we really need the consequent ?? - */ - def applyFarkasLemma(ants: Seq[LinearTemplate], conseqs: Seq[LinearTemplate], disableAnts: Boolean): Expr = { - - //compute the set of all constraint variables in ants - val antCVars = ants.foldLeft(Set[Expr]())((acc, ant) => acc ++ ant.coeffTemplate.keySet) - - //the creates constraints for a single consequent - def createCtrs(conseq: Option[LinearTemplate]): Expr = { - //create a set of identifiers one for each ants - val lambdas = ants.map((ant) => (ant -> Variable(FreshIdentifier("l", RealType, true)))).toMap - val lambda0 = Variable(FreshIdentifier("l", RealType, true)) - - //add a bunch of constraints on lambdas - var strictCtrLambdas = Seq[Variable]() - val lambdaCtrs = (ants.collect((ant) => ant.template match { - case t: LessEquals => GreaterEquals(lambdas(ant), zero) - case t: LessThan => { - val l = lambdas(ant) - strictCtrLambdas :+= l - GreaterEquals(l, zero) - } - }) :+ GreaterEquals(lambda0, zero)) - - //add the constraints on constant terms - val sumConst = ants.foldLeft(UMinus(lambda0): Expr)((acc, ant) => ant.constTemplate match { - case Some(d) => Plus(acc, Times(lambdas(ant), d)) - case None => acc - }) - - val cvars = antCVars ++ (if (conseq.isDefined) conseq.get.coeffTemplate.keys else Seq()) - //initialize enabled and disabled parts - var enabledPart: Expr = if (conseq.isDefined) { - conseq.get.constTemplate match { - case Some(d) => Equals(d, sumConst) - case None => Equals(zero, sumConst) - } - } else null - //the disabled part handles strict inequalities as well using Motzkin's transposition - var disabledPart: Expr = - if (strictCtrLambdas.isEmpty) Equals(one, sumConst) - else Or(Equals(one, sumConst), - And(Equals(zero, sumConst), createOr(strictCtrLambdas.map(GreaterThan(_, zero))))) - - for (cvar <- cvars) { - //compute the linear combination of all the coeffs of antCVars - //println("Processing cvar: "+cvar) - var sumCoeff: Expr = zero - for (ant <- ants) { - //handle coefficients here - if (ant.coeffTemplate.contains(cvar)) { - val addend = Times(lambdas(ant), ant.coeffTemplate.get(cvar).get) - if (sumCoeff == zero) - sumCoeff = addend - else - sumCoeff = Plus(sumCoeff, addend) - } - } - //println("sum coeff: "+sumCoeff) - //make the sum equal to the coeff. of cvar in conseq - if (conseq.isDefined) { - enabledPart = And(enabledPart, - (if (conseq.get.coeffTemplate.contains(cvar)) - Equals(conseq.get.coeffTemplate.get(cvar).get, sumCoeff) - else Equals(zero, sumCoeff))) - } - - disabledPart = And(disabledPart, Equals(zero, sumCoeff)) - } //end of cvars loop - - //the final constraint is a conjunction of lambda constraints and disjunction of enabled and disabled parts - if (disableAnts) And(createAnd(lambdaCtrs), disabledPart) - else { - And(createAnd(lambdaCtrs), Or(enabledPart, disabledPart)) - } - } - - val ctrs = if (disableAnts) { - createCtrs(None) - } else { - val Seq(head, tail @ _*) = conseqs - val nonLinearCtrs = tail.foldLeft(createCtrs(Some(head)))((acc, conseq) => And(acc, createCtrs(Some(conseq)))) - nonLinearCtrs - } - ExpressionTransformer.IntLiteralToReal(ctrs) - } - - def solveFarkasConstraints(nlctrs: Expr): (Option[Boolean], Model) = { - - // factor out common nonlinear terms and create an equiv-satisfiable constraint - def reduceCommonNLTerms(ctrs: Expr) = { - val nlUsage = new CounterMap[Expr]() - postTraversal{ - case t: Times => nlUsage.inc(t) - case e => ; - }(ctrs) - val repMap = nlUsage.collect{ - case (k, v) if v > 1 => - (k -> FreshIdentifier("t", RealType, true).toVariable) - }.toMap - createAnd(replace(repMap, ctrs) +: repMap.map { - case (k, v) => Equals(v, k) - }.toSeq) - } - - // try eliminate nonlinearity to whatever extent possible - var elimMap = Map[Identifier, (Identifier, Identifier)]() // maps the fresh identifiers to the product of the identifiers they represent. - def reduceNonlinearity(farkasctrs: Expr): Expr = { - val varCounts = new CounterMap[Identifier]() - // collect # of uses of each variable - postTraversal { - case Variable(id) => varCounts.inc(id) - case _ => ; - }(farkasctrs) - var adnlCtrs = Seq[Expr]() - val simpCtrs = simplePostTransform { - case Times(vlb @ Variable(lb), va @ Variable(a)) if (varCounts(lb) == 1 || varCounts(a) == 1) => // is lb or a used only once ? - // stats - Stats.updateCumStats(1, "NonlinearMultEliminated") - val freshid = FreshIdentifier(lb.name + a.name, RealType, true) - val freshvar = freshid.toVariable - elimMap += (freshid -> (lb, a)) - if (varCounts(lb) == 1) - // va = 0 ==> freshvar = 0 - adnlCtrs :+= Implies(Equals(va, realzero), Equals(freshvar, realzero)) - else // here varCounts(a) == 1 - adnlCtrs :+= Implies(Equals(vlb, realzero), Equals(freshvar, realzero)) - freshvar - case e => - e - }(farkasctrs) - createAnd(simpCtrs +: adnlCtrs) - } - val simpctrs = (reduceCommonNLTerms _ andThen - reduceNonlinearity)(nlctrs) - - //for debugging nonlinear constraints - if (this.debugNLCtrs && hasInts(simpctrs)) { - throw new IllegalStateException("Nonlinear constraints have integers: " + simpctrs) - } - if (verbose && LinearConstraintUtil.isLinearFormula(simpctrs)) { - reporter.info("Constraints reduced to linear !") - } - if (this.dumpNLCtrs) { - reporter.info("InputCtrs: " + nlctrs) - reporter.info("SimpCtrs: " + simpctrs) - if (this.dumpNLCtrsAsSMTLIB) { - val filename = program.modules.last.id + "-nlctr" + FileCountGUID.getID + ".smt2" - if (atomNum(simpctrs) >= 5) { - if (solveAsBitvectors) - toZ3SMTLIB(simpctrs, filename, "QF_BV", leonctx, program, useBitvectors = true, bitvecSize = bvsize) - else - toZ3SMTLIB(simpctrs, filename, "QF_NRA", leonctx, program) - reporter.info("NLctrs dumped to: " + filename) - } - } - } - - // solve the resulting constraints using solver - lazy val solver = if (solveAsBitvectors) { - throw new IllegalStateException("Not supported now. Will be in the future!") - //new ExtendedUFSolver(leonctx, program, useBitvectors = true, bitvecSize = bvsize) with TimeoutSolver - } else { - //new AbortableSolver(() => new SMTLIBZ3Solver(leonctx, program) with TimeoutSolver, ctx) - SimpleSolverAPI(new TimeoutSolverFactory( - SolverFactory.getFromName(leonctx, program)("smt-z3-u"), - timeout * 1000)) - } - if (verbose) reporter.info("solving...") - val (res, model) = - if (ctx.abort) (None, Model.empty) - else { - val (r, solTime) = getTime { solver.solveSAT(simpctrs) } - if (verbose) reporter.info((if (r._1.isDefined) "solved" else "timed out") + "... in " + solTime / 1000.0 + "s") - Stats.updateCounterTime(solTime, "NL-solving-time", "disjuncts") - r - } - res match { - case Some(true) => - // construct assignments for the variables that were removed during nonlinearity reduction - def divide(dividend: Expr, divisor: Expr) = { - divisor match { - case `realzero` => - assert(dividend == realzero) - // here result can be anything. So make it zero - realzero - case _ => - val res = evaluate(Division(dividend, divisor)) - res - } - } - val newassignments = elimMap.flatMap { - case (k, (v1, v2)) => - val kval = evaluate(model(k)) - if (model.isDefinedAt(v1) && model.isDefinedAt(v2)) - throw new IllegalStateException( - s"Variables $v1 and $v2 in an eliminated nonlinearity have models") - else if (model.isDefinedAt(v1)) { - val v2val = divide(kval, evaluate(model(v1))) - Seq((v2 -> v2val)) - } else if (model.isDefinedAt(v2)) - Seq((v1 -> divide(kval, evaluate(model(v2))))) - else - // here v1 * v2 = k. Therefore make v1 = k and v2 = 1 - Seq((v1 -> kval), (v2 -> FractionalLiteral(1, 1))) - } - val fullmodel = model ++ newassignments - if (this.verifyModel) { - val formula = replace(fullmodel.map { case (k, v) => (k.toVariable, v)}.toMap, nlctrs) - assert(evaluateRealFormula(formula)) - } - (res, fullmodel) - case _ => - (res, model) - } - } -} diff --git a/src/main/scala/leon/invariant/templateSolvers/NLTemplateSolver.scala b/src/main/scala/leon/invariant/templateSolvers/NLTemplateSolver.scala deleted file mode 100644 index 7a5dd39a921944d067697dc7eef721a172ba56c7..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/invariant/templateSolvers/NLTemplateSolver.scala +++ /dev/null @@ -1,61 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package invariant.templateSolvers - -import z3.scala._ -import purescala.Common._ -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Extractors._ -import purescala.Types._ -import evaluators._ -import java.io._ -import solvers._ -import solvers.combinators._ -import solvers.smtlib._ -import solvers.z3._ -import scala.util.control.Breaks._ -import purescala.ScalaPrinter -import scala.collection.mutable.{ Map => MutableMap } -import scala.reflect.runtime.universe -import invariant.engine._ -import invariant.factories._ -import invariant.util._ -import invariant.util.ExpressionTransformer._ -import invariant.structure._ -import invariant.structure.FunctionUtils._ -import Stats._ - -import Util._ -import PredicateUtil._ -import SolverUtil._ - -object NLTemplateSolver { - val verbose = true -} - -class NLTemplateSolver(ctx: InferenceContext, program: Program, - rootFun: FunDef, ctrTracker: ConstraintTracker, - minimizer: Option[(Expr, Model) => Model]) - extends TemplateSolver(ctx, rootFun, ctrTracker) { - - private val startFromEarlierModel = false - // state for tracking the last model - private var lastFoundModel: Option[Model] = None - - /** - * This function computes invariants belonging to the given templates incrementally. - * The result is a mapping from function definitions to the corresponding invariants. - */ - override def solve(tempIds: Set[Identifier], funs: Seq[FunDef]): (Option[Model], Option[Set[Call]]) = { - val initModel = completeModel( - (if (this.startFromEarlierModel && lastFoundModel.isDefined) lastFoundModel.get - else Model.empty), tempIds) - val univSolver = new UniversalQuantificationSolver(ctx, program, funs, ctrTracker, minimizer) - val (resModel, seenCalls) = univSolver.solveUNSAT(initModel, (m: Model) => lastFoundModel = Some(m)) - univSolver.free - (resModel, seenCalls) - } -} diff --git a/src/main/scala/leon/invariant/templateSolvers/NLTemplateSolverWithMult.scala b/src/main/scala/leon/invariant/templateSolvers/NLTemplateSolverWithMult.scala deleted file mode 100644 index 78746330187015714b04dee7c9b2cb25d0440b66..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/invariant/templateSolvers/NLTemplateSolverWithMult.scala +++ /dev/null @@ -1,97 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package invariant.templateSolvers - -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.Extractors._ -import solvers._ - -import invariant.engine._ -import invariant.factories._ -import invariant.util._ -import invariant.structure._ -import Util._ -import PredicateUtil._ - -class NLTemplateSolverWithMult(ctx: InferenceContext, program: Program, rootFun: FunDef, - ctrTracker: ConstraintTracker, minimizer: Option[(Expr, Model) => Model]) - extends NLTemplateSolver(ctx, program, rootFun, ctrTracker, minimizer) { - - throw new IllegalStateException("Not Maintained!!") -// val axiomFactory = new AxiomFactory(ctx) -// -// override def getVCForFun(fd: FunDef): Expr = { -// val plainvc = ctrTracker.getVC(fd).toExpr -// val nlvc = multToTimes(plainvc) -// nlvc -// } -// -// override def splitVC(fd: FunDef) = { -// val (paramPart, rest, modCons) = super.splitVC(fd) -// (multToTimes(paramPart), multToTimes(rest), modCons) -// } -// -// override def axiomsForTheory(formula: Formula, calls: Set[Call], model: LazyModel): Seq[Constraint] = { -// -// //in the sequel we instantiate axioms for multiplication -// val inst1 = unaryMultAxioms(formula, calls, linearEval.predEval(model)) -// val inst2 = binaryMultAxioms(formula, calls, linearEval.predEval(model)) -// val multCtrs = (inst1 ++ inst2).flatMap { -// case And(args) => args.map(ConstraintUtil.createConstriant _) -// case e => Seq(ConstraintUtil.createConstriant(e)) -// } -// -// Stats.updateCounterStats(multCtrs.size, "MultAxiomBlowup", "disjuncts") -// ctx.reporter.info("Number of multiplication induced predicates: " + multCtrs.size) -// multCtrs -// } -// -// def chooseSATPredicate(expr: Expr, predEval: (Expr => Option[Boolean])): Expr = { -// val norme = ExpressionTransformer.normalizeExpr(expr, ctx.multOp) -// val preds = norme match { -// case Or(args) => args -// case Operator(_, _) => Seq(norme) -// case _ => throw new IllegalStateException("Not(ant) is not in expected format: " + norme) -// } -// //pick the first predicate that holds true -// preds.collectFirst { case pred @ _ if predEval(pred).get => pred }.get -// } -// -// def isMultOp(call: Call): Boolean = { -// isMultFunctions(call.fi.tfd.fd) -// } -// -// def unaryMultAxioms(formula: Formula, calls: Set[Call], predEval: (Expr => Option[Boolean])): Seq[Expr] = { -// val axioms = calls.flatMap { -// case call @ _ if (isMultOp(call) && axiomFactory.hasUnaryAxiom(call)) => { -// val (ant, conseq) = axiomFactory.unaryAxiom(call) -// if (predEval(ant).get) -// Seq(ant, conseq) -// else -// Seq(chooseSATPredicate(Not(ant), predEval)) -// } -// case _ => Seq() -// } -// axioms.toSeq -// } -// -// def binaryMultAxioms(formula: Formula, calls: Set[Call], predEval: (Expr => Option[Boolean])): Seq[Expr] = { -// -// val mults = calls.filter(call => isMultOp(call) && axiomFactory.hasBinaryAxiom(call)) -// val product = cross(mults, mults).collect { case (c1, c2) if c1 != c2 => (c1, c2) } -// -// ctx.reporter.info("Theory axioms: " + product.size) -// Stats.updateCumStats(product.size, "-Total-theory-axioms") -// -// val newpreds = product.flatMap(pair => { -// val axiomInsts = axiomFactory.binaryAxiom(pair._1, pair._2) -// axiomInsts.flatMap { -// case (ant, conseq) if predEval(ant).get => Seq(ant, conseq) //if axiom-pre holds. -// case (ant, _) => Seq(chooseSATPredicate(Not(ant), predEval)) //if axiom-pre does not hold. -// } -// }) -// newpreds.toSeq -// } -} diff --git a/src/main/scala/leon/invariant/templateSolvers/TemplateSolver.scala b/src/main/scala/leon/invariant/templateSolvers/TemplateSolver.scala deleted file mode 100644 index c46228f3c99e71e49227ed5e6c1094537ca22d48..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/invariant/templateSolvers/TemplateSolver.scala +++ /dev/null @@ -1,85 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package invariant.templateSolvers - -import scala.collection.mutable.{Map => MutableMap} -import purescala.Common._ -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.ExprOps._ -import java.io._ -import invariant.engine._ -import invariant.factories._ -import invariant.util._ -import invariant.structure._ -import invariant.structure.FunctionUtils._ -import leon.solvers.Model -import PredicateUtil._ -import ExpressionTransformer._ - -abstract class TemplateSolver(ctx: InferenceContext, val rootFun: FunDef, - val ctrTracker: ConstraintTracker) { - - protected val reporter = ctx.reporter - private val dumpVCtoConsole = false - private val dumpVCasText = false - - /** - * Completes a model by adding mapping to new template variables - */ - def completeModel(model: Model, ids: Set[Identifier]) = { - val idmap = ids.map { id => - if (!model.isDefinedAt(id)) { - (id, simplestValue(id.getType)) - } else (id, model(id)) - }.toMap - new Model(idmap) - } - - var vcCache = Map[FunDef, Expr]() - protected def getVCForFun(fd: FunDef): Expr = { - vcCache.getOrElse(fd, { - val vcInit = ctrTracker.getVC(fd).toExpr - val vc = if (ctx.usereals) - ExpressionTransformer.IntLiteralToReal(vcInit) - else vcInit - vcCache += (fd -> vc) - vc - }) - } - - /** - * This function computes invariants belonging to the given templates incrementally. - * The result is a mapping from function definitions to the corresponding invariants. - */ - def solveTemplates(): (Option[Model], Option[Set[Call]]) = { - val funcs = ctrTracker.getFuncs - val tempIds = funcs.flatMap { fd => - val vc = ctrTracker.getVC(fd) - if (dumpVCtoConsole || dumpVCasText) { - val filename = "vc-" + FileCountGUID.getID - if (dumpVCtoConsole) { - println("Func: " + fd.id + " VC: " + vc) - } - if (dumpVCasText) { - val wr = new PrintWriter(new File(filename + ".txt")) - println("Printed VC of " + fd.id + " to file: " + filename) - wr.println(vc.toString()) - wr.close() - } - } - if (ctx.dumpStats) { - Stats.updateCounterStats(vc.atomsCount, "VC-size", "VC-refinement") - Stats.updateCounterStats(vc.funsCount, "UIF+ADT", "VC-refinement") - } - vc.templateIdsInFormula - }.toSet - - Stats.updateCounterStats(tempIds.size, "TemplateIds", "VC-refinement") - if (ctx.abort) (None, None) - else solve(tempIds, funcs) - } - - def solve(tempIds: Set[Identifier], funcVCs: Seq[FunDef]): (Option[Model], Option[Set[Call]]) -} \ No newline at end of file diff --git a/src/main/scala/leon/invariant/templateSolvers/UFADTEliminator.scala b/src/main/scala/leon/invariant/templateSolvers/UFADTEliminator.scala deleted file mode 100644 index d5c0ccc0d1f65d6cab2438bb27ae61e44b37453e..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/invariant/templateSolvers/UFADTEliminator.scala +++ /dev/null @@ -1,325 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package invariant.templateSolvers - -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.Extractors._ -import purescala.Types._ -import invariant.datastructure._ -import invariant.util._ -import leon.purescala.TypeOps -import PredicateUtil._ -import Stats._ -import scala.collection.mutable.{ Set => MutableSet, Map => MutableMap, MutableList } - -class UFADTEliminator(ctx: LeonContext, program: Program) { - - val debugAliases = false - val makeEfficient = true //this will happen at the expense of completeness - val reporter = ctx.reporter - val verbose = false - - // def collectCompatibleCalls(calls: Set[Expr]) = { - // //compute the cartesian product of the calls and select the pairs having the same function symbol and also implied by the precond - // val vec = calls.toArray - // val size = calls.size - // var j = 0 - // //for stats - // var tuples = 0 - // var functions = 0 - // var adts = 0 - // val product = vec.foldLeft(Set[(Expr, Expr)]())((acc, call) => { - // //an optimization: here we can exclude calls to maxFun from axiomatization, they will be inlined anyway - // /*val shouldConsider = if(InvariantisCallExpr(call)) { - // val BinaryOperator(_,FunctionInvocation(calledFun,_), _) = call - // if(calledFun == DepthInstPhase.maxFun) false - // else true - // } else true*/ - // var pairs = Set[(Expr, Expr)]() - // for (i <- j + 1 until size) { - // val call2 = vec(i) - // if (mayAlias(call, call2)) { - // call match { - // case Equals(_, fin: FunctionInvocation) => functions += 1 - // case Equals(_, tup: Tuple) => tuples += 1 - // case _ => adts += 1 - // } - // if (debugAliases) - // println("Aliases: " + call + "," + call2) - // pairs ++= Set((call, call2)) - // } else { - // if (debugAliases) { - // (call, call2) match { - // case (Equals(_, t1 @ Tuple(_)), Equals(_, t2 @ Tuple(_))) => - // println("No Aliases: " + t1.getType + "," + t2.getType) - // case _ => println("No Aliases: " + call + "," + call2) - // } - // } - // } - // } - // j += 1 - // acc ++ pairs - // }) - // if (verbose) reporter.info("Number of compatible calls: " + product.size) - // Stats.updateCounterStats(product.size, "Compatible-Calls", "disjuncts") - // Stats.updateCumStats(functions, "Compatible-functioncalls") - // Stats.updateCumStats(adts, "Compatible-adtcalls") - // Stats.updateCumStats(tuples, "Compatible-tuples") - // product - // } - - def collectCompatibleTerms(terms: Set[Expr]) = { - class Comp(val key: Either[TypedFunDef, TypeTree]) { - override def equals(other: Any) = other match { - case otherComp: Comp => mayAlias(key, otherComp.key) - case _ => false - } - // an weaker property whose equality is necessary for mayAlias - val hashcode = - (key: @unchecked) match { - case Left(TypedFunDef(fd, _)) => fd.id.hashCode() - case Right(ct: CaseClassType) => ct.classDef.id.hashCode() - case Right(tp @ TupleType(tps)) => (tps.hashCode() << 3) ^ tp.dimension - } - override def hashCode = hashcode - } - val compTerms = MutableMap[Comp, MutableList[Expr]]() - terms.foreach { term => - //an optimization: here we can exclude calls to maxFun from axiomatization, they will be inlined anyway - /*val shouldConsider = if(InvariantisCallExpr(call)) { - val BinaryOperator(_,FunctionInvocation(calledFun,_), _) = call - if(calledFun == DepthInstPhase.maxFun) false - else true - } else true*/ - val compKey: Either[TypedFunDef, TypeTree] = term match { - case Equals(_, rhs) => rhs match { // tuple types require special handling before they are used as keys - case tp: Tuple => - val TupleType(tps) = tp.getType - Right(TupleType(tps.map { TypeOps.bestRealType })) - case FunctionInvocation(tfd, _) => Left(tfd) - case CaseClass(ct, _) => Right(ct) - } - } - val comp = new Comp(compKey) - val compList = compTerms.getOrElse(comp, { - val newl = new MutableList[Expr]() - compTerms += (comp -> newl) - newl - }) - compList += term - } - if (debugAliases) { - compTerms.foreach { - case (_, v) => println("Aliases: " + v.mkString("{", ",", "}")) - } - } - compTerms - } - - /** - * Convert the theory formula into linear arithmetic formula. - * The calls could be functions calls or ADT constructor calls. - * 'predEval' is an evaluator that evaluates a predicate to a boolean value - * TODO: is type parameter inheritance handled correctly ? - */ - def constraintsForCalls(calls: Set[Expr], predEval: (Expr => Option[Boolean])): Seq[Expr] = { - - //check if two calls (to functions or ADT cons) have the same value in the model - def doesAlias(call1: Expr, call2: Expr): Option[Boolean] = { - val Operator(Seq(r1 @ Variable(_), _), _) = call1 - val Operator(Seq(r2 @ Variable(_), _), _) = call2 - predEval(Equals(r1, r2)) match { - case Some(true) if isCallExpr(call1) => - val (ants, _) = axiomatizeCalls(call1, call2) - val antsEvals = ants.map(ant => { - val Operator(Seq(lvar @ Variable(_), rvar @ Variable(_)), _) = ant - predEval(Equals(lvar, rvar)) - }) - // return `false` if at least one argument is false - if (antsEvals.exists(_ == Some(false))) Some(false) - else if (antsEvals.exists(!_.isDefined)) None // here, we cannot decide if the call is true or false - else Some(true) - case r => r - } - } - - def predForEquality(call1: Expr, call2: Expr): Seq[Expr] = { - val eqs = if (isCallExpr(call1)) { - val (_, rhs) = axiomatizeCalls(call1, call2) - Seq(rhs) - } else { - val (lhs, rhs) = axiomatizeADTCons(call1, call2) - lhs :+ rhs - } - //remove self equalities. - val preds = eqs.filter { - case Operator(Seq(Variable(lid), Variable(rid)), _) => { - if (lid == rid) false - else { - if (lid.getType == Int32Type || lid.getType == RealType || lid.getType == IntegerType) true - else false - } - } - case e @ _ => throw new IllegalStateException("Not an equality or Iff: " + e) - } - preds - } - - def predForDisequality(call1: Expr, call2: Expr): Seq[Expr] = { - val (ants, _) = if (isCallExpr(call1)) { - axiomatizeCalls(call1, call2) - } else { - axiomatizeADTCons(call1, call2) - } - if (makeEfficient && ants.exists { - case Equals(l, r) if (l.getType != RealType && l.getType != BooleanType && l.getType != IntegerType) => true - case _ => false - }) { - Seq() - } else { - var unsatIntEq: Option[Expr] = None - var unsatOtherEq: Option[Expr] = None - ants.foreach(eq => - if (unsatOtherEq.isEmpty) { - eq match { - case Equals(lhs @ Variable(_), rhs @ Variable(_)) if predEval(Equals(lhs, rhs)) == Some(false) => { // there must exist at least one such predicate - if (lhs.getType != Int32Type && lhs.getType != RealType && lhs.getType != IntegerType) - unsatOtherEq = Some(eq) - else if (unsatIntEq.isEmpty) - unsatIntEq = Some(eq) - } - case _ => ; - } - }) - if (unsatOtherEq.isDefined) Seq() //need not add any constraint - else if (unsatIntEq.isDefined) { - //pick the constraint a < b or a > b that is satisfied - val Equals(lhs @ Variable(_), rhs @ Variable(_)) = unsatIntEq.get - val lLTr = LessThan(lhs, rhs) - predEval(lLTr) match { - case Some(true) => Seq(lLTr) - case Some(false) => Seq(GreaterThan(lhs, rhs)) - case _ => Seq() // actually this case cannot happen. - } - } else throw new IllegalStateException("All arguments are equal: " + (call1, call2)) - } - } - - var equivClasses = new DisjointSets[Expr]() - var neqSet = MutableSet[(Expr, Expr)]() - val termClasses = collectCompatibleTerms(calls) - val preds = MutableList[Expr]() - termClasses.foreach { - case (_, compTerms) => - val vec = compTerms.toArray - val size = vec.size - vec.zipWithIndex.foreach { - case (t1, j) => - (j + 1 until size).foreach { i => - val t2 = vec(i) - if (compatibleTArgs(termTArgs(t1), termTArgs(t2))) { - //note: here we omit constraints that encode transitive equality facts - val class1 = equivClasses.findOrCreate(t1) - val class2 = equivClasses.findOrCreate(t2) - if (class1 != class2 && !neqSet.contains((t1, t2)) && !neqSet.contains((t2, t1))) { - doesAlias(t1, t2) match { - case Some(true) => - equivClasses.union(class1, class2) - preds ++= predForEquality(t1, t2) - case Some(false) => - neqSet ++= Set((t1, t2)) - preds ++= predForDisequality(t1, t2) - case _ => - // in this case, we construct a weaker disjunct by dropping this predicate - } - } - } - } - } - } - Stats.updateCounterStats(preds.size, "CallADT-Constraints", "disjuncts") - preds.toSeq - } - - def termTArgs(t: Expr) = { - t match { - case Equals(_, e) => - e match { - case FunctionInvocation(TypedFunDef(_, tps), _) => tps - case CaseClass(ct, _) => ct.tps - case tp: Tuple => - val TupleType(tps) = tp.getType - tps - } - } - } - - /** - * This function actually checks if two non-primitive expressions could have the same value - * (when some constraints on their arguments hold). - * Remark: notice that when the expressions have ADT types, then this is basically a form of may-alias check. - * TODO: handling type parameters can become very trickier here. - * For now ignoring type parameters of functions and classes. (This is complete, but may be less efficient) - */ - def mayAlias(term1: Either[TypedFunDef, TypeTree], term2: Either[TypedFunDef, TypeTree]): Boolean = { - (term1, term2) match { - case (Left(TypedFunDef(fd1, _)), Left(TypedFunDef(fd2, _))) => - fd1.id == fd2.id - case (Right(ct1: CaseClassType), Right(ct2: CaseClassType)) => - ct1.classDef.id == ct2.classDef.id - case (Right(tp1 @ TupleType(tps1)), Right(tp2 @ TupleType(tps2))) if tp1.dimension == tp2.dimension => - compatibleTArgs(tps1, tps2) //get the types and check if the types are compatible - case _ => false - } - } - - def compatibleTArgs(tps1: Seq[TypeTree], tps2: Seq[TypeTree]): Boolean = { - (tps1 zip tps2).forall { - case (t1, t2) => - val lub = TypeOps.leastUpperBound(t1, t2) - (lub == Some(t1) || lub == Some(t2)) // is t1 a super type of t2 - } - } - - /** - * This procedure generates constraints for the calls to be equal - */ - def axiomatizeCalls(call1: Expr, call2: Expr): (Seq[Expr], Expr) = { - val (v1, fi1, v2, fi2) = { - val Equals(r1, f1 @ FunctionInvocation(_, _)) = call1 - val Equals(r2, f2 @ FunctionInvocation(_, _)) = call2 - (r1, f1, r2, f2) - } - - val ants = (fi1.args.zip(fi2.args)).foldLeft(Seq[Expr]())((acc, pair) => { - val (arg1, arg2) = pair - acc :+ Equals(arg1, arg2) - }) - val conseq = Equals(v1, v2) - (ants, conseq) - } - - /** - * The returned pairs should be interpreted as a bidirectional implication - */ - def axiomatizeADTCons(sel1: Expr, sel2: Expr): (Seq[Expr], Expr) = { - val (v1, args1, v2, args2) = sel1 match { - case Equals(r1 @ Variable(_), CaseClass(_, a1)) => { - val Equals(r2 @ Variable(_), CaseClass(_, a2)) = sel2 - (r1, a1, r2, a2) - } - case Equals(r1 @ Variable(_), Tuple(a1)) => { - val Equals(r2 @ Variable(_), Tuple(a2)) = sel2 - (r1, a1, r2, a2) - } - } - val ants = (args1.zip(args2)).foldLeft(Seq[Expr]())((acc, pair) => { - val (arg1, arg2) = pair - acc :+ Equals(arg1, arg2) - }) - val conseq = Equals(v1, v2) - (ants, conseq) - } -} diff --git a/src/main/scala/leon/invariant/templateSolvers/UniversalQuantificationSolver.scala b/src/main/scala/leon/invariant/templateSolvers/UniversalQuantificationSolver.scala deleted file mode 100644 index db02d58c035fcfc08b06cd5d98347dfa34c4d9b7..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/invariant/templateSolvers/UniversalQuantificationSolver.scala +++ /dev/null @@ -1,402 +0,0 @@ -package leon -package invariant.templateSolvers - -import z3.scala._ -import purescala._ -import purescala.Common._ -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Extractors._ -import purescala.Types._ -import evaluators._ -import java.io._ -import solvers._ -import solvers.combinators._ -import solvers.smtlib._ -import solvers.z3._ -import scala.util.control.Breaks._ -import purescala.ScalaPrinter -import scala.collection.mutable.{ Map => MutableMap } -import scala.reflect.runtime.universe -import invariant.engine._ -import invariant.factories._ -import invariant.util._ -import invariant.util.ExpressionTransformer._ -import invariant.structure._ -import invariant.structure.FunctionUtils._ -import Stats._ -import leon.evaluators._ -import EvaluationResults._ - -import Util._ -import PredicateUtil._ -import SolverUtil._ - -class UniversalQuantificationSolver(ctx: InferenceContext, program: Program, - funs: Seq[FunDef], ctrTracker: ConstraintTracker, - minimizer: Option[(Expr, Model) => Model]) { - - import NLTemplateSolver._ - - //flags controlling debugging - val debugUnflattening = false - val debugIncrementalVC = false - val trackCompressedVCCTime = false - - val printCounterExample = false - val dumpInstantiatedVC = false - - val reporter = ctx.reporter - val timeout = ctx.vcTimeout - val leonctx = ctx.leonContext - - //flag controlling behavior - val disableCegis = true - private val useIncrementalSolvingForVCs = true - private val usePortfolio = false // portfolio has a bug in incremental solving - - val defaultEval = new DefaultEvaluator(leonctx, program) // an evaluator for extracting models - val existSolver = new ExistentialQuantificationSolver(ctx, program, ctrTracker, defaultEval) - - val solverFactory = - if (usePortfolio) { - if (useIncrementalSolvingForVCs) - throw new IllegalArgumentException("Cannot perform incremental solving with portfolio solvers!") - new PortfolioSolverFactory(leonctx.toSctx, Seq( - SolverFactory.getFromName(leonctx, program)("smt-cvc4-u"), - SolverFactory.getFromName(leonctx, program)("smt-z3-u"))) - } else - SolverFactory.uninterpreted(leonctx, program) - - def splitVC(fd: FunDef) = { - val (paramPart, rest, modCons) = - time { ctrTracker.getVC(fd).toUnflatExpr } { - t => Stats.updateCounterTime(t, "UnflatTime", "VC-refinement") - } - if (ctx.usereals) { - (IntLiteralToReal(paramPart), IntLiteralToReal(rest), modCons) - } else (paramPart, rest, modCons) - } - - case class FunData(modelCons: (Model, DefaultEvaluator) => FlatModel, paramParts: Expr, simpleParts: Expr) - val funInfos = funs.map { fd => - val (paramPart, rest, modelCons) = splitVC(fd) - if (hasReals(rest) && hasInts(rest)) - throw new IllegalStateException("Non-param Part has both integers and reals: " + rest) - if (debugIncrementalVC) { - assert(getTemplateVars(rest).isEmpty) - println("For function: " + fd.id) - println("Param part: " + paramPart) - } - (fd -> FunData(modelCons, paramPart, rest)) - }.toMap - - var funSolvers = initializeSolvers - def initializeSolvers = - if (!ctx.abort) { // this is required to ensure that solvers are not created after interrupts - funInfos.map { - case (fd, FunData(_, _, rest)) => - val vcSolver = solverFactory.getNewSolver() - vcSolver.assertCnstr(rest) - (fd -> vcSolver) - }.toMap - } else Map[FunDef, Solver with TimeoutSolver]() - - def free = { - if (useIncrementalSolvingForVCs) - funSolvers.foreach(entry => entry._2.free) - } - - /** - * State for minimization - */ - class MinimizationInfo { - var minStarted = false - var lastCorrectModel: Option[Model] = None - var minStartTime: Long = 0 // for stats - - def started = minStarted - def reset() = { - minStarted = false - lastCorrectModel = None - } - def updateProgress(model: Model) { - lastCorrectModel = Some(model) - if (!minStarted) { - minStarted = true - minStartTime = System.currentTimeMillis() - } - } - def complete { - reset() - /*val mintime = (System.currentTimeMillis() - minStartTime) - Stats.updateCounterTime(mintime, "minimization-time", "procs") - Stats.updateCumTime(mintime, "Total-Min-Time")*/ - } - def getLastCorrectModel = lastCorrectModel - } - - /** - * State for recording diffcult paths - */ - class DifficultPaths { - var paths = MutableMap[FunDef, Seq[Expr]]() - - def addPath(fd: FunDef, cePath: Expr): Unit = { - if (paths.contains(fd)) { - paths.update(fd, cePath +: paths(fd)) - } else { - paths += (fd -> Seq(cePath)) - } - } - def get(fd: FunDef) = paths.get(fd) - def hasPath(fd: FunDef) = paths.contains(fd) - def pathsToExpr(fd: FunDef) = Not(createOr(paths(fd))) - def size = paths.values.map(_.size).sum - } - - abstract class RefineRes - case class UnsolvableVC() extends RefineRes - case class NoSolution() extends RefineRes - case class CorrectSolution() extends RefineRes - case class NewSolution(tempModel: Model) extends RefineRes - - class ModelRefiner(tempModel: Model) { - val tempVarMap: Map[Expr, Expr] = tempModel.map { case (k, v) => (k.toVariable -> v) }.toMap - val seenPaths = new DifficultPaths() - private var callsInPaths = Set[Call]() - - def callsEncountered = callsInPaths - - def nextCandidate(conflicts: Seq[FunDef]): RefineRes = { - var newConflicts = Seq[FunDef]() - var blockedPaths = false - val newctrsOpt = conflicts.foldLeft(Some(Seq()): Option[Seq[Expr]]) { - case (None, _) => None - case _ if (ctx.abort) => None - case (Some(acc), fd) => - val disabledPaths = - if (seenPaths.hasPath(fd)) { - blockedPaths = true - seenPaths.pathsToExpr(fd) - } else tru - checkVCSAT(fd, tempModel, disabledPaths) match { - case (None, _) => None // VC cannot be decided - case (Some(false), _) => Some(acc) // VC is unsat - case (Some(true), univModel) => // VC is sat - newConflicts :+= fd - if (verbose) reporter.info("Function: " + fd.id + "--Found candidate invariant is not a real invariant! ") - if (printCounterExample) { - reporter.info("Model: " + univModel) - } - // generate constraints for making preventing the model - val (existCtr, linearpath, calls) = existSolver.generateCtrsForUNSAT(fd, univModel, tempModel) - if (existCtr == tru) throw new IllegalStateException("Cannot find a counter-example path!!") - callsInPaths ++= calls - //instantiate the disjunct - val cePath = simplifyArithmetic(TemplateInstantiator.instantiate( - createAnd(linearpath.map(_.template)), tempVarMap)) - //some sanity checks - if (variablesOf(cePath).exists(TemplateIdFactory.IsTemplateIdentifier _)) - throw new IllegalStateException("Found template identifier in counter-example disjunct: " + cePath) - seenPaths.addPath(fd, cePath) - Some(acc :+ existCtr) - } - } - newctrsOpt match { - case None => // give up, the VC cannot be decided - UnsolvableVC() - case Some(newctrs) if (newctrs.isEmpty) => - if (!blockedPaths) { //yes, hurray,found an inductive invariant - CorrectSolution() - } else { - //give up, only hard paths remaining - reporter.info("- Exhausted all easy paths !!") - reporter.info("- Number of remaining hard paths: " + seenPaths.size) - NoSolution() //TODO: what to unroll here ? - } - case Some(newctrs) => - existSolver.solveConstraints(newctrs, tempModel) match { - case (None, _) => - //here we have timed out while solving the non-linear constraints - if (verbose) - reporter.info("NLsolver timed-out on the disjunct... blocking this disjunct...") - Stats.updateCumStats(1, "retries") - nextCandidate(newConflicts) - case (Some(false), _) => // template not solvable, need more unrollings here - NoSolution() - case (Some(true), nextModel) => - NewSolution(nextModel) - } - } - } - def nextCandidate: RefineRes = nextCandidate(funs) - } - - /** - * @param foundModel a call-back that will be invoked every time a new model is found - */ - def solveUNSAT(initModel: Model, foundModel: Model => Unit): (Option[Model], Option[Set[Call]]) = { - val minInfo = new MinimizationInfo() - var sat: Option[Boolean] = Some(true) - var tempModel = initModel - var callsInPaths = Set[Call]() - var minimized = false - while (sat == Some(true) && !ctx.abort) { - Stats.updateCounter(1, "disjuncts") - if (verbose) { - reporter.info("Candidate invariants") - TemplateInstantiator.getAllInvariants(tempModel, ctrTracker.getFuncs).foreach{ - case(f, inv) => reporter.info(f.id + "-->" + PrettyPrinter(inv)) - } - } - val modRefiner = new ModelRefiner(tempModel) - sat = modRefiner.nextCandidate match { - case CorrectSolution() if (minimizer.isDefined && !minimized) => - minInfo.updateProgress(tempModel) - val minModel = minimizer.get(existSolver.getSolvedCtrs, tempModel) - minimized = true - if (minModel.toMap == tempModel.toMap) { - minInfo.complete - Some(false) - } else { - tempModel = minModel - Some(true) - } - case CorrectSolution() => // minimization has completed or is not applicable - minInfo.complete - Some(false) - case NewSolution(newModel) => - foundModel(newModel) - minimized = false - tempModel = newModel - Some(true) - case NoSolution() => // here template is unsolvable or only hard paths remain - None - case UnsolvableVC() if minInfo.started => - tempModel = minInfo.getLastCorrectModel.get - Some(false) - case UnsolvableVC() if !ctx.abort => - if (verbose) { - reporter.info("VC solving failed!...retrying with a bigger model...") - } - existSolver.solveConstraints(retryStrategy(tempModel), tempModel) match { - case (Some(true), newModel) => - foundModel(newModel) - tempModel = newModel - funSolvers = initializeSolvers // reinitialize all VC solvers as they all timed out - Some(true) - case _ => // give up, no other bigger invariant exist or existential solving timed out! - None - } - case _ => None - } - callsInPaths ++= modRefiner.callsEncountered - } - sat match { - case _ if ctx.abort => (None, None) - case None => (None, Some(callsInPaths)) //cannot solve template, more unrollings - case _ => (Some(tempModel), None) // template solved - } - } - - /** - * Strategy: try to find a value for templates that is bigger than the current value - */ - import RealValuedExprEvaluator._ - val rtwo = FractionalLiteral(2, 1) - def retryStrategy(tempModel: Model): Seq[Expr] = { - tempModel.map { - case (id, z @ FractionalLiteral(n, _)) if n == 0 => GreaterThan(id.toVariable, z) - case (id, fl: FractionalLiteral) => GreaterThan(id.toVariable, evaluate(RealTimes(rtwo, fl))) - }.toSeq - } - - protected def instantiateTemplate(e: Expr, tempVarMap: Map[Expr, Expr]): Expr = { - if (ctx.usereals) replace(tempVarMap, e) - else - simplifyArithmetic(TemplateInstantiator.instantiate(e, tempVarMap)) - } - - /** - * Checks if the VC of fd is unsat - */ - def checkVCSAT(fd: FunDef, tempModel: Model, disabledPaths: Expr): (Option[Boolean], LazyModel) = { - val tempIdMap = tempModel.toMap - val tempVarMap: Map[Expr, Expr] = tempIdMap.map { case (k, v) => k.toVariable -> v }.toMap - val funData = funInfos(fd) - val (solver, instExpr, modelCons) = - if (useIncrementalSolvingForVCs) { - val instParamPart = instantiateTemplate(funData.paramParts, tempVarMap) - (funSolvers(fd), And(instParamPart, disabledPaths), funData.modelCons) - } else { - val FunData(modCons, paramPart, rest) = funData - val instPart = instantiateTemplate(paramPart, tempVarMap) - (solverFactory.getNewSolver(), createAnd(Seq(rest, instPart, disabledPaths)), modCons) - } - //For debugging - if (dumpInstantiatedVC) { - val fullExpr = if (useIncrementalSolvingForVCs) And(funData.simpleParts, instExpr) else instExpr - ExpressionTransformer.PrintWithIndentation("vcInst", fullExpr) - } - // sanity check - if (hasMixedIntReals(instExpr)) - throw new IllegalStateException("Instantiated VC of " + fd.id + " contains mixed integer/reals: " + instExpr) - //reporter.info("checking VC inst ...") - solver.setTimeout(timeout * 1000) - val (res, packedModel) = - time { - if (useIncrementalSolvingForVCs) { - solver.push - solver.assertCnstr(instExpr) - val solRes = solver.check match { - case _ if ctx.abort => - (None, Model.empty) - case r @ Some(true) => - (r, solver.getModel) - case r => (r, Model.empty) - } - if (solRes._1.isDefined) // invoking pop() otherwise will throw an exception - solver.pop() - solRes - } else - SimpleSolverAPI(SolverFactory(solver.name, () => solver)).solveSAT(instExpr) - } { vccTime => - if (verbose) reporter.info("checked VC inst... in " + vccTime / 1000.0 + "s") - updateCounterTime(vccTime, "VC-check-time", "disjuncts") - updateCumTime(vccTime, "TotalVCCTime") - } - if (debugUnflattening) { - /*ctrTracker.getVC(fd).checkUnflattening(tempVarMap, - SimpleSolverAPI(SolverFactory(() => solverFactory.getNewSolver())), - defaultEval)*/ - verifyModel(funData.simpleParts, packedModel, SimpleSolverAPI(solverFactory)) - //val unflatPath = ctrTracker.getVC(fd).pickSatFromUnflatFormula(funData.simpleParts, packedModel, defaultEval) - } - //for statistics - if (trackCompressedVCCTime) { - val compressedVC = - unflatten(simplifyArithmetic(instantiateTemplate(ctrTracker.getVC(fd).eliminateBlockers, tempVarMap))) - Stats.updateCounterStats(atomNum(compressedVC), "Compressed-VC-size", "disjuncts") - time { - SimpleSolverAPI(solverFactory).solveSAT(compressedVC) - } { compTime => - Stats.updateCumTime(compTime, "TotalCompressVCCTime") - reporter.info("checked compressed VC... in " + compTime / 1000.0 + "s") - } - } - (res, modelCons(packedModel, defaultEval)) - } - - // cegis code, now not used - //val (cres, cctr, cmodel) = solveWithCegis(tempIds.toSet, createOr(newConfDisjuncts), inputCtr, Some(model)) - // def solveWithCegis(tempIds: Set[Identifier], expr: Expr, precond: Expr, initModel: Option[Model]): (Option[Boolean], Expr, Model) = { - // val cegisSolver = new CegisCore(ctx, program, timeout.toInt, NLTemplateSolver.this) - // val (res, ctr, model) = cegisSolver.solve(tempIds, expr, precond, solveAsInt = false, initModel) - // if (res.isEmpty) - // reporter.info("cegis timed-out on the disjunct...") - // (res, ctr, model) - // } - -} diff --git a/src/main/scala/leon/invariant/util/CallGraph.scala b/src/main/scala/leon/invariant/util/CallGraph.scala deleted file mode 100644 index ede408c37badd0a5727fe5e7f46e7afe099ec26f..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/invariant/util/CallGraph.scala +++ /dev/null @@ -1,116 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package invariant.util - -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.ExprOps._ -import ProgramUtil._ -import invariant.structure.FunctionUtils._ -import invariant.datastructure._ - -/** - * This represents a call graph of the functions in the program - */ -class CallGraph { - val graph = new DirectedGraph[FunDef]() - lazy val reverseCG = graph.reverse - - def addFunction(fd: FunDef) = graph.addNode(fd) - - def addEdgeIfNotPresent(src: FunDef, callee: FunDef): Unit = { - if (!graph.containsEdge(src, callee)) - graph.addEdge(src, callee) - } - - def callees(src: FunDef): Set[FunDef] = { - graph.getSuccessors(src) - } - - def transitiveCallees(src: FunDef): Set[FunDef] = { - graph.BFSReachables(Seq(src)) - } - - def transitiveCallers(dest: FunDef) : Set[FunDef] = { - reverseCG.BFSReachables(Seq(dest)) - } - - def transitiveCallees(srcs: Seq[FunDef]): Set[FunDef] = { - graph.BFSReachables(srcs) - } - - def transitiveCallers(dests: Seq[FunDef]) : Set[FunDef] = { - reverseCG.BFSReachables(dests) - } - - def isRecursive(fd: FunDef): Boolean = { - transitivelyCalls(fd, fd) - } - - /** - * Checks if the src transitively calls the procedure proc. - * Note: We cannot say that src calls itself even though source is reachable from itself in the callgraph - */ - def transitivelyCalls(src: FunDef, proc: FunDef): Boolean = { - graph.BFSReach(src, proc, excludeSrc = true) - } - - def calls(src: FunDef, proc: FunDef): Boolean = { - graph.containsEdge(src, proc) - } - - /** - * Sorting functions in reverse topological order. - * For functions within an SCC, we preserve the initial order - * given as input - */ - def reverseTopologicalOrder(initOrder: Seq[FunDef]): Seq[FunDef] = { - val orderMap = initOrder.zipWithIndex.toMap - graph.sccs.flatMap{scc => scc.sortWith((f1, f2) => orderMap(f1) <= orderMap(f2)) } - } - - override def toString: String = { - val procs = graph.getNodes - procs.foldLeft("")((acc, proc) => { - acc + proc.id + " --calls--> " + - graph.getSuccessors(proc).foldLeft("")((acc, succ) => acc + "," + succ.id) + "\n" - }) - } -} - -object CallGraphUtil { - - def constructCallGraph(prog: Program, - onlyBody: Boolean = false, - withTemplates: Boolean = false, - calleesFun: Expr => Set[FunDef] = getCallees): CallGraph = { - val cg = new CallGraph() - functionsWOFields(prog.definedFunctions).foreach{fd => - cg.addFunction(fd) - if (fd.hasBody) { - var funExpr = fd.body.get - if (!onlyBody) { - if (fd.hasPrecondition) - funExpr = Tuple(Seq(funExpr, fd.precondition.get)) - if (fd.hasPostcondition) - funExpr = Tuple(Seq(funExpr, fd.postcondition.get)) - } - if (withTemplates && fd.hasTemplate) { - funExpr = Tuple(Seq(funExpr, fd.getTemplate)) - } - //introduce a new edge for every callee - calleesFun(funExpr).foreach(cg.addEdgeIfNotPresent(fd, _)) - } - } - cg - } - - def getCallees(expr: Expr): Set[FunDef] = collect { - case expr@FunctionInvocation(TypedFunDef(callee, _), _) if callee.isRealFunction => - Set(callee) - case _ => - Set[FunDef]() - }(expr) - -} \ No newline at end of file diff --git a/src/main/scala/leon/invariant/util/ExprStructure.scala b/src/main/scala/leon/invariant/util/ExprStructure.scala deleted file mode 100644 index 3907ac376e7098415a13cb762ca3ca36cbb1c0dc..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/invariant/util/ExprStructure.scala +++ /dev/null @@ -1,78 +0,0 @@ -package leon -package invariant.util - -import purescala.ExprOps._ -import purescala.Expressions._ -import purescala.Extractors.Operator -import purescala.Types.{ClassType, TypeParameter} - -/** - * A class that looks for structural equality of expressions - * by ignoring the variable names. - * Useful for factoring common parts of two expressions into functions. - */ -class ExprStructure(val e: Expr) { - def structurallyEqual(e1: Expr, e2: Expr): Boolean = { - (e1, e2) match { - case (t1: Terminal, t2: Terminal) => - // we need to specially handle type parameters as they are not considered equal by default - (t1.getType, t2.getType) match { - case (ct1: ClassType, ct2: ClassType) => - if (ct1.classDef == ct2.classDef && ct1.tps.size == ct2.tps.size) { - (ct1.tps zip ct2.tps).forall { - case (TypeParameter(_), TypeParameter(_)) => - true - case (a, b) => - println(s"Checking Type arguments: $a, $b") - a == b - } - } else false - case (ty1, ty2) => ty1 == ty2 - } - case (Operator(args1, _), Operator(args2, _)) => - opEquals(e1, e2) && (args1.size == args2.size) && (args1 zip args2).forall { - case (a1, a2) => structurallyEqual(a1, a2) - } - case _ => - false - } - } - - def opEquals(e1: Expr, e2: Expr): Boolean = { - (e1, e2) match { - case (FunctionInvocation(tfd1, _), FunctionInvocation(tfd2, _)) - if tfd1.fd == tfd2.fd => true - case (CaseClass(cct1, _), CaseClass(cct2, _)) - if cct1.classDef == cct2.classDef => true - case (CaseClassSelector(cct1, _, fld1), CaseClassSelector(cct2, _, fld2)) - if cct1.classDef == cct2.classDef && fld1 == fld2 => true - case _ if e1.getClass.equals(e2.getClass) => true // check if e1 and e2 are same instances of the same class - case _ if e1.isInstanceOf[MethodInvocation] || e2.isInstanceOf[MethodInvocation] => - throw new IllegalArgumentException("MethodInvocations are not supported") - case _ => - //println(s"Not op equal: ($e1,$e2) classes: (${e1.getClass}, ${e2.getClass})") - false - } - } - - override def equals(other: Any) = { - other match { - case other: ExprStructure => - structurallyEqual(e, other.e) - case _ => - false - } - } - - val hashcode = { - var opndcount = 0 // operand count - var opcount = 0 // operator count - postTraversal { - case t: Terminal => opndcount += 1 - case _ => opcount += 1 - }(e) - (opndcount << 16) ^ opcount - } - - override def hashCode = hashcode -} \ No newline at end of file diff --git a/src/main/scala/leon/invariant/util/ExpressionTransformer.scala b/src/main/scala/leon/invariant/util/ExpressionTransformer.scala deleted file mode 100644 index 139319a3bf3767d24921596c3b8e16bc0e9cd695..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/invariant/util/ExpressionTransformer.scala +++ /dev/null @@ -1,608 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package invariant.util - -import purescala.Common._ -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Extractors._ -import purescala.Types._ -import java.io._ -import purescala.ScalaPrinter -import leon.invariant.factories.TemplateIdFactory -import PredicateUtil._ -import Util._ -import TVarFactory._ - -/** - * A collection of transformation on expressions and some utility methods. - * These operations are mostly semantic preserving (specific assumptions/requirements are specified on the operations) - */ -object ExpressionTransformer { - - // identifier for temporaries that are generated during flattening of terms other than functions - val flatContext = newContext - // temporaries used in the function flattening - val funFlatContext = newContext - // conversion of other language constructs - val langContext = newContext - - def createFlatTemp(name: String, tpe: TypeTree = Untyped) = createTemp(name, tpe, flatContext) - - /** - * This function conjoins the conjuncts created by 'transfomer' within the clauses containing Expr. - * This is meant to be used by operations that may flatten subexpression using existential quantifiers. - * @param insideFunction when set to true indicates that the newConjuncts (second argument) - * should not conjoined to the And(..) / Or(..) expressions found because they - * may be called inside a function. - * TODO: remove this function altogether and treat 'and' and 'or's as functions. - */ - def conjoinWithinClause(e: Expr, transformer: (Expr, Boolean) => (Expr, Set[Expr]), - insideFunction: Boolean): (Expr, Set[Expr]) = { - e match { - case And(args) if !insideFunction => - val newargs = args.map{arg => - val (nexp, ncjs) = transformer(arg, false) - createAnd(nexp +: ncjs.toSeq) - } - (createAnd(newargs), Set()) - - case Or(args) if !insideFunction => - val newargs = args.map{arg => - val (nexp, ncjs) = transformer(arg, false) - createAnd(nexp +: ncjs.toSeq) - } - (createOr(newargs), Set()) - - case t: Terminal => (t, Set()) - - case n @ Operator(args, op) => - var ncjs = Set[Expr]() - val newargs = args.map((arg) => { - val (nexp, js) = transformer(arg, true) - ncjs ++= js - nexp - }) - (op(newargs), ncjs) - case _ => throw new IllegalStateException("Impossible event: expr did not match any case: " + e) - } - } - - /** - * Assumed that that given expression has boolean type - * converting if-then-else and let into a logical formula - */ - def reduceLangBlocks(inexpr: Expr, multop: (Expr, Expr) => Expr) = { - - def transform(e: Expr, insideFunction: Boolean): (Expr, Set[Expr]) = { - e match { - // Handle asserts here. Return flattened body as the result - case as @ Assert(pred, _, body) => - val freshvar = createFlatTemp("asrtres", e.getType).toVariable - val newexpr = Equals(freshvar, body) - val resset = transform(newexpr, insideFunction) - (freshvar, resset._2 + resset._1) - - //handles division by constant - case Division(lhs, rhs @ InfiniteIntegerLiteral(v)) => - //this models floor and not integer division - val quo = createTemp("q", IntegerType, langContext).toVariable - var possibs = Seq[Expr]() - for (i <- v - 1 to 0 by -1) { - if (i == 0) possibs :+= Equals(lhs, Times(rhs, quo)) - else possibs :+= Equals(lhs, Plus(Times(rhs, quo), InfiniteIntegerLiteral(i))) - } - //compute the disjunction of all possibs - val newexpr = Or(possibs) - //println("newexpr: "+newexpr) - val resset = transform(newexpr, true) - (quo, resset._2 + resset._1) - - //handles division by variables - case Division(lhs, rhs) => - //this models floor and not integer division - val quo = createTemp("q", IntegerType, langContext).toVariable - val rem = createTemp("r", IntegerType, langContext).toVariable - val mult = multop(quo, rhs) - val divsem = Equals(lhs, Plus(mult, rem)) - //TODO: here, we have to use |rhs| - val newexpr = createAnd(Seq(divsem, LessEquals(zero, rem), LessEquals(rem, Minus(rhs, one)))) - val resset = transform(newexpr, true) - (quo, resset._2 + resset._1) - - case err @ Error(_, msg) => - //replace this by a fresh variable of the error type - (createTemp("err", err.getType, langContext).toVariable, Set[Expr]()) - - case Equals(lhs, rhs) => - val (nexp1, ncjs1) = transform(lhs, true) - val (nexp2, ncjs2) = transform(rhs, true) - (Equals(nexp1, nexp2), ncjs1 ++ ncjs2) - - case IfExpr(cond, thn, elze) if insideFunction => - val freshvar = createTemp("ifres", e.getType, langContext).toVariable - val (ncond, condConjs) = transform(cond, true) - val (nthen, thenConjs) = transform(Equals(freshvar, thn), false) - val (nelze, elzeConjs) = transform(Equals(freshvar, elze), false) - val conjs = condConjs + IfExpr(cond, - createAnd(nthen +: thenConjs.toSeq), createAnd(nelze +: elzeConjs.toSeq)) - (freshvar, conjs) - - case IfExpr(cond, thn, elze) => // here, we are at the top, and hence can avoid creating freshids - val (ncond, condConjs) = transform(cond, true) - val (nthen, thenConjs) = transform(thn, false) - val (nelze, elzeConjs) = transform(elze, false) - (IfExpr(cond, - createAnd(nthen +: thenConjs.toSeq), createAnd(nelze +: elzeConjs.toSeq)), condConjs) - - case Let(binder, value, body) => - //TODO: do we have to consider reuse of let variables ? - val (resbody, bodycjs) = transform(body, true) - val (resvalue, valuecjs) = transform(value, true) - (resbody, (valuecjs + Equals(binder.toVariable, resvalue)) ++ bodycjs) - - case _ => conjoinWithinClause(e, transform, insideFunction) - } - } - val (nexp, ncjs) = transform(inexpr, false) - val res = if (ncjs.nonEmpty) { - createAnd(nexp +: ncjs.toSeq) - } else nexp - res - } - - def isAtom(e: Expr): Boolean = e match { - case _: And | _: Or | _: IfExpr => false - case _ => true - } - - def isADTTheory(e: Expr) = e match { - case _: CaseClassSelector | _: CaseClass | _: TupleSelect | _: Tuple | _: IsInstanceOf => true - case _ => false - } - - def isSetTheory(e: Expr) = e match { - case _: SetUnion | _: ElementOfSet | _: SubsetOf | _: FiniteSet => true - case _ => false - } - - /** - * Requires: The expression has to be in NNF form and without if-then-else and let constructs - * Assumed that that given expression has boolean type - * (a) the function replaces every function call by a variable and creates a new equality - * (b) it also replaces arguments that are not variables by fresh variables and creates - * a new equality mapping the fresh variable to the argument expression - */ - def FlattenFunction(inExpr: Expr): Expr = { - - /** - * First return value is the new expression. The second return value is the - * set of new conjuncts - * @param insideFunction when set to true indicates that the newConjuncts (second argument) - * should not conjoined to the And(..) / Or(..) expressions found because they - * may be called inside a function. - */ - def flattenFunc(e: Expr, insideFunction: Boolean): (Expr, Set[Expr]) = { - e match { - case fi @ FunctionInvocation(fd, args) => - val (newargs, newConjuncts) = flattenArgs(args, true) - val freshResVar = Variable(createTemp("r", fi.getType, funFlatContext)) - (freshResVar, newConjuncts + Equals(freshResVar, FunctionInvocation(fd, newargs))) - - case adte if isADTTheory(adte) => - val Operator(args, op) = adte - val freshName = adte match { - case _: IsInstanceOf => "ci" - case _: CaseClassSelector => "cs" - case _: CaseClass => "cc" - case _: TupleSelect => "ts" - case _: Tuple => "tp" - } - val freshVar = Variable(createFlatTemp(freshName, adte.getType)) - val (newargs, newcjs) = flattenArgs(args, true) - (freshVar, newcjs + Equals(freshVar, op(newargs))) - - case SetUnion(_, _) | ElementOfSet(_, _) | SubsetOf(_, _) => - val Operator(args, op) = e - val (Seq(a1, a2), newcjs) = flattenArgs(args, true) - val newexpr = op(Seq(a1, a2)) - val freshResVar = Variable(createFlatTemp("set", e.getType)) - (freshResVar, newcjs + Equals(freshResVar, newexpr)) - - case fs @ FiniteSet(es, typ) => - val args = es.toSeq - val (nargs, newcjs) = flattenArgs(args, true) - val newexpr = FiniteSet(nargs.toSet, typ) - val freshResVar = Variable(createFlatTemp("fset", fs.getType)) - (freshResVar, newcjs + Equals(freshResVar, newexpr)) - - case And(args) if insideFunction => - val (nargs, cjs) = flattenArithmeticCtrs(args) - (And(nargs), cjs) - - case Or(args) if insideFunction => - val (nargs, cjs) = flattenArithmeticCtrs(args) - (Or(nargs), cjs) - - case IfExpr(cond, thn, elze) => // make condition of if-then-elze an atom - val (nthen, thenConjs) = flattenFunc(thn, false) - val (nelze, elzeConjs) = flattenFunc(elze, false) - val (ncond, condConjs) = flattenFunc(cond, true) match { - case r@(nc, _) if isAtom(nc) && getTemplateIds(nc).isEmpty => r - case (nc, conjs) => - val condvar = createFlatTemp("cond", cond.getType).toVariable - (condvar, conjs + Equals(condvar, nc)) - } - (IfExpr(ncond, createAnd(nthen +: thenConjs.toSeq), - createAnd(nelze +: elzeConjs.toSeq)), condConjs) - - case _ => conjoinWithinClause(e, flattenFunc, insideFunction) - } - } - - def flattenArgs(args: Seq[Expr], insideFunction: Boolean): (Seq[Expr], Set[Expr]) = { - var newConjuncts = Set[Expr]() - val newargs = args.map { - case v: Variable => v - case r: ResultVariable => r - case arg => - val (nexpr, ncjs) = flattenFunc(arg, insideFunction) - newConjuncts ++= ncjs - nexpr match { - case v: Variable => v - case r: ResultVariable => r - case _ => - val freshArgVar = Variable(createFlatTemp("arg", arg.getType)) - newConjuncts += Equals(freshArgVar, nexpr) - freshArgVar - } - } - (newargs, newConjuncts) - } - - def flattenArithmeticCtrs(args: Seq[Expr]) = { - val (flatArgs, cjs) = flattenArgs(args, true) - var ncjs = Set[Expr]() - val nargs = flatArgs.map { - case farg if isArithmeticRelation(farg) != Some(false) => - // 'farg' is a possibly arithmetic relation. - val argvar = createFlatTemp("ar", farg.getType).toVariable - ncjs += Equals(argvar, farg) - argvar - case farg => farg - } - (nargs, cjs ++ ncjs) - } - - val (nexp, ncjs) = flattenFunc(inExpr, false) - if (ncjs.nonEmpty) { - createAnd(nexp +: ncjs.toSeq) - } else nexp - } - - /** - * note: we consider even type parameters as ADT type - */ - def adtType(e: Expr) = { - val tpe = e.getType - tpe.isInstanceOf[ClassType] || tpe.isInstanceOf[TupleType] || tpe.isInstanceOf[TypeParameter] - } - - /** - * The following procedure converts the formula into negated normal form by pushing all not's inside. - * It will not convert boolean equalities or inequalities to disjunctions for performance. - * Assumption: - * (a) the formula does not have match constructs - * (b) all lets have been pulled to the top - * Some important features. - * (a) For a strict inequality with real variables/constants, the following produces a strict inequality - * (b) Strict inequalities with only integer variables/constants are reduced to non-strict inequalities - */ - def toNNF(inExpr: Expr, retainNEQ: Boolean = false): Expr = { - def nnf(expr: Expr): Expr = { -// /println("Invoking nnf on: "+expr) - expr match { - //case e if e.getType != BooleanType => e - case Not(Not(e1)) => nnf(e1) - case e @ Not(t: Terminal) => e - case Not(FunctionInvocation(tfd, args)) => Not(FunctionInvocation(tfd, args map nnf)) - case Not(And(args)) => createOr(args.map(arg => nnf(Not(arg)))) - case Not(Or(args)) => createAnd(args.map(arg => nnf(Not(arg)))) - case Not(Let(i, v, e)) => Let(i, nnf(v), nnf(Not(e))) - case Not(IfExpr(cond, thn, elze)) => IfExpr(nnf(cond), nnf(Not(thn)), nnf(Not(elze))) - case Not(e @ Operator(Seq(e1, e2), op)) => // Not of binary operator ? - e match { - case _: LessThan => GreaterEquals(e1, e2) - case _: LessEquals => GreaterThan(e1, e2) - case _: GreaterThan => LessEquals(e1, e2) - case _: GreaterEquals => LessThan(e1, e2) - case _: Implies => And(nnf(e1), nnf(Not(e2))) - case _: SubsetOf | _: ElementOfSet | _: SetUnion | _: FiniteSet => Not(e) // set ops - // handle equalities (which is shared by theories) - case _: Equals if e1.getType == BooleanType => Not(Equals(nnf(e1), nnf(e2))) - case _: Equals if adtType(e1) || e1.getType.isInstanceOf[SetType] => Not(e) // adt or set equality - case _: Equals if TypeUtil.isNumericType(e1.getType) => - if (retainNEQ) Not(Equals(e1, e2)) - else Or(nnf(LessThan(e1, e2)), nnf(GreaterThan(e1, e2))) - case _ => throw new IllegalStateException(s"Unknown binary operation: $e arg types: ${e1.getType},${e2.getType}") - } - case Implies(lhs, rhs) => nnf(Or(Not(lhs), rhs)) - case Equals(lhs, rhs @ (_: SubsetOf | _: ElementOfSet | _: IsInstanceOf | _: TupleSelect | _: CaseClassSelector)) => - Equals(nnf(lhs), rhs) - case Equals(lhs, FunctionInvocation(tfd, args)) => - Equals(nnf(lhs), FunctionInvocation(tfd, args map nnf)) - case Equals(lhs, rhs) if lhs.getType == BooleanType => Equals(nnf(lhs), nnf(rhs)) - case t: Terminal => t - case n @ Operator(args, op) => op(args map nnf) - case _ => throw new IllegalStateException("Impossible event: expr did not match any case: " + inExpr) - } - } - nnf(inExpr) - } - - /** - * Eliminates redundant nesting of ORs and ANDs. - * This is supposed to be a semantic preserving transformation - */ - def pullAndOrs(expr: Expr): Expr = { - simplePostTransform { - case Or(args) => - val newArgs = args.foldLeft(Seq[Expr]())((acc, arg) => arg match { - case Or(inArgs) => acc ++ inArgs - case _ => acc :+ arg - }) - createOr(newArgs) - case And(args) => - val newArgs = args.foldLeft(Seq[Expr]())((acc, arg) => arg match { - case And(inArgs) => acc ++ inArgs - case _ => acc :+ arg - }) - createAnd(newArgs) - case e => e - }(expr) - } - - /** - * Normalizes the expressions - */ - def normalizeExpr(expr: Expr, multOp: (Expr, Expr) => Expr): Expr = { - //println("Normalizing " + ScalaPrinter(expr) + "\n") - val redex = reduceLangBlocks(toNNF(matchToIfThenElse(expr)), multOp) - //println("After reducing lang blocks: " + ScalaPrinter(redex) + "\n") - val flatExpr = FlattenFunction(redex) - val simpExpr = pullAndOrs(flatExpr) - //println("After Normalizing: " + ScalaPrinter(flatExpr) + "\n") - simpExpr - } - - /** - * This is the inverse operation of flattening. - * This is used to produce a readable formula or more efficiently - * solvable formulas. - * Note: this is a helper method that assumes that 'flatIds' - * are not shared across disjuncts. - * If this is not guaranteed to hold, use the 'unflatten' method - */ - def simpleUnflattenWithMap(ine: Expr, excludeIds: Set[Identifier] = Set(), - includeFuns: Boolean): (Expr, Map[Identifier,Expr]) = { - - def isFlatTemp(id: Identifier) = - isTemp(id, flatContext) || (includeFuns && isTemp(id, funFlatContext)) - - var idMap = Map[Identifier, Expr]() - /** - * Here, relying on library transforms is dangerous as they - * can perform additional simplifications to the expression on-the-fly, - * which is not desirable here. - */ - def rec(e: Expr): Expr = e match { - case e @ Equals(Variable(id), rhs @ _) if isFlatTemp(id) && !excludeIds(id) => - val nrhs = rec(rhs) - if (idMap.contains(id)) Equals(Variable(id), nrhs) - else { - idMap += (id -> nrhs) - tru - } - // specially handle boolean function to prevent unnecessary simplifications - case Or(args) => Or(args map rec) - case And(args) => And(args map rec) - case IfExpr(cond, th, elze) => IfExpr(rec(cond), rec(th), rec(elze)) - case e => e // we should not recurse in other operations, note: Not(equals) should not be considered - } - val newe = rec(ine) - val closure = (e: Expr) => replaceFromIDs(idMap, e) - val rese = fix(closure)(newe) - (rese, idMap) - } - - def unflattenWithMap(ine: Expr, excludeIds: Set[Identifier] = Set(), - includeFuns: Boolean = true): (Expr, Map[Identifier,Expr]) = { - simpleUnflattenWithMap(ine, sharedIds(ine) ++ excludeIds, includeFuns) - } - - def unflatten(ine: Expr) = unflattenWithMap(ine)._1 - - /** - * convert all integer constants to real constants - */ - def IntLiteralToReal(inexpr: Expr): Expr = { - val transformer = (e: Expr) => e match { - case InfiniteIntegerLiteral(v) => FractionalLiteral(v, 1) - case IntLiteral(v) => FractionalLiteral(v, 1) - case _ => e - } - simplePostTransform(transformer)(inexpr) - } - - /** - * convert all real constants to integers - */ - def FractionalLiteralToInt(inexpr: Expr): Expr = { - val transformer = (e: Expr) => e match { - case FractionalLiteral(v, `bone`) => InfiniteIntegerLiteral(v) - case FractionalLiteral(_, _) => throw new IllegalStateException("cannot convert real literal to integer: " + e) - case _ => e - } - simplePostTransform(transformer)(inexpr) - } - - /** - * A hacky way to implement subexpression check. - * TODO: fix this - */ - def isSubExpr(key: Expr, expr: Expr): Boolean = { - var found = false - simplePostTransform { - case e if (e == key) => - found = true; e - case e => e - }(expr) - found - } - - /** - * Some simplification rules (keep adding more and more rules) - */ - def simplify(expr: Expr): Expr = { - //Note: some simplification are already performed by the class constructors (see Tree.scala) - simplePostTransform { - case Equals(lhs, rhs) if (lhs == rhs) => tru - case LessEquals(lhs, rhs) if (lhs == rhs) => tru - case GreaterEquals(lhs, rhs) if (lhs == rhs) => tru - case LessThan(lhs, rhs) if (lhs == rhs) => fls - case GreaterThan(lhs, rhs) if (lhs == rhs) => fls - case UMinus(InfiniteIntegerLiteral(v)) => InfiniteIntegerLiteral(-v) - case Equals(InfiniteIntegerLiteral(v1), InfiniteIntegerLiteral(v2)) => BooleanLiteral(v1 == v2) - case LessEquals(InfiniteIntegerLiteral(v1), InfiniteIntegerLiteral(v2)) => BooleanLiteral(v1 <= v2) - case LessThan(InfiniteIntegerLiteral(v1), InfiniteIntegerLiteral(v2)) => BooleanLiteral(v1 < v2) - case GreaterEquals(InfiniteIntegerLiteral(v1), InfiniteIntegerLiteral(v2)) => BooleanLiteral(v1 >= v2) - case GreaterThan(InfiniteIntegerLiteral(v1), InfiniteIntegerLiteral(v2)) => BooleanLiteral(v1 > v2) - case e => e - }(expr) - } - - /** - * Input expression is assumed to be in nnf form - * Note: (a) Not(Equals()) and Not(Variable) is allowed - */ - def isDisjunct(e: Expr): Boolean = e match { - case And(args) => args.forall(arg => isDisjunct(arg)) - case Not(Equals(_, _)) | Not(Variable(_)) => true - case Or(_) | Implies(_, _) | Not(_) | Equals(_, _) => false - case _ => true - } - - /** - * assuming that the expression is in nnf form - * Note: (a) Not(Equals()) and Not(Variable) is allowed - */ - def isConjunct(e: Expr): Boolean = e match { - case Or(args) => args.forall(arg => isConjunct(arg)) - case Not(Equals(_, _)) | Not(Variable(_)) => true - case And(_) | Implies(_, _) | Not(_) | Equals(_, _) => false - case _ => true - } - - def PrintWithIndentation(filePrefix: String, expr: Expr): Unit = { - - val filename = filePrefix + FileCountGUID.getID + ".txt" - val wr = new PrintWriter(new File(filename)) - - def uniOP(e: Expr, seen: Int): Boolean = e match { - case And(args) => { - //have we seen an or ? - if (seen == 2) false - else args.forall(arg => uniOP(arg, 1)) - } - case Or(args) => { - //have we seen an And ? - if (seen == 1) false - else args.forall(arg => uniOP(arg, 2)) - } - case t: Terminal => true - case n @ Operator(args, op) => args.forall(arg => uniOP(arg, seen)) - } - - def printRec(e: Expr, indent: Int): Unit = { - if (uniOP(e, 0)) wr.println(ScalaPrinter(e)) - else { - wr.write("\n" + " " * indent + "(\n") - e match { - case And(args) => { - var start = true - args.foreach((arg) => { - wr.print(" " * (indent + 1)) - if (!start) wr.print("^") - printRec(arg, indent + 1) - start = false - }) - } - case Or(args) => { - var start = true - args.foreach((arg) => { - wr.print(" " * (indent + 1)) - if (!start) wr.print("v") - printRec(arg, indent + 1) - start = false - }) - } - case _ => throw new IllegalStateException("how can this happen ? " + e) - } - wr.write(" " * indent + ")\n") - } - } - printRec(expr, 0) - wr.close() - println("Printed to file: " + filename) - } - - /** - * Converts to sum of products form by distributing - * multiplication over addition - */ - def normalizeMultiplication(e: Expr, multop: (Expr, Expr) => Expr): Expr = { - - def isConstantOrTemplateVar(e: Expr) = { - e match { - case l: Literal[_] => true - case Variable(id) if TemplateIdFactory.IsTemplateIdentifier(id) => true - case _ => false - } - } - - def distribute(e: Expr): Expr = { - simplePreTransform { - case e @ FunctionInvocation(TypedFunDef(fd, _), Seq(e1, e2)) if isMultFunctions(fd) => - val newe = (e1, e2) match { - case (Plus(sum1, sum2), _) => - // distribute e2 over e1 - Plus(multop(sum1, e2), multop(sum2, e2)) - case (_, Plus(sum1, sum2)) => - // distribute e1 over e2 - Plus(multop(e1, sum1), multop(e1, sum2)) - case (Times(arg1, arg2), _) => - // pull the constants out of multiplication (note: times is used when one of the arguments is a literal or template id - if (isConstantOrTemplateVar(arg1)) { - Times(arg1, multop(arg2, e2)) - } else - Times(arg2, multop(arg1, e2)) // here using commutativity axiom - case (_, Times(arg1, arg2)) => - if (isConstantOrTemplateVar(arg1)) - Times(arg1, multop(e1, arg2)) - else - Times(arg2, multop(e1, arg1)) - case _ if isConstantOrTemplateVar(e1) || isConstantOrTemplateVar(e2) => - // here one of the operands is a literal or template var, so convert mult to times and continue - Times(e1, e2) - case _ => - e - } - newe - case other => other - }(e) - } - distribute(e) - } -} diff --git a/src/main/scala/leon/invariant/util/LetTupleSimplification.scala b/src/main/scala/leon/invariant/util/LetTupleSimplification.scala deleted file mode 100644 index 72b814904a28207791043588d4bbbc908d0663a3..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/invariant/util/LetTupleSimplification.scala +++ /dev/null @@ -1,484 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package invariant.util - -import purescala.Common._ -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Extractors._ -import purescala.Types._ -import java.io._ -import java.io._ -import purescala.ScalaPrinter -import leon.utils._ -import PredicateUtil._ -import invariant.structure.Call -import invariant.structure.FunctionUtils._ -import leon.transformations.InstUtil._ -import TVarFactory._ - -/** - * A collection of transformation on expressions and some utility methods. - * These operations are mostly semantic preserving (specific assumptions/requirements are specified on the operations) - */ -object LetTupleSimplification { - - val zero = InfiniteIntegerLiteral(0) - val one = InfiniteIntegerLiteral(1) - val mone = InfiniteIntegerLiteral(-1) - val tru = BooleanLiteral(true) - val fls = BooleanLiteral(false) - val bone = BigInt(1) - // fresh ids created during simplification - val simpContext = newContext - - def letSanityChecks(ine: Expr) = { - simplePostTransform(_ match { - case letExpr @ Let(binderId, letValue, body) if (binderId.getType != letValue.getType) => - throw new IllegalStateException("Binder and value type mismatch: " + - s"(${binderId.getType},${letValue.getType})") - case e => e - })(ine) - } - - /** - * This function simplifies lets of the form <Var> = <TupleType Expr> by replacing - * uses of the <Var>._i by the approppriate expression in the let body or by - * introducing a new let <Var'> = <Var>._i and using <Var'> in place of <Var>._i - * in the original let body. - * Caution: this function may not be idempotent. - */ - def simplifyTuples(ine: Expr): Expr = { - - var processedLetBinders = Set[Identifier]() - def recSimplify(e: Expr, replaceMap: Map[Expr, Expr]): Expr = { - - //println("Before: "+e) - val transe = e match { - case letExpr @ Let(binderId, letValue, body) if !processedLetBinders(binderId) => - processedLetBinders += binderId - // transform the 'letValue' with the current map - val nvalue = recSimplify(letValue, replaceMap) - // enrich the map if letValue is of tuple type - nvalue.getType match { - case TupleType(argTypes) => - var freshBinders = Set[Identifier]() - def freshBinder(typ: TypeTree) = { - val freshid = createTemp(binderId.name, typ, simpContext) - freshBinders += freshid - freshid.toVariable - } - val newmap: Map[Expr, Expr] = nvalue match { - case Tuple(args) => // this is an optimization for the case where nvalue is a tuple - args.zipWithIndex.map { - case (t: Terminal, index) => - (TupleSelect(binderId.toVariable, index + 1) -> t) - case (_, index) => - (TupleSelect(binderId.toVariable, index + 1) -> freshBinder(argTypes(index))) - }.toMap - case _ => - argTypes.zipWithIndex.map { - case (argtype, index) => - (TupleSelect(binderId.toVariable, index + 1) -> freshBinder(argtype)) - }.toMap - } - // transform the body using the new map + old map - val nbody = recSimplify(body, replaceMap ++ newmap) - val bodyFreevars = variablesOf(nbody) - // create a sequence of lets for the freshBinders - val nletBody = newmap.foldLeft(nbody) { - case (acc, (k, Variable(id))) if freshBinders(id) && bodyFreevars(id) => - // here, the 'id' is a newly created binder and is also used in the transformed body - Let(id, k, acc) - case (acc, _) => - acc - } - Let(binderId, nvalue, nletBody) - case _ => - // no simplification can be done in this step - Let(binderId, nvalue, recSimplify(body, replaceMap)) - } - case ts @ TupleSelect(_, _) if replaceMap.contains(ts) => - postMap(replaceMap.lift, true)(e) //perform recursive replacements to handle nested tuple selects - //replaceMap(ts) //replace tuple-selects in the map with the new identifier - - case ts @ TupleSelect(Tuple(subes), i) => - subes(i - 1) - - case t: Terminal => t - - case Operator(subes, op) => - op(subes.map(recSimplify(_, replaceMap))) - } - //println("After: "+e) - transe - } - fixpoint((e: Expr) => simplifyArithmetic(recSimplify(e, Map())))(ine) - } - - // sanity checks - def checkTupleSelectInsideMax(e: Expr): Boolean = { - //exists( predicate: Expr => Expr) (e) - var error = false - def helper(e: Expr): Unit = { - e match { - case FunctionInvocation(tfd, args) if (tfd.fd == maxFun) => { - - val Seq(arg1: Expr, arg2: Expr) = args - (arg1, arg2) match { - case (_: TupleSelect, _) => error = true - case (_, _: TupleSelect) => error = true - case _ => { ; } - } - } - - case _ => { ; } - } - } - - postTraversal(helper)(e) - error - } - - def simplifyMax(ine: Expr): Expr = { - val debugMaxSimplify = false - //computes a lower bound value, assuming that every sub-term used in the term is positive - //Note: this is applicable only to expressions involving depth - def positiveTermLowerBound(e: Expr): Int = e match { - case IntLiteral(v) => v - case Plus(l, r) => positiveTermLowerBound(l) + positiveTermLowerBound(r) - case FunctionInvocation(tfd, args) if (tfd.fd == maxFun) => { - val Seq(arg1, arg2) = args - val lb1 = positiveTermLowerBound(arg1) - val lb2 = positiveTermLowerBound(arg2) - if (lb1 >= lb2) lb1 else lb2 - } - case _ => 0 //other case are not handled as they do not appear - } - - //checks if 'sub' is subsumed by 'e' i.e, 'e' will always take a value - // greater than or equal to 'sub'. - //Assuming that every sub-term used in the term is positive - def subsumedBy(sub: Expr, e: Expr): Boolean = e match { - case _ if (sub == e) => true - case Plus(l, r) => subsumedBy(sub, l) || subsumedBy(sub, r) - case FunctionInvocation(tfd, args) if (tfd.fd == maxFun) => - val Seq(l, r) = args - subsumedBy(sub, l) || subsumedBy(sub, r) - case _ => false - } - - // in the sequel, we are using the fact that 'depth' is positive and - // 'ine' contains only 'depth' variables - val simpe = simplePostTransform((e: Expr) => e match { - case FunctionInvocation(tfd, args) if (tfd.fd == maxFun) => { - if (debugMaxSimplify) { - println("Simplifying: " + e) - } - val newargs: Seq[Expr] = args.map(simplifyArithmetic) - val Seq(arg1: Expr, arg2: Expr) = newargs - val simpval = if (!hasCalls(arg1) && !hasCalls(arg2)) { - import invariant.structure.LinearConstraintUtil._ - val lt = exprToTemplate(LessEquals(Minus(arg1, arg2), InfiniteIntegerLiteral(0))) - //now, check if all the variables in 'lt' have only positive coefficients - val allPositive = lt.coeffTemplate.forall(entry => entry match { - case (k, IntLiteral(v)) if (v >= 0) => true - case _ => false - }) && (lt.constTemplate match { - case None => true - case Some(IntLiteral(v)) if (v >= 0) => true - case _ => false - }) - if (allPositive) arg1 - else { - val allNegative = lt.coeffTemplate.forall(entry => entry match { - case (k, IntLiteral(v)) if (v <= 0) => true - case _ => false - }) && (lt.constTemplate match { - case None => true - case Some(IntLiteral(v)) if (v <= 0) => true - case _ => false - }) - if (allNegative) arg2 - else FunctionInvocation(tfd, newargs) //here we cannot do any simplification. - } - - } else { - (arg1, arg2) match { - case (IntLiteral(v), r) if (v <= positiveTermLowerBound(r)) => r - case (l, IntLiteral(v)) if (v <= positiveTermLowerBound(l)) => l - case (l, r) if subsumedBy(l, r) => r - case (l, r) if subsumedBy(r, l) => l - case _ => FunctionInvocation(tfd, newargs) - } - } - if (debugMaxSimplify) { - println("Simplified value: " + simpval) - } - simpval - } - // case FunctionInvocation(tfd, args) if(tfd.fd.id.name == "max") => { - // throw new IllegalStateException("Found just max in expression " + e + "\n") - // } - case _ => e - })(ine) - simpe - } - - def inlineMax(ine: Expr): Expr = { - //inline 'max' operations here - simplePostTransform((e: Expr) => e match { - case FunctionInvocation(tfd, args) if (tfd.fd == maxFun) => - val Seq(arg1, arg2) = args - val bindWithLet = (value: Expr, body: (Expr with Terminal) => Expr) => { - value match { - case t: Terminal => body(t) - case Let(id, v, b: Terminal) => - //here we can use 'b' in 'body' - Let(id, v, body(b)) - case _ => - val mt = createTemp("mt", value.getType, simpContext) - Let(mt, value, body(mt.toVariable)) - } - } - bindWithLet(arg1, a1 => bindWithLet(arg2, a2 => - IfExpr(GreaterEquals(a1, a2), a1, a2))) - case _ => e - })(ine) - } - - def removeLetsFromLetValues(ine: Expr): Expr = { - - /** - * Navigates through the sequence of lets in 'e' - * and replaces its 'let' free part by subst. - * Assuming that 'e' has only lets at the top and no nested lets in the value - */ - def replaceLetBody(e: Expr, subst: Expr => Expr): Expr = e match { - case Let(binder, letv, letb) => - Let(binder, letv, replaceLetBody(letb, subst)) - case _ => - subst(e) - } - - // the function removes the lets from the let values - // by pulling them out - def pullLetToTop(e: Expr): Expr = { - val transe = e match { - case Lambda(args, body) => - Lambda(args, pullLetToTop(body)) - case Ensuring(body, pred) => - Ensuring(pullLetToTop(body), pullLetToTop(pred)) - case Require(pre, body) => - Require(pullLetToTop(pre), pullLetToTop(body)) - - case letExpr @ Let(binder, letValue, body) => - // transform the 'letValue' with the current map - pullLetToTop(letValue) match { - case sublet @ Let(binder2, subvalue, subbody) => - // transforming "let v = (let v1 = e1 in e2) in e3" - // to "let v1 = e1 in (let v = e2 in e3)" - // here, subvalue is free of lets, but subbody may again be a let - val newbody = replaceLetBody(subbody, Let(binder, _, pullLetToTop(body))) - Let(binder2, subvalue, newbody) - case nval => - // here, there is no let in the value - Let(binder, nval, pullLetToTop(body)) - } - - //don't pull lets out of if-then-else branches and match cases - case IfExpr(c, th, elze) => - replaceLetBody(pullLetToTop(c), IfExpr(_, pullLetToTop(th), pullLetToTop(elze))) - - case MatchExpr(scr, cases) => - val newcases = cases.map { - case MatchCase(pat, guard, rhs) => - MatchCase(pat, guard map pullLetToTop, pullLetToTop(rhs)) - } - replaceLetBody(pullLetToTop(scr), MatchExpr(_, newcases)) - - case Operator(Seq(), op) => - op(Seq()) - - case t: Terminal => t - - /*case Operator(Seq(e1, e2), op) => - replaceLetBody(pullLetToTop(e1), te1 => - replaceLetBody(pullLetToTop(e2), te2 => op(Seq(te1, te2))))*/ - - // Note: it is necessary to handle unary operators specially - case Operator(Seq(sube), op) => - replaceLetBody(pullLetToTop(sube), e => op(Seq(e))) - - case Operator(subes, op) => - // transform all the sub-expressions - val nsubes = subes map pullLetToTop - //collects all the lets and makes the bodies a tuple - var i = -1 - val transLet = nsubes.tail.foldLeft(nsubes.head) { - case (acc, nsube) => - i += 1 - replaceLetBody(acc, e1 => - replaceLetBody(nsube, e2 => e1 match { - case _ if i == 0 => - Tuple(Seq(e1, e2)) - case Tuple(args) => - Tuple(args :+ e2) - })) - } - // TODO: using tuple here is dangerous it relies on handling unary operators specially - replaceLetBody(transLet, (e: Expr) => e match { - case Tuple(args) => - op(args) - case _ => op(Seq(e)) //here, there was only one argument - }) - } - // println(s"E : $e After Pulling lets to top : \n $transe") - transe - } - val res = pullLetToTop(ine) - /*if(debug) - println(s"InE : $ine After Pulling lets to top : \n ${ScalaPrinter.apply(res)}")*/ - res - } - - def simplifyLetsAndLetsWithTuples(ine: Expr) = { - - def simplerLet(t: Expr): Expr = { - val res = t match { - case letExpr @ Let(i, t: Terminal, b) => - replace(Map(Variable(i) -> t), b) - - // check if the let can be completely removed - case letExpr @ Let(i, e, b) => { - val occurrences = count { - case Variable(x) if x == i => 1 - case _ => 0 - }(b) - - if (occurrences == 0) { - b - } else if (occurrences == 1) { - replace(Map(Variable(i) -> e), b) - } else { - //TODO: we can also remove zero occurrences and compress the tuples - // this may be necessary when instrumentations are combined. - letExpr match { - case letExpr @ Let(binder, lval @ Tuple(subes), b) => - def occurrences(index: Int) = { - val res = count { - case TupleSelect(sel, i) if sel == binder.toVariable && i == index => 1 - case _ => 0 - }(b) - res - } - val binderVar = binder.toVariable - val repmap: Map[Expr, Expr] = subes.zipWithIndex.collect { - case (sube, i) if occurrences(i + 1) == 1 => // sube is used only once ? - (TupleSelect(binderVar, i + 1) -> sube) - case (v @ Variable(_), i) => // sube is a variable ? - (TupleSelect(binderVar, i + 1) -> v) - case (ts @ TupleSelect(Variable(_), _), i) => // sube is a tuple select of a variable ? - (TupleSelect(binderVar, i + 1) -> ts) - }.toMap - Let(binder, lval, replace(repmap, b)) - //note: here, we cannot remove the let, - //if it is not used it will be removed in the next iteration - case e => e - } - } - } - // also perform a tuple simplification - case ts @ TupleSelect(Tuple(subes), i) => - subes(i - 1) - case e => e - } - res - } - val transforms = removeLetsFromLetValues _ andThen fixpoint(simplePostTransform(simplerLet)) _ andThen simplifyArithmetic - transforms(ine) - } - - /* - This function tries to simplify a part of the expression tree consisting of the same operation. - The operatoin needs to be associative and commutative for this simplification to work . - Arguments: - op: An implementation of the opertaion to be simplified - getLeaves: Gets all the operands from the AST (if the argument is not of - the form currently being simplified, this is required to return an empty set) - identity: The identity element for the operation - makeTree: Makes an AST from the operands - */ - def simplifyConstantsGeneral(e: Expr, op: (BigInt, BigInt) => BigInt, - getLeaves: (Expr, Boolean) => Seq[Expr], identity: BigInt, - makeTree: (Expr, Expr) => Expr): Expr = { - - val allLeaves = getLeaves(e, true) - // Here the expression is not of the form we are currently simplifying - if (allLeaves.size == 0) e - else { - // fold constants here - val allConstantsOpped = allLeaves.foldLeft(identity)((acc, e) => e match { - case InfiniteIntegerLiteral(x) => op(acc, x) - case _ => acc - }) - - val allNonConstants = allLeaves.filter((e) => e match { - case _: InfiniteIntegerLiteral => false - case _ => true - }) - - // Reconstruct the expressin tree with the non-constants and the result of constant evaluation above - if (allConstantsOpped != identity) { - allNonConstants.foldLeft(InfiniteIntegerLiteral(allConstantsOpped): Expr)((acc: Expr, currExpr) => makeTree(acc, currExpr)) - } else { - if (allNonConstants.size == 0) InfiniteIntegerLiteral(identity) - else { - allNonConstants.tail.foldLeft(allNonConstants.head)((acc: Expr, currExpr) => makeTree(acc, currExpr)) - } - } - } - } - - //Use the above function to simplify additions and maximums interleaved - def simplifyAdditionsAndMax(e: Expr): Expr = { - def getAllSummands(e: Expr, isTopLevel: Boolean): Seq[Expr] = { - e match { - case Plus(e1, e2) => { - getAllSummands(e1, false) ++ getAllSummands(e2, false) - } - case _ => if (isTopLevel) Seq[Expr]() else Seq[Expr](e) - } - } - - def getAllMaximands(e: Expr, isTopLevel: Boolean): Seq[Expr] = { - e match { - case FunctionInvocation(tfd, args) if (tfd.fd == maxFun) => { - args.foldLeft(Seq[Expr]())((accSet, e) => accSet ++ getAllMaximands(e, false)) - } - case _ => if (isTopLevel) Seq[Expr]() else Seq[Expr](e) - } - } - - simplePostTransform(e => { - val plusSimplifiedExpr = - simplifyConstantsGeneral(e, _ + _, getAllSummands, 0, ((e1, e2) => Plus(e1, e2))) - - // Maximum simplification assumes all arguments to max - // are non-negative (and hence 0 is the identity) - val maxSimplifiedExpr = - simplifyConstantsGeneral(plusSimplifiedExpr, - ((a: BigInt, b: BigInt) => if (a > b) a else b), - getAllMaximands, - 0, - ((e1, e2) => { - val typedMaxFun = TypedFunDef(maxFun, Seq()) - FunctionInvocation(typedMaxFun, Seq(e1, e2)) - })) - - maxSimplifiedExpr - })(e) - } -} \ No newline at end of file diff --git a/src/main/scala/leon/invariant/util/LinearRelationEvaluator.scala b/src/main/scala/leon/invariant/util/LinearRelationEvaluator.scala deleted file mode 100644 index 7c5886faf3ed4a05dc1845910eb149e6188ef81c..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/invariant/util/LinearRelationEvaluator.scala +++ /dev/null @@ -1,82 +0,0 @@ -package leon -package invariant.util - -import purescala.Common._ -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Extractors._ -import purescala.Types._ -import evaluators._ -import java.io._ -import solvers._ -import solvers.combinators._ -import solvers.smtlib._ -import solvers.z3._ -import scala.util.control.Breaks._ -import purescala.ScalaPrinter -import scala.collection.mutable.{ Map => MutableMap } -import scala.reflect.runtime.universe -import invariant.engine._ -import invariant.factories._ -import invariant.util._ -import invariant.util.ExpressionTransformer._ -import invariant.structure._ -import invariant.structure.FunctionUtils._ - -import Util._ -import PredicateUtil._ -import SolverUtil._ -import RealValuedExprEvaluator._ - -/** - * Evaluator for a predicate that is a simple equality/inequality between two variables. - * Some expressions cannot be evaluated, so we return none in those cases. - */ -class LinearRelationEvaluator(ctx: InferenceContext) { - - def predEval(model: LazyModel): (Expr => Option[Boolean]) = { - if (ctx.usereals) realEval(model) - else intEval(model) - } - - def intEval(model: LazyModel): (Expr => Option[Boolean]) = { - def modelVal(id: Identifier): BigInt = { - val InfiniteIntegerLiteral(v) = model(id) - v - } - def eval: (Expr => Option[Boolean]) = { - case And(args) => - val argres = args.map(eval) - if (argres.exists(!_.isDefined)) None - else - Some(argres.forall(_.get)) - case Equals(Variable(id1), Variable(id2)) => - if (model.isDefinedAt(id1) && model.isDefinedAt(id2)) - Some(model(id1) == model(id2)) //note: ADTs can also be compared for equality - else None - case LessEquals(Variable(id1), Variable(id2)) => Some(modelVal(id1) <= modelVal(id2)) - case GreaterEquals(Variable(id1), Variable(id2)) => Some(modelVal(id1) >= modelVal(id2)) - case GreaterThan(Variable(id1), Variable(id2)) => Some(modelVal(id1) > modelVal(id2)) - case LessThan(Variable(id1), Variable(id2)) => Some(modelVal(id1) < modelVal(id2)) - case e => throw new IllegalStateException("Predicate not handled: " + e) - } - eval - } - - def realEval(model: LazyModel): (Expr => Option[Boolean]) = { - def modelVal(id: Identifier): FractionalLiteral = { - //println("Identifier: "+id) - model(id).asInstanceOf[FractionalLiteral] - } - { - case Equals(Variable(id1), Variable(id2)) => Some(model(id1) == model(id2)) //note: ADTs can also be compared for equality - case e @ Operator(Seq(Variable(id1), Variable(id2)), op) if (e.isInstanceOf[LessThan] - || e.isInstanceOf[LessEquals] || e.isInstanceOf[GreaterThan] - || e.isInstanceOf[GreaterEquals]) => { - Some(evaluateRealPredicate(op(Seq(modelVal(id1), modelVal(id2))))) - } - case e => throw new IllegalStateException("Predicate not handled: " + e) - } - } -} diff --git a/src/main/scala/leon/invariant/util/Minimizer.scala b/src/main/scala/leon/invariant/util/Minimizer.scala deleted file mode 100644 index a05b9d72b49f104c3b0a7eb227a4e8d331c1161a..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/invariant/util/Minimizer.scala +++ /dev/null @@ -1,193 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package invariant.util - -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Extractors._ -import solvers._ -import solvers.smtlib.SMTLIBZ3Solver -import invariant.engine.InferenceContext -import invariant.factories._ -import leon.invariant.util.RealValuedExprEvaluator._ -import Stats._ - -class Minimizer(ctx: InferenceContext, program: Program) { - - val verbose = false - val debugMinimization = false - /** - * Here we are assuming that that initModel is a model for ctrs - * TODO: make sure that the template for rootFun is the time template - */ - val MaxIter = 16 //note we may not be able to represent anything beyond 2^16 - val half = FractionalLiteral(1, 2) - val two = FractionalLiteral(2, 1) - val rzero = FractionalLiteral(0, 1) - val mone = FractionalLiteral(-1, 1) - - private val leonctx = ctx.leonContext - val reporter = leonctx.reporter - - //for statistics and output - //store the lowerbounds for each template variables in the template of the rootFun provided it is a time template - var lowerBoundMap = Map[Variable, FractionalLiteral]() - def updateLowerBound(tvar: Variable, rval: FractionalLiteral) = { - //record the lower bound if it exist - if (lowerBoundMap.contains(tvar)) { - lowerBoundMap -= tvar - } - lowerBoundMap += (tvar -> rval) - } - - def tightenTimeBounds(timeTemplate: Expr)(inputCtr: Expr, initModel: Model) = { - //the order in which the template variables are minimized is based on the level of nesting of the terms - minimizeBounds(computeCompositionLevel(timeTemplate))(inputCtr, initModel) - } - - /** - * TODO: use incremental solving of z3 when it is supported in nlsat - * Do a binary search sequentially on the tempvars ordered by the rate of growth of the term they - * are a coefficient for. - */ - def minimizeBounds(nestMap: Map[Variable, Int])(inputCtr: Expr, initModel: Model): Model = { - val orderedTempVars = nestMap.toSeq.sortWith((a, b) => a._2 >= b._2).map(_._1) - lazy val solver = new SimpleSolverAPI(new TimeoutSolverFactory( - SolverFactory.getFromName(leonctx,program)("smt-z3-u"), - ctx.vcTimeout * 1000)) - - reporter.info("minimizing...") - var currentModel = initModel - orderedTempVars.foldLeft(inputCtr: Expr)((acc, tvar) => { - var upperBound = if (currentModel.isDefinedAt(tvar.id)) { - currentModel(tvar.id).asInstanceOf[FractionalLiteral] - } else { - initModel(tvar.id).asInstanceOf[FractionalLiteral] - } - //note: the lower bound is an integer by construction (and is by default zero) - var lowerBound: FractionalLiteral = - if (tvar == orderedTempVars(0) && lowerBoundMap.contains(tvar)) - lowerBoundMap(tvar) - else realzero - def updateState(nmodel: Model) = { - upperBound = nmodel(tvar.id).asInstanceOf[FractionalLiteral] - currentModel = nmodel - if (this.debugMinimization) - reporter.info("Found new upper bound: " + upperBound) - } - - if (this.debugMinimization) - reporter.info(s"Minimizing variable: $tvar Initial Bounds: [$upperBound,$lowerBound]") - var continue = true - var iter = 0 - do { - iter += 1 - if (continue) { - val currval = floor(evaluate(Times(half, Plus(upperBound, lowerBound)))) //make sure that curr val is an integer - if (evaluateRealPredicate(GreaterEquals(lowerBound, currval))) //check if the lowerbound, if it exists, is < currval - continue = false - else { - val boundCtr = And(LessEquals(tvar, currval), GreaterEquals(tvar, lowerBound)) - val (res, newModel) = - if (ctx.abort) (None, Model.empty) - else { - time { solver.solveSAT(And(acc, boundCtr)) }{minTime => - updateCumTime(minTime, "BinarySearchTime") - } - } - res match { - case Some(true) => - updateState(newModel) - case _ => //here we have a new lower bound: currval - lowerBound = currval - if (this.debugMinimization) - reporter.info("Found new lower bound: " + currval) - } - } - } - } while (!ctx.abort && continue && iter < MaxIter) - //A last ditch effort to make the upper bound an integer. - val currval @ FractionalLiteral(n, d) = - if (currentModel.isDefinedAt(tvar.id)) - currentModel(tvar.id).asInstanceOf[FractionalLiteral] - else - initModel(tvar.id).asInstanceOf[FractionalLiteral] - if (d != 1 && !ctx.abort) { - val (res, newModel) = solver.solveSAT(And(acc, Equals(tvar, floor(currval)))) - if (res == Some(true)) - updateState(newModel) - } - //here, we found a best-effort minimum - if (lowerBound != realzero) { - updateLowerBound(tvar, lowerBound) - } - And(acc, Equals(tvar, currval)) - }) - new Model(initModel.map { - case (id, e) => - if (currentModel.isDefinedAt(id)) - (id -> currentModel(id)) - else - (id -> initModel(id)) - }.toMap) - } - - def checkBoundingInteger(tvar: Variable, rl: FractionalLiteral, nlctr: Expr, solver: SimpleSolverAPI): Option[Model] = { - val nl @ FractionalLiteral(n, d) = normalizeFraction(rl) - if (d != 1) { - val flval = floor(nl) - val (res, newModel) = solver.solveSAT(And(nlctr, Equals(tvar, flval))) - res match { - case Some(true) => Some(newModel) - case _ => None - } - } else None - } - - /** - * The following code is little tricky - */ - def computeCompositionLevel(template: Expr): Map[Variable, Int] = { - var nestMap = Map[Variable, Int]() - - def updateMax(v: Variable, level: Int) = { - if (verbose) reporter.info("Nesting level: " + v + "-->" + level) - if (nestMap.contains(v)) { - if (nestMap(v) < level) { - nestMap -= v - nestMap += (v -> level) - } - } else - nestMap += (v -> level) - } - - def functionNesting(e: Expr): Int = { - e match { - - case Times(e1, v @ Variable(id)) if (TemplateIdFactory.IsTemplateIdentifier(id)) => { - val nestLevel = functionNesting(e1) - updateMax(v, nestLevel) - nestLevel - } - case Times(v @ Variable(id), e2) if (TemplateIdFactory.IsTemplateIdentifier(id)) => { - val nestLevel = functionNesting(e2) - updateMax(v, nestLevel) - nestLevel - } - case v @ Variable(id) if (TemplateIdFactory.IsTemplateIdentifier(id)) => { - updateMax(v, 0) - 0 - } - case FunctionInvocation(_, args) => 1 + args.foldLeft(0)((acc, arg) => acc + functionNesting(arg)) - case t: Terminal => 0 - /*case UnaryOperator(arg, _) => functionNesting(arg) - case BinaryOperator(a1, a2, _) => functionNesting(a1) + functionNesting(a2)*/ - case Operator(args, _) => args.foldLeft(0)((acc, arg) => acc + functionNesting(arg)) - } - } - functionNesting(template) - nestMap - } -} diff --git a/src/main/scala/leon/invariant/util/RealIntMap.scala b/src/main/scala/leon/invariant/util/RealIntMap.scala deleted file mode 100644 index 12ba7d97ceffc033b73e42a33ed8b49c65c30fff..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/invariant/util/RealIntMap.scala +++ /dev/null @@ -1,111 +0,0 @@ -package leon -package invariant.util - -import purescala.Common._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Types._ -import invariant.factories._ -import solvers._ -import TemplateIdFactory._ - -/** - * An abstract class for mapping integer typed variables to reals and vice-versa. - * Note: this preserves the template identifier property - */ -abstract class IntRealMap { - var oldToNew = Map[Identifier, Identifier]() - var newToOld = Map[Identifier, Identifier]() - - def mapLiteral(l: Literal[_]): Literal[_] - def unmapLiteral(l: Literal[_]): Literal[_] - def mapIdentifier(v: Identifier): Identifier - - def mapExpr(inexpr: Expr): Expr = { - val transformer = (e: Expr) => e match { - case l : Literal[_] => mapLiteral(l) - case v@Variable(id) => - Variable(oldToNew.getOrElse(id, { - val mid = mapIdentifier(id) - if (mid != id) { - oldToNew += (id -> mid) - newToOld += (mid -> id) - } - mid - })) - case _ => e - } - simplePostTransform(transformer)(inexpr) - } - - def unmapModel(model: Model): Model = { - new Model(model.map { - case (key, value) if (newToOld.contains(key)) => - (newToOld(key), - value match { - case l: Literal[_] => unmapLiteral(l) - case e => e - }) - case other => other - }.toMap) - } - - def mapModel(model: Model): Model = { - new Model(model.map { - case (k, v) if (oldToNew.contains(k)) => - (oldToNew(k), v match { - case l: Literal[_] => mapLiteral(l) - case e => e - }) - case other => other - }.toMap) - } -} - -/** - * maps all real valued variables and literals to new integer variables/literals and - * performs the reverse mapping - */ -class RealToInt extends IntRealMap { - - val bone = BigInt(1) - def mapLiteral(l: Literal[_]): Literal[_] = l match { - case FractionalLiteral(num, `bone`) => InfiniteIntegerLiteral(num) - case FractionalLiteral(_, _) => throw new IllegalStateException("Real literal with non-unit denominator") - case other => other - } - - def unmapLiteral(l: Literal[_]): Literal[_] = l match { - case InfiniteIntegerLiteral(v) => FractionalLiteral(v.toInt, 1) - case other => other - } - - def mapIdentifier(v: Identifier): Identifier = - if (v.getType == RealType) { - if (IsTemplateIdentifier(v)) freshIdentifier(v.name) - else FreshIdentifier(v.name, IntegerType, true) - } else v -} - -/** - * Maps integer literal and identifiers to real literal and identifiers - */ -/*class IntToReal extends IntRealMap { - - def mapLiteral(l: Literal[_]): Literal[_] = l match { - case InfiniteIntegerLiteral(v) => FractionalLiteral(v.toInt, 1) - case other => other - } - - *//** - * Here, we return fractional literals for integer-valued variables, - * and leave to the client to handle them - *//* - def unmapLiteral(l: Literal[_]): Literal[_] = l - - def mapIdentifier(v: Identifier): Identifier = - if (v.getType == IntegerType) { - if (IsTemplateIdentifier(v)) freshIdentifier(v.name, RealType) - else FreshIdentifier(v.name, RealType, true) - } else v -}*/ \ No newline at end of file diff --git a/src/main/scala/leon/invariant/util/RealValuedExprEvaluator.scala b/src/main/scala/leon/invariant/util/RealValuedExprEvaluator.scala deleted file mode 100644 index 8f334c0b660104c466012970964667545da6bca1..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/invariant/util/RealValuedExprEvaluator.scala +++ /dev/null @@ -1,107 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package invariant.util - -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Extractors._ -import scala.math.BigInt.int2bigInt - -object RealValuedExprEvaluator { - - /** - * Requires that the input expression is ground - */ - def evaluate(expr: Expr): FractionalLiteral = { - plainEvaluate(expr) - } - - def plainEvaluate(expr: Expr): FractionalLiteral = expr match { - - case UMinus(e) => { - val FractionalLiteral(num, denom) = plainEvaluate(e) - FractionalLiteral(-num, denom) - } - case Minus(lhs, rhs) => { - plainEvaluate(Plus(lhs, UMinus(rhs))) - } - case Plus(_, _) | RealPlus(_, _) => { - val Operator(Seq(lhs, rhs), op) = expr - val FractionalLiteral(lnum, ldenom) = plainEvaluate(lhs) - val FractionalLiteral(rnum, rdenom) = plainEvaluate(rhs) - normalizeFraction(FractionalLiteral((lnum * rdenom + rnum * ldenom), (ldenom * rdenom))) - } - case Times(_, _) | RealTimes(_, _) => { - val Operator(Seq(lhs, rhs), op) = expr - val FractionalLiteral(lnum, ldenom) = plainEvaluate(lhs) - val FractionalLiteral(rnum, rdenom) = plainEvaluate(rhs) - normalizeFraction(FractionalLiteral((lnum * rnum), (ldenom * rdenom))) - } - case Division(_, _) | RealDivision(_, _) => { - val Operator(Seq(lhs, rhs), op) = expr - val FractionalLiteral(lnum, ldenom) = plainEvaluate(lhs) - val FractionalLiteral(rnum, rdenom) = plainEvaluate(rhs) - plainEvaluate(Times(FractionalLiteral(lnum, ldenom), FractionalLiteral(rdenom, rnum))) - } - case il @ InfiniteIntegerLiteral(v) => FractionalLiteral(v, 1) - case rl @ FractionalLiteral(_, _) => normalizeFraction(rl) - case _ => throw new IllegalStateException("Not an evaluatable expression: " + expr) - } - - def evaluateRealPredicate(expr: Expr): Boolean = { - expr match { - case Equals(a @ FractionalLiteral(_, _), b @ FractionalLiteral(_, _)) => isEQZ(evaluate(Minus(a, b))) - case LessEquals(a @ FractionalLiteral(_, _), b @ FractionalLiteral(_, _)) => isLEZ(evaluate(Minus(a, b))) - case LessThan(a @ FractionalLiteral(_, _), b @ FractionalLiteral(_, _)) => isLTZ(evaluate(Minus(a, b))) - case GreaterEquals(a @ FractionalLiteral(_, _), b @ FractionalLiteral(_, _)) => isGEZ(evaluate(Minus(a, b))) - case GreaterThan(a @ FractionalLiteral(n1, d1), b @ FractionalLiteral(n2, d2)) => isGTZ(evaluate(Minus(a, b))) - } - } - - def isEQZ(rlit: FractionalLiteral): Boolean = { - val FractionalLiteral(n, d) = rlit - if (d == 0) throw new IllegalStateException("denominator zero") - (n == 0) - } - - def isLEZ(rlit: FractionalLiteral): Boolean = { - val FractionalLiteral(n, d) = rlit - if (d == 0) throw new IllegalStateException("denominator zero") - if (d < 0) throw new IllegalStateException("denominator negative: " + d) - (n <= 0) - } - - def isLTZ(rlit: FractionalLiteral): Boolean = { - val FractionalLiteral(n, d) = rlit - if (d == 0) throw new IllegalStateException("denominator zero") - if (d < 0) throw new IllegalStateException("denominator negative: " + d) - (n < 0) - } - - def isGEZ(rlit: FractionalLiteral): Boolean = { - val FractionalLiteral(n, d) = rlit - if (d == 0) throw new IllegalStateException("denominator zero") - if (d < 0) throw new IllegalStateException("denominator negative: " + d) - (n >= 0) - } - - def isGTZ(rlit: FractionalLiteral): Boolean = { - val FractionalLiteral(n, d) = rlit - if (d == 0) throw new IllegalStateException("denominator zero") - if (d < 0) throw new IllegalStateException("denominator negative: " + d) - (n > 0) - } - - def evaluateRealFormula(expr: Expr): Boolean = expr match { - case And(args) => args forall evaluateRealFormula - case Or(args) => args exists evaluateRealFormula - case Not(arg) => !evaluateRealFormula(arg) - case BooleanLiteral(b) => b - case Operator(args, op) => - op(args map evaluate) match { - case BooleanLiteral(b) => b - case p => evaluateRealPredicate(p) - } - } -} diff --git a/src/main/scala/leon/invariant/util/SelectorToCons.scala b/src/main/scala/leon/invariant/util/SelectorToCons.scala deleted file mode 100644 index a458d8197d7ea9658b65b73a225de29a3c9c9e17..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/invariant/util/SelectorToCons.scala +++ /dev/null @@ -1,116 +0,0 @@ -package leon -package invariant.util - -import purescala.Common._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Types._ -import invariant.factories._ -import solvers._ -import scala.collection.mutable.{ Set => MutableSet, Map => MutableMap } -import ExpressionTransformer._ -import TVarFactory._ - -object SelectToCons { - // temporaries generated during conversion of field selects to ADT constructions - val fieldSelContext = newContext -} - -/** - * A class that converts case-class or tuple selectors in an expression - * to constructors, and updates a given lazy model. - * We assume that all the arguments are flattened in the input expression. - */ -class SelectorToCons { - - import SelectToCons._ - - var fldIdMap = Map[Identifier, (Variable, Int)]() - - /** - * For now this works only on a disjunct - */ - def selToCons(disjunct: Seq[Expr]): Seq[Expr] = { - def classSelToCons(eq: Equals) = eq match { - case Equals(r: Variable, CaseClassSelector(ctype, cc: Variable, selfld)) => - //convert this to a cons by creating dummy variables - val args = ctype.fields.zipWithIndex.map { - case (fld, i) if fld.id == selfld => r - case (fld, i) => - val t = createTemp("fld", fld.getType, fieldSelContext) //create a dummy identifier there - fldIdMap += (t -> (cc, i)) - t.toVariable - } - Equals(cc, CaseClass(ctype, args)) - case _ => - throw new IllegalStateException("Selector not flattened: " + eq) - } - def tupleSelToCons(eq: Equals) = eq match { - case Equals(r: Variable, TupleSelect(tp: Variable, idx)) => - val tupleType = tp.getType.asInstanceOf[TupleType] - //convert this to a Tuple by creating dummy variables - val args = (1 until tupleType.dimension + 1).map { i => - if (i == idx) r - else { - val t = createTemp("fld", tupleType.bases(i - 1), fieldSelContext) //note: we have to use i-1 - fldIdMap += (t -> (tp, i - 1)) - t.toVariable - } - } - Equals(tp, Tuple(args)) - case _ => - throw new IllegalStateException("Selector not flattened: " + eq) - } - //println("Input expression: "+ine) - disjunct.map { // we need to traverse top-down - case eq @ Equals(_, _: CaseClassSelector) => - classSelToCons(eq) - case eq @ Equals(_, _: TupleSelect) => - tupleSelToCons(eq) - case _: CaseClassSelector | _: TupleSelect => - throw new IllegalStateException("Selector not flattened") - case e => e - } -// println("Output expression: "+rese) -// rese - } - - // def tupleSelToCons(e: Expr): Expr = { - // val (r, tpvar, index) = e match { - // case Equals(r0 @ Variable(_), TupleSelect(tpvar0, index0)) => (r0, tpvar0, index0) - // // case Iff(r0 @ Variable(_), TupleSelect(tpvar0, index0)) => (r0, tpvar0, index0) - // case _ => throw new IllegalStateException("Not a tuple-selector call") - // } - // //convert this to a Tuple by creating dummy variables - // val tupleType = tpvar.getType.asInstanceOf[TupleType] - // val args = (1 until tupleType.dimension + 1).map((i) => { - // if (i == index) r - // else { - // //create a dummy identifier there (note that here we have to use i-1) - // createTemp("fld", tupleType.bases(i - 1), fieldSelContext).toVariable - // } - // }) - // Equals(tpvar, Tuple(args)) - // } - - /** - * Expands a given model into a model with mappings for identifiers introduced during flattening. - * Note: this class cannot be accessed in parallel. - */ - def getModel(initModel: LazyModel) = new LazyModel { - override def get(iden: Identifier) = { - val idv = initModel.get(iden) - if (idv.isDefined) idv - else { - fldIdMap.get(iden) match { - case Some((Variable(inst), fldIdx)) => - initModel(inst) match { - case CaseClass(_, args) => Some(args(fldIdx)) - case Tuple(args) => Some(args(fldIdx)) - } - case None => None - } - } - } - } -} \ No newline at end of file diff --git a/src/main/scala/leon/invariant/util/SolverUtil.scala b/src/main/scala/leon/invariant/util/SolverUtil.scala deleted file mode 100644 index 97d94021d00d3ad798595822c4f7fe7129966eea..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/invariant/util/SolverUtil.scala +++ /dev/null @@ -1,146 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package invariant.util - -import purescala.Common._ -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Types._ -import solvers.z3._ -import solvers._ -import leon.invariant.templateSolvers.ExtendedUFSolver -import java.io._ -import Util._ -import PredicateUtil._ -import evaluators._ -import EvaluationResults._ -import purescala.Extractors._ - -object SolverUtil { - - def modelToExpr(model: Model): Expr = { - model.foldLeft(tru: Expr)((acc, elem) => { - val (k, v) = elem - val eq = Equals(k.toVariable, v) - if (acc == tru) eq - else And(acc, eq) - }) - } - - def completeWithRefModel(currModel: Model, refModel: Model) = { - new Model(refModel.toMap.map { - case (id, _) if currModel.isDefinedAt(id) => - (id -> currModel(id)) - case (id, v) => - (id -> v) - }.toMap) - } - - def toZ3SMTLIB(expr: Expr, filename: String, - theory: String, ctx: LeonContext, pgm: Program, - useBitvectors: Boolean = false, - bitvecSize: Int = 32) = { - //create new solver, assert constraints and print - val printSol = new ExtendedUFSolver(ctx, pgm) - printSol.assertCnstr(expr) - val writer = new PrintWriter(filename) - writer.println(printSol.ctrsToString(theory)) - printSol.free() - writer.flush() - writer.close() - } - - def verifyModel(e: Expr, model: Model, solver: SimpleSolverAPI) = { - solver.solveSAT(And(e, modelToExpr(model))) match { - case (Some(false), _) => - throw new IllegalStateException("Model doesn't staisfy formula!") - case _ => - } - } - - /** - * A helper function that can be used to hardcode an invariant and see if it unsatifies the paths - */ - def checkInvariant(expr: Expr, ctx: LeonContext, prog: Program): Option[Boolean] = { - val idmap: Map[Expr, Expr] = variablesOf(expr).collect { - case id @ _ if (id.name.toString == "a?") => id.toVariable -> InfiniteIntegerLiteral(6) - case id @ _ if (id.name.toString == "c?") => id.toVariable -> InfiniteIntegerLiteral(2) - }.toMap - //println("found ids: " + idmap.keys) - if (idmap.keys.nonEmpty) { - val newpathcond = replace(idmap, expr) - //check if this is solvable - val solver = SimpleSolverAPI(SolverFactory("extendedUF", () => new ExtendedUFSolver(ctx, prog))) - solver.solveSAT(newpathcond)._1 match { - case Some(true) => { - println("Path satisfiable for a?,c? -->6,2 ") - Some(true) - } - case _ => { - println("Path unsat for a?,c? --> 6,2") - Some(false) - } - } - } else None - } - - def collectUNSATCores(ine: Expr, ctx: LeonContext, prog: Program): Expr = { - var controlVars = Map[Variable, Expr]() - var newEqs = Map[Expr, Expr]() - val solver = new ExtendedUFSolver(ctx, prog) - val newe = simplePostTransform { - case e@(And(_) | Or(_)) => { - val v = TVarFactory.createTempDefault("a", BooleanType).toVariable - newEqs += (v -> e) - val newe = Equals(v, e) - - //create new variable and add it in disjunction - val cvar = FreshIdentifier("ctrl", BooleanType, true).toVariable - controlVars += (cvar -> newe) - solver.assertCnstr(Or(newe, cvar)) - v - } - case e => e - }(ine) - //create new variable and add it in disjunction - val cvar = FreshIdentifier("ctrl", BooleanType, true).toVariable - controlVars += (cvar -> newe) - solver.assertCnstr(Or(newe, cvar)) - - val res = solver.checkAssumptions(controlVars.keySet.map(Not.apply _)) - println("Result: " + res) - val coreExprs = solver.getUnsatCore - val simpcores = coreExprs.foldLeft(Seq[Expr]())((acc, coreExp) => { - val Not(cvar @ Variable(_)) = coreExp - val newexp = controlVars(cvar) - //println("newexp: "+newexp) - newexp match { - // case Iff(v@Variable(_),rhs) if(newEqs.contains(v)) => acc - case Equals(v @ Variable(_), rhs) if (v.getType == BooleanType && rhs.getType == BooleanType && newEqs.contains(v)) => acc - case _ => { - acc :+ newexp - } - } - }) - val cores = Util.fix((e: Expr) => replace(newEqs, e))(createAnd(simpcores.toSeq)) - - solver.free - //cores - ExpressionTransformer.unflatten(cores) - } - - //tests if the solver uses nlsat - def usesNLSat(solver: AbstractZ3Solver) = { - //check for nlsat - val x = FreshIdentifier("x", RealType).toVariable - val testExpr = Equals(Times(x, x), FractionalLiteral(2, 1)) - solver.assertCnstr(testExpr) - solver.check match { - case Some(true) => true - case _ => false - } - } - -} diff --git a/src/main/scala/leon/invariant/util/Stats.scala b/src/main/scala/leon/invariant/util/Stats.scala deleted file mode 100644 index 9b957d3aced36c2f4c4c881d4dd8337d24ae5922..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/invariant/util/Stats.scala +++ /dev/null @@ -1,137 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package invariant.util - -import purescala.Definitions._ -import purescala.Expressions._ -import java.io._ -import scala.collection.mutable.{Map => MutableMap} - -/** - * A generic statistics object that provides: - * (a) Temporal variables that change over time. We track the total sum and max of the values the variable takes over time - * (b) Counters that are incremented over time. Variables can be associated with counters. - * We track the averages value of a variable over time w.r.t to the counters with which it is associated. - */ -object Stats { - val keystats = MutableMap[String, (Long, Long)]() - val counterMap = MutableMap[String, Seq[String]]() - var cumKeys = Seq[String]() - var timekeys = Set[String]() //this may be inner, outer or cumkey - - private def updateStats(newval: Long, key: String, cname: Option[String]) = { - val (cum, max) = keystats.getOrElse(key, { - val init = (0: Long, 0: Long) - keystats += (key -> (0, 0)) - - if (cname.isDefined) { - val presentKeys = counterMap(cname.get) - counterMap.update(cname.get, presentKeys :+ key) - } else { - cumKeys :+= key - } - init - }) - val newcum = cum + newval - val newmax = if (max < newval) newval else max - keystats.update(key, (newcum, newmax)) - } - //a special method for adding times - private def updateTimeStats(newval: Long, key: String, cname: Option[String]) = { - if (!timekeys.contains(key)) - timekeys += key - updateStats(newval, key, cname) - } - - def updateCumStats(newval: Long, key: String) = updateStats(newval, key, None) - def updateCumTime(newval: Long, key: String) = updateTimeStats(newval, key, None) - def updateCounter(incr: Long, key: String) = { - if (!counterMap.contains(key)) { - counterMap.update(key, Seq()) - } - //counters are considered as cumulative stats - updateStats(incr, key, None) - } - def updateCounterStats(newval: Long, key: String, cname: String) = updateStats(newval, key, Some(cname)) - def updateCounterTime(newval: Long, key: String, cname: String) = updateTimeStats(newval, key, Some(cname)) - - private def getCum(key: String): Long = keystats(key)._1 - private def getMax(key: String): Long = keystats(key)._2 - - def dumpStats(pr: PrintWriter) = { - //Print cumulative stats - cumKeys.foreach(key => { - if (timekeys.contains(key)) { - pr.println(key + ": " + (getCum(key).toDouble / 1000.0) + "s") - } else - pr.println(key + ": " + getCum(key)) - }) - - //dump the averages and maximum of all stats associated with counters - counterMap.keys.foreach((ckey) => { - pr.println("### Statistics for counter: " + ckey + " ####") - val counterval = getCum(ckey) - val assocKeys = counterMap(ckey) - assocKeys.foreach((key) => { - if (timekeys.contains(key)) { - pr.println("Avg." + key + ": " + (getCum(key).toDouble / (counterval * 1000.0)) + "s") - pr.println("Max." + key + ": " + (getMax(key).toDouble / 1000.0) + "s") - } else { - pr.println("Avg." + key + ": " + (getCum(key).toDouble / counterval)) - pr.println("Max." + key + ": " + getMax(key)) - } - }) - }) - } - - def time[T](code: => T)(cont: Long => Unit): T = { - var t1 = System.currentTimeMillis() - val r = code - cont((System.currentTimeMillis() - t1)) - r - } - - def getTime[T](code: => T): (T, Long) = { - var t1 = System.currentTimeMillis() - val r = code - (r, (System.currentTimeMillis() - t1)) - } -} - -/** - * Statistics specific for this application - */ -object SpecificStats { - - var output: String = "" - def addOutput(out: String) = { - output += out + "\n" - } - def dumpOutputs(pr: PrintWriter) { - pr.println("########## Outputs ############") - pr.println(output) - pr.flush() - } - - //minimization stats - var lowerBounds = Map[FunDef, Map[Variable, FractionalLiteral]]() - var lowerBoundsOutput = Map[FunDef, String]() - def addLowerBoundStats(fd: FunDef, lbMap: Map[Variable, FractionalLiteral], out: String) = { - lowerBounds += (fd -> lbMap) - lowerBoundsOutput += (fd -> out) - } - def dumpMinimizationStats(pr: PrintWriter) { - pr.println("########## Lower Bounds ############") - lowerBounds.foreach((pair) => { - val (fd, lbMap) = pair - pr.print(fd.id + ": \t") - lbMap.foreach((entry) => { - pr.print("(" + entry._1 + "->" + entry._2 + "), ") - }) - pr.print("\t Test results: " + lowerBoundsOutput(fd)) - pr.println() - }) - pr.flush() - } -} \ No newline at end of file diff --git a/src/main/scala/leon/invariant/util/TVarFactory.scala b/src/main/scala/leon/invariant/util/TVarFactory.scala deleted file mode 100644 index cafad81b885c81cb3be4e8626954246564831796..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/invariant/util/TVarFactory.scala +++ /dev/null @@ -1,37 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package invariant.util - -import purescala.Common._ -import purescala.Types._ -import scala.collection.mutable.{ Set => MutableSet, Map => MutableMap} - -object TVarFactory { - - type Context = Int - val temporaries = MutableMap[Context, MutableSet[Identifier]]() - private var context: Context = 0 - - def newContext = { - context += 1 - temporaries += (context -> MutableSet()) - context - } - val defaultContext = newContext - - def createTemp(name: String, tpe: TypeTree = Untyped, context: Context): Identifier = { - val freshid = FreshIdentifier(name, tpe, true) - temporaries(context) += freshid - freshid - } - - def createTempDefault(name: String, tpe: TypeTree = Untyped): Identifier = { - val freshid = FreshIdentifier(name, tpe, true) - temporaries(defaultContext) += freshid - freshid - } - - def isTemp(id: Identifier, context: Context): Boolean = - temporaries.contains(context) && temporaries(context)(id) -} diff --git a/src/main/scala/leon/invariant/util/TimerUtil.scala b/src/main/scala/leon/invariant/util/TimerUtil.scala deleted file mode 100644 index 22dfffa24ec2c400126ea1df913377e511b9a4e0..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/invariant/util/TimerUtil.scala +++ /dev/null @@ -1,62 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package invariant.util - -import utils._ - -object TimerUtil { - /** - * Timeout in milliseconds - */ - def scheduleTask(callBack: () => Unit, timeOut: Long): Option[java.util.Timer] = { - if (timeOut > 0) { - val timer = new java.util.Timer() - timer.schedule(new java.util.TimerTask() { - def run() { - callBack() - timer.cancel() //the timer will be cancelled after it runs - } - }, timeOut) - Some(timer) - } else None - } -} - -class InterruptOnSignal(it: Interruptible) { - - private class Poll(signal: => Boolean, onSignal: => Unit) extends Thread { - private var keepRunning = true - - override def run(): Unit = { - while (!signal && keepRunning) { - Thread.sleep(100) // a relatively infrequent poll - } - if (signal && keepRunning) { - onSignal - } - } - - def finishedRunning(): Unit = { - keepRunning = false - } - } - - def interruptOnSignal[T](signal: => Boolean)(body: => T): T = { - var recdSignal = false - - val timer = new Poll(signal, { - it.interrupt() - recdSignal = true - }) - - timer.start() - val res = body - timer.finishedRunning() - - if (recdSignal) { - it.recoverInterrupt() - } - res - } -} \ No newline at end of file diff --git a/src/main/scala/leon/invariant/util/TreeUtil.scala b/src/main/scala/leon/invariant/util/TreeUtil.scala deleted file mode 100644 index 597aa270edad551f186006b12ac3af5442be3b4c..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/invariant/util/TreeUtil.scala +++ /dev/null @@ -1,508 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package invariant.util - -import purescala.Common._ -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Extractors._ -import purescala.Types._ -import scala.collection.mutable.{ Set => MutableSet, Map => MutableMap } -import leon.invariant._ -import invariant.engine._ -import invariant.factories._ -import invariant.structure._ -import FunctionUtils._ -import scala.annotation.tailrec -import PredicateUtil._ -import ProgramUtil._ -import TypeUtil._ -import Util._ -import solvers._ -import purescala.DefOps._ - -object ProgramUtil { - - def createLeonContext(ctx: LeonContext, opts: String*): LeonContext = { - Main.processOptions(opts.toList).copy(reporter = ctx.reporter, - interruptManager = ctx.interruptManager, files = ctx.files, timers = ctx.timers) - } - - /** - * Here, we exclude empty units that do not have any modules and empty - * modules that do not have any definitions - */ - def copyProgram(prog: Program, mapdefs: (Seq[Definition] => Seq[Definition])): Program = { - prog.copy(units = prog.units.collect { - case unit if unit.defs.nonEmpty => unit.copy(defs = unit.defs.collect { - case module: ModuleDef if module.defs.nonEmpty => - module.copy(defs = mapdefs(module.defs)) - case other => other - }) - }) - } - - def appendDefsToModules(p: Program, defs: Map[ModuleDef, Traversable[Definition]]): Program = { - val res = p.copy(units = for (u <- p.units) yield { - u.copy( - defs = u.defs.map { - case m: ModuleDef if defs.contains(m) => - m.copy(defs = m.defs ++ defs(m)) - case other => other - }) - }) - res - } - - def addDefs(p: Program, defs: Traversable[Definition], after: Definition): Program = { - var found = false - val res = p.copy(units = for (u <- p.units) yield { - u.copy( - defs = u.defs.map { - case m: ModuleDef => - val newdefs = for (df <- m.defs) yield { - df match { - case `after` => - found = true - after +: defs.toSeq - case d => - Seq(d) - } - } - m.copy(defs = newdefs.flatten) - case other => other - }) - }) - if (!found) { - println("addDefs could not find anchor definition!") - } - res - } - - def createTemplateFun(plainTemp: Expr): FunctionInvocation = { - val tmpl = Lambda(getTemplateIds(plainTemp).toSeq.map(id => ValDef(id)), plainTemp) - val tmplFd = new FunDef(FreshIdentifier("tmpl", FunctionType(Seq(tmpl.getType), BooleanType), false), Seq(), - Seq(ValDef(FreshIdentifier("arg", tmpl.getType))), BooleanType) - tmplFd.body = Some(BooleanLiteral(true)) - FunctionInvocation(TypedFunDef(tmplFd, Seq()), Seq(tmpl)) - } - - /** - * This is the default template generator. - * Note: we are not creating template for libraries. - */ - def getOrCreateTemplateForFun(fd: FunDef): Expr = { - val plainTemp = if (fd.hasTemplate) fd.getTemplate - else if (fd.annotations.contains("library")) BooleanLiteral(true) - else { - //just consider all the arguments, return values that are integers - val baseTerms = fd.params.filter((vardecl) => isNumericType(vardecl.getType)).map(_.toVariable) ++ - (if (isNumericType(fd.returnType)) Seq(getFunctionReturnVariable(fd)) - else Seq()) - val lhs = baseTerms.foldLeft(TemplateIdFactory.freshTemplateVar(): Expr)((acc, t) => { - Plus(Times(TemplateIdFactory.freshTemplateVar(), t), acc) - }) - val tempExpr = LessEquals(lhs, InfiniteIntegerLiteral(0)) - tempExpr - } - plainTemp - } - - def mapFunctionsInExpr(funmap: Map[FunDef, FunDef])(ine: Expr): Expr = { - simplePostTransform { - case FunctionInvocation(tfd, args) if funmap.contains(tfd.fd) => - FunctionInvocation(TypedFunDef(funmap(tfd.fd), tfd.tps), args) - case e => e - }(ine) - } - - /** - * For functions for which `funToTmpl` is not defined, their templates will be removed. - * Will only consider user-level functions. - */ - def assignTemplateAndCojoinPost(funToTmpl: Map[FunDef, Expr], prog: Program, - funToPost: Map[FunDef, Expr] = Map(), uniqueIdDisplay: Boolean = false): Program = { - - val keys = funToTmpl.keySet ++ funToPost.keySet - val userLevelFuns = userLevelFunctions(prog).toSet - if(!keys.diff(userLevelFuns).isEmpty) - throw new IllegalStateException("AssignTemplate function called on library functions: "+ keys.diff(userLevelFuns)) - - val funMap = userLevelFuns.foldLeft(Map[FunDef, FunDef]()) { - case (accMap, fd) => { - val freshId = FreshIdentifier(fd.id.name, fd.returnType, uniqueIdDisplay) - accMap + (fd -> new FunDef(freshId, fd.tparams, fd.params, fd.returnType)) - } - } - val mapExpr = mapFunctionsInExpr(funMap) _ - for ((from, to) <- funMap) { - to.fullBody = if (!funToTmpl.contains(from)) { - mapExpr { - from.fullBody match { - case Ensuring(b, post) => - Ensuring(b, - Lambda(Seq(ValDef(getResId(from).get)), - createAnd(Seq(from.getPostWoTemplate, funToPost.getOrElse(from, tru))))) - case fb => - fb - } - } - } else { - val newTmpl = createTemplateFun(funToTmpl(from)) - mapExpr { - from.fullBody match { - case Require(pre, body) => - val toPost = - Lambda(Seq(ValDef(FreshIdentifier("res", from.returnType))), - createAnd(Seq(newTmpl, funToPost.getOrElse(from, tru)))) - Ensuring(Require(pre, body), toPost) - - case Ensuring(Require(pre, body), post) => - Ensuring(Require(pre, body), - Lambda(Seq(ValDef(getResId(from).get)), - createAnd(Seq(from.getPostWoTemplate, newTmpl, funToPost.getOrElse(from, tru))))) - - case Ensuring(body, post) => - Ensuring(body, - Lambda(Seq(ValDef(getResId(from).get)), - createAnd(Seq(from.getPostWoTemplate, newTmpl, funToPost.getOrElse(from, tru))))) - - case body => - val toPost = - Lambda(Seq(ValDef(FreshIdentifier("res", from.returnType))), - createAnd(Seq(newTmpl, funToPost.getOrElse(from, tru)))) - Ensuring(body, toPost) - } - } - } - //copy annotations - from.flags.foreach(to.addFlag(_)) - } - val newprog = copyProgram(prog, (defs: Seq[Definition]) => defs.map { - case fd: FunDef if funMap.contains(fd) => - funMap(fd) - case d => d - }) - newprog - } - - def updatePost(funToPost: Map[FunDef, Lambda], prog: Program, uniqueIdDisplay: Boolean = true): Program = { - - val funMap = userLevelFunctions(prog).foldLeft(Map[FunDef, FunDef]()) { - case (accMap, fd) => - val freshId = FreshIdentifier(fd.id.name, fd.returnType, uniqueIdDisplay) - accMap + (fd -> new FunDef(freshId, fd.tparams, fd.params, fd.returnType)) - } - val mapExpr = mapFunctionsInExpr(funMap) _ - for ((from, to) <- funMap) { - to.fullBody = if (!funToPost.contains(from)) { - mapExpr(from.fullBody) - } else { - val newpost = funToPost(from) - mapExpr { - from.fullBody match { - case Ensuring(body, post) => - Ensuring(body, newpost) // replace the old post with new post - case body => - Ensuring(body, newpost) - } - } - } - //copy annotations - from.flags.foreach(to.addFlag(_)) - } - val newprog = copyProgram(prog, (defs: Seq[Definition]) => defs.map { - case fd: FunDef if funMap.contains(fd) => - funMap(fd) - case d => d - }) - newprog - } - - def functionByName(nm: String, prog: Program) = { - prog.definedFunctions.find(fd => fd.id.name == nm) - } - - def functionByFullName(nm: String, prog: Program) = { - prog.definedFunctions.find(fd => fullName(fd)(prog) == nm) - } - - def functionsWOFields(fds: Seq[FunDef]): Seq[FunDef] = { - fds.filter(fd => fd.isRealFunction) - } - - /** - * Functions that are not theory-operations or library methods that are not a part of the main unit - */ - def userLevelFunctions(program: Program): Seq[FunDef] = { - program.units.flatMap { u => - u.definedFunctions.filter(fd => !fd.isTheoryOperation && (u.isMainUnit || !(fd.isLibrary || fd.isInvariant))) - } - } - - def translateExprToProgram(ine: Expr, currProg: Program, newProg: Program): Expr = { - var funCache = Map[String, Option[FunDef]]() - def funInNewprog(fn: String) = - funCache.get(fn) match { - case None => - val fd = functionByFullName(fn, newProg) - funCache += (fn -> fd) - fd - case Some(fd) => fd - } - simplePostTransform { - case FunctionInvocation(TypedFunDef(fd, tps), args) => - val fname = fullName(fd)(currProg) - funInNewprog(fname) match { - case Some(nfd) => - FunctionInvocation(TypedFunDef(nfd, tps), args) - case _ => - throw new IllegalStateException(s"Cannot find translation for ${fname}") - } - case e => e - }(ine) - } - - def getFunctionReturnVariable(fd: FunDef) = { - if (fd.hasPostcondition) getResId(fd).get.toVariable - else ResultVariable(fd.returnType) /*FreshIdentifier("res", fd.returnType).toVariable*/ - } - - def getResId(funDef: FunDef): Option[Identifier] = { - funDef.fullBody match { - case Ensuring(_, post) => { - post match { - case Lambda(Seq(ValDef(fromRes)), _) => Some(fromRes) - } - } - case _ => None - } - } - - //compute the formal to the actual argument mapping - def formalToActual(call: Call): Map[Expr, Expr] = { - val fd = call.fi.tfd.fd - val resvar = getFunctionReturnVariable(fd) - val argmap: Map[Expr, Expr] = Map(resvar -> call.retexpr) ++ fd.params.map(_.id.toVariable).zip(call.fi.args) - argmap - } -} - -object PredicateUtil { - /** - * Returns a constructor for the let* and also the current - * body of let* - */ - def letStarUnapply(e: Expr): (Expr => Expr, Expr) = e match { - case Let(binder, letv, letb) => - val (cons, body) = letStarUnapply(letb) - (e => Let(binder, letv, cons(e)), body) - case base => - (e => e, base) - } - - def letStarUnapplyWithSimplify(e: Expr): (Expr => Expr, Expr) = { - val (letCons, letBody) = letStarUnapply(e) - (letCons andThen simplifyLets, letBody) - } - - /** - * Checks if the input expression has only template variables as free variables - */ - def isTemplateExpr(expr: Expr): Boolean = { - var foundVar = false - postTraversal { - case e @ Variable(id) => - if (!TemplateIdFactory.IsTemplateIdentifier(id)) - foundVar = true - case e @ ResultVariable(_) => - foundVar = true - case e => - }(expr) - !foundVar - } - - def isArithmeticRelation(e: Expr) = { - e match { - case Equals(l, r) => - if (l.getType == Untyped) None - else Some(TypeUtil.isNumericType(l.getType)) - case _: LessThan | _: LessEquals | _: GreaterThan | _: GreaterEquals => Some(true) - case _ => Some(false) - } - } - - def getTemplateIds(expr: Expr) = { - variablesOf(expr).filter(TemplateIdFactory.IsTemplateIdentifier) - } - - def getTemplateVars(expr: Expr): Set[Variable] = { - getTemplateIds(expr).map(_.toVariable) - } - - /** - * Checks if the expression has real valued sub-expressions. - */ - def hasReals(expr: Expr): Boolean = { - var foundReal = false - postTraversal { - case e if e.getType == RealType => - foundReal = true - case _ => - }(expr) - foundReal - } - - /** - * Checks if the expression has real valued sub-expressions. - */ - def hasRealsOrTemplates(expr: Expr): Boolean = { - var found = false - postTraversal { - case Variable(id) if id.getType == RealType || TemplateIdFactory.IsTemplateIdentifier(id) => - found = true - case e if e.getType == RealType => - found = true - case _ => - }(expr) - found - } - - /** - * Checks if the expression has real valued sub-expressions. - * Note: important, <, <=, > etc have default int type. - * However, they can also be applied over real arguments - * So check only if all terminals are real - */ - def hasInts(expr: Expr): Boolean = { - var foundInt = false - postTraversal { - case e: Terminal if (e.getType == Int32Type || e.getType == IntegerType) => - foundInt = true - case _ => - }(expr) - foundInt - } - - def hasMixedIntReals(expr: Expr): Boolean = { - hasInts(expr) && hasReals(expr) - } - - /** - * Assuming a flattenned formula - */ - def atomNum(e: Expr): Int = e match { - case And(args) => (args map atomNum).sum - case Or(args) => (args map atomNum).sum - case IfExpr(c, th, el) => atomNum(c) + atomNum(th) + atomNum(el) - case Not(arg) => atomNum(arg) - case e => 1 - } - - def numUIFADT(e: Expr): Int = { - var count: Int = 0 - simplePostTransform { - case e @ (FunctionInvocation(_, _) | CaseClass(_, _) | Tuple(_)) => { - count += 1 - e - } - case e => e - }(e) - count - } - - def hasCalls(e: Expr) = numUIFADT(e) >= 1 - - def getCallExprs(ine: Expr): Set[Expr] = { - var calls = Set[Expr]() - simplePostTransform((e: Expr) => e match { - case call @ _ if isCallExpr(e) => { - calls += e - call - } - case _ => e - })(ine) - calls - } - - def isCallExpr(e: Expr): Boolean = e match { - case Equals(Variable(_), FunctionInvocation(_, _)) => true - // case Iff(Variable(_),FunctionInvocation(_,_)) => true - case _ => false - } - - def isADTConstructor(e: Expr): Boolean = e match { - case Equals(Variable(_), CaseClass(_, _)) => true - case Equals(Variable(_), Tuple(_)) => true - case _ => false - } - - def isMultFunctions(fd: FunDef) = { - (fd.id.name == "mult" || fd.id.name == "pmult") && - fd.isTheoryOperation - } - - //replaces occurrences of mult by Times - def multToTimes(ine: Expr): Expr = { - simplePostTransform { - case FunctionInvocation(TypedFunDef(fd, _), args) if isMultFunctions(fd) => { - Times(args(0), args(1)) - } - case e => e - }(ine) - } - - def createAnd(exprs: Seq[Expr]): Expr = { - val newExprs = exprs.filterNot(conj => conj == tru) - newExprs match { - case Seq() => tru - case Seq(e) => e - case _ => And(newExprs) - } - } - - def createOr(exprs: Seq[Expr]): Expr = { - val newExprs = exprs.filterNot(disj => disj == fls) - newExprs match { - case Seq() => fls - case Seq(e) => e - case _ => Or(newExprs) - } - } - - def precOrTrue(fd: FunDef): Expr = fd.precondition match { - case Some(pre) => pre - case None => BooleanLiteral(true) - } - - /** - * Computes the set of variables that are shared across disjunctions. - * This may return bound variables as well - */ - def sharedIds(ine: Expr): Set[Identifier] = { - - def sharedOfDisjointExprs(args: Seq[Expr]) = { - var uniqueVars = Set[Identifier]() - var sharedVars = Set[Identifier]() - args.foreach { arg => - val candUniques = variablesOf(arg) -- sharedVars - val newShared = uniqueVars.intersect(candUniques) - sharedVars ++= newShared - uniqueVars = (uniqueVars ++ candUniques) -- newShared - } - sharedVars ++ (args flatMap rec) - } - def rec(e: Expr): Set[Identifier] = - e match { - case Or(args) => sharedOfDisjointExprs(args) - case IfExpr(c, th, el) => - rec(c) ++ sharedOfDisjointExprs(Seq(th, el)) - case Variable(_) => Set() - case Operator(args, op) => - (args flatMap rec).toSet - } - rec(ine) - } -} diff --git a/src/main/scala/leon/invariant/util/TypeUtil.scala b/src/main/scala/leon/invariant/util/TypeUtil.scala deleted file mode 100644 index 6d7c59e8068a3fc0ec2ed99bbebb926c6aa9fa41..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/invariant/util/TypeUtil.scala +++ /dev/null @@ -1,58 +0,0 @@ -package leon -package invariant.util - -import purescala.Common._ -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Extractors._ -import purescala.Types._ - -object TypeUtil { - def getTypeParameters(t: TypeTree): Seq[TypeParameter] = { - t match { - case tp @ TypeParameter(_) => Seq(tp) - case NAryType(tps, _) => - (tps flatMap getTypeParameters).distinct - } - } - - def getTypeArguments(t: TypeTree) : Seq[TypeTree] = t match { - case ct: ClassType => ct.tps - case NAryType(tps, _) => - (tps flatMap getTypeArguments).distinct - } - - def typeNameWOParams(t: TypeTree): String = t match { - case ct: ClassType => ct.id.name - case TupleType(ts) => ts.map(typeNameWOParams).mkString("(", ",", ")") - case ArrayType(t) => s"Array[${typeNameWOParams(t)}]" - case SetType(t) => s"Set[${typeNameWOParams(t)}]" - case MapType(from, to) => s"Map[${typeNameWOParams(from)}, ${typeNameWOParams(to)}]" - case FunctionType(fts, tt) => - val ftstr = fts.map(typeNameWOParams).mkString("(", ",", ")") - s"$ftstr => ${typeNameWOParams(tt)}" - case t => t.toString - } - - def instantiateTypeParameters(tpMap: Map[TypeParameter, TypeTree])(t: TypeTree): TypeTree = { - t match { - case tp: TypeParameter => tpMap.getOrElse(tp, tp) - case NAryType(subtypes, tcons) => - tcons(subtypes map instantiateTypeParameters(tpMap) _) - } - } - - def isNumericType(t: TypeTree) = t match { - case IntegerType | RealType => true - case Int32Type => - throw new IllegalStateException("BitVector types not supported yet!") - case _ => false - } - - def rootType(t: TypeTree): Option[AbstractClassType] = t match { - case absT: AbstractClassType => Some(absT) - case ct: CaseClassType => ct.parent - case _ => None - } -} \ No newline at end of file diff --git a/src/main/scala/leon/invariant/util/UnflatHelper.scala b/src/main/scala/leon/invariant/util/UnflatHelper.scala deleted file mode 100644 index c32d2208478553eddd822b2018c0ca3866fe79ee..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/invariant/util/UnflatHelper.scala +++ /dev/null @@ -1,99 +0,0 @@ -package leon -package invariant.util - -import purescala.Common._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Types._ -import invariant.factories._ -import solvers._ -import scala.collection.immutable._ -import scala.collection.mutable.{ Set => MutableSet, Map => MutableMap } -import ExpressionTransformer._ -import leon.evaluators._ -import EvaluationResults._ - -trait LazyModel { - def get(iden: Identifier): Option[Expr] - - def apply(iden: Identifier): Expr = - get(iden) match { - case Some(e) => e - case _ => throw new IllegalStateException(s"Cannot create mapping for $iden") - } - - def isDefinedAt(iden: Identifier) = get(iden).isDefined -} - -class SimpleLazyModel(m: Model) extends LazyModel { - def get(iden: Identifier): Option[Expr] = m.get(iden) -} - -/** - * Expands a given model into a model with mappings for identifiers introduced during flattening. - * Note: this class cannot be accessed in parallel. - */ -class FlatModel(freeVars: Set[Identifier], flatIdMap: Map[Identifier, Expr], initModel: Model, eval: DefaultEvaluator) extends LazyModel { - var idModel = initModel.toMap - - override def get(iden: Identifier) = { - var seen = Set[Identifier]() - def recBind(id: Identifier): Option[Expr] = { - val idv = idModel.get(id) - if (idv.isDefined) idv - else { - if (seen(id)) { - //we are in a cycle here - throw new IllegalStateException(s"$id depends on itself") - } else if (flatIdMap.contains(id)) { - val rhs = flatIdMap(id) - // recursively bind all freevars to values (we can ignore the return values) - seen += id - variablesOf(rhs).filterNot(idModel.contains).map(recBind) - eval.eval(rhs, idModel) match { - case Successful(v) => - idModel += (id -> v) - Some(v) - case _ => - None - //throw new IllegalStateException(s"Evaluation Falied for $id -> $rhs") - } - } else if (freeVars(id)) { - // here, `id` either belongs to values of the flatIdMap, or to flate or was lost in unflattening - println(s"Completing $id with simplest value") - val simpv = simplestValue(id.getType) - idModel += (id -> simpv) - Some(simpv) - } else - None - //throw new IllegalStateException(s"Cannot extract model $id as it not contained in the input expression: $ine") - } - } - recBind(iden) - } -} - -object UnflatHelper { - def evaluate(e: Expr, m: LazyModel, eval: DefaultEvaluator): Expr = { - val varsMap = variablesOf(e).collect { - case v if m.isDefinedAt(v) => (v -> m(v)) - }.toMap - eval.eval(e, varsMap) match { - case Successful(v) => v - case _ => - throw new IllegalStateException(s"Evaluation Falied for $e") - } - } -} - -/** - * A class that can used to compress a flattened expression - * and also expand the compressed models to the flat forms - */ -class UnflatHelper(ine: Expr, excludeIds: Set[Identifier], eval: DefaultEvaluator) { - - val (unflate, flatIdMap) = unflattenWithMap(ine, excludeIds, includeFuns = false) - val invars = variablesOf(ine) - - def getModel(m: Model) = new FlatModel(invars, flatIdMap, m, eval) -} \ No newline at end of file diff --git a/src/main/scala/leon/invariant/util/Util.scala b/src/main/scala/leon/invariant/util/Util.scala deleted file mode 100644 index 472bc11921501ff7fa2ac3c604f13740fe714250..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/invariant/util/Util.scala +++ /dev/null @@ -1,101 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package invariant.util - -import purescala.Expressions._ -import purescala.Types._ -import purescala.PrettyPrintable -import purescala.PrinterContext -import purescala.PrinterHelpers._ -import purescala.Definitions._ -import purescala.Common._ -import purescala.ExprOps._ - -object FileCountGUID { - var fileCount = 0 - def getID: Int = { - val oldcnt = fileCount - fileCount += 1 - oldcnt - } -} - -//three valued logic -object TVL { - abstract class Value - object FALSE extends Value - object TRUE extends Value - object MAYBE extends Value -} - -//this is used as a place holder for result -case class ResultVariable(tpe: TypeTree) extends Expr with Terminal with PrettyPrintable { - val getType = tpe - override def toString: String = "#res" - - def printWith(implicit pctx: PrinterContext) = { - p"#res" - } -} - -object Util { - - val zero = InfiniteIntegerLiteral(0) - val one = InfiniteIntegerLiteral(1) - val mone = InfiniteIntegerLiteral(-1) - val bone = BigInt(1) - val tru = BooleanLiteral(true) - val fls = BooleanLiteral(false) - - def fix[A](f: (A) => A)(a: A): A = { - val na = f(a) - if (a == na) a else fix(f)(na) - } - - def gcd(x: Int, y: Int): Int = { - if (x == 0) y - else gcd(y % x, x) - } - - /** - * A cross product with an optional filter - */ - def cross[U, V](a: Set[U], b: Set[V], selector: Option[(U, V) => Boolean] = None): Set[(U, V)] = { - - val product = (for (x <- a; y <- b) yield (x, y)) - if (selector.isDefined) - product.filter(pair => selector.get(pair._1, pair._2)) - else - product - } - - /** - * Transitively close the substitution map from identifiers to expressions. - * Note: the map is required to be acyclic. - */ - def substClosure(initMap: Map[Identifier, Expr]): Map[Identifier, Expr] = { - if (initMap.isEmpty) initMap - else { - var stables = Seq[(Identifier, Expr)]() - var unstables = initMap.toSeq - var changed = true - while (changed) { - changed = false - var foundStable = false - unstables = unstables.flatMap { - case (k, v) if variablesOf(v).intersect(initMap.keySet).isEmpty => - foundStable = true - stables +:= (k -> v) - Seq() - case (k, v) => - changed = true - Seq((k -> replaceFromIDs(initMap, v))) - } - if (!foundStable) - throw new IllegalStateException(s"No stable entry was found in the map! The map is possibly cyclic: $initMap") - } - stables.toMap - } - } -} diff --git a/src/main/scala/leon/laziness/ClosurePreAsserter.scala b/src/main/scala/leon/laziness/ClosurePreAsserter.scala deleted file mode 100644 index 48c51556cb1d5b1a8c75b122a5841e2a476ea3fa..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/laziness/ClosurePreAsserter.scala +++ /dev/null @@ -1,157 +0,0 @@ -package leon -package laziness - -import invariant.util._ -import invariant.structure.FunctionUtils._ -import purescala.Common._ -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.DefOps._ -import purescala.Types._ -import leon.invariant.util.TypeUtil._ -import leon.invariant.util.ProgramUtil._ -import leon.invariant.util.PredicateUtil._ -import LazinessUtil._ - -/** - * Generate lemmas that ensure that preconditions hold for closures. - * Note: here we cannot use `ClosureFactory` for anything other than state, - * since we work with the translated, type correct program here. - */ -class ClosurePreAsserter(p: Program) { - - /** - * A mapping from `closures` that are *created* in the program - * to the corresponding functions - */ - var ccToOp = Map[CaseClassDef, FunDef]() - def lookupOp(ccd: CaseClassDef): FunDef = { - ccToOp.getOrElse(ccd, { - val opname = ccNameToOpName(ccd.id.name) - val op = functionByName(opname, p).get - ccToOp += (ccd -> op) - op - }) - } - - def hasClassInvariants(cc: CaseClass): Boolean = { - lookupOp(cc.ct.classDef).hasPrecondition - } - - // TODO: A nasty way of finding anchor functions - // Fix this soon !! - var anchorfd: Option[FunDef] = None - val lemmas = p.definedFunctions.flatMap { - case fd if (fd.hasBody && !fd.isLibrary) => - //println("collection closure creation preconditions for: "+fd) - val closures = CollectorWithPaths { - case FunctionInvocation(TypedFunDef(fund, _), - Seq(cc: CaseClass, st)) if isClosureCons(fund) && hasClassInvariants(cc) => - (cc, st) - } traverse (fd.body.get) // Note: closures cannot be created in specs - // Note: once we have separated normal preconditions from state preconditions - // it suffices to just consider state preconditions here - closures.map { - case ((CaseClass(CaseClassType(ccd, _), argsRet), st), path) => - anchorfd = Some(fd) - val target = lookupOp(ccd) //find the target corresponding to the closure - - val pre = target.precondition.get - val args = - if (!isMemoized(target)) - argsRet.dropRight(1) // drop the return value which is the right-most field - else argsRet - val nargs = - if (target.params.size > args.size) // target takes state ? - args :+ st - else args - val pre2 = replaceFromIDs((target.params.map(_.id) zip nargs).toMap, pre) - val vc = path withCond precOrTrue(fd) implies pre2 - // create a function for each vc - val lemmaid = FreshIdentifier(ccd.id.name + fd.id.name + "Lem", Untyped, true) - val params = variablesOf(vc).toSeq.map(v => ValDef(v)) - val tparams = params.flatMap(p => getTypeParameters(p.getType)).distinct map TypeParameterDef - val lemmafd = new FunDef(lemmaid, tparams, params, BooleanType) - // reset the types of locals - val initGamma = params.map(vd => vd.id -> vd.getType).toMap - lemmafd.body = Some(TypeChecker.inferTypesOfLocals(vc, initGamma)) - // assert the lemma is true - val resid = FreshIdentifier("holds", BooleanType) - lemmafd.postcondition = Some(Lambda(Seq(ValDef(resid)), resid.toVariable)) - //println("Created lemma function: "+lemmafd) - lemmafd - } - case _ => Seq() - } - - /** - * Create functions that check the monotonicity of the preconditions - * of the ops - */ - val monoLemmas = { - var exprsProcessed = Set[ExprStructure]() - ccToOp.values.flatMap { - case op if op.hasPrecondition && !isMemoized(op) => // ignore memoized functions which are always evaluated at the time of creation - // get the state param - op.paramIds.find(isStateParam) match { - case Some(stparam) => - // remove disjuncts that do not depend on the state - val preDisjs = op.precondition.get match { - case And(args) => - args.filter(a => variablesOf(a).contains(stparam)) - case l: Let => // checks if the body of the let can be deconstructed as And - val (letsCons, letsBody) = letStarUnapply(l) - letsBody match { - case And(args) => - args.filter(a => variablesOf(a).contains(stparam)).map { - e => simplifyLets(letsCons(e)) - } - case _ => Seq() - } - case e => Seq() - } - if (preDisjs.nonEmpty) { - // create a new state parameter - val superSt = FreshIdentifier("st2@", stparam.getType) - val stType = stparam.getType.asInstanceOf[CaseClassType] - // assert that every component of `st` is a subset of `stparam` - val subsetExpr = createAnd( - stType.classDef.fields.map { fld => - val fieldSelect = (id: Identifier) => CaseClassSelector(stType, id.toVariable, fld.id) - SubsetOf(fieldSelect(stparam), fieldSelect(superSt)) - }) - // create a function for each pre-disjunct that is not processed - preDisjs.map(new ExprStructure(_)).collect { - case preStruct if !exprsProcessed(preStruct) => - exprsProcessed += preStruct - val pred = preStruct.e - val vc = Implies(And(subsetExpr, pred), - replaceFromIDs(Map(stparam -> superSt.toVariable), pred)) - val lemmaid = FreshIdentifier(op.id.name + "PreMonotone", Untyped, true) - val params = variablesOf(vc).toSeq.map(v => ValDef(v)) - val lemmafd = new FunDef(lemmaid, op.tparams, params, BooleanType) - lemmafd.body = Some(vc) - // assert that the lemma is true - val resid = FreshIdentifier("holds", BooleanType) - lemmafd.postcondition = Some(Lambda(Seq(ValDef(resid)), resid.toVariable)) - // add the trace induct annotation - lemmafd.addFlag(new Annotation("traceInduct", Seq())) - //println("Created lemma function: "+lemmafd) - lemmafd - } - } else Seq.empty[FunDef] // nothing to be done - case None => - Seq.empty[FunDef] // nothing to be done - } - case _ => - Seq.empty[FunDef] // nothing to be done - } - } - - def apply: Program = { - if (!lemmas.isEmpty) - addFunDefs(p, lemmas ++ monoLemmas, anchorfd.get) - else p - } -} diff --git a/src/main/scala/leon/laziness/FreeVariableFactory.scala b/src/main/scala/leon/laziness/FreeVariableFactory.scala deleted file mode 100644 index 7388227095e049d344b5e1b4a354ff69f46c7a08..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/laziness/FreeVariableFactory.scala +++ /dev/null @@ -1,97 +0,0 @@ -package leon -package laziness - -import invariant.util._ -import purescala.Common._ -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.Types._ - -/** - * A class that maintains a data type that can used to - * create free variables at different points in the program. - * All free variables are of type `FreeVar` which can be mapped - * to a required type by applying uninterpreted functions. - */ -object FreeVariableFactory { - - val fvClass = new AbstractClassDef(FreshIdentifier("FreeVar@"), Seq(), None) - val fvType = AbstractClassType(fvClass, Seq()) - val varCase = { - val cdef = new CaseClassDef(FreshIdentifier("Var@"), Seq(), Some(fvType), false) - cdef.setFields(Seq(ValDef(FreshIdentifier("fl", fvType)))) - fvClass.registerChild(cdef) - cdef - } - val nextCase = { - val cdef = new CaseClassDef(FreshIdentifier("NextVar@"), Seq(), Some(fvType), false) - cdef.setFields(Seq(ValDef(FreshIdentifier("fl", fvType)))) - fvClass.registerChild(cdef) - cdef - } - val nilCase = { - val cdef = new CaseClassDef(FreshIdentifier("NilVar@"), Seq(), Some(fvType), false) - fvClass.registerChild(cdef) - cdef - } - - class FreeVarListIterator(initRef: Variable) { - require(initRef.getType == fvType) - var refExpr : Expr = initRef - def current = CaseClass(varCase.typed, Seq(refExpr)) // Var(refExpr) - def next { - refExpr = CaseClass(nextCase.typed, Seq(refExpr)) // Next(refExpr) - } - // returns the current expressions and increments state - def nextExpr = { - val e = current - next - e - } - } - - def getFreeListIterator(initRef: Variable) = new FreeVarListIterator(initRef) - - var uifuns = Map[TypeTree, FunDef]() - def getOrCreateUF(t: TypeTree) = { - uifuns.getOrElse(t, { - val funName = "uop@" + TypeUtil.typeNameWOParams(t) - val param = ValDef(FreshIdentifier("a", fvType)) - val tparams = TypeUtil.getTypeParameters(t) map TypeParameterDef.apply _ - val uop = new FunDef(FreshIdentifier(funName), tparams, Seq(param), t) - uifuns += (t -> uop) - uop - }) - } - - class FreeVariableGenerator(initRef: Variable) { - val flIter = new FreeVarListIterator(initRef) - - /** - * Free operations are not guaranteed to be unique: They are - * uninterpreted functions of the form: f(ref). - * f(res_1) could be equal to f(res_2). - */ - def nextFV(t: TypeTree) = { - val uop = getOrCreateUF(t) - val fv = FunctionInvocation(TypedFunDef(uop, Seq()), Seq(flIter.current)) - flIter.next - fv - } - - /** - * References are guaranteed to be unique. - */ - def nextRef = { - val ref = flIter.current - flIter.next - ref - } - } - - def getFreeVarGenerator(initRef: Variable) = new FreeVariableGenerator(initRef) - - def fvClasses = Seq(fvClass, varCase, nextCase, nilCase) - - def fvFunctions = uifuns.keys.toSeq -} diff --git a/src/main/scala/leon/laziness/LazinessEliminationPhase.scala b/src/main/scala/leon/laziness/LazinessEliminationPhase.scala deleted file mode 100644 index 2d65f404e789f63ff80e3f4e1e77778160304fff..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/laziness/LazinessEliminationPhase.scala +++ /dev/null @@ -1,170 +0,0 @@ -package leon -package laziness - -import invariant.util._ -import invariant.structure.FunctionUtils._ -import purescala.ScalaPrinter -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.ExprOps._ -import LazinessUtil._ -import LazyVerificationPhase._ -import utils._ -import java.io._ -import invariant.engine.InferenceReport -/** - * TODO: Function names are assumed to be small case. Fix this!! - */ -object LazinessEliminationPhase extends SimpleLeonPhase[Program, LazyVerificationReport] { - val dumpInputProg = false - val dumpLiftProg = false - val dumpProgramWithClosures = false - val dumpTypeCorrectProg = false - val dumpProgWithPreAsserts = false - val dumpProgWOInstSpecs = false - val dumpInstrumentedProgram = false - val debugSolvers = false - val skipStateVerification = false - val skipResourceVerification = false - - val name = "Laziness Elimination Phase" - val description = "Coverts a program that uses lazy construct" + - " to a program that does not use lazy constructs" - - // options that control behavior - val optRefEquality = LeonFlagOptionDef("refEq", "Uses reference equality for comparing closures", false) - val optUseOrb = LeonFlagOptionDef("useOrb", "Use Orb to infer constants", false) - - override val definedOptions: Set[LeonOptionDef[Any]] = Set(optUseOrb, optRefEquality) - - /** - * TODO: add inlining annotations for optimization. - */ - def apply(ctx: LeonContext, prog: Program): LazyVerificationReport = { - val (progWOInstSpecs, instProg) = genVerifiablePrograms(ctx, prog) - val checkCtx = contextForChecks(ctx) - val stateVeri = - if (!skipStateVerification) - Some(checkSpecifications(progWOInstSpecs, checkCtx)) - else None - - val resourceVeri = - if (!skipResourceVerification) - Some(checkInstrumentationSpecs(instProg, checkCtx, - checkCtx.findOption(LazinessEliminationPhase.optUseOrb).getOrElse(false))) - else None - // dump stats if enabled - if (ctx.findOption(GlobalOptions.optBenchmark).getOrElse(false)) { - val modid = prog.units.find(_.isMainUnit).get.id - val filename = modid + "-stats.txt" - val pw = new PrintWriter(filename) - Stats.dumpStats(pw) - SpecificStats.dumpOutputs(pw) - ctx.reporter.info("Stats dumped to file: " + filename) - pw.close() - } - // return a report - new LazyVerificationReport(stateVeri, resourceVeri) - } - - def genVerifiablePrograms(ctx: LeonContext, prog: Program): (Program, Program) = { - if (dumpInputProg) - println("Input prog: \n" + ScalaPrinter.apply(prog)) - - val (pass, msg) = sanityChecks(prog, ctx) - assert(pass, msg) - - // refEq is by default false - val nprog = LazyExpressionLifter.liftLazyExpressions(prog, ctx.findOption(optRefEquality).getOrElse(false)) - if (dumpLiftProg) - prettyPrintProgramToFile(nprog, ctx, "-lifted", true) - - val funsManager = new LazyFunctionsManager(nprog) - val closureFactory = new LazyClosureFactory(nprog) - val progWithClosures = (new LazyClosureConverter(nprog, ctx, closureFactory, funsManager)).apply - if (dumpProgramWithClosures) - prettyPrintProgramToFile(progWithClosures, ctx, "-closures") - - //Rectify type parameters and local types - val typeCorrectProg = (new TypeRectifier(progWithClosures, closureFactory)).apply - if (dumpTypeCorrectProg) - prettyPrintProgramToFile(typeCorrectProg, ctx, "-typed") - - val progWithPre = (new ClosurePreAsserter(typeCorrectProg)).apply - if (dumpProgWithPreAsserts) - prettyPrintProgramToFile(progWithPre, ctx, "-withpre", uniqueIds = true) - - // verify the contracts that do not use resources - val progWOInstSpecs = InliningPhase.apply(ctx, removeInstrumentationSpecs(progWithPre)) - if (dumpProgWOInstSpecs) - prettyPrintProgramToFile(progWOInstSpecs, ctx, "-woinst") - - // instrument the program for resources (note: we avoid checking preconditions again here) - val instrumenter = new LazyInstrumenter(InliningPhase.apply(ctx, typeCorrectProg), ctx, closureFactory) - val instProg = instrumenter.apply - if (dumpInstrumentedProgram) - prettyPrintProgramToFile(instProg, ctx, "-withinst", uniqueIds = true) - (progWOInstSpecs, instProg) - } - - /** - * TODO: enforce that lazy and nested types do not overlap - * TODO: we are forced to make an assumption that lazy ops takes as type parameters only those - * type parameters of their return type and not more. (This is checked in the closureFactory,\ - * but may be check this upfront) - */ - def sanityChecks(p: Program, ctx: LeonContext): (Boolean, String) = { - // using a bit of a state here - var failMsg = "" - val checkres = p.definedFunctions.forall { - case fd if !fd.isLibrary => - /** - * Fails when the argument to a suspension creation - * is either a normal or memoized function depending on the flag - * 'argMem' = true implies fail if the argument is a memoized function - */ - def failOnClosures(argMem: Boolean, e: Expr) = e match { - case finv: FunctionInvocation if isLazyInvocation(finv)(p) => - finv match { - case FunctionInvocation(_, Seq(Lambda(_, FunctionInvocation(callee, _)))) if isMemoized(callee.fd) => argMem - case _ => !argMem - } - case _ => false - } - // specs should not create lazy closures, but can refer to memoized functions - val specCheckFailed = exists(failOnClosures(false, _))(fd.precOrTrue) || exists(failOnClosures(false, _))(fd.postOrTrue) - if (specCheckFailed) { - failMsg = "Lazy closure creation in the specification of function: " + fd.id - false - } else { - // cannot suspend a memoized function - val bodyCheckFailed = exists(failOnClosures(true, _))(fd.body.getOrElse(Util.tru)) - if (bodyCheckFailed) { - failMsg = "Suspending a memoized function is not supported! in body of: " + fd.id - false - } else { - def nestedSusp(e: Expr) = e match { - case finv @ FunctionInvocation(_, Seq(Lambda(_, call: FunctionInvocation))) if isLazyInvocation(finv)(p) && isLazyInvocation(call)(p) => true - case _ => false - } - val nestedCheckFailed = exists(nestedSusp)(fd.body.getOrElse(Util.tru)) - if (nestedCheckFailed) { - failMsg = "Nested suspension creation in the body: " + fd.id - false - } else { - // arguments or return types of memoized functions cannot be lazy because we do not know how to compare them for equality - if (isMemoized(fd)) { - val argCheckFailed = (fd.params.map(_.getType) :+ fd.returnType).exists(LazinessUtil.isLazyType) - if (argCheckFailed) { - failMsg = "Memoized function has a lazy argument or return type: " + fd.id - false - } else true - } else true - } - } - } - case _ => true - } - (checkres, failMsg) - } -} diff --git a/src/main/scala/leon/laziness/LazinessUtil.scala b/src/main/scala/leon/laziness/LazinessUtil.scala deleted file mode 100644 index 858b1fec28e5aff3d307ca485006c7673dae3dc7..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/laziness/LazinessUtil.scala +++ /dev/null @@ -1,223 +0,0 @@ -package leon -package laziness - -import purescala.ScalaPrinter -import purescala.Common._ -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.DefOps._ -import purescala.Types._ -import java.io.File -import java.io.FileWriter -import java.io.BufferedWriter -import scala.util.matching.Regex -import utils.FileOutputPhase - -object LazinessUtil { - - def isMemoized(fd: FunDef) = { - fd.flags.contains(Annotation("memoize", Seq())) - } - - def prettyPrintProgramToFile(p: Program, ctx: LeonContext, suffix: String, uniqueIds: Boolean = false) { - val optOutputDirectory = FileOutputPhase.optOutputDirectory - val outputFolder = ctx.findOptionOrDefault(optOutputDirectory) - try { - new File(outputFolder).mkdir() - } catch { - case _: java.io.IOException => - ctx.reporter.fatalError("Could not create directory " + outputFolder) - } - - for (u <- p.units if u.isMainUnit) { - val outputFile = s"$outputFolder${File.separator}${u.id.toString}$suffix.scala" - try { - val out = new BufferedWriter(new FileWriter(outputFile)) - val plainText = ScalaPrinter.apply(u, purescala.PrinterOptions(printUniqueIds = uniqueIds)) - //println("Plain text: "+plainText) - // remove '@' from the end of the identifier names - val pat = new Regex("""(\w+)(@)(\w*)(\*?)(\S*)""", "base", "at", "mid", "star", "rest") - - val pgmText = try{ pat.replaceAllIn(plainText, - m => { - m.group("base") + m.group("mid") + ( - if (!m.group("star").isEmpty()) "S" else "") + m.group("rest") - }) - } catch { - case _: IndexOutOfBoundsException => plainText - } - out.write(pgmText) - out.close() - } catch { - case _: java.io.IOException => ctx.reporter.fatalError("Could not write on " + outputFile) - } - } - ctx.reporter.info("Output written on " + outputFolder) - } - - def isLazyInvocation(e: Expr)(implicit p: Program): Boolean = e match { - case FunctionInvocation(TypedFunDef(fd, _), Seq(_)) => - fullName(fd)(p) == "leon.lazyeval.$" - case _ => - false - } - - def isEagerInvocation(e: Expr)(implicit p: Program): Boolean = e match { - case FunctionInvocation(TypedFunDef(fd, _), Seq(_)) => - fullName(fd)(p) == "leon.lazyeval.eager" - case _ => - false - } - - def isInStateCall(e: Expr)(implicit p: Program): Boolean = e match { - case FunctionInvocation(TypedFunDef(fd, _), Seq()) => - val fn = fullName(fd)(p) - (fn == "leon.lazyeval.inState" || fn == "leon.mem.inState") - case _ => - false - } - - def isOutStateCall(e: Expr)(implicit p: Program): Boolean = e match { - case FunctionInvocation(TypedFunDef(fd, _), Seq()) => - val fn = fullName(fd)(p) - (fn == "leon.lazyeval.outState" || fn == "leon.mem.outState") - case _ => - false - } - - def isEvaluatedInvocation(e: Expr)(implicit p: Program): Boolean = e match { - case FunctionInvocation(TypedFunDef(fd, _), Seq(_)) => - fullName(fd)(p) == "leon.lazyeval.Lazy.isEvaluated" - case _ => false - } - - def isSuspInvocation(e: Expr)(implicit p: Program): Boolean = e match { - case FunctionInvocation(TypedFunDef(fd, _), Seq(_, _)) => - fullName(fd)(p) == "leon.lazyeval.Lazy.isSuspension" - case _ => false - } - - def isWithStateCons(e: Expr)(implicit p: Program): Boolean = e match { - case CaseClass(cct, Seq(_)) => - val fn = fullName(cct.classDef)(p) - (fn == "leon.lazyeval.WithState" || fn == "leon.mem.memWithState") - case _ => false - } - - def isMemCons(e: Expr)(implicit p: Program): Boolean = e match { - case CaseClass(cct, Seq(_)) => - fullName(cct.classDef)(p) == "leon.mem.Mem" - case _ => false - } - - /** - * There are many overloads of withState functions with different number - * of arguments. All of them should pass this check. - */ - def isWithStateFun(e: Expr)(implicit p: Program): Boolean = e match { - case FunctionInvocation(TypedFunDef(fd, _), _) => - val fn = fullName(fd)(p) - (fn == "leon.lazyeval.WithState.withState" || - fn == "leon.mem.memWithState.withState") - case _ => false - } - - def isCachedInv(e: Expr)(implicit p: Program): Boolean = e match { - case FunctionInvocation(TypedFunDef(fd, _), Seq(_)) => - fullName(fd)(p) == "leon.mem.Mem.isCached" - case _ => false - } - - def isValueInvocation(e: Expr)(implicit p: Program): Boolean = e match { - case FunctionInvocation(TypedFunDef(fd, _), Seq(_)) => - fullName(fd)(p) == "leon.lazyeval.Lazy.value" - case _ => false - } - - def isStarInvocation(e: Expr)(implicit p: Program): Boolean = e match { - case FunctionInvocation(TypedFunDef(fd, _), Seq(_)) => - fullName(fd)(p) == "leon.lazyeval.Lazy.*" - case _ => false - } - - def isLazyType(tpe: TypeTree): Boolean = tpe match { - case CaseClassType(ccd, Seq(_)) if !ccd.hasParent && !ccd.isCaseObject => - ccd.id.name == "Lazy" - case _ => false - } - - def isMemType(tpe: TypeTree): Boolean = tpe match { - case CaseClassType(ccd, Seq(_)) if !ccd.hasParent && !ccd.isCaseObject => - ccd.id.name == "Mem" - case _ => false - } - - /** - * Lazy types are not nested by precondition - */ - def unwrapLazyType(tpe: TypeTree) = tpe match { - case ctype @ CaseClassType(_, Seq(innerType)) if isLazyType(ctype) || isMemType(ctype) => - Some(innerType) - case _ => None - } - - def opNameToCCName(name: String) = { - name.capitalize + "@" - } - - /** - * Convert the first character to lower case - * and remove the last character. - */ - def ccNameToOpName(name: String) = { - name.substring(0, 1).toLowerCase() + - name.substring(1, name.length() - 1) - } - - def typeNameToADTName(name: String) = { - "Lazy" + name - } - - def adtNameToTypeName(name: String) = { - name.substring(4) - } - - def typeToFieldName(name: String) = { - name.toLowerCase() - } - - def closureConsName(typeName: String) = { - "new@" + typeName - } - - def isClosureCons(fd: FunDef) = { - fd.id.name.startsWith("new@") - } - - def evalFunctionName(absTypeName: String) = { - "eval@" + absTypeName - } - - def isEvalFunction(fd: FunDef) = { - fd.id.name.startsWith("eval@") - } - - def isStateParam(id: Identifier) = { - id.name.startsWith("st@") - } - - def isPlaceHolderTParam(tp: TypeParameter) = { - tp.id.name.endsWith("@") - } - - def freshenTypeArguments(tpe: TypeTree): TypeTree = { - tpe match { - case NAryType(targs, tcons) => - val ntargs = targs.map { - case targ: TypeParameter => targ.freshen - case targ => targ - } - tcons(ntargs) - } - } -} diff --git a/src/main/scala/leon/laziness/LazyClosureConverter.scala b/src/main/scala/leon/laziness/LazyClosureConverter.scala deleted file mode 100644 index 8e86581d86d38216485bb96a0a7a25866007d0c1..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/laziness/LazyClosureConverter.scala +++ /dev/null @@ -1,787 +0,0 @@ -package leon -package laziness - -import invariant.util._ -import invariant.structure.FunctionUtils._ -import purescala.Common._ -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Extractors._ -import purescala.Types._ -import leon.invariant.util.TypeUtil._ -import LazinessUtil._ -import ProgramUtil._ -import PredicateUtil._ -import purescala.TypeOps.bestRealType - -/** - * (a) add state to every function in the program - * (b) thread state through every expression in the program sequentially - * (c) replace lazy constructions with case class creations - * (d) replace isEvaluated with currentState.contains() - * (e) replace accesses to $.value with calls to dispatch with current state - */ -class LazyClosureConverter(p: Program, ctx: LeonContext, - closureFactory: LazyClosureFactory, - funsManager: LazyFunctionsManager) { - val debug = false - // flags - //val removeRecursionViaEval = false - val refEq = ctx.findOptionOrDefault(LazinessEliminationPhase.optRefEquality) - - val funsNeedStates = funsManager.funsNeedStates - val funsRetStates = funsManager.funsRetStates - val starCallers = funsManager.funsNeedStateTps - val lazyTnames = closureFactory.lazyTypeNames - val lazyops = closureFactory.lazyops - - /** - * Note this method has side-effects - */ - var idmap = Map[Identifier, Identifier]() - def makeIdOfType(oldId: Identifier, tpe: TypeTree): Identifier = { - if (oldId.getType != tpe) { - val freshid = FreshIdentifier(oldId.name, tpe, true) - idmap += (oldId -> freshid) - freshid - } else oldId - } - - val funMap = p.definedFunctions.collect { - case fd if (fd.hasBody && !fd.isLibrary && !fd.isInvariant) => // skipping class invariants for now. - // replace lazy types in parameters and return values - val nparams = fd.params map { vd => - val nparam = makeIdOfType(vd.id, replaceLazyTypes(vd.getType)) - ValDef(nparam) - } - val nretType = replaceLazyTypes(fd.returnType) - val stparams = - if (funsNeedStates(fd) || starCallers(fd)) - // create fresh type parameters for the state - closureFactory.state.tparams.map(_ => TypeParameter.fresh("P@")) - else Seq() - - val nfd = if (funsNeedStates(fd)) { // this also includes lazy constructors - val stType = CaseClassType(closureFactory.state, stparams) - val stParam = ValDef(FreshIdentifier("st@", stType)) - val retTypeWithState = - if (funsRetStates(fd)) - TupleType(Seq(nretType, stType)) - else - nretType - // the type parameters will be unified later - new FunDef(FreshIdentifier(fd.id.name), fd.tparams ++ (stparams map TypeParameterDef), - nparams :+ stParam, retTypeWithState) - // body of these functions are defined later - } else { - new FunDef(FreshIdentifier(fd.id.name), fd.tparams ++ (stparams map TypeParameterDef), - nparams, nretType) - } - // copy annotations - fd.flags.foreach(nfd.addFlag(_)) - (fd -> nfd) - }.toMap - - // were used for optimization purposes - // lazy val uiTarges = funMap.collect { - // case (k, v) if closureFactory.isLazyOp(k) => - // val ufd = new FunDef(FreshIdentifier(v.id.name, v.id.getType, true), - // v.tparams, v.params, v.returnType) - // (k -> ufd) - // }.toMap - - //TODO: Optimization: we can omit come functions whose translations will not be recursive. - def takesStateButIndep(fd: FunDef) = - funsNeedStates(fd) && funsManager.hasStateIndependentBehavior(fd) - - /** - * A set of uninterpreted functions that are used - * in specs to ensure that value part is independent of the state - */ - val uiFuncs: Map[FunDef, (FunDef, Option[FunDef])] = funMap.collect { - case (k, v) if takesStateButIndep(k) => - val params = v.params.take(k.params.size) // ignore the state params - val retType = - if (funsRetStates(k)) { - val TupleType(valtype +: _) = v.returnType - valtype - } else v.returnType - val tparams = (params.map(_.getType) :+ retType).flatMap(getTypeParameters(_)).distinct - val tparamsDef = tparams.map(TypeParameterDef(_)) - val ufd = new FunDef(FreshIdentifier(v.id.name + "UI"), tparamsDef, params, retType) - - // we also need to create a function that assumes that the result equals - // the uninterpreted function - val valres = ValDef(FreshIdentifier("valres", retType)) - val pred = new FunDef(FreshIdentifier(v.id.name + "ValPred"), tparamsDef, - params :+ valres, BooleanType) - val resid = FreshIdentifier("res", pred.returnType) - pred.fullBody = Ensuring( - Equals(valres.id.toVariable, - FunctionInvocation(TypedFunDef(ufd, tparams), params.map(_.id.toVariable))), // res = funUI(..) - Lambda(Seq(ValDef(resid)), resid.toVariable)) // holds - pred.addFlag(Annotation("axiom", Seq())) // @axiom is similar to @library - (k -> (ufd, Some(pred))) - - case (k, v) if lazyops(k) => - // here 'v' cannot for sure take state params, otherwise it must be state indep - if (funsNeedStates(k)) - throw new IllegalStateException("Lazyop that has a state dependent behavior: " + k) - else { - val tparams = (v.params.map(_.getType) :+ v.returnType).flatMap(getTypeParameters(_)).distinct - val tparamsDef = tparams.map(TypeParameterDef(_)) - val ufd = new FunDef(FreshIdentifier(v.id.name + "UI"), tparamsDef, v.params, v.returnType) - (k -> (ufd, None)) - } - }.toMap - - /** - * A set of uninterpreted functions that return fixed but uninterpreted states - * Note: here I am using mutation on purpose to create uninterpreted states on - * demand. - */ - var uiStateFuns = Map[String, FunDef]() - def getUninterpretedState(lazyTypename: String, tparams: Seq[TypeParameter]) = { - val uiStateFun = if (uiStateFuns.contains(lazyTypename)) { - uiStateFuns(lazyTypename) - } else { - // create a body-less fundef that will return a state - val stType = CaseClassType(closureFactory.state, closureFactory.state.tparams.map(_.tp)) - val fd = new FunDef(FreshIdentifier("ui" + lazyTypename), closureFactory.state.tparams, Seq(), stType) - uiStateFuns += (lazyTypename -> fd) - fd - } - FunctionInvocation(TypedFunDef(uiStateFun, tparams), Seq()) - } - - def replaceLazyTypes(t: TypeTree): TypeTree = { - unwrapLazyType(t) match { - case None => - val NAryType(tps, tcons) = t - tcons(tps map replaceLazyTypes) - case Some(btype) => - val absClass = closureFactory.absClosureType(typeNameWOParams(btype)) - val ntype = AbstractClassType(absClass, getTypeParameters(btype)) - val NAryType(tps, tcons) = ntype - tcons(tps map replaceLazyTypes) - } - } - - /** - * Create dispatch functions for each lazy type. - * Note: the dispatch functions will be annotated as library so that - * their pre/posts are not checked (the fact that they hold are verified separately) - * Note that by using 'assume-pre' we can also assume the preconditions of closures at - * the call-sites. - */ - val evalFunctions = lazyTnames.map { tname => - val ismem = closureFactory.isMemType(tname) - val tpe = /*freshenTypeArguments(*/ closureFactory.lazyType(tname) //) - val absdef = closureFactory.absClosureType(tname) - val cdefs = closureFactory.closures(tname) - - // construct parameters and return types - val recvTparams = getTypeParameters(tpe) - val stTparams = closureFactory.state.tparams.map(_ => TypeParameter.fresh("P@")) - val param1 = FreshIdentifier("cl", AbstractClassType(absdef, recvTparams)) - val stType = CaseClassType(closureFactory.state, stTparams) - val param2 = FreshIdentifier("st@", stType) - val retType = TupleType(Seq(tpe, stType)) - - // create a eval function - val dfun = new FunDef(FreshIdentifier(evalFunctionName(absdef.id.name)), - (recvTparams ++ stTparams) map TypeParameterDef, - Seq(ValDef(param1), ValDef(param2)), retType) - - //println("Creating eval function: "+dfun) - // assign body of the eval fucntion - // create a match case to switch over the possible class defs and invoke the corresponding functions - val bodyMatchCases = cdefs map { cdef => - val ctype = CaseClassType(cdef, recvTparams) // we assume that the type parameters of cdefs are same as absdefs - val binder = FreshIdentifier("t", ctype) - val pattern = InstanceOfPattern(Some(binder), ctype) - // create a body of the match - // the last field represents the result (only if the type is a susp type) - val flds = - if (!ismem) cdef.fields.dropRight(1) - else cdef.fields - val args = flds map { fld => - CaseClassSelector(ctype, binder.toVariable, fld.id) - } - val op = closureFactory.caseClassToOp(cdef) - val targetFun = funMap(op) - // invoke the target fun with appropriate values - val invoke = - if (funsNeedStates(op)) - FunctionInvocation(TypedFunDef(targetFun, recvTparams ++ stTparams), args :+ param2.toVariable) - else - FunctionInvocation(TypedFunDef(targetFun, recvTparams), args) - val invokeRes = FreshIdentifier("dres", invoke.getType) - //println(s"invoking function $targetFun with args $args") - val updateFun = TypedFunDef(closureFactory.stateUpdateFuns(tname), stTparams) - val (valPart, stPart) = - if (funsRetStates(op)) { - val invokeSt = TupleSelect(invokeRes.toVariable, 2) - val nst = FunctionInvocation(updateFun, Seq(invokeSt, binder.toVariable)) - (TupleSelect(invokeRes.toVariable, 1), nst) - } else { - val nst = FunctionInvocation(updateFun, Seq(param2.toVariable, binder.toVariable)) - (invokeRes.toVariable, nst) - } - val rhs = Let(invokeRes, invoke, Tuple(Seq(valPart, stPart))) - MatchCase(pattern, None, rhs) - } - val cases = if (!ismem) { - // create a new match case for eager evaluation - val eagerDef = closureFactory.eagerClosure(tname).get - val ctype = CaseClassType(eagerDef, recvTparams) - val binder = FreshIdentifier("t", ctype) - // create a body of the match - val valPart = CaseClassSelector(ctype, binder.toVariable, eagerDef.fields(0).id) - val rhs = Tuple(Seq(valPart, param2.toVariable)) // state doesn't change for eager closure - MatchCase(InstanceOfPattern(Some(binder), ctype), None, rhs) +: bodyMatchCases - } else - bodyMatchCases - dfun.body = Some(MatchExpr(param1.toVariable, cases)) - dfun.addFlag(Annotation("axiom", Seq())) - (tname -> dfun) - }.toMap - - /** - * These are evalFunctions that do not affect the state. - */ - val computeFunctions = evalFunctions.map { - case (tname, evalfd) => - val tpe = /*freshenTypeArguments(*/ closureFactory.lazyType(tname) //) - val param1 = evalfd.params.head - val fun = new FunDef(FreshIdentifier(evalfd.id.name + "*", Untyped), - evalfd.tparams, Seq(param1), tpe) - val stTparams = evalfd.tparams.collect { - case tpd if isPlaceHolderTParam(tpd.tp) => tpd.tp - } - val uiState = getUninterpretedState(tname, stTparams) - val invoke = FunctionInvocation(TypedFunDef(evalfd, evalfd.tparams.map(_.tp)), - Seq(param1.id.toVariable, uiState)) - fun.body = Some(TupleSelect(invoke, 1)) - fun.addFlag(IsInlined) - (tname -> fun) - }.toMap - - /** - * Create closure construction functions that ensures a postcondition. - * They are defined for each lazy class since it avoids generics and - * simplifies the type inference (which is not full-fledged in Leon) - */ - val closureCons = lazyTnames.map { tname => - val adt = closureFactory.absClosureType(tname) - val param1Type = AbstractClassType(adt, adt.tparams.map(_.tp)) - val param1 = FreshIdentifier("cc", param1Type) - val stTparams = closureFactory.state.tparams.map(_ => TypeParameter.fresh("P@")) - val stType = CaseClassType(closureFactory.state, stTparams) - val param2 = FreshIdentifier("st@", stType) - val tparamdefs = adt.tparams ++ (stTparams map TypeParameterDef) - val fun = new FunDef(FreshIdentifier(closureConsName(tname)), tparamdefs, - Seq(ValDef(param1), ValDef(param2)), param1Type) - fun.body = Some(param1.toVariable) - fun.addFlag(IsInlined) - // assert that the closure in unevaluated if useRefEquality is enabled! is this needed ? - // not supported as of now - /*if (refEq) { - val resid = FreshIdentifier("res", param1Type) - val postbody = Not(ElementOfSet(resid.toVariable, param2.toVariable)) - fun.postcondition = Some(Lambda(Seq(ValDef(resid)), postbody)) - fun.addFlag(Annotation("axiom", Seq())) - }*/ - (tname -> fun) - }.toMap - - def mapExpr(expr: Expr)(implicit stTparams: Seq[TypeParameter]): (Option[Expr] => Expr, Boolean) = expr match { - - case finv @ FunctionInvocation(_, Seq(Lambda(_, FunctionInvocation(TypedFunDef(argfd, tparams), args)))) // lazy construction ? - if isLazyInvocation(finv)(p) => - val op = (nargs: Seq[Expr]) => ((st: Option[Expr]) => { - val adt = closureFactory.closureOfLazyOp(argfd) - // create lets to bind the nargs to variables - val (flatArgs, letCons) = nargs.foldRight((Seq[Variable](), (e: Expr) => e)) { - case (narg, (fargs, lcons)) => - val id = FreshIdentifier("a", narg.getType, true) - (id.toVariable +: fargs, e => Let(id, narg, lcons(e))) - } - val ccArgs = if (!isMemoized(argfd)) { - // construct a value for the result (an uninterpreted function) - val resval = FunctionInvocation(TypedFunDef(uiFuncs(argfd)._1, tparams), flatArgs) - flatArgs :+ resval - } else - flatArgs - val cc = CaseClass(CaseClassType(adt, tparams), ccArgs) - val baseLazyTypeName = closureFactory.lazyTypeNameOfClosure(adt) - val fi = FunctionInvocation(TypedFunDef(closureCons(baseLazyTypeName), tparams ++ stTparams), - Seq(cc, st.get)) - letCons(fi) // this could be 'fi' wrapped into lets - }, false) - mapNAryOperator(args, op) - - case cc @ CaseClass(_, Seq(FunctionInvocation(TypedFunDef(argfd, tparams), args))) if isMemCons(cc)(p) => - // in this case argfd is a memoized function - val op = (nargs: Seq[Expr]) => ((st: Option[Expr]) => { - val adt = closureFactory.closureOfLazyOp(argfd) - CaseClass(CaseClassType(adt, tparams), nargs) - }, false) - mapNAryOperator(args, op) - - case finv @ FunctionInvocation(_, Seq(Lambda(_, arg))) if isEagerInvocation(finv)(p) => - // here arg is guaranteed to be a variable - ((st: Option[Expr]) => { - val rootType = bestRealType(arg.getType) - val tname = typeNameWOParams(rootType) - val tparams = getTypeArguments(rootType) - val eagerClosure = closureFactory.eagerClosure(tname).get - CaseClass(CaseClassType(eagerClosure, tparams), Seq(arg)) - }, false) - - case finv @ FunctionInvocation(_, args) if isEvaluatedInvocation(finv)(p) => // isEval function ? - val op = (nargs: Seq[Expr]) => ((stOpt: Option[Expr]) => { - val narg = nargs(0) // there must be only one argument here - val baseType = unwrapLazyType(narg.getType).get - val tname = typeNameWOParams(baseType) - val st = stOpt.get - val stType = CaseClassType(closureFactory.state, stTparams) - val cls = closureFactory.selectFieldOfState(tname, st, stType) - val memberTest = ElementOfSet(narg, cls) - val subtypeTest = IsInstanceOf(narg, - CaseClassType(closureFactory.eagerClosure(tname).get, getTypeArguments(baseType))) - Or(memberTest, subtypeTest) - }, false) - mapNAryOperator(args, op) - - case finv @ FunctionInvocation(_, args) if isCachedInv(finv)(p) => // isCached function ? - val baseType = unwrapLazyType(args(0).getType).get - val op = (nargs: Seq[Expr]) => ((stOpt: Option[Expr]) => { - val narg = nargs(0) // there must be only one argument here - //println("narg: "+narg+" type: "+narg.getType) - val tname = typeNameWOParams(baseType) - val st = stOpt.get - val stType = CaseClassType(closureFactory.state, stTparams) - val cls = closureFactory.selectFieldOfState(tname, st, stType) - ElementOfSet(narg, cls) - }, false) - mapNAryOperator(args, op) - - case finv @ FunctionInvocation(_, Seq(recvr, funcArg)) if isSuspInvocation(finv)(p) => - ((st: Option[Expr]) => { - // `funcArg` is a closure whose body is a function invocation - //TODO: make sure the function is not partially applied in the body - funcArg match { - case Lambda(_, FunctionInvocation(TypedFunDef(fd, _), _)) => - // retrieve the case-class for the operation from the factory - val caseClass = closureFactory.closureOfLazyOp(fd) - val targs = TypeUtil.getTypeArguments(unwrapLazyType(recvr.getType).get) - val caseClassType = CaseClassType(caseClass, targs) - IsInstanceOf(recvr, caseClassType) - case _ => - throw new IllegalArgumentException("The argument to isSuspension should be " + - "a partially applied function of the form: <method-name> _") - } - }, false) - - case finv @ FunctionInvocation(_, Seq(recvr, stArgs @ _*)) if isWithStateFun(finv)(p) => - // recvr is a `WithStateCaseClass` and `stArgs` could be arbitrary expressions that return values of types of fields of state - val numStates = closureFactory.state.fields.size - if (stArgs.size != numStates) - throw new IllegalStateException("The arguments to `withState` should equal the number of states: " + numStates) - - val CaseClass(_, Seq(exprNeedingState)) = recvr - val (nexprCons, exprReturnsState) = mapExpr(exprNeedingState) - val nstConses = stArgs map mapExpr - if (nstConses.exists(_._2)) // any 'stArg' returning state - throw new IllegalStateException("One of the arguments to `withState` returns a new state, which is not supported: " + finv) - else { - ((st: Option[Expr]) => { - // create a new state using the nstConses - val nstSets = nstConses map { case (stCons, _) => stCons(st) } - val tparams = nstSets.flatMap(nst => getTypeParameters(nst.getType)).distinct - val nst = CaseClass(CaseClassType(closureFactory.state, tparams), nstSets) - nexprCons(Some(nst)) - }, exprReturnsState) - } - - case finv @ FunctionInvocation(_, args) if isValueInvocation(finv)(p) => // is value function ? - val op = (nargs: Seq[Expr]) => ((stOpt: Option[Expr]) => { - val st = stOpt.get - val baseType = unwrapLazyType(nargs(0).getType).get // there must be only one argument here - val tname = typeNameWOParams(baseType) - val dispFun = evalFunctions(tname) - val tparams = (getTypeParameters(baseType) ++ stTparams).distinct - FunctionInvocation(TypedFunDef(dispFun, tparams), nargs :+ st) - }, true) - mapNAryOperator(args, op) - - case finv @ FunctionInvocation(_, args) if isStarInvocation(finv)(p) => // is * function ? - val op = (nargs: Seq[Expr]) => ((st: Option[Expr]) => { - val baseType = unwrapLazyType(nargs(0).getType).get // there must be only one argument here - val tname = typeNameWOParams(baseType) - val dispFun = computeFunctions(tname) - val tparams = getTypeParameters(baseType) ++ stTparams - FunctionInvocation(TypedFunDef(dispFun, tparams), nargs) - }, false) - mapNAryOperator(args, op) - - case FunctionInvocation(TypedFunDef(fd, tparams), args) if funMap.contains(fd) => - mapNAryOperator(args, - (nargs: Seq[Expr]) => ((st: Option[Expr]) => { - val stArgs = - if (funsNeedStates(fd)) { - st.toSeq - } else Seq() - val stparams = - if (funsNeedStates(fd) || starCallers(fd)) { - stTparams - } else Seq() - FunctionInvocation(TypedFunDef(funMap(fd), tparams ++ stparams), nargs ++ stArgs) - }, funsRetStates(fd))) - - case Let(id, value, body) => - val (valCons, valUpdatesState) = mapExpr(value) - val (bodyCons, bodyUpdatesState) = mapExpr(body) - ((st: Option[Expr]) => { - val nval = valCons(st) - if (valUpdatesState) { - val freshid = FreshIdentifier(id.name, nval.getType, true) - val nextState = TupleSelect(freshid.toVariable, 2) - val transBody = replace(Map(id.toVariable -> TupleSelect(freshid.toVariable, 1)), - bodyCons(Some(nextState))) - if (bodyUpdatesState) - Let(freshid, nval, transBody) - else - Let(freshid, nval, Tuple(Seq(transBody, nextState))) - } else - Let(id, nval, bodyCons(st)) - }, valUpdatesState || bodyUpdatesState) - - case IfExpr(cond, thn, elze) => - val (condCons, condState) = mapExpr(cond) - val (thnCons, thnState) = mapExpr(thn) - val (elzeCons, elzeState) = mapExpr(elze) - ((st: Option[Expr]) => { - val (ncondCons, nst) = - if (condState) { - val cndExpr = condCons(st) - val bder = FreshIdentifier("c", cndExpr.getType) - val condst = TupleSelect(bder.toVariable, 2) - ((th: Expr, el: Expr) => - Let(bder, cndExpr, IfExpr(TupleSelect(bder.toVariable, 1), th, el)), - Some(condst)) - } else { - ((th: Expr, el: Expr) => IfExpr(condCons(st), th, el), st) - } - val nelze = - if ((condState || thnState) && !elzeState) - Tuple(Seq(elzeCons(nst), nst.get)) - else elzeCons(nst) - val nthn = - if (!thnState && (condState || elzeState)) - Tuple(Seq(thnCons(nst), nst.get)) - else thnCons(nst) - ncondCons(nthn, nelze) - }, condState || thnState || elzeState) - - case MatchExpr(scr, cases) => - val (scrCons, scrUpdatesState) = mapExpr(scr) - val casesRes = cases.foldLeft(Seq[(Option[Expr] => Expr, Boolean)]()) { - case (acc, MatchCase(pat, None, rhs)) => - acc :+ mapExpr(rhs) - case mcase => - throw new IllegalStateException("Match case with guards are not supported yet: " + mcase) - } - val casesUpdatesState = casesRes.exists(_._2) - ((st: Option[Expr]) => { - val scrExpr = scrCons(st) - val (nscrCons, scrst) = - if (scrUpdatesState) { - val bder = FreshIdentifier("scr", scrExpr.getType) - val scrst = Some(TupleSelect(bder.toVariable, 2)) - ((ncases: Seq[MatchCase]) => - Let(bder, scrExpr, MatchExpr(TupleSelect(bder.toVariable, 1), ncases)), - scrst) - } else { - //println(s"Scrutiny does not update state: current state: $st") - ((ncases: Seq[MatchCase]) => MatchExpr(scrExpr, ncases), st) - } - val ncases = (cases zip casesRes).map { - case (MatchCase(pat, None, _), (caseCons, caseUpdatesState)) => - val nrhs = - if ((scrUpdatesState || casesUpdatesState) && !caseUpdatesState) - Tuple(Seq(caseCons(scrst), scrst.get)) - else caseCons(scrst) - MatchCase(pat, None, nrhs) - } - nscrCons(ncases) - }, scrUpdatesState || casesUpdatesState) - - // need to reset types in the case of case class constructor calls - case CaseClass(cct, args) => - val ntype = replaceLazyTypes(cct).asInstanceOf[CaseClassType] - mapNAryOperator(args, - (nargs: Seq[Expr]) => ((st: Option[Expr]) => CaseClass(ntype, nargs), false)) - - // need to reset field ids of case class select - case CaseClassSelector(cct, clExpr, fieldId) if fieldMap.contains(fieldId) => - val ntype = replaceLazyTypes(cct).asInstanceOf[CaseClassType] - val nfield = fieldMap(fieldId) - mapNAryOperator(Seq(clExpr), - (nargs: Seq[Expr]) => ((st: Option[Expr]) => CaseClassSelector(ntype, nargs.head, nfield), false)) - - case Operator(args, op) => - // here, 'op' itself does not create a new state - mapNAryOperator(args, - (nargs: Seq[Expr]) => ((st: Option[Expr]) => op(nargs), false)) - - case t: Terminal => (_ => t, false) - } - - def mapNAryOperator(args: Seq[Expr], op: Seq[Expr] => (Option[Expr] => Expr, Boolean))(implicit stTparams: Seq[TypeParameter]) = { - // create n variables to model n lets - val letvars = args.map(arg => FreshIdentifier("arg", arg.getType, true).toVariable) - (args zip letvars).foldRight(op(letvars)) { - case ((arg, letvar), (accCons, stUpdatedBefore)) => - val (argCons, stUpdateFlag) = mapExpr(arg) - val cl = if (!stUpdateFlag) { - // here arg does not affect the newstate - (st: Option[Expr]) => replace(Map(letvar -> argCons(st)), accCons(st)) - } else { - // here arg does affect the newstate - (st: Option[Expr]) => - { - val narg = argCons(st) - val argres = FreshIdentifier("a", narg.getType, true).toVariable - val nstate = Some(TupleSelect(argres, 2)) - val letbody = - if (stUpdatedBefore) accCons(nstate) // here, 'acc' already returns a superseeding state - else Tuple(Seq(accCons(nstate), nstate.get)) // here, 'acc; only returns the result - Let(argres.id, narg, - Let(letvar.id, TupleSelect(argres, 1), letbody)) - } - } - (cl, stUpdatedBefore || stUpdateFlag) - } - } - - def fieldsOfState(st: Expr, stType: CaseClassType): Seq[Expr] = { - closureFactory.lazyTypeNames.map { tn => - closureFactory.selectFieldOfState(tn, st, stType) - } - } - - def assignBodiesToFunctions = { - val paramMap: Map[Expr, Expr] = idmap.map(e => (e._1.toVariable -> e._2.toVariable)) - funMap foreach { - case (fd, nfd) => - //println("Considering function: "+fd) - // Here, using name to identify 'state' parameters - val stateParam = nfd.params.collectFirst { - case vd if isStateParam(vd.id) => - vd.id.toVariable - } - val stType = stateParam.map(_.getType.asInstanceOf[CaseClassType]) - // Note: stTparams may be provided even if stParam is not required. - val stTparams = nfd.tparams.collect { - case tpd if isPlaceHolderTParam(tpd.tp) => tpd.tp - } - val (nbodyFun, bodyUpdatesState) = mapExpr(fd.body.get)(stTparams) - val nbody = nbodyFun(stateParam) - val bodyWithState = - if (!bodyUpdatesState && funsRetStates(fd)) - Tuple(Seq(nbody, stateParam.get)) - else - nbody - nfd.body = Some(simplifyLets(replace(paramMap, bodyWithState))) - //println(s"Body of ${fd.id.name} after conversion&simp: ${nfd.body}") - - // Important: specifications use lazy semantics but - // their state changes are ignored after their execution. - // This guarantees their observational purity/transparency - // collect class invariants that need to be added - if (fd.hasPrecondition) { - val (npreFun, preUpdatesState) = mapExpr(fd.precondition.get)(stTparams) - val npre = replace(paramMap, npreFun(stateParam)) - nfd.precondition = - if (preUpdatesState) - Some(TupleSelect(npre, 1)) // ignore state updated by pre - else Some(npre) - } - - // create a new result variable - val newres = - if (fd.hasPostcondition) { - val Lambda(Seq(ValDef(r)), _) = fd.postcondition.get - FreshIdentifier(r.name, bodyWithState.getType) - } else FreshIdentifier("r", nfd.returnType) - - // create an output state map - val outState = - if (bodyUpdatesState || funsRetStates(fd)) { - Some(TupleSelect(newres.toVariable, 2)) - } else - stateParam - - // create a specification that relates input-output states - val stateRel = - if (funsRetStates(fd)) { // add specs on states - val instates = fieldsOfState(stateParam.get, stType.get) - val outstates = fieldsOfState(outState.get, stType.get) - val stateRel = - if (fd.annotations.contains("invstate")) Equals.apply _ - else SubsetOf.apply _ - Some(createAnd((instates zip outstates).map(p => stateRel(p._1, p._2)))) - } else None - //println("stateRel: "+stateRel) - - // create a predicate that ensures that the value part is independent of the state - val valRel = - if (takesStateButIndep(fd)) { // add specs on value - val uipred = uiFuncs(fd)._2.get - val args = nfd.params.take(fd.params.size).map(_.id.toVariable) - val retarg = - if (funsRetStates(fd)) - TupleSelect(newres.toVariable, 1) - else newres.toVariable - Some(FunctionInvocation(TypedFunDef(uipred, nfd.tparams.map(_.tp)), - args :+ retarg)) - } else None - - val targetPost = - if (fd.hasPostcondition) { - val Lambda(Seq(ValDef(resid)), post) = fd.postcondition.get - val resval = - if (bodyUpdatesState || funsRetStates(fd)) - TupleSelect(newres.toVariable, 1) - else newres.toVariable - // thread state through postcondition - val (npostFun, postUpdatesState) = mapExpr(post)(stTparams) - // bind calls to instate and outstate calls to their respective values - val tpost = simplePostTransform { - case e if LazinessUtil.isInStateCall(e)(p) => - val baseType = getTypeArguments(e.getType).head - val tname = typeNameWOParams(baseType) - closureFactory.selectFieldOfState(tname, stateParam.get, stType.get) - - case e if LazinessUtil.isOutStateCall(e)(p) => - val baseType = getTypeArguments(e.getType).head - val tname = typeNameWOParams(baseType) - closureFactory.selectFieldOfState(tname, outState.get, stType.get) - - case e => e - }(replace(paramMap ++ Map(resid.toVariable -> resval), npostFun(outState))) - - val npost = - if (postUpdatesState) { - TupleSelect(tpost, 1) // ignore state updated by post - } else - tpost - Some(npost) - } else { - None - } - nfd.postcondition = Some(Lambda(Seq(ValDef(newres)), - createAnd(stateRel.toList ++ valRel.toList ++ targetPost.toList))) - } - } - - def assignContractsForEvals = evalFunctions.foreach { - case (tname, evalfd) => - val ismem = closureFactory.isMemType(tname) - val cdefs = closureFactory.closures(tname) - val recvTparams = getTypeParameters(evalfd.params.head.getType) - val postres = FreshIdentifier("res", evalfd.returnType) - val postMatchCases = cdefs map { cdef => - // create a body of the match (which asserts that return value equals the uninterpreted function) - // and also that the result field equals the result - val op = closureFactory.lazyopOfClosure(cdef) - val ctype = CaseClassType(cdef, recvTparams) - val binder = FreshIdentifier("t", ctype) - val pattern = InstanceOfPattern(Some(binder), ctype) - // t.clres == res._1 - val clause1 = if (!ismem) { - val clresField = cdef.fields.last - Equals(TupleSelect(postres.toVariable, 1), - CaseClassSelector(ctype, binder.toVariable, clresField.id)) - } else - Util.tru - //res._1 == uifun(args) - val clause2 = if (takesStateButIndep(op)) { - val flds = - if (!ismem) cdef.fields.dropRight(1) - else cdef.fields - val args = flds map { - fld => CaseClassSelector(ctype, binder.toVariable, fld.id) - } - Some(Equals(TupleSelect(postres.toVariable, 1), - FunctionInvocation(TypedFunDef(uiFuncs(op)._1, recvTparams), args))) - } else None - val rhs = createAnd(clause1 +: clause2.toList) - MatchCase(pattern, None, rhs) - } - // create a default case ot match other cases - val default = MatchCase(WildcardPattern(None), None, Util.tru) - evalfd.postcondition = Some( - Lambda(Seq(ValDef(postres)), - MatchExpr(evalfd.params.head.toVariable, postMatchCases :+ default))) - } - - /** - * Overrides the types of the lazy fields in the case class definitions - * Note: here we reset CaseClass fields instead of having to duplicate the - * entire class hierarchy. - */ - var fieldMap = Map[Identifier, Identifier]() - def copyField(oldId: Identifier, tpe: TypeTree): Identifier = { - val freshid = FreshIdentifier(oldId.name, tpe) - fieldMap += (oldId -> freshid) - freshid - } - - def transformCaseClasses = p.definedClasses.foreach { - case ccd: CaseClassDef if !ccd.flags.contains(Annotation("library", Seq())) && - ccd.fields.exists(vd => isLazyType(vd.getType)) => - val nfields = ccd.fields.map { fld => - unwrapLazyType(fld.getType) match { - case None => fld - case Some(btype) => - val clType = closureFactory.absClosureType(typeNameWOParams(btype)) - val typeArgs = getTypeArguments(btype) - //println(s"AbsType: $clType type args: $typeArgs") - val adtType = AbstractClassType(clType, typeArgs) - ValDef(copyField(fld.id, adtType)) - } - } - ccd.setFields(nfields) - case _ => ; - } - - def apply: Program = { - // TODO: for now pick a arbitrary point to add new defs. But ideally the lazy closure will be added to a separate module - // and imported every where - val anchor = funMap.values.last - transformCaseClasses - assignBodiesToFunctions - assignContractsForEvals - ProgramUtil.addDefs( - copyProgram(p, - (defs: Seq[Definition]) => defs.flatMap { - case fd: FunDef if funMap.contains(fd) => - uiFuncs.get(fd) match { - case Some((funui, Some(predui))) => - Seq(funMap(fd), funui, predui) - case Some((funui, _)) => - Seq(funMap(fd), funui) - case _ => Seq(funMap(fd)) - } - case d => Seq(d) - }), - closureFactory.allClosuresAndParents ++ Seq(closureFactory.state) ++ - closureCons.values ++ evalFunctions.values ++ - computeFunctions.values ++ uiStateFuns.values ++ - closureFactory.stateUpdateFuns.values, anchor) - } -} diff --git a/src/main/scala/leon/laziness/LazyClosureFactory.scala b/src/main/scala/leon/laziness/LazyClosureFactory.scala deleted file mode 100644 index ae7da4aa2bcafd6b3234569f6244038cb3a6b950..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/laziness/LazyClosureFactory.scala +++ /dev/null @@ -1,210 +0,0 @@ -package leon -package laziness - -import invariant.util._ -import purescala.Common._ -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Types._ -import leon.invariant.util.TypeUtil._ -import LazinessUtil._ - -//case class ClosureData(tpe: TypeTree, absDef: AbstractClassDef, caseClass: Seq[CaseClassDef]) - -class LazyClosureFactory(p: Program) { - val debug = false - implicit val prog = p - - /** - * all the operations that could be lazily evaluated - */ - val lazyopsList = p.definedFunctions.flatMap { - case fd if (fd.hasBody) => - filter(isLazyInvocation)(fd.body.get) map { - case FunctionInvocation(_, Seq(Lambda(_, FunctionInvocation(tfd, _)))) => tfd.fd - } - case _ => Seq() - }.distinct - - if (debug) { - println("Lazy operations found: \n" + lazyopsList.map(_.id).mkString("\n")) - } - - /** - * Create a mapping from types to the lazy/mem ops that may produce a value of that type - * TODO: relax the requirement that type parameters of return type of a function - * lazy evaluated/memoized should include all of its type parameters. - */ - private def closuresForOps(ops: List[FunDef]) = { - // using tpe name below to avoid mismatches due to type parameters - val tpeToLazyops = ops.groupBy(lop => typeNameWOParams(lop.returnType)) - if (debug) { - println("Type to Ops: " + tpeToLazyops.map { - case (k, v) => s"$k --> ${v.map(_.id).mkString(",")}" - }.mkString("\n")) - } - val tpeToAbsClass = tpeToLazyops.map { - case (tpename, ops) => - val tpcount = getTypeParameters(ops(0).returnType).size - //Safety check: - // (a) check that tparams of all ops should match and should be equal to the tparams of return type - // (b) all are either memoized or all are lazy - ops.foreach { op => - if (op.tparams.size != tpcount) - throw new IllegalStateException(s"Type parameters of the lazy/memoized operation ${op.id.name}" + - "should match the type parameters of the return type of the operation") - } - if(ops.size >= 2) { - ops.tail.forall(op => isMemoized(op) == isMemoized(ops.head)) - } - val absTParams = (1 to tpcount).map(i => TypeParameterDef(TypeParameter.fresh("T" + i))) - tpename -> new AbstractClassDef(FreshIdentifier(typeNameToADTName(tpename), Untyped), - absTParams, None) - }.toMap - var opToAdt = Map[FunDef, CaseClassDef]() - val tpeToADT = tpeToLazyops map { - case (tpename, ops) => - val ismem = isMemoized(ops(0)) - val baseT = ops(0).returnType //TODO: replace targs here i.e, use clresType ? - val absClass = tpeToAbsClass(tpename) - val absTParamsDef = absClass.tparams - val absTParams = absTParamsDef.map(_.tp) - - // create a case class for every operation - val cdefs = ops map { opfd => - assert(opfd.tparams.size == absTParamsDef.size) - val absType = AbstractClassType(absClass, opfd.tparams.map(_.tp)) - val classid = FreshIdentifier(opNameToCCName(opfd.id.name), Untyped) - val cdef = new CaseClassDef(classid, opfd.tparams, Some(absType), isCaseObject = false) - val nfields = opfd.params.map { vd => - val fldType = vd.getType - unwrapLazyType(fldType) match { - case None => - ValDef(FreshIdentifier(vd.id.name, fldType)) - case Some(btype) => - val btname = typeNameWOParams(btype) - val baseAbs = tpeToAbsClass(btname) - ValDef(FreshIdentifier(vd.id.name, - AbstractClassType(baseAbs, getTypeParameters(btype)))) - } - } - if (!ismem) { - // add a result field as well - val resField = ValDef(FreshIdentifier("clres", opfd.returnType)) - cdef.setFields(nfields :+ resField) - } else - cdef.setFields(nfields) - absClass.registerChild(cdef) - opToAdt += (opfd -> cdef) - cdef - } - if (!ismem) { - // create a case class to represent eager evaluation (when handling lazy ops) - val clresType = ops.head.returnType match { - case NAryType(tparams, tcons) => tcons(absTParams) - } - val eagerid = FreshIdentifier("Eager" + TypeUtil.typeNameWOParams(clresType)) - val eagerClosure = new CaseClassDef(eagerid, absTParamsDef, - Some(AbstractClassType(absClass, absTParams)), isCaseObject = false) - eagerClosure.setFields(Seq(ValDef(FreshIdentifier("a", clresType)))) - absClass.registerChild(eagerClosure) - (tpename -> (baseT, absClass, cdefs, Some(eagerClosure), ismem)) - } else - (tpename -> (baseT, absClass, cdefs, None, ismem)) - } - /*tpeToADT.foreach { - case (k, v) => println(s"$k --> ${ (v._2 +: v._3).mkString("\n\t") }") - }*/ - (tpeToADT, opToAdt) - } - - private val (tpeToADT, opToCaseClass) = closuresForOps(lazyopsList) - - // this fixes an ordering on lazy types - val lazyTypeNames = tpeToADT.keys.toSeq - val lazyops = opToCaseClass.keySet - lazy val caseClassToOp = opToCaseClass map { case (k, v) => v -> k } - val allClosuresAndParents: Seq[ClassDef] = tpeToADT.values.flatMap(v => (v._2 +: v._3) ++ v._4.toList).toSeq - val allClosureSet = allClosuresAndParents.toSet - - // lazy operations - def lazyType(tn: String) = tpeToADT(tn)._1 - def isMemType(tn: String) = tpeToADT(tn)._5 - def absClosureType(tn: String) = tpeToADT(tn)._2 - def closures(tn: String) = tpeToADT(tn)._3 - def eagerClosure(tn: String) = tpeToADT(tn)._4 - def lazyopOfClosure(cl: CaseClassDef) = caseClassToOp(cl) - def closureOfLazyOp(op: FunDef) = opToCaseClass(op) - def isLazyOp(op: FunDef) = opToCaseClass.contains(op) - def isClosureType(cd: ClassDef) = allClosureSet.contains(cd) - - /** - * Here, the lazy type name is recovered from the closure's name. - * This avoids the use of additional maps. - */ - def lazyTypeNameOfClosure(cl: CaseClassDef) = adtNameToTypeName(cl.parent.get.classDef.id.name) - - /** - * Define a state as an ADT whose fields are sets of closures. - * Note that we need to ensure that there are state ADT is not recursive. - */ - val state = { - var tparams = Seq[TypeParameter]() - var i = 0 - def freshTParams(n: Int): Seq[TypeParameter] = { - val start = i + 1 - i += n // create 'n' fresh ids - val nparams = (start to i).map(index => TypeParameter.fresh("T" + index)) - tparams ++= nparams - nparams - } - // field of the ADT - val fields = lazyTypeNames map { tn => - val absClass = absClosureType(tn) - val tparams = freshTParams(absClass.tparams.size) - val fldType = SetType(AbstractClassType(absClass, tparams)) - ValDef(FreshIdentifier(typeToFieldName(tn), fldType)) - } - val ccd = new CaseClassDef(FreshIdentifier("State@"), tparams map TypeParameterDef, None, false) - ccd.setFields(fields) - ccd - } - - def selectFieldOfState(tn: String, st: Expr, stType: CaseClassType) = { - val selName = typeToFieldName(tn) - stType.classDef.fields.find { fld => fld.id.name == selName } match { - case Some(fld) => - CaseClassSelector(stType, st, fld.id) - case _ => - throw new IllegalStateException(s"Cannot find a field of $stType with name: $selName") - } - } - - val stateUpdateFuns: Map[String, FunDef] = - lazyTypeNames.map { tn => - val fldname = typeToFieldName(tn) - val tparams = state.tparams.map(_.tp) - val stType = CaseClassType(state, tparams) - val param1 = FreshIdentifier("st@", stType) - val SetType(baseT) = stType.classDef.fields.find { fld => fld.id.name == fldname }.get.getType - val param2 = FreshIdentifier("cl", baseT) - - val updateFun = new FunDef(FreshIdentifier("updState" + tn), - state.tparams, Seq(ValDef(param1), ValDef(param2)), stType) - // create a body for the updateFun: - val nargs = state.fields.map { fld => - val fldSelect = CaseClassSelector(stType, param1.toVariable, fld.id) - if (fld.id.name == fldname) { - SetUnion(fldSelect, FiniteSet(Set(param2.toVariable), baseT)) // st@.tn + Set(param2) - } else { - fldSelect - } - } - val nst = CaseClass(stType, nargs) - updateFun.body = Some(nst) - // Inlining this seems to slow down verification. Why!! - //updateFun.addFlag(IsInlined) - (tn -> updateFun) - }.toMap -} diff --git a/src/main/scala/leon/laziness/LazyExpressionLifter.scala b/src/main/scala/leon/laziness/LazyExpressionLifter.scala deleted file mode 100644 index 1aa312c7ac87243f0aa63299466132a0ea12d600..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/laziness/LazyExpressionLifter.scala +++ /dev/null @@ -1,252 +0,0 @@ -package leon -package laziness - -import invariant.util._ -import invariant.structure.FunctionUtils._ -import purescala.Common._ -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Extractors._ -import purescala.Types._ -import purescala.TypeOps._ -import leon.invariant.util.TypeUtil._ -import LazinessUtil._ -import invariant.util.ProgramUtil._ -import FreeVariableFactory._ - -object LazyExpressionLifter { - - /** - * convert the argument of every lazy constructors to a procedure - */ - var globalId = 0 - def freshFunctionNameForArg = { - globalId += 1 - "lazyarg" + globalId - } - - /** - * (a) The functions lifts arguments of '$' to functions - * (b) lifts eager computations to lazy computations if necessary - * (c) converts memoization to lazy evaluation - * (d) Adds unique references to programs that create lazy closures. - */ - def liftLazyExpressions(prog: Program, createUniqueIds: Boolean = false): Program = { - - lazy val funsMan = new LazyFunctionsManager(prog) - lazy val needsId = funsMan.callersnTargetOfLazyCons - - var newfuns = Map[ExprStructure, (FunDef, ModuleDef)]() - val fdmap = ProgramUtil.userLevelFunctions(prog).collect { - case fd if fd.hasBody => - val nname = FreshIdentifier(fd.id.name) - val nfd = - if (createUniqueIds && needsId(fd)) { - val idparam = ValDef(FreshIdentifier("id@", fvType)) - new FunDef(nname, fd.tparams, fd.params :+ idparam, fd.returnType) - } else - new FunDef(nname, fd.tparams, fd.params, fd.returnType) - fd.flags.foreach(nfd.addFlag(_)) - (fd -> nfd) - }.toMap - - lazy val lazyFun = ProgramUtil.functionByFullName("leon.lazyeval.$", prog).get - lazy val valueFun = ProgramUtil.functionByFullName("leon.lazyeval.Lazy.value", prog).get - - var anchorDef: Option[FunDef] = None // a hack to find anchors - prog.modules.foreach { md => - def exprLifter(inmem: Boolean)(fl: Option[FreeVarListIterator])(expr: Expr) = expr match { - case finv @ FunctionInvocation(lazytfd, Seq(callByNameArg)) if isLazyInvocation(finv)(prog) => - val Lambda(Seq(), arg) = callByNameArg // extract the call-by-name parameter - arg match { - case _: FunctionInvocation => - finv - case _ => - val freevars = variablesOf(arg).toList - val tparams = freevars.map(_.getType).flatMap(getTypeParameters).distinct - val argstruc = new ExprStructure(arg) - val argfun = - if (newfuns.contains(argstruc)) { - newfuns(argstruc)._1 - } else { - //construct type parameters for the function - // note: we should make the base type of arg as the return type - val nname = FreshIdentifier(freshFunctionNameForArg, Untyped, true) - val tparamDefs = tparams map TypeParameterDef.apply - val params = freevars.map(ValDef(_)) - val retType = bestRealType(arg.getType) - val nfun = - if (createUniqueIds) { - val idparam = ValDef(FreshIdentifier("id@", fvType)) - new FunDef(nname, tparamDefs, params :+ idparam, retType) - } else - new FunDef(nname, tparamDefs, params, retType) - nfun.body = Some(arg) - newfuns += (argstruc -> (nfun, md)) - nfun - } - val fvVars = freevars.map(_.toVariable) - val params = - if (createUniqueIds) - fvVars :+ fl.get.nextExpr - else fvVars - FunctionInvocation(lazytfd, Seq(Lambda(Seq(), - FunctionInvocation(TypedFunDef(argfun, tparams), params)))) - } - - // is the argument of eager invocation not a variable ? - case finv @ FunctionInvocation(TypedFunDef(fd, Seq(tp)), cbn @ Seq(Lambda(Seq(), arg))) if isEagerInvocation(finv)(prog) => - val rootType = bestRealType(tp) - val ntps = Seq(rootType) - arg match { - case _: Variable => - FunctionInvocation(TypedFunDef(fd, ntps), cbn) - case _ => - val freshid = FreshIdentifier("t", rootType) - Let(freshid, arg, FunctionInvocation(TypedFunDef(fd, ntps), - Seq(Lambda(Seq(), freshid.toVariable)))) - } - - // is this an invocation of a memoized function ? - case FunctionInvocation(TypedFunDef(fd, targs), args) if isMemoized(fd) && !inmem => - // calling a memoized function is modeled as creating a lazy closure and forcing it - val tfd = TypedFunDef(fdmap.getOrElse(fd, fd), targs) - val finv = FunctionInvocation(tfd, args) - // enclose the call within the $ and force it - val susp = FunctionInvocation(TypedFunDef(lazyFun, Seq(tfd.returnType)), Seq(Lambda(Seq(), finv))) - FunctionInvocation(TypedFunDef(valueFun, Seq(tfd.returnType)), Seq(susp)) - - // every other function calls ? - case FunctionInvocation(TypedFunDef(fd, targs), args) if fdmap.contains(fd) => - val nargs = - if (createUniqueIds && needsId(fd)) - args :+ fl.get.nextExpr - else args - FunctionInvocation(TypedFunDef(fdmap(fd), targs), nargs) - - case e => e - } - md.definedFunctions.foreach { - case fd if fd.hasBody && !fd.isLibrary && !fd.isInvariant => - // create a free list iterator - val nfd = fdmap(fd) - val fliter = - if (createUniqueIds && needsId(fd)) { - if (!anchorDef.isDefined) - anchorDef = Some(nfd) - val initRef = nfd.params.last.id.toVariable - Some(getFreeListIterator(initRef)) - } else - None - - def rec(inmem: Boolean)(e: Expr): Expr = e match { - case Operator(args, op) => - val nargs = args map rec(inmem || isMemCons(e)(prog)) - exprLifter(inmem)(fliter)(op(nargs)) - } - if(fd.hasPrecondition) - nfd.precondition = Some(rec(true)(fd.precondition.get)) - if (fd.hasPostcondition) - nfd.postcondition = Some(rec(true)(fd.postcondition.get)) - nfd.body = Some(rec(false)(fd.body.get)) - case fd => - } - } - val progWithFuns = copyProgram(prog, (defs: Seq[Definition]) => defs.map { - case fd: FunDef => fdmap.getOrElse(fd, fd) - case d => d - }) - val progWithClasses = - if (createUniqueIds) ProgramUtil.addDefs(progWithFuns, fvClasses, anchorDef.get) - else progWithFuns - if (!newfuns.isEmpty) { - val modToNewDefs = newfuns.values.groupBy(_._2).map { case (k, v) => (k, v.map(_._1)) }.toMap - appendDefsToModules(progWithClasses, modToNewDefs) - } else - progWithClasses - } - - /** - * NOT USED CURRENTLY - * Lift the specifications on functions to the invariants corresponding - * case classes. - * Ideally we should class invariants here, but it is not currently supported - * so we create a functions that can be assume in the pre and post of functions. - * TODO: can this be optimized - */ - /* def liftSpecsToClosures(opToAdt: Map[FunDef, CaseClassDef]) = { - val invariants = opToAdt.collect { - case (fd, ccd) if fd.hasPrecondition => - val transFun = (args: Seq[Identifier]) => { - val argmap: Map[Expr, Expr] = (fd.params.map(_.id.toVariable) zip args.map(_.toVariable)).toMap - replace(argmap, fd.precondition.get) - } - (ccd -> transFun) - }.toMap - val absTypes = opToAdt.values.collect { - case cd if cd.parent.isDefined => cd.parent.get.classDef - } - val invFuns = absTypes.collect { - case abs if abs.knownCCDescendents.exists(invariants.contains) => - val absType = AbstractClassType(abs, abs.tparams.map(_.tp)) - val param = ValDef(FreshIdentifier("$this", absType)) - val tparams = abs.tparams - val invfun = new FunDef(FreshIdentifier(abs.id.name + "$Inv", Untyped), - tparams, BooleanType, Seq(param)) - (abs -> invfun) - }.toMap - // assign bodies for the 'invfuns' - invFuns.foreach { - case (abs, fd) => - val bodyCases = abs.knownCCDescendents.collect { - case ccd if invariants.contains(ccd) => - val ctype = CaseClassType(ccd, fd.tparams.map(_.tp)) - val cvar = FreshIdentifier("t", ctype) - val fldids = ctype.fields.map { - case ValDef(fid, Some(fldtpe)) => - FreshIdentifier(fid.name, fldtpe) - } - val pattern = CaseClassPattern(Some(cvar), ctype, - fldids.map(fid => WildcardPattern(Some(fid)))) - val rhsInv = invariants(ccd)(fldids) - // assert the validity of substructures - val rhsValids = fldids.flatMap { - case fid if fid.getType.isInstanceOf[ClassType] => - val t = fid.getType.asInstanceOf[ClassType] - val rootDef = t match { - case absT: AbstractClassType => absT.classDef - case _ if t.parent.isDefined => - t.parent.get.classDef - } - if (invFuns.contains(rootDef)) { - List(FunctionInvocation(TypedFunDef(invFuns(rootDef), t.tps), - Seq(fid.toVariable))) - } else - List() - case _ => List() - } - val rhs = Util.createAnd(rhsInv +: rhsValids) - MatchCase(pattern, None, rhs) - } - // create a default case - val defCase = MatchCase(WildcardPattern(None), None, Util.tru) - val matchExpr = MatchExpr(fd.params.head.id.toVariable, bodyCases :+ defCase) - fd.body = Some(matchExpr) - } - invFuns - }*/ - // Expressions for testing solvers - // a test expression - /*val tparam = - val dummyFunDef = new FunDef(FreshIdentifier("i"),Seq(), Seq(), IntegerType) - val eq = Equals(FunctionInvocation(TypedFunDef(dummyFunDef, Seq()), Seq()), InfiniteIntegerLiteral(0)) - import solvers._ - val solver = SimpleSolverAPI(SolverFactory(() => new solvers.smtlib.SMTLIBCVC4Solver(ctx, prog))) - solver.solveSAT(eq) match { - case (Some(true), m) => - println("Model: "+m.toMap) - case _ => println("Formula is unsat") - } - System.exit(0)*/ -} diff --git a/src/main/scala/leon/laziness/LazyFunctionsManager.scala b/src/main/scala/leon/laziness/LazyFunctionsManager.scala deleted file mode 100644 index 46220895e63bc872014283950cdbeee854381f65..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/laziness/LazyFunctionsManager.scala +++ /dev/null @@ -1,122 +0,0 @@ -package leon -package laziness - -import invariant.util._ -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Extractors._ -import LazinessUtil._ - -class LazyFunctionsManager(p: Program) { - - // includes calls made through the specs - val cg = CallGraphUtil.constructCallGraph(p, false, true, - // a specialized callee function that ignores functions called inside `withState` calls, because they would have state as an argument - (inexpr: Expr) => { - var callees = Set[FunDef]() - def rec(e: Expr): Unit = e match { - case cc @ CaseClass(_, args) if LazinessUtil.isWithStateCons(cc)(p) => - ; //nothing to be done - case f : FunctionInvocation if LazinessUtil.isSuspInvocation(f)(p) => - // we can ignore the arguments to susp invocation as they are not actual calls, but only a test - ; - case cc : CaseClass if LazinessUtil.isMemCons(cc)(p) => - ; // we can ignore the arguments to mem - //note: do not consider field invocations - case f @ FunctionInvocation(TypedFunDef(callee, _), args) if callee.isRealFunction => - callees += callee - args map rec - case Operator(args, _) => args map rec - } - rec(inexpr) - callees - }) - - val (funsNeedStates, funsRetStates, funsNeedStateTps) = { - var starRoots = Set[FunDef]() - var readRoots = Set[FunDef]() - var valRoots = Set[FunDef]() - p.definedFunctions.foreach { - case fd if fd.hasBody => - postTraversal { - case finv: FunctionInvocation if isStarInvocation(finv)(p) => - starRoots += fd - case finv: FunctionInvocation if isLazyInvocation(finv)(p) => - // the lazy invocation constructor will need the state - readRoots += fd - case finv: FunctionInvocation if isEvaluatedInvocation(finv)(p) || isCachedInv(finv)(p) => - readRoots += fd - case finv: FunctionInvocation if isValueInvocation(finv)(p) => - valRoots += fd - case _ => - ; - }(fd.body.get) - case _ => ; - } - val valCallers = cg.transitiveCallers(valRoots.toSeq) - val readfuns = cg.transitiveCallers(readRoots.toSeq) - val starCallers = cg.transitiveCallers(starRoots.toSeq) - //println("Ret roots: "+retRoots.map(_.id)+" ret funs: "+retfuns.map(_.id)) - (readfuns ++ valCallers, valCallers, starCallers ++ readfuns ++ valCallers) - } - - lazy val callersnTargetOfLazyCons = { - var consRoots = Set[FunDef]() - var targets = Set[FunDef]() - funsNeedStates.foreach { - case fd if fd.hasBody => - postTraversal { - case finv: FunctionInvocation if isLazyInvocation(finv)(p) => // this is the lazy invocation constructor - consRoots += fd - targets += finv.tfd.fd - case _ => - ; - }(fd.body.get) - case _ => ; - } - cg.transitiveCallers(consRoots.toSeq) ++ targets - } - - lazy val cgWithoutSpecs = CallGraphUtil.constructCallGraph(p, true, false) - lazy val callersOfIsEvalandIsSusp = { - var roots = Set[FunDef]() - funsNeedStates.foreach { - case fd if fd.hasBody => - postTraversal { - case finv: FunctionInvocation if - isEvaluatedInvocation(finv)(p) || isSuspInvocation(finv)(p) || isCachedInv(finv)(p) => // call to isEvaluated || isSusp ? - roots += fd - case _ => - ; - }(fd.body.get) - case _ => ; - } - cgWithoutSpecs.transitiveCallers(roots.toSeq) - } - - def isRecursive(fd: FunDef) : Boolean = { - cg.isRecursive(fd) - } - - def hasStateIndependentBehavior(fd: FunDef) : Boolean = { - // every function that does not call isEvaluated or is Susp has a state independent behavior - !callersOfIsEvalandIsSusp.contains(fd) - } - -// lazy val targetsOfLazyCons = { -// var callees = Set[FunDef]() -// funsNeedStates.foreach { -// case fd if fd.hasBody => -// postTraversal { -// case finv: FunctionInvocation if isLazyInvocation(finv)(p) => // this is the lazy invocation constructor -// callees += finv.tfd.fd -// case _ => -// ; -// }(fd.body.get) -// case _ => ; -// } -// callees -// } - -} diff --git a/src/main/scala/leon/laziness/LazyInstrumenter.scala b/src/main/scala/leon/laziness/LazyInstrumenter.scala deleted file mode 100644 index 6dc5a7e50f184daa34907ca12bc58e285539a318..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/laziness/LazyInstrumenter.scala +++ /dev/null @@ -1,70 +0,0 @@ -package leon -package laziness - -import purescala.Common._ -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.Types._ -import leon.invariant.util.TypeUtil._ -import leon.transformations._ -import LazinessUtil._ - -class LazyInstrumenter(p: Program, ctx: LeonContext, clFactory: LazyClosureFactory) { - - val exprInstFactory = (x: Map[FunDef, FunDef], y: SerialInstrumenter, z: FunDef) => - new LazyExprInstrumenter(x, y)(z) - val serialInst = new SerialInstrumenter(p, Some(exprInstFactory)) - /*def funsWithInstSpecs = { - serialInst.instToInstrumenter.values.flatMap{inst => - inst.getRootFuncs(p) - }.toList.distinct - }*/ - - def apply: Program = serialInst.apply - - class LazyExprInstrumenter(funMap: Map[FunDef, FunDef], serialInst: SerialInstrumenter)(implicit currFun: FunDef) - extends ExprInstrumenter(funMap, serialInst)(currFun) { - - val costOfMemoization: Map[Instrumentation, Int] = - Map(Time -> 1, Stack -> 1, Rec -> 1, TPR -> 1, Depth -> 1) - - override def apply(e: Expr): Expr = { - if (isEvalFunction(currFun)) { - val closureParam = currFun.params(0).id.toVariable - val stateParam = currFun.params(1).id.toVariable - // we need to specialize instrumentation of body - val nbody = e match { - case MatchExpr(scr, mcases) => - val ncases = mcases map { - case MatchCase(pat, guard, body) => - // instrument the state part (and ignore the val part) - // (Note: this is an hack to ensure that we always consider only one call to targets) - /*val transState = transform(statepart)(Map()) - val transVal = transform(valpart)(Map()) - - val caseId = FreshIdentifier("cd", transState.getType, true) - val casePart = Tuple(Seq(TupleSelect(transVal, 1), TupleSelect(caseId.toVariable, 1))) - val instPart = instrumenters map { m => selectInst(caseId.toVariable, m.inst) } - val lete = Let(caseId, transState, Tuple(casePart +: instPart))*/ - MatchCase(pat, guard, transform(body)(Map())) - } - MatchExpr(scr, ncases) - } - //val nbody = super.apply(e) - val bodyId = FreshIdentifier("bd", nbody.getType, true) - // we need to select the appropriate field of the state - val lazyTname = adtNameToTypeName(typeNameWOParams(closureParam.getType)) - val setField = clFactory.selectFieldOfState(lazyTname, stateParam, - stateParam.getType.asInstanceOf[CaseClassType]) - val instExprs = instrumenters map { m => - IfExpr(ElementOfSet(closureParam, setField), - InfiniteIntegerLiteral(costOfMemoization(m.inst)), - selectInst(bodyId.toVariable, m.inst)) - } - Let(bodyId, nbody, - Tuple(TupleSelect(bodyId.toVariable, 1) +: instExprs)) - } else - super.apply(e) - } - } -} diff --git a/src/main/scala/leon/laziness/LazyVerificationPhase.scala b/src/main/scala/leon/laziness/LazyVerificationPhase.scala deleted file mode 100644 index 14b84a116802d3ec938af03824b2eb9833a5a551..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/laziness/LazyVerificationPhase.scala +++ /dev/null @@ -1,233 +0,0 @@ -package leon -package laziness - -import invariant.util._ -import invariant.structure.FunctionUtils._ -import purescala.Common._ -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.ExprOps._ -import solvers._ -import transformations._ -import LazinessUtil._ -import purescala.Constructors._ -import verification._ -import PredicateUtil._ -import invariant.engine._ - -object LazyVerificationPhase { - - val debugInstVCs = false - val debugInferProgram = false - - class LazyVerificationReport(val stateVerification: Option[VerificationReport], - val resourceVeri: Option[VerificationReport]) { - def inferReport= resourceVeri match { - case Some(inf: InferenceReport) => Some(inf) - case _ => None - } - } - - def removeInstrumentationSpecs(p: Program): Program = { - def hasInstVar(e: Expr) = { - exists { e => InstUtil.InstTypes.exists(i => i.isInstVariable(e)) }(e) - } - val newPosts = p.definedFunctions.collect { - case fd if fd.postcondition.exists { exists(hasInstVar) } => - // remove the conjuncts that use instrumentation vars - val Lambda(resdef, pbody) = fd.postcondition.get - val npost = simplifyByConstructors(pbody) match { - case And(args) => - createAnd(args.filterNot(hasInstVar)) - case l: Let => // checks if the body of the let can be deconstructed as And - //println(s"Fist let val: ${l.value} body: ${l.body}") - val (letsCons, letsBody) = letStarUnapply(l) - //println("Let* body: "+letsBody) - letsBody match { - case And(args) => - letsCons(createAnd(args.filterNot(hasInstVar))) - case _ => Util.tru - } - case e => Util.tru - } - (fd -> Lambda(resdef, npost)) - }.toMap - ProgramUtil.updatePost(newPosts, p) //note: this will not update libraries - } - - def contextForChecks(userOptions: LeonContext) = { - val solverOptions = Main.processOptions(Seq("--solvers=smt-cvc4,smt-z3", "--assumepre")) - LeonContext(userOptions.reporter, userOptions.interruptManager, - solverOptions.options ++ userOptions.options) - } - - def collectCumulativeStats(rep: VerificationReport) { - Stats.updateCumTime(rep.totalTime, "Total-Verification-Time") - Stats.updateCumStats(rep.totalConditions, "Total-VCs-Generated") - val (withz3, withcvc) = rep.vrs.partition { - case (vc, vr) => - vr.solvedWith.map(s => s.name.contains("smt-z3")).get - } - Stats.updateCounter(withz3.size, "Z3SolvedVCs") - Stats.updateCounter(withcvc.size, "CVC4SolvedVCs") - Stats.updateCounterStats(withz3.map(_._2.timeMs.getOrElse(0L)).sum, "Z3-Time", "Z3SolvedVCs") - Stats.updateCounterStats(withcvc.map(_._2.timeMs.getOrElse(0L)).sum, "CVC4-Time", "CVC4SolvedVCs") - } - - def checkSpecifications(prog: Program, checkCtx: LeonContext): VerificationReport = { - // convert 'axiom annotation to library - prog.definedFunctions.foreach { fd => - if (fd.annotations.contains("axiom")) - fd.addFlag(Annotation("library", Seq())) - } - val report = VerificationPhase.apply(checkCtx, prog) - // collect stats - collectCumulativeStats(report) - if(!checkCtx.findOption(GlobalOptions.optSilent).getOrElse(false)) { - println(report.summaryString) - } - report - } - - def checkInstrumentationSpecs(p: Program, checkCtx: LeonContext, useOrb: Boolean): VerificationReport = { - p.definedFunctions.foreach { fd => - if (fd.annotations.contains("axiom")) - fd.addFlag(Annotation("library", Seq())) - } - val rep = - if (useOrb) { - /*// create an inference context - val inferOpts = Main.processOptions(Seq("--disableInfer", "--assumepreInf", "--minbounds", "--solvers=smt-cvc4")) - val ctxForInf = LeonContext(checkCtx.reporter, checkCtx.interruptManager, - inferOpts.options ++ checkCtx.options) - val inferctx = new InferenceContext(p, ctxForInf) - val vcSolver = (funDef: FunDef, prog: Program) => new VCSolver(inferctx, prog, funDef) - if (debugInferProgram){ - prettyPrintProgramToFile(inferctx.inferProgram, checkCtx, "-inferProg", true) - } - - val results = (new InferenceEngine(inferctx)).analyseProgram(inferctx.inferProgram, - funsToCheck.map(InstUtil.userFunctionName), vcSolver, None) - new InferenceReport(results.map { case (fd, ic) => (fd -> List[VC](ic)) }, inferctx.inferProgram)(inferctx)*/ - val inferctx = getInferenceContext(checkCtx, p) - checkUsingOrb(new InferenceEngine(inferctx), inferctx) - } else { - val funsToCheck = p.definedFunctions.filter(shouldGenerateVC) - val rep = checkVCs(funsToCheck.map(vcForFun), checkCtx, p) - // record some stats - collectCumulativeStats(rep) - rep - } - if (!checkCtx.findOption(GlobalOptions.optSilent).getOrElse(false)) - println("Resource Verification Results: \n" + rep.summaryString) - rep - } - - def getInferenceContext(checkCtx: LeonContext, p: Program): InferenceContext = { - // create an inference context - val inferOpts = Main.processOptions(Seq("--disableInfer", "--assumepreInf", "--minbounds", "--solvers=smt-cvc4")) - val ctxForInf = LeonContext(checkCtx.reporter, checkCtx.interruptManager, - inferOpts.options ++ checkCtx.options) - new InferenceContext(p, ctxForInf) - } - - def checkUsingOrb(infEngine: InferenceEngine, inferctx: InferenceContext, - progressCallback: Option[InferenceCondition => Unit] = None) = { - if (debugInferProgram) { - prettyPrintProgramToFile(inferctx.inferProgram, inferctx.leonContext, "-inferProg", true) - } - val funsToCheck = inferctx.initProgram.definedFunctions.filter(shouldGenerateVC) - val vcSolver = (funDef: FunDef, prog: Program) => new VCSolver(inferctx, prog, funDef) - val results = infEngine.analyseProgram(inferctx.inferProgram, - funsToCheck.map(InstUtil.userFunctionName), vcSolver, progressCallback) - new InferenceReport(results.map { case (fd, ic) => (fd -> List[VC](ic)) }, inferctx.inferProgram)(inferctx) - } - - def accessesSecondRes(e: Expr, resid: Identifier): Boolean = - exists(_ == TupleSelect(resid.toVariable, 2))(e) - - /** - * Note: we also skip verification of uninterpreted functions - */ - def shouldGenerateVC(fd: FunDef) = { - !fd.isInvariant && !fd.isLibrary && InstUtil.isInstrumented(fd) && fd.hasBody && - fd.postcondition.exists { post => - val Lambda(Seq(resdef), pbody) = post - accessesSecondRes(pbody, resdef.id) - } - } - - /** - * creates vcs - * Note: we only need to check specs involving instvars since others were checked before. - * Moreover, we can add other specs as assumptions since (A => B) ^ ((A ^ B) => C) => A => B ^ C - * checks if the expression uses res._2 which corresponds to instvars after instrumentation - */ - def vcForFun(fd: FunDef) = { - val (body, ants, post, tmpl) = collectAntsPostTmpl(fd) - if (tmpl.isDefined) - throw new IllegalStateException("Postcondition has holes! Run with --useOrb option") - val vc = implies(And(ants, body), post) - if (debugInstVCs) - println(s"VC for function ${fd.id} : " + vc) - VC(vc, fd, VCKinds.Postcondition) - } - - def collectAntsPostTmpl(fd: FunDef) = { - val Lambda(Seq(resdef), _) = fd.postcondition.get - val (pbody, tmpl) = (fd.getPostWoTemplate, fd.template) - val (instPost, assumptions) = simplifyByConstructors(pbody) match { - case And(args) => - val (instSpecs, rest) = args.partition(accessesSecondRes(_, resdef.id)) - (createAnd(instSpecs), createAnd(rest)) - case l: Let => - val (letsCons, letsBody) = letStarUnapplyWithSimplify(l) - letsBody match { - case And(args) => - val (instSpecs, rest) = args.partition(accessesSecondRes(_, resdef.id)) - (letsCons(createAnd(instSpecs)), letsCons(createAnd(rest))) - case _ => - (l, Util.tru) - } - case e => (e, Util.tru) - } - val ants = - if (fd.usePost) createAnd(Seq(fd.precOrTrue, assumptions)) - else fd.precOrTrue - (Equals(resdef.id.toVariable, fd.body.get), ants, instPost, tmpl) - } - - def checkVCs(vcs: List[VC], checkCtx: LeonContext, p: Program) = { - val timeout: Option[Long] = None - // Solvers selection and validation - val baseSolverF = SolverFactory.getFromSettings(checkCtx, p) - val solverF = timeout match { - case Some(sec) => - baseSolverF.withTimeout(sec / 1000) - case None => - baseSolverF - } - val vctx = new VerificationContext(checkCtx, p, solverF) - try { - VerificationPhase.checkVCs(vctx, vcs) - //println("Resource Verification Results: \n" + veriRep.summaryString) - } finally { - solverF.shutdown() - } - } - - class VCSolver(ctx: InferenceContext, p: Program, rootFd: FunDef) extends - UnfoldingTemplateSolver(ctx, p, rootFd) { - - override def constructVC(fd: FunDef): (Expr, Expr, Expr) = { - val (body, ants, post, tmpl) = collectAntsPostTmpl(rootFd) - val conseq = matchToIfThenElse(createAnd(Seq(post, tmpl.getOrElse(Util.tru)))) - //println(s"body: $body ants: $ants conseq: $conseq") - (matchToIfThenElse(body), matchToIfThenElse(ants), conseq) - } - - override def verifyVC(newprog: Program, newroot: FunDef) = { - solveUsingLeon(contextForChecks(ctx.leonContext), newprog, vcForFun(newroot)) - } - } -} diff --git a/src/main/scala/leon/laziness/TypeChecker.scala b/src/main/scala/leon/laziness/TypeChecker.scala deleted file mode 100644 index ab9c1b3e02a41b1a27c564b6ee845981ab8133d9..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/laziness/TypeChecker.scala +++ /dev/null @@ -1,215 +0,0 @@ -package leon -package laziness - -import purescala.Common._ -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.Extractors._ -import purescala.Types._ -import purescala.TypeOps._ -import invariant.util.TypeUtil._ - -object TypeChecker { - /** - * `gamma` is the initial type environment which has - * type bindings for free variables of `ine`. - * It is not necessary that gamma should match the types of the - * identifiers of the free variables. - * Set and Maps are not supported yet - */ - def inferTypesOfLocals(ine: Expr, initGamma: Map[Identifier, TypeTree]): Expr = { - var idmap = Map[Identifier, Identifier]() - var gamma = initGamma - - /** - * Note this method has side-effects - */ - def makeIdOfType(oldId: Identifier, tpe: TypeTree): Identifier = { - if (oldId.getType != tpe) { - val freshid = FreshIdentifier(oldId.name, tpe, true) - idmap += (oldId -> freshid) - gamma += (oldId -> tpe) - freshid - } else oldId - } - - def rec(e: Expr): (TypeTree, Expr) = { - val res = e match { - case Let(id, value, body) => - val (valType, nval) = rec(value) - val nid = makeIdOfType(id, valType) - val (btype, nbody) = rec(body) - (btype, Let(nid, nval, nbody)) - - case Ensuring(body, Lambda(Seq(resdef @ ValDef(resid)), postBody)) => - body match { - case NoTree(tpe) => - val nres = makeIdOfType(resid, tpe) - (tpe, Ensuring(body, Lambda(Seq(ValDef(nres)), rec(postBody)._2))) - case _ => - val (btype, nbody) = rec(body) - val nres = makeIdOfType(resid, btype) - (btype, Ensuring(nbody, Lambda(Seq(ValDef(nres)), rec(postBody)._2))) - } - - case MatchExpr(scr, mcases) => - val (scrtype, nscr) = rec(scr) - val ncases = mcases.map { - case MatchCase(pat, optGuard, rhs) => - // resetting the type of patterns in the matches - def mapPattern(p: Pattern, expType: TypeTree): (Pattern, TypeTree) = { - p match { - case InstanceOfPattern(bopt, ict) => - // choose the subtype of the `expType` that - // has the same constructor as `ict` - val ntype = subcast(ict, expType.asInstanceOf[ClassType]) - if (!ntype.isDefined) - throw new IllegalStateException(s"Cannot find subtype of $expType with name: ${ict.classDef.id.toString}") - val nbopt = bopt.map(makeIdOfType(_, ntype.get)) - (InstanceOfPattern(nbopt, ntype.get), ntype.get) - - case CaseClassPattern(bopt, ict, subpats) => - val ntype = subcast(ict, expType.asInstanceOf[ClassType]) - if (!ntype.isDefined) - throw new IllegalStateException(s"Cannot find subtype of $expType with name: ${ict.classDef.id.toString}") - val cct = ntype.get.asInstanceOf[CaseClassType] - val nbopt = bopt.map(makeIdOfType(_, cct)) - val npats = (subpats zip cct.fieldsTypes).map { - case (p, t) => - //println(s"Subpat: $p expected type: $t") - mapPattern(p, t)._1 - } - (CaseClassPattern(nbopt, cct, npats), cct) - - case TuplePattern(bopt, subpats) => - val TupleType(subts) = expType - val patnTypes = (subpats zip subts).map { - case (p, t) => mapPattern(p, t) - } - val npats = patnTypes.map(_._1) - val ntype = TupleType(patnTypes.map(_._2)) - val nbopt = bopt.map(makeIdOfType(_, ntype)) - (TuplePattern(nbopt, npats), ntype) - - case WildcardPattern(bopt) => - val nbopt = bopt.map(makeIdOfType(_, expType)) - (WildcardPattern(nbopt), expType) - - case LiteralPattern(bopt, lit) => - val ntype = lit.getType - val nbopt = bopt.map(makeIdOfType(_, ntype)) - (LiteralPattern(nbopt, lit), ntype) - case _ => - throw new IllegalStateException("Not supported yet!") - } - } - val npattern = mapPattern(pat, scrtype)._1 - val nguard = optGuard.map(rec(_)._2) - val nrhs = rec(rhs)._2 - //println(s"New rhs: $nrhs inferred type: ${nrhs.getType}") - MatchCase(npattern, nguard, nrhs) - } - val nmatch = MatchExpr(nscr, ncases) - //println("Old match expr: "+e+" \n new expr: "+nmatch) - (nmatch.getType, nmatch) - - case cs @ CaseClassSelector(cltype, clExpr, fld) => - val (ncltype: ClassType, nclExpr) = rec(clExpr) - // this is a hack. TODO: fix this - subcast(cltype, ncltype) match { - case Some(ntype : CaseClassType) => - val nop = CaseClassSelector(ntype, nclExpr, fld) - (nop.getType, nop) - case _ => - throw new IllegalStateException(s"$nclExpr : $ncltype cannot be cast to case class type: $cltype") - } - - case AsInstanceOf(clexpr, cltype) => - val (ncltype: ClassType, nexpr) = rec(clexpr) - subcast(cltype, ncltype) match { - case Some(ntype) => (ntype, AsInstanceOf(nexpr, ntype)) - case _ => - //println(s"asInstanceOf type of $clExpr is: $cltype inferred type of $nclExpr : $ct") - throw new IllegalStateException(s"$nexpr : $ncltype cannot be cast to case class type: $cltype") - } - - case v @ Variable(id) => - if (gamma.contains(id)) { - if (idmap.contains(id)) - (gamma(id), idmap(id).toVariable) - else { - (gamma(id), v) - } - } else (id.getType, v) - - case FunctionInvocation(TypedFunDef(fd, tparams), args) => - //println(s"Consider expr: $e initial type: ${e.getType}") - val nargs = args.map(arg => rec(arg)._2) - var tpmap = Map[TypeParameter, TypeTree]() - (fd.params zip nargs).foreach { x => - (x._1.getType, x._2.getType) match { - case (t1, t2) => - getTypeArguments(t1) zip getTypeArguments(t2) foreach { - case (tf : TypeParameter, ta) => - tpmap += (tf -> ta) - case _ => ; - } - /*throw new IllegalStateException(s"Types of formal and actual parameters: ($tf, $ta)" - + s"do not match for call: $call")*/ - } - } - // for uninterpreted functions, we could have a type parameter used only in the return type - val dummyTParam = TypeParameter.fresh("R@") - val ntparams = fd.tparams.map(_.tp).zipAll(tparams, dummyTParam, dummyTParam).map{ - case (paramt, argt) => - tpmap.getOrElse(paramt /* in this case we inferred the type parameter */, - argt /* in this case we reuse the argument type parameter */ ) - } - val nexpr = FunctionInvocation(TypedFunDef(fd, ntparams), nargs) - if (nexpr.getType == Untyped) { - throw new IllegalStateException(s"Cannot infer type for expression: $e "+ - s"arg types: ${nargs.map(_.getType).mkString(",")} \n Callee: ${fd} \n caller: ${nexpr}") - } - (nexpr.getType, nexpr) - - case FiniteSet(els, baseType) => - val nels = els.map(rec(_)._2) - // make sure every element has the same type (upcast it to the rootType) - val nbaseType = bestRealType(nels.head.getType) - if(!nels.forall(el => bestRealType(el.getType) == nbaseType)) - throw new IllegalStateException("Not all elements in the set have the same type: "+nbaseType) - val nop = FiniteSet(nels, nbaseType) - (nop.getType, nop) - - // need to handle tuple select specially - case TupleSelect(tup, i) => - val nop = TupleSelect(rec(tup)._2, i) - (nop.getType, nop) - case Operator(args, op) => - val nargs = args.map(arg => rec(arg)._2) - val nop = op(nargs) - (nop.getType, nop) - case t: Terminal => - (t.getType, t) - } - //println(s"Inferred type of $e : ${res._1} new expression: ${res._2}") - if (res._1 == Untyped) { - throw new IllegalStateException(s"Cannot infer type for expression: $e") - } - res - } - - def subcast(oldType: ClassType, newType: ClassType): Option[ClassType] = { - newType match { - case AbstractClassType(absClass, tps) if absClass.knownCCDescendants.contains(oldType.classDef) => - //here oldType.classDef <: absClass - Some(CaseClassType(oldType.classDef.asInstanceOf[CaseClassDef], tps)) - case cct: CaseClassType => - Some(cct) - case _ => - None - } - } - rec(ine)._2 - } -} diff --git a/src/main/scala/leon/laziness/TypeRectifier.scala b/src/main/scala/leon/laziness/TypeRectifier.scala deleted file mode 100644 index 72d61bd7c33d8dcf6bcfc894dbde864921bd33f5..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/laziness/TypeRectifier.scala +++ /dev/null @@ -1,188 +0,0 @@ -package leon -package laziness - -import invariant.structure.FunctionUtils._ -import purescala.Common._ -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Extractors._ -import purescala.Types._ -import leon.invariant.util.TypeUtil._ -import leon.invariant.util.LetTupleSimplification._ -import LazinessUtil._ -import leon.invariant.datastructure.DisjointSets -import invariant.util.ProgramUtil._ - -/** - * This performs type parameter inference based on few known facts. - * This algorithm is inspired by Hindley-Milner type Inference (but not the same). - * Result is a program in which all type paramters of functions, types of - * parameters of functions are correct. - * The subsequent phase performs a local type inference. - */ -class TypeRectifier(p: Program, clFactory: LazyClosureFactory) { - - val typeClasses = { - val tc = new DisjointSets[TypeTree]() - p.definedFunctions.foreach { - case fd if fd.hasBody && !fd.isLibrary && !fd.isInvariant => - postTraversal { - case call @ FunctionInvocation(TypedFunDef(fd, tparams), args) => - // unify formal type parameters with actual type arguments - (fd.tparams zip tparams).foreach(x => tc.union(x._1.tp, x._2)) - /** - * Unify the type parameters of types of formal parameters with - * type arguments of actual arguments. - * Note: we have a cyclic problem here, since we do not know the - * type of the variables in the programs, we cannot use them - * to infer type parameters, on the other hand we need to know the - * type parameters (at least of fundefs) to infer types of variables. - * The idea to start from few variables whose types we know are correct - * except for type paramters. - * Eg. the state parameters, parameters that have closure ADT type (not the '$' type), - * some parameters that are not of lazy type ('$') type may also have - * correct types, but it is hard to rely on them - */ - (fd.params zip args).foreach { x => - (x._1.getType, x._2.getType) match { - case (CaseClassType(cd1, targs1), CaseClassType(cd2, targs2)) if cd1 == cd2 && cd1 == clFactory.state => - (targs1 zip targs2).foreach { - case (t1: TypeParameter, t2: TypeParameter) => - tc.union(t1, t2) - case _ => - } - case (ct1: ClassType, ct2: ClassType) - if clFactory.isClosureType(ct1.classDef) && clFactory.isClosureType(ct1.classDef) => - // both types are newly created closures, so their types can be trusted - (ct1.tps zip ct2.tps).foreach { - case (t1: TypeParameter, t2: TypeParameter) => - tc.union(t1, t2) - case _ => - } - case (t1, t2) => - /*throw new IllegalStateException(s"Types of formal and actual parameters: ($tf, $ta)" - + s"do not match for call: $call")*/ - } - } - // consider also set contains methods - case ElementOfSet(arg, set) => - // merge the type parameters of `arg` and `set` - set.getType match { - case SetType(baseT) => - // TODO: this may break easily. Fix this. - // Important: here 'arg' may have type lazy type $[ltype] - // we need to get the type argument of ltype - getTypeParameters(arg.getType) zip getTypeArguments(baseT) foreach { - case (tf, ta) => - tc.union(tf, ta) - } - case _ => - } - case _ => - }(fd.fullBody) - case _ => ; - } - tc - } - - val equivTypeParams = typeClasses.toMap - - val fdMap = p.definedFunctions.collect { - case fd if !fd.isLibrary && !fd.isInvariant => - val (tempTPs, otherTPs) = fd.tparams.map(_.tp).partition { - case tp if isPlaceHolderTParam(tp) => true - case _ => false - } - val others = otherTPs.toSet[TypeTree] - // for each of the type parameter pick one representative from its equivalence class - val tpMap = fd.tparams.map { - case TypeParameterDef(tp) => - val tpclass = equivTypeParams.getOrElse(tp, Set(tp)) - val candReps = tpclass.filter(r => others.contains(r) || !r.isInstanceOf[TypeParameter]) - val concRep = candReps.find(!_.isInstanceOf[TypeParameter]) - val rep = - if (concRep.isDefined) // there exists a concrete type ? - concRep.get - else if (!candReps.isEmpty) - candReps.head - else - throw new IllegalStateException(s"Cannot find a non-placeholder in equivalence class $tpclass for fundef: \n $fd") - tp -> rep - }.toMap - val instf = instantiateTypeParameters(tpMap) _ - val paramMap = fd.params.map { - case vd @ ValDef(id) => - (id -> FreshIdentifier(id.name, instf(vd.getType))) - }.toMap - val ntparams = fd.tparams.map(tpd => tpMap(tpd.tp)).distinct.collect { - case tp: TypeParameter => tp - } map TypeParameterDef - val nfd = new FunDef(fd.id.freshen, ntparams, fd.params.map(vd => ValDef(paramMap(vd.id))), - instf(fd.returnType)) - fd -> (nfd, tpMap, paramMap) - }.toMap - - /** - * Replace fundefs and unify type parameters in function invocations. - * Replace old parameters by new parameters - */ - def transformFunBody(ifd: FunDef) = { - val (nfd, tpMap, paramMap) = fdMap(ifd) - // need to handle tuple select specially as it breaks if the type of - // the tupleExpr if it is not TupleTyped. - // cannot use simplePostTransform because of this - def rec(e: Expr): Expr = e match { - case FunctionInvocation(TypedFunDef(callee, targsOld), args) => // this is already done by the type checker - val targs = targsOld.map { - case tp: TypeParameter => tpMap.getOrElse(tp, tp) - case t => t - }.distinct - val ncallee = - if (fdMap.contains(callee)) - fdMap(callee)._1 - else callee - FunctionInvocation(TypedFunDef(ncallee, targs), args map rec) - - case CaseClass(cct, args) => - val targs = cct.tps.map { - case tp: TypeParameter => tpMap.getOrElse(tp, tp) - case t => t - }.distinct - CaseClass(CaseClassType(cct.classDef, targs), args map rec) - - case Variable(id) if paramMap.contains(id) => - paramMap(id).toVariable - case TupleSelect(tup, index) => - TupleSelect(rec(tup), index) - case Ensuring(NoTree(_), post) => - Ensuring(nfd.fullBody, rec(post)) // the newfd body would already be type correct - case Operator(args, op) => op(args map rec) - case t: Terminal => t - } - val nbody = rec(ifd.fullBody) - val initGamma = nfd.params.map(vd => vd.id -> vd.getType).toMap - - //println(s"Inferring types for ${ifd.id}: "+nbody) - val typedBody = TypeChecker.inferTypesOfLocals(nbody, initGamma) - /*if(ifd.id.name.contains("pushLeftWrapper")) { - //println(s"Inferring types for ${ifd.id} new fun: $nfd \n old body: ${ifd.fullBody} \n type correct body: $typedBody") - System.exit(0) - }*/ - typedBody - } - - def apply: Program = { - copyProgram(p, (defs: Seq[Definition]) => defs.map { - case fd: FunDef if fdMap.contains(fd) => - val nfd = fdMap(fd)._1 - if (!fd.fullBody.isInstanceOf[NoTree]) { - nfd.fullBody = simplifyLetsAndLetsWithTuples(transformFunBody(fd)) - } - fd.flags.foreach(nfd.addFlag(_)) - //println("New fun: "+fd) - nfd - case d => d - }) - } -} diff --git a/src/main/scala/leon/purescala/DefOps.scala b/src/main/scala/leon/purescala/DefOps.scala deleted file mode 100644 index af2727f7775059a4143f9a6d5df162c231064e83..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/purescala/DefOps.scala +++ /dev/null @@ -1,720 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package purescala - -import Definitions._ -import Expressions._ -import Common.Identifier -import ExprOps.{preMap, functionCallsOf} -import leon.purescala.Types.AbstractClassType -import leon.purescala.Types._ - -import scala.collection.mutable.{Map => MutableMap} - -object DefOps { - - private def packageOf(df: Definition)(implicit pgm: Program): PackageRef = { - df match { - case _ : Program => List() - case u : UnitDef => u.pack - case _ => unitOf(df).map(_.pack).getOrElse(List()) - } - } - - private def unitOf(df: Definition)(implicit pgm: Program): Option[UnitDef] = df match { - case p : Program => None - case u : UnitDef => Some(u) - case other => pgm.units.find(_.containsDef(df)) - } - - def moduleOf(df: Definition)(implicit pgm: Program): Option[ModuleDef] = df match { - case p : Program => None - case u : UnitDef => None - case other => pgm.units.flatMap(_.modules).find { _.containsDef(df) } - } - - def pathFromRoot(df: Definition)(implicit pgm: Program): List[Definition] = { - def rec(from: Definition): List[Definition] = { - from :: (if (from == df) { - Nil - } else { - from.subDefinitions.find { sd => (sd eq df) || sd.containsDef(df) } match { - case Some(sd) => - rec(sd) - case None => - Nil - } - }) - } - rec(pgm) - } - - def unitsInPackage(p: Program, pack: PackageRef) = p.units filter { _.pack == pack } - - - - /** Returns the set of definitions directly visible from the current definition - * Definitions that are shadowed by others are not returned. - */ - def visibleDefsFrom(df: Definition)(implicit pgm: Program): Set[Definition] = { - var toRet = Map[String,Definition]() - val asList = - (pathFromRoot(df).reverse flatMap { _.subDefinitions }) ++ { - unitsInPackage(pgm, packageOf(df)) flatMap { _.subDefinitions } - } ++ - List(pgm) ++ - ( for ( u <- unitOf(df).toSeq; - imp <- u.imports; - impDf <- imp.importedDefs(u) - ) yield impDf - ) - for ( - df <- asList; - name = df.id.toString - ) { - if (!(toRet contains name)) toRet += name -> df - } - toRet.values.toSet - } - - def visibleFunDefsFrom(df: Definition)(implicit pgm: Program): Set[FunDef] = { - visibleDefsFrom(df).collect { - case fd: FunDef => fd - } - } - - def funDefsFromMain(implicit pgm: Program): Set[FunDef] = { - pgm.units.filter(_.isMainUnit).toSet.flatMap{ (u: UnitDef) => - u.definedFunctions - } - } - - def visibleFunDefsFromMain(implicit p: Program): Set[FunDef] = { - p.units.filter(_.isMainUnit).toSet.flatMap{ (u: UnitDef) => - visibleFunDefsFrom(u) ++ u.definedFunctions - } - } - - private def stripPrefix(off: List[String], from: List[String]): List[String] = { - val commonPrefix = (off zip from).takeWhile(p => p._1 == p._2) - - val res = off.drop(commonPrefix.size) - - if (res.isEmpty) { - if (off.isEmpty) List() - else List(off.last) - } else { - res - } - } - - def simplifyPath(namesOf: List[String], from: Definition, useUniqueIds: Boolean)(implicit pgm: Program) = { - val pathFrom = pathFromRoot(from).dropWhile(_.isInstanceOf[Program]) - - val namesFrom = pathToNames(pathFrom, useUniqueIds) - - val names: Set[List[String]] = Set(namesOf, stripPrefix(namesOf, namesFrom)) ++ - getNameUnderImports(pathFrom, namesOf) - - names.toSeq.minBy(_.size).mkString(".") - } - - def fullNameFrom(of: Definition, from: Definition, useUniqueIds: Boolean)(implicit pgm: Program): String = { - val pathFrom = pathFromRoot(from).dropWhile(_.isInstanceOf[Program]) - - val namesFrom = pathToNames(pathFrom, useUniqueIds) - val namesOf = pathToNames(pathFromRoot(of), useUniqueIds) - - val sp = stripPrefix(namesOf, namesFrom) - if (sp.isEmpty) return "**** " + of.id.uniqueName - val names: Set[List[String]] = - Set(namesOf, stripPrefix(namesOf, namesFrom)) ++ getNameUnderImports(pathFrom, namesOf) - - names.toSeq.minBy(_.size).mkString(".") - } - - private def getNameUnderImports(pathFrom: List[Definition], namesOf: List[String]): Seq[List[String]] = { - pathFrom match { - case (u: UnitDef) :: _ => - val imports = u.imports.map { - case Import(path, true) => path - case Import(path, false) => path.init - }.toList - - def stripImport(of: List[String], imp: List[String]): Option[List[String]] = { - if (of.startsWith(imp)) { - Some(stripPrefix(of, imp)) - } else { - None - } - } - - for {imp <- imports - strippedImport <- stripImport(namesOf, imp) - } yield strippedImport - case _ => - Nil - } - } - - def pathToNames(path: List[Definition], useUniqueIds: Boolean): List[String] = { - path.flatMap { - case p: Program => - Nil - case u: UnitDef => - u.pack - case m: ModuleDef if m.isPackageObject => - Nil - case d => - if (useUniqueIds) { - List(d.id.uniqueName) - } else { - List(d.id.toString) - } - } - } - - def pathToString(path: List[Definition], useUniqueIds: Boolean): String = { - pathToNames(path, useUniqueIds).mkString(".") - } - - def fullName(df: Definition, useUniqueIds: Boolean = false)(implicit pgm: Program): String = { - pathToString(pathFromRoot(df), useUniqueIds) - } - - def qualifiedName(fd: FunDef, useUniqueIds: Boolean = false)(implicit pgm: Program): String = { - pathToString(pathFromRoot(fd).takeRight(2), useUniqueIds) - } - - private def nameToParts(name: String) = { - name.split("\\.").toList - } - - def searchWithin(name: String, within: Definition): Seq[Definition] = { - searchWithin(nameToParts(name), within) - } - - def searchWithin(ns: List[String], within: Definition): Seq[Definition] = { - (ns, within) match { - case (ns, p: Program) => - p.units.flatMap { u => - searchWithin(ns, u) - } - - case (ns, u: UnitDef) => - if (ns.startsWith(u.pack)) { - val rest = ns.drop(u.pack.size) - - u.defs.flatMap { - case d: ModuleDef if d.isPackageObject => - searchWithin(rest, d) - - case d => - rest match { - case n :: ns => - if (d.id.name == n) { - searchWithin(ns, d) - } else { - Nil - } - case Nil => - List(u) - } - } - } else { - Nil - } - - case (Nil, d) => List(d) - case (n :: ns, d) => - d.subDefinitions.filter(_.id.name == n).flatMap { sd => - searchWithin(ns, sd) - } - } - } - - def searchRelative(name: String, from: Definition)(implicit pgm: Program): Seq[Definition] = { - val names = nameToParts(name) - val path = pathFromRoot(from) - - searchRelative(names, path.reverse) - } - - private def resolveImports(imports: Seq[Import], names: List[String]): Seq[List[String]] = { - def resolveImport(i: Import): Option[List[String]] = { - if (!i.isWild && names.startsWith(i.path.last)) { - Some(i.path ++ names.tail) - } else if (i.isWild) { - Some(i.path ++ names) - } else { - None - } - } - - imports.flatMap(resolveImport) - } - - private def searchRelative(names: List[String], rpath: List[Definition])(implicit pgm: Program): Seq[Definition] = { - (names, rpath) match { - case (n :: ns, d :: ds) => - (d match { - case p: Program => - searchWithin(names, p) - - case u: UnitDef => - val inModules = d.subDefinitions.filter(_.id.name == n).flatMap { sd => - searchWithin(ns, sd) - } - - val namesImported = resolveImports(u.imports, names) - val nameWithPackage = u.pack ++ names - - val allNames = namesImported :+ nameWithPackage - - allNames.foldLeft(inModules) { _ ++ searchRelative(_, ds) } - - case d => - if (n == d.id.name) { - searchWithin(ns, d) - } else { - searchWithin(n :: ns, d) - } - }) ++ searchRelative(names, ds) - - case _ => - Nil - } - } - - def replaceDefsInProgram(p: Program)(fdMap: Map[FunDef, FunDef] = Map.empty, - cdMap: Map[ClassDef, ClassDef] = Map.empty): Program = { - p.copy(units = for (u <- p.units) yield { - u.copy(defs = u.defs.map { - case m : ModuleDef => - m.copy(defs = for (df <- m.defs) yield { - df match { - case cd : ClassDef => cdMap.getOrElse(cd, cd) - case fd : FunDef => fdMap.getOrElse(fd, fd) - case d => d - } - }) - case cd: ClassDef => cdMap.getOrElse(cd, cd) - case d => d - }) - }) - } - - def definitionReplacer( - fdMapF: FunDef => Option[FunDef], - cdMapF: ClassDef => Option[ClassDef], - fiMapF: (FunctionInvocation, FunDef) => Option[Expr] = defaultFiMap, - ciMapF: (CaseClass, CaseClassType) => Option[Expr] = defaultCdMap - ): DefinitionTransformer = { - - val idMap = new utils.Bijection[Identifier, Identifier] - val cdMap = new utils.Bijection[ClassDef , ClassDef ] - val fdMap = new utils.Bijection[FunDef , FunDef ] - - new DefinitionTransformer(idMap, fdMap, cdMap) { - override def transformExpr(expr: Expr)(implicit bindings: Map[Identifier, Identifier]): Option[Expr] = expr match { - case fi @ FunctionInvocation(TypedFunDef(fd, tps), args) => - val transformFd = transform(fd) - if(transformFd != fd) - fiMapF(fi, transformFd) - else - None - case cc @ CaseClass(cct, args) => - val transformCct = transform(cct).asInstanceOf[CaseClassType] - if(transformCct != cct) - ciMapF(cc, transformCct) - else - None - case _ => - None - } - - override def transformFunDef(fd: FunDef): Option[FunDef] = fdMapF(fd) - override def transformClassDef(cd: ClassDef): Option[ClassDef] = cdMapF(cd) - } - - } - - private def defaultFiMap(fi: FunctionInvocation, nfd: FunDef): Option[Expr] = (fi, nfd) match { - case (FunctionInvocation(old, args), newfd) if old.fd != newfd => - Some(FunctionInvocation(newfd.typed(old.tps), args)) - case _ => - None - } - - /** Return a [[DefinitionTransformer]] that transforms according to the given [[FunDef]] maps */ - def funDefReplacer( - fdMapF: FunDef => Option[FunDef], - fiMapF: (FunctionInvocation, FunDef) => Option[Expr] = defaultFiMap - ): DefinitionTransformer = { - definitionReplacer(fdMapF, cd => None, fiMapF) - } - - def transformProgram(transformer: DefinitionTransformer, p: Program) = { - val cdsMap = p.definedClasses.map (cd => cd -> transformer.transform(cd) ).toMap - val fdsMap = p.definedFunctions.map (fd => fd -> transformer.transform(fd) ).toMap - replaceDefsInProgram(p)(fdsMap, cdsMap) - } - - private def defaultCdMap(cc: CaseClass, ccd: CaseClassType): Option[Expr] = (cc, ccd) match { - case (CaseClass(old, args), newCcd) if old.classDef != newCcd.classDef => - Some(CaseClass(newCcd, args)) - case _ => - None - } - - /** Clones the given program by replacing some classes by other classes. - * - * @param p The original program - * @param ciMapF Given a previous case class invocation and its new case class definition, returns the expression to use. - * By default it is the case class construction using the new case class definition. - * @return the new program with a map from the old case classes to the new case classes, with maps concerning identifiers and function definitions. */ - def replaceCaseClassDefs(p: Program)(cdMapFOriginal: CaseClassDef => Option[Option[AbstractClassType] => CaseClassDef], - ciMapF: (CaseClass, CaseClassType) => Option[Expr] = defaultCdMap) - : (Program, Map[ClassDef, ClassDef], Map[Identifier, Identifier], Map[FunDef, FunDef]) = { - var cdMapFCache = Map[CaseClassDef, Option[Option[AbstractClassType] => CaseClassDef]]() - var cdMapCache = Map[ClassDef, Option[ClassDef]]() - var idMapCache = Map[Identifier, Identifier]() - var fdMapFCache = Map[FunDef, Option[FunDef]]() - var fdMapCache = Map[FunDef, Option[FunDef]]() - - def cdMapF(cd: ClassDef): Option[Option[AbstractClassType] => CaseClassDef] = { - cd match { - case ccd: CaseClassDef => - cdMapFCache.getOrElse(ccd, { - val new_cd_potential = cdMapFOriginal(ccd) - cdMapFCache += ccd -> new_cd_potential - new_cd_potential - }) - case acd: AbstractClassDef => None - } - } - - def tpMap[T <: TypeTree](tt: T): T = TypeOps.postMap{ - case AbstractClassType(asd, targs) => Some(AbstractClassType(cdMap(asd).asInstanceOf[AbstractClassDef], targs)) - case CaseClassType(ccd, targs) => Some(CaseClassType(cdMap(ccd).asInstanceOf[CaseClassDef], targs)) - case e => None - }(tt).asInstanceOf[T] - - def duplicateClassDef(cd: ClassDef): ClassDef = { - cdMapCache.get(cd) match { - case Some(new_cd) => - new_cd.get // None would have meant that this class would never be duplicated, which is not possible. - case None => - val parent = cd.parent.map(duplicateAbstractClassType) - val new_cd = cdMapF(cd).map(f => f(parent)).getOrElse { - cd match { - case acd:AbstractClassDef => acd.duplicate(parent = parent) - case ccd:CaseClassDef => - ccd.duplicate(parent = parent, fields = ccd.fieldsIds.map(id => ValDef(idMap(id)))) // Should not cycle since fields have to be abstract. - } - } - cdMapCache += cd -> Some(new_cd) - new_cd - } - } - - def duplicateAbstractClassType(act: AbstractClassType): AbstractClassType = { - TypeOps.postMap{ - case AbstractClassType(acd, tps) => Some(AbstractClassType(duplicateClassDef(acd).asInstanceOf[AbstractClassDef], tps)) - case CaseClassType(ccd, tps) => Some(CaseClassType(duplicateClassDef(ccd).asInstanceOf[CaseClassDef], tps)) - case _ => None - }(act).asInstanceOf[AbstractClassType] - } - - // If at least one descendants or known case class needs conversion, then all the hierarchy will be converted. - // If something extends List[A] and A is modified, then the first something should be modified. - def dependencies(s: ClassDef): Set[ClassDef] = { - leon.utils.fixpoint((s: Set[ClassDef]) => s ++ s.flatMap(_.knownDescendants) ++ s.flatMap(_.parent.toList.flatMap(p => TypeOps.collect[ClassDef]{ - case AbstractClassType(acd, _) => Set(acd:ClassDef) ++ acd.knownDescendants - case CaseClassType(ccd, _) => Set(ccd:ClassDef) - case _ => Set() - }(p))))(Set(s)) - } - - def cdMap(cd: ClassDef): ClassDef = { - cdMapCache.get(cd) match { - case Some(Some(new_cd)) => new_cd - case Some(None) => cd - case None => - if(cdMapF(cd).isDefined || dependencies(cd).exists(cd => cdMapF(cd).isDefined)) { // Needs replacement in any case. - duplicateClassDef(cd) - } else { - cdMapCache += cd -> None - } - cdMapCache(cd).getOrElse(cd) - } - } - - def idMap(id: Identifier): Identifier = { - if (!(idMapCache contains id)) { - val new_id = id.duplicate(tpe = tpMap(id.getType)) - idMapCache += id -> new_id - } - idMapCache(id) - } - - def idHasToChange(id: Identifier): Boolean = { - typeHasToChange(id.getType) - } - - def typeHasToChange(tp: TypeTree): Boolean = { - TypeOps.exists{ - case AbstractClassType(acd, _) => cdMap(acd) != acd - case CaseClassType(ccd, _) => cdMap(ccd) != ccd - case _ => false - }(tp) - } - - def patternHasToChange(p: Pattern): Boolean = { - PatternOps.exists { - case CaseClassPattern(optId, cct, sub) => optId.exists(idHasToChange) || typeHasToChange(cct) - case InstanceOfPattern(optId, cct) => optId.exists(idHasToChange) || typeHasToChange(cct) - case Extractors.Pattern(optId, subp, builder) => optId.exists(idHasToChange) - case e => false - } (p) - } - - def exprHasToChange(e: Expr): Boolean = { - ExprOps.exists{ - case Let(id, expr, body) => idHasToChange(id) - case Variable(id) => idHasToChange(id) - case ci @ CaseClass(cct, args) => typeHasToChange(cct) - case CaseClassSelector(cct, expr, identifier) => typeHasToChange(cct) || idHasToChange(identifier) - case IsInstanceOf(e, cct) => typeHasToChange(cct) - case AsInstanceOf(e, cct) => typeHasToChange(cct) - case MatchExpr(scrut, cases) => - cases.exists{ - case MatchCase(pattern, optGuard, rhs) => - patternHasToChange(pattern) - } - case fi @ FunctionInvocation(TypedFunDef(fd, tps), args) => - tps.exists(typeHasToChange) - case _ => - false - }(e) - } - - def funDefHasToChange(fd: FunDef): Boolean = { - exprHasToChange(fd.fullBody) || fd.params.exists(vid => typeHasToChange(vid.id.getType)) || typeHasToChange(fd.returnType) - } - - def funHasToChange(fd: FunDef): Boolean = { - funDefHasToChange(fd) || p.callGraph.transitiveCallees(fd).exists(fd => - fdMapFCache.get(fd) match { - case Some(Some(_)) => true - case Some(None) => false - case None => funDefHasToChange(fd) - }) - } - - def fdMapFCached(fd: FunDef): Option[FunDef] = { - fdMapFCache.get(fd) match { - case Some(e) => e - case None => - val new_fd = if(funHasToChange(fd)) { - Some(fd.duplicate(params = fd.params.map(vd => ValDef(idMap(vd.id))), returnType = tpMap(fd.returnType))) - } else { - None - } - fdMapFCache += fd -> new_fd - new_fd - } - } - - def duplicateParents(fd: FunDef): Unit = { - fdMapCache.get(fd) match { - case None => - fdMapCache += fd -> fdMapFCached(fd).orElse(Some(fd.duplicate())) - for(fp <- p.callGraph.callers(fd)) { - duplicateParents(fp) - } - case _ => - } - } - - def fdMap(fd: FunDef): FunDef = { - fdMapCache.get(fd) match { - case Some(Some(e)) => e - case Some(None) => fd - case None => - if(fdMapFCached(fd).isDefined || p.callGraph.transitiveCallees(fd).exists(fd => fdMapFCached(fd).isDefined)) { - duplicateParents(fd) - } else { - fdMapCache += fd -> None - } - fdMapCache(fd).getOrElse(fd) - } - } - - val newP = p.copy(units = for (u <- p.units) yield { - u.copy( - defs = u.defs.map { - case m : ModuleDef => - m.copy(defs = for (df <- m.defs) yield { - df match { - case cd : ClassDef => cdMap(cd) - case fd : FunDef => fdMap(fd) - case d => d - } - }) - case d => d - } - ) - }) - - def replaceClassDefUse(e: Pattern): Pattern = PatternOps.postMap{ - case CaseClassPattern(optId, cct, sub) => Some(CaseClassPattern(optId.map(idMap), tpMap[CaseClassType](cct), sub)) - case InstanceOfPattern(optId, cct) => Some(InstanceOfPattern(optId.map(idMap), tpMap[ClassType](cct))) - case UnapplyPattern(optId, TypedFunDef(fd, tps), subp) => Some(UnapplyPattern(optId.map(idMap), TypedFunDef(fdMap(fd), tps.map(tpMap)), subp)) - case Extractors.Pattern(Some(id), subp, builder) => Some(builder(Some(idMap(id)), subp)) - case e => None - }(e) - - def replaceClassDefsUse(e: Expr): Expr = { - ExprOps.postMap { - case Let(id, expr, body) => Some(Let(idMap(id), expr, body)) - case Lambda(vd, body) => Some(Lambda(vd.map(vd => ValDef(idMap(vd.id))), body)) - case Variable(id) => Some(Variable(idMap(id))) - case ci @ CaseClass(ct, args) => - ciMapF(ci, tpMap(ct)).map(_.setPos(ci)) - case CaseClassSelector(cct, expr, identifier) => - val new_cct = tpMap(cct) - val selection = if (new_cct != cct || new_cct.classDef.fieldsIds != cct.classDef.fieldsIds) idMap(identifier) else identifier - Some(CaseClassSelector(new_cct, expr, selection)) - case IsInstanceOf(e, ct) => Some(IsInstanceOf(e, tpMap(ct))) - case AsInstanceOf(e, ct) => Some(AsInstanceOf(e, tpMap(ct))) - case MatchExpr(scrut, cases) => - Some(MatchExpr(scrut, cases.map{ - case MatchCase(pattern, optGuard, rhs) => - MatchCase(replaceClassDefUse(pattern), optGuard, rhs) - })) - case fi @ FunctionInvocation(TypedFunDef(fd, tps), args) => - defaultFiMap(fi, fdMap(fd)).map(_.setPos(fi)) - case _ => - None - }(e) - } - - for (fd <- newP.definedFunctions) { - if (fdMapCache.getOrElse(fd, None).isDefined) { - fd.fullBody = replaceClassDefsUse(fd.fullBody) - } - } - - // make sure classDef invariants are correctly assigned to transformed classes - for ((cd, optNew) <- cdMapCache; newCd <- optNew; inv <- newCd.invariant) { - newCd.setInvariant(fdMap(inv)) - } - - (newP, - cdMapCache.collect{case (cd, Some(new_cd)) => cd -> new_cd}, - idMapCache, - fdMapCache.collect{case (cd, Some(new_cd)) => cd -> new_cd }) - } - - def addDefs(p: Program, cds: Traversable[Definition], after: Definition): Program = { - var found = false - val res = p.copy(units = for (u <- p.units) yield { - u.copy( - defs = u.defs.flatMap { - case m: ModuleDef => - val newdefs = for (df <- m.defs) yield { - df match { - case `after` => - found = true - after +: cds.toSeq - case d => Seq(d) - } - } - - Seq(m.copy(defs = newdefs.flatten)) - case `after` => - found = true - after +: cds.toSeq - case d => Seq(d) - } - ) - }) - - if (!found) { - println(s"addDefs could not find anchor definition! Not found: $after") - p.definedFunctions.filter(f => f.id.name == after.id.name).map(fd => fd.id.name + " : " + fd) match { - case Nil => - case e => - println("Did you mean one of:") - e foreach println - } - (new Exception).printStackTrace() - } - - res - } - - def addFunDefs(p: Program, fds: Traversable[FunDef], after: FunDef): Program = addDefs(p, fds, after) - - def addClassDefs(p: Program, fds: Traversable[ClassDef], after: ClassDef): Program = addDefs(p, fds, after) - - // @Note: This function does not filter functions in classdefs - def filterFunDefs(p: Program, fdF: FunDef => Boolean): Program = { - p.copy(units = p.units.map { u => - u.copy( - defs = u.defs.collect { - case md: ModuleDef => - md.copy(defs = md.defs.filter { - case fd: FunDef => fdF(fd) - case d => true - }) - - case cd => cd - } - ) - }) - } - - /** - * Returns a call graph starting from the given sources, taking into account - * instantiations of function type parameters, - * If given limit of explored nodes reached, it returns a partial set of reached TypedFunDefs - * and the boolean set to "false". - * Otherwise, it returns the full set of reachable TypedFunDefs and "true" - */ - - def typedTransitiveCallees(sources: Set[TypedFunDef], limit: Option[Int] = None): (Set[TypedFunDef], Boolean) = { - import leon.utils.SearchSpace.reachable - reachable( - sources, - (tfd: TypedFunDef) => functionCallsOf(tfd.fullBody) map { _.tfd }, - limit - ) - } - - def augmentCaseClassFields(extras: Seq[(CaseClassDef, Seq[(ValDef, Expr)])]) - (program: Program) = { - - def updateBody(body: Expr): Expr = { - preMap({ - case CaseClass(ct, args) => extras.find(p => p._1 == ct.classDef).map{ - case (ccd, extraFields) => - CaseClass(CaseClassType(ccd, ct.tps), args ++ extraFields.map{ case (_, v) => v }) - } - case _ => None - })(body) - } - - extras.foreach{ case (ccd, extraFields) => ccd.setFields(ccd.fields ++ extraFields.map(_._1)) } - for { - fd <- program.definedFunctions - } { - fd.body = fd.body.map(body => updateBody(body)) - fd.precondition = fd.precondition.map(pre => updateBody(pre)) - fd.postcondition = fd.postcondition.map(post => updateBody(post)) - } - } - -} diff --git a/src/main/scala/leon/purescala/Definitions.scala b/src/main/scala/leon/purescala/Definitions.scala deleted file mode 100644 index a86700e5419d736b96f834606ad40aa3355431fa..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/purescala/Definitions.scala +++ /dev/null @@ -1,645 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package purescala - -import utils.Library -import Common._ -import Expressions._ -import ExprOps._ -import Types._ -import TypeOps._ - -object Definitions { - - sealed abstract class Definition extends Tree { - - val id: Identifier - - def subDefinitions: Seq[Definition] // The enclosed scopes/definitions by this definition - - def containsDef(df: Definition): Boolean = { - subDefinitions.exists { sd => - sd == df || sd.containsDef(df) - } - } - - override def hashCode : Int = id.hashCode - override def equals(that : Any) : Boolean = that match { - case t : Definition => t.id == this.id - case _ => false - } - - def writeScalaFile(filename: String, opgm: Option[Program] = None) { - import java.io.FileWriter - import java.io.BufferedWriter - val fstream = new FileWriter(filename) - val out = new BufferedWriter(fstream) - out.write(ScalaPrinter(this, opgm = opgm)) - out.close() - } - } - - /** - * A ValDef declares a new identifier to be of a certain type. - * The optional tpe, if present, overrides the type of the underlying Identifier id - * This is useful to instantiate argument types of polymorphic functions - */ - case class ValDef(id: Identifier) extends Definition with Typed { - self: Serializable => - - val getType = id.getType - - var defaultValue : Option[FunDef] = None - - var isVar: Boolean = false - - def setIsVar(b: Boolean): this.type = { this.isVar = b; this } - - def subDefinitions = Seq() - - /** Transform this [[ValDef]] into a [[Expressions.Variable Variable]] */ - def toVariable : Variable = Variable(id) - } - - /** A wrapper for a program. For now a program is simply a single object. */ - case class Program(units: List[UnitDef]) extends Definition { - val id = FreshIdentifier("program") - - lazy val library = Library(this) - - def subDefinitions = units - - def definedFunctions = units.flatMap(_.definedFunctions) - def definedClasses = units.flatMap(_.definedClasses) - def classHierarchyRoots = units.flatMap(_.classHierarchyRoots) - def singleCaseClasses = units.flatMap(_.singleCaseClasses) - def modules = { - units.flatMap(_.defs.collect { - case md: ModuleDef => md - }) - } - - lazy val callGraph = new CallGraph(this) - - def caseClassDef(name: String) = definedClasses.collectFirst { - case ccd: CaseClassDef if ccd.id.name == name => ccd - }.getOrElse(throw LeonFatalError("Unknown case class '"+name+"'")) - - def lookupAll(name: String) = DefOps.searchWithin(name, this) - def lookup(name: String) = lookupAll(name).headOption - - def lookupCaseClass(name: String) = lookupAll(name).collect{ case c: CaseClassDef => c }.headOption - def lookupAbstractClass(name: String) = lookupAll(name).collect{ case c: AbstractClassDef => c }.headOption - def lookupFunDef(name: String) = lookupAll(name).collect{ case c: FunDef => c }.headOption - } - - object Program { - lazy val empty: Program = Program(Nil) - } - - case class TypeParameterDef(tp: TypeParameter) extends Definition { - def subDefinitions = Seq() - def freshen = TypeParameterDef(tp.freshen) - val id = tp.id - } - - /** A package as a path of names */ - type PackageRef = List[String] - - case class Import(path: List[String], isWild: Boolean) extends Tree { - def importedDefs(in: UnitDef)(implicit pgm: Program): Seq[Definition] = { - val found = DefOps.searchRelative(path.mkString("."), in) - if (isWild) { - found.flatMap(_.subDefinitions) - } else { - found - } - } - } - - /** Definition of a compilation unit, corresponding to a source file - * - * @param id The name of the file this [[UnitDef]] was compiled from - * @param pack The package of this [[UnitDef]] - * @param imports The imports of this [[UnitDef]] - * @param defs The [[Definition]]s (classes and objects) in this [[UnitDef]] - * @param isMainUnit Whether this is a user-provided file or a library file - */ - case class UnitDef( - id: Identifier, - pack: PackageRef, - imports: Seq[Import], - defs: Seq[Definition], - isMainUnit: Boolean - ) extends Definition { - - def subDefinitions = defs - - def definedFunctions = defs.flatMap{ - case m: ModuleDef => m.definedFunctions - case _ => Nil - } - - def definedClasses = defs.flatMap { - case c: ClassDef => List(c) - case m: ModuleDef => m.definedClasses - case _ => Nil - } - - def classHierarchyRoots = { - definedClasses.filter(!_.hasParent) - } - - // Guarantees that a parent always appears before its children - def classHierarchies = classHierarchyRoots map { root => - root +: root.knownDescendants - } - - def singleCaseClasses = { - definedClasses.collect { - case ccd: CaseClassDef if !ccd.hasParent => ccd - } - } - - def modules = defs.collect { - case md: ModuleDef => md - } - } - - object UnitDef { - def apply(id: Identifier, modules : Seq[ModuleDef]) : UnitDef = - UnitDef(id, Nil, Nil, modules, true) - } - - /** Corresponds to an '''object''' in scala. Contains [[FunDef]]s, [[ClassDef]]s and [[ValDef]]s. */ - case class ModuleDef(id: Identifier, defs: Seq[Definition], isPackageObject: Boolean) extends Definition { - - def subDefinitions = defs - - lazy val definedFunctions : Seq[FunDef] = defs.collect { case fd: FunDef => fd } - - lazy val definedClasses : Seq[ClassDef] = defs.collect { case ctd: ClassDef => ctd } - - lazy val classHierarchyRoots : Seq[ClassDef] = defs.collect { - case ctd: ClassDef if !ctd.hasParent => ctd - } - - lazy val algebraicDataTypes : Map[AbstractClassDef, Seq[CaseClassDef]] = defs.collect { - case c : CaseClassDef if c.parent.isDefined => c - }.groupBy(_.parent.get.classDef) - - lazy val singleCaseClasses : Seq[CaseClassDef] = defs.collect { - case c : CaseClassDef if !c.parent.isDefined => c - } - } - - /** A trait that represents flags that annotate a ClassDef with different attributes */ - sealed trait ClassFlag - - object ClassFlag { - def fromName(name: String, args: Seq[Option[Any]]): ClassFlag = Annotation(name, args) - } - - /** A trait that represents flags that annotate a FunDef with different attributes */ - sealed trait FunctionFlag - - object FunctionFlag { - def fromName(name: String, args: Seq[Option[Any]]): FunctionFlag = name match { - case "inline" => IsInlined - case _ => Annotation(name, args) - } - } - - // Whether this FunDef was originally a (lazy) field - case class IsField(isLazy: Boolean) extends FunctionFlag - // Compiler annotations given in the source code as @annot - case class Annotation(annot: String, args: Seq[Option[Any]]) extends FunctionFlag with ClassFlag - // If this class was a method. owner is the original owner of the method - case class IsMethod(owner: ClassDef) extends FunctionFlag - // If this function represents a loop that was there before XLangElimination - // Contains a link to the FunDef where the loop was defined - case class IsLoop(owner: FunDef) extends FunctionFlag - // If extraction fails of the function's body fais, it is marked as abstract - case object IsAbstract extends FunctionFlag - // Currently, the only synthetic functions are those that calculate default values of parameters - case object IsSynthetic extends FunctionFlag - // Is inlined - case object IsInlined extends FunctionFlag - // Is an ADT invariant method - case object IsADTInvariant extends FunctionFlag - case object IsInner extends FunctionFlag - - /** Represents a class definition (either an abstract- or a case-class) */ - sealed trait ClassDef extends Definition { - self => - - def subDefinitions = fields ++ methods ++ tparams - - val id: Identifier - val tparams: Seq[TypeParameterDef] - def fields: Seq[ValDef] - val parent: Option[AbstractClassType] - - def hasParent = parent.isDefined - - def fieldsIds = fields.map(_.id) - - private var _children: List[ClassDef] = Nil - - def registerChild(chd: ClassDef) = { - _children = (chd :: _children).sortBy(_.id.name) - } - - private var _methods = List[FunDef]() - - def registerMethod(fd: FunDef) = { - _methods = _methods ::: List(fd) - } - - def unregisterMethod(id: Identifier) = { - _methods = _methods filterNot (_.id == id) - } - - def clearMethods() { - _methods = Nil - } - - def methods = _methods - - private var _flags: Set[ClassFlag] = Set() - - def addFlags(flags: Set[ClassFlag]): this.type = { - this._flags ++= flags - this - } - - def addFlag(flag: ClassFlag): this.type = addFlags(Set(flag)) - - def flags = _flags - - private var _invariant: Option[FunDef] = None - - def invariant: Option[FunDef] = parent.flatMap(_.classDef.invariant).orElse(_invariant) - def setInvariant(fd: FunDef): Unit = parent match { - case Some(act) => act.classDef.setInvariant(fd) - case None => _invariant = Some(fd) - } - - def hasInvariant: Boolean = invariant.isDefined || (this +: root.knownDescendants).exists(cd => cd.methods.exists(_.isInvariant)) - - def annotations: Set[String] = extAnnotations.keySet - def extAnnotations: Map[String, Seq[Option[Any]]] = flags.collect { case Annotation(s, args) => s -> args }.toMap - - lazy val ancestors: Seq[ClassDef] = parent.toSeq flatMap { p => p.classDef +: p.classDef.ancestors } - - lazy val root = ancestors.lastOption.getOrElse(this) - - def knownChildren: Seq[ClassDef] = _children - - def knownDescendants: Seq[ClassDef] = { - knownChildren ++ knownChildren.flatMap { - case acd: AbstractClassDef => acd.knownDescendants - case _ => Nil - } - } - - def knownCCDescendants: Seq[CaseClassDef] = knownDescendants.collect { - case ccd: CaseClassDef => - ccd - } - - def isInductive: Boolean = { - def induct(tpe: TypeTree, seen: Set[ClassDef]): Boolean = tpe match { - case ct: ClassType => - val root = ct.classDef.root - seen(root) || ct.fields.forall(vd => induct(vd.getType, seen + root)) - case TupleType(tpes) => - tpes.forall(tpe => induct(tpe, seen)) - case _ => true - } - - if (this == root && !this.isAbstract) false - else if (this != root) root.isInductive - else knownCCDescendants.forall { ccd => - ccd.fields.forall(vd => induct(vd.getType, Set(root))) - } - } - - val isAbstract: Boolean - val isCaseObject: Boolean - - lazy val definedFunctions : Seq[FunDef] = methods - lazy val definedClasses = Seq(this) - - def typeArgs = tparams map (_.tp) - - def typed(tps: Seq[TypeTree]): ClassType - def typed: ClassType - } - - /** Abstract classes. */ - class AbstractClassDef(val id: Identifier, - val tparams: Seq[TypeParameterDef], - val parent: Option[AbstractClassType]) extends ClassDef { - - val fields = Nil - val isAbstract = true - val isCaseObject = false - - lazy val singleCaseClasses : Seq[CaseClassDef] = Nil - - def typed(tps: Seq[TypeTree]) = { - require(tps.length == tparams.length) - AbstractClassType(this, tps) - } - def typed: AbstractClassType = typed(tparams.map(_.tp)) - - /** Duplication of this [[CaseClassDef]]. - * @note This will not add known case class children - */ - def duplicate( - id: Identifier = this.id.freshen, - tparams: Seq[TypeParameterDef] = this.tparams, - parent: Option[AbstractClassType] = this.parent - ): AbstractClassDef = { - val acd = new AbstractClassDef(id, tparams, parent) - acd.addFlags(this.flags) - if (!parent.exists(_.classDef.hasInvariant)) invariant.foreach(inv => acd.setInvariant(inv)) - parent.foreach(_.classDef.registerChild(acd)) - acd.copiedFrom(this) - } - } - - /** Case classes/ case objects. */ - class CaseClassDef(val id: Identifier, - val tparams: Seq[TypeParameterDef], - val parent: Option[AbstractClassType], - val isCaseObject: Boolean) extends ClassDef { - - private var _fields = Seq[ValDef]() - - def fields = _fields - - def setFields(fields: Seq[ValDef]) { - _fields = fields - } - - val isAbstract = false - - def selectorID2Index(id: Identifier) : Int = { - val index = fields.indexWhere(_.id == id) - - if (index < 0) { - scala.sys.error( - "Could not find '"+id+"' ("+id.uniqueName+") within "+ - fields.map(_.id.uniqueName).mkString(", ") - ) - } else index - } - - lazy val singleCaseClasses : Seq[CaseClassDef] = if (hasParent) Nil else Seq(this) - - def typed: CaseClassType = typed(tparams.map(_.tp)) - def typed(tps: Seq[TypeTree]): CaseClassType = { - require(tps.length == tparams.length) - CaseClassType(this, tps) - } - - /** Duplication of this [[CaseClassDef]]. - * @note This will not replace recursive [[CaseClassDef]] calls in [[fields]] nor the parent abstract class types - */ - def duplicate( - id: Identifier = this.id.freshen, - tparams: Seq[TypeParameterDef] = this.tparams, - fields: Seq[ValDef] = this.fields, - parent: Option[AbstractClassType] = this.parent, - isCaseObject: Boolean = this.isCaseObject - ): CaseClassDef = { - val cd = new CaseClassDef(id, tparams, parent, isCaseObject) - cd.setFields(fields) - cd.addFlags(this.flags) - if (!parent.exists(_.classDef.hasInvariant)) invariant.foreach(inv => cd.setInvariant(inv)) - parent.foreach(_.classDef.registerChild(cd)) - cd.copiedFrom(this) - } - } - - /** Function/method definition. - * - * This class represents methods or fields of objects or classes. By "fields" we mean - * fields defined in the body of a class/object, not the constructor arguments of a case class - * (those are accessible through [[leon.purescala.Definitions.ClassDef.fields]]). - * - * When it comes to verification, all are treated the same (as functions). - * They are only differentiated when it comes to code generation/ pretty printing. - * By default, the FunDef represents a function/method as opposed to a field, - * unless otherwise specified by its flags. - * - * Bear in mind that [[id]] will not be consistently typed. - */ - class FunDef( - val id: Identifier, - val tparams: Seq[TypeParameterDef], - val params: Seq[ValDef], - val returnType: TypeTree - ) extends Definition { - - /* Body manipulation */ - - var fullBody: Expr = NoTree(returnType) - - def body: Option[Expr] = withoutSpec(fullBody) - def body_=(b: Option[Expr]) = { - fullBody = withBody(fullBody, b) - } - - def precondition = preconditionOf(fullBody) - def precondition_=(oe: Option[Expr]) = { - fullBody = withPrecondition(fullBody, oe) - } - def precondition_=(p: Path) = { - fullBody = withPath(fullBody, p) - } - def precOrTrue = precondition getOrElse BooleanLiteral(true) - - def postcondition = postconditionOf(fullBody) - def postcondition_=(op: Option[Expr]) = { - fullBody = withPostcondition(fullBody, op) - } - def postOrTrue = postcondition getOrElse { - val arg = ValDef(FreshIdentifier("res", returnType, alwaysShowUniqueID = true)) - Lambda(Seq(arg), BooleanLiteral(true)) - } - - def hasBody = body.isDefined - def hasPrecondition = precondition.isDefined - def hasPostcondition = postcondition.isDefined - - /* Nested definitions */ - def directlyNestedFuns = directlyNestedFunDefs(fullBody) - def subDefinitions = params ++ tparams ++ directlyNestedFuns.toList - - /** Duplication of this [[FunDef]]. - * @note This will not replace recursive function calls in [[fullBody]] - */ - def duplicate( - id: Identifier = this.id.freshen, - tparams: Seq[TypeParameterDef] = this.tparams, - params: Seq[ValDef] = this.params, - returnType: TypeTree = this.returnType - ): FunDef = { - val fd = new FunDef(id, tparams, params, returnType) - fd.fullBody = this.fullBody - fd.addFlags(this.flags) - fd.copiedFrom(this) - } - - /* Flags */ - - private[this] var _flags: Set[FunctionFlag] = Set() - - def addFlags(flags: Set[FunctionFlag]): FunDef = { - this._flags ++= flags - this - } - - def addFlag(flag: FunctionFlag): FunDef = addFlags(Set(flag)) - - def flags = _flags - - def annotations: Set[String] = extAnnotations.keySet - def extAnnotations: Map[String, Seq[Option[Any]]] = flags.collect { - case Annotation(s, args) => s -> args - }.toMap - def canBeLazyField = flags.contains(IsField(true)) && params.isEmpty && tparams.isEmpty - def canBeStrictField = flags.contains(IsField(false)) && params.isEmpty && tparams.isEmpty - def canBeField = canBeLazyField || canBeStrictField - def isRealFunction = !canBeField - def isSynthetic = flags contains IsSynthetic - def isInvariant = flags contains IsADTInvariant - def isInner = flags contains IsInner - def methodOwner = flags collectFirst { case IsMethod(cd) => cd } - - /* Wrapping in TypedFunDef */ - - def typed(tps: Seq[TypeTree]): TypedFunDef = { - assert(tps.size == tparams.size) - TypedFunDef(this, tps) - } - - def typed: TypedFunDef = typed(tparams.map(_.tp)) - - /* Auxiliary methods */ - - def qualifiedName(implicit pgm: Program) = DefOps.qualifiedName(this, useUniqueIds = false) - - def isRecursive(p: Program) = p.callGraph.transitiveCallees(this) contains this - - def paramIds = params map { _.id } - - def typeArgs = tparams map (_.tp) - - def applied(args: Seq[Expr]): FunctionInvocation = Constructors.functionInvocation(this, args) - def applied = FunctionInvocation(this.typed, this.paramIds map Variable) - } - - - // Wrapper for typing function according to valuations for type parameters - case class TypedFunDef(fd: FunDef, tps: Seq[TypeTree]) extends Tree { - val id = fd.id - - def signature = { - if (tps.nonEmpty) { - id.toString+tps.mkString("[", ", ", "]") - } else { - id.toString - } - } - - private lazy val typesMap: Map[TypeParameter, TypeTree] = { - (fd.typeArgs zip tps).toMap.filter(tt => tt._1 != tt._2) - } - - def translated(t: TypeTree): TypeTree = instantiateType(t, typesMap) - - def translated(e: Expr): Expr = instantiateType(e, typesMap, paramsMap) - - /** A mapping from this [[TypedFunDef]]'s formal parameters to real arguments - * - * @param realArgs The arguments to which the formal argumentas are mapped - * */ - def paramSubst(realArgs: Seq[Expr]) = { - require(realArgs.size == params.size) - (paramIds zip realArgs).toMap - } - - /** Substitute this [[TypedFunDef]]'s formal parameters with real arguments in some expression - * - * @param realArgs The arguments to which the formal argumentas are mapped - * @param e The expression in which the substitution will take place - */ - def withParamSubst(realArgs: Seq[Expr], e: Expr) = { - replaceFromIDs(paramSubst(realArgs), e) - } - - def applied(realArgs: Seq[Expr]): FunctionInvocation = { - FunctionInvocation(this, realArgs) - } - - def applied: FunctionInvocation = - applied(params map { _.toVariable }) - - /** - * Params will return ValDefs instantiated with the correct types - * For such a ValDef(id,tp) it may hold that (id.getType != tp) - */ - lazy val (params: Seq[ValDef], paramsMap: Map[Identifier, Identifier]) = { - if (typesMap.isEmpty) { - (fd.params, Map()) - } else { - val newParams = fd.params.map { vd => - val newTpe = translated(vd.getType) - val newId = FreshIdentifier(vd.id.name, newTpe, true).copiedFrom(vd.id) - vd.copy(id = newId).setPos(vd) - } - - val paramsMap: Map[Identifier, Identifier] = (fd.params zip newParams).map { case (vd1, vd2) => vd1.id -> vd2.id }.toMap - - (newParams, paramsMap) - } - } - - lazy val functionType = FunctionType(params.map(_.getType).toList, returnType) - - lazy val returnType: TypeTree = translated(fd.returnType) - - lazy val paramIds = params map { _.id } - - private var trCache = Map[Expr, Expr]() - - private def cached(e: Expr): Expr = { - trCache.getOrElse(e, { - val res = translated(e) - trCache += e -> res - res - }) - } - - // Methods that extract expressions from the underlying FunDef, using a cache - - def fullBody = cached(fd.fullBody) - def body = fd.body map cached - def precondition = fd.precondition map cached - def precOrTrue = cached(fd.precOrTrue) - def postcondition = fd.postcondition map cached - def postOrTrue = cached(fd.postOrTrue) - - def hasImplementation = body.isDefined - def hasBody = hasImplementation - def hasPrecondition = precondition.isDefined - def hasPostcondition = postcondition.isDefined - - override def getPos = fd.getPos - } -} diff --git a/src/main/scala/leon/purescala/FunctionClosure.scala b/src/main/scala/leon/purescala/FunctionClosure.scala deleted file mode 100644 index a98dbf64b8a52c9d049ca6b00260379940683945..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/purescala/FunctionClosure.scala +++ /dev/null @@ -1,200 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package purescala - -import Definitions._ -import Expressions._ -import ExprOps._ -import TypeOps.instantiateType -import Common.Identifier -import leon.purescala.Types.TypeParameter -import utils.GraphOps._ - -object FunctionClosure extends TransformationPhase { - - override val name: String = "Function Closure" - override val description: String = "Closing function with its scoping variables" - - /** Takes a FunDef and returns a Seq of all internal FunDef's contained in fd in closed form - * (and fd itself, without inned FunDef's). - * - * The strategy is as follows: Remove one layer of nested FunDef's, then call - * close recursively on the new functions. - */ - def close(fd: FunDef): Seq[FunDef] = { - - // Directly nested functions with their p.c. - val nestedWithPathsFull = { - val funDefs = directlyNestedFunDefs(fd.fullBody) - collectWithPC { - case LetDef(fd1, body) => fd1.filter(funDefs) - }(fd.fullBody) - } - - val nestedWithPaths = (for((fds, path) <- nestedWithPathsFull; fd <- fds) yield (fd, path)).toMap - val nestedFuns = nestedWithPaths.keys.toSeq - - // Transitively called funcions from each function - val callGraph: Map[FunDef, Set[FunDef]] = transitiveClosure( - nestedFuns.map { f => - val calls = functionCallsOf(f.fullBody) collect { - case FunctionInvocation(TypedFunDef(fd, _), _) if nestedFuns.contains(fd) => fd - } - - val pcCalls = functionCallsOf(nestedWithPaths(f).fullClause) collect { - case FunctionInvocation(TypedFunDef(fd, _), _) if nestedFuns.contains(fd) => fd - } - - f -> (calls ++ pcCalls) - }.toMap - ) - //println("nested funs: " + nestedFuns) - //println("call graph: " + callGraph) - - def freeVars(fd: FunDef, pc: Path): Set[Identifier] = - variablesOf(fd.fullBody) ++ pc.variables ++ pc.bindings.map(_._1) -- fd.paramIds - - // All free variables one should include. - // Contains free vars of the function itself plus of all transitively called functions. - // also contains free vars from PC if the PC is relevant to the fundef - val transFreeWithBindings = { - def step(current: Map[FunDef, Set[Identifier]]): Map[FunDef, Set[Identifier]] = { - nestedFuns.map(fd => { - val transFreeVars = (callGraph(fd) + fd).flatMap((fd2:FunDef) => current(fd2)) - val reqPath = nestedWithPaths(fd).filterByIds(transFreeVars) - (fd, transFreeVars ++ freeVars(fd, reqPath)) - }).toMap - } - - utils.fixpoint(step, -1)(nestedFuns.map(fd => (fd, variablesOf(fd.fullBody) -- fd.paramIds)).toMap) - } - - val transFree: Map[FunDef, Seq[Identifier]] = - //transFreeWithBindings.map(p => (p._1, p._2 -- nestedWithPaths(p._1).bindings.map(_._1))).map(p => (p._1, p._2.toSeq)) - transFreeWithBindings.map(p => (p._1, p._2.toSeq)) - - - // Closed functions along with a map (old var -> new var). - val closed = nestedWithPaths.map { - case (inner, pc) => inner -> closeFd(inner, fd, pc, transFree(inner)) - } - - // Remove LetDefs from fd - fd.fullBody = preMap({ - case LetDef(fds, bd) => - Some(bd) - case _ => - None - }, applyRec = true)(fd.fullBody) - - // A dummy substitution for fd, saying we should not change parameters - val dummySubst = FunSubst( - fd, - Map.empty.withDefault(id => id), - Map.empty.withDefault(id => id) - ) - - // Refresh function calls - (dummySubst +: closed.values.toSeq).foreach { - case FunSubst(f, callerMap, callerTMap) => - f.fullBody = preMap { - case fi @ FunctionInvocation(tfd, args) if closed contains tfd.fd => - val FunSubst(newCallee, calleeMap, calleeTMap) = closed(tfd.fd) - - // This needs some explanation. - // Say we have caller and callee. First we find the param. substitutions of callee - // (say old -> calleeNew) and reverse them. So we have a mapping (calleeNew -> old). - // We also have the caller mapping, (old -> callerNew). - // So we pass the callee parameters through these two mappings to get the caller parameters. - val mapReverse = calleeMap map { _.swap } - val extraArgs = newCallee.paramIds.drop(args.size).map { id => - callerMap(mapReverse(id)).toVariable - } - - // Similarly for type params - val tReverse = calleeTMap map { _.swap } - val tOrigExtraOrdered = newCallee.tparams.map{_.tp}.drop(tfd.tps.length).map(tReverse) - val tFinalExtra: Seq[TypeParameter] = tOrigExtraOrdered.map( tp => - callerTMap(tp) - ) - - Some(FunctionInvocation( - newCallee.typed(tfd.tps ++ tFinalExtra), - args ++ extraArgs - ).copiedFrom(fi)) - case _ => None - }(f.fullBody) - } - - val funs = closed.values.toSeq.map{ _.newFd } - funs foreach (_.addFlag(IsInner)) - - // Recursively close new functions - fd +: funs.flatMap(close) - } - - // Represents a substitution to a new function, along with parameter and type parameter - // mappings - private case class FunSubst( - newFd: FunDef, - paramsMap: Map[Identifier, Identifier], - tparamsMap: Map[TypeParameter, TypeParameter] - ) - - // Takes one inner function and closes it. - private def closeFd(inner: FunDef, outer: FunDef, pc: Path, free: Seq[Identifier]): FunSubst = { - //println("inner: " + inner) - //println("pc: " + pc) - //println("free: " + free.map(_.uniqueName)) - - val reqPC = pc.filterByIds(free.toSet) - - val tpFresh = outer.tparams map { _.freshen } - val tparamsMap = outer.typeArgs.zip(tpFresh map {_.tp}).toMap - - val freshVals = (inner.paramIds ++ free).map{_.freshen}.map(instantiateType(_, tparamsMap)) - val freeMap = (inner.paramIds ++ free).zip(freshVals).toMap - val freshParams = (inner.paramIds ++ free).filterNot(v => reqPC.isBound(v)).map(v => freeMap(v)) - - val newFd = inner.duplicate( - inner.id.freshen, - inner.tparams ++ tpFresh, - freshParams.map(ValDef(_)), - instantiateType(inner.returnType, tparamsMap) - ) - - val instBody = instantiateType( - withPath(newFd.fullBody, reqPC), - tparamsMap, - freeMap - ) - - newFd.fullBody = preMap { - case Let(id, v, r) if freeMap.isDefinedAt(id) => Some(Let(freeMap(id), v, r)) - case fi @ FunctionInvocation(tfd, args) if tfd.fd == inner => - Some(FunctionInvocation( - newFd.typed(tfd.tps ++ tpFresh.map{ _.tp }), - args ++ freshParams.drop(args.length).map(Variable) - ).setPos(fi)) - case _ => None - }(instBody) - - //HACK to make sure substitution happened even in nested fundef - newFd.fullBody = replaceFromIDs(freeMap.map(p => (p._1, p._2.toVariable)), newFd.fullBody) - - - FunSubst(newFd, freeMap, tparamsMap) - } - - override def apply(ctx: LeonContext, program: Program): Program = { - val newUnits = program.units.map { u => u.copy(defs = u.defs map { - case m: ModuleDef => - m.copy(defs = m.definedClasses ++ m.definedFunctions.flatMap(close)) - case cd => - cd - })} - Program(newUnits) - } - -} diff --git a/src/main/scala/leon/purescala/FunctionMapping.scala b/src/main/scala/leon/purescala/FunctionMapping.scala deleted file mode 100644 index cee6df3d2ccc562ecd98625b73fd5532cc873fb9..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/purescala/FunctionMapping.scala +++ /dev/null @@ -1,64 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package purescala - -import Definitions._ -import Expressions._ -import ExprOps._ -import Types._ - -abstract class FunctionMapping extends TransformationPhase { - - val functionToFunction : Map[FunDef, FunctionTransformer] - - case class FunctionTransformer( - to : FunDef, - onArgs : Seq[Expr] => List[Expr], - onTypes : Seq[TypeTree] => List[TypeTree] - ) - - val name = "Function Mapping" - val description = "Replace functions and their invocations according to a given mapping" - - private def replaceCalls(e : Expr) : Expr = preMap { - case fi@FunctionInvocation(TypedFunDef(fd, tps), args) if functionToFunction contains fd => - val FunctionTransformer(to, onArgs, onTypes) = functionToFunction(fd) - Some(FunctionInvocation(TypedFunDef(to, onTypes(tps)), onArgs(args)).setPos(fi)) - // case MethodInvocation - case _ => None - }(e) - - def apply(ctx: LeonContext, program: Program): Program = { - val newP = - program.copy(units = for (u <- program.units) yield { - u.copy( - defs = u.defs map { - case m: ModuleDef => - m.copy(defs = for (df <- m.defs) yield { - df match { - case f : FunDef => - val newF = functionToFunction.get(f).map{_.to}.getOrElse(f) - newF.fullBody = replaceCalls(f.fullBody) - newF - case c : ClassDef => - // val oldMethods = c.methods - // c.clearMethods() - // for (m <- oldMethods) { - // c.registerMethod(functionToFunction.get(m).map{_.to}.getOrElse(m)) - // } - c - case d => - d - } - }) - case d => d - } - ) - }) - - newP - - } - -} diff --git a/src/main/scala/leon/purescala/MethodLifting.scala b/src/main/scala/leon/purescala/MethodLifting.scala deleted file mode 100644 index d9d9e6f5a4b42c352e1177293063d5c313cb022d..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/purescala/MethodLifting.scala +++ /dev/null @@ -1,367 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package purescala - -import Common._ -import Definitions._ -import Expressions._ -import Extractors._ -import ExprOps._ -import Types._ -import Constructors._ -import TypeOps.instantiateType -import xlang.Expressions._ - -object MethodLifting extends TransformationPhase { - - val name = "Method Lifting" - val description = "Translate methods into functions of the companion object" - - // Takes cd and its subclasses and creates cases that together will form a composite method. - // fdId is the method id which will be searched for in the subclasses. - // cd is the hierarchy root - // A Seq of MatchCases is returned, along with a boolean that signifies if the matching is complete. - private def makeCases(cd: ClassDef, fdId: Identifier, breakDown: Expr => Expr): (Seq[MatchCase], Boolean) = cd match { - case ccd: CaseClassDef => - - // Common for both cases - val ct = ccd.typed - val binder = FreshIdentifier(ccd.id.name.toLowerCase, ct, true) - val fBinders = (ccd.fieldsIds zip ct.fields).map(p => p._1 -> p._2.id.freshen).toMap - def subst(e: Expr): Expr = e match { - case CaseClassSelector(`ct`, This(`ct`), i) => - Variable(fBinders(i)).setPos(e) - case This(`ct`) => - Variable(binder).setPos(e) - case OldThis(`ct`) => - Old(binder).setPos(e) - case e => - e - } - - ccd.methods.find(_.id == fdId).map { m => - - // Ancestor's method is a method in the case class - val subPatts = ccd.fields map (f => WildcardPattern(Some(fBinders(f.id)))) - val patt = CaseClassPattern(Some(binder), ct, subPatts) - val newE = simplePreTransform(subst)(breakDown(m.fullBody)) - val cse = SimpleCase(patt, newE).setPos(newE) - (List(cse), true) - - } orElse ccd.fields.find(_.id == fdId).map { f => - - // Ancestor's method is a case class argument in the case class - val subPatts = ccd.fields map (fld => - if (fld.id == f.id) - WildcardPattern(Some(fBinders(f.id))) - else - WildcardPattern(None) - ) - val patt = CaseClassPattern(Some(binder), ct, subPatts) - val newE = breakDown(Variable(fBinders(f.id))) - val cse = SimpleCase(patt, newE).setPos(newE) - (List(cse), true) - - } getOrElse { - (List(), false) - } - - case acd: AbstractClassDef => - val (r, c) = acd.knownChildren.map(makeCases(_, fdId, breakDown)).unzip - val recs = r.flatten - val complete = !(c contains false) - if (complete) { - // Children define all cases completely, we don't need to add anything - (recs, true) - } else if (!acd.methods.exists(m => m.id == fdId && m.body.nonEmpty)) { - // We don't have anything to add - (recs, false) - } else { - // We have something to add - val m = acd.methods.find(m => m.id == fdId).get - val at = acd.typed - val binder = FreshIdentifier(acd.id.name.toLowerCase, at, true) - val newE = simplePreTransform { - case This(ct) => asInstOf(Variable(binder), ct) - case e => e - } (breakDown(m.fullBody)) - - val cse = SimpleCase(InstanceOfPattern(Some(binder), at), newE).setPos(newE) - (recs :+ cse, true) - } - } - - def makeInvCases(cd: ClassDef): (Seq[MatchCase], Boolean) = { - val ct = cd.typed - val binder = FreshIdentifier(cd.id.name.toLowerCase, ct, true) - val fd = cd.methods.find(_.isInvariant).get - - cd match { - case ccd: CaseClassDef => - val fBinders = (ccd.fieldsIds zip ct.fields).map(p => p._1 -> p._2.id.freshen).toMap - - // Ancestor's method is a method in the case class - val subPatts = ccd.fields map (f => WildcardPattern(Some(fBinders(f.id)))) - val patt = CaseClassPattern(Some(binder), ct.asInstanceOf[CaseClassType], subPatts) - val newE = simplePreTransform { - case e @ CaseClassSelector(`ct`, This(`ct`), i) => - Variable(fBinders(i)).setPos(e) - case e @ This(`ct`) => - Variable(binder).setPos(e) - case e @ OldThis(`ct`) => - Old(binder).setPos(e) - case e => - e - } (fd.fullBody) - - if (newE == BooleanLiteral(true)) { - (Nil, false) - } else { - val cse = SimpleCase(patt, newE).setPos(newE) - (List(cse), true) - } - - case acd: AbstractClassDef => - val (r, c) = acd.knownChildren.map(makeInvCases).unzip - val recs = r.flatten - val complete = !(c contains false) - - val newE = simplePreTransform { - case This(ct) => asInstOf(Variable(binder), ct) - case OldThis(ct) => asInstOf(Old(binder), ct) - case e => e - } (fd.fullBody) - - if (newE == BooleanLiteral(true)) { - (recs, false) - } else { - val rhs = if (recs.isEmpty) { - newE - } else { - val allCases = if (complete) recs else { - recs :+ SimpleCase(WildcardPattern(None), BooleanLiteral(true)) - } - and(newE, MatchExpr(binder.toVariable, allCases)).setPos(newE) - } - val cse = SimpleCase(InstanceOfPattern(Some(binder), ct), rhs).setPos(newE) - (Seq(cse), true) - } - } - } - - def apply(ctx: LeonContext, program: Program): Program = { - - // First we create the appropriate functions from methods: - var mdToFds = Map[FunDef, FunDef]() - var mdToCls = Map[FunDef, ClassDef]() - - // Lift methods to the root class - for { - u <- program.units - ch <- u.classHierarchies - c <- ch - if c.parent.isDefined - fd <- c.methods - if c.ancestors.forall(!_.methods.map{_.id}.contains(fd.id)) - } { - val root = c.ancestors.last - val tMap = c.typeArgs.zip(root.typeArgs).toMap - val tSubst: TypeTree => TypeTree = instantiateType(_, tMap) - - val fdParams = fd.params map { vd => - val newId = FreshIdentifier(vd.id.name, tSubst(vd.id.getType)) - vd.copy(id = newId).setPos(vd.getPos) - } - val paramsMap = fd.params.zip(fdParams).map{ case (from, to) => from.id -> to.id }.toMap - val eSubst: Expr => Expr = instantiateType(_, tMap, paramsMap) - - val newFd = fd.duplicate(fd.id, fd.tparams, fdParams, tSubst(fd.returnType)) // FIXME: I don't like reusing the Identifier - - mdToCls += newFd -> c - - newFd.fullBody = eSubst(newFd.fullBody) - - c.unregisterMethod(fd.id) - root.registerMethod(newFd) - } - - val newUnits = for (u <- program.units) yield { - var fdsOf = Map[String, Set[FunDef]]() - // 1) Create one function for each method - for { cd <- u.classHierarchyRoots; fd <- cd.methods } { - // We import class type params and freshen them - val ctParams = cd.tparams map { _.freshen } - val tparamsMap = cd.typeArgs.zip(ctParams map { _.tp }).toMap - - val id = fd.id.freshen - val recType = cd.typed(ctParams.map(_.tp)) - val retType = instantiateType(fd.returnType, tparamsMap) - val fdParams = fd.params map { vd => - val newId = FreshIdentifier(vd.id.name, instantiateType(vd.id.getType, tparamsMap)) - vd.copy(id = newId).setPos(vd.getPos) - } - - val receiver = FreshIdentifier("thiss", recType).setPos(cd.id) - - val nfd = fd.duplicate(id, ctParams ++ fd.tparams, ValDef(receiver) +: fdParams, retType) - nfd.addFlag(IsMethod(cd)) - - def classPre(fd: FunDef) = mdToCls.get(fd) match { - case None => - BooleanLiteral(true) - case Some(cl) => - isInstOf(Variable(receiver), cl.typed(ctParams map { _.tp })) - } - - if (cd.knownDescendants.forall(cd => (cd.methods ++ cd.fields).forall(_.id != fd.id))) { - // Don't need to compose methods - val paramsMap = fd.params.zip(fdParams).map { case (x,y) => (x.id, y.id) }.toMap - def thisToReceiver(e: Expr): Option[Expr] = e match { - case th @ This(ct) => - Some(asInstOf(receiver.toVariable, ct).setPos(th)) - case th @ OldThis(ct) => - Some(asInstOf(Old(receiver), ct).setPos(th)) - case _ => - None - } - - val insTp: Expr => Expr = instantiateType(_, tparamsMap, paramsMap) - nfd.fullBody = postMap(thisToReceiver)(insTp(nfd.fullBody)) - - // Add precondition if the method was defined in a subclass - val pre = and(classPre(fd), nfd.precOrTrue) - nfd.fullBody = withPrecondition(nfd.fullBody, Some(pre)) - - } else { - // We need to compose methods of subclasses - - /* (Type) parameter substitutions that look at all subclasses */ - val paramsMap = (for { - c <- cd.knownDescendants :+ cd - m <- c.methods if m.id == fd.id - (from,to) <- m.params zip fdParams - } yield (from.id, to.id)).toMap - - val classParamsMap = (for { - c <- cd.knownDescendants :+ cd - (from, to) <- c.tparams zip ctParams - } yield (from.tp, to.tp)).toMap - - val methodParamsMap = (for { - c <- cd.knownDescendants :+ cd - m <- c.methods if m.id == fd.id - (from,to) <- m.tparams zip fd.tparams - } yield (from.tp, to.tp)).toMap - - def inst(cs: Seq[MatchCase]) = instantiateType( - matchExpr(Variable(receiver), cs).setPos(fd), - classParamsMap ++ methodParamsMap, - paramsMap + (receiver -> receiver) - ) - - if (fd.isInvariant) { - val (cases, complete) = makeInvCases(cd) - val allCases = if (complete) cases else { - cases :+ SimpleCase(WildcardPattern(None), BooleanLiteral(true)) - } - - nfd.fullBody = inst(allCases).setPos(fd.getPos) - } else { - /* Separately handle pre, post, body */ - val (pre, _) = makeCases(cd, fd.id, preconditionOf(_).getOrElse(BooleanLiteral(true))) - val (post, _) = makeCases(cd, fd.id, postconditionOf(_).getOrElse( - Lambda(Seq(ValDef(FreshIdentifier("res", retType, true))), BooleanLiteral(true)) - )) - val (body, _) = makeCases(cd, fd.id, withoutSpec(_).getOrElse(NoTree(retType))) - - // Some simplifications - val preSimple = { - val nonTrivial = pre.count{ _.rhs != BooleanLiteral(true) } - - val compositePre = - if (nonTrivial == 0) { - BooleanLiteral(true) - } else { - inst(pre).setPos(fd.getPos) - } - - Some(and(classPre(fd), compositePre)) - } - - val postSimple = { - val trivial = post.forall { - case SimpleCase(_, Lambda(_, BooleanLiteral(true))) => true - case _ => false - } - if (trivial) None - else { - val resVal = FreshIdentifier("res", retType, true) - Some(Lambda( - Seq(ValDef(resVal)), - inst(post map { cs => cs.copy( rhs = - application(cs.rhs, Seq(Variable(resVal))) - )}) - ).setPos(fd)) - } - } - - val bodySimple = { - val trivial = body forall { - case SimpleCase(_, NoTree(_)) => true - case _ => false - } - if (trivial) NoTree(retType) else inst(body) - } - - /* Construct full body */ - nfd.fullBody = withPostcondition( - withPrecondition(bodySimple, preSimple), - postSimple - ) - } - } - - mdToFds += fd -> nfd - fdsOf += cd.id.name -> (fdsOf.getOrElse(cd.id.name, Set()) + nfd) - - if (fd.isInvariant) cd.setInvariant(nfd) - } - - // 2) Place functions in existing companions: - val defs = u.defs map { - case md: ModuleDef if fdsOf contains md.id.name => - val fds = fdsOf(md.id.name) - fdsOf -= md.id.name - ModuleDef(md.id, md.defs ++ fds, false) - case d => d - } - - // 3) Create missing companions - val newCompanions = for ((name, fds) <- fdsOf) yield { - ModuleDef(FreshIdentifier(name), fds.toSeq, false) - } - - u.copy(defs = defs ++ newCompanions) - } - - val pgm = Program(newUnits) - - // 4) Remove methods in classes - for (cd <- pgm.definedClasses) { - cd.clearMethods() - } - - // 5) Replace method calls with function calls - for (fd <- pgm.definedFunctions) { - fd.fullBody = postMap{ - case mi @ MethodInvocation(IsTyped(rec, ct: ClassType), cd, tfd, args) => - Some(FunctionInvocation(mdToFds(tfd.fd).typed(ct.tps ++ tfd.tps), rec +: args).setPos(mi)) - case _ => None - }(fd.fullBody) - } - - pgm - } - -} diff --git a/src/main/scala/leon/purescala/RestoreMethods.scala b/src/main/scala/leon/purescala/RestoreMethods.scala deleted file mode 100644 index 7a80c0e85e8cf3029bce63668f420f3febbca543..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/purescala/RestoreMethods.scala +++ /dev/null @@ -1,120 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package purescala - -import Definitions._ -import Expressions._ -import ExprOps.replaceFromIDs -import DefOps._ -import Types._ - -object RestoreMethods extends TransformationPhase { - - val name = "Restore methods" - val description = "Restore methods that were previously turned into standalone functions" - - // @TODO: This code probably needs fixing, but is mostly unused and completely untested. - def apply(ctx: LeonContext, p: Program) = { - - val classMethods = p.definedFunctions.groupBy(_.methodOwner).collect { - case (Some(cd: ClassDef), fds) => cd -> fds - } - - val fdToMd = for( (cd, fds) <- classMethods; fd <- fds) yield { - val md = fd.duplicate(tparams = fd.tparams.drop(cd.tparams.size), params = fd.params.tail) - - val thisArg = fd.params.head - val thisExpr = This(thisArg.getType.asInstanceOf[ClassType]) - - md.fullBody = replaceFromIDs(Map(thisArg.id -> thisExpr), fd.fullBody) - - fd -> md - } - - // We inject methods, - def processClassDef(cd: ClassDef): ClassDef = { - if (classMethods contains cd) { - for (md <- classMethods(cd).map(fdToMd)) { - cd.registerMethod(md) - } - } - cd - } - - val np = p.copy(units = p.units.map { u => - u.copy(defs = u.defs.collect { - case md: ModuleDef => - md.copy( - defs = md.defs.flatMap { - case fd: FunDef if fdToMd contains(fd) => None - case cd: ClassDef => Some(processClassDef(cd)) - case d => Some(d) - } - ) - case cd: ClassDef => - processClassDef(cd) - }) - }) - - val np2 = transformProgram(funDefReplacer(fd => None, { (fi, fd) => - fdToMd.get(fi.tfd.fd) match { - case Some(md) => - Some(MethodInvocation( - fi.args.head, - fi.args.head.getType.asInstanceOf[ClassType].classDef, - md.typed(fi.tfd.tps.takeRight(md.tparams.size)), - fi.args.tail - )) - case None => - None - } - }), np) - - np2 - - - //// We need a special type of transitive closure, detecting only trans. calls on the same argument - //def transCallsOnSameArg(fd : FunDef) : Set[FunDef] = { - // require(fd.params.length == 1) - // require(fd.params.head.getType.isInstanceOf[ClassType]) - // def callsOnSameArg(fd : FunDef) : Set[FunDef] = { - // val theArg = fd.params.head.toVariable - // functionCallsOf(fd.fullBody) collect { case fi if fi.args contains theArg => fi.tfd.fd } - // } - // reachable(callsOnSameArg,fd) - //} - - //def refreshModule(m : ModuleDef) = { - // val newFuns : Seq[FunDef] = m.definedFunctions diff fd2MdFinal.keys.toSeq map substituteMethods// only keep non-methods - // for (cl <- m.definedClasses) { - // // We're going through some hoops to ensure strict fields are defined in topological order - - // // We need to work with the functions of the original program to have access to CallGraph - // val (strict, other) = whoseMethods.getOrElse(cl,Seq()).partition{ fd2MdFinal(_).canBeStrictField } - // val strictSet = strict.toSet - // // Make the call-subgraph that only includes the strict fields of this class - // val strictCallGraph = strict.map { st => - // (st, transCallsOnSameArg(st) & strictSet) - // }.toMap - // // Topologically sort, or warn in case of cycle - // val strictOrdered = topologicalSorting(strictCallGraph) fold ( - // cycle => { - // ctx.reporter.warning( - // s"""|Fields - // |${cycle map {_.id} mkString "\n"} - // |are involved in circular definition!""".stripMargin - // ) - // strict - // }, - // r => r - // ) - - // for (fun <- strictOrdered ++ other) { - // cl.registerMethod(fd2MdFinal(fun)) - // } - // } - // m.copy(defs = m.definedClasses ++ newFuns).copiedFrom(m) - //} - } -} diff --git a/src/main/scala/leon/purescala/ScalaPrinter.scala b/src/main/scala/leon/purescala/ScalaPrinter.scala deleted file mode 100644 index 393c38ad356796e4007220146317bb46ef0c62bb..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/purescala/ScalaPrinter.scala +++ /dev/null @@ -1,88 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package purescala - -import Extractors._ -import PrinterHelpers._ -import Common._ -import Expressions._ -import Types._ -import Definitions._ - -/** This pretty-printer only prints valid scala syntax */ -class ScalaPrinter(opts: PrinterOptions, - opgm: Option[Program], - sb: StringBuffer = new StringBuffer) extends PrettyPrinter(opts, opgm, sb) { - - override def pp(tree: Tree)(implicit ctx: PrinterContext): Unit = { - - tree match { - case m: ModuleDef => - // Don't print synthetic functions - super.pp(m.copy(defs = m.defs.filter { - case f: FunDef if f.isSynthetic => false - case _ => true - })) - case Not(Equals(l, r)) => optP { p"$l != $r" } - - case s @ FiniteSet(rss, t) => p"Set[$t](${rss.toSeq})" - case SetAdd(s,e) => optP { p"$s + $e" } - case ElementOfSet(e,s) => p"$s.contains($e)" - case SetUnion(l,r) => optP { p"$l ++ $r" } - case SetDifference(l,r) => optP { p"$l -- $r" } - case SetIntersection(l,r) => optP { p"$l & $r" } - case SetCardinality(s) => p"$s.size" - case SubsetOf(subset,superset) => p"$subset.subsetOf($superset)" - - case b @ FiniteBag(els, t) => p"Bag[$t](${els.toSeq})" - case BagAdd(s,e) => optP { p"$s + $e" } - case BagUnion(l,r) => optP { p"$l ++ $r" } - case BagDifference(l,r) => optP { p"$l -- $r" } - case BagIntersection(l,r) => optP { p"$l & $r" } - - case MapUnion(l,r) => optP { p"$l ++ $r" } - case m @ FiniteMap(els, k, v) => p"Map[$k,$v](${els.toSeq})" - - case InfiniteIntegerLiteral(v) => p"BigInt($v)" - case a@FiniteArray(elems, oDef, size) => - import ExprOps._ - val ArrayType(underlying) = a.getType - val default = oDef.getOrElse(simplestValue(underlying)) - size match { - case IntLiteral(s) => { - val explicitArray = Array.fill(s)(default) - for((id, el) <- elems) - explicitArray(id) = el - val lit = explicitArray.toList - p"Array($lit)" - } - case size => { - p"""|{ - | val tmp = Array.fill($size)($default) - |""" - for((id, el) <- elems) - p""""| tmp($id) = $el - |""" - p"""| tmp - |}""" - - } - } - - case Not(expr) => p"!$expr" - case Forall(args, bd) => - p"""|forall(($args) => - | $bd - |)""" - case NoTree(tpe) => - p"(_ : $tpe)" - case _ => - super.pp(tree) - } - } -} - -object ScalaPrinter extends PrettyPrinterFactory { - def create(opts: PrinterOptions, opgm: Option[Program]) = new ScalaPrinter(opts, opgm) -} diff --git a/src/main/scala/leon/purescala/ScopeSimplifier.scala b/src/main/scala/leon/purescala/ScopeSimplifier.scala deleted file mode 100644 index e5481b2cf97eac2246dbeea5168f513cb5f583ae..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/purescala/ScopeSimplifier.scala +++ /dev/null @@ -1,133 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package purescala - -import scala.collection.mutable.ListBuffer -import Common._ -import Definitions._ -import Expressions._ -import Extractors._ -import Constructors.letDef - -class ScopeSimplifier extends Transformer { - case class Scope(inScope: Set[Identifier] = Set(), oldToNew: Map[Identifier, Identifier] = Map(), funDefs: Map[FunDef, FunDef] = Map()) { - - def register(oldNew: (Identifier, Identifier)): Scope = { - val newId = oldNew._2 - copy(inScope = inScope + newId, oldToNew = oldToNew + oldNew) - } - - def register(oldNews: Seq[(Identifier, Identifier)]): Scope = { - (this /: oldNews){ case (oldScope, oldNew) => oldScope.register(oldNew) } - } - - def registerFunDef(oldNew: (FunDef, FunDef)): Scope = { - copy(funDefs = funDefs + oldNew) - } - } - - protected def genId(id: Identifier, scope: Scope): Identifier = { - val existCount = scope.inScope.count(_.name == id.name) - - FreshIdentifier.forceId(id.name, existCount, id.getType, existCount >= 1) - } - - protected def rec(e: Expr, scope: Scope): Expr = e match { - case Let(i, e, b) => - val si = genId(i, scope) - val se = rec(e, scope) - val sb = rec(b, scope.register(i -> si)) - Let(si, se, sb) - - case LetDef(fds, body: Expr) => - var newScope: Scope = scope - // First register all functions - val fds_newIds = for(fd <- fds) yield { - val newId = genId(fd.id, scope) - newScope = newScope.register(fd.id -> newId) - (fd, newId) - } - - val fds_mapping = for((fd, newId) <- fds_newIds) yield { - val localScopeToRegister = ListBuffer[(Identifier, Identifier)]() // We record the mapping of these variables only for the function. - val newArgs = for(ValDef(id) <- fd.params) yield { - val newArg = genId(id, newScope.register(localScopeToRegister)) - localScopeToRegister += (id -> newArg) // This renaming happens only inside the function. - ValDef(newArg) - } - - val newFd = fd.duplicate(id = newId, params = newArgs) - - newScope = newScope.registerFunDef(fd -> newFd) - (newFd, localScopeToRegister, fd) - } - - for((newFd, localScopeToRegister, fd) <- fds_mapping) { - newFd.fullBody = rec(fd.fullBody, newScope.register(localScopeToRegister)) - } - letDef(fds_mapping.map(_._1), rec(body, newScope)) - - case MatchExpr(scrut, cases) => - val rs = rec(scrut, scope) - - def trPattern(p: Pattern, scope: Scope): (Pattern, Scope) = { - val (newBinder, newScope) = p.binder match { - case Some(id) => - val newId = genId(id, scope) - val newScope = scope.register(id -> newId) - (Some(newId), newScope) - case None => - (None, scope) - } - - var curScope = newScope - val newSubPatterns = for (sp <- p.subPatterns) yield { - val (subPattern, subScope) = trPattern(sp, curScope) - curScope = subScope - subPattern - } - - val newPattern = p match { - case InstanceOfPattern(b, ctd) => - InstanceOfPattern(newBinder, ctd) - case WildcardPattern(b) => - WildcardPattern(newBinder) - case CaseClassPattern(b, ccd, sub) => - CaseClassPattern(newBinder, ccd, newSubPatterns) - case TuplePattern(b, sub) => - TuplePattern(newBinder, newSubPatterns) - case UnapplyPattern(b, obj, sub) => - UnapplyPattern(newBinder, obj, newSubPatterns) - case LiteralPattern(_, lit) => - LiteralPattern(newBinder, lit) - } - - (newPattern, curScope) - } - - MatchExpr(rs, cases.map { c => - val (newP, newScope) = trPattern(c.pattern, scope) - MatchCase(newP, c.optGuard map {rec(_, newScope)}, rec(c.rhs, newScope)) - }) - - case Variable(id) => - Variable(scope.oldToNew.getOrElse(id, id)) - - case FunctionInvocation(tfd, args) => - val newFd = scope.funDefs.getOrElse(tfd.fd, tfd.fd) - val newArgs = args.map(rec(_, scope)) - - FunctionInvocation(newFd.typed(tfd.tps), newArgs) - - case Operator(es, builder) => - builder(es.map(rec(_, scope))) - - case _ => - sys.error("Expression "+e+" ["+e.getClass+"] is not extractable") - } - - def transform(e: Expr): Expr = { - rec(e, Scope()) - } -} diff --git a/src/main/scala/leon/purescala/SelfPrettyPrinter.scala b/src/main/scala/leon/purescala/SelfPrettyPrinter.scala deleted file mode 100644 index 5a3483a0c5c7c27c3bb0b8bc10e2ab03e3789b46..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/purescala/SelfPrettyPrinter.scala +++ /dev/null @@ -1,234 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon.purescala - -import Constructors._ -import Expressions._ -import Types._ -import Common._ -import Definitions._ -import leon.evaluators._ -import leon.LeonContext -import leon.utils.StreamUtils - -import scala.collection.mutable.ListBuffer - -object SelfPrettyPrinter { - def prettyPrintersForType(inputType: TypeTree)(implicit ctx: LeonContext, program: Program): Stream[Lambda] = { - (new SelfPrettyPrinter).prettyPrintersForType(inputType) - } - def print(v: Expr, orElse: =>String, excluded: Set[FunDef] = Set())(implicit ctx: LeonContext, program: Program): String = { - (new SelfPrettyPrinter).print(v, orElse, excluded) - } -} - -/** T is the type of pretty-printers which have to be found (e.g. Lambda or Lambdas with identifiers) - * U is the type of the arguments during gathering */ -trait PrettyPrinterFinder[T, U >: T] { - protected def isExcluded(fd: FunDef): Boolean - protected def isAllowed(fd: FunDef): Boolean - - @inline def isValidPrinterName(s: String) = { val n = s.toLowerCase(); n.endsWith("tostring") || n.endsWith("mkstring") } - - @inline def isCandidate(fd: FunDef) = fd.returnType == StringType && - fd.params.nonEmpty && - !isExcluded(fd) && - (isAllowed(fd) || isValidPrinterName(fd.id.name)) - - /** Returns a list of possible lambdas that can transform the input type to a String. - * At this point, it does not consider yet the inputType. Only [[prettyPrinterFromCandidate]] will consider it. */ - def prettyPrintersForType(inputType: TypeTree/*, existingPp: Map[TypeTree, List[Lambda]] = Map()*/)(implicit ctx: LeonContext, program: Program): Stream[T] = { - program.definedFunctions.toStream flatMap { fd => - if(isCandidate(fd)) prettyPrinterFromCandidate(fd, inputType) else Stream.Empty - } - } - - /** How to fill the arguments for user-defined pretty-printers */ - def getPrintersForType(t: TypeTree, topLevel: TypeTree)(implicit ctx: LeonContext, program: Program): Option[Stream[U]] = t match { - case FunctionType(Seq(in), StringType) if in != topLevel => // Should have one argument. - Some(prettyPrintersForType(in)) - case _ => None - } - - // To Implement - def buildLambda(inputType: TypeTree, fd: FunDef, slu: Stream[List[U]]): Stream[T] - - def prettyPrinterFromCandidate(fd: FunDef, inputType: TypeTree)(implicit ctx: LeonContext, program: Program): Stream[T] = { - TypeOps.canBeSupertypeOf(fd.params.head.getType, inputType) match { - case Some(genericTypeMap) => - //println("Found a mapping for inputType = " + inputType + " " + fd) - def gatherPrettyPrinters(funIds: List[Identifier], acc: ListBuffer[Stream[U]] = ListBuffer[Stream[U]]()): Option[Stream[List[U]]] = funIds match { - case Nil => Some(StreamUtils.cartesianProduct(acc.toList)) - case funId::tail => // For each function, find an expression which could be provided if it exists. - getPrintersForType(funId.getType, inputType) match { - case Some(u) => gatherPrettyPrinters(tail, acc += u) - case None => - None - } - } - val funIds = fd.params.tail.map(x => TypeOps.instantiateType(x.id, genericTypeMap)).toList - gatherPrettyPrinters(funIds) match { - case Some(l) => buildLambda(inputType, fd, l) - case None => Stream.empty - } - case None => Stream.empty - } - } -} - -/** This pretty-printer uses functions defined in Leon itself. - * If not pretty printing function is defined, return the default value instead - */ -class SelfPrettyPrinter extends PrettyPrinterFinder[Lambda, Lambda] { top => - implicit val section = leon.utils.DebugSectionEvaluation - /* Functions whose name does not need to end with `tostring` or which can be abstract, i.e. which may contain a choose construct.*/ - protected var allowedFunctions = Set[FunDef]() - /* Functions totally excluded from the set of pretty-printer candidates */ - protected var excluded = Set[FunDef]() - /** Function whose name does not need to end with `tostring` or which can be abstract, i.e. which may contain a choose construct.*/ - def allowFunction(fd: FunDef) = { allowedFunctions += fd; this } - /** Functions whose name does not need to end with `tostring` or which can be abstract, i.e. which may contain a choose construct.*/ - def allowFunctions(fds: Set[FunDef]) = { allowedFunctions ++= fds; this } - /** Functions totally excluded from the set of pretty-printer candidates*/ - def excludeFunctions(fds: Set[FunDef]) = { excluded ++= fds; this } - /** Function totally excluded from the set of pretty-printer candidates*/ - def excludeFunction(fd: FunDef) = { excluded += fd; this } - - protected def isExcluded(fd: FunDef): Boolean = top.excluded(fd) - protected def isAllowed(fd: FunDef): Boolean = top.allowedFunctions(fd) - - /** How to fill the arguments for user-defined pretty-printers */ - override def getPrintersForType(t: TypeTree, underlying: TypeTree)(implicit ctx: LeonContext, program: Program): Option[Stream[Lambda]] = t match { - case FunctionType(Seq(StringType), StringType) => // Should have one argument. - val s = FreshIdentifier("s", StringType) // verify the type - Some(Stream(Lambda(Seq(ValDef(s)), Variable(s))) ++ super.getPrintersForType(t, underlying).getOrElse(Stream.empty) ) - case _ => super.getPrintersForType(t, underlying) - } - - /** From a list of lambdas used for pretty-printing the arguments of a function, builds the lambda for the function itself. */ - def buildLambda(inputType: TypeTree, fd: FunDef, slu: Stream[List[Lambda]]): Stream[Lambda] = { - for(lambdas <- slu) yield { - val x = FreshIdentifier("x", inputType) // verify the type - Lambda(Seq(ValDef(x)), functionInvocation(fd, Variable(x)::lambdas)) - } - } - - object withPossibleParameters extends PrettyPrinterFinder[(Lambda, List[Identifier]), (Expr, List[Identifier])] { - protected def isExcluded(fd: FunDef): Boolean = top.excluded(fd) - protected def isAllowed(fd: FunDef): Boolean = top.allowedFunctions(fd) - - /** If the returned identifiers are instantiated, each lambda becomes a pretty-printer. - * This allows to make use of mkString functions such as for maps */ - def prettyPrintersForTypes(inputType: TypeTree)(implicit ctx: LeonContext, program: Program) = { - (program.definedFunctions.toStream flatMap { fd => - if(isCandidate(fd)) prettyPrinterFromCandidate(fd, inputType) else Stream.Empty - }) #::: { - inputType match { - case Int32Type => - val i = FreshIdentifier("i", Int32Type) - Stream((Lambda(Seq(ValDef(i)), Int32ToString(Variable(i))), List[Identifier]())) - case IntegerType => - val i = FreshIdentifier("i", IntegerType) - Stream((Lambda(Seq(ValDef(i)), IntegerToString(Variable(i))), List[Identifier]())) - case StringType => - val i = FreshIdentifier("i", StringType) - Stream((Lambda(Seq(ValDef(i)), Variable(i)), List[Identifier]())) - case BooleanType => - val i = FreshIdentifier("i", BooleanType) - Stream((Lambda(Seq(ValDef(i)), BooleanToString(Variable(i))), List[Identifier]())) - case _ => Stream.empty - } - } - } - import leon.purescala.Extractors._ - - /** How to fill the arguments for user-defined pretty-printers */ - override def getPrintersForType(t: TypeTree, underlying: TypeTree)(implicit ctx: LeonContext, program: Program) = t match { - case FunctionType(Seq(StringType), StringType) => // Should have one argument. - val s = FreshIdentifier("s", StringType) // verify the type - Some(Stream((Lambda(Seq(ValDef(s)), Variable(s)), List())) ++ super.getPrintersForType(t, underlying).getOrElse(Stream.empty) ) - case FunctionType(Seq(t@ WithStringconverter(converter)), StringType) => // Should have one argument. - val s = FreshIdentifier("s", t) // verify the type - Some(Stream((Lambda(Seq(ValDef(s)), converter(Variable(s))), List())) ++ super.getPrintersForType(t, underlying).getOrElse(Stream.empty) ) - case StringType => - val const = FreshIdentifier("const", StringType) - Some(Stream((Variable(const), List(const)))) - case TupleType(targs) => - def convertPrinters(ts: Seq[TypeTree]): Option[Seq[Stream[(Expressions.Expr, List[Common.Identifier])]]] = { - ts match { - case Nil => Some(Seq()) - case t::tail => - getPrintersForType(t, underlying).flatMap(current => - convertPrinters(tail).map(remaining => - current +: remaining)) - } - } - convertPrinters(targs) match { - case None => None - case Some(t) => - val regrouped = leon.utils.StreamUtils.cartesianProduct(t) - val result = regrouped.map{lst => - val identifiers = lst.flatMap(_._2) - val lambdas = lst.collect{ case (l: Lambda, _) => l} - val valdefs = lambdas.flatMap(_.args) - val bodies = lst.map{ case (l: Lambda, _) => l.body case (e, _) => e } - if(valdefs.isEmpty) { - (Tuple(bodies), identifiers) - } else { - (Lambda(valdefs, Tuple(bodies)), identifiers) - } - } - Some(result) - } - case _ => super.getPrintersForType(t, underlying) - } - - /** From a list of expressions gathered for the parameters of the function definition, builds the lambda. */ - def buildLambda(inputType: TypeTree, fd: FunDef, slu: Stream[List[(Expr, List[Identifier])]]) = { - for(lambdas <- slu) yield { - val (args, ids) = lambdas.unzip - val all_ids = ids.flatten - val x = FreshIdentifier("x", inputType) // verify the type - (Lambda(Seq(ValDef(x)), functionInvocation(fd, Variable(x)::args)), all_ids) - } - } - } - - /** Actually prints the expression with as alternative the given orElse - * @param excluded The list of functions which should be excluded from pretty-printing - * (to avoid rendering counter-examples of toString methods using the method itself) - * @return a user defined string for the given typed expression. - **/ - def print(v: Expr, orElse: =>String, excluded: Set[FunDef] = Set())(implicit ctx: LeonContext, program: Program): String = { - this.excluded = excluded - val s = prettyPrintersForType(v.getType) // TODO: Included the variable excluded if necessary. - s.take(100).find { - // Limit the number of pretty-printers. - case Lambda(_, FunctionInvocation(TypedFunDef(fd, _), _)) => - (program.callGraph.transitiveCallees(fd) + fd).forall { fde => - !ExprOps.exists(_.isInstanceOf[Choose])(fde.fullBody) - } - case _ => false - } match { - case None => orElse - case Some(Lambda(Seq(ValDef(id)), body)) => - ctx.reporter.debug("Executing pretty printer for type " + v.getType + " : " + v + " on " + v) - val ste = new DefaultEvaluator(ctx, program) - try { - val result = ste.eval(body, Map(id -> v)) - - result.result match { - case Some(StringLiteral(res)) if res != "" => - res - case res => - ctx.reporter.debug("not a string literal " + result) - orElse - } - } catch { - case e: ContextualEvaluator#EvalError => - ctx.reporter.debug("Error " + e.msg) - orElse - } - } - } -} diff --git a/src/main/scala/leon/purescala/SimplifierWithPaths.scala b/src/main/scala/leon/purescala/SimplifierWithPaths.scala deleted file mode 100644 index 4b64e468ebaaea2c5a2879f416369a208e6d6e80..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/purescala/SimplifierWithPaths.scala +++ /dev/null @@ -1,139 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package purescala - -import Expressions._ -import Types._ -import ExprOps._ -import Extractors._ -import Constructors._ -import solvers._ - -class SimplifierWithPaths(sf: SolverFactory[Solver], val initPath: Path = Path.empty) extends TransformerWithPC { - - val solver = SimpleSolverAPI(sf) - - def impliedBy(e: Expr, path: Path) : Boolean = try { - solver.solveVALID(path implies e) match { - case Some(true) => true - case _ => false - } - } catch { - case _ : Exception => false - } - - def contradictedBy(e: Expr, path: Path) : Boolean = try { - solver.solveVALID(path implies not(e)) match { - case Some(true) => true - case _ => false - } - } catch { - case _ : Exception => false - } - - def valid(e: Expr) : Boolean = try { - solver.solveVALID(e) match { - case Some(true) => true - case _ => false - } - } catch { - case _ : Exception => false - } - - def sat(e: Expr) : Boolean = try { - solver.solveSAT(e) match { - case (Some(false),_) => false - case _ => true - } - } catch { - case _ : Exception => true - } - - protected override def rec(e: Expr, path: Path) = e match { - case Require(pre, body) if impliedBy(pre, path) => - body - - case IfExpr(cond, thenn, _) if impliedBy(cond, path) => - rec(thenn, path) - - case IfExpr(cond, _, elze ) if contradictedBy(cond, path) => - rec(elze, path) - - case And(e +: _) if contradictedBy(e, path) => - BooleanLiteral(false).copiedFrom(e) - - case And(e +: es) if impliedBy(e, path) => - val remaining = if (es.size > 1) And(es).copiedFrom(e) else es.head - rec(remaining, path) - - case Or(e +: _) if impliedBy(e, path) => - BooleanLiteral(true).copiedFrom(e) - - case Or(e +: es) if contradictedBy(e, path) => - val remaining = if (es.size > 1) Or(es).copiedFrom(e) else es.head - rec(remaining, path) - - case Implies(lhs, rhs) if impliedBy(lhs, path) => - rec(rhs, path) - - case Implies(lhs, rhs) if contradictedBy(lhs, path) => - BooleanLiteral(true).copiedFrom(e) - - case me @ MatchExpr(scrut, cases) => - val rs = rec(scrut, path) - - var stillPossible = true - - val conds = matchExprCaseConditions(me, path) - - val newCases = cases.zip(conds).flatMap { case (cs, cond) => - if (stillPossible && sat(cond.toClause)) { - - if (valid(cond.toClause)) { - stillPossible = false - } - - Seq((cs match { - case SimpleCase(p, rhs) => - SimpleCase(p, rec(rhs, cond)) - case GuardedCase(p, g, rhs) => - // @mk: This is quite a dirty hack. We just know matchCasePathConditions - // returns the current guard as the last element. - // We don't include it in the path condition when we recurse into itself. - // @nv: baaaaaaaad!!! - val condWithoutGuard = new Path(cond.elements.dropRight(1)) - val newGuard = rec(g, condWithoutGuard) - if (valid(newGuard)) - SimpleCase(p, rec(rhs,cond)) - else - GuardedCase(p, newGuard, rec(rhs, cond)) - }).copiedFrom(cs)) - } else { - Seq() - } - } - - newCases match { - case List() => - Error(e.getType, "Unreachable code").copiedFrom(e) - case _ => - matchExpr(rs, newCases).copiedFrom(e) - } - - case a @ Assert(pred, _, body) if impliedBy(pred, path) => - body - - case a @ Assert(pred, msg, body) if contradictedBy(pred, path) => - Error(body.getType, s"Assertion failed: $msg").copiedFrom(a) - - case b if b.getType == BooleanType && impliedBy(b, path) => - BooleanLiteral(true).copiedFrom(b) - - case b if b.getType == BooleanType && contradictedBy(b, path) => - BooleanLiteral(false).copiedFrom(b) - - case _ => - super.rec(e, path) - } -} diff --git a/src/main/scala/leon/purescala/TreeNormalizations.scala b/src/main/scala/leon/purescala/TreeNormalizations.scala deleted file mode 100644 index 4f5f5ffa9df53681ebccbc230262004cc39596e1..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/purescala/TreeNormalizations.scala +++ /dev/null @@ -1,91 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package purescala - -import Common._ -import Types._ -import Expressions._ -import ExprOps._ - -object TreeNormalizations { - - /* TODO: we should add CNF and DNF at least */ - - case class NonLinearExpressionException(msg: String) extends Exception - - //assume the function is an arithmetic expression, not a relation - //return a normal form where the [t a1 ... an] where - //expr = t + a1*x1 + ... + an*xn and xs = [x1 ... xn] - //do not keep the evaluation order - def linearArithmeticForm(expr: Expr, xs: Array[Identifier]): Array[Expr] = { - - //assume the expr is a literal (mult of constants and variables) with degree one - def extractCoef(e: Expr): (Expr, Identifier) = { - var id: Option[Identifier] = None - var coef: BigInt = 1 - - def rec(e: Expr): Unit = e match { - case InfiniteIntegerLiteral(i) => coef = coef*i - case Variable(id2) => if(id.isEmpty) id = Some(id2) else throw NonLinearExpressionException("multiple variable") - case Times(e1, e2) => rec(e1); rec(e2) - } - - rec(e) - assert(id.isDefined) - (InfiniteIntegerLiteral(coef), id.get) - } - - - def containsId(e: Expr, id: Identifier): Boolean = e match { - case Times(e1, e2) => containsId(e1, id) || containsId(e2, id) - case InfiniteIntegerLiteral(_) => false - case Variable(id2) => id == id2 - case err => throw NonLinearExpressionException("unexpected in containsId: " + err) - } - - def group(es: Seq[Expr], id: Identifier): Expr = { - val totalCoef = es.foldLeft(BigInt(0))((acc, e) => { - val (InfiniteIntegerLiteral(i), id2) = extractCoef(e) - assert(id2 == id) - acc + i - }) - Times(InfiniteIntegerLiteral(totalCoef), Variable(id)) - } - - var exprs: Seq[Expr] = expandedForm(expr) - val res: Array[Expr] = new Array(xs.length + 1) - - xs.zipWithIndex.foreach{case (id, index) => { - val (terms, rests) = exprs.partition(containsId(_, id)) - exprs = rests - val Times(coef, Variable(_)) = group(terms, id) - res(index+1) = coef - }} - - res(0) = simplifyArithmetic(exprs.foldLeft[Expr](InfiniteIntegerLiteral(0))(Plus)) - res - } - - //multiply two sums together and distribute in a larger sum - //do not keep the evaluation order - def multiply(es1: Seq[Expr], es2: Seq[Expr]): Seq[Expr] = { - for { - e1 <- es1 - e2 <- es2 - } yield Times(e1,e2) - } - - //expand the expr in a sum of "atoms", each atom being a product of literal and variable - //do not keep the evaluation order - def expandedForm(expr: Expr): Seq[Expr] = expr match { - case Plus(es1, es2) => expandedForm(es1) ++ expandedForm(es2) - case Minus(e1, e2) => expandedForm(e1) ++ expandedForm(e2).map(Times(InfiniteIntegerLiteral(-1), _): Expr) - case UMinus(e) => expandedForm(e).map(Times(InfiniteIntegerLiteral(-1), _): Expr) - case Times(es1, es2) => multiply(expandedForm(es1), expandedForm(es2)) - case v@Variable(_) if v.getType == IntegerType => Seq(v) - case n@InfiniteIntegerLiteral(_) => Seq(n) - case err => throw NonLinearExpressionException("unexpected in expandedForm: " + err) - } - -} diff --git a/src/main/scala/leon/purescala/Types.scala b/src/main/scala/leon/purescala/Types.scala deleted file mode 100644 index ae6cf3da32985c29c400e4d59af1fd6844277dad..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/purescala/Types.scala +++ /dev/null @@ -1,167 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package purescala - -import Common._ -import Expressions._ -import Definitions._ -import TypeOps._ - -object Types { - - trait Typed extends Printable { - def getType: TypeTree - def isTyped : Boolean = getType != Untyped - } - - class TypeErrorException(msg: String) extends Exception(msg) - - object TypeErrorException { - def apply(obj: Expr, exp: List[TypeTree]): TypeErrorException = { - new TypeErrorException("Type error: "+obj+", expected: "+exp.mkString(" or ")+", found "+obj.getType) - } - - def apply(obj: Expr, exp: TypeTree): TypeErrorException = { - apply(obj, List(exp)) - } - } - - abstract class TypeTree extends Tree with Typed { - val getType = this - - // Checks whether the subtypes of this type contain Untyped, - // and if so sets this to Untyped. - // Assumes the subtypes are correctly formed, so it does not descend - // deep into the TypeTree. - def unveilUntyped: TypeTree = this match { - case NAryType(tps, _) => - if (tps contains Untyped) Untyped else this - } - } - - case object Untyped extends TypeTree - case object BooleanType extends TypeTree - case object UnitType extends TypeTree - case object CharType extends TypeTree - case object IntegerType extends TypeTree - case object RealType extends TypeTree - - abstract class BitVectorType(val size: Int) extends TypeTree - case object Int32Type extends BitVectorType(32) - case object StringType extends TypeTree - - class TypeParameter private (name: String) extends TypeTree { - val id = FreshIdentifier(name, this) - def freshen = new TypeParameter(name) - - override def equals(that: Any) = that match { - case TypeParameter(id) => this.id == id - case _ => false - } - - override def hashCode = id.hashCode - } - - object TypeParameter { - def unapply(tp: TypeParameter): Option[Identifier] = Some(tp.id) - def fresh(name: String) = new TypeParameter(name) - } - - /* - * If you are not sure about the requirement, - * you should use tupleTypeWrap in purescala.Constructors - */ - case class TupleType(bases: Seq[TypeTree]) extends TypeTree { - val dimension: Int = bases.length - require(dimension >= 2) - } - - case class SetType(base: TypeTree) extends TypeTree - case class BagType(base: TypeTree) extends TypeTree - case class MapType(from: TypeTree, to: TypeTree) extends TypeTree - case class FunctionType(from: Seq[TypeTree], to: TypeTree) extends TypeTree - case class ArrayType(base: TypeTree) extends TypeTree - - sealed abstract class ClassType extends TypeTree { - val classDef: ClassDef - val id: Identifier = classDef.id - - override def hashCode : Int = id.hashCode + tps.hashCode - override def equals(that : Any) : Boolean = that match { - case t : ClassType => t.id == this.id && t.tps == this.tps - case _ => false - } - - val tps: Seq[TypeTree] - - assert(classDef.tparams.size == tps.size) - - lazy val fields = { - val tmap = (classDef.typeArgs zip tps).toMap - if (tmap.isEmpty) { - classDef.fields - } else { - // !! WARNING !! - // vd.id changes but this should not be an issue as selector uses - // classDef.params ids which do not change! - classDef.fields.map { vd => - val newTpe = instantiateType(vd.getType, tmap) - val newId = FreshIdentifier(vd.id.name, newTpe).copiedFrom(vd.id) - vd.copy(id = newId).setPos(vd) - } - } - } - - def invariant = classDef.invariant.map(_.typed(tps)) - - def knownDescendants = classDef.knownDescendants.map( _.typed(tps) ) - - def knownCCDescendants: Seq[CaseClassType] = classDef.knownCCDescendants.map( _.typed(tps) ) - - lazy val fieldsTypes = fields.map(_.getType) - - lazy val root: ClassType = parent.map{ _.root }.getOrElse(this) - - lazy val parent = classDef.parent.map { pct => - instantiateType(pct, (classDef.typeArgs zip tps).toMap) match { - case act: AbstractClassType => act - case t => throw LeonFatalError("Unexpected translated parent type: "+t) - } - } - - } - - case class AbstractClassType(classDef: AbstractClassDef, tps: Seq[TypeTree]) extends ClassType - case class CaseClassType(classDef: CaseClassDef, tps: Seq[TypeTree]) extends ClassType - - object NAryType extends TreeExtractor[TypeTree] { - def unapply(t: TypeTree): Option[(Seq[TypeTree], Seq[TypeTree] => TypeTree)] = t match { - case CaseClassType(ccd, ts) => Some((ts, ts => CaseClassType(ccd, ts))) - case AbstractClassType(acd, ts) => Some((ts, ts => AbstractClassType(acd, ts))) - case TupleType(ts) => Some((ts, TupleType)) - case ArrayType(t) => Some((Seq(t), ts => ArrayType(ts.head))) - case SetType(t) => Some((Seq(t), ts => SetType(ts.head))) - case BagType(t) => Some((Seq(t), ts => BagType(ts.head))) - case MapType(from,to) => Some((Seq(from, to), t => MapType(t(0), t(1)))) - case FunctionType(fts, tt) => Some((tt +: fts, ts => FunctionType(ts.tail.toList, ts.head))) - - /* TODO: use some extractable interface once this proved useful */ - case solvers.RawArrayType(from,to) => Some((Seq(from, to), t => solvers.RawArrayType(t(0), t(1)))) - - /* nullary types */ - case t => Some(Nil, _ => t) - } - } - - object FirstOrderFunctionType { - def unapply(tpe: TypeTree): Option[(Seq[TypeTree], TypeTree)] = tpe match { - case FunctionType(from, to) => - unapply(to).map(p => (from ++ p._1) -> p._2) orElse Some(from -> to) - case _ => None - } - } - - def optionToType(tp: Option[TypeTree]) = tp getOrElse Untyped - -} diff --git a/src/main/scala/leon/repair/RepairNDEvaluator.scala b/src/main/scala/leon/repair/RepairNDEvaluator.scala deleted file mode 100644 index 63a114ff9dcc0c4b934b2c92ce3a399937abb629..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/repair/RepairNDEvaluator.scala +++ /dev/null @@ -1,27 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package repair - -import purescala.Definitions.Program -import purescala.Expressions._ -import purescala.ExprOps.valuesOf -import evaluators.StreamEvaluator - -/** A [[leon.evaluators.StreamEvaluator StreamEvaluator]] that treats a specified expression [[nd]] as a non-deterministic value - * @note Expressions are compared against [[nd]] with reference equality. - */ -class RepairNDEvaluator(ctx: LeonContext, prog: Program, nd: Expr) extends StreamEvaluator(ctx, prog) { - - override def e(expr: Expr)(implicit rctx: RC, gctx: GC): Stream[Expr] = expr match { - case Not(c) if c eq nd => - // This is a hack: We know the only way nd is wrapped within a Not is if it is NOT within - // a recursive call. So we need to treat it deterministically at this point... - super.e(c) collect { case BooleanLiteral(b) => BooleanLiteral(!b) } - case c if c eq nd => - valuesOf(c.getType) - case other => - super.e(other) - } - -} diff --git a/src/main/scala/leon/repair/RepairPhase.scala b/src/main/scala/leon/repair/RepairPhase.scala deleted file mode 100644 index aff7747ad53eaa54ccc2d7864da154c4793985b2..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/repair/RepairPhase.scala +++ /dev/null @@ -1,37 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package repair - -import purescala.Definitions._ -import purescala.DefOps._ - -object RepairPhase extends UnitPhase[Program]() { - val name = "Repair" - val description = "Repairing" - - implicit val debugSection = utils.DebugSectionRepair - - def apply(ctx: LeonContext, program: Program) = { - val repairFuns: Option[Seq[String]] = ctx.findOption(GlobalOptions.optFunctions) - val verifTimeoutMs: Option[Long] = ctx.findOption(GlobalOptions.optTimeout) map { _ * 1000 } - - val reporter = ctx.reporter - - val fdFilter = { - import OptionsHelpers._ - - filterInclusive(repairFuns.map(fdMatcher(program)), None) - } - - val toRepair = funDefsFromMain(program).toList.filter(fdFilter).filter{ _.hasPostcondition }.sortWith((fd1, fd2) => fd1.getPos < fd2.getPos) - - if (toRepair.isEmpty) reporter.warning("No functions found with the given names") - - for (fd <- toRepair) { - val rep = new Repairman(ctx, program, fd, verifTimeoutMs, verifTimeoutMs) - rep.repair() - } - - } -} diff --git a/src/main/scala/leon/repair/Repairman.scala b/src/main/scala/leon/repair/Repairman.scala deleted file mode 100644 index 4cf9e1d359b8725527217e807f629257512618e8..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/repair/Repairman.scala +++ /dev/null @@ -1,270 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package repair - -import purescala.Path -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.Extractors._ -import purescala.ExprOps._ -import purescala.DefOps._ -import purescala.Constructors._ - -import evaluators._ -import solvers._ -import utils._ -import codegen._ -import verification._ -import datagen.GrammarDataGen - -import synthesis._ -import synthesis.rules._ -import synthesis.Witnesses._ -import synthesis.graph.{dotGenIds, DotGenerator} - -import rules._ - -class Repairman(ctx: LeonContext, program: Program, fd: FunDef, verifTimeoutMs: Option[Long], repairTimeoutMs: Option[Long]) { - implicit val ctx0 = ctx - - val reporter = ctx.reporter - - val doBenchmark = ctx.findOptionOrDefault(GlobalOptions.optBenchmark) - - implicit val debugSection = DebugSectionRepair - - def repair(): Unit = { - val to = new TimeoutFor(ctx.interruptManager) - - to.interruptAfter(repairTimeoutMs) { - reporter.info(ASCIIHelpers.title("1. Discovering tests for "+fd.id)) - - val timer = new Timer().start - - val eb = discoverTests() - - if (eb.invalids.nonEmpty) { - reporter.info(f" - Passing: ${eb.valids.size}%3d") - reporter.info(f" - Failing: ${eb.invalids.size}%3d") - - reporter.ifDebug { printer => - printer(eb.asString("Discovered Tests")) - } - - reporter.info(ASCIIHelpers.title("2. Minimizing tests")) - val eb2 = eb.minimizeInvalids(fd, ctx, program) - - // We exclude redundant failing tests, and only select the minimal tests - reporter.info(f" - Minimal Failing Set Size: ${eb2.invalids.size}%3d") - - reporter.ifDebug { printer => - printer(eb2.asString("Minimal Failing Tests")) - } - - val timeTests = timer.stop - timer.start - - val synth = getSynthesizer(eb2) - - try { - reporter.info(ASCIIHelpers.title("3. Synthesizing repair")) - val (search0, sols0) = synth.synthesize() - - val timeSynth = timer.stop - timer.start - - val (search, solutions) = synth.validate((search0, sols0), allowPartial = false) - - val timeVerify = timer.stop - - if (doBenchmark) { - val be = BenchmarkEntry.fromContext(ctx) ++ Map( - "function" -> fd.id.name, - "time_tests" -> timeTests, - "time_synthesis" -> timeSynth, - "time_verification" -> timeVerify, - "success" -> solutions.nonEmpty, - "verified" -> solutions.forall(_._2) - ) - - val bh = new BenchmarksHistory("repairs.dat") - - bh += be - - bh.write() - } - - reporter.ifDebug { printer => - import java.text.SimpleDateFormat - import java.util.Date - - val categoryName = fd.getPos.file.toString.split("/").dropRight(1).lastOption.getOrElse("?") - val benchName = categoryName+"."+fd.id.name - - val defs = visibleFunDefsFromMain(program).collect { - case fd: FunDef => 1 + fd.params.size + formulaSize(fd.fullBody) - } - - val pSize = defs.sum - val fSize = formulaSize(fd.body.get) - - def localizedExprs(n: graph.Node): List[Expr] = { - n match { - case on: graph.OrNode => - on.selected.flatMap(localizedExprs) - case an: graph.AndNode if an.ri.rule == Focus => - an.selected.flatMap(localizedExprs) - case an: graph.AndNode => - val TopLevelAnds(clauses) = an.p.ws - - val res = clauses.collect { - case Guide(expr) => expr - } - - res.toList - } - } - - val locSize = localizedExprs(search.g.root).map(formulaSize).sum - - val (solSize, proof) = solutions.headOption match { - case Some((sol, trusted)) => - val solExpr = sol.toSimplifiedExpr(ctx, program, fd) - val totalSolSize = formulaSize(solExpr) - (locSize+totalSolSize-fSize, if (trusted) "$\\chmark$" else "") - case _ => - (0, "X") - } - - val date = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(new Date()) - - val fw = new java.io.FileWriter("repair-report.txt", true) - - try { - fw.write(f"$date: $benchName%-30s & $pSize%4d & $fSize%4d & $locSize%4d & $solSize%4d & ${timeTests/1000.0}%2.1f & ${timeSynth/1000.0}%2.1f & $proof%7s \\\\\n") - } finally { - fw.close() - } - }(DebugSectionReport) - - if (synth.settings.generateDerivationTrees) { - val dot = new DotGenerator(search) - dot.writeFile("derivation"+ dotGenIds.nextGlobal + ".dot") - } - - if (solutions.isEmpty) { - reporter.error(ASCIIHelpers.title("Failed to repair!")) - } else { - - reporter.info(ASCIIHelpers.title("Repair successful:")) - for ( ((sol, isTrusted), i) <- solutions.zipWithIndex) { - reporter.info(ASCIIHelpers.subTitle("Solution "+(i+1)+ (if(isTrusted) "" else " (untrusted)" ) + ":")) - val expr = sol.toSimplifiedExpr(ctx, synth.program, fd) - reporter.info(expr.asString(program)(ctx)) - } - } - } finally { - synth.shutdown() - } - } else { - reporter.info(s"Could not find a wrong execution.") - } - } - } - - def getSynthesizer(eb: ExamplesBank): Synthesizer = { - - val origBody = fd.body.get - - val term = Terminating(fd.applied) - val guide = Guide(origBody) - val pre = fd.precOrTrue - - val prob = Problem.fromSpec(fd.postOrTrue, Path(Seq(pre, guide, term)), eb, Some(fd)) - - val ci = SourceInfo(fd, origBody, prob) - - // Return synthesizer for this choose - val so0 = SynthesisPhase.processOptions(ctx) - - val soptions = so0.copy( - functionsToIgnore = so0.functionsToIgnore + fd, - rules = Seq(Focus, CEGLESS) ++ so0.rules - ) - - new Synthesizer(ctx, program, ci, soptions) - } - - def getVerificationCExs(fd: FunDef): Seq[Example] = { - val timeoutMs = verifTimeoutMs.getOrElse(3000L) - val solverf = SolverFactory.getFromSettings(ctx, program).withTimeout(timeoutMs) - val vctx = new VerificationContext(ctx, program, solverf) - val vcs = VerificationPhase.generateVCs(vctx, Seq(fd)) - - try { - val report = VerificationPhase.checkVCs( - vctx, - vcs, - stopWhen = _.isInvalid - ) - - val vrs = report.vrs - - vrs.collect { case (_, VCResult(VCStatus.Invalid(ex), _, _)) => - InExample(fd.paramIds map ex) - } - } finally { - solverf.shutdown() - } - } - - def discoverTests(): ExamplesBank = { - - val maxEnumerated = 1000 - val maxValid = 400 - - val evaluator = new CodeGenEvaluator(ctx, program) - - val inputsToExample: Seq[Expr] => Example = { ins => - evaluator.eval(functionInvocation(fd, ins)) match { - case EvaluationResults.Successful(res) => - new InOutExample(ins, List(res)) - case _ => - new InExample(ins) - } - } - - val dataGen = new GrammarDataGen(evaluator) - - val generatedTests = dataGen - .generateFor(fd.paramIds, fd.precOrTrue, maxValid, maxEnumerated) - .map(inputsToExample) - .toList - - val (genPassing, genFailing) = generatedTests.partition { - case _: InOutExample => true - case _ => false - } - - val genTb = ExamplesBank(genPassing, genFailing).stripOuts - - // Extract passing/failing from the passes in POST - val userTb = new ExamplesFinder(ctx, program).extractFromFunDef(fd, partition = true) - - val allTb = genTb union userTb - - if (allTb.invalids.isEmpty) { - ExamplesBank(allTb.valids, getVerificationCExs(fd)) - } else { - allTb - } - } - - def programSize(pgm: Program): Int = { - visibleFunDefsFromMain(pgm).foldLeft(0) { - case (s, f) => - 1 + f.params.size + formulaSize(f.fullBody) + s - } - } -} diff --git a/src/main/scala/leon/repair/rules/Focus.scala b/src/main/scala/leon/repair/rules/Focus.scala deleted file mode 100644 index a77ed826a25bb8142d097c368525b11907b7b0de..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/repair/rules/Focus.scala +++ /dev/null @@ -1,264 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package repair -package rules - -import purescala.Path -import purescala.Expressions._ -import purescala.Common._ -import purescala.Types._ -import purescala.ExprOps._ -import purescala.Constructors._ -import purescala.Extractors._ - -import utils.fixpoint -import evaluators._ - -import synthesis._ -import Witnesses._ -import graph.AndNode - -case object Focus extends PreprocessingRule("Focus") { - - def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { - hctx.parentNode match { - case Some(an: AndNode) if an.ri.rule == Focus => - // We proceed as usual - case Some(_) => - return None; - case _ => - - } - - val qeb = p.qebFiltered - - val fd = hctx.functionContext - val program = hctx.program - - val evaluator = new DefaultEvaluator(hctx, program) - - // Check how an expression behaves on tests - // - returns Some(true) if for all tests e evaluates to true - // - returns Some(false) if for all tests e evaluates to false - // - returns None otherwise - def forAllTests(e: Expr, env: Map[Identifier, Expr], evaluator: Evaluator): Option[Boolean] = { - var soFar: Option[Boolean] = None - - qeb.invalids.foreach { ex => - evaluator.eval(e, (p.as zip ex.ins).toMap ++ env) match { - case EvaluationResults.Successful(BooleanLiteral(b)) => - soFar match { - case None => - soFar = Some(b) - case Some(`b`) => - /* noop */ - case _ => - return None - } - - case e => - //println("Evaluator said "+e) - return None - } - } - - soFar - } - - def existsFailing(e: Expr, env: Map[Identifier, Expr], evaluator: DeterministicEvaluator): Boolean = { - qeb.invalids.exists { ex => - evaluator.eval(e, (p.as zip ex.ins).toMap ++ env).result match { - case Some(BooleanLiteral(b)) => b - case _ => true - } - } - } - - - - val TopLevelAnds(clauses) = p.ws - - val guides = clauses.collect { - case Guide(expr) => expr - } - - val wss = clauses.filter { - case _: Guide => false - case _ => true - } - - def ws(g: Expr) = andJoin(Guide(g) +: wss) - - def testCondition(guide: IfExpr) = { - val IfExpr(cond, thenn, elze) = guide - val spec = letTuple(p.xs, IfExpr(Not(cond), thenn, elze), p.phi) - forAllTests(spec, Map(), new AngelicEvaluator(new RepairNDEvaluator(hctx, program, cond))) - } - - guides.flatMap { - case g @ IfExpr(c, thn, els) => - testCondition(g) match { - case Some(true) => - val cx = FreshIdentifier("cond", BooleanType) - // Focus on condition - val np = Problem(p.as, ws(c), p.pc, letTuple(p.xs, IfExpr(cx.toVariable, thn, els), p.phi), List(cx), qeb.stripOuts) - - Some(decomp(List(np), termWrap(IfExpr(_, thn, els)), s"Focus on if-cond '${c.asString}'")(p)) - - case _ => - // Try to focus on branches - forAllTests(c, Map(), evaluator) match { - case Some(true) => - val np = Problem(p.as, ws(thn), p.pc withCond c, p.phi, p.xs, qeb.filterIns(c)) - - Some(decomp(List(np), termWrap(IfExpr(c, _, els), c), s"Focus on if-then")(p)) - case Some(false) => - val np = Problem(p.as, ws(els), p.pc withCond not(c), p.phi, p.xs, qeb.filterIns(not(c))) - - Some(decomp(List(np), termWrap(IfExpr(c, thn, _), not(c)), s"Focus on if-else")(p)) - case None => - // We split - val sub1 = p.copy(ws = ws(thn), pc = p.pc map (replace(Map(g -> thn), _)) withCond c , eb = qeb.filterIns(c)) - val sub2 = p.copy(ws = ws(els), pc = p.pc map (replace(Map(g -> thn), _)) withCond Not(c), eb = qeb.filterIns(Not(c))) - - val onSuccess: List[Solution] => Option[Solution] = { - case List(s1, s2) => - Some(Solution(or(s1.pre, s2.pre), s1.defs++s2.defs, IfExpr(c, s1.term, s2.term))) - case _ => - None - } - - Some(decomp(List(sub1, sub2), onSuccess, s"Focus on both branches of '${c.asString}'")) - } - } - - case MatchExpr(scrut, cases) => - var pcSoFar = Path.empty - - // Generate subproblems for each match-case that fails at least one test. - var casesInfos = for (c <- cases) yield { - val map = mapForPattern(scrut, c.pattern) - - val thisCond = matchCaseCondition(scrut, c) - val prevPCSoFar = pcSoFar - val cond = pcSoFar merge thisCond - pcSoFar = pcSoFar merge thisCond.negate - - val subP = if (existsFailing(cond.toClause, map, evaluator)) { - - val vars = map.keys.toSeq - - val (p2e, _) = patternToExpression(c.pattern, scrut.getType) - - val substAs = ((scrut, p2e) match { - case (Variable(i), _) if p.as.contains(i) => Seq(i -> p2e) - case (Tuple(as), Tuple(tos)) => - val res = as.zip(tos) collect { - case (Variable(i), to) if p.as.contains(i) => i -> to - } - if (res.size == as.size) res else Nil - case _ => Nil - }).toMap - - if (substAs.nonEmpty) { - val subst = replaceFromIDs(substAs, (_:Expr)) - // FIXME intermediate binders?? - val newAs = (p.as diff substAs.keys.toSeq) ++ vars - val newPc = (p.pc merge prevPCSoFar) map subst - val newWs = subst(ws(c.rhs)) - val newPhi = subst(p.phi) - val eb2 = qeb.filterIns(cond.toClause) - val ebF: Seq[(Identifier, Expr)] => List[Seq[Expr]] = { (ins: Seq[(Identifier, Expr)]) => - val eval = evaluator.eval(tupleWrap(vars map Variable), map ++ ins) - val insWithout = ins.collect{ case (id, v) if !substAs.contains(id) => v } - eval.result.map(r => insWithout ++ unwrapTuple(r, vars.size)).toList - } - val newEb = eb2.flatMapIns(ebF) - Some(Problem(newAs, newWs, newPc, newPhi, p.xs, newEb)) - } else { - // Filter tests by the path-condition - val eb2 = qeb.filterIns(cond.toClause) - - val newPc = cond withBindings vars.map(id => id -> map(id)) - - Some(Problem(p.as, ws(c.rhs), p.pc merge newPc, p.phi, p.xs, eb2)) - } - } else { - None - } - - c -> (subP, cond) - } - - // Check if the match might be missing a case? (we check if one test - // goes to no defined cases) - val elsePc = pcSoFar - - if (existsFailing(elsePc.toClause, Map(), evaluator)) { - val newCase = MatchCase(WildcardPattern(None), None, NoTree(scrut.getType)) - - val qeb2 = qeb.filterIns(elsePc.toClause) - - val newProblem = Problem(p.as, andJoin(wss), p.pc merge elsePc, p.phi, p.xs, qeb2) - - casesInfos :+= (newCase -> (Some(newProblem), elsePc)) - } - - // Is there at least one subproblem? - if (casesInfos.exists(_._2._1.isDefined)) { - val infosP = casesInfos.collect { - case (c, (Some(p), pc)) => (c, (p, pc)) - } - - val nps = infosP.map(_._2._1).toList - - val appName = s"Focus on match-cases ${infosP.map(i => "'"+i._1.pattern.asString+"'").mkString(", ")}" - - val onSuccess: List[Solution] => Option[Solution] = { - case ss => - val matchSols = (infosP zip ss).map { case ((c, (pc)), s) => (c, (pc, s)) } - - val pres = matchSols.map { - case (_, (pc, s)) => - if(s.pre == BooleanLiteral(true)) { - BooleanLiteral(true) - } else { - p.pc and s.pre - } - } - - val solsMap = matchSols.toMap - - val expr = MatchExpr(scrut, casesInfos.map { case (c, _) => solsMap.get(c) match { - case Some((pc, s)) => - c.copy(rhs = s.term) - case None => - c - }}) - - Some(Solution(orJoin(pres), ss.map(_.defs).reduceLeft(_ ++ _), expr)) - } - - Some(decomp(nps, onSuccess, appName)(p)) - } else { - None - } - - case Let(id, value, body) => - val ebF: (Seq[Expr] => List[Seq[Expr]]) = { (e: Seq[Expr]) => - val map = (p.as zip e).toMap - - evaluator.eval(value, map).result.map { r => - e :+ r - }.toList - } - - val np = Problem(p.as, ws(body), p.pc withBinding (id -> value), p.phi, p.xs, qeb.flatMapIns(ebF)) - - Some(decomp(List(np), termWrap(Let(id, value, _)), s"Focus on let-body")(p)) - - case _ => None - } - } -} diff --git a/src/main/scala/leon/synthesis/Algebra.scala b/src/main/scala/leon/synthesis/Algebra.scala deleted file mode 100644 index 3ba54147d0974bf9f48aac183fd47942540d9d8f..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/Algebra.scala +++ /dev/null @@ -1,234 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon.synthesis - -/* - * This provides some algebra/number theory functions, including operation such as true division, - * GCD and LCM as well as some matrix computation. - * - * Notice that all those functionalities are independent of the Leon language and - * are working for Integer (by opposition to real numbers). - */ - -object Algebra { - /** Returns the remainder of the euclidian division between x an y (always positive) */ - def remainder(x: Int, y: Int) = ((x % y) + y.abs) % y.abs - - /** Returns the quotient of the euclidian division between a and b.*/ - def divide(a: Int, b: Int): (Int, Int) = { - val r = remainder(a, b) - ((a - r)/b, r) - } - - /** Returns the remainder of the euclidian division between the big integers x an y (always positive) */ - def remainder(x: BigInt, y: BigInt) = ((x % y) + y.abs) % y.abs - /** Returns the quotient of the euclidian division between the big integers x an y */ - def divide(a: BigInt, b: BigInt): (BigInt, BigInt) = { - val r = remainder(a, b) - ((a - r)/b, r) - } - /** Returns the gcd of two integers */ - def gcd(a: Int, b: Int): Int = { - val (na, nb) = (a.abs, b.abs) - def gcd0(a: Int, b: Int): Int = { - require(a >= b) - if(b == 0) a else gcd0(b, a % b) - } - if(na > nb) gcd0(na, nb) else gcd0(nb, na) - } - - /** Returns the gcd of three or more integers */ - def gcd(a1: Int, a2: Int, a3: Int, as: Int*): Int = { - var tmp = gcd(a1, a2) - tmp = gcd(tmp, a3) - var i = 0 - while(i < as.size) { - tmp = gcd(tmp, as(i)) - i += 1 - } - tmp - } - - /** Returns the gcd of a non-empty sequence of integers */ - def gcd(as: Seq[Int]): Int = { - require(as.length >= 1) - if(as.length == 1) - as.head.abs - else { - var tmp = gcd(as(0), as(1)) - var i = 2 - while(i < as.size) { - tmp = gcd(tmp, as(i)) - i += 1 - } - tmp - } - } - - /** Returns the gcd of two big integers */ - def gcd(a: BigInt, b: BigInt): BigInt = { - val (na, nb) = (a.abs, b.abs) - def gcd0(a: BigInt, b: BigInt): BigInt = { - require(a >= b) - if(b == BigInt(0)) a else gcd0(b, a % b) - } - if(na > nb) gcd0(na, nb) else gcd0(nb, na) - } - - /** Returns the gcd of three or more big integers */ - def gcd(a1: BigInt, a2: BigInt, a3: BigInt, as: BigInt*): BigInt = { - var tmp = gcd(a1, a2) - tmp = gcd(tmp, a3) - var i = 0 - while(i < as.size) { - tmp = gcd(tmp, as(i)) - i += 1 - } - tmp - } - - /** Returns the gcd of a non-empty sequence of big integers */ - def gcd(as: Seq[BigInt]): BigInt = { - require(as.length >= 1) - if(as.length == 1) - as.head.abs - else { - var tmp = gcd(as(0), as(1)) - var i = 2 - while(i < as.size) { - tmp = gcd(tmp, as(i)) - i += 1 - } - tmp - } - } - - /** Returns the lcm of two integers */ - def lcm(a: Int, b: Int): Int = { - val (na, nb) = (a.abs, b.abs) - na*nb/gcd(a, b) - } - - /** Returns the lcm of three or more integers */ - def lcm(a1: Int, a2: Int, a3: Int, as: Int*): Int = { - var tmp = lcm(a1, a2) - tmp = lcm(tmp, a3) - var i = 0 - while(i < as.size) { - tmp = lcm(tmp, as(i)) - i += 1 - } - tmp - } - - /** Returns the lcm of a sequence of integers */ - def lcm(as: Seq[Int]): Int = { - require(as.length >= 1) - if(as.length == 1) - as.head.abs - else { - var tmp = lcm(as(0), as(1)) - var i = 2 - while(i < as.size) { - tmp = lcm(tmp, as(i)) - i += 1 - } - tmp - } - } - - /** Returns the lcm of two big integers */ - def lcm(a: BigInt, b: BigInt): BigInt = { - val (na, nb) = (a.abs, b.abs) - na*nb/gcd(a, b) - } - - /** Returns the lcm of three or more big integers */ - def lcm(a1: BigInt, a2: BigInt, a3: BigInt, as: BigInt*): BigInt = { - var tmp = lcm(a1, a2) - tmp = lcm(tmp, a3) - var i = 0 - while(i < as.size) { - tmp = lcm(tmp, as(i)) - i += 1 - } - tmp - } - - /** Returns the lcm of a sequence of big integers */ - def lcm(as: Seq[BigInt]): BigInt = { - require(as.length >= 1) - if(as.length == 1) - as(0).abs - else { - var tmp = lcm(as(0), as(1)) - var i = 2 - while(i < as.size) { - tmp = lcm(tmp, as(i)) - i += 1 - } - tmp - } - } - - //return (x, y) such that ax + by = gcd(a, b) - def extendedEuclid(a: Int, b: Int): (Int, Int) = { - def rec(a: Int, b: Int): (Int, Int) = { - require(a >= 0 && b >= 0) - if(b == 0) (1, 0) else { - val (q, r) = divide(a, b) - val (s, t) = extendedEuclid(b, r) - (t, s - q * t) - } - } - if(a >= 0 && b >= 0) rec(a, b) - else if(a < 0 && b >= 0) {val (x, y) = rec(-a, b); (-x, y)} - else if(a >= 0 && b < 0) {val (x, y) = rec(a, -b); (x, -y)} - else if(a < 0 && b < 0) {val (x, y) = rec(-a, -b); (-x, -y)} - else sys.error("shouldn't have forgot a case here") - } - def extendedEuclid(a: BigInt, b: BigInt): (BigInt, BigInt) = { - def rec(a: BigInt, b: BigInt): (BigInt, BigInt) = { - require(a >= 0 && b >= 0) - if(b == BigInt(0)) (1, 0) else { - val (q, r) = divide(a, b) - val (s, t) = extendedEuclid(b, r) - (t, s - q * t) - } - } - if(a >= 0 && b >= 0) rec(a, b) - else if(a < 0 && b >= 0) {val (x, y) = rec(-a, b); (-x, y)} - else if(a >= 0 && b < 0) {val (x, y) = rec(a, -b); (x, -y)} - else if(a < 0 && b < 0) {val (x, y) = rec(-a, -b); (-x, -y)} - else sys.error("shouldn't have forgot a case here") - } - - - //val that the sol vector with the term in the equation - def eval(sol: Array[Int], equation: Array[Int]): Int = { - require(sol.length == equation.length) - sol.zip(equation).foldLeft(0)((acc, p) => acc + p._1 * p._2) - } - - //multiply the matrix by the vector: [M1 M2 .. Mn] * [v1 .. vn] = v1*M1 + ... + vn*Mn] - def mult(matrix: Array[Array[Int]], vector: Array[Int]): Array[Int] = { - require(vector.length == matrix(0).length && vector.length > 0) - val tmat = matrix.transpose - var tmp: Array[Int] = null - tmp = mult(vector(0), tmat(0)) - var i = 1 - while(i < vector.length) { - tmp = add(tmp, mult(vector(i), tmat(i))) - i += 1 - } - tmp - } - - def mult(c: Int, v: Array[Int]): Array[Int] = v.map(_ * c) - - def add(v1: Array[Int], v2: Array[Int]): Array[Int] = { - require(v1.length == v2.length) - v1.zip(v2).map(p => p._1 + p._2) - } - -} diff --git a/src/main/scala/leon/synthesis/ConversionPhase.scala b/src/main/scala/leon/synthesis/ConversionPhase.scala deleted file mode 100644 index d6329768e42e561c92d76c55a39764b11b650f66..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/ConversionPhase.scala +++ /dev/null @@ -1,213 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis - -import purescala.Common._ -import purescala.Expressions._ -import purescala.Types._ -import purescala.ExprOps._ -import purescala.Definitions._ -import purescala.Constructors._ - -object ConversionPhase extends UnitPhase[Program] { - val name = "Eliminate holes, withOracle and abstract definitions" - val description = "Convert Holes, withOracle found in bodies and abstract function bodies to equivalent Choose" - - /** - * This phase does 3 things: - * - * 1) Converts a body with "withOracle{ .. }" into a choose construct: - * - * def foo(a: T) = { - * require(..a..) - * withOracle { o => - * expr(a,o) ensuring { x => post(x) } - * } - * } - * - * gets converted into: - * - * def foo(a: T) { - * require(..a..) - * val o = choose { (o) => { - * val res = expr(a, o) - * pred(res) - * } - * expr(a,o) - * } ensuring { res => - * pred(res) - * } - * - * - * 2) Converts a body with "???" into a choose construct: - * - * def foo(a: T) = { - * require(..a..) - * expr(a, ???) - * } ensuring { x => post(x) } - * - * gets converted into: - * - * def foo(a: T) = { - * require(..a..) - * val h = choose ( h' => - * val res = expr(a, h') - * post(res) - * ) - * expr(a, h) - * } ensuring { res => - * post(res) - * } - * - * 3) Completes abstract definitions (IF NOT EXTERN): - * - * def foo(a: T) = { - * require(..a..) - * _ - * } ensuring { res => - * post(res) - * } - * - * gets converted to: - * - * def foo(a: T) = { - * require(..a..) - * choose(x => post(x)) - * } - * (in practice, there will be no pre-and postcondition) - * - * 4) Functions that have only a choose as body gets their spec from the choose. - * - * def foo(a: T) = { - * choose(x => post(a, x)) - * } - * - * gets converted to: - * - * def foo(a: T) = { - * choose(x => post(a, x)) - * } ensuring { x => post(a, x) } - * - * (in practice, there will be no pre-and postcondition) - */ - - def convert(e : Expr, ctx : LeonContext, isExtern: Boolean) : Expr = { - val (pre, body, post) = breakDownSpecs(e) - - // Ensure that holes are not found in pre and/or post conditions - (pre ++ post).foreach { - preTraversal{ - case h : Hole => - ctx.reporter.error(s"Holes like $h are not supported in pre- or postconditions. @ ${h.getPos}") - case wo: WithOracle => - ctx.reporter.error(s"WithOracle expressions are not supported in pre- or postconditions: ${wo.asString(ctx)} @ ${wo.getPos}") - case _ => - } - } - - def toExpr(h: Hole): (Expr, List[Identifier]) = { - h.alts match { - case Seq() => - val h1 = FreshIdentifier("hole", h.getType, true) - (h1.toVariable, List(h1)) - - case Seq(v) => - val h1 = FreshIdentifier("hole", BooleanType, true) - val h2 = FreshIdentifier("hole", h.getType, true) - (IfExpr(h1.toVariable, h2.toVariable, v), List(h1, h2)) - - case exs => - var ids: List[Identifier] = Nil - val ex = exs.init.foldRight(exs.last)({ (e: Expr, r: Expr) => - val h = FreshIdentifier("hole", BooleanType, true) - ids ::= h - IfExpr(h.toVariable, e, r) - }) - - (ex, ids.reverse) - } - } - - val fullBody = body match { - case Some(body) => - var holes = List[Identifier]() - - val withoutHoles = preMap { - case h : Hole => - val (expr, ids) = toExpr(h) - - holes ++= ids - - Some(expr) - case wo @ WithOracle(os, b) => - withoutSpec(b) map { pred => - val chooseOs = os.map(_.freshen) - - val pred = post.getOrElse( // FIXME: We need to freshen variables - Lambda(chooseOs.map(ValDef(_)), BooleanLiteral(true)) - ) - - letTuple(os, Choose(pred), b) - } - case _ => - None - }(body) - - if (holes.nonEmpty) { - val cids: List[Identifier] = holes.map(_.freshen) - val hToFresh = (holes zip cids.map(_.toVariable)).toMap - - val spec = post match { - case Some(post: Lambda) => - val asLet = letTuple(post.args.map(_.id), withoutHoles, post.body) - - Lambda(cids.map(ValDef(_)), replaceFromIDs(hToFresh, asLet)) - - case _ => - Lambda(cids.map(ValDef(_)), BooleanLiteral(true)) - } - - val choose = Choose(spec) - - val newBody = if (holes.size == 1) { - replaceFromIDs(Map(holes.head -> choose), withoutHoles) - } else { - letTuple(holes, choose, withoutHoles) - } - - withPostcondition(withPrecondition(newBody, pre), post) - - } else { - e - } - - case None => - if (isExtern) { - e - } else { - val newPost = post getOrElse Lambda(Seq(ValDef(FreshIdentifier("res", e.getType))), BooleanLiteral(true)) - withPrecondition(Choose(newPost), pre) - } - } - - // extract spec from chooses at the top-level - fullBody match { - case Require(_, Choose(spec)) => - withPostcondition(fullBody, Some(spec)) - case Choose(spec) => - withPostcondition(fullBody, Some(spec)) - case _ => - fullBody - } - } - - - def apply(ctx: LeonContext, pgm: Program): Unit = { - // TODO: remove side-effects - for (fd <- pgm.definedFunctions) { - fd.fullBody = convert(fd.fullBody, ctx, fd.annotations("extern")) - } - } - -} diff --git a/src/main/scala/leon/synthesis/CostModel.scala b/src/main/scala/leon/synthesis/CostModel.scala deleted file mode 100644 index 781940e71ec6d62c57895a9cafee2b4073e704bf..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/CostModel.scala +++ /dev/null @@ -1,102 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis - -import purescala.Expressions._ -import purescala.ExprOps._ - -/** A named way of computing the cost of problem and solutions.*/ -abstract class CostModel(val name: String) { - def solution(s: Solution): Cost - - def impossible: Cost - - def isImpossible(c: Cost): Boolean = { - c >= impossible - } -} - -/** Represents a cost used when comparing synthesis solutions for example */ -case class Cost(minSize: Int) extends AnyVal with Ordered[Cost] { - def compare(that: Cost): Int = { - this.minSize-that.minSize - } - - override def toString: String = { - f"$minSize%3d" - } -} - -/** Contains all and the default [CostModel] */ -object CostModels { - def default: CostModel = NaiveCostModel - - def all: Set[CostModel] = Set( - NaiveCostModel, - WeightedBranchesCostModel - ) -} - -/** Wrapped cost model. Not used at this moment. */ -class WrappedCostModel(cm: CostModel, name: String) extends CostModel(name) { - - def solution(s: Solution): Cost = cm.solution(s) - - def impossible = cm.impossible -} - -/** Computes a cost corresponding of the size of the solution expression divided by 10. - * For problems, returns a cost of 1 */ -class SizeBasedCostModel(name: String) extends CostModel(name) { - def solution(s: Solution) = { - Cost(formulaSize(s.term)) - } - - def impossible = Cost(1000) -} - -case object NaiveCostModel extends SizeBasedCostModel("Naive") - -case object WeightedBranchesCostModel extends SizeBasedCostModel("WeightedBranches") { - - def branchesCost(e: Expr): Int = { - case class BC(cost: Int, nesting: Int) - - def pre(e: Expr, c: BC) = { - (e, c.copy(nesting = c.nesting + 1)) - } - - def costOfBranches(alts: Int, nesting: Int) = { - if (nesting > 10) { - alts - } else { - (10-nesting)*alts - } - } - - def post(e: Expr, bc: BC) = e match { - case ie : IfExpr => - (e, bc.copy(cost = bc.cost + costOfBranches(2, bc.nesting))) - case ie : LetDef => - (e, bc.copy(cost = bc.cost + costOfBranches(2, bc.nesting))) - case ie : MatchExpr => - (e, bc.copy(cost = bc.cost + costOfBranches(ie.cases.size, bc.nesting))) - case _ => - (e, bc) - } - - def combiner(e: Expr, cs: Seq[BC]) = { - cs.foldLeft(BC(0,0))((bc1, bc2) => BC(bc1.cost + bc2.cost, 0)) - } - - val (_, bc) = genericTransform[BC](pre, post, combiner)(BC(0, 0))(e) - - bc.cost - } - - override def solution(s: Solution) = { - Cost(formulaSize(s.toExpr) + branchesCost(s.toExpr)) - } - -} diff --git a/src/main/scala/leon/synthesis/ExamplesBank.scala b/src/main/scala/leon/synthesis/ExamplesBank.scala deleted file mode 100644 index 913ae4376721933527b1a54c6470af5c6741bd33..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/ExamplesBank.scala +++ /dev/null @@ -1,230 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis - -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.Constructors._ -import purescala.Common._ -import evaluators.{TrackingEvaluator, DefaultEvaluator} -import leon.utils.ASCIIHelpers._ - -/** Sets of valid and invalid examples */ -case class ExamplesBank(valids: Seq[Example], invalids: Seq[Example]) { - def examples = valids ++ invalids - - // Minimize tests of a function so that tests that are invalid because of a - // recursive call are eliminated - def minimizeInvalids(fd: FunDef, ctx: LeonContext, program: Program): ExamplesBank = { - val evaluator = new TrackingEvaluator(ctx, program) - - invalids foreach { ts => - evaluator.eval(functionInvocation(fd, ts.ins)) - } - - val outInfo = invalids.collect { - case InOutExample(ins, outs) => ins -> outs - }.toMap - - val callGraph = evaluator.fullCallGraph - - def isFailing(fi: (FunDef, Seq[Expr])) = !evaluator.fiStatus(fi) && (fi._1 == fd) - - val failing = callGraph filter { case (from, to) => - isFailing(from) && (to forall (!isFailing(_)) ) - } - - val newInvalids = failing.keySet map { - case (_, args) => - outInfo.get(args) match { - case Some(outs) => - InOutExample(args, outs) - - case None => - InExample(args) - } - } - - ExamplesBank(valids, newInvalids.toSeq) - } - - def union(that: ExamplesBank) = { - ExamplesBank( - distinctIns(this.valids union that.valids), - distinctIns(this.invalids union that.invalids) - ) - } - - private def distinctIns(s: Seq[Example]): Seq[Example] = { - val insOuts = s.collect { - case InOutExample(ins, outs) => ins -> outs - }.toMap - - s.map(_.ins).distinct.map { - case ins => - insOuts.get(ins) match { - case Some(outs) => InOutExample(ins, outs) - case _ => InExample(ins) - } - } - } - - def flatMap(f: Example => List[Example]) = { - ExamplesBank(valids.flatMap(f), invalids.flatMap(f)) - } - - /** Expands each input example through the function f */ - def flatMapIns(f: Seq[Expr] => List[Seq[Expr]]) = { - flatMap { - case InExample(in) => - f(in).map(InExample) - - case InOutExample(in, out) => - f(in).map(InOutExample(_, out)) - } - } - - /** Expands each output example through the function f */ - def flatMapOuts(f: Seq[Expr] => List[Seq[Expr]]) = { - flatMap { - case InOutExample(in, out) => - f(out).map(InOutExample(in, _)) - - case e => - List(e) - } - } - - def stripOuts = { - flatMap { - case InOutExample(in, out) => - List(InExample(in)) - case e => - List(e) - } - } - - def asString(title: String)(implicit ctx: LeonContext): String = { - var tt = new Table(title) - - if (examples.nonEmpty) { - - val ow = examples.map { - case InOutExample(_, out) => out.size - case _ => 1 - }.max - - val iw = examples.map(_.ins.size).max - - def testsRows(section: String, ts: Seq[Example]) { - if (tt.rows.nonEmpty) { - tt += Row(Seq( - Cell(" ", iw + ow + 1) - )) - } - - tt += Row(Seq( - Cell(Console.BOLD+section+Console.RESET+":", iw + ow + 1) - )) - tt += Separator - - for (t <- ts) { - val os = t match { - case InOutExample(_, outs) => - outs.map(o => Cell(o.asString)) - case _ => - Seq(Cell("?", ow)) - } - - tt += Row( - t.ins.map(i => Cell(i.asString)) ++ Seq(Cell("->")) ++ os - ) - } - } - - if (valids.nonEmpty) { - testsRows("Valid tests", valids) - } - - if (invalids.nonEmpty) { - testsRows("Invalid tests", invalids) - } - - tt.render - } else { - "No tests." - } - } -} - -object ExamplesBank { - def empty = ExamplesBank(Nil, Nil) - -} - -/** Same as an ExamplesBank, but with identifiers corresponding to values. This - * allows us to evaluate expressions. */ -case class QualifiedExamplesBank(as: List[Identifier], xs: List[Identifier], eb: ExamplesBank)(implicit hctx: SearchContext) { - - // TODO: This might be slightly conservative. We might want something closer to a partial evaluator, - // to conserve things like (e: A).isInstanceOf[A] even when evaluation of e leads to choose - private lazy val evaluator = new DefaultEvaluator(hctx, hctx.program).setEvaluationFailOnChoose(true) - - def removeOuts(toRemove: Set[Identifier]): QualifiedExamplesBank = { - val nxs = xs.filterNot(toRemove) - val toKeep = xs.zipWithIndex.filterNot(x => toRemove(x._1)).map(_._2) - - QualifiedExamplesBank(as, nxs, eb flatMapOuts { out => List(toKeep.map(out)) }) - } - - def removeIns(toRemove: Set[Identifier]) = { - val nas = as.filterNot(toRemove) - val toKeep: List[Int] = as.zipWithIndex.filterNot(a => toRemove(a._1)).map(_._2) - - QualifiedExamplesBank(nas, xs, eb flatMapIns { (in: Seq[Expr]) => List(toKeep.map(in)) }) - } - - def evalIns: QualifiedExamplesBank = copy( eb = flatMapIns { mapping => - val evalAs = evaluator.evalEnv(mapping) - List(as map evalAs) - }) - - /** Filter inputs through expr which is an expression evaluating to a boolean */ - def filterIns(expr: Expr): QualifiedExamplesBank = { - filterIns(m => evaluator.eval(expr, m).result.contains(BooleanLiteral(true))) - } - - /** Filters inputs through the predicate pred, with an assignment of input variables to expressions. */ - def filterIns(pred: Map[Identifier, Expr] => Boolean): QualifiedExamplesBank = { - QualifiedExamplesBank(as, xs, - eb flatMapIns { in => - val m = (as zip in).toMap - if(pred(m)) { - List(in) - } else { - Nil - } - } - ) - } - - /** Maps inputs through the function f - * - * @return A new ExampleBank */ - def flatMapIns(f: Seq[(Identifier, Expr)] => List[Seq[Expr]]): ExamplesBank = { - eb flatMap { - case InExample(in) => - f(as zip in).map(InExample) - - case InOutExample(in, out) => - f(as zip in).map(InOutExample(_, out)) - } - } -} - -import scala.language.implicitConversions - -object QualifiedExamplesBank { - implicit def qebToEb(qeb: QualifiedExamplesBank): ExamplesBank = qeb.eb -} diff --git a/src/main/scala/leon/synthesis/ExamplesFinder.scala b/src/main/scala/leon/synthesis/ExamplesFinder.scala deleted file mode 100644 index aa15058f23beb728c648f469d6d1009f23ac2440..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/ExamplesFinder.scala +++ /dev/null @@ -1,321 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis - -import purescala.Expressions._ -import purescala.Definitions._ -import purescala.ExprOps._ -import purescala.Common._ -import purescala.Constructors._ -import evaluators._ -import leon.grammars._ -import codegen._ -import datagen._ -import solvers._ -import solvers.z3._ - -class ExamplesFinder(ctx0: LeonContext, program: Program) { - - lazy val evaluator = new DefaultEvaluator(ctx, program) - - lazy val abstractEvaluator = new AbstractEvaluator(ctx, program) - - implicit val ctx = ctx0 - - val reporter = ctx.reporter - - private var keepAbstractExamples = false - /** If true, will not evaluate examples to check them. */ - def setKeepAbstractExamples(b: Boolean) = { this.keepAbstractExamples = b; this } - /** Sets if evalution of the result of tests should stop on choose statements. - * Useful for programming by Example */ - def setEvaluationFailOnChoose(b: Boolean) = { evaluator.setEvaluationFailOnChoose(b); this } - - def extractFromFunDef(fd: FunDef, partition: Boolean): ExamplesBank = fd.postcondition match { - case Some(Lambda(Seq(ValDef(id)), post)) => - // @mk FIXME: make this more general - val tests = extractTestsOf(post) - - val insIds = fd.params.map(_.id).toSet - val outsIds = Set(id) - val allIds = insIds ++ outsIds - - val examples = tests.toSeq.flatMap { t => - val ids = t.keySet - if ((ids & allIds) == allIds) { - Some(InOutExample(fd.params.map(p => t(p.id)), Seq(t(id)))) - } else if ((ids & insIds) == insIds) { - Some(InExample(fd.params.map(p => t(p.id)))) - } else if((ids & outsIds) == outsIds) { // Examples provided on a part of the inputs. - Some(InOutExample(fd.params.map(p => t.getOrElse(p.id, Variable(p.id))), Seq(t(id)))) - } else { - None - } - } - - def isValidTest(e: Example): Boolean = { - e match { - case InOutExample(ins, outs) => - evaluator.eval(Equals(outs.head, FunctionInvocation(fd.typed, ins))) match { - case EvaluationResults.Successful(BooleanLiteral(true)) => true - case _ => false - } - - case _ => false - } - } - - if (partition) { - val (v, iv) = examples.partition(isValidTest) - ExamplesBank(v, iv) - } else { - ExamplesBank(examples, Nil) - } - case None => - ExamplesBank(Nil, Nil) - } - - /** Extract examples from the passes found in expression */ - def extractFromProblem(p: Problem): ExamplesBank = { - val testClusters = extractTestsOf(p.pc and p.phi) - - // Finally, we keep complete tests covering all as++xs - val allIds = (p.as ++ p.xs).toSet - val insIds = p.as.toSet - val outsIds = p.xs.toSet - - val examples = testClusters.toSeq.flatMap { t => - val ids = t.keySet - if ((ids & allIds) == allIds) { - Some(InOutExample(p.as.map(t), p.xs.map(t))) - } else if ((ids & insIds) == insIds) { - Some(InExample(p.as.map(t))) - } else if((ids & outsIds) == outsIds) { // Examples provided on a part of the inputs. - Some(InOutExample(p.as.map(p => t.getOrElse(p, Variable(p))), p.xs.map(t))) - } else { - None - } - } - - def isValidExample(ex: Example): Boolean = { - if (this.keepAbstractExamples) return true // TODO: Abstract interpretation here ? - val (mapping, cond) = ex match { - case io: InOutExample => - (Map((p.as zip io.ins) ++ (p.xs zip io.outs): _*), p.pc and p.phi) - case i => - ((p.as zip i.ins).toMap, p.pc.toClause) - } - - evaluator.eval(cond, mapping) match { - case EvaluationResults.Successful(BooleanLiteral(true)) => true - case _ => false - } - } - - ExamplesBank(examples.filter(isValidExample), Seq()) - } - - def generateForPC(ids: List[Identifier], pc: Expr, ctx: LeonContext, maxValid: Int = 400, maxEnumerated: Int = 1000): ExamplesBank = { - //println(program.definedClasses) - - val evaluator = new CodeGenEvaluator(ctx, program) - val datagen = new GrammarDataGen(evaluator, ValueGrammar) - val solverF = SolverFactory.getFromSettings(ctx, program) - val solverDataGen = new SolverDataGen(ctx, program, solverF) - - val generatedExamples = datagen.generateFor(ids, pc, maxValid, maxEnumerated).map(InExample) - - val solverExamples = try { - solverDataGen.generateFor(ids, pc, maxValid, maxEnumerated).map(InExample) - } catch { - case e: leon.Unsupported => - Nil - } - - ExamplesBank(generatedExamples.toSeq ++ solverExamples.toList, Nil) - } - - /** Extracts all passes constructs from the given postcondition, merges them if needed */ - private def extractTestsOf(e: Expr): Set[Map[Identifier, Expr]] = { - val allTests = collect[Map[Identifier, Expr]] { - case Passes(ins, outs, cases) => - val infos = extractIds(Tuple(Seq(ins, outs))) - val ioPairs = cases.flatMap(caseToExamples(ins, outs, _)) - - val exs = ioPairs.map{ case (i, o) => - val test = Tuple(Seq(i, o)) - val ids = variablesOf(test) - - // Test could contain expressions, we evaluate - abstractEvaluator.eval(test, Model.empty) match { - case EvaluationResults.Successful((res, _)) => res - case _ => test - } - } - try { - // Check whether we can extract all ids from example - val results = exs.collect { case e if this.keepAbstractExamples || infos.forall(_._2.isDefinedAt(e)) => - infos.map{ case (id, f) => id -> f(e) }.toMap - } - results.toSet - } catch { - case e: IDExtractionException => Set() - } - case _ => - Set() - }(e) - - - consolidateTests(allTests) - } - - private def expand(e: Expr): Expr= { - abstractEvaluator.eval(e) match { - case EvaluationResults.Successful((res, a)) => res - case _ => e - } - } - - private def expand(e: (Expr, Expr)): (Expr, Expr) = (expand(e._1), expand(e._2)) - - /** Processes ((in, out) passes { - * cs[=>Case pattExpr if guard => outR]*/ - private def caseToExamples(in: Expr, out: Expr, cs: MatchCase, examplesPerCase: Int = 5): Seq[(Expr,Expr)] = { - - def doSubstitute(subs : Seq[(Identifier, Expr)], e : Expr) = - subs.foldLeft(e) { - case (from, (id, to)) => replaceFromIDs(Map(id -> to), from) - } - - if (cs.rhs == out) { - // The trivial example - Seq() - } else { - // The pattern as expression (input expression)(may contain free variables) - val (pattExpr, ieMap) = patternToExpression(cs.pattern, in.getType) - val freeVars = variablesOf(pattExpr).toSeq - val res = if (exists(_.isInstanceOf[NoTree])(pattExpr)) { - reporter.warning(cs.pattern.getPos, "Unapply patterns are not supported in IO-example extraction") - Seq() - } else if (freeVars.isEmpty) { - // The input contains no free vars. Trivially return input-output pair - Seq((pattExpr, doSubstitute(ieMap,cs.rhs))) - } else { - // Extract test cases such as case x if x == s => - ((pattExpr, ieMap, cs.optGuard) match { - case (Variable(id), Seq(), Some(Equals(Variable(id2), s))) if id == id2 => - Some((Seq((s, doSubstitute(ieMap, cs.rhs))))) - case (Variable(id), Seq(), Some(Equals(s, Variable(id2)))) if id == id2 => - Some((Seq((s, doSubstitute(ieMap, cs.rhs))))) - case (a, b, c) => - None - }) getOrElse { - - if(this.keepAbstractExamples) { - cs.optGuard match { - case Some(BooleanLiteral(false)) => - Seq() - case None => - Seq((pattExpr, cs.rhs)) - case Some(pred) => - Seq((Require(pred, pattExpr), cs.rhs)) - } - } else { - // If the input contains free variables, it does not provide concrete examples. - // We will instantiate them according to a simple grammar to get them. - val dataGen = new GrammarDataGen(evaluator) - - val theGuard = replace(Map(in -> pattExpr), cs.optGuard.getOrElse(BooleanLiteral(true))) - - dataGen.generateFor(freeVars, theGuard, examplesPerCase, 1000).toSeq map { vals => - val inst = freeVars.zip(vals).toMap - val inR = replaceFromIDs(inst, pattExpr) - val outR = replaceFromIDs(inst, doSubstitute(ieMap, cs.rhs)) - (inR, outR) - } - } - } - } - - if(this.keepAbstractExamples) res.map(expand) else res - } - } - - /** Check if two tests are compatible. - * Compatible should evaluate to the same value for the same identifier - */ - private def isCompatible(m1: Map[Identifier, Expr], m2: Map[Identifier, Expr]) = { - val ks = m1.keySet & m2.keySet - ks.nonEmpty && ks.map(m1) == ks.map(m2) - } - - /** Merge tests t1 and t2 if they are compatible. Return m1 if not. - */ - private def mergeTest(m1: Map[Identifier, Expr], m2: Map[Identifier, Expr]) = { - if (!isCompatible(m1, m2)) { - m1 - } else { - m1 ++ m2 - } - } - - /** we now need to consolidate different clusters of compatible tests together - * t1: a->1, c->3 - * t1: a->1, c->3 - * t2: a->1, b->4 - * => a->1, b->4, c->3 - */ - private def consolidateTests(ts: Set[Map[Identifier, Expr]]): Set[Map[Identifier, Expr]] = { - - var consolidated = Set[Map[Identifier, Expr]]() - for (t <- ts) { - consolidated += t - - consolidated = consolidated.map { c => - mergeTest(c, t) - } - } - consolidated - } - - case class IDExtractionException(msg: String) extends Exception(msg) - - /** Extract ids in ins/outs args, and compute corresponding extractors for values map - * - * Examples: - * (a,b) => - * a -> _.1 - * b -> _.2 - * - * Cons(a, Cons(b, c)) => - * a -> _.head - * b -> _.tail.head - * c -> _.tail.tail - */ - private def extractIds(e: Expr): Seq[(Identifier, PartialFunction[Expr, Expr])] = e match { - case Variable(id) => - List((id, { case e => e })) - case Tuple(vs) => - vs.map(extractIds).zipWithIndex.flatMap{ case (ids, i) => - ids.map{ case (id, e) => - (id, andThen({ case Tuple(vs) => vs(i) case e => throw new IDExtractionException("Expected Tuple, got " + e) }, e)) - } - } - case CaseClass(cct, args) => - args.map(extractIds).zipWithIndex.flatMap { case (ids, i) => - ids.map{ case (id, e) => - (id, andThen({ case CaseClass(cct2, vs) if cct2 == cct => vs(i) case e => throw new IDExtractionException("Expected Case class of type " + cct + ", got " + e) } ,e)) - } - } - - case _ => - reporter.warning("Unexpected pattern in test-ids extraction: "+e) - Nil - } - - // Compose partial functions - private def andThen(pf1: PartialFunction[Expr, Expr], pf2: PartialFunction[Expr, Expr]): PartialFunction[Expr, Expr] = { - Function.unlift(pf1.lift(_) flatMap pf2.lift) - } -} diff --git a/src/main/scala/leon/synthesis/FileInterface.scala b/src/main/scala/leon/synthesis/FileInterface.scala deleted file mode 100644 index 50e72c6770e95b81fc2517abd2a1148af35be16b..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/FileInterface.scala +++ /dev/null @@ -1,117 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis - -import purescala.Expressions._ -import purescala.Common.Tree -import purescala.ScalaPrinter -import purescala.PrinterOptions -import purescala.PrinterContext -import purescala.Definitions.{Definition, Program} -import purescala.DefOps -import leon.utils.RangePosition - -import java.io.File -class FileInterface(reporter: Reporter) { - - def updateFile(origFile: File, solutions: Map[SourceInfo, Expr])(implicit ctx: LeonContext) { - import java.io.{File, BufferedWriter, FileWriter} - val FileExt = """^(.+)\.([^.]+)$""".r - - origFile.getAbsolutePath match { - case FileExt(path, "scala") => - var i = 0 - def savePath = path+".scala."+i - - while (new File(savePath).isFile) { - i += 1 - } - - val origCode = readFile(origFile) - val backup = new File(savePath) - val newFile = new File(origFile.getAbsolutePath) - origFile.renameTo(backup) - - var newCode = origCode - for ( (ci, e) <- solutions) { - newCode = substitute(newCode, ci.source, e) - } - - val out = new BufferedWriter(new FileWriter(newFile)) - out.write(newCode) - out.close() - case _ => - - } - } - - def substitute(originalCode: String, fromTree: Tree, printer: (Int) => String): String = { - fromTree.getPos match { - case rp: RangePosition => - val from = rp.pointFrom - val to = rp.pointTo - - val before = originalCode.substring(0, from) - val after = originalCode.substring(to, originalCode.length) - - // Get base indentation of last line: - val lineChars = before.substring(before.lastIndexOf('\n')+1).toList - - val indent = lineChars.takeWhile(_ == ' ').size - - val res = printer(indent/2) - - before + res + after - - case p => - sys.error("Substitution requires RangePos on the input tree: "+fromTree +": "+fromTree.getClass+" GOT" +p) - } - } - - def insertAfter(originalCode: String, fromTree: Tree, printer: (Int) => String): String = { - fromTree.getPos match { - case rp: RangePosition => - val from = rp.pointFrom - val to = rp.pointTo - - val before = originalCode.substring(0, to) - val after = originalCode.substring(to, originalCode.length) - - // Get base indentation of last line: - val lineChars = before.substring(before.lastIndexOf('\n')+1).toList - - val indent = lineChars.takeWhile(_ == ' ').size - - val res = printer(indent/2) - - before + res + after - - case p => - sys.error("Substitution requires RangePos on the input tree: "+fromTree +": "+fromTree.getClass+" GOT" +p) - } - } - - - def substitute(str: String, fromTree: Tree, toTree: Tree, optpgm: Option[Program] = None)(implicit ctx: LeonContext): String = { - substitute(str, fromTree, (indent: Int) => { - val opts = PrinterOptions.fromContext(ctx) - val printProgram = (fromTree, toTree) match { - case (from: Definition, d: Definition) => - optpgm.map(p => - if(!p.containsDef(d)) { - DefOps.addDefs(p, Seq(d), from) - } else p) - case _ => - optpgm - } - val p = new ScalaPrinter(opts, printProgram) - p.pp(toTree)(PrinterContext(toTree, Nil, indent, p)) - p.toString - }) - } - - def readFile(file: File): String = { - scala.io.Source.fromFile(file).mkString - } -} diff --git a/src/main/scala/leon/synthesis/Histogram.scala b/src/main/scala/leon/synthesis/Histogram.scala deleted file mode 100644 index 7950d1ce714a55785e9776e64d4649ed793a02cc..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/Histogram.scala +++ /dev/null @@ -1,236 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon.synthesis - -/** - * Histogram from 0 to `bound`, each value between 0 and 1 - * hist(c) = v means we have a `v` likelihood of finding a solution of cost `c` - */ -class Histogram(val bound: Int, val values: Array[Double]) extends Ordered[Histogram] { - /** - */ - def and(that: Histogram): Histogram = { - val a = Array.fill(bound)(0d) - var i = 0 - while(i < bound) { - var j = 0 - while(j <= i) { - - val v1 = this.values(j) * that.values(i - j) - val v2 = a(i) - - a(i) = v1+v2 - (v1*v2) - - j += 1 - } - i += 1 - } - - val res = new Histogram(bound, a) - println("==== && ====") - println("this:"+ this) - println("that:"+ that) - println(" ==> "+res) - res - } - - /** - * hist1(c) || hist2(c) == hist1(c)+hist2(c) - hist1(c)*hist2(c) - */ - def or(that: Histogram): Histogram = { - val a = Array.fill(bound)(0d) - var i = 0 - while(i < bound) { - val v1 = this.values(i) - val v2 = that.values(i) - - a(i) = v1+v2 - (v1*v2) - i += 1 - } - - val res = new Histogram(bound, a) - println("==== || ====") - println("this:"+ this) - println("that:"+ that) - println(" ==> "+res) - res - } - - lazy val mode = { - var max = 0d - var argMax = -1 - var i = 0 - while(i < bound) { - if ((argMax < 0) || values(i) > max) { - argMax = i - max = values(i) - } - i += 1 - } - (max, argMax) - } - - lazy val firstNonZero = { - var i = 0 - var mini = -1 - while(i < bound && mini < 0) { - if (values(i) > 0) { - mini = i - } - i += 1 - } - if (mini >= 0) { - (values(mini), mini) - } else { - (0d, bound) - } - } - - lazy val moment = { - var i = 0 - var moment = 0d - var allV = 0d - while(i < bound) { - val v = values(i) - moment += v*i - allV += v - i += 1 - } - - moment/allV - } - - def isImpossible = mode._1 == 0 - - def compareByMode(that: Histogram) = { - val (m1, am1) = this.mode - val (m2, am2) = that.mode - - if (m1 == m2) { - am1 - am2 - } else { - if (m2 < m1) { - -1 - } else if (m2 == m1) { - 0 - } else { - +1 - } - } - } - - def rescaled(by: Double): Histogram = { - val a = new Array[Double](bound) - - var i = 0 - while(i < bound) { - val v = values(i) - - val nv = 1-Math.pow(1-v, by) - - a(i) = nv - - i += 1 - } - - new Histogram(bound, a) - } - - def compareByFirstNonZero(that: Histogram) = { - this.firstNonZero._2 - that.firstNonZero._2 - } - - def compareByMoment(that: Histogram) = { - this.moment - that.moment - } - - /** - * Should return v<0 if `this` < `that`, that is, `this` represents better - * solutions than `that`. - */ - def compare(that: Histogram) = { - compareByFirstNonZero(that) - } - - override def toString: String = { - var lastv = -1d - var fromi = -1 - val entries = new scala.collection.mutable.ArrayBuffer[((Int, Int), Double)]() - - - for (i <- 0 until bound) { - val v = values(i) - if (lastv < 0) { - lastv = v - fromi = i - } - - if (lastv != v) { - entries += (fromi, i-1) -> lastv - lastv = v - fromi = i - } - } - entries += (fromi, bound-1) -> lastv - - val info = for (((from, to), v) <- entries) yield { - val k = if (from == to) { - s"$from" - } else { - s"$from..$to" - } - - f"$k -> $v%1.3f" - } - - s"H($summary: ${info.mkString(", ")})" - } - - - def summary: String = { - //val (m, am) = maxInfo - val (m, am) = firstNonZero - - f"$m%1.4f->$am%-2d ($moment%1.3f)" - } -} - -object Histogram { - def clampV(v: Double): Double = { - if (v < 0) { - 0d - } else if (v > 1) { - 1d - } else { - v - } - } - - def point(bound: Int, at: Int, v: Double) = { - if (bound <= at) { - empty(bound) - } else { - new Histogram(bound, Array.fill(bound)(0d).updated(at, clampV(v))) - } - } - - def empty(bound: Int) = { - new Histogram(bound, Array.fill(bound)(0d)) - } - - def uniform(bound: Int, v: Double) = { - uniformFrom(bound, 0, v) - } - - def uniformFrom(bound: Int, from: Int, v: Double) = { - val vSafe = clampV(v) - var i = from - val a = Array.fill(bound)(0d) - while(i < bound) { - a(i) = vSafe - i += 1 - } - - new Histogram(bound, a) - } -} diff --git a/src/main/scala/leon/synthesis/InOutExample.scala b/src/main/scala/leon/synthesis/InOutExample.scala deleted file mode 100644 index 244515d3addb359a28e2507782536730f33cc5ab..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/InOutExample.scala +++ /dev/null @@ -1,24 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis - -import purescala.Expressions._ - -sealed abstract class Example extends Printable { - def ins: Seq[Expr] - - def asString(implicit ctx: LeonContext) = { - def esToStr(es: Seq[Expr]): String = { - es.map(_.asString).mkString("(", ", ", ")") - } - - this match { - case InExample(ins) => esToStr(ins) - case InOutExample(ins, outs) => esToStr(ins)+" ~> "+esToStr(outs) - } - } -} - -case class InOutExample(ins: Seq[Expr], outs: Seq[Expr]) extends Example -case class InExample(ins: Seq[Expr]) extends Example diff --git a/src/main/scala/leon/synthesis/LinearEquations.scala b/src/main/scala/leon/synthesis/LinearEquations.scala deleted file mode 100644 index 3069c90aeaa67a07c450684f5856210955d30dcd..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/LinearEquations.scala +++ /dev/null @@ -1,136 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis - -import purescala.Expressions._ -import purescala.TreeNormalizations.linearArithmeticForm -import purescala.Types._ -import purescala.Common._ -import evaluators._ - -import synthesis.Algebra._ - -object LinearEquations { - /** Eliminates one variable from normalizedEquation t + a1*x1 + ... + an*xn = 0 - * @return a mapping for each of the n variables in (pre, map, freshVars) */ - def elimVariable(evaluator: Evaluator, as: Set[Identifier], normalizedEquation: List[Expr]): (Expr, List[Expr], List[Identifier]) = { - require(normalizedEquation.size > 1) - require(normalizedEquation.tail.forall{case InfiniteIntegerLiteral(i) if i != BigInt(0) => true case _ => false}) - val t: Expr = normalizedEquation.head - val coefsVars: List[BigInt] = normalizedEquation.tail.map{case InfiniteIntegerLiteral(i) => i} - val orderedParams: Array[Identifier] = as.toArray - val coefsParams: List[BigInt] = linearArithmeticForm(t, orderedParams).map{case InfiniteIntegerLiteral(i) => i}.toList - //val coefsParams: List[Int] = if(coefsParams0.head == 0) coefsParams0.tail else coefsParams0 - val d: BigInt = gcd((coefsParams ++ coefsVars).toSeq) - - if(coefsVars.size == 1) { - val coef = coefsVars.head - (Equals(Modulo(t, InfiniteIntegerLiteral(coef)), InfiniteIntegerLiteral(0)), List(UMinus(Division(t, InfiniteIntegerLiteral(coef)))), List()) - } else if(d > 1) { - val newCoefsParams: List[Expr] = coefsParams.map(i => InfiniteIntegerLiteral(i/d) : Expr) - val newT = newCoefsParams.zip(InfiniteIntegerLiteral(1)::orderedParams.map(Variable).toList).foldLeft[Expr](InfiniteIntegerLiteral(0))((acc, p) => Plus(acc, Times(p._1, p._2))) - elimVariable(evaluator, as, newT :: normalizedEquation.tail.map{case InfiniteIntegerLiteral(i) => InfiniteIntegerLiteral(i/d) : Expr}) - } else { - val basis: Array[Array[BigInt]] = linearSet(evaluator, as, normalizedEquation.tail.map{case InfiniteIntegerLiteral(i) => i}.toArray) - val (pre, sol) = particularSolution(as, normalizedEquation) - val freshVars: Array[Identifier] = basis(0).map(_ => FreshIdentifier("v", IntegerType, true)) - - val tbasis = basis.transpose - assert(freshVars.length == tbasis.length) - val basisWithFreshVars: Array[Array[Expr]] = freshVars.zip(tbasis).map{ - case (lambda, column) => column.map((i: BigInt) => Times(InfiniteIntegerLiteral(i), Variable(lambda)): Expr) - }.transpose - val combinationBasis: Array[Expr] = basisWithFreshVars.map((v: Array[Expr]) => v.foldLeft[Expr](InfiniteIntegerLiteral(0))((acc, e) => Plus(acc, e))) - assert(combinationBasis.length == sol.size) - val subst: List[Expr] = sol.zip(combinationBasis.toList).map(p => Plus(p._1, p._2): Expr) - - (pre, subst, freshVars.toList) - } - - } - - /** Computes a list of solutions to the equation c1*x1 + ... + cn*xn where coef = [c1 ... cn] - * @return the solution in the form of a list of n-1 vectors that form a basis for the set - * of solutions, that is res=(v1, ..., v{n-1}) and any solution x* to the original solution - * is a linear combination of the vi's - * Intuitively, we are building a "basis" for the "vector space" of solutions (although we are over - * integers, so it is not a vector space). - * we are returning a matrix where the columns are the vectors */ - def linearSet(evaluator: Evaluator, as: Set[Identifier], coef: Array[BigInt]): Array[Array[BigInt]] = { - - val K = Array.ofDim[BigInt](coef.length, coef.length-1) - for(i <- 0 until K.length) { - for(j <- 0 until K(i).length) { - if(i < j) - K(i)(j) = 0 - else if(i == j) { - K(j)(j) = gcd(coef.drop(j+1))/gcd(coef.drop(j)) - } - } - } - for(j <- 0 until K.length - 1) { - val (_, sols) = particularSolution(as, InfiniteIntegerLiteral(coef(j)*K(j)(j)) :: coef.drop(j+1).map(InfiniteIntegerLiteral).toList) - var i = 0 - while(i < sols.size) { - // seriously ??? - K(i+j+1)(j) = evaluator.eval(sols(i)).asInstanceOf[EvaluationResults.Successful[Expr]].value.asInstanceOf[InfiniteIntegerLiteral].value - i += 1 - } - } - - K - } - - /** @param as The parameters - * @param xs The variable for which we want to find one satisfiable assignment - * @return (pre, sol) with pre a precondition under which sol is a solution mapping to the xs */ - def particularSolution(as: Set[Identifier], xs: Set[Identifier], equation: Equals): (Expr, Map[Identifier, Expr]) = { - val lhs = equation.lhs - val rhs = equation.rhs - val orderedXs = xs.toArray - val normalized: Array[Expr] = linearArithmeticForm(Minus(lhs, rhs), orderedXs) - val (pre, sols) = particularSolution(as, normalized.toList) - (pre, orderedXs.zip(sols).toMap) - } - - /** @return a particular solution to t + c1x + c2y = 0, with (pre, (x0, y0)) */ - def particularSolution(as: Set[Identifier], t: Expr, c1: Expr, c2: Expr): (Expr, (Expr, Expr)) = { - val (InfiniteIntegerLiteral(i1), InfiniteIntegerLiteral(i2)) = (c1, c2) - val (v1, v2) = extendedEuclid(i1, i2) - val d = gcd(i1, i2) - - val pre = Equals(Modulo(t, InfiniteIntegerLiteral(d)), InfiniteIntegerLiteral(0)) - - (pre, - ( - UMinus(Times(InfiniteIntegerLiteral(v1), Division(t, InfiniteIntegerLiteral(d)))), - UMinus(Times(InfiniteIntegerLiteral(v2), Division(t, InfiniteIntegerLiteral(d)))) - ) - ) - } - - /** the equation must at least contain the term t and one variable */ - def particularSolution(as: Set[Identifier], normalizedEquation: List[Expr]): (Expr, List[Expr]) = { - require(normalizedEquation.size >= 2) - val t: Expr = normalizedEquation.head - val coefs: List[BigInt] = normalizedEquation.tail.map{case InfiniteIntegerLiteral(i) => i} - val d = gcd(coefs.toSeq) - val pre = Equals(Modulo(t, InfiniteIntegerLiteral(d)), InfiniteIntegerLiteral(0)) - - if(normalizedEquation.size == 2) { - (pre, List(UMinus(Division(t, normalizedEquation(1))))) - } else if(normalizedEquation.size == 3) { - val (_, (w1, w2)) = particularSolution(as, t, normalizedEquation(1), normalizedEquation(2)) - (pre, List(w1, w2)) - } else { - val gamma1: Expr = normalizedEquation(1) - val coefs: List[BigInt] = normalizedEquation.drop(2).map{case InfiniteIntegerLiteral(i) => i} - val gamma2: Expr = InfiniteIntegerLiteral(gcd(coefs.toSeq)) - val (_, (w1, w)) = particularSolution(as, t, gamma1, gamma2) - val (_, sols) = particularSolution(as, UMinus(Times(w, gamma2)) :: normalizedEquation.drop(2)) - (pre, w1 :: sols) - } - - } -} diff --git a/src/main/scala/leon/synthesis/PartialSolution.scala b/src/main/scala/leon/synthesis/PartialSolution.scala deleted file mode 100644 index 4f46997a83deed858f6cb13d4e1d086a00f77946..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/PartialSolution.scala +++ /dev/null @@ -1,95 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis - -import purescala.Expressions._ - -import graph._ -import strategies._ - -class PartialSolution(strat: Strategy, includeUntrusted: Boolean = false) { - - def includeSolution(s: Solution) = { - includeUntrusted || s.isTrusted - } - - def completeProblem(p: Problem) = { - Solution.choose(p) - } - - def solutionAround(n: Node): Expr => Option[Solution] = { - def solveWith(optn: Option[Node], sol: Solution): Option[Solution] = optn match { - case None => - Some(sol) - - case Some(n) => n.parent match { - case None => - Some(sol) - - case Some(on: OrNode) => - solveWith(on.parent, sol) - - case Some(an: AndNode) => - val ssols = for (d <- an.descendants) yield { - if (d == n) { - sol - } else { - getSolutionFor(d) - } - } - - an.ri.onSuccess(ssols).flatMap { nsol => - solveWith(an.parent, nsol) - } - } - } - - e : Expr => solveWith(Some(n), Solution(BooleanLiteral(true), Set(), e)) - - } - - def getSolutionFor(n: Node): Solution = { - n match { - case on: OrNode => - if (on.isSolved) { - val sols = on.generateSolutions() - sols.find(includeSolution) match { - case Some(sol) => - return sol - case _ => - } - } - - if (n.isExpanded) { - strat.bestAlternative(on) match { - case None => completeProblem(on.p) - case Some(d) => getSolutionFor(d) - } - } else { - completeProblem(on.p) - } - case an: AndNode => - if (an.isSolved) { - val sols = an.generateSolutions() - sols.find(includeSolution) match { - case Some(sol) => - return sol - case _ => - } - } - - if (n.isExpanded) { - an.ri.onSuccess(n.descendants.map(getSolutionFor)) match { - case Some(sol) => - sol - - case None => - completeProblem(an.ri.problem) - } - } else { - completeProblem(an.ri.problem) - } - } - } -} diff --git a/src/main/scala/leon/synthesis/Problem.scala b/src/main/scala/leon/synthesis/Problem.scala deleted file mode 100644 index 903d9e25f8650fd607c6784a0a63d4b142a7822b..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/Problem.scala +++ /dev/null @@ -1,96 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis - -import purescala.Path -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Types._ -import purescala.Common._ -import purescala.Constructors._ -import purescala.Extractors._ -import purescala.Definitions.FunDef -import Witnesses._ - -/** Defines a synthesis triple of the form: - * ⟦ as ⟨ ws && pc | phi ⟩ xs ⟧ - * - * @param as The list of input identifiers so far - * @param ws The axioms and other already proven theorems - * @param pc The path condition so far - * @param phi The formula on `as` and `xs` to satisfy - * @param xs The list of output identifiers for which we want to compute a function - * @note Since the examples are not eagerly filtered by [[Rule]]s, some may not - * pass the path condition. Use [[qebFiltered]] to get the legal ones. - */ -case class Problem(as: List[Identifier], ws: Expr, pc: Path, phi: Expr, xs: List[Identifier], eb: ExamplesBank = ExamplesBank.empty) extends Printable { - - // require(eb.examples.forall(_.ins.size == as.size)) - - val TopLevelAnds(wsList) = ws - - def inType = tupleTypeWrap(as.map(_.getType)) - def outType = tupleTypeWrap(xs.map(_.getType)) - - def allAs = as ++ (pc.bindings.map(_._1) diff wsList.collect{ case Inactive(i) => i }) - - def asString(implicit ctx: LeonContext): String = { - def pad(padding: String, text: String) = text.lines.mkString(s"\n|$padding") - val pcws = pc withCond ws - - val ebInfo = "/"+eb.valids.size+","+eb.invalids.size+"/" - - s"""|⟦ α ${if (as.nonEmpty) as.map(_.asString).mkString(", ") else "()"} - | Π ${pad(" ", pcws.fullClause.asString)} - | φ ${pad(" ", phi.asString)} - | x ${if (xs.nonEmpty) xs.map(_.asString).mkString(", ") else "()"} - |⟧ $ebInfo""".stripMargin - } - - def withWs(es: Traversable[Expr]) = { - copy(ws = andJoin(wsList ++ es)) - } - - def qebFiltered(implicit sctx: SearchContext) = qeb.evalIns.filterIns(pc.fullClause) - - // Qualified example bank, allows us to perform operations (e.g. filter) with expressions - def qeb(implicit sctx: SearchContext) = QualifiedExamplesBank(this.as, this.xs, eb) - -} - -object Problem { - - def fromSpec( - spec: Expr, - pc: Path = Path.empty, - eb: ExamplesBank = ExamplesBank.empty, - fd: Option[FunDef] = None - ): Problem = { - val xs = (spec match { - case Lambda(args, _) => args.map(_.id) - case IsTyped(_, FunctionType(from, to)) => - from map (FreshIdentifier("x", _, alwaysShowUniqueID = true)) - case _ => - throw LeonFatalError(s"$spec is the spec of a choose but is not a function?") - }).toList - - val phi = application(simplifyLets(spec), xs map { _.toVariable}) - val as = (variablesOf(phi) ++ pc.variables -- xs).toList.sortBy(_.name) - - val sortedAs = fd match { - case None => as - case Some(fd) => - val argsIndex = fd.params.map(_.id).zipWithIndex.toMap.withDefaultValue(100) - as.sortBy(a => argsIndex(a)) - } - - val (pcs, wss) = pc.partition { - case w: Witness => false - case _ => true - } - - Problem(sortedAs, andJoin(wss), pcs, phi, xs, eb) - } - -} diff --git a/src/main/scala/leon/synthesis/Rules.scala b/src/main/scala/leon/synthesis/Rules.scala deleted file mode 100644 index df7b2fd7bebcc8fb7c925bd6327789cf7271a444..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/Rules.scala +++ /dev/null @@ -1,203 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis - -import purescala.Common._ -import purescala.Expressions._ -import purescala.Types._ -import purescala.ExprOps._ -import purescala.Constructors._ -import rules._ - -/** A Rule can be applied on a synthesis problem */ -abstract class Rule(val name: String) extends RuleDSL with Printable { - def instantiateOn(implicit hctx: SearchContext, problem: Problem): Traversable[RuleInstantiation] - - val priority: RulePriority = RulePriorityDefault - - implicit val debugSection = leon.utils.DebugSectionSynthesis - - implicit val thisRule = this - - def asString(implicit ctx: LeonContext) = name -} - -abstract class NormalizingRule(name: String) extends Rule(name) { - override val priority = RulePriorityNormalizing -} - -abstract class PreprocessingRule(name: String) extends Rule(name) { - override val priority = RulePriorityPreprocessing -} - -/** Contains the list of all available rules for synthesis */ -object Rules { - - def all: List[Rule] = all(false, false) - /** Returns the list of all available rules for synthesis */ - def all(naiveGrammar: Boolean, introduceRecCalls: Boolean): List[Rule] = List[Rule]( - StringRender, - Unification.DecompTrivialClash, - Unification.OccursCheck, // probably useless - Disunification.Decomp, - ADTDual, - OnePoint, - Ground, - CaseSplit, - IndependentSplit, - IfSplit, - InputSplit, - UnusedInput, - EquivalentInputs, - UnconstrainedOutput, - if(naiveGrammar) NaiveCEGIS else CEGIS, - OptimisticGround, - GenericTypeEqualitySplit, - InequalitySplit, - rules.Assert, - DetupleInput, - ADTSplit, - InnerCaseSplit - ) ++ introduceRecCalls.option(IntroduceRecCalls) - -} - -/** When applying this to a [SearchContext] it returns a wrapped stream of solutions or a new list of problems. */ -abstract class RuleInstantiation(val description: String, - val onSuccess: SolutionBuilder = SolutionBuilderCloser()) - (implicit val problem: Problem, val rule: Rule) extends Printable { - - def apply(hctx: SearchContext): RuleApplication - - def asString(implicit ctx: LeonContext) = description -} - -object RuleInstantiation { - def apply(description: String)(f: => RuleApplication)(implicit problem: Problem, rule: Rule): RuleInstantiation = { - new RuleInstantiation(description) { - def apply(hctx: SearchContext): RuleApplication = f - } - } -} - -/** - * Wrapper class for a function returning a recomposed solution from a list of - * subsolutions - * - * We also need to know the types of the expected sub-solutions to use them in - * cost-models before having real solutions. - */ -abstract class SolutionBuilder { - val types: Seq[TypeTree] - - def apply(sols: List[Solution]): Option[Solution] -} - -case class SolutionBuilderDecomp(types: Seq[TypeTree], recomp: List[Solution] => Option[Solution]) extends SolutionBuilder { - def apply(sols: List[Solution]): Option[Solution] = { - assert(types.size == sols.size) - recomp(sols) - } -} - -/** - * Used by rules expected to close, no decomposition but maybe we already know - * the solution when instantiating - */ -case class SolutionBuilderCloser(osol: Option[Solution] = None) extends SolutionBuilder { - val types = Nil - def apply(sols: List[Solution]) = { - assert(sols.isEmpty) - osol - } -} - -/** - * Results of applying rule instantiations - * - * Can either close meaning a stream of solutions are available (can be empty, - * if it failed) - */ -sealed abstract class RuleApplication -/** Result of applying rule instantiation, finished, resulting in a stream of solutions */ -case class RuleClosed(solutions: Stream[Solution]) extends RuleApplication -/** Result of applying rule instantiation, resulting is a nnew list of problems */ -case class RuleExpanded(sub: List[Problem]) extends RuleApplication - -object RuleClosed { - def apply(s: Solution): RuleClosed = RuleClosed(Stream(s)) -} - -object RuleFailed { - def apply(): RuleClosed = RuleClosed(Stream.empty) -} - -/** - * Rule priorities, which drive the instantiation order. - */ -sealed abstract class RulePriority(val v: Int) extends Ordered[RulePriority] { - def compare(that: RulePriority) = this.v - that.v -} - -case object RulePriorityPreprocessing extends RulePriority(5) -case object RulePriorityNormalizing extends RulePriority(10) -case object RulePriorityHoles extends RulePriority(15) -case object RulePriorityDefault extends RulePriority(20) - -/** - * Common utilities used by rules - */ -trait RuleDSL { - this: Rule => - /** Replaces all first elements of `what` by their second element in the expression `ìn`*/ - def subst(what: (Identifier, Expr), in: Expr): Expr = replaceFromIDs(Map(what), in) - /** Replaces all keys of `what` by their key in the expression `ìn`*/ - def substAll(what: Map[Identifier, Expr], in: Expr): Expr = replaceFromIDs(what, in) - - val forward: List[Solution] => Option[Solution] = { ss => ss.headOption } - - /** Returns a function that transforms the precondition and term of the first Solution of a list using `f`. */ - def forwardMap(f : Expr => Expr) : List[Solution] => Option[Solution] = { - _.headOption map { s => - Solution(f(s.pre), s.defs, f(s.term), s.isTrusted) - } - } - - /** Groups sub-problems and a callback merging the solutions to produce a global solution.*/ - def decomp(sub: List[Problem], onSuccess: List[Solution] => Option[Solution], description: String) - (implicit problem: Problem): RuleInstantiation = { - - val subTypes = sub.map(_.outType) - - new RuleInstantiation(description, - SolutionBuilderDecomp(subTypes, onSuccess)) { - def apply(hctx: SearchContext) = RuleExpanded(sub) - } - } - - def solve(sol: Solution) - (implicit problem: Problem, ctx: LeonContext): RuleInstantiation = { - - new RuleInstantiation(s"Solve: ${sol.asString}", - SolutionBuilderCloser(Some(sol))) { - def apply(hctx: SearchContext) = RuleClosed(sol) - } - - } - - /** @param pc corresponds to the post-condition to reach the point where the solution is used. It - * will be used if the sub-solution has a non-true precondition. */ - def termWrap(f: Expr => Expr, pc: Expr = BooleanLiteral(true)): List[Solution] => Option[Solution] = { - case List(s) => - val pre = if (s.pre == BooleanLiteral(true)) { - BooleanLiteral(true) - } else { - and(pc, s.pre) - } - - Some(Solution(pre, s.defs, f(s.term), s.isTrusted)) - case _ => None - - } -} diff --git a/src/main/scala/leon/synthesis/Search.scala b/src/main/scala/leon/synthesis/Search.scala deleted file mode 100644 index 15efbc18f36377e2478f6479a3ada9c01ae11c63..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/Search.scala +++ /dev/null @@ -1,93 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis - -import strategies._ -import graph._ - -import scala.annotation.tailrec - -import leon.utils.Interruptible -import java.util.concurrent.atomic.AtomicBoolean - -class Search(val ctx: LeonContext, ci: SourceInfo, val strat: Strategy) extends Interruptible { - - val p: Problem = ci.problem - val g = new Graph(p) - - val interrupted = new AtomicBoolean(false) - - strat.init(g.root) - - def doExpand(n: Node, sctx: SynthesisContext): Unit = { - ctx.timers.synthesis.step.timed { - n match { - case an: AndNode => - ctx.timers.synthesis.applications.get(an.ri.asString(sctx)).timed { - an.expand(new SearchContext(sctx, ci.source, an, this)) - } - - case on: OrNode => - on.expand(new SearchContext(sctx, ci.source, on, this)) - } - } - } - - @tailrec - final def searchFrom(sctx: SynthesisContext, from: Node): Boolean = { - strat.getNextToExpand(from) match { - case Some(n) => - strat.beforeExpand(n) - - doExpand(n, sctx) - - strat.afterExpand(n) - - if (from.isSolved) { - true - } else if (interrupted.get) { - false - } else { - searchFrom(sctx, from) - } - case None => - false - } - } - - def traversePathFrom(n: Node, path: List[Int]): Option[Node] = path match { - case Nil => - Some(n) - case x :: xs => - if (n.isExpanded && n.descendants.size > x) { - traversePathFrom(n.descendants(x), xs) - } else { - None - } - } - - def traversePath(path: List[Int]): Option[Node] = { - traversePathFrom(g.root, path) - } - - def search(sctx: SynthesisContext): Stream[Solution] = { - if (searchFrom(sctx, g.root)) { - g.root.generateSolutions() - } else { - Stream.empty - } - } - - def interrupt(): Unit = { - interrupted.set(true) - strat.interrupt() - } - - def recoverInterrupt(): Unit = { - interrupted.set(false) - strat.recoverInterrupt() - } - - ctx.interruptManager.registerForInterrupts(this) -} diff --git a/src/main/scala/leon/synthesis/SearchContext.scala b/src/main/scala/leon/synthesis/SearchContext.scala deleted file mode 100644 index 63683d9278e0219eaeb32c5755e0c8ed25f79836..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/SearchContext.scala +++ /dev/null @@ -1,36 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis - -import graph._ -import purescala.Expressions.Expr - -/** - * This is context passed down rules, and include search-wise context, as well - * as current search location information - */ -class SearchContext ( - sctx: SynthesisContext, - val source: Expr, - val currentNode: Node, - val search: Search -) extends SynthesisContext( - sctx, - sctx.settings, - sctx.functionContext, - sctx.program -) { - - def searchDepth = { - def depthOf(n: Node): Int = n.parent match { - case Some(n2) => 1+depthOf(n2) - case None => 0 - } - - depthOf(currentNode) - } - - def parentNode: Option[Node] = currentNode.parent - -} diff --git a/src/main/scala/leon/synthesis/Solution.scala b/src/main/scala/leon/synthesis/Solution.scala deleted file mode 100644 index 435dd2b23d83c478361e667a11d3a34978f6a336..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/Solution.scala +++ /dev/null @@ -1,97 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis - -import purescala.Common._ -import purescala.Expressions._ -import purescala.Types.{TypeTree,TupleType} -import purescala.Definitions._ -import purescala.ExprOps._ -import purescala.Constructors._ -import purescala.Path - -import leon.utils.Simplifiers - -// Defines a synthesis solution of the form: -// ⟨ P | T ⟩ -case class Solution(pre: Expr, defs: Set[FunDef], term: Expr, isTrusted: Boolean = true) extends Printable { - - def asString(implicit ctx: LeonContext) = { - "⟨ "+pre.asString+" | "+defs.map(_.asString).mkString(" ")+" "+term.asString+" ⟩" - } - - def guardedTerm = { - if (pre == BooleanLiteral(true)) { - term - } else if (pre == BooleanLiteral(false)) { - Error(term.getType, "Impossible program") - } else { - IfExpr(pre, term, Error(term.getType, "Precondition failed")) - } - } - - def toExpr = { - letDef(defs.toList, guardedTerm) - } - - def ifOnFunDef[T](originalFun: FunDef)(body: => T): T = { - val saved = originalFun.body - originalFun.body = Some(term) - val res = body - originalFun.body = saved - res - } - - // Projects a solution (ignore several output variables) - // - // e.g. Given solution for [[ a < .. > x1, x2, x3, x4 ]] and List(0, 1, 3) - // It produces a solution for [[ a < .. > x1, x2, x4 ]] - // - // Indices are 0-indexed - def project(indices: Seq[Int]): Solution = { - term.getType match { - case TupleType(ts) => - val t = FreshIdentifier("t", term.getType, true) - val newTerm = Let(t, term, tupleWrap(indices.map(i => tupleSelect(t.toVariable, i+1, indices.size)))) - - Solution(pre, defs, newTerm) - case _ => - this - } - } - - def toSimplifiedExpr(ctx: LeonContext, p: Program, within: FunDef): Expr = { - Simplifiers.bestEffort(ctx, p)(toExpr, Path(within.precOrTrue)) - } -} - -object Solution { - - def term(term: Expr, isTrusted: Boolean = true) = { - Solution(BooleanLiteral(true), Set(), term, isTrusted) - } - - def choose(p: Problem): Solution = { - Solution(BooleanLiteral(true), Set(), Choose(Lambda(p.xs.map(ValDef), p.phi))) - } - - def chooseComplete(p: Problem): Solution = { - Solution(BooleanLiteral(true), Set(), Choose(Lambda(p.xs.map(ValDef), p.pc and p.phi))) - } - - // Generate the simplest, wrongest solution, used for complexity lower bound - def simplest(t: TypeTree): Solution = { - Solution(BooleanLiteral(true), Set(), simplestValue(t)) - } - - def failed(implicit p: Problem): Solution = { - val tpe = tupleTypeWrap(p.xs.map(_.getType)) - Solution(BooleanLiteral(false), Set(), Error(tpe, "Rule failed!")) - } - - def UNSAT(implicit p: Problem): Solution = { - val tpe = tupleTypeWrap(p.xs.map(_.getType)) - Solution(BooleanLiteral(false), Set(), Error(tpe, "Spec is UNSAT for this path!")) - } -} diff --git a/src/main/scala/leon/synthesis/SourceInfo.scala b/src/main/scala/leon/synthesis/SourceInfo.scala deleted file mode 100644 index 0a34afb69f951b7f16b295ed76b817e3f6b06d9f..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/SourceInfo.scala +++ /dev/null @@ -1,76 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis - -import purescala.Path -import purescala.Definitions._ -import purescala.Constructors._ -import purescala.Expressions._ -import purescala.ExprOps._ -import Witnesses._ - -case class SourceInfo(fd: FunDef, source: Expr, problem: Problem) - -object SourceInfo { - - class ChooseCollectorWithPaths extends CollectorWithPaths[(Choose,Path)] { - def collect(e: Expr, path: Path) = e match { - case c: Choose => Some(c -> path) - case _ => None - } - } - - def extractFromProgram(ctx: LeonContext, prog: Program): List[SourceInfo] = { - val functions = ctx.findOption(GlobalOptions.optFunctions) map { _.toSet } - - def excludeByDefault(fd: FunDef): Boolean = { - fd.annotations contains "library" - } - - val fdFilter = { - import OptionsHelpers._ - filterInclusive(functions.map(fdMatcher(prog)), Some(excludeByDefault _)) - } - - // Look for choose() - val results = for (f <- prog.definedFunctions if f.body.isDefined && fdFilter(f); - ci <- extractFromFunction(ctx, prog, f)) yield { - ci - } - - if (results.isEmpty) { - ctx.reporter.warning("No 'choose' found. Maybe the functions you indicated do not exist?") - } - - results.sortBy(_.source.getPos) - } - - def extractFromFunction(ctx: LeonContext, prog: Program, fd: FunDef): Seq[SourceInfo] = { - - val term = Terminating(fd.applied) - - val eFinder = new ExamplesFinder(ctx, prog) - - // We are synthesizing, so all examples are valid ones - val functionEb = eFinder.extractFromFunDef(fd, partition = false) - - for ((ch, path) <- new ChooseCollectorWithPaths().traverse(fd)) yield { - val outerEb = if (path.isEmpty) { - functionEb - } else { - ExamplesBank.empty - } - - val p = Problem.fromSpec(ch.pred, path withCond term, outerEb, Some(fd)) - - val pcEb = eFinder.generateForPC(p.as, path.toClause, ctx, 20) - val chooseEb = eFinder.extractFromProblem(p) - val eb = (outerEb union chooseEb) union pcEb - - val betterP = p.copy(eb = eb) - - SourceInfo(fd, ch, betterP) - } - } -} diff --git a/src/main/scala/leon/synthesis/SynthesisContext.scala b/src/main/scala/leon/synthesis/SynthesisContext.scala deleted file mode 100644 index 812adba389500f0d03a78a03df3fb1160b1da4d2..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/SynthesisContext.scala +++ /dev/null @@ -1,33 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis - -import solvers._ -import purescala.Definitions.{Program, FunDef} -import evaluators._ - -/** - * This is global information per entire search, contains necessary information - */ -class SynthesisContext( - context: LeonContext, - val settings: SynthesisSettings, - val functionContext: FunDef, - val program: Program -) extends LeonContext( - context.reporter, - context.interruptManager, - context.options, - context.files, - context.classDir, - context.timers -) { - - val solverFactory = SolverFactory.getFromSettings(context, program) - - lazy val defaultEvaluator = { - new DefaultEvaluator(context, program) - } - -} diff --git a/src/main/scala/leon/synthesis/SynthesisPhase.scala b/src/main/scala/leon/synthesis/SynthesisPhase.scala deleted file mode 100644 index 9578b6ddb1be4b90fc6ebe912f61ab4fd74c4ccf..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/SynthesisPhase.scala +++ /dev/null @@ -1,114 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis - -import purescala.ExprOps.replace -import purescala.ScalaPrinter -import purescala.Definitions.{Program, FunDef} - -import leon.utils._ -import graph._ - -object SynthesisPhase extends UnitPhase[Program] { - val name = "Synthesis" - val description = "Partial synthesis of \"choose\" constructs. Also used by repair during the synthesis stage." - - val optManual = LeonStringOptionDef("manual", "Manual search", default = "", "[cmd]") - val optCostModel = LeonStringOptionDef("costmodel", "Use a specific cost model for this search", "FIXME", "cm") - val optDerivTrees = LeonFlagOptionDef("derivtrees", "Generate derivation trees", false) - val optAllowPartial = LeonFlagOptionDef("partial", "Allow partial solutions", true) - val optIntroduceRecCalls = LeonFlagOptionDef("introreccalls", "Use a rule to introduce rec. calls outside of CEGIS", true) - - // CEGIS options - val optCEGISOptTimeout = LeonFlagOptionDef("cegis:opttimeout", "Consider a time-out of CE-search as untrusted solution", true ) - val optCEGISVanuatoo = LeonFlagOptionDef("cegis:vanuatoo", "Generate inputs using new korat-style generator", false) - val optCEGISNaiveGrammar = LeonFlagOptionDef("cegis:naive", "Use the old naive grammar for CEGIS", false) - val optCEGISMaxSize = LeonLongOptionDef("cegis:maxsize", "Maximum size of expressions synthesized by CEGIS", 7L, "N") - - override val definedOptions : Set[LeonOptionDef[Any]] = Set( - optManual, optCostModel, optDerivTrees, optAllowPartial, optIntroduceRecCalls, - optCEGISOptTimeout, optCEGISVanuatoo, optCEGISNaiveGrammar, optCEGISMaxSize - ) - - def processOptions(ctx: LeonContext): SynthesisSettings = { - val ms = ctx.findOption(optManual) - val timeout = ctx.findOption(GlobalOptions.optTimeout) - if (ms.isDefined && timeout.isDefined) { - ctx.reporter.warning("Defining timeout with manual search") - } - val costModel = { - ctx.findOption(optCostModel) match { - case None => CostModels.default - case Some(name) => CostModels.all.find(_.name.toLowerCase == name.toLowerCase) getOrElse { - var errorMsg = "Unknown cost model: " + name + "\n" + - "Defined cost models: \n" - - for (cm <- CostModels.all.toSeq.sortBy(_.name)) { - errorMsg += " - " + cm.name + (if (cm == CostModels.default) " (default)" else "") + "\n" - } - - ctx.reporter.fatalError(errorMsg) - } - } - } - - SynthesisSettings( - timeoutMs = timeout map { _ * 1000 }, - generateDerivationTrees = ctx.findOptionOrDefault(optDerivTrees), - costModel = costModel, - rules = Rules.all(ctx.findOptionOrDefault(optCEGISNaiveGrammar), ctx.findOptionOrDefault(optIntroduceRecCalls)), - manualSearch = ms, - functions = ctx.findOption(GlobalOptions.optFunctions) map { _.toSet }, - cegisUseOptTimeout = ctx.findOptionOrDefault(optCEGISOptTimeout), - cegisUseVanuatoo = ctx.findOptionOrDefault(optCEGISVanuatoo), - cegisMaxSize = ctx.findOptionOrDefault(optCEGISMaxSize).toInt - ) - } - - def apply(ctx: LeonContext, program: Program): Unit = { - val options = processOptions(ctx) - - val chooses = SourceInfo.extractFromProgram(ctx, program) - - var functions = Set[FunDef]() - - chooses.toSeq.sortBy(_.fd.id).foreach { ci => - val fd = ci.fd - - val synthesizer = new Synthesizer(ctx, program, ci, options) - - val to = new TimeoutFor(ctx.interruptManager) - - to.interruptAfter(options.timeoutMs) { - val allowPartial = ctx.findOptionOrDefault(optAllowPartial) - - val (search, solutions) = synthesizer.validate(synthesizer.synthesize(), allowPartial) - - try { - if (options.generateDerivationTrees) { - val dot = new DotGenerator(search) - dot.writeFile("derivation"+dotGenIds.nextGlobal+".dot") - } - - solutions.headOption foreach { case (sol, _) => - val expr = sol.toSimplifiedExpr(ctx, program, ci.fd) - fd.body = fd.body.map(b => replace(Map(ci.source -> expr), b)) - functions += fd - } - - } finally { - synthesizer.shutdown() - } - } - } - - for (fd <- functions) { - ctx.reporter.info(ASCIIHelpers.title(fd.id.name)) - ctx.reporter.info(ScalaPrinter(fd, opgm = Some(program))) - ctx.reporter.info("") - } - - } - -} diff --git a/src/main/scala/leon/synthesis/SynthesisSettings.scala b/src/main/scala/leon/synthesis/SynthesisSettings.scala deleted file mode 100644 index 5bbff9444ded17b5100b1bc5ba0c3949b7f4dc61..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/SynthesisSettings.scala +++ /dev/null @@ -1,23 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis - -import leon.purescala.Definitions.FunDef - -case class SynthesisSettings( - timeoutMs: Option[Long] = None, - generateDerivationTrees: Boolean = false, - costModel: CostModel = CostModels.default, - rules: Seq[Rule] = Rules.all, - manualSearch: Option[String] = None, - searchBound: Option[Int] = None, - functions: Option[Set[String]] = None, - functionsToIgnore: Set[FunDef] = Set(), - - // Cegis related options - cegisUseOptTimeout: Boolean = true, - cegisUseVanuatoo : Boolean = false, - cegisMaxSize: Int = 7 - -) diff --git a/src/main/scala/leon/synthesis/Synthesizer.scala b/src/main/scala/leon/synthesis/Synthesizer.scala deleted file mode 100644 index 0484fe57d02cef117003a4f732c33e58b7418859..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/Synthesizer.scala +++ /dev/null @@ -1,196 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis - -import purescala.Expressions.Choose -import purescala.Definitions._ -import purescala.ExprOps._ -import purescala.DefOps._ -import purescala.ScalaPrinter -import solvers._ -import leon.utils._ - -import scala.concurrent.duration._ - -import synthesis.strategies._ - -class Synthesizer(val context : LeonContext, - val program: Program, - val ci: SourceInfo, - val settings: SynthesisSettings) { - - val reporter = context.reporter - - lazy val sctx = new SynthesisContext(context, settings, ci.fd, program) - - implicit val debugSection = leon.utils.DebugSectionSynthesis - - def getSearch: Search = { - val strat0 = new CostBasedStrategy(context, settings.costModel) - - val strat1 = if (settings.manualSearch.isDefined) { - new ManualStrategy(context, settings.manualSearch, strat0) - } else { - strat0 - } - - val strat2 = settings.searchBound match { - case Some(b) => - BoundedStrategy(strat1, b) - case None => - strat1 - } - - new Search(context, ci, strat2) - } - - private var lastTime: Long = 0 - - def synthesize(): (Search, Stream[Solution]) = { - reporter.ifDebug { printer => - printer(ci.problem.eb.asString("Tests available for synthesis")(context)) - } - - val s = getSearch - - reporter.info(ASCIIHelpers.title(s"Synthesizing '${ci.fd.id}'")) - - val t = context.timers.synthesis.search.start() - - val sols = s.search(sctx) - - val diff = t.stop() - - lastTime = diff - - reporter.info("Finished in "+diff+"ms") - - - (s, sols) - } - - def validate(results: (Search, Stream[Solution]), allowPartial: Boolean): (Search, Stream[(Solution, Boolean)]) = { - val (s, sols) = results - - val result = sols.map { - case sol if sol.isTrusted => - (sol, Some(true)) - case sol => - validateSolution(s, sol, 5.seconds) - } - - // Print out report for synthesis, if necessary - reporter.ifDebug { printer => - import java.text.SimpleDateFormat - import java.util.Date - - val categoryName = ci.fd.getPos.file.toString.split("/").dropRight(1).lastOption.getOrElse("?") - val benchName = categoryName+"."+ci.fd.id.name - val time = lastTime/1000.0 - - val defs = visibleDefsFrom(ci.fd)(program).collect { - case cd: ClassDef => 1 + cd.fields.size - case fd: FunDef => 1 + fd.params.size + formulaSize(fd.fullBody) - } - - val psize = defs.sum - - val (size, calls, proof) = result.headOption match { - case Some((sol, trusted)) => - val expr = sol.toSimplifiedExpr(context, program, ci.fd) - val pr = trusted match { - case Some(true) => "✓" - case Some(false) => "✗" - case None => "?" - } - (formulaSize(expr), functionCallsOf(expr).size, pr) - case _ => - (0, 0, "F") - } - - val date = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(new Date()) - - val fw = new java.io.FileWriter("synthesis-report.txt", true) - - try { - fw.write(f"$date: $benchName%-50s | $psize%4d | $size%4d | $calls%4d | $proof%7s | $time%2.1f \n") - } finally { - fw.close - } - }(DebugSectionReport) - - (s, if (result.isEmpty && allowPartial) { - Stream((new PartialSolution(s.strat, true).getSolutionFor(s.g.root), false)) - } else { - // Discard invalid solutions - result collect { - case (sol, Some(true)) => (sol, true) - case (sol, None) => (sol, false) - } - }) - } - - def validateSolution(search: Search, sol: Solution, timeout: Duration): (Solution, Option[Boolean]) = { - import verification.VerificationPhase._ - import verification.VerificationContext - - val timer = context.timers.synthesis.validation - timer.start() - - reporter.info("Solution requires validation") - - val (npr, fd) = solutionToProgram(sol) - - val solverf = SolverFactory.getFromSettings(context, npr).withTimeout(timeout) - - try { - val vctx = new VerificationContext(context, npr, solverf) - val vcs = generateVCs(vctx, List(fd)) - val vcreport = checkVCs(vctx, vcs, stopWhen = _.isInvalid) - - if (vcreport.totalValid == vcreport.totalConditions) { - (sol, Some(true)) - } else if (vcreport.totalValid + vcreport.totalUnknown == vcreport.totalConditions) { - reporter.warning("Solution may be invalid:") - (sol, None) - } else { - reporter.error("Solution was invalid:") - reporter.error(ScalaPrinter(fd)) - reporter.error(vcreport.summaryString) - (new PartialSolution(search.strat, false).getSolutionFor(search.g.root), Some(false)) - } - } finally { - timer.stop() - solverf.shutdown() - } - } - - // Returns the new program and the new functions generated for this - def solutionToProgram(sol: Solution): (Program, FunDef) = { - // We replace the choose with the body of the synthesized solution - - val solutionExpr = sol.toSimplifiedExpr(context, program, ci.fd) - - val transformer = funDefReplacer { - case fd if fd eq ci.fd => - val nfd = fd.duplicate() - nfd.fullBody = replace(Map(ci.source -> solutionExpr), nfd.fullBody) - (fd.body, fd.postcondition) match { - case (Some(Choose(pred)), None) => - nfd.postcondition = Some(pred) - case _ => - } - Some(nfd) - case _ => None - } - val npr = transformProgram(transformer, program) - - (npr, transformer.transform(ci.fd)) - } - - def shutdown(): Unit = { - sctx.solverFactory.shutdown() - } -} - diff --git a/src/main/scala/leon/synthesis/Witnesses.scala b/src/main/scala/leon/synthesis/Witnesses.scala deleted file mode 100644 index 98b14905e9f740baea0bddad25aa96ebefd1cb0b..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/Witnesses.scala +++ /dev/null @@ -1,50 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon.synthesis - -import leon.purescala.Common.Identifier -import leon.purescala._ -import Types._ -import Extractors._ -import Expressions.Expr -import PrinterHelpers._ - -object Witnesses { - - abstract class Witness extends Expr with Extractable with PrettyPrintable { - val getType = BooleanType - override def isSimpleExpr = true - } - - case class Guide(e: Expr) extends Witness { - def extract: Option[(Seq[Expr], Seq[Expr] => Expr)] = Some((Seq(e), (es: Seq[Expr]) => Guide(es.head))) - - override def printWith(implicit pctx: PrinterContext): Unit = { - p"⊙{$e}" - } - } - - case class Terminating(fi: Expr) extends Witness { - def extract: Option[(Seq[Expr], Seq[Expr] => Expr)] = Some(( Seq(fi), { case Seq(fi) => Terminating(fi) })) - - override def printWith(implicit pctx: PrinterContext): Unit = { - p"↓$fi" - } - } - - case class Hint(e: Expr) extends Witness { - def extract: Option[(Seq[Expr], Seq[Expr] => Expr)] = Some(( Seq(e), { case Seq(e) => Hint(e) })) - - override def printWith(implicit pctx: PrinterContext): Unit = { - p"谶$e" - } - } - - case class Inactive(i: Identifier) extends Witness { - def extract: Option[(Seq[Expr], Seq[Expr] => Expr)] = Some((Seq(), _ => this )) - override def printWith(implicit pctx: PrinterContext): Unit = { - p"inactive($i)" - } - - } -} diff --git a/src/main/scala/leon/synthesis/disambiguation/ExamplesAdder.scala b/src/main/scala/leon/synthesis/disambiguation/ExamplesAdder.scala deleted file mode 100644 index 8f0eb70cd9f9939e2b82f16b3bb12e8827b0ea92..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/disambiguation/ExamplesAdder.scala +++ /dev/null @@ -1,117 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis -package disambiguation - -import purescala.Types.FunctionType -import purescala.Common.{FreshIdentifier, Identifier} -import purescala.Constructors.{ and, tupleWrap } -import purescala.Definitions.{ FunDef, Program, ValDef } -import purescala.ExprOps -import purescala.Extractors.TopLevelAnds -import purescala.Expressions._ -import leon.utils.Simplifiers - -/** - * @author Mikael - */ -object ExamplesAdder { - def replaceGenericValuesByVariable(e: Expr): (Expr, Map[Expr, Expr]) = { - var assignment = Map[Expr, Expr]() - var extension = 'a' - var id = "" - (ExprOps.postMap({ expr => expr match { - case g@GenericValue(tpe, index) => - val newIdentifier = FreshIdentifier(tpe.id.name.take(1).toLowerCase() + tpe.id.name.drop(1) + extension + id, tpe.id.getType) - if(extension != 'z' && extension != 'Z') - extension = (extension.toInt + 1).toChar - else if(extension == 'z') // No more than 52 generic variables in practice? - extension = 'A' - else { - if(id == "") id = "1" else id = (id.toInt + 1).toString - } - - val newVar = Variable(newIdentifier) - assignment += g -> newVar - Some(newVar) - case _ => None - } })(e), assignment) - } -} - -class ExamplesAdder(ctx0: LeonContext, program: Program) { - import ExamplesAdder._ - var _removeFunctionParameters = false - - def setRemoveFunctionParameters(b: Boolean) = { _removeFunctionParameters = b; this } - - /** Accepts the nth alternative of a question (0 being the current one) */ - def acceptQuestion[T <: Expr](fd: FunDef, q: Question[T], alternativeIndex: Int): Unit = { - val newIn = tupleWrap(q.inputs) - val newOut = if(alternativeIndex == 0) q.current_output else q.other_outputs(alternativeIndex - 1) - addToFunDef(fd, Seq((newIn, newOut))) - } - - private def filterCases(cases: Seq[MatchCase]) = cases.filter(c => c.optGuard != Some(BooleanLiteral(false))) - - def addToExpr(_post: Expr, id: Identifier, inputVariables: Expr, newCases: Seq[MatchCase]): Expr = { - val post= Simplifiers.bestEffort(ctx0, program)(_post) - if(purescala.ExprOps.exists(_.isInstanceOf[Passes])(post)) { - post match { - case TopLevelAnds(exprs) => - val i = exprs.lastIndexWhere { x => x match { - case Passes(in, out, cases) if in == inputVariables && out == Variable(id) => true - case _ => false - } } - if(i == -1) { - Lambda(Seq(ValDef(id)), and(post, Passes(inputVariables, Variable(id), newCases))) - //ctx0.reporter.info("No top-level passes in postcondition, adding it: " + fd) - } else { - val newPasses = exprs(i) match { - case Passes(in, out, cases) => - Passes(in, out, (filterCases(cases) ++ newCases).distinct ) - case _ => ??? - } - val newPost = and(exprs.updated(i, newPasses) : _*) - Lambda(Seq(ValDef(id)), newPost) - //ctx0.reporter.info("Adding the example to the passes postcondition: " + fd) - } - } - } else { - Lambda(Seq(ValDef(id)), and(post, Passes(inputVariables, Variable(id), newCases))) - //ctx0.reporter.info("No passes in postcondition, adding it:" + fd) - } - - - } - - /** Adds the given input/output examples to the function definitions */ - def addToFunDef(fd: FunDef, examples: Seq[(Expr, Expr)]): Unit = { - val params = if(_removeFunctionParameters) fd.params.filter(x => !x.getType.isInstanceOf[FunctionType]) else fd.params - val inputVariables = tupleWrap(params.map(p => Variable(p.id): Expr)) - val newCases = examples.map{ case (in, out) => exampleToCase(in, out) } - fd.postcondition match { - case Some(Lambda(Seq(ValDef(id)), post)) => - fd.postcondition = Some(addToExpr(post, id, inputVariables, newCases)) - case None => - val id = FreshIdentifier("res", fd.returnType, false) - fd.postcondition = Some(addToExpr(BooleanLiteral(true), id, inputVariables, newCases)) - } - fd.body match { // TODO: Is it correct to discard the choose construct inside the body? - case Some(Choose(Lambda(Seq(ValDef(id)), bodyChoose))) => - fd.body = Some(Choose(addToExpr(bodyChoose, id, inputVariables, newCases))) - case _ => - } - } - - private def exampleToCase(in: Expr, out: Expr): MatchCase = { - val (inPattern, inGuard) = ExprOps.expressionToPattern(in) - if(inGuard == BooleanLiteral(true)) { - MatchCase(inPattern, None, out) - } else /*if (in == in_raw) { } *else*/ { - val id = FreshIdentifier("out", in.getType, true) - MatchCase(WildcardPattern(Some(id)), Some(Equals(Variable(id), in)), out) - } - } - } diff --git a/src/main/scala/leon/synthesis/disambiguation/InputCoverage.scala b/src/main/scala/leon/synthesis/disambiguation/InputCoverage.scala deleted file mode 100644 index 77d66a592c782c5f6a86dbe0e4f85a22d763d67a..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/disambiguation/InputCoverage.scala +++ /dev/null @@ -1,378 +0,0 @@ -package leon -package synthesis.disambiguation - -import purescala.Expressions._ -import purescala.ExprOps -import purescala.Constructors._ -import purescala.Extractors._ -import purescala.Types.{ TupleType, BooleanType} -import purescala.Common.{Identifier, FreshIdentifier} -import purescala.Definitions.{FunDef, Program, ValDef} -import purescala.DefOps -import scala.collection.mutable.ListBuffer -import leon.purescala.Definitions.TypedFunDef -import leon.verification.VerificationContext -import leon.verification.VerificationPhase -import leon.solvers._ -import scala.concurrent.duration._ -import leon.verification.VCStatus -import leon.verification.VCResult -import leon.evaluators.AbstractEvaluator -import java.util.IdentityHashMap -import leon.utils.Position -import scala.collection.JavaConversions._ - -case class InputNotCoveredException(msg: String, lineExpr: Identifier) extends Exception(msg) - -/** - * @author Mikael - * If possible, synthesizes a set of inputs for the function so that they cover all parts of the function. - * - * @param fds The set of functions to cover - * @param fd The calling function - */ -class InputCoverage(fd: FunDef, fds: Set[FunDef])(implicit c: LeonContext, p: Program) { - - /** If set, performs a cleaning up step to cover the whole function */ - var minimizeExamples: Boolean = true - - /** Triggers mapping recording between each expression in the bodies of the functions and the example which triggers its computation. */ - def recordMapping() = { doRecordMapping = true; this } - - /** Returns a mapping from the expressions in the source code to the set of inputs that cover them - * Call `.recordMapping()` before calling `.result()` to have it filled. */ - def getRecordMapping(): IdentityHashMap[Expr, Set[Seq[Expr]]] = { - var res = new IdentityHashMap[Expr, Set[Seq[Expr]]]() - for((expr, setFlags) <- flagMapping) { - res.put(expr, setFlags.flatMap(flag => - if(!recordedMapping.containsKey(flag)) { - throw new Exception(s"No expression associated to $flag") - } else - recordedMapping.get(flag))) - } - res - } - - /** If true, for each expression in the bodies of the functions, record which example triggers it.*/ - private var doRecordMapping: Boolean = false - - /** Mapping from identifiers of lines of code to the input that are covering them.*/ - private var recordedMapping = new IdentityHashMap[Identifier, Set[Seq[Expr]]]() - - /** Mapping between every expression and its flag that triggers its computation.*/ - private var flagMapping = new IdentityHashMap[Expr, Set[Identifier]]() - - /** If the sub-branches contain identifiers, it returns them unchanged. - Else it creates a new boolean indicating this branch. */ - private def wrapBranch(e: (Expr, Option[Seq[Identifier]])): (Expr, Some[Seq[Identifier]]) = e._2 match { - case None => - val lineColumnId = FreshIdentifier("l" + Math.abs(e._1.getPos.line) + "c" + Math.abs(e._1.getPos.col), BooleanType).copiedFrom(e._1) - recordMapping(lineColumnId, e._1) - (tupleWrap(Seq(e._1, Variable(lineColumnId))), Some(Seq(lineColumnId))) - case Some(Seq()) => - val lineColumnId = FreshIdentifier("l" + Math.abs(e._1.getPos.line) + "c" + Math.abs(e._1.getPos.col), BooleanType).copiedFrom(e._1) - - def putInLastBody(expr: Expr): Expr = { recordMapping(lineColumnId, expr); expr } match { - case Tuple(Seq(v, prev_b)) => Tuple(Seq(v, or(prev_b, lineColumnId.toVariable))).copiedFrom(expr) - case LetTuple(binders, value, body) => letTuple(binders, value, putInLastBody(body)).copiedFrom(expr) - case MatchExpr(scrut, Seq(MatchCase(TuplePattern(optId, binders), None, rhs))) => - MatchExpr(scrut, Seq(MatchCase(TuplePattern(optId, binders), None, putInLastBody(rhs)))).copiedFrom(expr) - case FunctionInvocation(tfd, args) => - val default_type = tfd.returnType match { - case TupleType(Seq(t, BooleanType)) => - t - case t => throw new Exception(s"Did not expect function $tfd to return something else than a tuple of a type and a boolean. Got $t") - } - val arg_id = FreshIdentifier("arg", default_type) - val arg_b = FreshIdentifier("bd", BooleanType) - letTuple(Seq(arg_id, arg_b), expr, Tuple(Seq(arg_id.toVariable, or(arg_b.toVariable, lineColumnId.toVariable)))) - case _ => throw new Exception(s"Unexpected branching case: $expr") - } - (putInLastBody(e._1), Some(Seq(lineColumnId))) - case e_2: Some[_] => - // No need to introduce a new boolean since if one of the child booleans is true, then this IfExpr has been called. - (e._1, e_2) - } - - private def recordMapping(flag: Identifier, expr: Expr): Unit = { - if(doRecordMapping) { - val flags: Set[Identifier] = Option(flagMapping.get(expr)).getOrElse(Set()) + flag - flagMapping.put(expr, flags) - } - } - - private def recordMapping(flags: Option[Seq[Identifier]], expr: Expr): Unit = { - if(doRecordMapping) { - for{flagsC <- flags - id <- flagsC} { - recordMapping(id, expr) - } - } - expr match { - case i: IfExpr => - case m: MatchExpr => - case Operator(es, builder) => - for(e <- es) recordMapping(flags, e) - } - } - - private def recordExample(input: Seq[Expr], flags: Set[Identifier]): Unit = { - if(doRecordMapping) { - for{flag <- flags} { - val newMappings: Set[Seq[Expr]] = Option(recordedMapping.get(flag)).getOrElse(Set()) + input - recordedMapping.put(flag, newMappings) - } - } - } - - /** Returns true if there are some branching to monitor in the expression */ - private def hasConditionals(e: Expr) = { - ExprOps.exists{ case i:IfExpr => true case m: MatchExpr => true case f: FunctionInvocation if fds(f.tfd.fd) || f.tfd.fd == fd => true case _ => false}(e) - } - - /** Merges two set of identifiers. - * None means that the attached expression is the original one, - * Some(ids) means that it has been augmented with booleans and ids are the "monitoring" boolean flags. */ - private def merge(a: Option[Seq[Identifier]], b: Option[Seq[Identifier]]) = { - (a, b) match { - case (None, None) => None - case (a, None) => a - case (None, b) => b - case (Some(a), Some(b)) => Some(a ++ b) - } - } - - /** For each branch in the expression, adds a boolean variable such that the new type of the expression is (previousType, Boolean) - * If no variable is output, then the type of the expression is not changed. - * If the expression is augmented with a boolean, returns the list of boolean variables which appear in the expression */ - private def markBranches(e: Expr): (Expr, Option[Seq[Identifier]]) = - if(!hasConditionals(e)) (e, None) else e match { - case IfExpr(cond, thenn, elze) => - val (c1, cv1) = markBranches(cond) - val (t1, tv1) = wrapBranch(markBranches(thenn)) - val (e1, ev1) = wrapBranch(markBranches(elze)) - recordMapping(tv1, thenn) - recordMapping(ev1, elze) - val mergedIds = merge(merge(cv1, tv1), ev1) - recordMapping(mergedIds, cond) - recordMapping(mergedIds, e) - cv1 match { - case None => - (IfExpr(c1, t1, e1).copiedFrom(e), mergedIds) - case cv1 => - val arg_id = FreshIdentifier("arg", BooleanType) - val arg_b = FreshIdentifier("bc", BooleanType) - (letTuple(Seq(arg_id, arg_b), c1, IfExpr(Variable(arg_id), t1, e1).copiedFrom(e)).copiedFrom(e), mergedIds) - } - case m@MatchExpr(scrut, cases) => - val (c, ids) = markBranches(ExprOps.matchToIfThenElse(m)) // And replace the last error else statement with a dummy flag. - def replaceFinalElse(e: Expr): (Expr, Identifier)= e match { - case IfExpr(c1, t1, e1) => - val (newElse, id) = replaceFinalElse(e1) - (IfExpr(c1, t1, newElse).copiedFrom(e), id) - case Tuple(Seq(Error(tpe, msg), Variable(i))) => - (Tuple(Seq(Error(tpe, msg), BooleanLiteral(false))), i) - } - val (new_c, id_to_remove) = replaceFinalElse(c) - (new_c, ids.map(_.filter(_ != id_to_remove))) - case Or(args) if args.length >= 1 => - val c = args.foldRight[Expr](BooleanLiteral(false).copiedFrom(e)){ - case (arg, prev) => - IfExpr(arg, BooleanLiteral(true), prev).copiedFrom(e) - } - markBranches(c.copiedFrom(e)) - case And(args) if args.length >= 1 => - val c = args.foldRight[Expr](BooleanLiteral(true).copiedFrom(e)){ - case (arg, prev) => - IfExpr(Not(arg), BooleanLiteral(false), prev).copiedFrom(e) - } - markBranches(c.copiedFrom(e)) - - case Operator(lhsrhs, builder) => - // The exprBuilder adds variable definitions needed to compute the arguments. - val (exprBuilder, children, tmpIds, ids) = (((e: Expr) => e, ListBuffer[Expr](), ListBuffer[Identifier](), None: Option[Seq[Identifier]]) /: lhsrhs) { - case ((exprBuilder, children, tmpIds, ids), arg) => - val (arg1, argv1) = markBranches(arg) - recordMapping(argv1, arg) - if(argv1.nonEmpty || isNewFunCall(arg1)) { - val arg_id = FreshIdentifier("arg", arg.getType) - val arg_b = FreshIdentifier("ba", BooleanType) - val f = (body: Expr) => letTuple(Seq(arg_id, arg_b), arg1, body).copiedFrom(body) - (exprBuilder andThen f, children += Variable(arg_id), tmpIds += arg_b, merge(ids, argv1)) - } else { - (exprBuilder, children += arg, tmpIds, ids) - } - } - recordMapping(ids, e) - e match { - case FunctionInvocation(tfd@TypedFunDef(fd, targs), args) if fds(fd) => - val new_fd = wrapFunDef(fd) - // Is different since functions will return a boolean as well. - tmpIds match { - case Seq() => - val funCall = FunctionInvocation(TypedFunDef(new_fd, targs).copiedFrom(tfd), children).copiedFrom(e) - (exprBuilder(funCall), if(new_fd != fd) merge(Some(Seq()), ids) else ids) - case idvars => - val res_id = FreshIdentifier("res", fd.returnType) - val res_b = FreshIdentifier("bb", BooleanType) - val finalIds = idvars :+ res_b - val finalExpr = - tupleWrap(Seq(Variable(res_id), or(finalIds.map(Variable(_)): _*))).copiedFrom(e) - val funCall = letTuple(Seq(res_id, res_b), FunctionInvocation(TypedFunDef(new_fd, targs), children).copiedFrom(e), finalExpr).copiedFrom(e) - (exprBuilder(funCall), merge(Some(Seq()), ids)) - } - case _ => - tmpIds match { - case Seq() => - (e, ids) - case idvars => - val finalExpr = tupleWrap(Seq(builder(children).copiedFrom(e), or(idvars.map(Variable): _*))).copiedFrom(e) - (exprBuilder(finalExpr), ids) - } - } - } - - /** The cache of transforming functions.*/ - private var cache = Map[FunDef, FunDef]() - - /** Records all booleans serving to identify which part of the code should be executed.*/ - private var booleanFlags = Seq[Identifier]() - - /** Augment function definitions which should have boolean markers, leave others untouched. */ - private def wrapFunDef(f: FunDef): FunDef = { - if(!(cache contains f)) { - if(fds(f) || f == fd) { - val new_fd = f.duplicate(returnType = TupleType(Seq(f.returnType, BooleanType))) - new_fd.body = None - cache += f -> new_fd - val (new_body, bvars) = wrapBranch(markBranches(f.body.get)) // Recursive call. - new_fd.body = Some(new_body) // TODO: Handle post-condition if they exist. - booleanFlags ++= bvars.get - } else { - cache += f -> f - } - } - cache(f) - } - - /** Returns true if the expression is a function call with a new function already. */ - private def isNewFunCall(e: Expr): Boolean = e match { - case FunctionInvocation(TypedFunDef(fd, targs), args) => - cache.values.exists { f => f == fd } - case _ => false - } - - /** Returns a stream of covering inputs for the function `f`, - * such that if `f` is evaluated on each of these inputs, all parts of `{ f } U fds` will have been executed at least once. - * - * The number of expressions in each element is the same as the number of arguments of `f` */ - def result(): Stream[Seq[Expr]] = { - cache = Map() - booleanFlags = Seq() - recordedMapping = new IdentityHashMap[Identifier, Set[Seq[Expr]]]() - flagMapping = new IdentityHashMap[Expr, Set[Identifier]]() - /* Approximative algorithm Algorithm: - * In innermost branches, replace each result by (result, bi) where bi is a boolean described later. - * def f(x: Int, z: Boolean): (Int, Boolean) = - * x match { - * case 0 | 1 => - * if(z) { - * ({val (r, b1) = if(z) (x, bt) else (-x, be) - * val (res, b) = f(r, z) - * (res, b || b1) - * case _ => - * val (res, b) = if(z) (x, b2) - * else (f(x-1,!z)+f(x-2,!z), b3) - * (res*f(x-1,!z), b) - * } - * Let B be the set of bi. - * For each b in B - * Set all b' in B to false except b to true - * Find a counter-example. - * yield it in the stream. - */ - - /* Change all return types to accommodate the new covering boolean */ - - val transformer = DefOps.funDefReplacer({ - (f: FunDef) => - if((fds contains f) || f == fd) { - val new_fd = wrapFunDef(f) - if(f == fd) { - val h = FreshIdentifier("h", TupleType(Seq(fd.returnType, BooleanType)), false) - new_fd.postcondition = Some(Lambda(Seq(ValDef(h)), Not(TupleSelect(Variable(h), 2)))) - } - Some(new_fd) - } else - None - }, { - (fi: FunctionInvocation, newFd: FunDef) => //if(cache contains fi.tfd.fd) { - Some(TupleSelect(FunctionInvocation(newFd.typed, fi.args), 1)) - //} else None - }) - - val program = DefOps.transformProgram(transformer, p) - - val start_fd = transformer.transform(fd) - - var coveredBooleans = Set[Identifier]() - // For each boolean flag, set it to true, and find a counter-example which should reach this line. - // For each new counter-example, abstract evaluate the original function to remove booleans which have been reached. - val covering_examples = - for(bvar <- booleanFlags.toStream if !coveredBooleans(bvar)) yield { - val transformer2 = DefOps.funDefReplacer { - (f: FunDef) => - if(ExprOps.exists { case Variable(id) => booleanFlags contains id case _ => false }(f.fullBody)) { - val new_f = f.duplicate() - new_f.fullBody = ExprOps.preMap { - case Variable(id) if id == bvar => Some(BooleanLiteral(true)) - case Variable(id) if booleanFlags contains id => Some(BooleanLiteral(false)) - case _ => None - }(f.fullBody) - Some(new_f) - } else None - - } - val program2 = DefOps.transformProgram(transformer2, program) - val start_fd2 = transformer2.transform(start_fd) - val tfactory = SolverFactory.getFromSettings(c, program2).withTimeout(10.seconds) - - val vctx = new VerificationContext(c, program2, tfactory) - val vcs = VerificationPhase.generateVCs(vctx, Seq(start_fd2)) - VerificationPhase.checkVCs(vctx, vcs).results(vcs(0)) match { - case Some(VCResult(VCStatus.Invalid(model), _, _)) => - val finalExprs = fd.paramIds.map(model) - val whoIsEvaluated = functionInvocation(start_fd, finalExprs) - val ae = new AbstractEvaluator(c, p) - val coveredFlagsByCounterExample = ae.eval(whoIsEvaluated).result match { - case Some((Tuple(Seq(_, booleans)), _)) => - val targettedIds = ExprOps.collect{ case Variable(id) => Set(id) case _ => Set[Identifier]() }(booleans) - coveredBooleans ++= targettedIds - targettedIds - case _ => - Set(bvar) - } - //println(s"Recording the example $finalExprs covering $coveredFlagsByCounterExample") - recordExample(finalExprs, coveredFlagsByCounterExample) - finalExprs -> coveredFlagsByCounterExample - case e => - throw InputNotCoveredException("Could not find an input to cover the line: " + bvar.getPos.line + " (at col " + bvar.getPos.col + ")\n" + e.getOrElse(""), bvar) - } - } - - val covering_examples2 = if(minimizeExamples) { - // Remove examples whose coverage set is included in some other example. - for((covering_example, flags_met) <- covering_examples - if !covering_examples.exists{ - case (other_example, other_flags) => - other_example != covering_example && - flags_met.subsetOf(other_flags) - } - ) yield covering_example - } else { - covering_examples.map(_._1) - } - - covering_examples2 filter (_.nonEmpty) - } -} diff --git a/src/main/scala/leon/synthesis/disambiguation/InputPatternCoverage.scala b/src/main/scala/leon/synthesis/disambiguation/InputPatternCoverage.scala deleted file mode 100644 index c4eb0967d30e8d016ce4542aa59679706a891a81..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/disambiguation/InputPatternCoverage.scala +++ /dev/null @@ -1,281 +0,0 @@ -package leon -package synthesis.disambiguation - -import purescala.Expressions._ -import purescala.ExprOps -import purescala.Constructors._ -import purescala.Extractors._ -import purescala.Types._ -import purescala.Common._ -import purescala.Definitions.{FunDef, Program, TypedFunDef, ValDef} -import purescala.DefOps -import scala.collection.mutable.ListBuffer -import leon.LeonContext -import leon.purescala.Definitions._ -import leon.verification.VerificationContext -import leon.verification.VerificationPhase -import leon.solvers._ -import scala.concurrent.duration._ -import leon.verification.VCStatus -import leon.verification.VCResult -import leon.evaluators.AbstractEvaluator -import java.util.IdentityHashMap -import leon.utils.Position -import scala.collection.JavaConversions._ -import leon.evaluators.DefaultEvaluator -import leon.grammars.ValueGrammar -import leon.datagen.GrammarDataGen - -class InputPatternCoverageException(msg: String) extends - Exception(msg) - -case class PatternNotSupportedException(p: Pattern) extends - InputPatternCoverageException(s"The pattern $p is not supported for coverage.") - -case class PatternExtractionErrorException(p: Pattern, msg: String) extends - InputPatternCoverageException(s"The pattern $p cause problem during extraction: "+msg) - -/** - * @author Mikael - * If possible, synthesizes a set of inputs for the function so that they cover all parts of the function. - * Requires the function to only have pattern matchings without conditions, functions with single variable calls. - * - * @param fds The set of functions to cover - * @param fd The calling function - */ -class InputPatternCoverage(fd: TypedFunDef)(implicit c: LeonContext, p: Program) { - - def result(): Stream[Seq[Expr]] = coverFunDef(fd, Covered(Set(), Set()), None) - - type PathSegment = Either[Identifier, Int] - case class Path(orig: Identifier, route: List[PathSegment]) { - def :+(i: Identifier) = Path(orig, route :+ Left(i)) - def :+(i: Int) = Path(orig, route :+ Right(i)) - - def getType: TypeTree = { - (orig.getType /: route) { - case (TupleType(targs), Right(index)) => targs(index - 1) - case (cct@CaseClassType(ccd, targs), Left(id)) => cct.fieldsTypes(ccd.selectorID2Index(id)) - case e => throw new InputPatternCoverageException("Could not get type of path " + this + " because inconsistency " + e) - } - } - } - implicit def idToPath(i: Identifier): Path = Path(i, Nil) - - private case class Covered(alreadyCovered: Set[TypedFunDef], alreadyCoveredLambdas: Set[Lambda]) { - def apply(t: TypedFunDef) = alreadyCovered(t) - def apply(l: Lambda) = alreadyCoveredLambdas(l) - def +(t: TypedFunDef) = this.copy(alreadyCovered = alreadyCovered + t) - def +(l: Lambda) = this.copy(alreadyCoveredLambdas = alreadyCoveredLambdas + l) - } - - private def argsToMap(paramIds: Seq[Identifier], args: Option[Seq[Expr]]) = - args.map(args => paramIds.zip(args).toMap).getOrElse(Map.empty) - - private def coverFunLike(paramIds: Seq[Identifier], body: Expr, args: Option[Seq[Expr]], covered: Covered): Stream[Seq[Expr]] = { - val mapping = coverExpr(paramIds, body, covered, argsToMap(paramIds, args)) - leon.utils.StreamUtils.cartesianMap(mapping) map { mapping => - paramIds.map(i => convert(i)(mapping).getOrElse(a(i.getType))) - } - } - - private def uniqueValueOf(m: Map[Path, Stream[Expr]], id: Identifier): Expr = { - m.get(id) match { - case None => throw new InputPatternCoverageException(s"Did not find $id in $m") - case Some(s) => s.headOption.getOrElse(throw new InputPatternCoverageException(s"Empty value for $id in $m")) - } - } - - private def coverFunDef(f: TypedFunDef, covered: Covered, args: Option[Seq[Expr]]): Stream[Seq[Expr]] = if(covered(f)) { - Stream(f.paramIds.map(x => a(x.getType))) - } else { - f.body match { - case Some(body) => coverFunLike(f.paramIds, body, args, covered + f) - case None => - if(f.fd == p.library.mapMkString.get) { - args match { - case Some(Seq(mp, sepk_v, sep_entry, fkey, fvalue)) => - mp.getType match { - case MapType(keyType, valueType) => - val key1 = FreshIdentifier("k", keyType) - val key2 = FreshIdentifier("k", keyType) - val value1 = FreshIdentifier("v", valueType) - val value2 = FreshIdentifier("v", valueType) - val dumbbody = - tupleWrap(Seq(application(fkey, Seq(Variable(key1))), - application(fkey, Seq(Variable(key2))), - Variable(f.paramIds(1)), - Variable(f.paramIds(2)), - application(fvalue, Seq(Variable(value1))), - application(fvalue, Seq(Variable(value2))))) - val mapping = coverExpr(f.paramIds, dumbbody, covered + f, argsToMap(f.paramIds, args)) - val key1v = uniqueValueOf(mapping, key1) - val key2v = uniqueValueOf(mapping, key2) - val key2vUnique = if(key1v == key2v) { - all(keyType).filter(e => e != key1v).headOption.getOrElse(key2v) - } else key2v - val value1v = uniqueValueOf(mapping, value1) - val value2v = uniqueValueOf(mapping, value2) - val remaining_mapping = mapping - key1 - key2 - value1 - value2 - val firstValue = FiniteMap(Map(key1v -> value1v, key2vUnique -> value2v), keyType, valueType) - leon.utils.StreamUtils.cartesianMap(remaining_mapping) map { mapping => - firstValue +: - f.paramIds.tail.map(i => convert(i)(mapping).getOrElse(a(i.getType))) - } - case _ => throw new InputPatternCoverageException(s"Wrong usage of $f - no map type") - } - case _ => throw new InputPatternCoverageException(s"Wrong usage of $f") - } - }/* else if(f.fd == p.library.bagMkString.get) { - - } else if(f.fd == p.library.setMkString.get) { - - } */else throw new InputPatternCoverageException(s"empty body for function $f") - } - } - - private def coverLambda(l: Lambda, covered: Covered, args: Option[Seq[Expr]]): Stream[Seq[Expr]] = if(covered(l)) { - Stream(l.args.map(x => a(x.getType))) - } else { - coverFunLike(l.args.map(_.id), l.body, args, covered + l) - } - - private def mergeCoverage(a: Map[Path, Stream[Expr]], b: Map[Path, Stream[Expr]]): - Map[Path, Stream[Expr]] = { - if((a.keys.toSet intersect b.keys.toSet).nonEmpty) - throw new InputPatternCoverageException("Variable used twice: " + (a.keys.toSet intersect b.keys.toSet)+"\n"+a+"\n"+b) - a ++ b - } - - object Reconstructor { - def unapply(e: Expr): Option[Path] = e match { - case Variable(id) => Some(id) - case CaseClassSelector(cct, Reconstructor(path), ccid) => - Some(path :+ ccid) - case TupleSelect(Reconstructor(path), index) => - Some(path :+ index) - case _ => - None - } - } - - def compose(oldBindings: Map[Identifier, Expr], newBindings: Seq[Expr]): Seq[Expr] = { - newBindings.map { case Variable(id) => oldBindings.getOrElse(id, Variable(id)) case e => e } - } - - /** Map of g.left.symbol to the stream of expressions it could be assigned to */ - private def coverExpr(inputs: Seq[Identifier], e: Expr, covered: Covered, bindings: Map[Identifier, Expr]): Map[Path, Stream[Expr]] = { - println(s"Covering expr (inputs = $inputs, bindings = $bindings): \n$e") - val res : Map[Path, Stream[Expr]] = - e match { - case IfExpr(cond, thenn, elze) => throw new Exception("Requires only match/case pattern, got "+e) - case MatchExpr(Reconstructor(path), cases) if inputs.nonEmpty && inputs.headOption == Some(path.orig) => - val pathType = path.getType - val coveringExprs = cases.map(coverMatchCase(pathType, _, covered, bindings)) - val interleaved = leon.utils.StreamUtils.interleave[Expr](coveringExprs) - Map(path -> interleaved) - case FunctionInvocation(tfd@TypedFunDef(fd, targs), args @ (Reconstructor(path)+:tail)) => - Map(path -> coverFunDef(tfd, covered, Some(compose(bindings, args))).map(_.head)) - - case Reconstructor(path) if inputs.nonEmpty && inputs.headOption == Some(path.orig) => - Map(path -> Stream(a(path.getType))) - - case Application(Variable(f), args @ (Reconstructor(path)+:tail)) => - bindings.get(f) match { - case Some(l@Lambda(Seq(ValDef(i)), body))=> - Map(path -> coverLambda(l, covered, Some(compose(bindings, args))).map(_.head)) - case e => throw new InputPatternCoverageException("Wrong callback, expected lambda, got " + e + " (bindings = " + bindings + ")" ) - } - case Operator(lhsrhs, builder) => - if(lhsrhs.length == 0) { - Map() - } else { - lhsrhs.map(child => coverExpr(inputs, child, covered, bindings)).reduce(mergeCoverage) - } - } - res - } - - /** Returns an instance of the given type. Makes sure maps, sets and bags have at least two elements.*/ - private def a(t: TypeTree): Expr = { - t match { - case MapType(keyType, valueType) => - val pairs = all(keyType).take(2).toSeq.zip(all(valueType).take(2).toSeq).toMap - FiniteMap(pairs, keyType, valueType) - case SetType(elemType) => - val elems = all(elemType).take(2).toSet - FiniteSet(elems, elemType) - case BagType(elemType) => - val elems = all(elemType).take(2).toSeq - FiniteBag(elems.zip(1 to 2 map (IntLiteral)).toMap, elemType) - case _ => all(t).head - } - } - - /** Returns all instance of the given type */ - private def all(t: TypeTree): Stream[Expr] = { - val i = FreshIdentifier("i", t) - val datagen = new GrammarDataGen(new DefaultEvaluator(c, p), ValueGrammar) - val enumerated_inputs = datagen.generateMapping(Seq(i), BooleanLiteral(true), 10, 10).toStream - enumerated_inputs.toStream.map(_.head._2) - } - - def convert(topLevel: Path)(implicit binders: Map[Path, Expr]): Option[Expr] = { - binders.get(topLevel) match { - case None => - topLevel.getType match { - case cct@CaseClassType(ccd, targs) => - val args = ccd.fieldsIds.map(id => - (convert(topLevel :+ id) match { case Some(e) => e case None => return None }): Expr ) - Some(CaseClass(cct, args)) - case tt@TupleType(targs) => - val args = (1 to targs.length).map(index => - (convert(topLevel :+ index) match { case Some(e) => e case None => return None }): Expr ) - Some(Tuple(args.toSeq)) - case _ => None - } - case e => e - } - } - - // TODO: Take other constraints into account: Missed previous patterns ? - private def coverPattern(p: Pattern, inputType: TypeTree, binders: Map[Path, Expr], covered: Covered): Expr = p match { - case CaseClassPattern(binder, ct, subs) => - if(subs.exists{ case e: WildcardPattern => false case _ => true }) { - throw PatternNotSupportedException(p) - } - val args = subs.collect { case e: WildcardPattern => e } - CaseClass(ct, args.zipWithIndex.map{ - case (WildcardPattern(Some(o)), i) => - convert(o)(binders).getOrElse((throw PatternExtractionErrorException(p, s"Not able to recover value of ${o}")): Expr) - case (WildcardPattern(_), i) => a(ct.fieldsTypes(i)) - }) - case TuplePattern(binder, subs) => - if(subs.exists{ case e: WildcardPattern => false case _ => true }) { - throw PatternNotSupportedException(p) - } - val args = subs.collect { case e: WildcardPattern => e } - Tuple(args.zipWithIndex.map{ - case (WildcardPattern(Some(o)), i) => convert(o)(binders).getOrElse((throw PatternNotSupportedException(p)): Expr) - case (WildcardPattern(_), i) => - inputType match { - case TupleType(targs) => - a(targs(i)) - case _ => throw PatternNotSupportedException(p) - } - }) - case InstanceOfPattern(binder, ct) => - binder.map(b => convert(b)(binders).getOrElse((throw PatternNotSupportedException(p)): Expr)).getOrElse(a(ct)) - case LiteralPattern(ob, value) => value - case WildcardPattern(ob) => a(inputType) - case _ => throw PatternNotSupportedException(p) - } - - private def coverMatchCase(inputType: TypeTree, m: MatchCase, covered: Covered, bindings: Map[Identifier, Expr]) = m match { - case MatchCase(pattern, guard, rhs) => - val variables = pattern.binders.toSeq - val binders = coverExpr(variables, rhs, covered, bindings) - val cartesian = leon.utils.StreamUtils.cartesianMap(binders) - cartesian.map(binders => coverPattern(pattern, inputType, binders, covered)) - } -} \ No newline at end of file diff --git a/src/main/scala/leon/synthesis/disambiguation/InputRecCoverage.scala b/src/main/scala/leon/synthesis/disambiguation/InputRecCoverage.scala deleted file mode 100644 index 4a28941edc433cc76ab48f0dda01261f36fc5ade..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/disambiguation/InputRecCoverage.scala +++ /dev/null @@ -1,340 +0,0 @@ -package leon -package synthesis -package disambiguation - -import leon.LeonContext -import leon.evaluators.DefaultEvaluator -import leon.evaluators.DefaultRecContext -import purescala.Common.{Identifier, FreshIdentifier} -import purescala.Constructors._ -import purescala.DefOps -import purescala.Definitions.FunDef -import purescala.Definitions.Program -import purescala.ExprOps -import purescala.Definitions.{CaseClassDef, _} -import purescala.Expressions._ -import purescala.Extractors._ -import purescala.Types._ -import purescala.TypeOps -import leon.datagen.GrammarDataGen -import leon.grammars.ValueGrammar -import leon.purescala.TypeOps -import leon.evaluators.AbstractEvaluator -import leon.purescala.TypeOps -import scala.collection.mutable.Queue -import java.util.IdentityHashMap - -object InputRecCoverage { - class W[T <: Expr](val e: T) { - def somewhere(f: Expr): Boolean = e eq f - // To ensure that the "equals" method of exprs is not used during the computation. - } - - /** Returns true if the expression contains strings or integers */ - def isMarkedWithStringOrInt(e: Expr) = - ExprOps.exists{ - case StringLiteral(_) => true - case InfiniteIntegerLiteral(_) => true - case IntLiteral(_) => true - case _ => false - }(e) -} - -class InputRecCoverage(fd: FunDef, fds: Set[FunDef])(implicit ctx: LeonContext, program: Program) { - import InputRecCoverage._ - val inputCoverage = new InputCoverage(fd, fds) - - /** Flattens a string concatenation into a list of expressions */ - def flatten(s: StringConcat): List[Expr] = s match { - case StringConcat(a: StringConcat, b: StringConcat) => flatten(a) ++ flatten(b) - case StringConcat(a, b: StringConcat) => a :: flatten(b) - case StringConcat(a: StringConcat, b) => flatten(a) ++ List(b) - case StringConcat(a, b) => List(a, b) - } - - /** Creates a string concatenation from a list of expression. The list should have size >= 2.*/ - def rebuild(e: List[Expr]): StringConcat = e match { - case List(a, b) => StringConcat(a, b) - case a::l => StringConcat(a, rebuild(l)) - case _ => throw new Exception("List with less than 2 elements?!") - } - - /** Flattens a string concatenation into a list of expressions */ - def permutations(s: StringConcat): Stream[StringConcat] = { - flatten(s).permutations.toStream.tail.map(x => rebuild(x).copiedFrom(s)) - } - - def allConcatenations(): Set[W[StringConcat]] = { - var concatenations = Set[W[StringConcat]]() - - def collectConcatenations(e: Expr, keepConcatenations: Boolean = true): Unit = e match { - case s@StringConcat(a, b) => if(keepConcatenations) concatenations += new W(s) // Stop at the first concatenation. - collectConcatenations(a, false) - collectConcatenations(b, false) - case Operator(es, builder) => for(e <- es) - collectConcatenations(e, true) - } - - collectConcatenations(fd.body.get) - for(fd <- fds) { - collectConcatenations(fd.body.get) - } - concatenations - } - - /** Assert function for testing if .result() is rec-covering. */ - def assertIsRecCovering(inputs: Stream[Seq[Expr]]): Unit = { - // Contains the set of top-level concatenations in all programs. - val concatenations = allConcatenations() - - // For each of these concatenations, we check that there is at least one input which if evaluated while it is reverse, the result would be different. - // If not, we expand the covering example. - - val originalEvaluator = new DefaultEvaluator(ctx, program) - val originalOutputs: Map[Seq[Expr], Expr] = inputs.map(input => input -> originalEvaluator.eval(functionInvocation(fd, input)).result.get).toMap - for(stringConcatW <- concatenations; stringConcat = stringConcatW.e; stringConcatReversed <- permutations(stringConcat)) { - - val transformer = DefOps.funDefReplacer { - (f: FunDef) => - if(f.body.exists(body => ExprOps.exists(stringConcat eq _)(body))) { - val new_f = f.duplicate() - new_f.body = f.body.map(body => ExprOps.preMap(e => if(stringConcat eq e) { Some(stringConcatReversed)} else None)(body)) - Some(new_f) - } else None - } - val permuttedProgram = DefOps.transformProgram(transformer, program) - val modifiedEvaluator = new DefaultEvaluator(ctx, permuttedProgram) - - val oneInputMakesItDifferent = inputs.exists(input => - modifiedEvaluator.eval(functionInvocation(transformer.transform(fd), input)).result.get != originalOutputs(input)) - - assert(oneInputMakesItDifferent, "No input made the change " + stringConcat + " -> " + stringConcatReversed + " produce a different result") - } - } - - /** Returns a stream of rec-covering inputs for the function `f` to cover all functions in `{ f } U fds`. - * - * This means that for each concatenation operation, there is an input example which can differentiate between this concatenation and any of its permutations if possible. - **/ - def result(): Stream[Seq[Expr]] = { - var identifiableInputs = Map[Seq[Expr], Seq[Expr]]() - var inputs = inputCoverage.recordMapping().result().map{input => - val res = input.map(QuestionBuilder.makeGenericValuesUnique) - identifiableInputs += input -> res - res - } - - // Contains the set of top-level concatenations in all programs. - val concatenations = allConcatenations() - - // For each of these concatenations, we check that there is at least one input which if evaluated while it is reverse, the result would be different. - // If not, we expand the covering example. - - val originalEvaluator = new DefaultEvaluator(ctx, program) - var originalOutputs: Map[Seq[Expr], Expr] = inputs.map(input => input -> originalEvaluator.eval(functionInvocation(fd, input)).result.get).toMap - for(stringConcatW <- concatenations; stringConcat = stringConcatW.e; stringConcatReversed <- permutations(stringConcat)) { - //val (permuttedProgram, idMap, fdMap, cdMap) = DefOps.replaceFunDefs(program)({ - val transformer = DefOps.funDefReplacer { - (f: FunDef) => - if(f.body.exists(body => ExprOps.exists(stringConcat eq _)(body))) { - val new_f = f.duplicate() - new_f.body = f.body.map(body => ExprOps.preMap(e => if(stringConcat eq e) { Some(stringConcatReversed)} else None)(body)) - Some(new_f) - } else None - } - val permuttedProgram = DefOps.transformProgram(transformer, program) - val modifiedEvaluator = new DefaultEvaluator(ctx, permuttedProgram) - - val oneInputMakesItDifferent = inputs.exists(input => - modifiedEvaluator.eval(functionInvocation(transformer.transform(fd), input)).result.get != originalOutputs(input)) - - if(!oneInputMakesItDifferent) { - // Now we need to find an input which makes a difference if possible, when modified. - println("No input make this concatenation differ in output when permutted: " + stringConcat + " -> " + stringConcatReversed) - println(" mappings:\n" + inputs.map(input => input + "->" + originalEvaluator.eval(functionInvocation(fd, input)).result.get).mkString("\n")) - println("New mappings:\n" + inputs.map(input => input + "->" + modifiedEvaluator.eval(functionInvocation(transformer.transform(fd), input)).result.get).mkString("\n")) - // First, make all its terminals (strings and numbers) unique. - val covering = inputCoverage.getRecordMapping() - val coveringInputs = Option(covering.get(stringConcat)).getOrElse(Set()).map(x => identifiableInputs.getOrElse(x, x)) - //println(s"Input that can cover $stringConcat are " + coveringInputs.mkString(", ")) - - val input = coveringInputs.head - var mappingAtStringConcatOpt: Option[AbstractEvaluator#RC] = None - val please = new AbstractEvaluator(ctx, program) { - override def e(expr: Expr)(implicit rctx: RC, gctx: GC) = { - if(expr eq stringConcat) { - //println(s"Found string concat $stringConcat. Mapping = " + rctx) - rctx.mappings.values.find{v => - !input.exists(i => ExprOps.exists{ case e if e eq v => true case _ => false}(i)) - } match { - case None => - case Some(v) => - throw new Exception(s"Value not present from input ($input): $v") - } - mappingAtStringConcatOpt = Some(rctx) - } - super.e(expr) - } - } - please eval functionInvocation(fd, input) - - // Now we now for each term of the stringConcat which is the sub-expression of the input which is used for computation, - // and we can replace it if - // 1) The function call is more general, => we make sure to insert a string or number which make it identifiable - // 2) If not possible, the function call is more general, but inserting strings or numbers is not feasible. - // 3) The function call can handle only finitely number of values => We duplicate the input to cover all these possible values. - - val mappingAtStringConcat = mappingAtStringConcatOpt.getOrElse(throw new Exception(s"Did not find an execution context for $stringConcat when evaluating $input")) - - val members = flatten(stringConcat) - println(s"For the input $input and concatenation $stringConcat") - var newInput = Seq(input) - val toReplace = new IdentityHashMap[Expr, List[Expr]]() - for(m <- members) { - m match { - case FunctionInvocation(TypedFunDef(fd, targs), Seq(Variable(id))) => - mappingAtStringConcat.mappings.get(id) match { - case Some(expr) => - //println(s"Mapping $m is computed on $expr (with formula ${mappingAtStringConcat.mappingsAbstract(id)})") - // expr is a sub-expression of input. - // fd is the function called with the argument expr. - if(!isMarkedWithStringOrInt(expr)) { - val mainArg = fd.paramIds(0) - markWithStringOrInt(mainArg.getType, tupleWrap(input)) match { // TODO: Enumerate all possible values if not markable. - case Some(expr_marked) => - println(s"Expr unisized: $expr_marked") - if(!input.exists(i => ExprOps.exists{ case e if e eq expr => true case _ => false}(i))) { - throw new Exception(s"Did not find $expr (${expr.##}) in $input") - } - toReplace.put(expr, List(expr_marked)) - case None => - println("Not possible to mark the string, reverting to enumeration strategies") - // If there is a finite number of values at some place, replace with each of them. - val exprs = if(TypeOps.typeCardinality(mainArg.getType).nonEmpty) { - println("Finite enumeration") - all(mainArg.getType).toList - } else { - println("Infinite enumeration. Taking 5 values") - // Else try to find other values which make them identifiable at some point. - all(mainArg.getType).take(5).toList - } - println(s"$expr -> $exprs") - toReplace.put(expr, exprs) - } - } // Else nothing to do, there is already a unique identifier to expr. - - case None => throw new Exception(s"No binding at evaluation time for $id ... something is wrong.") - } - case IntegerToString(Variable(id)) => // Nothing to do, already identified - case Int32ToString(Variable(id)) => // Nothing to do, already identified - case BooleanToString(Variable(id)) => // TODO: Enumerate all possible values - case CharToString(Variable(id)) => // Nothing to do, already identified - case Variable(id) => // TODO: Enumerate all possible values of this id or have an identifiable one ? - case StringLiteral(k) => // Nothing to do, already identified - case e => throw new Exception(s"Rec-coverage is not supported when concatenating something else than fun(id) and constants; got $m") - } - } - if(!toReplace.isEmpty()) { - val new_inputs: Seq[Seq[Expr]] = - leon.utils.SeqUtils.cartesianProduct(input.map(i => ExprOps.postFlatmap{ case e if toReplace.containsKey(e) => Some(toReplace.get(e)) case _ => None}(i))) - println(s"Added new input: ${new_inputs.mkString("\n")}") - for(new_input <- new_inputs) { - originalOutputs += new_input -> originalEvaluator.eval(functionInvocation(fd, new_input)).result.get - newInput = newInput :+ new_input - } - inputs = inputs.flatMap{ i => if(i == input) new_inputs else Some(input) }.distinct - println(s"inputs: ${inputs.mkString("\n")}") - } else { - println(s"Did not find anything to identify the expr $stringConcat") - } - // Now we find which arguments are given to the function - - } // Else that's fine, the example covers the case. - } - - inputs - } - - - /** Returns an instance of the given type */ - private def a(t: TypeTree): Expr = { - all(t).head - } - - /** Returns all instance of the given type */ - private def all(t: TypeTree): Stream[Expr] = { - val i = FreshIdentifier("i", t) - val datagen = new GrammarDataGen(new DefaultEvaluator(ctx, program), ValueGrammar) - val enumerated_inputs = datagen.generateMapping(Seq(i), BooleanLiteral(true), 10, 10).toStream - enumerated_inputs.toStream.map(_.head._2) - } - - /** Returns an expression of the given type that contains at least a String, an Integer or an Int32 if possible. If not, returns None. */ - private def buildMarkableValue(e: TypeTree): Option[Expr] = { - var markableValues = Map[TypeTree, Expr]() - - val toTest = Queue[TypeTree](e) - // Build all the types to test - var finalTypes = Set[TypeTree]() - - while(toTest.nonEmpty) { - val v = toTest.dequeue() - v match { - case cct@CaseClassType(ccd, targs) => - finalTypes += v - for(tpe <- cct.fieldsTypes if !(finalTypes contains tpe) && !(toTest contains tpe)) { - toTest.enqueue(tpe) - } - case act@AbstractClassType(acd, targs) => - finalTypes += v - for(tpe <- act.knownCCDescendants if !(finalTypes contains tpe) && !(toTest contains tpe)) { - toTest.enqueue(tpe) - } - case StringType | Int32Type | IntegerType => - markableValues += v -> a(v) - case _ => - } - } - - // Read all the types until all types have been flagged markable and non-markable. - // All remaining are non-markable. - - var modified = true - while(modified && !(markableValues contains e)) { - modified = finalTypes find { tpe => - tpe match { - case cct@CaseClassType(ccd, targs) => - cct.fields.find(t => markableValues contains t.getType) match { - case Some(fieldId) => - markableValues += tpe -> CaseClass(cct, cct.fields.map(tpField => - if(tpField == fieldId) markableValues(fieldId.getType) else a(tpField.getType))) - finalTypes -= tpe - true - case None => - false - } - case act@AbstractClassType(acd, targs) => - act.knownCCDescendants.find(cc => markableValues contains cc) match { - case None => false - case Some(cc) => - markableValues += tpe -> markableValues(cc) - finalTypes -= tpe - true - } - case _ => false - } - } match { - case Some(_) => true - case None => false - } - } - markableValues.get(e) - } - - private def markWithStringOrInt(e: TypeTree, originalExpr: Expr): Option[Expr] = { - buildMarkableValue(e).map{ value => - val Tuple(Seq(_, res)) = QuestionBuilder.makeGenericValuesUnique(Tuple(Seq(originalExpr, value))) - res - } - } -} diff --git a/src/main/scala/leon/synthesis/disambiguation/Question.scala b/src/main/scala/leon/synthesis/disambiguation/Question.scala deleted file mode 100644 index 985efaed19d16350f248d3147d9b2a2010d7e91c..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/disambiguation/Question.scala +++ /dev/null @@ -1,11 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis.disambiguation - -import purescala.Expressions.Expr - -/** - * @author Mikael - */ -case class Question[T <: Expr](inputs: Seq[Expr], current_output: T, other_outputs: List[T]) \ No newline at end of file diff --git a/src/main/scala/leon/synthesis/disambiguation/QuestionBuilder.scala b/src/main/scala/leon/synthesis/disambiguation/QuestionBuilder.scala deleted file mode 100644 index 387bfc30858722fa781c2ca3ceccb658e5a9a2a1..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/disambiguation/QuestionBuilder.scala +++ /dev/null @@ -1,256 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis.disambiguation - -import datagen.GrammarDataGen -import synthesis.Solution -import evaluators.DefaultEvaluator -import purescala.Expressions._ -import purescala.ExprOps -import purescala.Types._ -import purescala.Common.Identifier -import purescala.Definitions.{FunDef, Program} -import purescala.DefOps -import grammars._ -import solvers.ModelBuilder -import scala.collection.mutable.ListBuffer -import evaluators.AbstractEvaluator -import scala.annotation.tailrec -import leon.evaluators.EvaluationResults -import leon.purescala.Common - -object QuestionBuilder { - /** Sort methods for questions. You can build your own */ - trait QuestionSortingType { - def apply[T <: Expr](e: Question[T]): Int - } - object QuestionSortingType { - case object IncreasingInputSize extends QuestionSortingType { - def apply[T <: Expr](q: Question[T]) = q.inputs.map(i => ExprOps.count(e => 1)(i)).sum - } - case object DecreasingInputSize extends QuestionSortingType{ - def apply[T <: Expr](q: Question[T]) = -IncreasingInputSize(q) - } - } - // Add more if needed. - - /** Sort methods for question's answers. You can (and should) build your own. */ - abstract class AlternativeSortingType[T <: Expr] extends Ordering[T] { self => - /** Prioritizes this comparison operator against the second one. */ - def &&(other: AlternativeSortingType[T]): AlternativeSortingType[T] = new AlternativeSortingType[T] { - def compare(e: T, f: T): Int = { - val ce = self.compare(e, f) - if(ce == 0) other.compare(e, f) else ce - } - } - } - object AlternativeSortingType { - /** Presents shortest alternatives first */ - case class ShorterIsBetter[T <: Expr]()(implicit c: LeonContext) extends AlternativeSortingType[T] { - def compare(e: T, f: T) = e.asString.length - f.asString.length - } - /** Presents balanced alternatives first */ - case class BalancedParenthesisIsBetter[T <: Expr]()(implicit c: LeonContext) extends AlternativeSortingType[T] { - def convert(e: T): Int = { - val s = e.asString - var openP, openB, openC = 0 - for(c <- s) c match { - case '(' if openP >= 0 => openP += 1 - case ')' => openP -= 1 - case '{' if openB >= 0 => openB += 1 - case '}' => openB -= 1 - case '[' if openC >= 0 => openC += 1 - case ']' => openC -= 1 - case _ => - } - Math.abs(openP) + Math.abs(openB) + Math.abs(openC) - } - def compare(e: T, f: T): Int = convert(e) - convert(f) - } - } - - /** Specific enumeration of strings, which can be used with the QuestionBuilder#setValueEnumerator method */ - object SpecialStringValueGrammar extends SimpleExpressionGrammar { - def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Prod] = t match { - case StringType => - List( - terminal(StringLiteral("")), - terminal(StringLiteral("a")), - terminal(StringLiteral("\"'\n\t")), - terminal(StringLiteral("Lara 2007")) - ) - case _ => ValueGrammar.computeProductions(t) - } - } - - /** Make all generic values uniquely identifiable among the final string (no value is a substring of another if possible) - * Duplicate generic values are not suitable for disambiguating questions since they remove an order. */ - def makeGenericValuesUnique(a: Expr): Expr = { - //println("Going to make this value unique:" + a) - var genVals = Set[Expr with Terminal](StringLiteral("")) - def freshenValue(g: Expr with Terminal): Option[Expr with Terminal] = g match { - case g: GenericValue => Some(GenericValue(g.tp, g.id + 1)) - case StringLiteral(s) => - val newS = if(s == "") "a" else s - val i = s.lastIndexWhere { c => c < 'a' || c > 'z' } - val prefix = s.take(i+1) - val suffix = s.drop(i+1) - val res = if(suffix.forall { _ == 'z' }) { - prefix + "a" + ("a" * suffix.length) - } else { - val last = suffix.reverse.prefixLength { _ == 'z' } - val digit2increase = suffix(suffix.length - last - 1) - prefix + (digit2increase.toInt + 1).toChar + ("a" * last) - } - Some(StringLiteral(res)) - case InfiniteIntegerLiteral(i) => Some(InfiniteIntegerLiteral(i+1)) - case IntLiteral(i) => if(i == Integer.MAX_VALUE) None else Some(IntLiteral(i+1)) - case CharLiteral(c) => if(c == Char.MaxValue) None else Some(CharLiteral((c+1).toChar)) - case otherLiteral => None - } - @tailrec @inline def freshValue(g: Expr with Terminal): Expr with Terminal = { - if(genVals contains g) - freshenValue(g) match { - case None => g - case Some(v) => freshValue(v) - } - else { - genVals += g - g - } - } - ExprOps.postMap{ e => e match { - case g:Expr with Terminal => - Some(freshValue(g)) - case _ => None - }}(a) - } - -} - -/** - * Builds a set of disambiguating questions for the problem - * - * {{{ - * def f(input: input.getType): T = - * [element of r.solution] - * }}} - * - * @tparam T A subtype of Expr that will be the type used in the Question[T] results. - * @param input The identifier of the unique function's input. Must be typed or the type should be defined by setArgumentType - * @param filter A function filtering which outputs should be considered for comparison. - * It takes as input the sequence of outputs already considered for comparison, and the new output. - * It should return Some(result) if the result can be shown, and None else. - * - */ -class QuestionBuilder[T <: Expr]( - input: Seq[Identifier], - solutions: Stream[Solution], - filter: (Seq[T], Expr) => Option[T], - originalFun: Option[FunDef] = None)(implicit c: LeonContext, p: Program) { - import QuestionBuilder._ - private var _argTypes = input.map(_.getType) - private var _questionSorMethod: QuestionSortingType = QuestionSortingType.IncreasingInputSize - private var _alternativeSortMethod: AlternativeSortingType[T] = AlternativeSortingType.BalancedParenthesisIsBetter[T]() && AlternativeSortingType.ShorterIsBetter[T]() - private var solutionsToTake = 15 - private var expressionsToTake = 15 // TODO: At least cover the program ! - private var keepEmptyAlternativeQuestions: T => Boolean = Set() - private var value_enumerator: ExpressionGrammar = ValueGrammar - private var expressionsToTestFirst: Option[Stream[Seq[Expr]]] = None - - /** Sets the way to sort questions during enumeration. Not used at this moment. See [[QuestionSortingType]] */ - def setSortQuestionBy(questionSorMethod: QuestionSortingType) = { _questionSorMethod = questionSorMethod; this } - /** Sets the way to sort alternatives. See [[AlternativeSortingType]] */ - def setSortAlternativesBy(alternativeSortMethod: AlternativeSortingType[T]) = { _alternativeSortMethod = alternativeSortMethod; this } - /** Sets the argument type. Not needed if the input identifier is already assigned a type. */ - def setArgumentType(argTypes: List[TypeTree]) = { _argTypes = argTypes; this } - /** Sets the number of solutions to consider. Default is 15 */ - def setSolutionsToTake(n: Int) = { solutionsToTake = n; this } - /** Sets the number of expressions to consider. Default is 15 */ - def setExpressionsToTake(n: Int) = { expressionsToTake = n; this } - /** Sets if when there is no alternative, the question should be kept. */ - def setKeepEmptyAlternativeQuestions(b: T => Boolean) = {keepEmptyAlternativeQuestions = b; this } - /** Sets the way to enumerate expressions */ - def setValueEnumerator(v: ExpressionGrammar) = value_enumerator = v - /** Sets the expressions to test first */ - def setExpressionsToTestFirst(s: Option[Stream[Seq[Expr]]]) = expressionsToTestFirst = s - - private def run(s: Solution, elems: Seq[(Identifier, Expr)]): Option[Expr] = { - val newProgram = DefOps.addFunDefs(p, s.defs, p.definedFunctions.head) - s.ifOnFunDef(originalFun.getOrElse(new FunDef(Common.FreshIdentifier("useless"), Nil, Nil, UnitType))){ - val e = new AbstractEvaluator(c, newProgram) - val model = new ModelBuilder - model ++= elems - val modelResult = model.result() - val evaluation = e.eval(s.term, modelResult) - for{x <- evaluation.result - res = x._1 - simp = ExprOps.simplifyArithmetic(res)} - yield simp - } - } - - /** Given an input, the current output, a list of alternative programs, compute a question if there is any. */ - def computeQuestion(possibleInput: Seq[(Identifier, Expr)], currentOutput: T, alternatives: List[Solution]): Option[Question[T]] = { - val alternative_outputs = (ListBuffer[T](currentOutput) /: alternatives) { (prev, alternative) => - run(alternative, possibleInput) match { - case Some(alternative_output) if alternative_output != currentOutput => - filter(prev, alternative_output) match { - case Some(alternative_output_filtered) => - prev += alternative_output_filtered - case _ => prev - } - case _ => prev - } - }.drop(1).toList.distinct - if(alternative_outputs.nonEmpty || keepEmptyAlternativeQuestions(currentOutput)) { - Some(Question(possibleInput.map(_._2), currentOutput, alternative_outputs.sortWith((e,f) => _alternativeSortMethod.compare(e, f) <= 0))) - } else { - None - } - } - - def getExpressionsToTestFirst(): Option[Stream[Seq[(Identifier, Expr)]]] = expressionsToTestFirst map { inputs => - val inputs_generics = inputs.map(y => y.map(x => makeGenericValuesUnique(x))) - inputs_generics.map(in => input zip in) - } - - def getAllPossibleInputs(expressionsToTake: Int): Stream[Seq[(Identifier, Expr)]]= { - val datagen = new GrammarDataGen(new DefaultEvaluator(c, p), value_enumerator) - val enumerated_inputs = datagen.generateMapping(input, BooleanLiteral(true), expressionsToTake, expressionsToTake) - .map(inputs => - inputs.map(id_expr => - (id_expr._1, makeGenericValuesUnique(id_expr._2)))).toStream - enumerated_inputs - } - - def inputsToQuestions(solution: Stream[Solution], inputs: Stream[Seq[(Identifier, Expr)]]): Stream[Question[T]] = { - val solution = solutions.head - val alternatives = solutions.drop(1).take(solutionsToTake).toList - for { - possibleInput <- inputs - currentOutputNonFiltered <- run(solution, possibleInput) - currentOutput <- filter(Seq(), currentOutputNonFiltered) - question <- computeQuestion(possibleInput, currentOutput, alternatives) - } yield question - } - - /** Returns a list of input/output questions to ask to the user. */ - def resultAsStream(): Stream[Question[T]] = { - if(solutions.isEmpty) return Stream.empty - - getExpressionsToTestFirst() foreach { inputs_generics => - val res = inputsToQuestions(solutions, inputs_generics) - if(res.nonEmpty) return res - } - - val enumerated_inputs = getAllPossibleInputs(expressionsToTake) - val questions = inputsToQuestions(solutions, enumerated_inputs) - questions - }/* - - def result(): List[Question[T]] = { - resultAsStream().toList.sortBy(_questionSorMethod(_)) - }*/ -} diff --git a/src/main/scala/leon/synthesis/grammars/ContextGrammar.scala b/src/main/scala/leon/synthesis/grammars/ContextGrammar.scala deleted file mode 100644 index 3076cd12ef4db29cf5304e8607bf907e9fa4c671..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/grammars/ContextGrammar.scala +++ /dev/null @@ -1,434 +0,0 @@ -package leon -package synthesis -package grammars - -import scala.collection.mutable.ListBuffer - -class ContextGrammar[SymbolTag, TerminalData] { - /** A tagged symbol */ - abstract class Symbol { def tag: SymbolTag } - /** A tagged non-terminal */ - case class NonTerminal(tag: SymbolTag, vcontext: List[NonTerminal] = Nil, hcontext: List[Symbol] = Nil) extends Symbol - /** A tagged terminal */ - case class Terminal(tag: SymbolTag)(val terminalData: TerminalData) extends Symbol - - /** All possible right-hand-side of rules */ - case class Expansion(ls: List[List[Symbol]]) { - def symbols = ls.flatten - def ++(other: Expansion): Expansion = Expansion((ls ++ other.ls).distinct) - def ++(other: Option[Expansion]): Expansion = other match {case Some(e) => this ++ e case None => this } - def contains(s: Symbol): Boolean = { - ls.exists { l => l.exists { x => x == s } } - } - /** Maps symbols to other symbols */ - def map(f: Symbol => Symbol): Expansion = { - Expansion(ls.map(_.map(f))) - } - /** Maps symbols with the left context as second argument */ - def mapLeftContext(f: (Symbol, List[Symbol]) => Symbol): Expansion = { - Expansion(ls.map(l => (l.foldLeft(ListBuffer[Symbol]()){ - case (l: ListBuffer[Symbol], s: Symbol) => l += f(s, l.toList) - }).toList )) - } - } - - // Shortcuts and helpers. - - /** An expansion which has only 1 non-terminal for each right-hand side */ - object VerticalRHS{ - def apply(symbols: Seq[NonTerminal]) = Expansion(symbols.map(symbol => List(symbol)).toList) - def unapply(e: Expansion): Option[List[NonTerminal]] = - if(e.ls.size >= 1 && e.ls.forall(rhs => rhs.length == 1 && rhs.head.isInstanceOf[NonTerminal])) - Some(e.ls.map(rhs => rhs.head.asInstanceOf[NonTerminal])) - else None - } - /** An expansion which has only 1 right-hand-side with one terminal and non-terminals */ - object HorizontalRHS { - def apply(tag: Terminal, symbols: Seq[NonTerminal]) = Expansion(List(tag :: symbols.toList)) - def unapply(e: Expansion): Option[(Terminal, List[NonTerminal])] = e.ls match { - case List((t: Terminal)::(nts : List[_])) - if nts.forall(_.isInstanceOf[NonTerminal]) => - Some((t, nts.map(_.asInstanceOf[NonTerminal]))) - case _ => None - } - } - /** An expansion which has only 1 terminal and only 1 right-hand-side */ - object TerminalRHS { - def apply(t: Terminal) = HorizontalRHS(t, Nil) - def unapply(e: Expansion): Option[Terminal] = e.ls match { - case List(List(t: Terminal)) => Some(t) - case _ => None - } - } - - /** An expansion which starts with terminals and ends with something like HorizontalRHS or VerticalRHS */ - object AugmentedTerminalsRHS { - def apply(t: Seq[Terminal], e: Expansion) = Expansion(t.map (x => List(x: Symbol)).toList) ++ e - def unapply(e: Expansion): Option[(List[Terminal], Expansion)] = e.ls match { - case Nil => None - case _::Nil => Some((Nil, e)) - case head::tail => - (head, unapply(Expansion(tail))) match { - case (_, None) => None - case (List(t: Terminal), Some((ts, exp))) => Some((t::ts, exp)) - case (lnt@List(nt: NonTerminal), Some((Nil, exp))) => Some((Nil, Expansion(List(lnt)) ++ exp)) - case _ => None - } - } - } - - // Remove unreachable non-terminals - def clean(g: Grammar): Grammar = { - val reachable = leon.utils.fixpoint({ - (nt: Set[NonTerminal]) => - nt ++ (nt map (g.rules) flatMap (_.symbols) collect { case nt: NonTerminal => nt }) - }, 64)(g.startNonTerminals.toSet) - val nonReachable = g.nonTerminals.toSet -- reachable - val res = g.copy(rules = g.rules -- nonReachable) - res - } - - /** A grammar here has a start sequence instead of a start symbol */ - case class Grammar(start: Seq[Symbol], rules: Map[NonTerminal, Expansion]) { - /** Returns all non-terminals of the given grammar */ - def nonTerminals: Seq[NonTerminal] = { - (startNonTerminals ++ - (for{(lhs, rhs) <- rules; s <- Seq(lhs) ++ - (for(r <- rhs.symbols.collect{ case k: NonTerminal => k }) yield r)} yield s)).distinct - } - lazy val startNonTerminals: Seq[NonTerminal] = { - start.collect{ case k: NonTerminal => k } - } - - abstract class NonTerminalMapping { - // Conversion from old to new non-terminals to duplicate rules afterwards. - private lazy val originalMapping = Map[NonTerminal, List[NonTerminal]](startNonTerminals.map(x => x -> List(x)) : _*) - protected var mapping = originalMapping - // Resets the mapping transformation - def reset() = mapping = originalMapping - - def apply(elem: NonTerminal): Seq[NonTerminal] = mapping.getOrElse(elem, List(elem)) - - def maybeKeep(elem: NonTerminal): Seq[NonTerminal] = Nil - // Keeps expansion the same but applies the current mapping to all keys, creating only updates. - def mapKeys(rules: Map[NonTerminal, Expansion]) = { - val tmpRes2 = for{(lhs, expansion) <- rules.toSeq - new_lhs <- (maybeKeep(lhs) ++ mapping.getOrElse(lhs, Nil)).distinct - } yield { - lhs -> (new_lhs -> expansion) - } - val rulestmpRes2 = rules -- tmpRes2.map(_._1) - rulestmpRes2 ++ tmpRes2.map(_._2) - } - } - - def markovize_vertical_filtered(pred: NonTerminal => Boolean): Grammar = { - val nts = nonTerminals - val rulesSeq = rules.toSeq - def parents(nt: NonTerminal): Seq[NonTerminal] = { - rulesSeq.collect{ case (ntprev, expansion) if expansion.contains(nt) => ntprev } - } - object Mapping extends NonTerminalMapping { - mapping = Map[NonTerminal, List[NonTerminal]](startNonTerminals.map(x => x -> List(x)) : _*) - def updateMapping(nt: NonTerminal, topContext: List[NonTerminal]): NonTerminal = { - if(pred(nt)) { - val res = nt.copy(vcontext = topContext) - mapping += nt -> (res::mapping.getOrElse(nt, Nil)).distinct - res - } else nt - } - } - - val newRules = (for{ - nt <- nts - expansion = rules(nt) - } yield (nt -> (expansion.map{(s: Symbol) => s match { - case n:NonTerminal => Mapping.updateMapping(n, nt::nt.vcontext) - case e => e - }}))).toMap - - val newRules2 = Mapping.mapKeys(newRules) - Grammar(start, newRules2) - } - - /** Applies 1-markovization to the grammar (add 1 history to every node) */ - def markovize_vertical(): Grammar = { - markovize_vertical_filtered(_ => true) - } - - class MarkovizationContext(pred: NonTerminal => Boolean) { - val nts = nonTerminals - val rulesSeq = rules.toSeq - /** Set of abstract non-terminals */ - val ants = (nts filter { nt => - rules(nt) match { - case AugmentedTerminalsRHS(terminals, VerticalRHS(sons)) => pred(nt) - case _ => false - } - }).toSet - /** Set of case class non-terminals */ - val cnts = (nts filter { nt => // case class non terminals - rules(nt) match { - case AugmentedTerminalsRHS(terminals, HorizontalRHS(t, sons)) => true - case _ => false - } - }).toSet - var getAnts = Map[NonTerminal, NonTerminal]() - var getDesc = Map[NonTerminal, Set[NonTerminal]]() - nts foreach { nt => - rules(nt) match { - case AugmentedTerminalsRHS(terminals, VerticalRHS(sons)) => - sons.foreach{ son => - getAnts += son -> nt - } - getDesc += nt -> sons.toSet - case _ => false - } - } - def getTopmostANT(nt: NonTerminal): NonTerminal = if(getAnts contains(nt)) getTopmostANT(getAnts(nt)) else nt - def getAncestors(nt: NonTerminal): Set[NonTerminal] = getAnts.get(nt).map(getAncestors).getOrElse(Set.empty[NonTerminal]) + nt - val startANT = startNonTerminals.map(getTopmostANT).toSet - def getDescendants(nt: NonTerminal): Set[NonTerminal] = getDesc.get(nt).map((x: Set[NonTerminal]) => - x.flatMap((y: NonTerminal) => getDescendants(y) + y)).getOrElse(Set.empty[NonTerminal]) - } - - /** Perform horizontal markovization only on the provided non-terminals and their descendants. - * @param pred The predicate checking if non-terminals are concerned. - * @param recursive If the horizontal context is propagated to ancestors if they are on the RHS of their children. - */ - def markovize_horizontal_filtered(pred: NonTerminal => Boolean, recursive: Boolean): Grammar = { - var toDuplicate = Map[NonTerminal, Set[NonTerminal]]() - var originals = Map[NonTerminal, NonTerminal]() - def getOriginal(nt: NonTerminal): NonTerminal = { - originals.get(nt).map(nt2 => if(nt2 != nt) getOriginal(nt2) else nt2).getOrElse(nt) - } - val c = new MarkovizationContext(pred) { - def process_sequence(ls: Seq[Symbol]): List[Symbol] = { - val (_, res) = ((ListBuffer[Symbol](), ListBuffer[Symbol]()) /: ls) { - case ((lbold, lbnew), nt: NonTerminal) if pred(nt) => - val context_version = nt.copy(hcontext = lbold.toList) - toDuplicate += nt -> (toDuplicate.getOrElse(nt, Set.empty[NonTerminal]) + context_version) - if(context_version != nt) originals += context_version -> nt - for(descendant <- getDescendants(nt) if descendant != nt) { - val descendant_context_version = descendant.copy(hcontext = lbold.toList) - toDuplicate += descendant -> (toDuplicate.getOrElse(descendant, Set.empty[NonTerminal]) + descendant_context_version) - originals += descendant_context_version -> descendant - } - for(ascendant <- getAncestors(nt) if ascendant != nt) { - val acendant_context_version = ascendant.copy(hcontext = lbold.toList) - toDuplicate += ascendant -> (toDuplicate.getOrElse(ascendant, Set.empty[NonTerminal]) + acendant_context_version) - originals += acendant_context_version -> ascendant - } - (lbold += nt, lbnew += context_version) - case ((lbold, lbnew), s) => - (lbold += s, lbnew += s) - } - res.toList - } - - val newStart = process_sequence(start) - private val newRules = rules.map{ case (k, expansion) => - k -> (expansion match { - case AugmentedTerminalsRHS(terminals, VerticalRHS(children)) => - expansion - case AugmentedTerminalsRHS(terminals, HorizontalRHS(t, children)) => - val children_new = process_sequence(t +: children) // Identifies duplicates and differentiates them. - AugmentedTerminalsRHS(terminals, HorizontalRHS(t, children_new.tail.asInstanceOf[List[NonTerminal]])) - case _ => - expansion - }) - } - - val newRules2 = for{ - (k, v) <- newRules - kp <- (toDuplicate.getOrElse(k, Set()) + k) - } yield { - v match { - case AugmentedTerminalsRHS(terminals, VerticalRHS(children)) if toDuplicate.getOrElse(k, Set.empty).nonEmpty => - val newChildren = children.map(child => child.copy(hcontext = kp.hcontext)) - kp -> AugmentedTerminalsRHS(terminals, VerticalRHS(newChildren)) - case AugmentedTerminalsRHS(terminals, HorizontalRHS(t, children)) => - // Transmit the left context to all ancestors of this node. - val new_rhs = if(recursive) { - val ancestors = getAncestors(k) - val newChildren = children.map(child => - child match { case nt: NonTerminal if ancestors(getOriginal(nt)) => - nt.copy(hcontext = kp.hcontext) - case _ => child - } - ) - HorizontalRHS(t, newChildren) - } else v - kp -> new_rhs - case _ => - kp -> v - } - } - } - import c._ - clean(Grammar(newStart, newRules2.toMap)) - } - - /** Applies horizontal markovization to the grammar (add the left history to every node and duplicate rules as needed. - * Is idempotent. */ - def markovize_horizontal(): Grammar = { - markovize_horizontal_filtered(_ => true, false) - } - - /** Same as vertical markovization, but we add in the vertical context only the nodes coming from a "different abstract hierarchy". Top-level nodes count as a different hierarchy. - * Different abstract hierarchy means that nodes do not have the same ancestor. - * @param pred The top-most non-terminals to consider (i.e. abstract ones) - */ - def markovize_abstract_vertical_filtered(pred: NonTerminal => Boolean): Grammar = { - val c = new MarkovizationContext(pred) { - var toDuplicate = Map[NonTerminal, Set[NonTerminal]]() - for(t <- startNonTerminals) { - val topAnt = getTopmostANT(t) - toDuplicate += topAnt -> Set(topAnt) - } - val newRules = rules.map{ case (k, expansion) => - k -> expansion.map{ - case nt: NonTerminal => - if(ants(nt)) { - val antp = getTopmostANT(k) - val ancestors = getAncestors(nt) - val ancestorTop = getTopmostANT(nt) - if(antp == ancestorTop && !startANT(antp)) nt else { - for(ancestor <- ancestors) { - val ancestorCopy = ancestor.copy(vcontext = List(antp)) - getAnts += ancestorCopy -> ancestor - toDuplicate += ancestor -> (toDuplicate.getOrElse(ancestor, Set()) + ancestorCopy) - } - nt.copy(vcontext = List(antp)) - } - } else nt - case s => s - } - } - val newRules2 = for{ - (k, v) <- newRules - kp <- (toDuplicate.getOrElse(k, Set(k)) + k) - } yield { - kp -> v - } - //println("newRules2 = " + grammarToString(Grammar(start, newRules2))) - - def reportVContext(nt: NonTerminal, parentNt: NonTerminal, rules: Map[NonTerminal, Expansion]): NonTerminal = { - if((nt.vcontext.isEmpty || (getTopmostANT(parentNt) == getTopmostANT(nt) && parentNt.vcontext.nonEmpty)) && pred(getTopmostANT(nt))) { - val thecopy = nt.copy(vcontext = parentNt.vcontext) - if(!(rules contains thecopy)) { - getAnts += thecopy -> nt - toDuplicate += nt -> (toDuplicate.getOrElse(nt, Set()) + nt + thecopy) - } - thecopy - } else nt - } - val newRules3 = leon.utils.fixpoint((newRules: Map[NonTerminal, Expansion]) => { - toDuplicate = Map() - val res = for{ - (k, v) <- newRules - } yield { - v match { - case AugmentedTerminalsRHS(terminals, VerticalRHS(children)) => - k -> AugmentedTerminalsRHS(terminals, VerticalRHS(children.map(x => reportVContext(x, k, newRules)))) - case AugmentedTerminalsRHS(terminals, HorizontalRHS(t, children)) => - k -> AugmentedTerminalsRHS(terminals, HorizontalRHS(t, children.map(x => reportVContext(x, k, newRules)))) - case _ => k -> v - } - } - //println("newRules3 = " + grammarToString(Grammar(start, res))) - //println("toDuplicate = " + toDuplicate.map{ case (k, v) => symbolToString(k) + "->" + v.map(symbolToString)}) - val res2 = for{ - (k, v) <- res - kp <- toDuplicate.getOrElse(k, Set(k)) - } yield { - kp -> v - } - //println("newRules4 = " + grammarToString(Grammar(start, res2))) - res2}, 64)(newRules2) - } - import c._ - - clean(Grammar(start, newRules3)) - } - - /** Same as vertical markovization, but we add in the vertical context only the nodes coming from a "different abstract hierarchy" - * Different abstract hierarchy means that nodes do not have the same ancestor. - */ - def markovize_abstract_vertical(): Grammar = { - markovize_abstract_vertical_filtered(_ => true) - } - - /** More general form of markovization, which is similar to introducing states in a top-down tree transducer - * We duplicate all m non-terminals n times, and we replace each of them by one of their n variants. - * n == 0 yields nothing. - * n == 1 yields the original grammar - * n == 2 yield all grammars obtained from the original by duplicating each non-terminals and trying all variants. - **/ - def markovize_all(n: Int): Stream[Grammar] = { - def variant(nt: NonTerminal, i: Int) = { - nt.copy(vcontext = nt.vcontext ++ List.fill(i)(nt)) - } - val nonTerminalsRHS: Seq[NonTerminal] = { - (startNonTerminals ++ - (for{(lhs, rhs) <- rules - s <-rhs.symbols.collect{ case k: NonTerminal => k } } yield s)) - } - val nonTerminalsRHSNew: Seq[NonTerminal] = { - (startNonTerminals ++ - (for{i <- 0 until n // 0 keeps non-terminals the same. - (lhs, rhs) <- rules - s <-rhs.symbols.collect{ case k: NonTerminal => k } } yield s)) - } - val nonTerminalsRHSSet = nonTerminalsRHS.toSet - val variants = nonTerminalsRHSSet.map(nt => - nt -> (0 until n).toStream.map(i => variant(nt, i)) - ).toMap - - val variantMap = nonTerminalsRHSNew.zipWithIndex.map(nti => nti -> variants(nti._1)).toMap - val assignments = leon.utils.StreamUtils.cartesianMap(variantMap) - for(assignment <- assignments) yield { - var i = 0 - def indexed[T](f: Int => T): T = { - val res = f(i) - i += 1 - res - } - def copyOfNt(nt: NonTerminal): NonTerminal = indexed { i => assignment((nt, i)) } - def copy(t: Symbol): Symbol = t match { - case nt: NonTerminal => copyOfNt(nt) - case e => e - } - val newStart = this.start.map(copy) - val newRules = for{i <- 0 until n - (lhs, expansion) <- this.rules - new_expansion = expansion.map(copy) - } yield (variant(lhs, i) -> new_expansion) - Grammar(newStart, newRules.toMap) - } - } - } - - def symbolToString(symbol: Symbol): String = { - symbol match { - case s: NonTerminal => nonterminalToString(s) - case s: Terminal => terminalToString(s) - } - } - def nonterminalToString(nonterminal: NonTerminal): String = { - nonterminal.tag + (if(nonterminal.vcontext != Nil) "_v["+nonterminal.vcontext.map(x => symbolToString(x)).reduce(_ + "," + _) + "]" else "") + - (if(nonterminal.hcontext != Nil) "_h["+nonterminal.hcontext.map(x => symbolToString(x)).reduce(_ + "," + _)+"]" else "") - } - def terminalToString(terminal: Terminal): String = { - terminal.tag + (if(terminal.terminalData == "") "" else "_" + terminal.terminalData) - } - def reduce(l: Iterable[String], separator: String) = if(l == Nil) "" else l.reduce(_ + separator + _) - def expansionToString(expansion: Expansion): String = { - reduce(expansion.ls.map(l => reduce(l.map(x => symbolToString(x)), " ")), " | ") - } - - def grammarToString(grammar: Grammar) = { - "Start: " + reduce(grammar.start.map(s => symbolToString(s)), " ") + "\n" + - reduce(grammar.rules.map(kv => symbolToString(kv._1) + " -> " + expansionToString(kv._2)).toList.sorted, "\n") - } -} \ No newline at end of file diff --git a/src/main/scala/leon/synthesis/grammars/package.scala b/src/main/scala/leon/synthesis/grammars/package.scala deleted file mode 100644 index 710fe0468c33e67cbf0b61d05495a5bdd1f090c6..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/grammars/package.scala +++ /dev/null @@ -1,28 +0,0 @@ -package leon -package synthesis - -import leon.grammars._ -import purescala.ExprOps._ -import purescala.Expressions.Expr -import purescala.Extractors.TopLevelAnds -import purescala.Types.{BooleanType, Int32Type, IntegerType} -import Witnesses.Hint - -package object grammars { - - def default(sctx: SynthesisContext, p: Problem, extraHints: Seq[Expr] = Seq()): ExpressionGrammar = { - val TopLevelAnds(ws) = p.ws - val hints = ws.collect{ case Hint(e) if formulaSize(e) >= 4 => e } - val inputs = p.allAs.map(_.toVariable) ++ hints ++ extraHints - val exclude = sctx.settings.functionsToIgnore - val recCalls = if (sctx.findOptionOrDefault(SynthesisPhase.optIntroduceRecCalls)) Empty() else SafeRecursiveCalls(sctx.program, p.ws, p.pc) - - BaseGrammar || - Closures || - EqualityGrammar(Set(IntegerType, Int32Type, BooleanType) ++ inputs.map { _.getType }) || - OneOf(inputs) || - Constants(sctx.functionContext.fullBody) || - FunctionCalls(sctx.program, sctx.functionContext, inputs.map(_.getType), exclude) || - recCalls - } -} diff --git a/src/main/scala/leon/synthesis/graph/DotGenerator.scala b/src/main/scala/leon/synthesis/graph/DotGenerator.scala deleted file mode 100644 index d1ba7a040c276b83b64ee7e7f8ecd07ee1adefcc..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/graph/DotGenerator.scala +++ /dev/null @@ -1,130 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon.synthesis -package graph - -import leon.utils.UniqueCounter - -import java.io.{File, FileWriter, BufferedWriter} - -class DotGenerator(search: Search) { - implicit val ctx = search.ctx - - val g = search.g - val strat = search.strat - - private val idCounter = new UniqueCounter[Unit] - idCounter.nextGlobal // Start with 1 - - def freshName(prefix: String) = { - prefix + idCounter.nextGlobal - } - - def writeFile(f: File): Unit = { - val out = new BufferedWriter(new FileWriter(f)) - out.write(asString) - out.close() - } - - def writeFile(path: String): Unit = writeFile(new File(path)) - - - def asString: String = { - val res = new StringBuffer() - - res append "digraph D {\n" - - // Print all nodes - val edges = collectEdges(g.root) - val nodes = edges.flatMap(e => Set(e._1, e._3)) - - var nodesToNames = Map[Node, String]() - - for (n <- nodes) { - val name = freshName("node") - - n match { - case ot: OrNode => - drawNode(res, name, ot) - case at: AndNode => - drawNode(res, name, at) - } - - nodesToNames += n -> name - } - - for ((f,i,t) <- edges) { - val label = f match { - case ot: OrNode => - "or" - case at: AndNode => - i.toString - } - - val style = if (f.selected contains t) { - ", style=\"bold\"" - } else { - "" - } - - res append " "+nodesToNames(f)+" -> "+nodesToNames(t) +" [label=\""+label+"\""+style+"]\n" - } - - res append "}\n" - - res.toString - } - - def limit(o: Any, length: Int = 200): String = { - val str = o.toString - if (str.length > length) { - str.substring(0, length-3)+"..." - } else { - str - } - } - - def nodeDesc(n: Node): String = n match { - case an: AndNode => an.ri.asString - case on: OrNode => on.p.asString - } - - def drawNode(res: StringBuffer, name: String, n: Node) { - - val index = n.parent.map(_.descendants.indexOf(n) + " ").getOrElse("") - - def escapeHTML(str: String) = str.replaceAll("&", "&").replaceAll("<", "<").replaceAll(">", ">") - - val color = if (n.isSolved) { - "palegreen" - } else if (n.isDeadEnd) { - "firebrick" - } else if(n.isExpanded) { - "grey80" - } else { - "white" - } - - - res append " "+name+" [ label = <<TABLE BORDER=\"0\" CELLBORDER=\"1\" CELLSPACING=\"0\">" - - res append "<TR><TD BORDER=\"0\">"+escapeHTML(strat.debugInfoFor(n))+"</TD></TR>" - - res append "<TR><TD BORDER=\"1\" BGCOLOR=\""+color+"\">"+escapeHTML(limit(index + nodeDesc(n)))+"</TD></TR>" - - if (n.isSolved) { - res append "<TR><TD BGCOLOR=\""+color+"\">"+escapeHTML(limit(n.generateSolutions().head.asString))+"</TD></TR>" - } - - res append "</TABLE>>, shape = \"none\" ];\n" - - } - - private def collectEdges(from: Node): Set[(Node, Int, Node)] = { - from.descendants.zipWithIndex.flatMap { case (d, i) => - Set((from, i, d)) ++ collectEdges(d) - }.toSet - } -} - -object dotGenIds extends UniqueCounter[Unit] diff --git a/src/main/scala/leon/synthesis/graph/Graph.scala b/src/main/scala/leon/synthesis/graph/Graph.scala deleted file mode 100644 index 1a0c1440e7ab73c5e1c25dbf9e3d790fb9aabfe7..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/graph/Graph.scala +++ /dev/null @@ -1,201 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis -package graph - -import leon.utils.StreamUtils.cartesianProduct -import leon.utils.DebugSectionSynthesis - -sealed class Graph(problem: Problem) { - val root = new RootNode(problem) - - // Returns closed/total - def getStats(from: Node = root): (Int, Int) = { - val isClosed = from.isDeadEnd || from.isSolved - val self = (if (isClosed) 1 else 0, 1) - - if (!from.isExpanded) { - self - } else { - from.descendants.foldLeft(self) { - case ((c,t), d) => - val (sc, st) = getStats(d) - (c+sc, t+st) - } - } - } -} - -sealed abstract class Node(val parent: Option[Node]) extends Printable { - - var descendants: List[Node] = Nil - // indicates whether this particular node has already been expanded - var isExpanded: Boolean = false - - def expand(implicit hctx: SearchContext) - - val p: Problem - - var isSolved: Boolean = false - def onSolved(desc: Node) - - // Solutions this terminal generates (!= None for terminals) - var solutions: Option[Stream[Solution]] = None - var selectedSolution = -1 - - var isDeadEnd: Boolean = false - - def isOpen = !isDeadEnd && !isSolved - - // For non-terminals, selected children for solution - var selected: List[Node] = Nil - - def composeSolutions(sols: List[Stream[Solution]]): Stream[Solution] - - // Generate solutions given selection+solutions - def generateSolutions(): Stream[Solution] = { - solutions.getOrElse { - composeSolutions(selected.map(n => n.generateSolutions())) - } - } -} - -/** Represents the conjunction of search nodes. - * @param parent Some node. None if it is the root node. - * @param ri The rule instantiation that created this AndNode. - **/ -class AndNode(parent: Option[Node], val ri: RuleInstantiation) extends Node(parent) { - val p = ri.problem - - override def asString(implicit ctx: LeonContext) = "\u2227 "+ri.asString - - def expand(implicit hctx: SearchContext): Unit = { - require(!isExpanded) - isExpanded = true - - def pad(prefix: String, message: String): String = { - val lines = message.split("\\n").toList - val padding = " " * (prefix.length + 1) - prefix + " " + lines.head + "\n" + lines.tail.map(padding + _).mkString("\n") - } - - import hctx.reporter.info - - val prefix = f"[${Option(ri).map(_.asString).getOrElse("?")}%-20s]" - - info(pad(prefix, ri.problem.asString)) - - ri.apply(hctx) match { - case RuleClosed(sols) => - solutions = Some(sols) - selectedSolution = 0 - - isSolved = sols.nonEmpty - - if (sols.isEmpty) { - info(s"$prefix Failed") - isDeadEnd = true - } else { - val sol = sols.head - val morePrefix = s"$prefix Solved ${if(sol.isTrusted) "" else "(untrusted)"} with: " - info(pad(morePrefix, sol.asString)) - } - - parent.foreach{ p => - if (isSolved) { - p.onSolved(this) - } - } - - case RuleExpanded(probs) => - info(s"$prefix Decomposed into:") - val morePrefix = s"$prefix -" - for(p <- probs) { - info(pad(morePrefix, p.asString)) - } - - descendants = probs.map(p => new OrNode(Some(this), p)) - - if (descendants.isEmpty) { - isDeadEnd = true - } - - selected = descendants - } - } - - def composeSolutions(solss: List[Stream[Solution]]): Stream[Solution] = { - cartesianProduct(solss).flatMap { - sols => ri.onSuccess(sols) - } - } - - private var solveds = Set[Node]() - - def onSolved(desc: Node): Unit = { - // We store everything within solveds - solveds += desc - - // Everything is solved correctly - if (solveds.size == descendants.size) { - isSolved = true - parent.foreach(_.onSolved(this)) - } - } - -} - -class OrNode(parent: Option[Node], val p: Problem) extends Node(parent) { - - override def asString(implicit ctx: LeonContext) = "\u2228 "+p.asString - - implicit val debugSection = DebugSectionSynthesis - - def getInstantiations(hctx: SearchContext): List[RuleInstantiation] = { - val rules = hctx.settings.rules - - val rulesPrio = rules.groupBy(_.priority).toSeq.sortBy(_._1) - - for ((prio, rs) <- rulesPrio) { - - val results = rs.flatMap{ r => - hctx.reporter.ifDebug(printer => printer("Testing rule: " + r)) - hctx.timers.synthesis.instantiations.get(r.asString(hctx)).timed { - r.instantiateOn(hctx, p) - } - }.toList - - if (results.nonEmpty) { - // We want to force all NormalizingRule's anyway, so no need to branch out. - // Just force the first one, and the rest may be applied afterwards. - return if (prio == RulePriorityNormalizing) results.take(1) else results - } - } - - Nil - } - - def expand(implicit hctx: SearchContext): Unit = { - require(!isExpanded) - - val ris = getInstantiations(hctx) - - descendants = ris.map(ri => new AndNode(Some(this), ri)) - selected = List() - - isExpanded = true - } - - def onSolved(desc: Node): Unit = { - isSolved = true - selected = List(desc) - parent.foreach(_.onSolved(this)) - } - - def composeSolutions(solss: List[Stream[Solution]]): Stream[Solution] = { - solss.toStream.flatten - } -} - -class RootNode(p: Problem) extends OrNode(None, p) diff --git a/src/main/scala/leon/synthesis/package.scala b/src/main/scala/leon/synthesis/package.scala deleted file mode 100644 index 50bb935c0037cee52cb580920643cc63a0d501dd..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/package.scala +++ /dev/null @@ -1,9 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon - -package object synthesis { - type Priority = Int - - val MAX_COST = 500 -} diff --git a/src/main/scala/leon/synthesis/programsets/DirectProgramSet.scala b/src/main/scala/leon/synthesis/programsets/DirectProgramSet.scala deleted file mode 100644 index 692bc2189c9e9717711c2533e0b825c9515d794f..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/programsets/DirectProgramSet.scala +++ /dev/null @@ -1,20 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon.synthesis.programsets - -import leon.purescala -import purescala.Expressions._ -import purescala.Extractors._ -import purescala.Constructors._ -import purescala.Types._ - -object DirectProgramSet { - def apply[T](p: Stream[T]): DirectProgramSet[T] = new DirectProgramSet(p) -} - -/** - * @author Mikael - */ -class DirectProgramSet[T](val p: Stream[T]) extends ProgramSet[T] { - def programs = p -} \ No newline at end of file diff --git a/src/main/scala/leon/synthesis/programsets/JoinProgramSet.scala b/src/main/scala/leon/synthesis/programsets/JoinProgramSet.scala deleted file mode 100644 index 78a47950a0c4a03f4498c72b0aa8d40c68567d8d..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/programsets/JoinProgramSet.scala +++ /dev/null @@ -1,38 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis.programsets - -import leon.purescala -import purescala.Expressions._ -import purescala.Extractors._ -import purescala.Constructors._ -import purescala.Types._ - -/** - * @author Mikael - */ -object JoinProgramSet { - - def apply[T, U1, U2](sets: (ProgramSet[U1], ProgramSet[U2]), recombine: (U1, U2) => T): Join2ProgramSet[T, U1, U2] = { - new Join2ProgramSet(sets, recombine) - } - def apply[T, U](sets: Seq[ProgramSet[U]], recombine: Seq[U] => T): JoinProgramSet[T, U] = { - new JoinProgramSet(sets, recombine) - } - def direct[U1, U2](set1: ProgramSet[U1], set2: ProgramSet[U2]): Join2ProgramSet[(U1, U2), U1, U2] = { - new Join2ProgramSet((set1, set2), (u1: U1, u2: U2) => (u1, u2)) - } - def direct[U](sets: Seq[ProgramSet[U]]): JoinProgramSet[Seq[U], U] = { - new JoinProgramSet(sets, (id: Seq[U]) => id) - } -} - -class Join2ProgramSet[T, U1, U2](sets: (ProgramSet[U1], ProgramSet[U2]), recombine: (U1, U2) => T) extends ProgramSet[T] { - def programs: Stream[T] = utils.StreamUtils.cartesianProduct(sets._1.programs, sets._2.programs).map(recombine.tupled) -} - -class JoinProgramSet[T, U](sets: Seq[ProgramSet[U]], recombine: Seq[U] => T) extends ProgramSet[T] { - def programs: Stream[T] = utils.StreamUtils.cartesianProduct(sets.map(_.programs)).map(recombine) -} - diff --git a/src/main/scala/leon/synthesis/programsets/ProgramSet.scala b/src/main/scala/leon/synthesis/programsets/ProgramSet.scala deleted file mode 100644 index 2e2b1ab2f219aaa9869607a4ae3634a43e750945..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/programsets/ProgramSet.scala +++ /dev/null @@ -1,16 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon.synthesis.programsets - -import leon.purescala -import purescala.Expressions._ -import purescala.Extractors._ -import purescala.Constructors._ -import purescala.Types._ - -/** - * @author Mikael - */ -abstract class ProgramSet[T] { - def programs: Stream[T] -} \ No newline at end of file diff --git a/src/main/scala/leon/synthesis/programsets/UnionProgramSet.scala b/src/main/scala/leon/synthesis/programsets/UnionProgramSet.scala deleted file mode 100644 index 4ebaeb566e7622a4c37bac6d4a525011046e25c9..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/programsets/UnionProgramSet.scala +++ /dev/null @@ -1,17 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis.programsets - -object UnionProgramSet { - def apply[T](sets: Seq[ProgramSet[T]]): UnionProgramSet[T] = { - new UnionProgramSet(sets) - } -} - -/** - * @author Mikael - */ -class UnionProgramSet[T](sets: Seq[ProgramSet[T]]) extends ProgramSet[T] { - def programs = utils.StreamUtils.interleave(sets.map(_.programs)) -} \ No newline at end of file diff --git a/src/main/scala/leon/synthesis/rules/ADTDual.scala b/src/main/scala/leon/synthesis/rules/ADTDual.scala deleted file mode 100644 index 40ab91bc6ad8d3a7843c38f3a96e46c1a0a9cd4c..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/rules/ADTDual.scala +++ /dev/null @@ -1,41 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis -package rules - -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Extractors._ -import purescala.Constructors._ - -/** For a `case class A(b: B, c: C)` and expressions `X,Y,D` the latest not containing any output variable, replaces - * `A(X, Y) = D` - * by the following equivalent - * `D.isInstanceOf[A] && X = D.b && Y = D.c` - * */ -case object ADTDual extends NormalizingRule("ADTDual") { - def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { - val xs = p.xs.toSet - val as = p.as.toSet - - val TopLevelAnds(exprs) = p.phi - - val (toRemove, toAdd) = exprs.collect { - case eq @ Equals(cc @ CaseClass(ct, args), e) if (variablesOf(e) -- as).isEmpty && (variablesOf(cc) & xs).nonEmpty => - (eq, IsInstanceOf(e, ct) +: (ct.classDef.fields zip args).map{ case (vd, ex) => Equals(ex, caseClassSelector(ct, e, vd.id)) } ) - - case eq @ Equals(e, cc @ CaseClass(ct, args)) if (variablesOf(e) -- as).isEmpty && (variablesOf(cc) & xs).nonEmpty => - (eq, IsInstanceOf(e, ct) +: (ct.classDef.fields zip args).map{ case (vd, ex) => Equals(ex, caseClassSelector(ct, e, vd.id)) } ) - }.unzip - - if (toRemove.nonEmpty) { - val sub = p.copy(phi = andJoin((exprs.toSet -- toRemove ++ toAdd.flatten).toSeq), eb = ExamplesBank.empty) - - Some(decomp(List(sub), forward, "ADTDual")) - } else { - None - } - } -} - diff --git a/src/main/scala/leon/synthesis/rules/ADTSplit.scala b/src/main/scala/leon/synthesis/rules/ADTSplit.scala deleted file mode 100644 index 8f6cef04cc6f8c632ce884ec410537cd984fdaf8..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/rules/ADTSplit.scala +++ /dev/null @@ -1,138 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis -package rules - -import Witnesses._ - -import purescala.Expressions._ -import purescala.Common._ -import purescala.Types._ -import purescala.ExprOps._ -import purescala.Extractors._ -import purescala.Constructors._ -import purescala.Definitions._ - -/** Abstract data type split. If a variable is typed as an abstract data type, then - * it will create a match case statement on all known subtypes. */ -case object ADTSplit extends Rule("ADT Split.") { - - def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { - // We approximate knowledge of types based on facts found at the top-level - // we don't care if the variables are known to be equal or not, we just - // don't want to split on two variables for which only one split - // alternative is viable. This should be much less expensive than making - // calls to a solver for each pair. - val facts: Map[Identifier, CaseClassType] = { - val TopLevelAnds(as) = andJoin(p.pc.conditions :+ p.phi) - val instChecks: Seq[(Identifier, CaseClassType)] = as.collect { - case IsInstanceOf(Variable(a), cct: CaseClassType) => a -> cct - case Equals(Variable(a), CaseClass(cct, _)) => a -> cct - } - val boundCcs = p.pc.bindings.collect { case (id, CaseClass(cct, _)) => id -> cct } - instChecks.toMap ++ boundCcs - } - - val candidates = p.allAs.collect { - case IsTyped(id, act @ AbstractClassType(cd, tpes)) => - - val optCases = cd.knownDescendants.sortBy(_.id.name).collect { - case ccd : CaseClassDef => - val cct = CaseClassType(ccd, tpes) - - if (facts contains id) { - if (cct == facts(id)) { - Seq(ccd) - } else { - Nil - } - } else { - Seq(ccd) - } - } - - val cases = optCases.flatten - - if (cases.nonEmpty) { - Some((id, act, cases)) - } else { - None - } - } - - candidates.collect { - case Some((id, act, cases)) => - val oas = p.as.filter(_ != id) - - val subInfo0 = for(ccd <- cases) yield { - val isInputVar = p.as.contains(id) - val cct = CaseClassType(ccd, act.tps) - - val args = cct.fields.map { vd => FreshIdentifier(vd.id.name, vd.getType, true) }.toList - - val whole = CaseClass(cct, args.map(Variable)) - - val subPhi = subst(id -> whole, p.phi) - val subPC = { - val classInv = cct.classDef.invariant.map { fd => - FunctionInvocation(fd.typed(cct.tps), Seq(whole)) - }.getOrElse(BooleanLiteral(true)) - val withSubst = (p.pc withCond classInv) map (subst(id -> whole, _)) - if (isInputVar) withSubst - else { - val mapping = cct.classDef.fields.zip(args).map { - case (f, a) => a -> caseClassSelector(cct, Variable(id), f.id) - } - withSubst.withCond(isInstOf(id.toVariable, cct)).withBindings(mapping) - } - } - val subWS = subst(id -> whole, p.ws) - - val eb2 = { - if (isInputVar) { - // Filter out examples where id has the wrong type, and fix input variables - // Note: It is fine to filter here as no evaluation is required - p.qeb.flatMapIns { inInfo => - inInfo.toMap.apply(id) match { - case CaseClass(`cct`, vs) => - List(vs ++ inInfo.filter(_._1 != id).map(_._2)) - case _ => - Nil - } - } - } else { - p.eb - } - } - val newAs = if (isInputVar) args ::: oas else p.as - val inactive = (!isInputVar).option(Inactive(id)) - val subProblem = Problem(newAs, subWS, subPC, subPhi, p.xs, eb2).withWs(Seq(Hint(whole)) ++ inactive) - val subPattern = CaseClassPattern(None, cct, args.map(id => WildcardPattern(Some(id)))) - - (cct, subProblem, subPattern) - } - - val subInfo = subInfo0.sortBy{ case (cct, _, _) => - cct.fieldsTypes.count { t => t == act } - } - - val onSuccess: List[Solution] => Option[Solution] = { sols => - val (cases, globalPres) = (for ((sol, (cct, problem, pattern)) <- sols zip subInfo) yield { - val retrievedArgs = pattern.subPatterns.collect{ case WildcardPattern(Some(id)) => id } - val substs = (for ((field,arg) <- cct.classDef.fields zip retrievedArgs ) yield { - (arg, caseClassSelector(cct, id.toVariable, field.id)) - }).toMap - ( - SimpleCase(pattern, sol.term), - and(IsInstanceOf(Variable(id), cct), replaceFromIDs(substs, sol.pre)) - ) - }).unzip - - Some(Solution(orJoin(globalPres), sols.flatMap(_.defs).toSet, matchExpr(Variable(id), cases), sols.forall(_.isTrusted))) - } - - decomp(subInfo.map(_._2).toList, onSuccess, s"ADT Split on '${id.asString(hctx)}'") - } - } -} diff --git a/src/main/scala/leon/synthesis/rules/Abduction.scala b/src/main/scala/leon/synthesis/rules/Abduction.scala deleted file mode 100644 index 4fe724e162d8fc030c7f39cceff8b2edc4498c3a..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/rules/Abduction.scala +++ /dev/null @@ -1,23 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis -package rules - -import purescala.Common._ -import purescala.DefOps._ -import purescala.Expressions._ -import purescala.TypeOps.unify -import purescala.TypeOps.canBeSubtypeOf -import purescala.Constructors._ -import purescala.ExprOps._ -import purescala.Definitions._ -import purescala.Extractors._ -import leon.solvers.{SolverFactory, SimpleSolverAPI} -import leon.utils.Simplifiers - -object Abduction extends Rule("Abduction") { - override def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { - Nil - } -} diff --git a/src/main/scala/leon/synthesis/rules/Assert.scala b/src/main/scala/leon/synthesis/rules/Assert.scala deleted file mode 100644 index 4b29537db8118ddf70d68b0c1024a1b950650e95..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/rules/Assert.scala +++ /dev/null @@ -1,48 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis -package rules - -import purescala.ExprOps._ -import purescala.TypeOps._ -import purescala.Extractors._ -import purescala.Expressions._ -import purescala.Constructors._ - -/** Moves the preconditions without output variables to the precondition. */ -case object Assert extends NormalizingRule("Assert") { - def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { - p.phi match { - case TopLevelAnds(exprs) => - val xsSet = p.xs.toSet - val (exprsA, others) = exprs.partition(e => (variablesOf(e) & xsSet).isEmpty) - - if (exprsA.nonEmpty) { - // If there is no more postcondition, then output the solution. - if (others.isEmpty) { - val simplestOut = simplestValue(tupleTypeWrap(p.xs.map(_.getType))) - - if (!isRealExpr(simplestOut)) { - None - } else { - Some(solve(Solution(pre = andJoin(exprsA), defs = Set(), term = simplestOut))) - } - } else { - val sub = p.copy(pc = p.pc withConds exprsA, phi = andJoin(others)) - - Some(decomp(List(sub), { - case List(s @ Solution(pre, defs, term, isTrusted)) => - Some(Solution(andJoin(exprsA :+ pre), defs, term, isTrusted)) - case _ => None - }, "Assert "+andJoin(exprsA).asString)) - } - } else { - None - } - case _ => - None - } - } -} - diff --git a/src/main/scala/leon/synthesis/rules/CEGIS.scala b/src/main/scala/leon/synthesis/rules/CEGIS.scala deleted file mode 100644 index f32364cc64b1abf72997ea61443b79dfe6f715c5..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/rules/CEGIS.scala +++ /dev/null @@ -1,32 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis -package rules - -import leon.grammars._ -import leon.grammars.aspects._ - -/** Basic implementation of CEGIS that uses a naive grammar */ -case object NaiveCEGIS extends CEGISLike("Naive CEGIS") { - def getParams(sctx: SynthesisContext, p: Problem) = { - CegisParams( - grammar = grammars.default(sctx, p), - rootLabel = Label(_), - optimizations = false - ) - } -} - -/** More advanced implementation of CEGIS that uses a less permissive grammar - * and some optimizations - */ -case object CEGIS extends CEGISLike("CEGIS") { - def getParams(sctx: SynthesisContext, p: Problem) = { - CegisParams( - grammar = grammars.default(sctx, p), - rootLabel = Label(_).withAspect(Tagged(Tags.Top, 0, None)), - optimizations = true - ) - } -} diff --git a/src/main/scala/leon/synthesis/rules/CEGISLike.scala b/src/main/scala/leon/synthesis/rules/CEGISLike.scala deleted file mode 100644 index 35d7c02d94e20e49ead20b6593743be46043cde8..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/rules/CEGISLike.scala +++ /dev/null @@ -1,992 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis -package rules - -import purescala.Expressions._ -import purescala.Common._ -import purescala.Definitions._ -import purescala.Types._ -import purescala.ExprOps._ -import purescala.TypeOps.depth -import purescala.DefOps._ -import purescala.Constructors._ - -import utils.MutableExpr -import solvers._ -import leon.grammars._ -import leon.grammars.aspects._ -import leon.utils._ - -import evaluators._ -import datagen._ - -import scala.collection.mutable.{HashMap => MutableMap} - -abstract class CEGISLike(name: String) extends Rule(name) { - - case class CegisParams( - grammar: ExpressionGrammar, - rootLabel: TypeTree => Label, - optimizations: Boolean, - maxSize: Option[Int] = None - ) - - def getParams(sctx: SynthesisContext, p: Problem): CegisParams - - def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { - - import hctx.reporter._ - - val exSolverTo = 500L - val cexSolverTo = 3000L - - // Track non-deterministic programs up to 100'000 programs, or give up - val nProgramsLimit = 100000 - - val timers = hctx.timers.synthesis.applications.CEGIS - - // CEGIS Flags to activate or deactivate features - val useOptTimeout = hctx.settings.cegisUseOptTimeout - val useVanuatoo = hctx.settings.cegisUseVanuatoo - - // The factor by which programs need to be reduced by testing before we validate them individually - val testReductionRatio = 10 - - val interruptManager = hctx.interruptManager - - val params = getParams(hctx, p) - - // If this CEGISLike forces a maxSize, take it, otherwise find it in the settings - val maxSize = params.maxSize.getOrElse(hctx.settings.cegisMaxSize) - - if (maxSize == 0) { - return Nil - } - - // Represents a non-deterministic program - object NonDeterministicProgram { - - // Current synthesized term size - private var termSize = 0 - - def unfolding = termSize - - private val targetType = tupleTypeWrap(p.xs.map(_.getType)) - - val grammar = params.grammar - - def rootLabel = params.rootLabel(targetType).withAspect(TypeDepthBound(depth(targetType) + 1)).withAspect(Sized(termSize)) - - def init(): Unit = { - updateCTree() - } - - /** - * Different view of the tree of expressions: - * - * Case used to illustrate the different views, assuming encoding: - * - * b1 => c1 == F(c2, c3) - * b2 => c1 == G(c4, c5) - * b3 => c6 == H(c4, c5) - * - * c1 -> Seq( - * (b1, F(_, _), Seq(c2, c3)) - * (b2, G(_, _), Seq(c4, c5)) - * ) - * c6 -> Seq( - * (b3, H(_, _), Seq(c7, c8)) - * ) - */ - private var cTree: Map[Identifier, Seq[(Identifier, Seq[Expr] => Expr, Seq[Identifier])]] = Map() - - // cTree in expression form - private var cExpr: Expr = _ - - // Top-level C identifiers corresponding to p.xs - private var rootC: Identifier = _ - - // Blockers - private var bs: Set[Identifier] = Set() - - private var bsOrdered: Seq[Identifier] = Seq() - - // Generator of fresh cs that minimizes labels - class CGenerator { - private var buffers = Map[Label, Stream[Identifier]]() - - private var slots = Map[Label, Int]().withDefaultValue(0) - - private def streamOf(t: Label): Stream[Identifier] = Stream.continually( - FreshIdentifier(t.asString, t.getType, true) - ) - - def rewind(): Unit = { - slots = Map[Label, Int]().withDefaultValue(0) - } - - def getNext(t: Label) = { - if (!(buffers contains t)) { - buffers += t -> streamOf(t) - } - - val n = slots(t) - slots += t -> (n+1) - - buffers(t)(n) - } - } - - // Programs we have manually excluded - var excludedPrograms = Set[Set[Identifier]]() - // Still live programs (allPrograms -- excludedPrograms) - var prunedPrograms = Set[Set[Identifier]]() - - // Update the c-tree after an increase in termsize - def updateCTree(): Unit = { - timers.updateCTree.start() - def freshB() = { - val id = FreshIdentifier("B", BooleanType, true) - bs += id - id - } - - def defineCTreeFor(l: Label, c: Identifier): Unit = { - if (!(cTree contains c)) { - val cGen = new CGenerator() - - val alts = grammar.getProductions(l) - - val cTreeData = alts flatMap { gen => - - // Optimize labels - cGen.rewind() - - val subCs = for (sl <- gen.subTrees) yield { - val subC = cGen.getNext(sl) - defineCTreeFor(sl, subC) - subC - } - - if (subCs.forall(sc => cTree(sc).nonEmpty)) { - val b = freshB() - Some((b, gen.builder, subCs)) - } else None - } - - cTree += c -> cTreeData - } - } - - val cGen = new CGenerator() - - rootC = { - val c = cGen.getNext(rootLabel) - defineCTreeFor(rootLabel, c) - c - } - - ifDebug { printer => - printer("Grammar so far:") - grammar.printProductions(printer) - printer("") - } - - bsOrdered = bs.toSeq.sorted - cExpr = setCExpr() - - excludedPrograms = Set() - prunedPrograms = allPrograms().toSet - - timers.updateCTree.stop() - } - - // Returns a count of all possible programs - val allProgramsCount: () => Int = { - var nAltsCache = Map[Label, Int]() - - def countAlternatives(l: Label): Int = { - if (!(nAltsCache contains l)) { - val count = grammar.getProductions(l).map { gen => - gen.subTrees.map(countAlternatives).product - }.sum - nAltsCache += l -> count - } - nAltsCache(l) - } - - () => countAlternatives(rootLabel) - } - - /** - * Returns all possible assignments to Bs in order to enumerate all possible programs - */ - def allPrograms(): Traversable[Set[Identifier]] = { - - var cache = Map[Identifier, Seq[Set[Identifier]]]() - - val c = allProgramsCount() - - if (c > nProgramsLimit) { - debug(s"Exceeded program limit: $c > $nProgramsLimit") - return Seq() - } - - def allProgramsFor(c: Identifier): Seq[Set[Identifier]] = { - if (!(cache contains c)) { - val subs = for ((b, _, subcs) <- cTree(c)) yield { - if (subcs.isEmpty) { - Seq(Set(b)) - } else { - val subPs = subcs map (s => allProgramsFor(s)) - val combos = SeqUtils.cartesianProduct(subPs).map(_.flatten.toSet) - combos map (_ + b) - } - } - cache += c -> subs.flatten - } - cache(c) - } - - allProgramsFor(rootC) - - } - - private def debugCTree(cTree: Map[Identifier, Seq[(Identifier, Seq[Expr] => Expr, Seq[Identifier])]], - markedBs: Set[Identifier] = Set()): Unit = { - println(" -- -- -- -- -- ") - for ((c, alts) <- cTree) { - println() - println(f"$c%-4s :=") - for ((b, builder, cs) <- alts ) { - val markS = if (markedBs(b)) Console.GREEN else "" - val markE = if (markedBs(b)) Console.RESET else "" - - val ex = builder(cs.map(_.toVariable)).asString - - println(f" $markS ${b.asString}%-4s => $ex%-40s [${cs.map(_.asString).mkString(", ")}]$markE") - } - } - } - - // This represents the current solution of the synthesis problem. - // It is within the image of hctx.functionContext in innerProgram. - // It should be set to the solution you want to check at each time. - // Usually it will either be cExpr or a concrete solution. - private val solutionBox = MutableExpr(NoTree(p.outType)) - private def setSolution(e: Expr) = solutionBox.underlying = e - - // The program with the body of the current function replaced by the current partial solution - private val (outerToInnerTrans, innerProgram) = { - - val outerSolution = { - new PartialSolution(hctx.search.strat, true) - .solutionAround(hctx.currentNode)(solutionBox) - .getOrElse(fatalError("Unable to get outer solution")) - } - - val program0 = addFunDefs(hctx.program, outerSolution.defs, hctx.functionContext) - - val t = funDefReplacer{ - case fd if fd == hctx.functionContext => - val nfd = fd.duplicate() - - nfd.fullBody = postMap { - case src if src eq hctx.source => - Some(outerSolution.term) - - case _ => None - }(nfd.fullBody) - - Some(nfd) - - case fd => - Some(fd.duplicate()) - } - (t, transformProgram(t, program0)) - } - - /** - * Since CEGIS works with a copy of the program, it needs to map outer - * function calls to inner function calls and vice-versa. 'inner' refers - * to the CEGIS-specific program, 'outer' refers to the actual program on - * which we do synthesis. - */ - private def outerToInner(e: Expr) = outerToInnerTrans.transform(e)(Map.empty) - private def outerToInner(fd: FunDef) = outerToInnerTrans.transform(fd) - private def outerToInner(id: Identifier) = outerToInnerTrans.transform(id) - - - private val innerPc = p.pc map outerToInner - private val innerPhi = outerToInner(p.phi) - // Depends on the current solution - private val innerSpec = outerToInner( - letTuple(p.xs, solutionBox, p.phi) - ) - - - // The program with the c-tree functions - private var programCTree: Program = _ - - private var evaluator: DefaultEvaluator = _ - - // Updates the program with the C tree after recalculating all relevant FunDef's - private def setCExpr(): Expr = { - - // Computes a Seq of functions corresponding to the choices made at each non-terminal of the grammar, - // and an expression which calls the top-level one. - def computeCExpr(): (Expr, Seq[FunDef]) = { - var cToFd = Map[Identifier, FunDef]() - - def exprOf(alt: (Identifier, Seq[Expr] => Expr, Seq[Identifier])): Expr = { - val (_, builder, cs) = alt - - val e = builder(cs.map { c => - cToFd(c).applied - }) - - outerToInner(e) - } - - // Define all C-def - for ((c, alts) <- cTree) yield { - cToFd += c -> new FunDef(FreshIdentifier(c.asString, alwaysShowUniqueID = true), Seq(), p.as.map(id => ValDef(id)), c.getType) - } - - // Fill C-def bodies - for ((c, alts) <- cTree) { - val body = if (alts.nonEmpty) { - alts.init.foldLeft(exprOf(alts.last)) { - case (e, alt) => IfExpr(alt._1.toVariable, exprOf(alt), e) - } - } else { - Error(c.getType, s"Empty production rule: $c") - } - cToFd(c).fullBody = body - } - - // Top-level expression for rootC - val expr = cToFd(rootC).applied - - (expr, cToFd.values.toSeq) - } - - val (cExpr, newFds) = computeCExpr() - - programCTree = addFunDefs(innerProgram, newFds, outerToInner(hctx.functionContext)) - evaluator = new DefaultEvaluator(hctx, programCTree) - - cExpr - //println("-- "*30) - //println(programCTree.asString) - //println(".. "*30) - } - - // Tests a candidate solution against an example in the correct environment - // None -> evaluator error - def testForProgram(bValues: Set[Identifier])(ex: Example): Option[Boolean] = { - - def redundant(e: Expr): Boolean = { - val (op1, op2) = e match { - case Minus(o1, o2) => (o1, o2) - case Modulo(o1, o2) => (o1, o2) - case Division(o1, o2) => (o1, o2) - case BVMinus(o1, o2) => (o1, o2) - case BVRemainder(o1, o2) => (o1, o2) - case BVDivision(o1, o2) => (o1, o2) - - case And(Seq(Not(o1), Not(o2))) => (o1, o2) - case And(Seq(Not(o1), o2)) => (o1, o2) - case And(Seq(o1, Not(o2))) => (o1, o2) - case And(Seq(o1, o2)) => (o1, o2) - - case Or(Seq(Not(o1), Not(o2))) => (o1, o2) - case Or(Seq(Not(o1), o2)) => (o1, o2) - case Or(Seq(o1, Not(o2))) => (o1, o2) - case Or(Seq(o1, o2)) => (o1, o2) - - case SetUnion(o1, o2) => (o1, o2) - case SetIntersection(o1, o2) => (o1, o2) - case SetDifference(o1, o2) => (o1, o2) - - case Equals(Not(o1), Not(o2)) => (o1, o2) - case Equals(Not(o1), o2) => (o1, o2) - case Equals(o1, Not(o2)) => (o1, o2) - case Equals(o1, o2) => (o1, o2) - case _ => return false - } - op1 == op2 - } - - val outerSol = getExpr(bValues) - - val redundancyCheck = false - - // This program contains a simplifiable expression, - // which means it is equivalent to a simpler one - // Deactivated for now, since it does not seem to help - if (redundancyCheck && params.optimizations && exists(redundant)(outerSol)) { - excludeProgram(bs, true) - return Some(false) - } - - val innerSol = outerToInner(outerSol) - - def withBindings(e: Expr) = p.pc.bindings.foldRight(e){ - case ((id, v), bd) => let(outerToInner(id), outerToInner(v), bd) - } - - setSolution(innerSol) - - timers.testForProgram.start() - - val innerEnv = p.as.zip(ex.ins).map{ case (id, v) => - (outerToInner(id), outerToInner(v)) - }.toMap - - val res = ex match { - case InExample(ins) => - evaluator.eval(withBindings(innerSpec),innerEnv) - - case InOutExample(ins, outs) => - evaluator.eval( - withBindings(equality(innerSol, tupleWrap(outs))), - innerEnv - ) - } - timers.testForProgram.stop() - - res match { - case EvaluationResults.Successful(res) => - Some(res == BooleanLiteral(true)) - - case EvaluationResults.RuntimeError(err) => - debug("RE testing CE: "+err) - Some(false) - - case EvaluationResults.EvaluatorError(err) => - debug("Error testing CE: "+err) - //println("Program\n" + programCTree) - //println("InnerSpec: " + withBindings(innerSpec)) - //println("InnerEnv" + innerEnv) - None - } - - } - - // Returns the outer expression corresponding to a B-valuation - def getExpr(bValues: Set[Identifier]): Expr = { - - def getCValue(c: Identifier): Expr = { - cTree(c).find(i => bValues(i._1)).map { - case (b, builder, cs) => - builder(cs.map(getCValue)) - }.getOrElse { - Error(c.getType, "Impossible assignment of bs") - } - } - - getCValue(rootC) - } - - /** - * Here we check the validity of a (small) number of programs in isolation. - * We keep track of CEXs generated by invalid programs and preemptively filter the rest of the programs with them. - */ - def validatePrograms(bss: Set[Set[Identifier]]): Either[Seq[Seq[Expr]], Stream[Solution]] = { - - var cexs = Seq[Seq[Expr]]() - - var best: List[Solution] = Nil - - for (bs <- bss.toSeq) { - // We compute the corresponding expr and replace it in place of the C-tree - val outerSol = getExpr(bs) - val innerSol = outerToInner(outerSol) - setSolution(innerSol) - - val cnstr = innerPc and letTuple(p.xs map outerToInner, innerSol, Not(innerPhi)) - - //println("Program:") - //println(programCTree) - //println("Constraint:") - //println(cnstr) - - val eval = new DefaultEvaluator(hctx, programCTree) - - if (cexs exists (cex => eval.eval(cnstr, p.as.zip(cex).toMap).result == Some(BooleanLiteral(true)))) { - debug(s"Rejected by CEX: $outerSol") - excludeProgram(bs, true) - } else { - //println("Solving for: "+cnstr.asString) - - val solverf = SolverFactory.getFromSettings(hctx, programCTree).withTimeout(cexSolverTo) - val solver = solverf.getNewSolver() - try { - debug("Sending candidate to solver...") - def currentSolution(trusted: Boolean) = Solution(BooleanLiteral(true), Set(), outerSol, isTrusted = trusted) - solver.assertCnstr(cnstr) - solver.check match { - case Some(true) => - debug(s"Proven invalid: $outerSol") - excludeProgram(bs, true) - val model = solver.getModel - //println("Found counter example: ") - //for ((s, v) <- model) { - // println(" "+s.asString+" -> "+v.asString) - //} - - //val evaluator = new DefaultEvaluator(ctx, prog) - //println(evaluator.eval(cnstr, model)) - //println(s"Program $outerSol fails with cex ${p.as.map(a => model.getOrElse(a, simplestValue(a.getType)))}") - cexs +:= p.as.map(a => model.getOrElse(a, simplestValue(a.getType))) - - case Some(false) => - // UNSAT, valid program - debug("Found valid program!") - return Right(Stream(currentSolution(true))) - - case None => - debug("Found a non-verifiable solution...") - // Optimistic valid solution - best +:= currentSolution(false) - } - } finally { - solverf.reclaim(solver) - solverf.shutdown() - } - } - } - - if (useOptTimeout && best.nonEmpty) { - // Interpret timeout in CE search as "the candidate is valid" - info(s"CEGIS could not prove the validity of the resulting ${best.size} expression(s)") - Right(best.toStream) - } else { - Left(cexs) - } - } - - def allProgramsClosed = prunedPrograms.isEmpty - def closeAllPrograms() = { - excludedPrograms ++= prunedPrograms - prunedPrograms = Set() - } - - // Explicitly remove program computed by bValues from the search space - // - // If the bValues comes from models, we make sure the bValues we exclude - // are minimal we make sure we exclude only Bs that are used. - def excludeProgram(bs: Set[Identifier], isMinimal: Boolean): Unit = { - - def filterBTree(c: Identifier): Set[Identifier] = { - val (b, _, subcs) = cTree(c).find(sub => bs(sub._1)).get - subcs.flatMap(filterBTree).toSet + b - } - - val bvs = if (isMinimal) { - bs - } else { - filterBTree(rootC) - } - - excludedPrograms += bvs - prunedPrograms -= bvs - } - - def unfold() = { - termSize += 1 - updateCTree() - } - - /** - * First phase of CEGIS: discover potential programs (that work on at least one input) - */ - def solveForTentativeProgram(): Option[Option[Set[Identifier]]] = { - timers.tentative.start() - val solverf = SolverFactory.getFromSettings(hctx, programCTree).withTimeout(exSolverTo) - val solver = solverf.getNewSolver() - - //println("Program: ") - //println("-"*80) - //println(programCTree.asString) - - setSolution(cExpr) - val toFind = innerPc and innerSpec - //println(" --- Constraints ---") - //println(" - "+toFind.asString) - try { - solver.assertCnstr(toFind) - - for ((c, alts) <- cTree) { - - val bs = alts.map(_._1) - - val either = for (a1 <- bs; a2 <- bs if a1 < a2) yield { - Or(Not(a1.toVariable), Not(a2.toVariable)) - } - - if (bs.nonEmpty) { - //println(" - "+andJoin(either).asString) - solver.assertCnstr(andJoin(either)) - - val oneOf = orJoin(bs.map(_.toVariable)) - //println(" - "+oneOf.asString) - solver.assertCnstr(oneOf) - } - } - - //println(" -- Excluded:") - for (ex <- excludedPrograms) { - val notThisProgram = Not(andJoin(ex.map(_.toVariable).toSeq)) - - //println(f" - ${notThisProgram.asString}%-40s ("+getExpr(ex)+")") - solver.assertCnstr(notThisProgram) - } - - solver.check match { - case Some(true) => - val model = solver.getModel - - val bModel = bs.filter(b => model.get(b).contains(BooleanLiteral(true))) - - //println("Tentative model: "+model.asString) - //println("Tentative model: "+bModel.filter(isBActive).map(_.asString).toSeq.sorted) - //println("Tentative expr: "+getExpr(bModel)) - - Some(Some(bModel)) - - case Some(false) => - //println("UNSAT!") - Some(None) - - case None => - /** - * If the remaining tentative programs are all infeasible, it - * might timeout instead of returning Some(false). We might still - * benefit from unfolding further - */ - None - } - } finally { - timers.tentative.stop() - solverf.reclaim(solver) - solverf.shutdown() - } - } - - /** - * Second phase of CEGIS: verify a given program by looking for CEX inputs - */ - def solveForCounterExample(bs: Set[Identifier]): Option[Option[Seq[Expr]]] = { - timers.cex.start() - val solverf = SolverFactory.getFromSettings(hctx, programCTree).withTimeout(cexSolverTo) - val solver = solverf.getNewSolver() - - try { - setSolution(cExpr) - solver.assertCnstr(andJoin(bsOrdered.map(b => if (bs(b)) b.toVariable else Not(b.toVariable)))) - solver.assertCnstr(innerPc and not(innerSpec)) - - //println("*"*80) - //println(Not(cnstr)) - //println(innerPc) - //println("*"*80) - //println(programCTree.asString) - //println("*"*80) - //Console.in.read() - - solver.check match { - case Some(true) => - val model = solver.getModel - val cex = p.as.map(a => model.getOrElse(a, simplestValue(a.getType))) - - Some(Some(cex)) - - case Some(false) => - Some(None) - - case None => - None - } - } finally { - timers.cex.stop() - solverf.reclaim(solver) - solverf.shutdown() - } - } - } - - List(new RuleInstantiation(this.name) { - def apply(hctx: SearchContext): RuleApplication = { - var result: Option[RuleApplication] = None - - val ndProgram = NonDeterministicProgram - ndProgram.init() - - implicit val ic = hctx - - debug("Acquiring initial list of examples") - - // To the list of known examples, we add an additional one produced by the solver - val solverExample = if (p.pc.isEmpty) { - List(InExample(p.as.map(a => simplestValue(a.getType)))) - } else { - val solverf = hctx.solverFactory - val solver = solverf.getNewSolver().setTimeout(exSolverTo*2) - - solver.assertCnstr(p.pc.toClause) - - try { - solver.check match { - case Some(true) => - val model = solver.getModel - List(InExample(p.as.map(a => model.getOrElse(a, simplestValue(a.getType))))) - - case Some(false) => - debug("Path-condition seems UNSAT") - return RuleFailed() - - case None => - if (!interruptManager.isInterrupted) { - warning("Solver could not solve path-condition") - } - Nil - //return RuleFailed() // This is not necessary though, but probably wanted - } - } finally { - solverf.reclaim(solver) - } - } - - val baseExampleInputs = p.qebFiltered.examples ++ solverExample - - ifDebug { debug => - baseExampleInputs.foreach { in => - debug(" - "+in.asString) - } - } - - /** - * We (lazily) generate additional tests for discarding potential programs with a data generator - */ - val nTests = if (p.pc.isEmpty) 50 else 20 - - val inputGenerator: Iterator[Example] = { - val complicated = exists{ - case FunctionInvocation(tfd, _) if tfd.fd == hctx.functionContext => true - case Choose(_) => true - case _ => false - }(p.pc.toClause) - - if (complicated) { - Iterator() - } else { - if (useVanuatoo) { - new VanuatooDataGen(hctx, hctx.program).generateFor(p.as, p.pc.toClause, nTests, 3000).map(InExample) - } else { - val evaluator = new DualEvaluator(hctx, hctx.program) - new GrammarDataGen(evaluator, ValueGrammar).generateFor(p.as, p.pc.toClause, nTests, 1000).map(InExample) - } - } - } - - // We keep number of failures per test to pull the better ones to the front - val failedTestsStats = new MutableMap[Example, Int]().withDefaultValue(0) - - // This is the starting test-base - val gi = new GrowableIterable[Example](baseExampleInputs, inputGenerator) - def hasInputExamples = gi.nonEmpty - - var n = 1 - - try { - do { - // Run CEGIS for one specific unfolding level - - // Unfold formula - ndProgram.unfold() - - val nInitial = ndProgram.prunedPrograms.size - debug(s"#Programs: $nInitial") - - def nPassing = ndProgram.prunedPrograms.size - - def programsReduced() = nPassing <= 10 || (nPassing <= 100 && nInitial / nPassing > testReductionRatio) - gi.canGrow = programsReduced - - def allInputExamples() = { - if (n == 10 || n == 50 || n % 500 == 0) { - gi.sortBufferBy(e => -failedTestsStats(e)) - } - n += 1 - gi.iterator - } - - //sctx.reporter.ifDebug{ printer => - // val limit = 100 - - // for (p <- prunedPrograms.take(limit)) { - // val ps = p.toSeq.sortBy(_.id).mkString(", ") - // printer(f" - $ps%-40s - "+ndProgram.getExpr(p)) - // } - // if(nInitial > limit) { - // printer(" - ...") - // } - //} - - debug(s"#Tests: >= ${gi.bufferedCount}") - ifDebug{ printer => - val es = allInputExamples() - for (e <- es.take(Math.min(gi.bufferedCount, 10))) { - printer(" - "+e.asString) - } - if(es.hasNext) { - printer(" - ...") - } - } - - // We further filter the set of working programs to remove those that fail on known examples - if (hasInputExamples) { - timers.filter.start() - for (bs <- ndProgram.prunedPrograms if !interruptManager.isInterrupted) { - val examples = allInputExamples() - var badExamples = List[Example]() - var stop = false - for (e <- examples if !stop && !badExamples.contains(e)) { - ndProgram.testForProgram(bs)(e) match { - case Some(true) => // ok, passes - case Some(false) => - // Program fails the test - stop = true - failedTestsStats(e) += 1 - debug(f" Program: ${ndProgram.getExpr(bs).asString}%-80s failed on: ${e.asString}") - ndProgram.excludeProgram(bs, true) - case None => - // Eval. error -> bad example - debug(s" Test $e crashed the evaluator, removing...") - badExamples ::= e - } - } - gi --= badExamples - } - timers.filter.stop() - } - - debug(s"#Programs passing tests: $nPassing out of $nInitial") - ifDebug{ printer => - for (p <- ndProgram.prunedPrograms.take(100)) { - printer(" - "+ndProgram.getExpr(p).asString) - } - if(nPassing > 100) { - printer(" - ...") - } - } - // CEGIS Loop at a given unfolding level - while (result.isEmpty && !interruptManager.isInterrupted && !ndProgram.allProgramsClosed) { - debug("Programs left: " + ndProgram.prunedPrograms.size) - - // Phase 0: If the number of remaining programs is small, validate them individually - if (programsReduced()) { - timers.validate.start() - val programsToValidate = ndProgram.prunedPrograms - debug(s"Will send ${programsToValidate.size} program(s) to validate individually") - ndProgram.validatePrograms(programsToValidate) match { - case Right(sols) => - // Found solution! Exit CEGIS - result = Some(RuleClosed(sols)) - case Left(cexs) => - debug(s"Found cexs! $cexs") - // Found some counterexamples - // (bear in mind that these will in fact exclude programs within validatePrograms()) - val newCexs = cexs.map(InExample) - newCexs foreach (failedTestsStats(_) += 1) - gi ++= newCexs - } - debug(s"#Programs after validating individually: ${ndProgram.prunedPrograms.size}") - timers.validate.stop() - } - - if (result.isEmpty && !ndProgram.allProgramsClosed) { - // Phase 1: Find a candidate program that works for at least 1 input - debug("Looking for program that works on at least 1 input...") - ndProgram.solveForTentativeProgram() match { - case Some(Some(bs)) => - debug(s"Found tentative model ${ndProgram.getExpr(bs)}, need to validate!") - // Phase 2: Validate candidate model - ndProgram.solveForCounterExample(bs) match { - case Some(Some(inputsCE)) => - debug("Found counter-example:" + inputsCE) - val ce = InExample(inputsCE) - // Found counterexample! Exclude this program - gi += ce - failedTestsStats(ce) += 1 - ndProgram.excludeProgram(bs, false) - - var bad = false - // Retest whether the newly found C-E invalidates some programs - for (p <- ndProgram.prunedPrograms if !bad) { - ndProgram.testForProgram(p)(ce) match { - case Some(true) => - case Some(false) => - debug(f" Program: ${ndProgram.getExpr(p).asString}%-80s failed on: ${ce.asString}") - failedTestsStats(ce) += 1 - ndProgram.excludeProgram(p, true) - case None => - debug(s" Test $ce failed, removing...") - gi -= ce - bad = true - } - } - - case Some(None) => - // Found no counter example! Program is a valid solution - val expr = ndProgram.getExpr(bs) - result = Some(RuleClosed(Solution(BooleanLiteral(true), Set(), expr))) - - case None => - // We are not sure - debug("Unknown") - if (useOptTimeout) { - // Interpret timeout in CE search as "the candidate is valid" - info("CEGIS could not prove the validity of the resulting expression") - val expr = ndProgram.getExpr(bs) - result = Some(RuleClosed(Solution(BooleanLiteral(true), Set(), expr, isTrusted = false))) - } else { - // Ok, we failed to validate, exclude this program - ndProgram.excludeProgram(bs, false) - // TODO: Make CEGIS fail early when it times out when verifying 1 program? - // result = Some(RuleFailed()) - } - } - - case Some(None) => - debug("There exists no candidate program!") - ndProgram.closeAllPrograms() - - case None => - debug("Timeout while getting tentative program!") - ndProgram.closeAllPrograms() - // TODO: Make CEGIS fail early when it times out when looking for tentative program? - //result = Some(RuleFailed()) - } - } - } - - } while(ndProgram.unfolding < maxSize && result.isEmpty && !interruptManager.isInterrupted) - - if (interruptManager.isInterrupted) interruptManager.recoverInterrupt() - result.getOrElse(RuleFailed()) - - } catch { - case e: Throwable => - warning("CEGIS crashed: "+e.getMessage) - e.printStackTrace() - RuleFailed() - } - } - }) - } -} diff --git a/src/main/scala/leon/synthesis/rules/CEGLESS.scala b/src/main/scala/leon/synthesis/rules/CEGLESS.scala deleted file mode 100644 index fc36fb794599aa56ca21decd069f727f523cb209..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/rules/CEGLESS.scala +++ /dev/null @@ -1,39 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis -package rules - -import purescala.ExprOps._ -import purescala.Types._ -import purescala.Extractors._ -import leon.grammars._ -import leon.grammars.aspects._ -import Witnesses._ - -case object CEGLESS extends CEGISLike("CEGLESS") { - def getParams(sctx: SynthesisContext, p: Problem) = { - val TopLevelAnds(clauses) = p.ws - - val guides = clauses.collect { - case Guide(e) => e - } - - sctx.reporter.ifDebug { printer => - printer("Guides available:") - for (g <- guides) { - printer(" - "+g.asString(sctx)) - } - } - - CegisParams( - grammar = grammars.default(sctx, p), - rootLabel = (tpe: TypeTree) => Label(tpe).withAspect(DepthBound(2)).withAspect(SimilarTo(guides)), - optimizations = false, - maxSize = Some((0 +: guides.map(depth(_) + 1)).max) - ) - } -} - - - diff --git a/src/main/scala/leon/synthesis/rules/CaseSplit.scala b/src/main/scala/leon/synthesis/rules/CaseSplit.scala deleted file mode 100644 index fd820f96dc4b0746849056e590deba79d1cdc993..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/rules/CaseSplit.scala +++ /dev/null @@ -1,41 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis -package rules - -import purescala.Expressions._ -import purescala.Constructors._ - -case object CaseSplit extends Rule("Case-Split") { - def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { - p.phi match { - case Or(os) => - List(split(os, "Split top-level Or")) - case _ => - Nil - } - } - - def split(alts: Seq[Expr], description: String)(implicit p: Problem): RuleInstantiation = { - val subs = alts.map(a => Problem(p.as, p.ws, p.pc, a, p.xs, p.eb)).toList - - val onSuccess: List[Solution] => Option[Solution] = { - case sols if sols.size == subs.size => - val pre = orJoin(sols.map(_.pre)) - val defs = sols.map(_.defs).reduceLeft(_ ++ _) - - val (prefix, last) = (sols.init, sols.last) - - val term = prefix.foldRight(last.term) { (s, t) => IfExpr(s.pre, s.term, t) } - - Some(Solution(pre, defs, term, sols.forall(_.isTrusted))) - - case _ => - None - } - - decomp(subs, onSuccess, description) - } -} - diff --git a/src/main/scala/leon/synthesis/rules/DetupleInput.scala b/src/main/scala/leon/synthesis/rules/DetupleInput.scala deleted file mode 100644 index 898353504282bc498f8a3c069b2a98c3783275a1..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/rules/DetupleInput.scala +++ /dev/null @@ -1,144 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis -package rules - -import Witnesses._ -import purescala.Expressions._ -import purescala.Common._ -import purescala.Types._ -import purescala.ExprOps._ -import purescala.Constructors._ -import purescala.Extractors.LetPattern - -/** Rule for detupling input variables, to be able to use their sub-expressions. For example, the input variable: - * {{{d: Cons(head: Int, tail: List)}}} - * will create the following input variables - * {{{head42: Int, tail57: List}}} - * Recomposition is available. - */ -case object DetupleInput extends NormalizingRule("Detuple In") { - - def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { - /** Returns true if this identifier is a tuple or a case class */ - def typeCompatible(id: Identifier) = id.getType match { - case CaseClassType(t, _) if !t.isAbstract => true - case TupleType(ts) => true - case _ => false - } - - def isDecomposable(id: Identifier) = typeCompatible(id) && !p.wsList.contains(Inactive(id)) - - /* Decomposes a decomposable input identifier (eg of type Tuple or case class) - * into a list of fresh typed identifiers, the tuple of these new identifiers, - * and the mapping of those identifiers to their respective expressions. - */ - def decompose(id: Identifier): (List[Identifier], Expr, Expr => Seq[Expr]) = id.getType match { - case cct @ CaseClassType(ccd, _) if !ccd.isAbstract => - val newIds = cct.fields.map{ vd => FreshIdentifier(vd.id.name, vd.getType, true) } - - val tMap: (Expr => Seq[Expr]) = { - case CaseClass(ccd, fields) => fields - } - - (newIds.toList, CaseClass(cct, newIds.map(Variable)), tMap) - - case TupleType(ts) => - val newIds = ts.zipWithIndex.map{ case (t, i) => FreshIdentifier(id.name+"_"+(i+1), t, true) } - - val tMap: (Expr => Seq[Expr]) = { - case Tuple(fields) => fields - } - - (newIds.toList, tupleWrap(newIds.map(Variable)), tMap) - - case _ => sys.error("woot") - } - - if (p.allAs.exists(isDecomposable)) { - var subProblem = p.phi - var subPc = p.pc - var subWs = p.ws - var hints: Seq[Expr] = Nil - var patterns = List[(Identifier, Pattern)]() - var revMap = Map[Expr, Expr]().withDefault((e: Expr) => e) - var inactive = Set[Identifier]() - - var ebMapInfo = Map[Identifier, Expr => Seq[Expr]]() - - val subAs = p.allAs.map { a => - if (isDecomposable(a)) { - val (newIds, expr, tMap) = decompose(a) - val patts = newIds map (id => WildcardPattern(Some(id))) - val patt = a.getType match { - case TupleType(_) => - TuplePattern(None, patts) - case cct: CaseClassType => - CaseClassPattern(None, cct, patts) - } - - subProblem = subst(a -> expr, subProblem) - subPc = { - val classInv = a.getType match { - case _:TupleType => BooleanLiteral(true) - case cc: CaseClassType => - cc.classDef.invariant.map { fd => - FunctionInvocation(fd.typed(cc.tps), Seq(expr)) - }.getOrElse(BooleanLiteral(true)) - } - val withSubst = (subPc withCond classInv) map (subst(a -> expr, _)) - if (!p.pc.boundIds.contains(a)){ - withSubst - } else { - inactive += a - val mapping = mapForPattern(a.toVariable, patt) - withSubst.withBindings(mapping) - } - } - subWs = subst(a -> expr, subWs) - revMap += expr -> Variable(a) - hints +:= Hint(expr) - - patterns +:= a -> patt - - ebMapInfo += a -> tMap - - a -> newIds - } else { - a -> List(a) - } - }.toMap - - val eb = p.qeb.flatMapIns { info => - List(info.flatMap { case (id, v) => - ebMapInfo.get(id) match { - case Some(m) => - m(v) - case None => - List(v) - } - }) - } - - val newAs = p.as.flatMap(subAs) - - val (as, patts) = patterns.unzip - - val sub = Problem(newAs, subWs, subPc, subProblem, p.xs, eb) - .withWs(hints) - .withWs(inactive.toSeq.map(Inactive)) - - val s = { (e: Expr) => - val body = simplePostTransform(revMap)(e) - (patts zip as).foldRight(body) { case ((patt, a), bd) => - LetPattern(patt, a.toVariable, bd) - } - } - - Some(decomp(List(sub), forwardMap(s), s"Detuple ${as.map(_.asString(hctx)).mkString(", ")}")) - } else { - None - } - } -} diff --git a/src/main/scala/leon/synthesis/rules/Disunification.scala b/src/main/scala/leon/synthesis/rules/Disunification.scala deleted file mode 100644 index a380dec539110cd5ae0f030bc3b98da8a12609cc..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/rules/Disunification.scala +++ /dev/null @@ -1,37 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis -package rules - -import purescala.Expressions._ -import purescala.Extractors._ -import purescala.Constructors._ - -object Disunification { - case object Decomp extends Rule("Disunif. Decomp.") { - def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { - val TopLevelAnds(exprs) = p.phi - - val (toRemove, toAdd) = exprs.collect { - case neq @ Not(Equals(cc1 @ CaseClass(cd1, args1), cc2 @ CaseClass(cd2, args2))) => - if (cc1 == cc2) { - (neq, List(BooleanLiteral(false))) - } else if (cd1 == cd2) { - (neq, (args1 zip args2).map(p => not(Equals(p._1, p._2)))) - } else { - (neq, List(BooleanLiteral(true))) - } - }.unzip - - if (toRemove.nonEmpty) { - val sub = p.copy(phi = orJoin((exprs.toSet -- toRemove ++ toAdd.flatten).toSeq), eb = ExamplesBank.empty) - - Some(decomp(List(sub), forward, this.name)) - } else { - None - } - } - } -} - diff --git a/src/main/scala/leon/synthesis/rules/EquivalentInputs.scala b/src/main/scala/leon/synthesis/rules/EquivalentInputs.scala deleted file mode 100644 index 02abccfd2488d506de43f6ada2aa7f713ec8fb51..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/rules/EquivalentInputs.scala +++ /dev/null @@ -1,112 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis -package rules - -import leon.utils._ -import purescala.Path -import purescala.Common._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Constructors._ -import purescala.Types.CaseClassType - -case object EquivalentInputs extends NormalizingRule("EquivalentInputs") { - def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { - - val simplifier = Simplifiers.bestEffort(hctx, hctx.program)(_:Expr) - - var subst = Map.empty[Identifier, Expr] - var reverseSubst = Map.empty[Identifier, Expr] - - var obsolete = Set.empty[Identifier] - var free = Set.empty[Identifier] - - def discoverEquivalences(p: Path): Path = { - - val vars = p.variables - val clauses = p.conditions - - val instanceOfs = clauses.collect { case IsInstanceOf(Variable(id), cct) if vars(id) => id -> cct }.toSet - - val equivalences = (for ((sid, cct: CaseClassType) <- instanceOfs) yield { - val fieldVals = for (f <- cct.classDef.fields) yield { - val fid = f.id - - p.bindings.collectFirst { - case (id, CaseClassSelector(`cct`, Variable(`sid`), `fid`)) => id - case (id, CaseClassSelector(`cct`, AsInstanceOf(Variable(`sid`), `cct`), `fid`)) => id - } - } - - if (fieldVals.forall(_.isDefined)) { - Some(sid -> CaseClass(cct, fieldVals.map(_.get.toVariable))) - } else if (fieldVals.exists(_.isDefined)) { - Some(sid -> CaseClass(cct, (cct.fields zip fieldVals).map { - case (_, Some(id)) => Variable(id) - case (vid, None) => Variable(vid.id.freshen) - })) - } else { - None - } - }).flatten - - val unbound = equivalences.flatMap(_._2.args.collect { case Variable(id) => id }) - obsolete ++= equivalences.map(_._1) - free ++= unbound - - def replace(e: Expr) = simplifier(replaceFromIDs(equivalences.toMap, e)) - subst = subst.mapValues(replace) ++ equivalences - - val reverse = equivalences.toMap.flatMap { case (id, CaseClass(cct, fields)) => - (cct.classDef.fields zip fields).map { case (vid, Variable(fieldId)) => - fieldId -> caseClassSelector(cct, asInstOf(Variable(id), cct), vid.id) - } - } - - reverseSubst ++= reverse.mapValues(replaceFromIDs(reverseSubst, _)) - - (p -- unbound) map replace - } - - // We could discover one equivalence, which could allow us to discover - // other equivalences: We do a fixpoint with limit 5. - val newPC = fixpoint({ (path: Path) => discoverEquivalences(path) }, 5)(p.pc) - - if (subst.nonEmpty) { - // XXX: must take place in this order!! obsolete & free is typically non-empty - val newAs = (p.as ++ free).distinct.filterNot(obsolete) - - val newBank = p.eb.flatMap { ex => - val mapping = (p.as zip ex.ins).toMap - val newIns = newAs.map(a => mapping.getOrElse(a, replaceFromIDs(mapping, reverseSubst(a)))) - List(ex match { - case ioe @ InOutExample(ins, outs) => ioe.copy(ins = newIns) - case ie @ InExample(ins) => ie.copy(ins = newIns) - }) - } - - val simplifierWithNewPC = Simplifiers.bestEffort(hctx, hctx.program)(_:Expr, newPC) - - val sub = p.copy( - as = newAs, - ws = replaceFromIDs(subst, p.ws), - pc = newPC, - phi = simplifierWithNewPC(replaceFromIDs(subst, p.phi)), - eb = newBank - ) - - val onSuccess = { - val reverse = subst.map(_.swap).mapValues(_.toVariable) - forwardMap(replace(reverse, _)) - } - - val substString = subst.map { case (f, t) => f.asString(hctx)+" -> "+t.asString(hctx) } - - List(decomp(List(sub), onSuccess, "Equivalent Inputs ("+substString.mkString(", ")+")")) - } else { - Nil - } - } -} diff --git a/src/main/scala/leon/synthesis/rules/GenericTypeEqualitySplit.scala b/src/main/scala/leon/synthesis/rules/GenericTypeEqualitySplit.scala deleted file mode 100644 index 8f1324828fc613a0bf277189b43bb583cd5a6bf5..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/rules/GenericTypeEqualitySplit.scala +++ /dev/null @@ -1,81 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis -package rules - -import purescala.Common.Identifier -import purescala.Constructors._ -import purescala.Expressions._ -import purescala.Extractors.{IsTyped, TopLevelAnds} -import purescala.Types._ -import Witnesses._ - -/** For every pair of input variables of the same generic type, - * checks equality and output an If-Then-Else statement with the two new branches. - */ -case object GenericTypeEqualitySplit extends Rule("Eq. Split") { - def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { - // We approximate knowledge of equality based on facts found at the top-level - // we don't care if the variables are known to be equal or not, we just - // don't want to split on two variables for which only one split - // alternative is viable. This should be much less expensive than making - // calls to a solver for each pair. - def getFacts(e: Expr): Set[Set[Identifier]] = e match { - case Not(e) => getFacts(e) - case Equals(Variable(a), Variable(b)) => Set(Set(a,b)) - case _ => Set() - } - - val facts: Set[Set[Identifier]] = { - val TopLevelAnds(as) = andJoin(p.pc.conditions :+ p.phi) - as.toSet.flatMap(getFacts) - } - - val candidates = p.allAs.combinations(2).collect { - case List(IsTyped(a1, TypeParameter(t1)), IsTyped(a2, TypeParameter(t2))) - if t1 == t2 && !facts(Set(a1, a2)) => - (a1, a2) - }.toList - - candidates.flatMap { - case (a1, a2) => - val v1 = Variable(a1) - val v2 = Variable(a2) - - val (f, t, isInput) = if (p.as contains a1) (a1, v2, true) else (a2, v1, p.as contains a2) - val eq = if (isInput) { - p.copy( - as = p.as.diff(Seq(f)), - pc = p.pc map (subst(f -> t, _)), - ws = subst(f -> t, p.ws), - phi = subst(f -> t, p.phi), - eb = p.qeb.removeIns(Set(f)) - ) - } else { - p.copy(pc = p.pc withCond Equals(v1,v2)).withWs(Seq(Inactive(f))) // FIXME! - } - - val neq = p.copy(pc = p.pc withCond not(Equals(v1, v2))) - - val subProblems = List(eq, neq) - - val onSuccess: List[Solution] => Option[Solution] = { - case sols @ List(sEQ, sNE) => - val pre = or( - and(Equals(v1, v2), sEQ.pre), - and(not(Equals(v1, v2)), sNE.pre) - ) - - val term = IfExpr(Equals(v1, v2), sEQ.term, sNE.term) - - Some(Solution(pre, sols.flatMap(_.defs).toSet, term, sols.forall(_.isTrusted))) - } - - Some(decomp(subProblems, onSuccess, s"Eq. Split on '$v1' and '$v2'")) - - case _ => - None - } - } -} diff --git a/src/main/scala/leon/synthesis/rules/Ground.scala b/src/main/scala/leon/synthesis/rules/Ground.scala deleted file mode 100644 index 588f94c1807972d77f78cd9793f7f23fe4702557..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/rules/Ground.scala +++ /dev/null @@ -1,44 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis -package rules - -import solvers._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Constructors._ -import scala.concurrent.duration._ - -case object Ground extends NormalizingRule("Ground") { - def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { - if (p.as.isEmpty) { - List(new RuleInstantiation(this.name) { - def apply(hctx: SearchContext): RuleApplication = { - val solver = SimpleSolverAPI(hctx.solverFactory.withTimeout(10.seconds)) - - val result = solver.solveSAT(p.phi) match { - case (Some(true), model) => - val solExpr = tupleWrap(p.xs.map(valuateWithModel(model))) - - if (!isRealExpr(solExpr)) { - RuleFailed() - } else { - val sol = Solution(BooleanLiteral(true), Set(), solExpr) - RuleClosed(sol) - } - case (Some(false), model) => - RuleClosed(Solution.UNSAT(p)) - case _ => - RuleFailed() - } - - result - } - }) - } else { - None - } - } -} - diff --git a/src/main/scala/leon/synthesis/rules/IfSplit.scala b/src/main/scala/leon/synthesis/rules/IfSplit.scala deleted file mode 100644 index 2472858b8f84b715fb756e5ee79b957bd7d0e250..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/rules/IfSplit.scala +++ /dev/null @@ -1,66 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis -package rules - -import leon.solvers._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Constructors._ - -case object IfSplit extends Rule("If-Split") { - def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { - - def split(i: IfExpr, description: String): Option[RuleInstantiation] = { - - val solver = SimpleSolverAPI(SolverFactory.getFromSettings(hctx, hctx.program).withTimeout(1000)) - if ( - solver.solveVALID(p.pc implies i.cond ).contains(true) || - solver.solveVALID(p.pc implies not(i.cond)).contains(true) - ){ - // Condition can only go one way, no reason to split... - return None - } - - val subs = List( - Problem(p.as, p.ws, p.pc withCond i.cond, replace(Map(i -> i.thenn), p.phi), p.xs), - Problem(p.as, p.ws, p.pc withCond not(i.cond), replace(Map(i -> i.elze), p.phi), p.xs) - ) - - val onSuccess: List[Solution] => Option[Solution] = { - case sols if sols.size == 2 => - val List(ts, es) = sols - - val pre = or(and(i.cond, ts.pre), and(not(i.cond), es.pre)) - val defs = ts.defs ++ es.defs - val term = IfExpr(i.cond, ts.term, es.term) - - Some(Solution(pre, defs, term, sols.forall(_.isTrusted))) - - case _ => - None - } - - Some(decomp(subs, onSuccess, description)) - } - - val ifs = collect{ - case i: IfExpr => Set(i) - case _ => Set[IfExpr]() - }(p.phi) - - val xsSet = p.xs.toSet - - ifs.flatMap { - case i @ IfExpr(cond, _, _) => - if ((variablesOf(cond) & xsSet).isEmpty) { - split(i, s"If-Split on '${cond.asString}'") - } else { - None - } - } - - } -} - diff --git a/src/main/scala/leon/synthesis/rules/IndependentSplit.scala b/src/main/scala/leon/synthesis/rules/IndependentSplit.scala deleted file mode 100644 index 6c7c0904108a4ef9137a6b14419ce90492300ab5..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/rules/IndependentSplit.scala +++ /dev/null @@ -1,133 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis -package rules - -import purescala.Expressions._ -import purescala.Types._ -import purescala.ExprOps._ -import purescala.Extractors._ -import purescala.Constructors._ -import purescala.Common._ - -/** Split output if parts of it are independent in the spec. - * Works in 2 phases: - * 1) Detuples output variables if they are tuples or case classes to their fields. - * 2) Tries to split spec in independent parts and solve them separately. - */ -case object IndependentSplit extends NormalizingRule("IndependentSplit") { - def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { - - /**** Phase 1 ****/ - def isDecomposable(id: Identifier) = id.getType match { - case CaseClassType(_, _) => true - case TupleType(_) => true - case _ => false - } - - val (newP, recon) = if (p.xs.exists(isDecomposable)) { - var newPhi = p.phi - - val (subOuts, outerOuts) = p.xs.map { x => - x.getType match { - case ct: CaseClassType => - val newIds = ct.fields.map { vd => FreshIdentifier(vd.id.name, vd.getType, true) } - - val newCC = CaseClass(ct, newIds.map(Variable)) - - newPhi = subst(x -> newCC, newPhi) - - (newIds, newCC) - case tt: TupleType => - val newIds = tt.bases.zipWithIndex.map{ case (t, ind) => - FreshIdentifier(s"${x}_${ind+1}", t, true) - } - val newTuple = Tuple(newIds map Variable) - newPhi = subst(x -> newTuple, newPhi) - (newIds, newTuple) - case _ => - (List(x), Variable(x)) - } - }.unzip - - val newOuts = subOuts.flatten - - val newEb = p.qeb.eb.flatMapOuts{ xs => List( (xs zip p.xs) flatMap { - case (CaseClass(_, args), id) if isDecomposable(id) => args.toList - case (Tuple(args), _) => args.toList - case (other, _) => List(other) - })} - - val newProb = Problem(p.as, p.ws, p.pc, newPhi, newOuts, newEb) - (newProb, letTuple(newOuts, _:Expr, tupleWrap(outerOuts))) - //, s"Detuple out ${p.xs.filter(isDecomposable).mkString(", ")}")) - } else { - (p, (e: Expr) => e) - } - - /**** Phase 2 ****/ - - val TopLevelAnds(clauses) = andJoin(newP.pc.conditions :+ newP.phi) - - var independentClasses = Set[Set[Identifier]]() - - // We group connect variables together - for (c <- clauses) { - val vs = variablesOf(c) - - var newClasses = Set[Set[Identifier]]() - - var thisClass = vs - - for (cl <- independentClasses) { - if ((cl & vs).nonEmpty) { - thisClass ++= cl - } else { - newClasses += cl - } - } - - independentClasses = newClasses + thisClass - } - - val outClasses = independentClasses.map(cl => cl & newP.xs.toSet).filter(_.nonEmpty) - - if (outClasses.size > 1) { - - val TopLevelAnds(phiClauses) = newP.phi - - val subs = (for (cl <- outClasses.toList) yield { - val xs = newP.xs.filter(cl) - - if (xs.nonEmpty) { - val phi = andJoin(phiClauses.filter(c => (variablesOf(c) & cl).nonEmpty)) - - val xsToRemove = newP.xs.filterNot(cl).toSet - - val eb = newP.qeb.removeOuts(xsToRemove) - - Some(newP.copy(phi = phi, xs = xs, eb = eb)) - } else { - None - } - }).flatten - - val onSuccess: List[Solution] => Option[Solution] = { sols => - - val infos = subs.map(_.xs).zip(sols.map(_.term)) - - val term = infos.foldLeft(tupleWrap(newP.xs.map(_.toVariable))) { - case (expr, (xs, term)) => - letTuple(xs, term, expr) - } - - Some(Solution(andJoin(sols.map(_.pre)), sols.flatMap(_.defs).toSet, recon(term), sols.forall(_.isTrusted))) - } - - List(decomp(subs, onSuccess, "Independent Clusters")) - } else { - Nil - } - } -} diff --git a/src/main/scala/leon/synthesis/rules/InequalitySplit.scala b/src/main/scala/leon/synthesis/rules/InequalitySplit.scala deleted file mode 100644 index fbd53f622d611f542558bd34699827f9630bd3d7..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/rules/InequalitySplit.scala +++ /dev/null @@ -1,113 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis -package rules - -import leon.synthesis.Witnesses.Inactive -import purescala.Expressions._ -import purescala.Types._ -import purescala.Constructors._ -import purescala.Extractors._ - -/** For every pair of variables of an integer type plus 0 of that type, - * splits for inequality between these variables - * and reconstructs the subproblems with a (nested) if-then-else. - * - * Takes into account knowledge about equality/inequality in the path condition. - * - */ -case object InequalitySplit extends Rule("Ineq. Split.") { - - // Represents NEGATIVE knowledge - private abstract class Fact { - val l: Expr - val r: Expr - } - private case class LT(l: Expr, r: Expr) extends Fact - private case class EQ(l: Expr, r: Expr) extends Fact - - def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { - - def getFacts(e: Expr): Set[Fact] = e match { - case LessThan(a, b) => Set(LT(b,a), EQ(a,b)) - case LessEquals(a, b) => Set(LT(b,a)) - case GreaterThan(a, b) => Set(LT(a,b), EQ(a,b)) - case GreaterEquals(a, b) => Set(LT(a,b)) - case Equals(a, b) => Set(LT(a,b), LT(b,a)) - case Not(LessThan(a, b)) => Set(LT(a,b)) - case Not(LessEquals(a, b)) => Set(LT(a,b), EQ(a,b)) - case Not(GreaterThan(a, b)) => Set(LT(b,a)) - case Not(GreaterEquals(a, b)) => Set(LT(b,a), EQ(a,b)) - case Not(Equals(a, b)) => Set(EQ(a,b)) - case _ => Set() - } - - val facts: Set[Fact] = { - val TopLevelAnds(fromPhi) = p.phi - (fromPhi.toSet ++ p.pc.conditions ++ p.pc.bindings.map { case (id,e) => Equals(id.toVariable, e) }) flatMap getFacts - } - - val candidates = - (p.allAs.map(_.toVariable).filter(_.getType == Int32Type) :+ IntLiteral(0)).combinations(2).toList ++ - (p.allAs.map(_.toVariable).filter(_.getType == IntegerType) :+ InfiniteIntegerLiteral(0)).combinations(2).toList - - candidates.flatMap { - case List(v1, v2) => - - val lt = if (!facts.contains(LT(v1, v2))) { - val pc = LessThan(v1, v2) - Some(pc, p.copy(pc = p.pc withCond pc)) - } else None - - val gt = if (!facts.contains(LT(v2, v1))) { - val pc = GreaterThan(v1, v2) - Some(pc, p.copy(pc = p.pc withCond pc)) - } else None - - val eq: Option[(Equals, Problem)] = if (!facts.contains(EQ(v1, v2)) && !facts.contains(EQ(v2,v1))) { - val pc = Equals(v1, v2) - // Let's see if an input variable is involved - val (f, t, isInput) = (v1, v2) match { - case (Variable(a1), _) if p.as.contains(a1) => (a1, v2, true) - case (_, Variable(a2)) if p.as.contains(a2) => (a2, v1, true) - case (Variable(a1), _) => (a1, v2, false) - } - val newP = if (isInput) { - p.copy( - as = p.as.diff(Seq(f)), - pc = p.pc map (subst(f -> t, _)), - ws = subst(f -> t, p.ws), - phi = subst(f -> t, p.phi), - eb = p.qeb.removeIns(Set(f)) - ) - } else { - p.copy(pc = p.pc withCond pc).withWs(Seq(Inactive(f))) // equality in pc is fine for numeric types - } - - Some(pc, newP) - } else None - - val (pcs, subProblems) = List(eq, lt, gt).flatten.unzip - - if (pcs.size < 2) None - else { - - val onSuccess: List[Solution] => Option[Solution] = { sols => - val pre = orJoin(pcs.zip(sols).map { case (pc, sol) => and(pc, sol.pre) }) - - val term = pcs.zip(sols) match { - case Seq((pc1, s1), (_, s2)) => - IfExpr(pc1, s1.term, s2.term) - case Seq((pc1, s1), (pc2, s2), (_, s3)) => - IfExpr(pc1, s1.term, IfExpr(pc2, s2.term, s3.term)) - } - - Some(Solution(pre, sols.flatMap(_.defs).toSet, term, sols.forall(_.isTrusted))) - } - - Some(decomp(subProblems, onSuccess, s"Ineq. Split on '${v1.asString(hctx)}' and '${v2.asString(hctx)}'")) - } - } - } -} diff --git a/src/main/scala/leon/synthesis/rules/InnerCaseSplit.scala b/src/main/scala/leon/synthesis/rules/InnerCaseSplit.scala deleted file mode 100644 index dff8b6f23f66ad8a96957635313510e1dd525db5..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/rules/InnerCaseSplit.scala +++ /dev/null @@ -1,52 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis -package rules - -import purescala.Expressions._ -import purescala.Constructors._ - -case object InnerCaseSplit extends Rule("Inner-Case-Split"){ - def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { - p.phi match { - case Or(_) => - // Inapplicable in this case, normal case-split has precedence here. - Nil - case _ => - var phi = p.phi - phi match { - case Not(And(es)) => - phi = orJoin(es.map(not)) - - case Not(Or(es)) => - phi = andJoin(es.map(not)) - - case _ => - } - - phi match { - case Or(os) => - List(rules.CaseSplit.split(os, "Inner case-split")) - - case And(as) => - val optapp = for ((a, i) <- as.zipWithIndex) yield { - a match { - case Or(os) => - Some(rules.CaseSplit.split(os.map(o => andJoin(as.updated(i, o))), "Inner case-split")) - - case _ => - None - } - } - - optapp.flatten - - case e => - Nil - } - } - } - -} - diff --git a/src/main/scala/leon/synthesis/rules/InputSplit.scala b/src/main/scala/leon/synthesis/rules/InputSplit.scala deleted file mode 100644 index fbdca3ad269868dde0c8865b348066e4905b877c..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/rules/InputSplit.scala +++ /dev/null @@ -1,53 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis -package rules - -import purescala.Path -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Constructors._ -import purescala.Types._ - -case object InputSplit extends Rule("In. Split") { - def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { - p.allAs.filter(_.getType == BooleanType).flatMap { a => - def getProblem(v: Boolean): Problem = { - def replaceA(e: Expr) = replaceFromIDs(Map(a -> BooleanLiteral(v)), e) - - val newPc: Path = { - val withoutA = p.pc -- Set(a) map replaceA - withoutA withConds (p.pc.bindings.collectFirst { case (`a`, res) => - if (v) res else not(res) - }) - } - - p.copy( - as = p.as.filterNot(_ == a), - ws = replaceA(p.ws), - pc = newPc, - phi = replaceA(p.phi), - eb = p.qeb.removeIns(Set(a)) - ) - } - - val sub1 = getProblem(true) - val sub2 = getProblem(false) - - val onSuccess: List[Solution] => Option[Solution] = { - case List(s1, s2) => - Some(Solution(or(and( Variable(a) , s1.pre), - and(Not(Variable(a)), s2.pre)), - s1.defs ++ s2.defs, - IfExpr(Variable(a), s1.term, s2.term), - s1.isTrusted && s2.isTrusted - )) - case _ => - None - } - - Some(decomp(List(sub1, sub2), onSuccess, s"Split on '$a'")) - } - } -} diff --git a/src/main/scala/leon/synthesis/rules/IntroduceRecCalls.scala b/src/main/scala/leon/synthesis/rules/IntroduceRecCalls.scala deleted file mode 100644 index ffdcb5f7b2afa395fe951df9aed91bbd1ea5fe6c..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/rules/IntroduceRecCalls.scala +++ /dev/null @@ -1,48 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis -package rules - -import purescala.Path -import purescala.Extractors.TopLevelAnds -import purescala.Expressions._ -import purescala.Constructors._ -import purescala.Common._ -import Witnesses.Terminating -import utils.Helpers.terminatingCalls - -case object IntroduceRecCalls extends NormalizingRule("Introduce rec. calls") { - - def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { - val existingCalls = p.pc.bindings.collect { case (_, fi: FunctionInvocation) => fi }.toSet - - val calls = terminatingCalls(hctx.program, p.ws, p.pc, None, false) - .map(_._1).distinct.filterNot(existingCalls) - - if (calls.isEmpty) return Nil - - val (recs, paths) = calls.map { newCall => - val rec = FreshIdentifier("rec", newCall.getType, alwaysShowUniqueID = true) - val path = Path.empty withBinding (rec -> newCall) - (rec, path) - }.unzip - - val newWs = calls map Terminating - val TopLevelAnds(ws) = p.ws - val newProblem = p.copy( - pc = paths.foldLeft(p.pc)(_ merge _), - ws = andJoin(ws ++ newWs), - eb = p.eb - ) - - val onSuccess = forwardMap { e => - recs.zip(calls).foldRight(e) { - case ( (id, call), bd) => - Let(id, call, bd) - } - } - - List(decomp(List(newProblem), onSuccess, s"Introduce calls ${calls mkString ", "}")) - } -} diff --git a/src/main/scala/leon/synthesis/rules/OnePoint.scala b/src/main/scala/leon/synthesis/rules/OnePoint.scala deleted file mode 100644 index 1fb5ead1472a33eb1b7554e092dda512e8b13a4a..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/rules/OnePoint.scala +++ /dev/null @@ -1,50 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis -package rules - -import purescala.Common._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Extractors._ -import purescala.Constructors._ - -/** If there is a top-level equality such as x = e where e does not contains x, then we can output the assignment and replace x anywhere else. */ -case object OnePoint extends NormalizingRule("One-point") { - def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { - val TopLevelAnds(exprs) = p.phi - - def validOnePoint(x: Identifier, e: Expr) = { - !(variablesOf(e) contains x) - } - - val candidates = exprs.collect { - case eq @ Equals(Variable(x), e) if (p.xs contains x) && validOnePoint(x, e) => - (x, e, eq) - case eq @ Equals(e, Variable(x)) if (p.xs contains x) && validOnePoint(x, e) => - (x, e, eq) - } - - if (candidates.nonEmpty) { - val (x, e, eq) = candidates.head - - val others = exprs.filter(_ != eq) - val oxs = p.xs.filter(_ != x) - - val newProblem = Problem(p.as, p.ws, p.pc, subst(x -> e, andJoin(others)), oxs, p.qeb.removeOuts(Set(x))) - - val onSuccess: List[Solution] => Option[Solution] = { - case List(s @ Solution(pre, defs, term, isTrusted)) => - Some(Solution(pre, defs, letTuple(oxs, term, subst(x -> e, tupleWrap(p.xs.map(Variable)))), isTrusted)) - case _ => - None - } - - List(decomp(List(newProblem), onSuccess, s"One-point on $x = $e")) - } else { - Nil - } - } -} - diff --git a/src/main/scala/leon/synthesis/rules/OptimisticGround.scala b/src/main/scala/leon/synthesis/rules/OptimisticGround.scala deleted file mode 100644 index ce05a9b3c4115f794bc5e0290c5adb464bbfc90f..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/rules/OptimisticGround.scala +++ /dev/null @@ -1,85 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis -package rules - -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Constructors._ - -import solvers._ - -import scala.concurrent.duration._ - -case object OptimisticGround extends Rule("Optimistic Ground") { - def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { - if (p.as.nonEmpty && p.xs.nonEmpty) { - val res = new RuleInstantiation(this.name) { - def apply(hctx: SearchContext) = { - - val solver = SimpleSolverAPI(hctx.solverFactory.withTimeout(50.millis)) - - val xss = p.xs.toSet - val ass = p.as.toSet - - var i = 0 - val maxTries = 3 - - var result: Option[RuleApplication] = None - var continue = true - var predicates: Seq[Expr] = Seq() - - while (result.isEmpty && i < maxTries && continue) { - val phi = p.pc and andJoin(p.phi +: predicates) - val notPhi = p.pc and andJoin(not(p.phi) +: predicates) - //println("SOLVING " + phi + " ...") - solver.solveSAT(phi) match { - case (Some(true), satModel) => - val newNotPhi = valuateWithModelIn(notPhi, xss, satModel) - - //println("REFUTING " + Not(newNotPhi) + "...") - solver.solveSAT(newNotPhi) match { - case (Some(true), invalidModel) => - // Found as such as the xs break, refine predicates - predicates = valuateWithModelIn(phi, ass, invalidModel) +: predicates - - case (Some(false), _) => - // Model appears valid, but it might be a fake expression (generic values) - val outExpr = tupleWrap(p.xs.map(valuateWithModel(satModel))) - - if (!isRealExpr(outExpr)) { - // It does contain a generic value, we skip - predicates = valuateWithModelIn(phi, xss, satModel) +: predicates - } else { - result = Some(RuleClosed(Solution(BooleanLiteral(true), Set(), outExpr))) - } - case _ => - continue = false - result = None - } - - case (Some(false), _) => - if (predicates.isEmpty) { - result = Some(RuleClosed(Solution.UNSAT(p))) - } else { - continue = false - result = None - } - case _ => - continue = false - result = None - } - - i += 1 - } - - result.getOrElse(RuleFailed()) - } - } - List(res) - } else { - Nil - } - } -} diff --git a/src/main/scala/leon/synthesis/rules/StringRender.scala b/src/main/scala/leon/synthesis/rules/StringRender.scala deleted file mode 100644 index 92286ea472729495f77704b667e295e2f2c6847d..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/rules/StringRender.scala +++ /dev/null @@ -1,761 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis -package rules - -import scala.annotation.tailrec -import scala.collection.mutable.ListBuffer -import evaluators.AbstractEvaluator -import purescala.Definitions._ -import purescala.Common._ -import purescala.Types._ -import purescala.Constructors._ -import purescala.Expressions._ -import purescala.Extractors._ -import purescala.TypeOps -import purescala.DefOps -import purescala.ExprOps -import purescala.SelfPrettyPrinter -import solvers.ModelBuilder -import solvers.string.StringSolver -import programsets.DirectProgramSet -import programsets.JoinProgramSet -import leon.utils.StreamUtils -import leon.synthesis.RulePriority - -/** A template generator for a given type tree. - * Extend this class using a concrete type tree, - * Then use the apply method to get a hole which can be a placeholder for holes in the template. - * Each call to the `.instantiate` method of the subsequent Template will provide different instances at each position of the hole. - */ -abstract class TypedTemplateGenerator(t: TypeTree) { - import StringRender.WithIds - /** Provides a hole which can be used multiple times in the expression. - * When calling .instantiateWithVars on the results, replaces each hole by a unique constant.*/ - def apply(f: Expr => Expr): TemplateGenerator = { - val id = FreshIdentifier("ConstToInstantiate", t, true) - new TemplateGenerator(f(Variable(id)), id, t) - } - def nested(f: Expr => WithIds[Expr]): TemplateGenerator = { - val id = FreshIdentifier("ConstToInstantiate", t, true) - val res = f(Variable(id)) - new TemplateGenerator(res._1, id, t, res._2) - } - class TemplateGenerator(template: Expr, varId: Identifier, t: TypeTree, initialHoles: List[Identifier] = Nil) { - private val optimizationVars = ListBuffer[Identifier]() ++= initialHoles - private def Const: Variable = { - val res = FreshIdentifier("const", t, true) - optimizationVars += res - Variable(res) - } - private def instantiate: Expr = { - ExprOps.postMap({ - case Variable(id) if id == varId => Some(Const) - case _ => None - })(template) - } - def instantiateWithVars: WithIds[Expr] = (instantiate, optimizationVars.toList) - } -} - -/** - * @author Mikael - */ -case object StringRender extends Rule("StringRender") { - override val priority = RulePriorityNormalizing - - // A type T augmented with a list of identifiers, for examples the free variables inside T - type WithIds[T] = (T, List[Identifier]) - - def EDIT_ME(n: Int): String = "_edit_me"+n+"_" - val EDIT_ME_REGEXP = "_edit_me\\d+_".r - def contains_EDIT_ME(s: String): Boolean = EDIT_ME_REGEXP.findFirstIn(s).nonEmpty - - var enforceDefaultStringMethodsIfAvailable = true - var enforceSelfStringMethodsIfAvailable = false - - val booleanTemplate = (a: Expr) => StringTemplateGenerator(Hole => IfExpr(a, Hole, Hole)) - - import StringSolver.{StringFormToken, Problem => SProblem, Equation, Assignment} - - /** Augment the left-hand-side to have possible function calls, such as x + "const" + customToString(_) ... - * Function calls will be eliminated when converting to a valid problem. - */ - sealed abstract class AugmentedStringFormToken - case class RegularStringFormToken(e: StringFormToken) extends AugmentedStringFormToken - case class OtherStringFormToken(e: Expr) extends AugmentedStringFormToken - type AugmentedStringForm = List[AugmentedStringFormToken] - - /** Augments the right-hand-side to have possible function calls, such as "const" + customToString(_) ... - * Function calls will be eliminated when converting to a valid problem. - */ - sealed abstract class AugmentedStringChunkRHS - case class RegularStringChunk(e: String) extends AugmentedStringChunkRHS - case class OtherStringChunk(e: Expr) extends AugmentedStringChunkRHS - type AugmentedStringLiteral = List[AugmentedStringChunkRHS] - - /** Converts an expression to a stringForm, suitable for StringSolver */ - def toStringForm(e: Expr, isVariableToLookFor: Identifier => Boolean, acc: List[AugmentedStringFormToken] = Nil)(implicit hctx: SearchContext): Option[AugmentedStringForm] = e match { - case StringLiteral(s) => - Some(RegularStringFormToken(Left(s))::acc) - case Variable(id) if isVariableToLookFor(id) => Some(RegularStringFormToken(Right(id))::acc) - case v@Variable(id) => Some(OtherStringFormToken(v)::acc) - case StringConcat(lhs, rhs) => - toStringForm(rhs, isVariableToLookFor, acc).flatMap(toStringForm(lhs, isVariableToLookFor, _)) - case e:Application => Some(OtherStringFormToken(e)::acc) - case e:FunctionInvocation => Some(OtherStringFormToken(e)::acc) - case _ => None - } - - /** Returns the string associated to the expression if it is computable */ - def toStringLiteral(e: Expr, isVariableToLookFor: Identifier => Boolean): Option[AugmentedStringLiteral] = e match { - case StringLiteral(s) => Some(List(RegularStringChunk(s))) - case StringConcat(lhs, rhs) => - toStringLiteral(lhs, isVariableToLookFor).flatMap(k => toStringLiteral(rhs, isVariableToLookFor).map(l => (k.init, k.last, l) match { - case (kinit, RegularStringChunk(s), RegularStringChunk(sp)::ltail) => - kinit ++ (RegularStringChunk(s + sp)::ltail) - case _ => k ++ l - })) - case e: Variable if(!isVariableToLookFor(e.id)) => Some(List(OtherStringChunk(e))) - case e: Application => Some(List(OtherStringChunk(e))) - case e: FunctionInvocation => Some(List(OtherStringChunk(e))) - case _ => None - } - - /** Converts an equality AugmentedStringForm == AugmentedStringLiteral to a list of equations - * For that, splits both strings on function applications. If they yield the same value, we can split, else it fails. */ - def toEquations(lhs: AugmentedStringForm, rhs: AugmentedStringLiteral): Option[List[Equation]] = { - def rec(lhs: AugmentedStringForm, rhs: AugmentedStringLiteral, - accEqs: ListBuffer[Equation], accLeft: ListBuffer[StringFormToken], accRight: StringBuffer): Option[List[Equation]] = (lhs, rhs) match { - case (Nil, Nil) => - (accLeft.toList, accRight.toString) match { - case (Nil, "") => Some(accEqs.toList) - case (lhs, rhs) => Some((accEqs += ((lhs, rhs))).toList) - } - case (OtherStringFormToken(e)::lhstail, OtherStringChunk(f)::rhstail) => - if(ExprOps.canBeHomomorphic(e, f).nonEmpty) { - rec(lhstail, rhstail, accEqs += ((accLeft.toList, accRight.toString)), ListBuffer[StringFormToken](), new StringBuffer) - } else None - case (OtherStringFormToken(e)::lhstail, Nil) => - None - case (Nil, OtherStringChunk(f)::rhstail) => - None - case (lhs, RegularStringChunk(s)::rhstail) => - rec(lhs, rhstail, accEqs, accLeft, accRight append s) - case (RegularStringFormToken(e)::lhstail, rhs) => - rec(lhstail, rhs, accEqs, accLeft += e, accRight) - } - rec(lhs, rhs, ListBuffer[Equation](), ListBuffer[StringFormToken](), new StringBuffer) - } - - /** Returns a stream of assignments compatible with input/output examples for the given template */ - def findAssignments(p: Program, inputs: Seq[Identifier], examples: ExamplesBank, template: Expr, variablesToReplace: Set[Identifier])(implicit hctx: SearchContext): Stream[Map[Identifier, String]] = { - //println(s"finding assignments in program\n$p") - val e = new AbstractEvaluator(hctx, p) - - @tailrec def gatherEquations(s: List[InOutExample], acc: ListBuffer[Equation] = ListBuffer()): Option[SProblem] = s match { - case Nil => Some(acc.toList) - case InOutExample(in, rhExpr)::q => - if(rhExpr.length == 1) { - val model = new ModelBuilder - model ++= inputs.zip(in) - val modelResult = model.result() - val evalResult = e.eval(template, modelResult) - evalResult.result match { - case None => - hctx.reporter.info("Eval = None : ["+template+"] in ["+inputs.zip(in)+"]") - hctx.reporter.info(evalResult) - None - case Some((sfExpr, abstractSfExpr)) => - //ctx.reporter.debug("Eval = ["+sfExpr+"] (from "+abstractSfExpr+")") - val sf = toStringForm(sfExpr, variablesToReplace) - val rhs = toStringLiteral(rhExpr.head, variablesToReplace) - (sf, rhs) match { - case (Some(sfget), Some(rhsget)) => - toEquations(sfget, rhsget) match { - case Some(equations) => - gatherEquations(q, acc ++= equations) - case None => - hctx.reporter.info("Could not extract equations from ["+sfget+"] == ["+rhsget+"]\n coming from ... == " + rhExpr) - None - } - case _ => - hctx.reporter.info("sf empty or rhs empty ["+sfExpr+"] = ["+rhExpr+"] => ["+sf+"] == ["+rhs+"]") - None - } - } - } else { - hctx.reporter.info("RHS.length != 1 : ["+rhExpr+"]") - None - } - } - - gatherEquations((examples.valids ++ examples.invalids).collect{ case io:InOutExample => io }.toList) match { - case Some(problem) => - StringSolver.solve(problem) - case None => Stream.empty - } - } - - /** With a given (template, fundefs, consts) will find consts so that (expr, funs) passes all the examples */ - def findSolutions(examples: ExamplesBank, - templateFunDefs: Stream[(WithIds[Expr], Seq[(FunDef, Stream[WithIds[Expr]])])])(implicit hctx: SearchContext, p: Problem): RuleApplication = { - // Fun is a stream of many function applications. - val funs = templateFunDefs.map{ case (template, funDefs) => - val funDefsSet = JoinProgramSet.direct(funDefs.map(fbody => fbody._2.map((fbody._1, _))).map(d => DirectProgramSet(d))) - JoinProgramSet.direct(funDefsSet, DirectProgramSet(Stream(template))).programs - } - - val wholeTemplates = StreamUtils.interleave(funs) - - def computeSolutions(funDefsBodies: Seq[(FunDef, WithIds[Expr])], template: WithIds[Expr]): Stream[Assignment] = { - val funDefs = for((funDef, body) <- funDefsBodies) yield { funDef.body = Some(body._1); funDef } - val newProgram = DefOps.addFunDefs(hctx.program, funDefs, hctx.functionContext) - val transformer = DefOps.funDefReplacer { fd => - if(fd == hctx.functionContext) { - val newfd = fd.duplicate() - newfd.body = Some(template._1) - Some(newfd) - } else None - } - val newProgram2 = DefOps.transformProgram(transformer, newProgram) - val newTemplate = ExprOps.postMap{ - case FunctionInvocation(TypedFunDef(fd, targs), exprs) => - Some(FunctionInvocation(TypedFunDef(transformer.transform(fd), targs), exprs)) - case _ => None - }(template._1) - val variablesToReplace = (template._2 ++ funDefsBodies.flatMap(_._2._2)).toSet - findAssignments(newProgram2, p.as.filter{ x => !x.getType.isInstanceOf[FunctionType] }, examples, newTemplate, variablesToReplace) - } - - val tagged_solutions = - for{(funDefs, template) <- wholeTemplates} yield computeSolutions(funDefs, template).map((funDefs, template, _)) - - solutionStreamToRuleApplication(p, leon.utils.StreamUtils.interleave(tagged_solutions))(hctx.program) - } - - /** Converts the stream of solutions to a RuleApplication */ - def solutionStreamToRuleApplication(p: Problem, solutions: Stream[(Seq[(FunDef, WithIds[Expr])], WithIds[Expr], Assignment)])(implicit program: Program): RuleApplication = { - if(solutions.isEmpty) RuleFailed() else { - RuleClosed( - for((funDefsBodies, (singleTemplate, ids), assignment) <- solutions) yield { - var _i = 0 - def i = { _i += 1; _i } - val fds = for((fd, (body, ids)) <- funDefsBodies) yield { - val initMap = ids.map(_ -> StringLiteral(EDIT_ME(i))).toMap - fd.body = Some(ExprOps.simplifyString(ExprOps.replaceFromIDs(initMap ++ assignment.mapValues(StringLiteral), body))) - fd - } - val initMap = ids.map(_ -> StringLiteral(EDIT_ME(i))).toMap - val term = ExprOps.simplifyString(ExprOps.replaceFromIDs(initMap ++ assignment.mapValues(StringLiteral), singleTemplate)) - val (finalTerm, finalDefs) = makeFunctionsUnique(term, fds.toSet) - - Solution(BooleanLiteral(true), finalDefs, finalTerm) - }) - } - } - - /** Crystallizes a solution so that it will not me modified if the body of fds is modified. */ - def makeFunctionsUnique(term: Expr, fds: Set[FunDef])(implicit program: Program): (Expr, Set[FunDef]) = { - var transformMap = Map[FunDef, FunDef]() - def mapExpr(body: Expr): Expr = { - ExprOps.preMap((e: Expr) => e match { - case FunctionInvocation(TypedFunDef(fd, _), args) if fds(fd) => Some(functionInvocation(getMapping(fd), args)) - case e => None - })(body) - } - - def getMapping(fd: FunDef): FunDef = { - transformMap.getOrElse(fd, { - val newfunDef = new FunDef(fd.id.freshen, fd.tparams, fd.params, fd.returnType) // With empty body - transformMap += fd -> newfunDef - newfunDef.body = fd.body.map(mapExpr _) - newfunDef - }) - } - - (mapExpr(term), fds.map(getMapping _)) - } - - - object StringTemplateGenerator extends TypedTemplateGenerator(StringType) - - trait FreshFunNameGenerator { - def funNames: Set[String] - def freshFunName(s: String): String = { - if(!funNames(s)) return s - var i = 1 - var s0 = s - do { - i += 1 - s0 = s + i - } while(funNames(s+i)) - s0 - } - } - trait PrettyPrinterProvider { - def provided_functions: Seq[Identifier] - } - type StringConverters = Map[TypeTree, List[Expr => Expr]] - - /** Creates an empty function definition for the dependent type */ - def createEmptyFunDef(ctx: FreshFunNameGenerator with PrettyPrinterProvider, vContext: List[TypeTree], hContext: List[TypeTree], typeToConvert: TypeTree)(implicit hctx: SearchContext): FunDef = { - def defaultFunName(t: TypeTree): String = t match { - case AbstractClassType(c, d) => c.id.asString(hctx) - case CaseClassType(c, d) => c.id.asString(hctx) - case t => t.asString(hctx) - } - - val funName2 = defaultFunName(typeToConvert) + ("" /: vContext) { - case (s, t) => "In" + defaultFunName(t) - } + (if(hContext.nonEmpty) hContext.length.toString else "") + "_s" - val funName3 = funName2.replaceAll("[^a-zA-Z0-9_]","") - val funName = funName3(0).toLower + funName3.substring(1) - val funId = FreshIdentifier(ctx.freshFunName(funName), alwaysShowUniqueID = false) - val argId= FreshIdentifier(typeToConvert.asString(hctx).toLowerCase().dropWhile(c => (c < 'a' || c > 'z') && (c < 'A' || c > 'Z')).headOption.getOrElse("x").toString, typeToConvert) - val tparams = hctx.functionContext.tparams - new FunDef(funId, tparams, ValDef(argId) :: ctx.provided_functions.map(ValDef(_)).toList, StringType) // Empty function. - } - - /** Assembles multiple MatchCase to a singleMatchExpr using the function definition fd */ - private val mergeMatchCases = (scrut: Expr) => (cases: Seq[WithIds[MatchCase]]) => (MatchExpr(scrut, cases.map(_._1)), cases.flatMap(_._2).toList) - private val mergeMatchCasesFd = (fd: FunDef) => mergeMatchCases(Variable(fd.params(0).id)) - - object FunDefTemplateGenerator { - protected val gcontext = new grammars.ContextGrammar[TypeTree, (Program, Set[FunDef]) => Stream[Expr => WithIds[Expr]]] - import gcontext._ - - protected val int32Symbol = NonTerminal(Int32Type) - protected val integerSymbol = NonTerminal(IntegerType) - protected val booleanSymbol = NonTerminal(BooleanType) - protected val stringSymbol = NonTerminal(StringType) - protected val bTemplateGenerator = (expr: Expr) => booleanTemplate(expr).instantiateWithVars - - // The pretty-printers are variable passed along in argument that have a type T => String for some type parameter T - def apply(argsInputs: Seq[Expr], argsPrettyPrinting: Seq[Identifier])(implicit hctx: SearchContext): GrammarBasedTemplateGenerator = { - implicit val program: Program = hctx.program - val startGrammar = Grammar( - (argsInputs.foldLeft(List[NonTerminal]()){(lb, i) => lb :+ NonTerminal(i.getType) }), - Map(int32Symbol -> TerminalRHS(Terminal(Int32Type)((prog, fds)=> Stream(expr => (Int32ToString(expr), Nil)))), - integerSymbol -> TerminalRHS(Terminal(IntegerType)((prog, fds) => Stream((expr => (IntegerToString(expr), Nil))))), - booleanSymbol -> TerminalRHS(Terminal(BooleanType)((prog, fds) => Stream((expr => (BooleanToString(expr), Nil)), bTemplateGenerator))), - stringSymbol -> TerminalRHS(Terminal(StringType)((prog, fds) => Stream((expr => (expr, Nil)),(expr => ((FunctionInvocation(program.library.escape.get.typed, Seq(expr)), Nil))))) - ))) - GrammarBasedTemplateGenerator(clean(exhaustive(startGrammar, argsPrettyPrinting)), argsInputs, argsPrettyPrinting) - } - - case class GrammarBasedTemplateGenerator(grammar: Grammar, argsInputs: Seq[Expr], argsPrettyPrinting: Seq[Identifier])(implicit hctx: SearchContext) { - /** Use with caution: These functions make the entire grammar increase exponentially in size*/ - def markovize_vertical() = copy(grammar=grammar.markovize_vertical()) - def markovize_horizontal() = copy(grammar=grammar.markovize_horizontal()) - def markovize_abstract_vertical() = copy(grammar=grammar.markovize_abstract_vertical()) - - /** Mark all occurrences of a given type so that we can differentiate its usage according to its rank from the left.*/ - def markovize_horizontal_nonterminal() = { - val selectedNt = getDuplicateCallsInSameRule() - //println("markovize_horizontal_nonterminal with "+selectedNt.map(symbolToString)+"...") - copy(grammar=grammar.markovize_horizontal_filtered(selectedNt, false)) - } - /** Mark all occurrences of a given type so that we can differentiate its usage according to its rank from the left.*/ - def markovize_horizontal_recursive_nonterminal() = { - val selectedNt = getDuplicateCallsInSameRule() - //println("markovize_horizontal_nonterminal with "+selectedNt.map(symbolToString)+"...") - copy(grammar=grammar.markovize_horizontal_filtered(selectedNt, true)) - } - - /** Mark all occurrences of a given type so that we can differentiate its usage depending from where it was taken from.*/ - def markovize_abstract_vertical_nonterminal() = { - //println("markovize_abstract_vertical_nonterminal...") - val selectedNt = getDirectlyRecursiveTypes() - copy(grammar=grammar.markovize_abstract_vertical_filtered(selectedNt)) - } - - /** Computes all possible variants of the grammar, from the simplest to the most complex.*/ - def allMarkovizations(): Stream[GrammarBasedTemplateGenerator] = { - Stream.from(1).flatMap(i => grammar.markovize_all(i).map(g => copy(grammar=g))) - } - - /** Mark all occurrences of a given type so that we can differentiate its usage depending from where it was taken from.*/ - def markovize_vertical_nonterminal() = { - //println("markovize_vertical_nonterminal...") - val selectedNt = getTypesAppearingAtMultiplePlaces() - copy(grammar=grammar.markovize_vertical_filtered(selectedNt)) - } - - def getAllTypes(): Set[TypeTree] = { - grammar.rules.keys.map(_.tag).toSet - } - - /** Monitoring data in the grammar */ - // Find all non-terminals which have a rule that use this type tree. - def getCallsfor(e: TypeTree): Seq[(NonTerminal, Expansion)] = { - grammar.rules.toSeq.filter{ case (k, v) => v.ls.exists(l => l.exists ( _.tag == e ))} - } - // Find all non-terminals which have the type tree on the RHS at least twice in the same rule. - // Used for horizontal markovization - def getDuplicateCallsInSameRule(): Set[NonTerminal] = { - def duplicates(l: List[Symbol], e: TypeTree) = { - val lnt = l.collect{case nt: NonTerminal => nt} - if(lnt.count ( _.tag == e ) >= 2) { - lnt.filter(_.tag == e) - } else Nil - } - getAllTypes().flatMap { e => - grammar.rules.toSeq.flatMap{ case (k, v) => - v.ls.flatMap{l => duplicates(l, e) } - }.toSet ++ duplicates(grammar.start.toList, e) - } - } - // Return types which call themselves in argument (and which might require vertical markovization to differentiate between an inner call and an outer call). - // Used for vertical markovization - def getDirectlyRecursiveTypes(): Set[NonTerminal] = { - grammar.rules.toSeq.flatMap{ case (k, v) => if(v match { - case AugmentedTerminalsRHS(_, VerticalRHS(children)) => children.exists(child => grammar.rules(child) match { - case AugmentedTerminalsRHS(_, HorizontalRHS(t, arguments)) => arguments.exists(_ == k) - case _ => false - }) - case _ => false}) - Seq(k) else Nil - }.toSet - } - // Returns non-terminals which appear on different RHS of different rules, and which require vertical markovization. - def getTypesAppearingAtMultiplePlaces(): Set[NonTerminal] = { - (grammar.rules.toSeq.flatMap{ case (k, v) => - v.ls.flatten - } ++ grammar.startNonTerminals). - groupBy { s => s }. - toSeq. - map(_._2). - filter(_.length >= 2). - flatMap(_.headOption). - collect{ case t: NonTerminal => t}. - toSet - } - - def buildAllFunDefTemplates(): Stream[(WithIds[Expr], Seq[(FunDef, Stream[WithIds[Expr]])])] = { - allMarkovizations().flatMap(_.buildFunDefTemplate(false)) - } - - /** Builds a set of fun defs out of the grammar */ - def buildFunDefTemplate(markovizations: Boolean = true): Stream[(WithIds[Expr], Seq[(FunDef, Stream[WithIds[Expr]])])] = { - // Collects all non-terminals. One non-terminal => One function. May regroup pattern matching in a separate simplifying phase. - val nts = grammar.nonTerminals - // Fresh function name generator. - val ctx = new FreshFunNameGenerator with PrettyPrinterProvider { - var funNames: Set[String] = Set() - override def freshFunName(s: String): String = { - val res = super.freshFunName(s) - funNames += res - res - } - def provided_functions = argsPrettyPrinting - } - // Matches a case class and returns its context type. - object TypedNonTerminal { - def unapply(nt: NonTerminal) = Some((nt.tag, nt.vcontext.map(_.tag), nt.hcontext.map(_.tag))) - } - /* We create FunDef for all on-terminals */ - val (funDefs, ctx2) = ((Map[NonTerminal, FunDef](), ctx) /: nts) { - case (mgen@(m, genctx), nt@TypedNonTerminal(tp, vct, hct)) => - (m + ((nt: NonTerminal) -> createEmptyFunDef(genctx, vct, hct, tp)), genctx) - } - // Look for default ways to pretty-print non-trivial functions. - val (newProg, newFuns) = //if(markovizations) - { - val functionsToAdd = funDefs.values.filter(fd => - fd.paramIds.headOption.map(x => - x.getType match { - case AbstractClassType(acd, targs) => hctx.program.library.List.get != acd - case _: MapType | _: SetType | _: BagType => false - case _ => true - } - ).getOrElse(false) - ).toSet - (DefOps.addFunDefs(hctx.program, functionsToAdd, hctx.functionContext), functionsToAdd) - } //else (hctx.program, Set[FunDef]()) - //println("In the following program: " + newProg) - - def rulesToBodies(e: Expansion, nt: NonTerminal, fd: FunDef): Stream[WithIds[Expr]] = { - val inputs = fd.params.map(_.id) - var customProgs: Stream[Expr => WithIds[Expr]] = Stream() - def filteredPrintersOf(t: Terminal): Stream[Expr => WithIds[Expr]] = { - val p = t.terminalData - val newFunsFiltered = fd.paramIds.headOption.map(_.getType) match { - case Some(forbiddenType) => - newFuns.filter(f => - f.paramIds.headOption.map(x => !TypeOps.isSubtypeOf(forbiddenType, x.getType)).getOrElse(true)) - case None => newFuns - } - //println("For fun " + fd.id.name +", taking only " + newFunsFiltered.map(_.id.name) + " into account") - p(newProg, newFunsFiltered) - } - def filteredPrintersOfAreNonEmpty(t: Terminal): Boolean = { - customProgs = filteredPrintersOf(t) - customProgs.nonEmpty - } - e match { - case TerminalRHS(terminal@Terminal(typeTree)) if filteredPrintersOfAreNonEmpty(terminal) => //Render this as a simple expression. - customProgs.map(f => f(Variable(inputs.head))) - case AugmentedTerminalsRHS(terminals, HorizontalRHS(terminal@Terminal(cct@CaseClassType(ccd, targs)), nts)) => // The subsequent calls of this function to sub-functions. - val fields = cct.classDef.fieldsIds.zip(cct.fieldsTypes) - val fieldstypes = fields.map{ case (id, tpe) => (tpe, (x: Expr) => CaseClassSelector(cct, x, id)) } - val builders = fieldstypes.flatMap(x => flattenTupleExtractors(x._1, x._2)) - - val childExprs = nts.zipWithIndex.map{ case (childNt:NonTerminal, childIndex) => - FunctionInvocation(TypedFunDef(funDefs(childNt), Seq()), List( - builders(childIndex)(Variable(inputs.head))) ++ - argsPrettyPrinting.map(Variable)) - } - terminals.map(filteredPrintersOf).toStream.flatten.map(_.apply(Variable(inputs.head))) #::: childExprs.map(x => (x, Nil)).permutations.toStream.map(interleaveIdentifiers) - case AugmentedTerminalsRHS(terminals, HorizontalRHS(terminal@Terminal(cct@TupleType(targs)), nts)) => // The subsequent calls of this function to sub-functions. - val fieldstypes = targs.zipWithIndex.map{case (tp, index) => (tp, (x: Expr) => TupleSelect(x, index+1)) } - val builders = fieldstypes.flatMap(x => flattenTupleExtractors(x._1, x._2)) - - val childExprs = nts.zipWithIndex.map{ case (childNt:NonTerminal, childIndex) => - FunctionInvocation(TypedFunDef(funDefs(childNt), Seq()), List( - builders(childIndex)(Variable(inputs.head))) ++ - argsPrettyPrinting.map(Variable)) - } - terminals.map(filteredPrintersOf).toStream.flatten.map(_.apply(Variable(inputs.head))) #::: childExprs.map(x => (x, Nil)).permutations.toStream.map(interleaveIdentifiers) - case AugmentedTerminalsRHS(terminals, VerticalRHS(children)) => // Match statement. - assert(inputs.length == 1 + argsPrettyPrinting.length) - val idInput = inputs.head - val scrut = Variable(idInput) - val matchCases = nt.tag match { - case AbstractClassType(acd, typeArgs) => - if(acd.id.name == "ThreadId") { - println("terminals for ThreadId:\n"+terminals.length) - println("result") - println(terminals.map(filteredPrintersOf).toStream.flatten.map(_.apply(Variable(inputs.head)))) - } - acd.knownCCDescendants map { ccd => - children.find(childNt => childNt.tag match { - case CaseClassType(`ccd`, `typeArgs`) => true - case _ => false - }) match { - case Some(nt) => - val matchInput = idInput.duplicate(tpe = nt.tag) - MatchCase(InstanceOfPattern(Some(matchInput), nt.tag.asInstanceOf[ClassType]), None, - FunctionInvocation(TypedFunDef(funDefs(nt), Seq()), List(Variable(matchInput)) ++ argsPrettyPrinting.map(Variable))) - case None => throw new Exception(s"Could not find $ccd in the children non-terminals $children") - } - } - case t => - throw new Exception(s"Should have been Vertical RHS, got $t. Rule:\n$nt -> $e\nFunDef:\n$fd") - } - - terminals.map(filteredPrintersOf).toStream.flatten.map(_.apply(Variable(inputs.head))) #::: Stream((MatchExpr(scrut, matchCases): Expr, Nil: List[Identifier])) - case _ => throw new Exception("No toString conversion found for " + nt) - } - } - - //println("Extracting functions from grammar:\n" + grammarToString(grammar).replaceAll("\\$", "_").replaceAll("\\[T3\\]", "T3").replaceAll("\\(|\\)","").replaceAll("<function1>","")) - - // We create the bodies of these functions - val possible_functions = for((nt, fd) <- funDefs.toSeq) yield { - val bodies: Stream[WithIds[Expr]] = rulesToBodies(grammar.rules(nt), nt, fd) - (fd, bodies/*.map{ b => println("Testing another body for " + fd.id.name + "\n" + b); b}*/) - } - - val inputExprs = grammar.startNonTerminals.zipWithIndex.map{ case (childNt, childIndex) => - (FunctionInvocation(TypedFunDef(funDefs(childNt), Seq()), Seq(argsInputs(childIndex)) ++ argsPrettyPrinting.map(Variable)), Nil) - } - //println("Found grammar\n" + grammarToString(grammar)) - - val startExprStream = inputExprs.permutations.toStream.map(inputs => - interleaveIdentifiers(inputs) - ) - - startExprStream.map(i => (i, possible_functions)) #::: // 1) Expressions without markovizations - (if(markovizations) { - Stream(markovize_horizontal_recursive_nonterminal(), markovize_horizontal_nonterminal()) - .flatMap(grammar => grammar.buildFunDefTemplateAndContinue( - gtg => gtg.allMarkovizations().flatMap(m => m.buildFunDefTemplate(false)))) - //markovize_abstract_vertical_nonterminal().buildFunDefTemplateAndContinue( _. - //markovize_vertical_nonterminal().buildFunDefTemplate(false)))) - } else Stream.empty) - // The Stream[WithIds[Expr]] is given thanks to the first formula with the start symbol. - // The FunDef are computed by recombining vertical rules into one pattern matching, and each expression using the horizontal children. - } - - def buildFunDefTemplateAndContinue(continueWith: GrammarBasedTemplateGenerator => (Stream[(WithIds[Expr], Seq[(FunDef, Stream[WithIds[Expr]])])])): (Stream[(WithIds[Expr], Seq[(FunDef, Stream[WithIds[Expr]])])]) = { - buildFunDefTemplate(false) #::: (continueWith(this)) - } - } - - protected def flattenTupleType(t: TypeTree): Seq[TypeTree] = { - t match { - case TupleType(targs) => targs.flatMap(flattenTupleType) - case t => Seq(t) - } - } - protected def flattenTupleExtractors(t: TypeTree, builder: Expr => Expr): Seq[Expr => Expr] = { - t match { - case TupleType(targs) => targs.zipWithIndex.flatMap{ - case (t, i) => flattenTupleExtractors(t, builder andThen ((x: Expr) => TupleSelect(x, i+1))) - } - case t => Seq(builder) - } - } - - protected def customPrettyPrinters(inputType: TypeTree, program: Program, allowedPrinters: Set[FunDef])(implicit hctx: SearchContext): List[Expr => WithIds[Expr]] = { - val toExclude: Set[FunDef] = - if(hctx.functionContext.paramIds.headOption.map(x => TypeOps.isSubtypeOf(inputType, x.getType)).getOrElse(false)) - Set(hctx.functionContext) - else - Set() - //println("Looking for pp of type " + inputType + ", excluded = " + toExclude.map(_.id.name) + ", allowed = " + allowedPrinters.map(_.id.name)) - val exprs1s: Stream[(Lambda, Expr => WithIds[Expr])] = (new SelfPrettyPrinter) - .allowFunctions(allowedPrinters + hctx.functionContext) - .excludeFunctions(toExclude + hctx.program.library.escape.get) - .withPossibleParameters - .prettyPrintersForType(inputType)(hctx, program) - .map{ case (l, identifiers) => (l, (input: Expr) => (application(l, Seq(input)), identifiers)) } // Use already pre-defined pretty printers. - val e = exprs1s.toList.sortBy{ case (Lambda(_, FunctionInvocation(tfd, _)), _) if tfd.fd == hctx.functionContext => 0 case _ => 1}.map(_._2) - //println("looking for functions to print " + inputType + ", got " + e) - e - } - - def constantPatternMatching(act: AbstractClassType)(implicit hctx: SearchContext): Stream[Expr => WithIds[Expr]] = { - val allKnownDescendantsAreCCAndHaveZeroArgs = act.knownCCDescendants.forall { - case CaseClassType(ccd, tparams2) => ccd.fields.isEmpty - case _ => false - } - if(act.id.name == "ThreadId") println("For ThreadId, allKnownDescendantsAreCCAndHaveZeroArgs = " + allKnownDescendantsAreCCAndHaveZeroArgs) - if(allKnownDescendantsAreCCAndHaveZeroArgs) { - val cases = (ListBuffer[WithIds[MatchCase]]() /: act.knownCCDescendants) { - case (acc, cct @ CaseClassType(ccd, tparams2)) => - val typeMap = ccd.typeArgs.zip(tparams2).toMap - val fields = ccd.fields.map(vd => TypeOps.instantiateType(vd.id, typeMap) ) - val pattern = CaseClassPattern(None, ccd.typed(tparams2), fields.map(k => WildcardPattern(Some(k)))) - val rhs = StringLiteral(ccd.id.asString) - MatchCase(pattern, None, rhs) - acc += ((MatchCase(pattern, None, rhs), Nil)) - case (acc, e) => hctx.reporter.fatalError("Could not handle this class definition for string rendering " + e) - } - if(act.id.name == "ThreadId") println("For ThreadId, cases = " + cases) - if(cases.nonEmpty) { - Stream((x: Expr) => mergeMatchCases(x)(cases)) - } else Stream.Empty - } else Stream.Empty - } - - /** Used to produce rules such as Cons => Elem List without context*/ - protected def horizontalChildren(n: NonTerminal)(implicit hctx: SearchContext): Option[Expansion] = n match { - case NonTerminal(cct@CaseClassType(ccd: CaseClassDef, tparams2), vc, hc) => - val flattenedTupleds = cct.fieldsTypes.flatMap(flattenTupleType) - val customs = (p: Program, fds: Set[FunDef]) => customPrettyPrinters(cct, p,fds).toStream - Some(AugmentedTerminalsRHS(Seq(Terminal(cct)(customs)), - HorizontalRHS(Terminal(cct)((prog, fds) => Stream.empty), flattenedTupleds.map(NonTerminal(_))))) - case NonTerminal(cct@TupleType(fieldsTypes), vc, hc) => - val flattenedTupleds = fieldsTypes.flatMap(flattenTupleType) - val customs = (p: Program, fds: Set[FunDef]) => customPrettyPrinters(cct, p, fds).toStream - Some(AugmentedTerminalsRHS(Seq(Terminal(cct)(customs)), - HorizontalRHS(Terminal(cct)((prog, fds) => Stream.empty), flattenedTupleds.map(NonTerminal(_))))) - case NonTerminal(otherType, vc, hc) if !otherType.isInstanceOf[AbstractClassType] => - val customs = (p: Program, fds: Set[FunDef]) => customPrettyPrinters(otherType, p, fds).toStream - //if(customs.nonEmpty) { - Some(AugmentedTerminalsRHS(Seq(Terminal(otherType)(customs)), - TerminalRHS(Terminal(otherType)((prog, fds) => Stream.empty)))) - //} else None - case _ => None - } - /** Used to produce rules such as List => Cons | Nil without context */ - protected def verticalChildren(n: NonTerminal)(implicit hctx: SearchContext): Option[Expansion] = n match { - case NonTerminal(act@AbstractClassType(acd: AbstractClassDef, tps), vc, hc) => - if(act.id.name == "ThreadId") println("Creating custom pretty printers for type ThreadId") - val customs = (p: Program, fds: Set[FunDef]) => constantPatternMatching(act) #::: customPrettyPrinters(act, p, fds).toStream - Some(AugmentedTerminalsRHS(Seq(Terminal(act)(customs)), - VerticalRHS(act.knownDescendants.map(tag => NonTerminal(tag))))) - case _ => None - } - - /** Find all direct calls to existing variables render the given type */ - protected def terminalChildren(n: NonTerminal, prettyPrinters: Seq[Identifier]): Option[Expansion] = n match { - case NonTerminal(tp: TypeParameter, vc, hc) => - val possible_pretty_printers = prettyPrinters.map(x => (x, x.getType)).collect{ case (id, FunctionType(tp, StringType)) => id} - val callers = possible_pretty_printers.toStream.map{ - case id => (x: Expr) => (Application(Variable(id), Seq(x)), Nil) - } - Some(TerminalRHS(Terminal(tp)((p, fds) => callers))) - case _ => None - } - - /** Find all dependencies and merge them into one grammar */ - protected def extendGrammar(prettyPrinters: Seq[Identifier])(grammar: Grammar)(implicit hctx: SearchContext): Grammar = { - val nts = grammar.nonTerminals - (grammar /: nts) { - case (grammar, n) => - /** If the grammar does not contain any rule for n, add them */ - if(!(grammar.rules contains n)) { - grammar.copy(rules = - grammar.rules + - (n -> ( - Expansion(Nil) ++ - terminalChildren(n, prettyPrinters) ++ - horizontalChildren(n) ++ - verticalChildren(n)))) - } else grammar - } - } - - /** Applies the transformation extendGrammar until the grammar reaches its fix point. */ - protected def exhaustive(grammar: Grammar, prettyPrinters: Seq[Identifier])(implicit hctx: SearchContext): Grammar = { - leon.utils.fixpoint(extendGrammar(prettyPrinters))(grammar) - } - } - - /** Transforms a sequence of identifiers into a single expression - * with new string constant identifiers interleaved between, before and after them. */ - def interleaveIdentifiers(exprs: Seq[WithIds[Expr]]): WithIds[Expr] = { - if(exprs.isEmpty) { - StringTemplateGenerator(Hole => Hole).instantiateWithVars - } else { - StringTemplateGenerator.nested(Hole => { - val res = ((StringConcat(Hole, exprs.head._1), exprs.head._2) /: exprs.tail) { - case ((finalExpr, finalIds), (expr, ids)) => (StringConcat(StringConcat(finalExpr, Hole), expr), finalIds ++ ids) - } - (StringConcat(res._1, Hole), res._2) - }).instantiateWithVars - } - } - - def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { - //hctx.reporter.debug("StringRender:Output variables="+p.xs+", their types="+p.xs.map(_.getType)) - if(hctx.currentNode.parent.nonEmpty) Nil else - p.xs match { - case List(IsTyped(v, StringType)) => - val examplesFinder = new ExamplesFinder(hctx, hctx.program) - .setKeepAbstractExamples(true) - .setEvaluationFailOnChoose(true) - - val examples = examplesFinder.extractFromProblem(p) - - val abstractStringConverters: StringConverters = p.as.flatMap { case x => - x.getType match { - case FunctionType(Seq(aType), StringType) => - List((aType, (arg: Expr) => application(Variable(x), Seq(arg)))) - case _ => Nil - } - }.groupBy(_._1).mapValues(_.map(_._2)) - - val (inputVariables, functionVariables) = p.as.partition ( x => x.getType match { - case f: FunctionType => false - case _ => true - }) - - val ruleInstantiations = ListBuffer[RuleInstantiation]() - val originalInputs = inputVariables.map(Variable) - ruleInstantiations += RuleInstantiation("String conversion") { - val synthesisResult = FunDefTemplateGenerator(originalInputs, functionVariables).buildFunDefTemplate(true) - - findSolutions(examples, synthesisResult) - } - - ruleInstantiations.toList - - case _ => Nil - } - } -} diff --git a/src/main/scala/leon/synthesis/rules/UnconstrainedOutput.scala b/src/main/scala/leon/synthesis/rules/UnconstrainedOutput.scala deleted file mode 100644 index 089c3da45c45c22ef9ff4f84bb924a33f076bce8..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/rules/UnconstrainedOutput.scala +++ /dev/null @@ -1,36 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis -package rules - -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Constructors._ -import purescala.TypeOps._ - -case object UnconstrainedOutput extends NormalizingRule("Unconstr.Output") { - def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { - val unconstr = (p.xs.toSet -- variablesOf(p.phi)).filter { x => - isRealExpr(simplestValue(x.getType)) - } - - if (unconstr.nonEmpty) { - val sub = p.copy(xs = p.xs.filterNot(unconstr), eb = p.qeb.removeOuts(unconstr)) - - val onSuccess: List[Solution] => Option[Solution] = { - case List(s) => - val term = letTuple(sub.xs, s.term, tupleWrap(p.xs.map(id => if (unconstr(id)) simplestValue(id.getType) else Variable(id)))) - - Some(Solution(s.pre, s.defs, term, s.isTrusted)) - case _ => - None - } - - Some(decomp(List(sub), onSuccess, s"Unconst. out ${p.xs.filter(unconstr).mkString(", ")}")) - } else { - None - } - } -} - diff --git a/src/main/scala/leon/synthesis/rules/Unification.scala b/src/main/scala/leon/synthesis/rules/Unification.scala deleted file mode 100644 index 419d256503de66f157094149e39cce0b0729d282..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/rules/Unification.scala +++ /dev/null @@ -1,61 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis -package rules - -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Extractors._ -import purescala.Constructors._ - -object Unification { - case object DecompTrivialClash extends NormalizingRule("Unif Dec./Clash/Triv.") { - def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { - val TopLevelAnds(exprs) = p.phi - - val (toRemove, toAdd) = exprs.collect { - case eq @ Equals(cc1 @ CaseClass(cd1, args1), cc2 @ CaseClass(cd2, args2)) => - if (cc1 == cc2) { - (eq, List(BooleanLiteral(true))) - } else if (cd1 == cd2) { - (eq, (args1 zip args2).map((Equals(_, _)).tupled)) - } else { - (eq, List(BooleanLiteral(false))) - } - }.unzip - - if (toRemove.nonEmpty) { - val sub = p.copy(phi = andJoin((exprs.toSet -- toRemove ++ toAdd.flatten).toSeq)) - - List(decomp(List(sub), forward, this.name)) - } else { - Nil - } - } - } - - // This rule is probably useless; it never happens except in crafted - // examples, and will be found by OptimisticGround anyway. - case object OccursCheck extends NormalizingRule("Unif OccursCheck") { - def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { - val TopLevelAnds(exprs) = p.phi - - val isImpossible = exprs.exists { - case eq @ Equals(cc : CaseClass, Variable(id)) if variablesOf(cc) contains id => - true - case eq @ Equals(Variable(id), cc : CaseClass) if variablesOf(cc) contains id => - true - case _ => - false - } - - if (isImpossible) { - List(solve(Solution.UNSAT)) - } else { - Nil - } - } - } -} - diff --git a/src/main/scala/leon/synthesis/rules/UnusedInput.scala b/src/main/scala/leon/synthesis/rules/UnusedInput.scala deleted file mode 100644 index 81e03a064db7a7eec92c191815230fc4cdbc657c..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/rules/UnusedInput.scala +++ /dev/null @@ -1,24 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis -package rules - -import purescala.ExprOps._ -import purescala.TypeOps._ - -case object UnusedInput extends NormalizingRule("UnusedInput") { - def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { - val unused = (p.as.toSet -- variablesOf(p.phi) -- p.pc.variables -- variablesOf(p.ws)).filter { a => - !isParametricType(a.getType) - } - - if (unused.nonEmpty) { - val sub = p.copy(as = p.as.filterNot(unused), eb = p.qeb.removeIns(unused)) - - List(decomp(List(sub), forward, s"Unused inputs ${p.as.filter(unused).mkString(", ")}")) - } else { - Nil - } - } -} diff --git a/src/main/scala/leon/synthesis/rules/unused/ADTInduction.scala b/src/main/scala/leon/synthesis/rules/unused/ADTInduction.scala deleted file mode 100644 index 0c1aa15e42f030ff24496dd4be05bc96ea1258b7..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/rules/unused/ADTInduction.scala +++ /dev/null @@ -1,141 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis -package rules.unused - -import purescala.Path -import purescala.Common._ -import purescala.Expressions._ -import purescala.Extractors._ -import purescala.Constructors._ -import purescala.ExprOps._ -import purescala.Types._ -import purescala.Definitions._ - -/** For every inductive variable, outputs a recursive solution if it exists */ -case object ADTInduction extends Rule("ADT Induction") { - def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { - /* All input variables which are inductive in the post condition, along with their abstract data type. */ - val candidates = p.as.collect { - case IsTyped(origId, act: AbstractClassType) if isInductiveOn(hctx.solverFactory)(p.pc, origId) => (origId, act) - } - - val instances = for (candidate <- candidates) yield { - val (origId, ct) = candidate - // All input variables except the current inductive one. - val oas = p.as.filterNot(_ == origId) - - // The return type (all output variables). - val resType = tupleTypeWrap(p.xs.map(_.getType)) - - // A new fresh variable similar to the current inductive one to perform induction. - val inductOn = FreshIdentifier(origId.name, origId.getType, true) - - // Duplicated arguments names based on existing remaining input variables. - val residualArgs = oas.map(id => FreshIdentifier(id.name, id.getType, true)) - - // Mapping from existing input variables to the new duplicated ones. - val residualMap = (oas zip residualArgs).map{ case (id, id2) => id -> Variable(id2) }.toMap - - // The value definition to be used in arguments of the recursive method. - val residualArgDefs = residualArgs.map(ValDef(_)) - - // Returns true if the case class has a field of type the one of the induction variable - // E.g. for `List` it returns true since `Cons(a: T, q: List[T])` and Cons is a List[T] - def isAlternativeRecursive(ct: CaseClassType): Boolean = { - ct.fields.exists(_.getType == origId.getType) - } - - // True if one of the case classes has a field with the type being the one of the induction variable - val isRecursive = ct.knownCCDescendants.exists(isAlternativeRecursive) - - // Map for getting a formula in the context of within the recursive function - val substMap = residualMap + (origId -> Variable(inductOn)) - - if (isRecursive) { - - // Transformation of conditions, variables and axioms to use the inner variables of the inductive function. - val innerPhi = substAll(substMap, p.phi) - val innerPC = p.pc map (substAll(substMap, _)) - val innerWS = substAll(substMap, p.ws) - - val subProblemsInfo = for (cct <- ct.knownCCDescendants) yield { - var recCalls = Map[List[Identifier], List[Expr]]() - var postFs = List[Expr]() - - val newIds = cct.fields.map(vd => FreshIdentifier(vd.id.name, vd.getType, true)).toList - - val inputs = (for (id <- newIds) yield { - if (id.getType == origId.getType) { - val postXs = p.xs map (id => FreshIdentifier("r", id.getType, true)) - val postXsMap = (p.xs zip postXs).toMap.mapValues(Variable) - - recCalls += postXs -> (Variable(id) +: residualArgs.map(id => Variable(id))) - - postFs ::= substAll(postXsMap + (inductOn -> Variable(id)), innerPhi) - id :: postXs - } else { - List(id) - } - }).flatten - - val subPhi = substAll(Map(inductOn -> CaseClass(cct, newIds.map(Variable))), innerPhi) - val subPC = innerPC map (substAll(Map(inductOn -> CaseClass(cct, newIds.map(Variable))), _)) - val subWS = substAll(Map(inductOn -> CaseClass(cct, newIds.map(Variable))), innerWS) - - val subPre = IsInstanceOf(Variable(origId), cct) - - val subProblem = Problem(inputs ::: residualArgs, subWS, subPC withConds postFs, subPhi, p.xs) - - (subProblem, subPre, cct, newIds, recCalls) - } - - val onSuccess: List[Solution] => Option[Solution] = { - case sols => - var globalPre = List[Expr]() - - // The recursive inner function - val newFun = new FunDef(FreshIdentifier("rec", alwaysShowUniqueID = true), Nil, ValDef(inductOn) +: residualArgDefs, resType) - - val cases = for ((sol, (problem, pre, cct, ids, calls)) <- sols zip subProblemsInfo) yield { - globalPre ::= and(pre, sol.pre) - - val caze = CaseClassPattern(None, cct, ids.map(id => WildcardPattern(Some(id)))) - SimpleCase(caze, calls.foldLeft(sol.term){ case (t, (binders, callargs)) => letTuple(binders, FunctionInvocation(newFun.typed, callargs), t) }) - } - - // Might be overly picky with obviously true pre (a.is[Cons] OR a.is[Nil]) - if (false && sols.exists(_.pre != BooleanLiteral(true))) { - // Required to avoid an impossible cases, which suffices to - // allow invalid programs. This might be too strong though: we - // might only have to enforce it on solutions of base cases. - None - } else { - val outerPre = orJoin(globalPre) - val funPre = p.pc withCond outerPre map (substAll(substMap, _)) - val funPost = substAll(substMap, p.phi) - val idPost = FreshIdentifier("res", resType) - - newFun.precondition = funPre - newFun.postcondition = Some(Lambda(Seq(ValDef(idPost)), letTuple(p.xs.toSeq, Variable(idPost), funPost))) - - newFun.body = Some(matchExpr(Variable(inductOn), cases)) - - Some(Solution(outerPre, - sols.flatMap(_.defs).toSet + newFun, - FunctionInvocation(newFun.typed, Variable(origId) :: oas.map(Variable)), - sols.forall(_.isTrusted) - )) - } - } - - Some(decomp(subProblemsInfo.map(_._1).toList, onSuccess, s"ADT Induction on '${origId.asString}'")) - } else { // If none of the descendants of the type is recursive, then nothing can be done. - None - } - } - - instances.flatten - } -} diff --git a/src/main/scala/leon/synthesis/rules/unused/ADTLongInduction.scala b/src/main/scala/leon/synthesis/rules/unused/ADTLongInduction.scala deleted file mode 100644 index e105ef9a71aaf5b2df8a23f1b1f7907360135202..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/rules/unused/ADTLongInduction.scala +++ /dev/null @@ -1,174 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis -package rules.unused - -import purescala.Path -import purescala.Common._ -import purescala.Expressions._ -import purescala.Extractors._ -import purescala.Constructors._ -import purescala.ExprOps._ -import purescala.Types._ -import purescala.Definitions._ - -case object ADTLongInduction extends Rule("ADT Long Induction") { - def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { - val candidates = p.as.collect { - case IsTyped(origId, act @ AbstractClassType(cd, tpe)) if isInductiveOn(hctx.solverFactory)(p.pc, origId) => (origId, act) - } - - val instances = for (candidate <- candidates) yield { - val (origId, ct) = candidate - val oas = p.as.filterNot(_ == origId) - - - val resType = tupleTypeWrap(p.xs.map(_.getType)) - - val inductOn = FreshIdentifier(origId.name, origId.getType, true) - val residualArgs = oas.map(id => FreshIdentifier(id.name, id.getType, true)) - val residualMap = (oas zip residualArgs).map{ case (id, id2) => id -> Variable(id2) }.toMap - val residualArgDefs = residualArgs.map(ValDef(_)) - - def isAlternativeRecursive(ct: CaseClassType): Boolean = { - ct.fields.exists(_.getType == origId.getType) - } - - val isRecursive = ct.knownCCDescendants.exists(isAlternativeRecursive) - - // Map for getting a formula in the context of within the recursive function - val substMap = residualMap + (origId -> Variable(inductOn)) - - - if (isRecursive) { - case class InductCase(ids: List[Identifier], - calls: List[Identifier], - pattern: Pattern, - outerPC: Path, - trMap: Map[Identifier, Expr]) - - val init = InductCase(inductOn :: residualArgs, List(), WildcardPattern(Some(inductOn)), Path.empty, Map(inductOn -> Variable(inductOn))) - - def isRec(id: Identifier) = id.getType == origId.getType - - def unrollPattern(id: Identifier, cct: CaseClassType, withIds: List[Identifier])(on: Pattern): Pattern = on match { - case WildcardPattern(Some(pid)) if pid == id => - CaseClassPattern(None, cct, withIds.map(id => WildcardPattern(Some(id)))) - - case CaseClassPattern(binder, sccd, sub) => - CaseClassPattern(binder, sccd, sub.map(unrollPattern(id, cct, withIds))) - - case _ => on - } - - def unroll(ic: InductCase): List[InductCase] = { - if (ic.ids.exists(isRec)) { - val InductCase(ids, calls, pat, pc, trMap) = ic - - (for (id <- ids if isRec(id)) yield { - for (cct <- ct.knownCCDescendants) yield { - val subIds = cct.fields.map(vd => FreshIdentifier(vd.id.name, vd.getType, true)).toList - - val newIds = ids.filterNot(_ == id) ++ subIds - val newCalls = if (subIds.nonEmpty) { - List(subIds.find(isRec)).flatten - } else { - List() - } - - //println(ccd) - //println(subIds) - val newPattern = unrollPattern(id, cct, subIds)(pat) - - val newMap = trMap.mapValues(v => substAll(Map(id -> CaseClass(cct, subIds.map(Variable))), v)) - - InductCase(newIds, newCalls, newPattern, pc withCond IsInstanceOf(Variable(id), cct), newMap) - } - }).flatten - } else { - List(ic) - } - } - - val cases = unroll(init).flatMap(unroll) - - val innerPhi = substAll(substMap, p.phi) - val innerPC = p.pc map (substAll(substMap, _)) - val innerWS = substAll(substMap, p.ws) - - val subProblemsInfo = for (c <- cases) yield { - val InductCase(ids, calls, pat, pc, trMap) = c - - // generate subProblem - - var recCalls = Map[List[Identifier], List[Expr]]() - - val subPC = innerPC map (substAll(trMap, _)) - val subWS = substAll(trMap, innerWS) - val subPhi = substAll(trMap, innerPhi) - - var postXss = List[Identifier]() - var postFs = List[Expr]() - - for (cid <- calls) { - val postXs = p.xs map (id => FreshIdentifier("r", id.getType, true)) - postXss = postXss ::: postXs - val postXsMap = (p.xs zip postXs).toMap.mapValues(Variable) - postFs = substAll(postXsMap + (inductOn -> Variable(cid)), innerPhi) :: postFs - - recCalls += postXs -> (Variable(cid) +: residualArgs.map(id => Variable(id))) - } - - val subProblem = Problem(c.ids ::: postXss, subWS, subPC withConds postFs, subPhi, p.xs) - //println(subProblem) - //println(recCalls) - (subProblem, pat, recCalls, pc) - } - - val onSuccess: List[Solution] => Option[Solution] = { - case sols => - var globalPre = List.empty[Path] - - val newFun = new FunDef(FreshIdentifier("rec", alwaysShowUniqueID = true), Nil, ValDef(inductOn) +: residualArgDefs, resType) - - val cases = for ((sol, (problem, pat, calls, pc)) <- sols zip subProblemsInfo) yield { - globalPre ::= (pc withCond sol.pre) - - SimpleCase(pat, calls.foldLeft(sol.term){ case (t, (binders, callargs)) => letTuple(binders, FunctionInvocation(newFun.typed, callargs), t) }) - } - - // Might be overly picky with obviously true pre (a.is[Cons] OR a.is[Nil]) - if (false && sols.exists(_.pre != BooleanLiteral(true))) { - // Required to avoid an impossible cases, which suffices to - // allow invalid programs. This might be too strong though: we - // might only have to enforce it on solutions of base cases. - None - } else { - val outerPre = orJoin(globalPre.map(_.toClause)) - val funPre = p.pc withCond outerPre map (substAll(substMap, _)) - val funPost = substAll(substMap, p.phi) - val idPost = FreshIdentifier("res", resType) - - newFun.precondition = funPre - newFun.postcondition = Some(Lambda(Seq(ValDef(idPost)), letTuple(p.xs.toSeq, Variable(idPost), funPost))) - - newFun.body = Some(matchExpr(Variable(inductOn), cases)) - - Some(Solution(outerPre, - sols.flatMap(_.defs).toSet + newFun, - FunctionInvocation(newFun.typed, Variable(origId) :: oas.map(Variable)), - sols.forall(_.isTrusted) - )) - } - } - - Some(decomp(subProblemsInfo.map(_._1), onSuccess, s"ADT Long Induction on '${origId.asString}'")) - } else { - None - } - } - - instances.flatten - } -} diff --git a/src/main/scala/leon/synthesis/rules/unused/AsChoose.scala b/src/main/scala/leon/synthesis/rules/unused/AsChoose.scala deleted file mode 100644 index 06270b939929a96985719b78a3f58ee58d9ce931..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/rules/unused/AsChoose.scala +++ /dev/null @@ -1,12 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon.synthesis.rules.unused - -import leon.synthesis._ - -case object AsChoose extends Rule("As Choose") { - def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { - Some(solve(Solution.choose(p))) - } -} - diff --git a/src/main/scala/leon/synthesis/rules/unused/BottomUpTegis.scala b/src/main/scala/leon/synthesis/rules/unused/BottomUpTegis.scala deleted file mode 100644 index aaead75b49dd07882654b0940d1279625107bc33..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/rules/unused/BottomUpTegis.scala +++ /dev/null @@ -1,129 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis -package rules.unused - -import purescala.Expressions._ -import purescala.Common._ -import purescala.Types._ -import purescala.Constructors._ -import evaluators._ -import codegen.CodeGenParams - -import leon.grammars._ - -import bonsai.enumerators._ - -case object BottomUpTEGIS extends BottomUpTEGISLike("BU TEGIS") { - def getGrammar(sctx: SynthesisContext, p: Problem) = { - grammars.default(sctx, p) - } - - def getRootLabel(tpe: TypeTree): Label = Label(tpe) -} - -abstract class BottomUpTEGISLike(name: String) extends Rule(name) { - def getGrammar(sctx: SynthesisContext, p: Problem): ExpressionGrammar - - def getRootLabel(tpe: TypeTree): Label - - def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { - - val tests = p.eb.valids.collect { - case io: InOutExample => (io.ins, io.outs) - } - - if (tests.nonEmpty) { - List(new RuleInstantiation(this.name) { - def apply(hctx: SearchContext): RuleApplication = { - - val evalParams = CodeGenParams.default.copy(maxFunctionInvocations = 2000) - //val evaluator = new CodeGenEvaluator(sctx.context, sctx.program, evalParams) - //val evaluator = new DefaultEvaluator(sctx.context, sctx.program) - val evaluator = new DualEvaluator(hctx, hctx.program, params = evalParams) - - val grammar = getGrammar(hctx, p) - - val nTests = tests.size - - var compiled = Map[ProductionRule[Label, Expr], Vector[Vector[Expr]] => Option[Vector[Expr]]]() - - /** - * Compile Generators to functions from Expr to Expr. The compiled - * generators will be passed to the enumerator - */ - def compile(gen: ProductionRule[Label, Expr]): Vector[Vector[Expr]] => Option[Vector[Expr]] = { - compiled.getOrElse(gen, { - val executor = if (gen.subTrees.isEmpty) { - - { (vecs: Vector[Vector[Expr]]) => - val expr = gen.builder(Nil) - val res = tests.map { case (is, o) => (p.as zip is).toMap }.flatMap { case inputs => - evaluator.eval(expr, inputs) match { - case EvaluationResults.Successful(out) => Some(out) - case _ => None - } - }.toVector - - if (res.size == nTests) { - Some(res) - } else { - None - } - } - } else { - val args = gen.subTrees.map(t => FreshIdentifier("arg", t.getType, true)) - val argsV = args.map(_.toVariable) - val expr = gen.builder(argsV) - val ev = evaluator.compile(expr, args).get - - { (vecs: Vector[Vector[Expr]]) => - val res = (0 to nTests-1).toVector.flatMap { i => - val inputs = new solvers.Model((args zip vecs.map(_(i))).toMap) - ev(inputs) match { - case EvaluationResults.Successful(out) => Some(out) - case _ => - None - } - } - - if (res.size == nTests) { - Some(res) - } else { - None - } - } - } - - compiled += gen -> executor - executor - }) - } - - val targetType = tupleTypeWrap(p.xs.map(_.getType)) - val wrappedTests = tests.map { case (is, os) => (is, tupleWrap(os))} - - val enum = new BottomUpEnumerator[Label, Expr, Expr, ProductionRule[Label, Expr]]( - grammar.getProductions(_)(hctx), - wrappedTests, - { (vecs, gen) => - compile(gen)(vecs) - }, - 3 - ) - - val matches = enum.iterator(getRootLabel(targetType)) - - if (matches.hasNext) { - RuleClosed(Solution(BooleanLiteral(true), Set(), matches.next(), isTrusted = false)) - } else { - RuleFailed() - } - } - }) - } else { - Nil - } - } -} diff --git a/src/main/scala/leon/synthesis/rules/unused/DetupleOutput.scala b/src/main/scala/leon/synthesis/rules/unused/DetupleOutput.scala deleted file mode 100644 index 18067e9529596329d44cacf9672931c6e0178414..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/rules/unused/DetupleOutput.scala +++ /dev/null @@ -1,56 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis -package rules.unused - -import purescala.Expressions._ -import purescala.Common._ -import purescala.Types._ -import purescala.Constructors._ - -case object DetupleOutput extends Rule("Detuple Out") { - - def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { - def isDecomposable(id: Identifier) = id.getType match { - case CaseClassType(t, _) if !t.isAbstract => true - case _ => false - } - - if (p.xs.exists(isDecomposable)) { - var subProblem = p.phi - - val (subOuts, outerOuts) = p.xs.map { x => - if (isDecomposable(x)) { - val ct = x.getType.asInstanceOf[CaseClassType] - - val newIds = ct.fields.map{ vd => FreshIdentifier(vd.id.name, vd.getType, true) } - - val newCC = CaseClass(ct, newIds.map(Variable)) - - subProblem = subst(x -> newCC, subProblem) - - (newIds, newCC) - } else { - (List(x), Variable(x)) - } - }.unzip - - val newOuts = subOuts.flatten - - val sub = Problem(p.as, p.ws, p.pc, subProblem, newOuts) - - val onSuccess: List[Solution] => Option[Solution] = { - case List(sol) => - Some(Solution(sol.pre, sol.defs, letTuple(newOuts, sol.term, tupleWrap(outerOuts)), sol.isTrusted)) - case _ => - None - } - - Some(decomp(List(sub), onSuccess, s"Detuple out ${p.xs.filter(isDecomposable).mkString(", ")}")) - } else { - None - } - - } -} diff --git a/src/main/scala/leon/synthesis/rules/unused/IntInduction.scala b/src/main/scala/leon/synthesis/rules/unused/IntInduction.scala deleted file mode 100644 index 67d5d49484a7c3f9a2b2201f5a3350f5cc8c71c4..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/rules/unused/IntInduction.scala +++ /dev/null @@ -1,75 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis -package rules.unused - -import purescala.Common.FreshIdentifier -import purescala.Constructors._ -import purescala.Definitions.{FunDef, ValDef} -import purescala.Expressions._ -import purescala.Extractors.IsTyped -import purescala.Types.IntegerType - -case object IntInduction extends Rule("Int Induction") { - def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { - p.as match { - case List(IsTyped(origId, IntegerType)) => - val tpe = tupleTypeWrap(p.xs.map(_.getType)) - - val inductOn = FreshIdentifier(origId.name, origId.getType, true) - - val postXs = p.xs map (id => FreshIdentifier("r", id.getType, true)) - - val postXsMap = (p.xs zip postXs).toMap.mapValues(Variable) - - val newPhi = subst(origId -> Variable(inductOn), p.phi) - val newPc = p.pc map (subst(origId -> Variable(inductOn), _)) - val newWs = subst(origId -> Variable(inductOn), p.ws) - val postCondGT = substAll(postXsMap + (origId -> Minus(Variable(inductOn), InfiniteIntegerLiteral(1))), p.phi) - val postCondLT = substAll(postXsMap + (origId -> Plus(Variable(inductOn), InfiniteIntegerLiteral(1))), p.phi) - - val subBase = Problem(List(), subst(origId -> InfiniteIntegerLiteral(0), p.ws), p.pc map (subst(origId -> InfiniteIntegerLiteral(0), _)), subst(origId -> InfiniteIntegerLiteral(0), p.phi), p.xs) - val subGT = Problem(inductOn :: postXs, newWs, newPc withCond and(GreaterThan(Variable(inductOn), InfiniteIntegerLiteral(0)), postCondGT), newPhi, p.xs) - val subLT = Problem(inductOn :: postXs, newWs, newPc withCond and(LessThan(Variable(inductOn), InfiniteIntegerLiteral(0)), postCondLT), newPhi, p.xs) - - val onSuccess: List[Solution] => Option[Solution] = { - case List(base, gt, lt) => - if (base.pre != BooleanLiteral(true) || (gt.pre != BooleanLiteral(true) && lt.pre != BooleanLiteral(true))) { - // Required to avoid an impossible base-case, which suffices to - // allow invalid programs. - None - } else { - val preIn = or(and(Equals(Variable(inductOn), InfiniteIntegerLiteral(0)), base.pre), - and(GreaterThan(Variable(inductOn), InfiniteIntegerLiteral(0)), gt.pre), - and(LessThan(Variable(inductOn), InfiniteIntegerLiteral(0)), lt.pre)) - val preOut = subst(inductOn -> Variable(origId), preIn) - - val newFun = new FunDef(FreshIdentifier("rec", alwaysShowUniqueID = true), Nil, Seq(ValDef(inductOn)), tpe) - val idPost = FreshIdentifier("res", tpe) - - newFun.precondition = Some(preIn) - newFun.postcondition = Some(Lambda(Seq(ValDef(idPost)), letTuple(p.xs.toSeq, Variable(idPost), p.phi))) - - newFun.body = Some( - IfExpr(Equals(Variable(inductOn), InfiniteIntegerLiteral(0)), - base.toExpr, - IfExpr(GreaterThan(Variable(inductOn), InfiniteIntegerLiteral(0)), - letTuple(postXs, FunctionInvocation(newFun.typed, Seq(Minus(Variable(inductOn), InfiniteIntegerLiteral(1)))), gt.toExpr), - letTuple(postXs, FunctionInvocation(newFun.typed, Seq(Plus(Variable(inductOn), InfiniteIntegerLiteral(1)))), lt.toExpr))) - ) - - - Some(Solution(preOut, base.defs++gt.defs++lt.defs+newFun, FunctionInvocation(newFun.typed, Seq(Variable(origId))), - Seq(base, gt, lt).forall(_.isTrusted))) - } - case _ => - None - } - - Some(decomp(List(subBase, subGT, subLT), onSuccess, s"Int Induction on '$origId'")) - case _ => - None - } - } -} diff --git a/src/main/scala/leon/synthesis/rules/unused/IntegerEquation.scala b/src/main/scala/leon/synthesis/rules/unused/IntegerEquation.scala deleted file mode 100644 index 7c2d7b251a02c0cdd431298fab67e4419b747a8d..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/rules/unused/IntegerEquation.scala +++ /dev/null @@ -1,132 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis -package rules.unused - -import purescala.Common._ -import purescala.Expressions._ -import purescala.Extractors._ -import purescala.Constructors._ -import purescala.ExprOps._ -import purescala.TreeNormalizations._ -import purescala.Types._ -import LinearEquations.elimVariable -import evaluators._ - -case object IntegerEquation extends Rule("Integer Equation") { - def instantiateOn(implicit hctx: SearchContext, problem: Problem): Traversable[RuleInstantiation] = if(!problem.xs.exists(_.getType == Int32Type)) Nil else { - - val TopLevelAnds(exprs) = problem.phi - - val (eqs, others) = exprs.partition(_.isInstanceOf[Equals]) - var candidates: Seq[Expr] = eqs - var allOthers: Seq[Expr] = others - - val evaluator = new DefaultEvaluator(hctx, hctx.program) - - var vars: Set[Identifier] = Set() - var eqxs: List[Identifier] = List() - - var optionNormalizedEq: Option[List[Expr]] = None - while(candidates.nonEmpty && optionNormalizedEq == None) { - val eq@Equals(_,_) = candidates.head - candidates = candidates.tail - - vars = variablesOf(eq) - eqxs = problem.xs.toSet.intersect(vars).toList - - try { - optionNormalizedEq = Some(linearArithmeticForm(Minus(eq.lhs, eq.rhs), eqxs.toArray).toList) - } catch { - case NonLinearExpressionException(_) => - allOthers = allOthers :+ eq - } - } - allOthers = allOthers ++ candidates - - optionNormalizedEq match { - case None => Nil - case Some(normalizedEq0) => { - - val eqas = problem.as.toSet.intersect(vars) - - val (neqxs, normalizedEq1) = eqxs.zip(normalizedEq0.tail).filterNot{ case (_, IntLiteral(0)) => true case _ => false}.unzip - val normalizedEq = normalizedEq0.head :: normalizedEq1 - - if(normalizedEq.size == 1) { - val eqPre = Equals(normalizedEq.head, IntLiteral(0)) - val newProblem = Problem(problem.as, problem.ws, problem.pc withCond eqPre, andJoin(allOthers), problem.xs) - - val onSuccess: List[Solution] => Option[Solution] = { - case List(s @ Solution(pre, defs, term, isTrusted)) => - Some(Solution(and(eqPre, pre), defs, term, isTrusted)) - case _ => - None - } - - List(decomp(List(newProblem), onSuccess, this.name)) - - } else { - val (eqPre0, eqWitness, freshxs) = elimVariable(evaluator, eqas, normalizedEq) - val eqPre = eqPre0 match { - case Equals(Modulo(_, IntLiteral(1)), _) => BooleanLiteral(true) - case c => c - } - - val eqSubstMap: Map[Expr, Expr] = neqxs.zip(eqWitness).map{case (id, e) => (Variable(id), simplifyArithmetic(e))}.toMap - val freshFormula0 = simplifyArithmetic(replace(eqSubstMap, andJoin(allOthers))) - - var freshInputVariables: List[Identifier] = Nil - var equivalenceConstraints: Map[Expr, Expr] = Map() - val freshFormula = simplePreTransform({ - case d @ Division(_, _) => { - assert(variablesOf(d).intersect(problem.xs.toSet).isEmpty) - val newVar = FreshIdentifier("d", Int32Type, true) - freshInputVariables ::= newVar - equivalenceConstraints += (Variable(newVar) -> d) - Variable(newVar) - } - case e => e - })(freshFormula0) - - val ys: List[Identifier] = problem.xs.filterNot(neqxs.contains(_)) - val subproblemxs: List[Identifier] = freshxs ++ ys - - val newProblem = Problem(problem.as ++ freshInputVariables, problem.ws, problem.pc withCond eqPre, freshFormula, subproblemxs) - - val onSuccess: List[Solution] => Option[Solution] = { - case List(s @ Solution(pre, defs, term, isTrusted)) => { - val freshPre = replace(equivalenceConstraints, pre) - val freshTerm = replace(equivalenceConstraints, term) - val freshsubxs = subproblemxs.map(id => FreshIdentifier(id.name, id.getType)) - val id2res: Map[Expr, Expr] = - freshsubxs.zip(subproblemxs).map{case (id1, id2) => (Variable(id1), Variable(id2))}.toMap ++ - neqxs.map(id => (Variable(id), eqSubstMap(Variable(id)))).toMap - Some(Solution( - and(eqPre, freshPre), - defs, - simplifyArithmetic(simplifyLets( - letTuple(subproblemxs, freshTerm, replace(id2res, tupleWrap(problem.xs.map(Variable)))) - )), - isTrusted - )) - } - - case _ => - None - } - - - if (subproblemxs.isEmpty) { - // we directly solve - List(solve(onSuccess(List(Solution((problem.pc withCond eqPre).toClause, Set(), UnitLiteral()))).get)) - } else { - List(decomp(List(newProblem), onSuccess, this.name)) - } - } - } - } - - } -} diff --git a/src/main/scala/leon/synthesis/rules/unused/IntegerInequalities.scala b/src/main/scala/leon/synthesis/rules/unused/IntegerInequalities.scala deleted file mode 100644 index e809fe8e04bf11fa456c605030dbb458b814c9c8..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/rules/unused/IntegerInequalities.scala +++ /dev/null @@ -1,221 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis -package rules.unused - -import purescala.Common._ -import purescala.Expressions._ -import purescala.Extractors._ -import purescala.ExprOps._ -import purescala.TreeNormalizations.linearArithmeticForm -import purescala.TreeNormalizations.NonLinearExpressionException -import purescala.Types._ -import purescala.Constructors._ -import purescala.Definitions._ -import leon.synthesis.Algebra.lcm - -case object IntegerInequalities extends Rule("Integer Inequalities") { - def instantiateOn(implicit hctx: SearchContext, problem: Problem): Traversable[RuleInstantiation] = { - val TopLevelAnds(exprs) = problem.phi - - //assume that we only have inequalities - var lhsSides: List[Expr] = Nil - var exprNotUsed: List[Expr] = Nil - //normalized all inequalities to LessEquals(t, 0) - exprs.foreach{ - case LessThan(a, b) => lhsSides ::= Plus(Minus(a, b), IntLiteral(1)) - case LessEquals(a, b) => lhsSides ::= Minus(a, b) - case GreaterThan(a, b) => lhsSides ::= Plus(Minus(b, a), IntLiteral(1)) - case GreaterEquals(a, b) => lhsSides ::= Minus(b, a) - case e => exprNotUsed ::= e - } - - val ineqVars = lhsSides.foldLeft(Set[Identifier]())((acc, lhs) => acc ++ variablesOf(lhs)) - val nonIneqVars = exprNotUsed.foldLeft(Set[Identifier]())((acc, x) => acc ++ variablesOf(x)) - val candidateVars = ineqVars.intersect(problem.xs.toSet).diff(nonIneqVars) - - val processedVars: Set[(Identifier, Int)] = candidateVars.flatMap(v => { - try { - val normalizedLhs: List[List[Expr]] = lhsSides.map(linearArithmeticForm(_, Array(v)).toList) - if(normalizedLhs.isEmpty) - Some((v, 0)) - else - Some((v, lcm(normalizedLhs.map{ case List(t, IntLiteral(i)) => if(i == 0) 1 else i.abs case _ => sys.error("shouldn't happen") }))) - } catch { - case NonLinearExpressionException(_) => - None - } - }) - - if (processedVars.isEmpty) { - Nil - } else { - val processedVar = processedVars.toList.sortWith((t1, t2) => t1._2 <= t2._2).head._1 - - val otherVars: List[Identifier] = problem.xs.filterNot(_ == processedVar) - - - val normalizedLhs: List[List[Expr]] = lhsSides.map(linearArithmeticForm(_, Array(processedVar)).toList) - var upperBounds: List[(Expr, Int)] = Nil // (t, c) means c*x <= t - var lowerBounds: List[(Expr, Int)] = Nil // (t, c) means t <= c*x - normalizedLhs.foreach{ - case List(t, IntLiteral(i)) => - if(i > 0) upperBounds ::= (expandAndSimplifyArithmetic(UMinus(t)), i) - else if(i < 0) lowerBounds ::= (expandAndSimplifyArithmetic(t), -i) - else exprNotUsed ::= LessEquals(t, IntLiteral(0)) //TODO: make sure that these are added as preconditions - case err => sys.error("unexpected from normal form: " + err) - } - - val L = if(upperBounds.isEmpty && lowerBounds.isEmpty) -1 else lcm((upperBounds ::: lowerBounds).map(_._2)) - - //optimization when coef = 1 and when ub - lb is a constant greater than LCM - //upperBounds = upperBounds.filterNot{case (ub, uc) => if(uc == 1) { - // exprNotUsed ++= lowerBounds.map{case (lb, lc) => LessEquals(lb, Times(IntLiteral(lc), ub))} - // true - // } else - // false - //} - //lowerBounds = lowerBounds.filterNot{case (lb, lc) => if(lc == 1) { - // exprNotUsed ++= upperBounds.map{case (ub, uc) => LessEquals(Times(IntLiteral(uc), lb), ub)} - // true - // } else - // false - //} - //upperBounds = upperBounds.filterNot{case (ub, uc) => { - // lowerBounds.forall{case (lb, lc) => { - // expandAndSimplifyArithmetic(Minus(ub, lb)) match { - // case IntLiteral(n) => L - 1 <= n - // case _ => false - // }}} - //}} - - - //define max function - val maxValDefs: Seq[ValDef] = lowerBounds.map(_ => ValDef(FreshIdentifier("b", Int32Type))) - val maxFun = new FunDef(FreshIdentifier("max"), Nil, maxValDefs, Int32Type) - def maxRec(bounds: List[Expr]): Expr = bounds match { - case (x1 :: x2 :: xs) => { - val v = FreshIdentifier("m", Int32Type) - Let(v, IfExpr(LessThan(x1, x2), x2, x1), maxRec(Variable(v) :: xs)) - } - case (x :: Nil) => x - case Nil => sys.error("cannot build a max expression with no argument") - } - if(lowerBounds.nonEmpty) - maxFun.body = Some(maxRec(maxValDefs.map(vd => Variable(vd.id)).toList)) - def max(xs: Seq[Expr]): Expr = FunctionInvocation(maxFun.typed, xs) - //define min function - val minValDefs: Seq[ValDef] = upperBounds.map(_ => ValDef(FreshIdentifier("b", Int32Type))) - val minFun = new FunDef(FreshIdentifier("min"), Nil, minValDefs, Int32Type) - def minRec(bounds: List[Expr]): Expr = bounds match { - case (x1 :: x2 :: xs) => { - val v = FreshIdentifier("m", Int32Type) - Let(v, IfExpr(LessThan(x1, x2), x1, x2), minRec(Variable(v) :: xs)) - } - case (x :: Nil) => x - case Nil => sys.error("cannot build a min expression with no argument") - } - if(upperBounds.nonEmpty) - minFun.body = Some(minRec(minValDefs.map(vd => Variable(vd.id)).toList)) - def min(xs: Seq[Expr]): Expr = FunctionInvocation(minFun.typed, xs) - val floorFun = new FunDef(FreshIdentifier("floorDiv"), Nil, Seq( - ValDef(FreshIdentifier("x", Int32Type)), - ValDef(FreshIdentifier("x", Int32Type))), Int32Type) - val ceilingFun = new FunDef(FreshIdentifier("ceilingDiv"), Nil, Seq( - ValDef(FreshIdentifier("x", Int32Type)), - ValDef(FreshIdentifier("x", Int32Type))), Int32Type) - ceilingFun.body = Some(IntLiteral(0)) - def floorDiv(x: Expr, y: Expr): Expr = FunctionInvocation(floorFun.typed, Seq(x, y)) - def ceilingDiv(x: Expr, y: Expr): Expr = FunctionInvocation(ceilingFun.typed, Seq(x, y)) - - val witness: Expr = if(upperBounds.isEmpty) { - if(lowerBounds.size > 1) max(lowerBounds.map{case (b, c) => ceilingDiv(b, IntLiteral(c))}) - else ceilingDiv(lowerBounds.head._1, IntLiteral(lowerBounds.head._2)) - } else { - if(upperBounds.size > 1) min(upperBounds.map{case (b, c) => floorDiv(b, IntLiteral(c))}) - else floorDiv(upperBounds.head._1, IntLiteral(upperBounds.head._2)) - } - - if(otherVars.isEmpty) { //here we can simply evaluate the precondition and return a witness - - val constraints: List[Expr] = for((ub, uc) <- upperBounds; (lb, lc) <- lowerBounds) - yield LessEquals(ceilingDiv(lb, IntLiteral(lc)), floorDiv(ub, IntLiteral(uc))) - val pre = And(exprNotUsed ++ constraints) - List(solve(Solution(pre, Set(), tupleWrap(Seq(witness))))) - } else { - - val involvedVariables = (upperBounds++lowerBounds).foldLeft(Set[Identifier]())((acc, t) => { - acc ++ variablesOf(t._1) - }).intersect(problem.xs.toSet) //output variables involved in the bounds of the process variables - var newPre: Expr = BooleanLiteral(true) - if(involvedVariables.isEmpty) { - newPre = And( - for((ub, uc) <- upperBounds; (lb, lc) <- lowerBounds) - yield LessEquals(ceilingDiv(lb, IntLiteral(lc)), floorDiv(ub, IntLiteral(uc))) - ) - lowerBounds = Nil - upperBounds = Nil - } - - val remainderIds: List[Identifier] = upperBounds.map(_ => FreshIdentifier("k", Int32Type, true)) - val quotientIds: List[Identifier] = lowerBounds.map(_ => FreshIdentifier("l", Int32Type, true)) - - val newUpperBounds: List[Expr] = upperBounds.map{case (bound, coef) => Times(IntLiteral(L/coef), bound)} - val newLowerBounds: List[Expr] = lowerBounds.map{case (bound, coef) => Times(IntLiteral(L/coef), bound)} - - - val subProblemFormula = expandAndSimplifyArithmetic(And( - newUpperBounds.zip(remainderIds).zip(quotientIds).flatMap{ - case ((b, k), l) => Equals(b, Plus(Times(IntLiteral(L), Variable(l)), Variable(k))) :: - newLowerBounds.map(lbound => LessEquals(Variable(k), Minus(b, lbound))) - } ++ exprNotUsed)) - val subProblemxs: List[Identifier] = quotientIds ++ otherVars - val subProblem = Problem(problem.as ++ remainderIds, problem.ws, problem.pc, subProblemFormula, subProblemxs) - - def onSuccess(sols: List[Solution]): Option[Solution] = sols match { - case List(s @ Solution(pre, defs, term, isTrusted)) => - if(remainderIds.isEmpty) { - Some(Solution( - And(newPre, pre), - defs, - letTuple(subProblemxs, term, - Let(processedVar, witness, - tupleWrap(problem.xs.map(Variable)))), - isTrusted - )) - } else if(remainderIds.size > 1) { - sys.error("TODO") - } else { - val k = remainderIds.head - - val loopCounter = Variable(FreshIdentifier("i", Int32Type, true)) - val concretePre = replace(Map(Variable(k) -> loopCounter), pre) - val concreteTerm = replace(Map(Variable(k) -> loopCounter), term) - val returnType = tupleTypeWrap(problem.xs.map(_.getType)) - val funDef = new FunDef(FreshIdentifier("rec", alwaysShowUniqueID = true), Nil, Seq(ValDef(loopCounter.id)), returnType) - val funBody = expandAndSimplifyArithmetic(IfExpr( - LessThan(loopCounter, IntLiteral(0)), - Error(returnType, "No solution exists"), - IfExpr( - concretePre, - letTuple(subProblemxs, concreteTerm, - Let(processedVar, witness, - tupleWrap(problem.xs.map(Variable))) - ), - FunctionInvocation(funDef.typed, Seq(Minus(loopCounter, IntLiteral(1)))) - ) - )) - funDef.body = Some(funBody) - - Some(Solution(And(newPre, pre), defs + funDef, FunctionInvocation(funDef.typed, Seq(IntLiteral(L-1))))) - } - case _ => - None - } - - List(decomp(List(subProblem), onSuccess, this.name)) - } - } - } -} diff --git a/src/main/scala/leon/synthesis/rules/unused/OptimisticInjection.scala b/src/main/scala/leon/synthesis/rules/unused/OptimisticInjection.scala deleted file mode 100644 index b3d35f8a090d8f1b85532040b558cb8abf5a0c8f..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/rules/unused/OptimisticInjection.scala +++ /dev/null @@ -1,43 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis -package rules.unused - -import purescala.Expressions._ -import purescala.Extractors._ -import purescala.Constructors._ - -case object OptimisticInjection extends Rule("Opt. Injection") { - def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { - val TopLevelAnds(exprs) = p.phi - - val eqfuncalls = exprs.collect{ - case eq @ Equals(FunctionInvocation(fd, args), e) => - ((fd, e), args, eq : Expr) - case eq @ Equals(e, FunctionInvocation(fd, args)) => - ((fd, e), args, eq : Expr) - } - - val candidates = eqfuncalls.groupBy(_._1).filter(_._2.size > 1) - if (candidates.nonEmpty) { - - var newExprs = exprs - for (cands <- candidates.values) { - val cand = cands.take(2) - val toRemove = cand.map(_._3).toSet - val argss = cand.map(_._2) - val args = argss(0) zip argss(1) - - newExprs ++= args.map{ case (l, r) => Equals(l, r) } - newExprs = newExprs.filterNot(toRemove) - } - - val sub = p.copy(phi = andJoin(newExprs)) - - Some(decomp(List(sub), forward, s"Injection ${candidates.keySet.map(_._1.id).mkString(", ")}")) - } else { - None - } - } -} diff --git a/src/main/scala/leon/synthesis/rules/unused/SelectiveInlining.scala b/src/main/scala/leon/synthesis/rules/unused/SelectiveInlining.scala deleted file mode 100644 index 9dd4e68e01ede5e953c3b0f06a573be505056dc3..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/rules/unused/SelectiveInlining.scala +++ /dev/null @@ -1,43 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis -package rules.unused - -import purescala.Constructors._ -import purescala.Expressions.{Equals, Expr, FunctionInvocation} -import purescala.Extractors.TopLevelAnds - -case object SelectiveInlining extends Rule("Sel. Inlining") { - def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { - val TopLevelAnds(exprs) = p.phi - - val eqfuncalls = exprs.collect{ - case eq @ Equals(FunctionInvocation(fd, args), e) => - ((fd, e), args, eq : Expr) - case eq @ Equals(e, FunctionInvocation(fd, args)) => - ((fd, e), args, eq : Expr) - } - - val candidates = eqfuncalls.groupBy(_._1).filter(_._2.size > 1) - if (candidates.nonEmpty) { - - var newExprs = exprs - for (cands <- candidates.values) { - val cand = cands.take(2) - val toRemove = cand.map(_._3).toSet - val argss = cand.map(_._2) - val args = argss(0) zip argss(1) - - newExprs ++= args.map{ case (l, r) => Equals(l, r) } - newExprs = newExprs.filterNot(toRemove) - } - - val sub = p.copy(phi = andJoin(newExprs)) - - Some(decomp(List(sub), forward, s"Inlining ${candidates.keySet.map(_._1.id).mkString(", ")}")) - } else { - None - } - } -} diff --git a/src/main/scala/leon/synthesis/rules/unused/SygusCVC4.scala b/src/main/scala/leon/synthesis/rules/unused/SygusCVC4.scala deleted file mode 100644 index 232eae5646734d639425b2875cd440c459f87fb5..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/rules/unused/SygusCVC4.scala +++ /dev/null @@ -1,27 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis -package rules.unused - -import leon.solvers.sygus.CVC4SygusSolver - -import leon.grammars._ - -case object SygusCVC4 extends Rule("SygusCVC4") { - def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { - List(new RuleInstantiation(this.name) { - def apply(hctx: SearchContext): RuleApplication = { - - val s = new CVC4SygusSolver(hctx, hctx.program, p) - - s.checkSynth() match { - case Some(expr) => - RuleClosed(Solution.term(expr, isTrusted = false)) - case None => - RuleFailed() - } - } - }) - } -} diff --git a/src/main/scala/leon/synthesis/rules/unused/TEGIS.scala b/src/main/scala/leon/synthesis/rules/unused/TEGIS.scala deleted file mode 100644 index a22a048f0b302cb9e2b2cb67afd974ce351fdfd6..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/rules/unused/TEGIS.scala +++ /dev/null @@ -1,16 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis -package rules.unused - -import leon.grammars._ - -case object TEGIS extends TEGISLike("TEGIS") { - def getParams(sctx: SynthesisContext, p: Problem) = { - TegisParams( - grammar = grammars.default(sctx, p), - rootLabel = Label(_) - ) - } -} diff --git a/src/main/scala/leon/synthesis/rules/unused/TEGISLike.scala b/src/main/scala/leon/synthesis/rules/unused/TEGISLike.scala deleted file mode 100644 index 05f8242e6d2dac1eaf9a91182900ffbd06ff119e..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/rules/unused/TEGISLike.scala +++ /dev/null @@ -1,119 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis -package rules.unused - -import purescala.Expressions._ -import purescala.Types._ -import purescala.Constructors._ - -import datagen._ -import evaluators._ -import codegen.CodeGenParams -import leon.grammars._ -import leon.utils.GrowableIterable - -import scala.collection.mutable.{HashMap => MutableMap} - -import bonsai.enumerators._ - -abstract class TEGISLike(name: String) extends Rule(name) { - case class TegisParams( - grammar: ExpressionGrammar, - rootLabel: TypeTree => Label, - enumLimit: Int = 10000, - reorderInterval: Int = 50 - ) - - def getParams(sctx: SynthesisContext, p: Problem): TegisParams - - def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { - - List(new RuleInstantiation(this.name) { - def apply(hctx: SearchContext): RuleApplication = { - implicit val ci = hctx - val params = getParams(hctx, p) - val grammar = params.grammar - - val nTests = if (p.pc.isEmpty) 50 else 20 - - val useVanuatoo = hctx.settings.cegisUseVanuatoo - - val inputGenerator: Iterator[Seq[Expr]] = if (useVanuatoo) { - new VanuatooDataGen(hctx, hctx.program).generateFor(p.as, p.pc.toClause, nTests, 3000) - } else { - val evaluator = new DualEvaluator(hctx, hctx.program) - new GrammarDataGen(evaluator, ValueGrammar).generateFor(p.as, p.pc.toClause, nTests, 1000) - } - - val gi = new GrowableIterable[Seq[Expr]](p.eb.examples.map(_.ins).distinct, inputGenerator) - - val failedTestsStats = new MutableMap[Seq[Expr], Int]().withDefaultValue(0) - - var n = 1 - def allInputExamples() = { - if (n == 10 || n == 50 || n % 500 == 0) { - gi.sortBufferBy(e => -failedTestsStats(e)) - } - n += 1 - gi.iterator - } - - if (gi.nonEmpty) { - - val evalParams = CodeGenParams.default.copy(maxFunctionInvocations = 2000) - val evaluator = new DualEvaluator(hctx, hctx.program, params = evalParams) - - val enum = new MemoizedEnumerator[Label, Expr, ProductionRule[Label, Expr]](grammar.getProductions) - - val targetType = tupleTypeWrap(p.xs.map(_.getType)) - - val timers = hctx.timers.synthesis.rules.tegis - - val allExprs = enum.iterator(params.rootLabel(targetType)) - - var candidate: Option[Expr] = None - - def findNext(): Option[Expr] = { - candidate = None - timers.generating.start() - allExprs.take(params.enumLimit).takeWhile(e => candidate.isEmpty).foreach { e => - val exprToTest = letTuple(p.xs, e, p.phi) - - //sctx.reporter.debug("Got expression "+e.asString) - timers.testing.start() - if (allInputExamples().forall{ t => - val res = evaluator.eval(exprToTest, p.as.zip(t).toMap) match { - case EvaluationResults.Successful(BooleanLiteral(true)) => - //sctx.reporter.debug("Test "+t.map(_.asString)+" passed!") - true - case _ => - //sctx.reporter.debug("Test "+t.map(_.asString)+" failed!") - failedTestsStats += t -> (failedTestsStats(t)+1) - false - } - res - }) { - candidate = Some(tupleWrap(Seq(e))) - } - timers.testing.stop() - } - timers.generating.stop() - - candidate - } - - val toStream = Stream.continually(findNext()).takeWhile(_.nonEmpty).map( e => - Solution(BooleanLiteral(true), Set(), e.get, isTrusted = false) - ) - - RuleClosed(toStream) - } else { - hctx.reporter.debug("No test available") - RuleFailed() - } - } - }) - } -} diff --git a/src/main/scala/leon/synthesis/rules/unused/TEGLESS.scala b/src/main/scala/leon/synthesis/rules/unused/TEGLESS.scala deleted file mode 100644 index 8b0c019d495b936fe5293224ca01c0efa83f76c9..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/rules/unused/TEGLESS.scala +++ /dev/null @@ -1,34 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis -package rules.unused - -import purescala.Types._ -import purescala.Extractors._ -import Witnesses._ -import leon.grammars._ -import leon.grammars.aspects.{SimilarTo, DepthBound} - -case object TEGLESS extends TEGISLike("TEGLESS") { - def getParams(sctx: SynthesisContext, p: Problem) = { - - val TopLevelAnds(clauses) = p.ws - - val guides = clauses.collect { - case Guide(expr) => expr - } - - sctx.reporter.ifDebug { printer => - printer("Guides available:") - for (g <- guides) { - printer(" - "+g) - } - } - - TegisParams( - grammar = grammars.default(sctx, p), - rootLabel = { (tpe: TypeTree) => Label(tpe).withAspect(DepthBound(2)).withAspect(SimilarTo(guides)) } - ) - } -} diff --git a/src/main/scala/leon/synthesis/strategies/BoundedStrategy.scala b/src/main/scala/leon/synthesis/strategies/BoundedStrategy.scala deleted file mode 100644 index 72074c5954ac2e643570543d06875af58d993ec2..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/strategies/BoundedStrategy.scala +++ /dev/null @@ -1,24 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis -package strategies - -import synthesis.graph._ - -case class BoundedStrategy(underlying: Strategy, bound: Int) extends WrappedStrategy(underlying) { - private[this] var nSteps = 0 - - override def getNextToExpand(from: Node): Option[Node] = { - if (nSteps < bound) { - super.getNextToExpand(from) - } else { - None - } - } - - override def afterExpand(n: Node) = { - super.afterExpand(n); - nSteps += 1 - } -} diff --git a/src/main/scala/leon/synthesis/strategies/CostBasedStrategy.scala b/src/main/scala/leon/synthesis/strategies/CostBasedStrategy.scala deleted file mode 100644 index de1f7d4105c1661a48f6d8a5696da1750c33bc51..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/strategies/CostBasedStrategy.scala +++ /dev/null @@ -1,94 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis -package strategies - -import synthesis.graph._ - -class CostBasedStrategy(ctx: LeonContext, cm: CostModel) extends Strategy { - private var bestSols = Map[Node, Option[Solution]]() - private var bestCosts = Map[Node, Cost]() - - override def init(root: RootNode): Unit = { - super.init(root) - computeBestSolutionFor(root) - } - - def computeBestSolutionFor(n: Node): Option[Solution] = { - val res = if (n.isSolved) { - Some(n.generateSolutions().head) - } else if (n.isDeadEnd) { - None - } else if (!n.isExpanded) { - n match { - case an: AndNode => - an.ri.onSuccess match { - case SolutionBuilderCloser(_) => - Some(Solution.simplest(an.p.outType)) - - case SolutionBuilderDecomp(types, recomp) => - recomp(types.toList.map(Solution.simplest)) - } - case on: OrNode => - Some(Solution.simplest(n.p.outType)) - } - } else { - n match { - case an: AndNode => - val subs = an.descendants.map(bestSolutionFor) - - if (subs.forall(_.isDefined)) { - an.ri.onSuccess(subs.flatten) - } else { - None - } - case on: OrNode => - on.descendants.foreach(bestSolutionFor) - - bestSolutionFor(on.descendants.minBy(bestCosts)) - } - } - - bestSols += n -> res - bestCosts += n -> res.map(cm.solution _).getOrElse(cm.impossible) - - res - } - - def bestAlternative(on: OrNode): Option[Node] = { - if (on.isDeadEnd) { - None - } else { - Some(on.descendants.minBy(bestCosts)) - } - } - - def bestSolutionFor(n: Node): Option[Solution] = { - bestSols.get(n) match { - case Some(os) => os - case None => computeBestSolutionFor(n) - } - } - - def recomputeCost(n: Node): Unit = { - val oldCost = bestCosts(n) - computeBestSolutionFor(n) - - if (bestCosts(n) != oldCost) { - n.parent.foreach(recomputeCost) - } - } - - override def afterExpand(n: Node): Unit = { - super.afterExpand(n) - - for (d <- n.descendants) { - bestSolutionFor(d) - } - - recomputeCost(n) - } - - def debugInfoFor(n: Node) = bestCosts.get(n).map(_.toString).getOrElse("?") -} diff --git a/src/main/scala/leon/synthesis/strategies/ManualStrategy.scala b/src/main/scala/leon/synthesis/strategies/ManualStrategy.scala deleted file mode 100644 index 9f821d593201b488dbf6fccdd59d51d812feb8c4..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/strategies/ManualStrategy.scala +++ /dev/null @@ -1,286 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis -package strategies - -import purescala.Common.FreshIdentifier - -import graph._ - -class ManualStrategy(ctx: LeonContext, initCmd: Option[String], strat: Strategy) extends Strategy { - implicit val ctx_ = ctx - - import ctx.reporter._ - - abstract class Command - case class Cd(path: List[Int]) extends Command - case object Parent extends Command - case object Quit extends Command - case object Noop extends Command - case object Best extends Command - case object Tree extends Command - case object Help extends Command - - // Manual search state: - var rootNode: Node = _ - - var path = List[Int]() - - override def init(n: RootNode) = { - super.init(n) - strat.init(n) - - rootNode = n - } - - def currentNode(path: List[Int]): Node = { - def findFrom(n: Node, path: List[Int]): Node = { - path match { - case Nil => n - case p :: ps => - findDescendent(n, p) match { - case Some(d) => - findFrom(d, ps) - case None => - n - } - } - } - - findFrom(rootNode, path) - } - - override def beforeExpand(n: Node) = { - super.beforeExpand(n) - strat.beforeExpand(n) - } - - override def afterExpand(n: Node) = { - super.afterExpand(n) - strat.afterExpand(n) - - // Backtrack view to a point where node is neither closed nor solved - if (n.isDeadEnd || n.isSolved) { - val backtrackTo = findAncestor(n, n => !n.isDeadEnd && !n.isSolved) - - path = backtrackTo.map(pathTo).getOrElse(Nil) - } - } - - private def findAncestor(n: Node, f: Node => Boolean): Option[Node] = { - n.parent.flatMap { n => - if (f(n)) Some(n) else findAncestor(n, f) - } - } - - private def pathTo(n: Node): List[Int] = { - n.parent match { - case None => Nil - case Some(p) => pathTo(p) :+ p.descendants.indexOf(n) - } - } - - - def bestAlternative(n: OrNode) = strat.bestAlternative(n) - - - def printGraph() { - def title(str: String) = "\u001b[1m" + str + "\u001b[0m" - def failed(str: String) = "\u001b[31m" + str + "\u001b[0m" - def solved(str: String) = "\u001b[32m" + str + "\u001b[0m" - def expanded(str: String) = "\u001b[33m" + str + "\u001b[0m" - - def displayNode(n: Node, inTitle: Boolean = false): String = { - n match { - case an: AndNode => - val app = an.ri.asString(ctx) - s"(${debugInfoFor(n)}) ${indent(app, inTitle)}" - case on: OrNode => - val p = on.p.asString(ctx) - s"(${debugInfoFor(n)}) ${indent(p, inTitle)}" - } - } - - def indent(a: String, inTitle: Boolean): String = { - a.replaceAll("\n", "\n"+(" "*(if(inTitle) 11 else 13))) - } - - def pathToString(cd: List[Int]): String = { - cd.map(i => f"$i%2d").mkString(" ") - } - - val c = currentNode(path) - - println("-"*120) - val at = path.lastOption.map(p => pathToString(List(p))).getOrElse(" R") - - println(title(at+" \u2510 "+displayNode(c, true))) - - for ((sn, i) <- c.descendants.zipWithIndex) { - val sp = List(i) - - if (sn.isSolved) { - println(solved(" "+pathToString(sp)+" \u2508 "+displayNode(sn))) - } else if (sn.isDeadEnd) { - println(failed(" "+pathToString(sp)+" \u2508 "+displayNode(sn))) - } else if (sn.isExpanded) { - println(expanded(" "+pathToString(sp)+" \u2508 "+displayNode(sn))) - } else { - println(" "+pathToString(sp)+" \u2508 "+displayNode(sn)) - } - } - println("-"*120) - } - - var continue = true - - def findDescendent(n: Node, index: Int): Option[Node] = { - n.descendants.zipWithIndex.find(_._2 == index).map(_._1) - } - - def manualGetNext(): Option[Node] = { - val c = currentNode(path) - - if (!c.isExpanded) { - Some(c) - } else { - printGraph() - - nextCommand() match { - case Quit => - None - - case Help => - val tOpen = "\u001b[1m" - val tClose = "\u001b[0m" - println(s"""| - |${tOpen}Available commands: $tClose - |$tOpen (cd) N $tClose Expand descendant N - |$tOpen cd .. $tClose Go one level up - |$tOpen b $tClose Expand best descendant - |$tOpen t $tClose Display the partial solution around the current node - |$tOpen q $tClose Quit the search - |$tOpen h $tClose Display this message - |""".stripMargin) - manualGetNext() - case Parent => - if (path.nonEmpty) { - path = path.dropRight(1) - } else { - error("Already at root node!") - } - - manualGetNext() - - case Tree => - val hole = FreshIdentifier("\u001b[1;31m??? \u001b[0m", c.p.outType) - val ps = new PartialSolution(this, true) - - ps.solutionAround(c)(hole.toVariable) match { - case Some(sol) => - println("-"*120) - println(sol.toExpr.asString) - case None => - error("woot!") - } - manualGetNext() - - case Best => - strat.bestNext(c) match { - case Some(n) => - val i = c.descendants.indexOf(n) - path = path :+ i - Some(currentNode(path)) - - case None => - error("Woot?") - manualGetNext() - } - - - case Cd(Nil) => - error("Woot?") - None - - case Cd(next :: rest) => - findDescendent(c, next) match { - case Some(_) => - path = path :+ next - case None => - warning("Unknown descendant: "+next) - } - - if (rest.nonEmpty) { - cmdQueue = Cd(rest) :: cmdQueue - } - manualGetNext() - } - } - } - - override def getNextToExpand(root: Node): Option[Node] = { - manualGetNext() - } - - def debugInfoFor(n: Node) = strat.debugInfoFor(n) - - var cmdQueue = initCmd.map( str => parseCommands(parseString(str))).getOrElse(Nil) - - private def parseString(s: String): List[String] = { - Option(s).map(_.trim.split("\\s+|,").toList).getOrElse(fatalError("End of stream")) - } - - private def nextCommand(): Command = cmdQueue match { - case c :: cs => - cmdQueue = cs - c - - case Nil => - print("Next action? (h for help) "+path.mkString(" ")+" $ ") - val line = scala.io.StdIn.readLine() - val parts = parseString(line) - - cmdQueue = parseCommands(parts) - nextCommand() - } - - private def parseCommands(tokens: List[String]): List[Command] = tokens match { - case "cd" :: ".." :: ts => - Parent :: parseCommands(ts) - - case "cd" :: ts => - val path = ts.takeWhile { t => t.forall(_.isDigit) } - - if (path.isEmpty) { - parseCommands(ts) - } else { - Cd(path.map(_.toInt)) :: parseCommands(ts.drop(path.size)) - } - - case "t" :: ts => - Tree :: parseCommands(ts) - - case "b" :: ts => - Best :: parseCommands(ts) - - case "h" :: ts => - Help :: parseCommands(ts) - - case "q" :: ts => - Quit :: Nil - - case Nil | "" :: Nil => - Nil - - case ts => - val path = ts.takeWhile { t => t.forall(_.isDigit) } - - if (path.isEmpty) { - error("Unknown command "+ts.head) - parseCommands(ts.tail) - } else { - Cd(path.map(_.toInt)) :: parseCommands(ts.drop(path.size)) - } - } -} diff --git a/src/main/scala/leon/synthesis/strategies/Strategy.scala b/src/main/scala/leon/synthesis/strategies/Strategy.scala deleted file mode 100644 index 5b6253bd9b63ded24843de6764b0b003a88af4a1..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/strategies/Strategy.scala +++ /dev/null @@ -1,59 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis -package strategies - -import synthesis.graph._ - -import leon.utils.Interruptible - -abstract class Strategy extends Interruptible { - - // Nodes to consider next - private var openNodes = Set[Node]() - - def init(root: RootNode): Unit = { - openNodes += root - } - - /** - * Standard-next for AndNodes, strategy-best for OrNodes - */ - def bestNext(n: Node): Option[Node] = { - n match { - case an: AndNode => - an.descendants.find(_.isOpen) - - case on: OrNode => - bestAlternative(on) - } - } - - def bestAlternative(on: OrNode): Option[Node] - - def getNextToExpand(root: Node): Option[Node] = { - if (openNodes(root)) { - Some(root) - } else if (openNodes.isEmpty) { - None - } else { - bestNext(root).flatMap(getNextToExpand) - } - } - - def beforeExpand(n: Node): Unit = {} - - def afterExpand(n: Node): Unit = { - openNodes -= n - openNodes ++= n.descendants - } - - def interrupt() = {} - - def recoverInterrupt() = {} - - def debugInfoFor(n: Node): String -} - - diff --git a/src/main/scala/leon/synthesis/strategies/TimeSlicingStrategy.scala b/src/main/scala/leon/synthesis/strategies/TimeSlicingStrategy.scala deleted file mode 100644 index 2ed3b56f22d85c5ea29b8fcf671d63434af1476d..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/strategies/TimeSlicingStrategy.scala +++ /dev/null @@ -1,40 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis -package strategies - -import synthesis.graph._ - -class TimeSlicingStrategy(ctx: LeonContext) extends Strategy { - - var timeSpent = Map[Node, Long]().withDefaultValue(0l) - - def bestAlternative(on: OrNode): Option[Node] = { - on.descendants.filter(_.isOpen).sortBy(timeSpent).headOption - } - - def recordTime(from: Node, t: Long): Unit = { - timeSpent += from -> (timeSpent(from) + t) - - from.parent.foreach { - recordTime(_, t) - } - } - - var tstart: Long = 0; - - override def beforeExpand(n: Node): Unit = { - super.beforeExpand(n) - tstart = System.currentTimeMillis - } - - override def afterExpand(n: Node): Unit = { - super.afterExpand(n) - - val t = System.currentTimeMillis - tstart - recordTime(n, t) - } - - def debugInfoFor(n: Node) = timeSpent.get(n).map(_.toString).getOrElse("?") -} diff --git a/src/main/scala/leon/synthesis/strategies/WrappedStrategy.scala b/src/main/scala/leon/synthesis/strategies/WrappedStrategy.scala deleted file mode 100644 index 8e281ea3d1cc524fe46ef9a8fcdcbfe746bdffa1..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/strategies/WrappedStrategy.scala +++ /dev/null @@ -1,34 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis -package strategies - -import synthesis.graph._ - -class WrappedStrategy(underlying: Strategy) extends Strategy { - - override def init(root: RootNode) = underlying.init(root) - - override def getNextToExpand(from: Node): Option[Node] = { - underlying.getNextToExpand(from) - } - - override def bestAlternative(on: OrNode): Option[Node] = { - underlying.bestAlternative(on) - } - - override def beforeExpand(n: Node) = { - underlying.beforeExpand(n) - } - - override def afterExpand(n: Node) = { - underlying.afterExpand(n); - } - - override def interrupt() = underlying.interrupt() - - override def recoverInterrupt() = underlying.recoverInterrupt() - - def debugInfoFor(n: Node) = underlying.debugInfoFor(n) -} diff --git a/src/main/scala/leon/synthesis/utils/Helpers.scala b/src/main/scala/leon/synthesis/utils/Helpers.scala deleted file mode 100644 index 767b703c08e2dbb175712e016cf36a3d96050dd5..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/utils/Helpers.scala +++ /dev/null @@ -1,119 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis -package utils - -import purescala.Path -import purescala.Definitions._ -import purescala.Types._ -import purescala.Extractors._ -import purescala.TypeOps._ -import purescala.Expressions._ -import purescala.DefOps._ -import purescala.Common._ -import Witnesses._ - -object Helpers { - /** - * Filter functions potentially returning target type - * - * If the function takes type parameters, it will try to find an assignment - * such that the function returns the target type. - * - * The return is thus a set of typed functions. - */ - def functionsReturning(fds: Set[FunDef], tpe: TypeTree): Set[TypedFunDef] = { - fds.flatMap { fd => - canBeSubtypeOf(fd.returnType, tpe) match { - case Some(tpsMap) => - Some(fd.typed(fd.typeArgs.map(tp => tpsMap.getOrElse(tp, tp)))) - case None => - None - } - } - } - - /** Given an initial set of function calls provided by a list of [[Terminating]], - * returns function calls that will hopefully be safe to call recursively from within this initial function calls. - * - * For each returned call, one argument is substituted by a "smaller" one, while the rest are left as holes. - * - * @param prog The current program - * @param ws Helper predicates that contain [[Terminating]]s with the initial calls - * @param pc The path condition - * @param tpe The expected type for the returned function calls. If absent, all types are permitted. - * @return A list of pairs (safe function call, holes), - * where holes stand for the rest of the arguments of the function. - */ - def terminatingCalls(prog: Program, ws: Expr, pc: Path, tpe: Option[TypeTree], introduceHoles: Boolean): List[(FunctionInvocation, Option[Set[Identifier]])] = { - - val TopLevelAnds(wss) = ws - - val gs: List[Terminating] = wss.toList.collect { - case t : Terminating => t - } - - def subExprsOf(expr: Expr, v: Variable): Option[(Variable, Expr)] = expr match { - case CaseClassSelector(cct, r, _) => subExprsOf(r, v) - case (r: Variable) if leastUpperBound(r.getType, v.getType).isDefined => Some(r -> v) - case _ => None - } - - val z = InfiniteIntegerLiteral(0) - val one = InfiniteIntegerLiteral(1) - val knownSmallers = (pc.bindings.flatMap { - // @nv: used to check both Equals(id, selector) and Equals(selector, id) - case (id, s @ CaseClassSelector(cct, r, _)) => subExprsOf(s, id.toVariable) - case _ => None - } ++ pc.conditions.flatMap { - case GreaterThan(v: Variable, `z`) => - Some(v -> Minus(v, one)) - case LessThan(`z`, v: Variable) => - Some(v -> Minus(v, one)) - case LessThan(v: Variable, `z`) => - Some(v -> Plus(v, one)) - case GreaterThan(`z`, v: Variable) => - Some(v -> Plus(v, one)) - case _ => None - }).groupBy(_._1).mapValues(v => v.map(_._2)) - - def argsSmaller(e: Expr, tpe: TypeTree): Seq[Expr] = e match { - case CaseClass(cct, args) => - (cct.fieldsTypes zip args).collect { - case (t, e) if isSubtypeOf(t, tpe) => - List(e) ++ argsSmaller(e, tpe) - }.flatten - case v: Variable => - knownSmallers.getOrElse(v, Seq()) - case _ => Nil - } - - val res = gs.flatMap { - case Terminating(FunctionInvocation(tfd, args)) if tpe forall (isSubtypeOf(tfd.returnType, _)) => - val ids = tfd.params.map(vd => FreshIdentifier("<hole>", vd.getType, true)).toList - - for (((a, i), tpe) <- args.zipWithIndex zip tfd.params.map(_.getType); - smaller <- argsSmaller(a, tpe)) yield { - val newArgs = (if (introduceHoles) ids.map(_.toVariable) else args).updated(i, smaller) - (FunctionInvocation(tfd, newArgs), if(introduceHoles) Some(ids.toSet - ids(i)) else None) - } - case _ => - Nil - } - - res - } - - - /** - * All functions we call use in synthesis, which includes: - * - all functions in main units - * - all functions imported, or methods of classes imported - */ - def functionsAvailable(p: Program): Set[FunDef] = { - visibleFunDefsFromMain(p) - } - - -} diff --git a/src/main/scala/leon/synthesis/utils/MutableExpr.scala b/src/main/scala/leon/synthesis/utils/MutableExpr.scala deleted file mode 100644 index f30fd8104d60eee795988e0199b36082ee9e8a18..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/utils/MutableExpr.scala +++ /dev/null @@ -1,21 +0,0 @@ -package leon -package synthesis.utils - -import purescala.Expressions.Expr -import purescala.Extractors.Extractable -import purescala.{PrinterHelpers, PrinterContext, PrettyPrintable} - -/** A mutable expression box useful for CEGIS */ -case class MutableExpr(var underlying: Expr) extends Expr with Extractable with PrettyPrintable { - def getType = underlying.getType - - def extract: Option[(Seq[Expr], (Seq[Expr]) => Expr)] = Some( - Seq(underlying), - { case Seq(e) => underlying = e; this } - ) - - override def printWith(implicit pctx: PrinterContext): Unit = { - import PrinterHelpers._ - p"$underlying" - } -} \ No newline at end of file diff --git a/src/main/scala/leon/synthesis/utils/SynthesisProblemExtractionPhase.scala b/src/main/scala/leon/synthesis/utils/SynthesisProblemExtractionPhase.scala deleted file mode 100644 index 4bbd5ad693dea93ff9ac4f1bafff2365bbd88fb5..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/utils/SynthesisProblemExtractionPhase.scala +++ /dev/null @@ -1,24 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package synthesis -package utils - -import purescala.DefOps.funDefsFromMain -import purescala.Definitions._ - -object SynthesisProblemExtractionPhase extends SimpleLeonPhase[Program, (Program, Map[FunDef, Seq[SourceInfo]])] { - val name = "Synthesis Problem Extraction" - val description = "Synthesis Problem Extraction" - - def apply(ctx: LeonContext, p: Program): (Program, Map[FunDef, Seq[SourceInfo]]) = { - // Look for choose() - val results = for (f <- funDefsFromMain(p).toSeq.sortBy(_.id) if f.body.isDefined) yield { - f -> SourceInfo.extractFromFunction(ctx, p, f) - } - - (p, results.toMap) - } - -} - diff --git a/src/main/scala/leon/termination/ChainBuilder.scala b/src/main/scala/leon/termination/ChainBuilder.scala deleted file mode 100644 index a4b90940ca0d8730916f2cc93a538a4e937a3401..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/termination/ChainBuilder.scala +++ /dev/null @@ -1,179 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package termination - -import purescala.Path -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Constructors._ -import purescala.Common._ - -import scala.collection.mutable.{Map => MutableMap} - -final case class Chain(relations: List[Relation]) { - - private def identifier: Map[(Relation, Relation), Int] = { - (relations zip (relations.tail :+ relations.head)).groupBy(p => p).mapValues(_.size) - } - - override def equals(obj: Any): Boolean = obj match { - case (chain : Chain) => chain.identifier == identifier - case _ => false - } - - override def hashCode(): Int = identifier.hashCode() - - lazy val funDef : FunDef = relations.head.funDef - lazy val funDefs : Set[FunDef] = relations.map(_.funDef).toSet - - lazy val size: Int = relations.size - - private lazy val inlining : Seq[(Seq[ValDef], Expr)] = { - def rec(list: List[Relation], funDef: TypedFunDef, args: Seq[Expr]): Seq[(Seq[ValDef], Expr)] = list match { - case Relation(_, _, fi @ FunctionInvocation(fitfd, nextArgs), _) :: xs => - val tfd = TypedFunDef(fitfd.fd, fitfd.tps.map(funDef.translated)) - val subst = funDef.paramSubst(args) - val expr = replaceFromIDs(subst, hoistIte(expandLets(matchToIfThenElse(tfd.body.get)))) - val mappedArgs = nextArgs.map(e => replaceFromIDs(subst, tfd.translated(e))) - - (tfd.params, expr) +: rec(xs, tfd, mappedArgs) - case Nil => Seq.empty - } - - val body = hoistIte(expandLets(matchToIfThenElse(funDef.body.get))) - val tfd = funDef.typed(funDef.tparams.map(_.tp)) - (tfd.params, body) +: rec(relations, tfd, tfd.params.map(_.toVariable)) - } - - lazy val finalParams : Seq[ValDef] = inlining.last._1 - - def loop(initialArgs: Seq[Identifier] = Seq.empty, finalArgs: Seq[Identifier] = Seq.empty): Path = { - def rec(relations: List[Relation], funDef: TypedFunDef, subst: Map[Identifier, Identifier]): Path = { - val Relation(_, path, FunctionInvocation(fitfd, args), _) = relations.head - val tfd = TypedFunDef(fitfd.fd, fitfd.tps.map(funDef.translated)) - - val translate : Expr => Expr = { - val free : Set[Identifier] = path.variables -- funDef.fd.params.map(_.id) - val freeMapping : Map[Identifier,Identifier] = free.map(id => id -> { - FreshIdentifier(id.name, funDef.translated(id.getType), true).copiedFrom(id) - }).toMap - - val map : Map[Expr, Expr] = (subst ++ freeMapping).map(p => p._1.toVariable -> p._2.toVariable) - (e: Expr) => replace(map, funDef.translated(e)) - } - - lazy val newArgs = args.map(translate) - - path.map(translate) merge (relations.tail match { - case Nil => - Path.empty withBindings (finalArgs zip newArgs) - case xs => - val params = tfd.params.map(_.id) - val freshParams = tfd.params.map(arg => FreshIdentifier(arg.id.name, arg.getType, true)) - Path.empty withBindings (freshParams zip newArgs) merge rec(xs, tfd, (params zip freshParams).toMap) - }) - } - - rec(relations, funDef.typed, (funDef.params.map(_.id) zip initialArgs).toMap) - } - - /* - def reentrant(other: Chain) : Seq[Expr] = { - assert(funDef == other.funDef) - val bindingSubst = funDef.params.map(vd => vd.id -> vd.id.freshen).toMap - val firstLoop = loop(finalSubst = bindingSubst) - val secondLoop = other.loop(initialSubst = bindingSubst) - firstLoop ++ secondLoop - } - */ - - lazy val cycles : Seq[List[Relation]] = relations.indices.map { index => - val (start, end) = relations.splitAt(index) - end ++ start - } - - def compose(that: Chain) : Set[Chain] = { - val map = relations.zipWithIndex.map(p => p._1.call.tfd.fd -> ((p._2 + 1) % relations.size)).groupBy(_._1).mapValues(_.map(_._2)) - val tmap = that.relations.zipWithIndex.map(p => p._1.funDef -> p._2).groupBy(_._1).mapValues(_.map(_._2)) - val keys = map.keys.toSet & tmap.keys.toSet - - for { - fd <- keys - i1 <- map(fd) - (start1, end1) = relations.splitAt(i1) - called = if (start1.isEmpty) relations.head.funDef else start1.last.call.tfd.fd - i2 <- tmap(called) - (start2, end2) = that.relations.splitAt(i2) - } yield Chain(start1 ++ end2 ++ start2 ++ end1) - } - - lazy val inlined: Seq[Expr] = inlining.map(_._2) -} - -trait ChainBuilder extends RelationBuilder { self: Strengthener with RelationComparator => - - protected type ChainSignature = (FunDef, Set[RelationSignature]) - - protected def funDefChainSignature(funDef: FunDef): ChainSignature = { - funDef -> (checker.program.callGraph.transitiveCallees(funDef) + funDef).map(funDefRelationSignature) - } - - private val chainCache : MutableMap[FunDef, (Set[FunDef], Set[Chain], ChainSignature)] = MutableMap.empty - - def getChains(funDef: FunDef)(implicit solver: Processor with Solvable): (Set[FunDef], Set[Chain]) = chainCache.get(funDef) match { - case Some((subloop, chains, signature)) if signature == funDefChainSignature(funDef) => subloop -> chains - case _ => { - val relationConstraints : MutableMap[Relation, SizeConstraint] = MutableMap.empty - - def decreasing(relations: List[Relation]): Boolean = { - val constraints = relations.map(relation => relationConstraints.getOrElse(relation, { - val Relation(funDef, path, FunctionInvocation(_, args), _) = relation - val args0 = funDef.params.map(_.toVariable) - val constraint = if (solver.definitiveALL(path implies self.softDecreasing(args0, args))) { - if (solver.definitiveALL(path implies self.sizeDecreasing(args0, args))) { - StrongDecreasing - } else { - WeakDecreasing - } - } else { - NoConstraint - } - - relationConstraints(relation) = constraint - constraint - })).toSet - - !constraints(NoConstraint) && constraints(StrongDecreasing) - } - - def chains(seen: Set[FunDef], chain: List[Relation]) : (Set[FunDef], Set[Chain]) = { - val Relation(_, _, FunctionInvocation(tfd, _), _) :: _ = chain - val fd = tfd.fd - - if (!checker.program.callGraph.transitivelyCalls(fd, funDef)) { - Set.empty[FunDef] -> Set.empty[Chain] - } else if (fd == funDef) { - Set.empty[FunDef] -> Set(Chain(chain.reverse)) - } else if (seen(fd)) { - Set(fd) -> Set.empty[Chain] - } else { - val results = getRelations(fd).map(r => chains(seen + fd, r :: chain)) - val (funDefs, allChains) = results.unzip - (funDefs.flatten, allChains.flatten) - } - } - - val results = getRelations(funDef).map(r => chains(Set.empty, r :: Nil)) - val (funDefs, allChains) = results.unzip - - val loops = funDefs.flatten - val filteredChains = allChains.flatten.filter(chain => !decreasing(chain.relations)) - - chainCache(funDef) = (loops, filteredChains, funDefChainSignature(funDef)) - - loops -> filteredChains - } - } -} diff --git a/src/main/scala/leon/termination/ChainComparator.scala b/src/main/scala/leon/termination/ChainComparator.scala deleted file mode 100644 index 5e68311c05713a6e7e3e4657e912315e96c9ebf7..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/termination/ChainComparator.scala +++ /dev/null @@ -1,170 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package termination - -import purescala.Path -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Types._ -import purescala.TypeOps._ -import purescala.Constructors._ -import purescala.Common._ - -trait ChainComparator { self : StructuralSize => - val checker: TerminationChecker - - def structuralDecreasing(e1: Expr, e2s: Seq[(Path, Expr)]): Seq[Expr] = flatTypesPowerset(e1.getType).toSeq.map { - recons => andJoin(e2s.map { case (path, e2) => - path implies GreaterThan(self.fullSize(recons(e1)), self.fullSize(recons(e2))) - }) - } - - /* - def structuralDecreasing(e1: Expr, e2s: Seq[(Seq[Expr], Expr)]) : Expr = e1.getType match { - case ContainerType(def1, fields1) => Or(fields1.zipWithIndex map { case ((id1, type1), index) => - structuralDecreasing(CaseClassSelector(def1, e1, id1), e2s.map { case (path, e2) => - e2.getType match { - case ContainerType(def2, fields2) => (path, CaseClassSelector(def2, e2, fields2(index)._1)) - case _ => scala.sys.error("Unexpected input combinations: " + e1 + " " + e2) - } - }) - }) - case TupleType(types1) => Or((0 until types1.length) map { case index => - structuralDecreasing(tupleSelect(e1, index + 1), e2s.map { case (path, e2) => - e2.getType match { - case TupleType(_) => (path, tupleSelect(e2, index + 1)) - case _ => scala.sys.error("Unexpected input combination: " + e1 + " " + e2) - } - }) - }) - case c: ClassType => And(e2s map { case (path, e2) => - e2.getType match { - case c2: ClassType => Implies(And(path), GreaterThan(self.size(e1), self.size(e2))) - case _ => scala.sys.error("Unexpected input combination: " + e1 + " " + e2) - } - }) - case _ => BooleanLiteral(false) - } - */ - - private sealed abstract class NumericEndpoint { - def inverse: NumericEndpoint = this match { - case UpperBoundEndpoint => LowerBoundEndpoint - case LowerBoundEndpoint => UpperBoundEndpoint - case InnerEndpoint => AnyEndpoint - case AnyEndpoint => InnerEndpoint - case NoEndpoint => NoEndpoint - } - def <(that: NumericEndpoint) : Boolean = (this, that) match { - case (UpperBoundEndpoint, AnyEndpoint) => true - case (LowerBoundEndpoint, AnyEndpoint) => true - case (InnerEndpoint, AnyEndpoint) => true - case (NoEndpoint, AnyEndpoint) => true - case (InnerEndpoint, UpperBoundEndpoint) => true - case (InnerEndpoint, LowerBoundEndpoint) => true - case (NoEndpoint, UpperBoundEndpoint) => true - case (NoEndpoint, LowerBoundEndpoint) => true - case (NoEndpoint, InnerEndpoint) => true - case _ => false - } - def <=(that: NumericEndpoint) : Boolean = (this, that) match { - case (t1, t2) if t1 < t2 => true - case (t1, t2) if t1 == t2 => true - case _ => false - } - def min(that: NumericEndpoint) : NumericEndpoint = { - if (this <= that) this else if (that <= this) that else InnerEndpoint - } - def max(that: NumericEndpoint) : NumericEndpoint = { - if (this <= that) that else if (that <= this) this else AnyEndpoint - } - } - - private case object UpperBoundEndpoint extends NumericEndpoint - private case object LowerBoundEndpoint extends NumericEndpoint - private case object InnerEndpoint extends NumericEndpoint - private case object AnyEndpoint extends NumericEndpoint - private case object NoEndpoint extends NumericEndpoint - - private def numericEndpoint(value: Expr, cluster: Set[Chain]) = { - - object Value { - val vars = variablesOf(value) - assert(vars.size == 1) - - def simplifyBinaryArithmetic(e1: Expr, e2: Expr) : Boolean = { - val (inE1, inE2) = (variablesOf(e1) == vars, variablesOf(e2) == vars) - if (inE1 && inE2) false else if (inE1) unapply(e1) else if (inE2) unapply(e2) else { - scala.sys.error("How the heck did we get here?!?") - } - } - - def unapply(expr: Expr): Boolean = if (variablesOf(expr) != vars) false else expr match { - case Plus(e1, e2) => simplifyBinaryArithmetic(e1, e2) - case Minus(e1, e2) => simplifyBinaryArithmetic(e1, e2) - // case Times(e1, e2) => ... Need to make sure multiplier is not negative! - case e => e == value - } - } - - def matches(expr: Expr) : NumericEndpoint = expr match { - case And(es) => es.map(matches).foldLeft[NumericEndpoint](AnyEndpoint)(_ min _) - case Or(es) => es.map(matches).foldLeft[NumericEndpoint](NoEndpoint)(_ max _) - case Not(e) => matches(e).inverse - case GreaterThan(Value(), e) if variablesOf(e).isEmpty => LowerBoundEndpoint - case GreaterThan(e, Value()) if variablesOf(e).isEmpty => UpperBoundEndpoint - case GreaterEquals(Value(), e) if variablesOf(e).isEmpty => LowerBoundEndpoint - case GreaterEquals(e, Value()) if variablesOf(e).isEmpty => UpperBoundEndpoint - case Equals(Value(), e) if variablesOf(e).isEmpty => InnerEndpoint - case Equals(e, Value()) if variablesOf(e).isEmpty => InnerEndpoint - case LessThan(e1, e2) => matches(GreaterThan(e2, e1)) - case LessEquals(e1, e2) => matches(GreaterEquals(e2, e1)) - case _ => NoEndpoint - } - - def endpoint(expr: Expr) : NumericEndpoint = expr match { - case IfExpr(cond, thenn, elze) => matches(cond) match { - case NoEndpoint => - endpoint(thenn) min endpoint(elze) - case ep => - val terminatingThen = functionCallsOf(thenn).forall(fi => checker.terminates(fi.tfd.fd).isGuaranteed) - val terminatingElze = functionCallsOf(elze).forall(fi => checker.terminates(fi.tfd.fd).isGuaranteed) - val thenEndpoint = if (terminatingThen) ep max endpoint(thenn) else endpoint(thenn) - val elzeEndpoint = if (terminatingElze) ep.inverse max endpoint(elze) else endpoint(elze) - thenEndpoint max elzeEndpoint - } - case _ => NoEndpoint - } - - cluster.foldLeft[NumericEndpoint](AnyEndpoint) { (acc, chain) => - acc min chain.inlined.foldLeft[NumericEndpoint](NoEndpoint) { (acc, expr) => - acc max endpoint(expr) - } - } - } - - def numericConverging(e1: Expr, e2s: Seq[(Path, Expr)], cluster: Set[Chain]) : Seq[Expr] = flatType(e1.getType).toSeq.flatMap { - recons => recons(e1) match { - case e if e.getType == IntegerType => - val endpoint = numericEndpoint(e, cluster) - - val uppers = if (endpoint == UpperBoundEndpoint || endpoint == AnyEndpoint) { - Some(andJoin(e2s map { case (path, e2) => path implies GreaterThan(e, recons(e2)) })) - } else { - None - } - - val lowers = if (endpoint == LowerBoundEndpoint || endpoint == AnyEndpoint) { - Some(andJoin(e2s map { case (path, e2) => path implies LessThan(e, recons(e2)) })) - } else { - None - } - - uppers ++ lowers - case _ => Seq.empty - } - } -} - -// vim: set ts=4 sw=4 et: diff --git a/src/main/scala/leon/termination/ChainProcessor.scala b/src/main/scala/leon/termination/ChainProcessor.scala deleted file mode 100644 index 1fcf89fbcd7e6e54b78be94780e45c1436ecc7e8..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/termination/ChainProcessor.scala +++ /dev/null @@ -1,72 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package termination - -import purescala.Expressions._ -import purescala.Common._ -import purescala.Definitions._ -import purescala.Constructors._ - -class ChainProcessor( - val checker: TerminationChecker, - val modules: ChainBuilder with ChainComparator with Strengthener with StructuralSize -) extends Processor with Solvable { - - val name: String = "Chain Processor" - - def run(problem: Problem) = { - reporter.debug("- Strengthening postconditions") - modules.strengthenPostconditions(problem.funSet)(this) - - reporter.debug("- Strengthening applications") - modules.strengthenApplications(problem.funSet)(this) - - reporter.debug("- Running ChainBuilder") - val chainsMap : Map[FunDef, (Set[FunDef], Set[Chain])] = problem.funSet.map { funDef => - funDef -> modules.getChains(funDef)(this) - }.toMap - - val loopPoints = chainsMap.foldLeft(Set.empty[FunDef]) { case (set, (fd, (fds, chains))) => set ++ fds } - - if (loopPoints.size > 1) { - reporter.debug("-+> Multiple looping points, can't build chain proof") - None - } else { - val funDef = loopPoints.headOption getOrElse { - chainsMap.collectFirst { case (fd, (fds, chains)) if chains.nonEmpty => fd }.get - } - - val chains = chainsMap(funDef)._2 - - val e1 = tupleWrap(funDef.params.map(_.toVariable)) - val e2s = chains.toSeq.map { chain => - val freshParams = chain.finalParams.map(arg => FreshIdentifier(arg.id.name, arg.id.getType, true)) - (chain.loop(finalArgs = freshParams), tupleWrap(freshParams.map(_.toVariable))) - } - - reporter.debug("-+> Searching for structural size decrease") - - val structuralFormulas = modules.structuralDecreasing(e1, e2s) - val structuralDecreasing = structuralFormulas.exists(formula => definitiveALL(formula)) - - reporter.debug("-+> Searching for numerical converging") - - val numericFormulas = modules.numericConverging(e1, e2s, chains) - val numericDecreasing = numericFormulas.exists(formula => definitiveALL(formula)) - - if (structuralDecreasing || numericDecreasing) - Some(problem.funDefs map Cleared) - else { - val maybeReentrant = chains.flatMap(c1 => chains.flatMap(c2 => c1 compose c2)).exists { - chain => maybeSAT(chain.loop().toClause) - } - - if (!maybeReentrant) - Some(problem.funDefs map Cleared) - else - None - } - } - } -} diff --git a/src/main/scala/leon/termination/ComplexTerminationChecker.scala b/src/main/scala/leon/termination/ComplexTerminationChecker.scala deleted file mode 100644 index 059ebdafdadd93cd3c36abfefc464743562b6240..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/termination/ComplexTerminationChecker.scala +++ /dev/null @@ -1,61 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package termination - -import purescala.Definitions._ - -import scala.collection.mutable.{Map => MutableMap} - -class ComplexTerminationChecker(context: LeonContext, initProgram: Program) extends ProcessingPipeline(context, initProgram) { - - val name = "Complex Termination Checker" - val description = "A modular termination checker with a few basic modules™" - - val modules = - new StructuralSize - with ArgsSizeSumRelationComparator - with ChainComparator - with Strengthener - with RelationBuilder - with ChainBuilder { - val checker = ComplexTerminationChecker.this - } - - val modulesLexicographic = - new StructuralSize - with LexicographicRelationComparator - with Strengthener - with RelationBuilder { - val checker = ComplexTerminationChecker.this - } - - val modulesOuter = - new StructuralSize - with ArgsOuterSizeRelationComparator - with Strengthener - with RelationBuilder { - val checker = ComplexTerminationChecker.this - } - - val modulesBV = - new StructuralSize - with BVRelationComparator - with Strengthener - with RelationBuilder { - val checker = ComplexTerminationChecker.this - } - - def processors = List( - new RecursionProcessor(this, modules), - // RelationProcessor is the only Processor which benefits from trying a different RelationComparator - new RelationProcessor(this, modulesBV), - new RelationProcessor(this, modulesOuter), - new RelationProcessor(this, modulesLexicographic), - new RelationProcessor(this, modules), - new ChainProcessor(this, modules), - new SelfCallsProcessor(this), - new LoopProcessor(this, modules) - ) - -} diff --git a/src/main/scala/leon/termination/LoopProcessor.scala b/src/main/scala/leon/termination/LoopProcessor.scala deleted file mode 100644 index 643c62b3a9bdd01efa3ba9964441553e69661307..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/termination/LoopProcessor.scala +++ /dev/null @@ -1,56 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package termination - -import purescala.Definitions._ -import purescala.Common._ -import purescala.Expressions._ -import purescala.Constructors._ - -import scala.collection.mutable.{Map => MutableMap} - -class LoopProcessor(val checker: TerminationChecker, val modules: ChainBuilder with Strengthener with StructuralSize, k: Int = 10) extends Processor with Solvable { - - val name: String = "Loop Processor" - - def run(problem: Problem) = { - reporter.debug("- Strengthening applications") - modules.strengthenApplications(problem.funSet)(this) - - reporter.debug("- Running ChainBuilder") - val chains : Set[Chain] = problem.funSet.flatMap(fd => modules.getChains(fd)(this)._2) - - reporter.debug("- Searching for loops") - val nonTerminating: MutableMap[FunDef, Result] = MutableMap.empty - - (0 to k).foldLeft(chains) { (cs, index) => - reporter.debug("-+> Iteration #" + index) - for (chain <- cs if !nonTerminating.isDefinedAt(chain.funDef) && - (chain.funDef.params zip chain.finalParams).forall(p => p._1.getType == p._2.getType)) { - val freshParams = chain.funDef.params.map(arg => FreshIdentifier(arg.id.name, arg.getType, true)) - val path = chain.loop(finalArgs = freshParams) - - val srcTuple = tupleWrap(chain.funDef.params.map(_.toVariable)) - val resTuple = tupleWrap(freshParams.map(_.toVariable)) - - definitiveSATwithModel(path and equality(srcTuple, resTuple)) match { - case Some(model) => - val args = chain.funDef.params.map(arg => model(arg.id)) - val res = if (chain.relations.exists(_.inLambda)) MaybeBroken(chain.funDef, args) else Broken(chain.funDef, args) - nonTerminating(chain.funDef) = res - case None => - } - } - - cs.flatMap(c1 => chains.flatMap(c2 => c1.compose(c2))) - } - - if (nonTerminating.nonEmpty) - Some(nonTerminating.values.toSeq) - else - None - } -} - -// vim: set ts=4 sw=4 et: diff --git a/src/main/scala/leon/termination/ProcessingPipeline.scala b/src/main/scala/leon/termination/ProcessingPipeline.scala deleted file mode 100644 index 94d5566d3e2127bcaf3448eb3aa16ff26b902f02..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/termination/ProcessingPipeline.scala +++ /dev/null @@ -1,224 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package termination - -import purescala.Expressions._ -import purescala.Definitions._ -import utils._ - -trait Problem { - def funSet: Set[FunDef] - def funDefs: Seq[FunDef] - def contains(fd: FunDef): Boolean = funSet(fd) - - override def toString : String = funDefs.map(_.id).mkString("Problem(", ",", ")") -} - -object Problem { - def apply(fds: Set[FunDef]): Problem = new Problem { - val funSet = fds - lazy val funDefs = funSet.toSeq - } - - def apply(fds: Seq[FunDef]): Problem = new Problem { - val funDefs = fds - lazy val funSet = funDefs.toSet - } -} - -sealed abstract class Result(funDef: FunDef) -case class Cleared(funDef: FunDef) extends Result(funDef) -case class Broken(funDef: FunDef, args: Seq[Expr]) extends Result(funDef) -case class MaybeBroken(funDef: FunDef, args: Seq[Expr]) extends Result(funDef) - -abstract class ProcessingPipeline(context: LeonContext, initProgram: Program) extends TerminationChecker(context, initProgram) { - implicit val debugSection = utils.DebugSectionTermination - - import scala.collection.mutable.{PriorityQueue, Map => MutableMap, Set => MutableSet} - - def processors: List[Processor] - - private lazy val processorArray: Array[Processor] = { - assert(processors.nonEmpty) - processors.toArray - } - - private val reporter: Reporter = context.reporter - - implicit object ProblemOrdering extends Ordering[(Problem, Int)] { - def compare(a: (Problem, Int), b: (Problem, Int)): Int = { - val ((aProblem, aIndex), (bProblem, bIndex)) = (a,b) - val (aDefs, bDefs) = (aProblem.funSet, bProblem.funSet) - - val aCallees: Set[FunDef] = aDefs.flatMap(program.callGraph.transitiveCallees) - val bCallees: Set[FunDef] = bDefs.flatMap(program.callGraph.transitiveCallees) - - lazy val aCallers: Set[FunDef] = aDefs.flatMap(program.callGraph.transitiveCallers) - lazy val bCallers: Set[FunDef] = bDefs.flatMap(program.callGraph.transitiveCallers) - - val aCallsB = bDefs.subsetOf(aCallees) - val bCallsA = aDefs.subsetOf(bCallees) - - if (aCallsB && !bCallsA) { - -1 - } else if (bCallsA && !aCallsB) { - 1 - } else { - val smallerPool = bCallees.size compare aCallees.size - if (smallerPool != 0) smallerPool else { - val largerImpact = aCallers.size compare bCallers.size - if (largerImpact != 0) largerImpact else { - bIndex compare aIndex - } - } - } - } - } - - private val problems = new PriorityQueue[(Problem, Int)] - def running = problems.nonEmpty - - private val clearedMap : MutableMap[FunDef, String] = MutableMap.empty - private val brokenMap : MutableMap[FunDef, (String, Seq[Expr])] = MutableMap.empty - private val maybeBrokenMap : MutableMap[FunDef, (String, Seq[Expr])] = MutableMap.empty - - private val unsolved : MutableSet[Problem] = MutableSet.empty - private val dependencies : MutableSet[Problem] = MutableSet.empty - - def isProblem(fd: FunDef): Boolean = { - lazy val callees = program.callGraph.transitiveCallees(fd) - lazy val problemDefs = problems.flatMap(_._1.funDefs).toSet - unsolved.exists(_.contains(fd)) || - dependencies.exists(_.contains(fd)) || - unsolved.exists(_.funDefs exists callees) || - dependencies.exists(_.funDefs exists callees) || - problemDefs(fd) || (problemDefs intersect callees).nonEmpty - } - - private def printQueue() { - val sb = new StringBuilder() - sb.append("- Problems in Queue:\n") - for(p @ (problem, index) <- problems) { - sb.append(" -> Problem awaiting processor #") - sb.append(index + 1) - sb.append(" (") - sb.append(processorArray(index).name) - sb.append(")") - if (p == problems.head) sb.append(" <- next") - sb.append("\n") - for(funDef <- problem.funDefs) { - sb.append(" " + funDef.id + "\n") - } - } - reporter.debug(sb.toString) - } - - private def printResult(results: List[Result]) { - val sb = new StringBuilder() - sb.append("- Processing Result:\n") - for(result <- results) result match { - case Cleared(fd) => sb.append(f" ${fd.id}%-10s Cleared\n") - case Broken(fd, args) => sb.append(f" ${fd.id}%-10s ${"Broken for arguments: " + args.mkString("(", ",", ")")}\n") - case MaybeBroken(fd, args) => sb.append(f" ${fd.id}%-10s ${"HO construct application breaks for arguments: " + args.mkString("(", ",", ")")}\n") - } - reporter.debug(sb.toString) - } - - private val terminationMap : MutableMap[FunDef, TerminationGuarantee] = MutableMap.empty - def terminates(funDef: FunDef): TerminationGuarantee = { - val guarantee = { - terminationMap.get(funDef) orElse - brokenMap.get(funDef).map { case (reason, args) => LoopsGivenInputs(reason, args) } orElse - maybeBrokenMap.get(funDef).map { case (reason, args) => MaybeLoopsGivenInputs(reason, args)} getOrElse { - val callees = program.callGraph.transitiveCallees(funDef) - val broken = brokenMap.keys.toSet ++ maybeBrokenMap.keys - val brokenCallees = callees intersect broken - if (brokenCallees.nonEmpty) { - CallsNonTerminating(brokenCallees) - } else if (isProblem(funDef)) { - NoGuarantee - } else { - clearedMap.get(funDef).map(Terminates).getOrElse( - if (!running) { - val verified = verifyTermination(funDef) - for (fd <- verified) terminates(fd) // fill in terminationMap - terminates(funDef) - } else { - if (!dependencies.exists(_.contains(funDef))) { - dependencies ++= generateProblems(funDef) - } - NoGuarantee - } - ) - } - } - } - - if (!running) terminationMap(funDef) = guarantee - guarantee - } - - private def generateProblems(funDef: FunDef): Seq[Problem] = { - val funDefs = program.callGraph.transitiveCallees(funDef) + funDef - val pairs = program.callGraph.allCalls.filter { case (fd1, fd2) => funDefs(fd1) && funDefs(fd2) } - val callGraph = pairs.groupBy(_._1).mapValues(_.map(_._2)) - val allComponents = SCC.scc(callGraph) - - val (problemComponents, nonRec) = allComponents.partition { fds => - fds.flatMap(fd => program.callGraph.transitiveCallees(fd)) exists fds - } - - for (fd <- funDefs -- problemComponents.toSet.flatten) clearedMap(fd) = "Non-recursive" - val newProblems = problemComponents.filter(fds => fds.forall { fd => !terminationMap.isDefinedAt(fd) }) - newProblems.map(fds => Problem(fds.toSeq)) - } - - def verifyTermination(funDef: FunDef): Set[FunDef] = { - reporter.debug("Verifying termination of " + funDef.id) - val terminationProblems = generateProblems(funDef) - problems ++= terminationProblems.map(_ -> 0) - - val it = new Iterator[(String, List[Result])] { - def hasNext : Boolean = running - def next() : (String, List[Result]) = { - printQueue() - val (problem, index) = problems.head - val processor : Processor = processorArray(index) - reporter.debug("Running " + processor.name) - val result = processor.run(problem) - reporter.debug(" +-> " + (if (result.isDefined) "Success" else "Failure")+ "\n") - - // dequeue and enqueue atomically to make sure the queue always - // makes sense (necessary for calls to terminates(fd)) - problems.dequeue() - result match { - case None if index == processorArray.length - 1 => - unsolved += problem - case None => - problems.enqueue(problem -> (index + 1)) - case Some(results) => - val impacted = problem.funDefs.flatMap(fd => program.callGraph.transitiveCallers(fd)) - val reenter = unsolved.filter(p => (p.funDefs intersect impacted).nonEmpty) - problems.enqueue(reenter.map(_ -> 0).toSeq : _*) - unsolved --= reenter - } - - if (dependencies.nonEmpty) { - problems.enqueue(dependencies.map(_ -> 0).toSeq : _*) - dependencies.clear - } - - processor.name -> result.toList.flatten - } - } - - for ((reason, results) <- it; result <- results) result match { - case Cleared(fd) => clearedMap(fd) = reason - case Broken(fd, args) => brokenMap(fd) = (reason, args) - case MaybeBroken(fd, args) => maybeBrokenMap(fd) = (reason, args) - } - - terminationProblems.flatMap(_.funDefs).toSet - } -} diff --git a/src/main/scala/leon/termination/Processor.scala b/src/main/scala/leon/termination/Processor.scala deleted file mode 100644 index 024385ebf5c138d8c744053c9deee780bd10175b..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/termination/Processor.scala +++ /dev/null @@ -1,61 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package termination - -import purescala.Expressions._ -import purescala.Common._ -import purescala.Definitions._ - -import scala.concurrent.duration._ - -import leon.solvers._ - -trait Processor { - - val name: String - - val checker : TerminationChecker - - implicit val debugSection = utils.DebugSectionTermination - val reporter = checker.context.reporter - - def run(problem: Problem): Option[Seq[Result]] -} - -trait Solvable extends Processor { - - val modules: Strengthener with StructuralSize - - private val solver: SolverFactory[Solver] = - SolverFactory.getFromSettings(checker.context, checker.program).withTimeout(1.seconds) - - type Solution = (Option[Boolean], Map[Identifier, Expr]) - - private def withoutPosts[T](block: => T): T = { - val dangerousFunDefs = checker.functions.filter(fd => !checker.terminates(fd).isGuaranteed) - val backups = dangerousFunDefs.toList map { fd => - val p = fd.postcondition - fd.postcondition = None - () => fd.postcondition = p - } - - val res : T = block // force evaluation now - backups.foreach(_()) - res - } - - def maybeSAT(problem: Expr): Boolean = withoutPosts { - SimpleSolverAPI(solver).solveSAT(problem)._1 getOrElse true - } - - def definitiveALL(problem: Expr): Boolean = withoutPosts { - SimpleSolverAPI(solver).solveSAT(Not(problem))._1.exists(!_) - } - - def definitiveSATwithModel(problem: Expr): Option[Model] = withoutPosts { - val (sat, model) = SimpleSolverAPI(solver).solveSAT(problem) - if (sat.isDefined && sat.get) Some(model) else None - } -} - diff --git a/src/main/scala/leon/termination/RecursionProcessor.scala b/src/main/scala/leon/termination/RecursionProcessor.scala deleted file mode 100644 index 622723e2b12a2dea284f9f4bb9a0b630c861a0cf..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/termination/RecursionProcessor.scala +++ /dev/null @@ -1,52 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package termination - -import purescala.Expressions._ -import purescala.Common._ - -import scala.annotation.tailrec - -class RecursionProcessor(val checker: TerminationChecker, val rb: RelationBuilder) extends Processor { - - val name: String = "Recursion Processor" - - private def isSubtreeOf(expr: Expr, id: Identifier) : Boolean = { - @tailrec - def rec(e: Expr, fst: Boolean): Boolean = e match { - case Variable(aid) if aid == id => !fst - case CaseClassSelector(_, cc, _) => rec(cc, false) - case _ => false - } - rec(expr, true) - } - - def run(problem: Problem) = if (problem.funDefs.size > 1) None else { - val funDef = problem.funDefs.head - val relations = rb.getRelations(funDef) - val (recursive, others) = relations.partition({ case Relation(_, _, FunctionInvocation(tfd, _), _) => tfd.fd == funDef }) - - if (others.exists({ case Relation(_, _, FunctionInvocation(tfd, _), _) => !checker.terminates(tfd.fd).isGuaranteed })) { - None - } else { - val decreases = funDef.params.zipWithIndex.exists({ case (arg, index) => - recursive.forall({ case Relation(_, path, FunctionInvocation(_, args), _) => - args(index) match { - // handle case class deconstruction in match expression! - case Variable(id) => path.bindings.exists { - case (vid, ccs) if vid == id => isSubtreeOf(ccs, arg.id) - case _ => false - } - case expr => isSubtreeOf(expr, arg.id) - } - }) - }) - - if (decreases) - Some(Cleared(funDef) :: Nil) - else - None - } - } -} diff --git a/src/main/scala/leon/termination/RelationBuilder.scala b/src/main/scala/leon/termination/RelationBuilder.scala deleted file mode 100644 index 5feef247e396feefa1bc35d47f45033f1590031e..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/termination/RelationBuilder.scala +++ /dev/null @@ -1,68 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package termination - -import purescala._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Definitions._ - -import scala.collection.mutable.{Map => MutableMap} - -final case class Relation(funDef: FunDef, path: Path, call: FunctionInvocation, inLambda: Boolean) { - override def toString : String = "Relation(" + funDef.id + "," + path + ", " + call.tfd.id + call.args.mkString("(",",",")") + "," + inLambda + ")" -} - -trait RelationBuilder { self: Strengthener => - - val checker: TerminationChecker - - protected type RelationSignature = (FunDef, Option[Expr], Option[Expr], Option[Expr], Boolean, Set[(FunDef, Boolean)]) - - protected def funDefRelationSignature(fd: FunDef): RelationSignature = { - val strengthenedCallees = checker.program.callGraph.callees(fd).map(fd => fd -> strengthened(fd)) - (fd, fd.precondition, fd.body, fd.postcondition, checker.terminates(fd).isGuaranteed, strengthenedCallees) - } - - private val relationCache : MutableMap[FunDef, (Set[Relation], RelationSignature)] = MutableMap.empty - - def getRelations(funDef: FunDef): Set[Relation] = relationCache.get(funDef) match { - case Some((relations, signature)) if signature == funDefRelationSignature(funDef) => relations - case _ => { - val collector = new CollectorWithPaths[Relation] { - var inLambda: Boolean = false - - override def rec(e: Expr, path: Path): Expr = e match { - case l : Lambda => - val old = inLambda - inLambda = true - val res = super.rec(e, path) - inLambda = old - res - case _ => - super.rec(e, path) - } - - def collect(e: Expr, path: Path): Option[Relation] = e match { - case fi @ FunctionInvocation(f, args) if checker.functions(f.fd) => - Some(Relation(funDef, path, fi, inLambda)) - case _ => None - } - - override def walk(e: Expr, path: Path) = e match { - case FunctionInvocation(tfd, args) => - val funDef = tfd.fd - Some(FunctionInvocation(tfd, (funDef.params.map(_.id) zip args) map { case (id, arg) => - rec(arg, path withCond self.applicationConstraint(funDef, id, arg, args)) - })) - case _ => None - } - } - - val relations = collector.traverse(funDef).toSet - relationCache(funDef) = (relations, funDefRelationSignature(funDef)) - relations - } - } -} diff --git a/src/main/scala/leon/termination/RelationComparator.scala b/src/main/scala/leon/termination/RelationComparator.scala deleted file mode 100644 index a7cba806d82c39c79603b0169e4d6cb17cbd7cb8..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/termination/RelationComparator.scala +++ /dev/null @@ -1,112 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package termination - -import purescala.Expressions._ -import leon.purescala.Constructors._ -import leon.purescala.Types.Int32Type - -trait RelationComparator { self : StructuralSize => - - val comparisonMethod: String - - def isApplicableFor(p: Problem): Boolean - - /** strictly decreasing: args1 > args2 */ - def sizeDecreasing(args1: Seq[Expr], args2: Seq[Expr]): Expr - - /** weakly decreasing: args1 >= args2 */ - def softDecreasing(args1: Seq[Expr], args2: Seq[Expr]): Expr -} - - -trait ArgsSizeSumRelationComparator extends RelationComparator { self : StructuralSize => - - val comparisonMethod = "comparing sum of argument sizes" - - def isApplicableFor(p: Problem): Boolean = true - - def sizeDecreasing(args1: Seq[Expr], args2: Seq[Expr]): Expr = { - GreaterThan(self.fullSize(tupleWrap(args1)), self.fullSize(tupleWrap(args2))) - } - - def softDecreasing(args1: Seq[Expr], args2: Seq[Expr]): Expr = { - GreaterEquals(self.fullSize(tupleWrap(args1)), self.fullSize(tupleWrap(args2))) - } -} - - -trait ArgsOuterSizeRelationComparator extends RelationComparator { self : StructuralSize => - - val comparisonMethod = "comparing outer structural sizes of argument types" - - def isApplicableFor(p: Problem): Boolean = true - - def sizeDecreasing(args1: Seq[Expr], args2: Seq[Expr]): Expr = { - GreaterThan(self.outerSize(tupleWrap(args1)), self.outerSize(tupleWrap(args2))) - } - - def softDecreasing(args1: Seq[Expr], args2: Seq[Expr]): Expr = { - GreaterEquals(self.outerSize(tupleWrap(args1)), self.outerSize(tupleWrap(args2))) - } -} - - -trait LexicographicRelationComparator extends RelationComparator { self : StructuralSize => - - val comparisonMethod = "comparing argument lists lexicographically" - - def isApplicableFor(p: Problem): Boolean = true - - def sizeDecreasing(s1: Seq[Expr], s2: Seq[Expr]): Expr = { - lexicographicDecreasing(s1, s2, strict = true, sizeOfOneExpr = e => self.fullSize(e)) - } - - def softDecreasing(s1: Seq[Expr], s2: Seq[Expr]): Expr = { - lexicographicDecreasing(s1, s2, strict = false, sizeOfOneExpr = e => self.fullSize(e)) - } -} - -// for bitvector Ints -trait BVRelationComparator extends RelationComparator { self : StructuralSize => - - /* - Note: It might seem that comparing the sum of all Int arguments is more flexible, but on - bitvectors, this causes overflow problems, so we won't be able to prove anything! - So that's why this function is useless: - - def bvSize(args: Seq[Expr]): Expr = { - val absValues: Seq[Expr] = args.collect{ - case e if e.getType == Int32Type => FunctionInvocation(typedAbsIntFun, Seq(e)) - } - absValues.foldLeft[Expr](IntLiteral(0)){ case (sum, expr) => BVPlus(sum, expr) } - } - */ - - val comparisonMethod = "comparing Int arguments lexicographically" - - def isApplicableFor(p: Problem): Boolean = { - p.funDefs.forall(fd => fd.params.exists(valdef => valdef.getType == Int32Type)) - } - - def bvSize(e: Expr): Expr = FunctionInvocation(typedAbsIntFun, Seq(e)) - - /* Note: We swap the arguments to the `lexicographicDecreasing` call - * since bvSize maps into negative ints! (avoids -Integer.MIN_VALUE overflow) */ - - def sizeDecreasing(s10: Seq[Expr], s20: Seq[Expr]): Expr = { - val s1 = s10.filter(_.getType == Int32Type) - val s2 = s20.filter(_.getType == Int32Type) - lexicographicDecreasing(s2, s1, strict = true, sizeOfOneExpr = bvSize) - } - - def softDecreasing(s10: Seq[Expr], s20: Seq[Expr]): Expr = { - val s1 = s10.filter(_.getType == Int32Type) - val s2 = s20.filter(_.getType == Int32Type) - lexicographicDecreasing(s2, s1, strict = false, sizeOfOneExpr = bvSize) - } -} - - -// vim: set ts=4 sw=4 et: diff --git a/src/main/scala/leon/termination/RelationProcessor.scala b/src/main/scala/leon/termination/RelationProcessor.scala deleted file mode 100644 index 980c015f1cbe29647a9c00fbc308d64fb5fa80ca..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/termination/RelationProcessor.scala +++ /dev/null @@ -1,83 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package termination - -import leon.purescala.Expressions._ -import leon.purescala.Constructors._ -import leon.purescala.Definitions._ - -class RelationProcessor( - val checker: TerminationChecker, - val modules: RelationBuilder with RelationComparator with Strengthener with StructuralSize - ) extends Processor with Solvable { - - val name: String = "Relation Processor " + modules.comparisonMethod - - def run(problem: Problem): Option[Seq[Result]] = { - if (!modules.isApplicableFor(problem)) return None - - reporter.debug("- Strengthening postconditions") - modules.strengthenPostconditions(problem.funSet)(this) - - reporter.debug("- Strengthening applications") - modules.strengthenApplications(problem.funSet)(this) - - val formulas = problem.funDefs.map({ funDef => - funDef -> modules.getRelations(funDef).collect({ - case Relation(_, path, FunctionInvocation(tfd, args), _) if problem.funSet(tfd.fd) => - val args0 = funDef.params.map(_.toVariable) - def constraint(expr: Expr) = path implies expr - val greaterThan = modules.sizeDecreasing(args0, args) - val greaterEquals = modules.softDecreasing(args0, args) - (tfd.fd, (constraint(greaterThan), constraint(greaterEquals))) - }) - }) - - sealed abstract class Result - case object Success extends Result - case class Dep(deps: Set[FunDef]) extends Result - case object Failure extends Result - - reporter.debug("- Searching for structural size decrease") - val decreasing = formulas.map({ case (fd, formulas) => - val solved = formulas.map({ case (fid, (gt, ge)) => - if (definitiveALL(gt)) Success - else if (definitiveALL(ge)) Dep(Set(fid)) - else Failure - }) - - val result = if(solved.contains(Failure)) Failure else { - val deps = solved.collect({ case Dep(fds) => fds }).flatten - if (deps.isEmpty) Success - else Dep(deps) - } - fd -> result - }) - - val (terminating, nonTerminating) = { - def currentReducing(fds: Set[FunDef], deps: List[(FunDef, Set[FunDef])]): (Set[FunDef], List[(FunDef, Set[FunDef])]) = { - val (okDeps, nokDeps) = deps.partition({ case (fd, deps) => deps.subsetOf(fds) }) - val newFds = fds ++ okDeps.map(_._1) - (newFds, nokDeps) - } - - def fix[A,B](f: (A,B) => (A,B), a: A, b: B): (A,B) = { - val (na, nb) = f(a, b) - if(na == a && nb == b) (a,b) else fix(f, na, nb) - } - - val ok = decreasing.collect({ case (fd, Success) => fd }) - val nok = decreasing.collect({ case (fd, Dep(fds)) => fd -> fds }).toList - val (allOk, allNok) = fix(currentReducing, ok.toSet, nok) - (allOk, allNok.map(_._1).toSet ++ decreasing.collect({ case (fd, Failure) => fd })) - } - - assert(terminating ++ nonTerminating == problem.funSet) - - if (nonTerminating.isEmpty) - Some(terminating.map(Cleared).toSeq) - else - None - } -} diff --git a/src/main/scala/leon/termination/SelfCallsProcessor.scala b/src/main/scala/leon/termination/SelfCallsProcessor.scala deleted file mode 100644 index 6a14750224fa8766ffe6c9a4784e38a07d793b2d..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/termination/SelfCallsProcessor.scala +++ /dev/null @@ -1,75 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package termination - -import purescala.Definitions._ -import purescala.Common._ -import purescala.Expressions._ - -class SelfCallsProcessor(val checker: TerminationChecker) extends Processor { - - val name: String = "Self Calls Processor" - - def run(problem: Problem): Option[Seq[Result]] = { - reporter.debug("- Self calls processor...") - - val nonTerminating = problem.funDefs - .filter(fd => fd.hasBody && alwaysCalls(fd.body.get, fd)) - - if (nonTerminating.nonEmpty) - Some(nonTerminating.map(fd => Broken(fd, Seq(Variable(FreshIdentifier("any input")))))) - else - None - } - - - def alwaysCalls(expr: Expr, f: FunDef): Boolean = { - val seenFunDefs = collection.mutable.HashSet[FunDef]() - - def rec(e0: Expr): Boolean = e0 match { - case Assert(pred: Expr, error: Option[String], body: Expr) => rec(pred) || rec(body) - case Let(binder: Identifier, value: Expr, body: Expr) => rec(value) || rec(body) - case LetDef(fds, body: Expr) => rec(body) // don't enter fds because we don't know if it will be called - case FunctionInvocation(tfd: TypedFunDef, args: Seq[Expr]) => - tfd.fd == f /* <-- success in proving non-termination */ || - args.exists(arg => rec(arg)) || (tfd.fd.hasBody && (!seenFunDefs.contains(tfd.fd)) && { - seenFunDefs += tfd.fd - rec(tfd.fd.body.get) - }) - case Application(caller: Expr, args: Seq[Expr]) => rec(caller) || args.exists(arg => rec(arg)) - case Lambda(args: Seq[ValDef], body: Expr) => false // we don't know if it will be called - //case Forall(args: Seq[ValDef], body: Expr) ? - case IfExpr(cond: Expr, thenn: Expr, elze: Expr) => rec(cond) // don't enter thenn/elze - case Tuple (exprs: Seq[Expr]) => exprs.exists(ex => rec(ex)) - case TupleSelect(tuple: Expr, index: Int) => rec(tuple) - case MatchExpr(scrutinee: Expr, cases: Seq[MatchCase]) => rec(scrutinee) - // case Passes(in: Expr, out : Expr, cases : Seq[MatchCase]) ? - case And(exprs: Seq[Expr]) => rec(exprs.head) // only the first expr will definitely be executed, if it returns false, - // nothing more will be executed due to short-curcuit evaluation - case Or(exprs: Seq[Expr]) => rec(exprs.head) - // case Implies(lhs: Expr, rhs: Expr) short-circuit evaluation as well? - case Not(expr: Expr) => rec(expr) - case Equals(lhs: Expr, rhs: Expr) => rec(lhs) || rec(rhs) - case CaseClass(ct, args: Seq[Expr]) => args.exists(arg => rec(arg)) - case IsInstanceOf(expr: Expr, ct) => rec(expr) - case AsInstanceOf(expr: Expr, ct) => rec(expr) - case CaseClassSelector(ct, caseClassExpr, selector) => rec(caseClassExpr) - case Plus(lhs: Expr, rhs: Expr) => rec(lhs) || rec(rhs) - case Minus(lhs: Expr, rhs: Expr) => rec(lhs) || rec(rhs) - case UMinus(expr: Expr) => rec(expr) - case Times(lhs: Expr, rhs: Expr) => rec(lhs) || rec(rhs) - case Division(lhs: Expr, rhs: Expr) => rec(lhs) || rec(rhs) - case Modulo(lhs: Expr, rhs: Expr) => rec(lhs) || rec(rhs) - case LessThan(lhs: Expr, rhs: Expr) => rec(lhs) || rec(rhs) - case GreaterThan(lhs: Expr, rhs: Expr) => rec(lhs) || rec(rhs) - case LessEquals(lhs: Expr, rhs: Expr) => rec(lhs) || rec(rhs) - case GreaterEquals(lhs: Expr, rhs: Expr) => rec(lhs) || rec(rhs) - /* TODO marker trait for Bit-vector arithmetic and treat them all at once */ - // TODO set & map operations - case _ => false - } - - rec(expr) - } -} diff --git a/src/main/scala/leon/termination/SimpleTerminationChecker.scala b/src/main/scala/leon/termination/SimpleTerminationChecker.scala deleted file mode 100644 index a5864eeb4490c8d58a54589166b26933b329591d..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/termination/SimpleTerminationChecker.scala +++ /dev/null @@ -1,123 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package termination - -import purescala.Common._ -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.ExprOps._ -import utils._ - -import scala.collection.mutable.{ Map => MutableMap } - -import scala.annotation.tailrec - -class SimpleTerminationChecker(context: LeonContext, program: Program) extends TerminationChecker(context, program) { - - val name = "T1" - val description = "The simplest form of Terminator™" - - private lazy val callGraph: Map[FunDef, Set[FunDef]] = - program.callGraph.allCalls.groupBy(_._1).mapValues(_.map(_._2)) // one liner from hell - - private lazy val components = SCC.scc(callGraph) - val allVertices = callGraph.keySet ++ callGraph.values.flatten - - val sccArray = components.toArray - val cSize = sccArray.length - - val funDefToSCCIndex = (callGraph.keySet ++ callGraph.values.flatten).map(v => - v -> (0 until cSize).find(i => sccArray(i)(v)).get).toMap - - val sccGraph = (0 until cSize).map({ i => - val dsts = for { - v <- sccArray(i) - c <- callGraph.getOrElse(v, Set.empty) - } yield funDefToSCCIndex(c) - i -> dsts - }).toMap - - private val answerCache = MutableMap.empty[FunDef, TerminationGuarantee] - - def terminates(funDef: FunDef) = answerCache.getOrElse(funDef, { - val g = forceCheckTermination(funDef) - answerCache(funDef) = g - g - }) - - private def forceCheckTermination(funDef: FunDef): TerminationGuarantee = { - // We would have to clarify what it means to terminate. - // We would probably need something along the lines of: - // "Terminates for all values satisfying prec." - if (funDef.hasPrecondition) - return NoGuarantee - - // This is also too confusing for me to think about now. - if (!funDef.hasBody) - return NoGuarantee - - val sccIndex = funDefToSCCIndex.getOrElse(funDef, { - return NoGuarantee - }) - val sccCallees = sccGraph(sccIndex) - - // We check all functions that are in a "lower" scc. These must - // terminate for all inputs in any case. - val sccLowerCallees = sccCallees - sccIndex - val lowerDefs = sccLowerCallees.flatMap(sccArray(_)) - val lowerOK = lowerDefs.forall(terminates(_).isGuaranteed) - if (!lowerOK) - return NoGuarantee - - // Now all we need to do is check the functions in the same - // scc. But who knows, maybe none of these are called? - if (!sccCallees(sccIndex)) { - // (the distinction isn't exactly useful...) - if (sccCallees.isEmpty) - return Terminates("no calls") - else - return Terminates("by subcalls") - } - - // So now we know the function is recursive (or mutually - // recursive). Maybe it's just self-recursive? - if (sccArray(sccIndex).size == 1) { - assert(sccArray(sccIndex) == Set(funDef)) - // Yes it is ! - // Now we apply a simple recipe: we check that in each (self) - // call, at least one argument is of an ADT type and decreases. - // Yes, it's that restrictive. - val callsOfInterest = { (e: Expr) => - functionCallsOf(simplifyLets(matchToIfThenElse(e))).filter(_.tfd.fd == funDef) - } - - val callsToAnalyze = callsOfInterest(funDef.fullBody) - - val funDefArgsIDs = funDef.params.map(_.id).toSet - - if (callsToAnalyze.forall { fi => - fi.args.exists { arg => - isSubTreeOfArg(arg, funDefArgsIDs) - } - }) { - return Terminates("decreasing") - } else { - return NoGuarantee - } - } - - // Handling mutually recursive functions is beyond my willpower. - NoGuarantee - } - - private def isSubTreeOfArg(expr: Expr, args: Set[Identifier]): Boolean = { - @tailrec - def rec(e: Expr, fst: Boolean): Boolean = e match { - case Variable(id) if args(id) => !fst - case CaseClassSelector(_, cc, _) => rec(cc, false) - case _ => false - } - rec(expr, true) - } -} diff --git a/src/main/scala/leon/termination/Strengthener.scala b/src/main/scala/leon/termination/Strengthener.scala deleted file mode 100644 index 21b95787f08b65274086e069d8470040fcabe005..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/termination/Strengthener.scala +++ /dev/null @@ -1,190 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package termination - -import purescala.Path -import purescala.Expressions._ -import purescala.Types._ -import purescala.ExprOps._ -import purescala.Common._ -import purescala.Definitions._ -import purescala.Constructors._ - -import scala.collection.mutable.{Set => MutableSet, Map => MutableMap} - -trait Strengthener { self : RelationComparator => - - val checker: TerminationChecker - - implicit object CallGraphOrdering extends Ordering[FunDef] { - def compare(a: FunDef, b: FunDef): Int = { - val aCallsB = checker.program.callGraph.transitivelyCalls(a, b) - val bCallsA = checker.program.callGraph.transitivelyCalls(b, a) - if (aCallsB && !bCallsA) -1 - else if (bCallsA && !aCallsB) 1 - else 0 - } - } - - private val strengthenedPost : MutableSet[FunDef] = MutableSet.empty - - def strengthenPostconditions(funDefs: Set[FunDef])(implicit solver: Processor with Solvable) { - // Strengthen postconditions on all accessible functions by adding size constraints - val callees : Set[FunDef] = funDefs.flatMap(fd => checker.program.callGraph.transitiveCallees(fd)) - val sortedCallees : Seq[FunDef] = callees.toSeq.sorted - - for (funDef <- sortedCallees if !strengthenedPost(funDef) && funDef.hasBody && checker.terminates(funDef).isGuaranteed) { - def strengthen(cmp: (Seq[Expr], Seq[Expr]) => Expr): Boolean = { - val old = funDef.postcondition - val postcondition = { - val res = FreshIdentifier("res", funDef.returnType, true) - val post = old.map{application(_, Seq(Variable(res)))}.getOrElse(BooleanLiteral(true)) - val sizePost = cmp(funDef.params.map(_.toVariable), Seq(res.toVariable)) - Lambda(Seq(ValDef(res)), and(post, sizePost)) - } - - funDef.postcondition = Some(postcondition) - - val prec = matchToIfThenElse(funDef.precOrTrue) - val body = matchToIfThenElse(funDef.body.get) - val post = matchToIfThenElse(postcondition) - val formula = implies(prec, application(post, Seq(body))) - - // @nv: one must also check satisfiability here as if both formula and - // !formula are UNSAT, we will proceed to invalid strenghtening - if (!solver.maybeSAT(formula) || !solver.definitiveALL(formula)) { - funDef.postcondition = old - false - } else { - true - } - } - - // test if size is smaller or equal to input - val weekConstraintHolds = strengthen(self.softDecreasing) - - if (weekConstraintHolds) { - // try to improve postcondition with strictly smaller - strengthen(self.sizeDecreasing) - } - - strengthenedPost += funDef - } - } - - sealed abstract class SizeConstraint - case object StrongDecreasing extends SizeConstraint - case object WeakDecreasing extends SizeConstraint - case object NoConstraint extends SizeConstraint - - private val strengthenedApp : MutableSet[FunDef] = MutableSet.empty - - protected def strengthened(fd: FunDef): Boolean = strengthenedApp(fd) - - private val appConstraint : MutableMap[(FunDef, Identifier), SizeConstraint] = MutableMap.empty - - def applicationConstraint(fd: FunDef, id: Identifier, arg: Expr, args: Seq[Expr]): Expr = arg match { - case Lambda(fargs, body) => appConstraint.get(fd -> id) match { - case Some(StrongDecreasing) => self.sizeDecreasing(args, fargs.map(_.toVariable)) - case Some(WeakDecreasing) => self.softDecreasing(args, fargs.map(_.toVariable)) - case _ => BooleanLiteral(true) - } - case _ => BooleanLiteral(true) - } - - def strengthenApplications(funDefs: Set[FunDef])(implicit solver: Processor with Solvable) { - val transitiveFunDefs = funDefs ++ funDefs.flatMap(checker.program.callGraph.transitiveCallees) - val sortedFunDefs = transitiveFunDefs.toSeq.sorted - - for (funDef <- sortedFunDefs if !strengthenedApp(funDef) && funDef.hasBody && checker.terminates(funDef).isGuaranteed) { - - val appCollector = new CollectorWithPaths[(Identifier,Path,Seq[Expr])] { - def collect(e: Expr, path: Path): Option[(Identifier, Path, Seq[Expr])] = e match { - case Application(Variable(id), args) => Some((id, path, args)) - case _ => None - } - } - - val applications = appCollector.traverse(funDef).distinct - - val funDefArgs = funDef.params.map(_.toVariable) - - val allFormulas = for ((id, path, appArgs) <- applications) yield { - val soft = path implies self.softDecreasing(funDefArgs, appArgs) - val hard = path implies self.sizeDecreasing(funDefArgs, appArgs) - id -> ((soft, hard)) - } - - val formulaMap = allFormulas.groupBy(_._1).mapValues(_.map(_._2).unzip) - - val constraints = for ((id, (weakFormulas, strongFormulas)) <- formulaMap) yield id -> { - if (solver.definitiveALL(andJoin(weakFormulas.toSeq))) { - if (solver.definitiveALL(andJoin(strongFormulas.toSeq))) { - StrongDecreasing - } else { - WeakDecreasing - } - } else { - NoConstraint - } - } - - val funDefHOArgs = funDef.params.map(_.id).filter(_.getType.isInstanceOf[FunctionType]).toSet - - val fiCollector = new CollectorWithPaths[(Path, Seq[Expr], Seq[(Identifier,(FunDef, Identifier))])] { - def collect(e: Expr, path: Path): Option[(Path, Seq[Expr], Seq[(Identifier,(FunDef, Identifier))])] = e match { - case FunctionInvocation(tfd, args) if (funDefHOArgs intersect args.collect({ case Variable(id) => id }).toSet).nonEmpty => - Some((path, args, (args zip tfd.fd.params).collect { - case (Variable(id), vd) if funDefHOArgs(id) => id -> ((tfd.fd, vd.id)) - })) - case _ => None - } - } - - val invocations = fiCollector.traverse(funDef) - val id2invocations : Seq[(Identifier, ((FunDef, Identifier), Path, Seq[Expr]))] = - for { - p <- invocations - c <- p._3 - } yield c._1 -> (c._2, p._1, p._2) - val invocationMap: Map[Identifier, Seq[((FunDef, Identifier), Path, Seq[Expr])]] = - id2invocations.groupBy(_._1).mapValues(_.map(_._2)) - - def constraint(id: Identifier, passings: Seq[((FunDef, Identifier), Path, Seq[Expr])]): SizeConstraint = { - if (constraints.get(id) == Some(NoConstraint)) NoConstraint - else if (passings.exists(p => appConstraint.get(p._1) == Some(NoConstraint))) NoConstraint - else passings.foldLeft[SizeConstraint](constraints.getOrElse(id, StrongDecreasing)) { - case (constraint, (key, path, args)) => - - lazy val strongFormula = path implies self.sizeDecreasing(funDefArgs, args) - lazy val weakFormula = path implies self.softDecreasing(funDefArgs, args) - - (constraint, appConstraint.get(key)) match { - case (_, Some(NoConstraint)) => scala.sys.error("Whaaaat!?!? This shouldn't happen...") - case (_, None) => NoConstraint - case (NoConstraint, _) => NoConstraint - case (StrongDecreasing | WeakDecreasing, Some(StrongDecreasing)) => - if (solver.definitiveALL(weakFormula)) StrongDecreasing - else NoConstraint - case (StrongDecreasing, Some(WeakDecreasing)) => - if (solver.definitiveALL(strongFormula)) StrongDecreasing - else if (solver.definitiveALL(weakFormula)) WeakDecreasing - else NoConstraint - case (WeakDecreasing, Some(WeakDecreasing)) => - if (solver.definitiveALL(weakFormula)) WeakDecreasing - else NoConstraint - } - } - } - - val outers = invocationMap.mapValues(_.filter(_._1._1 != funDef)) - funDefHOArgs.foreach { id => appConstraint(funDef -> id) = constraint(id, outers.getOrElse(id, Seq.empty)) } - - val selfs = invocationMap.mapValues(_.filter(_._1._1 == funDef)) - funDefHOArgs.foreach { id => appConstraint(funDef -> id) = constraint(id, selfs.getOrElse(id, Seq.empty)) } - - strengthenedApp += funDef - } - } -} diff --git a/src/main/scala/leon/termination/StructuralSize.scala b/src/main/scala/leon/termination/StructuralSize.scala deleted file mode 100644 index 6d6bf68945cc72759f6d07f87fdee29b3e38a725..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/termination/StructuralSize.scala +++ /dev/null @@ -1,257 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package termination - -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Types._ -import purescala.Definitions._ -import purescala.Constructors._ -import purescala.Common._ - -import scala.collection.mutable.{Map => MutableMap} - -trait StructuralSize { - - /* Absolute value for BigInt type - * - * def absBigInt(x: BigInt): BigInt = if (x >= 0) x else -x - */ - val typedAbsBigIntFun: TypedFunDef = { - val x = FreshIdentifier("x", IntegerType, alwaysShowUniqueID = true) - val absFun = new FunDef(FreshIdentifier("absBigInt", alwaysShowUniqueID = true), Seq(), Seq(ValDef(x)), IntegerType) - absFun.body = Some(IfExpr( - GreaterEquals(Variable(x), InfiniteIntegerLiteral(0)), - Variable(x), - UMinus(Variable(x)) - )) - absFun.typed - } - - /* Negative absolute value for Int type - * - * To avoid -Integer.MIN_VALUE overflow, we use negative absolute value - * for bitvector integers. - * - * def absInt(x: Int): Int = if (x >= 0) -x else x - */ - val typedAbsIntFun: TypedFunDef = { - val x = FreshIdentifier("x", Int32Type, alwaysShowUniqueID = true) - val absFun = new FunDef(FreshIdentifier("absInt", alwaysShowUniqueID = true), Seq(), Seq(ValDef(x)), Int32Type) - absFun.body = Some(IfExpr( - GreaterEquals(Variable(x), IntLiteral(0)), - BVUMinus(Variable(x)), - Variable(x) - )) - absFun.typed - } - - /* Absolute value for Int (32 bit) type into mathematical integers - * - * We use a recursive function here as the bv2int functionality provided - * through SMT solvers is waaaaay too slow. Recursivity requires the - * postcondition for verification efforts to succeed. - * - * def absInt(x: Int): BigInt = (if (x == 0) { - * BigInt(0) - * } else if (x > 0) { - * 1 + absInt(x - 1) - * } else { - * 1 + absInt(-(x + 1)) // avoids -Integer.MIN_VALUE overflow - * }) ensuring (_ >= 0) - */ - def typedAbsInt2IntegerFun: TypedFunDef = { - val x = FreshIdentifier("x", Int32Type, alwaysShowUniqueID = true) - val absFun = new FunDef(FreshIdentifier("absInt", alwaysShowUniqueID = true), Seq(), Seq(ValDef(x)), IntegerType) - absFun.body = Some(IfExpr( - Equals(Variable(x), IntLiteral(0)), - InfiniteIntegerLiteral(0), - IfExpr( - GreaterThan(Variable(x), IntLiteral(0)), - Plus(InfiniteIntegerLiteral(1), FunctionInvocation(absFun.typed, Seq(BVMinus(Variable(x), IntLiteral(1))))), - Plus(InfiniteIntegerLiteral(1), FunctionInvocation(absFun.typed, Seq(BVUMinus(BVPlus(Variable(x), IntLiteral(1)))))) - ))) - val res = FreshIdentifier("res", IntegerType, alwaysShowUniqueID = true) - absFun.postcondition = Some(Lambda(Seq(ValDef(res)), GreaterEquals(Variable(res), InfiniteIntegerLiteral(0)))) - absFun.typed - } - - private val fullCache: MutableMap[TypeTree, FunDef] = MutableMap.empty - - /* Fully recursive size computation - * - * Computes (positive) size of leon types by summing up all sub-elements - * accessible in the type definition. - */ - def fullSize(expr: Expr) : Expr = { - def funDef(ct: ClassType, cases: ClassType => Seq[MatchCase]): FunDef = { - // we want to reuse generic size functions for sub-types - val classDef = ct.root.classDef - val argumentType = classDef.typed - val typeParams = classDef.tparams.map(_.tp) - - fullCache.get(argumentType) match { - case Some(fd) => fd - case None => - val argument = ValDef(FreshIdentifier("x", argumentType, true)) - val formalTParams = typeParams.map(TypeParameterDef) - val fd = new FunDef(FreshIdentifier("fullSize", alwaysShowUniqueID = true), formalTParams, Seq(argument), IntegerType) - fullCache(argumentType) = fd - - val body = simplifyLets(matchToIfThenElse(matchExpr(argument.toVariable, cases(argumentType)))) - val postId = FreshIdentifier("res", IntegerType) - val postcondition = Lambda(Seq(ValDef(postId)), GreaterThan(Variable(postId), InfiniteIntegerLiteral(0))) - - fd.body = Some(body) - fd.postcondition = Some(postcondition) - fd - } - } - - def caseClassType2MatchCase(c: CaseClassType): MatchCase = { - val arguments = c.fields.map(vd => FreshIdentifier(vd.id.name, vd.getType)) - val argumentPatterns = arguments.map(id => WildcardPattern(Some(id))) - val sizes = arguments.map(id => fullSize(Variable(id))) - val result = sizes.foldLeft[Expr](InfiniteIntegerLiteral(1))(Plus) - purescala.Extractors.SimpleCase(CaseClassPattern(None, c, argumentPatterns), result) - } - - expr.getType match { - case (ct: ClassType) => - val fd = funDef(ct.root, { - case (act: AbstractClassType) => act.knownCCDescendants map caseClassType2MatchCase - case (cct: CaseClassType) => Seq(caseClassType2MatchCase(cct)) - }) - FunctionInvocation(TypedFunDef(fd, ct.tps), Seq(expr)) - case TupleType(argTypes) => argTypes.zipWithIndex.map({ - case (_, index) => fullSize(tupleSelect(expr, index + 1, true)) - }).foldLeft[Expr](InfiniteIntegerLiteral(0))(plus) - case IntegerType => - FunctionInvocation(typedAbsBigIntFun, Seq(expr)) - case Int32Type => - FunctionInvocation(typedAbsInt2IntegerFun, Seq(expr)) - case _ => InfiniteIntegerLiteral(0) - } - } - - private val outerCache: MutableMap[TypeTree, FunDef] = MutableMap.empty - - object ContainerType { - def unapply(c: ClassType): Option[(CaseClassType, Seq[(Identifier, TypeTree)])] = c match { - case cct @ CaseClassType(ccd, _) => - if (cct.fields.exists(arg => purescala.TypeOps.isSubtypeOf(arg.getType, cct.root))) None - else if (ccd.hasParent && ccd.parent.get.knownDescendants.size > 1) None - else Some((cct, cct.fields.map(arg => arg.id -> arg.getType))) - case _ => None - } - } - - def flatTypesPowerset(tpe: TypeTree): Set[Expr => Expr] = { - def powerSetToFunSet(l: TraversableOnce[Expr => Expr]): Set[Expr => Expr] = { - l.toSet.subsets.filter(_.nonEmpty).map{ - (reconss : Set[Expr => Expr]) => (e : Expr) => - tupleWrap(reconss.toSeq map { f => f(e) }) - }.toSet - } - - def rec(tpe: TypeTree): Set[Expr => Expr] = tpe match { - case ContainerType(cct, fields) => - powerSetToFunSet(fields.zipWithIndex.flatMap { case ((fieldId, fieldTpe), index) => - rec(fieldTpe).map(recons => (e: Expr) => recons(caseClassSelector(cct, e, fieldId))) - }) - case TupleType(tpes) => - powerSetToFunSet(tpes.indices.flatMap { case index => - rec(tpes(index)).map(recons => (e: Expr) => recons(tupleSelect(e, index + 1, true))) - }) - case _ => Set((e: Expr) => e) - } - - rec(tpe) - } - - def flatType(tpe: TypeTree): Set[Expr => Expr] = { - def rec(tpe: TypeTree): Set[Expr => Expr] = tpe match { - case ContainerType(cct, fields) => - fields.zipWithIndex.flatMap { case ((fieldId, fieldTpe), index) => - rec(fieldTpe).map(recons => (e: Expr) => recons(caseClassSelector(cct, e, fieldId))) - }.toSet - case TupleType(tpes) => - tpes.indices.flatMap { case index => - rec(tpes(index)).map(recons => (e: Expr) => recons(tupleSelect(e, index + 1, true))) - }.toSet - case _ => Set((e: Expr) => e) - } - - rec(tpe) - } - - /* Recursively computes outer datastructure size - * - * Computes the structural size of a datastructure but only considers - * the outer-most datatype definition. - * - * eg. for List[List[T]], the inner list size is not considered - */ - def outerSize(expr: Expr) : Expr = { - def dependencies(ct: ClassType): Set[ClassType] = { - def deps(ct: ClassType): Set[ClassType] = ct.fieldsTypes.collect { case ct: ClassType => ct.root }.toSet - utils.fixpoint((cts: Set[ClassType]) => cts ++ cts.flatMap(deps))(Set(ct)) - } - - flatType(expr.getType).foldLeft[Expr](InfiniteIntegerLiteral(0)) { case (i, f) => - val e = f(expr) - plus(i, e.getType match { - case ct: ClassType => - val root = ct.root - val fd = outerCache.getOrElse(root.classDef.typed, { - val tcd = root.classDef.typed - val id = FreshIdentifier("x", tcd, true) - val fd = new FunDef(FreshIdentifier("outerSize", alwaysShowUniqueID = true), - root.classDef.tparams, - Seq(ValDef(id)), - IntegerType) - outerCache(tcd) = fd - - fd.body = Some(MatchExpr(Variable(id), tcd.knownCCDescendants map { cct => - val args = cct.fields.map(_.id.freshen) - purescala.Extractors.SimpleCase( - CaseClassPattern(None, cct, args.map(id => WildcardPattern(Some(id)))), - args.foldLeft[Expr](InfiniteIntegerLiteral(1)) { case (e, id) => - plus(e, id.getType match { - case ct: ClassType if dependencies(tcd)(ct.root) => outerSize(Variable(id)) - case _ => InfiniteIntegerLiteral(0) - }) - }) - })) - - val res = FreshIdentifier("res", IntegerType, true) - fd.postcondition = Some(Lambda(Seq(ValDef(res)), GreaterEquals(Variable(res), InfiniteIntegerLiteral(0)))) - fd - }) - - FunctionInvocation(fd.typed(root.tps), Seq(e)) - - case _ => InfiniteIntegerLiteral(0) - }) - } - } - - def lexicographicDecreasing(s1: Seq[Expr], s2: Seq[Expr], strict: Boolean, sizeOfOneExpr: Expr => Expr): Expr = { - // Note: The Equal and GreaterThan ASTs work for both BigInt and Bitvector - - val sameSizeExprs = for ((arg1, arg2) <- s1 zip s2) yield Equals(sizeOfOneExpr(arg1), sizeOfOneExpr(arg2)) - - val greaterBecauseGreaterAtFirstDifferentPos = - orJoin(for (firstDifferent <- 0 until scala.math.min(s1.length, s2.length)) yield and( - andJoin(sameSizeExprs.take(firstDifferent)), - GreaterThan(sizeOfOneExpr(s1(firstDifferent)), sizeOfOneExpr(s2(firstDifferent))) - )) - - if (s1.length > s2.length || (s1.length == s2.length && !strict)) { - or(andJoin(sameSizeExprs), greaterBecauseGreaterAtFirstDifferentPos) - } else { - greaterBecauseGreaterAtFirstDifferentPos - } - } -} diff --git a/src/main/scala/leon/termination/TerminationChecker.scala b/src/main/scala/leon/termination/TerminationChecker.scala deleted file mode 100644 index c6dec9a6d2c3a51cb361064aeabdef64a7cbbf31..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/termination/TerminationChecker.scala +++ /dev/null @@ -1,41 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package termination - -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.DefOps._ - -abstract class TerminationChecker(val context: LeonContext, initProgram: Program) extends LeonComponent { - val program = { - transformProgram(funDefReplacer{ fd => Some(fd.duplicate()) }, initProgram) - } - - val functions = visibleFunDefsFromMain(program) - - def terminates(funDef : FunDef) : TerminationGuarantee -} - -sealed abstract class TerminationGuarantee { - def isGuaranteed: Boolean -} - -abstract class Terminating(justification: String) extends TerminationGuarantee { - override def isGuaranteed: Boolean = true -} - -case class Terminates(justification: String) extends Terminating(justification) - -abstract class NonTerminating extends TerminationGuarantee { - override def isGuaranteed: Boolean = false -} - -case class LoopsGivenInputs(justification: String, args: Seq[Expr]) extends NonTerminating -case class MaybeLoopsGivenInputs(justification: String, args: Seq[Expr]) extends NonTerminating - -case class CallsNonTerminating(calls: Set[FunDef]) extends NonTerminating - -case object NoGuarantee extends TerminationGuarantee { - override def isGuaranteed: Boolean = false -} diff --git a/src/main/scala/leon/termination/TerminationPhase.scala b/src/main/scala/leon/termination/TerminationPhase.scala deleted file mode 100644 index a7802c490efb494f419e8158439d3c8dbe848f3f..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/termination/TerminationPhase.scala +++ /dev/null @@ -1,35 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package termination - -import purescala.Definitions._ - -object TerminationPhase extends SimpleLeonPhase[Program, TerminationReport] { - val name = "Termination" - val description = "Check termination of PureScala functions" - - def apply(ctx: LeonContext, program: Program): TerminationReport = { - val startTime = System.currentTimeMillis - -// val tc = new SimpleTerminationChecker(ctx, program) - val tc = new ComplexTerminationChecker(ctx, program) - - def excludeByDefault(fd: FunDef): Boolean = fd.annotations contains "library" - - val fdFilter = { - import OptionsHelpers._ - - filterInclusive(ctx.findOption(GlobalOptions.optFunctions).map(fdMatcher(program)), Some(excludeByDefault _)) - } - - val toVerify = tc.program.definedFunctions.filter(fdFilter).sortWith((fd1, fd2) => fd1.getPos < fd2.getPos) - - val results = toVerify.map { funDef => - funDef -> tc.terminates(funDef) - } - val endTime = System.currentTimeMillis - - new TerminationReport(ctx, tc.program, results, (endTime - startTime).toDouble / 1000.0d) - } -} diff --git a/src/main/scala/leon/termination/TerminationReport.scala b/src/main/scala/leon/termination/TerminationReport.scala deleted file mode 100644 index 47ef92268834a47ab584720f7169989def23531b..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/termination/TerminationReport.scala +++ /dev/null @@ -1,58 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package termination - -import purescala.Definitions._ -import utils.Report -import utils.ASCIIHelpers._ -import leon.purescala.PrettyPrinter -import leon.purescala.SelfPrettyPrinter - -case class TerminationReport(ctx: LeonContext, program: Program, results : Seq[(FunDef,TerminationGuarantee)], time : Double) extends Report { - - def summaryString : String = { - var t = Table("Termination summary") - - for ((fd, g) <- results) t += Row(Seq( - Cell(fd.id.asString(ctx)), - Cell { - val result = if (g.isGuaranteed) "\u2713" else "\u2717" - val verdict = g match { - case LoopsGivenInputs(reason, args) => - val niceArgs = args.map { v => - SelfPrettyPrinter.print(v, PrettyPrinter(v))(ctx, program) - } - "Non-terminating for call: " + niceArgs.mkString(fd.id + "(", ",", ")") - case CallsNonTerminating(funDefs) => - "Calls non-terminating functions " + funDefs.map(_.id).mkString(",") - case Terminates(reason) => - "Terminates (" + reason + ")" - case _ => g.toString - } - s"$result $verdict" - } - )) - - t += Separator - - t += Row(Seq(Cell( - f"Analysis time: $time%7.3f", - spanning = 2 - ))) - - t.render - } - - def evaluationString : String = { - val sb = new StringBuilder - for((fd,g) <- results) { - val guar = g match { - case NoGuarantee => "u" - case t => if (t.isGuaranteed) "t" else "n" - } - sb.append(f"- ${fd.id.name}%-30s $guar\n") - } - sb.toString - } -} diff --git a/src/main/scala/leon/transformations/DepthInstPhase.scala b/src/main/scala/leon/transformations/DepthInstPhase.scala deleted file mode 100644 index 96d47147e6a8439e8d37105feff523248963599c..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/transformations/DepthInstPhase.scala +++ /dev/null @@ -1,106 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package transformations - -import purescala.Common._ -import purescala.Definitions._ -import purescala.Extractors._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Types._ -import leon.utils._ -import invariant.util.Util._ - -object DepthCostModel { - val typedMaxFun = TypedFunDef(InstUtil.maxFun, Seq()) - - def costOf(e: Expr): Int = - e match { - case FunctionInvocation(fd, args) => 1 - case t: Terminal => 0 - case _ => 1 - } - - def costOfExpr(e: Expr) = InfiniteIntegerLiteral(costOf(e)) -} - -class DepthInstrumenter(p: Program, si: SerialInstrumenter) extends Instrumenter(p, si) { - import DepthCostModel._ - - def inst = Depth - - def functionsToInstrument(): Map[FunDef, List[Instrumentation]] = { - //find all functions transitively called from rootFuncs (here ignore functions called via pre/post conditions) - val instFunSet = getRootFuncs().foldLeft(Set[FunDef]())((acc, fd) => acc ++ cg.transitiveCallees(fd)).filter(_.hasBody) - instFunSet.map(x => (x, List(Depth))).toMap - } - - def additionalfunctionsToAdd(): Seq[FunDef] = Seq()// - max functions are inlined, so they need not be added - - def instrumentMatchCase(me: MatchExpr, mc: MatchCase, - caseExprCost: Expr, scrutineeCost: Expr): Expr = { - val costMatch = costOfExpr(me) - def totalCostOfMatchPatterns(me: MatchExpr, mc: MatchCase): BigInt = 0 - combineDepthIds(costMatch, List(caseExprCost, scrutineeCost)) - } - - def instrument(e: Expr, subInsts: Seq[Expr], funInvResVar: Option[Variable] = None)(implicit fd: FunDef, letIdMap: Map[Identifier, Identifier]): Expr = { - val costOfParent = costOfExpr(e) - e match { - case Variable(id) if letIdMap.contains(id) => - // add the cost of instrumentation here - Plus(costOfParent, si.selectInst(fd)(letIdMap(id).toVariable, inst)) - - case t: Terminal => costOfParent - case FunctionInvocation(tfd, args) => - val depthvar = subInsts.last - val remSubInsts = subInsts.slice(0, subInsts.length - 1) - val costofOp = { - costOfParent match { - case InfiniteIntegerLiteral(x) if (x == 0) => depthvar - case _ => Plus(costOfParent, depthvar) - } - } - combineDepthIds(costofOp, remSubInsts) - case e : Let => - //in this case, ignore the depth of the value, it will included if the bounded variable is - // used in the body - combineDepthIds(costOfParent, subInsts.tail) - case _ => - val costofOp = costOfParent - combineDepthIds(costofOp, subInsts) - } - } - - def instrumentIfThenElseExpr(e: IfExpr, condInst: Option[Expr], - thenInst: Option[Expr], elzeInst: Option[Expr]): (Expr, Expr) = { - - val cinst = condInst.toList - val tinst = thenInst.toList - val einst = elzeInst.toList - - (combineDepthIds(zero, cinst ++ tinst), combineDepthIds(zero, cinst ++ einst)) - } - - def combineDepthIds(costofOp: Expr, subeInsts: Seq[Expr]): Expr = { - if (subeInsts.size == 0) costofOp - else if (subeInsts.size == 1) Plus(costofOp, subeInsts(0)) - else { - //optimization: remove duplicates from 'subeInsts' as 'max' is an idempotent operation - val head +: tail = subeInsts.distinct - val summand = tail.foldLeft(head: Expr)((acc, id) => { - (acc, id) match { - case (InfiniteIntegerLiteral(x), _) if (x == 0) => id - case (_, InfiniteIntegerLiteral(x)) if (x == 0) => acc - case _ => - FunctionInvocation(typedMaxFun, Seq(acc, id)) - } - }) - costofOp match { - case InfiniteIntegerLiteral(x) if (x == 0) => summand - case _ => Plus(costofOp, summand) - } - } - } -} diff --git a/src/main/scala/leon/transformations/InstrumentationUtil.scala b/src/main/scala/leon/transformations/InstrumentationUtil.scala deleted file mode 100644 index 5c26b194a55900adcea0383fe1b469c299bab0af..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/transformations/InstrumentationUtil.scala +++ /dev/null @@ -1,143 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package transformations - -import purescala.Common._ -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Extractors._ -import purescala.Types._ -import leon.utils.Library -import invariant.util._ -import invariant.util._ -import Util._ -import ProgramUtil._ -import PredicateUtil._ - -sealed abstract class Instrumentation { - val getType: TypeTree - val name: String - def isInstVariable(e: Expr): Boolean = { - e match { - case FunctionInvocation(TypedFunDef(fd, _), _) if (fd.id.name == name && fd.annotations("library")) => - true - case _ => false - } - } - override def toString = name -} - -object Time extends Instrumentation { - override val getType = IntegerType - override val name = "time" -} -object Depth extends Instrumentation { - override val getType = IntegerType - override val name = "depth" -} -object Rec extends Instrumentation { - override val getType = IntegerType - override val name = "rec" -} - -/** - * time per recursive step. - */ -object TPR extends Instrumentation { - override val getType = IntegerType - override val name = "tpr" -} - -object Stack extends Instrumentation { - override val getType = IntegerType - override val name = "stack" -} -//add more instrumentation variables - -object InstUtil { - - val InstTypes = Seq(Time, Depth, Rec, TPR, Stack) - - val maxFun = { - val xid = FreshIdentifier("x", IntegerType) - val yid = FreshIdentifier("y", IntegerType) - val varx = xid.toVariable - val vary = yid.toVariable - val args = Seq(xid, yid) - val maxType = FunctionType(Seq(IntegerType, IntegerType), IntegerType) - val mfd = new FunDef(FreshIdentifier("max", maxType, false), Seq(), args.map(arg => ValDef(arg)), IntegerType) - - val cond = GreaterEquals(varx, vary) - mfd.body = Some(IfExpr(cond, varx, vary)) - mfd.addFlag(Annotation("theoryop", Seq())) - mfd - } - - def userFunctionName(fd: FunDef) = { - val splits = fd.id.name.split("-") - if(!splits.isEmpty) splits(0) - else "" - } - - def getInstMap(fd: FunDef) = { - val resvar = getResId(fd).get.toVariable // note: every instrumented function has a postcondition - val insts = fd.id.name.split("-").tail // split the name of the function w.r.t '-' - (insts.zipWithIndex).foldLeft(Map[Expr, String]()) { - case (acc, (instName, i)) => - acc + (TupleSelect(resvar, i + 2) -> instName) - } - } - - def getInstExpr(fd: FunDef, inst: Instrumentation) = { - val resvar = getResId(fd).get.toVariable // note: every instrumented function has a postcondition - val insts = fd.id.name.split("-").tail // split the name of the function w.r.t '-' - val index = insts.indexOf(inst.name) - if (index >= 0) - Some(TupleSelect(resvar, index + 2)) - else None - } - - def getInstVariableMap(fd: FunDef) = { - getInstMap(fd).map { - case (ts, instName) => - (ts -> Variable(FreshIdentifier(instName, IntegerType))) - } - } - - def isInstrumented(fd: FunDef, instType: Instrumentation) = { - fd.id.name.split("-").contains(instType.toString) - } - - def isInstrumented(fd: FunDef) = { - val comps = fd.id.name.split("-") - InstTypes.exists { x => comps.contains(x.toString) } - } - - def resultExprForInstVariable(fd: FunDef, instType: Instrumentation) = { - getInstVariableMap(fd).collectFirst { - case (k, Variable(id)) if (id.name == instType.toString) => k - } - } - - def replaceInstruVars(e: Expr, fd: FunDef): Expr = { - val resvar = getResId(fd).get.toVariable - val newres = FreshIdentifier(resvar.id.name, resvar.getType).toVariable - replace(getInstVariableMap(fd) + (TupleSelect(resvar, 1) -> newres), e) - } - - /** - * Checks if the given expression is a resource bound of the given function. - */ - def isResourceBoundOf(fd: FunDef)(e: Expr) = { - val instExprs = InstTypes.map(getInstExpr(fd, _)).collect { - case Some(inste) => inste - }.toSet - !instExprs.isEmpty && isArithmeticRelation(e).get && - exists { - case sub: TupleSelect => instExprs(sub) - case _ => false - }(e) - } -} diff --git a/src/main/scala/leon/transformations/IntToRealProgram.scala b/src/main/scala/leon/transformations/IntToRealProgram.scala deleted file mode 100644 index f5444dec5970169d6adca7375322e0e776c625f4..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/transformations/IntToRealProgram.scala +++ /dev/null @@ -1,238 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package transformations - -import purescala.Common._ -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Extractors._ -import purescala.Types._ -import leon.purescala.ScalaPrinter - -import invariant.factories._ -import invariant.util._ -import Util._ -import ProgramUtil._ -import PredicateUtil._ -import TypeUtil._ -import invariant.structure._ - -abstract class ProgramTypeTransformer { - protected var defmap = Map[ClassDef, ClassDef]() - protected var idmap = Map[Identifier, Identifier]() - protected var newFundefs = Map[FunDef, FunDef]() - - def mapField(cdef: CaseClassDef, fieldId: Identifier): Identifier = { - (cdef.fieldsIds.collectFirst { - case fid @ _ if (fid.name == fieldId.name) => fid - }).get - } - - def mapClass[T <: ClassDef](cdef: T): T = { - if (defmap.contains(cdef)) { - defmap(cdef).asInstanceOf[T] - } else { - cdef match { - case ccdef: CaseClassDef => - val newparent = if (ccdef.hasParent) { - val absType = ccdef.parent.get - Some(AbstractClassType(mapClass(absType.classDef), absType.tps)) - } else None - val newclassDef = ccdef.duplicate(id = FreshIdentifier(ccdef.id.name, ccdef.id.getType, true), parent = newparent) - - //important: register a child if a parent was newly created. - if (newparent.isDefined) - newparent.get.classDef.registerChild(newclassDef) - - defmap += (ccdef -> newclassDef) - newclassDef.setFields(ccdef.fields.map(mapDecl)) - newclassDef.asInstanceOf[T] - - case acdef: AbstractClassDef => - val newparent = if (acdef.hasParent) { - val absType = acdef.parent.get - Some(AbstractClassType(mapClass(absType.classDef), absType.tps)) - } else None - val newClassDef = acdef.duplicate(id = FreshIdentifier(acdef.id.name, acdef.id.getType, true), parent = newparent) - defmap += (acdef -> newClassDef) - newClassDef.asInstanceOf[T] - } - } - } - - def mapId(id: Identifier): Identifier = { - val newtype = mapType(id.getType) - val newId = idmap.getOrElse(id, { - //important need to preserve distinction between template variables and ordinary variables - val freshId = if (TemplateIdFactory.IsTemplateIdentifier(id)) TemplateIdFactory.copyIdentifier(id) - else FreshIdentifier(id.name, newtype, true) - idmap += (id -> freshId) - freshId - }) - newId - } - - def mapDecl(decl: ValDef): ValDef = { - decl.copy(id = mapId(decl.id)) - } - - def mapType(tpe: TypeTree): TypeTree = { - tpe match { - case t @ RealType => mapNumericType(t) - case t @ IntegerType => mapNumericType(t) - case AbstractClassType(adef, tps) => AbstractClassType(mapClass(adef), tps) - case CaseClassType(cdef, tps) => CaseClassType(mapClass(cdef), tps) - case TupleType(bases) => TupleType(bases.map(mapType)) - case _ => tpe - } - } - - def mapNumericType(tpe: TypeTree): TypeTree - - def mapLiteral(lit: Literal[_]): Literal[_] - - def transform(program: Program): Program = { - //create a new fundef for each function in the program - //Unlike functions, classes are created lazily as required. - newFundefs = program.definedFunctions.map((fd) => { - val newFunType = FunctionType(fd.tparams.map((currParam) => currParam.tp), fd.returnType) - val newfd = new FunDef(FreshIdentifier(fd.id.name, newFunType, true), fd.tparams, fd.params.map(mapDecl), mapType(fd.returnType)) - (fd, newfd) - }).toMap - - /** - * Here, we assume that tuple-select and case-class-select have been reduced - */ - def transformExpr(e: Expr): Expr = e match { - case l: Literal[_] => mapLiteral(l) - case v @ Variable(inId) => mapId(inId).toVariable - case FunctionInvocation(TypedFunDef(intfd, tps), args) => FunctionInvocation(TypedFunDef(newFundefs(intfd), tps), args.map(transformExpr)) - case CaseClass(CaseClassType(classDef, tps), args) => CaseClass(CaseClassType(mapClass(classDef), tps), args.map(transformExpr)) - case IsInstanceOf(expr, CaseClassType(classDef, tps)) => IsInstanceOf(transformExpr(expr), CaseClassType(mapClass(classDef), tps)) - case CaseClassSelector(CaseClassType(classDef, tps), expr, fieldId) => { - val newtype = CaseClassType(mapClass(classDef), tps) - CaseClassSelector(newtype, transformExpr(expr), mapField(newtype.classDef, fieldId)) - } - //need to handle 'let' and 'letTuple' specially - case Let(binder, value, body) => Let(mapId(binder), transformExpr(value), transformExpr(body)) - case t: Terminal => t - /*case UnaryOperator(arg, op) => op(transformExpr(arg)) - case BinaryOperator(arg1, arg2, op) => op(transformExpr(arg1), transformExpr(arg2))*/ - case Operator(args, op) => op(args.map(transformExpr)) - } - - //create a body, pre, post for each newfundef - newFundefs.foreach((entry) => { - val (fd, newfd) = entry - - //add a new precondition - newfd.precondition = - if (fd.precondition.isDefined) - Some(transformExpr(fd.precondition.get)) - else None - - //add a new body - newfd.body = if (fd.hasBody) { - //replace variables by constants if possible - val simpBody = matchToIfThenElse(fd.body.get) - Some(transformExpr(simpBody)) - } else Some(NoTree(fd.returnType)) - - // FIXME - //add a new postcondition - newfd.fullBody = if (fd.postcondition.isDefined && newfd.body.isDefined) { - val Lambda(Seq(ValDef(resid)), pexpr) = fd.postcondition.get - val tempRes = mapId(resid).toVariable - Ensuring(newfd.body.get, Lambda(Seq(ValDef(tempRes.id)), transformExpr(pexpr))) - // Some(mapId(resid), transformExpr(pexpr)) - } else NoTree(fd.returnType) - - fd.flags.foreach(newfd.addFlag(_)) - }) - - val newprog = copyProgram(program, (defs: Seq[Definition]) => defs.map { - case fd: FunDef => newFundefs(fd) - case cd: ClassDef => mapClass(cd) - case d @ _ => throw new IllegalStateException("Unknown Definition: " + d) - }) - newprog - } -} - -class IntToRealProgram extends ProgramTypeTransformer { - - private var realToIntId = Map[Identifier, Identifier]() - - def mapNumericType(tpe: TypeTree) = { - require(isNumericType(tpe)) - tpe match { - case IntegerType => RealType - case _ => tpe - } - } - - def mapLiteral(lit: Literal[_]): Literal[_] = lit match { - case IntLiteral(v) => FractionalLiteral(v, 1) - case _ => lit - } - - def apply(program: Program): Program = { - - val newprog = transform(program) - //reverse the map - realToIntId = idmap.map(entry => (entry._2 -> entry._1)) - //println("After Real Program Conversion: \n" + ScalaPrinter.apply(newprog)) - //print all the templates - /*newprog.definedFunctions.foreach((fd) => { - val funinfo = FunctionInfoFactory.getFunctionInfo(fd) - if (funinfo.isDefined && funinfo.get.hasTemplate) - println("Function: " + fd.id + " template --> " + funinfo.get.getTemplate) - })*/ - newprog - } - - /** - * Assuming that the model maps only variables - */ - def unmapModel(model: Map[Identifier, Expr]): Map[Identifier, Expr] = { - model.map((pair) => { - val (key, value) = if (realToIntId.contains(pair._1)) { - (realToIntId(pair._1), pair._2) - } else pair - (key -> value) - }) - } -} - -class RealToIntProgram extends ProgramTypeTransformer { - val debugIntToReal = false - val bone = BigInt(1) - - def mapNumericType(tpe: TypeTree) = { - require(isNumericType(tpe)) - tpe match { - case RealType => IntegerType - case _ => tpe - } - } - - def mapLiteral(lit: Literal[_]): Literal[_] = lit match { - case FractionalLiteral(v, `bone`) => InfiniteIntegerLiteral(v) - case FractionalLiteral(_, _) => throw new IllegalStateException("Cannot convert real to integer: " + lit) - case _ => lit - } - - def apply(program: Program): Program = { - - val newprog = transform(program) - - if (debugIntToReal) - println("Program to Verify: \n" + ScalaPrinter.apply(newprog)) - - newprog - } - - def mappedFun(fd: FunDef): FunDef = newFundefs(fd) -} diff --git a/src/main/scala/leon/transformations/NonRecursiveTimePhase.scala b/src/main/scala/leon/transformations/NonRecursiveTimePhase.scala deleted file mode 100644 index 7df3f03f081d7160f66d44ced7bd8f0550fafc11..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/transformations/NonRecursiveTimePhase.scala +++ /dev/null @@ -1,123 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package transformations - -import purescala.Common._ -import purescala.Definitions._ -import purescala.Extractors._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Types._ -import leon.utils._ -import leon.invariant.util.Util._ - -import scala.collection.mutable.{Map => MutableMap} - -object tprCostModel { - def costOf(e: Expr): Int = e match { - case FunctionInvocation(fd, _) if !fd.hasBody => 0 // uninterpreted functions - case FunctionInvocation(fd, args) => 1 - case t: Terminal => 0 - case _ => 1 - } - def costOfExpr(e: Expr) = InfiniteIntegerLiteral(costOf(e)) -} - -class TPRInstrumenter(p: Program, si: SerialInstrumenter) extends Instrumenter(p, si) { - import tprCostModel._ - - def inst = TPR - - val sccs = cg.graph.sccs.flatMap { scc => - scc.map(fd => (fd -> scc.toSet)) - }.toMap - - //find all functions transitively called from rootFuncs (here ignore functions called via pre/post conditions) - val tprFuncs = getRootFuncs() - val timeFuncs = tprFuncs.foldLeft(Set[FunDef]())((acc, fd) => acc ++ cg.transitiveCallees(fd)).filter(_.hasBody) - - def functionsToInstrument(): Map[FunDef, List[Instrumentation]] = { - var emap = MutableMap[FunDef,List[Instrumentation]]() - def update(fd: FunDef, inst: Instrumentation) { - if (emap.contains(fd)) - emap(fd) :+= inst - else emap.update(fd, List(inst)) - } - tprFuncs.map(fd => update(fd, TPR)) - timeFuncs.map(fd => update(fd, Time)) - emap.toMap - } - - def additionalfunctionsToAdd() = Seq() - - def instrumentMatchCase( - me: MatchExpr, - mc: MatchCase, - caseExprCost: Expr, - scrutineeCost: Expr): Expr = { - val costMatch = costOfExpr(me) - - def totalCostOfMatchPatterns(me: MatchExpr, mc: MatchCase): BigInt = { - def patCostRecur(pattern: Pattern, innerPat: Boolean, countLeafs: Boolean): Int = { - pattern match { - case InstanceOfPattern(_, _) => { - if (innerPat) 2 else 1 - } - case WildcardPattern(None) => 0 - case WildcardPattern(Some(id)) => { - if (countLeafs && innerPat) 1 - else 0 - } - case CaseClassPattern(_, _, subPatterns) => { - (if (innerPat) 2 else 1) + subPatterns.foldLeft(0)((acc, subPat) => - acc + patCostRecur(subPat, true, countLeafs)) - } - case TuplePattern(_, subPatterns) => { - (if (innerPat) 2 else 1) + subPatterns.foldLeft(0)((acc, subPat) => - acc + patCostRecur(subPat, true, countLeafs)) - } - case LiteralPattern(_, _) => if (innerPat) 2 else 1 - case _ => - throw new NotImplementedError(s"Pattern $pattern not handled yet!") - } - } - me.cases.take(me.cases.indexOf(mc)).foldLeft(0)( - (acc, currCase) => acc + patCostRecur(currCase.pattern, false, false)) + - patCostRecur(mc.pattern, false, true) - } - Plus(costMatch, Plus( - Plus(InfiniteIntegerLiteral(totalCostOfMatchPatterns(me, mc)), - caseExprCost), - scrutineeCost)) - } - - def instrument(e: Expr, subInsts: Seq[Expr], funInvResVar: Option[Variable] = None) - (implicit fd: FunDef, letIdMap: Map[Identifier,Identifier]): Expr = e match { - case t: Terminal => costOfExpr(t) - case FunctionInvocation(tfd, args) => { - val remSubInsts = if (tprFuncs.contains(tfd.fd)) - subInsts.slice(0, subInsts.length - 1) - else subInsts - if (sccs(fd)(tfd.fd)) { - remSubInsts.foldLeft(costOfExpr(e) : Expr)( - (acc: Expr, subeTime: Expr) => Plus(subeTime, acc)) - } - else { - val allSubInsts = remSubInsts :+ si.selectInst(tfd.fd)(funInvResVar.get, Time) - allSubInsts.foldLeft(costOfExpr(e) : Expr)( - (acc: Expr, subeTime: Expr) => Plus(subeTime, acc)) - } - } - case _ => - subInsts.foldLeft(costOfExpr(e) : Expr)( - (acc: Expr, subeTime: Expr) => Plus(subeTime, acc)) - } - - def instrumentIfThenElseExpr(e: IfExpr, condInst: Option[Expr], - thenInst: Option[Expr], elzeInst: Option[Expr]): (Expr, Expr) = { - val costIf = costOfExpr(e) - (Plus(costIf, Plus(condInst.get, thenInst.get)), - Plus(costIf, Plus(condInst.get, elzeInst.get))) - } -} \ No newline at end of file diff --git a/src/main/scala/leon/transformations/NonlinearityEliminationPhase.scala b/src/main/scala/leon/transformations/NonlinearityEliminationPhase.scala deleted file mode 100644 index 76d34c195bce7603eac0864f8ded92968bbb0b9a..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/transformations/NonlinearityEliminationPhase.scala +++ /dev/null @@ -1,188 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package transformations - -import invariant.factories._ -import invariant.util._ -import Util._ -import ProgramUtil._ -import PredicateUtil._ -import TypeUtil._ -import invariant.structure.FunctionUtils._ - -import purescala.ScalaPrinter -import purescala.Common._ -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Extractors._ -import purescala.Types._ - -object MultFuncs { - def getMultFuncs(domain: TypeTree): (FunDef, FunDef) = { - //a recursive function that represents multiplication of two positive arguments - val pivMultFun = { - val xid = FreshIdentifier("x", domain) - val yid = FreshIdentifier("y", domain) - val varx = xid.toVariable - val vary = yid.toVariable - val args = Seq(xid, yid) - val funcType = FunctionType(Seq(domain, domain), domain) - val mfd = new FunDef(FreshIdentifier("pmult", funcType, false), Seq(), args.map(arg => ValDef(arg)), domain) - val tmfd = TypedFunDef(mfd, Seq()) - - //define a body (a) using mult(x,y) = if(x == 0 || y ==0) 0 else mult(x-1,y) + y - val cond = Or(Equals(varx, zero), Equals(vary, zero)) - val xminus1 = Minus(varx, one) - val yminus1 = Minus(vary, one) - val elze = Plus(FunctionInvocation(tmfd, Seq(xminus1, vary)), vary) - mfd.body = Some(IfExpr(cond, zero, elze)) - - //add postcondition - val resvar = FreshIdentifier("res", domain).toVariable - val post0 = GreaterEquals(resvar, zero) - - //define alternate definitions of multiplication as postconditions - //(a) res = !(x==0 || y==0) => mult(x,y-1) + x - val guard = Not(cond) - val defn2 = Equals(resvar, Plus(FunctionInvocation(tmfd, Seq(varx, yminus1)), varx)) - val post1 = Implies(guard, defn2) - - // mfd.postcondition = Some((resvar.id, And(Seq(post0, post1)))) - mfd.fullBody = Ensuring(mfd.body.get, Lambda(Seq(ValDef(resvar.id)), And(Seq(post0, post1)))) - //set function properties (for now, only monotonicity) - mfd.addFlags(Set(Annotation("theoryop", Seq()), Annotation("monotonic", Seq()))) //"distributive" ? - mfd - } - - //a function that represents multiplication, this transitively calls pmult - val multFun = { - val xid = FreshIdentifier("x", domain) - val yid = FreshIdentifier("y", domain) - val args = Seq(xid, yid) - val funcType = FunctionType(Seq(domain, domain), domain) - val fd = new FunDef(FreshIdentifier("mult", funcType, false), Seq(), args.map(arg => ValDef(arg)), domain) - val tpivMultFun = TypedFunDef(pivMultFun, Seq()) - - //the body is defined as mult(x,y) = val px = if(x < 0) -x else x; - //val py = if(y<0) -y else y; val r = pmult(px,py); - //if(x < 0 && y < 0 || x >= 0 && y >= 0) r else -r - val varx = xid.toVariable - val vary = yid.toVariable - val modx = IfExpr(LessThan(varx, zero), UMinus(varx), varx) - val mody = IfExpr(LessThan(vary, zero), UMinus(vary), vary) - val px = FreshIdentifier("px", domain, false) - val py = FreshIdentifier("py", domain, false) - val call = Let(px, modx, Let(py, mody, FunctionInvocation(tpivMultFun, Seq(px, py).map(_.toVariable)))) - val bothPive = And(GreaterEquals(varx, zero), GreaterEquals(vary, zero)) - val bothNive = And(LessThan(varx, zero), LessThan(vary, zero)) - val res = FreshIdentifier("r", domain, false) - val body = Let(res, call, IfExpr(Or(bothPive, bothNive), res.toVariable, UMinus(res.toVariable))) - fd.body = Some(body) - //set function properties - fd.addFlags(Set(Annotation("theoryop", Seq()), Annotation("monotonic", Seq()))) - fd - } - - (pivMultFun, multFun) - } -} - -class NonlinearityEliminator(skipAxioms: Boolean, domain: TypeTree) { - import MultFuncs._ - require(isNumericType(domain)) - - val debugNLElim = false - - val one = InfiniteIntegerLiteral(1) - val zero = InfiniteIntegerLiteral(0) - - val (pivMultFun, multFun) = getMultFuncs(domain) - - //TOOD: note associativity property of multiplication is not taken into account - def apply(program: Program): Program = { - //create a fundef for each function in the program - val newFundefs = userLevelFunctions(program).map { fd => - val newFunType = FunctionType(fd.tparams.map(_.tp), fd.returnType) - val newfd = new FunDef(FreshIdentifier(fd.id.name, newFunType, false), fd.tparams, fd.params, fd.returnType) - (fd -> newfd) - }.toMap - - //note, handling templates variables is slightly tricky as we need to preserve a*x as it is - val tmult = TypedFunDef(multFun, Seq()) - var addMult = false - def replaceFun(ine: Expr, allowedVars: Set[Identifier] = Set()): Expr = { - simplePostTransform(e => e match { - case fi @ FunctionInvocation(tfd1, args) if newFundefs.contains(tfd1.fd) => - FunctionInvocation(TypedFunDef(newFundefs(tfd1.fd), tfd1.tps), args) - - case Times(Variable(id), e2) if (allowedVars.contains(id)) => e - case Times(e1, Variable(id)) if (allowedVars.contains(id)) => e - - case Times(e1, e2) if (!e1.isInstanceOf[Literal[_]] && !e2.isInstanceOf[Literal[_]]) => { - //replace times by a mult function - addMult = true - FunctionInvocation(tmult, Seq(e1, e2)) - } - //note: include mult function if division operation is encountered - //division is handled during verification condition generation. - case Division(_, _) => { - addMult = true - e - } - case _ => e - })(ine) - } - - //create a body, pre, post for each newfundef - newFundefs.foreach { - case (fd, newfd) => - //add a new precondition - newfd.precondition = - if (fd.precondition.isDefined) - Some(replaceFun(fd.precondition.get)) - else None - //add a new body - newfd.body = if (fd.hasBody) { - //replace variables by constants if possible - val simpBody = simplifyLets(fd.body.get) - Some(replaceFun(simpBody)) - } else None - //add a new postcondition - newfd.postcondition = if (fd.postcondition.isDefined) { - //we need to handle template and postWoTemplate specially - val Lambda(resultBinders, _) = fd.postcondition.get - val tmplExpr = fd.templateExpr - val newpost = if (fd.hasTemplate) { - val FunctionInvocation(tmpfd, Seq(Lambda(tmpvars, tmpbody))) = tmplExpr.get - val newtmp = FunctionInvocation(tmpfd, Seq(Lambda(tmpvars, - replaceFun(tmpbody, tmpvars.map(_.id).toSet)))) - fd.postWoTemplate match { - case None => - newtmp - case Some(postExpr) => - And(replaceFun(postExpr), newtmp) - } - } else - replaceFun(fd.getPostWoTemplate) - Some(Lambda(resultBinders, newpost)) - } else None - fd.flags.foreach(newfd.addFlag(_)) - } - val transProg = copyProgram(program, (defs: Seq[Definition]) => { - defs.map { - case fd: FunDef if newFundefs.contains(fd) => newFundefs(fd) - case d => d - } - }) - val newprog = - if (addMult) - addDefs(transProg, Seq(multFun, pivMultFun), transProg.units.find(_.isMainUnit).get.definedFunctions.last) - else transProg - if (debugNLElim) - println("After Nonlinearity Elimination: \n" + ScalaPrinter.apply(newprog)) - - newprog - } -} diff --git a/src/main/scala/leon/transformations/ProgramSimplifier.scala b/src/main/scala/leon/transformations/ProgramSimplifier.scala deleted file mode 100644 index cc4ce609d0da066e70bbd8347dbe9d8220cafcf3..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/transformations/ProgramSimplifier.scala +++ /dev/null @@ -1,74 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package transformations - -import purescala.Common._ -import purescala.Definitions._ -import purescala.Extractors._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Types._ -import leon.purescala.ScalaPrinter -import leon.utils._ -import invariant.util._ -import Util._ -import ProgramUtil._ -import PredicateUtil._ -import invariant.util.ExpressionTransformer._ -import invariant.structure.FunctionUtils._ -import invariant.util.LetTupleSimplification._ - -/** - * A simplifier phase that eliminates tuples that are not needed - * from function bodies, and also performs other simplifications. - * Note: performing simplifications during instrumentation - * will affect the validity of the information stored in function info. - */ -object ProgramSimplifier { - val debugSimplify = false - - def apply(program: Program, instFuncs: Seq[FunDef]): Program = { - val funMap = ((userLevelFunctions(program) ++ instFuncs).distinct).foldLeft(Map[FunDef, FunDef]()) { - case (accMap, fd) => { - val freshId = FreshIdentifier(fd.id.name, fd.returnType) - val newfd = new FunDef(freshId, fd.tparams, fd.params, fd.returnType) - accMap + (fd -> newfd) - } - } - def mapExpr(ine: Expr, fd: FunDef): Expr = { - val replaced = simplePostTransform((e: Expr) => e match { - case FunctionInvocation(tfd, args) if funMap.contains(tfd.fd) => - FunctionInvocation(TypedFunDef(funMap(tfd.fd), tfd.tps), args) - case _ => e - })(ine) - - // Note: simplify only instrumented functions - // One might want to add the maximum function to the program in the stack - // and depth instrumentation phases if inlineMax is removed from here - if (InstUtil.isInstrumented(fd)) { - val allSimplifications = - simplifyTuples _ andThen - simplifyMax _ andThen - simplifyLetsAndLetsWithTuples _ andThen - simplifyAdditionsAndMax _ andThen - inlineMax _ - allSimplifications(replaced) - } else replaced - } - - for ((from, to) <- funMap) { - to.fullBody = mapExpr(from.fullBody, from) - //copy annotations - from.flags.foreach(to.addFlag(_)) - } - val newprog = copyProgram(program, (defs: Seq[Definition]) => defs.map { - case fd: FunDef if funMap.contains(fd) => funMap(fd) - case d => d - }) - - if (debugSimplify) - println("After Simplifications: \n" + ScalaPrinter.apply(newprog)) - newprog - } -} \ No newline at end of file diff --git a/src/main/scala/leon/transformations/RecursionCountInstrumenter.scala b/src/main/scala/leon/transformations/RecursionCountInstrumenter.scala deleted file mode 100644 index 729a8d45a6d6955fd31f1694fce03dade7c89e6d..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/transformations/RecursionCountInstrumenter.scala +++ /dev/null @@ -1,73 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package transformations - -import purescala.Common._ -import purescala.Definitions._ -import purescala.Extractors._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Types._ -import leon.utils._ -import invariant.util.Util._ - -class RecursionCountInstrumenter(p: Program, si: SerialInstrumenter) extends Instrumenter(p, si) { - - def inst = Rec - - val sccs = cg.graph.sccs.flatMap { scc => - scc.map(fd => (fd -> scc.toSet)) - }.toMap - - /** - * Instrument only those functions that are in the same sccs of the root functions - */ - def functionsToInstrument(): Map[FunDef, List[Instrumentation]] = { - val instFunSet = getRootFuncs().flatMap(sccs.apply _).filter(_.hasBody) - instFunSet.map(x => (x, List(Rec))).toMap - } - - override def additionalfunctionsToAdd(): Seq[FunDef] = Seq.empty[FunDef] - - def addSubInstsIfNonZero(subInsts: Seq[Expr], init: Expr): Expr = { - subInsts.foldLeft(init) { - case (acc, subinst) if subinst != zero => - if (acc == zero) subinst - else Plus(acc, subinst) - } - } - - def instrumentMatchCase(me: MatchExpr, - mc: MatchCase, - caseExprCost: Expr, - scrutineeCost: Expr): Expr = { - Plus(caseExprCost, scrutineeCost) - } - - def instrument(e: Expr, subInsts: Seq[Expr], funInvResVar: Option[Variable] = None) - (implicit fd: FunDef, leIdtMap: Map[Identifier,Identifier]): Expr = e match { - case FunctionInvocation(TypedFunDef(callee, _), _) if sccs(fd)(callee) => - //this is a recursive call - //Note that the last element of subInsts is the instExpr of the invoked function - addSubInstsIfNonZero(subInsts, one) - case FunctionInvocation(TypedFunDef(callee, _), _) if si.funcInsts.contains(callee) && si.funcInsts(callee).contains(this.inst) => - //this is not a recursive call, so do not consider the cost of the callee - //Note that the last element of subInsts is the instExpr of the invoked function - addSubInstsIfNonZero(subInsts.take(subInsts.size - 1), zero) - case _ => - //add the cost of every sub-expression - addSubInstsIfNonZero(subInsts, zero) - } - - def instrumentIfThenElseExpr(e: IfExpr, condInst: Option[Expr], thenInst: Option[Expr], - elzeInst: Option[Expr]): (Expr, Expr) = { - - val cinst = condInst.toList - val tinst = thenInst.toList - val einst = elzeInst.toList - - (addSubInstsIfNonZero(cinst ++ tinst, zero), - addSubInstsIfNonZero(cinst ++ einst, zero)) - } -} \ No newline at end of file diff --git a/src/main/scala/leon/transformations/SerialInstrumentationPhase.scala b/src/main/scala/leon/transformations/SerialInstrumentationPhase.scala deleted file mode 100644 index ada1386798e5a7adb10453f1e405543c52cfee41..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/transformations/SerialInstrumentationPhase.scala +++ /dev/null @@ -1,499 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package transformations - -import purescala.Common._ -import purescala.Definitions._ -import purescala.Extractors._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Types._ -import leon.purescala.ScalaPrinter -import leon.utils._ -import invariant.util._ -import Util._ -import ProgramUtil._ -import PredicateUtil._ -import invariant.util.CallGraphUtil -import invariant.structure.FunctionUtils._ -import scala.collection.mutable.{ Map => MutableMap } - -/** - * An instrumentation phase that performs a sequence of instrumentations - */ - -object InstrumentationPhase extends TransformationPhase { - val name = "Instrumentation Phase" - val description = "Instruments the program for all counters needed" - - def apply(ctx: LeonContext, program: Program): Program = { - val si = new SerialInstrumenter(program) - val instprog = si.apply - //println("Instrumented Program: "+ScalaPrinter.apply(instprog, purescala.PrinterOptions(printUniqueIds = true))) - instprog - } -} - -class SerialInstrumenter(program: Program, - exprInstOpt: Option[(Map[FunDef, FunDef], SerialInstrumenter, FunDef) => ExprInstrumenter] = None) { - val debugInstrumentation = false - - val exprInstFactory = exprInstOpt.getOrElse((x: Map[FunDef, FunDef], y: SerialInstrumenter, z: FunDef) => new ExprInstrumenter(x, y)(z)) - - val instToInstrumenter: Map[Instrumentation, Instrumenter] = - Map(Time -> new TimeInstrumenter(program, this), - Depth -> new DepthInstrumenter(program, this), - Rec -> new RecursionCountInstrumenter(program, this), - Stack -> new StackSpaceInstrumenter(program, this), - TPR -> new TPRInstrumenter(program, this)) - - // a map from functions to the list of instrumentations to be performed for the function - val funcInsts = { - var emap = MutableMap[FunDef, List[Instrumentation]]() - def update(fd: FunDef, inst: Instrumentation) { - if (emap.contains(fd)) - emap(fd) = (emap(fd) :+ inst).distinct - else emap.update(fd, List(inst)) - } - instToInstrumenter.values.foreach { m => - m.functionsToInstrument.foreach({ - case (fd, instsToPerform) => - instsToPerform.foreach(instToPerform => update(fd, instToPerform)) - }) - } - emap.toMap - } - val instFuncs = funcInsts.keySet - - def instrumenters(fd: FunDef) = funcInsts(fd) map instToInstrumenter.apply _ - def instTypes(fd: FunDef) = funcInsts(fd).map(_.getType) - /** - * Index of the instrumentation 'inst' in result tuple that would be created. - * The return value will be >= 2 as the actual result value would be at index 1 - */ - def instIndex(fd: FunDef)(ins: Instrumentation) = funcInsts(fd).indexOf(ins) + 2 - def selectInst(fd: FunDef)(e: Expr, ins: Instrumentation) = TupleSelect(e, instIndex(fd)(ins)) - - def apply: Program = { - - if (instFuncs.isEmpty) program - else { - //create new functions. Augment the return type of a function iff the postcondition uses - //the instrumentation variable or if the function is transitively called from such a function - var funMap = Map[FunDef, FunDef]() - (userLevelFunctions(program) ++ instFuncs).distinct.foreach { fd => - if (instFuncs.contains(fd)) { - val newRetType = TupleType(fd.returnType +: instTypes(fd)) - // let the names of the function encode the kind of instrumentations performed - val freshId = FreshIdentifier(fd.id.name + "-" + funcInsts(fd).map(_.name).mkString("-"), newRetType) - val newfd = new FunDef(freshId, fd.tparams, fd.params, newRetType) - funMap += (fd -> newfd) - } else { - //here we need not augment the return types but do need to create a new copy - val freshId = FreshIdentifier(fd.id.name, fd.returnType) - val newfd = new FunDef(freshId, fd.tparams, fd.params, fd.returnType) - funMap += (fd -> newfd) - } - } - - def mapExpr(ine: Expr): Expr = { - simplePostTransform((e: Expr) => e match { - case FunctionInvocation(tfd, args) if funMap.contains(tfd.fd) => - if (instFuncs.contains(tfd.fd)) - TupleSelect(FunctionInvocation(TypedFunDef(funMap(tfd.fd), tfd.tps), args), 1) - else - FunctionInvocation(TypedFunDef(funMap(tfd.fd), tfd.tps), args) - case _ => e - })(ine) - } - - def mapBody(body: Expr, from: FunDef, to: FunDef) = { - val res = - if (from.isExtern) { - // this is an extern function, we must only rely on the specs - // so make the body empty - NoTree(to.returnType) - } else if (instFuncs.contains(from)) { - exprInstFactory(funMap, this, from)(body) - } else - mapExpr(body) - res - } - - def mapPost(pred: Expr, from: FunDef, to: FunDef) = { - pred match { - case Lambda(Seq(ValDef(fromRes)), postCond) if (instFuncs.contains(from)) => - val toResId = FreshIdentifier(fromRes.name, to.returnType, true) - val newpost = postMap((e: Expr) => e match { - case Variable(`fromRes`) => - Some(TupleSelect(toResId.toVariable, 1)) - - case _ if funcInsts(from).exists(_.isInstVariable(e)) => - val inst = funcInsts(from).find(_.isInstVariable(e)).get - Some(TupleSelect(toResId.toVariable, instIndex(from)(inst))) - - case _ => - None - })(postCond) - Lambda(Seq(ValDef(toResId)), mapExpr(newpost)) - case _ => - mapExpr(pred) - } - } - - // Map the bodies and preconditions - for ((from, to) <- funMap) { - //copy annotations - from.flags.foreach(to.addFlag(_)) - to.fullBody = from.fullBody match { - case b @ NoTree(_) => NoTree(to.returnType) - case Require(pre, body) => - //here 'from' does not have a postcondition but 'to' will always have a postcondition - val toPost = - Lambda(Seq(ValDef(FreshIdentifier("res", to.returnType))), BooleanLiteral(true)) - val bodyPre = - Require(mapExpr(pre), mapBody(body, from, to)) - Ensuring(bodyPre, toPost) - - case Ensuring(Require(pre, body), post) => - Ensuring(Require(mapExpr(pre), mapBody(body, from, to)), - mapPost(post, from, to)) - - case Ensuring(body, post) => - Ensuring(mapBody(body, from, to), mapPost(post, from, to)) - - case body => - val toPost = - Lambda(Seq(ValDef(FreshIdentifier("res", to.returnType))), BooleanLiteral(true)) - Ensuring(mapBody(body, from, to), toPost) - } - } - - val additionalFuncs = funMap.flatMap { - case (k, _) => - if (instFuncs(k)) - instrumenters(k).flatMap(_.additionalfunctionsToAdd) - else List() - }.toList.distinct - - val newprog = copyProgram(program, (defs: Seq[Definition]) => - defs.map { - case fd: FunDef if funMap.contains(fd) => - funMap(fd) - case d => d - } ++ additionalFuncs) - if (debugInstrumentation) - println("After Instrumentation: \n" + ScalaPrinter.apply(newprog)) - - ProgramSimplifier(newprog, instFuncs.toSeq) - } - } -} - -class ExprInstrumenter(funMap: Map[FunDef, FunDef], serialInst: SerialInstrumenter)(implicit currFun: FunDef) { - val retainMatches = true - - val insts = serialInst.funcInsts(currFun) - val instrumenters = serialInst.instrumenters(currFun) - val instIndex = serialInst.instIndex(currFun) _ - val selectInst = serialInst.selectInst(currFun) _ - val instTypes = serialInst.instTypes(currFun) - - // Should be called only if 'expr' has to be instrumented - // Returned Expr is always an expr of type tuple (Expr, Int) - def tupleify(e: Expr, subs: Seq[Expr], recons: Seq[Expr] => Expr)(implicit letIdMap: Map[Identifier, Identifier]): Expr = { - // When called for: - // Op(n1,n2,n3) - // e = Op(n1,n2,n3) - // subs = Seq(n1,n2,n3) - // recons = { Seq(newn1,newn2,newn3) => Op(newn1, newn2, newn3) } - // - // This transformation should return, informally: - // - // LetTuple((e1,t1), transform(n1), - // LetTuple((e2,t2), transform(n2), - // ... - // Tuple(recons(e1, e2, ...), t1 + t2 + ... costOfExpr(Op) - // ... - // ) - // ) - // - // You will have to handle FunctionInvocation specially here! - tupleifyRecur(e, subs, recons, List(), Map()) - } - - def tupleifyRecur(e: Expr, subs: Seq[Expr], recons: Seq[Expr] => Expr, subeVals: List[Expr], - subeInsts: Map[Instrumentation, List[Expr]])(implicit letIdMap: Map[Identifier, Identifier]): Expr = { - //note: subs.size should be zero if e is a terminal - if (subs.size == 0) { - e match { - case v @ Variable(id) => - val valPart = if (letIdMap.contains(id)) { - TupleSelect(letIdMap(id).toVariable, 1) //this takes care of replacement - } else v - val instPart = instrumenters map (_.instrument(v, Seq())) - Tuple(valPart +: instPart) - - case t: Terminal => - val instPart = instrumenters map (_.instrument(t, Seq())) - val finalRes = Tuple(t +: instPart) - finalRes - - // TODO: We are ignoring the construction cost of fields. Fix this. - case f @ FunctionInvocation(TypedFunDef(fd, tps), args) => - if (!fd.hasLazyFieldFlag) { - val newfd = TypedFunDef(funMap(fd), tps) - val newFunInv = FunctionInvocation(newfd, subeVals) - //create a variables to store the result of function invocation - if (serialInst.instFuncs(fd)) { - //this function is also instrumented - val resvar = Variable(FreshIdentifier("e", newfd.returnType, true)) - val valexpr = TupleSelect(resvar, 1) - val instexprs = instrumenters.map { m => - val calleeInst = - if (serialInst.funcInsts(fd).contains(m.inst) && fd.isUserFunction) { - List(serialInst.selectInst(fd)(resvar, m.inst)) - } else List() // ignoring fields here - //Note we need to ensure that the last element of list is the instval of the finv - m.instrument(e, subeInsts.getOrElse(m.inst, List()) ++ calleeInst, Some(resvar)) - } - Let(resvar.id, newFunInv, Tuple(valexpr +: instexprs)) - } else { - val resvar = Variable(FreshIdentifier("e", newFunInv.getType, true)) - val instexprs = instrumenters.map { m => - m.instrument(e, subeInsts.getOrElse(m.inst, List())) - } - Let(resvar.id, newFunInv, Tuple(resvar +: instexprs)) - } - - } else - throw new UnsupportedOperationException("Lazy fields are not handled in instrumentation." + - " Consider using the --lazy option and rewrite your program using lazy constructor `$`") - - case _ => - val exprPart = recons(subeVals) - val instexprs = instrumenters.zipWithIndex.map { - case (menter, i) => menter.instrument(e, subeInsts.getOrElse(menter.inst, List())) - } - Tuple(exprPart +: instexprs) - } - } else { - val currExp = subs.head - val resvar = Variable(FreshIdentifier("e", TupleType(currExp.getType +: instTypes), true)) - val eval = TupleSelect(resvar, 1) - val instMap = insts.map { inst => - (inst -> (subeInsts.getOrElse(inst, List()) :+ selectInst(resvar, inst))) - }.toMap - //process the remaining arguments - val recRes = tupleifyRecur(e, subs.tail, recons, subeVals :+ eval, instMap) - //transform the current expression - val newCurrExpr = transform(currExp) - Let(resvar.id, newCurrExpr, recRes) - } - } - - /** - * TODO: need to handle new expression trees - * Match statements without guards are now instrumented directly - */ - def transform(e: Expr)(implicit letIdMap: Map[Identifier, Identifier]): Expr = e match { - // Assume that none of the matchcases has a guard. It has already been converted into an if then else - case me @ MatchExpr(scrutinee, matchCases) => - val containsGuard = matchCases.exists(currCase => currCase.optGuard.isDefined) - if (containsGuard) { - def rewritePM(me: MatchExpr): Option[Expr] = { - val MatchExpr(scrut, cases) = me - val condsAndRhs = for (cse <- cases) yield { - val map = mapForPattern(scrut, cse.pattern) - val patCond = conditionForPattern(scrut, cse.pattern, includeBinders = false) - val realCond = cse.optGuard match { - case Some(g) => patCond withCond replaceFromIDs(map, g) - case None => patCond - } - val newRhs = replaceFromIDs(map, cse.rhs) - (realCond.toClause, newRhs) - } - val bigIte = condsAndRhs.foldRight[Expr]( - Error(me.getType, "Match is non-exhaustive").copiedFrom(me))((p1, ex) => { - if (p1._1 == BooleanLiteral(true)) { - p1._2 - } else { - IfExpr(p1._1, p1._2, ex) - } - }) - Some(bigIte) - } - transform(rewritePM(me).get) - } else { - val instScrutinee = - Variable(FreshIdentifier("scr", TupleType(scrutinee.getType +: instTypes), true)) - - def transformMatchCaseList(mCases: Seq[MatchCase]): Seq[MatchCase] = { - def transformMatchCase(mCase: MatchCase) = { - val MatchCase(pattern, guard, expr) = mCase - val newExpr = { - val exprVal = - Variable(FreshIdentifier("expr", TupleType(expr.getType +: instTypes), true)) - val newExpr = transform(expr) - val instExprs = instrumenters map { m => - m.instrumentMatchCase(me, mCase, selectInst(exprVal, m.inst), - selectInst(instScrutinee, m.inst)) - } - val letBody = Tuple(TupleSelect(exprVal, 1) +: instExprs) - Let(exprVal.id, newExpr, letBody) - } - MatchCase(pattern, guard, newExpr) - } - if (mCases.length == 0) Seq[MatchCase]() - else { - transformMatchCase(mCases.head) +: transformMatchCaseList(mCases.tail) - } - } - val matchExpr = MatchExpr(TupleSelect(instScrutinee, 1), - transformMatchCaseList(matchCases)) - Let(instScrutinee.id, transform(scrutinee), matchExpr) - } - - case Let(i, v, b) => { - val (ni, nv) = { - val ir = Variable(FreshIdentifier("ir", TupleType(v.getType +: instTypes), true)) - val transv = transform(v) - (ir, transv) - } - val r = Variable(FreshIdentifier("r", TupleType(b.getType +: instTypes), true)) - val transformedBody = transform(b)(letIdMap + (i -> ni.id)) - val instexprs = instrumenters map { m => - m.instrument(e, List(selectInst(ni, m.inst), selectInst(r, m.inst))) - } - Let(ni.id, nv, - Let(r.id, transformedBody, Tuple(TupleSelect(r, 1) +: instexprs))) - } - - case ife @ IfExpr(cond, th, elze) => { - val (nifCons, condInsts) = { - val rescond = Variable(FreshIdentifier("c", TupleType(cond.getType +: instTypes), true)) - val condInstPart = insts.map { inst => (inst -> selectInst(rescond, inst)) }.toMap - val recons = (e1: Expr, e2: Expr) => { - Let(rescond.id, transform(cond), IfExpr(TupleSelect(rescond, 1), e1, e2)) - } - (recons, condInstPart) - } - val (nthenCons, thenInsts) = { - val resthen = Variable(FreshIdentifier("th", TupleType(th.getType +: instTypes), true)) - val thInstPart = insts.map { inst => (inst -> selectInst(resthen, inst)) }.toMap - val recons = (theninsts: List[Expr]) => { - Let(resthen.id, transform(th), Tuple(TupleSelect(resthen, 1) +: theninsts)) - } - (recons, thInstPart) - } - val (nelseCons, elseInsts) = { - val reselse = Variable(FreshIdentifier("el", TupleType(elze.getType +: instTypes), true)) - val elInstPart = insts.map { inst => (inst -> selectInst(reselse, inst)) }.toMap - val recons = (einsts: List[Expr]) => { - Let(reselse.id, transform(elze), Tuple(TupleSelect(reselse, 1) +: einsts)) - } - (recons, elInstPart) - } - val (finalThInsts, finalElInsts) = instrumenters.foldLeft((List[Expr](), List[Expr]())) { - case ((thinsts, elinsts), menter) => - val inst = menter.inst - val (thinst, elinst) = menter.instrumentIfThenElseExpr(ife, - Some(condInsts(inst)), Some(thenInsts(inst)), Some(elseInsts(inst))) - (thinsts :+ thinst, elinsts :+ elinst) - } - val nthen = nthenCons(finalThInsts) - val nelse = nelseCons(finalElInsts) - nifCons(nthen, nelse) - } - - // For all other operations, we go through a common tupleifier. - case n @ Operator(ss, recons) => - tupleify(e, ss, recons) - - case t: Terminal => - tupleify(e, Seq(), { case Seq() => t }) - } - - def apply(e: Expr): Expr = { - // Apply transformations - val newe = - if (retainMatches) e - else matchToIfThenElse(liftExprInMatch(e)) - val transformed = transform(newe)(Map()) - val bodyId = FreshIdentifier("bd", transformed.getType, true) - val instExprs = instrumenters map { m => - m.instrumentBody(newe, - selectInst(bodyId.toVariable, m.inst)) - } - Let(bodyId, transformed, - Tuple(TupleSelect(bodyId.toVariable, 1) +: instExprs)) - } - - def liftExprInMatch(ine: Expr): Expr = { - def helper(e: Expr): Expr = { - e match { - case MatchExpr(strut, cases) => strut match { - case t: Terminal => e - case _ => { - val freshid = FreshIdentifier("m", strut.getType, true) - Let(freshid, strut, MatchExpr(freshid.toVariable, cases)) - } - } - case _ => e - } - } - - if (retainMatches) helper(ine) - else simplePostTransform(helper)(ine) - } -} - -/** - * Implements procedures for a specific instrumentation - */ -abstract class Instrumenter(program: Program, si: SerialInstrumenter) { - - def inst: Instrumentation - - protected val cg = CallGraphUtil.constructCallGraph(program, onlyBody = true) - - def functionsToInstrument(): Map[FunDef, List[Instrumentation]] - - def additionalfunctionsToAdd(): Seq[FunDef] - - def instrumentBody(bodyExpr: Expr, instExpr: Expr)(implicit fd: FunDef): Expr = instExpr - - def getRootFuncs(prog: Program = program): Set[FunDef] = { - prog.definedFunctions.filter { fd => - (fd.hasPostcondition && exists(inst.isInstVariable)(fd.postcondition.get)) - }.toSet - } - - /** - * Given an expression to be instrumented - * and the instrumentation of each of its subexpressions, - * computes an instrumentation for the procedure. - * The sub-expressions correspond to expressions returned - * by Expression Extractors. - * fd is the function containing the expression `e` - */ - def instrument(e: Expr, subInsts: Seq[Expr], funInvResVar: Option[Variable] = None)(implicit fd: FunDef, letIdMap: Map[Identifier, Identifier]): Expr - - /** - * Instrument procedure specialized for if-then-else - */ - def instrumentIfThenElseExpr(e: IfExpr, condInst: Option[Expr], - thenInst: Option[Expr], elzeInst: Option[Expr]): (Expr, Expr) - - /** - * This function is expected to combine the cost of the scrutinee, - * the pattern matching and the expression. - * The cost model for pattern matching is left to the user. - * As matches with guards are converted to ifThenElse statements, - * the user may want to make sure that the cost model for pattern - * matching across match statements and ifThenElse statements is consistent - */ - def instrumentMatchCase(me: MatchExpr, mc: MatchCase, - caseExprCost: Expr, scrutineeCost: Expr): Expr -} diff --git a/src/main/scala/leon/transformations/StackSpaceInstrumenter.scala b/src/main/scala/leon/transformations/StackSpaceInstrumenter.scala deleted file mode 100644 index 51c36bb93356fed867ecbbd0276eaf336f891606..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/transformations/StackSpaceInstrumenter.scala +++ /dev/null @@ -1,338 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package transformations - -import purescala.Common._ -import purescala.Definitions._ -import purescala.Extractors._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Types._ -import leon.utils._ - -class StackSpaceInstrumenter(p: Program, si: SerialInstrumenter) extends Instrumenter(p, si) { - val typedMaxFun = TypedFunDef(InstUtil.maxFun, Seq()) - val optimiseTailCalls = true - - def inst = Stack - - def functionsToInstrument(): Map[FunDef, List[Instrumentation]] = { - // find all functions transitively called from rootFuncs (here ignore functions called via pre/post conditions) - val instFunSet = getRootFuncs().foldLeft(Set[FunDef]())((acc, fd) => acc ++ cg.transitiveCallees(fd)).filter(_.hasBody) - instFunSet.map(x => (x, List(Stack))).toMap - } - - def additionalfunctionsToAdd(): Seq[FunDef] = Seq() //Seq(InstUtil.maxFun) - max functions are inlined, so they need not be added - - def addSubInstsIfNonZero(subInsts: Seq[Expr], init: Expr): Expr = { - subInsts.foldLeft(init)((acc: Expr, subeTime: Expr) => { - (subeTime, acc) match { - case (InfiniteIntegerLiteral(x), _) if (x == 0) => acc - case (_, InfiniteIntegerLiteral(x)) if (x == 0) => subeTime - case _ => FunctionInvocation(typedMaxFun, Seq(acc, subeTime)) - } - }) - } - - // Check if a given function call is a tail recursive call - def isTailCall(call: FunctionInvocation, fd: FunDef): Boolean = { - if (fd.body.isDefined) { - def helper(e: Expr): Boolean = { - e match { - case FunctionInvocation(_,_) if (e == call) => true - case Let(binder, value, body) => helper(body) - case LetDef(_,body) => helper(body) - case IfExpr(_,thenExpr, elseExpr) => helper(thenExpr) || helper(elseExpr) - case MatchExpr(_, mCases) => { - mCases.exists(currCase => helper(currCase.rhs)) - } - case _ => false - } - } - helper(fd.body.get) - } - else false - } - - def instrumentMatchCase(me: MatchExpr, mc: MatchCase, - caseExprCost: Expr, scrutineeCost: Expr): Expr = { - - def costOfMatchPattern(me: MatchExpr, mc: MatchCase): Expr = { - val costOfMatchPattern = 1 - InfiniteIntegerLiteral(costOfMatchPattern) - } - - addSubInstsIfNonZero(Seq(costOfMatchPattern(me, mc), caseExprCost, scrutineeCost), InfiniteIntegerLiteral(0)) - } - - def instrument(e: Expr, subInsts: Seq[Expr], funInvResVar: Option[Variable] = None) - (implicit fd: FunDef, letIdMap: Map[Identifier,Identifier]): Expr = { - - e match { - case t: Terminal => InfiniteIntegerLiteral(0) - case FunctionInvocation(callFd, args) => { - // Need to extimate the size of the activation frame of this function. - // #Args + - // #LocalVals + - // #Temporaries created (assume tree-like evaluation of expressions. This will the maximum - // number of temporaries allocated. Also because we assume all the - // temporaries are allocated on the stack and not used only from registers) - - val numTemps = - if (callFd.body.isDefined) { - val (temp, stack) = estimateTemporaries(callFd.body.get) - temp + stack - } else 0 - val retVar = subInsts.last - val remSubInsts = subInsts.slice(0, subInsts.length - 1) - val totalInvocationCost = { - // model scala's tail recursion optimization here - if ((isTailCall(FunctionInvocation(callFd, args), fd) && fd.id == callFd.id) && optimiseTailCalls) - InfiniteIntegerLiteral(0) - else - retVar - } - val subeTimesExpr = addSubInstsIfNonZero(remSubInsts, InfiniteIntegerLiteral(0)) - - subeTimesExpr match { - case InfiniteIntegerLiteral(x) if (x == 0) => totalInvocationCost - case _ => - addSubInstsIfNonZero(remSubInsts :+ totalInvocationCost, InfiniteIntegerLiteral(0)) - } - } - case _ => addSubInstsIfNonZero(subInsts, InfiniteIntegerLiteral(0)) - } - } - - override def instrumentBody(bodyExpr: Expr, instExpr: Expr)(implicit fd: FunDef): Expr = { - val minActivationRecSize = 2 - val (temps, stack) = estimateTemporaries(bodyExpr) - //println(temps + " " + stack) - Plus(instExpr, InfiniteIntegerLiteral(temps + stack + fd.params.length + - 1 /*object ref*/ + - 1 /*return variable before jumping*/ + - minActivationRecSize /*Sometimes for some reason, there are holes in local vars*/)) - } - - def instrumentIfThenElseExpr(e: IfExpr, condInst: Option[Expr], thenInst: Option[Expr], - elzeInst: Option[Expr]): (Expr, Expr) = { - import invariant.util.Util._ - - val cinst = condInst.toList - val tinst = thenInst.toList - val einst = elzeInst.toList - - (addSubInstsIfNonZero(cinst ++ tinst, zero), - addSubInstsIfNonZero(cinst ++ einst, zero)) - } - - /* Tries to estimate the depth of the operand stack and the temporaries - (excluding arguments) needed by the bytecode. As the JVM might perform - some optimizations when actually executing the bytecode, what we compute - here is an upper bound on the memory needed to evaluate the expression - */ - // (temporaries, stackSize) - def estimateTemporaries(e: Expr): (Int, Int) = { - e match { - /* Like vals */ - case Let(binder: Identifier, value: Expr, body: Expr) => { - // One for the val created + Temps in expr on RHS of initilisation + Rem. body - val (valTemp, valStack) = estimateTemporaries(value) - val (bodyTemp, bodyStack) = estimateTemporaries(body) - (1 + valTemp + bodyTemp, Math.max(valStack, bodyStack)) - } - - case LetDef(fds, body: Expr) => { - // The function definition does not take up stack space. Goes into the constant pool - estimateTemporaries(body) - } - - case FunctionInvocation(tfd: TypedFunDef, args: Seq[Expr]) => { - // One for the object reference. + stack for computing arguments and also the - // fact that the arguments go into the stack - val (temp, stack) = - args.foldLeft(((0, 0), 0))((t: ((Int, Int),Int), currExpr) => { - t match { - case (acc: (Int, Int), currExprNum: Int) => - val (seTemp, seStack) = estimateTemporaries(currExpr) - ((acc._1 + seTemp, Math.max(acc._2, currExprNum + seStack)), 1 + currExprNum) - } - })._1 - - (temp + 1 /*possibly if the ret val needs to be stored for future use*/, stack + 1) - } - - case MethodInvocation(rec: Expr, cd: ClassDef, tfd: TypedFunDef, args: Seq[Expr]) => { - val (recTemp, recStack) = estimateTemporaries(rec) - val (temp, stack) = - args.foldLeft(((recTemp, Math.max(args.length, recStack)), 0))((t: ((Int, Int),Int), currExpr) => { - t match { - case (acc: (Int, Int), currExprNum: Int) => - val (seTemp, seStack) = estimateTemporaries(currExpr) - ((acc._1 + seTemp, Math.max(acc._2, currExprNum + seStack)), 1 + currExprNum) - } - })._1 - - (temp + 1 /*possibly if the ret val needs to be stored for future use*/, stack + 1) - } - - case Application(caller: Expr, args: Seq[Expr]) => { - val (callerTemp, callerStack) = estimateTemporaries(caller) - args.foldLeft(((callerTemp, Math.max(args.length, callerStack)), 0))((t: ((Int, Int),Int), currExpr) => { - t match { - case (acc: (Int, Int), currExprNum: Int) => - val (seTemp, seStack) = estimateTemporaries(currExpr) - ((acc._1 + seTemp, Math.max(acc._2, currExprNum + seStack)), 1 + currExprNum) - } - })._1 - } - - case IfExpr(cond: Expr, thenn: Expr, elze: Expr) => { - val (condTemp, condStack) = estimateTemporaries(cond) - val (thennTemp, thennStack) = estimateTemporaries(thenn) - val (elzeTemp, elzeStack) = estimateTemporaries(elze) - - (condTemp + thennTemp + elzeTemp, - Math.max(condStack, Math.max(thennStack, elzeStack))) - } - - case Tuple (exprs: Seq[Expr]) => { - val (temp, stack) = - exprs.foldLeft(((0, 0), 0))((t: ((Int, Int),Int), currExpr) => { - t match { - case (acc: (Int, Int), currExprNum: Int) => - val (seTemp, seStack) = estimateTemporaries(currExpr) - ((acc._1 + seTemp, Math.max(acc._2, currExprNum + seStack)), 1 + currExprNum) - } - })._1 - - (temp, stack + 2) - } - - case MatchExpr(scrutinee: Expr, cases: Seq[MatchCase]) => { - - // FIXME - def estimateTemporariesMatchPattern(pattern: Pattern): (Int, Int) = { - pattern match { - case InstanceOfPattern(binder: Option[Identifier], ct: ClassType) => { // c: Class - (0,1) - } - - case WildcardPattern(binder: Option[Identifier]) => { // c @ _ - (if (binder.isDefined) 1 else 0, 0) - } - - case CaseClassPattern(binder: Option[Identifier], ct: CaseClassType, subPatterns: Seq[Pattern]) => { - val (temp, stack) = - subPatterns.foldLeft((1 /* create a new var for matching */, 1))((t: (Int, Int), currPattern) => { - t match { - case acc: (Int, Int) => { - val (patTemp, patStack) = estimateTemporariesMatchPattern(currPattern) - (acc._1 + patTemp, Math.max(acc._2, patStack)) - } - } - }) - - (temp, stack) - } - - case TuplePattern(binder: Option[Identifier], subPatterns: Seq[Pattern]) => { - val (temp, stack) = - subPatterns.foldLeft((1 /* create a new var for matching */, 1))((t: (Int, Int), currPattern) => { - t match { - case acc: (Int, Int) => { - val (patTemp, patStack) = estimateTemporariesMatchPattern(currPattern) - (acc._1 + patTemp, Math.max(acc._2, patStack)) - } - } - }) - - (temp, stack) - } - - case LiteralPattern(binder, lit) => { - (0,2) - } - case _ => - throw new NotImplementedError(s"Pattern $pattern not handled yet!") - } - } - - val (scrTemp, scrStack) = estimateTemporaries(scrutinee) - - val res = cases.foldLeft(((scrTemp + 1 /* create a new var for matching */, Math.max(scrStack, 3 /*MatchError*/))))((t: (Int, Int), currCase: MatchCase) => { - t match { - case acc: (Int, Int) => - val (patTemp, patStack) = estimateTemporariesMatchPattern(currCase.pattern) - val (rhsTemp, rhsStack) = estimateTemporaries(currCase.rhs) - val (guardTemp, guardStack) = - if (currCase.optGuard.isDefined) estimateTemporaries(currCase.optGuard.get) else (0,0) - - (patTemp + rhsTemp + guardTemp + acc._1, - Math.max(acc._2, Math.max(patStack, Math.max(guardStack, rhsStack)))) - } - }) - res - } - - /* Propositional logic */ - case Implies(lhs: Expr, rhs: Expr) => { - val (lhsTemp, lhsStack)= estimateTemporaries(lhs) - val (rhsTemp, rhsStack)= estimateTemporaries(rhs) - (rhsTemp + lhsTemp, 1 + Math.max(lhsStack, rhsStack)) - } - - case Not(expr: Expr) => estimateTemporaries(expr) - - case Equals(lhs: Expr, rhs: Expr) => { - val (lhsTemp, lhsStack)= estimateTemporaries(lhs) - val (rhsTemp, rhsStack)= estimateTemporaries(rhs) - (rhsTemp + lhsTemp + - // If object ref, check for non nullity - 1, - //(if (!(lhs.getType == IntegerType && rhs.getType == IntegerType)) 1 else 0), - 1 + Math.max(lhsStack, rhsStack)) - } - - case CaseClass(ct: CaseClassType, args: Seq[Expr]) => { - val (temp, stack) = - args.foldLeft(((0, 0), 0))((t: ((Int, Int),Int), currExpr) => { - t match { - case (acc: (Int, Int), currExprNum: Int) => - val (seTemp, seStack) = estimateTemporaries(currExpr) - ((acc._1 + seTemp, Math.max(acc._2, currExprNum + seStack)), 1 + currExprNum) - } - })._1 - - (temp, stack + 2) - } - - case _: Literal[_] => (0, 1) - - case Variable(id: Identifier) => (0, 1) - - case Lambda(args: Seq[ValDef], body: Expr) => (0, 0) - - case TupleSelect(tuple: Expr, index: Int) => (0, 1) - - /*case BinaryOperator(s1,s2,_) => { - val (s1Temp, s1Stack)= estimateTemporaries(s1) - val (s2Temp, s2Stack)= estimateTemporaries(s2) - (s1Temp + s2Temp, Math.max(s1Stack, 1 + s2Stack)) - }*/ - - case Operator(exprs, _) => { - exprs.foldLeft(((0, 0), 0))((t: ((Int, Int),Int), currExpr) => { - t match { - case (acc: (Int, Int), currExprNum: Int) => - val (seTemp, seStack) = estimateTemporaries(currExpr) - ((acc._1 + seTemp, Math.max(acc._2, currExprNum + seStack)), 1 + currExprNum) - } - })._1 - } - - case _ => (0, 0) - } - } -} diff --git a/src/main/scala/leon/transformations/TimeStepsPhase.scala b/src/main/scala/leon/transformations/TimeStepsPhase.scala deleted file mode 100644 index 9180978ba63d33295aea1988602a48c8fdef8feb..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/transformations/TimeStepsPhase.scala +++ /dev/null @@ -1,97 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package transformations - -import purescala.Common._ -import purescala.Definitions._ -import purescala.Extractors._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Types._ -import leon.utils._ -import leon.invariant.util.Util._ - -object timeCostModel { - def costOf(e: Expr): Int = e match { - case FunctionInvocation(fd, _) if !fd.hasBody => 0 // uninterpreted functions - case FunctionInvocation(fd, args) => 1 - case t: Terminal => 0 - case _ => 1 - } - - def costOfExpr(e: Expr) = InfiniteIntegerLiteral(costOf(e)) -} - -class TimeInstrumenter(p: Program, si: SerialInstrumenter) extends Instrumenter(p, si) { - import timeCostModel._ - - def inst = Time - - def functionsToInstrument(): Map[FunDef, List[Instrumentation]] = { - //find all functions transitively called from rootFuncs (here ignore functions called via pre/post conditions) - val instFunSet = getRootFuncs().foldLeft(Set[FunDef]())((acc, fd) => acc ++ cg.transitiveCallees(fd)).filter(_.hasBody) // ignore uninterpreted functions - instFunSet.map(x => (x, List(Time))).toMap - } - - def additionalfunctionsToAdd() = Seq() - - def instrumentMatchCase( - me: MatchExpr, - mc: MatchCase, - caseExprCost: Expr, - scrutineeCost: Expr): Expr = { - val costMatch = costOfExpr(me) - - def totalCostOfMatchPatterns(me: MatchExpr, mc: MatchCase): BigInt = { - - def patCostRecur(pattern: Pattern, innerPat: Boolean, countLeafs: Boolean): Int = { - pattern match { - case InstanceOfPattern(_, _) => { - if (innerPat) 2 else 1 - } - case WildcardPattern(None) => 0 - case WildcardPattern(Some(id)) => { - if (countLeafs && innerPat) 1 - else 0 - } - case CaseClassPattern(_, _, subPatterns) => { - (if (innerPat) 2 else 1) + subPatterns.foldLeft(0)((acc, subPat) => - acc + patCostRecur(subPat, true, countLeafs)) - } - case TuplePattern(_, subPatterns) => { - (if (innerPat) 2 else 1) + subPatterns.foldLeft(0)((acc, subPat) => - acc + patCostRecur(subPat, true, countLeafs)) - } - case LiteralPattern(_, _) => if (innerPat) 2 else 1 - case _ => - throw new NotImplementedError(s"Pattern $pattern not handled yet!") - } - } - - me.cases.take(me.cases.indexOf(mc)).foldLeft(0)( - (acc, currCase) => acc + patCostRecur(currCase.pattern, false, false)) + - patCostRecur(mc.pattern, false, true) - } - - Plus(costMatch, Plus( - Plus(InfiniteIntegerLiteral(totalCostOfMatchPatterns(me, mc)), - caseExprCost), - scrutineeCost)) - } - - def instrument(e: Expr, subInsts: Seq[Expr], funInvResVar: Option[Variable] = None) - (implicit fd: FunDef, letIdMap: Map[Identifier,Identifier]): Expr = e match { - case t: Terminal => costOfExpr(t) - case _ => - subInsts.foldLeft(costOfExpr(e) : Expr)( - (acc: Expr, subeTime: Expr) => Plus(subeTime, acc)) - } - - def instrumentIfThenElseExpr(e: IfExpr, condInst: Option[Expr], - thenInst: Option[Expr], elzeInst: Option[Expr]): (Expr, Expr) = { - val costIf = costOfExpr(e) - (Plus(costIf, Plus(condInst.get, thenInst.get)), - Plus(costIf, Plus(condInst.get, elzeInst.get))) - } -} \ No newline at end of file diff --git a/src/main/scala/leon/utils/Simplifiers.scala b/src/main/scala/leon/utils/Simplifiers.scala deleted file mode 100644 index d1fe992821c17e2a5b4676e537b41a90297aaf03..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/utils/Simplifiers.scala +++ /dev/null @@ -1,49 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package utils - -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.ScopeSimplifier -import purescala.Path -import solvers._ - -object Simplifiers { - - def bestEffort(ctx: LeonContext, p: Program)(e: Expr, pc: Path = Path.empty): Expr = { - val solverf = SolverFactory.uninterpreted(ctx, p) - - try { - val simplifiers = (simplifyLets _). - andThen(simplifyPaths(solverf, pc)). - andThen(simplifyArithmetic). - andThen(evalGround(ctx, p)). - andThen(normalizeExpression) - - // Simplify first using stable simplifiers - val s = fixpoint(simplifiers, 5)(e) - - // Clean up ids/names - (new ScopeSimplifier).transform(s) - } finally { - solverf.shutdown() - } - } - - def namePreservingBestEffort(ctx: LeonContext, p: Program)(e: Expr): Expr = { - val solverf = SolverFactory.uninterpreted(ctx, p) - - try { - val simplifiers = (simplifyArithmetic _). - andThen(evalGround(ctx, p)). - andThen(normalizeExpression) - - // Simplify first using stable simplifiers - fixpoint(simplifiers, 5)(e) - } finally { - solverf.shutdown() - } - } -} diff --git a/src/main/scala/leon/verification/DefaultTactic.scala b/src/main/scala/leon/verification/DefaultTactic.scala deleted file mode 100644 index 4d03e20e3c968e346a4ed4b7cbd28b4c3b6f6781..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/verification/DefaultTactic.scala +++ /dev/null @@ -1,78 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package verification - -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Definitions._ -import purescala.Constructors._ - -class DefaultTactic(vctx: VerificationContext) extends Tactic(vctx) { - val description = "Default verification condition generation approach" - - def generatePostconditions(fd: FunDef): Seq[VC] = { - (fd.postcondition, fd.body) match { - case (Some(post), Some(body)) => - val vc = implies(fd.precOrTrue, application(post, Seq(body))) - Seq(VC(vc, fd, VCKinds.Postcondition).setPos(post)) - case _ => - Nil - } - } - - def generatePreconditions(fd: FunDef): Seq[VC] = { - - val calls = collectWithPC { - case c @ FunctionInvocation(tfd, _) if tfd.hasPrecondition => - c - }(fd.fullBody) - - calls.map { - case (fi @ FunctionInvocation(tfd, args), path) => - val pre = tfd.withParamSubst(args, tfd.precondition.get) - val vc = path implies pre - val fiS = sizeLimit(fi.asString, 40) - VC(vc, fd, VCKinds.Info(VCKinds.Precondition, s"call $fiS")).setPos(fi) - } - - } - - def generateCorrectnessConditions(fd: FunDef): Seq[VC] = { - - def eToVCKind(e: Expr) = e match { - case _ : MatchExpr => - VCKinds.ExhaustiveMatch - - case Assert(_, Some(err), _) => - if (err.startsWith("Map ")) { - VCKinds.MapUsage - } else if (err.startsWith("Array ")) { - VCKinds.ArrayUsage - } else if (err.startsWith("Division ")) { - VCKinds.DivisionByZero - } else if (err.startsWith("Modulo ")) { - VCKinds.ModuloByZero - } else if (err.startsWith("Remainder ")) { - VCKinds.RemainderByZero - } else if (err.startsWith("Cast ")) { - VCKinds.CastError - } else { - VCKinds.Assert - } - - case _ => - VCKinds.Assert - } - - // We don't collect preconditions here, because these are handled by generatePreconditions - val calls = collectCorrectnessConditions(fd.fullBody) - - calls.map { - case (e, correctnessCond) => - VC(correctnessCond, fd, eToVCKind(e)).setPos(e) - } - } - - -} diff --git a/src/main/scala/leon/verification/InductionTactic.scala b/src/main/scala/leon/verification/InductionTactic.scala deleted file mode 100644 index a81081bc7d6dd973f4bf1a5613250747aa7a9749..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/verification/InductionTactic.scala +++ /dev/null @@ -1,96 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package verification - -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Types._ -import purescala.Definitions._ -import purescala.Constructors._ - -class InductionTactic(vctx: VerificationContext) extends DefaultTactic(vctx) { - override val description = "Induction tactic for suitable functions" - - val reporter = vctx.reporter - - private def firstAbsClassDef(args: Seq[ValDef]): Option[(AbstractClassType, ValDef)] = { - args.map(vd => (vd.getType, vd)).collect { - case (act: AbstractClassType, vd) => (act, vd) - }.headOption - } - - private def selectorsOfParentType(parentType: ClassType, cct: CaseClassType, expr: Expr): Seq[Expr] = { - val childrenOfSameType = (cct.classDef.fields zip cct.fieldsTypes).collect { case (vd, tpe) if tpe == parentType => vd } - for (field <- childrenOfSameType) yield { - caseClassSelector(cct, expr, field.id) - } - } - - override def generatePostconditions(fd: FunDef): Seq[VC] = { - (fd.body, firstAbsClassDef(fd.params), fd.postcondition) match { - case (Some(body), Some((parentType, arg)), Some(post)) => - for (cct <- parentType.knownCCDescendants) yield { - val selectors = selectorsOfParentType(parentType, cct, arg.toVariable) - - val subCases = selectors.map { sel => - replace(Map(arg.toVariable -> sel), - implies(fd.precOrTrue, application(post, Seq(body))) - ) - } - - val vc = implies( - and(IsInstanceOf(arg.toVariable, cct), fd.precOrTrue), - implies(andJoin(subCases), application(post, Seq(body))) - ) - - VC(vc, fd, VCKinds.Info(VCKinds.Postcondition, s"ind. on ${arg.asString} / ${cct.classDef.id.asString}")).setPos(fd) - } - - case (body, _, post) => - if (post.isDefined && body.isDefined) { - reporter.warning(fd.getPos, "Could not find abstract class type argument to induct on") - } - super.generatePostconditions(fd) - } - } - - override def generatePreconditions(fd: FunDef): Seq[VC] = { - (fd.body, firstAbsClassDef(fd.params)) match { - case (Some(b), Some((parentType, arg))) => - val body = b - - val calls = collectWithPC { - case fi @ FunctionInvocation(tfd, _) if tfd.hasPrecondition => (fi, tfd.precondition.get) - }(body) - - for { - ((fi @ FunctionInvocation(tfd, args), pre), path) <- calls - cct <- parentType.knownCCDescendants - } yield { - val selectors = selectorsOfParentType(parentType, cct, arg.toVariable) - - val subCases = selectors.map { sel => - replace(Map(arg.toVariable -> sel), - implies(fd.precOrTrue, tfd.withParamSubst(args, pre)) - ) - } - - val vc = path - .withConds(Seq(IsInstanceOf(arg.toVariable, cct), fd.precOrTrue) ++ subCases) - .implies(tfd.withParamSubst(args, pre)) - - // Crop the call to display it properly - val fiS = sizeLimit(fi.asString, 25) - - VC(vc, fd, VCKinds.Info(VCKinds.Precondition, s"call $fiS, ind. on (${arg.asString} : ${cct.classDef.id.asString})")).setPos(fi) - } - - case (body, _) => - if (body.isDefined) { - reporter.warning(fd.getPos, "Could not find abstract class type argument to induct on") - } - super.generatePreconditions(fd) - } - } -} diff --git a/src/main/scala/leon/verification/InjectAsserts.scala b/src/main/scala/leon/verification/InjectAsserts.scala deleted file mode 100644 index faa5701f7a7f80bc29c81282696a898febcc6c80..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/verification/InjectAsserts.scala +++ /dev/null @@ -1,87 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package verification - -import purescala._ -import Expressions._ -import ExprOps._ -import Definitions._ -import Constructors._ - -object InjectAsserts extends SimpleLeonPhase[Program, Program] { - - val name = "Asserts" - val description = "Inject asserts for various correctness conditions (map accesses, array accesses, divisions by zero,..)" - - def apply(ctx: LeonContext, pgm: Program): Program = { - def indexUpTo(i: Expr, e: Expr) = { - and(GreaterEquals(i, IntLiteral(0)), LessThan(i, e)) - } - - pgm.definedFunctions.foreach(fd => { - fd.body = fd.body.map(postMap { - case e @ ArraySelect(a, i) => - Some(Assert(indexUpTo(i, ArrayLength(a)), Some("Array index out of range"), e).setPos(e)) - case e @ ArrayUpdated(a, i, v) => - Some(ArrayUpdated( - a, - Assert(indexUpTo(i, ArrayLength(a)), Some("Array index out of range"), i).setPos(i), - v - ).setPos(e)) - case e @ MapApply(m,k) => - Some(Assert(MapIsDefinedAt(m, k), Some("Map undefined at this index"), e).setPos(e)) - - case e @ AsInstanceOf(expr, ct) => - Some(Assert(IsInstanceOf(expr, ct), - Some("Cast error"), - e - ).setPos(e)) - - case e @ Division(_, d) => - Some(assertion(not(equality(d, InfiniteIntegerLiteral(0))), - Some("Division by zero"), - e - ).setPos(e)) - case e @ Remainder(_, d) => - Some(assertion(not(equality(d, InfiniteIntegerLiteral(0))), - Some("Remainder by zero"), - e - ).setPos(e)) - case e @ Modulo(_, d) => - Some(assertion(not(equality(d, InfiniteIntegerLiteral(0))), - Some("Modulo by zero"), - e - ).setPos(e)) - - case e @ BVDivision(_, d) => - Some(assertion(not(equality(d, IntLiteral(0))), - Some("Division by zero"), - e - ).setPos(e)) - case e @ BVRemainder(_, d) => - Some(assertion(not(equality(d, IntLiteral(0))), - Some("Remainder by zero"), - e - ).setPos(e)) - - case e @ RealDivision(_, d) => - Some(assertion(not(equality(d, FractionalLiteral(0, 1))), - Some("Division by zero"), - e - ).setPos(e)) - - case e @ CaseClass(cct, args) if cct.root.classDef.hasInvariant => - Some(assertion(FunctionInvocation(cct.root.invariant.get, Seq(e)), - Some("ADT invariant"), - e - ).setPos(e)) - - case _ => - None - }) - }) - - pgm - } -} diff --git a/src/main/scala/leon/verification/Tactic.scala b/src/main/scala/leon/verification/Tactic.scala deleted file mode 100644 index 75910e625432dce2b4084c93a95b8c8ce8d3638b..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/verification/Tactic.scala +++ /dev/null @@ -1,34 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package verification - -import purescala.Definitions._ -import purescala.Expressions._ - -abstract class Tactic(vctx: VerificationContext) { - val description : String - - implicit val ctx = vctx - - def generateVCs(fd: FunDef): Seq[VC] = { - generatePostconditions(fd) ++ - generatePreconditions(fd) ++ - generateCorrectnessConditions(fd) - } - - def generatePostconditions(function: FunDef): Seq[VC] - def generatePreconditions(function: FunDef): Seq[VC] - def generateCorrectnessConditions(function: FunDef): Seq[VC] - - protected def sizeLimit(s: String, limit: Int) = { - require(limit > 3) - // Crop the call to display it properly - val res = s.takeWhile(_ != '\n').take(limit) - if (res == s) { - res - } else { - res + " ..." - } - } -} diff --git a/src/main/scala/leon/verification/TraceInductionTactic.scala b/src/main/scala/leon/verification/TraceInductionTactic.scala deleted file mode 100644 index 73e9f464bf379c742a9f89346a5d7c6eff44d7d4..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/verification/TraceInductionTactic.scala +++ /dev/null @@ -1,177 +0,0 @@ -/* Copyright 2009-2015 EPFL, Lausanne */ - -package leon -package verification - -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.Constructors._ -import purescala.ExprOps._ -import purescala.DefOps._ -import purescala.Common._ -import purescala.Types._ -import purescala.TypeOps.instantiateType -import purescala.Extractors._ -import invariant.util.PredicateUtil._ -import leon.utils._ - -/** - * This tactic applies only to non-recursive functions. - * Inducts over the recursive calls of the first recursive procedure in the body of `funDef` - */ -class TraceInductionTactic(vctx: VerificationContext) extends Tactic(vctx) { - val description: String = "A tactic that performs induction over the recursions of a function." - - val cg = vctx.program.callGraph - val defaultTactic = new DefaultTactic(vctx) - val deepInduct = true // a flag for enabling deep induction pattern discovery - - def generatePostconditions(function: FunDef): Seq[VC] = { - assert(!cg.isRecursive(function) && function.body.isDefined) - val inductFunname = function.extAnnotations("traceInduct") match { - case Seq(Some(arg: String)) => Some(arg) - case a => None - } - // print debug info - if(inductFunname.isDefined) - ctx.reporter.debug("Extracting induction pattern from: "+inductFunname.get)(DebugSectionVerification) - - // helper function - def selfRecs(fd: FunDef): Set[FunctionInvocation] = { - if(fd.body.isDefined){ - collect{ - case fi@FunctionInvocation(tfd, _) if tfd.fd == fd => - Set(fi) - case _ => Set.empty[FunctionInvocation] - }(fd.body.get) - } else Set() - } - - if (function.hasPostcondition) { - // construct post(body) - val prop = application(function.postcondition.get, Seq(function.body.get)) - val paramVars = function.paramIds.map(_.toVariable) - // extract the first recursive call when scanning `prop` AST from left to right - var funInv: Option[FunctionInvocation] = None - preTraversal { - case _ if funInv.isDefined => - // do nothing - case fi @ FunctionInvocation(tfd, args) if cg.isRecursive(tfd.fd) // function is recursive - => - val argCheck = - if (deepInduct) { - // here we do a much deeper check - // collect all arguments that are not `paramVars` - val rest = args.zipWithIndex.filterNot(p => paramVars.contains(p._1)) - // check if 'rest' is invariant in all recursive calls - val calleeParams = tfd.fd.params.map(_.id.toVariable) - val restInv = selfRecs(tfd.fd).forall { - case FunctionInvocation(_, recArgs) => - rest.forall { case (_, i) => calleeParams(i) == recArgs(i) } - } - val paramArgs = args.filter(paramVars.contains) - paramArgs.toSet.size == paramArgs.size && // paramArgs are unique ? - restInv - } else { - args.forall(paramVars.contains) && // all arguments are parameters - args.toSet.size == args.size // all arguments are unique - } - if (argCheck) { - if (inductFunname.isDefined) { - if (inductFunname.get == tfd.fd.id.name) - funInv = Some(fi) - } else { - funInv = Some(fi) - } - } - case _ => - }(prop) - funInv match { - case None => - ctx.reporter.warning("Cannot discover induction pattern! Falling back to normal tactic.") - defaultTactic.generatePostconditions(function) - case Some(finv) => - // create a new function that realizes the tactic - val tactFun = new FunDef(FreshIdentifier(function.id.name + "-VCTact"), function.tparams, - function.params, BooleanType) - tactFun.precondition = function.precondition - // the body of tactFun is a conjunction of induction pattern of finv, and the property - val callee = finv.tfd.fd - val paramIndex = paramVars.zipWithIndex.toMap - val framePositions = finv.args.zipWithIndex.collect { // note: the index here is w.r.t calleeArgs - case (v: Variable, i) if paramVars.contains(v) => (v, i) - }.toMap - val footprint = paramVars.filterNot(framePositions.keySet.contains) - val indexedFootprint = footprint.map { a => paramIndex(a) -> a }.toMap // index here is w.r.t params - - // the returned expression will have boolean value - def inductPattern(e: Expr): Expr = { - e match { - case IfExpr(c, th, el) => - createAnd(Seq(inductPattern(c), - IfExpr(c, inductPattern(th), inductPattern(el)))) - - case MatchExpr(scr, cases) => - val scrpat = inductPattern(scr) - val casePats = cases.map{ - case MatchCase(pat, optGuard, rhs) => - val guardPat = optGuard.toSeq.map(inductPattern _) - (guardPat, MatchCase(pat, optGuard, inductPattern(rhs))) - } - val pats = scrpat +: casePats.flatMap(_._1) :+ MatchExpr(scr, casePats.map(_._2)) - createAnd(pats) - - case Let(i, v, b) => - createAnd(Seq(inductPattern(v), Let(i, v, inductPattern(b)))) - - case FunctionInvocation(tfd, args) => - val argPattern = createAnd(args.map(inductPattern)) - if (tfd.fd == callee) { // self recursive call ? - // create a tactFun invocation to mimic the recursion pattern - val indexedArgs = framePositions.map { - case (f, i) => paramIndex(f) -> args(i) - }.toMap ++ indexedFootprint - val recArgs = (0 until indexedArgs.size).map(indexedArgs) - val recCall = FunctionInvocation(TypedFunDef(tactFun, tactFun.tparams.map(_.tp)), recArgs) - createAnd(Seq(argPattern, recCall)) - } else { - argPattern - } - - case Operator(args, op) => - // conjoin all the expressions and return them - createAnd(args.map(inductPattern)) - } - } - val argsMap = callee.params.map(_.id).zip(finv.args).toMap - val tparamMap = callee.typeArgs.zip(finv.tfd.tps).toMap - val inlinedBody = instantiateType(replaceFromIDs(argsMap, callee.body.get), tparamMap, Map()) - val inductScheme = inductPattern(inlinedBody) - // add body, pre and post for the tactFun - tactFun.body = Some(createAnd(Seq(inductScheme, prop))) - tactFun.precondition = function.precondition - // postcondition is `holds` - val resid = FreshIdentifier("holds", BooleanType) - tactFun.postcondition = Some(Lambda(Seq(ValDef(resid)), resid.toVariable)) - - // print debug info if needed - ctx.reporter.debug("Autogenerated tactic fun: "+tactFun)(DebugSectionVerification) - - // generate vcs using the tactfun - (defaultTactic.generatePostconditions(tactFun) ++ - defaultTactic.generatePreconditions(tactFun) ++ - defaultTactic.generateCorrectnessConditions(tactFun)) map { - // rename the VCs to a user-friendly name - case VC(cond, _, vcinfo) => - VC(cond, function, VCKinds.Info(VCKinds.PostTactVC, vcinfo.toString)) - } - } - } else Seq() - } - - def generatePreconditions(function: FunDef): Seq[VC] = - defaultTactic.generatePreconditions(function) - - def generateCorrectnessConditions(function: FunDef): Seq[VC] = - defaultTactic.generateCorrectnessConditions(function) -} diff --git a/src/main/scala/leon/verification/VerificationCondition.scala b/src/main/scala/leon/verification/VerificationCondition.scala deleted file mode 100644 index 117380725d27e06afddfa577ed105186982a0216..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/verification/VerificationCondition.scala +++ /dev/null @@ -1,110 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon.verification - -import leon.purescala.Expressions._ -import leon.purescala.Definitions._ -import leon.purescala.Types._ -import leon.purescala.PrettyPrinter -import leon.purescala.ExprOps -import leon.utils.Positioned -import leon.solvers._ -import leon.LeonContext -import leon.purescala.SelfPrettyPrinter - -/** This is just to hold some history information. */ -case class VC(condition: Expr, fd: FunDef, kind: VCKind) extends Positioned { - override def toString = { - fd.id.name +" - " +kind.toString - } - // If the two functions are the same but have different positions, used to transfer one to the other. - def copyTo(newFun: FunDef) = { - val thisPos = this.getPos - val newPos = ExprOps.lookup(_.getPos == thisPos, _.getPos)(fd.fullBody, newFun.fullBody) match { - case Some(position) => position - case None => newFun.getPos - } - val newCondition = ExprOps.lookup(condition == _, i => i)(fd.fullBody, newFun.fullBody).getOrElse(condition) - VC(newCondition, newFun, kind).setPos(newPos) - } -} - -abstract class VCKind(val name: String, val abbrv: String) { - override def toString = name - - def underlying = this -} - -object VCKinds { - case class Info(k: VCKind, info: String) extends VCKind(k.abbrv+" ("+info+")", k.abbrv) { - override def underlying = k - } - - case object Precondition extends VCKind("precondition", "precond.") - case object Postcondition extends VCKind("postcondition", "postcond.") - case object Assert extends VCKind("body assertion", "assert.") - case object ExhaustiveMatch extends VCKind("match exhaustiveness", "match.") - case object MapUsage extends VCKind("map usage", "map use") - case object ArrayUsage extends VCKind("array usage", "arr. use") - case object DivisionByZero extends VCKind("division by zero", "div 0") - case object ModuloByZero extends VCKind("modulo by zero", "mod 0") - case object RemainderByZero extends VCKind("remainder by zero", "rem 0") - case object CastError extends VCKind("cast correctness", "cast") - case object PostTactVC extends VCKind("Postcondition Tactic", "tact") -} - -case class VCResult(status: VCStatus, solvedWith: Option[Solver], timeMs: Option[Long]) { - def isValid = status == VCStatus.Valid - def isInvalid = status.isInstanceOf[VCStatus.Invalid] - def isInconclusive = !isValid && !isInvalid - - def report(vctx: VerificationContext) { - import vctx.reporter - - status match { - case VCStatus.Valid => - reporter.info(" => VALID") - - case VCStatus.Invalid(cex) => - reporter.warning(" => INVALID") - reporter.warning("Found counter-example:") - - // We use PrettyPrinter explicitly and not ScalaPrinter: printing very - // large arrays faithfully in ScalaPrinter is hard, while PrettyPrinter - // is free to simplify - val strings = cex.toSeq.sortBy(_._1.name).map { - case (id, v) => - (id.asString(vctx), SelfPrettyPrinter.print(v, PrettyPrinter(v))(vctx, vctx.program)) - } - - if (strings.nonEmpty) { - val max = strings.map(_._1.length).max - - for ((id, v) <- strings) { - reporter.warning((" %-"+max+"s -> %s").format(id, v)) - } - } else { - reporter.warning(f" (Empty counter-example)") - } - case _ => - reporter.warning(" => "+status.name.toUpperCase) - } - } -} - -object VCResult { - def unknown = VCResult(VCStatus.Unknown, None, None) -} - -sealed abstract class VCStatus(val name: String) { - override def toString = name -} - -object VCStatus { - case class Invalid(cex: Model) extends VCStatus("invalid") - case object Valid extends VCStatus("valid") - case object Unknown extends VCStatus("unknown") - case object Timeout extends VCStatus("timeout") - case object Cancelled extends VCStatus("cancelled") - case object Crashed extends VCStatus("crashed") -} diff --git a/src/main/scala/leon/verification/VerificationContext.scala b/src/main/scala/leon/verification/VerificationContext.scala deleted file mode 100644 index 87811eb8955cf2b1e1e1ac91c4c4c14adae01a4d..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/verification/VerificationContext.scala +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package verification - -import purescala.Definitions.Program -import solvers._ - -class VerificationContext( - context: LeonContext, - val program: Program, - val solverFactory: SolverFactory[Solver] -) extends LeonContext( - context.reporter, - context.interruptManager, - context.options, - context.files, - context.classDir, - context.timers -) { - lazy val checkInParallel: Boolean = context.findOptionOrDefault(VerificationPhase.optParallelVCs) -} diff --git a/src/main/scala/leon/verification/VerificationPhase.scala b/src/main/scala/leon/verification/VerificationPhase.scala deleted file mode 100644 index 01df60771b9c602272a28f3967e338c5f8f2ac34..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/verification/VerificationPhase.scala +++ /dev/null @@ -1,188 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package verification - -import purescala.Definitions._ -import purescala.ExprOps._ - -import scala.concurrent.duration._ - -import solvers._ - -object VerificationPhase extends SimpleLeonPhase[Program,VerificationReport] { - val name = "Verification" - val description = "Verification of function contracts" - - val optParallelVCs = LeonFlagOptionDef("parallel", "Check verification conditions in parallel", default = false) - - override val definedOptions: Set[LeonOptionDef[Any]] = Set(optParallelVCs) - - implicit val debugSection = utils.DebugSectionVerification - - def apply(ctx: LeonContext, program: Program): VerificationReport = { - val filterFuns: Option[Seq[String]] = ctx.findOption(GlobalOptions.optFunctions) - val timeout: Option[Long] = ctx.findOption(GlobalOptions.optTimeout) - - val reporter = ctx.reporter - - // Solvers selection and validation - val baseSolverF = SolverFactory.getFromSettings(ctx, program) - - val solverF = timeout match { - case Some(sec) => - baseSolverF.withTimeout(sec.seconds) - case None => - baseSolverF - } - - val vctx = new VerificationContext(ctx, program, solverF) - - reporter.debug("Generating Verification Conditions...") - - def excludeByDefault(fd: FunDef): Boolean = fd.annotations contains "library" - - val fdFilter = { - import OptionsHelpers._ - - filterInclusive(filterFuns.map(fdMatcher(program)), Some(excludeByDefault _)) - } - - val toVerify = program.definedFunctions.filter(fdFilter).sortWith((fd1, fd2) => fd1.getPos < fd2.getPos) - - for(funDef <- toVerify) { - if (excludeByDefault(funDef)) { - reporter.warning("Forcing verification of " + funDef.qualifiedName(program) + " which was assumed verified") - } - } - - try { - val vcs = generateVCs(vctx, toVerify) - - reporter.debug("Checking Verification Conditions...") - - checkVCs(vctx, vcs) - } finally { - solverF.shutdown() - } - } - - def generateVCs(vctx: VerificationContext, toVerify: Seq[FunDef]): Seq[VC] = { - val defaultTactic = new DefaultTactic(vctx) - val inductionTactic = new InductionTactic(vctx) - val trInductTactic = new TraceInductionTactic(vctx) - - val vcs = for(funDef <- toVerify) yield { - val tactic: Tactic = - if (funDef.annotations.contains("induct")) { - inductionTactic - } else if(funDef.annotations.contains("traceInduct")){ - trInductTactic - }else { - defaultTactic - } - - if(funDef.body.isDefined) { - tactic.generateVCs(funDef) - } else { - Nil - } - } - - vcs.flatten - } - - def checkVCs( - vctx: VerificationContext, - vcs: Seq[VC], - stopWhen: VCResult => Boolean = _ => false - ): VerificationReport = { - val interruptManager = vctx.interruptManager - - var stop = false - - val initMap: Map[VC, Option[VCResult]] = vcs.map(v => v -> None).toMap - - val results = if (vctx.checkInParallel) { - for (vc <- vcs.par if !stop) yield { - val r = checkVC(vctx, vc) - if (interruptManager.isInterrupted) interruptManager.recoverInterrupt() - stop = stopWhen(r) - vc -> Some(r) - } - } else { - for (vc <- vcs.toSeq.sortWith((a,b) => a.fd.getPos < b.fd.getPos) if !interruptManager.isInterrupted && !stop) yield { - val r = checkVC(vctx, vc) - if (interruptManager.isInterrupted) interruptManager.recoverInterrupt() - stop = stopWhen(r) - vc -> Some(r) - } - } - - VerificationReport(vctx.program, initMap ++ results) - } - - def checkVC(vctx: VerificationContext, vc: VC): VCResult = { - import vctx.reporter - import vctx.solverFactory - - val interruptManager = vctx.interruptManager - - val vcCond = vc.condition - - val s = solverFactory.getNewSolver() - - try { - reporter.synchronized { - reporter.info(s" - Now considering '${vc.kind}' VC for ${vc.fd.id} @${vc.getPos}...") - reporter.debug(simplifyLets(vcCond).asString(vctx)) - reporter.debug("Solving with: " + s.name) - } - - val tStart = System.currentTimeMillis - - s.assertVC(vc) - - val satResult = s.check - - val dt = System.currentTimeMillis - tStart - - val res = satResult match { - case _ if interruptManager.isInterrupted => - VCResult(VCStatus.Cancelled, Some(s), Some(dt)) - - case None => - val status = s match { - case ts: TimeoutSolver => - ts.optTimeout match { - case Some(t) if t < dt => - VCStatus.Timeout - case _ => - VCStatus.Unknown - } - case _ => - VCStatus.Unknown - } - VCResult(status, Some(s), Some(dt)) - - case Some(false) => - VCResult(VCStatus.Valid, s.getResultSolver, Some(dt)) - - case Some(true) => - VCResult(VCStatus.Invalid(s.getModel), s.getResultSolver, Some(dt)) - } - - reporter.synchronized { - if (vctx.checkInParallel) { - reporter.info(s" - Result for '${vc.kind}' VC for ${vc.fd.id} @${vc.getPos}:") - } - res.report(vctx) - } - - res - - } finally { - s.free() - } - } -} diff --git a/src/main/scala/leon/verification/VerificationReport.scala b/src/main/scala/leon/verification/VerificationReport.scala deleted file mode 100644 index 374122f72781e4e0c24ee05e6a96fa18cfed5be5..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/verification/VerificationReport.scala +++ /dev/null @@ -1,51 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package verification - -import purescala.Definitions.Program -import utils.Report - -case class VerificationReport(program: Program, results: Map[VC, Option[VCResult]]) extends Report { - val vrs: Seq[(VC, VCResult)] = results.toSeq.sortBy { case (vc, _) => (vc.fd.id.name, vc.kind.toString) }.map { - case (vc, or) => (vc, or.getOrElse(VCResult.unknown)) - } - - lazy val totalConditions : Int = vrs.size - - lazy val totalTime: Long = vrs.map(_._2.timeMs.getOrElse(0l)).sum - - lazy val totalValid: Int = vrs.count(_._2.isValid) - lazy val totalInvalid: Int = vrs.count(_._2.isInvalid) - lazy val totalUnknown: Int = vrs.count(_._2.isInconclusive) - - def summaryString : String = if(totalConditions >= 0) { - import utils.ASCIIHelpers._ - - var t = Table("Verification Summary") - - t ++= vrs.map { case (vc, vr) => - val timeStr = vr.timeMs.map(t => f"${t/1000d}%-3.3f").getOrElse("") - Row(Seq( - Cell(vc.fd.qualifiedName(program)), - Cell(vc.kind.name), - Cell(vc.getPos), - Cell(vr.status), - Cell(vr.solvedWith.map(_.name).getOrElse("")), - Cell(timeStr, align = Right) - )) - } - - t += Separator - - t += Row(Seq( - Cell(f"total: $totalConditions%-4d valid: $totalValid%-4d invalid: $totalInvalid%-4d unknown $totalUnknown%-4d", 5), - Cell(f"${totalTime/1000d}%7.3f", align = Right) - )) - - t.render - - } else { - "No verification conditions were analyzed." - } -} diff --git a/src/main/scala/leon/xlang/AntiAliasingPhase.scala b/src/main/scala/leon/xlang/AntiAliasingPhase.scala deleted file mode 100644 index 8f4fb2a13cd263e1e597ac7c7bc5fb44271d6692..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/xlang/AntiAliasingPhase.scala +++ /dev/null @@ -1,650 +0,0 @@ -/* Copyright 2009-2015 EPFL, Lausanne */ -package leon.xlang - -import leon.TransformationPhase -import leon.LeonContext -import leon.purescala.Common._ -import leon.purescala.Definitions._ -import leon.purescala.Expressions._ -import leon.purescala.ExprOps._ -import leon.purescala.DefOps._ -import leon.purescala.Types._ -import leon.purescala.Constructors._ -import leon.purescala.Extractors._ -import leon.purescala.DependencyFinder -import leon.purescala.DefinitionTransformer -import leon.utils.Bijection -import leon.xlang.Expressions._ -import leon.xlang.ExprOps._ - -object AntiAliasingPhase extends TransformationPhase { - - val name = "Anti-Aliasing" - val description = "Make aliasing explicit" - - - override def apply(ctx: LeonContext, program: Program): Program = { - - val effectsAnalysis = new EffectsAnalysis - - //we need to perform this now, because as soon as we apply the def transformer - //some types will become Untyped and the checkAliasing won't be reliable anymore - allFunDefs(program).foreach(fd => checkAliasing(fd, effectsAnalysis)(ctx)) - - //mapping for case classes that needs to be replaced - //var ccdMap: Map[CaseClassDef, CaseClassDef] = - // (for { - // ccd <- program.singleCaseClasses - // } yield (ccd, updateCaseClassDef(ccd))).toMap.filter(p => p._1 != p._2) - - - //println("ccdMap: " + ccdMap) - //val ccdBijection: Bijection[ClassDef, ClassDef] = Bijection(ccdMap) - //val (pgm, _, _, _) = replaceDefs(program)(fd => None, cd => ccdBijection.getB(cd)) - //println(pgm) - - //val dependencies = new DependencyFinder - //ccdMap = updateCaseClassDefs(ccdMap, dependencies, pgm) - - //val idsMap: Map[Identifier, Identifier] = ccdMap.flatMap(p => { - // p._1.fields.zip(p._2.fields).filter(pvd => pvd._1.id != pvd._2).map(p => (p._1.id, p._2.id)) - //}).toMap - val transformer = new DefinitionTransformer { - override def transformType(tpe: TypeTree): Option[TypeTree] = tpe match { - case (ft: FunctionType) => Some(makeFunctionTypeExplicit(ft, effectsAnalysis)) - case _ => None - } - //override def transformClassDef(cd: ClassDef): Option[ClassDef] = ccdBijection.getB(cd) - } - //pgm.singleCaseClasses.foreach(ccd => println(ccd + " -> " + defTf.transform(ccd))) - val cdsMap = program.definedClasses.map(cd => cd -> transformer.transform(cd)).toMap - val fdsMap = program.definedFunctions.map(fd => fd -> transformer.transform(fd)).toMap - val pgm = replaceDefsInProgram(program)(fdsMap, cdsMap) - //println(leon.purescala.ScalaPrinter(pgm, ctx))//leon.purescala.ScalaPrinter.create(leon.purescala.PrinterOptions(printTypes = true), Some(pgm)).pp(pgm)) - //println(pgm) - - val fds = allFunDefs(pgm) - - var updatedFunctions: Map[FunDef, FunDef] = Map() - - //for each fun def, all the vars the the body captures. Only - //mutable types. - val varsInScope: Map[FunDef, Set[Identifier]] = (for { - fd <- fds - } yield { - val allFreeVars = fd.body.map(bd => variablesOf(bd)).getOrElse(Set()) - val freeVars = allFreeVars -- fd.params.map(_.id) - val mutableFreeVars = freeVars.filter(id => effectsAnalysis.isMutableType(id.getType)) - (fd, mutableFreeVars) - }).toMap - - /* - * The first pass will introduce all new function definitions, - * so that in the next pass we can update function invocations - */ - for { - fd <- fds - } { - updatedFunctions += (fd -> updateFunDef(fd, effectsAnalysis)(ctx)) - } - //println(updatedFunctions.filter(p => p._1.id != p._2.id).mkString("\n")) - - for { - fd <- fds - } { - updateBody(fd, effectsAnalysis, updatedFunctions, varsInScope)(ctx) - } - - replaceDefsInProgram(pgm)(updatedFunctions, Map[ClassDef, ClassDef]()) - - //pgm.copy(units = for (u <- pgm.units) yield { - // u.copy(defs = u.defs.map { - // case m : ModuleDef => - // m.copy(defs = for (df <- m.defs) yield { - // df match { - // case cd : CaseClassDef => ccdBijection.getBorElse(cd, cd) - // case fd : FunDef => updatedFunctions.getOrElse(fd, fd) - // case d => d - // } - // }) - // case cd: CaseClassDef => ccdBijection.getBorElse(cd, cd) - // case d => d - // }) - //}) - } - - /* - * Create a new FunDef for a given FunDef in the program. - * Adapt the signature to express its effects. In case the - * function has no effect, this will still return the original - * fundef. - * - * Also update FunctionType parameters that need to become explicit - * about the effect they could perform (returning any mutable type that - * they receive). - */ - private def updateFunDef(fd: FunDef, effects: EffectsAnalysis)(ctx: LeonContext): FunDef = { - - val ownEffects = effects(fd) - val aliasedParams: Seq[Identifier] = fd.params.zipWithIndex.flatMap{ - case (vd, i) if ownEffects.contains(i) => Some(vd) - case _ => None - }.map(_.id) - - //val newParams = fd.params.map(vd => vd.getType match { - // case (ft: FunctionType) => { - // val nft = makeFunctionTypeExplicit(ft) - // if(ft == nft) vd else ValDef(vd.id.duplicate(tpe = nft)) - // } - // case (cct: CaseClassType) => ccdMap.get(cct.classDef) match { - // case Some(ncd) => ValDef(vd.id.duplicate(tpe = ncd.typed)) - // case None => vd - // } - // case _ => vd - //}) - - - fd.body.foreach(body => getReturnedExpr(body).foreach{ - case v@Variable(id) if aliasedParams.contains(id) => - ctx.reporter.fatalError(v.getPos, "Cannot return a shared reference to a mutable object") - case _ => () - }) - - if(aliasedParams.isEmpty) fd else { - val newReturnType: TypeTree = tupleTypeWrap(fd.returnType +: aliasedParams.map(_.getType)) - val newFunDef = new FunDef(fd.id.freshen, fd.tparams, fd.params, newReturnType) - newFunDef.addFlags(fd.flags) - newFunDef.setPos(fd) - newFunDef - } - } - - //private def updateCaseClassDef(ccd: CaseClassDef): CaseClassDef = { - // val newFields = ccd.fields.map(vd => vd.getType match { - // case (ft: FunctionType) => { - // val nft = makeFunctionTypeExplicit(ft) - // if(nft == ft) vd else { - // ValDef(vd.id.duplicate(tpe = nft)) - // } - // } - // case _ => vd - // }) - // if(newFields != ccd.fields) { - // ccd.duplicate(fields = newFields) - // } else { - // ccd - // } - //} - - //recursively update until fixpoint reach - //private def updateCaseClassDefs(ccdMap: Map[CaseClassDef, CaseClassDef], deps: DependencyFinder, pgm: Program): Map[CaseClassDef, CaseClassDef] = { - // for { - // ccd <- pgm.singleCaseClasses - // } { - // if(deps(ccd).exists(_ == - // (ccd, updateCaseClassDef(ccd))).toMap.filter(p => p._1 != p._2) - // } - // for - //} - - private def updateBody(fd: FunDef, effects: EffectsAnalysis, updatedFunDefs: Map[FunDef, FunDef], - varsInScope: Map[FunDef, Set[Identifier]]) - (ctx: LeonContext): Unit = { - //println("updating: " + fd) - - val ownEffects = effects(fd) - val aliasedParams: Seq[Identifier] = fd.params.zipWithIndex.flatMap{ - case (vd, i) if ownEffects.contains(i) => Some(vd) - case _ => None - }.map(_.id) - - val newFunDef = updatedFunDefs(fd) - - if(aliasedParams.isEmpty) { - val newBody = fd.body.map(body => { - makeSideEffectsExplicit(body, fd, Seq(), effects, updatedFunDefs, varsInScope)(ctx) - }) - newFunDef.body = newBody - newFunDef.precondition = fd.precondition - newFunDef.postcondition = fd.postcondition - } else { - val freshLocalVars: Seq[Identifier] = aliasedParams.map(v => v.freshen) - val rewritingMap: Map[Identifier, Identifier] = aliasedParams.zip(freshLocalVars).toMap - - val newBody = fd.body.map(body => { - - val freshBody = rewriteIDs(rewritingMap, body) - val explicitBody = makeSideEffectsExplicit(freshBody, fd, freshLocalVars, effects, updatedFunDefs, varsInScope)(ctx) - - //only now we rewrite function parameters that changed names when the new function was introduced - val paramRewritings: Map[Identifier, Identifier] = - fd.params.zip(newFunDef.params).filter({ - case (ValDef(oid), ValDef(nid)) if oid != nid => true - case _ => false - }).map(p => (p._1.id, p._2.id)).toMap - val explicitBody2 = replaceFromIDs(paramRewritings.map(p => (p._1, p._2.toVariable)), explicitBody) - - - //WARNING: only works if side effects in Tuples are extracted from left to right, - // in the ImperativeTransformation phase. - val finalBody: Expr = Tuple(explicitBody2 +: freshLocalVars.map(_.toVariable)) - - freshLocalVars.zip(aliasedParams).foldLeft(finalBody)((bd, vp) => { - LetVar(vp._1, Variable(vp._2), bd) - }) - - }) - - val newPostcondition = fd.postcondition.map(post => { - val Lambda(Seq(res), postBody) = post - val newRes = ValDef(FreshIdentifier(res.id.name, newFunDef.returnType)) - val newBody = - replace( - aliasedParams.zipWithIndex.map{ case (id, i) => - (id.toVariable, TupleSelect(newRes.toVariable, i+2)): (Expr, Expr)}.toMap ++ - aliasedParams.map(id => - (Old(id), id.toVariable): (Expr, Expr)).toMap + - (res.toVariable -> TupleSelect(newRes.toVariable, 1)), - postBody) - Lambda(Seq(newRes), newBody).setPos(post) - }) - - newFunDef.body = newBody - newFunDef.precondition = fd.precondition - newFunDef.postcondition = newPostcondition - } - } - - //We turn all local val of mutable objects into vars and explicit side effects - //using assignments. We also make sure that no aliasing is being done. - private def makeSideEffectsExplicit - (body: Expr, originalFd: FunDef, aliasedParams: Seq[Identifier], effects: EffectsAnalysis, - updatedFunDefs: Map[FunDef, FunDef], varsInScope: Map[FunDef, Set[Identifier]]) - (ctx: LeonContext): Expr = { - - val newFunDef = updatedFunDefs(originalFd) - - def mapApplication(args: Seq[Expr], nfi: Expr, nfiType: TypeTree, fiEffects: Set[Int], rewritings: Map[Identifier, Expr]): Expr = { - if(fiEffects.nonEmpty) { - val modifiedArgs: Seq[(Identifier, Expr)] = - args.zipWithIndex.filter{ case (arg, i) => fiEffects.contains(i) } - .map(arg => { - val rArg = replaceFromIDs(rewritings, arg._1) - (findReceiverId(rArg).get, rArg) - }) - - val duplicatedParams = modifiedArgs.diff(modifiedArgs.distinct).distinct - if(duplicatedParams.nonEmpty) - ctx.reporter.fatalError(nfi.getPos, "Illegal passing of aliased parameter: " + duplicatedParams.head) - - val freshRes = FreshIdentifier("res", nfiType) - - val extractResults = Block( - modifiedArgs.zipWithIndex.map{ case ((id, expr), index) => { - val resSelect = TupleSelect(freshRes.toVariable, index + 2) - expr match { - case cs@CaseClassSelector(_, obj, mid) => - Assignment(id, deepCopy(cs, resSelect)) - case _ => - Assignment(id, resSelect) - } - }}, - TupleSelect(freshRes.toVariable, 1)) - - - val newExpr = Let(freshRes, nfi, extractResults) - newExpr - } else { - nfi - } - } - - //println("aliased params: " + aliasedParams) - preMapWithContext[(Set[Identifier], Map[Identifier, Expr], Set[Expr])]((expr, context) => { - val bindings = context._1 - val rewritings = context._2 - val toIgnore = context._3 - expr match { - - case l@Let(id, v, b) if effects.isMutableType(id.getType) => { - val varDecl = LetVar(id, v, b).setPos(l) - (Some(varDecl), (bindings + id, rewritings, toIgnore)) - } - - //case l@Let(id, IsTyped(v, CaseClassType(ccd, _)), b) if ccdMap.contains(ccd) => { - // val ncd = ccdMap(ccd) - // val freshId = id.duplicate(tpe = ncd.typed) - // val freshBody = replaceFromIDs(Map(id -> freshId.toVariable), b) - // (Some(Let(freshId, v, freshBody).copiedFrom(l)), context) - //} - - case l@LetVar(id, IsTyped(v, tpe), b) if effects.isMutableType(tpe) => { - (None, (bindings + id, rewritings, toIgnore)) - } - - case m@MatchExpr(scrut, cses) if effects.isMutableType(scrut.getType) => { - - val tmp: Map[Identifier, Expr] = cses.flatMap{ case MatchCase(pattern, guard, rhs) => { - mapForPattern(scrut, pattern) - //val binder = pattern.binder.get - //binder -> scrut - }}.toMap - - (None, (bindings, rewritings ++ tmp, toIgnore)) - } - - case up@ArrayUpdate(a, i, v) => { - val ra = replaceFromIDs(rewritings, a) - - ra match { - case x@Variable(id) => - if(bindings.contains(id)) - (Some(Assignment(id, ArrayUpdated(x, i, v).setPos(up)).setPos(up)), context) - else - (None, context) - case CaseClassSelector(_, o, id) => //should be a path in an object, with id the array to update in the object - findReceiverId(o) match { - case None => - ctx.reporter.fatalError(up.getPos, "Unsupported form of array update: " + up) - case Some(oid) => { - if(bindings.contains(oid)) - (Some(Assignment(oid, deepCopy(ArraySelect(ra, i), v).setPos(up))), context) - else - (None, context) - } - } - case _ => - ctx.reporter.fatalError(up.getPos, "Unsupported form of array update: " + up) - } - } - - case as@FieldAssignment(o, id, v) => { - val so = replaceFromIDs(rewritings, o) - findReceiverId(so) match { - case None => - ctx.reporter.fatalError(as.getPos, "Unsupported form of field assignment: " + as) - case Some(oid) => { - if(bindings.contains(oid)) - (Some(Assignment(oid, deepCopy(CaseClassSelector(o.getType.asInstanceOf[CaseClassType], so, id), v))), context) - else - (None, context) - } - } - } - - //we need to replace local fundef by the new updated fun defs. - case l@LetDef(fds, body) => { - //this might be traversed several time in case of doubly nested fundef, - //so we need to ignore the second times by checking if updatedFunDefs - //contains a mapping or not - val nfds = fds.map(fd => updatedFunDefs.get(fd).getOrElse(fd)) - (Some(LetDef(nfds, body).copiedFrom(l)), context) - } - - case lambda@Lambda(params, body) => { - val ft@FunctionType(_, _) = lambda.getType - val ownEffects = effects.functionTypeEffects(ft) - val aliasedParams: Seq[Identifier] = params.zipWithIndex.flatMap{ - case (vd, i) if ownEffects.contains(i) => Some(vd) - case _ => None - }.map(_.id) - - if(aliasedParams.isEmpty) { - (None, context) - } else { - val freshLocalVars: Seq[Identifier] = aliasedParams.map(v => v.freshen) - val rewritingMap: Map[Identifier, Identifier] = aliasedParams.zip(freshLocalVars).toMap - val freshBody = replaceFromIDs(rewritingMap.map(p => (p._1, p._2.toVariable)), body) - val explicitBody = makeSideEffectsExplicit(freshBody, originalFd, freshLocalVars, effects, updatedFunDefs, varsInScope)(ctx) - - //WARNING: only works if side effects in Tuples are extracted from left to right, - // in the ImperativeTransformation phase. - val finalBody: Expr = Tuple(explicitBody +: freshLocalVars.map(_.toVariable)) - - val wrappedBody: Expr = freshLocalVars.zip(aliasedParams).foldLeft(finalBody)((bd, vp) => { - LetVar(vp._1, Variable(vp._2), bd) - }) - val finalLambda = Lambda(params, wrappedBody).copiedFrom(lambda) - - (Some(finalLambda), context) - } - - } - - case fi@FunctionInvocation(fd, args) => { - - val vis: Set[Identifier] = varsInScope.get(fd.fd).getOrElse(Set()) - args.find({ - case Variable(id) => vis.contains(id) - case _ => false - }).foreach(aliasedArg => - ctx.reporter.fatalError(aliasedArg.getPos, "Illegal passing of aliased parameter: " + aliasedArg)) - - - - updatedFunDefs.get(fd.fd) match { - case None => (None, context) - case Some(nfd) => { - val nfi = FunctionInvocation(nfd.typed(fd.tps), args.map(arg => replaceFromIDs(rewritings, arg))).copiedFrom(fi) - val fiEffects = effects(fd.fd) - (Some(mapApplication(args, nfi, nfd.typed(fd.tps).returnType, fiEffects, rewritings)), context) - } - } - } - - case app@Application(callee, args) => { - if(toIgnore(app)) (None, context) else { - val ft@FunctionType(_, to) = callee.getType - to match { - case TupleType(tps) if effects.isMutableType(tps.last) => { - val nfi = Application(callee, args.map(arg => replaceFromIDs(rewritings, arg))).copiedFrom(app) - val fiEffects = effects.functionTypeEffects(ft) - (Some(mapApplication(args, nfi, to, fiEffects, rewritings)), (bindings, rewritings, toIgnore + nfi)) - } - case _ => (None, context) - } - } - } - - //case app@Application(callee@Variable(id), args) => { - // originalFd.params.zip(newFunDef.params) - // .find(p => p._1.id == id) - // .map(p => p._2.id) match { - // case Some(newId) => - // val ft@FunctionType(_, _) = callee.getType - // val nft = makeFunctionTypeExplicit(ft) - - // if(ft == nft) (None, context) else { - // val nfi = Application(Variable(newId).copiedFrom(callee), args.map(arg => replaceFromIDs(rewritings, arg))).copiedFrom(app) - // val fiEffects = functionTypeEffects(ft) - // (Some(mapApplication(args, nfi, nft.to, fiEffects, rewritings)), context) - // } - // case None => (None, context) - // } - //} - - //case app@Application(callee@CaseClassSelector(cct, obj, id), args) => { - // ccdMap.get(cct.classDef) match { - // case None => - // (None, context) - // case Some(ncd) => { - // val nid = cct.classDef.fields.zip(ncd.fields).find(p => id == p._1.id).map(_._2.id).get - // val nccs = CaseClassSelector(ncd.typed, obj, nid).copiedFrom(callee) - // val nfi = Application(nccs, args.map(arg => replaceFromIDs(rewritings, arg))).copiedFrom(app) - // val ft@FunctionType(_, _) = callee.getType - // val nft = makeFunctionTypeExplicit(ft) - // val fiEffects = functionTypeEffects(ft) - // (Some(mapApplication(args, nfi, nft.to, fiEffects, rewritings)), context) - // } - // } - //} - - //case CaseClass(cct, args) => { - // ccdMap.get(cct.classDef) match { - // case None => - // (None, context) - // case Some(ncd) => { - // (Some(CaseClass(ncd.typed, args)), context) - // } - // } - //} - - case _ => (None, context) - } - - })(body, (aliasedParams.toSet, Map(), Set())) - } - - - //convert a function type with mutable parameters, into a function type - //that returns the mutable parameters. This makes explicit all possible - //effects of the function. This should be used for higher order functions - //declared as parameters. - private def makeFunctionTypeExplicit(tpe: FunctionType, effects: EffectsAnalysis): FunctionType = { - val newReturnTypes = tpe.from.filter(t => effects.isMutableType(t)) - if(newReturnTypes.isEmpty) - tpe - else { - FunctionType(tpe.from, TupleType(tpe.to +: newReturnTypes)) - } - } - - - private def checkAliasing(fd: FunDef, effects: EffectsAnalysis)(ctx: LeonContext): Unit = { - def checkReturnValue(body: Expr, bindings: Set[Identifier]): Unit = { - getReturnedExpr(body).foreach{ - case expr if effects.isMutableType(expr.getType) => - findReceiverId(expr).foreach(id => - if(bindings.contains(id)) - ctx.reporter.fatalError(expr.getPos, "Cannot return a shared reference to a mutable object: " + expr) - ) - case _ => () - } - } - - if(fd.canBeField && effects.isMutableType(fd.returnType)) - ctx.reporter.fatalError(fd.getPos, "A global field cannot refer to a mutable object") - - fd.body.foreach(bd => { - val params = fd.params.map(_.id).toSet - checkReturnValue(bd, params) - preMapWithContext[Set[Identifier]]((expr, bindings) => expr match { - case l@Let(id, v, b) if effects.isMutableType(v.getType) => { - if(!isExpressionFresh(v, effects)) - ctx.reporter.fatalError(v.getPos, "Illegal aliasing: " + v) - (None, bindings + id) - } - case l@LetVar(id, v, b) if effects.isMutableType(v.getType) => { - if(!isExpressionFresh(v, effects)) - ctx.reporter.fatalError(v.getPos, "Illegal aliasing: " + v) - (None, bindings + id) - } - case l@LetDef(fds, body) => { - fds.foreach(fd => fd.body.foreach(bd => checkReturnValue(bd, bindings))) - (None, bindings) - } - - case _ => (None, bindings) - })(bd, params) - }) - } - - /* - * A bit hacky, but not sure of the best way to do something like that - * currently. - */ - private def getReturnedExpr(expr: Expr): Seq[Expr] = expr match { - case Let(_, _, rest) => getReturnedExpr(rest) - case LetVar(_, _, rest) => getReturnedExpr(rest) - case Block(_, rest) => getReturnedExpr(rest) - case IfExpr(_, thenn, elze) => getReturnedExpr(thenn) ++ getReturnedExpr(elze) - case MatchExpr(_, cses) => cses.flatMap{ cse => getReturnedExpr(cse.rhs) } - case e => Seq(expr) - } - - - /* - * returns all fun def in the program, including local definitions inside - * other functions (LetDef). - */ - private def allFunDefs(pgm: Program): Seq[FunDef] = - pgm.definedFunctions.flatMap(fd => - fd.body.toSet.flatMap((bd: Expr) => - nestedFunDefsOf(bd)) + fd) - - - private def findReceiverId(o: Expr): Option[Identifier] = o match { - case Variable(id) => Some(id) - case CaseClassSelector(_, e, _) => findReceiverId(e) - case AsInstanceOf(e, ct) => findReceiverId(e) - case ArraySelect(a, _) => findReceiverId(a) - case _ => None - } - - - //private def extractFieldPath(o: Expr): (Expr, List[Identifier]) = { - // def rec(o: Expr): List[Identifier] = o match { - // case CaseClassSelector(_, r, i) => - // val res = toFieldPath(r) - // (res._1, i::res) - // case expr => (expr, Nil) - // } - // val res = rec(o) - // (res._1, res._2.reverse) - //} - - - def deepCopy(rec: Expr, nv: Expr): Expr = { - rec match { - case CaseClassSelector(_, r, id) => - val sub = copy(r, id, nv) - deepCopy(r, sub) - case as@ArraySelect(a, index) => - deepCopy(a, ArrayUpdated(a, index, nv).setPos(as)) - case expr => nv - } - } - - private def copy(cc: Expr, id: Identifier, nv: Expr): Expr = { - val ct@CaseClassType(ccd, _) = cc.getType - val newFields = ccd.fields.map(vd => - if(vd.id == id) - nv - else - CaseClassSelector(CaseClassType(ct.classDef, ct.tps), cc, vd.id) - ) - CaseClass(CaseClassType(ct.classDef, ct.tps), newFields).setPos(cc.getPos) - } - - - /* - * A fresh expression is an expression that is newly created - * and does not share memory with existing values and variables. - * - * If the expression is made of existing immutable variables (Int or - * immutable case classes), it is considered fresh as we consider all - * non mutable objects to have a value-copy semantics. - * - * It turns out that an expression of non-mutable type is always fresh, - * as it can not contains reference to a mutable object, by definition - */ - private def isExpressionFresh(expr: Expr, effects: EffectsAnalysis): Boolean = { - !effects.isMutableType(expr.getType) || (expr match { - case v@Variable(_) => !effects.isMutableType(v.getType) - case CaseClass(_, args) => args.forall(arg => isExpressionFresh(arg, effects)) - - case FiniteArray(elems, default, _) => elems.forall(p => isExpressionFresh(p._2, effects)) && default.forall(e => isExpressionFresh(e, effects)) - - //function invocation always return a fresh expression, by hypothesis (global assumption) - case FunctionInvocation(_, _) => true - - //ArrayUpdated returns a mutable array, which by definition is a clone of the original - case ArrayUpdated(_, _, _) => true - - //any other expression is conservately assumed to be non-fresh if - //any sub-expression is non-fresh (i.e. an if-then-else with a reference in one branch) - case Operator(args, _) => args.forall(arg => isExpressionFresh(arg, effects)) - }) - } - -} diff --git a/src/main/scala/leon/xlang/Constructors.scala b/src/main/scala/leon/xlang/Constructors.scala deleted file mode 100644 index c55842f920a81cd06d508f3f4e76e29ceecbafaf..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/xlang/Constructors.scala +++ /dev/null @@ -1,36 +0,0 @@ -package leon -package xlang - -import purescala.Expressions._ -import purescala.Types._ -import xlang.Expressions._ -import xlang.ExprOps._ - -object Constructors { - - def block(exprs: Seq[Expr]): Expr = { - require(exprs.nonEmpty) - - val flat = exprs.flatMap{ - case Block(es2, el) => es2 :+ el - case e2 => Seq(e2) - } - - val init = flat.init - val last = flat.last - val filtered = init.filter{ - case UnitLiteral() => false - case _ => true - } - - val finalSeq = - if(last == UnitLiteral() && filtered.last.getType == UnitType) filtered else (filtered :+ last) - - finalSeq match { - case Seq() => UnitLiteral() - case Seq(e) => e - case es => Block(es.init, es.last) - } - } - -} diff --git a/src/main/scala/leon/xlang/EffectsAnalysis.scala b/src/main/scala/leon/xlang/EffectsAnalysis.scala deleted file mode 100644 index 68861435a2bf3e10752fca3cd9d6b4d5ae3aa6d7..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/xlang/EffectsAnalysis.scala +++ /dev/null @@ -1,243 +0,0 @@ -package leon -package xlang - -import purescala.Common._ -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Extractors._ -import purescala.Types._ -import purescala.DependencyFinder -import xlang.Expressions._ - -import scala.collection.mutable.{Map => MutableMap, Set => MutableSet} - - -/** Provides effect analysis for full Leon language - * - * This holds state for caching the current state of the analysis, so if - * you modify your program you may want to create a new EffectsAnalysis - * instance. - * - * You can use it lazily by only querying effects for the functions you need. - * The internals make sure to compute only the part of the dependencies graph - * that is needed to get the effect of the function. - * - * An effect is defined as the impact that a funcion can have on its environment. - * In the Leon language, there are no global variables that aren't explicit, so - * the effect of a function is defined as the set of its parameters that are mutated - * by executing the body. It is a conservative over-abstraction, as some update operations - * might actually not modify the object, but this will still be considered as an - * effect. This is in contrast to the upcomming @pure instruction, that will actually - * emit VC to prove that the body does not modify the parameter, even with the presence - * of mutation operators. - * - * There are actually a very limited number of features relying on global state (epsilon). - * EffectsAnalysis will not take such effects into account. You need to make sure the - * program either does not rely on epsilon, or has been rewriting with the IntroduceGlobalStatePhase - * phase that introduce any global state explicitly as function parameters. In the future, - * if we do end up supporting global variables, it is likely that we will rely on another - * phase to introduce that global state explicitly into the list of parameters of functions - * that make use of it. - * - * An effect is detected by a FieldAssignment to one of the parameters that are mutable. It - * can come from transitively calling a function that perform a field assignment as well. - * If the function uses higher-order functions that take mutable types as parameters, they - * will be conservatively assumed to perform an effect as well. A function type is not by itself - * a mutable type, but if it is applied on a mutable type that is taken as a parameter as well, - * it will be detected as an effect by the EffectsAnalysis. - * TODO: maybe we could have "conditional effects", effects depending on the effects of the lambda. - * - * The EffectsAnalysis also provides functions to analyze the the mutability of a type and expression. - * The isMutableType function need to perform a graph traversal (explore all the fields recursively - * to find if one is mutable) - */ -class EffectsAnalysis { - - - private val dependencies = new DependencyFinder - private var mutableTypes: MutableMap[TypeTree, Boolean] = MutableMap.empty - - //for each fundef, the set of modified params (by index) - //once set, the value is final and won't be modified further - private val cachedEffects: MutableMap[FunDef, Set[Int]] = MutableMap.empty - - def apply(fd: FunDef): Set[Int] = cachedEffects.getOrElseUpdate(fd, { - effectsAnalysis(fd) - }) - - - - /** Determine if the type is mutable - * - * In Leon, we classify types as either mutable or immutable. Immutable - * type can be referenced freely, while mutable types must be treated with - * care. This function uses a cache, so make sure to not update your class - * def and use the same instance of EffectsAnalysis. It is fine to add - * new ClassDef types on the fly, granted that they use fresh identifiers. - */ - def isMutableType(tpe: TypeTree): Boolean = { - def rec(tpe: TypeTree, abstractClasses: Set[ClassType]): Boolean = mutableTypes.getOrElseUpdate(tpe, tpe match { - case (ct: ClassType) if abstractClasses.contains(ct) => false - case (arr: ArrayType) => true - case ct@CaseClassType(ccd, _) => ccd.fields.exists(vd => vd.isVar || rec(vd.getType, abstractClasses + ct)) - case (ct: ClassType) => ct.knownDescendants.exists(c => rec(c, abstractClasses + ct)) - case _ => false - }) - rec(tpe, Set()) - } - - /** Effects at the level of types for a function - * - * This disregards the actual implementation of a function, and considers only - * its type to determine a conservative abstraction of its effects. - */ - def functionTypeEffects(ft: FunctionType): Set[Int] = { - ft.from.zipWithIndex.flatMap{ case (vd, i) => - if(isMutableType(vd.getType)) Some(i) else None - }.toSet - } - - /* - * Check if expr is mutating variable id. This only checks if the expression - * is the mutation operation, and will not perform expression traversal to - * see if a sub-expression mutates something. - * TODO: clarify this with a function that look at the whole expression - */ - def isMutationOf(expr: Expr, id: Identifier): Boolean = expr match { - case ArrayUpdate(o, _, _) => findReceiverId(o).exists(_ == id) - case FieldAssignment(obj, _, _) => findReceiverId(obj).exists(_ == id) - case Application(callee, args) => { - val ft@FunctionType(_, _) = callee.getType - val effects = functionTypeEffects(ft) - args.map(findReceiverId(_)).zipWithIndex.exists{ - case (Some(argId), index) => argId == id && effects.contains(index) - case _ => false - } - } - case _ => false - } - - - /* - * compute effects for each function that from depends on, including any nested - * functions (LetDef). - * While computing effects for from, it might have to compute transitive effects - * of dependencies. It will update the global effects map while doing so. - */ - private def effectsAnalysis(from: FunDef): Set[Int] = { - - //all the FunDef to consider to compute the effects of from - val fds: Set[FunDef] = dependencies(from).collect{ case (fd: FunDef) => fd } + from - - //currently computed effects (incremental) - var effects: Map[FunDef, Set[Int]] = Map()//cachedEffects.filterKeys(fds.contains) - //missing dependencies for a function for its effects to be complete - var missingEffects: Map[FunDef, Set[FunctionInvocation]] = Map() - - def effectsFullyComputed(fd: FunDef): Boolean = !missingEffects.isDefinedAt(fd) - - for { - fd <- fds - } { - cachedEffects.get(fd) match { - case Some(efcts) => - effects += (fd -> efcts) - case None => - fd.body match { - case None => - effects += (fd -> Set()) - case Some(body) => { - val mutableParams = fd.params.filter(vd => isMutableType(vd.getType)) - val localAliases: Map[ValDef, Set[Identifier]] = mutableParams.map(vd => (vd, computeLocalAliases(vd.id, body))).toMap - val mutatedParams = mutableParams.filter(vd => exists(expr => localAliases(vd).exists(id => isMutationOf(expr, id)))(body)) - val mutatedParamsIndices = fd.params.zipWithIndex.flatMap{ - case (vd, i) if mutatedParams.contains(vd) => Some(i) - case _ => None - }.toSet - effects = effects + (fd -> mutatedParamsIndices) - - val missingCalls: Set[FunctionInvocation] = functionCallsOf(body).filterNot(fi => fi.tfd.fd == fd) - if(missingCalls.nonEmpty) - missingEffects += (fd -> missingCalls) - } - } - } - } - - def rec(): Unit = { - val previousMissingEffects = missingEffects - - for{ (fd, calls) <- missingEffects } { - var newMissingCalls: Set[FunctionInvocation] = calls - for(fi <- calls) { - val mutatedArgs = invocEffects(fi) - val mutatedFunParams: Set[Int] = fd.params.zipWithIndex.flatMap{ - case (vd, i) if mutatedArgs.contains(vd.id) => Some(i) - case _ => None - }.toSet - effects += (fd -> (effects(fd) ++ mutatedFunParams)) - - if(effectsFullyComputed(fi.tfd.fd)) { - newMissingCalls -= fi - } - } - if(newMissingCalls.isEmpty) - missingEffects = missingEffects - fd - else - missingEffects += (fd -> newMissingCalls) - } - - if(missingEffects != previousMissingEffects) { - rec() - } - } - - def invocEffects(fi: FunctionInvocation): Set[Identifier] = { - //TODO: the require should be fine once we consider nested functions as well - //require(effects.isDefinedAt(fi.tfd.fd) - val mutatedParams: Set[Int] = effects.get(fi.tfd.fd).getOrElse(Set()) - functionInvocationEffects(fi, mutatedParams).toSet - } - - rec() - - effects.foreach{ case (fd, efcts) => if(!cachedEffects.isDefinedAt(fd)) cachedEffects(fd) = efcts } - - effects(from) - } - - //for a given id, compute the identifiers that alias it or some part of the object refered by id - private def computeLocalAliases(id: Identifier, body: Expr): Set[Identifier] = { - def pre(expr: Expr, ids: Set[Identifier]): Set[Identifier] = expr match { - case l@Let(i, Variable(v), _) if ids.contains(v) => ids + i - case m@MatchExpr(Variable(v), cses) if ids.contains(v) => { - val newIds = cses.flatMap(mc => mc.pattern.binders) - ids ++ newIds - } - case e => ids - } - def combiner(e: Expr, ctx: Set[Identifier], ids: Seq[Set[Identifier]]): Set[Identifier] = ctx ++ ids.toSet.flatten + id - val res = preFoldWithContext(pre, combiner)(body, Set(id)) - res - } - - - private def findReceiverId(o: Expr): Option[Identifier] = o match { - case Variable(id) => Some(id) - case CaseClassSelector(_, e, _) => findReceiverId(e) - case AsInstanceOf(e, ct) => findReceiverId(e) - case ArraySelect(a, _) => findReceiverId(a) - case _ => None - } - - //return the set of modified variables arguments to a function invocation, - //given the effect of the fun def invoked. - private def functionInvocationEffects(fi: FunctionInvocation, effects: Set[Int]): Seq[Identifier] = { - fi.args.map(arg => findReceiverId(arg)).zipWithIndex.flatMap{ - case (Some(id), i) if effects.contains(i) => Some(id) - case _ => None - } - } - -} diff --git a/src/main/scala/leon/xlang/EpsilonElimination.scala b/src/main/scala/leon/xlang/EpsilonElimination.scala deleted file mode 100644 index 4d77d89dd578604ab10f322997118436ac83810d..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/xlang/EpsilonElimination.scala +++ /dev/null @@ -1,42 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon.xlang - -import leon.{UnitPhase, LeonContext} -import leon.purescala.Common._ -import leon.purescala.Definitions._ -import leon.purescala.Expressions._ -import leon.purescala.ExprOps._ -import leon.xlang.Expressions._ - -object EpsilonElimination extends UnitPhase[Program] { - - val name = "Epsilon Elimination" - val description = "Remove all epsilons from the program" - - def apply(ctx: LeonContext, pgm: Program) = { - - for (fd <- pgm.definedFunctions) { - fd.fullBody = preTransformWithBinders({ - case (eps@Epsilon(pred, tpe), binders) => - val freshName = FreshIdentifier("epsilon") - val bSeq = binders.toSeq - val freshParams = bSeq.map { _.freshen } - val newFunDef = new FunDef(freshName, Nil, freshParams map (ValDef(_)), tpe) - val epsilonVar = EpsilonVariable(eps.getPos, tpe) - val resId = FreshIdentifier("res", tpe) - val eMap: Map[Expr, Expr] = bSeq.zip(freshParams).map { - case (from, to) => (Variable(from), Variable(to)) - }.toMap ++ Seq((epsilonVar, Variable(resId))) - val postcondition = replace(eMap, pred) - newFunDef.postcondition = Some(Lambda(Seq(ValDef(resId)), postcondition)) - newFunDef.addFlags(fd.flags) - newFunDef.addFlag(Annotation("extern", Seq())) - LetDef(Seq(newFunDef), FunctionInvocation(newFunDef.typed, bSeq map Variable)) - - case (other, _) => other - }, fd.paramIds.toSet)(fd.fullBody) - } - } - -} diff --git a/src/main/scala/leon/xlang/ExprOps.scala b/src/main/scala/leon/xlang/ExprOps.scala deleted file mode 100644 index 35d7914e241d5dbf2bdb26e10bb2b249dfb0b336..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/xlang/ExprOps.scala +++ /dev/null @@ -1,52 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package xlang - -import purescala.Expressions._ -import xlang.Expressions._ -import purescala.ExprOps._ -import purescala.Common._ - -object ExprOps { - - def isXLang(expr: Expr): Boolean = exists { - _.isInstanceOf[XLangExpr] - }(expr) - - def containsEpsilon(e: Expr) = exists{ - case _ : Epsilon => true - case _ => false - }(e) - - def flattenBlocks(expr: Expr): Expr = { - postMap({ - case Block(exprs, last) => - val filtered = exprs.filter{ - case UnitLiteral() => false - case _ => true - } - val nexprs = (filtered :+ last).flatMap{ - case Block(es2, el) => es2 :+ el - case e2 => Seq(e2) - } - Some(nexprs match { - case Seq() => UnitLiteral() - case Seq(e) => e - case es => Block(es.init, es.last) - }) - case _ => - None - })(expr) - } - - def rewriteIDs(substs: Map[Identifier, Identifier], expr: Expr) : Expr = { - postMap({ - case Assignment(i, v) => substs.get(i).map(ni => Assignment(ni, v)) - case FieldAssignment(o, i, v) => substs.get(i).map(ni => FieldAssignment(o, ni, v)) - case Variable(i) => substs.get(i).map(ni => Variable(ni)) - case _ => None - })(expr) - } -} - diff --git a/src/main/scala/leon/xlang/Expressions.scala b/src/main/scala/leon/xlang/Expressions.scala deleted file mode 100644 index a1726b5ec64b910190bb2e95c5dea0af041c1120..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/xlang/Expressions.scala +++ /dev/null @@ -1,155 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package xlang - -import purescala.Common._ -import purescala.Types._ -import purescala.Expressions._ -import purescala.Extractors._ -import purescala.{PrettyPrintable, PrinterContext} -import utils._ - -object Expressions { - import purescala.PrinterHelpers._ - - trait XLangExpr extends Expr - - case class Old(id: Identifier) extends XLangExpr with Terminal with PrettyPrintable { - val getType = id.getType - - def printWith(implicit pctx: PrinterContext): Unit = { - p"old($id)" - } - } - case class OldThis(ct: ClassType) extends XLangExpr with Terminal with PrettyPrintable { - val getType = ct - - def printWith(implicit pctx: PrinterContext): Unit = { - p"old(this)" - } - } - - case class Block(exprs: Seq[Expr], last: Expr) extends XLangExpr with Extractable with PrettyPrintable { - def extract: Option[(Seq[Expr], (Seq[Expr])=>Expr)] = { - Some((exprs :+ last, exprs => Block(exprs.init, exprs.last))) - } - - override def getPos = exprs.headOption match { - case Some(head) => Position.between(head.getPos, last.getPos) - case None => last.getPos - } - - def printWith(implicit pctx: PrinterContext) { - p"${nary(exprs :+ last, "\n")}" - } - - val getType = last.getType - - override def isSimpleExpr = false - } - - case class Assignment(varId: Identifier, expr: Expr) extends XLangExpr with Extractable with PrettyPrintable { - val getType = UnitType - - def extract: Option[(Seq[Expr], Seq[Expr]=>Expr)] = { - Some((Seq(expr), (es: Seq[Expr]) => Assignment(varId, es.head))) - } - - def printWith(implicit pctx: PrinterContext) { - p"$varId = $expr" - } - } - - case class FieldAssignment(obj: Expr, varId: Identifier, expr: Expr) extends XLangExpr with Extractable with PrettyPrintable { - val getType = UnitType - - def extract: Option[(Seq[Expr], Seq[Expr]=>Expr)] = { - Some((Seq(obj, expr), (es: Seq[Expr]) => FieldAssignment(es(0), varId, es(1)))) - } - - def printWith(implicit pctx: PrinterContext) { - p"${obj}.${varId} = ${expr}" - } - } - - case class While(cond: Expr, body: Expr) extends XLangExpr with Extractable with PrettyPrintable { - val getType = UnitType - var invariant: Option[Expr] = None - - def getInvariant: Expr = invariant.get - def setInvariant(inv: Expr) = { invariant = Some(inv); this } - def setInvariant(inv: Option[Expr]) = { invariant = inv; this } - - def extract: Option[(Seq[Expr], Seq[Expr]=>Expr)] = { - Some((Seq(cond, body) ++ invariant, { (es:Seq[Expr]) => es match { - case Seq(e1, e2) => While(e1, e2).setPos(this) - case Seq(e1, e2, e3) => While(e1, e2).setInvariant(e3).setPos(this) - }})) - } - - def printWith(implicit pctx: PrinterContext) { - invariant match { - case Some(inv) => - p"""|@invariant($inv) - |""" - case None => - } - - p"""|while($cond) { - | $body - |}""" - } - } - - case class Epsilon(pred: Expr, tpe: TypeTree) extends XLangExpr with Extractable with PrettyPrintable { - def extract: Option[(Seq[Expr], Seq[Expr]=>Expr)] = { - Some((Seq(pred), (es: Seq[Expr]) => Epsilon(es.head, this.getType).setPos(this))) - } - - def printWith(implicit pctx: PrinterContext) { - p"epsilon(x${getPos.line}_${getPos.col} => $pred)" - } - - val getType = tpe - } - - case class EpsilonVariable(pos: Position, tpe: TypeTree) extends XLangExpr with Terminal with PrettyPrintable { - - def printWith(implicit pctx: PrinterContext) { - p"x${pos.line}_${pos.col}" - } - - val getType = tpe - } - - //same as let, buf for mutable variable declaration - case class LetVar(binder: Identifier, value: Expr, body: Expr) extends XLangExpr with Extractable with PrettyPrintable { - val getType = body.getType - - def extract: Option[(Seq[Expr], Seq[Expr]=>Expr)] = { - Some( Seq(value, body), (es:Seq[Expr]) => LetVar(binder, es(0), es(1)) ) - } - - def printWith(implicit pctx: PrinterContext) { - p"""|var $binder = $value - |$body""" - } - - override def isSimpleExpr = false - } - - case class ArrayUpdate(array: Expr, index: Expr, newValue: Expr) extends XLangExpr with Extractable with PrettyPrintable { - val getType = UnitType - - def extract: Option[(Seq[Expr], (Seq[Expr])=>Expr)] = { - val ArrayUpdate(t1, t2, t3) = this - Some((Seq(t1,t2,t3), (as: Seq[Expr]) => ArrayUpdate(as(0), as(1), as(2)))) - } - - def printWith(implicit pctx: PrinterContext) { - p"$array($index) = $newValue" - } - } - -} diff --git a/src/main/scala/leon/xlang/FixReportLabels.scala b/src/main/scala/leon/xlang/FixReportLabels.scala deleted file mode 100644 index 7615ec437b487aa73c2d28257179b705c80fac18..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/xlang/FixReportLabels.scala +++ /dev/null @@ -1,50 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package xlang - -import leon.purescala.Definitions.IsLoop -import leon.verification._ - -object FixReportLabels extends SimpleLeonPhase[VerificationReport, VerificationReport]{ - - override val name: String = "fixReportLabels" - override val description: String = "Fix verification report labels to reflect the original imperative VCs" - - // TODO: something of this sort should be included - // case object InvariantEntry extends VCKind("invariant init", "inv. init.") - case object InvariantPost extends VCKind("invariant postcondition", "inv. post.") - case object InvariantInd extends VCKind("invariant inductive", "inv. ind.") - - def apply(ctx: LeonContext, vr: VerificationReport): VerificationReport = { - //this is enough to convert invariant postcondition and inductive conditions. However the initial validity - //of the invariant (before entering the loop) will still appear as a regular function precondition - //To fix this, we need some information in the VCs about which function the precondtion refers to - //although we could arguably say that the term precondition is general enough to refer both to a loop invariant - //precondition and a function invocation precondition - - val newResults = for ((vc, ovr) <- vr.results) yield { - val (vcKind, fd) = vc.fd.flags.collectFirst { case IsLoop(orig) => orig } match { - case None => (vc.kind, vc.fd) - case Some(owner) => (vc.kind.underlying match { - case VCKinds.Precondition => InvariantInd - case VCKinds.Postcondition => InvariantPost - case _ => vc.kind - }, owner) - } - - val nvc = VC( - vc.condition, - fd, - vcKind - ).setPos(vc.getPos) - - nvc -> ovr - - } - - VerificationReport(vr.program, newResults) - - } - -} diff --git a/src/main/scala/leon/xlang/ImperativeCodeElimination.scala b/src/main/scala/leon/xlang/ImperativeCodeElimination.scala deleted file mode 100644 index cf70d2eb13856bccfb819d63d113a85616e8a902..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/xlang/ImperativeCodeElimination.scala +++ /dev/null @@ -1,383 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package xlang - -import leon.purescala.Common._ -import leon.purescala.Definitions._ -import leon.purescala.Expressions._ -import leon.purescala.Extractors._ -import leon.purescala.Constructors._ -import leon.purescala.ExprOps._ -import leon.purescala.TypeOps.leastUpperBound -import leon.purescala.Types._ -import leon.xlang.Expressions._ - -object ImperativeCodeElimination extends UnitPhase[Program] { - - val name = "Imperative Code Elimination" - val description = "Transform imperative constructs into purely functional code" - - def apply(ctx: LeonContext, pgm: Program): Unit = { - for { - fd <- pgm.definedFunctions - body <- fd.body if exists(requireRewriting)(body) - } { - val (res, scope, _) = toFunction(body)(State(fd, Set(), Map())) - fd.body = Some(scope(res)) - - } - - //probably not the cleanest way to do it, but if somehow we still have Old - //expressions at that point, they can be safely removed as the object is - //equals to its original value - for { - fd <- pgm.definedFunctions - } { - fd.postcondition = fd.postcondition.map(post => { - preMap{ - case Old(v) => Some(v.toVariable) - case _ => None - }(post) - }) - } - - } - - /* varsInScope refers to variable declared in the same level scope. - Typically, when entering a nested function body, the scope should be - reset to empty */ - private case class State( - parent: FunDef, - varsInScope: Set[Identifier], - funDefsMapping: Map[FunDef, (FunDef, List[Identifier])] - ) { - def withVar(i: Identifier) = copy(varsInScope = varsInScope + i) - def withFunDef(fd: FunDef, nfd: FunDef, ids: List[Identifier]) = - copy(funDefsMapping = funDefsMapping + (fd -> (nfd, ids))) - } - - //return a "scope" consisting of purely functional code that defines potentially needed - //new variables (val, not var) and a mapping for each modified variable (var, not val :) ) - //to their new name defined in the scope. The first returned valued is the value of the expression - //that should be introduced as such in the returned scope (the val already refers to the new names) - private def toFunction(expr: Expr)(implicit state: State): (Expr, Expr => Expr, Map[Identifier, Identifier]) = { - import state._ - expr match { - case LetVar(id, e, b) => - val newId = id.freshen - val (rhsVal, rhsScope, rhsFun) = toFunction(e) - val (bodyRes, bodyScope, bodyFun) = toFunction(b)(state.withVar(id)) - val scope = (body: Expr) => rhsScope(Let(newId, rhsVal, replaceNames(rhsFun + (id -> newId), bodyScope(body))).copiedFrom(expr)) - (bodyRes, scope, (rhsFun + (id -> newId)) ++ bodyFun) - - case Assignment(id, e) => - assert(varsInScope.contains(id)) - val newId = id.freshen - val (rhsVal, rhsScope, rhsFun) = toFunction(e) - val scope = (body: Expr) => rhsScope(Let(newId, rhsVal, body).copiedFrom(expr)) - (UnitLiteral(), scope, rhsFun + (id -> newId)) - - case ite@IfExpr(cond, tExpr, eExpr) => - val (cRes, cScope, cFun) = toFunction(cond) - val (tRes, tScope, tFun) = toFunction(tExpr) - val (eRes, eScope, eFun) = toFunction(eExpr) - - val iteRType = leastUpperBound(tRes.getType, eRes.getType).getOrElse(Untyped) - - val modifiedVars: Seq[Identifier] = (tFun.keys ++ eFun.keys).toSet.intersect(varsInScope).toSeq - val resId = FreshIdentifier("res", iteRType) - val freshIds = modifiedVars.map( _.freshen ) - val iteType = tupleTypeWrap(resId.getType +: freshIds.map(_.getType)) - - val thenVal = tupleWrap(tRes +: modifiedVars.map(vId => tFun.getOrElse(vId, vId).toVariable)) - val elseVal = tupleWrap(eRes +: modifiedVars.map(vId => eFun.getOrElse(vId, vId).toVariable)) - val iteExpr = IfExpr(cRes, replaceNames(cFun, tScope(thenVal)), replaceNames(cFun, eScope(elseVal))).copiedFrom(ite) - - val scope = (body: Expr) => { - val tupleId = FreshIdentifier("t", iteType) - cScope(Let(tupleId, iteExpr, Let( - resId, - tupleSelect(tupleId.toVariable, 1, modifiedVars.nonEmpty), - freshIds.zipWithIndex.foldLeft(body)((b, id) => - Let(id._1, tupleSelect(tupleId.toVariable, id._2 + 2, true), b) - )) - ).copiedFrom(expr)) - } - - (resId.toVariable, scope, cFun ++ modifiedVars.zip(freshIds).toMap) - - case m @ MatchExpr(scrut, cses) => - val csesRhs = cses.map(_.rhs) //we can ignore pattern, and the guard is required to be pure - val (csesRes, csesScope, csesFun) = csesRhs.map(toFunction).unzip3 - val (scrutRes, scrutScope, scrutFun) = toFunction(scrut) - - val modifiedVars: Seq[Identifier] = csesFun.toSet.flatMap((m: Map[Identifier, Identifier]) => m.keys).intersect(varsInScope).toSeq - val resId = FreshIdentifier("res", m.getType) - val freshIds = modifiedVars.map(id => FreshIdentifier(id.name, id.getType)) - val matchType = tupleTypeWrap(resId.getType +: freshIds.map(_.getType)) - - val csesVals = csesRes.zip(csesFun).map { - case (cRes, cFun) => tupleWrap(cRes +: modifiedVars.map(vId => cFun.getOrElse(vId, vId).toVariable)) - } - - val newRhs = csesVals.zip(csesScope).map { - case (cVal, cScope) => replaceNames(scrutFun, cScope(cVal)) - } - val matchE = matchExpr(scrutRes, cses.zip(newRhs).map { - case (mc @ MatchCase(pat, guard, _), newRhs) => - MatchCase(pat, guard map { replaceNames(scrutFun, _) }, newRhs).setPos(mc) - }).setPos(m) - - val scope = (body: Expr) => { - val tupleId = FreshIdentifier("t", matchType) - scrutScope( - Let(tupleId, matchE, - Let(resId, tupleSelect(tupleId.toVariable, 1, freshIds.nonEmpty), - freshIds.zipWithIndex.foldLeft(body)((b, id) => - Let(id._1, tupleSelect(tupleId.toVariable, id._2 + 2, true), b) - ) - ) - ) - ) - } - - (resId.toVariable, scope, scrutFun ++ modifiedVars.zip(freshIds).toMap) - - case wh@While(cond, body) => - val whileFunDef = new FunDef(parent.id.duplicate(name = (parent.id.name + "While")), Nil, Nil, UnitType).setPos(wh) - whileFunDef.addFlag(IsLoop(parent)) - whileFunDef.body = Some( - IfExpr(cond, - Block(Seq(body), FunctionInvocation(whileFunDef.typed, Seq()).setPos(wh)), - UnitLiteral())) - whileFunDef.precondition = wh.invariant - whileFunDef.postcondition = Some( - Lambda( - Seq(ValDef(FreshIdentifier("bodyRes", UnitType))), - and(Not(getFunctionalResult(cond)), wh.invariant.getOrElse(BooleanLiteral(true))).setPos(wh) - ).setPos(wh) - ) - - val newExpr = LetDef(Seq(whileFunDef), FunctionInvocation(whileFunDef.typed, Seq()).setPos(wh)).setPos(wh) - toFunction(newExpr) - - case Block(Seq(), expr) => - toFunction(expr) - - case Block(exprs, expr) => - val (scope, fun) = exprs.foldRight((body: Expr) => body, Map[Identifier, Identifier]())((e, acc) => { - val (accScope, accFun) = acc - val (rVal, rScope, rFun) = toFunction(e) - val scope = (body: Expr) => { - rVal match { - case FunctionInvocation(tfd, args) => - rScope(replaceNames(rFun, Let(FreshIdentifier("tmp", tfd.returnType), rVal, accScope(body)))) - case _ => - rScope(replaceNames(rFun, accScope(body))) - } - - } - (scope, rFun ++ accFun) - }) - val (lastRes, lastScope, lastFun) = toFunction(expr) - val finalFun = fun ++ lastFun - ( - replaceNames(finalFun, lastRes), - (body: Expr) => scope(replaceNames(fun, lastScope(body))), - finalFun - ) - - //pure expression (that could still contain side effects as a subexpression) (evaluation order is from left to right) - case Let(id, e, b) => - val (bindRes, bindScope, bindFun) = toFunction(e) - val (bodyRes, bodyScope, bodyFun) = toFunction(b) - ( - bodyRes, - (b2: Expr) => bindScope(Let(id, bindRes, replaceNames(bindFun, bodyScope(b2))).copiedFrom(expr)), - bindFun ++ bodyFun - ) - - //a function invocation can update variables in scope. - case fi@FunctionInvocation(tfd, args) => - - - val (recArgs, argScope, argFun) = args.foldRight((Seq[Expr](), (body: Expr) => body, Map[Identifier, Identifier]()))((arg, acc) => { - val (accArgs, accScope, accFun) = acc - val (argVal, argScope, argFun) = toFunction(arg) - val newScope = (body: Expr) => argScope(replaceNames(argFun, accScope(body))) - (argVal +: accArgs, newScope, argFun ++ accFun) - }) - - val fd = tfd.fd - state.funDefsMapping.get(fd) match { - case Some((newFd, modifiedVars)) => { - val newInvoc = FunctionInvocation(newFd.typed, recArgs ++ modifiedVars.map(id => id.toVariable)).setPos(fi) - val freshNames = modifiedVars.map(id => id.freshen) - val tmpTuple = FreshIdentifier("t", newFd.returnType) - - val scope = (body: Expr) => { - argScope(Let(tmpTuple, newInvoc, - freshNames.zipWithIndex.foldRight(body)((p, b) => - Let(p._1, TupleSelect(tmpTuple.toVariable, p._2 + 2), b)) - )) - } - val newMap = argFun ++ modifiedVars.zip(freshNames).toMap - - (TupleSelect(tmpTuple.toVariable, 1), scope, newMap) - } - case None => - (FunctionInvocation(tfd, recArgs).copiedFrom(fi), argScope, argFun) - } - - - case LetDef(fds, b) => - - if(fds.size > 1) { - //TODO: no support for true mutual recursion - toFunction(LetDef(Seq(fds.head), LetDef(fds.tail, b))) - } else { - - val fd = fds.head - - def fdWithoutSideEffects = { - fd.body.foreach { bd => - val (fdRes, fdScope, _) = toFunction(bd) - fd.body = Some(fdScope(fdRes)) - } - val (bodyRes, bodyScope, bodyFun) = toFunction(b) - (bodyRes, (b2: Expr) => LetDef(Seq(fd), bodyScope(b2)).setPos(fd).copiedFrom(expr), bodyFun) - } - - fd.body match { - case Some(bd) => { - - val modifiedVars: List[Identifier] = - collect[Identifier]({ - case Assignment(v, _) => Set(v) - case FunctionInvocation(tfd, _) => state.funDefsMapping.get(tfd.fd).map(p => p._2.toSet).getOrElse(Set()) - case _ => Set() - })(bd).intersect(state.varsInScope).toList - - if(modifiedVars.isEmpty) fdWithoutSideEffects else { - - val freshNames: List[Identifier] = modifiedVars.map(id => id.freshen) - - val newParams: Seq[ValDef] = fd.params ++ freshNames.map(n => ValDef(n)) - val freshVarDecls: List[Identifier] = freshNames.map(id => id.freshen) - - val rewritingMap: Map[Identifier, Identifier] = - modifiedVars.zip(freshVarDecls).toMap - val freshBody = - preMap({ - case Assignment(v, e) => rewritingMap.get(v).map(nv => Assignment(nv, e)) - case Variable(id) => rewritingMap.get(id).map(nid => Variable(nid)) - case _ => None - })(bd) - val wrappedBody = freshNames.zip(freshVarDecls).foldLeft(freshBody)((body, p) => { - LetVar(p._2, Variable(p._1), body) - }) - - val newReturnType = TupleType(fd.returnType :: modifiedVars.map(_.getType)) - - val newFd = new FunDef(fd.id.freshen, fd.tparams, newParams, newReturnType).setPos(fd) - newFd.addFlags(fd.flags) - - val (fdRes, fdScope, fdFun) = - toFunction(wrappedBody)( - State(state.parent, - Set(), - state.funDefsMapping.map{case (fd, (nfd, mvs)) => (fd, (nfd, mvs.map(v => rewritingMap.getOrElse(v, v))))} + - (fd -> ((newFd, freshVarDecls)))) - ) - val newRes = Tuple(fdRes :: freshVarDecls.map(vd => fdFun(vd).toVariable)) - val newBody = fdScope(newRes) - - newFd.body = Some(newBody) - newFd.precondition = fd.precondition.map(prec => { - replace(modifiedVars.zip(freshNames).map(p => (p._1.toVariable, p._2.toVariable)).toMap, prec) - }) - newFd.postcondition = fd.postcondition.map(post => { - val Lambda(Seq(res), postBody) = post - val newRes = ValDef(FreshIdentifier(res.id.name, newFd.returnType)) - - val newBody = - replace( - modifiedVars.zipWithIndex.map{ case (v, i) => - (v.toVariable, TupleSelect(newRes.toVariable, i+2)): (Expr, Expr)}.toMap ++ - modifiedVars.zip(freshNames).map{ case (ov, nv) => - (Old(ov), nv.toVariable)}.toMap + - (res.toVariable -> TupleSelect(newRes.toVariable, 1)), - postBody) - Lambda(Seq(newRes), newBody).setPos(post) - }) - - val (bodyRes, bodyScope, bodyFun) = toFunction(b)(state.withFunDef(fd, newFd, modifiedVars)) - (bodyRes, (b2: Expr) => LetDef(Seq(newFd), bodyScope(b2)).copiedFrom(expr), bodyFun) - } - } - case None => fdWithoutSideEffects - } - } - - //TODO: handle vars in scope, just like LetDef - case ld@Lambda(params, body) => - val (bodyVal, bodyScope, bodyFun) = toFunction(body) - (Lambda(params, bodyScope(bodyVal)).copiedFrom(ld), (e: Expr) => e, Map()) - - case c @ Choose(b) => - //Recall that Choose cannot mutate variables from the scope - (c, (b2: Expr) => b2, Map()) - - case And(args) => - val ifExpr = args.reduceRight((el, acc) => IfExpr(el, acc, BooleanLiteral(false))) - toFunction(ifExpr) - - case Or(args) => - val ifExpr = args.reduceRight((el, acc) => IfExpr(el, BooleanLiteral(true), acc)) - toFunction(ifExpr) - - //TODO: this should be handled properly by the Operator case, but there seems to be a subtle bug in the way Let's are lifted - // which leads to Assert refering to the wrong value of a var in some cases. - case a@Assert(cond, msg, body) => - val (condVal, condScope, condFun) = toFunction(cond) - val (bodyRes, bodyScope, bodyFun) = toFunction(body) - val scope = (body: Expr) => condScope(Assert(condVal, msg, replaceNames(condFun, bodyScope(body))).copiedFrom(a)) - (bodyRes, scope, condFun ++ bodyFun) - - - case n @ Operator(args, recons) => - val (recArgs, scope, fun) = args.foldRight((Seq[Expr](), (body: Expr) => body, Map[Identifier, Identifier]()))((arg, acc) => { - val (accArgs, accScope, accFun) = acc - val (argVal, argScope, argFun) = toFunction(arg) - val newScope = (body: Expr) => argScope(replaceNames(argFun, accScope(body))) - (argVal +: accArgs, newScope, argFun ++ accFun) - }) - - (recons(recArgs).copiedFrom(n), scope, fun) - - case _ => - sys.error("not supported: " + expr) - } - } - - def replaceNames(fun: Map[Identifier, Identifier], expr: Expr) = replaceFromIDs(fun mapValues Variable, expr) - - - /* Extract functional result value. Useful to remove side effect from conditions when moving it to post-condition */ - private def getFunctionalResult(expr: Expr): Expr = { - preMap({ - case Block(_, res) => Some(res) - case _ => None - })(expr) - } - - private def requireRewriting(expr: Expr) = expr match { - case (e: Block) => true - case (e: Assignment) => true - case (e: While) => true - case (e: LetVar) => true - case _ => false - } - -} diff --git a/src/main/scala/leon/xlang/IntroduceGlobalStatePhase.scala b/src/main/scala/leon/xlang/IntroduceGlobalStatePhase.scala deleted file mode 100644 index 28e385a9a2308c86cf5f15e395fd8d904e914629..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/xlang/IntroduceGlobalStatePhase.scala +++ /dev/null @@ -1,168 +0,0 @@ -/* Copyright 2009-2015 EPFL, Lausanne */ -package leon.xlang - -import leon.TransformationPhase -import leon.LeonContext -import leon.purescala.Common._ -import leon.purescala.Definitions._ -import leon.purescala.Expressions._ -import leon.purescala.ExprOps._ -import leon.purescala.DefOps._ -import leon.purescala.Types._ -import leon.purescala.Constructors._ -import leon.purescala.Extractors._ -import leon.xlang.Expressions._ - -object IntroduceGlobalStatePhase extends TransformationPhase { - - val name = "introduce-global-state" - val description = "Introduce a global state passed around to all functions that depend on it" - - - override def apply(ctx: LeonContext, pgm: Program): Program = { - - val globalStateCCD = new CaseClassDef(FreshIdentifier("GlobalState"), Seq(), None, false) - val epsilonSeed = FreshIdentifier("epsilonSeed", IntegerType) - globalStateCCD.setFields(Seq(ValDef(epsilonSeed).setIsVar(true))) - - val fds = allFunDefs(pgm) - var updatedFunctions: Map[FunDef, FunDef] = Map() - - val statefulFunDefs = funDefsNeedingState(pgm) - - /* - * The first pass will introduce all new function definitions, - * so that in the next pass we can update function invocations - */ - for { - fd <- fds if statefulFunDefs.contains(fd) - } { - updatedFunctions += (fd -> extendFunDefWithState(fd, globalStateCCD)(ctx)) - } - - for { - fd <- fds if statefulFunDefs.contains(fd) - } { - updateBody(fd, updatedFunctions, globalStateCCD, epsilonSeed)(ctx) - } - - val pgm0 = replaceDefsInProgram(pgm)(updatedFunctions) - - val globalStateModule = ModuleDef(FreshIdentifier("GlobalStateModule"), Seq(globalStateCCD), false) - val globalStateUnit = UnitDef(FreshIdentifier("GlobalStateUnit"), List("leon", "internal"), Seq(), Seq(globalStateModule), false) - pgm0.copy(units = globalStateUnit :: pgm0.units) - } - - private def extendFunDefWithState(fd: FunDef, stateCCD: CaseClassDef)(ctx: LeonContext): FunDef = { - val newParams = fd.params :+ ValDef(FreshIdentifier("globalState", stateCCD.typed)) - val newFunDef = new FunDef(fd.id.freshen, fd.tparams, newParams, fd.returnType) - newFunDef.addFlags(fd.flags) - newFunDef.setPos(fd) - newFunDef - } - - private def updateBody(fd: FunDef, updatedFunctions: Map[FunDef, FunDef], globalStateCCD: CaseClassDef, epsilonSeed: Identifier)(ctx: LeonContext): FunDef = { - val nfd = updatedFunctions(fd) - val stateParam: ValDef = nfd.params.last - - nfd.body = fd.body.map(body => postMap{ - case fi@FunctionInvocation(efd, args) if updatedFunctions.contains(efd.fd) => { - Some(FunctionInvocation(updatedFunctions(efd.fd).typed(efd.tps), args :+ stateParam.id.toVariable)) - } - case eps@Epsilon(pred, _) => { - val nextEpsilonSeed = Plus( - CaseClassSelector(globalStateCCD.typed, stateParam.id.toVariable, epsilonSeed), - InfiniteIntegerLiteral(1)) - Some(Block(Seq(FieldAssignment(stateParam.id.toVariable, epsilonSeed, nextEpsilonSeed)), eps)) - } - case _ => None - }(body)) - - nfd.precondition = fd.precondition - nfd.postcondition = fd.postcondition - - nfd - } - - def funDefsNeedingState(pgm: Program): Set[FunDef] = { - - def compute(body: Expr): Boolean = exists{ - case Epsilon(_, _) => true - case _ => false - }(body) - - def combine(b1: Boolean, b2: Boolean) = b1 || b2 - - programFixpoint(pgm, compute, combine).filter(p => p._2).keySet - } - - /* - * compute some A for each function in the program, including any nested - * functions (LetDef). The A value is transitive, combined with the A value - * of all called functions in the body. The combine function combines the current - * value computed with a new value from a function invocation. - */ - private def programFixpoint[A](pgm: Program, compute: (Expr) => A, combine: (A, A) => A): Map[FunDef, A] = { - - //currently computed results (incremental) - var res: Map[FunDef, A] = Map() - //missing dependencies for a function - var missingDependencies: Map[FunDef, Set[FunctionInvocation]] = Map() - - def fullyComputed(fd: FunDef): Boolean = !missingDependencies.isDefinedAt(fd) - - for { - fd <- allFunDefs(pgm) - } { - fd.body match { - case None => - () //TODO: maybe some default value? res += (fd -> Set()) - case Some(body) => { - res = res + (fd -> compute(body)) - val missingCalls: Set[FunctionInvocation] = functionCallsOf(body).filterNot(fi => fi.tfd.fd == fd) - if(missingCalls.nonEmpty) - missingDependencies += (fd -> missingCalls) - } - } - } - - def rec(): Unit = { - val previousMissingDependencies = missingDependencies - - for{ (fd, calls) <- missingDependencies } { - var newMissingCalls: Set[FunctionInvocation] = calls - for(fi <- calls) { - val newA = res.get(fi.tfd.fd).map(ra => combine(res(fd), ra)).getOrElse(res(fd)) - res += (fd -> newA) - - if(fullyComputed(fi.tfd.fd)) { - newMissingCalls -= fi - } - } - if(newMissingCalls.isEmpty) - missingDependencies = missingDependencies - fd - else - missingDependencies += (fd -> newMissingCalls) - } - - if(missingDependencies != previousMissingDependencies) { - rec() - } - } - - rec() - res - } - - - /* - * returns all fun def in the program, including local definitions inside - * other functions (LetDef). - */ - private def allFunDefs(pgm: Program): Seq[FunDef] = - pgm.definedFunctions.flatMap(fd => - fd.body.toSet.flatMap((bd: Expr) => - nestedFunDefsOf(bd)) + fd) - - -} diff --git a/src/main/scala/leon/xlang/NoXLangFeaturesChecking.scala b/src/main/scala/leon/xlang/NoXLangFeaturesChecking.scala deleted file mode 100644 index fc7c3f76b933b4b4d38da2391eaf5c6fcad9813e..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/xlang/NoXLangFeaturesChecking.scala +++ /dev/null @@ -1,44 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package xlang - -import purescala.ExprOps.collect -import purescala.Definitions._ - -import utils.Position - -import xlang.Expressions._ - -object NoXLangFeaturesChecking extends UnitPhase[Program] { - - val name = "no-xlang" - val description = "Ensure and enforce that no xlang features are used" - - override def apply(ctx: LeonContext, pgm: Program): Unit = { - val errors = pgm.definedFunctions.flatMap(fd => collect[(Position, String)]{ - case (e: Block) => - Set((e.getPos, "Block expressions require xlang desugaring")) - case (e: Assignment) => - Set((e.getPos, "Mutating variables requires xlang desugaring")) - case (e: While) => - Set((e.getPos, "While expressions require xlang desugaring")) - case (e: Epsilon) => - Set((e.getPos, "Usage of epsilons requires xlang desugaring")) - case (e: EpsilonVariable) => - Set((e.getPos, "Usage of epsilons requires xlang desugaring")) - case (e: LetVar) => - Set((e.getPos, "Mutable variables (e.g. 'var x' instead of 'val x') require xlang desugaring")) - case (e: ArrayUpdate) => - Set((e.getPos, "In-place updates of arrays require xlang desugaring")) - case _ => - Set() - }(fd.fullBody)) - - for ((p, msg) <- errors) { - ctx.reporter.error(p, msg) - } - } - -} - diff --git a/src/main/scala/leon/xlang/XLangCleanupPhase.scala b/src/main/scala/leon/xlang/XLangCleanupPhase.scala deleted file mode 100644 index 10feb89fe68ad880300c0b0f88387d280ee2dd8a..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/xlang/XLangCleanupPhase.scala +++ /dev/null @@ -1,204 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package xlang - -import purescala.Common._ -import purescala.Definitions._ -import purescala.DefinitionTransformer -import purescala.DefOps._ -import purescala.Expressions._ -import purescala.Extractors._ -import purescala.Constructors._ -import purescala.Types._ - -/** Cleanup the program after running XLang desugaring. - * - * This functions simplifies away typical pattern of expressions - * that can be generated during xlang desugaring phase. The most - * common case is the generation of function returning tuple with - * Unit in it, which can be safely eliminated. - */ -object XLangCleanupPhase extends TransformationPhase { - - val name = "xlang cleanup" - val description = "Cleanup program after running xlang desugaring" - - //private var fun2FreshFun: Map[FunDef, FunDef] = Map() - //private var id2FreshId: Map[Identifier, Identifier] = Map() - - override def apply(ctx: LeonContext, program: Program): Program = { - - val transformer = new DefinitionTransformer { - override def transformType(tpe: TypeTree): Option[TypeTree] = tpe match { - case (tt: TupleType) if tt.bases.exists(_ == UnitType) => - Some(tupleTypeWrap(tt.bases.filterNot(_ == UnitType))) - case _ => None - } - - override def transformExpr(expr: Expr)(implicit bindings: Map[Identifier, Identifier]): Option[Expr] = expr match { - case sel@TupleSelect(IsTyped(t, TupleType(bases)), index) => - if(bases(index-1) == UnitType) - Some(UnitLiteral()) - else { - val nbUnitsUntilIndex = bases.take(index).count(_ == UnitType) - if(nbUnitsUntilIndex == 0) - None - else if(bases.count(_ != UnitType) == 1) - Some(t) - else - Some(TupleSelect(t, index - nbUnitsUntilIndex).copiedFrom(sel)) - } - case tu@Tuple(es) if es.exists(_.getType == UnitType) => - Some(tupleWrap(es.filterNot(_.getType == UnitType)).copiedFrom(tu)) - case let@Let(id, IsTyped(t, tt@TupleType(bases)), rest) if bases.exists(_.getType == UnitType) => - val ntt = tupleTypeFilterUnits(tt) - val nid = id.duplicate(tpe=ntt) - Some(Let(nid, t, transform(rest)(bindings + (id -> nid))).copiedFrom(let)) - - case _ => None - } - } - - val cdsMap = program.definedClasses.map(cd => cd -> transformer.transform(cd)).toMap - val fdsMap = program.definedFunctions.map(fd => fd -> transformer.transform(fd)).toMap - val pgm = replaceDefsInProgram(program)(fdsMap, cdsMap) - pgm - } - - private def tupleTypeFilterUnits(tt: TupleType): TypeTree = tupleTypeWrap(tt.bases.filterNot(_ == UnitType)) -} - -// val newUnits = pgm.units map { u => u.copy(defs = u.defs.map { -// case m: ModuleDef => -// fun2FreshFun = Map() -// val allFuns = m.definedFunctions -// //first introduce new signatures without Unit parameters -// allFuns.foreach(fd => { -// if(fd.returnType != UnitType && fd.params.exists(vd => vd.getType == UnitType)) { -// val freshFunDef = fd.duplicate(params = fd.params.filterNot(vd => vd.getType == UnitType)) -// fun2FreshFun += (fd -> freshFunDef) -// } else { -// fun2FreshFun += (fd -> fd) //this will make the next step simpler -// } -// }) -// -// //then apply recursively to the bodies -// val newFuns = allFuns.collect{ case fd if fd.returnType != UnitType => -// val newFd = fun2FreshFun(fd) -// newFd.fullBody = removeUnit(fd.fullBody) -// newFd -// } -// -// ModuleDef(m.id, m.definedClasses ++ newFuns, m.isPackageObject ) -// case d => -// d -// })} -// -// -// Program(newUnits) -// } -// -// private def simplifyType(tpe: TypeTree): TypeTree = tpe match { -// case TupleType(tpes) => tupleTypeWrap(tpes.map(simplifyType).filterNot{ _ == UnitType }) -// case t => t -// } -// -// //remove unit value as soon as possible, so expr should never be equal to a unit -// private def removeUnit(expr: Expr): Expr = { -// assert(expr.getType != UnitType) -// expr match { -// case fi@FunctionInvocation(tfd, args) => -// val newArgs = args.filterNot(arg => arg.getType == UnitType) -// FunctionInvocation(fun2FreshFun(tfd.fd).typed(tfd.tps), newArgs).setPos(fi) -// -// case IsTyped(Tuple(args), TupleType(tpes)) => -// val newArgs = tpes.zip(args).collect { -// case (tp, arg) if tp != UnitType => arg -// } -// tupleWrap(newArgs.map(removeUnit)) // @mk: FIXME this may actually return a Unit, is that cool? -// -// case ts@TupleSelect(t, index) => -// val TupleType(tpes) = t.getType -// val simpleTypes = tpes map simplifyType -// val newArity = tpes.count(_ != UnitType) -// val newIndex = simpleTypes.take(index).count(_ != UnitType) -// tupleSelect(removeUnit(t), newIndex, newArity) -// -// case Let(id, e, b) => -// if(id.getType == UnitType) -// removeUnit(b) -// else { -// id.getType match { -// case TupleType(tpes) if tpes.contains(UnitType) => { -// val newTupleType = tupleTypeWrap(tpes.filterNot(_ == UnitType)) -// val freshId = FreshIdentifier(id.name, newTupleType) -// id2FreshId += (id -> freshId) -// val newBody = removeUnit(b) -// id2FreshId -= id -// Let(freshId, removeUnit(e), newBody) -// } -// case _ => Let(id, removeUnit(e), removeUnit(b)) -// } -// } -// -// case LetDef(fds, b) => -// val nonUnits = fds.filter(fd => fd.returnType != UnitType) -// if(nonUnits.isEmpty) { -// removeUnit(b) -// } else { -// val fdtoFreshFd = for(fd <- nonUnits) yield { -// val m = if(fd.params.exists(vd => vd.getType == UnitType)) { -// val freshFunDef = fd.duplicate(params = fd.params.filterNot(vd => vd.getType == UnitType)) -// fd -> freshFunDef -// } else { -// fd -> fd -// } -// fun2FreshFun += m -// m -// } -// for((fd, freshFunDef) <- fdtoFreshFd) { -// if(fd.params.exists(vd => vd.getType == UnitType)) { -// freshFunDef.fullBody = removeUnit(fd.fullBody) -// } else { -// fd.body = fd.body.map(b => removeUnit(b)) -// } -// } -// val rest = removeUnit(b) -// val newFds = for((fd, freshFunDef) <- fdtoFreshFd) yield { -// fun2FreshFun -= fd -// if(fd.params.exists(vd => vd.getType == UnitType)) { -// freshFunDef -// } else { -// fd -// } -// } -// -// letDef(newFds, rest) -// } -// -// case ite@IfExpr(cond, tExpr, eExpr) => -// val thenRec = removeUnit(tExpr) -// val elseRec = removeUnit(eExpr) -// IfExpr(removeUnit(cond), thenRec, elseRec) -// -// case v @ Variable(id) => -// if(id2FreshId.isDefinedAt(id)) -// Variable(id2FreshId(id)) -// else v -// -// case m @ MatchExpr(scrut, cses) => -// val scrutRec = removeUnit(scrut) -// val csesRec = cses.map{ cse => -// MatchCase(cse.pattern, cse.optGuard map removeUnit, removeUnit(cse.rhs)) -// } -// matchExpr(scrutRec, csesRec).setPos(m) -// -// case Operator(args, recons) => -// recons(args.map(removeUnit)) -// -// case _ => sys.error("not supported: " + expr) -// } -// } -// -//} diff --git a/src/main/scala/leon/xlang/XLangDesugaringPhase.scala b/src/main/scala/leon/xlang/XLangDesugaringPhase.scala deleted file mode 100644 index ae216dc62c8f92f715b6d464bd88e6a0fff307cd..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/xlang/XLangDesugaringPhase.scala +++ /dev/null @@ -1,31 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package leon -package xlang - -import utils._ -import purescala.Definitions.Program - -object XLangDesugaringPhase extends LeonPhase[Program, Program] { - - val name = "xlang desugaring" - val description = "Desugar xlang features into PureScala" - - override def run(ctx: LeonContext, pgm: Program): (LeonContext, Program) = { - - def debugTrees(title: String) = - PrintTreePhase(title).when(ctx.reporter.isDebugEnabled(DebugSectionTrees)) - - val phases = - debugTrees("Program before starting xlang-desugaring") andThen - IntroduceGlobalStatePhase andThen - debugTrees("Program after introduce-global-state") andThen - AntiAliasingPhase andThen - debugTrees("Program after anti-aliasing") andThen - EpsilonElimination andThen - ImperativeCodeElimination - - phases.run(ctx, pgm) - } - -}