diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h index 11f5b23e62c66..58bd61b2ae8b8 100644 --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -43,6 +43,7 @@ class GreedyRewriteConfig; #define GEN_PASS_DECL_SYMBOLDCE #define GEN_PASS_DECL_SYMBOLPRIVATIZE #define GEN_PASS_DECL_TOPOLOGICALSORT +#define GEN_PASS_DECL_COMPOSITEFIXEDPOINTPASS #include "mlir/Transforms/Passes.h.inc" /// Creates an instance of the Canonicalizer pass, configured with default @@ -130,6 +131,12 @@ createSymbolPrivatizePass(ArrayRef excludeSymbols = {}); /// their producers. std::unique_ptr createTopologicalSortPass(); +/// Create composite pass, which runs provided set of passes until fixed point +/// or maximum number of iterations reached. +std::unique_ptr createCompositeFixedPointPass( + std::string name, llvm::function_ref populateFunc, + int maxIterations = 10); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td index 51b2a27da639d..1b40a87c63f27 100644 --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -552,4 +552,21 @@ def TopologicalSort : Pass<"topological-sort"> { let constructor = "mlir::createTopologicalSortPass()"; } +def CompositeFixedPointPass : Pass<"composite-fixed-point-pass"> { + let summary = "Composite fixed point pass"; + let description = [{ + Composite pass runs provided set of passes until fixed point or maximum + number of iterations reached. + }]; + + let options = [ + Option<"name", "name", "std::string", /*default=*/"\"CompositeFixedPointPass\"", + "Composite pass display name">, + Option<"pipelineStr", "pipeline", "std::string", /*default=*/"", + "Composite pass inner pipeline">, + Option<"maxIter", "max-iterations", "int", /*default=*/"10", + "Maximum number of iterations if inner pipeline">, + ]; +} + #endif // MLIR_TRANSFORMS_PASSES diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt index 6c32ecf8a2a2f..90c0298fb5e46 100644 --- a/mlir/lib/Transforms/CMakeLists.txt +++ b/mlir/lib/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ add_subdirectory(Utils) add_mlir_library(MLIRTransforms Canonicalizer.cpp + CompositePass.cpp ControlFlowSink.cpp CSE.cpp GenerateRuntimeVerification.cpp diff --git a/mlir/lib/Transforms/CompositePass.cpp b/mlir/lib/Transforms/CompositePass.cpp new file mode 100644 index 0000000000000..b388a28da6424 --- /dev/null +++ b/mlir/lib/Transforms/CompositePass.cpp @@ -0,0 +1,105 @@ +//===- CompositePass.cpp - Composite pass code ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// CompositePass allows to run set of passes until fixed point is reached. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/Passes.h" + +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" + +namespace mlir { +#define GEN_PASS_DEF_COMPOSITEFIXEDPOINTPASS +#include "mlir/Transforms/Passes.h.inc" +} // namespace mlir + +using namespace mlir; + +namespace { +struct CompositeFixedPointPass final + : public impl::CompositeFixedPointPassBase { + using CompositeFixedPointPassBase::CompositeFixedPointPassBase; + + CompositeFixedPointPass( + std::string name_, llvm::function_ref populateFunc, + int maxIterations) { + name = std::move(name_); + maxIter = maxIterations; + populateFunc(dynamicPM); + + llvm::raw_string_ostream os(pipelineStr); + dynamicPM.printAsTextualPipeline(os); + } + + LogicalResult initializeOptions( + StringRef options, + function_ref errorHandler) override { + if (failed(CompositeFixedPointPassBase::initializeOptions(options, + errorHandler))) + return failure(); + + if (failed(parsePassPipeline(pipelineStr, dynamicPM))) + return errorHandler("Failed to parse composite pass pipeline"); + + return success(); + } + + LogicalResult initialize(MLIRContext *context) override { + if (maxIter <= 0) + return emitError(UnknownLoc::get(context)) + << "Invalid maxIterations value: " << maxIter << "\n"; + + return success(); + } + + void getDependentDialects(DialectRegistry ®istry) const override { + dynamicPM.getDependentDialects(registry); + } + + void runOnOperation() override { + auto op = getOperation(); + OperationFingerPrint fp(op); + + int currentIter = 0; + int maxIterVal = maxIter; + while (true) { + if (failed(runPipeline(dynamicPM, op))) + return signalPassFailure(); + + if (currentIter++ >= maxIterVal) { + op->emitWarning("Composite pass \"" + llvm::Twine(name) + + "\"+ didn't converge in " + llvm::Twine(maxIterVal) + + " iterations"); + break; + } + + OperationFingerPrint newFp(op); + if (newFp == fp) + break; + + fp = newFp; + } + } + +protected: + llvm::StringRef getName() const override { return name; } + +private: + OpPassManager dynamicPM; +}; +} // namespace + +std::unique_ptr mlir::createCompositeFixedPointPass( + std::string name, llvm::function_ref populateFunc, + int maxIterations) { + + return std::make_unique(std::move(name), + populateFunc, maxIterations); +} diff --git a/mlir/test/Transforms/composite-pass.mlir b/mlir/test/Transforms/composite-pass.mlir new file mode 100644 index 0000000000000..829470c2c9aa6 --- /dev/null +++ b/mlir/test/Transforms/composite-pass.mlir @@ -0,0 +1,26 @@ +// RUN: mlir-opt %s --log-actions-to=- --test-composite-fixed-point-pass -split-input-file | FileCheck %s +// RUN: mlir-opt %s --log-actions-to=- --composite-fixed-point-pass='name=TestCompositePass pipeline=any(canonicalize,cse)' -split-input-file | FileCheck %s + +// CHECK-LABEL: running `TestCompositePass` +// CHECK: running `Canonicalizer` +// CHECK: running `CSE` +// CHECK-NOT: running `Canonicalizer` +// CHECK-NOT: running `CSE` +func.func @test() { + return +} + +// ----- + +// CHECK-LABEL: running `TestCompositePass` +// CHECK: running `Canonicalizer` +// CHECK: running `CSE` +// CHECK: running `Canonicalizer` +// CHECK: running `CSE` +// CHECK-NOT: running `Canonicalizer` +// CHECK-NOT: running `CSE` +func.func @test() { +// this constant will be canonicalized away, causing another pass iteration + %0 = arith.constant 1.5 : f32 + return +} diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt index 2a3a8608db544..a849b7ebd29e2 100644 --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -20,6 +20,7 @@ endif() # Exclude tests from libMLIR.so add_mlir_library(MLIRTestTransforms TestCommutativityUtils.cpp + TestCompositePass.cpp TestConstantFold.cpp TestControlFlowSink.cpp TestInlining.cpp diff --git a/mlir/test/lib/Transforms/TestCompositePass.cpp b/mlir/test/lib/Transforms/TestCompositePass.cpp new file mode 100644 index 0000000000000..5c0d93cc0d64e --- /dev/null +++ b/mlir/test/lib/Transforms/TestCompositePass.cpp @@ -0,0 +1,38 @@ +//===------ TestCompositePass.cpp --- composite test pass -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to test the composite pass utility. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Transforms/Passes.h" + +namespace mlir { +namespace test { +void registerTestCompositePass() { + registerPassPipeline( + "test-composite-fixed-point-pass", "Test composite pass", + [](OpPassManager &pm, StringRef optionsStr, + function_ref errorHandler) { + if (!optionsStr.empty()) + return failure(); + + pm.addPass(createCompositeFixedPointPass( + "TestCompositePass", [](OpPassManager &p) { + p.addPass(createCanonicalizerPass()); + p.addPass(createCSEPass()); + })); + return success(); + }, + [](function_ref) {}); +} +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index 82b3881792bf3..6ce9f3041d6f4 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -68,6 +68,7 @@ void registerTosaTestQuantUtilAPIPass(); void registerVectorizerTestPass(); namespace test { +void registerTestCompositePass(); void registerCommutativityUtils(); void registerConvertCallOpPass(); void registerInliner(); @@ -195,6 +196,7 @@ void registerTestPasses() { registerVectorizerTestPass(); registerTosaTestQuantUtilAPIPass(); + mlir::test::registerTestCompositePass(); mlir::test::registerCommutativityUtils(); mlir::test::registerConvertCallOpPass(); mlir::test::registerInliner();