Skip to content

Commit 5ce271e

Browse files
authored
[MLIR] TosaToLinalgNamed: Lower unsigned tosa.max_pool2d (llvm#123290)
This PR allows to lower **unsigned** `tosa.max_pool2d` to linalg. ``` // CHECK-LABEL: @max_pool_ui8 func.func @max_pool_ui8(%arg0: tensor<1x6x34x62xui8>) -> tensor<1x4x32x62xui8> { // CHECK: builtin.unrealized_conversion_cast {{.*}} : tensor<1x6x34x62xui8> to tensor<1x6x34x62xi8> // CHECK: arith.constant 0 // CHECK: linalg.pooling_nhwc_max_unsigned {{.*}} : (tensor<1x4x32x62xi8>) -> tensor<1x4x32x62xi8> // CHECK: builtin.unrealized_conversion_cast {{.*}} : tensor<1x4x32x62xi8> to tensor<1x4x32x62xui8> %0 = tosa.max_pool2d %arg0 {pad = array<i64: 0, 0, 0, 0>, kernel = array<i64: 3, 3>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xui8>) -> tensor<1x4x32x62xui8> return %0 : tensor<1x4x32x62xui8> } ``` It does this by - converting the MaxPool2dConverter from OpRewriterPattern to OpConversion Pattern - adjusting the padding value to the the minimum unsigned value when the max_pool is unsigned - lowering to `linalg.pooling_nhwc_max_unsigned` (which uses `arith.maxui`) when the max_pool is unsigned
1 parent d70f54f commit 5ce271e

File tree

4 files changed

+56
-18
lines changed

4 files changed

+56
-18
lines changed

mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ void populateTosaToLinalgConversionPatterns(const TypeConverter &converter,
5252

5353
/// Populates conversion passes from TOSA dialect to Linalg named operations.
5454
void populateTosaToLinalgNamedConversionPatterns(
55-
RewritePatternSet *patterns, const TosaToLinalgNamedOptions &options);
55+
const TypeConverter &converter, RewritePatternSet *patterns,
56+
const TosaToLinalgNamedOptions &options);
5657

5758
} // namespace tosa
5859
} // namespace mlir

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -695,17 +695,18 @@ class FullyConnectedConverter
695695
}
696696
};
697697

698-
class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
698+
class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
699699
public:
700-
using OpRewritePattern<tosa::MaxPool2dOp>::OpRewritePattern;
700+
using OpConversionPattern::OpConversionPattern;
701701

702702
// Compute the dynamic output sizes of the maxpool operation.
703703
static SmallVector<Value>
704-
computeDynamicOutputSizes(tosa::MaxPool2dOp op, PatternRewriter &rewriter) {
704+
computeDynamicOutputSizes(tosa::MaxPool2dOp op, OpAdaptor adaptor,
705+
ConversionPatternRewriter &rewriter) {
705706
TensorType resultTy = op.getType();
706707
Location loc = op.getLoc();
707708

708-
TypedValue<TensorType> input = op.getInput();
709+
Value input = adaptor.getInput();
709710
ArrayRef<int64_t> kernel = op.getKernel();
710711
ArrayRef<int64_t> pad = op.getPad();
711712
ArrayRef<int64_t> stride = op.getStride();
@@ -744,16 +745,22 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
744745
return dynamicDims;
745746
}
746747

747-
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
748-
PatternRewriter &rewriter) const final {
748+
LogicalResult
749+
matchAndRewrite(tosa::MaxPool2dOp op, OpAdaptor adaptor,
750+
ConversionPatternRewriter &rewriter) const final {
749751
Location loc = op.getLoc();
750-
TypedValue<TensorType> input = op.getInput();
751-
ShapedType inputTy = input.getType();
752+
Value input = adaptor.getInput();
753+
ShapedType inputTy = cast<ShapedType>(input.getType());
752754

753-
ShapedType resultTy = op.getType();
755+
bool isUnsigned = op.getType().getElementType().isUnsignedInteger();
756+
ShapedType resultTy =
757+
cast<ShapedType>(getTypeConverter()->convertType(op.getType()));
758+
if (!resultTy)
759+
return rewriter.notifyMatchFailure(op, "failed to convert type");
754760
Type resultETy = inputTy.getElementType();
755761

756-
SmallVector<Value> dynamicDims = computeDynamicOutputSizes(op, rewriter);
762+
SmallVector<Value> dynamicDims =
763+
computeDynamicOutputSizes(op, adaptor, rewriter);
757764

758765
// Determine what the initial value needs to be for the max pool op.
759766
TypedAttr initialAttr;
@@ -762,7 +769,10 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
762769
resultETy, APFloat::getLargest(
763770
cast<FloatType>(resultETy).getFloatSemantics(), true));
764771

765-
if (isa<IntegerType>(resultETy))
772+
else if (isUnsigned)
773+
initialAttr = rewriter.getIntegerAttr(
774+
resultETy, APInt::getZero(resultETy.getIntOrFloatBitWidth()));
775+
else if (isa<IntegerType>(resultETy))
766776
initialAttr = rewriter.getIntegerAttr(
767777
resultETy,
768778
APInt::getSignedMinValue(resultETy.getIntOrFloatBitWidth()));
@@ -798,9 +808,15 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
798808
Value fakeWindowDims =
799809
rewriter.create<tensor::EmptyOp>(loc, kernel, resultETy);
800810

801-
rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxOp>(
802-
op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
803-
filledEmptyTensor, strideAttr, dilationAttr);
811+
if (isUnsigned) {
812+
rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxUnsignedOp>(
813+
op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
814+
filledEmptyTensor, strideAttr, dilationAttr);
815+
} else {
816+
rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxOp>(
817+
op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
818+
filledEmptyTensor, strideAttr, dilationAttr);
819+
}
804820
return success();
805821
}
806822
};
@@ -1070,7 +1086,8 @@ class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
10701086
} // namespace
10711087

10721088
void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
1073-
RewritePatternSet *patterns, const TosaToLinalgNamedOptions &options) {
1089+
const TypeConverter &converter, RewritePatternSet *patterns,
1090+
const TosaToLinalgNamedOptions &options) {
10741091
if (options.preferConv2DKernelLayoutHWCF) {
10751092
patterns->add<ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcHwcfOp,
10761093
linalg::Conv2DNhwcHwcfQOp>>(
@@ -1085,10 +1102,13 @@ void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
10851102
ConvConverter<tosa::Conv3DOp, linalg::Conv3DNdhwcDhwcfOp, linalg::Conv3DNdhwcDhwcfQOp>,
10861103
DepthwiseConvConverter,
10871104
MatMulConverter,
1088-
MaxPool2dConverter,
10891105
AvgPool2dConverter,
10901106
FullyConnectedConverter,
10911107
TransposeConverter
10921108
>(patterns->getContext());
1109+
1110+
patterns->add<
1111+
MaxPool2dConverter
1112+
>(converter, patterns->getContext());
10931113
// clang-format on
10941114
}

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ struct TosaToLinalgNamed
4747
}
4848

4949
void runOnOperation() override {
50+
TypeConverter converter;
51+
tosa::populateTosaTypeConversion(converter);
52+
5053
RewritePatternSet patterns(&getContext());
5154
ConversionTarget target(getContext());
5255
target.addLegalDialect<linalg::LinalgDialect, tosa::TosaDialect,
@@ -67,7 +70,8 @@ struct TosaToLinalgNamed
6770
FunctionOpInterface func = getOperation();
6871
TosaToLinalgNamedOptions options;
6972
options.preferConv2DKernelLayoutHWCF = preferConv2DKernelLayoutHWCF;
70-
tosa::populateTosaToLinalgNamedConversionPatterns(&patterns, options);
73+
tosa::populateTosaToLinalgNamedConversionPatterns(converter, &patterns,
74+
options);
7175
if (failed(applyFullConversion(func, target, std::move(patterns))))
7276
signalPassFailure();
7377
}

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,19 @@ func.func @max_pool_i8(%arg0: tensor<1x6x34x62xi8>) -> () {
200200
return
201201
}
202202

203+
// CHECK-LABEL: @max_pool_ui8
204+
func.func @max_pool_ui8(%arg0: tensor<1x6x34x62xui8>) -> tensor<1x4x32x62xui8> {
205+
// CHECK: builtin.unrealized_conversion_cast {{.*}} : tensor<1x6x34x62xui8> to tensor<1x6x34x62xi8>
206+
// CHECK: arith.constant 0
207+
// CHECK: linalg.pooling_nhwc_max_unsigned
208+
// CHECK-SAME: ins({{.*}} : tensor<1x6x34x62xi8>, tensor<3x3xi8>)
209+
// CHECK-SAME: outs({{.*}} : tensor<1x4x32x62xi8>)
210+
// CHECK-SAME: -> tensor<1x4x32x62xi8>
211+
// CHECK: builtin.unrealized_conversion_cast {{.*}} : tensor<1x4x32x62xi8> to tensor<1x4x32x62xui8>
212+
%0 = tosa.max_pool2d %arg0 {pad = array<i64: 0, 0, 0, 0>, kernel = array<i64: 3, 3>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xui8>) -> tensor<1x4x32x62xui8>
213+
return %0 : tensor<1x4x32x62xui8>
214+
}
215+
203216
// CHECK-LABEL: @max_pool_i16
204217
func.func @max_pool_i16(%arg0: tensor<1x6x34x62xi16>) -> () {
205218
// CHECK: arith.constant -32768

0 commit comments

Comments
 (0)