Skip to content

Commit 5b66b6a

Browse files
authored
[mlir][pass] Add composite pass utility (#87166)
Composite pass allows to run sequence of passes in the loop until fixed point or maximum number of iterations is reached. The usual candidates are canonicalize+CSE as canonicalize can open more opportunities for CSE and vice-versa.
1 parent 1d06f41 commit 5b66b6a

File tree

8 files changed

+197
-0
lines changed

8 files changed

+197
-0
lines changed

mlir/include/mlir/Transforms/Passes.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class GreedyRewriteConfig;
4343
#define GEN_PASS_DECL_SYMBOLDCE
4444
#define GEN_PASS_DECL_SYMBOLPRIVATIZE
4545
#define GEN_PASS_DECL_TOPOLOGICALSORT
46+
#define GEN_PASS_DECL_COMPOSITEFIXEDPOINTPASS
4647
#include "mlir/Transforms/Passes.h.inc"
4748

4849
/// Creates an instance of the Canonicalizer pass, configured with default
@@ -130,6 +131,12 @@ createSymbolPrivatizePass(ArrayRef<std::string> excludeSymbols = {});
130131
/// their producers.
131132
std::unique_ptr<Pass> createTopologicalSortPass();
132133

134+
/// Create composite pass, which runs provided set of passes until fixed point
135+
/// or maximum number of iterations reached.
136+
std::unique_ptr<Pass> createCompositeFixedPointPass(
137+
std::string name, llvm::function_ref<void(OpPassManager &)> populateFunc,
138+
int maxIterations = 10);
139+
133140
//===----------------------------------------------------------------------===//
134141
// Registration
135142
//===----------------------------------------------------------------------===//

mlir/include/mlir/Transforms/Passes.td

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,4 +552,21 @@ def TopologicalSort : Pass<"topological-sort"> {
552552
let constructor = "mlir::createTopologicalSortPass()";
553553
}
554554

555+
def CompositeFixedPointPass : Pass<"composite-fixed-point-pass"> {
556+
let summary = "Composite fixed point pass";
557+
let description = [{
558+
Composite pass runs provided set of passes until fixed point or maximum
559+
number of iterations reached.
560+
}];
561+
562+
let options = [
563+
Option<"name", "name", "std::string", /*default=*/"\"CompositeFixedPointPass\"",
564+
"Composite pass display name">,
565+
Option<"pipelineStr", "pipeline", "std::string", /*default=*/"",
566+
"Composite pass inner pipeline">,
567+
Option<"maxIter", "max-iterations", "int", /*default=*/"10",
568+
"Maximum number of iterations if inner pipeline">,
569+
];
570+
}
571+
555572
#endif // MLIR_TRANSFORMS_PASSES

mlir/lib/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ add_subdirectory(Utils)
22

33
add_mlir_library(MLIRTransforms
44
Canonicalizer.cpp
5+
CompositePass.cpp
56
ControlFlowSink.cpp
67
CSE.cpp
78
GenerateRuntimeVerification.cpp

mlir/lib/Transforms/CompositePass.cpp

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
//===- CompositePass.cpp - Composite pass code ----------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// CompositePass allows to run set of passes until fixed point is reached.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Transforms/Passes.h"
14+
15+
#include "mlir/Pass/Pass.h"
16+
#include "mlir/Pass/PassManager.h"
17+
18+
namespace mlir {
19+
#define GEN_PASS_DEF_COMPOSITEFIXEDPOINTPASS
20+
#include "mlir/Transforms/Passes.h.inc"
21+
} // namespace mlir
22+
23+
using namespace mlir;
24+
25+
namespace {
26+
struct CompositeFixedPointPass final
27+
: public impl::CompositeFixedPointPassBase<CompositeFixedPointPass> {
28+
using CompositeFixedPointPassBase::CompositeFixedPointPassBase;
29+
30+
CompositeFixedPointPass(
31+
std::string name_, llvm::function_ref<void(OpPassManager &)> populateFunc,
32+
int maxIterations) {
33+
name = std::move(name_);
34+
maxIter = maxIterations;
35+
populateFunc(dynamicPM);
36+
37+
llvm::raw_string_ostream os(pipelineStr);
38+
dynamicPM.printAsTextualPipeline(os);
39+
}
40+
41+
LogicalResult initializeOptions(
42+
StringRef options,
43+
function_ref<LogicalResult(const Twine &)> errorHandler) override {
44+
if (failed(CompositeFixedPointPassBase::initializeOptions(options,
45+
errorHandler)))
46+
return failure();
47+
48+
if (failed(parsePassPipeline(pipelineStr, dynamicPM)))
49+
return errorHandler("Failed to parse composite pass pipeline");
50+
51+
return success();
52+
}
53+
54+
LogicalResult initialize(MLIRContext *context) override {
55+
if (maxIter <= 0)
56+
return emitError(UnknownLoc::get(context))
57+
<< "Invalid maxIterations value: " << maxIter << "\n";
58+
59+
return success();
60+
}
61+
62+
void getDependentDialects(DialectRegistry &registry) const override {
63+
dynamicPM.getDependentDialects(registry);
64+
}
65+
66+
void runOnOperation() override {
67+
auto op = getOperation();
68+
OperationFingerPrint fp(op);
69+
70+
int currentIter = 0;
71+
int maxIterVal = maxIter;
72+
while (true) {
73+
if (failed(runPipeline(dynamicPM, op)))
74+
return signalPassFailure();
75+
76+
if (currentIter++ >= maxIterVal) {
77+
op->emitWarning("Composite pass \"" + llvm::Twine(name) +
78+
"\"+ didn't converge in " + llvm::Twine(maxIterVal) +
79+
" iterations");
80+
break;
81+
}
82+
83+
OperationFingerPrint newFp(op);
84+
if (newFp == fp)
85+
break;
86+
87+
fp = newFp;
88+
}
89+
}
90+
91+
protected:
92+
llvm::StringRef getName() const override { return name; }
93+
94+
private:
95+
OpPassManager dynamicPM;
96+
};
97+
} // namespace
98+
99+
std::unique_ptr<Pass> mlir::createCompositeFixedPointPass(
100+
std::string name, llvm::function_ref<void(OpPassManager &)> populateFunc,
101+
int maxIterations) {
102+
103+
return std::make_unique<CompositeFixedPointPass>(std::move(name),
104+
populateFunc, maxIterations);
105+
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// RUN: mlir-opt %s --log-actions-to=- --test-composite-fixed-point-pass -split-input-file | FileCheck %s
2+
// RUN: mlir-opt %s --log-actions-to=- --composite-fixed-point-pass='name=TestCompositePass pipeline=any(canonicalize,cse)' -split-input-file | FileCheck %s
3+
4+
// CHECK-LABEL: running `TestCompositePass`
5+
// CHECK: running `Canonicalizer`
6+
// CHECK: running `CSE`
7+
// CHECK-NOT: running `Canonicalizer`
8+
// CHECK-NOT: running `CSE`
9+
func.func @test() {
10+
return
11+
}
12+
13+
// -----
14+
15+
// CHECK-LABEL: running `TestCompositePass`
16+
// CHECK: running `Canonicalizer`
17+
// CHECK: running `CSE`
18+
// CHECK: running `Canonicalizer`
19+
// CHECK: running `CSE`
20+
// CHECK-NOT: running `Canonicalizer`
21+
// CHECK-NOT: running `CSE`
22+
func.func @test() {
23+
// this constant will be canonicalized away, causing another pass iteration
24+
%0 = arith.constant 1.5 : f32
25+
return
26+
}

mlir/test/lib/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ endif()
2020
# Exclude tests from libMLIR.so
2121
add_mlir_library(MLIRTestTransforms
2222
TestCommutativityUtils.cpp
23+
TestCompositePass.cpp
2324
TestConstantFold.cpp
2425
TestControlFlowSink.cpp
2526
TestInlining.cpp
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
//===------ TestCompositePass.cpp --- composite test pass -----------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements a pass to test the composite pass utility.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Pass/Pass.h"
14+
#include "mlir/Pass/PassManager.h"
15+
#include "mlir/Pass/PassRegistry.h"
16+
#include "mlir/Transforms/Passes.h"
17+
18+
namespace mlir {
19+
namespace test {
20+
void registerTestCompositePass() {
21+
registerPassPipeline(
22+
"test-composite-fixed-point-pass", "Test composite pass",
23+
[](OpPassManager &pm, StringRef optionsStr,
24+
function_ref<LogicalResult(const Twine &)> errorHandler) {
25+
if (!optionsStr.empty())
26+
return failure();
27+
28+
pm.addPass(createCompositeFixedPointPass(
29+
"TestCompositePass", [](OpPassManager &p) {
30+
p.addPass(createCanonicalizerPass());
31+
p.addPass(createCSEPass());
32+
}));
33+
return success();
34+
},
35+
[](function_ref<void(const detail::PassOptions &)>) {});
36+
}
37+
} // namespace test
38+
} // namespace mlir

mlir/tools/mlir-opt/mlir-opt.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ void registerTosaTestQuantUtilAPIPass();
6868
void registerVectorizerTestPass();
6969

7070
namespace test {
71+
void registerTestCompositePass();
7172
void registerCommutativityUtils();
7273
void registerConvertCallOpPass();
7374
void registerInliner();
@@ -195,6 +196,7 @@ void registerTestPasses() {
195196
registerVectorizerTestPass();
196197
registerTosaTestQuantUtilAPIPass();
197198

199+
mlir::test::registerTestCompositePass();
198200
mlir::test::registerCommutativityUtils();
199201
mlir::test::registerConvertCallOpPass();
200202
mlir::test::registerInliner();

0 commit comments

Comments
 (0)