16
16
using namespace mlir ;
17
17
using namespace mlir ::tensor;
18
18
19
- // / Compute a map that for a given dimension of the expanded type gives the
20
- // / dimension in the collapsed type it maps to. Essentially its the inverse of
21
- // / the `reassocation` maps.
22
- static llvm::DenseMap<int64_t , int64_t >
23
- getExpandedDimToCollapsedDimMap (ArrayRef<AffineMap> reassociation) {
24
- llvm::DenseMap<int64_t , int64_t > expandedDimToCollapsedDim;
25
- for (const auto &map : enumerate(reassociation)) {
26
- unsigned startPos =
27
- cast<AffineDimExpr>(map.value ().getResults ().front ()).getPosition ();
28
- unsigned endPos =
29
- cast<AffineDimExpr>(map.value ().getResults ().back ()).getPosition ();
30
- for (auto dim : llvm::seq_inclusive (startPos, endPos)) {
31
- expandedDimToCollapsedDim[dim] = map.index ();
32
- }
33
- }
34
- return expandedDimToCollapsedDim;
35
- }
36
-
37
19
// / For reshape op compute the shape at dimension `dimIndex` of the output in
38
20
// / terms of shape of the `src`, when the reshape op is a collapsing
39
21
// / operation. It is the product of the shape of the collapsed dimensions of the
@@ -76,84 +58,15 @@ static SmallVector<OpFoldResult, 4> getCollapsedOutputShapeFromInputShape(
76
58
}));
77
59
}
78
60
79
- // / For an expanding reshape op, compute the value for a dimension of the output
80
- // / from the shape of the input.
81
- static OpFoldResult getExpandedOutputDimFromInputShape (
82
- OpBuilder &builder, Location loc, int64_t dimIndex, Value src,
83
- ArrayRef<int64_t > dstStaticShape, ArrayRef<AffineMap> reassociation,
84
- llvm::DenseMap<int64_t , int64_t > &expandedDimToCollapsedDim) {
85
- if (!ShapedType::isDynamic (dstStaticShape[dimIndex])) {
86
- // Static dimension: return Attribute.
87
- return builder.getIndexAttr (dstStaticShape[dimIndex]);
88
- }
89
- unsigned sourceDimPos = expandedDimToCollapsedDim[dimIndex];
90
- unsigned startPos =
91
- cast<AffineDimExpr>(reassociation[sourceDimPos].getResults ().front ())
92
- .getPosition ();
93
- unsigned endPos =
94
- cast<AffineDimExpr>(reassociation[sourceDimPos].getResults ().back ())
95
- .getPosition ();
96
- int64_t linearizedStaticDim = 1 ;
97
- for (auto d :
98
- llvm::enumerate (dstStaticShape.slice (startPos, endPos - startPos + 1 ))) {
99
- if (d.index () + startPos == static_cast <unsigned >(dimIndex))
100
- continue ;
101
- assert (!ShapedType::isDynamic (d.value ()) &&
102
- " single dimension cannot be expanded into multiple dynamic "
103
- " dimensions" );
104
- linearizedStaticDim *= d.value ();
105
- }
106
- OpFoldResult sourceDim =
107
- builder.create <tensor::DimOp>(loc, src, sourceDimPos).getResult ();
108
-
109
- // Dynamic dimension: return Value.
110
- return affine::makeComposedAffineApply (
111
- builder, loc,
112
- AffineMap::get (
113
- 0 , 1 ,
114
- builder.getAffineSymbolExpr (0 ).floorDiv (linearizedStaticDim)),
115
- sourceDim)
116
- ->getResult (0 );
117
- }
118
-
119
- // / Given the `src` of an expanding reshape op, the reassociation maps and the
120
- // / result type, compute the shape of the result of the reshape.
121
- static SmallVector<OpFoldResult, 4 > getExpandedOutputShapeFromInputShape (
122
- OpBuilder &builder, Location loc, Value src,
123
- ArrayRef<int64_t > dstStaticShape, ArrayRef<AffineMap> reassociation) {
124
- llvm::DenseMap<int64_t , int64_t > expandedDimToCollapsedDim =
125
- getExpandedDimToCollapsedDimMap (reassociation);
126
- return llvm::to_vector<4 >(llvm::map_range (
127
- llvm::seq<int64_t >(0 , dstStaticShape.size ()), [&](int64_t dim) {
128
- return getExpandedOutputDimFromInputShape (builder, loc, dim, src,
129
- dstStaticShape, reassociation,
130
- expandedDimToCollapsedDim);
131
- }));
132
- }
133
-
134
- static SmallVector<OpFoldResult, 4 >
135
- getReshapeOutputShapeFromInputShape (OpBuilder &builder, Location loc, Value src,
136
- ArrayRef<int64_t > dstStaticShape,
137
- ArrayRef<AffineMap> reassocation) {
138
- return dstStaticShape.size () >
139
- static_cast <size_t >(
140
- llvm::cast<ShapedType>(src.getType ()).getRank ())
141
- ? getExpandedOutputShapeFromInputShape (
142
- builder, loc, src, dstStaticShape, reassocation)
143
- : getCollapsedOutputShapeFromInputShape (
144
- builder, loc, src, dstStaticShape, reassocation);
145
- }
146
-
147
- template <typename OpTy>
148
- struct ReifyExpandOrCollapseShapeOp
61
+ struct ReifyCollapseShapeOp
149
62
: public ReifyRankedShapedTypeOpInterface::ExternalModel<
150
- ReifyExpandOrCollapseShapeOp<OpTy>, OpTy > {
63
+ ReifyCollapseShapeOp, CollapseShapeOp > {
151
64
LogicalResult
152
65
reifyResultShapes (Operation *op, OpBuilder &b,
153
66
ReifiedRankedShapedTypeDims &reifiedReturnShapes) const {
154
67
auto loc = op->getLoc ();
155
- auto reshapeOp = cast<OpTy >(op);
156
- reifiedReturnShapes.push_back (getReshapeOutputShapeFromInputShape (
68
+ auto reshapeOp = cast<tensor::CollapseShapeOp >(op);
69
+ reifiedReturnShapes.push_back (getCollapsedOutputShapeFromInputShape (
157
70
b, loc, reshapeOp.getSrc (), reshapeOp.getResultType ().getShape (),
158
71
reshapeOp.getReassociationMaps ()));
159
72
return success ();
@@ -162,6 +75,20 @@ struct ReifyExpandOrCollapseShapeOp
162
75
163
76
namespace {
164
77
78
+ struct ReifyExpandShapeOp
79
+ : public ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyExpandShapeOp,
80
+ ExpandShapeOp> {
81
+ LogicalResult
82
+ reifyResultShapes (Operation *op, OpBuilder &b,
83
+ ReifiedRankedShapedTypeDims &reifyResultShapes) const {
84
+ auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
85
+ SmallVector<OpFoldResult> resultShapes =
86
+ expandShapeOp.getMixedOutputShape ();
87
+ reifyResultShapes.emplace_back (std::move (resultShapes));
88
+ return success ();
89
+ }
90
+ };
91
+
165
92
struct ReifyPadOp
166
93
: public ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyPadOp,
167
94
PadOp> {
@@ -202,10 +129,8 @@ struct ReifyPadOp
202
129
void mlir::tensor::registerInferTypeOpInterfaceExternalModels (
203
130
DialectRegistry ®istry) {
204
131
registry.addExtension (+[](MLIRContext *ctx, TensorDialect *dialect) {
205
- ExpandShapeOp::attachInterface<
206
- ReifyExpandOrCollapseShapeOp<tensor::ExpandShapeOp>>(*ctx);
207
- CollapseShapeOp::attachInterface<
208
- ReifyExpandOrCollapseShapeOp<tensor::CollapseShapeOp>>(*ctx);
132
+ ExpandShapeOp::attachInterface<ReifyExpandShapeOp>(*ctx);
133
+ CollapseShapeOp::attachInterface<ReifyCollapseShapeOp>(*ctx);
209
134
PadOp::attachInterface<ReifyPadOp>(*ctx);
210
135
});
211
136
}
0 commit comments