Skip to content

Commit 2117677

Browse files
authored
[mlir] Fix bugs in expand_shape patterns after semantics changes (#94631)
After the `output_shape` field was added to `expand_shape` ops, dynamically sized expand shapes are now possible, but this was not accounted for in the folder. This PR tightens the constraints of the folder to fix this.
1 parent c886d66 commit 2117677

File tree

2 files changed

+110
-12
lines changed

2 files changed

+110
-12
lines changed

mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -85,21 +85,49 @@ bool isReassociationValid(ArrayRef<AffineMap> reassociation,
8585
template <typename ReshapeOpTy, typename InverseReshapeOpTy>
8686
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
8787
ArrayRef<Attribute> operands) {
88-
88+
// Fold identity reshape.
8989
if (reshapeOp.getSrcType() == reshapeOp.getType())
9090
return reshapeOp.getSrc();
9191

92-
// Fold producer-consumer reshape ops where the operand type of the
93-
// producer is same as the return type of the consumer.
94-
auto reshapeSrcOp =
95-
reshapeOp.getSrc().template getDefiningOp<InverseReshapeOpTy>();
96-
if (reshapeSrcOp && reshapeSrcOp.getSrcType() == reshapeOp.getResultType())
97-
return reshapeSrcOp.getSrc();
98-
9992
// Reshape of a constant can be replaced with a new constant.
10093
if (auto elements = dyn_cast_or_null<DenseElementsAttr>(operands.front()))
10194
return elements.reshape(cast<ShapedType>(reshapeOp.getResult().getType()));
10295

96+
// Fold if the producer reshape source has the same shape with at most 1
97+
// dynamic dimension.
98+
auto reshapeSrcOp =
99+
reshapeOp.getSrc().template getDefiningOp<InverseReshapeOpTy>();
100+
if (!reshapeSrcOp)
101+
return nullptr;
102+
auto srcType = reshapeSrcOp.getSrcType();
103+
auto resultType = reshapeOp.getResultType();
104+
if (srcType != resultType)
105+
return nullptr;
106+
107+
if (llvm::count_if(srcType.getShape(), ShapedType::isDynamic) < 2) {
108+
return reshapeSrcOp.getSrc();
109+
}
110+
111+
// Fold producer-consumer reshape ops when they are perfect inverses of each
112+
// other:
113+
// 1) Reassociation indices are equivalent.
114+
// 2) Boundary types are equivalent.
115+
// 3) No reassociations have more than 1 dynamic dimension, and reassociated
116+
// shapes are equal for each reassociation.
117+
auto reassociations = reshapeOp.getReassociationIndices();
118+
if (reassociations != reshapeSrcOp.getReassociationIndices())
119+
return nullptr;
120+
// If the reshapes are expanding and then collapsing, the ops can be folded
121+
// despite multiple dynamic dimensions.
122+
if (srcType.getRank() < reshapeSrcOp.getResultType().getRank())
123+
return reshapeSrcOp.getSrc();
124+
if (llvm::all_of(reassociations, [&](auto reInd) {
125+
ArrayRef<int64_t> srcSlice =
126+
srcType.getShape().slice(reInd.front(), reInd.size());
127+
return llvm::count_if(srcSlice, ShapedType::isDynamic) < 2;
128+
})) {
129+
return reshapeSrcOp.getSrc();
130+
}
103131
return nullptr;
104132
}
105133

@@ -360,10 +388,12 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
360388
resultShape.slice(resultIndices.front(), resultIndices.size());
361389

362390
if (srcSubShape.size() == resultSubShape.size()) {
363-
if (srcSubShape == resultSubShape)
391+
if (srcSubShape == resultSubShape &&
392+
llvm::count_if(srcSubShape, ShapedType::isDynamic) < 2) {
364393
composedReassociation.push_back(srcIndices);
365-
else
394+
} else {
366395
return std::nullopt;
396+
}
367397
}
368398

369399
// Find reassociation to collapse `srcSubShape` into `resultSubShape`.

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1139,7 +1139,7 @@ func.func @fold_collapse_of_expand(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32>
11391139
return %1 : tensor<12x4xf32>
11401140
}
11411141
// CHECK-LABEL: @fold_collapse_of_expand
1142-
// CHECK-NOT: linalg.{{.*}}shape
1142+
// CHECK-NOT: tensor.{{.*}}_shape
11431143

11441144
// -----
11451145

@@ -1152,7 +1152,75 @@ func.func @fold_collapse_of_expand_dynamic(%arg0 : tensor<?x?xf32>, %arg1: index
11521152
return %1 : tensor<?x?xf32>
11531153
}
11541154
// CHECK-LABEL: @fold_collapse_of_expand_dynamic
1155-
// CHECK-NOT: linalg.{{.*}}_shape
1155+
// CHECK-NOT: tensor.{{.*}}_shape
1156+
1157+
// -----
1158+
1159+
func.func @fold_collapse_of_expand_fully_dynamic(%arg0 : tensor<?x?xf32>, %arg1: index, %arg2: index, %arg3: index)
1160+
-> tensor<?x?xf32> {
1161+
%0 = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [%arg1, %arg2, %arg3]
1162+
: tensor<?x?xf32> into tensor<?x?x?xf32>
1163+
%1 = tensor.collapse_shape %0 [[0, 1], [2]]
1164+
: tensor<?x?x?xf32> into tensor<?x?xf32>
1165+
return %1 : tensor<?x?xf32>
1166+
}
1167+
// CHECK-LABEL: @fold_collapse_of_expand_fully_dynamic
1168+
// CHECK-NOT: tensor.{{.*}}_shape
1169+
1170+
// -----
1171+
1172+
func.func @no_fold_parallel_collapse_of_expand_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index, %arg4: index)
1173+
-> tensor<?x?x?xf32> {
1174+
%0 = tensor.expand_shape %arg0 [[0, 1], [2], [3]] output_shape [%arg1, %arg2, %arg3, %arg4]
1175+
: tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
1176+
%1 = tensor.collapse_shape %0 [[0], [1], [2, 3]]
1177+
: tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
1178+
return %1 : tensor<?x?x?xf32>
1179+
}
1180+
// CHECK-LABEL: @no_fold_parallel_collapse_of_expand_dynamic
1181+
// CHECK: tensor.expand_shape
1182+
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape
1183+
// CHECK: return %[[COLLAPSE]]
1184+
1185+
// -----
1186+
1187+
func.func @fold_expand_of_collapse(%arg0 : tensor<3x4x4xf32>) -> tensor<3x4x4xf32> {
1188+
%0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
1189+
: tensor<3x4x4xf32> into tensor<12x4xf32>
1190+
%1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [3, 4, 4]
1191+
: tensor<12x4xf32> into tensor<3x4x4xf32>
1192+
return %1 : tensor<3x4x4xf32>
1193+
}
1194+
// CHECK-LABEL: @fold_expand_of_collapse
1195+
// CHECK-NOT: tensor.{{.*}}_shape
1196+
1197+
// -----
1198+
1199+
func.func @fold_expand_of_collapse_dynamic(%arg0 : tensor<?x4x?xf32>, %arg1: index, %arg2: index)
1200+
-> tensor<?x4x?xf32> {
1201+
%0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
1202+
: tensor<?x4x?xf32> into tensor<?x?xf32>
1203+
%1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, 4, %arg2]
1204+
: tensor<?x?xf32> into tensor<?x4x?xf32>
1205+
return %1 : tensor<?x4x?xf32>
1206+
}
1207+
// CHECK-LABEL: @fold_expand_of_collapse_dynamic
1208+
// CHECK-NOT: tensor.{{.*}}_shape
1209+
1210+
// -----
1211+
1212+
func.func @no_fold_expand_of_collapse_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index)
1213+
-> tensor<?x?x?xf32> {
1214+
%0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
1215+
: tensor<?x?x?xf32> into tensor<?x?xf32>
1216+
%1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, %arg2, %arg3]
1217+
: tensor<?x?xf32> into tensor<?x?x?xf32>
1218+
return %1 : tensor<?x?x?xf32>
1219+
}
1220+
// CHECK-LABEL: @no_fold_expand_of_collapse_dynamic
1221+
// CHECK: tensor.collapse_shape
1222+
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape
1223+
// CHECK: return %[[EXPAND]]
11561224

11571225
// -----
11581226

0 commit comments

Comments
 (0)