|
24 | 24 | #include "mlir/IR/TypeUtilities.h"
|
25 | 25 | #include "mlir/Interfaces/DestinationStyleOpInterface.h"
|
26 | 26 | #include "mlir/Interfaces/LoopLikeInterface.h"
|
| 27 | +#include "mlir/Support/LLVM.h" |
27 | 28 | #include "llvm/ADT/DenseSet.h"
|
28 | 29 | #include "llvm/ADT/STLExtras.h"
|
29 | 30 | #include "llvm/ADT/SmallBitVector.h"
|
@@ -1982,14 +1983,94 @@ struct FoldDimOfCollapseShape : public OpRewritePattern<DimOp> {
|
1982 | 1983 | return success();
|
1983 | 1984 | }
|
1984 | 1985 | };
|
| 1986 | + |
| 1987 | +/// Fold/sink a producer `tensor.cast` with a consumer `tensor.expand_shape` by |
| 1988 | +/// matching constant output_shape operands of the expand. This makes the |
| 1989 | +/// `tensor.expand_shape` more static and creates a consumer cast that can be |
| 1990 | +/// propagated further. |
| 1991 | +struct ConvertToStaticExpandShape : public OpRewritePattern<ExpandShapeOp> { |
| 1992 | + using OpRewritePattern<ExpandShapeOp>::OpRewritePattern; |
| 1993 | + |
| 1994 | + LogicalResult matchAndRewrite(ExpandShapeOp expandOp, |
| 1995 | + PatternRewriter &rewriter) const override { |
| 1996 | + auto castOp = expandOp.getSrc().getDefiningOp<CastOp>(); |
| 1997 | + if (!canFoldIntoConsumerOp(castOp)) |
| 1998 | + return failure(); |
| 1999 | + |
| 2000 | + ArrayRef<int64_t> castSrcShape = castOp.getSource().getType().getShape(); |
| 2001 | + SmallVector<ReassociationIndices, 4> reassoc = |
| 2002 | + expandOp.getReassociationIndices(); |
| 2003 | + |
| 2004 | + SmallVector<int64_t> newOutputShape(expandOp.getResultType().getShape()); |
| 2005 | + SmallVector<Value> dynamicOutputShape; |
| 2006 | + auto outputIt = expandOp.getOutputShape().begin(); |
| 2007 | + |
| 2008 | + for (const auto &[inputDim, innerReassoc] : llvm::enumerate(reassoc)) { |
| 2009 | + for (uint64_t outDim : innerReassoc) { |
| 2010 | + if (!ShapedType::isDynamic(newOutputShape[outDim])) |
| 2011 | + continue; |
| 2012 | + |
| 2013 | + // If the cast's src type is dynamic, don't infer any of the |
| 2014 | + // corresponding expanded dimensions. `tensor.expand_shape` requires at |
| 2015 | + // least one of the expanded dimensions to be dynamic if the input is |
| 2016 | + // dynamic. |
| 2017 | + Value val = *outputIt; |
| 2018 | + ++outputIt; |
| 2019 | + if (ShapedType::isDynamic(castSrcShape[inputDim])) { |
| 2020 | + dynamicOutputShape.push_back(val); |
| 2021 | + continue; |
| 2022 | + } |
| 2023 | + |
| 2024 | + APInt cst; |
| 2025 | + if (matchPattern(val, m_ConstantInt(&cst))) { |
| 2026 | + newOutputShape[outDim] = cst.getSExtValue(); |
| 2027 | + } else { |
| 2028 | + dynamicOutputShape.push_back(val); |
| 2029 | + } |
| 2030 | + } |
| 2031 | + } |
| 2032 | + |
| 2033 | + // Couldn't match any values, nothing to change |
| 2034 | + if (expandOp.getOutputShape().size() == dynamicOutputShape.size()) |
| 2035 | + return failure(); |
| 2036 | + |
| 2037 | + // Calculate the input shape from the output |
| 2038 | + SmallVector<int64_t> newInputShape(expandOp.getSrcType().getRank(), 1l); |
| 2039 | + for (auto inDim : llvm::seq<int>(0, newInputShape.size())) { |
| 2040 | + for (auto outDim : reassoc[inDim]) { |
| 2041 | + auto ofr = newOutputShape[outDim]; |
| 2042 | + if (ShapedType::isDynamic(ofr)) { |
| 2043 | + newInputShape[inDim] = ShapedType::kDynamic; |
| 2044 | + break; |
| 2045 | + } |
| 2046 | + newInputShape[inDim] *= ofr; |
| 2047 | + } |
| 2048 | + } |
| 2049 | + |
| 2050 | + SmallVector<OpFoldResult> outputOfr = |
| 2051 | + getMixedValues(newOutputShape, dynamicOutputShape, rewriter); |
| 2052 | + auto inputType = RankedTensorType::get( |
| 2053 | + newInputShape, expandOp.getSrcType().getElementType()); |
| 2054 | + auto outputType = RankedTensorType::get( |
| 2055 | + newOutputShape, expandOp.getSrcType().getElementType()); |
| 2056 | + auto inputCast = rewriter.create<CastOp>(expandOp.getLoc(), inputType, |
| 2057 | + expandOp.getSrc()); |
| 2058 | + auto newExpand = rewriter.create<ExpandShapeOp>( |
| 2059 | + expandOp.getLoc(), outputType, inputCast.getResult(), |
| 2060 | + expandOp.getReassociationIndices(), outputOfr); |
| 2061 | + rewriter.replaceOpWithNewOp<CastOp>(expandOp, expandOp.getType(), |
| 2062 | + newExpand.getResult()); |
| 2063 | + return success(); |
| 2064 | + } |
| 2065 | +}; |
1985 | 2066 | } // namespace
|
1986 | 2067 |
|
1987 | 2068 | void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
1988 | 2069 | MLIRContext *context) {
|
1989 | 2070 | results.add<
|
1990 | 2071 | ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
|
1991 | 2072 | ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
|
1992 |
| - FoldReshapeWithConstant<ExpandShapeOp>, |
| 2073 | + ConvertToStaticExpandShape, FoldReshapeWithConstant<ExpandShapeOp>, |
1993 | 2074 | FoldReshapeWithSplat<ExpandShapeOp>,
|
1994 | 2075 | FoldReshapeWithFromElements<ExpandShapeOp>, FoldDimOfExpandShape,
|
1995 | 2076 | FoldDimOfCollapseShape>(context);
|
|
0 commit comments