Skip to content

Commit 455f71d

Browse files
authored
[mlir] Convert expand_shape to more static form (llvm#112265)
Add pattern that converts a `tensor.expand_shape` op to a more static form. This matches the pattern: `tensor.cast` -> `tensor.expand_shape` if it has a foldable `tensor.cast` and some constant foldable `output_shape` operands for the `tensor.expand_shape`. This makes the `tensor.expand_shape` more static, as well as allowing the static information to be propagated further down in the program.
1 parent 8c2e8b5 commit 455f71d

File tree

2 files changed

+136
-1
lines changed

2 files changed

+136
-1
lines changed

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "mlir/IR/TypeUtilities.h"
2525
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
2626
#include "mlir/Interfaces/LoopLikeInterface.h"
27+
#include "mlir/Support/LLVM.h"
2728
#include "llvm/ADT/DenseSet.h"
2829
#include "llvm/ADT/STLExtras.h"
2930
#include "llvm/ADT/SmallBitVector.h"
@@ -1982,14 +1983,94 @@ struct FoldDimOfCollapseShape : public OpRewritePattern<DimOp> {
19821983
return success();
19831984
}
19841985
};
1986+
1987+
/// Fold/sink a producer `tensor.cast` with a consumer `tensor.expand_shape` by
1988+
/// matching constant output_shape operands of the expand. This makes the
1989+
/// `tensor.expand_shape` more static and creates a consumer cast that can be
1990+
/// propagated further.
1991+
struct ConvertToStaticExpandShape : public OpRewritePattern<ExpandShapeOp> {
1992+
using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
1993+
1994+
LogicalResult matchAndRewrite(ExpandShapeOp expandOp,
1995+
PatternRewriter &rewriter) const override {
1996+
auto castOp = expandOp.getSrc().getDefiningOp<CastOp>();
1997+
if (!canFoldIntoConsumerOp(castOp))
1998+
return failure();
1999+
2000+
ArrayRef<int64_t> castSrcShape = castOp.getSource().getType().getShape();
2001+
SmallVector<ReassociationIndices, 4> reassoc =
2002+
expandOp.getReassociationIndices();
2003+
2004+
SmallVector<int64_t> newOutputShape(expandOp.getResultType().getShape());
2005+
SmallVector<Value> dynamicOutputShape;
2006+
auto outputIt = expandOp.getOutputShape().begin();
2007+
2008+
for (const auto &[inputDim, innerReassoc] : llvm::enumerate(reassoc)) {
2009+
for (uint64_t outDim : innerReassoc) {
2010+
if (!ShapedType::isDynamic(newOutputShape[outDim]))
2011+
continue;
2012+
2013+
// If the cast's src type is dynamic, don't infer any of the
2014+
// corresponding expanded dimensions. `tensor.expand_shape` requires at
2015+
// least one of the expanded dimensions to be dynamic if the input is
2016+
// dynamic.
2017+
Value val = *outputIt;
2018+
++outputIt;
2019+
if (ShapedType::isDynamic(castSrcShape[inputDim])) {
2020+
dynamicOutputShape.push_back(val);
2021+
continue;
2022+
}
2023+
2024+
APInt cst;
2025+
if (matchPattern(val, m_ConstantInt(&cst))) {
2026+
newOutputShape[outDim] = cst.getSExtValue();
2027+
} else {
2028+
dynamicOutputShape.push_back(val);
2029+
}
2030+
}
2031+
}
2032+
2033+
// Couldn't match any values, nothing to change
2034+
if (expandOp.getOutputShape().size() == dynamicOutputShape.size())
2035+
return failure();
2036+
2037+
// Calculate the input shape from the output
2038+
SmallVector<int64_t> newInputShape(expandOp.getSrcType().getRank(), 1l);
2039+
for (auto inDim : llvm::seq<int>(0, newInputShape.size())) {
2040+
for (auto outDim : reassoc[inDim]) {
2041+
auto ofr = newOutputShape[outDim];
2042+
if (ShapedType::isDynamic(ofr)) {
2043+
newInputShape[inDim] = ShapedType::kDynamic;
2044+
break;
2045+
}
2046+
newInputShape[inDim] *= ofr;
2047+
}
2048+
}
2049+
2050+
SmallVector<OpFoldResult> outputOfr =
2051+
getMixedValues(newOutputShape, dynamicOutputShape, rewriter);
2052+
auto inputType = RankedTensorType::get(
2053+
newInputShape, expandOp.getSrcType().getElementType());
2054+
auto outputType = RankedTensorType::get(
2055+
newOutputShape, expandOp.getSrcType().getElementType());
2056+
auto inputCast = rewriter.create<CastOp>(expandOp.getLoc(), inputType,
2057+
expandOp.getSrc());
2058+
auto newExpand = rewriter.create<ExpandShapeOp>(
2059+
expandOp.getLoc(), outputType, inputCast.getResult(),
2060+
expandOp.getReassociationIndices(), outputOfr);
2061+
rewriter.replaceOpWithNewOp<CastOp>(expandOp, expandOp.getType(),
2062+
newExpand.getResult());
2063+
return success();
2064+
}
2065+
};
19852066
} // namespace
19862067

19872068
void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
19882069
MLIRContext *context) {
19892070
results.add<
19902071
ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
19912072
ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
1992-
FoldReshapeWithConstant<ExpandShapeOp>,
2073+
ConvertToStaticExpandShape, FoldReshapeWithConstant<ExpandShapeOp>,
19932074
FoldReshapeWithSplat<ExpandShapeOp>,
19942075
FoldReshapeWithFromElements<ExpandShapeOp>, FoldDimOfExpandShape,
19952076
FoldDimOfCollapseShape>(context);

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2741,3 +2741,57 @@ func.func @pack_dont_drop_attributes(%arg0: tensor<?x?x?xf16>, %arg1: tensor<128
27412741
%pack = tensor.pack %arg0 padding_value(%cst : f16) outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 1] into %arg1 {test_attr} : tensor<?x?x?xf16> -> tensor<128x?x100x16x1xf16>
27422742
return %pack : tensor<128x?x100x16x1xf16>
27432743
}
2744+
2745+
// -----
2746+
2747+
func.func @fold_expand_of_cast(%arg0 : tensor<10x10xf32>)
2748+
-> tensor<10x1x10xf32> {
2749+
%c1 = arith.constant 1 : index
2750+
%c10 = arith.constant 10 : index
2751+
%0 = tensor.cast %arg0 : tensor<10x10xf32> to tensor<?x?xf32>
2752+
%1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%c10, %c1, %c10]
2753+
: tensor<?x?xf32> into tensor<?x?x?xf32>
2754+
%2 = tensor.cast %1 : tensor<?x?x?xf32> to tensor<10x1x10xf32>
2755+
return %2 : tensor<10x1x10xf32>
2756+
}
2757+
// CHECK-LABEL: func.func @fold_expand_of_cast
2758+
// CHECK: %[[RES:.+]] = tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]] output_shape [10, 1, 10]
2759+
// CHECK: return %[[RES]]
2760+
2761+
// -----
2762+
2763+
func.func @sink_expand_of_cast(%arg0 : tensor<?x10xf32>)
2764+
-> tensor<?x?x?xf32> {
2765+
%c1 = arith.constant 1 : index
2766+
%c10 = arith.constant 10 : index
2767+
%0 = tensor.cast %arg0 : tensor<?x10xf32> to tensor<?x?xf32>
2768+
%1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%c10, %c1, %c10]
2769+
: tensor<?x?xf32> into tensor<?x?x?xf32>
2770+
return %1 : tensor<?x?x?xf32>
2771+
}
2772+
// CHECK-LABEL: func.func @sink_expand_of_cast
2773+
// CHECK-DAG: %[[C10:.*]] = arith.constant 10
2774+
// CHECK-DAG: %[[C1:.*]] = arith.constant 1
2775+
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]]
2776+
// CHECK-SAME: output_shape [%[[C10]], %[[C1]], 10]
2777+
// CHECK: %[[RES:.+]] = tensor.cast %[[EXPAND]]
2778+
// CHECK: return %[[RES]]
2779+
2780+
// -----
2781+
2782+
func.func @partial_sink_expand_of_cast(%arg0 : tensor<10x10xf32>, %arg1 : index, %arg2 : index)
2783+
-> tensor<?x?x?xf32> {
2784+
%c10 = arith.constant 10 : index
2785+
%0 = tensor.cast %arg0 : tensor<10x10xf32> to tensor<?x?xf32>
2786+
%1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, %arg2, %c10]
2787+
: tensor<?x?xf32> into tensor<?x?x?xf32>
2788+
return %1 : tensor<?x?x?xf32>
2789+
}
2790+
// CHECK-LABEL: func.func @partial_sink_expand_of_cast
2791+
// CHECK: %[[CAST:.+]] = tensor.cast
2792+
// CHECK-SAME: tensor<10x10xf32> to tensor<?x10xf32>
2793+
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]]
2794+
// CHECK-SAME: output_shape [%{{.*}}, %{{.*}}, 10]
2795+
// CHECK: %[[RES:.+]] = tensor.cast %[[EXPAND]]
2796+
// CHECK-SAME: tensor<?x?x10xf32> to tensor<?x?x?xf32>
2797+
// CHECK: return %[[RES]]

0 commit comments

Comments
 (0)