Skip to content

Commit 085d629

Browse files
committed
Add interface for get and set indices functions
1 parent b722dd7 commit 085d629

File tree

1 file changed

+59
-34
lines changed

1 file changed

+59
-34
lines changed
Lines changed: 59 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,3 @@
1-
//===- Utils.cpp - Transform utilities ------------------------------------===//
2-
//
3-
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4-
// See https://llvm.org/LICENSE.txt for license information.
5-
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6-
//
7-
//===----------------------------------------------------------------------===//
8-
91
#include "mlir/Dialect/AMDGPU/Transforms/Utils.h"
102

113
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
@@ -15,33 +7,66 @@
157
using namespace mlir;
168
using namespace mlir::amdgpu;
179

10+
// Define an interface for operations with indices
11+
class IndicesInterface {
12+
public:
13+
virtual std::optional<Operation::operand_range> getIndices() = 0;
14+
virtual void setIndices(ArrayRef<Value> indices) = 0;
15+
virtual ~IndicesInterface() = default;
16+
};
17+
18+
// Implement a generic class that uses IndicesInterface
19+
class OperationWithIndices : public IndicesInterface {
20+
private:
21+
Operation *op;
22+
template <typename OpType>
23+
static std::optional<Operation::operand_range> getIndicesImpl(Operation *op) {
24+
if (auto specificOp = dyn_cast<OpType>(op))
25+
return specificOp.getIndices();
26+
return std::nullopt;
27+
}
28+
29+
template <typename OpType>
30+
static void setIndicesImpl(Operation *op, ArrayRef<Value> indices) {
31+
if (auto specificOp = dyn_cast<OpType>(op))
32+
specificOp.getIndicesMutable().assign(indices);
33+
}
34+
35+
public:
36+
OperationWithIndices(Operation *op) : op(op) {}
37+
38+
std::optional<Operation::operand_range> getIndices() override {
39+
auto result = getIndicesImpl<memref::LoadOp>(op);
40+
if (!result)
41+
result = getIndicesImpl<memref::StoreOp>(op);
42+
if (!result)
43+
result = getIndicesImpl<vector::LoadOp>(op);
44+
if (!result)
45+
result = getIndicesImpl<vector::StoreOp>(op);
46+
if (!result)
47+
result = getIndicesImpl<vector::TransferReadOp>(op);
48+
if (!result)
49+
result = getIndicesImpl<vector::TransferWriteOp>(op);
50+
51+
return result;
52+
}
53+
54+
void setIndices(ArrayRef<Value> indices) override {
55+
setIndicesImpl<memref::LoadOp>(op, indices);
56+
setIndicesImpl<memref::StoreOp>(op, indices);
57+
setIndicesImpl<vector::LoadOp>(op, indices);
58+
setIndicesImpl<vector::StoreOp>(op, indices);
59+
setIndicesImpl<vector::TransferReadOp>(op, indices);
60+
setIndicesImpl<vector::TransferWriteOp>(op, indices);
61+
}
62+
};
63+
1864
std::optional<Operation::operand_range> amdgpu::getIndices(Operation *op) {
19-
if (auto loadOp = dyn_cast<memref::LoadOp>(op))
20-
return loadOp.getIndices();
21-
if (auto storeOp = dyn_cast<memref::StoreOp>(op))
22-
return storeOp.getIndices();
23-
if (auto vectorReadOp = dyn_cast<vector::LoadOp>(op))
24-
return vectorReadOp.getIndices();
25-
if (auto vectorStoreOp = dyn_cast<vector::StoreOp>(op))
26-
return vectorStoreOp.getIndices();
27-
if (auto transferReadOp = dyn_cast<vector::TransferReadOp>(op))
28-
return transferReadOp.getIndices();
29-
if (auto transferWriteOp = dyn_cast<vector::TransferWriteOp>(op))
30-
return transferWriteOp.getIndices();
31-
return std::nullopt;
65+
OperationWithIndices operationWithIndices(op);
66+
return operationWithIndices.getIndices();
3267
}
3368

3469
void amdgpu::setIndices(Operation *op, ArrayRef<Value> indices) {
35-
if (auto loadOp = dyn_cast<memref::LoadOp>(op))
36-
return loadOp.getIndicesMutable().assign(indices);
37-
if (auto storeOp = dyn_cast<memref::StoreOp>(op))
38-
return storeOp.getIndicesMutable().assign(indices);
39-
if (auto vectorReadOp = dyn_cast<vector::LoadOp>(op))
40-
return vectorReadOp.getIndicesMutable().assign(indices);
41-
if (auto vectorStoreOp = dyn_cast<vector::StoreOp>(op))
42-
return vectorStoreOp.getIndicesMutable().assign(indices);
43-
if (auto transferReadOp = dyn_cast<vector::TransferReadOp>(op))
44-
return transferReadOp.getIndicesMutable().assign(indices);
45-
if (auto transferWriteOp = dyn_cast<vector::TransferWriteOp>(op))
46-
return transferWriteOp.getIndicesMutable().assign(indices);
47-
}
70+
OperationWithIndices operationWithIndices(op);
71+
operationWithIndices.setIndices(indices);
72+
}

0 commit comments

Comments
 (0)