1
- // ===- Utils.cpp - Transform utilities ------------------------------------===//
2
- //
3
- // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
- // See https://llvm.org/LICENSE.txt for license information.
5
- // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
- //
7
- // ===----------------------------------------------------------------------===//
8
-
9
1
#include " mlir/Dialect/AMDGPU/Transforms/Utils.h"
10
2
11
3
#include " mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
15
7
using namespace mlir ;
16
8
using namespace mlir ::amdgpu;
17
9
10
+ // Define an interface for operations with indices
11
+ class IndicesInterface {
12
+ public:
13
+ virtual std::optional<Operation::operand_range> getIndices () = 0;
14
+ virtual void setIndices (ArrayRef<Value> indices) = 0;
15
+ virtual ~IndicesInterface () = default ;
16
+ };
17
+
18
+ // Implement a generic class that uses IndicesInterface
19
+ class OperationWithIndices : public IndicesInterface {
20
+ private:
21
+ Operation *op;
22
+ template <typename OpType>
23
+ static std::optional<Operation::operand_range> getIndicesImpl (Operation *op) {
24
+ if (auto specificOp = dyn_cast<OpType>(op))
25
+ return specificOp.getIndices ();
26
+ return std::nullopt;
27
+ }
28
+
29
+ template <typename OpType>
30
+ static void setIndicesImpl (Operation *op, ArrayRef<Value> indices) {
31
+ if (auto specificOp = dyn_cast<OpType>(op))
32
+ specificOp.getIndicesMutable ().assign (indices);
33
+ }
34
+
35
+ public:
36
+ OperationWithIndices (Operation *op) : op(op) {}
37
+
38
+ std::optional<Operation::operand_range> getIndices () override {
39
+ auto result = getIndicesImpl<memref::LoadOp>(op);
40
+ if (!result)
41
+ result = getIndicesImpl<memref::StoreOp>(op);
42
+ if (!result)
43
+ result = getIndicesImpl<vector::LoadOp>(op);
44
+ if (!result)
45
+ result = getIndicesImpl<vector::StoreOp>(op);
46
+ if (!result)
47
+ result = getIndicesImpl<vector::TransferReadOp>(op);
48
+ if (!result)
49
+ result = getIndicesImpl<vector::TransferWriteOp>(op);
50
+
51
+ return result;
52
+ }
53
+
54
+ void setIndices (ArrayRef<Value> indices) override {
55
+ setIndicesImpl<memref::LoadOp>(op, indices);
56
+ setIndicesImpl<memref::StoreOp>(op, indices);
57
+ setIndicesImpl<vector::LoadOp>(op, indices);
58
+ setIndicesImpl<vector::StoreOp>(op, indices);
59
+ setIndicesImpl<vector::TransferReadOp>(op, indices);
60
+ setIndicesImpl<vector::TransferWriteOp>(op, indices);
61
+ }
62
+ };
63
+
18
64
std::optional<Operation::operand_range> amdgpu::getIndices (Operation *op) {
19
- if (auto loadOp = dyn_cast<memref::LoadOp>(op))
20
- return loadOp.getIndices ();
21
- if (auto storeOp = dyn_cast<memref::StoreOp>(op))
22
- return storeOp.getIndices ();
23
- if (auto vectorReadOp = dyn_cast<vector::LoadOp>(op))
24
- return vectorReadOp.getIndices ();
25
- if (auto vectorStoreOp = dyn_cast<vector::StoreOp>(op))
26
- return vectorStoreOp.getIndices ();
27
- if (auto transferReadOp = dyn_cast<vector::TransferReadOp>(op))
28
- return transferReadOp.getIndices ();
29
- if (auto transferWriteOp = dyn_cast<vector::TransferWriteOp>(op))
30
- return transferWriteOp.getIndices ();
31
- return std::nullopt;
65
+ OperationWithIndices operationWithIndices (op);
66
+ return operationWithIndices.getIndices ();
32
67
}
33
68
34
69
void amdgpu::setIndices (Operation *op, ArrayRef<Value> indices) {
35
- if (auto loadOp = dyn_cast<memref::LoadOp>(op))
36
- return loadOp.getIndicesMutable ().assign (indices);
37
- if (auto storeOp = dyn_cast<memref::StoreOp>(op))
38
- return storeOp.getIndicesMutable ().assign (indices);
39
- if (auto vectorReadOp = dyn_cast<vector::LoadOp>(op))
40
- return vectorReadOp.getIndicesMutable ().assign (indices);
41
- if (auto vectorStoreOp = dyn_cast<vector::StoreOp>(op))
42
- return vectorStoreOp.getIndicesMutable ().assign (indices);
43
- if (auto transferReadOp = dyn_cast<vector::TransferReadOp>(op))
44
- return transferReadOp.getIndicesMutable ().assign (indices);
45
- if (auto transferWriteOp = dyn_cast<vector::TransferWriteOp>(op))
46
- return transferWriteOp.getIndicesMutable ().assign (indices);
47
- }
70
+ OperationWithIndices operationWithIndices (op);
71
+ operationWithIndices.setIndices (indices);
72
+ }
0 commit comments